File: controller_loopback.py

package info (click to toggle)
python-bumble 0.0.225-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 9,464 kB
  • sloc: python: 75,258; java: 3,782; javascript: 823; xml: 203; sh: 172; makefile: 8
file content (206 lines) | stat: -rw-r--r-- 7,251 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import time

import click

import bumble.logging
from bumble.colors import color
from bumble.hci import (
    HCI_READ_LOOPBACK_MODE_COMMAND,
    HCI_WRITE_LOOPBACK_MODE_COMMAND,
    HCI_Read_Loopback_Mode_Command,
    HCI_Write_Loopback_Mode_Command,
    LoopbackMode,
)
from bumble.host import Host
from bumble.transport import open_transport


class Loopback:
    """Send and receive ACL data packets in local loopback mode"""

    def __init__(self, packet_size: int, packet_count: int, transport: str):
        self.transport = transport
        self.packet_size = packet_size
        self.packet_count = packet_count
        self.connection_handle: int | None = None
        self.connection_event = asyncio.Event()
        self.done = asyncio.Event()
        self.expected_cid = 0
        self.bytes_received = 0
        self.start_timestamp = 0.0
        self.last_timestamp = 0.0

    def on_connection(self, connection_handle: int, *args):
        """Retrieve connection handle from new connection event"""
        if not self.connection_event.is_set():
            # save first connection handle for ACL
            # subsequent connections are SCO
            self.connection_handle = connection_handle
            self.connection_event.set()

    def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
        """Calculate packet receive speed"""
        now = time.time()
        print(f'<<< Received packet {cid}: {len(pdu)} bytes')
        assert connection_handle == self.connection_handle
        assert cid == self.expected_cid
        self.expected_cid += 1
        if cid == 0:
            self.start_timestamp = now
        else:
            elapsed_since_start = now - self.start_timestamp
            elapsed_since_last = now - self.last_timestamp
            self.bytes_received += len(pdu)
            instant_rx_speed = len(pdu) / elapsed_since_last
            average_rx_speed = self.bytes_received / elapsed_since_start
            print(
                color(
                    f'@@@ RX speed: instant={instant_rx_speed:.4f},'
                    f' average={average_rx_speed:.4f}',
                    'cyan',
                )
            )

        self.last_timestamp = now

        if self.expected_cid == self.packet_count:
            print(color('@@@ Received last packet', 'green'))
            self.done.set()

    async def run(self) -> None:
        """Run a loopback throughput test"""
        print(color('>>> Connecting to HCI...', 'green'))
        async with await open_transport(self.transport) as (
            hci_source,
            hci_sink,
        ):
            print(color('>>> Connected', 'green'))

            host = Host(hci_source, hci_sink)
            await host.reset()

            # make sure data can fit in one l2cap pdu
            l2cap_header_size = 4

            packet_queue = (
                host.acl_packet_queue
                if host.acl_packet_queue
                else host.le_acl_packet_queue
            )
            if packet_queue is None:
                print(color('!!! No packet queue', 'red'))
                return
            max_packet_size = packet_queue.max_packet_size - l2cap_header_size
            if self.packet_size > max_packet_size:
                print(
                    color(
                        f'!!! Packet size ({self.packet_size}) larger than max supported'
                        f' size ({max_packet_size})',
                        'red',
                    )
                )
                return

            if not host.supports_command(
                HCI_WRITE_LOOPBACK_MODE_COMMAND
            ) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND):
                print(color('!!! Loopback mode not supported', 'red'))
                return

            # set event callbacks
            host.on('connection', self.on_connection)
            host.on('l2cap_pdu', self.on_l2cap_pdu)

            loopback_mode = LoopbackMode.LOCAL

            print(color('### Setting loopback mode', 'blue'))
            await host.send_sync_command(
                HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
            )

            print(color('### Checking loopback mode', 'blue'))
            response = await host.send_sync_command(HCI_Read_Loopback_Mode_Command())
            if response.loopback_mode != loopback_mode:
                print(color('!!! Loopback mode mismatch', 'red'))
                return

            await self.connection_event.wait()
            assert self.connection_handle is not None
            print(color('### Connected', 'cyan'))

            print(color('=== Start sending', 'magenta'))
            start_time = time.time()
            bytes_sent = 0
            for cid in range(0, self.packet_count):
                # using the cid as an incremental index
                host.send_l2cap_pdu(
                    self.connection_handle, cid, bytes(self.packet_size)
                )
                print(
                    color(
                        f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
                    )
                )
                bytes_sent += self.packet_size  # don't count L2CAP or HCI header sizes
                await asyncio.sleep(0)  # yield to allow packet receive

            await self.done.wait()
            print(color('=== Done!', 'magenta'))

            elapsed = time.time() - start_time
            average_tx_speed = bytes_sent / elapsed
            print(
                color(
                    f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
                    f' in {elapsed:.2f} seconds)',
                    'green',
                )
            )


# -----------------------------------------------------------------------------
@click.command()
@click.option(
    '--packet-size',
    '-s',
    metavar='SIZE',
    type=click.IntRange(8, 4096),
    default=500,
    help='Packet size',
)
@click.option(
    '--packet-count',
    '-c',
    metavar='COUNT',
    type=click.IntRange(1, 65535),
    default=10,
    help='Packet count',
)
@click.argument('transport')
def main(packet_size, packet_count, transport):
    bumble.logging.setup_basic_logging()
    loopback = Loopback(packet_size, packet_count, transport)
    asyncio.run(loopback.run())


# -----------------------------------------------------------------------------
if __name__ == '__main__':
    main()