1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
|
# localslackirc
# Copyright (C) 2020 Salvo "LtWorf" Tomaselli
#
# localslackirc is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# author Salvo "LtWorf" Tomaselli <tiposchi@tiscali.it>
import asyncio
import gzip
import json
from typing import Optional, NamedTuple, Any
from uuid import uuid1
from urllib import parse
def multipart_form(form_fields: dict[str, Any]) -> tuple[str, bytes]:
"""
Convert a dictionary to post data and returns relevant headers.
The dictionary can contain values as open files, or anything else.
None values are skipped.
Anything that is not an open file is cast to str
"""
data = {}
has_files = False
for k, v in form_fields.items():
if v is not None:
data[k] = v
if hasattr(v, 'read') and hasattr(v, 'name'):
has_files = True
if not has_files:
return (
'Content-Type: application/x-www-form-urlencoded\r\n',
parse.urlencode(data).encode('ascii')
)
boundary = str(uuid1()).encode('ascii')
form_data = b''
for k, v in data.items():
form_data += b'--' + boundary + b'\r\n'
if hasattr(v, 'read') and hasattr(v, 'name'):
form_data += f'Content-Disposition: form-data; name="{k}"; filename="{v.name}"\r\n'.encode('ascii')
form_data += b'\r\n' + v.read() + b'\r\n'
else:
strv = str(v)
form_data += f'Content-Disposition: form-data; name="{k}"\r\n'.encode('ascii')
form_data += b'\r\n' + strv.encode('ascii') + b'\r\n'
form_data += b'--' + boundary + b'\r\n'
header = f'Content-Type: multipart/form-data; boundary={boundary.decode("ascii")}\r\n'
return header, form_data
class Response(NamedTuple):
status: int
headers: dict[str, str]
data: bytes
def json(self):
return json.loads(self.data)
class Request:
def __init__(self, base_url: str) -> None:
"""https://slack.com/api/
In my case, base_url is "https://slack.com/api/"
"""
self.base_url = parse.urlsplit(base_url)
if self.base_url.scheme == 'https':
self.ssl = True
self.port = 443
else:
self.ssl = False
self.port = 80
# Override port if explicitly defined
if self.base_url.port:
self.port = self.base_url.port
self.hostname = self.base_url.hostname
self.path = self.base_url.path
self._connections: dict[str, tuple[asyncio.streams.StreamReader, asyncio.streams.StreamWriter]] = {}
def __del__(self):
for i in self._connections.values():
i[1].close()
async def _connect(self) -> tuple[asyncio.streams.StreamReader, asyncio.streams.StreamWriter]:
"""
Get a connection.
It can be an already cached one or a new one.
"""
task = asyncio.tasks.current_task()
assert task is not None # Mypy doesn't notice this is in an async
key = task.get_name()
r = self._connections.get(key)
if r is None:
r = await asyncio.open_connection(self.hostname, self.port, ssl=self.ssl)
self._connections[key] = r
return r
async def post(self, path: str, headers: dict[str, str], data: dict[str, Any], timeout: float=0) -> Response:
"""
post a request.
data will be sent as a form. Fields are converted to str, except for
open files, which are read and sent. Open files must be opened in
binary mode.
Due to the possibility that the cached connection got closed, it will do
one retry before raising the exception
"""
try:
return await self._post(path, headers, data, timeout)
except (BrokenPipeError, ConnectionResetError, asyncio.IncompleteReadError):
# Clear connection from pool
task = asyncio.tasks.current_task()
assert task is not None # Mypy doesn't notice this is in an async
key = task.get_name()
r, w = self._connections.pop(key)
w.close()
return await self._post(path, headers, data, timeout)
async def _post(self, path: str, headers: dict[str, str], data: dict[str, Any], timeout: float=0) -> Response:
# Prepare request
req = f'POST {self.path + path} HTTP/1.1\r\n'
req += f'Host: {self.hostname}\r\n'
req += 'Connection: keep-alive\r\n'
req += 'Accept-Encoding: gzip\r\n'
for k, v in headers.items():
req += f'{k}: {v}\r\n'
header, post_data = multipart_form(data)
req += header
req += f'Content-Length: {len(post_data)}\r\n'
req += '\r\n'
# Send request
# 1 retry in case the keep alive connection was closed
reader, writer = await self._connect()
writer.write(req.encode('ascii'))
writer.write(post_data)
await writer.drain()
# Read response
line = await reader.readline()
if len(line) == 0:
raise BrokenPipeError()
try:
status = int(line.split(b' ')[1])
except Exception as e:
raise Exception(f'Invalid data {line!r} {e}')
# Read headers
headers = {}
while True:
line = await reader.readline()
if line == b'\r\n':
break
elif len(line) == 0:
raise BrokenPipeError()
k, v = line.decode('ascii').split(':', 1)
headers[k.lower()] = v.strip()
# Read data
read_data = b''
if headers.get('transfer-encoding') == 'chunked':
while True:
line = await reader.readline()
if len(line) == 0:
raise BrokenPipeError()
if not line.endswith(b'\r\n'):
raise Exception('Unexpected end of chunked data')
size = int(line, 16)
read_data += (await reader.readexactly(size + 2))[:-2]
if size == 0:
break
elif 'content-length' in headers:
size = int(headers['content-length'])
read_data = await reader.readexactly(size)
else:
raise NotImplementedError('Can only handle chunked or content length' + repr(headers))
# decompress if needed
if headers.get('content-encoding') == 'gzip':
read_data = gzip.decompress(read_data)
return Response(status, headers, read_data)
|