File: symbolic_shape_runtime_fusion.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (55 lines) | stat: -rw-r--r-- 2,392 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#pragma once

#include <torch/csrc/Export.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>

#include <unordered_map>

namespace torch {
namespace jit {

// Takes in a TensorExprGraph of static shapes and generalizes the input shapes
// to symbolic dimensions. Dimensions of value 1 will be preserved, otherwise
// dimensions with the same value will be bucketed to the same symbolic shape.
// E.g. Tensor(5, 3), Tensor(3, 1) -> Tensor(SS(-1), SS(-2)), Tensor(SS(-2), 1)
// From there, runs symbolic shape inference on the graph, and creates a
// versionining if in the graph with prim::TensorExprDynamicGuard checking if
// the inputs at runtime match the Generalized Symbolic Shapes that are inputs
// to the TE Kernel. The computate to calculate all symbolic dimensions is
// inlined in to the if block with the TE Kernel. All Sym Dim Value* are
// appended to the end of the TE Kernel Graph/Node inputs, and the Node is
// augmented with a integer list attr `symbolic_shape_inputs` that gives the
// mapping from Value * -> Symbolic Shape int64_t value. For more lengthy IR
// examples and walkthrough look at ShapeAnalysisTest.DynamicShapesFusion in
// `test_shape_analysis` Returns True on Success, False on Failure, can fail if
// shape propagation fails to propagate # of dims or if complete shapes on
// inputs not set

TORCH_API bool GenerateGuard(
    Node* tensorexpr_graph_node,
    bool add_composed_op = false);

TORCH_API void runTensorExprDynamicGroup(const Code& code, Stack& stack);

enum class StrideInput {
  // Tensors natively store whether they are contiguous or not as a property
  // this makes it faster to query `is_contiguous` or
  // `is_contiguous(memory_format=channels_last)`
  // than looping through the sizes/strides yourself
  // For tensors with these properties, we only store one value:
  TENSOR_CONT,
  TENSOR_CONT_CHANNELS_LAST,
  // now, we describe other cases, where there is one stride enum
  // per dimension
  S_ONE, // STRIDE_ONE: packed
  S_CONT, // STRIDE_CONTIGUOUS: stride[i + 1] * sizes[i + 1]
  S_TRAN_CONT, // STRIDE_TRANSPOSED_CONTIGUOUS: stride[i-1] * sizes[i-1]
  S_AS_ARG, // STRIDE_AS_ARG: stride passed in as runtime value
};

TORCH_API std::string toString(StrideInput si);
TORCH_API StrideInput strideInputFromString(const std::string& si);

} // namespace jit
} // namespace torch