File: xnn_compiler.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (121 lines) | stat: -rw-r--r-- 4,312 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
// Copyright (c) Meta Platforms, Inc. and affiliates.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <caffe2/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h>
#include <torch/csrc/jit/backends/xnnpack/serialization/schema_generated.h>

#include <ATen/Utils.h>

namespace torch {
namespace jit {
namespace xnnpack {
namespace delegate {

void XNNCompiler::compileModel(
    const void* buffer_pointer,
    size_t num_bytes,
    XNNExecutor* executor) {
  auto output_min = -std::numeric_limits<float>::infinity();
  auto output_max = std::numeric_limits<float>::infinity();

  auto flatbuffer_graph = fb_xnnpack::GetXNNGraph(buffer_pointer);
  // initialize xnnpack
  xnn_status status = xnn_initialize(/*allocator =*/nullptr);
  TORCH_CHECK(xnn_status_success == status, "Failed to initialize xnnpack");

  // create xnnpack subgraph
  xnn_subgraph_t subgraph_ptr = nullptr;
  status = xnn_create_subgraph(
      /*external_value_ids=*/flatbuffer_graph->num_externs(),
      /*flags=*/0,
      &subgraph_ptr);
  TORCH_CHECK(xnn_status_success == status, "Failed to create xnn subgraph");

  // mapping from old ids to new created value ids
  // The old ids that were serialied were generated AoT, since
  // we are re-defining tensor values, the defined IDs could be
  // different from the ones generated AoT, as a result, we need
  // a new mapping from the old ids to the newly created ones
  std::unordered_map<uint32_t, uint32_t> remapped_ids;

  for (auto value : *flatbuffer_graph->xvalues()) {
    switch (value->xvalue_type()) {
      case fb_xnnpack::XValueUnion::XNNTensorValue: {
        auto tensor_value = value->xvalue_as_XNNTensorValue();

        std::vector<size_t> dims_data;
        for (auto dim : *tensor_value->dims()) {
          dims_data.push_back(static_cast<size_t>(dim));
        }

        uint32_t id = XNN_INVALID_VALUE_ID;
        const auto& constant_buffer = *flatbuffer_graph->constant_buffer();
        auto buffer_idx = tensor_value->constant_buffer_idx();
        const auto buffer_ptr = buffer_idx == 0
            ? nullptr
            : constant_buffer[buffer_idx]->storage()->data();
        status = xnn_define_tensor_value(
            /*subgraph=*/subgraph_ptr,
            /*datatype=*/xnn_datatype_fp32,
            /*num_dims=*/tensor_value->num_dims(),
            /*dims=*/dims_data.data(),
            /*data=*/buffer_ptr,
            /*external_id=*/tensor_value->external_id(),
            /*flags=*/tensor_value->flags(),
            /*id_out=*/&id);
        TORCH_CHECK(
            status == xnn_status_success,
            "Failed to define tensor values in graph")
        // map serialized id to newly generated id
        remapped_ids.emplace(std::make_pair(tensor_value->id_out(), id));
        break;
      }
      default: {
        TORCH_CHECK(false, "Unhandled value type found in deserialization");
      }
    }
  }

  for (auto node : *flatbuffer_graph->xnodes()) {
    switch (node->xnode_type()) {
      case fb_xnnpack::XNodeUnion::XNNAdd: {
        auto graph_node = node->xnode_as_XNNAdd();
        status = xnn_define_add2(
            subgraph_ptr,
            output_min,
            output_max,
            remapped_ids.at(graph_node->input1_id()),
            remapped_ids.at(graph_node->input2_id()),
            remapped_ids.at(graph_node->output_id()),
            graph_node->flags());
        TORCH_CHECK(status == xnn_status_success, "Failed to create add node")
        break;
      }
      default:
        TORCH_CHECK(false, "Unhandled node type found in deserialization");
    }
  }

  xnn_runtime_t runtime_ptr = nullptr;
  status = xnn_create_runtime_v2(subgraph_ptr, nullptr, 0, &runtime_ptr);
  TORCH_CHECK(xnn_status_success == status);

  executor->runtime_ =
      std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)>(
          runtime_ptr, xnn_delete_runtime);

  for (auto old_id : *flatbuffer_graph->input_ids()) {
    executor->input_ids_.emplace_back(remapped_ids.at(old_id));
  }

  for (auto old_id : *flatbuffer_graph->output_ids()) {
    executor->output_ids_.emplace_back(remapped_ids.at(old_id));
  }
};

} // namespace delegate
} // namespace xnnpack
} // namespace jit
} // namespace torch