1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574
|
#pragma once
#include "caffe2/core/operator.h"
#include "caffe2/perfkernels/adagrad.h"
#if defined(USE_FBGEMM) && !defined(__NVCC__)
#include "fbgemm/FbgemmEmbedding.h"
#endif
namespace caffe2 {
template <typename Context>
void adagrad_update(
int N,
const float* w,
const float* g,
const float* h,
float* nw,
float* nh,
float epsilon,
float decay,
const float* lr,
Context* /*context*/,
float weight_decay = 0.f) {
return adagrad_update(
N, w, g, h, nw, nh, epsilon, decay, lr[0], weight_decay);
}
template <typename Context>
void adagrad_update_output_effective_lr(
int N,
const float* paramIn,
const float* gradIn,
const float* momentIn,
float* paramOut,
float* momentOut,
float* effectiveLROut,
float epsilon,
float decay,
const float* lr,
Context* /*context*/,
float weight_decay = 0.f) {
for (const auto i : c10::irange(N)) {
float grad = std::fma(weight_decay, paramIn[i], gradIn[i]);
float moment = momentOut[i] = decay * momentIn[i] + grad * grad;
float effective_lr = effectiveLROut[i] =
lr[0] / (std::sqrt(moment) + epsilon);
paramOut[i] = paramIn[i] + effective_lr * grad;
}
}
template <typename Context>
void adagrad_update_output_effective_lr_and_update(
int N,
const float* paramIn,
const float* gradIn,
const float* momentIn,
float* paramOut,
float* momentOut,
float* effectiveLROut,
float* updateOut,
float epsilon,
float decay,
const float* lr,
Context* /*context*/,
float weight_decay = 0.f) {
for (const auto i : c10::irange(N)) {
float grad = std::fma(weight_decay, paramIn[i], gradIn[i]);
float moment = momentOut[i] = decay * momentIn[i] + grad * grad;
float effective_lr = effectiveLROut[i] =
lr[0] / (std::sqrt(moment) + epsilon);
float update = updateOut[i] = effective_lr * grad;
paramOut[i] = paramIn[i] + update;
}
}
template <class Context>
class AdagradOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
AdagradOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)),
decay_(this->template GetSingleArgument<float>("decay", 1.0f)),
weight_decay_(
this->template GetSingleArgument<float>("weight_decay", 0.f)) {
VLOG(1) << "gradient optimization operator in use: "
<< "AdagradOp"
<< " weight_decay_=" << weight_decay_;
}
bool RunOnDevice() override {
CAFFE_ENFORCE_EQ(
Input(GRAD).numel(),
Input(MOMENT_1).numel(),
"PARAM size: ",
Input(PARAM).numel(),
", GRAD size: ",
Input(GRAD).numel(),
", MOMENT_1 size: ",
Input(MOMENT_1).numel(),
", LR size: ",
Input(LR).numel());
CAFFE_ENFORCE_EQ(Input(GRAD).numel(), Input(PARAM).numel());
Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1));
if (OutputSize() == 2) {
adagrad_update<Context>(
Input(GRAD).numel(),
Input(PARAM).template data<float>(),
Input(GRAD).template data<float>(),
Input(MOMENT_1).template data<float>(),
Output(OUTPUT_PARAM)->template mutable_data<float>(),
Output(OUTPUT_MOMENT_1)->template mutable_data<float>(),
epsilon_,
decay_,
Input(LR).template data<float>(),
&context_,
weight_decay_);
} else if (OutputSize() == 3) {
Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(Input(GRAD));
adagrad_update_output_effective_lr<Context>(
Input(GRAD).numel(),
Input(PARAM).template data<float>(),
Input(GRAD).template data<float>(),
Input(MOMENT_1).template data<float>(),
Output(OUTPUT_PARAM)->template mutable_data<float>(),
Output(OUTPUT_MOMENT_1)->template mutable_data<float>(),
Output(OUTPUT_EFFECTIVE_LR)->template mutable_data<float>(),
epsilon_,
decay_,
Input(LR).template data<float>(),
&context_,
weight_decay_);
} else {
Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(Input(GRAD));
Output(OUTPUT_UPDATE)->ResizeLike(Input(GRAD));
adagrad_update_output_effective_lr_and_update<Context>(
Input(GRAD).numel(),
Input(PARAM).template data<float>(),
Input(GRAD).template data<float>(),
Input(MOMENT_1).template data<float>(),
Output(OUTPUT_PARAM)->template mutable_data<float>(),
Output(OUTPUT_MOMENT_1)->template mutable_data<float>(),
Output(OUTPUT_EFFECTIVE_LR)->template mutable_data<float>(),
Output(OUTPUT_UPDATE)->template mutable_data<float>(),
epsilon_,
decay_,
Input(LR).template data<float>(),
&context_,
weight_decay_);
}
return true;
}
protected:
float epsilon_;
float decay_;
float weight_decay_;
INPUT_TAGS(PARAM, MOMENT_1, GRAD, LR);
OUTPUT_TAGS(
OUTPUT_PARAM,
OUTPUT_MOMENT_1,
OUTPUT_EFFECTIVE_LR,
OUTPUT_UPDATE);
};
class SparseAdagradOp final : public Operator<CPUContext> {
public:
SparseAdagradOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)),
weight_decay_(
this->template GetSingleArgument<float>("weight_decay", 0.f)) {
VLOG(1) << "gradient optimization operator in use: "
<< "SparseAdagradOp"
<< " weight_decay_=" << weight_decay_;
const float decay = this->template GetSingleArgument<float>("decay", 1.0);
CAFFE_ENFORCE_EQ(
decay, 1.0, "Decay is not supported for SparseSimdAdagradOp");
}
bool RunOnDevice() override {
// Enforce shapes
// input(embedding/momentum) == outputs(embedding/momentum)
CAFFE_ENFORCE_EQ(
Input(PARAM).numel(),
Input(MOMENT_1).numel(),
"Input Param size: ",
Input(PARAM).numel(),
" Input Moment size: ",
Input(MOMENT_1).numel());
CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
CAFFE_ENFORCE_EQ(
Input(PARAM).size_from_dim(1),
Input(GRAD).size_from_dim(Input(INDICES).dim()));
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
this, Input(INDICES));
}
template <typename SIndex>
bool DoRunWithType() {
const auto* lr = Input(LR).template data<float>();
auto n = Input(INDICES).numel();
const auto* indices = Input(INDICES).template data<SIndex>();
const auto* gradIn = Input(GRAD).template data<float>();
auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<float>();
auto* momentOut = Output(OUTPUT_MOMENT_1)->template mutable_data<float>();
if (n == 0) {
return true;
}
auto block_size = Input(GRAD).numel() / n;
// input(grad) is compatible with size of indexes
CAFFE_ENFORCE_EQ(
Input(GRAD).numel() % n,
0,
"Incorrect gradient size:",
Input(GRAD).numel(),
" size of indexes:",
n);
#if defined(USE_FBGEMM) && !defined(__NVCC__)
VLOG(1) << "using fbgemm::GenerateSparseAdaGrad in SparseAdagradOp";
if (block_size != last_block_size_) {
last_block_size_ = block_size;
if (std::is_same<SIndex, std::int32_t>::value) {
kernel_i32_ = fbgemm::GenerateSparseAdaGrad<std::int32_t>(
block_size,
/*rowwise=*/false,
/*prefetch=*/16,
weight_decay_ != 0.0f);
} else {
CAFFE_ENFORCE((std::is_same<SIndex, std::int64_t>::value));
kernel_i64_ = fbgemm::GenerateSparseAdaGrad<std::int64_t>(
block_size,
/*rowwise=*/false,
/*prefetch=*/16,
weight_decay_ != 0.0f);
}
}
int num_rows_processed;
if (std::is_same<SIndex, std::int32_t>::value) {
num_rows_processed = kernel_i32_(
n,
Input(PARAM).numel(),
paramOut,
gradIn,
momentOut,
reinterpret_cast<const std::int32_t*>(indices),
epsilon_,
lr[0],
weight_decay_,
/*counter=*/nullptr,
/*counter_halflife=*/0);
} else {
num_rows_processed = kernel_i64_(
n,
Input(PARAM).numel(),
paramOut,
gradIn,
momentOut,
reinterpret_cast<const std::int64_t*>(indices),
epsilon_,
lr[0],
weight_decay_,
/*counter=*/nullptr,
/*counter_halflife=*/0);
}
if (num_rows_processed < n) {
CAFFE_ENFORCE_GE(
Input(PARAM).numel(),
(indices[num_rows_processed] + 1) * block_size,
this->debug_def().input(PARAM),
", out of bound, idx:",
indices[num_rows_processed],
" for input i:",
num_rows_processed,
" and block_size:",
block_size,
" max size:",
Input(PARAM).numel());
return false;
} else {
return true;
}
#endif
VLOG(1)
<< "using internal::adagrad_update_prefetch_inlined in SparseAdagradOp";
const auto* paramIn = Input(PARAM).template data<float>();
const auto* momentIn = Input(MOMENT_1).template data<float>();
std::vector<float> grad(block_size);
for (const auto i : c10::irange(n)) {
auto idx = indices[i];
auto offsetI = i * block_size;
auto offsetIdx = idx * block_size;
// Enforce:
// access within range
// gradient access within range
CAFFE_ENFORCE_GE(
Input(PARAM).numel(),
block_size + offsetIdx,
this->debug_def().input(PARAM),
", out of bound, idx:",
idx,
" for input i:",
i,
" and block size:",
block_size,
" max size:",
Input(PARAM).numel());
if (block_size == 1) {
float gi = std::fma(weight_decay_, paramIn[idx], gradIn[i]);
float hi = momentOut[idx] = momentIn[idx] + gi * gi;
paramOut[idx] = paramIn[idx] + lr[0] * gi / (std::sqrt(hi) + epsilon_);
} else {
// prefetching
const int prefdist_T0 = 16;
int i_pref = (i < n - prefdist_T0) ? i + prefdist_T0 : i;
std::size_t idx_pref = indices[i_pref];
internal::adagrad_update_prefetch_inlined(
block_size,
paramIn + offsetIdx,
¶mIn[idx_pref * block_size],
gradIn + offsetI,
momentIn + offsetIdx,
&momentIn[idx_pref * block_size],
paramOut + offsetIdx,
¶mOut[idx_pref * block_size],
momentOut + offsetIdx,
&momentOut[idx_pref * block_size],
epsilon_,
lr[0],
weight_decay_);
}
}
return true;
}
protected:
float epsilon_;
const float weight_decay_;
#if defined(USE_FBGEMM) && !defined(__NVCC__)
fbgemm::SparseAdaGradSignature<std::int32_t>::Type kernel_i32_;
fbgemm::SparseAdaGradSignature<std::int64_t>::Type kernel_i64_;
std::int64_t last_block_size_{-1};
#endif
INPUT_TAGS(PARAM, MOMENT_1, INDICES, GRAD, LR);
OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1);
};
template <class Context>
class RowWiseSparseAdagradOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
RowWiseSparseAdagradOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)),
weight_decay_(
this->template GetSingleArgument<float>("weight_decay", 0.f)),
counter_halflife_(
this->template GetSingleArgument<int64_t>("counter_halflife", -1)) {
VLOG(1) << "gradient optimization operator in use: "
<< "RowWiseSparseAdagradOp"
<< " weight_decay_=" << weight_decay_
<< " counter_halflife=" << counter_halflife_;
}
bool RunOnDevice() override {
// Enforce shapes
CAFFE_ENFORCE_EQ(Input(PARAM).sizes()[0], Input(MOMENT_1).numel());
CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
CAFFE_ENFORCE_EQ(
Input(PARAM).size_from_dim(1),
Input(GRAD).size_from_dim(Input(INDICES).dim()));
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
this, Input(INDICES));
}
template <typename SIndex>
bool DoRunWithType() {
const auto* lr = Input(LR).template data<float>();
auto* param = Output(OUTPUT_PARAM)->template mutable_data<float>();
auto* moment = Output(OUTPUT_MOMENT_1)->template mutable_data<float>();
const auto* indices = Input(INDICES).template data<SIndex>();
const auto* gradIn = Input(GRAD).template data<float>();
const auto* count = counter_halflife_ == -1
? nullptr
: Input(COUNTER).template data<double>();
auto n = Input(INDICES).numel();
if (n == 0) {
return true;
}
auto block_size = Input(GRAD).numel() / n;
// Enforce:
// Input(embedding/momentum) == outputs(embedding/momentum)
CAFFE_ENFORCE_EQ(
Input(PARAM).numel() / block_size,
Input(MOMENT_1).numel(),
"Input Param size: ",
Input(PARAM).numel(),
" Block size: ",
block_size,
" Input Moment size: ",
Input(MOMENT_1).numel());
// input(grad) is compatible with size of indexes
CAFFE_ENFORCE_EQ(
Input(GRAD).numel() % n,
0,
"Incorrect gradient size:",
Input(GRAD).numel(),
" size of indexes:",
n);
#if defined(USE_FBGEMM) && !defined(__NVCC__)
VLOG(1) << "using fbgemm::GenerateSparseAdaGrad in RowWiseSparseAdagradOp";
if (block_size != last_block_size_) {
last_block_size_ = block_size;
if (std::is_same<SIndex, std::int32_t>::value) {
kernel_i32_ = fbgemm::GenerateSparseAdaGrad<std::int32_t>(
block_size,
/*rowwise=*/true,
/*prefetch=*/16,
weight_decay_ != 0.0f);
} else {
CAFFE_ENFORCE((std::is_same<SIndex, std::int64_t>::value));
kernel_i64_ = fbgemm::GenerateSparseAdaGrad<std::int64_t>(
block_size,
/*rowwise=*/true,
/*prefetch=*/16,
weight_decay_ != 0.0f);
}
}
int num_rows_processed;
if (std::is_same<SIndex, std::int32_t>::value) {
num_rows_processed = kernel_i32_(
n,
Input(PARAM).numel(),
param,
gradIn,
moment,
reinterpret_cast<const std::int32_t*>(indices),
epsilon_,
lr[0],
weight_decay_,
(counter_halflife_ > 0) ? count : nullptr,
counter_halflife_);
} else {
num_rows_processed = kernel_i64_(
n,
Input(PARAM).numel(),
param,
gradIn,
moment,
reinterpret_cast<const std::int64_t*>(indices),
epsilon_,
lr[0],
weight_decay_,
(counter_halflife_ > 0) ? count : nullptr,
counter_halflife_);
}
if (num_rows_processed < n) {
// Enforce:
// access within range
CAFFE_ENFORCE_GE(
Input(PARAM).numel(),
(indices[num_rows_processed] + 1) * block_size,
this->debug_def().input(PARAM),
", out of bound, idx:",
indices[num_rows_processed],
" for input i:",
num_rows_processed,
" and block size:",
block_size,
" max size:",
Input(PARAM).numel());
return false;
} else {
return true;
}
#else
VLOG(1) << "using plain adagrad updates in RowWiseSparseAdagradOp";
for (const auto i : c10::irange(n)) {
auto idx = indices[i];
float freq = (counter_halflife_ > 0 && count[idx] > 0)
? counter_halflife_ / count[idx]
: 1.0;
if (block_size == 1) {
float gi = std::fma(weight_decay_ * freq, param[idx], gradIn[i]);
float hi = moment[idx] = moment[idx] + gi * gi;
param[idx] = param[idx] + lr[0] * gi / (std::sqrt(hi) + epsilon_);
} else {
auto offsetI = i * block_size;
auto offsetIdx = idx * block_size;
#ifndef NDEBUG
CAFFE_ENFORCE_GE(
Input(PARAM).numel(),
block_size + offsetIdx,
this->debug_def().input(PARAM),
", out of bound, idx:",
idx,
" for input i:",
i,
" and block size:",
block_size);
CAFFE_ENFORCE_GE(
Input(GRAD).numel(),
block_size + offsetI,
this->debug_def().input(GRAD),
", out of bound idx, idx:",
idx,
" for input i:",
i);
#endif
float* w = param + offsetIdx;
const float* g = gradIn + offsetI;
float* h = moment + idx;
float hs = 0.;
for (const auto j : c10::irange(block_size)) {
float gj = std::fma(weight_decay_ * freq, w[j], g[j]);
hs += gj * gj;
}
float hi = h[0] = h[0] + hs / block_size;
float step = lr[0] / (std::sqrt(hi) + epsilon_);
for (const auto j : c10::irange(block_size)) {
float gj = std::fma(weight_decay_ * freq, w[j], g[j]);
w[j] = w[j] + gj * step;
}
}
}
return true;
#endif // !USE_FBGEMM
}
protected:
float epsilon_;
const float weight_decay_;
const int64_t counter_halflife_;
#if defined(USE_FBGEMM) && !defined(__NVCC__)
fbgemm::SparseAdaGradSignature<std::int32_t>::Type kernel_i32_;
fbgemm::SparseAdaGradSignature<std::int64_t>::Type kernel_i64_;
std::int64_t last_block_size_{-1};
#endif
INPUT_TAGS(PARAM, MOMENT_1, INDICES, GRAD, LR, COUNTER);
OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1);
};
} // namespace caffe2
|