1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
|
#include <test/cpp/lazy/test_lazy_ops_util.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/ir_builder.h>
#include <torch/csrc/lazy/core/ir_dump_util.h>
#include <torch/csrc/lazy/core/tensor_impl.h>
#include <iostream>
#include <string>
namespace torch {
namespace lazy {
namespace {
bool IsLtcTensor(const at::Tensor& tensor) {
return dynamic_cast<torch::lazy::LTCTensorImpl*>(
tensor.unsafeGetTensorImpl());
}
std::unordered_set<std::string>* CreateIgnoredCounters() {
std::unordered_set<std::string>* icounters =
new std::unordered_set<std::string>();
// Add below the counters whose name need to be ignored when doing
// is-any-counter-changed assertins.
icounters->insert("aten::rand");
return icounters;
}
} // namespace
const std::unordered_set<std::string>* GetIgnoredCounters() {
static const std::unordered_set<std::string>* icounters =
CreateIgnoredCounters();
return icounters;
}
at::Tensor ToCpuTensor(const at::Tensor& tensor) {
// tensor.to() implicitly triggers a sync if t.device=torch::kLazy.
return tensor.to(torch::kCPU);
}
torch::Tensor CopyToDevice(
const torch::Tensor& tensor,
const torch::Device& device) {
return tensor.clone().to(device, /*non_blocking=*/false, /*copy=*/true);
}
bool EqualValues(at::Tensor tensor1, at::Tensor tensor2) {
tensor1 = ToCpuTensor(tensor1);
tensor2 = ToCpuTensor(tensor2);
if (torch::isnan(tensor1).any().item<bool>()) {
EXPECT_TRUE(EqualValues(torch::isnan(tensor1), torch::isnan(tensor2)));
tensor1.nan_to_num_();
tensor2.nan_to_num_();
}
if (tensor1.sizes() != tensor2.sizes() ||
tensor1.dtype() != tensor2.dtype()) {
std::cerr << "Different shape:\n"
<< tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n"
<< tensor2.dtype() << " " << tensor2.sizes() << "\n";
return false;
}
at::ScalarType type1 = tensor1.scalar_type();
at::ScalarType type2 = tensor2.scalar_type();
if (type1 != type2) {
tensor1 = tensor1.toType(type2);
}
bool equal = tensor1.equal(tensor2);
return equal;
}
bool EqualValuesNoElementTypeCheck(at::Tensor tensor1, at::Tensor tensor2) {
tensor1 = ToCpuTensor(tensor1);
tensor2 = ToCpuTensor(tensor2);
if (tensor1.sizes() != tensor2.sizes()) {
std::cerr << "Different shape:\n"
<< tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n"
<< tensor2.dtype() << " " << tensor2.sizes() << "\n";
return false;
}
at::ScalarType type1 = tensor1.scalar_type();
at::ScalarType type2 = tensor2.scalar_type();
if (type1 != type2) {
tensor1 = tensor1.toType(type2);
}
bool equal = tensor1.equal(tensor2);
return equal;
}
void ForEachDevice(const std::function<void(const torch::Device&)>& devfn) {
// Currently TorchScript backend only supports one type of hardware per
// process, which is set by env. And the ordinal is always 0 given distributed
// training/ multi-device is not supported yet.
auto device = torch::lazy::BackendDevice();
torch::Device torch_device = torch::lazy::backendDeviceToAtenDevice(device);
devfn(torch_device);
}
bool CloseValues(
at::Tensor tensor1,
at::Tensor tensor2,
double rtol,
double atol) {
tensor1 = ToCpuTensor(tensor1);
tensor2 = ToCpuTensor(tensor2);
if (torch::isnan(tensor1).any().item<bool>()) {
EXPECT_TRUE(EqualValues(torch::isnan(tensor1), torch::isnan(tensor2)));
tensor1.nan_to_num_();
tensor2.nan_to_num_();
}
if (tensor1.sizes() != tensor2.sizes() ||
tensor1.dtype() != tensor2.dtype()) {
std::cerr << "Different shape:\n"
<< tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n"
<< tensor2.dtype() << " " << tensor2.sizes() << "\n";
return false;
}
bool equal = tensor1.allclose(tensor2, rtol, atol);
return equal;
}
std::string GetTensorTextGraph(at::Tensor tensor) {
torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor);
return torch::lazy::DumpUtil::ToText({lazy_tensor->GetIrValue().node.get()});
}
std::string GetTensorDotGraph(at::Tensor tensor) {
torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor);
return torch::lazy::DumpUtil::ToDot({lazy_tensor->GetIrValue().node.get()});
}
void TestBackward(
const std::vector<torch::Tensor>& inputs,
const torch::Device& device,
const std::function<torch::Tensor(const std::vector<torch::Tensor>&)>&
testfn,
double rtol,
double atol,
int derivative_level) {
std::vector<torch::Tensor> input_vars;
std::vector<torch::Tensor> xinput_vars;
std::vector<torch::Tensor> inputs_w_grad;
std::vector<torch::Tensor> xinputs_w_grad;
for (size_t i = 0; i < inputs.size(); ++i) {
const torch::Tensor& input = inputs[i];
if (input.defined()) {
torch::Tensor oinput =
input.clone().detach().set_requires_grad(input.requires_grad());
input_vars.push_back(oinput);
torch::Tensor xinput = CopyToDevice(input, device)
.detach()
.set_requires_grad(input.requires_grad());
xinput_vars.push_back(xinput);
if (input.requires_grad()) {
inputs_w_grad.push_back(oinput);
xinputs_w_grad.push_back(xinput);
}
} else {
input_vars.emplace_back();
xinput_vars.emplace_back();
}
}
torch::Tensor output = testfn(input_vars);
torch::Tensor xoutput = testfn(xinput_vars);
torch::lazy::AllClose(output, xoutput, rtol, atol);
std::vector<torch::Tensor> outs = {output};
std::vector<torch::Tensor> xouts = {xoutput};
for (int d = 1; d <= derivative_level; ++d) {
// Check grad of sum(outs) w.r.t inputs_w_grad.
torch::Tensor sum = torch::zeros_like(outs[0]).sum();
torch::Tensor xsum = torch::zeros_like(xouts[0]).sum();
for (size_t i = 0; i < outs.size(); ++i) {
if (outs[i].requires_grad()) {
sum += outs[i].sum();
xsum += xouts[i].sum();
}
}
// Calculating higher order derivative requires create_graph=true
bool create_graph = d != derivative_level;
outs = torch::autograd::grad(
{sum},
inputs_w_grad,
/*grad_outputs=*/{},
/*retain_graph=*/c10::nullopt,
/*create_graph=*/create_graph,
/*allow_unused=*/true);
xouts = torch::autograd::grad(
{xsum},
xinputs_w_grad,
/*grad_outputs=*/{},
/*retain_graph=*/c10::nullopt,
/*create_graph=*/create_graph,
/*allow_unused=*/true);
for (size_t i = 0; i < outs.size(); ++i) {
ASSERT_EQ(outs[i].defined(), xouts[i].defined());
if (outs[i].defined()) {
AllClose(outs[i], xouts[i], rtol, atol);
}
}
}
}
} // namespace lazy
} // namespace torch
|