Source code for jina.drivers.predict

from typing import List, Any, Union, Tuple, Optional

import numpy as np

from . import BaseExecutableDriver, FlatRecursiveMixin
from ..helper import typename

if False:
    from ..types.sets import DocumentSet


[docs]class BasePredictDriver(FlatRecursiveMixin, BaseExecutableDriver): """Drivers inherited from :class:`BasePredictDriver` will bind :meth:`predict` by default :param fields: name of fields to be used to predict tags, default "embeddings" :param args: additional positional arguments which are just used for the parent initialization :param kwargs: additional key value arguments which are just used for the parent initialization """ def __init__( self, executor: Optional[str] = None, method: str = 'predict', fields: Union[Tuple, str] = 'embedding', *args, **kwargs, ): self.fields = fields super().__init__(executor, method, *args, **kwargs)
[docs]class BaseLabelPredictDriver(BasePredictDriver): """Base class of a Driver for label prediction. :param output_tag: output label will be written to ``doc.tags`` :param args: additional positional arguments which are just used for the parent initialization :param kwargs: additional key value arguments which are just used for the parent initialization """ def __init__(self, output_tag: str = 'prediction', *args, **kwargs): super().__init__(*args, **kwargs) self.output_tag = output_tag def _apply_all( self, docs: 'DocumentSet', *args, **kwargs, ) -> None: if self.fields == 'embedding': predict_input, docs_pts = docs.all_embeddings elif self.fields == 'content': predict_input, docs_pts = docs.all_contents else: raise ValueError( f'{self.fields} is not a valid field name for {self!r}, must be one of embeddings, contents' ) if docs_pts: prediction = self.exec_fn(predict_input) labels = self.prediction2label( prediction ) # type: List[Union[str, List[str]]] for doc, label in zip(docs_pts, labels): doc.tags[self.output_tag] = label
[docs] def prediction2label(self, prediction: 'np.ndarray') -> List[Any]: """Converting ndarray prediction into list of readable labels .. note:: ``len(output)`` should be the same as ``prediction.shape[0]`` :param prediction: the float/int numpy ndarray given by :class:`BaseClassifier` :return: the readable label to be stored. .. # noqa: DAR401 .. # noqa: DAR202 """ raise NotImplementedError
[docs]class BinaryPredictDriver(BaseLabelPredictDriver): """Converts binary prediction into string label. This is often used with binary classifier. :param one_label: label when prediction is one :param zero_label: label when prediction is zero :param args: additional positional arguments which are just used for the parent initialization :param kwargs: additional key value arguments which are just used for the parent initialization """ def __init__(self, one_label: str = 'yes', zero_label: str = 'no', *args, **kwargs): super().__init__(*args, **kwargs) self.one_label = one_label self.zero_label = zero_label
[docs] def prediction2label(self, prediction: 'np.ndarray') -> List[str]: """ :param prediction: a (B,) or (B, 1) zero one array :return: the labels as either ``self.one_label`` or ``self.zero_label`` .. # noqa: DAR401 """ p = np.squeeze(prediction) if p.ndim > 1: raise ValueError( f'{typename(self)} expects prediction has ndim=1, but receiving ndim={p.ndim}' ) return [self.one_label if v else self.zero_label for v in p.astype(bool)]
[docs]class OneHotPredictDriver(BaseLabelPredictDriver): """Mapping prediction to one of the given labels Expect prediction to be 2dim array, zero-one valued. Each row corresponds to a sample, each column corresponds to a label. Each row can have only one 1. This is often used with multi-class classifier. """ def __init__(self, labels: List[str], *args, **kwargs): super().__init__(*args, **kwargs) self.labels = labels
[docs] def validate_labels(self, prediction: 'np.ndarray'): """Validate the labels. :param prediction: the predictions .. # noqa: DAR401 """ if prediction.ndim != 2: raise ValueError( f'{typename(self)} expects prediction to have ndim=2, but received {prediction.ndim}' ) if prediction.shape[1] != len(self.labels): raise ValueError( f'{typename(self)} expects prediction.shape[1]==len(self.labels), but received {prediction.shape}' )
[docs] def prediction2label(self, prediction: 'np.ndarray') -> List[str]: """ :param prediction: a (B, C) array where C is the number of classes, only one element can be one :return: the list of labels """ self.validate_labels(prediction) p = np.argmax(prediction, axis=1) return [self.labels[v] for v in p]
[docs]class MultiLabelPredictDriver(OneHotPredictDriver): """Mapping prediction to a list of labels Expect prediction to be 2dim array, zero-one valued. Each row corresponds to a sample, each column corresponds to a label. Each row can have only multiple 1s. This is often used with multi-label classifier, where each instance can have multiple labels """
[docs] def prediction2label(self, prediction: 'np.ndarray') -> List[List[str]]: """Transform the prediction into labels. :param prediction: the array of predictions :return: nested list of labels """ self.validate_labels(prediction) return [[self.labels[int(pp)] for pp in p.nonzero()[0]] for p in prediction]
[docs]class Prediction2DocBlobDriver(BasePredictDriver): """Write the prediction result directly into ``document.blob``. .. warning:: This will erase the content in ``document.text`` and ``document.buffer``. """ def _apply_all( self, docs: 'DocumentSet', *args, **kwargs, ) -> None: if self.fields == 'embedding': predict_input, docs_pts = docs.all_embeddings elif self.fields == 'content': predict_input, docs_pts = docs.all_contents else: raise ValueError( f'{self.fields} is not a valid field name for {self!r}, must be one of embeddings, contents' ) if docs_pts: prediction = self.exec_fn(predict_input) for doc, pred in zip(docs_pts, prediction): doc.blob = pred