import abc
from types import SimpleNamespace
from typing import Dict, Optional
from jina.jaml import JAMLCompatible
from jina.logging.logger import JinaLogger
from jina.serve.helper import store_init_kwargs, wrap_func
__all__ = ['BaseGateway']
class GatewayType(type(JAMLCompatible), type):
"""The class of Gateway type, which is the metaclass of :class:`BaseGateway`."""
def __new__(cls, *args, **kwargs):
"""
# noqa: DAR101
# noqa: DAR102
:return: Gateway class
"""
_cls = super().__new__(cls, *args, **kwargs)
return cls.register_class(_cls)
@staticmethod
def register_class(cls):
"""
Register a class.
:param cls: The class.
:return: The class, after being registered.
"""
reg_cls_set = getattr(cls, '_registered_class', set())
cls_id = f'{cls.__module__}.{cls.__name__}'
if cls_id not in reg_cls_set:
reg_cls_set.add(cls_id)
setattr(cls, '_registered_class', reg_cls_set)
wrap_func(
cls,
['__init__'],
store_init_kwargs,
taboo={'self', 'args', 'kwargs', 'runtime_args'},
)
return cls
[docs]class BaseGateway(JAMLCompatible, metaclass=GatewayType):
"""
The base class of all custom Gateways, can be used to build a custom interface to a Jina Flow that supports
gateway logic
:class:`jina.Gateway` as an alias for this class.
"""
def __init__(
self,
name: Optional[str] = 'gateway',
runtime_args: Optional[Dict] = None,
**kwargs,
):
"""
:param name: Gateway pod name
:param runtime_args: a dict of arguments injected from :class:`Runtime` during runtime
:param kwargs: additional extra keyword arguments to avoid failing when extra params ara passed that are not expected
"""
self._add_runtime_args(runtime_args)
self.name = name
self.logger = JinaLogger(self.name)
self.tracing = self.runtime_args.tracing
self.tracer_provider = self.runtime_args.tracer_provider
self.grpc_tracing_server_interceptors = (
self.runtime_args.grpc_tracing_server_interceptors
)
import json
from jina.serve.streamer import GatewayStreamer, _ExecutorStreamer
graph_description = json.loads(self.runtime_args.graph_description)
graph_conditions = json.loads(self.runtime_args.graph_conditions)
deployments_addresses = json.loads(self.runtime_args.deployments_addresses)
deployments_metadata = json.loads(self.runtime_args.deployments_metadata)
deployments_no_reduce = json.loads(self.runtime_args.deployments_no_reduce)
self.streamer = GatewayStreamer(
graph_representation=graph_description,
executor_addresses=deployments_addresses,
graph_conditions=graph_conditions,
deployments_metadata=deployments_metadata,
deployments_no_reduce=deployments_no_reduce,
timeout_send=self.runtime_args.timeout_send,
retries=self.runtime_args.retries,
compression=self.runtime_args.compression,
runtime_name=self.runtime_args.runtime_name,
prefetch=self.runtime_args.prefetch,
logger=self.logger,
metrics_registry=self.runtime_args.metrics_registry,
meter=self.runtime_args.meter,
aio_tracing_client_interceptors=self.runtime_args.aio_tracing_client_interceptors,
tracing_client_interceptor=self.runtime_args.tracing_client_interceptor,
)
GatewayStreamer._set_env_streamer_args(
graph_representation=graph_description,
executor_addresses=deployments_addresses,
graph_conditions=graph_conditions,
deployments_metadata=deployments_metadata,
deployments_no_reduce=deployments_no_reduce,
timeout_send=self.runtime_args.timeout_send,
retries=self.runtime_args.retries,
compression=self.runtime_args.compression,
runtime_name=self.runtime_args.runtime_name,
prefetch=self.runtime_args.prefetch,
)
self.executor = {executor_name: _ExecutorStreamer(self.streamer._connection_pool, executor_name=executor_name) for executor_name in deployments_addresses.keys()}
def _add_runtime_args(self, _runtime_args: Optional[Dict]):
from jina.parsers import set_gateway_runtime_args_parser
parser = set_gateway_runtime_args_parser()
default_args = parser.parse_args([])
default_args_dict = dict(vars(default_args))
_runtime_args = _runtime_args or {}
runtime_set_args = {
'tracer_provider': None,
'grpc_tracing_server_interceptors': None,
'runtime_name': 'test',
'metrics_registry': None,
'meter': None,
'aio_tracing_client_interceptors': None,
'tracing_client_interceptor': None,
}
runtime_args_dict = {**runtime_set_args, **default_args_dict, **_runtime_args}
self.runtime_args = SimpleNamespace(**runtime_args_dict)
@property
def port(self):
"""Gets the first port of the port list argument. To be used in the regular case where a Gateway exposes a single port
:return: The first port to be exposed
"""
return self.runtime_args.port[0]
@property
def ports(self):
"""Gets all the list of ports from the runtime_args as a list.
:return: The lists of ports to be exposed
"""
return self.runtime_args.port
@property
def protocols(self):
"""Gets all the list of protocols from the runtime_args as a list.
:return: The lists of protocols to be exposed
"""
return self.runtime_args.protocol
@property
def host(self):
"""Gets the host from the runtime_args
:return: The host where to bind the gateway
"""
return self.runtime_args.host
[docs] @abc.abstractmethod
async def setup_server(self):
"""Setup server"""
...
[docs] @abc.abstractmethod
async def run_server(self):
"""Run server forever"""
...
[docs] @abc.abstractmethod
async def shutdown(self):
"""Shutdown the server and free other allocated resources, e.g, streamer object, health check service, ..."""
...
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass