File: utils.py

package info (click to toggle)
huggingface-hub 0.31.1-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 5,092 kB
  • sloc: python: 40,321; makefile: 54
file content (56 lines) | stat: -rw-r--r-- 1,615 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import contextlib
from typing import Generator
from unittest.mock import patch


@contextlib.contextmanager
def production_endpoint() -> Generator:
    """Patch huggingface_hub to connect to production server in a context manager.

    Ugly way to patch all constants at once.
    TODO: refactor when https://github.com/huggingface/huggingface_hub/issues/1172 is fixed.

    Example:
    ```py
    def test_push_to_hub():
        # Pull from production Hub
        with production_endpoint():
            model = ...from_pretrained("modelname")

        # Push to staging Hub
        model.push_to_hub()
    ```
    """
    PROD_ENDPOINT = "https://huggingface.co"
    ENDPOINT_TARGETS = [
        "huggingface_hub.constants",
        "huggingface_hub._commit_api",
        "huggingface_hub.hf_api",
        "huggingface_hub.lfs",
        "huggingface_hub.commands.user",
        "huggingface_hub.utils._git_credential",
    ]

    PROD_URL_TEMPLATE = PROD_ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}"
    URL_TEMPLATE_TARGETS = [
        "huggingface_hub.constants",
        "huggingface_hub.file_download",
    ]

    from huggingface_hub.hf_api import api

    patchers = (
        [patch(target + ".ENDPOINT", PROD_ENDPOINT) for target in ENDPOINT_TARGETS]
        + [patch(target + ".HUGGINGFACE_CO_URL_TEMPLATE", PROD_URL_TEMPLATE) for target in URL_TEMPLATE_TARGETS]
        + [patch.object(api, "endpoint", PROD_URL_TEMPLATE)]
    )

    # Start all patches
    for patcher in patchers:
        patcher.start()

    yield

    # Stop all patches
    for patcher in patchers:
        patcher.stop()