import collections
import json
import os
from typing import Union, TextIO, Dict, Tuple, Optional

from yaml import MappingNode
from yaml.composer import Composer
from yaml.constructor import FullConstructor, ConstructorError
from yaml.parser import Parser
from yaml.reader import Reader
from yaml.resolver import Resolver
from yaml.scanner import Scanner

from jina.excepts import BadConfigSource
from jina.helper import is_yaml_filepath
from jina.importer import PathImporter

[docs]class JinaConstructor(FullConstructor): """Convert List into tuple when doing hashing."""
[docs] def get_hashable_key(self, key): """ Get the hash value of key. :param key: key value to be hashed. :return: Hash value of key. """ try: hash(key) except: if isinstance(key, list): for i in range(len(key)): if not isinstance(key[i], key[i] = self.get_hashable_key(key[i]) key = tuple(key) return key raise ValueError(f'unhashable key: {key}') return key
[docs] def construct_mapping(self, node, deep=True): """ Build the mapping from node. :param node: the node to traverse :param deep: required param from YAML constructor :return: Mapped data """ if isinstance(node, MappingNode): self.flatten_mapping(node) return self._construct_mapping(node, deep=deep)
def _construct_mapping(self, node, deep=True): if not isinstance(node, MappingNode): raise ConstructorError( None, None, 'expected a mapping node, but found %s' %, node.start_mark, ) mapping = {} for key_node, value_node in node.value: key = self.construct_object(key_node, deep=True) if not isinstance(key, try: key = self.get_hashable_key(key) except Exception as exc: raise ConstructorError( 'while constructing a mapping', node.start_mark, 'found unacceptable key (%s)' % exc, key_node.start_mark, ) value = self.construct_object(value_node, deep=deep) mapping[key] = value return mapping
[docs]class JinaResolver(Resolver): """Remove `on|On|ON` as bool resolver.""" pass
[docs]class JinaLoader(Reader, Scanner, Parser, Composer, JinaConstructor, JinaResolver): """ The Jina loader which should be able to load YAML safely. :param stream: the stream to load. """ def __init__(self, stream): Reader.__init__(self, stream) Scanner.__init__(self) Parser.__init__(self) Composer.__init__(self) JinaConstructor.__init__(self) JinaResolver.__init__(self)
# remove on|On|ON resolver JinaResolver.yaml_implicit_resolvers.pop('o') JinaResolver.yaml_implicit_resolvers.pop('O')
[docs]def parse_config_source( path: Union[str, TextIO, Dict], allow_stream: bool = True, allow_yaml_file: bool = True, allow_builtin_resource: bool = True, allow_raw_yaml_content: bool = True, allow_raw_driver_yaml_content: bool = True, allow_class_type: bool = True, allow_dict: bool = True, allow_json: bool = True, *args, **kwargs, ) -> Tuple[TextIO, Optional[str]]: """ Check if the text or text stream is valid. .. # noqa: DAR401 :param path: the multi-kind source of the configs. :param allow_stream: flag :param allow_yaml_file: flag :param allow_builtin_resource: flag :param allow_raw_yaml_content: flag :param allow_raw_driver_yaml_content: flag :param allow_class_type: flag :param allow_dict: flag :param allow_json: flag :param args: unused :param kwargs: unused :return: a tuple, the first element is the text stream, the second element is the file path associate to it if available. """ import io from pkg_resources import resource_filename if not path: raise BadConfigSource elif allow_dict and isinstance(path, dict): from . import JAML tmp = JAML.dump(path) return io.StringIO(tmp), None elif allow_stream and hasattr(path, 'read'): # already a readable stream return path, None elif allow_yaml_file and is_yaml_filepath(path): comp_path = complete_path(path) return open(comp_path, encoding='utf8'), comp_path elif ( allow_builtin_resource and path.lstrip().startswith('_') and os.path.exists( resource_filename('jina', '/'.join(('resources', f'executors.{path}.yml'))) ) ): # NOTE: this returns a binary stream comp_path = resource_filename( 'jina', '/'.join(('resources', f'executors.{path}.yml')) ) return open(comp_path, encoding='utf8'), comp_path elif allow_raw_yaml_content and path.lstrip().startswith(('!', 'jtype')): # possible YAML content path = path.replace('|', '\n with: ') return io.StringIO(path), None elif allow_raw_driver_yaml_content and path.lstrip().startswith(('- !', '- jtype')): # possible driver YAML content, right now it is only used for debugging with open( resource_filename( 'jina', '/'.join( ( 'resources', 'executors.base.all.yml' if path.lstrip().startswith('- !!') else 'executors.base.yml', ) ), ) ) as fp: _defaults = path = path.replace('- !!', '- !').replace( '|', '\n with: ' ) # for indent, I know, its nasty path = _defaults.replace('*', path) return io.StringIO(path), None elif allow_class_type and path.isidentifier(): # possible class name return io.StringIO(f'!{path}'), None elif allow_json and isinstance(path, str): try: from . import JAML tmp = json.loads(path) tmp = JAML.dump(tmp) return io.StringIO(tmp), None except json.JSONDecodeError: raise BadConfigSource(path) else: raise BadConfigSource( f'{path} can not be resolved, it should be a readable stream,' ' or a valid file path, or a supported class name.' )
[docs]def complete_path(path: str, extra_search_paths: Optional[Tuple[str]] = None) -> str: """ Complete the path of file via searching in abs and relative paths. :param path: path of file. :param extra_search_paths: extra paths to conduct search :return: Completed file path. """ _p = None if os.path.exists(path): # this checks both abs and relative paths already _p = path else: _p = _search_file_in_paths(path, extra_search_paths) if _p: return os.path.abspath(_p) else: raise FileNotFoundError(f'can not find {path}')
def _search_file_in_paths(path, extra_search_paths: Optional[Tuple[str]] = None): """ Search in all dirs of the PATH environment variable and all dirs of files used in the call stack. :param path: the path to search for :param extra_search_paths: any extra locations to search for :return: the path (if found) """ import inspect search_paths = [] if extra_search_paths: search_paths.extend((v for v in extra_search_paths)) frame = inspect.currentframe() # iterates over the call stack while frame: search_paths.append(os.path.dirname(inspect.getfile(frame))) frame = frame.f_back search_paths += os.environ['PATH'].split(os.pathsep) # return first occurrence of path. If it does not exist, return None. for p in search_paths: _p = os.path.join(p, path) if os.path.exists(_p): return _p
[docs]def load_py_modules(d: Dict, extra_search_paths: Optional[Tuple[str]] = None) -> None: """ Find 'py_modules' in the dict recursively and then load them. :param d: the dictionary to traverse :param extra_search_paths: any extra paths to search """ mod = [] def _finditem(obj, key='py_modules'): value = obj.get(key, []) if isinstance(value, str): mod.append(value) elif isinstance(value, (list, tuple)): mod.extend(value) for k, v in obj.items(): if isinstance(v, dict): _finditem(v, key) _finditem(d) if mod: mod = [complete_path(m, extra_search_paths) for m in mod] PathImporter.add_modules(*mod)