1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
|
import numpy as np
import os
def normalized(a, axis=-1, order=2):
l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
l2[l2==0] = 1
return a / np.expand_dims(l2, axis)
N=100000
dummy_data_multiplier=3
N_queries = 1000
d=8
K=5
np.random.seed(1)
print("Generating data...")
batches_dummy= [ normalized(np.float32(np.random.random( (N,d)))) for _ in range(dummy_data_multiplier)]
batch_final = normalized (np.float32(np.random.random( (N,d))))
queries = normalized(np.float32(np.random.random( (N_queries,d))))
print("Computing distances...")
dist=np.dot(queries,batch_final.T)
topk=np.argsort(-dist)[:,:K]
print("Saving...")
try:
os.mkdir("data")
except OSError as e:
pass
for idx, batch_dummy in enumerate(batches_dummy):
batch_dummy.tofile('data/batch_dummy_%02d.bin' % idx)
batch_final.tofile('data/batch_final.bin')
queries.tofile('data/queries.bin')
np.int32(topk).tofile('data/gt.bin')
with open("data/config.txt", "w") as file:
file.write("%d %d %d %d %d" %(N, dummy_data_multiplier, N_queries, d, K))
|