File: value_refinement_utils.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (82 lines) | stat: -rw-r--r-- 2,669 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#pragma once

#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/peephole.h>
#include <torch/csrc/jit/passes/peephole_list_idioms.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/utils/memory.h>

namespace torch {
namespace jit {

// Refine from Value of type List -> len of list
// If a refinement mapping of List Value * -> len is present in a block
// the list is guaranteed to be that length
// TODO: vector may be faster
using ListRefinement = std::unordered_map<Value*, int64_t>;

TORCH_API ListRefinement
intersectRefinements(const ListRefinement& ref1, const ListRefinement& ref2);

TORCH_API ListRefinement
unionRefinements(const ListRefinement& ref1, const ListRefinement& ref2);

// Represents the refinement information that can be carried on a boolean
struct BooleanRefinementMapping {
  BooleanRefinementMapping(
      ListRefinement true_refine,
      ListRefinement false_refine)
      : true_refine_(std::move(true_refine)),
        false_refine_(std::move(false_refine)){};
  BooleanRefinementMapping() = default; // empty

  static BooleanRefinementMapping FalseRefinements(
      ListRefinement false_refine) {
    return BooleanRefinementMapping({}, std::move(false_refine));
  }

  static BooleanRefinementMapping TrueRefinements(ListRefinement true_refine) {
    return BooleanRefinementMapping(std::move(true_refine), {});
  }

  BooleanRefinementMapping intersectBooleanRefinementMapping(
      BooleanRefinementMapping& other) {
    return BooleanRefinementMapping(
        intersectRefinements(true_refine_, other.true_refine()),
        intersectRefinements(false_refine_, other.false_refine()));
  }

  ListRefinement& true_refine() {
    return true_refine_;
  }

  ListRefinement& false_refine() {
    return false_refine_;
  }

 private:
  ListRefinement true_refine_;
  ListRefinement false_refine_;
};

TORCH_API void joinIfRefinements(
    Node* if_node,
    std::unordered_set<Block*>& throwing_blocks,
    ListRefinement& curr_block_refinements,
    ListRefinement& true_block_refinements,
    ListRefinement& false_block_refinements,
    std::unordered_map<Value*, BooleanRefinementMapping>& info);

// handles adding blocks to throwing blocks and propagating refinements via
// boolean comparisons
TORCH_API bool handleCommonRefinentOperators(
    Node* n,
    std::unordered_set<Block*>& throwing_blocks,
    std::unordered_map<Value*, BooleanRefinementMapping>& info);

} // namespace jit
} // namespace torch