File: test_utils_validators.py

package info (click to toggle)
huggingface-hub 1.2.2-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,424 kB
  • sloc: python: 45,857; sh: 434; makefile: 33
file content (59 lines) | stat: -rw-r--r-- 2,009 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import unittest
from pathlib import Path
from unittest.mock import Mock, patch

from huggingface_hub.utils import (
    HFValidationError,
    validate_hf_hub_args,
    validate_repo_id,
)


@patch("huggingface_hub.utils._validators.validate_repo_id")
class TestHfHubValidator(unittest.TestCase):
    """Test `validate_hf_hub_args` decorator calls all default validators."""

    def test_validate_repo_id_as_arg(self, validate_repo_id_mock: Mock) -> None:
        """Test `validate_repo_id` is called when `repo_id` is passed as arg."""
        self.dummy_function(123)
        validate_repo_id_mock.assert_called_once_with(123)

    def test_validate_repo_id_as_kwarg(self, validate_repo_id_mock: Mock) -> None:
        """Test `validate_repo_id` is called when `repo_id` is passed as kwarg."""
        self.dummy_function(repo_id=123)
        validate_repo_id_mock.assert_called_once_with(123)

    @staticmethod
    @validate_hf_hub_args
    def dummy_function(repo_id: str) -> None:
        pass


class TestRepoIdValidator(unittest.TestCase):
    VALID_VALUES = (
        "123",
        "foo",
        "foo/bar",
        "Foo-BAR_foo.bar123",
    )
    NOT_VALID_VALUES = (
        Path("foo/bar"),  # Must be a string
        "a" * 100,  # Too long
        "datasets/foo/bar",  # Repo_type forbidden in repo_id
        ".repo_id",  # Cannot start with .
        "repo_id.",  # Cannot end with .
        "foo--bar",  # Cannot contain "--"
        "foo..bar",  # Cannot contain "."
        "foo.git",  # Cannot end with ".git"
    )

    def test_valid_repo_ids(self) -> None:
        """Test `repo_id` validation on valid values."""
        for repo_id in self.VALID_VALUES:
            validate_repo_id(repo_id)

    def test_not_valid_repo_ids(self) -> None:
        """Test `repo_id` validation on not valid values."""
        for repo_id in self.NOT_VALID_VALUES:
            with self.assertRaises(HFValidationError, msg=f"'{repo_id}' must not be valid"):
                validate_repo_id(repo_id)