1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
|
import io
import struct
from typing import Callable, List, Tuple
import numpy as np
import torch
from math import prod
import socket
from megapose_server.server_operations import ServerMessage, SERVER_OPERATION_CODE_LENGTH
def read_string(buffer: io.BytesIO):
'''
Read a string from a buffer.
From the buffer, read an int describing the length of the string, then read the characters (ascii)
'''
str_count = struct.unpack('>I', buffer.read(struct.calcsize('I')))[0]
data = struct.unpack(f'{str_count}s', buffer.read(str_count))[0]
return data.decode('ascii')
def pack_string(s: str, buffer: bytearray):
'''
Pack a string into a buffer
an int being the length of the string and the characters are appended to the buffer
'''
buffer.extend(struct.pack(f'>I{len(s)}s', len(s), s.encode('ascii')))
def read_image(buffer: io.BytesIO):
'''
Read an image (an array of uint8 values) from a buffer.
First, 3 ints are read (height, width and channels) then height * width * channels bytes
The elements are consumed from the buffer.
If the image has an alpha channel, it is discarded
returns the image as an np.array
'''
image_shape = struct.unpack('>3I', buffer.read(struct.calcsize('>3I')))
elem_count = prod(image_shape)
img_bytes = buffer.read(elem_count)
# img = torch.frombuffer(img_bytes, dtype=torch.uint8, count=elem_count).view(image_shape)
img = np.frombuffer(img_bytes, dtype=np.uint8, count=elem_count).reshape(image_shape)
if image_shape[-1] == 4: # Image is of type RGBA, discard alpha
img = img[:, :, :3]
return img
def read_uint16_image(buffer: io.BytesIO):
'''
Read an uint16 image from a buffer
First, 2 ints are read (height, width) then the endianness symbol and then height * width * 2 bytes.
The elements are consumed from the buffer.
returns the image as an np.array
'''
image_shape = struct.unpack('>2I', buffer.read(struct.calcsize('>2I')))
endianness = struct.unpack('c', buffer.read(struct.calcsize('c')))[0].decode('ascii')
assert endianness in ['>', '<']
elem_count = prod(image_shape)
img_bytes = buffer.read(elem_count * 2)
dt = np.dtype(np.uint16)
dt = dt.newbyteorder(endianness)
img = np.frombuffer(img_bytes, dtype=dt, count=elem_count).reshape(image_shape)
return img
def pack_image(image, buffer: bytearray):
'''
Pack an image into a buffer
'''
image = image.astype(np.uint8)
assert len(image.shape) == 3
buffer.extend(struct.pack('>3I', *image.shape))
buffer.extend(image.tobytes('C'))
def create_message(message_code: ServerMessage, fn: Callable[[bytearray], None]):
'''
Create a message to be sent on the network
A message has the shape
MSG_LENGTH | MSG_CODE | DATA
where MSG_LENGTH is the length of DATA (in bytes), and MSG_CODE is the operation to be performed
'''
data = bytearray()
temp_length = struct.pack('>I', 0)
data.extend(temp_length)
data.extend(struct.pack(f'{SERVER_OPERATION_CODE_LENGTH}s', message_code.value.encode('UTF-8')))
header_length = struct.calcsize(f'>I{SERVER_OPERATION_CODE_LENGTH}s')
fn(data)
data[0:struct.calcsize('>I')] = struct.pack('>I', len(data) - header_length)
return data
def receive_message(s: socket.socket) -> Tuple[str, io.BytesIO]:
'''
Read a socket message
A message has the shape
MSG_LENGTH | MSG_CODE | DATA
returns the code as an str value associated to the ServerOperation enum, as well as DATA, as an io.ByteIO
'''
msg_length = s.recv(4)
length = struct.unpack('>I', msg_length)[0]
code = s.recv(SERVER_OPERATION_CODE_LENGTH).decode('UTF-8')
# data = bytearray(length)
data = bytearray()
iters = 0
read_count = 0
while read_count < length:
packet = s.recv(length - read_count)
if not packet:
return None
data.extend(packet)
# data[read_count:read_count + len(packet)] = packet
read_count += len(packet)
return code, io.BytesIO(data)
|