import asyncio
import json
import threading
from typing import TYPE_CHECKING, Optional, Tuple
import grpc
from grpc import RpcError
from jina.clients.base import BaseClient
from jina.clients.helper import callback_exec
from jina.excepts import BadClientInput, BadServerFlow, InternalNetworkError
from jina.logging.profile import ProgressBar
from jina.proto import jina_pb2, jina_pb2_grpc
from jina.serve.helper import extract_trailing_metadata
from jina.serve.networking import GrpcConnectionPool
from jina.serve.stream import RequestStreamer
from jina.types.request.data import Request
if TYPE_CHECKING: # pragma: no cover
from jina.clients.base import CallbackFnType, InputType
[docs]class GRPCBaseClient(BaseClient):
"""A simple Python client for connecting to the gRPC gateway.
It manages the asyncio event loop internally, so all interfaces are synchronous from the outside.
"""
_lock = threading.RLock()
async def _is_flow_ready(self, **kwargs) -> bool:
"""Sends a dry run to the Flow to validate if the Flow is ready to receive requests
:param kwargs: potential kwargs received passed from the public interface
:return: boolean indicating the health/readiness of the Flow
"""
try:
async with GrpcConnectionPool.get_grpc_channel(
f'{self.args.host}:{self.args.port}',
asyncio=True,
tls=self.args.tls,
) as channel:
stub = jina_pb2_grpc.JinaGatewayDryRunRPCStub(channel)
self.logger.debug(f'connected to {self.args.host}:{self.args.port}')
call_result = stub.dry_run(
jina_pb2.google_dot_protobuf_dot_empty__pb2.Empty(),
metadata=kwargs.get('metadata', None),
credentials=kwargs.get('credentials', None),
timeout=kwargs.get('timeout', None),
)
metadata, response = (
await call_result.trailing_metadata(),
await call_result,
)
if response.code == jina_pb2.StatusProto.SUCCESS:
return True
else:
self.logger.error(
f'Returned code is not expected! Exception: {response.exception}'
)
except RpcError as e:
self.logger.error(f'RpcError: {e.details()}')
except Exception as e:
self.logger.error(f'Error while getting response from grpc server {e!r}')
return False
async def _stream_rpc(
self,
channel,
req_iter,
metadata,
on_error,
on_done,
on_always,
continue_on_error,
p_bar,
**kwargs,
):
stub = jina_pb2_grpc.JinaRPCStub(channel)
async for resp in stub.Call(
req_iter,
compression=self.compression,
metadata=metadata,
credentials=kwargs.get('credentials', None),
timeout=kwargs.get('timeout', None),
):
callback_exec(
response=resp,
on_error=on_error,
on_done=on_done,
on_always=on_always,
continue_on_error=continue_on_error,
logger=self.logger,
)
if self.show_progress:
p_bar.update()
yield resp
async def _unary_rpc(
self,
channel,
req_iter,
metadata,
on_error,
on_done,
on_always,
continue_on_error,
p_bar,
results_in_order,
prefetch,
**kwargs,
):
stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel)
def _request_handler(
request: 'Request',
) -> 'Tuple[asyncio.Future, Optional[asyncio.Future]]':
return (
asyncio.ensure_future(
stub.process_single_data(
request,
compression=self.compression,
metadata=metadata,
credentials=kwargs.get('credentials', None),
timeout=kwargs.get('timeout', None),
)
),
None,
)
def _result_handler(resp):
callback_exec(
response=resp,
on_error=on_error,
on_done=on_done,
on_always=on_always,
continue_on_error=continue_on_error,
logger=self.logger,
)
return resp
streamer_args = vars(self.args)
if prefetch:
streamer_args['prefetch'] = prefetch
streamer = RequestStreamer(
request_handler=_request_handler,
result_handler=_result_handler,
iterate_sync_in_thread=False,
logger=self.logger,
**streamer_args,
)
async for response in streamer.stream(
request_iterator=req_iter, results_in_order=results_in_order
):
if self.show_progress:
p_bar.update()
yield response
async def _get_results(
self,
inputs: 'InputType',
on_done: 'CallbackFnType',
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
compression: Optional[str] = None,
max_attempts: int = 1,
initial_backoff: float = 0.5,
max_backoff: float = 0.1,
backoff_multiplier: float = 1.5,
results_in_order: bool = False,
stream: bool = True,
prefetch: Optional[int] = None,
**kwargs,
):
try:
self.compression = (
getattr(grpc.Compression, compression)
if compression
else grpc.Compression.NoCompression
)
self.inputs = inputs
req_iter = self._get_requests(**kwargs)
continue_on_error = self.continue_on_error
# while loop with retries, check in which state the `iterator` remains after failure
options = GrpcConnectionPool.get_default_grpc_options()
if max_attempts > 1:
service_config_json = json.dumps(
{
"methodConfig": [
{
# To apply retry to all methods, put [{}] in the "name" field
"name": [{}],
"retryPolicy": {
"maxAttempts": max_attempts,
"initialBackoff": f"{initial_backoff}s",
"maxBackoff": f"{max_backoff}s",
"backoffMultiplier": backoff_multiplier,
"retryableStatusCodes": [
"UNAVAILABLE",
"DEADLINE_EXCEEDED",
"INTERNAL",
],
},
}
]
}
)
# NOTE: the retry feature will be enabled by default >=v1.40.0
options.append(("grpc.enable_retries", 1))
options.append(("grpc.service_config", service_config_json))
metadata = kwargs.pop('metadata', ())
if results_in_order:
metadata = metadata + (('__results_in_order__', 'true'),)
if prefetch:
metadata = metadata + (('__prefetch__', str(prefetch)),)
with self._lock:
async with GrpcConnectionPool.get_grpc_channel(
f'{self.args.host}:{self.args.port}',
options=options,
asyncio=True,
tls=self.args.tls,
aio_tracing_client_interceptors=self.aio_tracing_client_interceptors(),
) as channel:
self.logger.debug(f'connected to {self.args.host}:{self.args.port}')
with ProgressBar(
total_length=self._inputs_length, disable=not self.show_progress
) as p_bar:
try:
if stream:
async for resp in self._stream_rpc(
channel=channel,
req_iter=req_iter,
metadata=metadata,
on_error=on_error,
on_done=on_done,
on_always=on_always,
continue_on_error=continue_on_error,
p_bar=p_bar,
**kwargs,
):
yield resp
else:
async for resp in self._unary_rpc(
channel=channel,
req_iter=req_iter,
metadata=metadata,
on_error=on_error,
on_done=on_done,
on_always=on_always,
continue_on_error=continue_on_error,
p_bar=p_bar,
results_in_order=results_in_order,
prefetch=prefetch,
**kwargs,
):
yield resp
except (
grpc.aio._call.AioRpcError,
InternalNetworkError,
) as err:
my_code = err.code()
my_details = err.details()
trailing_metadata = extract_trailing_metadata(err)
msg = f'gRPC error: {my_code} {my_details}'
if trailing_metadata:
msg = f'gRPC error: {my_code} {my_details}\n{trailing_metadata}'
if my_code == grpc.StatusCode.UNAVAILABLE:
self.logger.error(
f'{msg}\nThe ongoing request is terminated as the server is not available or closed already.'
)
raise ConnectionError(my_details)
elif my_code == grpc.StatusCode.DEADLINE_EXCEEDED:
self.logger.error(
f'{msg}\nThe ongoing request is terminated due to a server-side timeout.'
)
raise ConnectionError(my_details)
elif my_code == grpc.StatusCode.INTERNAL:
self.logger.error(
f'{msg}\ninternal error on the server side'
)
raise err
elif (
my_code == grpc.StatusCode.UNKNOWN
and 'asyncio.exceptions.TimeoutError' in my_details
):
raise BadClientInput(
f'{msg}\n'
'often the case is that you define/send a bad input iterator to jina, '
'please double check your input iterator'
) from err
else:
raise BadServerFlow(msg) from err
except KeyboardInterrupt:
self.logger.warning('user cancel the process')
except asyncio.CancelledError as ex:
self.logger.warning(f'process error: {ex!r}')
raise
except:
# Not sure why, adding this line helps in fixing a hanging test
raise