1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
|
#include <gtest/gtest.h>
#include <test/cpp/lazy/test_lazy_ops_util.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include <vector>
namespace torch {
namespace lazy {
namespace {
class LazyGraphExecutorTest : public ::testing::Test {
protected:
void SetUp() override {
executor_ = LazyGraphExecutor::Get();
}
using CachedComputationType = LazyGraphExecutor::CachedComputation;
std::shared_ptr<CachedComputationType> GetCachedComputation(hash_t hash) {
return executor_->GetComputationCache()->Get(hash);
}
void EnsureComputationIsCached(
std::vector<LazyTensorPtr>& tensors,
hash_t hash) {
// Force computation to be cached by syncing the tensors.
executor_->SyncTensorsGraph(
&tensors, /* devices */ {}, /* wait */ true, /* sync_ltc_data */ true);
// Ensure that the computation cache entry exists.
auto cached_computation = GetCachedComputation(hash);
EXPECT_NE(cached_computation, nullptr)
<< "Computation should be cached after sync";
}
LazyGraphExecutor* executor_;
};
TEST_F(LazyGraphExecutorTest, TestClearComputationCache) {
ForEachDevice([&](const torch::Device& device) {
torch::Tensor tensor_a =
torch::rand({2, 2}, at::TensorOptions(torch::kFloat));
torch::Tensor tensor_b =
torch::rand({2, 2}, at::TensorOptions(torch::kFloat));
torch::Tensor xla_tensor_a = CopyToDevice(tensor_a, device);
torch::Tensor xla_tensor_b = CopyToDevice(tensor_b, device);
torch::Tensor result = xla_tensor_a + xla_tensor_b;
std::vector<LazyTensorPtr> tensors{TryGetLtcTensor(result)};
hash_t hash = executor_->GetGraphHash(tensors);
EnsureComputationIsCached(tensors, hash);
EXPECT_EQ(executor_->GetComputationCache()->Numel(), 1);
// Clear the entire computation cache.
executor_->ClearComputationCache();
// Ensure that there are no cache entries.
EXPECT_EQ(executor_->GetComputationCache()->Numel(), 0);
auto cached_computation = GetCachedComputation(hash);
EXPECT_EQ(cached_computation, nullptr)
<< "Cache entry should be null after clearing";
});
}
TEST_F(LazyGraphExecutorTest, TestRemoveSpecificCacheEntry) {
ForEachDevice([&](const torch::Device& device) {
torch::Tensor tensor_a =
torch::rand({2, 2}, at::TensorOptions(torch::kFloat));
torch::Tensor tensor_b =
torch::rand({2, 2}, at::TensorOptions(torch::kFloat));
torch::Tensor xla_tensor_a = CopyToDevice(tensor_a, device);
torch::Tensor xla_tensor_b = CopyToDevice(tensor_b, device);
torch::Tensor result = xla_tensor_a + xla_tensor_b;
std::vector<LazyTensorPtr> tensors{TryGetLtcTensor(result)};
hash_t hash = executor_->GetGraphHash(tensors);
EnsureComputationIsCached(tensors, hash);
// Remove a specific cache entry.
executor_->RemoveFromComputationCache(hash);
// Ensure that the cache entry has been removed.
auto cached_computation = GetCachedComputation(hash);
EXPECT_EQ(cached_computation, nullptr)
<< "Cache entry should be null after removal";
// Attempting to remove again should not do anything.
executor_->RemoveFromComputationCache(hash);
});
}
} // namespace
} // namespace lazy
} // namespace torch
|