Source code for jina.serve.networking

import asyncio
import contextlib
import ipaddress
import os
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from urllib.parse import urlparse

import grpc
from grpc.aio import AioRpcError
from grpc_reflection.v1alpha.reflection_pb2 import ServerReflectionRequest
from grpc_reflection.v1alpha.reflection_pb2_grpc import ServerReflectionStub

from jina.enums import PollingType
from jina.importer import ImportExtensions
from jina.logging.logger import JinaLogger
from jina.proto import jina_pb2, jina_pb2_grpc
from jina.types.request import Request
from jina.types.request.control import ControlRequest
from jina.types.request.data import DataRequest

TLS_PROTOCOL_SCHEMES = ['grpcs', 'https', 'wss']

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from prometheus_client import CollectorRegistry


[docs]class ReplicaList: """ Maintains a list of connections to replicas and uses round robin for selecting a replica """ def __init__(self, summary): self._connections = [] self._address_to_connection_idx = {} self._address_to_channel = {} self._rr_counter = 0 self.summary = summary
[docs] def add_connection(self, address: str): """ Add connection with address to the connection list :param address: Target address of this connection """ if address not in self._address_to_connection_idx: try: parsed_address = urlparse(address) address = parsed_address.netloc if parsed_address.netloc else address use_tls = parsed_address.scheme in TLS_PROTOCOL_SCHEMES except: use_tls = False self._address_to_connection_idx[address] = len(self._connections) stubs, channel = GrpcConnectionPool.create_async_channel_stub( address, tls=use_tls, summary=self.summary ) self._address_to_channel[address] = channel self._connections.append(stubs)
[docs] async def remove_connection(self, address: str): """ Remove connection with address from the connection list :param address: Remove connection for this address :returns: The removed connection or None if there was not any for the given address """ if address in self._address_to_connection_idx: self._rr_counter = ( self._rr_counter % (len(self._connections) - 1) if (len(self._connections) - 1) else 0 ) idx_to_delete = self._address_to_connection_idx.pop(address) popped_connection = self._connections.pop(idx_to_delete) # we should handle graceful termination better, 0.5 is a rather random number here await self._address_to_channel[address].close(0.5) del self._address_to_channel[address] # update the address/idx mapping for address in self._address_to_connection_idx: if self._address_to_connection_idx[address] > idx_to_delete: self._address_to_connection_idx[address] -= 1 return popped_connection return None
[docs] def get_next_connection(self): """ Returns a connection from the list. Strategy is round robin :returns: A connection from the pool """ try: connection = self._connections[self._rr_counter] except IndexError: # This can happen as a race condition while removing connections self._rr_counter = 0 connection = self._connections[self._rr_counter] self._rr_counter = (self._rr_counter + 1) % len(self._connections) return connection
[docs] def get_all_connections(self): """ Returns all available connections :returns: A complete list of all connections from the pool """ return self._connections
[docs] def has_connection(self, address: str) -> bool: """ Checks if a connection for ip exists in the list :param address: The address to check :returns: True if a connection for the ip exists in the list """ return address in self._address_to_connection_idx
[docs] def has_connections(self) -> bool: """ Checks if this contains any connection :returns: True if any connection is managed, False otherwise """ return len(self._address_to_connection_idx) > 0
[docs] async def close(self): """ Close all connections and clean up internal state """ for address in self._address_to_channel: await self._address_to_channel[address].close(0.5) self._address_to_channel.clear() self._address_to_connection_idx.clear() self._connections.clear() self._rr_counter = 0
[docs]class GrpcConnectionPool: """ Manages a list of grpc connections. :param logger: the logger to use :param compression: The compression algorithm to be used by this GRPCConnectionPool when sending data to GRPC """ K8S_PORT_USES_AFTER = 8082 K8S_PORT_USES_BEFORE = 8081 K8S_PORT = 8080
[docs] class ConnectionStubs: """ Maintains a list of grpc stubs available for a particular connection """ STUB_MAPPING = { 'jina.JinaControlRequestRPC': jina_pb2_grpc.JinaControlRequestRPCStub, 'jina.JinaDataRequestRPC': jina_pb2_grpc.JinaDataRequestRPCStub, 'jina.JinaSingleDataRequestRPC': jina_pb2_grpc.JinaSingleDataRequestRPCStub, 'jina.JinaDiscoverEndpointsRPC': jina_pb2_grpc.JinaDiscoverEndpointsRPCStub, 'jina.JinaRPC': jina_pb2_grpc.JinaRPCStub, } def __init__(self, address, channel, summary): self.address = address self.channel = channel self._summary_time = summary self._initialized = False # This has to be done lazily, because the target endpoint may not be available # when a connection is added async def _init_stubs(self): available_services = await GrpcConnectionPool.get_available_services( self.channel ) stubs = defaultdict(lambda: None) for service in available_services: stubs[service] = self.STUB_MAPPING[service](self.channel) self.control_stub = stubs['jina.JinaControlRequestRPC'] self.data_list_stub = stubs['jina.JinaDataRequestRPC'] self.single_data_stub = stubs['jina.JinaSingleDataRequestRPC'] self.stream_stub = stubs['jina.JinaRPC'] self.endpoints_discovery_stub = stubs['jina.JinaDiscoverEndpointsRPC'] self._initialized = True
[docs] async def send_discover_endpoint( self, timeout: Optional[float] = None, ) -> Tuple: """ Use the endpoint discovery stub to request for the Endpoints Exposed by an Executor :param timeout: defines timeout for sending request :returns: Tuple of response and metadata about the response """ if not self._initialized: await self._init_stubs() call_result = self.endpoints_discovery_stub.endpoint_discovery( jina_pb2.google_dot_protobuf_dot_empty__pb2.Empty(), timeout=timeout, ) metadata, response = ( await call_result.trailing_metadata(), await call_result, ) return response, metadata
[docs] async def send_requests( self, requests: List[Request], metadata, compression, timeout: Optional[float] = None, ) -> Tuple: """ Send requests and uses the appropriate grpc stub for this Stub is chosen based on availability and type of requests :param requests: the requests to send :param metadata: the metadata to send alongside the requests :param compression: defines if compression should be used :param timeout: defines timeout for sending request :returns: Tuple of response and metadata about the response """ if not self._initialized: await self._init_stubs() request_type = type(requests[0]) if request_type == DataRequest and len(requests) == 1: if self.single_data_stub: call_result = self.single_data_stub.process_single_data( requests[0], metadata=metadata, compression=compression, timeout=timeout, ) with self._summary_time: metadata, response = ( await call_result.trailing_metadata(), await call_result, ) return response, metadata elif self.stream_stub: with self._summary_time: async for resp in self.stream_stub.Call( iter(requests), compression=compression, timeout=timeout ): return resp, None if request_type == DataRequest and len(requests) > 1: if self.data_list_stub: call_result = self.data_list_stub.process_data( requests, metadata=metadata, compression=compression, timeout=timeout, ) with self._summary_time: metadata, response = ( await call_result.trailing_metadata(), await call_result, ) return response, metadata else: raise ValueError( 'Can not send list of DataRequests. gRPC endpoint not available.' ) elif request_type == ControlRequest: if self.control_stub: call_result = self.control_stub.process_control( requests[0], timeout=timeout ) metadata, response = ( await call_result.trailing_metadata(), await call_result, ) return response, metadata else: raise ValueError( 'Can not send ControlRequest. gRPC endpoint not available.' ) else: raise ValueError(f'Unsupported request type {type(requests[0])}')
class _ConnectionPoolMap: def __init__(self, logger: Optional[JinaLogger], summary): self._logger = logger # this maps deployments to shards or heads self._deployments: Dict[str, Dict[str, Dict[int, ReplicaList]]] = {} # dict stores last entity id used for a particular deployment, used for round robin self._access_count: Dict[str, int] = {} self.summary = summary if os.name != 'nt': os.unsetenv('http_proxy') os.unsetenv('https_proxy') def add_replica(self, deployment: str, shard_id: int, address: str): self._add_connection(deployment, shard_id, address, 'shards') def add_head( self, deployment: str, address: str, head_id: Optional[int] = 0 ): # the head_id is always 0 for now, this will change when scaling the head self._add_connection(deployment, head_id, address, 'heads') def get_replicas( self, deployment: str, head: bool, entity_id: Optional[int] = None, increase_access_count: bool = True, ) -> ReplicaList: if deployment in self._deployments: type_ = 'heads' if head else 'shards' if entity_id is None and head: entity_id = 0 return self._get_connection_list( deployment, type_, entity_id, increase_access_count ) else: self._logger.debug( f'Unknown deployment {deployment}, no replicas available' ) return None def get_replicas_all_shards(self, deployment: str) -> List[ReplicaList]: replicas = [] if deployment in self._deployments: for shard_id in self._deployments[deployment]['shards']: replicas.append( self._get_connection_list(deployment, 'shards', shard_id) ) return replicas async def close(self): # Close all connections to all replicas for deployment in self._deployments: for entity_type in self._deployments[deployment]: for shard_in in self._deployments[deployment][entity_type]: await self._deployments[deployment][entity_type][ shard_in ].close() self._deployments.clear() def _get_connection_list( self, deployment: str, type_: str, entity_id: Optional[int] = None, increase_access_count: bool = True, ) -> ReplicaList: try: if entity_id is None and len(self._deployments[deployment][type_]) > 0: # select a random entity if increase_access_count: self._access_count[deployment] += 1 return self._deployments[deployment][type_][ self._access_count[deployment] % len(self._deployments[deployment][type_]) ] else: return self._deployments[deployment][type_][entity_id] except KeyError: if ( entity_id is None and deployment in self._deployments and len(self._deployments[deployment][type_]) ): # This can happen as a race condition when removing connections while accessing it # In this case we don't care for the concrete entity, so retry with the first one return self._get_connection_list( deployment, type_, 0, increase_access_count ) self._logger.debug( f'did not find a connection for deployment {deployment}, type {type} and entity_id {entity_id}. There are {len(self._deployments[deployment][type]) if deployment in self._deployments else 0} available connections for this deployment and type. ' ) return None def _add_deployment(self, deployment: str): if deployment not in self._deployments: self._deployments[deployment] = {'shards': {}, 'heads': {}} self._access_count[deployment] = 0 def _add_connection( self, deployment: str, entity_id: int, address: str, type: str, ): self._add_deployment(deployment) if entity_id not in self._deployments[deployment][type]: connection_list = ReplicaList(self.summary) self._deployments[deployment][type][entity_id] = connection_list if not self._deployments[deployment][type][entity_id].has_connection( address ): self._logger.debug( f'adding connection for deployment {deployment}/{type}/{entity_id} to {address}' ) self._deployments[deployment][type][entity_id].add_connection(address) else: self._logger.debug( f'ignoring activation of pod, {address} already known' ) async def remove_head(self, deployment, address, head_id: Optional[int] = 0): return await self._remove_connection(deployment, head_id, address, 'heads') async def remove_replica( self, deployment, address, shard_id: Optional[int] = 0 ): return await self._remove_connection( deployment, shard_id, address, 'shards' ) async def _remove_connection(self, deployment, entity_id, address, type): if ( deployment in self._deployments and entity_id in self._deployments[deployment][type] ): self._logger.debug( f'removing connection for deployment {deployment}/{type}/{entity_id} to {address}' ) connection = await self._deployments[deployment][type][ entity_id ].remove_connection(address) if not self._deployments[deployment][type][entity_id].has_connections(): del self._deployments[deployment][type][entity_id] return connection return None def __init__( self, logger: Optional[JinaLogger] = None, compression: str = 'NoCompression', metrics_registry: Optional['CollectorRegistry'] = None, ): self._logger = logger or JinaLogger(self.__class__.__name__) GRPC_COMPRESSION_MAP = { 'NoCompression'.lower(): grpc.Compression.NoCompression, 'Gzip'.lower(): grpc.Compression.Gzip, 'Deflate'.lower(): grpc.Compression.Deflate, } if compression.lower() not in GRPC_COMPRESSION_MAP: import warnings warnings.warn( message=f'Your compression "{compression}" is not supported. Supported ' f'algorithms are `Gzip`, `Deflate` and `NoCompression`. NoCompression will be used as ' f'default' ) self.compression = GRPC_COMPRESSION_MAP.get( compression.lower(), grpc.Compression.NoCompression ) if metrics_registry: with ImportExtensions( required=True, help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Summary self._summary_time = Summary( 'sending_request_seconds', 'Time spent between sending a request to the Pod and receiving the response', registry=metrics_registry, namespace='jina', ).time() else: self._summary_time = contextlib.nullcontext() self._connections = self._ConnectionPoolMap(self._logger, self._summary_time)
[docs] def send_request( self, request: Request, deployment: str, head: bool = False, shard_id: Optional[int] = None, polling_type: PollingType = PollingType.ANY, endpoint: Optional[str] = None, timeout: Optional[float] = None, ) -> List[asyncio.Task]: """Send a single message to target via one or all of the pooled connections, depending on polling_type. Convenience function wrapper around send_request. :param request: a single request to send :param deployment: name of the Jina deployment to send the message to :param head: If True it is send to the head, otherwise to the worker pods :param shard_id: Send to a specific shard of the deployment, ignored for polling ALL :param polling_type: defines if the message should be send to any or all pooled connections for the target :param endpoint: endpoint to target with the request :param timeout: timeout for sending the requests :return: list of asyncio.Task items for each send call """ return self.send_requests( requests=[request], deployment=deployment, head=head, shard_id=shard_id, polling_type=polling_type, endpoint=endpoint, timeout=timeout, )
[docs] def send_requests( self, requests: List[Request], deployment: str, head: bool = False, shard_id: Optional[int] = None, polling_type: PollingType = PollingType.ANY, endpoint: Optional[str] = None, timeout: Optional[float] = None, ) -> List[asyncio.Task]: """Send a request to target via one or all of the pooled connections, depending on polling_type :param requests: request (DataRequest/ControlRequest) to send :param deployment: name of the Jina deployment to send the request to :param head: If True it is send to the head, otherwise to the worker pods :param shard_id: Send to a specific shard of the deployment, ignored for polling ALL :param polling_type: defines if the request should be send to any or all pooled connections for the target :param endpoint: endpoint to target with the requests :param timeout: timeout for sending the requests :return: list of asyncio.Task items for each send call """ results = [] connections = [] if polling_type == PollingType.ANY: connection_list = self._connections.get_replicas(deployment, head, shard_id) if connection_list: connections.append(connection_list.get_next_connection()) elif polling_type == PollingType.ALL: connection_lists = self._connections.get_replicas_all_shards(deployment) for connection_list in connection_lists: connections.append(connection_list.get_next_connection()) else: raise ValueError(f'Unsupported polling type {polling_type}') for connection in connections: task = self._send_requests(requests, connection, endpoint, timeout=timeout) results.append(task) return results
[docs] def send_discover_endpoint( self, deployment: str, head: bool = True, shard_id: Optional[int] = None, timeout: Optional[float] = None, ) -> asyncio.Task: """Sends a discover Endpoint call to target. :param deployment: name of the Jina deployment to send the request to :param head: If True it is send to the head, otherwise to the worker pods :param shard_id: Send to a specific shard of the deployment, ignored for polling ALL :param timeout: timeout for sending the requests :return: asyncio.Task items to send call """ connection = None connection_list = self._connections.get_replicas( deployment, head, shard_id, False ) if connection_list: connection = connection_list.get_next_connection() return self._send_discover_endpoint(connection, timeout=timeout)
[docs] def send_request_once( self, request: Request, deployment: str, head: bool = False, shard_id: Optional[int] = None, timeout: Optional[float] = None, ) -> asyncio.Task: """Send msg to target via only one of the pooled connections :param request: request to send :param deployment: name of the Jina deployment to send the message to :param head: If True it is send to the head, otherwise to the worker pods :param shard_id: Send to a specific shard of the deployment, ignored for polling ALL :param timeout: timeout for sending the requests :return: asyncio.Task representing the send call """ return self.send_requests_once( [request], deployment=deployment, head=head, shard_id=shard_id, timeout=timeout, )
[docs] def send_requests_once( self, requests: List[Request], deployment: str, head: bool = False, shard_id: Optional[int] = None, endpoint: Optional[str] = None, timeout: Optional[float] = None, ) -> asyncio.Task: """Send a request to target via only one of the pooled connections :param requests: request to send :param deployment: name of the Jina deployment to send the request to :param head: If True it is send to the head, otherwise to the worker pods :param shard_id: Send to a specific shard of the deployment, ignored for polling ALL :param endpoint: endpoint to target with the requests :param timeout: timeout for sending the requests :return: asyncio.Task representing the send call """ replicas = self._connections.get_replicas(deployment, head, shard_id) if replicas: connection = replicas.get_next_connection() return self._send_requests(requests, connection, endpoint, timeout=timeout) else: self._logger.debug( f'no available connections for deployment {deployment} and shard {shard_id}' ) return None
[docs] def add_connection( self, deployment: str, address: str, head: bool = False, shard_id: Optional[int] = None, ): """ Adds a connection for a deployment to this connection pool :param deployment: The deployment the connection belongs to, like 'encoder' :param head: True if the connection is for a head :param address: Address used for the grpc connection, format is <host>:<port> :param shard_id: Optional parameter to indicate this connection belongs to a shard, ignored for heads """ if head: self._connections.add_head(deployment, address, 0) else: if shard_id is None: shard_id = 0 self._connections.add_replica(deployment, shard_id, address)
[docs] async def remove_connection( self, deployment: str, address: str, head: bool = False, shard_id: Optional[int] = None, ): """ Removes a connection to a deployment :param deployment: The deployment the connection belongs to, like 'encoder' :param address: Address used for the grpc connection, format is <host>:<port> :param head: True if the connection is for a head :param shard_id: Optional parameter to indicate this connection belongs to a shard, ignored for heads :return: The removed connection, None if it did not exist """ if head: return await self._connections.remove_head(deployment, address) else: if shard_id is None: shard_id = 0 return await self._connections.remove_replica(deployment, address, shard_id)
[docs] def start(self): """ Starts the connection pool """ pass
[docs] async def close(self): """ Closes the connection pool """ await self._connections.close()
def _send_requests( self, requests: List[Request], connection: ConnectionStubs, endpoint: Optional[str] = None, timeout: Optional[float] = None, ) -> asyncio.Task: # this wraps the awaitable object from grpc as a coroutine so it can be used as a task # the grpc call function is not a coroutine but some _AioCall async def task_wrapper(): metadata = (('endpoint', endpoint),) if endpoint else None for i in range(3): try: return await connection.send_requests( requests=requests, metadata=metadata, compression=self.compression, timeout=timeout, ) except AioRpcError as e: if i == 2: self._logger.debug(f'GRPC call failed, retries exhausted') raise else: self._logger.debug( f'GRPC call failed with code {e.code()}, retry attempt {i + 1}/3' ) return asyncio.create_task(task_wrapper()) def _send_discover_endpoint( self, connection: ConnectionStubs, timeout: Optional[float] = None, ) -> asyncio.Task: # this wraps the awaitable object from grpc as a coroutine so it can be used as a task # the grpc call function is not a coroutine but some _AioCall async def task_wrapper(): for i in range(3): try: return await connection.send_discover_endpoint( timeout=timeout, ) except AioRpcError as e: # connection failures and cancelled requests should be retried # all other cases should not be retried and will be raised immediately # connection failures have the code grpc.StatusCode.UNAVAILABLE # cancelled requests have the code grpc.StatusCode.CANCELLED # requests usually gets cancelled when the server shuts down # retries for cancelled requests will hit another replica in K8s if ( e.code() != grpc.StatusCode.UNAVAILABLE and e.code() != grpc.StatusCode.CANCELLED ): raise elif e.code() == grpc.StatusCode.UNAVAILABLE and i == 2: self._logger.debug(f'GRPC call failed, retries exhausted') raise else: self._logger.debug( f'GRPC call failed with code {e.code()}, retry attempt {i + 1}/3' ) except AttributeError: # in gateway2gateway communication, gateway does not expose this endpoint. So just send empty list which corresponds to all endpoints valid from jina import __default_endpoint__ ep = jina_pb2.EndpointsProto() ep.endpoints.extend([__default_endpoint__]) return ep, None return asyncio.create_task(task_wrapper())
[docs] @staticmethod def get_grpc_channel( address: str, options: Optional[list] = None, asyncio: bool = False, tls: bool = False, root_certificates: Optional[str] = None, ) -> grpc.Channel: """ Creates a grpc channel to the given address :param address: The address to connect to, format is <host>:<port> :param options: A list of options to pass to the grpc channel :param asyncio: If True, use the asyncio implementation of the grpc channel :param tls: If True, use tls encryption for the grpc channel :param root_certificates: The path to the root certificates for tls, only used if tls is True :return: A grpc channel or an asyncio channel """ secure_channel = grpc.secure_channel insecure_channel = grpc.insecure_channel if asyncio: secure_channel = grpc.aio.secure_channel insecure_channel = grpc.aio.insecure_channel if options is None: options = GrpcConnectionPool.get_default_grpc_options() if tls: credentials = grpc.ssl_channel_credentials( root_certificates=root_certificates ) return secure_channel(address, credentials, options) return insecure_channel(address, options)
[docs] @staticmethod def activate_worker_sync( worker_host: str, worker_port: int, target_head: str, shard_id: Optional[int] = None, ) -> ControlRequest: """ Register a given worker to a head by sending an activate request :param worker_host: the host address of the worker :param worker_port: the port of the worker :param target_head: address of the head to send the activate request to :param shard_id: id of the shard the worker belongs to :returns: the response request """ activate_request = ControlRequest(command='ACTIVATE') activate_request.add_related_entity( 'worker', worker_host, worker_port, shard_id ) if os.name != 'nt': os.unsetenv('http_proxy') os.unsetenv('https_proxy') return GrpcConnectionPool.send_request_sync(activate_request, target_head)
[docs] @staticmethod async def activate_worker( worker_host: str, worker_port: int, target_head: str, shard_id: Optional[int] = None, ) -> ControlRequest: """ Register a given worker to a head by sending an activate request :param worker_host: the host address of the worker :param worker_port: the port of the worker :param target_head: address of the head to send the activate request to :param shard_id: id of the shard the worker belongs to :returns: the response request """ activate_request = ControlRequest(command='ACTIVATE') activate_request.add_related_entity( 'worker', worker_host, worker_port, shard_id ) return await GrpcConnectionPool.send_request_async( activate_request, target_head )
[docs] @staticmethod async def deactivate_worker( worker_host: str, worker_port: int, target_head: str, shard_id: Optional[int] = None, ) -> ControlRequest: """ Remove a given worker to a head by sending a deactivate request :param worker_host: the host address of the worker :param worker_port: the port of the worker :param target_head: address of the head to send the deactivate request to :param shard_id: id of the shard the worker belongs to :returns: the response request """ activate_request = ControlRequest(command='DEACTIVATE') activate_request.add_related_entity( 'worker', worker_host, worker_port, shard_id ) return await GrpcConnectionPool.send_request_async( activate_request, target_head )
[docs] @staticmethod def send_request_sync( request: Request, target: str, timeout=100.0, tls=False, root_certificates: Optional[str] = None, endpoint: Optional[str] = None, ) -> Request: """ Sends a request synchronously to the target via grpc :param request: the request to send :param target: where to send the request to, like 127.0.0.1:8080 :param timeout: timeout for the send :param tls: if True, use tls encryption for the grpc channel :param root_certificates: the path to the root certificates for tls, only used if tls is True :param endpoint: endpoint to target with the request :returns: the response request """ for i in range(3): try: with GrpcConnectionPool.get_grpc_channel( target, tls=tls, root_certificates=root_certificates, ) as channel: if type(request) == DataRequest: metadata = (('endpoint', endpoint),) if endpoint else None stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) response, call = stub.process_single_data.with_call( request, timeout=timeout, metadata=metadata, ) elif type(request) == ControlRequest: stub = jina_pb2_grpc.JinaControlRequestRPCStub(channel) response = stub.process_control(request, timeout=timeout) return response except grpc.RpcError as e: if e.code() != grpc.StatusCode.UNAVAILABLE or i == 2: raise
[docs] @staticmethod def send_requests_sync( requests: List[Request], target: str, timeout=100.0, tls=False, root_certificates: Optional[str] = None, endpoint: Optional[str] = None, ) -> Request: """ Sends a list of requests synchronically to the target via grpc :param requests: the requests to send :param target: where to send the request to, like 127.0.0.1:8080 :param timeout: timeout for the send :param tls: if True, use tls for the grpc channel :param root_certificates: the path to the root certificates for tls, only used if tls is True :param endpoint: endpoint to target with the request :returns: the response request """ for i in range(3): try: with GrpcConnectionPool.get_grpc_channel( target, tls=tls, root_certificates=root_certificates, ) as channel: metadata = (('endpoint', endpoint),) if endpoint else None stub = jina_pb2_grpc.JinaDataRequestRPCStub(channel) response, call = stub.process_data.with_call( requests, timeout=timeout, metadata=metadata, ) return response except grpc.RpcError as e: if e.code() != grpc.StatusCode.UNAVAILABLE or i == 2: raise
[docs] @staticmethod def get_default_grpc_options(): """ Returns a list of default options used for creating grpc channels. Documentation is here https://github.com/grpc/grpc/blob/master/include/grpc/impl/codegen/grpc_types.h :returns: list of tuples defining grpc parameters """ return [ ('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1), ]
[docs] @staticmethod async def send_request_async( request: Request, target: str, timeout: float = 1.0, tls: bool = False, root_certificates: Optional[str] = None, ) -> Request: """ Sends a request asynchronously to the target via grpc :param request: the request to send :param target: where to send the request to, like 127.0.0.1:8080 :param timeout: timeout for the send :param tls: if True, use tls for the grpc channel :param root_certificates: the path to the root certificates for tls, only used if tls is True :returns: the response request """ async with GrpcConnectionPool.get_grpc_channel( target, asyncio=True, tls=tls, root_certificates=root_certificates, ) as channel: if type(request) == DataRequest: stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) return await stub.process_single_data(request, timeout=timeout) elif type(request) == ControlRequest: stub = jina_pb2_grpc.JinaControlRequestRPCStub(channel) return await stub.process_control(request, timeout=timeout)
[docs] @staticmethod def create_async_channel_stub( address, tls=False, root_certificates: Optional[str] = None, summary=None ) -> Tuple[ConnectionStubs, grpc.aio.Channel]: """ Creates an async GRPC Channel. This channel has to be closed eventually! :param address: the address to create the connection to, like 127.0.0.0.1:8080 :param tls: if True, use tls for the grpc channel :param root_certificates: the path to the root certificates for tls, only u :param summary: Optional Prometheus summary object :returns: DataRequest/ControlRequest stubs and an async grpc channel """ channel = GrpcConnectionPool.get_grpc_channel( address, asyncio=True, tls=tls, root_certificates=root_certificates, ) return ( GrpcConnectionPool.ConnectionStubs(address, channel, summary), channel, )
[docs] @staticmethod async def get_available_services(channel) -> List[str]: """ Lists available services by name, exposed at target address :param channel: the channel to use :returns: List of services offered """ reflection_stub = ServerReflectionStub(channel) response = reflection_stub.ServerReflectionInfo( iter([ServerReflectionRequest(list_services="")]) ) service_names = [] async for res in response: service_names.append( [ service.name for service in res.list_services_response.service if service.name != 'grpc.reflection.v1alpha.ServerReflection' ] ) return service_names[0]
[docs]def in_docker(): """ Checks if the current process is running inside Docker :return: True if the current process is running inside Docker """ path = '/proc/self/cgroup' if os.path.exists('/.dockerenv'): return True if os.path.isfile(path): with open(path) as file: return any('docker' in line for line in file) return False
[docs]def host_is_local(hostname): """ Check if hostname is point to localhost :param hostname: host to check :return: True if hostname means localhost, False otherwise """ import socket fqn = socket.getfqdn(hostname) if fqn in ("localhost", "0.0.0.0") or hostname == '0.0.0.0': return True try: return ipaddress.ip_address(hostname).is_loopback except ValueError: return False