1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
|
#include <sqlite3.h>
#include <functional>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <sstream>
std::unique_ptr<sqlite3, int (*)(sqlite3*)> OpenDb(const char* filename, int flags)
{
sqlite3* db;
if(sqlite3_open_v2(filename, &db, flags, nullptr) != SQLITE_OK)
abort();
if(db == nullptr)
abort();
return {db, &sqlite3_close_v2};
}
std::unique_ptr<sqlite3_stmt, int (*)(sqlite3_stmt*)> PrepareStatement(sqlite3* db,
const std::string& sql)
{
sqlite3_stmt* stmt;
const char* tail;
if(sqlite3_prepare_v2(db, sql.c_str(), sql.length(), &stmt, &tail) != SQLITE_OK ||
stmt == nullptr)
{
std::cerr << "Error while preparing SQL statement: " << sqlite3_errmsg(db) << std::endl;
std::cerr << "Statement: {" << sql << "}" << std::endl;
abort();
}
if(tail != &sql[0] + sql.length())
{
std::cerr << "Statement leftover: {" << tail << "}" << std::endl;
abort();
}
return {stmt, &sqlite3_finalize};
}
struct ProblemConfig
{
int64_t in_d, in_h, in_w;
int64_t fil_d, fil_h, fil_w;
int64_t pad_d, pad_h, pad_w;
int64_t conv_stride_d, conv_stride_h, conv_stride_w;
int64_t dilation_d, dilation_h, dilation_w;
int64_t spatial_dim, out_channels, in_channels, batchsize, group_count, bias;
std::string layout, data_type, direction;
template <class Self>
static void Visit(Self&& self, std::function<void(int64_t&, std::string)> f)
{
// The column names match the driver command line argument names
f(self.spatial_dim, "spatial_dim");
f(self.in_channels, "in_channels");
f(self.in_h, "in_h");
f(self.in_w, "in_w");
f(self.in_d, "in_d");
f(self.fil_h, "fil_h");
f(self.fil_w, "fil_w");
f(self.fil_d, "fil_d");
f(self.out_channels, "out_channels");
f(self.batchsize, "batchsize");
f(self.pad_h, "pad_h");
f(self.pad_w, "pad_w");
f(self.pad_d, "pad_d");
f(self.conv_stride_h, "conv_stride_h");
f(self.conv_stride_w, "conv_stride_w");
f(self.conv_stride_d, "conv_stride_d");
f(self.dilation_h, "dilation_h");
f(self.dilation_w, "dilation_w");
f(self.dilation_d, "dilation_d");
f(self.bias, "bias");
f(self.group_count, "group_count");
}
template <class Self>
static void Visit(Self&& self, std::function<void(std::string&, std::string)> f)
{
f(self.layout, "layout");
f(self.data_type, "data_type");
f(self.direction, "direction");
}
template <class Self, class Visitor>
static void VisitAll(Self&& self, const Visitor& f)
{
Visit(std::forward<Self>(self), [&](int64_t& value, std::string name) { f(value, name); });
Visit(std::forward<Self>(self),
[&](std::string& value, std::string name) { f(value, name); });
}
[[nodiscard]] static const std::string& GetFieldNames()
{
static const std::string value = []() {
std::ostringstream ss;
ProblemConfig::VisitAll(ProblemConfig{}, [&](auto&&, auto name) {
if(ss.tellp() != 0)
ss << ", ";
ss << name;
});
return ss.str();
}();
return value;
}
[[nodiscard]] std::string Serialize()
{
std::ostringstream ss;
ProblemConfig::VisitAll(*this, [&](auto&& value, auto&&) {
if(ss.tellp() != 0)
ss << "x";
ss << value;
});
return ss.str();
}
};
int main(int argn, char** args)
{
if(argn < 2 || argn > 3)
{
std::cerr << "Usage:" << std::endl;
std::cerr << args[0] << " input_path [output_path]" << std::endl;
std::cerr << "input_path - path to the input file, expected to be sqlite3 db." << std::endl;
std::cerr << "output_path - optional path to the output file. Existing file would be "
"replaced. Defaults to the input_path with .txt appended to the end"
<< std::endl;
}
const std::string in_filename = args[1];
const std::string out_filename = argn > 2 ? args[2] : (in_filename + ".txt");
constexpr const int db_flags = SQLITE_OPEN_READONLY;
const auto select_query = "SELECT solver, params, " + ProblemConfig::GetFieldNames() +
" FROM perf_db "
"INNER JOIN config ON perf_db.config = config.id";
const auto db = OpenDb(in_filename.c_str(), db_flags);
const auto stmt = PrepareStatement(db.get(), select_query);
auto db_content = std::unordered_map<std::string, std::string>{};
for(int step_result = sqlite3_step(stmt.get()); step_result != SQLITE_DONE;
step_result = sqlite3_step(stmt.get()))
{
if(step_result == SQLITE_BUSY)
{
sqlite3_sleep(10);
continue;
}
if(step_result == SQLITE_ERROR)
{
std::cerr << sqlite3_errmsg(db.get()) << std::endl;
abort();
}
if(step_result == SQLITE_MISUSE)
abort();
int col = 0;
std::string solver = reinterpret_cast<const char*>(sqlite3_column_text(stmt.get(), col++));
std::string perfcgf = reinterpret_cast<const char*>(sqlite3_column_text(stmt.get(), col++));
ProblemConfig problem;
ProblemConfig::VisitAll(problem, [&](auto& value, auto) {
if constexpr(std::is_convertible_v<decltype(value), int>)
value = sqlite3_column_int(stmt.get(), col++);
else if constexpr(std::is_convertible_v<decltype(value), std::string>)
value = reinterpret_cast<const char*>(sqlite3_column_text(stmt.get(), col++));
else
static_assert(false, "unsupported type");
});
if(sqlite3_column_count(stmt.get()) != col)
abort();
auto& record = db_content[problem.Serialize()];
if(!record.empty())
record.append(";");
record.append(solver).append(":").append(perfcgf);
}
auto out = std::ofstream{out_filename};
for(const auto& line : db_content)
out << line.first << "=" << line.second << std::endl;
return 0;
}
|