Source code for jina.executors.encoders

__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

from typing import Any

from .. import BaseExecutor
from ..compound import CompoundExecutor

if False:
    # fix type-hint complain for sphinx and flake
    import numpy as np


[docs]class BaseEncoder(BaseExecutor): """``BaseEncoder`` encodes chunk into vector representation. The key function is :func:`encode`. .. seealso:: :mod:`jina.drivers.handlers.encode` """
[docs] def encode(self, data: Any, *args, **kwargs) -> Any: raise NotImplementedError
[docs]class BaseNumericEncoder(BaseEncoder): """BaseNumericEncoder encodes data from a ndarray, potentially B x ([T] x D) into a ndarray of B x D"""
[docs] def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray': """ :param data: a `B x ([T] x D)` numpy ``ndarray``, `B` is the size of the batch :return: a `B x D` numpy ``ndarray`` """ raise NotImplementedError
[docs]class BaseImageEncoder(BaseNumericEncoder): """BaseImageEncoder encodes data from a ndarray, potentially B x (Height x Width) into a ndarray of B x D""" pass
[docs]class BaseVideoEncoder(BaseNumericEncoder): """BaseVideoEncoder encodes data from a ndarray, potentially B x (Time x Height x Width) into a ndarray of B x D""" pass
[docs]class BaseAudioEncoder(BaseNumericEncoder): """BaseAudioEncoder encodes data from a ndarray, potentially B x (Time x D) into a ndarray of B x D""" pass
[docs]class BaseTextEncoder(BaseEncoder): """ BaseTextEncoder encodes data from an array of string type (data.dtype.kind == 'U') of size B into a ndarray of B x D. """
[docs] def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray': """ :param data: an 1d array of string type (data.dtype.kind == 'U') in size B :return: an ndarray of `B x D` """ raise NotImplementedError
[docs]class PipelineEncoder(CompoundExecutor):
[docs] def encode(self, data: Any, *args, **kwargs) -> Any: if not self.components: raise NotImplementedError for be in self.components: data = be.encode(data, *args, **kwargs) return data
[docs] def train(self, data: Any, *args, **kwargs) -> None: if not self.components: raise NotImplementedError for idx, be in enumerate(self.components): if not be.is_trained: be.train(data, *args, **kwargs) if idx + 1 < len(self.components): data = be.encode(data, *args, **kwargs)