1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
|
#pragma once
#include <cstdio>
#include <set>
#include <string>
#include <type_traits>
namespace caffe2 {
template <typename T>
void StoreMatrixInMatrixMarketFormat(
int m,
int n,
const T* a,
const std::string& matrix_name) {
using namespace std;
static set<string> dumped_matrix_names;
string name(matrix_name);
string::size_type pos = name.rfind('/');
if (pos != string::npos) {
name = name.substr(pos + 1);
}
if (dumped_matrix_names.find(name) == dumped_matrix_names.end()) {
dumped_matrix_names.insert(name);
FILE* fp = fopen((matrix_name + ".mtx").c_str(), "w");
if (!fp) {
return;
}
if (is_integral<T>::value) {
fprintf(fp, "%%%%MatrixMarket matrix array integer general\n");
} else {
fprintf(fp, "%%%%MatrixMarket matrix array real general\n");
}
fprintf(fp, "%d %d\n", m, n);
// matrix market array format uses column-major order
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
if (is_integral<T>::value) {
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
fprintf(fp, "%d\n", static_cast<int>(a[j * m + i]));
} else {
fprintf(fp, "%f\n", static_cast<float>(a[j * m + i]));
}
}
}
fclose(fp);
}
}
} // namespace caffe2
|