1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398
|
#include <c10/util/StringUtil.h>
#include <fmt/format.h>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/debug.h>
#include <torch/csrc/distributed/c10d/logger.hpp>
#include <string>
#include <c10/util/CallOnce.h>
#ifdef USE_C10D_GLOO
#include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
#endif
namespace c10d {
// Logs runtime stats to configured destination. Note that since data collection
// only runs every ddp_runtime_logging_sample_rate iterations, the actual
// training iterations recorded will be like 10,
// (20-10) * ddp_runtime_logging_sample_rate,
// (50-10) * ddp_runtime_logging_sample_rate and so on.
const int LoggingIterations[] = {10, 20, 50, 100, 500, 800, 1000}; // NOLINT
std::ostream& operator<<(std::ostream& output, const Logger& logger) {
auto& ddp_logging_data = (*logger.ddp_logging_data_);
std::string loggerInfo = fmt::format(
"[Rank {} / {}] [before iteration {}] Training {} unused_parameter_size={} \n "
"Avg forward compute time: {} \n Avg backward compute time: {} \n"
"Avg backward comm. time: {} \n Avg backward comm/comp overlap time: {}",
ddp_logging_data.ints_map["rank"],
ddp_logging_data.ints_map["world_size"],
ddp_logging_data.ints_map["iteration"],
ddp_logging_data.strs_map["module_name"],
ddp_logging_data.ints_map["unused_parameter_size"],
ddp_logging_data.ints_map["avg_forward_compute_time"],
ddp_logging_data.ints_map["avg_backward_compute_time"],
ddp_logging_data.ints_map["avg_backward_comm_time"],
ddp_logging_data.ints_map["avg_backward_compute_comm_overlap_time"]);
if (ddp_logging_data.strs_map["comm_hook"] != "") {
loggerInfo += fmt::format(
"\n Gradient comm. hook: {}", ddp_logging_data.strs_map["comm_hook"]);
}
if (ddp_logging_data.ints_map["join_uneven_inputs"]) {
loggerInfo += "\n Uneven input detection with join() enabled.";
}
return output << loggerInfo;
}
Logger::Logger(std::shared_ptr<c10d::Reducer> reducer) {
reducer_ = reducer;
ddp_logging_data_ = std::make_unique<at::DDPLoggingData>();
}
c10::once_flag log_graph_static_flag;
void Logger::log_if_graph_static(bool is_static) {
c10::call_once(log_graph_static_flag, [this, is_static]() {
ddp_logging_data_->ints_map["can_set_static_graph"] = is_static;
// It is useful to report the iteration that training finished at.
ddp_logging_data_->ints_map["iteration"] = reducer_->num_iterations_;
at::LogPyTorchDDPUsage(*ddp_logging_data_);
});
}
// Environment variables
void Logger::set_env_variables() {
ddp_logging_data_->strs_map["master_port"] = parse_env("MASTER_PORT");
ddp_logging_data_->strs_map["master_addr"] = parse_env("MASTER_ADDR");
ddp_logging_data_->strs_map["torch_distributed_debug"] =
parse_env("TORCH_DISTRIBUTED_DEBUG");
ddp_logging_data_->strs_map["cuda_visible_devices"] =
parse_env("CUDA_VISIBLE_DEVICES");
if (reducer_->process_group_->getBackendName() == "nccl") {
ddp_logging_data_->strs_map["nccl_socket_ifname"] =
parse_env("NCCL_SOCKET_IFNAME");
ddp_logging_data_->strs_map["nccl_blocking_wait"] =
parse_env("NCCL_BLOCKING_WAIT");
ddp_logging_data_->strs_map["nccl_async_error_handling"] =
parse_env("NCCL_ASYNC_ERROR_HANDLING");
ddp_logging_data_->strs_map["nccl_debug"] = parse_env("NCCL_DEBUG");
ddp_logging_data_->strs_map["nccl_nthreads"] = parse_env("NCCL_NTHREADS");
ddp_logging_data_->strs_map["nccl_ib_timeout"] =
parse_env("NCCL_IB_TIMEOUT");
}
if (reducer_->process_group_->getBackendName() == "gloo") {
ddp_logging_data_->strs_map["gloo_socket_ifname"] =
parse_env("GLOO_SOCKET_IFNAME");
ddp_logging_data_->strs_map["gloo_device_transport"] =
parse_env("GLOO_DEVICE_TRANSPORT");
#ifdef USE_C10D_GLOO
auto gloo_pg =
static_cast<c10d::ProcessGroupGloo*>(reducer_->process_group_.get());
auto n_threads = gloo_pg->getNumThreads();
ddp_logging_data_->ints_map["gloo_num_threads"] = n_threads;
#endif
}
}
void Logger::set_parameter_stats() {
// The number of parameter tensors
ddp_logging_data_->ints_map["num_parameter_tensors"] =
reducer_->params_.size();
// Total parameters size (Bytes)
ddp_logging_data_->ints_map["total_parameter_size_bytes"] = 0;
// Parameters' data types, there may be multiple data
// types for mixed precision training.
std::set<std::string> unique_dtypes;
for (const auto& t : reducer_->params_) {
ddp_logging_data_->ints_map["total_parameter_size_bytes"] +=
t.numel() * t.element_size();
unique_dtypes.insert(std::string(t.dtype().name()));
}
ddp_logging_data_->strs_map["dtypes"] = c10::Join(", ", unique_dtypes);
}
std::vector<std::vector<size_t>> Logger::get_per_bucket_variable_indices() {
std::vector<std::vector<size_t>> per_bucket_variable_indices;
per_bucket_variable_indices.reserve(reducer_->buckets_.size());
for (const auto& bucket : reducer_->buckets_) {
const auto& indices = bucket.variable_indices;
per_bucket_variable_indices.push_back(indices);
}
return per_bucket_variable_indices;
}
std::vector<int64_t> Logger::get_bucket_sizes() {
std::vector<int64_t> bucket_sizes;
for (const auto& bucket : reducer_->buckets_) {
const auto& variables = bucket.variables;
int64_t bucket_size = 0;
for (const auto& v : variables) {
bucket_size += v.numel() * v.element_size();
}
bucket_sizes.push_back(bucket_size);
}
return bucket_sizes;
}
// Communication hook. Empty string if not set, in which case it will not be
// logged.
void Logger::set_comm_hook(const std::string& hook) {
ddp_logging_data_->strs_map["comm_hook"] = hook;
}
// Whether we are running under model.join() context manager for DDP uneven
// inputs.
void Logger::set_uneven_input_join() {
ddp_logging_data_->ints_map["join_uneven_inputs"] = true;
}
void Logger::set_static_graph() {
ddp_logging_data_->ints_map["static_graph"] = reducer_->static_graph_;
}
// Data that can be got during DistributedDataParallel construction time
void Logger::set_construction_data_and_log(
const std::string& module_name,
const std::vector<int>& device_ids,
int output_device,
bool broadcast_buffers,
bool has_sync_bn,
bool static_graph) {
// No lock is needed, as it will be called in DistributedDataParallel
// constructor.
if (static_graph) {
set_static_graph();
}
ddp_logging_data_->strs_map["module_name"] = module_name;
ddp_logging_data_->ints_map["world_size"] =
reducer_->process_group_->getSize();
ddp_logging_data_->ints_map["rank"] = reducer_->process_group_->getRank();
// In which iteration of the training loop the get_ddp_logging_data()
// is called to fetch the DDPLoggingData, 0 if the data is fetched
// before training loop.
ddp_logging_data_->ints_map["iteration"] = 0;
ddp_logging_data_->ints_map["is_multi_device_module"] =
reducer_->is_multi_device_module_;
set_parameter_stats();
// A list of bucket sizes (Bytes) calculated during construction time
ddp_logging_data_->strs_map["bucket_sizes"] =
c10::Join(", ", get_bucket_sizes());
set_env_variables();
// DistributedDataParallel constructor input parameters
ddp_logging_data_->strs_map["device_ids"] = c10::Join(", ", device_ids);
ddp_logging_data_->ints_map["output_device"] = output_device;
ddp_logging_data_->ints_map["broadcast_buffers"] = broadcast_buffers;
ddp_logging_data_->ints_map["has_sync_bn"] = has_sync_bn;
ddp_logging_data_->ints_map["bucket_cap_bytes"] = reducer_->bucket_bytes_cap_;
ddp_logging_data_->ints_map["find_unused_parameters"] =
reducer_->find_unused_parameters_;
ddp_logging_data_->ints_map["gradient_as_bucket_view"] =
reducer_->gradient_as_bucket_view_;
ddp_logging_data_->strs_map["backend_name"] =
reducer_->process_group_->getBackendName();
if (debug_level() != DebugLevel::Off) {
std::string initInfo = fmt::format(
"[Rank {}]: DDP Initialized with: \n",
ddp_logging_data_->ints_map["rank"]);
std::stringstream ddpLoggingDataInfo;
for (const auto& intItem : ddp_logging_data_->ints_map) {
ddpLoggingDataInfo << intItem.first << ": " << intItem.second << "\n";
}
for (const auto& strItem : ddp_logging_data_->strs_map) {
ddpLoggingDataInfo << strItem.first << ": " << strItem.second << "\n";
}
LOG(INFO) << initInfo << ddpLoggingDataInfo.str();
}
at::LogPyTorchDDPUsage(*ddp_logging_data_);
}
void Logger::set_event_time(
int64_t& event_time,
Timer& timer,
Timer::Event event) {
auto timestamp = timer.getTimestamp(event);
if (timestamp != c10::nullopt) {
// TODO: should we set this as human-readable time instead of unixtime?
event_time = *timestamp;
}
}
void Logger::calculate_avg_time(
int64_t& avg_time,
int64_t& time_duration,
Timer& timer,
Timer::Event start_event,
Timer::Event end_event) {
TORCH_CHECK(num_iterations_stats_recorded_ > 0);
c10::optional<int64_t> maybe_time_duration =
timer.measureDifference(start_event, end_event);
if (!maybe_time_duration.has_value()) {
return;
}
time_duration = maybe_time_duration.value();
avg_time = (time_duration + avg_time * (num_iterations_stats_recorded_ - 1)) /
num_iterations_stats_recorded_;
}
void Logger::reset_performance_stats() {
ddp_logging_data_->ints_map["forward_compute_time"] = 0;
ddp_logging_data_->ints_map["backward_comm_time"] = 0;
ddp_logging_data_->ints_map["backward_compute_time"] = 0;
ddp_logging_data_->ints_map["backward_compute_comm_overlap_time"] = 0;
ddp_logging_data_->ints_map["forward_compute_time_start"] = 0;
ddp_logging_data_->ints_map["backward_compute_time_start"] = 0;
ddp_logging_data_->ints_map["backward_comm_time_start"] = 0;
ddp_logging_data_->ints_map["backward_compute_time_end"] = 0;
ddp_logging_data_->ints_map["backward_comm_time_end"] = 0;
}
void Logger::set_runtime_stats_and_log() {
// Sync with reducer's data
std::lock_guard<std::mutex> lock(reducer_->mutex_);
// Set runtime stats at the sampling iterations.
if (!reducer_->should_collect_runtime_stats()) {
return;
}
num_iterations_stats_recorded_++;
// Set ith iteration when the runtime stats are set.
ddp_logging_data_->ints_map["iteration"] = reducer_->num_iterations_;
// When get_ddp_logging_data() is called, "unused_parameter_size",
// "has_rebuilt_buckets" and "rebuilt_bucket_sizes" are updated in the latest
// sampling iteration.
// If unused_parameters_ is not empty, calculate its sizes.
// unused_parameters_ is calculated in forward call of
// each iteration.
if (reducer_->unused_parameters_.size() == 0 &&
reducer_->find_unused_parameters_) {
// No unused params in this iteration
ddp_logging_data_->ints_map["unused_parameter_size"] = 0;
}
for (const auto& unused_index : reducer_->unused_parameters_) {
const auto& v = reducer_->params_[unused_index];
ddp_logging_data_->ints_map["unused_parameter_size"] +=
v.numel() * v.element_size();
}
// rebuilt_bucket_sizes will not change once buckets are rebuilt,
// so it only needs to set once during whole training loop.
// Rebuild buckets stats after 1st iteration
if (ddp_logging_data_->ints_map["has_rebuilt_buckets"] !=
reducer_->has_rebuilt_bucket_) {
ddp_logging_data_->ints_map["has_rebuilt_buckets"] =
reducer_->has_rebuilt_bucket_;
ddp_logging_data_->strs_map["rebuilt_bucket_sizes"] =
c10::Join(", ", get_bucket_sizes());
// Log per-bucket variable indices
std::vector<std::string> per_bucket_variable_indices;
auto indices = get_per_bucket_variable_indices();
per_bucket_variable_indices.reserve(indices.size());
for (const auto& bucket_indices : indices) {
per_bucket_variable_indices.push_back(c10::Join(" ", bucket_indices));
}
ddp_logging_data_->strs_map["rebuilt_per_bucket_param_indices"] =
c10::Join(", ", per_bucket_variable_indices);
}
// Log gradient ready order
if (!reducer_->grad_ready_order_indices_.empty()) {
// Note that the indices are for the previous iteration as
// this function is called in forward pass, and we last computed gradient
// ready order in the last backward pass.
ddp_logging_data_->strs_map["prev_iteration_grad_ready_order_indices"] =
c10::Join(", ", reducer_->grad_ready_order_indices_);
}
reset_performance_stats();
// Cuda time stats are only collected for single device modules.
if (reducer_->params_[0].is_cuda() && reducer_->is_multi_device_module_) {
TORCH_WARN_ONCE(
"Cuda time stats are not collected for multi-device modules.");
return;
}
if (!reducer_->params_[0].is_cuda() && !reducer_->params_[0].is_cpu()) {
TORCH_WARN_ONCE(
"Time stats are currently only collected for CPU and CUDA devices. "
"Please refer to CpuTimer or CudaTimer for how to register timer "
"for other device type.");
return;
}
TORCH_INTERNAL_ASSERT(reducer_->timer_);
calculate_avg_time(
ddp_logging_data_->ints_map["avg_forward_compute_time"],
ddp_logging_data_->ints_map["forward_compute_time"],
*reducer_->timer_,
Timer::Event::kForwardStart,
Timer::Event::kBackwardComputeStart);
calculate_avg_time(
ddp_logging_data_->ints_map["avg_backward_compute_time"],
ddp_logging_data_->ints_map["backward_compute_time"],
*reducer_->timer_,
Timer::Event::kBackwardComputeStart,
Timer::Event::kBackwardComputeEnd);
calculate_avg_time(
ddp_logging_data_->ints_map["avg_backward_comm_time"],
ddp_logging_data_->ints_map["backward_comm_time"],
*reducer_->timer_,
Timer::Event::kBackwardCommStart,
Timer::Event::kBackwardCommEnd);
calculate_avg_time(
ddp_logging_data_->ints_map["avg_backward_compute_comm_overlap_time"],
ddp_logging_data_->ints_map["backward_compute_comm_overlap_time"],
*reducer_->timer_,
Timer::Event::kBackwardCommStart,
Timer::Event::kBackwardComputeEnd);
set_event_time(
ddp_logging_data_->ints_map["forward_compute_time_start"],
*reducer_->timer_,
Timer::Event::kForwardStart);
set_event_time(
ddp_logging_data_->ints_map["backward_compute_time_start"],
*reducer_->timer_,
Timer::Event::kBackwardComputeStart);
set_event_time(
ddp_logging_data_->ints_map["backward_comm_time_start"],
*reducer_->timer_,
Timer::Event::kBackwardCommStart);
set_event_time(
ddp_logging_data_->ints_map["backward_compute_time_end"],
*reducer_->timer_,
Timer::Event::kBackwardComputeEnd);
set_event_time(
ddp_logging_data_->ints_map["backward_comm_time_end"],
*reducer_->timer_,
Timer::Event::kBackwardCommEnd);
// Log runtime stats to stderr if TORCH_DISTRIBUTED_DEBUG=DETAIL is enabled.
if (debug_level() == DebugLevel::Detail) {
LOG(INFO) << *this;
}
// Log runtime (e.g. avg performance) stats at the beginning and also
// after a larger number of iterations. Choosing 10/1000/10000 is
// not scientific here, it assumes most of applications will run
// at least 10 iterations. stats could have smaller variance if
// selected num_iterations_ is larger.
if (std::find(
std::begin(LoggingIterations),
std::end(LoggingIterations),
num_iterations_stats_recorded_) != std::end(LoggingIterations)) {
at::LogPyTorchDDPUsage(*ddp_logging_data_);
}
}
at::DDPLoggingData Logger::get_ddp_logging_data() {
std::lock_guard<std::mutex> lock(reducer_->mutex_);
return *ddp_logging_data_;
}
} // namespace c10d
|