File: test_transforms_v2_utils.py

package info (click to toggle)
pytorch-vision 0.21.0-3
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 20,228 kB
  • sloc: python: 65,904; cpp: 11,406; ansic: 2,459; java: 550; sh: 265; xml: 79; objc: 56; makefile: 33
file content (92 lines) | stat: -rw-r--r-- 4,104 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import PIL.Image
import pytest

import torch

import torchvision.transforms.v2._utils
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_masks, make_image

from torchvision import tv_tensors
from torchvision.transforms.v2._utils import has_all, has_any
from torchvision.transforms.v2.functional import to_pil_image


IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=tv_tensors.BoundingBoxFormat.XYXY)
MASK = make_detection_masks(DEFAULT_SIZE)


@pytest.mark.parametrize(
    ("sample", "types", "expected"),
    [
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
        ((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
        ((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask), False),
        ((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
        (
            (IMAGE, BOUNDING_BOX, MASK),
            (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
            True,
        ),
        ((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
        ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, tv_tensors.Image),), True),
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
        ((IMAGE,), (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
        (
            (torch.Tensor(IMAGE),),
            (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
            True,
        ),
        (
            (to_pil_image(IMAGE),),
            (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
            True,
        ),
    ],
)
def test_has_any(sample, types, expected):
    assert has_any(sample, *types) is expected


@pytest.mark.parametrize(
    ("sample", "types", "expected"),
    [
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
        (
            (IMAGE, BOUNDING_BOX, MASK),
            (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
            True,
        ),
        ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
        ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), False),
        ((IMAGE, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
        (
            (IMAGE, BOUNDING_BOX, MASK),
            (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
            True,
        ),
        ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
        ((IMAGE, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
        ((IMAGE, BOUNDING_BOX), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
        (
            (IMAGE, BOUNDING_BOX, MASK),
            (lambda obj: isinstance(obj, (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask)),),
            True,
        ),
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
    ],
)
def test_has_all(sample, types, expected):
    assert has_all(sample, *types) is expected