1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
|
import contextlib
from typing import Generator
from unittest.mock import patch
@contextlib.contextmanager
def production_endpoint() -> Generator:
"""Patch huggingface_hub to connect to production server in a context manager.
Ugly way to patch all constants at once.
TODO: refactor when https://github.com/huggingface/huggingface_hub/issues/1172 is fixed.
Example:
```py
def test_push_to_hub():
# Pull from production Hub
with production_endpoint():
model = ...from_pretrained("modelname")
# Push to staging Hub
model.push_to_hub()
```
"""
PROD_ENDPOINT = "https://huggingface.co"
ENDPOINT_TARGETS = [
"huggingface_hub.constants",
"huggingface_hub._commit_api",
"huggingface_hub.hf_api",
"huggingface_hub.lfs",
"huggingface_hub.commands.user",
"huggingface_hub.utils._git_credential",
]
PROD_URL_TEMPLATE = PROD_ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}"
URL_TEMPLATE_TARGETS = [
"huggingface_hub.constants",
"huggingface_hub.file_download",
]
from huggingface_hub.hf_api import api
patchers = (
[patch(target + ".ENDPOINT", PROD_ENDPOINT) for target in ENDPOINT_TARGETS]
+ [patch(target + ".HUGGINGFACE_CO_URL_TEMPLATE", PROD_URL_TEMPLATE) for target in URL_TEMPLATE_TARGETS]
+ [patch.object(api, "endpoint", PROD_URL_TEMPLATE)]
)
# Start all patches
for patcher in patchers:
patcher.start()
yield
# Stop all patches
for patcher in patchers:
patcher.stop()
|