Source code for jina.serve.runtimes.gateway.request_handling

import asyncio
import copy
import time
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple

import grpc.aio
from docarray import DocumentArray

from jina.excepts import InternalNetworkError
from jina.importer import ImportExtensions
from jina.proto import jina_pb2
from jina.serve.networking import GrpcConnectionPool
from jina.serve.runtimes.gateway.graph.topology_graph import TopologyGraph
from jina.serve.runtimes.helper import _is_param_for_specific_executor
from jina.serve.runtimes.request_handlers.data_request_handler import DataRequestHandler

    from asyncio import Future

    from prometheus_client import CollectorRegistry

    from jina.types.request import Request

[docs]class MonitoringRequestMixin: """ Mixin for the request handling monitoring :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler :param runtime_name: optional runtime_name that will be registered during monitoring """ def __init__( self, metrics_registry: Optional['CollectorRegistry'] = None, runtime_name: Optional[str] = None, ): self._request_init_time = {} if metrics_registry else None if metrics_registry: with ImportExtensions( required=True, help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Counter, Gauge, Summary self._receiving_request_metrics = Summary( 'receiving_request_seconds', 'Time spent processing successful request', registry=metrics_registry, namespace='jina', labelnames=('runtime_name',), ).labels(runtime_name) self._pending_requests_metrics = Gauge( 'number_of_pending_requests', 'Number of pending requests', registry=metrics_registry, namespace='jina', labelnames=('runtime_name',), ).labels(runtime_name) self._failed_requests_metrics = Counter( 'failed_requests', 'Number of failed requests', registry=metrics_registry, namespace='jina', labelnames=('runtime_name',), ).labels(runtime_name) self._successful_requests_metrics = Counter( 'successful_requests', 'Number of successful requests', registry=metrics_registry, namespace='jina', labelnames=('runtime_name',), ).labels(runtime_name) else: self._receiving_request_metrics = None self._pending_requests_metrics = None self._failed_requests_metrics = None self._successful_requests_metrics = None def _update_start_request_metrics(self, request: 'Request'): if self._receiving_request_metrics: self._request_init_time[request.request_id] = time.time() if self._pending_requests_metrics: def _update_end_successful_requests_metrics(self, result: 'Request'): if ( self._receiving_request_metrics ): # this one should only be observed when the metrics is succesful init_time = self._request_init_time.pop( result.request_id ) # need to pop otherwise it stays in memory forever self._receiving_request_metrics.observe(time.time() - init_time) if self._pending_requests_metrics: self._pending_requests_metrics.dec() if self._successful_requests_metrics: def _update_end_failed_requests_metrics(self, result: 'Request'): if self._pending_requests_metrics: self._pending_requests_metrics.dec() if self._failed_requests_metrics: def _update_end_request_metrics(self, result: 'Request'): if result.status.code != jina_pb2.StatusProto.ERROR: self._update_end_successful_requests_metrics(result) else: self._update_end_failed_requests_metrics(result)
[docs]class RequestHandler(MonitoringRequestMixin): """ Class that handles the requests arriving to the gateway and the result extracted from the requests future. :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler :param runtime_name: optional runtime_name that will be registered during monitoring """ def __init__( self, metrics_registry: Optional['CollectorRegistry'] = None, runtime_name: Optional[str] = None, ): super().__init__(metrics_registry, runtime_name) self._executor_endpoint_mapping = None self._gathering_endpoints = False
[docs] def handle_request( self, graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool' ) -> Callable[['Request'], 'Tuple[Future, Optional[Future]]']: """ Function that handles the requests arriving to the gateway. This will be passed to the streamer. :param graph: The TopologyGraph of the Flow. :param connection_pool: The connection pool to be used to send messages to specific nodes of the graph :return: Return a Function that given a Request will return a Future from where to extract the response """ async def gather_endpoints(request_graph): nodes = request_graph.all_nodes try: tasks_to_get_endpoints = [ node.get_endpoints(connection_pool) for node in nodes ] endpoints = await asyncio.gather(*tasks_to_get_endpoints) except InternalNetworkError as err: err_code = err.code() if err_code == grpc.StatusCode.UNAVAILABLE: err._details = ( err.details() + f' |Gateway: Communication error with deployment at address(es) {err.dest_addr}. Head or worker(s) may be down.' ) raise err else: raise self._executor_endpoint_mapping = {} for node, (endp, _) in zip(nodes, endpoints): self._executor_endpoint_mapping[] = endp.endpoints def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]': self._update_start_request_metrics(request) # important that the gateway needs to have an instance of the graph per request request_graph = copy.deepcopy(graph) if graph.has_filter_conditions: request_doc_ids =[ :, 'id' ] # used to maintain order of docs that are filtered by executors responding_tasks = [] floating_tasks = [] endpoint = request.header.exec_endpoint r = request.routes.add() r.executor = 'gateway' r.start_time.GetCurrentTime() # If the request is targeting a specific deployment, we can send directly to the deployment instead of # querying the graph num_outgoing_nodes = len(request_graph.origin_nodes) has_specific_params = False request_input_parameters = request.parameters for key in request_input_parameters: if _is_param_for_specific_executor(key): has_specific_params = True break target_executor = request.header.target_executor # reset it in case we send to an external gateway request.header.target_executor = '' for origin_node in request_graph.origin_nodes: leaf_tasks = origin_node.get_leaf_tasks( connection_pool=connection_pool, request_to_send=request, previous_task=None, endpoint=endpoint, executor_endpoint_mapping=self._executor_endpoint_mapping, target_executor_pattern=target_executor or None, request_input_parameters=request_input_parameters, request_input_has_specific_params=has_specific_params, copy_request_at_send=num_outgoing_nodes > 1 and has_specific_params, ) # Every origin node returns a set of tasks that are the ones corresponding to the leafs of each of their # subtrees that unwrap all the previous tasks. It starts like a chain of waiting for tasks from previous # nodes responding_tasks.extend([task for ret, task in leaf_tasks if ret]) floating_tasks.extend([task for ret, task in leaf_tasks if not ret]) def _sort_response_docs(response): # sort response docs according to their order in the initial request def sort_by_request_order(doc): if in request_doc_ids: return request_doc_ids.index( else: return len(request_doc_ids) # put new/unknown docs at the end sorted_docs = sorted(, key=sort_by_request_order) = DocumentArray(sorted_docs) async def _process_results_at_end_gateway( tasks: List[asyncio.Task], request_graph: TopologyGraph ) -> asyncio.Future: try: if ( self._executor_endpoint_mapping is None and not self._gathering_endpoints ): self._gathering_endpoints = True asyncio.create_task(gather_endpoints(request_graph)) partial_responses = await asyncio.gather(*tasks) except Exception as e: # update here failed request self._update_end_failed_requests_metrics(request) raise partial_responses, metadatas = zip(*partial_responses) filtered_partial_responses = list( filter(lambda x: x is not None, partial_responses) ) response = filtered_partial_responses[0] request_graph.add_routes(response) if graph.has_filter_conditions: _sort_response_docs(response) collect_results = request_graph.collect_all_results() resp_params = response.parameters if len(collect_results) > 0: resp_params[DataRequestHandler._KEY_RESULT] = collect_results response.parameters = resp_params return response # In case of empty topologies if not responding_tasks: r.end_time.GetCurrentTime() future = asyncio.Future() future.set_result((request, {})) responding_tasks.append(future) return ( asyncio.ensure_future( _process_results_at_end_gateway(responding_tasks, request_graph) ), asyncio.ensure_future(asyncio.gather(*floating_tasks)) if len(floating_tasks) > 0 else None, ) return _handle_request
[docs] def handle_result(self) -> Callable[['Request'], 'Request']: """ Function that handles the result when extracted from the request future :return: Return a Function that returns a request to be returned to the client """ def _handle_result(result: 'Request'): """ Function that handles the result when extracted from the request future :param result: The result returned to the gateway. It extracts the request to be returned to the client :return: Returns a request to be returned to the client """ for route in result.routes: if route.executor == 'gateway': route.end_time.GetCurrentTime() self._update_end_request_metrics(result) return result return _handle_result