File: agent.py

package info (click to toggle)
python-scrapy 2.13.3-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,664 kB
  • sloc: python: 52,028; xml: 199; makefile: 25; sh: 7
file content (185 lines) | stat: -rw-r--r-- 6,491 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from __future__ import annotations

from collections import deque
from typing import TYPE_CHECKING

from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
from twisted.web.client import (
    URI,
    BrowserLikePolicyForHTTPS,
    ResponseFailed,
    _StandardEndpointFactory,
)
from twisted.web.error import SchemeNotSupported

from scrapy.core.downloader.contextfactory import AcceptableProtocolsContextFactory
from scrapy.core.http2.protocol import H2ClientFactory, H2ClientProtocol

if TYPE_CHECKING:
    from twisted.internet.base import ReactorBase
    from twisted.internet.endpoints import HostnameEndpoint

    from scrapy.http import Request, Response
    from scrapy.settings import Settings
    from scrapy.spiders import Spider


ConnectionKeyT = tuple[bytes, bytes, int]


class H2ConnectionPool:
    def __init__(self, reactor: ReactorBase, settings: Settings) -> None:
        self._reactor = reactor
        self.settings = settings

        # Store a dictionary which is used to get the respective
        # H2ClientProtocolInstance using the  key as Tuple(scheme, hostname, port)
        self._connections: dict[ConnectionKeyT, H2ClientProtocol] = {}

        # Save all requests that arrive before the connection is established
        self._pending_requests: dict[
            ConnectionKeyT, deque[Deferred[H2ClientProtocol]]
        ] = {}

    def get_connection(
        self, key: ConnectionKeyT, uri: URI, endpoint: HostnameEndpoint
    ) -> Deferred[H2ClientProtocol]:
        if key in self._pending_requests:
            # Received a request while connecting to remote
            # Create a deferred which will fire with the H2ClientProtocol
            # instance
            d: Deferred[H2ClientProtocol] = Deferred()
            self._pending_requests[key].append(d)
            return d

        # Check if we already have a connection to the remote
        conn = self._connections.get(key, None)
        if conn:
            # Return this connection instance wrapped inside a deferred
            return defer.succeed(conn)

        # No connection is established for the given URI
        return self._new_connection(key, uri, endpoint)

    def _new_connection(
        self, key: ConnectionKeyT, uri: URI, endpoint: HostnameEndpoint
    ) -> Deferred[H2ClientProtocol]:
        self._pending_requests[key] = deque()

        conn_lost_deferred: Deferred[list[BaseException]] = Deferred()
        conn_lost_deferred.addCallback(self._remove_connection, key)

        factory = H2ClientFactory(uri, self.settings, conn_lost_deferred)
        conn_d = endpoint.connect(factory)
        conn_d.addCallback(self.put_connection, key)

        d: Deferred[H2ClientProtocol] = Deferred()
        self._pending_requests[key].append(d)
        return d

    def put_connection(
        self, conn: H2ClientProtocol, key: ConnectionKeyT
    ) -> H2ClientProtocol:
        self._connections[key] = conn

        # Now as we have established a proper HTTP/2 connection
        # we fire all the deferred's with the connection instance
        pending_requests = self._pending_requests.pop(key, None)
        while pending_requests:
            d = pending_requests.popleft()
            d.callback(conn)

        return conn

    def _remove_connection(
        self, errors: list[BaseException], key: ConnectionKeyT
    ) -> None:
        self._connections.pop(key)

        # Call the errback of all the pending requests for this connection
        pending_requests = self._pending_requests.pop(key, None)
        while pending_requests:
            d = pending_requests.popleft()
            d.errback(ResponseFailed(errors))

    def close_connections(self) -> None:
        """Close all the HTTP/2 connections and remove them from pool

        Returns:
            Deferred that fires when all connections have been closed
        """
        for conn in self._connections.values():
            assert conn.transport is not None  # typing
            conn.transport.abortConnection()


class H2Agent:
    def __init__(
        self,
        reactor: ReactorBase,
        pool: H2ConnectionPool,
        context_factory: BrowserLikePolicyForHTTPS = BrowserLikePolicyForHTTPS(),
        connect_timeout: float | None = None,
        bind_address: bytes | None = None,
    ) -> None:
        self._reactor = reactor
        self._pool = pool
        self._context_factory = AcceptableProtocolsContextFactory(
            context_factory, acceptable_protocols=[b"h2"]
        )
        self.endpoint_factory = _StandardEndpointFactory(
            self._reactor, self._context_factory, connect_timeout, bind_address
        )

    def get_endpoint(self, uri: URI) -> HostnameEndpoint:
        return self.endpoint_factory.endpointForURI(uri)

    def get_key(self, uri: URI) -> ConnectionKeyT:
        """
        Arguments:
            uri - URI obtained directly from request URL
        """
        return uri.scheme, uri.host, uri.port

    def request(self, request: Request, spider: Spider) -> Deferred[Response]:
        uri = URI.fromBytes(bytes(request.url, encoding="utf-8"))
        try:
            endpoint = self.get_endpoint(uri)
        except SchemeNotSupported:
            return defer.fail(Failure())

        key = self.get_key(uri)
        d: Deferred[H2ClientProtocol] = self._pool.get_connection(key, uri, endpoint)
        d2: Deferred[Response] = d.addCallback(
            lambda conn: conn.request(request, spider)
        )
        return d2


class ScrapyProxyH2Agent(H2Agent):
    def __init__(
        self,
        reactor: ReactorBase,
        proxy_uri: URI,
        pool: H2ConnectionPool,
        context_factory: BrowserLikePolicyForHTTPS = BrowserLikePolicyForHTTPS(),
        connect_timeout: float | None = None,
        bind_address: bytes | None = None,
    ) -> None:
        super().__init__(
            reactor=reactor,
            pool=pool,
            context_factory=context_factory,
            connect_timeout=connect_timeout,
            bind_address=bind_address,
        )
        self._proxy_uri = proxy_uri

    def get_endpoint(self, uri: URI) -> HostnameEndpoint:
        return self.endpoint_factory.endpointForURI(self._proxy_uri)

    def get_key(self, uri: URI) -> ConnectionKeyT:
        """We use the proxy uri instead of uri obtained from request url"""
        return b"http-proxy", self._proxy_uri.host, self._proxy_uri.port