File: test_hash_embed.py

package info (click to toggle)
python-thinc 9.1.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,896 kB
  • sloc: python: 17,122; javascript: 1,559; ansic: 342; makefile: 15; sh: 13
file content (19 lines) | stat: -rw-r--r-- 533 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import numpy

from thinc.api import HashEmbed


def test_init():
    model = HashEmbed(64, 1000).initialize()
    assert model.get_dim("nV") == 1000
    assert model.get_dim("nO") == 64
    assert model.get_param("E").shape == (1000, 64)


def test_seed_changes_bucket():
    model1 = HashEmbed(64, 1000, seed=2).initialize()
    model2 = HashEmbed(64, 1000, seed=1).initialize()
    arr = numpy.ones((1,), dtype="uint64")
    vector1 = model1.predict(arr)
    vector2 = model2.predict(arr)
    assert vector1.sum() != vector2.sum()