1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
|
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Any, cast
import OpenSSL.SSL
import pytest
from pytest_twisted import async_yield_fixture
from twisted.web import server, static
from twisted.web.client import Agent, BrowserLikePolicyForHTTPS, readBody
from twisted.web.client import Response as TxResponse
from scrapy.core.downloader import Downloader, Slot
from scrapy.core.downloader.contextfactory import (
ScrapyClientContextFactory,
load_context_factory_from_settings,
)
from scrapy.core.downloader.handlers.http11 import _RequestBodyProducer
from scrapy.exceptions import ScrapyDeprecationWarning
from scrapy.settings import Settings
from scrapy.utils.defer import deferred_f_from_coro_f, maybe_deferred_to_future
from scrapy.utils.misc import build_from_crawler
from scrapy.utils.python import to_bytes
from scrapy.utils.spider import DefaultSpider
from scrapy.utils.test import get_crawler
from tests.mockserver.http_resources import PayloadResource
from tests.mockserver.utils import ssl_context_factory
if TYPE_CHECKING:
from twisted.internet.defer import Deferred
from twisted.web.iweb import IBodyProducer
class TestSlot:
def test_repr(self):
slot = Slot(concurrency=8, delay=0.1, randomize_delay=True)
assert repr(slot) == "Slot(concurrency=8, delay=0.10, randomize_delay=True)"
class TestContextFactoryBase:
context_factory = None
@async_yield_fixture
async def server_url(self, tmp_path):
(tmp_path / "file").write_bytes(b"0123456789")
r = static.File(str(tmp_path))
r.putChild(b"payload", PayloadResource())
site = server.Site(r, timeout=None)
port = self._listen(site)
portno = port.getHost().port
yield f"https://127.0.0.1:{portno}/"
await port.stopListening()
def _listen(self, site):
from twisted.internet import reactor
return reactor.listenSSL(
0,
site,
contextFactory=self.context_factory or ssl_context_factory(),
interface="127.0.0.1",
)
@staticmethod
async def get_page(
url: str,
client_context_factory: BrowserLikePolicyForHTTPS,
body: str | None = None,
) -> bytes:
from twisted.internet import reactor
agent = Agent(reactor, contextFactory=client_context_factory)
body_producer = _RequestBodyProducer(body.encode()) if body else None
response: TxResponse = cast(
"TxResponse",
await maybe_deferred_to_future(
agent.request(
b"GET",
url.encode(),
bodyProducer=cast("IBodyProducer", body_producer),
)
),
)
with warnings.catch_warnings():
# https://github.com/twisted/twisted/issues/8227
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message=r".*does not have an abortConnection method",
)
d: Deferred[bytes] = readBody(response) # type: ignore[arg-type]
return await maybe_deferred_to_future(d)
class TestContextFactory(TestContextFactoryBase):
@deferred_f_from_coro_f
async def testPayload(self, server_url: str) -> None:
s = "0123456789" * 10
crawler = get_crawler()
settings = Settings()
client_context_factory = load_context_factory_from_settings(settings, crawler)
body = await self.get_page(
server_url + "payload", client_context_factory, body=s
)
assert body == to_bytes(s)
def test_override_getContext(self):
class MyFactory(ScrapyClientContextFactory):
def getContext(
self, hostname: Any = None, port: Any = None
) -> OpenSSL.SSL.Context:
ctx: OpenSSL.SSL.Context = super().getContext(hostname, port)
return ctx
with warnings.catch_warnings(record=True) as w:
MyFactory()
assert len(w) == 1
assert (
"Overriding ScrapyClientContextFactory.getContext() is deprecated"
in str(w[0].message)
)
class TestContextFactoryTLSMethod(TestContextFactoryBase):
async def _assert_factory_works(
self, server_url: str, client_context_factory: ScrapyClientContextFactory
) -> None:
s = "0123456789" * 10
body = await self.get_page(
server_url + "payload", client_context_factory, body=s
)
assert body == to_bytes(s)
@deferred_f_from_coro_f
async def test_setting_default(self, server_url: str) -> None:
crawler = get_crawler()
settings = Settings()
client_context_factory = load_context_factory_from_settings(settings, crawler)
assert client_context_factory._ssl_method == OpenSSL.SSL.SSLv23_METHOD
await self._assert_factory_works(server_url, client_context_factory)
def test_setting_none(self):
crawler = get_crawler()
settings = Settings({"DOWNLOADER_CLIENT_TLS_METHOD": None})
with pytest.raises(KeyError):
load_context_factory_from_settings(settings, crawler)
def test_setting_bad(self):
crawler = get_crawler()
settings = Settings({"DOWNLOADER_CLIENT_TLS_METHOD": "bad"})
with pytest.raises(KeyError):
load_context_factory_from_settings(settings, crawler)
@deferred_f_from_coro_f
async def test_setting_explicit(self, server_url: str) -> None:
crawler = get_crawler()
settings = Settings({"DOWNLOADER_CLIENT_TLS_METHOD": "TLSv1.2"})
client_context_factory = load_context_factory_from_settings(settings, crawler)
assert client_context_factory._ssl_method == OpenSSL.SSL.TLSv1_2_METHOD
await self._assert_factory_works(server_url, client_context_factory)
@deferred_f_from_coro_f
async def test_direct_from_crawler(self, server_url: str) -> None:
# the setting is ignored
crawler = get_crawler(settings_dict={"DOWNLOADER_CLIENT_TLS_METHOD": "bad"})
client_context_factory = build_from_crawler(ScrapyClientContextFactory, crawler)
assert client_context_factory._ssl_method == OpenSSL.SSL.SSLv23_METHOD
await self._assert_factory_works(server_url, client_context_factory)
@deferred_f_from_coro_f
async def test_direct_init(self, server_url: str) -> None:
client_context_factory = ScrapyClientContextFactory(OpenSSL.SSL.TLSv1_2_METHOD)
assert client_context_factory._ssl_method == OpenSSL.SSL.TLSv1_2_METHOD
await self._assert_factory_works(server_url, client_context_factory)
@deferred_f_from_coro_f
async def test_fetch_deprecated_spider_arg():
class CustomDownloader(Downloader):
def fetch(self, request, spider): # pylint: disable=signature-differs
return super().fetch(request, spider)
crawler = get_crawler(DefaultSpider, {"DOWNLOADER": CustomDownloader})
with pytest.warns(
ScrapyDeprecationWarning,
match=r"The fetch\(\) method of .+\.CustomDownloader requires a spider argument",
):
await crawler.crawl_async()
|