File: grad_layout_contract.h

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (76 lines) | stat: -rw-r--r-- 2,822 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#pragma once

#include <ATen/Tensor.h>

namespace torch::autograd::utils {

// Helper functions to enforce the "Gradient Layout Contract" described in
// torch/csrc/autograd/functions/accumulate_grad.h.

// Checks if grad obeys the contract with variable.
inline bool obeys_layout_contract(
    const at::Tensor& grad,
    const at::Tensor& variable) {
  TORCH_INTERNAL_ASSERT(!grad.is_sparse());
  TORCH_INTERNAL_ASSERT(!grad.is_sparse_csr());
  TORCH_INTERNAL_ASSERT(!variable.is_sparse_csr());

  // NOLINTNEXTLINE(bugprone-branch-clone)
  if (variable.is_nested()) {
    // TODO: Nested Tensor does not have an implementation of detach. The
    // current implementation of nested tensor likely does obey the gradient
    // contract and should return true, but this would likely change in the
    // future
    return false;
  } else if (variable.is_sparse()) {
    // Gradient Layout Contract is not applicable for sparse layouts
    return false;
  } else if (variable.is_non_overlapping_and_dense()) {
    // Only look at stride for dimensions that are not of size 1.
    const auto& grad_sizes = grad.sym_sizes();
    const auto& grad_strides = grad.sym_strides();
    const auto& variable_strides = variable.sym_strides();
    for (const auto idx : c10::irange(grad_sizes.size())) {
      if (grad_sizes[idx] != 1) {
        if (grad_strides[idx] != variable_strides[idx]) {
          return false;
        }
      } else {
        // This should not be needed but we don't check if a Tensor has views
        // before stashing it. And 0-strided Tensors of size 1 are actually
        // views for ops like cat.
        // TODO: Actually detect views in the accumulateGrad function so that
        // this Tensor is not considered at all.
        if (grad_strides[idx] == 0) {
          return false;
        }
      }
    }
    return true;
  } else {
    return grad.is_contiguous(at::MemoryFormat::Contiguous);
  }
}

// Creates a clone of new_grad that obeys the contract with variable.
// The clone should attach to new_grad's history if GradMode::is_enabled().
inline at::Tensor clone_obey_contract(
    const at::Tensor& new_grad,
    const at::Tensor& variable) {
  if (variable.is_non_overlapping_and_dense()) {
    // (1)
    // Does this dicey-looking sequence attach the result to new_grad's
    // history if GradMode::is_enabled()?  Yes, and @alband says it should.
    return std::move(new_grad
                         .new_empty_strided_symint(
                             variable.sym_sizes(),
                             variable.sym_strides(),
                             variable.options().memory_format(std::nullopt))
                         .copy_(new_grad));
  } else {
    // (2)
    return new_grad.clone(at::MemoryFormat::Contiguous);
  }
}

} // namespace torch::autograd::utils