1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
|
"""Go2rtc websocket messages."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Annotated, Any, ClassVar
from mashumaro import field_options
from mashumaro.config import BaseConfig
from mashumaro.mixins.orjson import DataClassORJSONMixin
from mashumaro.types import Discriminator
from webrtc_models import (
RTCIceServer, # noqa: TCH002 # Mashumaro needs the import to generate the correct code
)
@dataclass(frozen=True)
class WsMessage:
"""Websocket message."""
TYPE: ClassVar[str]
def __post_serialize__(self, d: dict[Any, Any]) -> dict[Any, Any]:
"""Add type to serialized dict."""
# ClassVar will not serialize by default
d["type"] = self.TYPE
return d
@dataclass(frozen=True)
class BaseMessage(WsMessage, DataClassORJSONMixin):
"""Base message class."""
class Config(BaseConfig):
"""Config for BaseMessage."""
serialize_by_alias = True
discriminator = Discriminator(
field="type",
include_subtypes=True,
variant_tagger_fn=lambda cls: cls.TYPE,
)
@dataclass(frozen=True)
class WebRTCCandidate(BaseMessage):
"""WebRTC ICE candidate message."""
TYPE = "webrtc/candidate"
candidate: str = field(metadata=field_options(alias="value"))
@dataclass(frozen=True)
class WebRTC(BaseMessage):
"""WebRTC message."""
TYPE = "webrtc"
value: Annotated[
WebRTCOffer | WebRTCValue,
Discriminator(
field="type",
include_subtypes=True,
variant_tagger_fn=lambda cls: cls.TYPE,
),
]
@dataclass(frozen=True)
class WebRTCValue(WsMessage):
"""WebRTC value for WebRTC message."""
sdp: str
@dataclass(frozen=True)
class WebRTCOffer(WebRTCValue):
"""WebRTC offer message."""
TYPE = "offer"
ice_servers: list[RTCIceServer]
def __pre_serialize__(self) -> WebRTCOffer:
"""Pre serialize.
Go2rtc supports only ice_servers with urls as list of strings.
"""
for server in self.ice_servers:
if isinstance(server.urls, str):
server.urls = [server.urls]
return self
def to_json(self, **kwargs: Any) -> str:
"""Convert to json."""
return WebRTC(self).to_json(**kwargs)
@dataclass(frozen=True)
class WebRTCAnswer(WebRTCValue):
"""WebRTC answer message."""
TYPE = "answer"
@dataclass(frozen=True)
class WsError(BaseMessage):
"""Error message."""
TYPE = "error"
error: str = field(metadata=field_options(alias="value"))
ReceiveMessages = WebRTCAnswer | WebRTCCandidate | WsError
SendMessages = WebRTCCandidate | WebRTCOffer
|