File: test_typedefs.py

package info (click to toggle)
scikit-learn 1.4.2%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 25,036 kB
  • sloc: python: 201,105; cpp: 5,790; ansic: 854; makefile: 304; sh: 56; javascript: 20
file content (22 lines) | stat: -rw-r--r-- 631 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import numpy as np
import pytest

from sklearn.utils._typedefs import testing_make_array_from_typed_val


@pytest.mark.parametrize(
    "type_t, value, expected_dtype",
    [
        ("uint8_t", 1, np.uint8),
        ("intp_t", 1, np.intp),
        ("float64_t", 1.0, np.float64),
        ("float32_t", 1.0, np.float32),
        ("int32_t", 1, np.int32),
        ("int64_t", 1, np.int64),
    ],
)
def test_types(type_t, value, expected_dtype):
    """Check that the types defined in _typedefs correspond to the expected
    numpy dtypes.
    """
    assert testing_make_array_from_typed_val[type_t](value).dtype == expected_dtype