File: main.cpp

package info (click to toggle)
miopen 6.4.3%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 66,788 kB
  • sloc: cpp: 300,511; lisp: 29,731; ansic: 2,683; sh: 471; python: 323; makefile: 155
file content (189 lines) | stat: -rw-r--r-- 6,441 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
#include <sqlite3.h>

#include <functional>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <sstream>

std::unique_ptr<sqlite3, int (*)(sqlite3*)> OpenDb(const char* filename, int flags)
{
    sqlite3* db;
    if(sqlite3_open_v2(filename, &db, flags, nullptr) != SQLITE_OK)
        abort();
    if(db == nullptr)
        abort();
    return {db, &sqlite3_close_v2};
}

std::unique_ptr<sqlite3_stmt, int (*)(sqlite3_stmt*)> PrepareStatement(sqlite3* db,
                                                                       const std::string& sql)
{
    sqlite3_stmt* stmt;
    const char* tail;
    if(sqlite3_prepare_v2(db, sql.c_str(), sql.length(), &stmt, &tail) != SQLITE_OK ||
       stmt == nullptr)
    {
        std::cerr << "Error while preparing SQL statement: " << sqlite3_errmsg(db) << std::endl;
        std::cerr << "Statement: {" << sql << "}" << std::endl;
        abort();
    }
    if(tail != &sql[0] + sql.length())
    {
        std::cerr << "Statement leftover: {" << tail << "}" << std::endl;
        abort();
    }
    return {stmt, &sqlite3_finalize};
}

struct ProblemConfig
{
    int64_t in_d, in_h, in_w;
    int64_t fil_d, fil_h, fil_w;
    int64_t pad_d, pad_h, pad_w;
    int64_t conv_stride_d, conv_stride_h, conv_stride_w;
    int64_t dilation_d, dilation_h, dilation_w;
    int64_t spatial_dim, out_channels, in_channels, batchsize, group_count, bias;
    std::string layout, data_type, direction;

    template <class Self>
    static void Visit(Self&& self, std::function<void(int64_t&, std::string)> f)
    {
        // The column names match the driver command line argument names
        f(self.spatial_dim, "spatial_dim");
        f(self.in_channels, "in_channels");
        f(self.in_h, "in_h");
        f(self.in_w, "in_w");
        f(self.in_d, "in_d");
        f(self.fil_h, "fil_h");
        f(self.fil_w, "fil_w");
        f(self.fil_d, "fil_d");
        f(self.out_channels, "out_channels");
        f(self.batchsize, "batchsize");
        f(self.pad_h, "pad_h");
        f(self.pad_w, "pad_w");
        f(self.pad_d, "pad_d");
        f(self.conv_stride_h, "conv_stride_h");
        f(self.conv_stride_w, "conv_stride_w");
        f(self.conv_stride_d, "conv_stride_d");
        f(self.dilation_h, "dilation_h");
        f(self.dilation_w, "dilation_w");
        f(self.dilation_d, "dilation_d");
        f(self.bias, "bias");
        f(self.group_count, "group_count");
    }

    template <class Self>
    static void Visit(Self&& self, std::function<void(std::string&, std::string)> f)
    {
        f(self.layout, "layout");
        f(self.data_type, "data_type");
        f(self.direction, "direction");
    }

    template <class Self, class Visitor>
    static void VisitAll(Self&& self, const Visitor& f)
    {
        Visit(std::forward<Self>(self), [&](int64_t& value, std::string name) { f(value, name); });
        Visit(std::forward<Self>(self),
              [&](std::string& value, std::string name) { f(value, name); });
    }

    [[nodiscard]] static const std::string& GetFieldNames()
    {
        static const std::string value = []() {
            std::ostringstream ss;
            ProblemConfig::VisitAll(ProblemConfig{}, [&](auto&&, auto name) {
                if(ss.tellp() != 0)
                    ss << ", ";
                ss << name;
            });
            return ss.str();
        }();
        return value;
    }

    [[nodiscard]] std::string Serialize()
    {
        std::ostringstream ss;
        ProblemConfig::VisitAll(*this, [&](auto&& value, auto&&) {
            if(ss.tellp() != 0)
                ss << "x";
            ss << value;
        });
        return ss.str();
    }
};

int main(int argn, char** args)
{
    if(argn < 2 || argn > 3)
    {
        std::cerr << "Usage:" << std::endl;
        std::cerr << args[0] << " input_path [output_path]" << std::endl;
        std::cerr << "input_path - path to the input file, expected to be sqlite3 db." << std::endl;
        std::cerr << "output_path - optional path to the output file. Existing file would be "
                     "replaced. Defaults to the input_path with .txt appended to the end"
                  << std::endl;
    }

    const std::string in_filename  = args[1];
    const std::string out_filename = argn > 2 ? args[2] : (in_filename + ".txt");
    constexpr const int db_flags   = SQLITE_OPEN_READONLY;

    const auto select_query = "SELECT solver, params, " + ProblemConfig::GetFieldNames() +
                              " FROM perf_db "
                              "INNER JOIN config ON perf_db.config = config.id";

    const auto db   = OpenDb(in_filename.c_str(), db_flags);
    const auto stmt = PrepareStatement(db.get(), select_query);
    auto db_content = std::unordered_map<std::string, std::string>{};

    for(int step_result = sqlite3_step(stmt.get()); step_result != SQLITE_DONE;
        step_result     = sqlite3_step(stmt.get()))
    {
        if(step_result == SQLITE_BUSY)
        {
            sqlite3_sleep(10);
            continue;
        }

        if(step_result == SQLITE_ERROR)
        {
            std::cerr << sqlite3_errmsg(db.get()) << std::endl;
            abort();
        }

        if(step_result == SQLITE_MISUSE)
            abort();

        int col             = 0;
        std::string solver  = reinterpret_cast<const char*>(sqlite3_column_text(stmt.get(), col++));
        std::string perfcgf = reinterpret_cast<const char*>(sqlite3_column_text(stmt.get(), col++));
        ProblemConfig problem;

        ProblemConfig::VisitAll(problem, [&](auto& value, auto) {
            if constexpr(std::is_convertible_v<decltype(value), int>)
                value = sqlite3_column_int(stmt.get(), col++);
            else if constexpr(std::is_convertible_v<decltype(value), std::string>)
                value = reinterpret_cast<const char*>(sqlite3_column_text(stmt.get(), col++));
            else
                static_assert(false, "unsupported type");
        });

        if(sqlite3_column_count(stmt.get()) != col)
            abort();

        auto& record = db_content[problem.Serialize()];
        if(!record.empty())
            record.append(";");
        record.append(solver).append(":").append(perfcgf);
    }

    auto out = std::ofstream{out_filename};
    for(const auto& line : db_content)
        out << line.first << "=" << line.second << std::endl;

    return 0;
}