1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
|
#include <torch/csrc/utils/tensor_flatten.h>
#include <map>
#include <unordered_map>
namespace torch {
namespace utils {
using namespace at;
std::vector<TensorGroup> take_tensors(
TensorList tensors,
size_t size_limit,
bool fine_grained) {
std::vector<TensorGroup> results;
// an overapproximation, but at least we won't have to copy stuff around
results.reserve(tensors.size());
std::map<int64_t, TensorGroup> groups;
size_t cur_group_size = 0;
for (const auto& tensor : tensors) {
size_t tensor_size = 0;
if (tensor.is_sparse()) {
const auto& indices = tensor._indices();
const auto& values = tensor._values();
tensor_size = indices.numel() * indices.element_size() +
values.numel() * indices.element_size();
} else {
tensor_size = tensor.numel() * tensor.element_size();
}
auto& type_group = groups[type_id(tensor)];
type_group.tensors.push_back(tensor);
if (fine_grained) {
cur_group_size += tensor_size;
// Regardless the type, the current total size exceeds the limit
if (cur_group_size >= size_limit) {
// Spill all types to separate groups in results
for (auto& entry : groups) {
auto& group = entry.second;
results.emplace_back(std::move(group));
}
cur_group_size = 0;
groups.clear();
}
} else {
type_group.size += tensor_size;
if (type_group.size >= size_limit) {
results.emplace_back();
std::swap(results.back(), type_group);
}
}
}
// End case. Look for any remaining groups and return them.
for (auto& entry : groups) {
auto& group = entry.second;
if (group.tensors.empty()) {
continue;
}
results.emplace_back(std::move(group));
}
return results;
}
void reorder_tensors_like(std::vector<Tensor>& tensors, TensorList order) {
AT_ASSERT(tensors.size() == order.size());
std::unordered_map<size_t, std::vector<size_t>> type_id_to_indices;
for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i)
type_id_to_indices[type_id(tensors[i])].push_back(i);
std::unordered_map<size_t, size_t> type_id_to_type_used;
std::vector<Tensor> ordered_tensors;
ordered_tensors.reserve(tensors.size());
for (auto& tmpl_tensor : order) {
size_t tmpl_type_id = type_id(tmpl_tensor);
auto& indices = type_id_to_indices[tmpl_type_id];
auto& used = type_id_to_type_used[tmpl_type_id];
ordered_tensors.push_back(tensors[indices[used++]]);
}
std::swap(tensors, ordered_tensors);
}
namespace {
at::Tensor get_indices(const at::Tensor& t) {
return t._indices();
}
at::Tensor get_values(const at::Tensor& t) {
return t._values();
}
} // namespace
std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(
at::TensorList tensors) {
auto flat_indices = utils::flatten_dense_tensors(fmap(tensors, &get_indices));
auto flat_values = utils::flatten_dense_tensors(fmap(tensors, &get_values));
return std::make_pair(flat_indices, flat_values);
}
std::vector<at::Tensor> unflatten_sparse_tensors(
const at::Tensor& flat_indices,
const at::Tensor& flat_values,
at::TensorList tensors) {
if (tensors.size() == 0)
return {};
auto indices =
utils::unflatten_dense_tensors(flat_indices, fmap(tensors, &get_indices));
auto values =
utils::unflatten_dense_tensors(flat_values, fmap(tensors, &get_values));
std::vector<at::Tensor> outputs;
outputs.reserve(tensors.size());
for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i) {
auto& ref_t = tensors[i];
auto t =
at::_sparse_coo_tensor_unsafe(indices[i], values[i], ref_t.sizes());
outputs.emplace_back(t._coalesced_(ref_t.is_coalesced()));
}
return outputs;
}
} // namespace utils
} // namespace torch
|