1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
|
# cython: binding=True, infer_types=True, profile=False
from libc.stdint cimport int32_t, int64_t
from preshed.maps cimport PreshMap
from typing import Callable, Dict, Mapping, Optional, Tuple, Union, cast
import numpy
from ..config import registry
from ..model import Model
from ..types import Ints1d, Ints2d
from ..util import to_numpy
InT = Union[Ints1d, Ints2d]
OutT = Ints2d
cdef lookup(PreshMap mapping, int64_t[:] keys, int64_t default):
"""
Faster dict.get(keys, default) for the case when
the "dict" is a Dict[int, int] converted to PreshMap
and the "keys" is a numpy integer vector.
"""
cdef int maxi = len(keys)
result = numpy.empty(maxi, dtype="int64")
cdef int64_t[:] result_view = result
for i in range(maxi):
v = mapping[keys[i]]
if v is None:
result_view[i] = default
else:
result_view[i] = v
return result
@registry.layers("premap_ids.v1")
def premap_ids(
mapping_table: Mapping[int, int],
default: int = 0,
*,
column: Optional[int] = None
):
"""Remap integer inputs to integers a mapping table, usually as a
preprocess before embeddings."""
mapper = PreshMap(initial_size=len(mapping_table))
for k, v in mapping_table.items():
if not (isinstance(k, int) and isinstance(v, int)):
raise ValueError(
"mapping_table has to be of type Mapping[int, int], "
f"but found {k}, {type(k)} and {v}, {type(v)}"
)
mapper[k] = v
return Model(
"premap_ids",
forward,
attrs={
"mapping_table": mapper, "default": default, "column": column
}
)
def forward(
model: Model, inputs: InT, is_train: bool
) -> Tuple[OutT, Callable]:
table = model.attrs["mapping_table"]
default = model.attrs["default"]
column = model.attrs["column"]
# Have to convert to numpy anyways, because
# cupy ints don't work together with Python ints.
if column is None:
idx = to_numpy(inputs)
else:
idx = to_numpy(cast(Ints2d, inputs)[:, column])
result = lookup(table, idx, default)
arr = model.ops.asarray2i(result)
output = model.ops.reshape2i(arr, -1, 1)
inputs_shape = inputs.shape
def backprop(dY: OutT) -> InT:
if len(inputs_shape) == 1:
return model.ops.alloc1i(*inputs_shape)
else:
return model.ops.alloc2i(*inputs_shape)
return output, backprop
|