Source code for jina.drivers.encode

__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

from typing import Optional

from . import BaseExecutableDriver, FlatRecursiveMixin
from ..types.sets import DocumentSet
from ..excepts import LengthMismatchException


[docs]class BaseEncodeDriver(BaseExecutableDriver): """Drivers inherited from this Driver will bind :meth:`encode` by default """ def __init__( self, executor: Optional[str] = None, method: str = 'encode', *args, **kwargs ): super().__init__(executor, method, *args, **kwargs)
[docs]class EncodeDriver(FlatRecursiveMixin, BaseEncodeDriver): """Extract the content from documents and call executor and do encoding""" def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None: contents, docs_pts = docs.all_contents if docs_pts: embeds = self.exec_fn(contents) if len(docs_pts) != embeds.shape[0]: msg = ( f'mismatched {len(docs_pts)} docs from level {docs_pts[0].granularity} ' f'and a {embeds.shape} shape embedding, the first dimension must be the same' ) self.logger.error(msg) raise LengthMismatchException(msg) for doc, embedding in zip(docs_pts, embeds): doc.embedding = embedding
[docs]class ScipySparseEncodeDriver(FlatRecursiveMixin, BaseEncodeDriver): """Extract the content from documents and call executor and do encoding""" def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None: contents, docs_pts = docs.all_contents if docs_pts: embeds = self.exec_fn(contents) if len(docs_pts) != embeds.shape[0]: msg = ( f'mismatched {len(docs_pts)} docs from level {docs_pts[0].granularity} ' f'and a {embeds.shape} shape embedding, the first dimension must be the same' ) self.logger.error(msg) raise LengthMismatchException(msg) for idx, doc in enumerate(docs_pts): doc.embedding = embeds.getrow(idx)