1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
|
#pragma once
#include <torch/csrc/autograd/variable.h>
namespace torch::autograd {
struct TORCH_API VariableInfo {
explicit VariableInfo();
explicit VariableInfo(const Variable& var, bool use_zeros_like = false);
Variable zeros(at::OptionalDeviceGuard& device_guard) const;
at::Layout layout = at::Layout::Strided;
at::Device device = at::kCPU;
at::ScalarType scalar_type = at::kFloat;
std::vector<c10::SymInt> size;
bool requires_grad;
bool is_empty;
// needed for e.g. NJTs since they only support zeros_like()
std::optional<Variable> the_var;
};
} // namespace torch::autograd
|