File: test_lazy_graph_executor.cpp

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (97 lines) | stat: -rw-r--r-- 3,252 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include <gtest/gtest.h>

#include <test/cpp/lazy/test_lazy_ops_util.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>

#include <vector>

namespace torch {
namespace lazy {
namespace {

class LazyGraphExecutorTest : public ::testing::Test {
 protected:
  void SetUp() override {
    executor_ = LazyGraphExecutor::Get();
  }

  using CachedComputationType = LazyGraphExecutor::CachedComputation;

  std::shared_ptr<CachedComputationType> GetCachedComputation(hash_t hash) {
    return executor_->GetComputationCache()->Get(hash);
  }

  void EnsureComputationIsCached(
      std::vector<LazyTensorPtr>& tensors,
      hash_t hash) {
    // Force computation to be cached by syncing the tensors.
    executor_->SyncTensorsGraph(
        &tensors, /* devices */ {}, /* wait */ true, /* sync_ltc_data */ true);

    // Ensure that the computation cache entry exists.
    auto cached_computation = GetCachedComputation(hash);
    EXPECT_NE(cached_computation, nullptr)
        << "Computation should be cached after sync";
  }

  LazyGraphExecutor* executor_;
};

TEST_F(LazyGraphExecutorTest, TestClearComputationCache) {
  ForEachDevice([&](const torch::Device& device) {
    torch::Tensor tensor_a =
        torch::rand({2, 2}, at::TensorOptions(torch::kFloat));
    torch::Tensor tensor_b =
        torch::rand({2, 2}, at::TensorOptions(torch::kFloat));

    torch::Tensor xla_tensor_a = CopyToDevice(tensor_a, device);
    torch::Tensor xla_tensor_b = CopyToDevice(tensor_b, device);
    torch::Tensor result = xla_tensor_a + xla_tensor_b;

    std::vector<LazyTensorPtr> tensors{TryGetLtcTensor(result)};
    hash_t hash = executor_->GetGraphHash(tensors);
    EnsureComputationIsCached(tensors, hash);
    EXPECT_EQ(executor_->GetComputationCache()->Numel(), 1);

    // Clear the entire computation cache.
    executor_->ClearComputationCache();

    // Ensure that there are no cache entries.
    EXPECT_EQ(executor_->GetComputationCache()->Numel(), 0);
    auto cached_computation = GetCachedComputation(hash);
    EXPECT_EQ(cached_computation, nullptr)
        << "Cache entry should be null after clearing";
  });
}

TEST_F(LazyGraphExecutorTest, TestRemoveSpecificCacheEntry) {
  ForEachDevice([&](const torch::Device& device) {
    torch::Tensor tensor_a =
        torch::rand({2, 2}, at::TensorOptions(torch::kFloat));
    torch::Tensor tensor_b =
        torch::rand({2, 2}, at::TensorOptions(torch::kFloat));

    torch::Tensor xla_tensor_a = CopyToDevice(tensor_a, device);
    torch::Tensor xla_tensor_b = CopyToDevice(tensor_b, device);
    torch::Tensor result = xla_tensor_a + xla_tensor_b;

    std::vector<LazyTensorPtr> tensors{TryGetLtcTensor(result)};
    hash_t hash = executor_->GetGraphHash(tensors);
    EnsureComputationIsCached(tensors, hash);

    // Remove a specific cache entry.
    executor_->RemoveFromComputationCache(hash);

    // Ensure that the cache entry has been removed.
    auto cached_computation = GetCachedComputation(hash);
    EXPECT_EQ(cached_computation, nullptr)
        << "Cache entry should be null after removal";

    // Attempting to remove again should not do anything.
    executor_->RemoveFromComputationCache(hash);
  });
}

} // namespace
} // namespace lazy
} // namespace torch