File: fetcher.py

package info (click to toggle)
python-tuf 6.0.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,300 kB
  • sloc: python: 7,738; makefile: 8
file content (140 lines) | stat: -rw-r--r-- 4,684 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright 2021, New York University and the TUF contributors
# SPDX-License-Identifier: MIT OR Apache-2.0

"""Provides an interface for network IO abstraction."""

# Imports
import abc
import logging
import tempfile
from collections.abc import Iterator
from contextlib import contextmanager
from typing import IO

from tuf.api import exceptions

logger = logging.getLogger(__name__)


# Classes
class FetcherInterface(metaclass=abc.ABCMeta):
    """Defines an interface for abstract network download.

    By providing a concrete implementation of the abstract interface,
    users of the framework can plug-in their preferred/customized
    network stack.

    Implementations of FetcherInterface only need to implement ``_fetch()``.
    The public API of the class is already implemented.
    """

    @abc.abstractmethod
    def _fetch(self, url: str) -> Iterator[bytes]:
        """Fetch the contents of HTTP/HTTPS ``url`` from a remote server.

        Implementations must raise ``DownloadHTTPError`` if they receive
        an HTTP error code.

        Implementations may raise any errors but the ones that are not
        ``DownloadErrors`` will be wrapped in a ``DownloadError`` by
        ``fetch()``.

        Args:
            url: URL string that represents a file location.

        Raises:
            exceptions.DownloadHTTPError: HTTP error code was received.

        Returns:
            Bytes iterator
        """
        raise NotImplementedError  # pragma: no cover

    def fetch(self, url: str) -> Iterator[bytes]:
        """Fetch the contents of HTTP/HTTPS ``url`` from a remote server.

        Args:
            url: URL string that represents a file location.

        Raises:
            exceptions.DownloadError: An error occurred during download.
            exceptions.DownloadHTTPError: An HTTP error code was received.

        Returns:
            Bytes iterator
        """
        # Ensure that fetch() only raises DownloadErrors, regardless of the
        # fetcher implementation
        try:
            return self._fetch(url)
        except exceptions.DownloadError as e:
            raise e
        except Exception as e:
            raise exceptions.DownloadError(f"Failed to download {url}") from e

    @contextmanager
    def download_file(self, url: str, max_length: int) -> Iterator[IO]:
        """Download file from given ``url``.

        It is recommended to use ``download_file()`` within a ``with``
        block to guarantee that allocated file resources will always
        be released even if download fails.

        Args:
            url: URL string that represents the location of the file.
            max_length: Upper bound of file size in bytes.

        Raises:
            exceptions.DownloadError: An error occurred during download.
            exceptions.DownloadLengthMismatchError: Downloaded bytes exceed
                ``max_length``.
            exceptions.DownloadHTTPError: An HTTP error code was received.

        Yields:
            ``TemporaryFile`` object that points to the contents of ``url``.
        """
        logger.debug("Downloading: %s", url)

        number_of_bytes_received = 0

        with tempfile.TemporaryFile() as temp_file:
            chunks = self.fetch(url)
            for chunk in chunks:
                number_of_bytes_received += len(chunk)
                if number_of_bytes_received > max_length:
                    raise exceptions.DownloadLengthMismatchError(
                        f"Downloaded {number_of_bytes_received} bytes exceeding"
                        f" the maximum allowed length of {max_length}"
                    )

                temp_file.write(chunk)

            logger.debug(
                "Downloaded %d out of %d bytes",
                number_of_bytes_received,
                max_length,
            )

            temp_file.seek(0)
            yield temp_file

    def download_bytes(self, url: str, max_length: int) -> bytes:
        """Download bytes from given ``url``.

        Returns the downloaded bytes, otherwise like ``download_file()``.

        Args:
            url: URL string that represents the location of the file.
            max_length: Upper bound of data size in bytes.

        Raises:
            exceptions.DownloadError: An error occurred during download.
            exceptions.DownloadLengthMismatchError: Downloaded bytes exceed
                ``max_length``.
            exceptions.DownloadHTTPError: An HTTP error code was received.

        Returns:
            Content of the file in bytes.
        """
        with self.download_file(url, max_length) as dl_file:
            return dl_file.read()