"""A module for the websockets-based Client for Jina."""importasynciofromcontextlibimportAsyncExitStackfromtypingimportTYPE_CHECKING,Dict,Optionalfromjina.clients.baseimportBaseClientfromjina.clients.base.helperimportWebsocketClientletfromjina.clients.helperimportcallback_exec,callback_exec_on_errorfromjina.helperimportget_or_reuse_loopfromjina.importerimportImportExtensionsfromjina.logging.profileimportProgressBarfromjina.serve.streamimportRequestStreamerifTYPE_CHECKING:fromjina.clients.baseimportCallbackFnType,InputTypefromjina.types.requestimportRequest
[docs]classWebSocketBaseClient(BaseClient):"""A Websocket Client."""asyncdef_get_results(self,inputs:'InputType',on_done:'CallbackFnType',on_error:Optional['CallbackFnType']=None,on_always:Optional['CallbackFnType']=None,**kwargs,):""" :param inputs: the callable :param on_done: the callback for on_done :param on_error: the callback for on_error :param on_always: the callback for on_always :param kwargs: kwargs for _get_task_name and _get_requests :yields: generator over results """withImportExtensions(required=True):importaiohttpself.inputs=inputsrequest_iterator=self._get_requests(**kwargs)asyncwithAsyncExitStack()asstack:try:cm1=ProgressBar(total_length=self._inputs_length,disable=not(self.show_progress))p_bar=stack.enter_context(cm1)proto='wss'ifself.args.tlselse'ws'url=f'{proto}://{self.args.host}:{self.args.port}/'iolet=awaitstack.enter_async_context(WebsocketClientlet(url=url,logger=self.logger))request_buffer:Dict[str,asyncio.Future]=dict()def_result_handler(result):returnresultasyncdef_receive():def_response_handler(response):ifresponse.header.request_idinrequest_buffer:future=request_buffer.pop(response.header.request_id)future.set_result(response)else:self.logger.warning(f'discarding unexpected response with request id {response.header.request_id}')"""Await messages from WebsocketGateway and process them in the request buffer"""try:asyncforresponseiniolet.recv_message():_response_handler(response)finally:ifrequest_buffer:self.logger.warning(f'{self.__class__.__name__} closed, cancelling all outstanding requests')forfutureinrequest_buffer.values():future.cancel()request_buffer.clear()def_handle_end_of_iter():"""Send End of iteration signal to the Gateway"""asyncio.create_task(iolet.send_eoi())def_request_handler(request:'Request')->'asyncio.Future':""" For each request in the iterator, we send the `Message` using `iolet.send_message()`. For websocket requests from client, for each request in the iterator, we send the request in `bytes` using using `iolet.send_message()`. Then add {<request-id>: <an-empty-future>} to the request buffer. This empty future is used to track the `result` of this request during `receive`. :param request: current request in the iterator :return: asyncio Future for sending message """future=get_or_reuse_loop().create_future()request_buffer[request.header.request_id]=futureasyncio.create_task(iolet.send_message(request))returnfuturestreamer=RequestStreamer(args=self.args,request_handler=_request_handler,result_handler=_result_handler,end_of_iter_handler=_handle_end_of_iter,)receive_task=get_or_reuse_loop().create_task(_receive())ifreceive_task.done():raiseRuntimeError('receive task not running, can not send messages')asyncforresponseinstreamer.stream(request_iterator):callback_exec(response=response,on_error=on_error,on_done=on_done,on_always=on_always,continue_on_error=self.continue_on_error,logger=self.logger,)ifself.show_progress:p_bar.update()yieldresponseexceptaiohttp.ClientErrorase:self.logger.error(f'Error while streaming response from websocket server {e!r}')ifon_errororon_always:ifon_error:callback_exec_on_error(on_error,e,self.logger)ifon_always:callback_exec(response=None,on_error=None,on_done=None,on_always=on_always,continue_on_error=self.continue_on_error,logger=self.logger,)else:raisee