File: tensor_flatten.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (127 lines) | stat: -rw-r--r-- 3,828 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include <torch/csrc/utils/tensor_flatten.h>

#include <map>
#include <unordered_map>

namespace torch {
namespace utils {

using namespace at;

std::vector<TensorGroup> take_tensors(
    TensorList tensors,
    size_t size_limit,
    bool fine_grained) {
  std::vector<TensorGroup> results;
  // an overapproximation, but at least we won't have to copy stuff around
  results.reserve(tensors.size());
  std::map<int64_t, TensorGroup> groups;
  size_t cur_group_size = 0;

  for (const auto& tensor : tensors) {
    size_t tensor_size = 0;
    if (tensor.is_sparse()) {
      const auto& indices = tensor._indices();
      const auto& values = tensor._values();
      tensor_size = indices.numel() * indices.element_size() +
          values.numel() * indices.element_size();
    } else {
      tensor_size = tensor.numel() * tensor.element_size();
    }

    auto& type_group = groups[type_id(tensor)];
    type_group.tensors.push_back(tensor);

    if (fine_grained) {
      cur_group_size += tensor_size;
      // Regardless the type, the current total size exceeds the limit
      if (cur_group_size >= size_limit) {
        // Spill all types to separate groups in results
        for (auto& entry : groups) {
          auto& group = entry.second;
          results.emplace_back(std::move(group));
        }
        cur_group_size = 0;
        groups.clear();
      }
    } else {
      type_group.size += tensor_size;
      if (type_group.size >= size_limit) {
        results.emplace_back();
        std::swap(results.back(), type_group);
      }
    }
  }
  // End case. Look for any remaining groups and return them.
  for (auto& entry : groups) {
    auto& group = entry.second;
    if (group.tensors.empty()) {
      continue;
    }
    results.emplace_back(std::move(group));
  }
  return results;
}

void reorder_tensors_like(std::vector<Tensor>& tensors, TensorList order) {
  AT_ASSERT(tensors.size() == order.size());
  std::unordered_map<size_t, std::vector<size_t>> type_id_to_indices;
  for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i)
    type_id_to_indices[type_id(tensors[i])].push_back(i);

  std::unordered_map<size_t, size_t> type_id_to_type_used;
  std::vector<Tensor> ordered_tensors;
  ordered_tensors.reserve(tensors.size());
  for (auto& tmpl_tensor : order) {
    size_t tmpl_type_id = type_id(tmpl_tensor);
    auto& indices = type_id_to_indices[tmpl_type_id];
    auto& used = type_id_to_type_used[tmpl_type_id];
    ordered_tensors.push_back(tensors[indices[used++]]);
  }
  std::swap(tensors, ordered_tensors);
}

namespace {

at::Tensor get_indices(const at::Tensor& t) {
  return t._indices();
}

at::Tensor get_values(const at::Tensor& t) {
  return t._values();
}

} // namespace

std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(
    at::TensorList tensors) {
  auto flat_indices = utils::flatten_dense_tensors(fmap(tensors, &get_indices));
  auto flat_values = utils::flatten_dense_tensors(fmap(tensors, &get_values));
  return std::make_pair(flat_indices, flat_values);
}

std::vector<at::Tensor> unflatten_sparse_tensors(
    const at::Tensor& flat_indices,
    const at::Tensor& flat_values,
    at::TensorList tensors) {
  if (tensors.size() == 0)
    return {};

  auto indices =
      utils::unflatten_dense_tensors(flat_indices, fmap(tensors, &get_indices));
  auto values =
      utils::unflatten_dense_tensors(flat_values, fmap(tensors, &get_values));

  std::vector<at::Tensor> outputs;
  outputs.reserve(tensors.size());
  for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i) {
    auto& ref_t = tensors[i];
    auto t =
        at::_sparse_coo_tensor_unsafe(indices[i], values[i], ref_t.sizes());
    outputs.emplace_back(t._coalesced_(ref_t.is_coalesced()));
  }
  return outputs;
}

} // namespace utils
} // namespace torch