Source code for jina.peapods.runtimes.asyncio.rest.app

import argparse
import asyncio
import warnings
from typing import Any

from google.protobuf.json_format import MessageToDict, MessageToJson

from ..grpc.async_call import AsyncPrefetchCall
from ....zmq import AsyncZmqlet
from ..... import __version__
from .....clients.request import request_generator
from .....enums import RequestType
from .....helper import get_full_version, random_identity
from .....importer import ImportExtensions
from .....logging import JinaLogger, default_logger
from .....logging.profile import used_memory_readable
from .....types.message import Message
from .....types.request import Request


[docs]def get_fastapi_app(args: 'argparse.Namespace', logger: 'JinaLogger'): """ Get the app from FastAPI as the REST interface. :param args: passed arguments. :param logger: Jina logger. :return: fastapi app """ with ImportExtensions(required=True): from fastapi import FastAPI, WebSocket, Body from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from starlette.endpoints import WebSocketEndpoint from starlette import status from starlette.types import Receive, Scope, Send from starlette.responses import StreamingResponse from .models import ( JinaStatusModel, JinaIndexRequestModel, JinaDeleteRequestModel, JinaUpdateRequestModel, JinaSearchRequestModel, ) app = FastAPI( title='Jina', description='REST interface for Jina', version=__version__, ) app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) zmqlet = AsyncZmqlet(args, default_logger) servicer = AsyncPrefetchCall(args, zmqlet) def error(reason, status_code): """ Get the error code. :param reason: content of error :param status_code: status code :return: error in JSON response """ return JSONResponse(content={'reason': reason}, status_code=status_code) @app.on_event('shutdown') def _shutdown(): zmqlet.close() @app.on_event('startup') async def startup(): """Log the host information when start the server.""" default_logger.info( f''' Jina REST interface 💬 Swagger UI:\thttp://localhost:{args.port_expose}/docs 📚 Redoc :\thttp://localhost:{args.port_expose}/redoc ''' ) from jina import __ready_msg__ default_logger.success(__ready_msg__) @app.get( path='/status', summary='Get the status of Jina', response_model=JinaStatusModel, tags=['jina'], ) async def _status(): _info = get_full_version() return { 'jina': _info[0], 'envs': _info[1], 'used_memory': used_memory_readable(), } @app.post(path='/api/{mode}', deprecated=True) async def api(mode: str, body: Any = Body(...)): """ Request mode service and return results in JSON, a deprecated interface. :param mode: INDEX, SEARCH, DELETE, UPDATE, CONTROL, TRAIN. :param body: Request body. :return: Results in JSONresponse. """ warnings.warn('this interface will be retired soon', DeprecationWarning) if mode.upper() not in RequestType.__members__: return error(reason=f'unsupported mode {mode}', status_code=405) if 'data' not in body: return error('"data" field is empty', 406) body['mode'] = RequestType.from_string(mode) from .....clients import BaseClient BaseClient.add_default_kwargs(body) req_iter = request_generator(**body) results = await get_result_in_json(req_iter=req_iter) return JSONResponse(content=results[0], status_code=200) async def get_result_in_json(req_iter): """ Convert message to JSON data. :param req_iter: Request iterator :return: Results in JSON format """ return [ MessageToDict(k) async for k in servicer.Call(request_iterator=req_iter, context=None) ] @app.post(path='/index', summary='Index documents into Jina', tags=['CRUD']) async def index_api(body: JinaIndexRequestModel): """ Index API to index documents. :param body: index request. :return: Response of the results. """ from .....clients import BaseClient bd = body.dict() bd['mode'] = RequestType.INDEX return StreamingResponse( result_in_stream(request_generator(**bd)), media_type='application/json' ) @app.post(path='/search', summary='Search documents from Jina', tags=['CRUD']) async def search_api(body: JinaSearchRequestModel): """ Search API to search documents. :param body: search request. :return: Response of the results. """ from .....clients import BaseClient bd = body.dict() bd['mode'] = RequestType.SEARCH return StreamingResponse( result_in_stream(request_generator(**bd)), media_type='application/json' ) @app.put(path='/update', summary='Update documents in Jina', tags=['CRUD']) async def update_api(body: JinaUpdateRequestModel): """ Update API to update documents. :param body: update request. :return: Response of the results. """ from .....clients import BaseClient bd = body.dict() bd['mode'] = RequestType.UPDATE return StreamingResponse( result_in_stream(request_generator(**bd)), media_type='application/json' ) @app.delete(path='/delete', summary='Delete documents in Jina', tags=['CRUD']) async def delete_api(body: JinaDeleteRequestModel): """ Delete API to delete documents. :param body: delete request. :return: Response of the results. """ from .....clients import BaseClient bd = body.dict() bd['mode'] = RequestType.DELETE return StreamingResponse( result_in_stream(request_generator(**bd)), media_type='application/json' ) async def result_in_stream(req_iter): """ Streams results from AsyncPrefetchCall as json :param req_iter: request iterator :yield: result """ async for k in servicer.Call(request_iterator=req_iter, context=None): yield MessageToJson(k) @app.websocket_route(path='/stream') class StreamingEndpoint(WebSocketEndpoint): """ :meth:`handle_receive()` Await a message on :meth:`websocket.receive()` Send the message to zmqlet via :meth:`zmqlet.send_message()` and await :meth:`handle_send()` Await a message on :meth:`zmqlet.recv_message()` Send the message back to client via :meth:`websocket.send()` and await :meth:`dispatch()` Awaits on concurrent tasks :meth:`handle_receive()` & :meth:`handle_send()` This makes sure gateway is nonblocking Await exit strategy: :meth:`handle_receive()` keeps track of num_requests received :meth:`handle_send()` keeps track of num_responses sent Client sends a final message: `bytes(True)` to indicate request iterator is empty Server exits out of await when `(num_requests == num_responses != 0 and is_req_empty)` """ encoding = None def __init__(self, scope: 'Scope', receive: 'Receive', send: 'Send') -> None: super().__init__(scope, receive, send) self.args = args self.name = args.name or self.__class__.__name__ self._id = random_identity() self.client_encoding = None async def dispatch(self) -> None: """Awaits on concurrent tasks :meth:`handle_receive()` & :meth:`handle_send()`""" websocket = WebSocket(self.scope, receive=self.receive, send=self.send) await self.on_connect(websocket) close_code = status.WS_1000_NORMAL_CLOSURE await asyncio.gather( self.handle_receive(websocket=websocket, close_code=close_code), ) async def on_connect(self, websocket: WebSocket) -> None: """ Await the websocket to accept and log the information. :param websocket: connected websocket """ # TODO(Deepankar): To enable multiple concurrent clients, # Register each client - https://fastapi.tiangolo.com/advanced/websockets/#handling-disconnections-and-multiple-clients # And move class variables to instance variable await websocket.accept() self.client_info = f'{websocket.client.host}:{websocket.client.port}' logger.success( f'Client {self.client_info} connected to stream requests via websockets' ) async def handle_receive(self, websocket: WebSocket, close_code: int) -> None: """ Await a message on :meth:`websocket.receive()` Send the message to zmqlet via :meth:`zmqlet.send_message()` and await :param websocket: WebSocket connection between clinet sand server. :param close_code: close code """ def handle_route(msg: 'Message') -> 'Request': """ Add route information to `message`. :param msg: receive message :return: message response with route information """ msg.add_route(self.name, self._id) return msg.response try: while True: message = await websocket.receive() if message['type'] == 'websocket.receive': data = await self.decode(websocket, message) if data == bytes(True): await asyncio.sleep(0.1) continue await zmqlet.send_message( Message(None, Request(data), 'gateway', **vars(self.args)) ) response = await zmqlet.recv_message(callback=handle_route) if self.client_encoding == 'bytes': await websocket.send_bytes(response.SerializeToString()) else: await websocket.send_json(response.json()) elif message['type'] == 'websocket.disconnect': close_code = int( message.get('code', status.WS_1000_NORMAL_CLOSURE) ) break except Exception as exc: close_code = status.WS_1011_INTERNAL_ERROR logger.error(f'Got an exception in handle_receive: {exc!r}') raise finally: await self.on_disconnect(websocket, close_code) async def decode(self, websocket: WebSocket, message: Message) -> Any: """ Decode the text or bytes format `message` :param websocket: WebSocket connection. :param message: Jina `Message`. :return: decoded message. """ if 'text' in message or 'json' in message: self.client_encoding = 'text' if 'bytes' in message: self.client_encoding = 'bytes' return await super().decode(websocket, message) async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: """ Log the information when client is disconnected. :param websocket: disconnected websocket :param close_code: close code """ logger.info(f'Client {self.client_info} got disconnected!') return app