File: conftest.py

package info (click to toggle)
python-opt-einsum-fx 0.1.4-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 180 kB
  • sloc: python: 664; makefile: 13
file content (30 lines) | stat: -rw-r--r-- 933 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import pytest

import torch

FLOAT_TOLERANCE = {
    t: torch.as_tensor(v, dtype=t)
    for t, v in {torch.float32: 1e-5, torch.float64: 1e-10}.items()
}


@pytest.fixture(scope="session", autouse=True, params=["float32", "float64"])
def float_tolerance(request):
    """Run all tests with various PyTorch default dtypes.

    This is a session-wide, autouse fixture — you only need to request it explicitly if a test needs to know the tolerance for the current default dtype.

    Returns
    --------
        A precision threshold to use for closeness tests.
    """
    old_dtype = torch.get_default_dtype()
    dtype = {"float32": torch.float32, "float64": torch.float64}[request.param]
    torch.set_default_dtype(dtype)
    yield FLOAT_TOLERANCE[dtype]
    torch.set_default_dtype(old_dtype)


@pytest.fixture(scope="session")
def allclose(float_tolerance):
    return lambda x, y: torch.allclose(x, y, atol=float_tolerance)