1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
|
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/windows_services/elevated_tracing_service/session_registry.h"
#include <objbase.h>
#include <utility>
#include "base/check_deref.h"
#include "base/check_op.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/memory/ptr_util.h"
namespace elevated_tracing_service {
namespace {
SessionRegistry* g_instance = nullptr;
} // namespace
// SessionRegistry::SessionCore ------------------------------------------------
// A thread-safe holder of a session's primary `IUnknown` pointer.
class SessionRegistry::SessionCore
: public base::RefCountedThreadSafe<SessionCore> {
public:
explicit SessionCore(IUnknown* unknown) : unknown_(unknown) {}
// Returns the value held, or nullptr if another thread has already taken
// it.
IUnknown* release() { return unknown_.exchange(nullptr); }
private:
friend class base::RefCountedThreadSafe<SessionCore>;
~SessionCore() = default;
std::atomic<IUnknown*> unknown_;
};
// SessionRegistry::ScopedSession ----------------------------------------------
SessionRegistry::ScopedSession::~ScopedSession() {
std::move(on_session_destroyed_).Run();
}
SessionRegistry::ScopedSession::ScopedSession(
base::Process client_process,
base::OnceClosure on_session_destroyed,
base::OnceClosure on_client_terminated)
: client_process_watcher_(std::move(client_process),
std::move(on_client_terminated)),
on_session_destroyed_(std::move(on_session_destroyed)) {}
// SessionRegistry --------------------------------------------
SessionRegistry::SessionRegistry() {
CHECK_EQ(std::exchange(g_instance, this), nullptr);
}
SessionRegistry::~SessionRegistry() {
CHECK_EQ(std::exchange(g_instance, nullptr), this);
}
// static
std::unique_ptr<SessionRegistry::ScopedSession>
SessionRegistry::RegisterActiveSession(IUnknown* session,
base::Process client_process) {
SessionRegistry& instance = CHECK_DEREF(g_instance);
// Create a new Core instance for this session and make it the current one if
// there isn't already an active session. Wrap and return the core in a new
// ScopedSession if it becomes the active session. Otherwise, return null.
auto core = base::MakeRefCounted<SessionCore>(session);
SessionCore* expected_null = nullptr;
return instance.active_session_.compare_exchange_strong(expected_null,
core.get())
? base::WrapUnique(new ScopedSession(
std::move(client_process),
base::BindOnce(&SessionRegistry::OnSessionDestroyed,
&instance, core),
base::BindOnce(&SessionRegistry::OnClientTerminated,
&instance, core)))
: nullptr;
}
void SessionRegistry::SetSessionClearedClosureForTesting(
base::OnceClosure on_session_cleared) {
on_session_cleared_ = std::move(on_session_cleared);
}
void SessionRegistry::OnSessionDestroyed(scoped_refptr<SessionCore> core) {
// The session is being destroyed cleanly. Clear the IUnknown pointer held in
// the core so that a race with the client process watcher doesn't try to use
// it after it has become a dangling pointer.
if (core->release() != nullptr) {
// This task is handling session destruction before the termination task, so
// take responsibility of clearing the active session. From this point
// onward, a new call to RegisterActiveSession() will succeed.
ClearActiveSession(core.get());
}
}
void SessionRegistry::OnClientTerminated(scoped_refptr<SessionCore> core) {
// The client process associated with the session has terminated. If the core
// still holds the session's IUnknown pointer (meaning that the ScopedSession
// has yet to be destroyed), tell COM to force a disconnect.
if (IUnknown* unknown = core->release(); unknown != nullptr) {
::CoDisconnectObject(unknown, /*dwReserved=*/0);
// This task is handling client termination before the session is destroyed,
// so take responsibility of clearing the active session. From this point
// onward, a new call to RegisterActiveSession() will succeed.
ClearActiveSession(core.get());
}
}
void SessionRegistry::ClearActiveSession(SessionCore* core) {
SessionCore* expected_core = core;
if (active_session_.compare_exchange_strong(expected_core, nullptr) &&
on_session_cleared_) {
std::move(on_session_cleared_).Run();
}
}
} // namespace elevated_tracing_service
|