Source code for jina.serve.runtimes.gateway.websocket.app

import argparse
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union

from jina.clients.request import request_generator
from jina.enums import DataInputType, WebsocketSubProtocols
from jina.excepts import InternalNetworkError
from jina.helper import get_full_version
from jina.importer import ImportExtensions
from jina.logging.logger import JinaLogger
from jina.types.request.data import DataRequest
from jina.types.request.status import StatusMessage

if TYPE_CHECKING:
    from prometheus_client import CollectorRegistry


def _fits_ws_close_msg(msg: str):
    # Websocket close messages ('reasons') can't exceed 123 bytes:
    # https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/close
    ws_closing_msg_max_len = 123
    return len(msg.encode('utf-8')) <= ws_closing_msg_max_len


[docs]def get_fastapi_app( args: 'argparse.Namespace', logger: 'JinaLogger', timeout_send: Optional[float] = None, metrics_registry: Optional['CollectorRegistry'] = None, ): """ Get the app from FastAPI as the Websocket interface. :param args: passed arguments. :param logger: Jina logger. :param timeout_send: Timeout to be used when sending to Executors :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler :return: fastapi app """ from jina.serve.runtimes.gateway.http.models import JinaEndpointRequestModel with ImportExtensions(required=True): from fastapi import FastAPI, Response, WebSocket, WebSocketDisconnect, status class ConnectionManager: def __init__(self): self.active_connections: List[WebSocket] = [] self.protocol_dict: Dict[str, WebsocketSubProtocols] = {} def get_client(self, websocket: WebSocket) -> str: return f'{websocket.client.host}:{websocket.client.port}' def get_subprotocol(self, headers: Dict): try: if 'sec-websocket-protocol' in headers: subprotocol = WebsocketSubProtocols( headers['sec-websocket-protocol'] ) elif b'sec-websocket-protocol' in headers: subprotocol = WebsocketSubProtocols( headers[b'sec-websocket-protocol'].decode() ) else: subprotocol = WebsocketSubProtocols.JSON logger.debug( f'no protocol headers passed. Choosing default subprotocol {WebsocketSubProtocols.JSON}' ) except Exception as e: logger.debug( f'got an exception while setting user\'s subprotocol, defaulting to JSON {e}' ) subprotocol = WebsocketSubProtocols.JSON return subprotocol async def connect(self, websocket: WebSocket): await websocket.accept() subprotocol = self.get_subprotocol(dict(websocket.scope['headers'])) logger.info( f'client {websocket.client.host}:{websocket.client.port} connected ' f'with subprotocol {subprotocol}' ) self.active_connections.append(websocket) self.protocol_dict[self.get_client(websocket)] = subprotocol def disconnect(self, websocket: WebSocket): self.protocol_dict.pop(self.get_client(websocket)) self.active_connections.remove(websocket) async def receive(self, websocket: WebSocket) -> Any: subprotocol = self.protocol_dict[self.get_client(websocket)] if subprotocol == WebsocketSubProtocols.JSON: return await websocket.receive_json(mode='text') elif subprotocol == WebsocketSubProtocols.BYTES: return await websocket.receive_bytes() async def iter(self, websocket: WebSocket) -> AsyncIterator[Any]: try: while True: yield await self.receive(websocket) except WebSocketDisconnect: pass async def send( self, websocket: WebSocket, data: Union[DataRequest, StatusMessage] ) -> None: subprotocol = self.protocol_dict[self.get_client(websocket)] if subprotocol == WebsocketSubProtocols.JSON: return await websocket.send_json(data.to_dict(), mode='text') elif subprotocol == WebsocketSubProtocols.BYTES: return await websocket.send_bytes(data.to_bytes()) manager = ConnectionManager() app = FastAPI() from jina.serve.bff import GatewayBFF import json graph_description = json.loads(args.graph_description) graph_conditions = json.loads(args.graph_conditions) deployments_addresses = json.loads(args.deployments_addresses) deployments_disable_reduce = json.loads(args.deployments_disable_reduce) gateway_bff = GatewayBFF(graph_representation=graph_description, executor_addresses=deployments_addresses, graph_conditions=graph_conditions, deployments_disable_reduce=deployments_disable_reduce, timeout_send=timeout_send, retries=args.retries, compression=args.compression, runtime_name=args.name, prefetch=args.prefetch, logger=logger, metrics_registry=metrics_registry) @app.get( path='/', summary='Get the health of Jina service', ) async def _health(): """ Get the health of this Jina service. .. # noqa: DAR201 """ return {} @app.get( path='/status', summary='Get the status of Jina service', ) async def _status(): """ Get the status of this Jina service. This is equivalent to running `jina -vf` from command line. .. # noqa: DAR201 """ version, env_info = get_full_version() for k, v in version.items(): version[k] = str(v) for k, v in env_info.items(): env_info[k] = str(v) return {'jina': version, 'envs': env_info} @app.on_event('shutdown') async def _shutdown(): await gateway_bff.close() @app.websocket('/') async def websocket_endpoint( websocket: WebSocket, response: Response ): # 'response' is a FastAPI response, not a Jina response await manager.connect(websocket) async def req_iter(): async for request in manager.iter(websocket): if isinstance(request, dict): if request == {}: break else: # NOTE: Helps in converting camelCase to snake_case req_generator_input = JinaEndpointRequestModel(**request).dict() req_generator_input['data_type'] = DataInputType.DICT if request['data'] is not None and 'docs' in request['data']: req_generator_input['data'] = req_generator_input['data'][ 'docs' ] # you can't do `yield from` inside an async function for data_request in request_generator(**req_generator_input): yield data_request elif isinstance(request, bytes): if request == bytes(True): break else: yield DataRequest(request) try: async for msg in gateway_bff.stream(request_iterator=req_iter()): await manager.send(websocket, msg) except InternalNetworkError as err: import grpc manager.disconnect(websocket) fallback_msg = ( f'Connection to deployment at {err.dest_addr} timed out. You can adjust `timeout_send` attribute.' if err.code() == grpc.StatusCode.DEADLINE_EXCEEDED else f'Network error while connecting to deployment at {err.dest_addr}. It may be down.' ) msg = ( err.details() if _fits_ws_close_msg( err.details() ) # some messages are too long for ws closing message else fallback_msg ) await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=msg) except WebSocketDisconnect: logger.info('Client successfully disconnected from server') manager.disconnect(websocket) async def _get_singleton_result(request_iterator) -> Dict: """ Streams results from AsyncPrefetchCall as a dict :param request_iterator: request iterator, with length of 1 :return: the first result from the request iterator """ async for k in gateway_bff.stream(request_iterator=request_iterator): request_dict = k.to_dict() return request_dict from docarray import DocumentArray from jina.proto import jina_pb2 from jina.serve.executors import __dry_run_endpoint__ from jina.serve.runtimes.gateway.http.models import PROTO_TO_PYDANTIC_MODELS @app.get( path='/dry_run', summary='Get the readiness of Jina Flow service, sends an empty DocumentArray to the complete Flow to ' 'validate connectivity', response_model=PROTO_TO_PYDANTIC_MODELS.StatusProto, ) async def _dry_run_http(): """ Get the health of the complete Flow service. .. # noqa: DAR201 """ da = DocumentArray() try: _ = await _get_singleton_result( request_generator( exec_endpoint=__dry_run_endpoint__, data=da, data_type=DataInputType.DOCUMENT, ) ) status_message = StatusMessage() status_message.set_code(jina_pb2.StatusProto.SUCCESS) return status_message.to_dict() except Exception as ex: status_message = StatusMessage() status_message.set_exception(ex) return status_message.to_dict(use_integers_for_enums=True) @app.websocket('/dry_run') async def websocket_endpoint( websocket: WebSocket, response: Response ): # 'response' is a FastAPI response, not a Jina response from jina.proto import jina_pb2 from jina.serve.executors import __dry_run_endpoint__ await manager.connect(websocket) da = DocumentArray() try: async for _ in gateway_bff.stream( request_iterator=request_generator( exec_endpoint=__dry_run_endpoint__, data=da, data_type=DataInputType.DOCUMENT, ) ): pass status_message = StatusMessage() status_message.set_code(jina_pb2.StatusProto.SUCCESS) await manager.send(websocket, status_message) except InternalNetworkError as err: manager.disconnect(websocket) msg = ( err.details() if _fits_ws_close_msg(err.details()) # some messages are too long else f'Network error while connecting to deployment at {err.dest_addr}. It may be down.' ) await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=msg) except WebSocketDisconnect: logger.info('Client successfully disconnected from server') manager.disconnect(websocket) except Exception as ex: manager.disconnect(websocket) status_message = StatusMessage() status_message.set_exception(ex) await manager.send(websocket, status_message) return app