Source code for jina.orchestrate.pods

import argparse
import multiprocessing
import os
import time
from abc import ABC, abstractmethod
from typing import Dict, Optional, Type, Union, TYPE_CHECKING

from jina.constants import __ready_msg__, __stop_msg__, __windows__
from jina.enums import PodRoleType
from jina.excepts import RuntimeFailToStart, RuntimeRunForeverEarlyError
from jina.helper import typename
from jina.jaml import JAML
from jina.logging.logger import JinaLogger
from jina.orchestrate.pods.helper import ConditionalEvent, _get_event
from jina.parsers.helper import _update_gateway_args
from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime
from jina.serve.runtimes.gateway import GatewayRuntime

if TYPE_CHECKING:
    import threading

__all__ = ['BasePod', 'Pod']


def run(
        args: 'argparse.Namespace',
        name: str,
        runtime_cls: Type[AsyncNewLoopRuntime],
        envs: Dict[str, str],
        is_started: Union['multiprocessing.Event', 'threading.Event'],
        is_shutdown: Union['multiprocessing.Event', 'threading.Event'],
        is_ready: Union['multiprocessing.Event', 'threading.Event'],
        jaml_classes: Optional[Dict] = None,
):
    """Method representing the :class:`BaseRuntime` activity.

    This method is the target for the Pod's `thread` or `process`

    .. note::
        :meth:`run` is running in subprocess/thread, the exception can not be propagated to the main process.
        Hence, please do not raise any exception here.

    .. note::
        Please note that env variables are process-specific. Subprocess inherits envs from
        the main process. But Subprocess's envs do NOT affect the main process. It does NOT
        mess up user local system envs.

    .. warning::
        If you are using ``thread`` as backend, envs setting will likely be overidden by others

    .. note::
        `jaml_classes` contains all the :class:`JAMLCompatible` classes registered in the main process.
        When using `spawn` as the multiprocessing start method, passing this argument to `run` method re-imports
        & re-registers all `JAMLCompatible` classes.

    :param args: namespace args from the Pod
    :param name: name of the Pod to have proper logging
    :param runtime_cls: the runtime class to instantiate
    :param envs: a dictionary of environment variables to be set in the new Process
    :param is_started: concurrency event to communicate runtime is properly started. Used for better logging
    :param is_shutdown: concurrency event to communicate runtime is terminated
    :param is_ready: concurrency event to communicate runtime is ready to receive messages
    :param jaml_classes: all the `JAMLCompatible` classes imported in main process
    """
    logger = JinaLogger(name, **vars(args))

    def _unset_envs():
        if envs:
            for k in envs.keys():
                os.environ.pop(k, None)

    def _set_envs():
        if args.env:
            os.environ.update({k: str(v) for k, v in envs.items()})

    try:
        _set_envs()

        runtime = runtime_cls(
            args=args,
        )
    except Exception as ex:
        logger.error(
            f'{ex!r} during {runtime_cls!r} initialization'
            + f'\n add "--quiet-error" to suppress the exception details'
            if not args.quiet_error
            else '',
            exc_info=not args.quiet_error,
        )
    else:
        if not is_shutdown.is_set():
            is_started.set()
            with runtime:
                is_ready.set()
                runtime.run_forever()
    finally:
        _unset_envs()
        is_shutdown.set()
        logger.debug(f'process terminated')


[docs]class BasePod(ABC): """ :class:`BasePod` is an interface from which all the classes managing the lifetime of a Runtime inside a local process, container must inherit. It exposes the required APIs so that the `BasePod` can be handled by the `cli` api as a context manager or by a `Deployment`. What makes a BasePod a BasePod is that it manages the lifecycle of a Runtime (gateway or not gateway) """ def __init__(self, args: 'argparse.Namespace'): self.args = args if self.args.pod_role == PodRoleType.GATEWAY: _update_gateway_args(self.args) self.args.parallel = getattr(self.args, 'shards', 1) self.name = self.args.name or self.__class__.__name__ self.is_forked = False self.logger = JinaLogger(self.name, **vars(self.args)) self._envs = {'JINA_DEPLOYMENT_NAME': self.name} if self.args.quiet: self._envs['JINA_LOG_CONFIG'] = 'QUIET' if self.args.env: self._envs.update(self.args.env) # arguments needed to create `runtime` and communicate with it in the `run` in the stack of the new process # or thread.f test_worker = multiprocessing.Process() self.is_ready = _get_event(test_worker) self.is_shutdown = _get_event(test_worker) self.cancel_event = _get_event(test_worker) self.is_started = _get_event(test_worker) self.ready_or_shutdown = ConditionalEvent( events_list=[self.is_ready, self.is_shutdown], ) self.runtime_ctrl_address = self._get_control_address() self._timeout_ctrl = self.args.timeout_ctrl def _get_control_address(self): if self.args.pod_role == PodRoleType.GATEWAY: return f'{self.args.host}:{self.args.port[0]}' return f'{self.args.host}:{self.args.port}'
[docs] def close(self) -> None: """Close the Pod This method makes sure that the `Process` is properly finished and its resources properly released """ self.logger.debug('waiting for ready or shutdown signal from runtime') if not self.is_shutdown.is_set() and self.is_started.is_set(): try: self.logger.debug(f'terminate') self._terminate() if not self.is_shutdown.wait( timeout=self._timeout_ctrl if not __windows__ else 1.0 ): if not __windows__: raise Exception( f'Shutdown signal was not received for {self._timeout_ctrl} seconds' ) else: self.logger.warning( 'Pod was forced to close after 1 second. Graceful closing is not available on Windows.' ) except Exception as ex: self.logger.error( f'{ex!r} during {self.close!r}' + f'\n add "--quiet-error" to suppress the exception details' if not self.args.quiet_error else '', exc_info=not self.args.quiet_error, ) else: # here shutdown has been set already, therefore `run` will gracefully finish self.logger.debug( f'{"shutdown is is already set" if self.is_shutdown.is_set() else "Runtime was never started"}. Runtime will end gracefully on its own' ) self._terminate() self.is_shutdown.set() self.logger.debug(__stop_msg__) self.logger.close()
def __enter__(self): return self.start() def __exit__(self, exc_type, exc_val, exc_tb): self.close() def _wait_for_ready_or_shutdown(self, timeout: Optional[float]): """ Waits for the process to be ready or to know it has failed. :param timeout: The time to wait before readiness or failure is determined .. # noqa: DAR201 """ if self.args.pod_role == PodRoleType.GATEWAY: return GatewayRuntime.wait_for_ready_or_shutdown( timeout=timeout, ready_or_shutdown_event=self.ready_or_shutdown.event, ctrl_address=self.runtime_ctrl_address, timeout_ctrl=self._timeout_ctrl, protocol=self.args.protocol[0], ) else: return AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=timeout, ready_or_shutdown_event=self.ready_or_shutdown.event, ctrl_address=self.runtime_ctrl_address, timeout_ctrl=self._timeout_ctrl, ) def _fail_start_timeout(self, timeout): """ Closes the Pod and raises a TimeoutError with the corresponding warning messages :param timeout: The time to wait before readiness or failure is determined .. # noqa: DAR201 """ _timeout = timeout or -1 self.logger.warning( f'{self} timeout after waiting for {self.args.timeout_ready}ms, ' f'if your executor takes time to load, you may increase --timeout-ready' ) self.close() raise TimeoutError( f'{typename(self)}:{self.name} can not be initialized after {_timeout * 1e3}ms' ) def _check_failed_to_start(self): """ Raises a corresponding exception if failed to start """ if self.is_shutdown.is_set(): # return too early and the shutdown is set, means something fails!! if not self.is_started.is_set(): raise RuntimeFailToStart else: raise RuntimeRunForeverEarlyError
[docs] def wait_start_success(self): """Block until all pods starts successfully. If not success, it will raise an error hoping the outer function to catch it """ _timeout = self.args.timeout_ready if _timeout <= 0: _timeout = None else: _timeout /= 1e3 if self._wait_for_ready_or_shutdown(_timeout): self._check_failed_to_start() self.logger.debug(__ready_msg__) else: self._fail_start_timeout(_timeout)
[docs] async def async_wait_start_success(self): """ Wait for the `Pod` to start successfully in a non-blocking manner """ import asyncio _timeout = self.args.timeout_ready if _timeout <= 0: _timeout = None else: _timeout /= 1e3 timeout_ns = 1e9 * _timeout if _timeout else None now = time.time_ns() while timeout_ns is None or time.time_ns() - now < timeout_ns: if self.ready_or_shutdown.event.is_set(): self._check_failed_to_start() self.logger.debug(__ready_msg__) return else: await asyncio.sleep(0.1) self._fail_start_timeout(_timeout)
@property def role(self) -> 'PodRoleType': """Get the role of this pod in a deployment .. #noqa: DAR201""" return self.args.pod_role
[docs] @abstractmethod def start(self): """Start the BasePod. This method calls :meth:`start` in :class:`multiprocesssing.Process`. .. #noqa: DAR201 """ ...
@abstractmethod def _terminate(self): ...
[docs] @abstractmethod def join(self, *args, **kwargs): """Joins the BasePod. Wait for the BasePod to properly terminate :param args: extra positional arguments :param kwargs: extra keyword arguments """ ...
[docs]class Pod(BasePod): """ :class:`Pod` is a thread/process- container of :class:`BaseRuntime`. It leverages :class:`multiprocessing.Process` to manage the lifecycle of a :class:`BaseRuntime` object in a robust way. A :class:`Pod` must be equipped with a proper :class:`Runtime` class to work. """ def __init__(self, args: 'argparse.Namespace'): super().__init__(args) self.runtime_cls = self._get_runtime_cls() self.worker = multiprocessing.Process( target=run, kwargs={ 'args': args, 'name': self.name, 'envs': self._envs, 'is_started': self.is_started, 'is_shutdown': self.is_shutdown, 'is_ready': self.is_ready, 'runtime_cls': self.runtime_cls, 'jaml_classes': JAML.registered_classes(), }, name=self.name, daemon=False, )
[docs] def start(self): """Start the Pod. This method calls :meth:`start` in :class:`multiprocesssing.Process`. .. #noqa: DAR201 """ self.worker.start() self.is_forked = multiprocessing.get_start_method().lower() == 'fork' if not self.args.noblock_on_start: self.wait_start_success() return self
[docs] def join(self, *args, **kwargs): """Joins the Pod. This method calls :meth:`join` in :class:`multiprocesssing.Process`. :param args: extra positional arguments to pass to join :param kwargs: extra keyword arguments to pass to join """ self.logger.debug(f'joining the process') self.worker.join(*args, **kwargs) self.logger.debug(f'successfully joined the process')
def _terminate(self): """Terminate the Pod. This method calls :meth:`terminate` in :class:`multiprocesssing.Process`. """ self.logger.debug(f'terminating the runtime process') self.worker.terminate() self.logger.debug(f'runtime process properly terminated') def _get_runtime_cls(self) -> AsyncNewLoopRuntime: from jina.orchestrate.pods.helper import update_runtime_cls from jina.serve.runtimes import get_runtime update_runtime_cls(self.args) return get_runtime(self.args.runtime_cls)