File: test_datasets_samplers.py

package info (click to toggle)
pytorch-vision 0.14.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 15,188 kB
  • sloc: python: 49,008; cpp: 10,019; sh: 610; java: 550; xml: 79; objc: 56; makefile: 32
file content (86 lines) | stat: -rw-r--r-- 3,785 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import pytest
import torch
from common_utils import assert_equal, get_list_of_videos
from torchvision import io
from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler
from torchvision.datasets.video_utils import VideoClips


@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
class TestDatasetsSamplers:
    def test_random_clip_sampler(self, tmpdir):
        video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
        video_clips = VideoClips(video_list, 5, 5)
        sampler = RandomClipSampler(video_clips, 3)
        assert len(sampler) == 3 * 3
        indices = torch.tensor(list(iter(sampler)))
        videos = torch.div(indices, 5, rounding_mode="floor")
        v_idxs, count = torch.unique(videos, return_counts=True)
        assert_equal(v_idxs, torch.tensor([0, 1, 2]))
        assert_equal(count, torch.tensor([3, 3, 3]))

    def test_random_clip_sampler_unequal(self, tmpdir):
        video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
        video_clips = VideoClips(video_list, 5, 5)
        sampler = RandomClipSampler(video_clips, 3)
        assert len(sampler) == 2 + 3 + 3
        indices = list(iter(sampler))
        assert 0 in indices
        assert 1 in indices
        # remove elements of the first video, to simplify testing
        indices.remove(0)
        indices.remove(1)
        indices = torch.tensor(indices) - 2
        videos = torch.div(indices, 5, rounding_mode="floor")
        v_idxs, count = torch.unique(videos, return_counts=True)
        assert_equal(v_idxs, torch.tensor([0, 1]))
        assert_equal(count, torch.tensor([3, 3]))

    def test_uniform_clip_sampler(self, tmpdir):
        video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
        video_clips = VideoClips(video_list, 5, 5)
        sampler = UniformClipSampler(video_clips, 3)
        assert len(sampler) == 3 * 3
        indices = torch.tensor(list(iter(sampler)))
        videos = torch.div(indices, 5, rounding_mode="floor")
        v_idxs, count = torch.unique(videos, return_counts=True)
        assert_equal(v_idxs, torch.tensor([0, 1, 2]))
        assert_equal(count, torch.tensor([3, 3, 3]))
        assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))

    def test_uniform_clip_sampler_insufficient_clips(self, tmpdir):
        video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
        video_clips = VideoClips(video_list, 5, 5)
        sampler = UniformClipSampler(video_clips, 3)
        assert len(sampler) == 3 * 3
        indices = torch.tensor(list(iter(sampler)))
        assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))

    def test_distributed_sampler_and_uniform_clip_sampler(self, tmpdir):
        video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
        video_clips = VideoClips(video_list, 5, 5)
        clip_sampler = UniformClipSampler(video_clips, 3)

        distributed_sampler_rank0 = DistributedSampler(
            clip_sampler,
            num_replicas=2,
            rank=0,
            group_size=3,
        )
        indices = torch.tensor(list(iter(distributed_sampler_rank0)))
        assert len(distributed_sampler_rank0) == 6
        assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14]))

        distributed_sampler_rank1 = DistributedSampler(
            clip_sampler,
            num_replicas=2,
            rank=1,
            group_size=3,
        )
        indices = torch.tensor(list(iter(distributed_sampler_rank1)))
        assert len(distributed_sampler_rank1) == 6
        assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4]))


if __name__ == "__main__":
    pytest.main([__file__])