1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
|
#include <torch/csrc/jit/passes/update_differentiable_graph_requires_grad.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
namespace torch {
namespace jit {
void UpdateDifferentiableGraphRequiresGrad(
Block* block,
c10::optional<bool> new_requires_grad) {
for (Node* n : block->nodes()) {
if (n->kind() == prim::profile) {
n->ty_(
attr::profiled_type,
n->ty(attr::profiled_type)
->expect<TensorType>()
->withRequiresGrad(new_requires_grad));
}
for (Block* b : n->blocks()) {
UpdateDifferentiableGraphRequiresGrad(b, new_requires_grad);
}
}
}
void UpdateDifferentiableGraphRequiresGrad(
std::shared_ptr<Graph>& diff_forward_graph,
c10::optional<bool> new_requires_grad) {
UpdateDifferentiableGraphRequiresGrad(
diff_forward_graph->block(), new_requires_grad);
}
} // namespace jit
} // namespace torch
|