1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
|
"""
helpers
~~~~~~~
This module contains helpers for the nio tests.
"""
import os
from random import choice
from string import ascii_letters, ascii_uppercase
from faker import Faker
from faker.providers import BaseProvider
from hpack.hpack import Encoder
from hyperframe.frame import (
AltSvcFrame,
ContinuationFrame,
DataFrame,
GoAwayFrame,
HeadersFrame,
PingFrame,
PriorityFrame,
PushPromiseFrame,
RstStreamFrame,
SettingsFrame,
WindowUpdateFrame,
)
from nio.crypto import OlmAccount, OlmDevice
from nio.store import Ed25519Key
SAMPLE_SETTINGS = {
SettingsFrame.HEADER_TABLE_SIZE: 4096,
SettingsFrame.ENABLE_PUSH: 1,
SettingsFrame.MAX_CONCURRENT_STREAMS: 2,
}
faker = Faker()
class Provider(BaseProvider):
def mx_id(self):
return f"@{faker.user_name()}:{faker.hostname()}"
def avatar_url(self):
return f"mxc://{faker.hostname()}/{''.join(choice(ascii_letters) for i in range(24))}#auto"
def device_id(self):
return "".join(choice(ascii_uppercase) for i in range(10))
def olm_key_pair(self):
return OlmAccount().identity_keys
def olm_device(self):
user_id = faker.mx_id()
device_id = faker.device_id()
key_pair = faker.olm_key_pair()
return OlmDevice(
user_id,
device_id,
key_pair,
)
def ed25519_key(self):
return Ed25519Key(
faker.mx_id(), faker.device_id(), faker.olm_key_pair()["ed25519"]
)
faker.add_provider(Provider)
ephemeral_dir = os.path.join(os.curdir, "tests/data/encryption")
def ephemeral(func):
def wrapper(*args, **kwargs):
try:
ret = func(*args, **kwargs)
finally:
os.remove(os.path.join(ephemeral_dir, "@ephemeral:example.org_DEVICEID.db"))
return ret
return wrapper
class FrameFactory:
"""
A class containing lots of helper methods and state to build frames. This
allows test cases to easily build correct HTTP/2 frames to feed to
hyper-h2.
"""
def __init__(self):
self.encoder = Encoder()
def refresh_encoder(self):
self.encoder = Encoder()
def preamble(self):
return b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
def build_headers_frame(self, headers, flags=[], stream_id=1, **priority_kwargs):
"""
Builds a single valid headers frame out of the contained headers.
"""
f = HeadersFrame(stream_id)
f.data = self.encoder.encode(headers)
f.flags.add("END_HEADERS")
for flag in flags:
f.flags.add(flag)
for k, v in priority_kwargs.items():
setattr(f, k, v)
return f
def build_continuation_frame(self, header_block, flags=[], stream_id=1):
"""
Builds a single continuation frame out of the binary header block.
"""
f = ContinuationFrame(stream_id)
f.data = header_block
f.flags = set(flags)
return f
def build_data_frame(self, data, flags=None, stream_id=1, padding_len=0):
"""
Builds a single data frame out of a chunk of data.
"""
flags = set(flags) if flags is not None else set()
f = DataFrame(stream_id)
f.data = data
f.flags = flags
if padding_len:
flags.add("PADDED")
f.pad_length = padding_len
return f
def build_settings_frame(self, settings, ack=False):
"""
Builds a single settings frame.
"""
f = SettingsFrame(0)
if ack:
f.flags.add("ACK")
f.settings = settings
return f
def build_window_update_frame(self, stream_id, increment):
"""
Builds a single WindowUpdate frame.
"""
f = WindowUpdateFrame(stream_id)
f.window_increment = increment
return f
def build_ping_frame(self, ping_data, flags=None):
"""
Builds a single Ping frame.
"""
f = PingFrame(0)
f.opaque_data = ping_data
if flags:
f.flags = set(flags)
return f
def build_goaway_frame(self, last_stream_id, error_code=0, additional_data=b""):
"""
Builds a single GOAWAY frame.
"""
f = GoAwayFrame(0)
f.error_code = error_code
f.last_stream_id = last_stream_id
f.additional_data = additional_data
return f
def build_rst_stream_frame(self, stream_id, error_code=0):
"""
Builds a single RST_STREAM frame.
"""
f = RstStreamFrame(stream_id)
f.error_code = error_code
return f
def build_push_promise_frame(
self, stream_id, promised_stream_id, headers, flags=[]
):
"""
Builds a single PUSH_PROMISE frame.
"""
f = PushPromiseFrame(stream_id)
f.promised_stream_id = promised_stream_id
f.data = self.encoder.encode(headers)
f.flags = set(flags)
f.flags.add("END_HEADERS")
return f
def build_priority_frame(self, stream_id, weight, depends_on=0, exclusive=False):
"""
Builds a single priority frame.
"""
f = PriorityFrame(stream_id)
f.depends_on = depends_on
f.stream_weight = weight
f.exclusive = exclusive
return f
def build_alt_svc_frame(self, stream_id, origin, field):
"""
Builds a single ALTSVC frame.
"""
f = AltSvcFrame(stream_id)
f.origin = origin
f.field = field
return f
|