1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
|
#ifndef CAFFE2_OPERATORS_FILLER_OP_H_
#define CAFFE2_OPERATORS_FILLER_OP_H_
#include <c10/util/irange.h>
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
// FillerOp takes in either zero or one input.
//
// If the number of input is 1, the shape will be identical to that of the input
// at run time with optional additional dimensions appended at the end as
// specified by "extra_shape" argument. In that case the "shape" parameter
// should not be set.
//
// If the number of inputs is 0, the full shape must be provided via "shape"
// argument
template <class Context>
class FillerOp : public Operator<Context> {
public:
template <class... Args>
explicit FillerOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
shape_(this->template GetRepeatedArgument<int64_t>("shape")),
extra_shape_(ToVectorint64_t(
this->template GetRepeatedArgument<int>("extra_shape"))),
input_as_shape_(
this->template GetSingleArgument<bool>("input_as_shape", false)) {
if (InputSize()) {
if (shape_.size() != 0) {
CAFFE_THROW(
"Cannot set the shape argument and pass in an input at "
"the same time");
}
} else {
if (!extra_shape_.empty()) {
CAFFE_THROW("Cannot set extra_shape when there is no input");
}
if (input_as_shape_) {
CAFFE_THROW("An input must be given if input_as_shape is true");
}
if (shape_.size() == 0 &&
this->template HasSingleArgumentOfType<int>("shape")) {
CAFFE_THROW("Fill 'shape' argument was a scalar, list expected");
}
}
}
virtual ~FillerOp() {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
auto* output = Operator<Context>::Output(0);
if (InputSize()) {
auto shape = vector<int64_t>{};
if (input_as_shape_) {
if (this->InputIsTensorType(0, CPU)) {
// originally, shape input must be in CPU context
auto& input = this->template Input<Tensor>(0, CPU);
CAFFE_ENFORCE_EQ(
input.dim(),
1,
"When input_as_shape is true, the input must be a 1D tensor of "
"data type int64_t");
CAFFE_ENFORCE(input.numel() > 0);
auto* shape_data = input.template data<int64_t>();
shape.insert(shape.end(), shape_data, shape_data + input.dim32(0));
} else {
// in ONNX case, we allow shape to be in CUDA context
auto& input = Input(0);
CAFFE_ENFORCE_EQ(
input.dim(),
1,
"When input_as_shape is true, the input must be a 1D tensor of "
"data type int64_t");
CAFFE_ENFORCE(input.numel() > 0);
auto* shape_data = input.template data<int64_t>();
std::unique_ptr<int64_t[]> shape_data_copy =
std::make_unique<int64_t[]>(input.dim32(0));
context_.template CopyToCPU<int64_t>(
input.dim32(0), shape_data, shape_data_copy.get());
shape.insert(
shape.end(),
shape_data_copy.get(),
shape_data_copy.get() + input.dim32(0));
}
} else {
auto& input = Input(0);
shape.insert(shape.end(), input.sizes().begin(), input.sizes().end());
}
shape.insert(shape.end(), extra_shape_.begin(), extra_shape_.end());
output->Resize(shape);
shape_ = shape;
} else {
output->Resize(shape_);
}
return Fill(output);
}
virtual bool Fill(Tensor* output) = 0;
protected:
vector<int64_t> shape_;
vector<int64_t> extra_shape_;
bool input_as_shape_;
};
template <typename T, class Context>
class UniformFillOp final : public FillerOp<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit UniformFillOp(Args&&... args)
: FillerOp<Context>(std::forward<Args>(args)...),
min_(this->template GetSingleArgument<T>("min", 0)),
max_(this->template GetSingleArgument<T>("max", 1)) {
if (InputSize() == 3) {
CAFFE_ENFORCE(
!this->template HasSingleArgumentOfType<T>("min"),
"Cannot set both min arg and min input blob");
CAFFE_ENFORCE(
!this->template HasSingleArgumentOfType<T>("max"),
"Cannot set both max arg and max input blob");
} else {
CAFFE_ENFORCE_LT(
min_, max_, "Max value should be bigger than min value.");
}
}
bool Fill(Tensor* output) override {
T min = min_;
T max = max_;
if (InputSize() == 3) {
CAFFE_ENFORCE_EQ(1, Input(1).numel(), "min blob must be scalar");
CAFFE_ENFORCE_EQ(1, Input(2).numel(), "max blob must be scalar");
min = *Input(1).template data<T>();
max = *Input(2).template data<T>();
if (min > max) {
auto shape = output->sizes().vec();
shape[0] = 0;
output->Resize(shape);
output->template mutable_data<T>();
return true;
}
}
math::RandUniform<T, Context>(
output->numel(),
min,
max,
output->template mutable_data<T>(),
&context_);
return true;
}
private:
T min_;
T max_;
};
template <class Context>
class UniqueUniformFillOp final : public FillerOp<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit UniqueUniformFillOp(Args&&... args)
: FillerOp<Context>(std::forward<Args>(args)...) {
TensorProto_DataType dtype =
static_cast<TensorProto_DataType>(this->template GetSingleArgument<int>(
"dtype", TensorProto_DataType_INT32));
switch (dtype) {
case TensorProto_DataType_INT32:
CheckRange<int>();
body_ = &UniqueUniformFillOp::FillWithType<int>;
break;
case TensorProto_DataType_INT64:
CheckRange<int64_t>();
body_ = &UniqueUniformFillOp::FillWithType<int64_t>;
break;
case TensorProto_DataType_UNDEFINED:
CAFFE_THROW(
"UniqueUniformFill op cannot have undefined 'dtype' argument");
// break;
default:
CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
}
}
bool Fill(Tensor* output) override {
return (this->*body_)(output);
}
private:
template <typename T>
void CheckRange() {
CAFFE_ENFORCE(this->template HasSingleArgumentOfType<T>("min"));
CAFFE_ENFORCE(this->template HasSingleArgumentOfType<T>("max"));
CAFFE_ENFORCE_LT(
this->template GetSingleArgument<T>("min", 0),
this->template GetSingleArgument<T>("max", 0),
"Max value should be bigger than min value.");
}
template <typename T>
bool FillWithType(Tensor* output) {
T min = this->template GetSingleArgument<T>("min", 0);
T max = this->template GetSingleArgument<T>("max", 0);
const T* avoid_data = nullptr;
size_t avoid_size = 0;
if (InputSize() >= 2) {
auto& avoid = Input(1);
avoid_data = avoid.template data<T>();
avoid_size = avoid.numel();
}
math::RandUniformUnique<T, Context>(
output->numel(),
min,
max,
output->template mutable_data<T>(),
avoid_size,
avoid_data,
&context_);
return true;
}
bool (UniqueUniformFillOp::*body_)(Tensor* output);
};
template <class Context>
class ConstantFillOp final : public FillerOp<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit ConstantFillOp(Args&&... args)
: FillerOp<Context>(std::forward<Args>(args)...) {
TensorProto_DataType dtype =
static_cast<TensorProto_DataType>(this->template GetSingleArgument<int>(
"dtype", TensorProto_DataType_FLOAT));
if (!OperatorBase::HasArgument("dtype") &&
OperatorBase::HasArgument("value")) {
// If 'dtype' is not provided, infer type based on the type of 'value'
// Currently, single argument contains either float, int64 or bytes
if (this->template HasSingleArgumentOfType<float>("value")) {
dtype = TensorProto_DataType_FLOAT;
} else if (this->template HasSingleArgumentOfType<int64_t>("value")) {
dtype = TensorProto_DataType_INT64;
} else {
CAFFE_THROW("Argument 'value' is of unexpected type");
}
VLOG(1) << "Argument 'dtype' is not provided. Assume the data type is "
<< "the same as that of argument 'value': " << dtype;
}
switch (dtype) {
case TensorProto_DataType_FLOAT:
body_ = &ConstantFillOp::FillWithType<float>;
break;
case TensorProto_DataType_DOUBLE:
body_ = &ConstantFillOp::FillWithType<double>;
break;
case TensorProto_DataType_BOOL:
body_ = &ConstantFillOp::FillWithType<bool>;
break;
case TensorProto_DataType_INT8:
body_ = &ConstantFillOp::FillWithType<int8_t>;
break;
case TensorProto_DataType_INT16:
body_ = &ConstantFillOp::FillWithType<int16_t>;
break;
case TensorProto_DataType_INT32:
body_ = &ConstantFillOp::FillWithType<int>;
break;
case TensorProto_DataType_INT64:
body_ = &ConstantFillOp::FillWithType<int64_t>;
break;
case TensorProto_DataType_UINT8:
body_ = &ConstantFillOp::FillWithType<uint8_t>;
break;
case TensorProto_DataType_UINT16:
body_ = &ConstantFillOp::FillWithType<uint16_t>;
break;
case TensorProto_DataType_STRING:
body_ = &ConstantFillOp::FillWithString;
break;
case TensorProto_DataType_UNDEFINED:
CAFFE_THROW("ConstantFill op cannot have undefined 'dtype' argument");
// break;
default:
CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
}
}
bool Fill(Tensor* output) override {
return (this->*body_)(output);
}
template <typename T>
bool FillWithType(Tensor* output) {
T value = this->template GetSingleArgument<T>("value", 0);
if (InputSize() == 2) {
auto& value_vec = Input(1);
if (value_vec) {
CAFFE_ENFORCE_EQ(
value_vec.size(), 1, "value vector must have 1 element");
value = value_vec.template data<T>()[0];
}
}
auto* data = output->template mutable_data<T>();
if (output->numel()) {
math::Set<T, Context>(output->numel(), value, data, &context_);
}
return true;
}
bool FillWithString(Tensor* output) {
CAFFE_ENFORCE_LT(
InputSize(), 2, "constant fill string from tensor is not supported");
auto value = this->template GetSingleArgument<std::string>("value", "");
auto* data = output->template mutable_data<std::string>();
for (int i = 0; i < output->numel(); ++i) {
data[i] = value;
}
return true;
}
private:
bool (ConstantFillOp::*body_)(Tensor* output);
};
template <class Context>
class DiagonalFillOp final : public FillerOp<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit DiagonalFillOp(Args&&... args)
: FillerOp<Context>(std::forward<Args>(args)...) {
TensorProto_DataType dtype =
static_cast<TensorProto_DataType>(this->template GetSingleArgument<int>(
"dtype", TensorProto_DataType_FLOAT));
if (!OperatorBase::HasArgument("dtype") &&
OperatorBase::HasArgument("value")) {
// If 'dtype' is not provided, infer type based on the type of 'value'
// Currently, single argument contains either float, int64 or bytes
if (this->template HasSingleArgumentOfType<float>("value")) {
dtype = TensorProto_DataType_FLOAT;
} else if (this->template HasSingleArgumentOfType<int64_t>("value")) {
dtype = TensorProto_DataType_INT64;
} else {
CAFFE_THROW("Argument 'value' is of unexpected type");
}
VLOG(1) << "Argument 'dtype' is not provided. Assume the data type is "
<< "the same as that of argument 'value': " << dtype;
}
switch (dtype) {
case TensorProto_DataType_FLOAT:
body_ = &DiagonalFillOp::FillWithType<float>;
break;
case TensorProto_DataType_DOUBLE:
body_ = &DiagonalFillOp::FillWithType<double>;
break;
case TensorProto_DataType_BOOL:
body_ = &DiagonalFillOp::FillWithType<bool>;
break;
case TensorProto_DataType_INT8:
body_ = &DiagonalFillOp::FillWithType<int8_t>;
break;
case TensorProto_DataType_INT16:
body_ = &DiagonalFillOp::FillWithType<int16_t>;
break;
case TensorProto_DataType_INT32:
body_ = &DiagonalFillOp::FillWithType<int>;
break;
case TensorProto_DataType_INT64:
body_ = &DiagonalFillOp::FillWithType<int64_t>;
break;
case TensorProto_DataType_UINT8:
body_ = &DiagonalFillOp::FillWithType<uint8_t>;
break;
case TensorProto_DataType_UINT16:
body_ = &DiagonalFillOp::FillWithType<uint16_t>;
break;
case TensorProto_DataType_UNDEFINED:
CAFFE_THROW("Cannot have undefined 'dtype' argument");
default:
CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
}
}
bool Fill(Tensor* output) override {
return (this->*body_)(output);
}
template <typename T>
bool FillWithType(Tensor* output);
private:
void VerifyOutputShape(Tensor* output) {
CAFFE_ENFORCE(output->dim() >= 2, "Input shape must be >= 2D");
}
int64_t GetStepSize(Tensor* output) {
int64_t step;
if (output->dim() == 2) {
step = output->size(1) + 1;
} else {
int64_t prev_i = output->size(0);
for (auto i : output->sizes()) {
if (i != prev_i) {
CAFFE_THROW("All dimensions of input must be of equal length");
}
}
vector<int64_t> cumprod(output->dim());
auto dims = output->sizes();
std::partial_sum(
dims.begin(),
dims.end() - 1,
cumprod.begin(),
std::multiplies<int64_t>());
step = 1 +
std::accumulate(
cumprod.begin(), cumprod.end(), static_cast<int64_t>(0));
VLOG(0) << step;
}
return step;
}
bool (DiagonalFillOp::*body_)(Tensor* output);
};
template <typename T, class Context>
class GaussianFillOp final : public FillerOp<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit GaussianFillOp(Args&&... args)
: FillerOp<Context>(std::forward<Args>(args)...),
mean_(this->template GetSingleArgument<float>("mean", 0)),
std_(this->template GetSingleArgument<float>("std", 1)) {
TORCH_DCHECK_GT(std_, 0) << "Standard deviation should be nonnegative.";
}
bool Fill(Tensor* output) override {
math::RandGaussian<T, Context>(
output->numel(),
mean_,
std_,
output->template mutable_data<T>(),
&context_);
return true;
}
private:
T mean_;
T std_;
};
template <typename T, class Context>
class XavierFillOp final : public FillerOp<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit XavierFillOp(Args&&... args)
: FillerOp<Context>(std::forward<Args>(args)...) {}
bool Fill(Tensor* output) override {
const int fan_in = output->numel() / output->dim32(0);
T scale = std::sqrt(T(3) / fan_in);
math::RandUniform<T, Context>(
output->numel(),
-scale,
scale,
output->template mutable_data<T>(),
&context_);
return true;
}
};
template <typename T, class Context>
class MSRAFillOp final : public FillerOp<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit MSRAFillOp(Args&&... args)
: FillerOp<Context>(std::forward<Args>(args)...) {}
bool Fill(Tensor* output) override {
const int fan_out = output->numel() / output->dim32(1);
T scale = std::sqrt(T(2) / fan_out);
math::RandGaussian<T, Context>(
output->numel(),
0.0,
scale,
output->template mutable_data<T>(),
&context_);
return true;
}
};
// This is mostly used just as a debugging purpose stuff: it fills a tensor
// sequentially with values 0, 1, 2..., which can then be used to check e.g.
// reshape operations by allowing one to read the indices more easily.
template <typename T, class Context>
class RangeFillOp final : public FillerOp<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit RangeFillOp(Args&&... args)
: FillerOp<Context>(std::forward<Args>(args)...) {}
bool Fill(Tensor* output) override;
};
template <class Context>
class LengthsRangeFillOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_SIMPLE_CTOR_DTOR(LengthsRangeFillOp);
bool RunOnDevice() override {
auto& input = Input(0);
auto* input_data = input.template data<int32_t>();
CAFFE_ENFORCE_EQ(input.dim(), 1, "Input must be a vector.");
auto len_sum = std::accumulate(input_data, input_data + input.numel(), 0);
auto* output = Output(0, {len_sum}, at::dtype<int32_t>());
auto* output_data = output->template mutable_data<int32_t>();
int32_t offset = 0;
for (const auto i : c10::irange(input.numel())) {
auto len = input_data[i];
auto start = output_data + offset;
std::iota(
start,
start + len,
0); // make the third argument the arg of this operator
offset += len;
}
return true;
}
};
template <int VALUE_TYPE = TensorProto_DataType_FLOAT>
inline std::vector<TensorShape> FillerTensorInference(
const OperatorDef& def,
const vector<TensorShape>& in) {
vector<TensorShape> out(1);
ArgumentHelper helper(def);
out[0].set_data_type(static_cast<TensorProto_DataType>(
helper.GetSingleArgument<int>("dtype", VALUE_TYPE)));
if (in.size()) {
// TODO
bool input_as_shape =
helper.GetSingleArgument<bool>("input_as_shape", false);
if (input_as_shape) {
out[0].set_unknown_shape(true);
return out;
}
for (auto d : in[0].dims()) {
out[0].add_dims(d);
}
} else {
auto shape = helper.GetRepeatedArgument<int64_t>("shape");
for (auto d : shape) {
out[0].add_dims(d);
}
}
return out;
}
} // namespace caffe2
#endif // CAFFE2_OPERATORS_FILLER_OP_H_
|