1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
|
import numpy as np
import pytest
from sklearn.utils._weight_vector import (
WeightVector32,
WeightVector64,
)
@pytest.mark.parametrize(
"dtype, WeightVector",
[
(np.float32, WeightVector32),
(np.float64, WeightVector64),
],
)
def test_type_invariance(dtype, WeightVector):
"""Check the `dtype` consistency of `WeightVector`."""
weights = np.random.rand(100).astype(dtype)
average_weights = np.random.rand(100).astype(dtype)
weight_vector = WeightVector(weights, average_weights)
assert np.asarray(weight_vector.w).dtype is np.dtype(dtype)
assert np.asarray(weight_vector.aw).dtype is np.dtype(dtype)
|