File: test_util.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (120 lines) | stat: -rw-r--r-- 3,129 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#ifndef NOM_TESTS_TEST_UTIL_H
#define NOM_TESTS_TEST_UTIL_H

#include "caffe2/core/common.h"
#include "nomnigraph/Graph/Graph.h"
#include "nomnigraph/Graph/Algorithms.h"
#include "nomnigraph/Representations/NeuralNet.h"
#include "nomnigraph/Converters/Dot.h"

#include <map>

class TestClass {
public:
  TestClass() {}
  ~TestClass() {}
};

struct NNEquality {
  static bool equal(
      const typename nom::repr::NNGraph::NodeRef& a,
      const typename nom::repr::NNGraph::NodeRef& b) {
    if (
        !nom::repr::nn::is<nom::repr::NeuralNetOperator>(a) ||
        !nom::repr::nn::is<nom::repr::NeuralNetOperator>(b)) {
      return false;
    }
    auto a_ = nom::repr::nn::get<nom::repr::NeuralNetOperator>(a);
    auto b_ = nom::repr::nn::get<nom::repr::NeuralNetOperator>(b);

    bool sameKind = a_->getKind() == b_->getKind();
    if (sameKind && a_->getKind() == nom::repr::NeuralNetOperator::NNKind::GenericOperator) {
      return a_->getName() == b_->getName();
    }
    return sameKind;
  }
};

// Very simple random number generator used to generate platform independent
// random test data.
class TestRandom {
 public:
  TestRandom(unsigned int seed) : seed_(seed){};

  unsigned int nextInt() {
    seed_ = A * seed_ + C;
    return seed_;
  }

 private:
  static const unsigned int A = 1103515245;
  static const unsigned int C = 12345;
  unsigned int seed_;
};

/** Our test graph looks like this:
 *           +-------+
 *           | entry |
 *           +-------+
 *             |
 *             |
 *             v
 *           +-------+
 *           |   1   |
 *           +-------+
 *             |
 *             |
 *             v
 * +---+     +-------+
 * | 4 | <-- |   2   |
 * +---+     +-------+
 *   |         |
 *   |         |
 *   |         v
 *   |       +-------+
 *   |       |   3   |
 *   |       +-------+
 *   |         |
 *   |         |
 *   |         v
 *   |       +-------+
 *   +-----> |   6   |
 *           +-------+
 *             |
 *             |
 *             v
 * +---+     +-------+
 * | 5 | --> |   7   |
 * +---+     +-------+
 *             |
 *             |
 *             v
 *           +-------+
 *           | exit  |
 *           +-------+
 *
 * Here is the code used to generate the dot file for it:
 *
 *  auto str = nom::converters::convertToDotString(&graph,
 *    [](nom::Graph<std::string>::NodeRef node) {
 *      std::map<std::string, std::string> labelMap;
 *      labelMap["label"] = node->data();
 *      return labelMap;
 *    });
 */
TORCH_API nom::Graph<std::string> createGraph();

TORCH_API nom::Graph<std::string> createGraphWithCycle();

std::map<std::string, std::string> BBPrinter(typename nom::repr::NNCFGraph::NodeRef node);

std::map<std::string, std::string> cfgEdgePrinter(typename nom::repr::NNCFGraph::EdgeRef edge);

std::map<std::string, std::string> NNPrinter(typename nom::repr::NNGraph::NodeRef node);

TORCH_API nom::Graph<TestClass>::NodeRef createTestNode(
    nom::Graph<TestClass>& g);

TORCH_API std::map<std::string, std::string> TestNodePrinter(
    nom::Graph<TestClass>::NodeRef node);
#endif // NOM_TESTS_TEST_UTIL_H