1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055
|
import contextlib
import functools
import importlib
import inspect
import itertools
import os
import pathlib
import platform
import random
import shutil
import string
import struct
import tarfile
import unittest
import unittest.mock
import zipfile
from collections import defaultdict
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
import numpy as np
import PIL
import PIL.Image
import pytest
import torch
import torchvision.datasets
import torchvision.io
from common_utils import disable_console_output, get_tmp_dir
from torch.utils._pytree import tree_any
from torch.utils.data import DataLoader
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2
from torchvision.transforms.functional import get_dimensions
from torchvision.transforms.v2.functional import get_size
__all__ = [
"UsageError",
"lazy_importer",
"test_all_configs",
"DatasetTestCase",
"ImageDatasetTestCase",
"VideoDatasetTestCase",
"create_image_or_video_tensor",
"create_image_file",
"create_image_folder",
"create_video_file",
"create_video_folder",
"make_tar",
"make_zip",
"create_random_string",
]
class UsageError(Exception):
"""Should be raised in case an error happens in the setup rather than the test."""
class LazyImporter:
r"""Lazy importer for additional dependencies.
Some datasets require additional packages that are no direct dependencies of torchvision. Instances of this class
provide modules listed in MODULES as attributes. They are only imported when accessed.
"""
MODULES = (
"av",
"lmdb",
"pycocotools",
"requests",
"scipy.io",
"scipy.sparse",
"h5py",
)
def __init__(self):
modules = defaultdict(list)
for module in self.MODULES:
module, *submodules = module.split(".", 1)
if submodules:
modules[module].append(submodules[0])
else:
# This introduces the module so that it is known when we later iterate over the dictionary.
modules.__missing__(module)
for module, submodules in modules.items():
# We need the quirky 'module=module' and submodules=submodules arguments to the lambda since otherwise the
# lookup for these would happen at runtime rather than at definition. Thus, without it, every property
# would try to import the last item in 'modules'
setattr(
type(self),
module,
property(lambda self, module=module, submodules=submodules: LazyImporter._import(module, submodules)),
)
@staticmethod
def _import(package, subpackages):
try:
module = importlib.import_module(package)
except ImportError as error:
raise UsageError(
f"Failed to import module '{package}'. "
f"This probably means that the current test case needs '{package}' installed, "
f"but it is not a dependency of torchvision. "
f"You need to install it manually, for example 'pip install {package}'."
) from error
for name in subpackages:
importlib.import_module(f".{name}", package=package)
return module
lazy_importer = LazyImporter()
def requires_lazy_imports(*modules):
def outer_wrapper(fn):
@functools.wraps(fn)
def inner_wrapper(*args, **kwargs):
for module in modules:
getattr(lazy_importer, module.replace(".", "_"))
return fn(*args, **kwargs)
return inner_wrapper
return outer_wrapper
def test_all_configs(test):
"""Decorator to run test against all configurations.
Add this as decorator to an arbitrary test to run it against all configurations. This includes
:attr:`DatasetTestCase.DEFAULT_CONFIG` and :attr:`DatasetTestCase.ADDITIONAL_CONFIGS`.
The current configuration is provided as the first parameter for the test:
.. code-block::
@test_all_configs()
def test_foo(self, config):
pass
.. note::
This will try to remove duplicate configurations. During this process it will not preserve a potential
ordering of the configurations or an inner ordering of a configuration.
"""
def maybe_remove_duplicates(configs):
try:
return [dict(config_) for config_ in {tuple(sorted(config.items())) for config in configs}]
except TypeError:
# A TypeError will be raised if a value of any config is not hashable, e.g. a list. In that case duplicate
# removal would be a lot more elaborate, and we simply bail out.
return configs
@functools.wraps(test)
def wrapper(self):
configs = []
if self.DEFAULT_CONFIG is not None:
configs.append(self.DEFAULT_CONFIG)
if self.ADDITIONAL_CONFIGS is not None:
configs.extend(self.ADDITIONAL_CONFIGS)
if not configs:
configs = [self._KWARG_DEFAULTS.copy()]
else:
configs = maybe_remove_duplicates(configs)
for config in configs:
with self.subTest(**config):
test(self, config)
return wrapper
class DatasetTestCase(unittest.TestCase):
"""Abstract base class for all dataset testcases.
You have to overwrite the following class attributes:
- DATASET_CLASS (torchvision.datasets.VisionDataset): Class of dataset to be tested.
- FEATURE_TYPES (Sequence[Any]): Types of the elements returned by index access of the dataset. Instead of
providing these manually, you can instead subclass ``ImageDatasetTestCase`` or ``VideoDatasetTestCase```to
get a reasonable default, that should work for most cases. Each entry of the sequence may be a tuple,
to indicate multiple possible values.
Optionally, you can overwrite the following class attributes:
- DEFAULT_CONFIG (Dict[str, Any]): Config that will be used by default. If omitted, this defaults to all
keyword arguments of the dataset minus ``transform``, ``target_transform``, ``transforms``, and
``download``. Overwrite this if you want to use a default value for a parameter for which the dataset does
not provide one.
- ADDITIONAL_CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictionary can
contain an arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
``transforms``, or ``download``.
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
available, the tests are skipped.
Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on.
The fake data should resemble the original data as close as necessary, while containing only few examples. During
the creation of the dataset check-, download-, and extract-functions from ``torchvision.datasets.utils`` are
disabled.
Without further configuration, the testcase will test if
1. the dataset raises a :class:`FileNotFoundError` or a :class:`RuntimeError` if the data files are not found or
corrupted,
2. the dataset inherits from `torchvision.datasets.VisionDataset`,
3. the dataset can be turned into a string,
4. the feature types of a returned example matches ``FEATURE_TYPES``,
5. the number of examples matches the injected fake data, and
6. the dataset calls ``transform``, ``target_transform``, or ``transforms`` if available when accessing data.
Case 3. to 6. are tested against all configurations in ``CONFIGS``.
To add dataset-specific tests, create a new method that takes no arguments with ``test_`` as a name prefix:
.. code-block::
def test_foo(self):
pass
If you want to run the test against all configs, add the ``@test_all_configs`` decorator to the definition and
accept a single argument:
.. code-block::
@test_all_configs
def test_bar(self, config):
pass
Within the test you can use the ``create_dataset()`` method that yields the dataset as well as additional
information provided by the ``ìnject_fake_data()`` method:
.. code-block::
def test_baz(self):
with self.create_dataset() as (dataset, info):
pass
"""
DATASET_CLASS = None
FEATURE_TYPES = None
DEFAULT_CONFIG = None
ADDITIONAL_CONFIGS = None
REQUIRED_PACKAGES = None
# These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
_TRANSFORM_KWARGS = {
"transform",
"target_transform",
"transforms",
}
# These keyword arguments get a 'special' treatment and should not be set in DEFAULT_CONFIG or ADDITIONAL_CONFIGS.
_SPECIAL_KWARGS = {
*_TRANSFORM_KWARGS,
"download",
}
# These fields are populated during setupClass() within _populate_private_class_attributes()
# This will be a dictionary containing all keyword arguments with their respective default values extracted from
# the dataset constructor.
_KWARG_DEFAULTS = None
# This will be a set of all _SPECIAL_KWARGS that the dataset constructor takes.
_HAS_SPECIAL_KWARG = None
# These functions are disabled during dataset creation in create_dataset().
_CHECK_FUNCTIONS = {
"check_md5",
"check_integrity",
}
_DOWNLOAD_EXTRACT_FUNCTIONS = {
"download_url",
"download_file_from_google_drive",
"extract_archive",
"download_and_extract_archive",
}
def dataset_args(self, tmpdir: str, config: Dict[str, Any]) -> Sequence[Any]:
"""Define positional arguments passed to the dataset.
.. note::
The default behavior is only valid if the dataset to be tested has ``root`` as the only required parameter.
Otherwise, you need to overwrite this method.
Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
to be created and in turn also for the fake data injected here.
config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
fields for all dataset parameters with default values.
Returns:
(Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets.
"""
return (tmpdir,)
def inject_fake_data(self, tmpdir: str, config: Dict[str, Any]) -> Union[int, Dict[str, Any]]:
"""Inject fake data for dataset into a temporary directory.
During the creation of the dataset the download and extract logic is disabled. Thus, the fake data injected
here needs to resemble the raw data, i.e. the state of the dataset directly after the files are downloaded and
potentially extracted.
Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
to be created and in turn also for the fake data injected here.
config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
fields for all dataset parameters with default values.
Needs to return one of the following:
1. (int): Number of examples in the dataset to be created, or
2. (Dict[str, Any]): Additional information about the injected fake data. Must contain the field
``"num_examples"`` that corresponds to the number of examples in the dataset to be created.
"""
raise NotImplementedError("You need to provide fake data in order for the tests to run.")
@contextlib.contextmanager
def create_dataset(
self,
config: Optional[Dict[str, Any]] = None,
inject_fake_data: bool = True,
patch_checks: Optional[bool] = None,
**kwargs: Any,
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
r"""Create the dataset in a temporary directory.
The configuration passed to the dataset is populated to contain at least all parameters with default values.
For this the following order of precedence is used:
1. Parameters in :attr:`kwargs`.
2. Configuration in :attr:`config`.
3. Configuration in :attr:`~DatasetTestCase.DEFAULT_CONFIG`.
4. Default parameters of the dataset.
Args:
config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset.
inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
creating the dataset.
patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
omitted defaults to the same value as ``inject_fake_data``.
**kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
overlap with ``config``.
Yields:
dataset (torchvision.dataset.VisionDataset): Dataset.
info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
for details.
"""
if patch_checks is None:
patch_checks = inject_fake_data
special_kwargs, other_kwargs = self._split_kwargs(kwargs)
complete_config = self._KWARG_DEFAULTS.copy()
if self.DEFAULT_CONFIG:
complete_config.update(self.DEFAULT_CONFIG)
if config:
complete_config.update(config)
if other_kwargs:
complete_config.update(other_kwargs)
if "download" in self._HAS_SPECIAL_KWARG and special_kwargs.get("download", False):
# override download param to False param if its default is truthy
special_kwargs["download"] = False
patchers = self._patch_download_extract()
if patch_checks:
patchers.update(self._patch_checks())
with get_tmp_dir() as tmpdir:
args = self.dataset_args(tmpdir, complete_config)
info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None
with self._maybe_apply_patches(patchers), disable_console_output():
dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs)
yield dataset, info
@classmethod
def setUpClass(cls):
cls._verify_required_public_class_attributes()
cls._populate_private_class_attributes()
cls._process_optional_public_class_attributes()
super().setUpClass()
@classmethod
def _verify_required_public_class_attributes(cls):
if cls.DATASET_CLASS is None:
raise UsageError(
"The class attribute 'DATASET_CLASS' needs to be overwritten. "
"It should contain the class of the dataset to be tested."
)
if cls.FEATURE_TYPES is None:
raise UsageError(
"The class attribute 'FEATURE_TYPES' needs to be overwritten. "
"It should contain a sequence of types that the dataset returns when accessed by index."
)
@classmethod
def _populate_private_class_attributes(cls):
defaults = []
for cls_ in cls.DATASET_CLASS.__mro__:
if cls_ is torchvision.datasets.VisionDataset:
break
argspec = inspect.getfullargspec(cls_.__init__)
if not argspec.defaults:
continue
defaults.append(
{
kwarg: default
for kwarg, default in zip(argspec.args[-len(argspec.defaults) :], argspec.defaults)
if not kwarg.startswith("_")
}
)
if not argspec.varkw:
break
kwarg_defaults = dict()
for config in reversed(defaults):
kwarg_defaults.update(config)
has_special_kwargs = set()
for name in cls._SPECIAL_KWARGS:
if name not in kwarg_defaults:
continue
del kwarg_defaults[name]
has_special_kwargs.add(name)
cls._KWARG_DEFAULTS = kwarg_defaults
cls._HAS_SPECIAL_KWARG = has_special_kwargs
@classmethod
def _process_optional_public_class_attributes(cls):
def check_config(config, name):
special_kwargs = tuple(f"'{name}'" for name in cls._SPECIAL_KWARGS if name in config)
if special_kwargs:
raise UsageError(
f"{name} contains a value for the parameter(s) {', '.join(special_kwargs)}. "
f"These are handled separately by the test case and should not be set here. "
f"If you need to test some custom behavior regarding these parameters, "
f"you need to write a custom test (*not* test case), e.g. test_custom_transform()."
)
if cls.DEFAULT_CONFIG is not None:
check_config(cls.DEFAULT_CONFIG, "DEFAULT_CONFIG")
if cls.ADDITIONAL_CONFIGS is not None:
for idx, config in enumerate(cls.ADDITIONAL_CONFIGS):
check_config(config, f"CONFIGS[{idx}]")
if cls.REQUIRED_PACKAGES:
missing_pkgs = []
for pkg in cls.REQUIRED_PACKAGES:
try:
importlib.import_module(pkg)
except ImportError:
missing_pkgs.append(f"'{pkg}'")
if missing_pkgs:
raise unittest.SkipTest(
f"The package(s) {', '.join(missing_pkgs)} are required to load the dataset "
f"'{cls.DATASET_CLASS.__name__}', but are not installed."
)
def _split_kwargs(self, kwargs):
special_kwargs = kwargs.copy()
other_kwargs = {key: special_kwargs.pop(key) for key in set(special_kwargs.keys()) - self._SPECIAL_KWARGS}
return special_kwargs, other_kwargs
def _inject_fake_data(self, tmpdir, config):
info = self.inject_fake_data(tmpdir, config)
if info is None:
raise UsageError(
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
"examples for the current configuration."
)
elif isinstance(info, int):
info = dict(num_examples=info)
elif not isinstance(info, dict):
raise UsageError(
f"The additional information returned by the method 'inject_fake_data' must be either an "
f"integer indicating the number of examples for the current configuration or a dictionary with "
f"the same content. Got {type(info)} instead."
)
elif "num_examples" not in info:
raise UsageError(
"The information dictionary returned by the method 'inject_fake_data' must contain a "
"'num_examples' field that holds the number of examples for the current configuration."
)
return info
def _patch_download_extract(self):
module = inspect.getmodule(self.DATASET_CLASS).__name__
return {unittest.mock.patch(f"{module}.{function}") for function in self._DOWNLOAD_EXTRACT_FUNCTIONS}
def _patch_checks(self):
module = inspect.getmodule(self.DATASET_CLASS).__name__
return {unittest.mock.patch(f"{module}.{function}", return_value=True) for function in self._CHECK_FUNCTIONS}
@contextlib.contextmanager
def _maybe_apply_patches(self, patchers):
with contextlib.ExitStack() as stack:
mocks = {}
for patcher in patchers:
with contextlib.suppress(AttributeError):
mocks[patcher.target] = stack.enter_context(patcher)
yield mocks
def test_not_found_or_corrupted(self):
with pytest.raises((FileNotFoundError, RuntimeError)):
with self.create_dataset(inject_fake_data=False):
pass
def test_smoke(self):
with self.create_dataset() as (dataset, _):
assert isinstance(dataset, torchvision.datasets.VisionDataset)
@test_all_configs
def test_str_smoke(self, config):
with self.create_dataset(config) as (dataset, _):
assert isinstance(str(dataset), str)
@test_all_configs
def test_feature_types(self, config):
with self.create_dataset(config) as (dataset, _):
example = dataset[0]
if len(self.FEATURE_TYPES) > 1:
actual = len(example)
expected = len(self.FEATURE_TYPES)
assert (
actual == expected
), "The number of the returned features does not match the the number of elements in FEATURE_TYPES: "
f"{actual} != {expected}"
else:
example = (example,)
for idx, (feature, expected_feature_type) in enumerate(zip(example, self.FEATURE_TYPES)):
with self.subTest(idx=idx):
assert isinstance(feature, expected_feature_type)
@test_all_configs
def test_num_examples(self, config):
with self.create_dataset(config) as (dataset, info):
assert len(list(dataset)) == len(dataset) == info["num_examples"]
@test_all_configs
def test_transforms(self, config):
mock = unittest.mock.Mock(wraps=lambda *args: args[0] if len(args) == 1 else args)
for kwarg in self._TRANSFORM_KWARGS:
if kwarg not in self._HAS_SPECIAL_KWARG:
continue
mock.reset_mock()
with self.subTest(kwarg=kwarg):
with self.create_dataset(config, **{kwarg: mock}) as (dataset, _):
dataset[0]
mock.assert_called()
@test_all_configs
def test_transforms_v2_wrapper(self, config):
try:
with self.create_dataset(config) as (dataset, info):
for target_keys in [None, "all"]:
if target_keys is not None and self.DATASET_CLASS not in {
torchvision.datasets.CocoDetection,
torchvision.datasets.VOCDetection,
torchvision.datasets.Kitti,
torchvision.datasets.WIDERFace,
}:
with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"):
wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
continue
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
assert isinstance(wrapped_dataset, self.DATASET_CLASS)
assert len(wrapped_dataset) == info["num_examples"]
wrapped_sample = wrapped_dataset[0]
assert tree_any(
lambda item: isinstance(item, (tv_tensors.TVTensor, PIL.Image.Image)), wrapped_sample
)
except TypeError as error:
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
if str(error).startswith(msg):
pytest.skip(msg)
raise error
except RuntimeError as error:
if "currently not supported by this wrapper" in str(error):
pytest.skip("Config is currently not supported by this wrapper")
raise error
class ImageDatasetTestCase(DatasetTestCase):
"""Abstract base class for image dataset testcases.
- Overwrites the FEATURE_TYPES class attribute to expect a :class:`PIL.Image.Image` and an integer label.
"""
FEATURE_TYPES = (PIL.Image.Image, int)
@contextlib.contextmanager
def create_dataset(
self,
config: Optional[Dict[str, Any]] = None,
inject_fake_data: bool = True,
patch_checks: Optional[bool] = None,
**kwargs: Any,
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
with super().create_dataset(
config=config,
inject_fake_data=inject_fake_data,
patch_checks=patch_checks,
**kwargs,
) as (dataset, info):
# PIL.Image.open() only loads the image metadata upfront and keeps the file open until the first access
# to the pixel data occurs. Trying to delete such a file results in an PermissionError on Windows. Thus, we
# force-load opened images.
# This problem only occurs during testing since some tests, e.g. DatasetTestCase.test_feature_types open an
# image, but never use the underlying data. During normal operation it is reasonable to assume that the
# user wants to work with the image he just opened rather than deleting the underlying file.
with self._force_load_images():
yield dataset, info
@contextlib.contextmanager
def _force_load_images(self):
open = PIL.Image.open
def new(fp, *args, **kwargs):
image = open(fp, *args, **kwargs)
if isinstance(fp, (str, pathlib.Path)):
image.load()
return image
with unittest.mock.patch("PIL.Image.open", new=new):
yield
class VideoDatasetTestCase(DatasetTestCase):
"""Abstract base class for video dataset testcases.
- Overwrites the 'FEATURE_TYPES' class attribute to expect two :class:`torch.Tensor` s for the video and audio as
well as an integer label.
- Overwrites the 'REQUIRED_PACKAGES' class attribute to require PyAV (``av``).
- Adds the 'DEFAULT_FRAMES_PER_CLIP' class attribute. If no 'frames_per_clip' is provided by 'inject_fake_data()'
and it is the last parameter without a default value in the dataset constructor, the value of the
'DEFAULT_FRAMES_PER_CLIP' class attribute is appended to the output.
"""
FEATURE_TYPES = (torch.Tensor, torch.Tensor, int)
REQUIRED_PACKAGES = ("av",)
FRAMES_PER_CLIP = 1
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_args = self._set_default_frames_per_clip(self.dataset_args)
def _set_default_frames_per_clip(self, dataset_args):
argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__)
args_without_default = argspec.args[1 : (-len(argspec.defaults) if argspec.defaults else None)]
frames_per_clip_last = args_without_default[-1] == "frames_per_clip"
@functools.wraps(dataset_args)
def wrapper(tmpdir, config):
args = dataset_args(tmpdir, config)
if frames_per_clip_last and len(args) == len(args_without_default) - 1:
args = (*args, self.FRAMES_PER_CLIP)
return args
return wrapper
def test_output_format(self):
for output_format in ["TCHW", "THWC"]:
with self.create_dataset(output_format=output_format) as (dataset, _):
for video, *_ in dataset:
if output_format == "TCHW":
num_frames, num_channels, *_ = video.shape
else: # output_format == "THWC":
num_frames, *_, num_channels = video.shape
assert num_frames == self.FRAMES_PER_CLIP
assert num_channels == 3
@test_all_configs
def test_transforms_v2_wrapper(self, config):
# `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly
# or use the supported `"TCHW"`
if config.setdefault("output_format", "TCHW") == "THWC":
return
super().test_transforms_v2_wrapper.__wrapped__(self, config)
def _no_collate(batch):
return batch
def check_transforms_v2_wrapper_spawn(dataset, expected_size):
# This check ensures that the wrapped datasets can be used with multiprocessing_context="spawn" in the DataLoader.
# We also check that transforms are applied correctly as a non-regression test for
# https://github.com/pytorch/vision/issues/8066
# Implicitly, this also checks that the wrapped datasets are pickleable.
# To save CI/test time, we only check on Windows where "spawn" is the default
if platform.system() != "Windows":
pytest.skip("Multiprocessing spawning is only checked on macOS.")
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)
def resize_was_applied(item):
# Checking the size of the output ensures that the Resize transform was correctly applied
return isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)) and get_size(item) == list(
expected_size
)
for wrapped_sample in dataloader:
assert tree_any(resize_was_applied, wrapped_sample)
def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
r"""Create a random uint8 tensor.
Args:
size (Sequence[int]): Size of the tensor.
"""
return torch.randint(0, 256, size, dtype=torch.uint8)
def create_image_file(
root: Union[pathlib.Path, str], name: Union[pathlib.Path, str], size: Union[Sequence[int], int] = 10, **kwargs: Any
) -> pathlib.Path:
"""Create an image file from random data.
Args:
root (Union[str, pathlib.Path]): Root directory the image file will be placed in.
name (Union[str, pathlib.Path]): Name of the image file.
size (Union[Sequence[int], int]): Size of the image that represents the ``(num_channels, height, width)``. If
scalar, the value is used for the height and width. If not provided, three channels are assumed.
kwargs (Any): Additional parameters passed to :meth:`PIL.Image.Image.save`.
Returns:
pathlib.Path: Path to the created image file.
"""
if isinstance(size, int):
size = (size, size)
if len(size) == 2:
size = (3, *size)
if len(size) != 3:
raise UsageError(
f"The 'size' argument should either be an int or a sequence of length 2 or 3. Got {len(size)} instead"
)
image = create_image_or_video_tensor(size)
file = pathlib.Path(root) / name
# torch (num_channels x height x width) -> PIL (width x height x num_channels)
image = image.permute(2, 1, 0)
# For grayscale images PIL doesn't use a channel dimension
if image.shape[2] == 1:
image = torch.squeeze(image, 2)
PIL.Image.fromarray(image.numpy()).save(file, **kwargs)
return file
def create_image_folder(
root: Union[pathlib.Path, str],
name: Union[pathlib.Path, str],
file_name_fn: Callable[[int], str],
num_examples: int,
size: Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]] = None,
**kwargs: Any,
) -> List[pathlib.Path]:
"""Create a folder of random images.
Args:
root (Union[str, pathlib.Path]): Root directory the image folder will be placed in.
name (Union[str, pathlib.Path]): Name of the image folder.
file_name_fn (Callable[[int], str]): Should return a file name if called with the file index.
num_examples (int): Number of images to create.
size (Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]]): Size of the images. If
callable, will be called with the index of the corresponding file. If omitted, a random height and width
between 3 and 10 pixels is selected on a per-image basis.
kwargs (Any): Additional parameters passed to :func:`create_image_file`.
Returns:
List[pathlib.Path]: Paths to all created image files.
.. seealso::
- :func:`create_image_file`
"""
if size is None:
def size(idx: int) -> Tuple[int, int, int]:
num_channels = 3
height, width = torch.randint(3, 11, size=(2,), dtype=torch.int).tolist()
return (num_channels, height, width)
root = pathlib.Path(root) / name
os.makedirs(root, exist_ok=True)
return [
create_image_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs)
for idx in range(num_examples)
]
def shape_test_for_stereo(
left: PIL.Image.Image,
right: PIL.Image.Image,
disparity: Optional[np.ndarray] = None,
valid_mask: Optional[np.ndarray] = None,
):
left_dims = get_dimensions(left)
right_dims = get_dimensions(right)
c, h, w = left_dims
# check that left and right are the same size
assert left_dims == right_dims
assert c == 3
# check that the disparity has the same spatial dimensions
# as the input
if disparity is not None:
assert disparity.ndim == 3
assert disparity.shape == (1, h, w)
if valid_mask is not None:
# check that valid mask is the same size as the disparity
_, dh, dw = disparity.shape
mh, mw = valid_mask.shape
assert dh == mh
assert dw == mw
@requires_lazy_imports("av")
def create_video_file(
root: Union[pathlib.Path, str],
name: Union[pathlib.Path, str],
size: Union[Sequence[int], int] = (1, 3, 10, 10),
fps: float = 25,
**kwargs: Any,
) -> pathlib.Path:
"""Create a video file from random data.
Args:
root (Union[str, pathlib.Path]): Root directory the video file will be placed in.
name (Union[str, pathlib.Path]): Name of the video file.
size (Union[Sequence[int], int]): Size of the video that represents the
``(num_frames, num_channels, height, width)``. If scalar, the value is used for the height and width.
If not provided, ``num_frames=1`` and ``num_channels=3`` are assumed.
fps (float): Frame rate in frames per second.
kwargs (Any): Additional parameters passed to :func:`torchvision.io.write_video`.
Returns:
pathlib.Path: Path to the created image file.
Raises:
UsageError: If PyAV is not available.
"""
if isinstance(size, int):
size = (size, size)
if len(size) == 2:
size = (3, *size)
if len(size) == 3:
size = (1, *size)
if len(size) != 4:
raise UsageError(
f"The 'size' argument should either be an int or a sequence of length 2, 3, or 4. Got {len(size)} instead"
)
video = create_image_or_video_tensor(size)
file = pathlib.Path(root) / name
torchvision.io.write_video(str(file), video.permute(0, 2, 3, 1), fps, **kwargs)
return file
@requires_lazy_imports("av")
def create_video_folder(
root: Union[str, pathlib.Path],
name: Union[str, pathlib.Path],
file_name_fn: Callable[[int], str],
num_examples: int,
size: Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]] = None,
fps=25,
**kwargs,
) -> List[pathlib.Path]:
"""Create a folder of random videos.
Args:
root (Union[str, pathlib.Path]): Root directory the video folder will be placed in.
name (Union[str, pathlib.Path]): Name of the video folder.
file_name_fn (Callable[[int], str]): Should return a file name if called with the file index.
num_examples (int): Number of videos to create.
size (Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]]): Size of the videos. If
callable, will be called with the index of the corresponding file. If omitted, a random even height and
width between 4 and 10 pixels is selected on a per-video basis.
fps (float): Frame rate in frames per second.
kwargs (Any): Additional parameters passed to :func:`create_video_file`.
Returns:
List[pathlib.Path]: Paths to all created video files.
Raises:
UsageError: If PyAV is not available.
.. seealso::
- :func:`create_video_file`
"""
if size is None:
def size(idx):
num_frames = 1
num_channels = 3
# The 'libx264' video codec, which is the default of torchvision.io.write_video, requires the height and
# width of the video to be divisible by 2.
height, width = (torch.randint(2, 6, size=(2,), dtype=torch.int) * 2).tolist()
return (num_frames, num_channels, height, width)
root = pathlib.Path(root) / name
os.makedirs(root, exist_ok=True)
return [
create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs)
for idx in range(num_examples)
]
def _split_files_or_dirs(root, *files_or_dirs):
files = set()
dirs = set()
for file_or_dir in files_or_dirs:
path = pathlib.Path(file_or_dir)
if not path.is_absolute():
path = root / path
if path.is_file():
files.add(path)
else:
dirs.add(path)
for sub_file_or_dir in path.glob("**/*"):
if sub_file_or_dir.is_file():
files.add(sub_file_or_dir)
else:
dirs.add(sub_file_or_dir)
if root in dirs:
dirs.remove(root)
return files, dirs
def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
archive = pathlib.Path(root) / name
if not files_or_dirs:
# We need to invoke `Path.with_suffix("")`, since call only applies to the last suffix if multiple suffixes are
# present. For example, `pathlib.Path("foo.tar.gz").with_suffix("")` results in `foo.tar`.
file_or_dir = archive
for _ in range(len(archive.suffixes)):
file_or_dir = file_or_dir.with_suffix("")
if file_or_dir.exists():
files_or_dirs = (file_or_dir,)
else:
raise ValueError("No file or dir provided.")
files, dirs = _split_files_or_dirs(root, *files_or_dirs)
with opener(archive) as fh:
for file in sorted(files):
adder(fh, file, file.relative_to(root))
if remove:
for file in files:
os.remove(file)
for dir in dirs:
shutil.rmtree(dir, ignore_errors=True)
return archive
def make_tar(root, name, *files_or_dirs, remove=True, compression=None):
# TODO: detect compression from name
return _make_archive(
root,
name,
*files_or_dirs,
opener=lambda archive: tarfile.open(archive, f"w:{compression}" if compression else "w"),
adder=lambda fh, file, relative_file: fh.add(file, arcname=relative_file),
remove=remove,
)
def make_zip(root, name, *files_or_dirs, remove=True):
return _make_archive(
root,
name,
*files_or_dirs,
opener=lambda archive: zipfile.ZipFile(archive, "w"),
adder=lambda fh, file, relative_file: fh.write(file, arcname=relative_file),
remove=remove,
)
def create_random_string(length: int, *digits: str) -> str:
"""Create a random string.
Args:
length (int): Number of characters in the generated string.
*digits (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`.
"""
if not digits:
digits = string.ascii_lowercase
else:
digits = "".join(itertools.chain(*digits))
return "".join(random.choice(digits) for _ in range(length))
def make_fake_pfm_file(h, w, file_name):
values = list(range(3 * h * w))
# Note: we pack everything in little endian: -1.0, and "<"
content = f"PF \n{w} {h} \n-1.0\n".encode() + struct.pack("<" + "f" * len(values), *values)
with open(file_name, "wb") as f:
f.write(content)
def make_fake_flo_file(h, w, file_name):
"""Creates a fake flow file in .flo format."""
# Everything needs to be in little Endian according to
# https://vision.middlebury.edu/flow/code/flow-code/README.txt
values = list(range(2 * h * w))
content = (
struct.pack("<4c", *(c.encode() for c in "PIEH"))
+ struct.pack("<i", w)
+ struct.pack("<i", h)
+ struct.pack("<" + "f" * len(values), *values)
)
with open(file_name, "wb") as f:
f.write(content)
|