1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
|
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CAFFE2_CORE_NET_ASYNC_TRACING_H_
#define CAFFE2_CORE_NET_ASYNC_TRACING_H_
#include "caffe2/core/common.h"
#include "caffe2/core/net_async_base.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/timer.h"
C10_DECLARE_string(caffe2_net_async_tracing_filepath);
C10_DECLARE_string(caffe2_net_async_names_to_trace);
C10_DECLARE_int(caffe2_net_async_tracing_nth);
namespace caffe2 {
namespace tracing {
struct TORCH_API TracerEvent {
int op_id_ = -1;
int task_id_ = -1;
int stream_id_ = -1;
const char* name_ = nullptr;
const char* category_ = nullptr;
long timestamp_ = -1.0;
bool is_beginning_ = false;
long thread_label_ = -1;
std::thread::id tid_;
int iter_ = -1;
};
enum TracingField {
TRACE_OP,
TRACE_TASK,
TRACE_STREAM,
TRACE_THREAD,
TRACE_NAME,
TRACE_CATEGORY,
TRACE_ITER,
};
enum class TracingMode {
EVERY_K_ITERATIONS,
GLOBAL_TIMESLICE,
};
struct TracingConfig {
TracingMode mode{TracingMode::EVERY_K_ITERATIONS};
std::string filepath{"/tmp"};
// for TracingMode::EVERY_K_ITERATIONS
int64_t trace_every_nth_batch = 100;
int64_t dump_every_nth_batch = 10000;
// for TracingMode::GLOBAL_TIMESLICE
int64_t trace_every_n_ms = 2 * 60 * 1000; // 2min
int64_t trace_for_n_ms = 1000; // 1sec
};
class TORCH_API Tracer {
public:
Tracer(
const NetBase* net,
const std::string& net_name,
TracingConfig = TracingConfig{});
void recordEvent(const TracerEvent& event);
std::string opTraceName(const OperatorBase* op);
std::string opBlobsInfo(const OperatorBase& op);
std::string serializeEvent(const TracerEvent& event);
void linearizeEvents();
void renameThreads();
void setEnabled(bool enabled);
bool isEnabled() const;
const TracingConfig& config() {
return config_;
}
int bumpIter();
int getIter();
int bumpDumpingIter();
// Dump the tracing result to file with given suffix, and then
// clear current events.
void dumpTracingResultAndClearEvents(const std::string& file_suffix);
virtual ~Tracer();
private:
const NetBase* net_ = nullptr;
std::string filename_;
std::vector<TracerEvent> events_;
std::mutex tracer_mutex_;
bool enabled_ = false;
Timer timer_;
int iter_;
int dumping_iter_;
TracingConfig config_;
friend class TracerGuard;
};
class TORCH_API TracerGuard {
public:
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.UninitializedObject)
TracerGuard() {}
void init(Tracer* tracer);
void addArgument();
void addArgument(TracingField field, const char* value);
void addArgument(TracingField field, int value);
template <typename T, typename... Args>
void addArgument(TracingField field, const T& value, const Args&... args) {
addArgument(field, value);
addArgument(args...);
}
void recordEventStart();
virtual ~TracerGuard();
static TracerGuard* getCurrentTracerGuard();
void disable();
private:
bool enabled_ = false;
TracerEvent event_;
Tracer* tracer_;
};
// Extract the shard id from name of the form "...shard:123..."
// Return -1 if there is no shard found
TORCH_API int extractShardId(const std::string& name);
// Check if the net name is white-listed for tracing (specified via a command
// line flag)
TORCH_API bool isTraceableNetName(const std::string& net_name);
TORCH_API std::shared_ptr<Tracer> create(
const NetBase* net,
const std::string& net_name);
TORCH_API bool startIter(const std::shared_ptr<Tracer>& tracer);
} // namespace tracing
#define TRACE_NAME_CONCATENATE(s1, s2) s1##s2
#define TRACE_ANONYMOUS_NAME(str) TRACE_NAME_CONCATENATE(str, __LINE__)
#define TRACE_EVENT_INIT(...) \
TRACE_ANONYMOUS_NAME(trace_guard).init(tracer_.get()); \
TRACE_ANONYMOUS_NAME(trace_guard).addArgument(__VA_ARGS__); \
TRACE_ANONYMOUS_NAME(trace_guard).recordEventStart();
// Supposed to be used only once per scope in AsyncNetBase-derived nets
#define TRACE_EVENT(...) \
tracing::TracerGuard TRACE_ANONYMOUS_NAME(trace_guard); \
if (tracer_ && tracer_->isEnabled()) { \
TRACE_EVENT_INIT(__VA_ARGS__) \
}
#define TRACE_EVENT_IF(cond, ...) \
tracing::TracerGuard TRACE_ANONYMOUS_NAME(trace_guard); \
if (tracer_ && tracer_->isEnabled() && (cond)) { \
TRACE_EVENT_INIT(__VA_ARGS__) \
}
} // namespace caffe2
#endif // CAFFE2_CORE_NET_ASYNC_TRACING_H_
|