File: simpleobsws.py

package info (click to toggle)
python3-simpleobsws 1.4.3-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 112 kB
  • sloc: python: 489; makefile: 3
file content (336 lines) | stat: -rw-r--r-- 14,383 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
import logging
wsLogger = logging.getLogger('websockets')
wsLogger.setLevel(logging.INFO)
log = logging.getLogger(__name__)
import asyncio
import websockets
import base64
import hashlib
import json
import msgpack
import uuid
import time
import inspect
import enum
from dataclasses import dataclass, field
from inspect import signature

RPC_VERSION = 1

class RequestBatchExecutionType(enum.Enum):
    SerialRealtime = 0
    SerialFrame = 1
    Parallel = 2

@dataclass
class IdentificationParameters:
    ignoreNonFatalRequestChecks: bool = None
    eventSubscriptions: int = None

@dataclass
class Request:
    requestType: str
    requestData: dict = None
    inputVariables: dict = None # Request batch only
    outputVariables: dict = None # Request batch only

@dataclass
class RequestStatus:
    result: bool = False
    code: int = 0
    comment: str = None

@dataclass
class RequestResponse:
    requestType: str = ''
    requestStatus: RequestStatus = field(default_factory=RequestStatus)
    responseData: dict = None

    def has_data(self):
        return self.responseData != None

    def ok(self):
        return self.requestStatus.result

@dataclass
class _ResponseWaiter:
    event: asyncio.Event = field(default_factory=asyncio.Event)
    response_data: dict = None

class MessageTimeout(Exception):
    pass
class EventRegistrationError(Exception):
    pass
class NotIdentifiedError(Exception):
    pass

async def _wait_for_cond(cond, func):
    async with cond:
        await cond.wait_for(func)

class WebSocketClient:
    def __init__(self,
        url: str = "ws://localhost:4444",
        password: str = '',
        identification_parameters: IdentificationParameters = IdentificationParameters()
    ):
        self.url = url
        self.password = password
        self.identification_parameters = identification_parameters

        self.http_headers = {}
        self.ws = None
        self.ws_open = False
        self.waiters = {}
        self.identified = False
        self.recv_task = None
        self.hello_message = None
        self.event_callbacks = []
        self.cond = asyncio.Condition()

    # Todo: remove bool return, raise error if already open
    async def connect(self):
        if self.ws and self.ws_open:
            log.debug('WebSocket session is already open. Returning early.')
            return False
        self.answers = {}
        self.recv_task = None
        self.identified = False
        self.hello_message = None
        self.ws = await websockets.connect(self.url, subprotocols = ['obswebsocket.msgpack'], additional_headers = self.http_headers, max_size=2**24)
        self.ws_open = True
        self.recv_task = asyncio.create_task(self._ws_recv_task())
        return True

    async def wait_until_identified(self, timeout: int = 10):
        if not self.ws_open:
            log.debug('WebSocket session is not open. Returning early.')
            return False
        try:
            await asyncio.wait_for(_wait_for_cond(self.cond, self.is_identified), timeout=timeout)
            return True
        except asyncio.TimeoutError:
            return False

    # Todo: remove bool return, raise error if already closed
    async def disconnect(self):
        if self.recv_task == None:
            log.debug('WebSocket session is not open. Returning early.')
            return False
        self.recv_task.cancel()
        await self.ws.close()
        self.ws = None
        self.ws_open = False
        self.answers = {}
        self.identified = False
        self.recv_task = None
        self.hello_message = None
        return True

    async def call(self, request: Request, timeout: int = 15):
        if not self.identified:
            raise NotIdentifiedError('Calls to requests cannot be made without being identified with obs-websocket.')
        request_id = str(uuid.uuid1())
        request_payload = {
            'op': 6,
            'd': {
                'requestType': request.requestType,
                'requestId': request_id
            }
        }
        if request.requestData != None:
            request_payload['d']['requestData'] = request.requestData
        log.debug('Sending Request message:\n{}'.format(json.dumps(request_payload, indent=2)))
        waiter = _ResponseWaiter()
        try:
            self.waiters[request_id] = waiter
            await self.ws.send(msgpack.packb(request_payload))
            await asyncio.wait_for(waiter.event.wait(), timeout=timeout)
        except asyncio.TimeoutError:
            raise MessageTimeout('The request with type {} timed out after {} seconds.'.format(request.requestType, timeout))
        finally:
            del self.waiters[request_id]
        return self._build_request_response(waiter.response_data)

    async def emit(self, request: Request):
        if not self.identified:
            raise NotIdentifiedError('Emits to requests cannot be made without being identified with obs-websocket.')
        request_id = str(uuid.uuid1())
        request_payload = {
            'op': 6,
            'd': {
                'requestType': request.requestType,
                'requestId': 'emit_{}'.format(request_id)
            }
        }
        if request.requestData != None:
            request_payload['d']['requestData'] = request.requestData
        log.debug('Sending Request message:\n{}'.format(json.dumps(request_payload, indent=2)))
        await self.ws.send(msgpack.packb(request_payload))

    async def call_batch(self, requests: list, timeout: int = 15, halt_on_failure: bool = None, execution_type: RequestBatchExecutionType = None, variables: dict = None):
        if not self.identified:
            raise NotIdentifiedError('Calls to requests cannot be made without being identified with obs-websocket.')
        request_batch_id = str(uuid.uuid1())
        request_batch_payload = {
            'op': 8,
            'd': {
                'requestId': request_batch_id,
                'requests': []
            }
        }
        if halt_on_failure != None:
            request_batch_payload['d']['haltOnFailure'] = halt_on_failure
        if execution_type:
            request_batch_payload['d']['executionType'] = execution_type.value
        if variables:
            request_batch_payload['d']['variables'] = variables
        for request in requests:
            request_payload = {
                'requestType': request.requestType
            }
            if request.inputVariables:
                request_payload['inputVariables'] = request.inputVariables
            if request.outputVariables:
                request_payload['outputVariables'] = request.outputVariables
            if request.requestData:
                request_payload['requestData'] = request.requestData
            request_batch_payload['d']['requests'].append(request_payload)
        log.debug('Sending Request batch message:\n{}'.format(json.dumps(request_batch_payload, indent=2)))
        waiter = _ResponseWaiter()
        try:
            self.waiters[request_batch_id] = waiter
            await self.ws.send(msgpack.packb(request_batch_payload))
            await asyncio.wait_for(waiter.event.wait(), timeout=timeout)
        except asyncio.TimeoutError:
            raise MessageTimeout('The request batch timed out after {} seconds.'.format(timeout))
        finally:
            del self.waiters[request_batch_id]
        ret = []
        for result in waiter.response_data['results']:
            ret.append(self._build_request_response(result))
        return ret

    async def emit_batch(self, requests: list, halt_on_failure: bool = None, execution_type: RequestBatchExecutionType = None, variables: dict = None):
        if not self.identified:
            raise NotIdentifiedError('Emits to requests cannot be made without being identified with obs-websocket.')
        request_batch_id = str(uuid.uuid1())
        request_batch_payload = {
            'op': 8,
            'd': {
                'requestId': 'emit_{}'.format(request_batch_id),
                'requests': []
            }
        }
        if halt_on_failure != None:
            request_batch_payload['d']['haltOnFailure'] = halt_on_failure
        if execution_type:
            request_batch_payload['d']['executionType'] = execution_type.value
        if variables:
            request_batch_payload['d']['variables'] = variables
        for request in requests:
            request_payload = {
                'requestType': request.requestType
            }
            if request.requestData:
                request_payload['requestData'] = request.requestData
            request_batch_payload['d']['requests'].append(request_payload)
        log.debug('Sending Request batch message:\n{}'.format(json.dumps(request_batch_payload, indent=2)))
        await self.ws.send(msgpack.packb(request_batch_payload))

    def register_event_callback(self, callback, event: str = None):
        if not inspect.iscoroutinefunction(callback):
            raise EventRegistrationError('Registered functions must be async')
        else:
            event_callbacks_copy = self.event_callbacks.copy()
            event_callbacks_copy.append((callback, event))
            self.event_callbacks = event_callbacks_copy

    def deregister_event_callback(self, callback, event: str = None):
        event_callbacks_copy = self.event_callbacks.copy()
        for c, t in self.event_callbacks: # We can use the old event_callbacks list to avoid our iterator being invalidated
            if (c == callback) and (event == None or t == event):
                event_callbacks_copy.remove((c, t))
        self.event_callbacks = event_callbacks_copy

    def is_identified(self):
        return self.identified

    def _get_hello_data(self):
        return self.hello_message

    def _build_request_response(self, response: dict):
        ret = RequestResponse(response['requestType'], responseData = response.get('responseData'))
        ret.requestStatus.result = response['requestStatus']['result']
        ret.requestStatus.code = response['requestStatus']['code']
        ret.requestStatus.comment = response['requestStatus'].get('comment')
        return ret

    async def _send_identify(self, password, identification_parameters):
        if self.hello_message == None:
            return
        identify_message = {'op': 1, 'd': {}}
        identify_message['d']['rpcVersion'] = RPC_VERSION
        if 'authentication' in self.hello_message:
            secret = base64.b64encode(hashlib.sha256((self.password + self.hello_message['authentication']['salt']).encode('utf-8')).digest())
            authentication_string = base64.b64encode(hashlib.sha256(secret + (self.hello_message['authentication']['challenge'].encode('utf-8'))).digest()).decode('utf-8')
            identify_message['d']['authentication'] = authentication_string
        if self.identification_parameters.ignoreNonFatalRequestChecks != None:
            identify_message['d']['ignoreNonFatalRequestChecks'] = self.identification_parameters.ignoreNonFatalRequestChecks
        if self.identification_parameters.eventSubscriptions != None:
            identify_message['d']['eventSubscriptions'] = self.identification_parameters.eventSubscriptions
        log.debug('Sending Identify message:\n{}'.format(json.dumps(identify_message, indent=2)))
        await self.ws.send(msgpack.packb(identify_message))

    async def _ws_recv_task(self):
        while self.ws_open:
            message = ''
            try:
                message = await self.ws.recv()
                if not message or type(message) != bytes:
                    continue
                incoming_payload = msgpack.unpackb(message)

                log.debug('Received message:\n{}'.format(json.dumps(incoming_payload, indent=2)))

                op_code = incoming_payload['op']
                data_payload = incoming_payload['d']
                if op_code == 7 or op_code == 9: # RequestResponse or RequestBatchResponse
                    paylod_request_id = data_payload['requestId']
                    if paylod_request_id.startswith('emit_'):
                        continue
                    try:
                        waiter = self.waiters[paylod_request_id]
                        waiter.response_data = data_payload
                        waiter.event.set()
                    except KeyError:
                        log.warning('Discarding request response {} because there is no waiter for it.'.format(paylod_request_id))
                elif op_code == 5: # Event
                    for callback, trigger in self.event_callbacks:
                        if trigger == None:
                            params = len(signature(callback).parameters)
                            if params == 1:
                                asyncio.create_task(callback(data_payload))
                            elif params == 2:
                                asyncio.create_task(callback(data_payload['eventType'], data_payload.get('eventData')))
                            elif params == 3:
                                asyncio.create_task(callback(data_payload['eventType'], data_payload.get('eventIntent'), data_payload.get('eventData')))
                        elif trigger == data_payload['eventType']:
                            asyncio.create_task(callback(data_payload.get('eventData')))
                elif op_code == 0: # Hello
                    self.hello_message = data_payload
                    await self._send_identify(self.password, self.identification_parameters)
                elif op_code == 2: # Identified
                    self.identified = True
                    async with self.cond:
                        self.cond.notify_all()
                else:
                    log.warning('Unknown OpCode: {}'.format(op_code))
            except (websockets.exceptions.ConnectionClosed, websockets.exceptions.ConnectionClosedError, websockets.exceptions.ConnectionClosedOK):
                log.debug('The WebSocket connection was closed. Code: {} | Reason: {}'.format(self.ws.close_code, self.ws.close_reason))
                self.ws_open = False
                break
            except (ValueError, msgpack.UnpackException):
                continue
        self.ws_open = False
        self.identified = False