Source code for jina.helloworld.fashion.helper

import gzip
import os
import random
import urllib.request
import webbrowser
from collections import defaultdict

import numpy as np

from docarray import Document
from jina.logging.predefined import default_logger
from jina.logging.profile import ProgressBar

result_html = []
top_k = 0
num_docs_evaluated = 0
evaluation_value = defaultdict(float)


[docs]def get_groundtruths(target, pseudo_match=False): # group doc_ids by their labels a = np.squeeze(target['index-labels']['data']) a = np.stack([a, np.arange(len(a))], axis=1) a = a[a[:, 0].argsort()] lbl_group = np.split(a[:, 1], np.unique(a[:, 0], return_index=True)[1][1:]) # each label has one groundtruth, i.e. all docs that have the same label are considered as matches groundtruths = {lbl: Document(tags={'id': -1}) for lbl in range(10)} for lbl, doc_ids in enumerate(lbl_group): if not pseudo_match: # full-match, each doc has 6K matches for doc_id in doc_ids: match = Document() match.tags['id'] = int(doc_id) groundtruths[lbl].matches.append(match) else: # pseudo-match, each doc has only one match, but this match's id is a list of 6k elements match = Document() match.tags['id'] = doc_ids.tolist() groundtruths[lbl].matches.append(match) return groundtruths
[docs]def index_generator(num_docs: int, target: dict): """ Generate the index data. :param num_docs: Number of documents to be indexed. :param target: Dictionary which stores the data paths :yields: index data """ for internal_doc_id in range(num_docs): # x_blackwhite.shape is (28,28) x_blackwhite = 255 - target['index']['data'][internal_doc_id] # x_color.shape is (28,28,3) x_color = np.stack((x_blackwhite,) * 3, axis=-1) d = Document(content=x_color) d.tags['id'] = internal_doc_id yield d
[docs]def query_generator(num_docs: int, target: dict): """ Generate the query data. :param num_docs: Number of documents to be queried :param target: Dictionary which stores the data paths :yields: query data """ for _ in range(num_docs): num_data = len(target['query-labels']['data']) idx = random.randint(0, num_data - 1) # x_blackwhite.shape is (28,28) x_blackwhite = 255 - target['query']['data'][idx] # x_color.shape is (28,28,3) x_color = np.stack((x_blackwhite,) * 3, axis=-1) d = Document( content=x_color, tags={ 'id': -1, 'query_label': float(target['query-labels']['data'][idx][0]), }, ) yield d
[docs]def write_html(html_path): """ Method to present results in browser. :param html_path: path of the written html """ with open( os.path.join(os.path.dirname(os.path.realpath(__file__)), 'demo.html') ) as fp, open(html_path, 'w') as fw: t = fp.read() t = t.replace('{% RESULT %}', '\n'.join(result_html)) t = t.replace( '{% PRECISION_EVALUATION %}', '{:.2f}%'.format(evaluation_value['precision_at_k'] * 100.0), ) t = t.replace( '{% RECALL_EVALUATION %}', '{:.2f}%'.format(evaluation_value['recall_at_k'] * 100.0), ) t = t.replace('{% TOP_K %}', str(top_k)) fw.write(t) url_html_path = 'file://' + os.path.abspath(html_path) try: webbrowser.open(url_html_path, new=2) except: pass # intentional pass, browser support isn't cross-platform finally: default_logger.info( f'You should see a "demo.html" opened in your browser, ' f'if not you may open {url_html_path} manually' ) default_logger.info( f'🤩 Intrigued? Play with `jina hello fashion --help` and [link=https://github.com/jina-ai/jina]learn more about Jina[/]' )
[docs]def download_data(targets, download_proxy=None, task_name='download fashion-mnist'): """ Download data. :param targets: target path for data. :param download_proxy: download proxy (e.g. 'http', 'https') :param task_name: name of the task """ opener = urllib.request.build_opener() opener.addheaders = [('User-agent', 'Mozilla/5.0')] if download_proxy: proxy = urllib.request.ProxyHandler( {'http': download_proxy, 'https': download_proxy} ) opener.add_handler(proxy) urllib.request.install_opener(opener) with ProgressBar(total_length=len(targets), description=task_name) as t: for k, v in targets.items(): if not os.path.exists(v['filename']): urllib.request.urlretrieve( v['url'], v['filename'], ) t.update() if k == 'index-labels' or k == 'query-labels': v['data'] = load_labels(v['filename']) if k == 'index' or k == 'query': v['data'] = load_mnist(v['filename'])
[docs]def load_mnist(path): """ Load MNIST data :param path: path of data :return: MNIST data in np.array """ with gzip.open(path, 'rb') as fp: return np.frombuffer(fp.read(), dtype=np.uint8, offset=16).reshape([-1, 28, 28])
[docs]def load_labels(path: str): """ Load labels from path :param path: path of labels :return: labels in np.array """ with gzip.open(path, 'rb') as fp: return np.frombuffer(fp.read(), dtype=np.uint8, offset=8).reshape([-1, 1])