import argparse
import asyncio
from typing import (
TYPE_CHECKING,
AsyncIterator,
Awaitable,
Callable,
Iterator,
Optional,
Union,
)
from jina.excepts import InternalNetworkError
from jina.logging.logger import JinaLogger
from jina.serve.stream.helper import AsyncRequestsIterator, _RequestsCounter
__all__ = ['RequestStreamer']
from jina.types.request.data import Response
if TYPE_CHECKING:
from jina.types.request import Request
[docs]class RequestStreamer:
"""
A base async request/response streamer.
"""
class _EndOfStreaming(Exception):
pass
def __init__(
self,
request_handler: Callable[['Request'], 'Awaitable[Request]'],
result_handler: Callable[['Request'], Optional['Request']],
prefetch: int = 0,
end_of_iter_handler: Optional[Callable[[], None]] = None,
logger: Optional['JinaLogger'] = None,
**logger_kwargs
):
"""
:param request_handler: The callable responsible for handling the request. It should handle a request as input and return a Future to be awaited
:param result_handler: The callable responsible for handling the response.
:param end_of_iter_handler: Optional callable to handle the end of iteration if some special action needs to be taken.
:param prefetch: How many Requests are processed from the Client at the same time.
:param logger: Optional logger that can be used for logging
:param logger_kwargs: Extra keyword arguments that may be passed to the internal logger constructor if none is provided
"""
self.logger = logger or JinaLogger(self.__class__.__name__, **logger_kwargs)
self._prefetch = prefetch
self._request_handler = request_handler
self._result_handler = result_handler
self._end_of_iter_handler = end_of_iter_handler
self.total_num_floating_tasks_alive = 0
[docs] async def stream(
self, request_iterator, context=None, *args
) -> AsyncIterator['Request']:
"""
stream requests from client iterator and stream responses back.
:param request_iterator: iterator of requests
:param context: context of the grpc call
:param args: positional arguments
:yield: responses from Executors
"""
async_iter: AsyncIterator = self._stream_requests(request_iterator)
try:
async for response in async_iter:
yield response
except InternalNetworkError as err:
if (
context is not None
): # inside GrpcGateway we can handle the error directly here through the grpc context
context.set_details(err.details())
context.set_code(err.code())
self.logger.error(
f'Error while getting responses from deployments: {err.details()}'
)
r = Response()
if err.request_id:
r.header.request_id = err.request_id
yield r
else: # HTTP and WS need different treatment further up the stack
raise
async def _stream_requests(
self,
request_iterator: Union[Iterator, AsyncIterator],
) -> AsyncIterator:
"""Implements request and response handling without prefetching
:param request_iterator: requests iterator from Client
:yield: responses
"""
result_queue = asyncio.Queue()
floating_results_queue = asyncio.Queue()
end_of_iter = asyncio.Event()
all_requests_handled = asyncio.Event()
requests_to_handle = _RequestsCounter()
floating_tasks_to_handle = _RequestsCounter()
all_floating_requests_awaited = asyncio.Event()
empty_requests_iterator = asyncio.Event()
def update_all_handled():
if end_of_iter.is_set() and requests_to_handle.count == 0:
all_requests_handled.set()
async def end_future():
raise self._EndOfStreaming
def callback(future: 'asyncio.Future'):
"""callback to be run after future is completed.
1. Put the future in the result queue.
2. Remove the future from futures when future is completed.
..note::
callback cannot be an awaitable, hence we cannot do `await queue.put(...)` here.
We don't add `future.result()` to the queue, as that would consume the exception in the callback,
which is difficult to handle.
:param future: asyncio Future object retured from `handle_response`
"""
result_queue.put_nowait(future)
def hanging_callback(future: 'asyncio.Future'):
floating_results_queue.put_nowait(future)
async def iterate_requests() -> None:
"""
1. Traverse through the request iterator.
2. `add_done_callback` to the future returned by `handle_request`.
This callback adds the completed future to `result_queue`
3. Append future to list of futures.
4. Handle EOI (needed for websocket client)
5. Set `end_of_iter` event
"""
num_reqs = 0
async for request in AsyncRequestsIterator(
iterator=request_iterator,
request_counter=requests_to_handle,
prefetch=self._prefetch,
):
num_reqs += 1
requests_to_handle.count += 1
future_responses, future_hanging = self._request_handler(
request=request
)
future_responses.add_done_callback(callback)
if future_hanging is not None:
floating_tasks_to_handle.count += 1
future_hanging.add_done_callback(hanging_callback)
else:
all_floating_requests_awaited.set()
if num_reqs == 0:
empty_requests_iterator.set()
if self._end_of_iter_handler is not None:
self._end_of_iter_handler()
end_of_iter.set()
update_all_handled()
if all_requests_handled.is_set():
# It will be waiting for something that will never appear
future_cancel = asyncio.ensure_future(end_future())
result_queue.put_nowait(future_cancel)
if (
all_floating_requests_awaited.is_set()
or empty_requests_iterator.is_set()
):
# It will be waiting for something that will never appear
future_cancel = asyncio.ensure_future(end_future())
floating_results_queue.put_nowait(future_cancel)
async def handle_floating_responses():
while (
not all_floating_requests_awaited.is_set()
and not empty_requests_iterator.is_set()
):
hanging_response = await floating_results_queue.get()
try:
hanging_response.result()
floating_tasks_to_handle.count -= 1
if floating_tasks_to_handle.count == 0 and end_of_iter.is_set():
all_floating_requests_awaited.set()
except self._EndOfStreaming:
pass
asyncio.create_task(iterate_requests())
handle_floating_task = asyncio.create_task(handle_floating_responses())
self.total_num_floating_tasks_alive += 1
def floating_task_done(*args):
self.total_num_floating_tasks_alive -= 1
handle_floating_task.add_done_callback(floating_task_done)
while not all_requests_handled.is_set():
future = await result_queue.get()
try:
response = self._result_handler(future.result())
yield response
requests_to_handle.count -= 1
update_all_handled()
except self._EndOfStreaming:
pass
[docs] async def wait_floating_requests_end(self):
"""
Await this coroutine to make sure that all the floating tasks that the request handler may bring are properly consumed
"""
while self.total_num_floating_tasks_alive > 0:
await asyncio.sleep(0)