Source code for jina.clients.base.websocket

"""A module for the websockets-based Client for Jina."""
import asyncio
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Dict, Optional, Tuple

from starlette import status

from jina.clients.base import BaseClient
from jina.clients.base.helper import WebsocketClientlet
from jina.clients.helper import callback_exec
from jina.helper import get_or_reuse_loop
from jina.importer import ImportExtensions
from jina.logging.profile import ProgressBar
from jina.proto import jina_pb2
from jina.serve.stream import RequestStreamer

if TYPE_CHECKING:
    from jina.clients.base import CallbackFnType, InputType
    from jina.types.request import Request


[docs]class WebSocketBaseClient(BaseClient): """A Websocket Client.""" async def _dry_run(self, **kwargs) -> bool: """Sends a dry run to the Flow to validate if the Flow is ready to receive requests :param kwargs: kwargs coming from the public interface. Includes arguments to be passed to the `WebsocketClientlet` :return: boolean indicating the readiness of the Flow """ async with AsyncExitStack() as stack: try: proto = 'wss' if self.args.tls else 'ws' url = f'{proto}://{self.args.host}:{self.args.port}/dry_run' iolet = await stack.enter_async_context( WebsocketClientlet(url=url, logger=self.logger, **kwargs) ) async def _receive(): try: async for response in iolet.recv_dry_run(): return response except Exception as exc: self.logger.error( f'Error while fetching response from Websocket server {exc!r}' ) raise async def _send(): return await iolet.send_dry_run() receive_task = asyncio.create_task(_receive()) if receive_task.done(): raise RuntimeError( 'receive task not running, can not send messages' ) try: send_task = asyncio.create_task(_send()) _, response_result = await asyncio.gather(send_task, receive_task) if response_result.proto.code == jina_pb2.StatusProto.SUCCESS: return True finally: if iolet.close_code == status.WS_1011_INTERNAL_ERROR: raise ConnectionError(iolet.close_message) await receive_task except Exception as e: self.logger.error( f'Error while getting response from websocket server {e!r}' ) return False async def _get_results( self, inputs: 'InputType', on_done: 'CallbackFnType', on_error: Optional['CallbackFnType'] = None, on_always: Optional['CallbackFnType'] = None, **kwargs, ): """ :param inputs: the callable :param on_done: the callback for on_done :param on_error: the callback for on_error :param on_always: the callback for on_always :param kwargs: kwargs coming from the public interface. Includes arguments to be passed to the `WebsocketClientlet` :yields: generator over results """ with ImportExtensions(required=True): import aiohttp self.inputs = inputs request_iterator = self._get_requests(**kwargs) async with AsyncExitStack() as stack: cm1 = ProgressBar( total_length=self._inputs_length, disable=not (self.show_progress) ) p_bar = stack.enter_context(cm1) proto = 'wss' if self.args.tls else 'ws' url = f'{proto}://{self.args.host}:{self.args.port}/' iolet = await stack.enter_async_context( WebsocketClientlet(url=url, logger=self.logger, **kwargs) ) request_buffer: Dict[ str, asyncio.Future ] = dict() # maps request_ids to futures (tasks) def _result_handler(result): return result async def _receive(): def _response_handler(response): if response.header.request_id in request_buffer: future = request_buffer.pop(response.header.request_id) future.set_result(response) else: self.logger.warning( f'discarding unexpected response with request id {response.header.request_id}' ) """Await messages from WebsocketGateway and process them in the request buffer""" try: async for response in iolet.recv_message(): _response_handler(response) finally: if request_buffer: self.logger.warning( f'{self.__class__.__name__} closed, cancelling all outstanding requests' ) for future in request_buffer.values(): future.cancel() request_buffer.clear() def _handle_end_of_iter(): """Send End of iteration signal to the Gateway""" asyncio.create_task(iolet.send_eoi()) def _request_handler( request: 'Request', ) -> 'Tuple[asyncio.Future, Optional[asyncio.Future]]': """ For each request in the iterator, we send the `Message` using `iolet.send_message()`. For websocket requests from client, for each request in the iterator, we send the request in `bytes` using using `iolet.send_message()`. Then add {<request-id>: <an-empty-future>} to the request buffer. This empty future is used to track the `result` of this request during `receive`. :param request: current request in the iterator :return: asyncio Future for sending message """ future = get_or_reuse_loop().create_future() request_buffer[request.header.request_id] = future asyncio.create_task(iolet.send_message(request)) return future, None streamer = RequestStreamer( request_handler=_request_handler, result_handler=_result_handler, end_of_iter_handler=_handle_end_of_iter, prefetch=getattr(self.args, 'prefetch', 0), logger=self.logger, **vars(self.args) ) receive_task = asyncio.create_task(_receive()) if receive_task.done(): raise RuntimeError('receive task not running, can not send messages') try: async for response in streamer.stream(request_iterator): callback_exec( response=response, on_error=on_error, on_done=on_done, on_always=on_always, continue_on_error=self.continue_on_error, logger=self.logger, ) if self.show_progress: p_bar.update() yield response finally: if iolet.close_code == status.WS_1011_INTERNAL_ERROR: raise ConnectionError(iolet.close_message) await receive_task