File: _download_hooks.py

package info (click to toggle)
pytorch-text 0.14.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 11,560 kB
  • sloc: python: 14,197; cpp: 2,404; sh: 214; makefile: 20
file content (64 lines) | stat: -rw-r--r-- 2,190 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import re

import requests

# This is to allow monkey-patching in fbcode
from torch.hub import load_state_dict_from_url  # noqa
from torchtext._internal.module_utils import is_module_available
from tqdm import tqdm

if is_module_available("torchdata"):
    from torchdata.datapipes.iter import HttpReader, GDriveReader  # noqa F401


def _stream_response(r, chunk_size=16 * 1024):
    total_size = int(r.headers.get("Content-length", 0))
    with tqdm(total=total_size, unit="B", unit_scale=1) as t:
        for chunk in r.iter_content(chunk_size):
            if chunk:
                t.update(len(chunk))
                yield chunk


def _get_response_from_google_drive(url):
    confirm_token = None
    session = requests.Session()
    response = session.get(url, stream=True)
    for k, v in response.cookies.items():
        if k.startswith("download_warning"):
            confirm_token = v
    if confirm_token is None:
        if "Quota exceeded" in str(response.content):
            raise RuntimeError(
                "Google drive link {} is currently unavailable, because the quota was exceeded.".format(url)
            )
        else:
            raise RuntimeError("Internal error: confirm_token was not found in Google drive link.")

    url = url + "&confirm=" + confirm_token
    response = session.get(url, stream=True)

    if "content-disposition" not in response.headers:
        raise RuntimeError("Internal error: headers don't contain content-disposition.")

    filename = re.findall('filename="(.+)"', response.headers["content-disposition"])
    if filename is None:
        raise RuntimeError("Filename could not be autodetected")
    filename = filename[0]

    return response, filename


class DownloadManager:
    def get_local_path(self, url, destination):
        if "drive.google.com" not in url:
            response = requests.get(url, headers={"User-Agent": "Mozilla/5.0"}, stream=True)
        else:
            response, filename = _get_response_from_google_drive(url)

        with open(destination, "wb") as f:
            for chunk in _stream_response(response):
                f.write(chunk)


_DATASET_DOWNLOAD_MANAGER = DownloadManager()