Source code for jina.serve.runtimes.head

import argparse
import asyncio
import contextlib
import json
import os
from abc import ABC
from collections import defaultdict
from typing import Dict, List, Optional, Tuple

import grpc
from grpc_health.v1 import health, health_pb2, health_pb2_grpc
from grpc_reflection.v1alpha import reflection

from jina.enums import PollingType
from jina.excepts import InternalNetworkError
from jina.helper import get_full_version
from jina.importer import ImportExtensions
from jina.proto import jina_pb2, jina_pb2_grpc
from jina.serve.networking import GrpcConnectionPool
from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime
from jina.serve.runtimes.helper import _get_grpc_server_options
from jina.serve.runtimes.request_handlers.data_request_handler import DataRequestHandler
from import DataRequest, Response

[docs]class HeadRuntime(AsyncNewLoopRuntime, ABC): """ Runtime is used in head pods. It responds to Gateway requests and sends to uses_before/uses_after and its workers """ DEFAULT_POLLING = PollingType.ANY def __init__( self, args: argparse.Namespace, **kwargs, ): """Initialize grpc server for the head runtime. :param args: args from CLI :param kwargs: keyword args """ self._health_servicer = health.HealthServicer(experimental_non_blocking=True) super().__init__(args, **kwargs) if is None: = '' = self._deployment_name = os.getenv('JINA_DEPLOYMENT_NAME', 'worker') self.connection_pool = GrpcConnectionPool(, logger=self.logger, compression=args.compression, metrics_registry=self.metrics_registry, ) self._retries = self.args.retries if self.metrics_registry: with ImportExtensions( required=True, help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Summary self._summary = ( Summary( 'receiving_request_seconds', 'Time spent processing request', registry=self.metrics_registry, namespace='jina', labelnames=('runtime_name',), ) .labels( .time() ) else: self._summary = contextlib.nullcontext() polling = getattr(args, 'polling', try: # try loading the polling args as json endpoint_polling = json.loads(polling) # '*' is used a wildcard and will match all endpoints, except /index, /search and explicitly defined endpoins default_polling = ( PollingType.from_string(endpoint_polling['*']) if '*' in endpoint_polling else self.DEFAULT_POLLING ) self._polling = self._default_polling_dict(default_polling) for endpoint in endpoint_polling: self._polling[endpoint] = PollingType( endpoint_polling[endpoint] if type(endpoint_polling[endpoint]) == int else PollingType.from_string(endpoint_polling[endpoint]) ) except (ValueError, TypeError): # polling args is not a valid json, try interpreting as a polling enum type default_polling = ( polling if type(polling) == PollingType else PollingType.from_string(polling) ) self._polling = self._default_polling_dict(default_polling) if hasattr(args, 'connection_list') and args.connection_list: connection_list = json.loads(args.connection_list) for shard_id in connection_list: shard_connections = connection_list[shard_id] if isinstance(shard_connections, str): self.connection_pool.add_connection( deployment=self._deployment_name, address=shard_connections, shard_id=int(shard_id), ) else: for connection in shard_connections: self.connection_pool.add_connection( deployment=self._deployment_name, address=connection, shard_id=int(shard_id), ) self.uses_before_address = args.uses_before_address self.timeout_send = args.timeout_send if self.timeout_send: self.timeout_send /= 1e3 # convert ms to seconds if self.uses_before_address: self.connection_pool.add_connection( deployment='uses_before', address=self.uses_before_address ) self.uses_after_address = args.uses_after_address if self.uses_after_address: self.connection_pool.add_connection( deployment='uses_after', address=self.uses_after_address ) self._reduce = not args.disable_reduce def _default_polling_dict(self, default_polling): return defaultdict( lambda: default_polling, {'/search': PollingType.ALL, '/index': PollingType.ANY}, )
[docs] async def async_setup(self): """Wait for the GRPC server to start""" self._grpc_server = grpc.aio.server( options=_get_grpc_server_options(self.args.grpc_server_options) ) jina_pb2_grpc.add_JinaSingleDataRequestRPCServicer_to_server( self, self._grpc_server ) jina_pb2_grpc.add_JinaDataRequestRPCServicer_to_server(self, self._grpc_server) jina_pb2_grpc.add_JinaDiscoverEndpointsRPCServicer_to_server( self, self._grpc_server ) jina_pb2_grpc.add_JinaInfoRPCServicer_to_server(self, self._grpc_server) service_names = ( jina_pb2.DESCRIPTOR.services_by_name['JinaSingleDataRequestRPC'].full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaDataRequestRPC'].full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaDiscoverEndpointsRPC'].full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaInfoRPC'].full_name, reflection.SERVICE_NAME, ) # Mark all services as healthy. health_pb2_grpc.add_HealthServicer_to_server( self._health_servicer, self._grpc_server ) for service in service_names: self._health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING) reflection.enable_server_reflection(service_names, self._grpc_server) bind_addr = f'{self.args.port}' self._grpc_server.add_insecure_port(bind_addr) self.logger.debug(f'start listening on {bind_addr}') await self._grpc_server.start()
[docs] async def async_run_forever(self): """Block until the GRPC server is terminated""" await self._grpc_server.wait_for_termination()
[docs] async def async_cancel(self): """Stop the GRPC server""" self.logger.debug('cancel HeadRuntime') await self._grpc_server.stop(0)
[docs] async def async_teardown(self): """Close the connection pool""" self._health_servicer.enter_graceful_shutdown() await self.async_cancel() await self.connection_pool.close()
[docs] async def process_single_data(self, request: DataRequest, context) -> DataRequest: """ Process the received requests and return the result as a new request :param request: the data request to process :param context: grpc context :returns: the response request """ return await self.process_data([request], context)
def _handle_internalnetworkerror(self, err, context, response): err_code = err.code() if err_code == grpc.StatusCode.UNAVAILABLE: context.set_details( f'|Head: Failed to connect to worker (Executor) pod at address {err.dest_addr}. It may be down.' ) elif err_code == grpc.StatusCode.DEADLINE_EXCEEDED: context.set_details( f'|Head: Connection to worker (Executor) pod at address {err.dest_addr} could be established, but timed out.' ) context.set_code(err.code()) self.logger.error(f'Error while getting responses from Pods: {err.details()}') if err.request_id: response.header.request_id = err.request_id return response
[docs] async def process_data(self, requests: List[DataRequest], context) -> DataRequest: """ Process the received data request and return the result as a new request :param requests: the data requests to process :param context: grpc context :returns: the response request """ try: with self._summary: endpoint = dict(context.invocation_metadata()).get('endpoint') response, metadata = await self._handle_data_request(requests, endpoint) context.set_trailing_metadata(metadata.items()) return response except InternalNetworkError as err: # can't connect, Flow broken, interrupt the streaming through gRPC error mechanism return self._handle_internalnetworkerror( err=err, context=context, response=Response() ) except ( RuntimeError, Exception, ) as ex: # some other error, keep streaming going just add error info self.logger.error( f'{ex!r}' + f'\n add "--quiet-error" to suppress the exception details' if not self.args.quiet_error else '', exc_info=not self.args.quiet_error, ) requests[0].add_exception(ex, executor=None) context.set_trailing_metadata((('is-error', 'true'),)) return requests[0]
[docs] async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto: """ Uses the connection pool to send a discover endpoint call to the workers :param empty: The service expects an empty protobuf message :param context: grpc context :returns: the response request """ response = jina_pb2.EndpointsProto() try: if self.uses_before_address: ( uses_before_response, _, ) = await self.connection_pool.send_discover_endpoint( deployment='uses_before', head=False ) response.endpoints.extend(uses_before_response.endpoints) if self.uses_after_address: ( uses_after_response, _, ) = await self.connection_pool.send_discover_endpoint( deployment='uses_after', head=False ) response.endpoints.extend(uses_after_response.endpoints) worker_response, _ = await self.connection_pool.send_discover_endpoint( deployment=self._deployment_name, head=False ) response.endpoints.extend(worker_response.endpoints) except InternalNetworkError as err: # can't connect, Flow broken, interrupt the streaming through gRPC error mechanism return self._handle_internalnetworkerror( err=err, context=context, response=response ) return response
async def _handle_data_request( self, requests: List[DataRequest], endpoint: Optional[str] ) -> Tuple[DataRequest, Dict]: self.logger.debug(f'recv {len(requests)} DataRequest(s)') DataRequestHandler.merge_routes(requests) uses_before_metadata = None if self.uses_before_address: ( response, uses_before_metadata, ) = await self.connection_pool.send_requests_once( requests, deployment='uses_before', timeout=self.timeout_send, retries=self._retries, ) requests = [response] worker_send_tasks = self.connection_pool.send_requests( requests=requests, deployment=self._deployment_name, polling_type=self._polling[endpoint], timeout=self.timeout_send, retries=self._retries, ) worker_results = await asyncio.gather(*worker_send_tasks) if len(worker_results) == 0: raise RuntimeError( f'Head {} did not receive a response when sending message to worker pods' ) worker_results, metadata = zip(*worker_results) response_request = worker_results[0] uses_after_metadata = None if self.uses_after_address: ( response_request, uses_after_metadata, ) = await self.connection_pool.send_requests_once( worker_results, deployment='uses_after', timeout=self.timeout_send, retries=self._retries, ) elif len(worker_results) > 1 and self._reduce: DataRequestHandler.reduce_requests(worker_results) elif len(worker_results) > 1 and not self._reduce: # worker returned multiple responsed, but the head is configured to skip reduction # just concatenate the docs in this case = DataRequestHandler.get_docs_from_request( requests, field='docs' ) merged_metadata = self._merge_metadata( metadata, uses_after_metadata, uses_before_metadata ) return response_request, merged_metadata def _merge_metadata(self, metadata, uses_after_metadata, uses_before_metadata): merged_metadata = {} if uses_before_metadata: for key, value in uses_before_metadata: merged_metadata[key] = value for meta in metadata: for key, value in meta: merged_metadata[key] = value if uses_after_metadata: for key, value in uses_after_metadata: merged_metadata[key] = value return merged_metadata async def _status(self, empty, context) -> jina_pb2.JinaInfoProto: """ Process the the call requested and return the JinaInfo of the Runtime :param empty: The service expects an empty protobuf message :param context: grpc context :returns: the response request """ infoProto = jina_pb2.JinaInfoProto() version, env_info = get_full_version() for k, v in version.items(): infoProto.jina[k] = str(v) for k, v in env_info.items(): infoProto.envs[k] = str(v) return infoProto