Source code for jina.helper

__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

import asyncio
import functools
import json
import math
import os
import random
import re
import sys
import threading
import time
import uuid
import warnings
from argparse import ArgumentParser, Namespace
from contextlib import contextmanager
from datetime import datetime
from itertools import islice
from types import SimpleNamespace
from typing import (
    Tuple,
    Optional,
    Iterator,
    Any,
    Union,
    List,
    Dict,
    Set,
    Sequence,
    Iterable,
)
from urllib.request import Request, urlopen

import numpy as np

__all__ = [
    'batch_iterator',
    'parse_arg',
    'random_port',
    'random_identity',
    'random_uuid',
    'expand_env_var',
    'colored',
    'ArgNamespace',
    'is_valid_local_config_source',
    'cached_property',
    'is_url',
    'typename',
    'get_public_ip',
    'get_internal_ip',
    'convert_tuple_to_list',
    'run_async',
    'deprecated_alias',
]

from jina.excepts import NotSupportedError


[docs]def deprecated_alias(**aliases): """ Usage, kwargs with key as the deprecated arg name and value be a tuple, (new_name, deprecate_level). With level 0 means warning, level 1 means exception. For example: .. highlight:: python .. code-block:: python @deprecated_alias(input_fn=('inputs', 0), buffer=('input_fn', 0), callback=('on_done', 1), output_fn=('on_done', 1)) :param aliases: maps aliases to new arguments :return: wrapper """ def _rename_kwargs(func_name: str, kwargs, aliases): """ Raise warnings or exceptions for deprecated arguments. :param func_name: Name of the function. :param kwargs: key word arguments from the function which is decorated. :param aliases: kwargs with key as the deprecated arg name and value be a tuple, (new_name, deprecate_level). """ for alias, new_arg in aliases.items(): if not isinstance(new_arg, tuple): raise ValueError( f'{new_arg} must be a tuple, with first element as the new name, ' f'second element as the deprecated level: 0 as warning, 1 as exception' ) if alias in kwargs: new_name, dep_level = new_arg if new_name in kwargs: raise NotSupportedError( f'{func_name} received both {alias} and {new_name}' ) if dep_level == 0: warnings.warn( f'`{alias}` is renamed to `{new_name}` in `{func_name}()`, the usage of `{alias}` is ' f'deprecated and will be removed in the next version.', DeprecationWarning, ) kwargs[new_name] = kwargs.pop(alias) elif dep_level == 1: raise NotSupportedError(f'{alias} has been renamed to `{new_name}`') def deco(f): """ Set Decorator function. :param f: function the decorator is used for :return: wrapper """ @functools.wraps(f) def wrapper(*args, **kwargs): """ Set wrapper function. :param args: wrapper arguments :param kwargs: wrapper key word arguments :return: result of renamed function. """ _rename_kwargs(f.__name__, kwargs, aliases) return f(*args, **kwargs) return wrapper return deco
def get_readable_size(num_bytes: Union[int, float]) -> str: """ Transform the bytes into readable value with different units (e.g. 1 KB, 20 MB, 30.1 GB). :param num_bytes: Number of bytes. :return: Human readable string representation. """ num_bytes = int(num_bytes) if num_bytes < 1024: return f'{num_bytes} Bytes' elif num_bytes < 1024 ** 2: return f'{num_bytes / 1024:.1f} KB' elif num_bytes < 1024 ** 3: return f'{num_bytes / (1024 ** 2):.1f} MB' else: return f'{num_bytes / (1024 ** 3):.1f} GB' def call_obj_fn(obj, fn: str): """ Get a named attribute from an object; getattr(obj, 'fn') is equivalent to obj.fn. :param obj: Target object. :param fn: Desired attribute. """ if obj is not None and hasattr(obj, fn): getattr(obj, fn)() def touch_dir(base_dir: str) -> None: """ Create a directory from given path if it doesn't exist. :param base_dir: Path of target path. """ if not os.path.exists(base_dir): os.makedirs(base_dir)
[docs]def batch_iterator( data: Iterable[Any], batch_size: int, axis: int = 0, yield_slice: bool = False, yield_dict: bool = False, ) -> Iterator[Any]: """ Get an iterator of batches of data. For example: .. highlight:: python .. code-block:: python for batch in batch_iterator(data, batch_size, split_over_axis, yield_slice=yield_slice): # Do something with batch :param data: Data source. :param batch_size: Size of one batch. :param axis: Determine which axis to iterate for np.ndarray data. :param yield_slice: Return tuple type of data if True else return np.ndarray type. :param yield_dict: Return dict type of data if True else return tuple type. :yield: data :return: An Iterator of batch data. """ import numpy as np if not batch_size or batch_size <= 0: yield data return if isinstance(data, np.ndarray): _l = data.shape[axis] _d = data.ndim sl = [slice(None)] * _d if batch_size >= _l: if yield_slice: yield tuple(sl) else: yield data return for start in range(0, _l, batch_size): end = min(_l, start + batch_size) sl[axis] = slice(start, end) if yield_slice: yield tuple(sl) else: yield data[tuple(sl)] elif isinstance(data, Sequence): if batch_size >= len(data): yield data return for _ in range(0, len(data), batch_size): yield data[_ : _ + batch_size] elif isinstance(data, Iterable): # as iterator, there is no way to know the length of it while True: if yield_dict: chunk = dict(islice(data, batch_size)) else: chunk = tuple(islice(data, batch_size)) if not chunk: return yield chunk else: raise TypeError(f'unsupported type: {type(data)}')
[docs]def parse_arg(v: str) -> Optional[Union[bool, int, str, list, float]]: """ Parse the arguments from string to `Union[bool, int, str, list, float]`. :param v: The string of arguments :return: The parsed arguments list. """ m = re.match(r'^[\'"](.*)[\'"]$', v) if m: return m.group(1) if v.startswith('[') and v.endswith(']'): # function args must be immutable tuples not list tmp = v.replace('[', '').replace(']', '').strip().split(',') if len(tmp) > 0: return [parse_arg(vv.strip()) for vv in tmp] else: return [] try: v = int(v) # parse int parameter except ValueError: try: v = float(v) # parse float parameter except ValueError: if len(v) == 0: # ignore it when the parameter is empty v = None elif v.lower() == 'true': # parse boolean parameter v = True elif v.lower() == 'false': v = False return v
def countdown(t: int, reason: str = 'I am blocking this thread') -> None: """ Display the countdown in console. For example: .. highlight:: python .. code-block:: python countdown(10, reason=colored('re-fetch access token', 'cyan', attrs=['bold', 'reverse'])) :param t: Countdown time. :param reason: A string message of reason for this Countdown. """ try: sys.stdout.write('\n') sys.stdout.flush() while t > 0: t -= 1 msg = f'⏳ {colored("%3d" % t, "yellow")}s left: {reason}' sys.stdout.write(f'\r{msg}') sys.stdout.flush() time.sleep(1) sys.stdout.write('\n') sys.stdout.flush() except KeyboardInterrupt: sys.stdout.write('no more patience? good bye!') _random_names = ( ( 'first', 'great', 'local', 'small', 'right', 'large', 'young', 'early', 'major', 'clear', 'black', 'whole', 'third', 'white', 'short', 'human', 'royal', 'wrong', 'legal', 'final', 'close', 'total', 'prime', 'happy', 'sorry', 'basic', 'aware', 'ready', 'green', 'heavy', 'extra', 'civil', 'chief', 'usual', 'front', 'fresh', 'joint', 'alone', 'rural', 'light', 'equal', 'quiet', 'quick', 'daily', 'urban', 'upper', 'moral', 'vital', 'empty', 'brief', ), ( 'world', 'house', 'place', 'group', 'party', 'money', 'point', 'state', 'night', 'water', 'thing', 'order', 'power', 'court', 'level', 'child', 'south', 'staff', 'woman', 'north', 'sense', 'death', 'range', 'table', 'trade', 'study', 'other', 'price', 'class', 'union', 'value', 'paper', 'right', 'voice', 'stage', 'light', 'march', 'board', 'month', 'music', 'field', 'award', 'issue', 'basis', 'front', 'heart', 'force', 'model', 'space', 'peter', ), ) def random_name() -> str: """ Generate a random name from list. :return: A Random name. """ return '_'.join(random.choice(_random_names[j]) for j in range(2))
[docs]def random_port() -> Optional[int]: """ Get a random available port number from '49153' to '65535'. :return: A random port. """ import threading import multiprocessing from contextlib import closing import socket def _get_port(port=0): with multiprocessing.Lock(): with threading.Lock(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: try: s.bind(('', port)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] except OSError: pass _port = None if 'JINA_RANDOM_PORTS' in os.environ: min_port = int(os.environ.get('JINA_RANDOM_PORT_MIN', '49153')) max_port = int(os.environ.get('JINA_RANDOM_PORT_MAX', '65535')) for _port in np.random.permutation(range(min_port, max_port + 1)): if _get_port(_port) is not None: break else: raise OSError( f'Couldn\'t find an available port in [{min_port}, {max_port}].' ) else: _port = _get_port() return int(_port)
[docs]def random_identity(use_uuid1: bool = False) -> str: """ Generate random UUID. ..note:: A MAC address or time-based ordering (UUID1) can afford increased database performance, since it's less work to sort numbers closer-together than those distributed randomly (UUID4) (see here). A second related issue, is that using UUID1 can be useful in debugging, even if origin data is lost or not explicitly stored. :param use_uuid1: use UUID1 instead of UUID4. This is the default Document ID generator. :return: A random UUID. """ return str(random_uuid(use_uuid1))
[docs]def random_uuid(use_uuid1: bool = False) -> uuid.UUID: """ Get a random UUID. :param use_uuid1: Use UUID1 if True, else use UUID4. :return: A random UUID. """ return uuid.uuid1() if use_uuid1 else uuid.uuid4()
[docs]def expand_env_var(v: str) -> Optional[Union[bool, int, str, list, float]]: """ Expand the environment variables. :param v: String of environment variables. :return: Parsed environment variables. """ if isinstance(v, str): return parse_arg(os.path.expandvars(v)) else: return v
def expand_dict( d: Dict, expand_fn=expand_env_var, resolve_cycle_ref=True ) -> Dict[str, Any]: """ Expand variables from YAML file. :param d: Target Dict. :param expand_fn: Parsed environment variables. :param resolve_cycle_ref: Defines if cyclic references should be resolved. :return: Expanded variables. """ expand_map = SimpleNamespace() pat = re.compile(r'{.+}|\$[a-zA-Z0-9_]*\b') def _scan(sub_d: Union[Dict, List], p): if isinstance(sub_d, dict): for k, v in sub_d.items(): if isinstance(v, dict): p.__dict__[k] = SimpleNamespace() _scan(v, p.__dict__[k]) elif isinstance(v, list): p.__dict__[k] = list() _scan(v, p.__dict__[k]) else: p.__dict__[k] = v elif isinstance(sub_d, list): for idx, v in enumerate(sub_d): if isinstance(v, dict): p.append(SimpleNamespace()) _scan(v, p[idx]) elif isinstance(v, list): p.append(list()) _scan(v, p[idx]) else: p.append(v) def _replace(sub_d: Union[Dict, List], p): if isinstance(sub_d, Dict): for k, v in sub_d.items(): if isinstance(v, (dict, list)): _replace(v, p.__dict__[k]) else: if isinstance(v, str) and pat.findall(v): sub_d[k] = _sub(v, p) elif isinstance(sub_d, List): for idx, v in enumerate(sub_d): if isinstance(v, (dict, list)): _replace(v, p[idx]) else: if isinstance(v, str) and pat.findall(v): sub_d[idx] = _sub(v, p) def _sub(v, p): if resolve_cycle_ref: try: v = v.format(root=expand_map, this=p) except KeyError: pass return expand_fn(v) _scan(d, expand_map) _replace(d, expand_map) return d _ATTRIBUTES = { 'bold': 1, 'dark': 2, 'underline': 4, 'blink': 5, 'reverse': 7, 'concealed': 8, } _HIGHLIGHTS = { 'on_grey': 40, 'on_red': 41, 'on_green': 42, 'on_yellow': 43, 'on_blue': 44, 'on_magenta': 45, 'on_cyan': 46, 'on_white': 47, } _COLORS = { 'grey': 30, 'red': 31, 'green': 32, 'yellow': 33, 'blue': 34, 'magenta': 35, 'cyan': 36, 'white': 37, } _RESET = '\033[0m' def build_url_regex_pattern(): """ Set up the regex pattern of URL. :return: Regex pattern. """ ul = '\u00a1-\uffff' # Unicode letters range (must not be a raw string). # IP patterns ipv4_re = ( r'(?:25[0-5]|2[0-4]\d|[0-1]?\d?\d)(?:\.(?:25[0-5]|2[0-4]\d|[0-1]?\d?\d)){3}' ) ipv6_re = r'\[[0-9a-f:.]+\]' # (simple regex, validated later) # Host patterns hostname_re = ( r'[a-z' + ul + r'0-9](?:[a-z' + ul + r'0-9-]{0,61}[a-z' + ul + r'0-9])?' ) # Max length for domain name labels is 63 characters per RFC 1034 sec. 3.1 domain_re = r'(?:\.(?!-)[a-z' + ul + r'0-9-]{1,63}(?<!-))*' tld_re = ( r'\.' # dot r'(?!-)' # can't start with a dash r'(?:[a-z' + ul + '-]{2,63}' # domain label r'|xn--[a-z0-9]{1,59})' # or punycode label r'(?<!-)' # can't end with a dash r'\.?' # may have a trailing dot ) host_re = '(' + hostname_re + domain_re + tld_re + '|localhost)' return re.compile( r'^(?:[a-z0-9.+-]*)://' # scheme is validated separately r'(?:[^\s:@/]+(?::[^\s:@/]*)[email protected])?' # user:pass authentication r'(?:' + ipv4_re + '|' + ipv6_re + '|' + host_re + ')' r'(?::\d{2,5})?' # port r'(?:[/?#][^\s]*)?' # resource path r'\Z', re.IGNORECASE, ) url_pat = build_url_regex_pattern()
[docs]def is_url(text): """ Check if the text is URL. :param text: The target text. :return: True if text is URL else False. """ return url_pat.match(text) is not None
if os.name == 'nt': os.system('color')
[docs]def colored( text: str, color: Optional[str] = None, on_color: Optional[str] = None, attrs: Optional[Union[str, list]] = None, ) -> str: """ Give the text with color. :param text: The target text. :param color: The color of text. Chosen from the following. { 'grey': 30, 'red': 31, 'green': 32, 'yellow': 33, 'blue': 34, 'magenta': 35, 'cyan': 36, 'white': 37 } :param on_color: The on_color of text. Chosen from the following. { 'on_grey': 40, 'on_red': 41, 'on_green': 42, 'on_yellow': 43, 'on_blue': 44, 'on_magenta': 45, 'on_cyan': 46, 'on_white': 47 } :param attrs: Attributes of color. Chosen from the following. { 'bold': 1, 'dark': 2, 'underline': 4, 'blink': 5, 'reverse': 7, 'concealed': 8 } :return: Colored text. """ if 'JINA_LOG_NO_COLOR' not in os.environ: fmt_str = '\033[%dm%s' if color: text = fmt_str % (_COLORS[color], text) if on_color: text = fmt_str % (_HIGHLIGHTS[on_color], text) if attrs: if isinstance(attrs, str): attrs = [attrs] if isinstance(attrs, list): for attr in attrs: text = fmt_str % (_ATTRIBUTES[attr], text) text += _RESET return text
[docs]class ArgNamespace: """Helper function for argparse.Namespace object."""
[docs] @staticmethod def kwargs2list(kwargs: Dict) -> List[str]: """ Convert dict to an argparse-friendly list. :param kwargs: dictionary of key-values to be converted :return: argument list """ args = [] from .executors import BaseExecutor for k, v in kwargs.items(): k = k.replace('_', '-') if v is not None: if isinstance(v, bool): if v: args.append(f'--{k}') elif isinstance(v, list): # for nargs args.extend([f'--{k}', *(str(vv) for vv in v)]) elif isinstance(v, dict): args.extend([f'--{k}', json.dumps(v)]) elif isinstance(v, type) and issubclass(v, BaseExecutor): args.extend([f'--{k}', v.__name__]) else: args.extend([f'--{k}', str(v)]) return args
[docs] @staticmethod def kwargs2namespace( kwargs: Dict[str, Union[str, int, bool]], parser: ArgumentParser ) -> Namespace: """ Convert dict to a namespace. :param kwargs: dictionary of key-values to be converted :param parser: the parser for building kwargs into a namespace :return: argument list """ args = ArgNamespace.kwargs2list(kwargs) try: p_args, unknown_args = parser.parse_known_args(args) except SystemExit: raise ValueError( f'bad arguments "{args}" with parser {parser}, ' 'you may want to double check your args ' ) return p_args
[docs] @staticmethod def get_parsed_args( kwargs: Dict[str, Union[str, int, bool]], parser: ArgumentParser ) -> Tuple[List[str], Namespace, List[Any]]: """ Get all parsed args info in a dict. :param kwargs: dictionary of key-values to be converted :param parser: the parser for building kwargs into a namespace :return: argument namespace, positional arguments and unknown arguments """ args = ArgNamespace.kwargs2list(kwargs) try: p_args, unknown_args = parser.parse_known_args(args) if unknown_args: from .logging import default_logger default_logger.debug( f'parser {typename(parser)} can not ' f'recognize the following args: {unknown_args}, ' f'they are ignored. if you are using them from a global args (e.g. Flow), ' f'then please ignore this message' ) except SystemExit: raise ValueError( f'bad arguments "{args}" with parser {parser}, ' 'you may want to double check your args ' ) return args, p_args, unknown_args
[docs] @staticmethod def get_non_defaults_args( args: Namespace, parser: ArgumentParser, taboo: Optional[Set[str]] = None ) -> Dict: """ Get non-default args in a dict. :param args: the namespace to parse :param parser: the parser for referring the default values :param taboo: exclude keys in the final result :return: non defaults """ if taboo is None: taboo = set() non_defaults = {} _defaults = vars(parser.parse_args([])) for k, v in vars(args).items(): if k in _defaults and k not in taboo and _defaults[k] != v: non_defaults[k] = v return non_defaults
[docs] @staticmethod def flatten_to_dict( args: Union[Dict[str, 'Namespace'], 'Namespace'] ) -> Dict[str, Any]: """Convert argparse.Namespace to dict to be uploaded via REST. :param args: namespace or dict or namespace to dict. :return: pea args """ if isinstance(args, Namespace): return vars(args) elif isinstance(args, dict): pea_args = {} for k, v in args.items(): if isinstance(v, Namespace): pea_args[k] = vars(v) elif isinstance(v, list): pea_args[k] = [vars(_) for _ in v] else: pea_args[k] = v return pea_args
[docs]def is_valid_local_config_source(path: str) -> bool: # TODO: this function must be refactored before 1.0 (Han 12.22) """ Check if the path is valid. :param path: Local file path. :return: True if the path is valid else False. """ try: from .jaml import parse_config_source parse_config_source(path) return True except FileNotFoundError: return False
def get_full_version() -> Optional[Tuple[Dict, Dict]]: """ Get the version of libraries used in Jina and environment variables. :return: Version information and environment variables """ from . import __version__, __proto_version__, __jina_env__ from google.protobuf.internal import api_implementation import os, zmq, numpy, google.protobuf, grpc, yaml from grpc import _grpcio_metadata from pkg_resources import resource_filename import platform from .logging import default_logger try: info = { 'jina': __version__, 'jina-proto': __proto_version__, 'jina-vcs-tag': os.environ.get('JINA_VCS_VERSION', '(unset)'), 'libzmq': zmq.zmq_version(), 'pyzmq': numpy.__version__, 'protobuf': google.protobuf.__version__, 'proto-backend': api_implementation._default_implementation_type, 'grpcio': getattr(grpc, '__version__', _grpcio_metadata.__version__), 'pyyaml': yaml.__version__, 'python': platform.python_version(), 'platform': platform.system(), 'platform-release': platform.release(), 'platform-version': platform.version(), 'architecture': platform.machine(), 'processor': platform.processor(), 'jina-resources': resource_filename('jina', 'resources'), } env_info = {k: os.getenv(k, '(unset)') for k in __jina_env__} full_version = info, env_info except Exception as e: default_logger.error(str(e)) full_version = None return full_version def format_full_version_info(info: Dict, env_info: Dict) -> str: """ Format the version information. :param info: Version information of Jina libraries. :param env_info: The Jina environment variables. :return: Formatted version information. """ version_info = '\n'.join(f'- {k:30s}{v}' for k, v in info.items()) env_info = '\n'.join(f'* {k:30s}{v}' for k, v in env_info.items()) return version_info + '\n' + env_info def _use_uvloop(): from .importer import ImportExtensions with ImportExtensions( required=False, help_text='Jina uses uvloop to manage events and sockets, ' 'it often yields better performance than builtin asyncio', ): import uvloop asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) def get_or_reuse_loop(): """ Get a new eventloop or reuse the current opened eventloop. :return: A new eventloop or reuse the current opened eventloop. """ try: loop = asyncio.get_running_loop() if loop.is_closed(): raise RuntimeError except RuntimeError: if 'JINA_DISABLE_UVLOOP' not in os.environ: _use_uvloop() # no running event loop # create a new loop loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop
[docs]def typename(obj): """ Get the typename of object. :param obj: Target object. :return: Typename of the obj. """ if not isinstance(obj, type): obj = obj.__class__ try: return f'{obj.__module__}.{obj.__name__}' except AttributeError: return str(obj)
[docs]class cached_property: """The decorator to cache property of a class.""" def __init__(self, func): """ Create the :class:`cached_property`. :param func: Cached function. """ self.func = func def __get__(self, obj, cls): cached_value = obj.__dict__.get(f'CACHED_{self.func.__name__}', None) if cached_value is not None: return cached_value value = obj.__dict__[f'CACHED_{self.func.__name__}'] = self.func(obj) return value def __delete__(self, obj): cached_value = obj.__dict__.get(f'CACHED_{self.func.__name__}', None) if cached_value is not None: if hasattr(cached_value, 'close'): cached_value.close() del obj.__dict__[f'CACHED_{self.func.__name__}']
def get_now_timestamp(): """ Get the datetime. :return: The datetime in int format. """ now = datetime.now() return int(datetime.timestamp(now)) def get_readable_time(*args, **kwargs): """ Get the datetime in human readable format (e.g. 115 days and 17 hours and 46 minutes and 40 seconds). For example: .. highlight:: python .. code-block:: python get_readable_time(seconds=1000) :param args: arguments for datetime.timedelta :param kwargs: key word arguments for datetime.timedelta :return: Datetime in human readable format. """ import datetime secs = float(datetime.timedelta(*args, **kwargs).total_seconds()) units = [('day', 86400), ('hour', 3600), ('minute', 60), ('second', 1)] parts = [] for unit, mul in units: if secs / mul >= 1 or mul == 1: if mul > 1: n = int(math.floor(secs / mul)) secs -= n * mul else: n = int(secs) parts.append(f'{n} {unit}' + ('' if n == 1 else 's')) return ' and '.join(parts)
[docs]def get_internal_ip(): """ Return the private IP address of the gateway for connecting from other machine in the same network. :return: Private IP address. """ import socket ip = '127.0.0.1' try: with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: # doesn't even have to be reachable s.connect(('10.255.255.255', 1)) ip = s.getsockname()[0] except Exception: pass return ip
[docs]def get_public_ip(): """ Return the public IP address of the gateway for connecting from other machine in the public network. :return: Public IP address. """ import urllib.request timeout = 0.5 results = [] def _get_ip(url): try: with urllib.request.urlopen(url, timeout=timeout) as fp: results.append(fp.read().decode('utf8')) except: pass ip_server_list = [ 'https://api.ipify.org', 'https://ident.me', 'https://ipinfo.io/ip', ] threads = [] for idx, ip in enumerate(ip_server_list): t = threading.Thread(target=_get_ip, args=(ip,)) threads.append(t) t.start() for t in threads: t.join(timeout) for r in results: if r: return r
[docs]def convert_tuple_to_list(d: Dict): """ Convert all the tuple type values from a dict to list. :param d: Dict type of data. """ for k, v in d.items(): if isinstance(v, tuple): d[k] = list(v) elif isinstance(v, dict): convert_tuple_to_list(v)
def is_jupyter() -> bool: # pragma: no cover """ Check if we're running in a Jupyter notebook, using magic command `get_ipython` that only available in Jupyter. :return: True if run in a Jupyter notebook else False. """ try: get_ipython # noqa: F821 except NameError: return False shell = get_ipython().__class__.__name__ # noqa: F821 if shell == 'ZMQInteractiveShell': return True # Jupyter notebook or qtconsole elif shell == 'Shell': return True # Google colab elif shell == 'TerminalInteractiveShell': return False # Terminal running IPython else: return False # Other type (?)
[docs]def run_async(func, *args, **kwargs): """Generalized asyncio.run for jupyter notebook. When running inside jupyter, an eventloop is already exist, can't be stopped, can't be killed. Directly calling asyncio.run will fail, as This function cannot be called when another asyncio event loop is running in the same thread. .. see_also: https://stackoverflow.com/questions/55409641/asyncio-run-cannot-be-called-from-a-running-event-loop :param func: function to run :param args: parameters :param kwargs: key-value parameters :return: asyncio.run(func) """ class _RunThread(threading.Thread): """Create a running thread when in Jupyter notebook.""" def run(self): """Run given `func` asynchronously.""" self.result = asyncio.run(func(*args, **kwargs)) try: loop = asyncio.get_running_loop() except RuntimeError: loop = None if loop and loop.is_running(): # eventloop already exist # running inside Jupyter if is_jupyter(): thread = _RunThread() thread.start() thread.join() try: return thread.result except AttributeError: from .excepts import BadClient raise BadClient( 'something wrong when running the eventloop, result can not be retrieved' ) else: raise RuntimeError( 'you have an eventloop running but not using Jupyter/ipython, ' 'this may mean you are using Jina with other integration? if so, then you ' 'may want to use AsyncClient/AsyncFlow instead of Client/Flow. If not, then ' 'please report this issue here: https://github.com/jina-ai/jina' ) else: return asyncio.run(func(*args, **kwargs))
def slugify(value): """ Normalize string, converts to lowercase, removes non-alpha characters, and converts spaces to hyphens. :param value: Original string. :return: Processed string. """ s = str(value).strip().replace(' ', '_') return re.sub(r'(?u)[^-\w.]', '', s) @contextmanager def change_cwd(path): """ Change the current working dir to ``path`` in a context and set it back to the original one when leaves the context. Yields nothing :param path: Target path. :yields: nothing """ curdir = os.getcwd() os.chdir(path) try: yield finally: os.chdir(curdir) @contextmanager def change_env(key, val): """ Change the environment of ``key`` to ``val`` in a context and set it back to the original one when leaves the context. :param key: Old environment variable. :param val: New environment variable. :yields: nothing """ old_var = os.environ.get(key, None) os.environ[key] = val try: yield finally: if old_var: os.environ[key] = old_var else: os.environ.pop(key) def is_yaml_filepath(val) -> bool: """ Check if the file is YAML file. :param val: Path of target file. :return: True if the file is YAML else False. """ r = r'^[/\w\-\_\.]+.ya?ml$' return re.match(r, val.strip()) is not None def download_mermaid_url(mermaid_url, output) -> None: """ Download the jpg image from mermaid_url. :param mermaid_url: The URL of the image. :param output: A filename specifying the name of the image to be created, the suffix svg/jpg determines the file type of the output image. """ try: req = Request(mermaid_url, headers={'User-Agent': 'Mozilla/5.0'}) with open(output, 'wb') as fp: fp.write(urlopen(req).read()) except: from jina.logging import default_logger default_logger.error( 'can not download image, please check your graph and the network connections' ) def ding(req): """Play a ding sound `on_done`, used in 2021 April fools day # noqa: DAR101 """ import subprocess from pkg_resources import resource_filename soundfx = resource_filename('jina', '/'.join(('resources', 'soundfx', 'bell.mp3'))) subprocess.call(f'ffplay -nodisp -autoexit {soundfx} >/dev/null 2>&1', shell=True) def find_request_binding(target): """Find `@request` decorated methods in a class. :param target: the target class to check :return: a dictionary with key as request type and value as method name """ import ast, inspect res = {} def visit_function_def(node): for e in node.decorator_list: req_name = '' if isinstance(e, ast.Call) and e.func.id == 'requests': req_name = e.keywords[0].value.s elif isinstance(e, ast.Name) and e.id == 'requests': req_name = 'default' if req_name: req_name = _canonical_request_name(req_name) if req_name in res: raise ValueError( f'you already bind `{res[req_name]}` with `{req_name}` request' ) else: res[req_name] = node.name V = ast.NodeVisitor() V.visit_FunctionDef = visit_function_def V.visit(compile(inspect.getsource(target), '?', 'exec', ast.PyCF_ONLY_AST)) return res def _canonical_request_name(req_name: str): """Return the canonical name of a request :param req_name: the orginal request name :return: canonical form of the request """ return req_name.lower().replace('request', '')