1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
|
#ifndef CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_
#define CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename T, class Context>
class LabelCrossEntropyOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(LabelCrossEntropyOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
static constexpr T kLOG_THRESHOLD() {
return static_cast<T>(1e-20);
}
// Input: X, label
// Output: Y
};
template <typename T, class Context>
class LabelCrossEntropyGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(LabelCrossEntropyGradientOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
// Input: X, label, dY
// Ouptut: dX. There is no gradient with respect to the label.
static constexpr T kLOG_THRESHOLD() {
return static_cast<T>(1e-20);
}
};
// Hacky: turns a vector of probabilities into a 2-column matrix with
// complimentary probabilities for binary classification
template <typename T, class Context>
class MakeTwoClassOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(MakeTwoClassOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
// Input: X
// Output: Y = vstack(1-X, X)
};
template <typename T, class Context>
class MakeTwoClassGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(MakeTwoClassGradientOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
// Input: dY
// Ouptut: dX
};
template <typename T, class Context>
class SigmoidCrossEntropyWithLogitsOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit SigmoidCrossEntropyWithLogitsOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
log_D_trick_(
this->template GetSingleArgument<bool>("log_D_trick", false)),
unjoined_lr_loss_(
this->template GetSingleArgument<bool>("unjoined_lr_loss", false)) {
CAFFE_ENFORCE(
!(log_D_trick_ && unjoined_lr_loss_),
"log_D_trick_ and unjoined_lr_loss_ cannot be set as True simultaneously");
}
bool RunOnDevice() override;
protected:
bool log_D_trick_;
bool unjoined_lr_loss_;
};
template <typename T, class Context>
class SigmoidCrossEntropyWithLogitsGradientOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit SigmoidCrossEntropyWithLogitsGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
log_D_trick_(
this->template GetSingleArgument<bool>("log_D_trick", false)),
unjoined_lr_loss_(
this->template GetSingleArgument<bool>("unjoined_lr_loss", false)) {
}
bool RunOnDevice() override;
protected:
bool log_D_trick_;
bool unjoined_lr_loss_;
};
template <typename T, class Context>
class WeightedSigmoidCrossEntropyWithLogitsOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(WeightedSigmoidCrossEntropyWithLogitsOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
};
template <typename T, class Context>
class WeightedSigmoidCrossEntropyWithLogitsGradientOp final
: public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(WeightedSigmoidCrossEntropyWithLogitsGradientOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
};
template <typename T, class Context>
class TORCH_API CrossEntropyOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(CrossEntropyOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
// Input: X, label
// Output: Y
static constexpr T kLOG_THRESHOLD() {
return static_cast<T>(1e-20);
}
};
template <typename T, class Context>
class TORCH_API CrossEntropyGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(CrossEntropyGradientOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
// Input: X, label, dY
// Ouptut: dX. There is no gradient with respect to the label.
static constexpr T kLOG_THRESHOLD() {
return static_cast<T>(1e-20);
}
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_
|