1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
|
import dataclasses
from websockets.exceptions import NegotiationError
class OpExtension:
name = "x-op"
def __init__(self, op=None):
self.op = op
def decode(self, frame, *, max_size=None):
return frame # pragma: no cover
def encode(self, frame):
return frame # pragma: no cover
def __eq__(self, other):
return isinstance(other, OpExtension) and self.op == other.op
class ClientOpExtensionFactory:
name = "x-op"
def __init__(self, op=None):
self.op = op
def get_request_params(self):
return [("op", self.op)]
def process_response_params(self, params, accepted_extensions):
if params != [("op", self.op)]:
raise NegotiationError()
return OpExtension(self.op)
class ServerOpExtensionFactory:
name = "x-op"
def __init__(self, op=None):
self.op = op
def process_request_params(self, params, accepted_extensions):
if params != [("op", self.op)]:
raise NegotiationError()
return [("op", self.op)], OpExtension(self.op)
class NoOpExtension:
name = "x-no-op"
def __repr__(self):
return "NoOpExtension()"
def decode(self, frame, *, max_size=None):
return frame
def encode(self, frame):
return frame
class ClientNoOpExtensionFactory:
name = "x-no-op"
def get_request_params(self):
return []
def process_response_params(self, params, accepted_extensions):
if params:
raise NegotiationError()
return NoOpExtension()
class ServerNoOpExtensionFactory:
name = "x-no-op"
def __init__(self, params=None):
self.params = params or []
def process_request_params(self, params, accepted_extensions):
return self.params, NoOpExtension()
class Rsv2Extension:
name = "x-rsv2"
def decode(self, frame, *, max_size=None):
assert frame.rsv2
return dataclasses.replace(frame, rsv2=False)
def encode(self, frame):
assert not frame.rsv2
return dataclasses.replace(frame, rsv2=True)
def __eq__(self, other):
return isinstance(other, Rsv2Extension)
class ClientRsv2ExtensionFactory:
name = "x-rsv2"
def get_request_params(self):
return []
def process_response_params(self, params, accepted_extensions):
return Rsv2Extension()
class ServerRsv2ExtensionFactory:
name = "x-rsv2"
def process_request_params(self, params, accepted_extensions):
return [], Rsv2Extension()
|