File: wmh_test.go

package info (click to toggle)
golang-github-go-enry-go-license-detector 4.3.0%2Bgit20221007.a3a1cc6-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 13,068 kB
  • sloc: makefile: 25
file content (115 lines) | stat: -rw-r--r-- 2,492 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package wmh

import (
	"testing"

	"github.com/stretchr/testify/assert"
)

func TestWMHSerialize(t *testing.T) {
	hasher := NewWeightedMinHasher(100, 50, 7)
	bytes, err := hasher.MarshalBinary()
	assert.Nil(t, err)
	newHasher := &WeightedMinHasher{}
	err = newHasher.UnmarshalBinary(bytes)
	assert.Nil(t, err)
	assert.Equal(t, hasher.Bitness, newHasher.Bitness)
	assert.Equal(t, hasher.dim, newHasher.dim)
	assert.Equal(t, hasher.sampleSize, newHasher.sampleSize)
	assert.Equal(t, hasher.rs, newHasher.rs)
	assert.Equal(t, hasher.lnCs, newHasher.lnCs)
	assert.Equal(t, hasher.betas, newHasher.betas)
}

func TestWMHHash(t *testing.T) {
	hasher := NewWeightedMinHasher(100, 50, 7)
	assert.NotNil(t, hasher)
	hasher.Bitness = 32
	hash := hasher.Hash([]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
		[]int{0, 10, 20, 30, 40, 50, 60, 70, 80, 90})
	/*
		import numpy, datasketch
		gen = datasketch.WeightedMinHashGenerator(100, 50, 7)
		with open("test_data/wmh.bin", "rb") as fin:
			fin.read(9)
			gen.rs = numpy.frombuffer(fin.read(100*50*4), dtype=numpy.float32).reshape(50, 100)
			gen.ln_cs = numpy.frombuffer(fin.read(100*50*4), dtype=numpy.float32).reshape(50, 100)
			betas = numpy.frombuffer(fin.read(100*50*2), dtype=numpy.uint16)
			gen.betas = (betas / ((1 << 16) - 1)).astype(numpy.float32).reshape(50, 100)
		v = numpy.zeros(100, numpy.float32)
		for i, ii in enumerate([0, 10, 20, 30, 40, 50, 60, 70, 80, 90]):
			v[ii] = i + 1
		mh = gen.minhash(v)
		for h in mh.hashvalues:
			print("%d," % (h[0] | (h[1] << 16)))
	*/
	truth := []uint64{
		65586,
		0,
		65626,
		65616,
		65626,
		30,
		65616,
		90,
		40,
		65576,
		65596,
		65586,
		65626,
		65626,
		589884,
		20,
		65616,
		65626,
		65596,
		65626,
		262234,
		131152,
		65596,
		65596,
		65556,
		65626,
		65576,
		65606,
		65626,
		65606,
		10,
		90,
		65596,
		65586,
		65626,
		65606,
		65626,
		0,
		131162,
		65626,
		65576,
		65626,
		65616,
		65606,
		65606,
		131152,
		65566,
		65626,
		65586,
		65626,
	}
	assert.Equal(t, truth, hash)
}

func TestWMHTrash(t *testing.T) {
	hasher := NewWeightedMinHasher(100, 50, 7)
	assert.Panics(t, func() {
		hasher.Hash([]float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
			[]int{0, 10, 20, 30, 40, 50, 60, 70, 80, 90})
	})
	assert.Panics(t, func() {
		hasher.Hash([]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
			[]int{0, 10, 20, 30, 40, 50, 60, 70, 80})
	})
	assert.Panics(t, func() {
		hasher.Hash([]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
			[]int{0, 10, 20, 30, 40, 50, 60, 70, 80, 100})
	})
}