[docs]classHeadRuntime(AsyncNewLoopRuntime,ABC):""" Runtime is used in head pods. It responds to Gateway requests and sends to uses_before/uses_after and its workers """DEFAULT_POLLING=PollingType.ANYdef__init__(self,args:argparse.Namespace,cancel_event:Optional[Union['asyncio.Event','multiprocessing.Event','threading.Event']]=None,**kwargs,):"""Initialize grpc server for the head runtime. :param args: args from CLI :param cancel_event: the cancel event used to wait for canceling :param kwargs: keyword args """super().__init__(args,cancel_event,**kwargs)ifargs.nameisNone:args.name=''self.name=args.nameself._deployment_name=os.getenv('JINA_DEPLOYMENT_NAME','worker')self.connection_pool=create_connection_pool(logger=self.logger,k8s_connection_pool=args.k8s_connection_pool,k8s_namespace=args.k8s_namespace,)polling=getattr(args,'polling',self.DEFAULT_POLLING.name)try:# try loading the polling args as jsonendpoint_polling=json.loads(polling)# '*' is used a wildcard and will match all endpoints, except /index, /search and explicitly defined endpoinsdefault_polling=(PollingType.from_string(endpoint_polling['*'])if'*'inendpoint_pollingelseself.DEFAULT_POLLING)self._polling=self._default_polling_dict(default_polling)forendpointinendpoint_polling:self._polling[endpoint]=PollingType(endpoint_polling[endpoint]iftype(endpoint_polling[endpoint])==intelsePollingType.from_string(endpoint_polling[endpoint]))except(ValueError,TypeError):# polling args is not a valid json, try interpreting as a polling enum typedefault_polling=(pollingiftype(polling)==PollingTypeelsePollingType.from_string(polling))self._polling=self._default_polling_dict(default_polling)# In K8s the ConnectionPool needs the information about the Jina Deployment its running in# This is stored in the environment variable JINA_DEPLOYMENT_NAME in all Jina K8s default templatesif(type(self.connection_pool)==K8sGrpcConnectionPooland'JINA_DEPLOYMENT_NAME'notinos.environ):raiseValueError('K8s deployments need to specify the environment variable "JINA_DEPLOYMENT_NAME"')ifhasattr(args,'connection_list')andargs.connection_list:connection_list=json.loads(args.connection_list)forshard_idinconnection_list:shard_connections=connection_list[shard_id]ifisinstance(shard_connections,str):self.connection_pool.add_connection(deployment=self._deployment_name,address=shard_connections,shard_id=int(shard_id),)else:forconnectioninshard_connections:self.connection_pool.add_connection(deployment=self._deployment_name,address=connection,shard_id=int(shard_id),)self.uses_before_address=args.uses_before_addressifself.uses_before_address:self.connection_pool.add_connection(deployment='uses_before',address=self.uses_before_address)self.uses_after_address=args.uses_after_addressifself.uses_after_address:self.connection_pool.add_connection(deployment='uses_after',address=self.uses_after_address)self._reduce=notargs.disable_reducedef_default_polling_dict(self,default_polling):returndefaultdict(lambda:default_polling,{'/search':PollingType.ALL,'/index':PollingType.ANY},)
[docs]asyncdefasync_setup(self):""" Wait for the GRPC server to start """self._grpc_server=grpc.aio.server(options=[('grpc.max_send_message_length',-1),('grpc.max_receive_message_length',-1),])jina_pb2_grpc.add_JinaSingleDataRequestRPCServicer_to_server(self,self._grpc_server)jina_pb2_grpc.add_JinaDataRequestRPCServicer_to_server(self,self._grpc_server)jina_pb2_grpc.add_JinaControlRequestRPCServicer_to_server(self,self._grpc_server)bind_addr=f'0.0.0.0:{self.args.port}'self._grpc_server.add_insecure_port(bind_addr)self.logger.debug(f'Start listening on {bind_addr}')awaitself._grpc_server.start()
[docs]asyncdefasync_run_forever(self):"""Block until the GRPC server is terminated """self.connection_pool.start()awaitself._grpc_server.wait_for_termination()
[docs]asyncdefasync_cancel(self):"""Stop the GRPC server"""self.logger.debug('Cancel HeadRuntime')awaitself._grpc_server.stop(0)
[docs]asyncdefasync_teardown(self):"""Close the connection pool"""awaitself.async_cancel()awaitself.connection_pool.close()
[docs]asyncdefprocess_single_data(self,request:DataRequest,context)->DataRequest:""" Process the received requests and return the result as a new request :param request: the data request to process :param context: grpc context :returns: the response request """returnawaitself.process_data([request],context)
[docs]asyncdefprocess_data(self,requests:List[DataRequest],context)->DataRequest:""" Process the received data request and return the result as a new request :param requests: the data requests to process :param context: grpc context :returns: the response request """try:endpoint=dict(context.invocation_metadata()).get('endpoint')response,metadata=awaitself._handle_data_request(requests,endpoint)context.set_trailing_metadata(metadata.items())returnresponseexcept(RuntimeError,Exception)asex:self.logger.error(f'{ex!r}'+f'\n add "--quiet-error" to suppress the exception details'ifnotself.args.quiet_errorelse'',exc_info=notself.args.quiet_error,)raise
[docs]asyncdefprocess_control(self,request:ControlRequest,*args)->ControlRequest:""" Process the received control request and return the input request :param request: the data request to process :param args: additional arguments in the grpc call, ignored :returns: the input request """try:ifself.logger.debug_enabled:self._log_control_request(request)ifrequest.command=='ACTIVATE':forrelatedEntityinrequest.relatedEntities:connection_string=f'{relatedEntity.address}:{relatedEntity.port}'self.connection_pool.add_connection(deployment=self._deployment_name,address=connection_string,shard_id=relatedEntity.shard_idifrelatedEntity.HasField('shard_id')elseNone,)elifrequest.command=='DEACTIVATE':forrelatedEntityinrequest.relatedEntities:connection_string=f'{relatedEntity.address}:{relatedEntity.port}'awaitself.connection_pool.remove_connection(deployment=self._deployment_name,address=connection_string,shard_id=relatedEntity.shard_id,)returnrequestexcept(RuntimeError,Exception)asex:self.logger.error(f'{ex!r}'+f'\n add "--quiet-error" to suppress the exception details'ifnotself.args.quiet_errorelse'',exc_info=notself.args.quiet_error,)raise
asyncdef_handle_data_request(self,requests:List[DataRequest],endpoint:Optional[str])->Tuple[DataRequest,Dict]:self.logger.debug(f'recv {len(requests)} DataRequest(s)')DataRequestHandler.merge_routes(requests)uses_before_metadata=Noneifself.uses_before_address:(response,uses_before_metadata,)=awaitself.connection_pool.send_requests_once(requests,deployment='uses_before')requests=[response]eliflen(requests)>1andself._reduce:requests=[DataRequestHandler.reduce_requests(requests)]worker_send_tasks=self.connection_pool.send_requests(requests=requests,deployment=self._deployment_name,polling_type=self._polling[endpoint],)worker_results=awaitasyncio.gather(*worker_send_tasks)iflen(worker_results)==0:raiseRuntimeError(f'Head {self.name} did not receive a response when sending message to worker pods')worker_results,metadata=zip(*worker_results)response_request=worker_results[0]uses_after_metadata=Noneifself.uses_after_address:(response_request,uses_after_metadata,)=awaitself.connection_pool.send_requests_once(worker_results,deployment='uses_after')eliflen(worker_results)>1andself._reduce:DataRequestHandler.reduce_requests(worker_results)eliflen(worker_results)>1andnotself._reduce:# worker returned multiple responsed, but the head is configured to skip reduction# just concatenate the docs in this caseresponse_request.data.docs=DataRequestHandler.get_docs_from_request(requests,field='docs')merged_metadata=self._merge_metadata(metadata,uses_after_metadata,uses_before_metadata)returnresponse_request,merged_metadatadef_merge_metadata(self,metadata,uses_after_metadata,uses_before_metadata):merged_metadata={}ifuses_before_metadata:forkey,valueinuses_before_metadata:merged_metadata[key]=valueformetainmetadata:forkey,valueinmeta:merged_metadata[key]=valueifuses_after_metadata:forkey,valueinuses_after_metadata:merged_metadata[key]=valuereturnmerged_metadata