1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
|
# Copyright 2021, New York University and the TUF contributors
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Unit test for Urllib3Fetcher."""
import io
import logging
import math
import os
import sys
import tempfile
import unittest
from typing import ClassVar
from unittest.mock import Mock, patch
import urllib3
from tests import utils
from tuf.api import exceptions
from tuf.ngclient import Urllib3Fetcher
logger = logging.getLogger(__name__)
class TestFetcher(unittest.TestCase):
"""Test Urllib3Fetcher class."""
server_process_handler: ClassVar[utils.TestServerProcess]
@classmethod
def setUpClass(cls) -> None:
"""
Create a temporary file and launch a simple server in the
current working directory.
"""
cls.server_process_handler = utils.TestServerProcess(log=logger)
cls.file_contents = b"junk data"
cls.file_length = len(cls.file_contents)
with tempfile.NamedTemporaryFile(
dir=os.getcwd(), delete=False
) as cls.target_file:
cls.target_file.write(cls.file_contents)
cls.url_prefix = (
f"http://{utils.TEST_HOST_ADDRESS}:"
f"{cls.server_process_handler.port!s}"
)
target_filename = os.path.basename(cls.target_file.name)
cls.url = f"{cls.url_prefix}/{target_filename}"
@classmethod
def tearDownClass(cls) -> None:
# Stop server process and perform clean up.
cls.server_process_handler.clean()
os.remove(cls.target_file.name)
def setUp(self) -> None:
# Instantiate a concrete instance of FetcherInterface
self.fetcher = Urllib3Fetcher()
# Simple fetch.
def test_fetch(self) -> None:
with tempfile.TemporaryFile() as temp_file:
for chunk in self.fetcher.fetch(self.url):
temp_file.write(chunk)
temp_file.seek(0)
self.assertEqual(self.file_contents, temp_file.read())
# URL data downloaded in more than one chunk
def test_fetch_in_chunks(self) -> None:
# Set a smaller chunk size to ensure that the file will be downloaded
# in more than one chunk
self.fetcher.chunk_size = 4
# expected_chunks_count: 3 (depends on length of self.file_length)
expected_chunks_count = math.ceil(
self.file_length / self.fetcher.chunk_size
)
self.assertEqual(expected_chunks_count, 3)
chunks_count = 0
with tempfile.TemporaryFile() as temp_file:
for chunk in self.fetcher.fetch(self.url):
temp_file.write(chunk)
chunks_count += 1
temp_file.seek(0)
self.assertEqual(self.file_contents, temp_file.read())
# Check that we calculate chunks as expected
self.assertEqual(chunks_count, expected_chunks_count)
# Incorrect URL parsing
def test_url_parsing(self) -> None:
with self.assertRaises(exceptions.DownloadError):
self.fetcher.fetch("http://invalid/")
# File not found error
def test_http_error(self) -> None:
with self.assertRaises(exceptions.DownloadHTTPError) as cm:
self.url = f"{self.url_prefix}/non-existing-path"
self.fetcher.fetch(self.url)
self.assertEqual(cm.exception.status_code, 404)
# Response read timeout error
@patch.object(urllib3.PoolManager, "request")
def test_response_read_timeout(self, mock_session_get: Mock) -> None:
mock_response = Mock()
mock_response.status = 200
attr = {
"stream.side_effect": urllib3.exceptions.MaxRetryError(
urllib3.connectionpool.ConnectionPool("localhost"),
"",
urllib3.exceptions.TimeoutError(),
)
}
mock_response.configure_mock(**attr)
mock_session_get.return_value = mock_response
with self.assertRaises(exceptions.SlowRetrievalError):
next(self.fetcher.fetch(self.url))
mock_response.stream.assert_called_once()
# Read/connect session timeout error
@patch.object(
urllib3.PoolManager,
"request",
side_effect=urllib3.exceptions.MaxRetryError(
urllib3.connectionpool.ConnectionPool("localhost"),
"",
urllib3.exceptions.TimeoutError(),
),
)
def test_session_get_timeout(self, mock_session_get: Mock) -> None:
with self.assertRaises(exceptions.SlowRetrievalError):
self.fetcher.fetch(self.url)
mock_session_get.assert_called_once()
# Simple bytes download
def test_download_bytes(self) -> None:
data = self.fetcher.download_bytes(self.url, self.file_length)
self.assertEqual(self.file_contents, data)
# Download file smaller than required max_length
def test_download_bytes_upper_length(self) -> None:
data = self.fetcher.download_bytes(self.url, self.file_length + 4)
self.assertEqual(self.file_contents, data)
# Download a file bigger than expected
def test_download_bytes_length_mismatch(self) -> None:
with self.assertRaises(exceptions.DownloadLengthMismatchError):
self.fetcher.download_bytes(self.url, self.file_length - 4)
# Simple file download
def test_download_file(self) -> None:
with self.fetcher.download_file(
self.url, self.file_length
) as temp_file:
temp_file.seek(0, io.SEEK_END)
self.assertEqual(self.file_length, temp_file.tell())
# Download file smaller than required max_length
def test_download_file_upper_length(self) -> None:
with self.fetcher.download_file(
self.url, self.file_length + 4
) as temp_file:
temp_file.seek(0, io.SEEK_END)
self.assertEqual(self.file_length, temp_file.tell())
# Download a file bigger than expected
def test_download_file_length_mismatch(self) -> None:
with self.assertRaises(
exceptions.DownloadLengthMismatchError
), self.fetcher.download_file(self.url, self.file_length - 4):
pass # we never get here as download_file() raises
# Run unit test.
if __name__ == "__main__":
utils.configure_test_logging(sys.argv)
unittest.main()
|