1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
|
import gzip
import hashlib
import logging
import os
import tarfile
import zipfile
import torch
from torchtext import _CACHE_DIR
from ._download_hooks import _DATASET_DOWNLOAD_MANAGER
logger = logging.getLogger(__name__)
def reporthook(t):
"""
https://github.com/tqdm/tqdm.
"""
last_b = [0]
def inner(b=1, bsize=1, tsize=None):
"""
b: int, optional
Number of blocks just transferred [default: 1].
bsize: int, optional
Size of each block (in tqdm units) [default: 1].
tsize: int, optional
Total size (in tqdm units). If [default: None] remains unchanged.
"""
if tsize is not None:
t.total = tsize
t.update((b - last_b[0]) * bsize)
last_b[0] = b
return inner
def validate_file(file_obj, hash_value, hash_type="sha256"):
"""Validate a given file object with its hash.
Args:
file_obj: File object to read from.
hash_value (str): Hash for url.
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
Returns:
bool: return True if its a valid file, else False.
"""
if hash_type == "sha256":
hash_func = hashlib.sha256()
elif hash_type == "md5":
hash_func = hashlib.md5()
else:
raise ValueError
while True:
# Read by chunk to avoid filling memory
chunk = file_obj.read(1024 ** 2)
if not chunk:
break
hash_func.update(chunk)
return hash_func.hexdigest() == hash_value
def _check_hash(path, hash_value, hash_type):
logger.info("Validating hash {} matches hash of {}".format(hash_value, path))
with open(path, "rb") as file_obj:
if not validate_file(file_obj, hash_value, hash_type):
raise RuntimeError(
"The hash of {} does not match. Delete the file manually and retry.".format(os.path.abspath(path))
)
def download_from_url(url, path=None, root=".data", overwrite=False, hash_value=None, hash_type="sha256"):
"""Download file, with logic (from tensor2tensor) for Google Drive. Returns
the path to the downloaded file.
Args:
url: the url of the file from URL header. (None)
path: path where file will be saved
root: download folder used to store the file in (.data)
overwrite: overwrite existing files (False)
hash_value (str, optional): hash for url (Default: ``None``).
hash_type (str, optional): hash type, among "sha256" and "md5" (Default: ``"sha256"``).
Examples:
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
>>> torchtext.utils.download_from_url(url)
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
>>> torchtext.utils.download_from_url(url)
>>> '.data/validation.tar.gz'
"""
# figure out filename and root
if path is None:
_, filename = os.path.split(url)
root = os.path.abspath(root)
path = os.path.join(root, filename)
else:
path = os.path.abspath(path)
root, filename = os.path.split(os.path.abspath(path))
# skip download if path exists and overwrite is not True
if os.path.exists(path):
logger.info("File %s already exists." % path)
if not overwrite:
if hash_value:
_check_hash(path, hash_value, hash_type)
return path
# make root dir if does not exist
if not os.path.exists(root):
try:
os.makedirs(root)
except OSError:
raise OSError("Can't create the download directory {}.".format(root))
# download data and move to path
_DATASET_DOWNLOAD_MANAGER.get_local_path(url, destination=path)
logger.info("File {} downloaded.".format(path))
# validate
if hash_value:
_check_hash(path, hash_value, hash_type)
# all good
return path
def extract_archive(from_path, to_path=None, overwrite=False):
"""Extract archive.
Args:
from_path: the path of the archive.
to_path: the root path of the extracted files (directory of from_path)
overwrite: overwrite existing files (False)
Returns:
List of paths to extracted files even if not overwritten.
Examples:
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
>>> from_path = './validation.tar.gz'
>>> to_path = './'
>>> torchtext.utils.download_from_url(url, from_path)
>>> torchtext.utils.extract_archive(from_path, to_path)
>>> ['.data/val.de', '.data/val.en']
>>> torchtext.utils.download_from_url(url, from_path)
>>> torchtext.utils.extract_archive(from_path, to_path)
>>> ['.data/val.de', '.data/val.en']
"""
if to_path is None:
to_path = os.path.dirname(from_path)
if from_path.endswith((".tar.gz", ".tgz")):
logger.info("Opening tar file {}.".format(from_path))
with tarfile.open(from_path, "r") as tar:
files = []
for file_ in tar:
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
logger.info("{} already extracted.".format(file_path))
if not overwrite:
continue
tar.extract(file_, to_path)
logger.info("Finished extracting tar file {}.".format(from_path))
return files
elif from_path.endswith(".zip"):
assert zipfile.is_zipfile(from_path), from_path
logger.info("Opening zip file {}.".format(from_path))
with zipfile.ZipFile(from_path, "r") as zfile:
files = []
for file_ in zfile.namelist():
file_path = os.path.join(to_path, file_)
files.append(file_path)
if os.path.exists(file_path):
logger.info("{} already extracted.".format(file_path))
if not overwrite:
continue
zfile.extract(file_, to_path)
files = [f for f in files if os.path.isfile(f)]
logger.info("Finished extracting zip file {}.".format(from_path))
return files
elif from_path.endswith(".gz"):
logger.info("Opening gz file {}.".format(from_path))
default_block_size = 65536
filename = from_path[:-3]
files = [filename]
with gzip.open(from_path, "rb") as gzfile, open(filename, "wb") as d_file:
while True:
block = gzfile.read(default_block_size)
if not block:
break
else:
d_file.write(block)
d_file.write(block)
logger.info("Finished extracting gz file {}.".format(from_path))
return files
else:
raise NotImplementedError("We currently only support tar.gz, .tgz, .gz and zip achives.")
def _log_class_usage(klass):
identifier = "torchtext"
if klass and hasattr(klass, "__name__"):
identifier += f".{klass.__name__}"
torch._C._log_api_usage_once(identifier)
def get_asset_local_path(asset_path: str, overwite=False) -> str:
"""Get local path for assets. Download if path does not exost locally
Args:
asset_path: Local path to asset or remote URL
overwrite: Indicate whether to overwrite the file when downloading from URL (default: False)
Returns:
bool: local path of the asset after downloading or reading from cache
Examples:
>>> url = 'http://<HOST>/file.txt'
>>> torchtext.utils.get_asset_local_path(url)
>>> '.data/file.txt'
>>> torchtext.utils.get_asset_local_path('/home/user/file.txt')
>>> '/home/user/file.txt'
"""
if os.path.exists(asset_path):
local_path = asset_path
else:
local_path = download_from_url(url=asset_path, root=_CACHE_DIR, overwrite=overwite)
return local_path
|