File: download.py

package info (click to toggle)
dials-data 2.4.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 944 kB
  • sloc: python: 647; sh: 31; makefile: 24
file content (328 lines) | stat: -rw-r--r-- 11,664 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
from __future__ import annotations

import concurrent.futures
import contextlib
import errno
import functools
import hashlib
import os
import tarfile
import warnings
import zipfile
from pathlib import Path
from typing import Any, Optional, Union
from urllib.parse import urlparse

import py.path
import requests

import dials_data.datasets

if os.name == "posix":
    import fcntl

    def _platform_lock(file_handle):
        fcntl.lockf(file_handle, fcntl.LOCK_EX)

    def _platform_unlock(file_handle):
        fcntl.lockf(file_handle, fcntl.LOCK_UN)

elif os.name == "nt":
    import msvcrt

    def _platform_lock(file_handle):
        file_handle.seek(0)
        while True:
            try:
                msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
                # Call will only block for 10 sec and then raise
                # OSError: [Errno 36] Resource deadlock avoided
                break  # lock obtained
            except OSError as e:
                if e.errno != errno.EDEADLK:
                    raise

    def _platform_unlock(file_handle):
        file_handle.seek(0)
        msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1)

else:

    def _platform_lock(file_handle):
        raise NotImplementedError("File locking not supported on this platform")

    _platform_unlock = _platform_lock


@contextlib.contextmanager
def _file_lock(file_handle):
    """
    Cross-platform file locking. Open a file for writing or appending.
    Then a file lock can be obtained with:

    with open(filename, 'w') as fh:
      with _file_lock(fh):
        (..)
    """
    lock = False
    try:
        _platform_lock(file_handle)
        lock = True
        yield
    finally:
        if lock:
            _platform_unlock(file_handle)


@contextlib.contextmanager
def download_lock(target_dir: Optional[Path]):
    """
    Obtains a (cooperative) lock on a lockfile in a target directory, so only a
    single (cooperative) process can enter this context manager at any one time.
    If the lock is held this will block until the existing lock is released.
    """
    if not target_dir:
        yield
        return
    target_dir.mkdir(parents=True, exist_ok=True)
    with target_dir.joinpath(".lock").open(mode="w") as fh:
        with _file_lock(fh):
            yield


def _download_to_file(session: requests.Session, url: str, pyfile: Path):
    """
    Downloads a single URL to a file.
    """
    with session.get(url, stream=True) as r:
        r.raise_for_status()
        pyfile.parent.mkdir(parents=True, exist_ok=True)
        with pyfile.open(mode="wb") as f:
            for chunk in r.iter_content(chunk_size=40960):
                f.write(chunk)


def file_hash(file_to_hash: Path) -> str:
    """Returns the SHA256 digest of a file."""
    sha256_hash = hashlib.sha256()
    with file_to_hash.open("rb") as f:
        for block in iter(lambda: f.read(4096), b""):
            sha256_hash.update(block)
    return sha256_hash.hexdigest()


def fetch_dataset(
    dataset,
    ignore_hashinfo: bool = False,
    verify: bool = False,
    read_only: bool = False,
    verbose: bool = False,
    pre_scan: bool = True,
    download_lockdir: Optional[Path] = None,
) -> Union[bool, Any]:
    """Check for the presence or integrity of the local copy of the specified
    test dataset. If the dataset is not available or out of date then attempt
    to download/update it transparently.

    :param verbose:          Show everything as it happens.
    :param pre_scan:         If all files are present and all file sizes match
                             then skip file integrity check and exit quicker.
    :param read_only:        Only use existing data, never download anything.
                             Implies pre_scan=True.
    :returns:                False if the dataset can not be downloaded/updated
                             for any reason.
                             True if the dataset is present and passes a
                             cursory inspection.
                             A validation dictionary if the dataset is present
                             and was fully verified.
    """
    if dataset not in dials_data.datasets.definition:
        return False
    definition = dials_data.datasets.definition[dataset]

    target_dir: Path = dials_data.datasets.repository_location() / dataset
    if read_only and not target_dir.is_dir():
        return False

    integrity_info = definition.get("hashinfo")
    if not integrity_info or ignore_hashinfo:
        integrity_info = dials_data.datasets.create_integrity_record(dataset)

    if "verify" not in integrity_info:
        integrity_info["verify"] = [{} for _ in definition["data"]]
    filelist: list[dict[str, Any]] = [
        {
            "url": source["url"],
            "file": target_dir / os.path.basename(urlparse(source["url"]).path),
            "files": source.get("files"),
            "verify": hashinfo,
        }
        for source, hashinfo in zip(definition["data"], integrity_info["verify"])
    ]

    if pre_scan or read_only:
        if all(
            item["file"].is_file()
            and item["verify"].get("size")
            and item["verify"]["size"] == item["file"].stat().st_size
            for item in filelist
        ):
            return True
        if read_only:
            return False

    # Acquire lock if required as files may be downloaded/written.
    with download_lock(download_lockdir):
        _fetch_filelist(filelist)

    return integrity_info


def _fetch_filelist(filelist: list[dict[str, Any]]) -> None:
    with requests.Session() as rs:
        pool = concurrent.futures.ThreadPoolExecutor(max_workers=5)
        pool.map(functools.partial(_fetch_file, rs), filelist)


def _fetch_file(session: requests.Session, source: dict[str, Any]) -> None:
    valid = False
    if source["file"].is_file():
        # verify
        valid = True
        if source["verify"]:
            if source["verify"]["size"] != source["file"].stat().st_size:
                valid = False
            elif source["verify"]["hash"] != file_hash(source["file"]):
                valid = False

    downloaded = False
    if not valid:
        print(f"Downloading {source['url']}")
        _download_to_file(session, source["url"], source["file"])
        downloaded = True

    # verify
    valid = True
    if source["verify"]:
        if source["verify"]["size"] != source["file"].stat().st_size:
            print(
                f"File size mismatch on {source['file']}: "
                f"{source['file'].stat().st_size}, expected {source['verify']['size']}"
            )
        elif source["verify"]["hash"] != file_hash(source["file"]):
            print(f"File hash mismatch on {source['file']}")
    else:
        source["verify"]["size"] = source["file"].stat().st_size
        source["verify"]["hash"] = file_hash(source["file"])

    # If the file is a tar archive, then decompress
    if source["files"]:
        target_dir = source["file"].parent
        if downloaded or not all((target_dir / f).is_file() for f in source["files"]):
            # If the file has been (re)downloaded, or we don't have all the requested
            # files from the archive, then we need to decompress the archive
            print(f"Decompressing {source['file']}")
            if source["file"].suffix == ".zip":
                with zipfile.ZipFile(source["file"]) as zf:
                    try:
                        for f in source["files"]:
                            zf.extract(f, path=source["file"].parent)
                    except KeyError:
                        print(
                            f"Expected file {f} not present "
                            f"in zip archive {source['file']}"
                        )
            else:
                with tarfile.open(source["file"]) as tar:
                    for f in source["files"]:
                        try:
                            tar.extract(f, path=source["file"].parent)
                        except KeyError:
                            print(
                                f"Expected file {f} not present "
                                f"in tar archive {source['file']}"
                            )


class DataFetcher:
    """A class that offers access to regression datasets.

    To initialize:
        df = DataFetcher()
    Then
        df('insulin')
    returns a Path object to the insulin data. If that data is not already
    on disk it is downloaded automatically.

    To disable all downloads:
        df = DataFetcher(read_only=True)

    Do not use this class directly in tests! Use the dials_data fixture.
    """

    def __init__(self, read_only=False):
        self._cache: dict[str, Optional[Path]] = {}
        self._target_dir: Path = dials_data.datasets.repository_location()
        self._read_only: bool = read_only and os.access(self._target_dir, os.W_OK)

    def __repr__(self) -> str:
        return "<{}DataFetcher: {}>".format(
            "R/O " if self._read_only else "",
            self._target_dir,
        )

    def result_filter(self, result, **kwargs):
        """
        An overridable function to mangle lookup results.
        Used in tests to transform negative lookups to test skips.
        Overriding functions should add **kwargs to function signature
        to be forwards compatible.
        """
        return result

    def __call__(self, test_data: str, pathlib=None, **kwargs):
        """
        Return the location of a dataset, transparently downloading it if
        necessary and possible.
        The return value can be manipulated by overriding the result_filter
        function.
        :param test_data: name of the requested dataset.
        :param pathlib: Whether to return the result as a Python pathlib object.
                        The default for this setting is 'False' for now (leading
                        to a py.path.local object being returned), but the default
                        will change to 'True' in a future dials.data release.
                        Set to 'True' for forward compatibility.
        :return: A pathlib or py.path.local object pointing to the dataset, or False
                 if the dataset is not available.
        """
        if test_data not in self._cache:
            self._cache[test_data] = self._attempt_fetch(test_data)
        if pathlib is None:
            warnings.warn(
                "The DataFetcher currently returns py.path.local() objects. "
                "This will in the future change to pathlib.Path() objects. "
                "You can either add a pathlib=True argument to obtain a pathlib.Path() object, "
                "or pathlib=False to silence this warning for now.",
                DeprecationWarning,
                stacklevel=2,
            )
        if not self._cache[test_data]:
            return self.result_filter(result=False)
        elif not pathlib:
            return self.result_filter(result=py.path.local(self._cache[test_data]))
        return self.result_filter(result=self._cache[test_data])

    def _attempt_fetch(self, test_data: str) -> Optional[Path]:
        if self._read_only:
            data_available = fetch_dataset(test_data, pre_scan=True, read_only=True)
        else:
            data_available = fetch_dataset(
                test_data,
                pre_scan=True,
                read_only=False,
                download_lockdir=self._target_dir,
            )
        if data_available:
            return self._target_dir / test_data
        else:
            return None