1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
|
import re
import requests
# This is to allow monkey-patching in fbcode
from torch.hub import load_state_dict_from_url # noqa
from torchtext._internal.module_utils import is_module_available
from tqdm import tqdm
if is_module_available("torchdata"):
from torchdata.datapipes.iter import HttpReader, GDriveReader # noqa F401
def _stream_response(r, chunk_size=16 * 1024):
total_size = int(r.headers.get("Content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=1) as t:
for chunk in r.iter_content(chunk_size):
if chunk:
t.update(len(chunk))
yield chunk
def _get_response_from_google_drive(url):
confirm_token = None
session = requests.Session()
response = session.get(url, stream=True)
for k, v in response.cookies.items():
if k.startswith("download_warning"):
confirm_token = v
if confirm_token is None:
if "Quota exceeded" in str(response.content):
raise RuntimeError(
"Google drive link {} is currently unavailable, because the quota was exceeded.".format(url)
)
else:
raise RuntimeError("Internal error: confirm_token was not found in Google drive link.")
url = url + "&confirm=" + confirm_token
response = session.get(url, stream=True)
if "content-disposition" not in response.headers:
raise RuntimeError("Internal error: headers don't contain content-disposition.")
filename = re.findall('filename="(.+)"', response.headers["content-disposition"])
if filename is None:
raise RuntimeError("Filename could not be autodetected")
filename = filename[0]
return response, filename
class DownloadManager:
def get_local_path(self, url, destination):
if "drive.google.com" not in url:
response = requests.get(url, headers={"User-Agent": "Mozilla/5.0"}, stream=True)
else:
response, filename = _get_response_from_google_drive(url)
with open(destination, "wb") as f:
for chunk in _stream_response(response):
f.write(chunk)
_DATASET_DOWNLOAD_MANAGER = DownloadManager()
|