File: db_throughput.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (104 lines) | stat: -rw-r--r-- 3,256 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
/**
 * Copyright (c) 2016-present, Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <cstdio>
#include <thread>
#include <vector>

#include "caffe2/core/db.h"
#include "caffe2/core/init.h"
#include "caffe2/core/timer.h"
#include "caffe2/core/logging.h"

C10_DEFINE_string(input_db, "", "The input db.");
C10_DEFINE_string(input_db_type, "", "The input db type.");
C10_DEFINE_int(report_interval, 1000, "The report interval.");
C10_DEFINE_int(repeat, 10, "The number to repeat the throughput test.");
C10_DEFINE_bool(use_reader, false, "If true, use the reader interface.");
C10_DEFINE_int(
    num_read_threads,
    1,
    "The number of concurrent reading threads.");

using caffe2::db::Cursor;
using caffe2::db::DB;
using caffe2::db::DBReader;
using caffe2::string;

void TestThroughputWithDB() {
  std::unique_ptr<DB> in_db(caffe2::db::CreateDB(
      FLAGS_input_db_type, FLAGS_input_db, caffe2::db::READ));
  std::unique_ptr<Cursor> cursor(in_db->NewCursor());
  for (int iter_id = 0; iter_id < FLAGS_repeat; ++iter_id) {
    caffe2::Timer timer;
    for (int i = 0; i < FLAGS_report_interval; ++i) {
      string key = cursor->key();
      string value = cursor->value();
      //VLOG(1) << "Key " << key;
      cursor->Next();
      if (!cursor->Valid()) {
        cursor->SeekToFirst();
      }
    }
    double elapsed_seconds = timer.Seconds();
    printf(
        "Iteration %03d, took %4.5f seconds, throughput %f items/sec.\n",
        iter_id,
        elapsed_seconds,
        FLAGS_report_interval / elapsed_seconds);
  }
}

void TestThroughputWithReaderWorker(const DBReader* reader, int thread_id) {
  string key, value;
  for (int iter_id = 0; iter_id < FLAGS_repeat; ++iter_id) {
    caffe2::Timer timer;
    for (int i = 0; i < FLAGS_report_interval; ++i) {
      reader->Read(&key, &value);
    }
    double elapsed_seconds = timer.Seconds();
    printf(
        "Thread %03d iteration %03d, took %4.5f seconds, "
        "throughput %f items/sec.\n",
        thread_id,
        iter_id,
        elapsed_seconds,
        FLAGS_report_interval / elapsed_seconds);
  }
}

void TestThroughputWithReader() {
  caffe2::db::DBReader reader(FLAGS_input_db_type, FLAGS_input_db);
  std::vector<std::unique_ptr<std::thread>> reading_threads(
      FLAGS_num_read_threads);
  for (int i = 0; i < reading_threads.size(); ++i) {
    reading_threads[i].reset(new std::thread(
        TestThroughputWithReaderWorker, &reader, i));
  }
  for (int i = 0; i < reading_threads.size(); ++i) {
    reading_threads[i]->join();
  }
}

int main(int argc, char** argv) {
  caffe2::GlobalInit(&argc, &argv);
  if (FLAGS_use_reader) {
    TestThroughputWithReader();
  } else {
    TestThroughputWithDB();
  }
  return 0;
}