File: testing_utils.py

package info (click to toggle)
huggingface-hub 1.2.2-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,424 kB
  • sloc: python: 45,857; sh: 434; makefile: 33
file content (481 lines) | stat: -rw-r--r-- 16,445 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
import inspect
import os
import shutil
import stat
import time
import unittest
import uuid
from contextlib import contextmanager
from enum import Enum
from functools import wraps
from pathlib import Path
from typing import Callable, Optional, TypeVar, Union
from unittest.mock import Mock, patch

import httpx
import pytest

from huggingface_hub.utils import is_package_available, logging
from tests.testing_constants import ENDPOINT_PRODUCTION, ENDPOINT_PRODUCTION_URL_SCHEME


logger = logging.get_logger(__name__)

SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
# Example model ids

# An actual model hosted on huggingface.co,
# w/ more details.
DUMMY_MODEL_ID = "julien-c/dummy-unknown"
DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
# One particular commit (not the top of `main`)
DUMMY_MODEL_ID_REVISION_INVALID = "aaaaaaa"
# This commit does not exist, so we should 404.
DUMMY_MODEL_ID_PINNED_SHA1 = "d9e9f15bc825e4b2c9249e9578f884bbcb5e3684"
# Sha-1 of config.json on the top of `main`, for checking purposes
DUMMY_MODEL_ID_PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3"
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes

# "hf-internal-testing/dummy-will-be-renamed" has been renamed to "hf-internal-testing/dummy-renamed"
DUMMY_RENAMED_OLD_MODEL_ID = "hf-internal-testing/dummy-will-be-renamed"
DUMMY_RENAMED_NEW_MODEL_ID = "hf-internal-testing/dummy-renamed"

SAMPLE_DATASET_IDENTIFIER = "lhoestq/custom_squad"
# Example dataset ids
DUMMY_DATASET_ID = "gaia-benchmark/GAIA"
DUMMY_DATASET_ID_REVISION_ONE_SPECIFIC_COMMIT = "c603981e170e9e333934a39781d2ae3a2677e81f"  # on branch "test-branch"

YES = ("y", "yes", "t", "true", "on", "1")
NO = ("n", "no", "f", "false", "off", "0")


# Xet testing
DUMMY_XET_MODEL_ID = "celinah/dummy-xet-testing"
DUMMY_XET_FILE = "dummy.safetensors"
DUMMY_XET_REGULAR_FILE = "dummy.txt"

# extra large file for testing on production
DUMMY_EXTRA_LARGE_FILE_MODEL_ID = "brianronan/dummy-xet-edge-case-files"
DUMMY_EXTRA_LARGE_FILE_NAME = "verylargemodel.safetensors"  # > 50GB file
DUMMY_TINY_FILE_NAME = "tiny.safetensors"  # 45 byte file


def repo_name(id: Optional[str] = None, prefix: str = "repo") -> str:
    """
    Return a readable pseudo-unique repository name for tests.

    Example:
    ```py
    >>> repo_name()
    repo-2fe93f-16599646671840

    >>> repo_name("my-space", prefix='space')
    space-my-space-16599481979701
    """
    if id is None:
        id = uuid.uuid4().hex[:6]
    ts = int(time.time() * 10e3)
    return f"{prefix}-{id}-{ts}"


def parse_flag_from_env(key: str, default: bool = False) -> bool:
    try:
        value = os.environ[key]
    except KeyError:
        # KEY isn't set, default to `default`.
        return default

    # KEY is set, convert it to True or False.
    if value.lower() in YES:
        return True
    elif value.lower() in NO:
        return False
    else:
        # More values are supported, but let's keep the message simple.
        raise ValueError(f"If set, '{key}' must be one of {YES + NO}. Got '{value}'.")


def parse_int_from_env(key, default=None):
    try:
        value = os.environ[key]
    except KeyError:
        _value = default
    else:
        try:
            _value = int(value)
        except ValueError:
            raise ValueError("If set, {} must be a int.".format(key))
    return _value


_run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False)


def require_git_lfs(test_case):
    """
    Decorator to mark tests that requires git-lfs.

    git-lfs requires additional dependencies, and tests are skipped by default. Set the RUN_GIT_LFS_TESTS environment
    variable to a truthy value to run them.
    """
    if not _run_git_lfs_tests:
        return unittest.skip("test of git lfs workflow")(test_case)
    else:
        return test_case


def requires(package_name: str):
    """
    Decorator marking a test that requires PyTorch.
    These tests are skipped when PyTorch isn't installed.
    """

    def _inner(test_case):
        if not is_package_available(package_name):
            return pytest.mark.skip(f"Test requires '{package_name}'")(test_case)
        else:
            return test_case

    return _inner


class RequestWouldHangIndefinitelyError(Exception):
    pass


class OfflineSimulationMode(Enum):
    CONNECTION_FAILS = 0
    CONNECTION_TIMES_OUT = 1
    HF_HUB_OFFLINE_SET_TO_1 = 2


@contextmanager
def offline(mode=OfflineSimulationMode.CONNECTION_FAILS, timeout=1e-16):
    """
    Simulate offline mode.

    There are three offline simulation modes:

    CONNECTION_FAILS (default mode): a ConnectionError is raised for each network call.
        Connection errors are created by mocking socket.socket
    CONNECTION_TIMES_OUT: the connection hangs until it times out.
        The default timeout value is low (1e-16) to speed up the tests.
        Timeout errors are created by mocking httpx.request
    HF_HUB_OFFLINE_SET_TO_1: the HF_HUB_OFFLINE_SET_TO_1 environment variable is set to 1.
        This makes the http/ftp calls of the library instantly fail and raise an OfflineModeEnabled error.
    """
    import socket

    # Store the original httpx.request to avoid recursion
    original_httpx_request = httpx.request

    def timeout_request(method, url, **kwargs):
        # Change the url to an invalid url so that the connection hangs
        invalid_url = "https://10.255.255.1"
        if kwargs.get("timeout") is None:
            raise RequestWouldHangIndefinitelyError(
                f"Tried a call to {url} in offline mode with no timeout set. Please set a timeout."
            )
        kwargs["timeout"] = timeout
        try:
            return original_httpx_request(method, invalid_url, **kwargs)
        except Exception as e:
            # The following changes in the error are just here to make the offline timeout error prettier
            if hasattr(e, "request"):
                e.request.url = url
            if hasattr(e, "args") and e.args:
                max_retry_error = e.args[0]
                if hasattr(max_retry_error, "args"):
                    max_retry_error.args = (max_retry_error.args[0].replace("10.255.255.1", f"OfflineMock[{url}]"),)
                e.args = (max_retry_error,)
            raise

    def offline_socket(*args, **kwargs):
        raise socket.error("Offline mode is enabled.")

    if mode is OfflineSimulationMode.CONNECTION_FAILS:
        # inspired from https://stackoverflow.com/a/18601897
        with patch("socket.socket", offline_socket):
            with patch("huggingface_hub.utils._http.get_session") as get_session_mock:
                mock_client = Mock()

                # Mock the request method to raise connection error
                def mock_request(*args, **kwargs):
                    raise httpx.ConnectError("Connection failed")

                # Mock the stream method to raise connection error
                def mock_stream(*args, **kwargs):
                    raise httpx.ConnectError("Connection failed")

                mock_client.request = mock_request
                mock_client.stream = mock_stream
                get_session_mock.return_value = mock_client
                yield
    elif mode is OfflineSimulationMode.CONNECTION_TIMES_OUT:
        # inspired from https://stackoverflow.com/a/904609
        with patch("httpx.request", timeout_request):
            with patch("huggingface_hub.utils._http.get_session") as get_session_mock:
                mock_client = Mock()
                mock_client.request = timeout_request

                # Mock the stream method to raise timeout
                def mock_stream(*args, **kwargs):
                    raise httpx.ConnectTimeout("Connection timed out")

                mock_client.stream = mock_stream
                get_session_mock.return_value = mock_client
                yield
    elif mode is OfflineSimulationMode.HF_HUB_OFFLINE_SET_TO_1:
        with patch("huggingface_hub.constants.HF_HUB_OFFLINE", True):
            yield
    else:
        raise ValueError("Please use a value from the OfflineSimulationMode enum.")


def set_write_permission_and_retry(func, path, excinfo):
    os.chmod(path, stat.S_IWRITE)
    func(path)


def rmtree_with_retry(path: Union[str, Path]) -> None:
    shutil.rmtree(path, onerror=set_write_permission_and_retry)


def with_production_testing(func):
    file_download = patch("huggingface_hub.constants.HUGGINGFACE_CO_URL_TEMPLATE", ENDPOINT_PRODUCTION_URL_SCHEME)
    hf_api = patch("huggingface_hub.constants.ENDPOINT", ENDPOINT_PRODUCTION)
    return hf_api(file_download(func))


def expect_deprecation(function_name: str):
    """
    Decorator to flag tests that we expect to use deprecated arguments.

    Args:
        function_name (`str`):
            Name of the function that we expect to use in a deprecated way.

    NOTE: if a test is expected to warn FutureWarnings but is not, the test will fail.

    Context: over time, some arguments/methods become deprecated. In order to track
             deprecation in tests, we run pytest with flag `-Werror::FutureWarning`.
             In order to keep old tests during the deprecation phase (before removing
             the feature completely) without changing them internally, we can flag
             them with this decorator.
    See full discussion in https://github.com/huggingface/huggingface_hub/pull/952.

    This decorator works hand-in-hand with the `_deprecate_arguments` and
    `_deprecate_positional_args` decorators.

    Example
    ```py
    # in src/hub_mixins.py
    from .utils._deprecation import _deprecate_arguments

    @_deprecate_arguments(version="0.12", deprecated_args={"repo_url"})
    def push_to_hub(...):
        (...)

    # in tests/test_something.py
    from .testing_utils import expect_deprecation

    class SomethingTest(unittest.TestCase):
        (...)

        @expect_deprecation("push_to_hub"):
        def test_push_to_hub_git_version(self):
            (...)
            push_to_hub(repo_url="something") <- Should warn with FutureWarnings
            (...)
    ```
    """

    def _inner_decorator(test_function: Callable) -> Callable:
        @wraps(test_function)
        def _inner_test_function(*args, **kwargs):
            with pytest.warns(FutureWarning, match=f".*'{function_name}'.*"):
                return test_function(*args, **kwargs)

        return _inner_test_function

    return _inner_decorator


def skip_on_windows(reason: str):
    """
    Decorator to flag tests that we want to skip on Windows.

    Args:
        reason (`str`):
            Reason to skip it.
    """

    def _inner_decorator(test_function: Callable) -> Callable:
        return pytest.mark.skipif(os.name == "nt", reason=reason)(test_function)

    return _inner_decorator


T = TypeVar("T")


def handle_injection(cls: T) -> T:
    """Handle mock injection for each test of a test class.

    When patching variables on a class level, only relevant mocks will be injected to
    the tests. This has 2 advantages:
    1. There is no need to expect all mocks in test arguments when they are not needed.
    2. Default mock injection append all mocks 1 by 1 to the test args. If the order of
       the patch calls or test argument is changed, it can lead to unexpected behavior.

    NOTE: `@handle_injection` has to be defined after the `@patch` calls.

    Example:
    ```py
    @patch("something.foo")
    @patch("something_else.foo.bar") # order doesn't matter
    @handle_injection # after @patch calls
    def TestHelloWorld(unittest.TestCase):

        def test_hello_foo(self, mock_foo: Mock) -> None:
            (...)

        def test_hello_bar(self, mock_bar: Mock) -> None
            (...)

        def test_hello_both(self, mock_foo: Mock, mock_bar: Mock) -> None:
            (...)
    ```

    There are limitations with the current implementation:
    1. All patched variables must have different names.
       Named injection will not work with both `@patch("something.foo")` and
       `@patch("something_else.foo")` patches.
    2. Tests are expected to take only `self` and mock arguments. If it's not the case,
       this helper will fail.
    3. Tests arguments must follow the `mock_{variable_name}` naming.
       Example: `@patch("something._foo")` -> `"mock__foo"`.
    4. Tests arguments must be typed as `Mock`.

    If required, we can improve the current implementation in the future to mitigate
    those limitations.

    Based on:
    - https://stackoverflow.com/a/3467879
    - https://stackoverflow.com/a/30764825
    - https://stackoverflow.com/a/57115876

    NOTE: this decorator is inspired from the fixture system from pytest.
    """
    # Iterate over class functions and decorate tests
    # Taken from https://stackoverflow.com/a/3467879
    #        and https://stackoverflow.com/a/30764825
    for name, fn in inspect.getmembers(cls):
        if name.startswith("test_"):
            setattr(cls, name, handle_injection_in_test(fn))

    # Return decorated class
    return cls


def handle_injection_in_test(fn: Callable) -> Callable:
    """
    Handle injections at a test level. See `handle_injection` for more details.

    Example:
    ```py
    def TestHelloWorld(unittest.TestCase):

        @patch("something.foo")
        @patch("something_else.foo.bar") # order doesn't matter
        @handle_injection_in_test # after @patch calls
        def test_hello_foo(self, mock_foo: Mock) -> None:
            (...)
    ```
    """
    signature = inspect.signature(fn)
    parameters = signature.parameters

    @wraps(fn)
    def _inner(*args, **kwargs):
        assert kwargs == {}

        # Initialize new dict at least with `self`.
        assert len(args) > 0
        assert len(parameters) > 0
        new_kwargs = {"self": args[0]}

        # Check which mocks have been injected
        mocks = {}
        for value in args[1:]:
            assert isinstance(value, Mock)
            mock_name = "mock_" + value._extract_mock_name()
            mocks[mock_name] = value

        # Check which mocks are expected
        for name, parameter in parameters.items():
            if name == "self":
                continue
            assert parameter.annotation is Mock
            assert name in mocks, (
                f"Mock `{name}` not found for test `{fn.__name__}`. Available: {', '.join(sorted(mocks.keys()))}"
            )
            new_kwargs[name] = mocks[name]

        # Run test only with a subset of mocks
        return fn(**new_kwargs)

    return _inner


def use_tmp_repo(repo_type: str = "model") -> Callable[[T], T]:
    """
    Test decorator to create a repo for the test and properly delete it afterward.

    TODO: could we make `_api`, `_user` and `_token` cleaner ?

    Example:
    ```py
    from huggingface_hub import RepoUrl
    from .testing_utils import use_tmp_repo

    class HfApiCommonTest(unittest.TestCase):
        _api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN)

        @use_tmp_repo()
        def test_create_tag_on_model(self, repo_url: RepoUrl) -> None:
            (...)

        @use_tmp_repo("dataset")
        def test_create_tag_on_dataset(self, repo_url: RepoUrl) -> None:
            (...)
    ```
    """

    def _inner_use_tmp_repo(test_fn: T) -> T:
        @wraps(test_fn)
        def _inner(*args, **kwargs):
            self = args[0]
            assert isinstance(self, unittest.TestCase)
            create_repo_kwargs = {}
            if repo_type == "space":
                create_repo_kwargs["space_sdk"] = "gradio"

            repo_url = self._api.create_repo(
                repo_id=repo_name(prefix=repo_type), repo_type=repo_type, **create_repo_kwargs
            )
            try:
                return test_fn(*args, **kwargs, repo_url=repo_url)
            finally:
                self._api.delete_repo(repo_id=repo_url.repo_id, repo_type=repo_type)

        return _inner

    return _inner_use_tmp_repo


def assert_in_logs(caplog: pytest.LogCaptureFixture, expected_output):
    """Helper to check if a message appears in logs."""
    log_text = "\n".join(record.message for record in caplog.records)
    assert expected_output in log_text, f"Expected '{expected_output}' not found in logs"