1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
|
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0.
import argparse
from awscrt import io, mqtt
from awscrt.io import LogLevel
import threading
import uuid
TIMEOUT = 5 # seconds given to each step of the test before giving up
UNIQUE_ID = str(uuid.uuid4()) # prevent simultaneously-running tests from interfering with each other
CLIENT_ID = 'test_pubsub_' + UNIQUE_ID
TOPIC = 'test/pubsub/' + UNIQUE_ID
MESSAGE = 'test message ' + UNIQUE_ID
parser = argparse.ArgumentParser()
parser.add_argument('--endpoint', required=True, help="Connect to this endpoint (aka host-name)")
parser.add_argument('--port', type=int, help="Override default connection port")
parser.add_argument('--cert', help="File path to your client certificate, in PEM format")
parser.add_argument('--key', help="File path to your private key, in PEM format")
parser.add_argument('--root-ca', help="File path to root certificate authority, in PEM format")
io.init_logging(LogLevel.Trace, 'stderr')
def on_connection_interrupted(connection, error, **kwargs):
print("Connection has been interrupted with error", error)
def on_connection_resumed(connection, return_code, session_present, **kwargs):
print("Connection has been resumed with return code", return_code, "and session present:", session_present)
if not session_present:
print("Resubscribing to existing topics")
resubscribe_future, packet_id = connection.resubscribe_existing_topics()
def on_resubscribe_complete(resubscribe_future):
try:
resubscribe_results = resubscribe_future.result()
print("Resubscribe results:", resubscribe_results)
assert(resubscribe_results['packet_id'] == packet_id)
for (topic, qos) in resubscribe_results['topics']:
assert(qos is not None)
except Exception as e:
print("Resubscribe failure:", e)
exit(-1)
resubscribe_future.add_done_callback(on_resubscribe_complete)
receive_results = {}
receive_event = threading.Event()
def on_receive_message(topic, payload, dup, qos, retain, **kwargs):
receive_results['topic'] = topic
receive_results['payload'] = payload
receive_results['dup'] = dup
receive_results['qos'] = qos
receive_results['retain'] = retain
receive_event.set()
# Run
args = parser.parse_args()
event_loop_group = io.EventLoopGroup(1)
host_resolver = io.DefaultHostResolver(event_loop_group)
client_bootstrap = io.ClientBootstrap(event_loop_group, host_resolver)
tls_options = None
if args.cert or args.key or args.root_ca:
if args.cert:
assert args.key
tls_options = io.TlsContextOptions.create_client_with_mtls_from_path(args.cert, args.key)
else:
tls_options = io.TlsContextOptions()
if args.root_ca:
with open(args.root_ca, mode='rb') as ca:
rootca = ca.read()
tls_options.override_default_trust_store(rootca)
if args.port:
port = args.port
elif io.is_alpn_available():
port = 443
if tls_options:
tls_options.alpn_list = ['x-amzn-mqtt-ca']
else:
port = 8883
tls_context = io.ClientTlsContext(tls_options) if tls_options else None
mqtt_client = mqtt.Client(client_bootstrap, tls_context)
# Connect
print("Connecting to {}:{} with client-id:{}".format(args.endpoint, port, CLIENT_ID))
mqtt_connection = mqtt.Connection(
client=mqtt_client,
host_name=args.endpoint,
port=port,
client_id=CLIENT_ID,
on_connection_interrupted=on_connection_interrupted,
on_connection_resumed=on_connection_resumed)
connect_results = mqtt_connection.connect().result(TIMEOUT)
assert(connect_results['session_present'] == False)
# Subscribe
print("Subscribing to:", TOPIC)
qos = mqtt.QoS.AT_LEAST_ONCE
subscribe_future, subscribe_packet_id = mqtt_connection.subscribe(
topic=TOPIC,
qos=qos,
callback=on_receive_message)
subscribe_results = subscribe_future.result(TIMEOUT)
assert(subscribe_results['packet_id'] == subscribe_packet_id)
assert(subscribe_results['topic'] == TOPIC)
print(subscribe_results)
assert(subscribe_results['qos'] == qos)
# Publish
print("Publishing to '{}': {}".format(TOPIC, MESSAGE))
publish_future, publish_packet_id = mqtt_connection.publish(
topic=TOPIC,
payload=MESSAGE,
qos=mqtt.QoS.AT_LEAST_ONCE)
publish_results = publish_future.result(TIMEOUT)
assert(publish_results['packet_id'] == publish_packet_id)
# Receive Message
print("Waiting to receive message")
assert(receive_event.wait(TIMEOUT))
assert(receive_results['topic'] == TOPIC)
assert(receive_results['payload'].decode() == MESSAGE)
# Unsubscribe
print("Unsubscribing from topic")
unsubscribe_future, unsubscribe_packet_id = mqtt_connection.unsubscribe(TOPIC)
unsubscribe_results = unsubscribe_future.result(TIMEOUT)
assert(unsubscribe_results['packet_id'] == unsubscribe_packet_id)
# Disconnect
print("Disconnecting")
mqtt_connection.disconnect().result(TIMEOUT)
# Done
print("Test Success")
|