[docs]defget_fastapi_app(args:'argparse.Namespace',topology_graph:'TopologyGraph',connection_pool:'GrpcConnectionPool',logger:'JinaLogger',metrics_registry:Optional['CollectorRegistry']=None,):""" Get the app from FastAPI as the Websocket interface. :param args: passed arguments. :param topology_graph: topology graph that manages the logic of sending to the proper executors. :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them :param logger: Jina logger. :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler :return: fastapi app """fromjina.serve.runtimes.gateway.http.modelsimportJinaEndpointRequestModelwithImportExtensions(required=True):fromfastapiimportFastAPI,WebSocket,WebSocketDisconnectclassConnectionManager:def__init__(self):self.active_connections:List[WebSocket]=[]self.protocol_dict:Dict[str,WebsocketSubProtocols]={}defget_client(self,websocket:WebSocket)->str:returnf'{websocket.client.host}:{websocket.client.port}'defget_subprotocol(self,headers:Dict):try:if'sec-websocket-protocol'inheaders:subprotocol=WebsocketSubProtocols(headers['sec-websocket-protocol'])elifb'sec-websocket-protocol'inheaders:subprotocol=WebsocketSubProtocols(headers[b'sec-websocket-protocol'].decode())else:subprotocol=WebsocketSubProtocols.JSONlogger.debug(f'no protocol headers passed. Choosing default subprotocol {WebsocketSubProtocols.JSON}')exceptExceptionase:logger.debug(f'got an exception while setting user\'s subprotocol, defaulting to JSON {e}')subprotocol=WebsocketSubProtocols.JSONreturnsubprotocolasyncdefconnect(self,websocket:WebSocket):awaitwebsocket.accept()subprotocol=self.get_subprotocol(dict(websocket.scope['headers']))logger.info(f'client {websocket.client.host}:{websocket.client.port} connected 'f'with subprotocol {subprotocol}')self.active_connections.append(websocket)self.protocol_dict[self.get_client(websocket)]=subprotocoldefdisconnect(self,websocket:WebSocket):self.protocol_dict.pop(self.get_client(websocket))self.active_connections.remove(websocket)asyncdefreceive(self,websocket:WebSocket)->Any:subprotocol=self.protocol_dict[self.get_client(websocket)]ifsubprotocol==WebsocketSubProtocols.JSON:returnawaitwebsocket.receive_json(mode='text')elifsubprotocol==WebsocketSubProtocols.BYTES:returnawaitwebsocket.receive_bytes()asyncdefiter(self,websocket:WebSocket)->AsyncIterator[Any]:try:whileTrue:yieldawaitself.receive(websocket)exceptWebSocketDisconnect:passasyncdefsend(self,websocket:WebSocket,data:DataRequest)->None:subprotocol=self.protocol_dict[self.get_client(websocket)]ifsubprotocol==WebsocketSubProtocols.JSON:returnawaitwebsocket.send_json(data.to_dict(),mode='text')elifsubprotocol==WebsocketSubProtocols.BYTES:returnawaitwebsocket.send_bytes(data.to_bytes())manager=ConnectionManager()app=FastAPI()fromjina.serve.runtimes.gateway.request_handlingimportRequestHandlerfromjina.serve.streamimportRequestStreamerrequest_handler=RequestHandler(metrics_registry,args.name)streamer=RequestStreamer(args=args,request_handler=request_handler.handle_request(graph=topology_graph,connection_pool=connection_pool),result_handler=request_handler.handle_result(),)streamer.Call=streamer.stream@app.on_event('shutdown')asyncdef_shutdown():awaitconnection_pool.close()@app.websocket('/')asyncdefwebsocket_endpoint(websocket:WebSocket):awaitmanager.connect(websocket)asyncdefreq_iter():asyncforrequestinmanager.iter(websocket):ifisinstance(request,dict):ifrequest=={}:breakelse:# NOTE: Helps in converting camelCase to snake_casereq_generator_input=JinaEndpointRequestModel(**request).dict()req_generator_input['data_type']=DataInputType.DICTifrequest['data']isnotNoneand'docs'inrequest['data']:req_generator_input['data']=req_generator_input['data']['docs']# you can't do `yield from` inside an async functionfordata_requestinrequest_generator(**req_generator_input):yielddata_requestelifisinstance(request,bytes):ifrequest==bytes(True):breakelse:yieldDataRequest(request)try:asyncformsginstreamer.stream(request_iterator=req_iter()):awaitmanager.send(websocket,msg)exceptWebSocketDisconnect:logger.info('Client successfully disconnected from server')manager.disconnect(websocket)returnapp