1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
|
/*
* Copyright (C) 2020 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "adb/pairing/pairing_connection.h"
#include <stddef.h>
#include <stdint.h>
#include <functional>
#include <memory>
#include <string_view>
#include <thread>
#include <vector>
#include <adb/pairing/pairing_auth.h>
#include <adb/tls/tls_connection.h>
#include <android-base/endian.h>
#include <android-base/logging.h>
#include <android-base/macros.h>
#include <android-base/unique_fd.h>
#include "pairing.pb.h"
using namespace adb;
using android::base::unique_fd;
using TlsError = tls::TlsConnection::TlsError;
const uint8_t kCurrentKeyHeaderVersion = 1;
const uint8_t kMinSupportedKeyHeaderVersion = 1;
const uint8_t kMaxSupportedKeyHeaderVersion = 1;
const uint32_t kMaxPayloadSize = kMaxPeerInfoSize * 2;
struct PairingPacketHeader {
uint8_t version; // PairingPacket version
uint8_t type; // the type of packet (PairingPacket.Type)
uint32_t payload; // Size of the payload in bytes
} __attribute__((packed));
struct PairingAuthDeleter {
void operator()(PairingAuthCtx* p) { pairing_auth_destroy(p); }
}; // PairingAuthDeleter
using PairingAuthPtr = std::unique_ptr<PairingAuthCtx, PairingAuthDeleter>;
// PairingConnectionCtx encapsulates the protocol to authenticate two peers with
// each other. This class will open the tcp sockets and handle the pairing
// process. On completion, both sides will have each other's public key
// (certificate) if successful, otherwise, the pairing failed. The tcp port
// number is hardcoded (see pairing_connection.cpp).
//
// Each PairingConnectionCtx instance represents a different device trying to
// pair. So for the device, we can have multiple PairingConnectionCtxs while the
// host may have only one (unless host has a PairingServer).
//
// See pairing_connection_test.cpp for example usage.
//
struct PairingConnectionCtx {
public:
using Data = std::vector<uint8_t>;
using ResultCallback = pairing_result_cb;
enum class Role {
Client,
Server,
};
explicit PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info,
const Data& certificate, const Data& priv_key);
virtual ~PairingConnectionCtx();
// Starts the pairing connection on a separate thread.
// Upon completion, if the pairing was successful,
// |cb| will be called with the peer information and certificate.
// Otherwise, |cb| will be called with empty data. |fd| should already
// be opened. PairingConnectionCtx will take ownership of the |fd|.
//
// Pairing is successful if both server/client uses the same non-empty
// |pswd|, and they are able to exchange the information. |pswd| and
// |certificate| must be non-empty. Start() can only be called once in the
// lifetime of this object.
//
// Returns true if the thread was successfully started, false otherwise.
bool Start(int fd, ResultCallback cb, void* opaque);
private:
// Setup the tls connection.
bool SetupTlsConnection();
/************ PairingPacketHeader methods ****************/
// Tries to write out the header and payload.
bool WriteHeader(const PairingPacketHeader* header, std::string_view payload);
// Tries to parse incoming data into the |header|. Returns true if header
// is valid and header version is supported. |header| is filled on success.
// |header| may contain garbage if unsuccessful.
bool ReadHeader(PairingPacketHeader* header);
// Creates a PairingPacketHeader.
void CreateHeader(PairingPacketHeader* header, adb::proto::PairingPacket::Type type,
uint32_t payload_size);
// Checks if actual matches expected.
bool CheckHeaderType(adb::proto::PairingPacket::Type expected, uint8_t actual);
/*********** State related methods **************/
// Handles the State::ExchangingMsgs state.
bool DoExchangeMsgs();
// Handles the State::ExchangingPeerInfo state.
bool DoExchangePeerInfo();
// The background task to do the pairing.
void StartWorker();
// Calls |cb_| and sets the state to Stopped.
void NotifyResult(const PeerInfo* p);
static PairingAuthPtr CreatePairingAuthPtr(Role role, const Data& pswd);
enum class State {
Ready,
ExchangingMsgs,
ExchangingPeerInfo,
Stopped,
};
std::atomic<State> state_{State::Ready};
Role role_;
Data pswd_;
PeerInfo peer_info_;
Data cert_;
Data priv_key_;
// Peer's info
PeerInfo their_info_;
ResultCallback cb_;
void* opaque_ = nullptr;
std::unique_ptr<tls::TlsConnection> tls_;
PairingAuthPtr auth_;
unique_fd fd_;
std::thread thread_;
static constexpr size_t kExportedKeySize = 64;
}; // PairingConnectionCtx
PairingConnectionCtx::PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info,
const Data& cert, const Data& priv_key)
: role_(role), pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key) {
CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty());
}
PairingConnectionCtx::~PairingConnectionCtx() {
// Force close the fd and wait for the worker thread to finish.
fd_.reset();
if (thread_.joinable()) {
thread_.join();
}
}
bool PairingConnectionCtx::SetupTlsConnection() {
tls_ = tls::TlsConnection::Create(
role_ == Role::Server ? tls::TlsConnection::Role::Server
: tls::TlsConnection::Role::Client,
std::string_view(reinterpret_cast<const char*>(cert_.data()), cert_.size()),
std::string_view(reinterpret_cast<const char*>(priv_key_.data()), priv_key_.size()),
fd_);
if (tls_ == nullptr) {
LOG(ERROR) << "Unable to start TlsConnection. Unable to pair fd=" << fd_.get();
return false;
}
// Allow any peer certificate
tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
// SSL doesn't seem to behave correctly with fdevents so just do a blocking
// read for the pairing data.
if (tls_->DoHandshake() != TlsError::Success) {
LOG(ERROR) << "Failed to handshake with the peer fd=" << fd_.get();
return false;
}
// To ensure the connection is not stolen while we do the PAKE, append the
// exported key material from the tls connection to the password.
std::vector<uint8_t> exportedKeyMaterial = tls_->ExportKeyingMaterial(kExportedKeySize);
if (exportedKeyMaterial.empty()) {
LOG(ERROR) << "Failed to export key material";
return false;
}
pswd_.insert(pswd_.end(), std::make_move_iterator(exportedKeyMaterial.begin()),
std::make_move_iterator(exportedKeyMaterial.end()));
auth_ = CreatePairingAuthPtr(role_, pswd_);
return true;
}
bool PairingConnectionCtx::WriteHeader(const PairingPacketHeader* header,
std::string_view payload) {
PairingPacketHeader network_header = *header;
network_header.payload = htonl(network_header.payload);
if (!tls_->WriteFully(std::string_view(reinterpret_cast<const char*>(&network_header),
sizeof(PairingPacketHeader))) ||
!tls_->WriteFully(payload)) {
LOG(ERROR) << "Failed to write out PairingPacketHeader";
state_ = State::Stopped;
return false;
}
return true;
}
bool PairingConnectionCtx::ReadHeader(PairingPacketHeader* header) {
auto data = tls_->ReadFully(sizeof(PairingPacketHeader));
if (data.empty()) {
return false;
}
uint8_t* p = data.data();
// First byte is always PairingPacketHeader version
header->version = *p;
++p;
if (header->version < kMinSupportedKeyHeaderVersion ||
header->version > kMaxSupportedKeyHeaderVersion) {
LOG(ERROR) << "PairingPacketHeader version mismatch (us=" << kCurrentKeyHeaderVersion
<< " them=" << header->version << ")";
return false;
}
// Next byte is the PairingPacket::Type
if (!adb::proto::PairingPacket::Type_IsValid(*p)) {
LOG(ERROR) << "Unknown PairingPacket type=" << static_cast<uint32_t>(*p);
return false;
}
header->type = *p;
++p;
// Last, the payload size
header->payload = ntohl(*(reinterpret_cast<uint32_t*>(p)));
if (header->payload == 0 || header->payload > kMaxPayloadSize) {
LOG(ERROR) << "header payload not within a safe payload size (size=" << header->payload
<< ")";
return false;
}
return true;
}
void PairingConnectionCtx::CreateHeader(PairingPacketHeader* header,
adb::proto::PairingPacket::Type type,
uint32_t payload_size) {
header->version = kCurrentKeyHeaderVersion;
uint8_t type8 = static_cast<uint8_t>(static_cast<int>(type));
header->type = type8;
header->payload = payload_size;
}
bool PairingConnectionCtx::CheckHeaderType(adb::proto::PairingPacket::Type expected_type,
uint8_t actual) {
uint8_t expected = *reinterpret_cast<uint8_t*>(&expected_type);
if (actual != expected) {
LOG(ERROR) << "Unexpected header type (expected=" << static_cast<uint32_t>(expected)
<< " actual=" << static_cast<uint32_t>(actual) << ")";
return false;
}
return true;
}
void PairingConnectionCtx::NotifyResult(const PeerInfo* p) {
cb_(p, fd_.get(), opaque_);
state_ = State::Stopped;
}
bool PairingConnectionCtx::Start(int fd, ResultCallback cb, void* opaque) {
if (fd < 0) {
return false;
}
fd_.reset(fd);
State expected = State::Ready;
if (!state_.compare_exchange_strong(expected, State::ExchangingMsgs)) {
return false;
}
cb_ = cb;
opaque_ = opaque;
thread_ = std::thread([this] { StartWorker(); });
return true;
}
bool PairingConnectionCtx::DoExchangeMsgs() {
uint32_t payload = pairing_auth_msg_size(auth_.get());
std::vector<uint8_t> msg(payload);
pairing_auth_get_spake2_msg(auth_.get(), msg.data());
PairingPacketHeader header;
CreateHeader(&header, adb::proto::PairingPacket::SPAKE2_MSG, payload);
// Write our SPAKE2 msg
if (!WriteHeader(&header,
std::string_view(reinterpret_cast<const char*>(msg.data()), msg.size()))) {
LOG(ERROR) << "Failed to write SPAKE2 msg.";
return false;
}
// Read the peer's SPAKE2 msg header
if (!ReadHeader(&header)) {
LOG(ERROR) << "Invalid PairingPacketHeader.";
return false;
}
if (!CheckHeaderType(adb::proto::PairingPacket::SPAKE2_MSG, header.type)) {
return false;
}
// Read the SPAKE2 msg payload and initialize the cipher for
// encrypting the PeerInfo and certificate.
auto their_msg = tls_->ReadFully(header.payload);
if (their_msg.empty() ||
!pairing_auth_init_cipher(auth_.get(), their_msg.data(), their_msg.size())) {
LOG(ERROR) << "Unable to initialize pairing cipher [their_msg.size=" << their_msg.size()
<< "]";
return false;
}
return true;
}
bool PairingConnectionCtx::DoExchangePeerInfo() {
// Encrypt PeerInfo
std::vector<uint8_t> buf;
uint8_t* p = reinterpret_cast<uint8_t*>(&peer_info_);
buf.assign(p, p + sizeof(peer_info_));
std::vector<uint8_t> outbuf(pairing_auth_safe_encrypted_size(auth_.get(), buf.size()));
CHECK(!outbuf.empty());
size_t outsize;
if (!pairing_auth_encrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) {
LOG(ERROR) << "Failed to encrypt peer info";
return false;
}
outbuf.resize(outsize);
// Write out the packet header
PairingPacketHeader out_header;
out_header.version = kCurrentKeyHeaderVersion;
out_header.type = static_cast<uint8_t>(static_cast<int>(adb::proto::PairingPacket::PEER_INFO));
out_header.payload = htonl(outbuf.size());
if (!tls_->WriteFully(
std::string_view(reinterpret_cast<const char*>(&out_header), sizeof(out_header)))) {
LOG(ERROR) << "Unable to write PairingPacketHeader";
return false;
}
// Write out the encrypted payload
if (!tls_->WriteFully(
std::string_view(reinterpret_cast<const char*>(outbuf.data()), outbuf.size()))) {
LOG(ERROR) << "Unable to write encrypted peer info";
return false;
}
// Read in the peer's packet header
PairingPacketHeader header;
if (!ReadHeader(&header)) {
LOG(ERROR) << "Invalid PairingPacketHeader.";
return false;
}
if (!CheckHeaderType(adb::proto::PairingPacket::PEER_INFO, header.type)) {
return false;
}
// Read in the encrypted peer certificate
buf = tls_->ReadFully(header.payload);
if (buf.empty()) {
return false;
}
// Try to decrypt the certificate
outbuf.resize(pairing_auth_safe_decrypted_size(auth_.get(), buf.data(), buf.size()));
if (outbuf.empty()) {
LOG(ERROR) << "Unsupported payload while decrypting peer info.";
return false;
}
if (!pairing_auth_decrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) {
LOG(ERROR) << "Failed to decrypt";
return false;
}
outbuf.resize(outsize);
// The decrypted message should contain the PeerInfo.
if (outbuf.size() != sizeof(PeerInfo)) {
LOG(ERROR) << "Got size=" << outbuf.size() << "PeerInfo.size=" << sizeof(PeerInfo);
return false;
}
p = outbuf.data();
::memcpy(&their_info_, p, sizeof(PeerInfo));
p += sizeof(PeerInfo);
return true;
}
void PairingConnectionCtx::StartWorker() {
// Setup the secure transport
if (!SetupTlsConnection()) {
NotifyResult(nullptr);
return;
}
for (;;) {
switch (state_) {
case State::ExchangingMsgs:
if (!DoExchangeMsgs()) {
NotifyResult(nullptr);
return;
}
state_ = State::ExchangingPeerInfo;
break;
case State::ExchangingPeerInfo:
if (!DoExchangePeerInfo()) {
NotifyResult(nullptr);
return;
}
NotifyResult(&their_info_);
return;
case State::Ready:
case State::Stopped:
LOG(FATAL) << __func__ << ": Got invalid state";
return;
}
}
}
// static
PairingAuthPtr PairingConnectionCtx::CreatePairingAuthPtr(Role role, const Data& pswd) {
switch (role) {
case Role::Client:
return PairingAuthPtr(pairing_auth_client_new(pswd.data(), pswd.size()));
break;
case Role::Server:
return PairingAuthPtr(pairing_auth_server_new(pswd.data(), pswd.size()));
break;
}
}
static PairingConnectionCtx* CreateConnection(PairingConnectionCtx::Role role, const uint8_t* pswd,
size_t pswd_len, const PeerInfo* peer_info,
const uint8_t* x509_cert_pem, size_t x509_size,
const uint8_t* priv_key_pem, size_t priv_size) {
CHECK(pswd);
CHECK_GT(pswd_len, 0U);
CHECK(x509_cert_pem);
CHECK_GT(x509_size, 0U);
CHECK(priv_key_pem);
CHECK_GT(priv_size, 0U);
CHECK(peer_info);
std::vector<uint8_t> vec_pswd(pswd, pswd + pswd_len);
std::vector<uint8_t> vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size);
std::vector<uint8_t> vec_priv_key(priv_key_pem, priv_key_pem + priv_size);
return new PairingConnectionCtx(role, vec_pswd, *peer_info, vec_x509_cert, vec_priv_key);
}
PairingConnectionCtx* pairing_connection_client_new(const uint8_t* pswd, size_t pswd_len,
const PeerInfo* peer_info,
const uint8_t* x509_cert_pem, size_t x509_size,
const uint8_t* priv_key_pem, size_t priv_size) {
return CreateConnection(PairingConnectionCtx::Role::Client, pswd, pswd_len, peer_info,
x509_cert_pem, x509_size, priv_key_pem, priv_size);
}
PairingConnectionCtx* pairing_connection_server_new(const uint8_t* pswd, size_t pswd_len,
const PeerInfo* peer_info,
const uint8_t* x509_cert_pem, size_t x509_size,
const uint8_t* priv_key_pem, size_t priv_size) {
return CreateConnection(PairingConnectionCtx::Role::Server, pswd, pswd_len, peer_info,
x509_cert_pem, x509_size, priv_key_pem, priv_size);
}
void pairing_connection_destroy(PairingConnectionCtx* ctx) {
CHECK(ctx);
delete ctx;
}
bool pairing_connection_start(PairingConnectionCtx* ctx, int fd, pairing_result_cb cb,
void* opaque) {
return ctx->Start(fd, cb, opaque);
}
|