Source code for jina.serve.runtimes.gateway.websocket.app

import argparse
from typing import List, TYPE_CHECKING

from jina.importer import ImportExtensions
from jina.logging.logger import JinaLogger
from jina.types.request.data import DataRequest

if TYPE_CHECKING:
    from jina.serve.runtimes.gateway.graph.topology_graph import TopologyGraph
    from jina.serve.networking import GrpcConnectionPool


[docs]def get_fastapi_app( args: 'argparse.Namespace', topology_graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool', logger: 'JinaLogger', ): """ Get the app from FastAPI as the Websocket interface. :param args: passed arguments. :param topology_graph: topology graph that manages the logic of sending to the proper executors. :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them :param logger: Jina logger. :return: fastapi app """ with ImportExtensions(required=True): from fastapi import FastAPI, WebSocket, WebSocketDisconnect class ConnectionManager: def __init__(self): self.active_connections: List[WebSocket] = [] async def connect(self, websocket: WebSocket): await websocket.accept() logger.debug( f'client {websocket.client.host}:{websocket.client.port} connected' ) self.active_connections.append(websocket) def disconnect(self, websocket: WebSocket): self.active_connections.remove(websocket) manager = ConnectionManager() app = FastAPI() from jina.serve.stream import RequestStreamer from jina.serve.runtimes.gateway.request_handling import ( handle_request, handle_result, ) streamer = RequestStreamer( args=args, request_handler=handle_request( graph=topology_graph, connection_pool=connection_pool ), result_handler=handle_result, ) streamer.Call = streamer.stream @app.on_event('shutdown') async def _shutdown(): await connection_pool.close() @app.websocket('/') async def websocket_endpoint(websocket: WebSocket): await manager.connect(websocket) async def req_iter(): async for request_bytes in websocket.iter_bytes(): if request_bytes == bytes(True): break yield DataRequest(request_bytes) try: async for msg in streamer.stream(request_iterator=req_iter()): await websocket.send_bytes(bytes(msg)) except WebSocketDisconnect: logger.debug('Client successfully disconnected from server') manager.disconnect(websocket) return app