import argparse
import asyncio
import contextlib
import json
import os
from abc import ABC
from collections import defaultdict
from typing import Dict, List, Optional, Tuple

import grpc
from grpc_health.v1 import health, health_pb2, health_pb2_grpc
from grpc_reflection.v1alpha import reflection

from jina.enums import PollingType
from jina.excepts import InternalNetworkError
from jina.helper import get_full_version
from jina.importer import ImportExtensions
from jina.proto import jina_pb2, jina_pb2_grpc
from jina.serve.networking import GrpcConnectionPool
from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime
from jina.serve.runtimes.helper import _get_grpc_server_options
from jina.serve.runtimes.request_handlers.data_request_handler import DataRequestHandler
from import DataRequest, Response

[docs]class HeadRuntime(AsyncNewLoopRuntime, ABC): """ Runtime is used in head pods. It responds to Gateway requests and sends to uses_before/uses_after and its workers """ DEFAULT_POLLING = PollingType.ANY def __init__( self, args: argparse.Namespace, **kwargs, ): """Initialize grpc server for the head runtime. :param args: args from CLI :param kwargs: keyword args """ self._health_servicer = health.HealthServicer(experimental_non_blocking=True) super().__init__(args, **kwargs) if is None: = '' = self._deployment_name = os.getenv('JINA_DEPLOYMENT_NAME', 'worker') self.connection_pool = GrpcConnectionPool(, logger=self.logger, compression=args.compression, metrics_registry=self.metrics_registry, ) self._retries = self.args.retries if self.metrics_registry: with ImportExtensions( required=True, help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Summary self._summary = ( Summary( 'receiving_request_seconds', 'Time spent processing request', registry=self.metrics_registry, namespace='jina', labelnames=('runtime_name',), ) .labels( .time() ) else: self._summary = contextlib.nullcontext() polling = getattr(args, 'polling', try: # try loading the polling args as json endpoint_polling = json.loads(polling) # '*' is used a wildcard and will match all endpoints, except /index, /search and explicitly defined endpoins default_polling = ( PollingType.from_string(endpoint_polling['*']) if '*' in endpoint_polling else self.DEFAULT_POLLING ) self._polling = self._default_polling_dict(default_polling) for endpoint in endpoint_polling: self._polling[endpoint] = PollingType( endpoint_polling[endpoint] if type(endpoint_polling[endpoint]) == int else PollingType.from_string(endpoint_polling[endpoint]) ) except (ValueError, TypeError): # polling args is not a valid json, try interpreting as a polling enum type default_polling = ( polling if type(polling) == PollingType else PollingType.from_string(polling) ) self._polling = self._default_polling_dict(default_polling) if hasattr(args, 'connection_list') and args.connection_list: connection_list = json.loads(args.connection_list) for shard_id in connection_list: shard_connections = connection_list[shard_id] if isinstance(shard_connections, str): self.connection_pool.add_connection( deployment=self._deployment_name, address=shard_connections, shard_id=int(shard_id), ) else: for connection in shard_connections: self.connection_pool.add_connection( deployment=self._deployment_name, address=connection, shard_id=int(shard_id), ) self.uses_before_address = args.uses_before_address self.timeout_send = args.timeout_send if self.timeout_send: self.timeout_send /= 1e3 # convert ms to seconds if self.uses_before_address: self.connection_pool.add_connection( deployment='uses_before', address=self.uses_before_address ) self.uses_after_address = args.uses_after_address if self.uses_after_address: self.connection_pool.add_connection( deployment='uses_after', address=self.uses_after_address ) self._reduce = not args.disable_reduce def _default_polling_dict(self, default_polling): return defaultdict( lambda: default_polling, {'/search': PollingType.ALL, '/index': PollingType.ANY}, )
[docs] async def async_setup(self): """Wait for the GRPC server to start""" self._grpc_server = grpc.aio.server( options=_get_grpc_server_options(self.args.grpc_server_options) ) jina_pb2_grpc.add_JinaSingleDataRequestRPCServicer_to_server( self, self._grpc_server ) jina_pb2_grpc.add_JinaDataRequestRPCServicer_to_server(self, self._grpc_server) jina_pb2_grpc.add_JinaDiscoverEndpointsRPCServicer_to_server( self, self._grpc_server ) jina_pb2_grpc.add_JinaInfoRPCServicer_to_server(self, self._grpc_server) service_names = ( jina_pb2.DESCRIPTOR.services_by_name['JinaSingleDataRequestRPC'].full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaDataRequestRPC'].full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaDiscoverEndpointsRPC'].full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaInfoRPC'].full_name, reflection.SERVICE_NAME, ) # Mark all services as healthy. health_pb2_grpc.add_HealthServicer_to_server( self._health_servicer, self._grpc_server ) for service in service_names: self._health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING) reflection.enable_server_reflection(service_names, self._grpc_server) bind_addr = f'{self.args.port}' self._grpc_server.add_insecure_port(bind_addr) self.logger.debug(f'start listening on {bind_addr}') await self._grpc_server.start()
[docs] async def async_run_forever(self): """Block until the GRPC server is terminated""" await self._grpc_server.wait_for_termination()
[docs] async def async_cancel(self): """Stop the GRPC server""" self.logger.debug('cancel HeadRuntime') await self._grpc_server.stop(0)
[docs] async def async_teardown(self): """Close the connection pool""" self._health_servicer.enter_graceful_shutdown() await self.async_cancel() await self.connection_pool.close()
[docs] async def process_single_data(self, request: DataRequest, context) -> DataRequest: """ Process the received requests and return the result as a new request :param request: the data request to process :param context: grpc context :returns: the response request """ return await self.process_data([request], context)
def _handle_internalnetworkerror(self, err, context, response): err_code = err.code() if err_code == grpc.StatusCode.UNAVAILABLE: context.set_details( f'|Head: Failed to connect to worker (Executor) pod at address {err.dest_addr}. It may be down.' ) elif err_code == grpc.StatusCode.DEADLINE_EXCEEDED: context.set_details( f'|Head: Connection to worker (Executor) pod at address {err.dest_addr} could be established, but timed out.' ) context.set_code(err.code()) self.logger.error(f'Error while getting responses from Pods: {err.details()}') if err.request_id: response.header.request_id = err.request_id return response
[docs] async def process_data(self, requests: List[DataRequest], context) -> DataRequest: """ Process the received data request and return the result as a new request :param requests: the data requests to process :param context: grpc context :returns: the response request """ try: with self._summary: endpoint = dict(context.invocation_metadata()).get('endpoint') response, metadata = await self._handle_data_request(requests, endpoint) context.set_trailing_metadata(metadata.items()) return response except InternalNetworkError as err: # can't connect, Flow broken, interrupt the streaming through gRPC error mechanism return self._handle_internalnetworkerror( err=err, context=context, response=Response() ) except ( RuntimeError, Exception, ) as ex: # some other error, keep streaming going just add error info self.logger.error( f'{ex!r}' + f'\n add "--quiet-error" to suppress the exception details' if not self.args.quiet_error else '', exc_info=not self.args.quiet_error, ) requests[0].add_exception(ex, executor=None) context.set_trailing_metadata((('is-error', 'true'),)) return requests[0]
[docs] async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto: """ Uses the connection pool to send a discover endpoint call to the workers :param empty: The service expects an empty protobuf message :param context: grpc context :returns: the response request """ response = jina_pb2.EndpointsProto() try: if self.uses_before_address: ( uses_before_response, _, ) = await self.connection_pool.send_discover_endpoint( deployment='uses_before', head=False ) response.endpoints.extend(uses_before_response.endpoints) if self.uses_after_address: ( uses_after_response, _, ) = await self.connection_pool.send_discover_endpoint( deployment='uses_after', head=False ) response.endpoints.extend(uses_after_response.endpoints) worker_response, _ = await self.connection_pool.send_discover_endpoint( deployment=self._deployment_name, head=False ) response.endpoints.extend(worker_response.endpoints) except InternalNetworkError as err: # can't connect, Flow broken, interrupt the streaming through gRPC error mechanism return self._handle_internalnetworkerror( err=err, context=context, response=response ) return response
async def _handle_data_request( self, requests: List[DataRequest], endpoint: Optional[str] ) -> Tuple[DataRequest, Dict]: self.logger.debug(f'recv {len(requests)} DataRequest(s)') DataRequestHandler.merge_routes(requests) uses_before_metadata = None if self.uses_before_address: ( response, uses_before_metadata, ) = await self.connection_pool.send_requests_once( requests, deployment='uses_before', timeout=self.timeout_send, retries=self._retries, ) requests = [response] worker_send_tasks = self.connection_pool.send_requests( requests=requests, deployment=self._deployment_name, polling_type=self._polling[endpoint], timeout=self.timeout_send, retries=self._retries, ) worker_results = await asyncio.gather(*worker_send_tasks) if len(worker_results) == 0: raise RuntimeError( f'Head {} did not receive a response when sending message to worker pods' ) worker_results, metadata = zip(*worker_results) response_request = worker_results[0] uses_after_metadata = None if self.uses_after_address: ( response_request, uses_after_metadata, ) = await self.connection_pool.send_requests_once( worker_results, deployment='uses_after', timeout=self.timeout_send, retries=self._retries, ) elif len(worker_results) > 1 and self._reduce: DataRequestHandler.reduce_requests(worker_results) elif len(worker_results) > 1 and not self._reduce: # worker returned multiple responsed, but the head is configured to skip reduction # just concatenate the docs in this case = DataRequestHandler.get_docs_from_request( requests, field='docs' ) merged_metadata = self._merge_metadata( metadata, uses_after_metadata, uses_before_metadata ) return response_request, merged_metadata def _merge_metadata(self, metadata, uses_after_metadata, uses_before_metadata): merged_metadata = {} if uses_before_metadata: for key, value in uses_before_metadata: merged_metadata[key] = value for meta in metadata: for key, value in meta: merged_metadata[key] = value if uses_after_metadata: for key, value in uses_after_metadata: merged_metadata[key] = value return merged_metadata async def _status(self, empty, context) -> jina_pb2.JinaInfoProto: """ Process the the call requested and return the JinaInfo of the Runtime :param empty: The service expects an empty protobuf message :param context: grpc context :returns: the response request """ infoProto = jina_pb2.JinaInfoProto() version, env_info = get_full_version() for k, v in version.items(): infoProto.jina[k] = str(v) for k, v in env_info.items(): infoProto.envs[k] = str(v) return infoProto