__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"
from collections import defaultdict
from typing import Tuple, Dict, List
import numpy as np
from . import FlatRecursiveMixin
from .encode import BaseEncodeDriver
from ..types.document.multimodal import MultimodalDocument
if False:
from ..types.sets import DocumentSet
[docs]class MultiModalDriver(FlatRecursiveMixin, BaseEncodeDriver):
"""Extract multimodal embeddings from different modalities.
Input-Output ::
Input:
document:
|- chunk: {modality: mode1}
|
|- chunk: {modality: mode2}
Output:
document: (embedding: multimodal encoding)
|- chunk: {modality: mode1}
|
|- chunk: {modality: mode2}
.. note::
- It traverses on the ``documents`` for which we want to apply the ``multimodal`` embedding. This way
we can use the `batching` capabilities for the `executor`.
.. warning::
- It assumes that every ``chunk`` of a ``document`` belongs to a different modality.
"""
def __init__(self, traversal_paths: Tuple[str] = ('r',), *args, **kwargs):
super().__init__(traversal_paths=traversal_paths, *args, **kwargs)
@property
def positional_modality(self) -> List[str]:
"""Get position per modality.
:return: the list of strings representing the name and order of the modality.
"""
if not self._exec.positional_modality:
raise RuntimeError(
'Could not know which position of the ndarray to load to each modality'
)
return self._exec.positional_modality
def _get_executor_input_arguments(
self, content_by_modality: Dict[str, 'np.ndarray']
) -> List['np.ndarray']:
"""From a dictionary ``content_by_modality`` it returns the arguments in the proper order so that they can be
passed to the executor.
:param content_by_modality: a dictionary of `Document content` by modality name
:return: list of input arguments as np arrays
"""
return [content_by_modality[modality] for modality in self.positional_modality]
def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
"""Apply the driver to each of the Documents in docs.
:param docs: the docs for which a ``multimodal embedding`` will be computed, whose chunks are of different
:param args: unused
:param kwargs: unused
"""
content_by_modality = defaultdict(
list
) # array of num_rows equal to num_docs and num_columns equal to
valid_docs = []
for doc in docs:
# convert to MultimodalDocument
doc = MultimodalDocument(doc)
if doc.modality_content_map:
valid_docs.append(doc)
for modality in self.positional_modality:
content_by_modality[modality].append(doc[modality])
else:
self.logger.warning(
f'Invalid doc {doc.id}. Only one chunk per modality is accepted'
)
if len(valid_docs) > 0:
# Pass a variable length argument (one argument per array)
for modality in self.positional_modality:
content_by_modality[modality] = np.stack(content_by_modality[modality])
# Guarantee that the arguments are provided to the executor in its desired order
input_args = self._get_executor_input_arguments(content_by_modality)
embeds = self.exec_fn(*input_args)
if len(valid_docs) != embeds.shape[0]:
self.logger.error(
f'mismatched {len(valid_docs)} docs from level {valid_docs[0].granularity} '
f'and a {embeds.shape} shape embedding, the first dimension must be the same'
)
for doc, embedding in zip(valid_docs, embeds):
doc.embedding = embedding