File: test_fetcher_ng.py

package info (click to toggle)
python-tuf 6.0.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,300 kB
  • sloc: python: 7,738; makefile: 8
file content (182 lines) | stat: -rw-r--r-- 6,327 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Copyright 2021, New York University and the TUF contributors
# SPDX-License-Identifier: MIT OR Apache-2.0

"""Unit test for Urllib3Fetcher."""

import io
import logging
import math
import os
import sys
import tempfile
import unittest
from typing import ClassVar
from unittest.mock import Mock, patch

import urllib3

from tests import utils
from tuf.api import exceptions
from tuf.ngclient import Urllib3Fetcher

logger = logging.getLogger(__name__)


class TestFetcher(unittest.TestCase):
    """Test Urllib3Fetcher class."""

    server_process_handler: ClassVar[utils.TestServerProcess]

    @classmethod
    def setUpClass(cls) -> None:
        """
        Create a temporary file and launch a simple server in the
        current working directory.
        """
        cls.server_process_handler = utils.TestServerProcess(log=logger)

        cls.file_contents = b"junk data"
        cls.file_length = len(cls.file_contents)
        with tempfile.NamedTemporaryFile(
            dir=os.getcwd(), delete=False
        ) as cls.target_file:
            cls.target_file.write(cls.file_contents)

        cls.url_prefix = (
            f"http://{utils.TEST_HOST_ADDRESS}:"
            f"{cls.server_process_handler.port!s}"
        )
        target_filename = os.path.basename(cls.target_file.name)
        cls.url = f"{cls.url_prefix}/{target_filename}"

    @classmethod
    def tearDownClass(cls) -> None:
        # Stop server process and perform clean up.
        cls.server_process_handler.clean()
        os.remove(cls.target_file.name)

    def setUp(self) -> None:
        # Instantiate a concrete instance of FetcherInterface
        self.fetcher = Urllib3Fetcher()

    # Simple fetch.
    def test_fetch(self) -> None:
        with tempfile.TemporaryFile() as temp_file:
            for chunk in self.fetcher.fetch(self.url):
                temp_file.write(chunk)

            temp_file.seek(0)
            self.assertEqual(self.file_contents, temp_file.read())

    # URL data downloaded in more than one chunk
    def test_fetch_in_chunks(self) -> None:
        # Set a smaller chunk size to ensure that the file will be downloaded
        # in more than one chunk
        self.fetcher.chunk_size = 4

        # expected_chunks_count: 3 (depends on length of self.file_length)
        expected_chunks_count = math.ceil(
            self.file_length / self.fetcher.chunk_size
        )
        self.assertEqual(expected_chunks_count, 3)

        chunks_count = 0
        with tempfile.TemporaryFile() as temp_file:
            for chunk in self.fetcher.fetch(self.url):
                temp_file.write(chunk)
                chunks_count += 1

            temp_file.seek(0)
            self.assertEqual(self.file_contents, temp_file.read())
            # Check that we calculate chunks as expected
            self.assertEqual(chunks_count, expected_chunks_count)

    # Incorrect URL parsing
    def test_url_parsing(self) -> None:
        with self.assertRaises(exceptions.DownloadError):
            self.fetcher.fetch("http://invalid/")

    # File not found error
    def test_http_error(self) -> None:
        with self.assertRaises(exceptions.DownloadHTTPError) as cm:
            self.url = f"{self.url_prefix}/non-existing-path"
            self.fetcher.fetch(self.url)
        self.assertEqual(cm.exception.status_code, 404)

    # Response read timeout error
    @patch.object(urllib3.PoolManager, "request")
    def test_response_read_timeout(self, mock_session_get: Mock) -> None:
        mock_response = Mock()
        mock_response.status = 200
        attr = {
            "stream.side_effect": urllib3.exceptions.MaxRetryError(
                urllib3.connectionpool.ConnectionPool("localhost"),
                "",
                urllib3.exceptions.TimeoutError(),
            )
        }
        mock_response.configure_mock(**attr)
        mock_session_get.return_value = mock_response

        with self.assertRaises(exceptions.SlowRetrievalError):
            next(self.fetcher.fetch(self.url))
        mock_response.stream.assert_called_once()

    # Read/connect session timeout error
    @patch.object(
        urllib3.PoolManager,
        "request",
        side_effect=urllib3.exceptions.MaxRetryError(
            urllib3.connectionpool.ConnectionPool("localhost"),
            "",
            urllib3.exceptions.TimeoutError(),
        ),
    )
    def test_session_get_timeout(self, mock_session_get: Mock) -> None:
        with self.assertRaises(exceptions.SlowRetrievalError):
            self.fetcher.fetch(self.url)
        mock_session_get.assert_called_once()

    # Simple bytes download
    def test_download_bytes(self) -> None:
        data = self.fetcher.download_bytes(self.url, self.file_length)
        self.assertEqual(self.file_contents, data)

    # Download file smaller than required max_length
    def test_download_bytes_upper_length(self) -> None:
        data = self.fetcher.download_bytes(self.url, self.file_length + 4)
        self.assertEqual(self.file_contents, data)

    # Download a file bigger than expected
    def test_download_bytes_length_mismatch(self) -> None:
        with self.assertRaises(exceptions.DownloadLengthMismatchError):
            self.fetcher.download_bytes(self.url, self.file_length - 4)

    # Simple file download
    def test_download_file(self) -> None:
        with self.fetcher.download_file(
            self.url, self.file_length
        ) as temp_file:
            temp_file.seek(0, io.SEEK_END)
            self.assertEqual(self.file_length, temp_file.tell())

    # Download file smaller than required max_length
    def test_download_file_upper_length(self) -> None:
        with self.fetcher.download_file(
            self.url, self.file_length + 4
        ) as temp_file:
            temp_file.seek(0, io.SEEK_END)
            self.assertEqual(self.file_length, temp_file.tell())

    # Download a file bigger than expected
    def test_download_file_length_mismatch(self) -> None:
        with self.assertRaises(
            exceptions.DownloadLengthMismatchError
        ), self.fetcher.download_file(self.url, self.file_length - 4):
            pass  # we never get here as download_file() raises


# Run unit test.
if __name__ == "__main__":
    utils.configure_test_logging(sys.argv)
    unittest.main()