File: network_utils.py

package info (click to toggle)
visp 3.6.0-5
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 119,296 kB
  • sloc: cpp: 500,914; ansic: 52,904; xml: 22,642; python: 7,365; java: 4,247; sh: 482; makefile: 237; objc: 145
file content (119 lines) | stat: -rw-r--r-- 4,101 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import io
import struct
from typing import Callable, List, Tuple
import numpy as np
import torch
from math import prod
import socket

from megapose_server.server_operations import ServerMessage, SERVER_OPERATION_CODE_LENGTH

def read_string(buffer: io.BytesIO):
    '''
    Read a string from a buffer.
    From the buffer, read an int describing the length of the string, then read the characters (ascii)
    '''
    str_count = struct.unpack('>I', buffer.read(struct.calcsize('I')))[0]
    data = struct.unpack(f'{str_count}s', buffer.read(str_count))[0]
    return data.decode('ascii')
def pack_string(s: str, buffer: bytearray):
    '''
    Pack a string into a buffer
    an int being the length of the string and the characters are appended to the buffer
    '''
    buffer.extend(struct.pack(f'>I{len(s)}s', len(s), s.encode('ascii')))

def read_image(buffer: io.BytesIO):
    '''
    Read an image (an array of uint8 values)  from a buffer.

    First, 3 ints are read (height, width and channels) then height * width * channels bytes
    The elements are consumed from the buffer.

    If the image has an alpha channel, it is discarded

    returns the image as an np.array
    '''
    image_shape = struct.unpack('>3I', buffer.read(struct.calcsize('>3I')))
    elem_count = prod(image_shape)
    img_bytes = buffer.read(elem_count)
    # img = torch.frombuffer(img_bytes, dtype=torch.uint8, count=elem_count).view(image_shape)
    img = np.frombuffer(img_bytes, dtype=np.uint8, count=elem_count).reshape(image_shape)
    if image_shape[-1] == 4: # Image is of type RGBA, discard alpha
        img = img[:, :, :3]

    return img

def read_uint16_image(buffer: io.BytesIO):
    '''
    Read an uint16 image from a buffer
    First, 2 ints are read (height, width) then the endianness symbol and then height * width * 2 bytes.

    The elements are consumed from the buffer.

    returns the image as an np.array
    '''
    image_shape = struct.unpack('>2I', buffer.read(struct.calcsize('>2I')))
    endianness = struct.unpack('c', buffer.read(struct.calcsize('c')))[0].decode('ascii')
    assert endianness in ['>', '<']
    elem_count = prod(image_shape)
    img_bytes = buffer.read(elem_count * 2)
    dt = np.dtype(np.uint16)
    dt = dt.newbyteorder(endianness)
    img = np.frombuffer(img_bytes, dtype=dt, count=elem_count).reshape(image_shape)
    return img

def pack_image(image, buffer: bytearray):
    '''
    Pack an image into a buffer
    '''
    image = image.astype(np.uint8)
    assert len(image.shape) == 3
    buffer.extend(struct.pack('>3I', *image.shape))
    buffer.extend(image.tobytes('C'))

def create_message(message_code: ServerMessage, fn: Callable[[bytearray], None]):
    '''
    Create a message to be sent on the network

    A message has the shape

    MSG_LENGTH | MSG_CODE | DATA
    where MSG_LENGTH is the length of DATA (in bytes), and MSG_CODE is the operation to be performed
    '''
    data = bytearray()
    temp_length = struct.pack('>I', 0)
    data.extend(temp_length)
    data.extend(struct.pack(f'{SERVER_OPERATION_CODE_LENGTH}s', message_code.value.encode('UTF-8')))

    header_length = struct.calcsize(f'>I{SERVER_OPERATION_CODE_LENGTH}s')
    fn(data)
    data[0:struct.calcsize('>I')] = struct.pack('>I', len(data) - header_length)
    return data

def receive_message(s: socket.socket) -> Tuple[str, io.BytesIO]:
    '''
    Read a socket message
    A message has the shape

    MSG_LENGTH | MSG_CODE | DATA

    returns the code as an str value associated to the ServerOperation enum, as well as DATA, as an io.ByteIO

    '''
    msg_length = s.recv(4)
    length = struct.unpack('>I', msg_length)[0]
    code = s.recv(SERVER_OPERATION_CODE_LENGTH).decode('UTF-8')
    # data = bytearray(length)
    data = bytearray()
    iters = 0
    read_count = 0
    while read_count < length:
        packet = s.recv(length - read_count)
        if not packet:
            return None
        data.extend(packet)
        # data[read_count:read_count + len(packet)] = packet
        read_count += len(packet)

    return code, io.BytesIO(data)