1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
|
#pragma once
#include <string>
#include <vector>
#include <c10/macros/Macros.h>
#include <c10/util/ThreadLocalDebugInfo.h>
#include <ATen/record_function.h>
#include <ATen/core/ivalue.h>
namespace torch {
extern TORCH_API const std::string kParamCommsCallName;
class TORCH_API ParamCommsDebugInfo
: public c10::DebugInfoBase {
public:
ParamCommsDebugInfo() = default;
ParamCommsDebugInfo(
int rank,
std::string&& colName,
int inSize,
int outSize,
at::ScalarType dType,
std::vector<int64_t> inSplitSizes,
std::vector<int64_t> outSplitSizes);
~ParamCommsDebugInfo() override = default;
int getRank() const {
return rank_;
}
const std::string getColumnName() const {
return columnName_;
}
int getInMessageSize() const {
return inMessageSize_;
}
int getOutMessageSize() const {
return outMessageSize_;
}
at::ScalarType getDType() const {
return dType_;
}
const std::vector<int64_t>& getInputSplitSizes() const {
return inputSplitSizes_;
}
const std::vector<int64_t>& getOutputSplitSizes() const {
return outputSplitSizes_;
}
private:
int rank_{};
std::string columnName_;
int inMessageSize_{};
int outMessageSize_{};
at::ScalarType dType_ = at::kByte;
std::vector<int64_t> inputSplitSizes_;
std::vector<int64_t> outputSplitSizes_;
};
#define RECORD_PARAM_COMMS(rank, colName, inSize, outSize, dType, inSplitSizes, outSplitSizes) \
auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \
rank, \
colName, \
inSize, \
outSize, \
dType, \
inSplitSizes, \
outSplitSizes); \
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
RECORD_FUNCTION(torch::kParamCommsCallName, std::vector<c10::IValue>());
} // namespace torch
|