File: bounds_overlap.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (127 lines) | stat: -rw-r--r-- 4,486 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#pragma once

#include <torch/csrc/jit/tensorexpr/expr.h>
#include <torch/csrc/jit/tensorexpr/ir.h>

#include <deque>
#include <vector>

namespace torch {
namespace jit {
namespace tensorexpr {
namespace analysis {

// A simple class containing the start and end of a range in a single dimension.
struct TORCH_API Bound {
  ExprPtr start{nullptr};
  ExprPtr end{nullptr};

  // This stores whether or not the start and end of this Bound have previously
  // been swapped. This occurs when the bound is in a loop with a negative
  // stride.
  bool swapped{false};

  Bound() = default;
  Bound(ExprPtr s, ExprPtr e) : start(s), end(e) {}

  void print() const;
  bool equals(const Bound& other) const;

  // The comparison operators are conservative. If the compare operator returns
  // true, it means that all the elements satisfy the logical expression. But
  // the false does not mean the opposite comparison is satisfied. It could be
  // but not always.
  bool operator==(const Bound& other) const;
  bool operator!=(const Bound& other) const;
  bool operator<(const Bound& other) const;
  bool operator<=(const Bound& other) const;
  bool operator>(const Bound& other) const;
  bool operator>=(const Bound& other) const;

  void swap() {
    std::swap(start, end);
    swapped = !swapped;
  }
};

struct BoundHash {
  size_t operator()(const Bound& b) const {
    return std::hash<ExprPtr>()(b.start) ^ std::hash<ExprPtr>()(b.end);
  }
};

// The type of overlap found. Each condition is true only if none of the
// previous conditions hold.
//     ContainedOrEqual: All elements in the Bound A are in the Bound B (this
//                       includes the case where the bounds are equal).
//     Contains: All elements in the Bound B are in the Bound B.
//     PartialOverlap: Any elements in the Bound B are in the Bound A.
//     NoOverlap: No elements in the Bound A are in the bound B.
enum class OverlapKind {
  ContainedOrEqual,
  Contains,
  PartialOverlap,
  NoOverlap
};

// The Bound comparison result.
//     True: Every Bound element always satisfies the given comparison operator
//     False: Every Bound element always does NOT satisfy the given comparison
//     operator
//     NotDetermined: Some elements satisfy the given comparison operator and
//     some elements not
enum class CmpEvalResult { True, False, NotDetermined };

// Returns the kind of overlap between Bound A and Bound A in a single
// dimension.
OverlapKind TORCH_API boundOverlap(Bound A, Bound B);

// The comparison is conservative and the compare result is deterministic.
// It means that every element of the Bound to be compared needs to satisfiy
// the given comparison operator.
CmpEvalResult TORCH_API compareBound(
    const Bound& a,
    const Bound& b,
    const CompareSelectOperation& cmp_op);

// A multi dimensional bound representing the bound of a set of indices.
using IndexBounds = std::vector<Bound>;

// Returns true if two IndexBounds are equivalent.
bool TORCH_API indexBoundsEquals(const IndexBounds& A, const IndexBounds& B);

// Flattens a multi dimensional bound to a single dimension. The IndexBounds "a"
// *must* encapsulate the entire range of the buffer.
Bound TORCH_API flattenBounds(const IndexBounds& a);

// Determines the kind of overlap in X dimensions.
OverlapKind TORCH_API overlaps(const IndexBounds& a, const IndexBounds& b);

// Returns the Bound slices created by subtracing bound B from bound A.
// Multiple Bounds can be returned in the case where B slices A into two
// distinct regions with no overlap.
//
// For example:
//    subtractBound((0, 10), (2, 4)) => [(0, 1), (5, 10)]
//       bound A: (0, 10)
//       bound B: (2, 4)
//       If we remove slice (2, 4) from the slice (0, 10), we will be left
//       with 2 slices, one at the start (0, 1), and one at the end (5, 10).
//       So, the result of this subtraction is [(0, 1), (5, 10)].
//
// Note: this doesn't use IndexBounds because the Bounds returned do not
// represent multiple different dimensions.
std::vector<Bound> TORCH_API subtractBound(Bound a, Bound b);

// Returns the bound slices created by subtracting the IndexBounds B from A.
std::vector<IndexBounds> TORCH_API subtractIndicesBounds(
    const IndexBounds& A,
    const IndexBounds& B,
    OverlapKind overlap);
std::vector<IndexBounds> TORCH_API
subtractIndicesBounds(const IndexBounds& A, const IndexBounds& B);

} // namespace analysis
} // namespace tensorexpr
} // namespace jit
} // namespace torch