[docs]defget_groundtruths(target,pseudo_match=False):# group doc_ids by their labelsa=np.squeeze(target['index-labels']['data'])a=np.stack([a,np.arange(len(a))],axis=1)a=a[a[:,0].argsort()]lbl_group=np.split(a[:,1],np.unique(a[:,0],return_index=True)[1][1:])# each label has one groundtruth, i.e. all docs that have the same label are considered as matchesgroundtruths={lbl:Document(tags={'id':-1})forlblinrange(10)}forlbl,doc_idsinenumerate(lbl_group):ifnotpseudo_match:# full-match, each doc has 6K matchesfordoc_idindoc_ids:match=Document()match.tags['id']=int(doc_id)groundtruths[lbl].matches.append(match)else:# pseudo-match, each doc has only one match, but this match's id is a list of 6k elementsmatch=Document()match.tags['id']=doc_ids.tolist()groundtruths[lbl].matches.append(match)returngroundtruths
[docs]defindex_generator(num_docs:int,target:dict):""" Generate the index data. :param num_docs: Number of documents to be indexed. :param target: Dictionary which stores the data paths :yields: index data """forinternal_doc_idinrange(num_docs):# x_blackwhite.shape is (28,28)x_blackwhite=255-target['index']['data'][internal_doc_id]# x_color.shape is (28,28,3)x_color=np.stack((x_blackwhite,)*3,axis=-1)d=Document(content=x_color)d.tags['id']=internal_doc_idyieldd
[docs]defquery_generator(num_docs:int,target:dict):""" Generate the query data. :param num_docs: Number of documents to be queried :param target: Dictionary which stores the data paths :yields: query data """for_inrange(num_docs):num_data=len(target['query-labels']['data'])idx=random.randint(0,num_data-1)# x_blackwhite.shape is (28,28)x_blackwhite=255-target['query']['data'][idx]# x_color.shape is (28,28,3)x_color=np.stack((x_blackwhite,)*3,axis=-1)d=Document(content=x_color,tags={'id':-1,'query_label':float(target['query-labels']['data'][idx][0]),},)yieldd
[docs]defprint_result(groundtruths,resp):fromdocarrayimportDocumentArrayglobaltop_kglobalevaluation_value""" Callback function to receive results. :param resp: returned response with data """queries=DocumentArray()queries.extend(resp.docs)gts=DocumentArray()forqueryinqueries:gt=groundtruths[query.tags['query_label']]gts.append(gt)queries.evaluate(gts,metric='recall_at_k',hash_fn=lambdad:d.tags['id'],top_k=50)queries.evaluate(gts,metric='precision_at_k',hash_fn=lambdad:d.tags['id'],top_k=50)forqueryinqueries:vi=query.uriresult_html.append(f'<tr><td><img src="{vi}"/></td><td>')top_k=len(query.matches)forkkinquery.matches:kmi=kk.uriresult_html.append(f'<img src="{kmi}" style="opacity:{kk.scores["cosine"].value}"/>')result_html.append('</td></tr>\n')# update evaluation values# as evaluator set to return running avg, here we can simply replace the valuefork,evaluationinquery.evaluations.items():evaluation_value[k]=evaluation.value
[docs]defdownload_data(targets,download_proxy=None,task_name='download fashion-mnist'):""" Download data. :param targets: target path for data. :param download_proxy: download proxy (e.g. 'http', 'https') :param task_name: name of the task """opener=urllib.request.build_opener()opener.addheaders=[('User-agent','Mozilla/5.0')]ifdownload_proxy:proxy=urllib.request.ProxyHandler({'http':download_proxy,'https':download_proxy})opener.add_handler(proxy)urllib.request.install_opener(opener)withProgressBar(total_length=len(targets),description=task_name)ast:fork,vintargets.items():ifnotos.path.exists(v['filename']):urllib.request.urlretrieve(v['url'],v['filename'],)t.update()ifk=='index-labels'ork=='query-labels':v['data']=load_labels(v['filename'])ifk=='index'ork=='query':v['data']=load_mnist(v['filename'])
[docs]defload_mnist(path):""" Load MNIST data :param path: path of data :return: MNIST data in np.array """withgzip.open(path,'rb')asfp:returnnp.frombuffer(fp.read(),dtype=np.uint8,offset=16).reshape([-1,28,28])
[docs]defload_labels(path:str):""" Load labels from path :param path: path of labels :return: labels in np.array """withgzip.open(path,'rb')asfp:returnnp.frombuffer(fp.read(),dtype=np.uint8,offset=8).reshape([-1,1])