1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
# Copyright 2021, New York University and the TUF contributors
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Provides an interface for network IO abstraction."""
# Imports
import abc
import logging
import tempfile
from collections.abc import Iterator
from contextlib import contextmanager
from typing import IO
from tuf.api import exceptions
logger = logging.getLogger(__name__)
# Classes
class FetcherInterface(metaclass=abc.ABCMeta):
"""Defines an interface for abstract network download.
By providing a concrete implementation of the abstract interface,
users of the framework can plug-in their preferred/customized
network stack.
Implementations of FetcherInterface only need to implement ``_fetch()``.
The public API of the class is already implemented.
"""
@abc.abstractmethod
def _fetch(self, url: str) -> Iterator[bytes]:
"""Fetch the contents of HTTP/HTTPS ``url`` from a remote server.
Implementations must raise ``DownloadHTTPError`` if they receive
an HTTP error code.
Implementations may raise any errors but the ones that are not
``DownloadErrors`` will be wrapped in a ``DownloadError`` by
``fetch()``.
Args:
url: URL string that represents a file location.
Raises:
exceptions.DownloadHTTPError: HTTP error code was received.
Returns:
Bytes iterator
"""
raise NotImplementedError # pragma: no cover
def fetch(self, url: str) -> Iterator[bytes]:
"""Fetch the contents of HTTP/HTTPS ``url`` from a remote server.
Args:
url: URL string that represents a file location.
Raises:
exceptions.DownloadError: An error occurred during download.
exceptions.DownloadHTTPError: An HTTP error code was received.
Returns:
Bytes iterator
"""
# Ensure that fetch() only raises DownloadErrors, regardless of the
# fetcher implementation
try:
return self._fetch(url)
except exceptions.DownloadError as e:
raise e
except Exception as e:
raise exceptions.DownloadError(f"Failed to download {url}") from e
@contextmanager
def download_file(self, url: str, max_length: int) -> Iterator[IO]:
"""Download file from given ``url``.
It is recommended to use ``download_file()`` within a ``with``
block to guarantee that allocated file resources will always
be released even if download fails.
Args:
url: URL string that represents the location of the file.
max_length: Upper bound of file size in bytes.
Raises:
exceptions.DownloadError: An error occurred during download.
exceptions.DownloadLengthMismatchError: Downloaded bytes exceed
``max_length``.
exceptions.DownloadHTTPError: An HTTP error code was received.
Yields:
``TemporaryFile`` object that points to the contents of ``url``.
"""
logger.debug("Downloading: %s", url)
number_of_bytes_received = 0
with tempfile.TemporaryFile() as temp_file:
chunks = self.fetch(url)
for chunk in chunks:
number_of_bytes_received += len(chunk)
if number_of_bytes_received > max_length:
raise exceptions.DownloadLengthMismatchError(
f"Downloaded {number_of_bytes_received} bytes exceeding"
f" the maximum allowed length of {max_length}"
)
temp_file.write(chunk)
logger.debug(
"Downloaded %d out of %d bytes",
number_of_bytes_received,
max_length,
)
temp_file.seek(0)
yield temp_file
def download_bytes(self, url: str, max_length: int) -> bytes:
"""Download bytes from given ``url``.
Returns the downloaded bytes, otherwise like ``download_file()``.
Args:
url: URL string that represents the location of the file.
max_length: Upper bound of data size in bytes.
Raises:
exceptions.DownloadError: An error occurred during download.
exceptions.DownloadLengthMismatchError: Downloaded bytes exceed
``max_length``.
exceptions.DownloadHTTPError: An HTTP error code was received.
Returns:
Content of the file in bytes.
"""
with self.download_file(url, max_length) as dl_file:
return dl_file.read()
|