__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"
import argparse
import asyncio
import os
import tempfile
import time
from typing import List, Callable, Union, Tuple, Optional
import zmq
import zmq.asyncio
from zmq.eventloop.zmqstream import ZMQStream
from zmq.ssh import tunnel_connection
from ... import __default_host__, Request
from ...enums import SocketType
from ...helper import colored, random_identity, get_readable_size, get_or_reuse_loop
from ...importer import ImportExtensions
from ...logging import default_logger, profile_logger, JinaLogger
from ...types.message import Message
from ...types.message.common import ControlMessage
[docs]class Zmqlet:
"""A `Zmqlet` object can send/receive data to/from ZeroMQ socket and invoke callback function. It
has three sockets for input, output and control.
:param args: the parsed arguments from the CLI
:param logger: the logger to use
:param ctrl_addr: control address
.. warning::
Starting from v0.3.6, :class:`ZmqStreamlet` replaces :class:`Zmqlet` as one of the key components in :class:`jina.peapods.runtimes.zmq.zed.ZEDRuntime`.
It requires :mod:`tornado` and :mod:`uvloop` to be installed.
"""
def __init__(
self,
args: 'argparse.Namespace',
logger: Optional['JinaLogger'] = None,
ctrl_addr: Optional[str] = None,
):
self.args = args
self.identity = random_identity()
self.name = args.name or self.__class__.__name__
self.logger = logger
self.send_recv_kwargs = vars(args)
if ctrl_addr:
self.ctrl_addr = ctrl_addr
self.ctrl_with_ipc = self.ctrl_addr.startswith('ipc://')
else:
self.ctrl_addr, self.ctrl_with_ipc = self.get_ctrl_address(
args.host, args.port_ctrl, args.ctrl_with_ipc
)
self.bytes_sent = 0
self.bytes_recv = 0
self.msg_recv = 0
self.msg_sent = 0
self.is_closed = False
self.opened_socks = [] # this must be here for `close()`
self.ctx, self.in_sock, self.out_sock, self.ctrl_sock = self._init_sockets()
self._register_pollin()
self.opened_socks.extend([self.in_sock, self.out_sock, self.ctrl_sock])
if self.in_sock_type == zmq.DEALER:
self._send_idle_to_router()
def _register_pollin(self):
"""Register :attr:`in_sock`, :attr:`ctrl_sock` and :attr:`out_sock` (if :attr:`out_sock_type` is zmq.ROUTER) in poller."""
self.poller = zmq.Poller()
self.poller.register(self.in_sock, zmq.POLLIN)
self.poller.register(self.ctrl_sock, zmq.POLLIN)
if self.out_sock_type == zmq.ROUTER:
self.poller.register(self.out_sock, zmq.POLLIN)
[docs] def pause_pollin(self):
"""Remove :attr:`in_sock` from the poller """
self.poller.unregister(self.in_sock)
[docs] def resume_pollin(self):
"""Put :attr:`in_sock` back to the poller """
self.poller.register(self.in_sock)
[docs] @staticmethod
def get_ctrl_address(
host: Optional[str], port_ctrl: Optional[str], ctrl_with_ipc: bool
) -> Tuple[str, bool]:
"""Get the address of the control socket
:param host: the host in the arguments
:param port_ctrl: the control port
:param ctrl_with_ipc: a bool of whether using IPC protocol for controlling
:return: A tuple of two pieces:
- a string of control address
- a bool of whether using IPC protocol for controlling
"""
ctrl_with_ipc = (os.name != 'nt') and ctrl_with_ipc
if ctrl_with_ipc:
return _get_random_ipc(), ctrl_with_ipc
else:
host_out = host
if '@' in host_out:
# [email protected]
host_out = host_out.split('@')[-1]
else:
host_out = host_out
return f'tcp://{host_out}:{port_ctrl}', ctrl_with_ipc
def _pull(self, interval: int = 1):
socks = dict(self.poller.poll(interval))
# the priority ctrl_sock > in_sock
if socks.get(self.ctrl_sock) == zmq.POLLIN:
return self.ctrl_sock
elif socks.get(self.out_sock) == zmq.POLLIN:
return self.out_sock # for dealer return idle status to router
elif socks.get(self.in_sock) == zmq.POLLIN:
return self.in_sock
def _close_sockets(self):
"""Close input, output and control sockets of this `Zmqlet`. """
for k in self.opened_socks:
k.close()
def _init_sockets(self) -> Tuple:
"""Initialize all sockets and the ZMQ context.
:return: A tuple of four pieces:
- ZMQ context
- the input socket
- the output socket
- the control socket
"""
ctx = self._get_zmq_ctx()
ctx.setsockopt(zmq.LINGER, 0)
self.logger.debug('setting up sockets...')
try:
if self.ctrl_with_ipc:
ctrl_sock, ctrl_addr = _init_socket(
ctx,
self.ctrl_addr,
None,
SocketType.PAIR_BIND,
use_ipc=self.ctrl_with_ipc,
)
else:
ctrl_sock, ctrl_addr = _init_socket(
ctx, __default_host__, self.args.port_ctrl, SocketType.PAIR_BIND
)
self.logger.debug(f'control over {colored(ctrl_addr, "yellow")}')
in_sock, in_addr = _init_socket(
ctx,
self.args.host_in,
self.args.port_in,
self.args.socket_in,
self.identity,
ssh_server=self.args.ssh_server,
ssh_keyfile=self.args.ssh_keyfile,
ssh_password=self.args.ssh_password,
)
self.logger.debug(
f'input {self.args.host_in}:{colored(self.args.port_in, "yellow")}'
)
out_sock, out_addr = _init_socket(
ctx,
self.args.host_out,
self.args.port_out,
self.args.socket_out,
self.identity,
ssh_server=self.args.ssh_server,
ssh_keyfile=self.args.ssh_keyfile,
ssh_password=self.args.ssh_password,
)
self.logger.debug(
f'output {self.args.host_out}:{colored(self.args.port_out, "yellow")}'
)
self.logger.info(
f'input {colored(in_addr, "yellow")} ({self.args.socket_in.name}) '
f'output {colored(out_addr, "yellow")} ({self.args.socket_out.name}) '
f'control over {colored(ctrl_addr, "yellow")} ({SocketType.PAIR_BIND.name})'
)
self.in_sock_type = in_sock.type
self.out_sock_type = out_sock.type
self.ctrl_sock_type = ctrl_sock.type
return ctx, in_sock, out_sock, ctrl_sock
except zmq.error.ZMQError as ex:
self.close()
raise ex
def _get_zmq_ctx(self):
return zmq.Context()
def __enter__(self):
# time.sleep(.1) # timeout handshake is unnecessary at the Pod level, it is only required for gateway
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
[docs] def close(self):
"""Close all sockets and shutdown the ZMQ context associated to this `Zmqlet`.
.. note::
This method is idempotent.
"""
if not self.is_closed:
self.is_closed = True
self._close_sockets()
if hasattr(self, 'ctx'):
self.ctx.term()
self.print_stats()
[docs] def print_stats(self):
"""Print out the network stats of of itself """
self.logger.info(
f'#sent: {self.msg_sent} '
f'#recv: {self.msg_recv} '
f'sent_size: {get_readable_size(self.bytes_sent)} '
f'recv_size: {get_readable_size(self.bytes_recv)}'
)
profile_logger.info(
{
'msg_sent': self.msg_sent,
'msg_recv': self.msg_recv,
'bytes_sent': self.bytes_sent,
'bytes_recv': self.bytes_recv,
}
)
[docs] def send_message(self, msg: 'Message'):
"""Send a message via the output socket
:param msg: the protobuf message to send
"""
# choose output sock
if msg.is_data_request:
o_sock = self.out_sock
else:
o_sock = self.ctrl_sock
self.bytes_sent += send_message(o_sock, msg, **self.send_recv_kwargs)
self.msg_sent += 1
if o_sock == self.out_sock and self.in_sock_type == zmq.DEALER:
self._send_idle_to_router()
def _send_control_to_router(self, command, raise_exception=False):
msg = ControlMessage(command, pod_name=self.name, identity=self.identity)
self.bytes_sent += send_message(
self.in_sock, msg, raise_exception=raise_exception, **self.send_recv_kwargs
)
self.msg_sent += 1
self.logger.debug(
f'control message {command} with id {self.identity} is sent to the router'
)
def _send_idle_to_router(self):
"""Tell the upstream router this dealer is idle """
self._send_control_to_router('IDLE')
def _send_cancel_to_router(self, raise_exception=False):
"""
Tell the upstream router this dealer is canceled
:param raise_exception: if true: raise an exception which might occur during send, if false: log error
"""
self._send_control_to_router('CANCEL', raise_exception)
[docs] def recv_message(
self, callback: Callable[['Message'], 'Message'] = None
) -> 'Message':
"""Receive a protobuf message from the input socket
:param callback: the callback function, which modifies the recevied message inplace.
:return: the received (and modified) protobuf message
"""
i_sock = self._pull()
if i_sock is not None:
msg = recv_message(i_sock, **self.send_recv_kwargs)
self.bytes_recv += msg.size
self.msg_recv += 1
if callback:
return callback(msg)
[docs]class AsyncZmqlet(Zmqlet):
"""An async vesion of :class:`Zmqlet`.
The :func:`send_message` and :func:`recv_message` works in the async manner.
"""
def _get_zmq_ctx(self):
return zmq.asyncio.Context()
[docs] async def send_message(self, msg: 'Message', sleep: float = 0, **kwargs):
"""Send a protobuf message in async via the output socket
:param msg: the protobuf message to send
:param sleep: the sleep time of every two sends in millisecond.
A near-zero value could result in bad load balancing in the proceeding pods.
:param kwargs: keyword arguments
"""
# await asyncio.sleep(sleep) # preventing over-speed sending
try:
num_bytes = await send_message_async(
self.out_sock, msg, **self.send_recv_kwargs
)
self.bytes_sent += num_bytes
self.msg_sent += 1
except (asyncio.CancelledError, TypeError) as ex:
self.logger.error(f'sending message error: {ex!r}, gateway cancelled?')
[docs] async def recv_message(
self, callback: Callable[['Message'], Union['Message', 'Request']] = None
) -> Optional['Message']:
"""
Receive a protobuf message in async manner.
:param callback: Callback function to receive message
:return: Received protobuf message. Or None in case of any error.
"""
try:
msg = await recv_message_async(self.in_sock, **self.send_recv_kwargs)
self.msg_recv += 1
if msg is not None:
self.bytes_recv += msg.size
if callback:
return callback(msg)
else:
self.logger.error('Received message is empty.')
except (asyncio.CancelledError, TypeError) as ex:
self.logger.error(f'receiving message error: {ex!r}, gateway cancelled?')
def __enter__(self):
time.sleep(0.2) # sleep a bit until handshake is done
return self
[docs]class ZmqStreamlet(Zmqlet):
"""A :class:`ZmqStreamlet` object can send/receive data to/from ZeroMQ stream and invoke callback function. It
has three sockets for input, output and control.
.. warning::
Starting from v0.3.6, :class:`ZmqStreamlet` replaces :class:`Zmqlet` as one of the key components in :class:`jina.peapods.runtime.BasePea`.
It requires :mod:`tornado` and :mod:`uvloop` to be installed.
"""
def _register_pollin(self):
"""Register :attr:`in_sock`, :attr:`ctrl_sock` and :attr:`out_sock` in poller."""
with ImportExtensions(required=True):
import tornado.ioloop
get_or_reuse_loop()
self.io_loop = tornado.ioloop.IOLoop.current()
self.in_sock = ZMQStream(self.in_sock, self.io_loop)
self.out_sock = ZMQStream(self.out_sock, self.io_loop)
self.ctrl_sock = ZMQStream(self.ctrl_sock, self.io_loop)
self.in_sock.stop_on_recv()
[docs] def close(self):
"""Close all sockets and shutdown the ZMQ context associated to this `Zmqlet`.
.. note::
This method is idempotent.
"""
if self.in_sock_type == zmq.DEALER:
try:
self._send_cancel_to_router(raise_exception=True)
except zmq.error.ZMQError:
self.logger.info(
f'The dealer {self.name} can not unsubscribe from the router. '
f'In case the router is down this is expected.'
)
if not self.is_closed:
# wait until the close signal is received
time.sleep(0.01)
for s in self.opened_socks:
s.flush()
super().close()
if hasattr(self, 'io_loop'):
try:
self.io_loop.stop()
# Replace handle events function, to skip
# None event after sockets are closed.
if hasattr(self.in_sock, '_handle_events'):
self.in_sock._handle_events = lambda *args, **kwargs: None
if hasattr(self.out_sock, '_handle_events'):
self.out_sock._handle_events = lambda *args, **kwargs: None
if hasattr(self.ctrl_sock, '_handle_events'):
self.ctrl_sock._handle_events = lambda *args, **kwargs: None
except AttributeError as e:
self.logger.error(f'failed to stop. {e!r}')
[docs] def pause_pollin(self):
"""Remove :attr:`in_sock` from the poller """
self.in_sock.stop_on_recv()
[docs] def resume_pollin(self):
"""Put :attr:`in_sock` back to the poller """
self.in_sock.on_recv(self._in_sock_callback)
[docs] def start(self, callback: Callable[['Message'], 'Message']):
"""
Open all sockets and start the ZMQ context associated to this `Zmqlet`.
:param callback: callback function to receive the protobuf message
"""
def _callback(msg, sock_type):
msg = _parse_from_frames(sock_type, msg)
self.bytes_recv += msg.size
self.msg_recv += 1
msg = callback(msg)
if msg:
self.send_message(msg)
self._in_sock_callback = lambda x: _callback(x, self.in_sock_type)
self.in_sock.on_recv(self._in_sock_callback)
self.ctrl_sock.on_recv(lambda x: _callback(x, self.ctrl_sock_type))
if self.out_sock_type == zmq.ROUTER:
self.out_sock.on_recv(lambda x: _callback(x, self.out_sock_type))
self.io_loop.start()
self.io_loop.clear_current()
self.io_loop.close(all_fds=True)
[docs]def send_ctrl_message(
address: str, cmd: Union[str, Message], timeout: int
) -> 'Message':
"""Send a control message to a specific address and wait for the response
:param address: the socket address to send
:param cmd: the control command to send
:param timeout: the waiting time (in ms) for the response
:return: received message
"""
if isinstance(cmd, str):
# we assume ControlMessage as default
msg = ControlMessage(cmd)
else:
msg = cmd
# control message is short, set a timeout and ask for quick response
with zmq.Context() as ctx:
ctx.setsockopt(zmq.LINGER, 0)
sock, _ = _init_socket(ctx, address, None, SocketType.PAIR_CONNECT)
send_message(sock, msg, timeout)
r = None
try:
r = recv_message(sock, timeout)
except TimeoutError:
pass
finally:
sock.close()
return r
[docs]def send_message(
sock: Union['zmq.Socket', 'ZMQStream'],
msg: 'Message',
raise_exception: bool = False,
timeout: int = -1,
**kwargs,
) -> int:
"""Send a protobuf message to a socket
:param sock: the target socket to send
:param msg: the protobuf message
:param raise_exception: if true: raise an exception which might occur during send, if false: log error
:param timeout: waiting time (in seconds) for sending
:param kwargs: keyword arguments
:return: the size (in bytes) of the sent message
"""
num_bytes = 0
try:
_prep_send_socket(sock, timeout)
sock.send_multipart(msg.dump())
num_bytes = msg.size
except zmq.error.Again:
raise TimeoutError(
f'cannot send message to sock {sock} after timeout={timeout}ms, please check the following:'
'is the server still online? is the network broken? are "port" correct?'
)
except zmq.error.ZMQError as ex:
if raise_exception:
raise ex
else:
default_logger.critical(ex)
finally:
try:
sock.setsockopt(zmq.SNDTIMEO, -1)
except zmq.error.ZMQError:
pass
return num_bytes
def _prep_send_socket(sock, timeout):
if timeout > 0:
sock.setsockopt(zmq.SNDTIMEO, timeout)
else:
sock.setsockopt(zmq.SNDTIMEO, -1)
def _prep_recv_socket(sock, timeout):
if timeout > 0:
sock.setsockopt(zmq.RCVTIMEO, timeout)
else:
sock.setsockopt(zmq.RCVTIMEO, -1)
[docs]async def send_message_async(
sock: 'zmq.Socket', msg: 'Message', timeout: int = -1, **kwargs
) -> int:
"""Send a protobuf message to a socket in async manner
:param sock: the target socket to send
:param msg: the protobuf message
:param timeout: waiting time (in seconds) for sending
:param kwargs: keyword arguments
:return: the size (in bytes) of the sent message
"""
try:
_prep_send_socket(sock, timeout)
await sock.send_multipart(msg.dump())
return msg.size
except zmq.error.Again:
raise TimeoutError(
f'cannot send message to sock {sock} after timeout={timeout}ms, please check the following:'
'is the server still online? is the network broken? are "port" correct? '
)
except zmq.error.ZMQError as ex:
default_logger.critical(ex)
except asyncio.CancelledError:
default_logger.error('all gateway tasks are cancelled')
except Exception as ex:
raise ex
finally:
try:
sock.setsockopt(zmq.SNDTIMEO, -1)
except zmq.error.ZMQError:
pass
[docs]def recv_message(sock: 'zmq.Socket', timeout: int = -1, **kwargs) -> 'Message':
"""Receive a protobuf message from a socket
:param sock: the socket to pull from
:param timeout: max wait time for pulling, -1 means wait forever
:param kwargs: keyword arguments
:return: a tuple of two pieces
- the received protobuf message
- the size of the message in bytes
"""
try:
_prep_recv_socket(sock, timeout)
msg_data = sock.recv_multipart()
return _parse_from_frames(sock.type, msg_data)
except zmq.error.Again:
raise TimeoutError(
f'no response from sock {sock} after timeout={timeout}ms, please check the following:'
'is the server still online? is the network broken? are "port" correct? '
)
except Exception as ex:
raise ex
finally:
sock.setsockopt(zmq.RCVTIMEO, -1)
[docs]async def recv_message_async(
sock: 'zmq.Socket', timeout: int = -1, **kwargs
) -> 'Message':
"""Receive a protobuf message from a socket in async manner
:param sock: the socket to pull from
:param timeout: max wait time for pulling, -1 means wait forever
:param kwargs: keyword arguments
:return: a tuple of two pieces
- the received protobuf message
- the size of the message in bytes
"""
try:
_prep_recv_socket(sock, timeout)
msg_data = await sock.recv_multipart()
return _parse_from_frames(sock.type, msg_data)
except zmq.error.Again:
raise TimeoutError(
f'no response from sock {sock} after timeout={timeout}ms, please check the following:'
'is the server still online? is the network broken? are "port" correct? '
)
except zmq.error.ZMQError as ex:
default_logger.critical(ex)
except asyncio.CancelledError:
default_logger.error('all gateway tasks are cancelled')
except Exception as ex:
raise ex
finally:
try:
sock.setsockopt(zmq.RCVTIMEO, -1)
except zmq.error.ZMQError:
pass
def _parse_from_frames(sock_type, frames: List[bytes]) -> 'Message':
"""
Build :class:`Message` from a list of frames.
The list of frames (has length >=3) has the following structure:
- offset 0: the client id, can be empty
- offset 1: is the offset 2 frame compressed
- offset 2: the body of the serialized protobuf message
:param sock_type: the recv socket type
:param frames: list of bytes to parse from
:return: a :class:`Message` object
"""
if sock_type == zmq.DEALER:
# dealer consumes the first part of the message as id, we need to prepend it back
frames = [b' '] + frames
elif sock_type == zmq.ROUTER:
# the router appends dealer id when receive it, we need to remove it
frames.pop(0)
return Message(frames[1], frames[2])
def _get_random_ipc() -> str:
"""
Get a random IPC address for control port
:return: random IPC address
"""
try:
tmp = os.environ['JINA_IPC_SOCK_TMP']
if not os.path.exists(tmp):
raise ValueError(
f'This directory for sockets ({tmp}) does not seems to exist.'
)
tmp = os.path.join(tmp, random_identity())
except KeyError:
tmp = tempfile.NamedTemporaryFile().name
return f'ipc://{tmp}'
def _init_socket(
ctx: 'zmq.Context',
host: str,
port: Optional[int],
socket_type: 'SocketType',
identity: Optional[str] = None,
use_ipc: bool = False,
ssh_server: Optional[str] = None,
ssh_keyfile: Optional[str] = None,
ssh_password: Optional[str] = None,
) -> Tuple['zmq.Socket', str]:
sock = {
SocketType.PULL_BIND: lambda: ctx.socket(zmq.PULL),
SocketType.PULL_CONNECT: lambda: ctx.socket(zmq.PULL),
SocketType.SUB_BIND: lambda: ctx.socket(zmq.SUB),
SocketType.SUB_CONNECT: lambda: ctx.socket(zmq.SUB),
SocketType.PUB_BIND: lambda: ctx.socket(zmq.PUB),
SocketType.PUB_CONNECT: lambda: ctx.socket(zmq.PUB),
SocketType.PUSH_BIND: lambda: ctx.socket(zmq.PUSH),
SocketType.PUSH_CONNECT: lambda: ctx.socket(zmq.PUSH),
SocketType.PAIR_BIND: lambda: ctx.socket(zmq.PAIR),
SocketType.PAIR_CONNECT: lambda: ctx.socket(zmq.PAIR),
SocketType.ROUTER_BIND: lambda: ctx.socket(zmq.ROUTER),
SocketType.DEALER_CONNECT: lambda: ctx.socket(zmq.DEALER),
}[socket_type]()
sock.setsockopt(zmq.LINGER, 0)
if socket_type == SocketType.DEALER_CONNECT:
sock.set_string(zmq.IDENTITY, identity)
# if not socket_type.is_pubsub:
# sock.hwm = int(os.environ.get('JINA_SOCKET_HWM', 1))
if socket_type.is_bind:
if use_ipc:
sock.bind(host)
else:
# JEP2, if it is bind, then always bind to local
if host != __default_host__:
default_logger.warning(
f'host is set from {host} to {__default_host__} as the socket is in BIND type'
)
host = __default_host__
if port is None:
sock.bind_to_random_port(f'tcp://{host}')
else:
try:
sock.bind(f'tcp://{host}:{port}')
except zmq.error.ZMQError:
default_logger.error(
f'error when binding port {port} to {host}, this port is occupied. '
f'If you are using Linux, try `lsof -i :{port}` to see which process '
f'occupies the port.'
)
raise
else:
if port is None:
address = host
else:
address = f'tcp://{host}:{port}'
# note that ssh only takes effect on CONNECT, not BIND
# that means control socket setup does not need ssh
if ssh_server:
tunnel_connection(sock, address, ssh_server, ssh_keyfile, ssh_password)
else:
sock.connect(address)
if socket_type in {SocketType.SUB_CONNECT, SocketType.SUB_BIND}:
# sock.setsockopt(zmq.SUBSCRIBE, identity.encode('ascii') if identity else b'')
sock.subscribe('') # An empty shall subscribe to all incoming messages
return sock, sock.getsockopt_string(zmq.LAST_ENDPOINT)