File: normalization.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (983 lines) | stat: -rw-r--r-- 33,120 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ops/normalization.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

int nonNegativeAxis(int axis, int ndims) {
  return (axis >= 0) ? axis : (ndims + axis);
}

Val* numFeatures(TensorView* x, const std::vector<int>& dims, int ndims) {
  Val* num_features = IrBuilder::create<Double>(x->container(), 1);
  for (const auto dim : dims) {
    const int axis = nonNegativeAxis(dim, ndims);
    num_features = mul(num_features, x->domain()->domain()[axis]->extent());
  }
  return num_features;
}

TensorView* mean(TensorView* x, const std::vector<int>& dims, bool keepdim) {
  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");

  const int kNumberOfDims =
      TensorDomain::noReductions(x->getMaybeRFactorDomain()).size();

  auto sum_x = sum(x, dims, keepdim);
  auto y = div(sum_x, numFeatures(x, dims, kNumberOfDims));
  return y;
}

TensorView* variance(
    TensorView* x,
    const std::vector<int>& dims,
    bool unbiased,
    bool keepdim) {
  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
  int64_t correction = unbiased ? 1 : 0;
  return variance(x, dims, correction, keepdim);
}

TensorView* variance(
    TensorView* x,
    const std::vector<int>& dims,
    int64_t correction,
    bool keepdim) {
  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");

  TORCH_CHECK(
      correction >= 0, "correction must be non-negative, but got ", correction);

  const int kNumberOfDims =
      TensorDomain::noReductions(x->getMaybeRFactorDomain()).size();

  auto bcast_mean = mean(x, dims, true /* keepdim */);
  auto x_mean_sub = sub(x, bcast_mean);
  auto x_mean_sub_sq = mul(x_mean_sub, x_mean_sub);
  auto sum_x_mean_sub_sq = sum(x_mean_sub_sq, dims, keepdim);

  auto num_features = numFeatures(x, dims, kNumberOfDims);
  if (correction > 0) {
    num_features =
        sub(num_features, IrBuilder::create<Int>(x->container(), correction));
  }
  auto y = div(sum_x_mean_sub_sq, num_features);

  return y;
}

TORCH_CUDA_CU_API VarMeanResult variance_mean(
    TensorView* x,
    const std::vector<int>& dims,
    int64_t correction,
    bool keepdim) {
  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");

  TORCH_CHECK(
      correction >= 0, "correction must be non-negative, but got ", correction);

  // There are compilation errors for half precision
  auto dtype = x->getDataType().value();
  TORCH_CHECK(
      !(dtype == DataType::Half || dtype == DataType::BFloat16),
      "variance_mean is not supported for ",
      dtype,
      " please upcast to float");

  if (isComplexType(x->getDataType().value())) {
    // There are compilation errors:
    // __tmp_kernel1.cu(6727): error: namespace "CudaCodeGen::std" has no member
    // "imagf"
    // __tmp_kernel1.cu(6753): error: namespace "CudaCodeGen::std" has no member
    // "realf"
    TORCH_CHECK(false, "var_mean is not supported for complex types.");
    auto out_real = variance_mean(real(x), dims, correction, keepdim);
    auto out_imag = variance_mean(imag(x), dims, correction, keepdim);
    // variance of a complex tensor is the sum of real and imaginary variances
    // and is real mean of a complex tensor is complex complex(out_real.mean,
    // out_imag.mean) It seems construction of a complex tensor from two real
    // tensors is not supported yet
    return {add(out_real.var, out_imag.var), nullptr};
  }

  const int kNumberOfDims =
      TensorDomain::noReductions(x->getMaybeRFactorDomain()).size();
  auto num_features = numFeatures(x, dims, kNumberOfDims);
  if (correction > 0) {
    num_features =
        sub(num_features, IrBuilder::create<Int>(x->container(), correction));
  }

  auto welford_out = Welford(x, dims);
  auto mean = welford_out.avg;
  auto var = mul(welford_out.var_sum, reciprocal(num_features));

  if (keepdim) {
    std::vector<bool> is_broadcast(kNumberOfDims, false);
    for (auto dim : dims) {
      is_broadcast[dim] = true;
    }
    var = broadcast(var, is_broadcast);
    mean = broadcast(mean, is_broadcast);
  }

  return {var, mean};
}

TensorView* standard_deviation(
    TensorView* x,
    const std::vector<int>& dims,
    bool unbiased,
    bool keepdim) {
  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
  return sqrt(variance(x, dims, unbiased, keepdim));
}

TensorView* softmax(TensorView* x, int dim) {
  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");

  const int kNumberOfDims =
      TensorDomain::noReductions(x->getMaybeRFactorDomain()).size();
  const int kReductionAxis = (dim < 0) ? dim + kNumberOfDims : dim;
  TORCH_INTERNAL_ASSERT(kReductionAxis >= 0 && kReductionAxis < kNumberOfDims);

  std::vector<bool> broadcast_mask(kNumberOfDims, false);
  broadcast_mask[kReductionAxis] = true;

  auto max_val = max(x, {kReductionAxis});
  auto bcast_max = broadcast(max_val, broadcast_mask);
  auto x_max_sub = sub(x, bcast_max);
  auto exp_val = exp(x_max_sub);
  auto sum_exp = sum(exp_val, {kReductionAxis});
  auto bcast_sum = broadcast(sum_exp, broadcast_mask);
  auto y = mul(exp_val, reciprocal(bcast_sum));

  return y;
}

TensorView* softmax_backward(TensorView* dy, TensorView* y, int dim) {
  TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
  TORCH_INTERNAL_ASSERT(y != nullptr, "Output is invalid.");

  const int kNumberOfDims =
      TensorDomain::noReductions(y->getMaybeRFactorDomain()).size();
  const int kReductionAxis = (dim < 0) ? dim + kNumberOfDims : dim;
  TORCH_INTERNAL_ASSERT(kReductionAxis >= 0 && kReductionAxis < kNumberOfDims);

  std::vector<bool> broadcast_mask(kNumberOfDims, false);
  broadcast_mask[kReductionAxis] = true;

  auto new_grad = mul(dy, y);
  auto sum_new_grad = sum(new_grad, {kReductionAxis});
  auto bcast_sum = broadcast(sum_new_grad, broadcast_mask);
  auto output_sum_mul = mul(y, bcast_sum);
  auto dx = sub(new_grad, output_sum_mul);

  return dx;
}

TensorView* log_softmax(TensorView* x, int dim) {
  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");

  const int kNumberOfDims =
      TensorDomain::noReductions(x->getMaybeRFactorDomain()).size();
  const int kReductionAxis = (dim < 0) ? dim + kNumberOfDims : dim;
  TORCH_INTERNAL_ASSERT(kReductionAxis >= 0 && kReductionAxis < kNumberOfDims);

  std::vector<bool> broadcast_mask(kNumberOfDims, false);
  broadcast_mask[kReductionAxis] = true;

  auto max_val = max(x, {kReductionAxis});
  auto bcast_max = broadcast(max_val, broadcast_mask);
  auto x_max_sub = sub(x, bcast_max);
  auto exp_val = exp(x_max_sub);
  auto bcast_sum = sum(exp_val, {kReductionAxis}, true /* keepdim */);
  auto log_sum_exp = log(bcast_sum);
  auto y = sub(x_max_sub, log_sum_exp);

  return y;
}

TensorView* log_softmax_backward(TensorView* dy, TensorView* y, int dim) {
  TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
  TORCH_INTERNAL_ASSERT(y != nullptr, "Output is invalid.");

  const int kNumberOfDims =
      TensorDomain::noReductions(y->getMaybeRFactorDomain()).size();
  const int kReductionAxis = (dim < 0) ? dim + kNumberOfDims : dim;
  TORCH_INTERNAL_ASSERT(kReductionAxis >= 0 && kReductionAxis < kNumberOfDims);

  auto bcast_sum_grad = sum(dy, {kReductionAxis}, true /* keepdim */);
  auto softmax = exp(y);
  auto softmax_sum_mul = mul(softmax, bcast_sum_grad);
  auto dx = sub(dy, softmax_sum_mul);

  return dx;
}

ForwardNormResult layer_norm(
    TensorView* x,
    const std::vector<int64_t>& norm_shape,
    TensorView* weight,
    TensorView* bias,
    Val* eps) {
  return layer_norm(x, norm_shape.size(), weight, bias, eps);
}

auto norm_properties_from_num_dims(
    const TensorView* x,
    const size_t kNormShapeNumDims) {
  // (B, C, H, W, D) tensor
  // norm_shape = [H, W, D]
  // M = outer = product of remaining dimensions = B * C
  // N = reduction = product of norm_shape = H * W * D
  // weight = bias = norm_shape tensor
  const size_t kNumberOfDims =
      TensorDomain::noReductions(x->getMaybeRFactorDomain()).size();
  const size_t kOuterNumDims = kNumberOfDims - kNormShapeNumDims;

  std::vector<int> outer_reduction_axes(kOuterNumDims);
  std::vector<bool> outer_broadcast_mask(kNumberOfDims, false);
  std::vector<int> inner_reduction_axes(kNormShapeNumDims);
  std::vector<bool> inner_broadcast_mask(kNumberOfDims, false);

  for (const auto idx : c10::irange(kOuterNumDims)) {
    outer_reduction_axes[idx] = idx;
    outer_broadcast_mask[idx] = true;
  }

  Val* num_features = IrBuilder::create<Double>(x->container(), 1);
  for (const auto idx : c10::irange(kNormShapeNumDims)) {
    const size_t axis = kNumberOfDims - 1 - idx;
    inner_reduction_axes[idx] = axis;
    inner_broadcast_mask[axis] = true;
    num_features = mul(num_features, x->domain()->domain()[axis]->extent());
  }
  struct result {
    std::vector<int> outer_reduction_axes;
    std::vector<bool> outer_broadcast_mask;
    std::vector<int> inner_reduction_axes;
    std::vector<bool> inner_broadcast_mask;
    Val* num_features = nullptr;
  } r;
  r.outer_reduction_axes = outer_reduction_axes;
  r.outer_broadcast_mask = outer_broadcast_mask;
  r.inner_reduction_axes = inner_reduction_axes;
  r.inner_broadcast_mask = inner_broadcast_mask;
  r.num_features = num_features;
  return r;
}

ForwardNormResult layer_norm(
    TensorView* x,
    const size_t kNormShapeNumDims,
    TensorView* weight,
    TensorView* bias,
    Val* eps) {
  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
  TORCH_INTERNAL_ASSERT(
      eps != nullptr && eps->getDataType().has_value() &&
          eps->getDataType().value() == DataType::Double,
      "Epsilon (eps) is not a valid Double.");

  auto r = norm_properties_from_num_dims(x, kNormShapeNumDims);

  // Main algorithm
  auto welford_out = Welford(x, r.inner_reduction_axes);
  auto mean_bcast = broadcast(welford_out.avg, r.inner_broadcast_mask);
  auto x_sub_mean = sub(x, mean_bcast);

  auto var_sum_bcast = broadcast(welford_out.var_sum, r.inner_broadcast_mask);
  auto var = mul(var_sum_bcast, reciprocal(r.num_features));
  auto var_eps = add(var, eps);
  auto invstd = rsqrt(var_eps);

  auto y = mul(x_sub_mean, invstd);

  // Optional: norm * weight
  if (weight != nullptr) {
    auto weight_bcast = broadcast(weight, r.outer_broadcast_mask);
    y = mul(y, weight_bcast);
  }

  // Optional: norm * weight + bias
  if (bias != nullptr) {
    auto bias_bcast = broadcast(bias, r.outer_broadcast_mask);
    y = add(y, bias_bcast);
  }

  return {y, mean_bcast, invstd};
}

ForwardRMSNormResult rms_norm(
    TensorView* x,
    const std::vector<int64_t>& norm_shape,
    TensorView* weight,
    Val* eps) {
  return rms_norm(x, norm_shape.size(), weight, eps);
}

ForwardRMSNormResult rms_norm(
    TensorView* x,
    const size_t kNormShapeNumDims,
    TensorView* weight,
    Val* eps) {
  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
  TORCH_INTERNAL_ASSERT(
      eps != nullptr && eps->getDataType().has_value() &&
          eps->getDataType().value() == DataType::Double,
      "Epsilon (eps) is not a valid Double.");

  auto r = norm_properties_from_num_dims(x, kNormShapeNumDims);

  // Main algorithm
  auto var_sum = sum(mul(x, x), r.inner_reduction_axes);
  auto var_sum_bcast = broadcast(var_sum, r.inner_broadcast_mask);
  auto var = mul(var_sum_bcast, reciprocal(r.num_features));
  auto var_eps = add(var, eps);
  auto invstd = rsqrt(var_eps);

  auto y = mul(x, invstd);

  // Optional: norm * weight
  if (weight != nullptr) {
    auto weight_bcast = broadcast(weight, r.outer_broadcast_mask);
    y = mul(y, weight_bcast);
  }

  return {y, invstd};
}

BackwardNormResult layer_norm_backward(
    TensorView* dy,
    TensorView* x,
    const std::vector<int64_t>& norm_shape,
    TensorView* mean,
    TensorView* invstd,
    TensorView* weight,
    TensorView* bias,
    const std::vector<bool>& output_mask) {
  TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
  TORCH_INTERNAL_ASSERT(mean != nullptr, "Mean is invalid.");
  TORCH_INTERNAL_ASSERT(invstd != nullptr, "Inv std is invalid.");

  auto r = norm_properties_from_num_dims(x, norm_shape.size());

  auto x_hat = mul(sub(x, mean), invstd);

  TensorView* grad_x_hat = nullptr;
  if (weight != nullptr) {
    auto* bcast_weight = broadcast(weight, r.outer_broadcast_mask);
    grad_x_hat = mul(dy, bcast_weight);
  } else {
    grad_x_hat = dy;
  }

  auto a = mul(r.num_features, grad_x_hat);

  auto b = sum(grad_x_hat, r.inner_reduction_axes);
  auto bcast_b = broadcast(b, r.inner_broadcast_mask);

  auto c1 = mul(grad_x_hat, x_hat);
  auto c2 = sum(c1, r.inner_reduction_axes);
  auto bcast_c2 = broadcast(c2, r.inner_broadcast_mask);
  auto c3 = mul(x_hat, bcast_c2);

  auto inner = sub(sub(a, bcast_b), c3);
  auto reciprocal_size = reciprocal(r.num_features);

  TensorView* dx = nullptr;
  if (output_mask[0]) {
    dx = mul(mul(reciprocal_size, invstd), inner);
  }

  TensorView* dw = nullptr;
  if (output_mask[1] && weight != nullptr) {
    dw = sum(mul(dy, x_hat), r.outer_reduction_axes);
  }

  TensorView* db = nullptr;
  if (output_mask[2] && bias != nullptr) {
    db = sum(dy, r.outer_reduction_axes);
  }
  return {dx, dw, db};
}

BackwardRMSNormResult rms_norm_backward(
    TensorView* dy,
    TensorView* x,
    const std::vector<int64_t>& norm_shape,
    TensorView* invstd,
    TensorView* weight,
    const std::vector<bool>& output_mask) {
  TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
  TORCH_INTERNAL_ASSERT(invstd != nullptr, "Inv std is invalid.");

  auto r = norm_properties_from_num_dims(x, norm_shape.size());

  auto x_hat = mul(x, invstd);

  TensorView* grad_x_hat = nullptr;
  if (weight != nullptr) {
    auto* bcast_weight = broadcast(weight, r.outer_broadcast_mask);
    grad_x_hat = mul(dy, bcast_weight);
  } else {
    grad_x_hat = dy;
  }

  auto a = mul(r.num_features, grad_x_hat);

  auto b = sum(grad_x_hat, r.inner_reduction_axes);
  auto bcast_b = broadcast(b, r.inner_broadcast_mask);

  auto c1 = mul(grad_x_hat, x_hat);
  auto c2 = sum(c1, r.inner_reduction_axes);
  auto bcast_c2 = broadcast(c2, r.inner_broadcast_mask);
  auto c3 = mul(x_hat, bcast_c2);

  auto inner = sub(sub(a, bcast_b), c3);
  auto reciprocal_size = reciprocal(r.num_features);

  TensorView* dx = nullptr;
  if (output_mask[0]) {
    dx = mul(mul(reciprocal_size, invstd), inner);
  }

  TensorView* dw = nullptr;
  if (output_mask[1] && weight != nullptr) {
    dw = sum(mul(dy, x_hat), r.outer_reduction_axes);
  }

  return {dx, dw};
}

ForwardNormResult batch_norm(
    TensorView* x,
    TensorView* weight,
    TensorView* bias,
    TensorView* running_mean,
    TensorView* running_var,
    const bool kTraining,
    Val* momentum,
    Val* eps,
    bool channels_last) {
  auto fusion = FusionGuard::getCurFusion();

  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");

  TORCH_INTERNAL_ASSERT(
      !((running_var == nullptr) ^ (running_mean == nullptr)),
      "running stats should comes in pairs");

  TORCH_INTERNAL_ASSERT(
      momentum != nullptr && momentum->getDataType().has_value() &&
          momentum->getDataType().value() == DataType::Double,
      "Momentum is not a valid Double.");

  TORCH_INTERNAL_ASSERT(
      eps != nullptr && eps->getDataType().has_value() &&
          eps->getDataType().value() == DataType::Double,
      "Epsilon (eps) is not a valid Double.");

  // (B, C, H, W, D) tensor
  // M = outer = channels
  // N = reduction = B * H * W * D
  // weight = bias = (C) tensor
  const size_t kNumberOfDims =
      TensorDomain::noReductions(x->getMaybeRFactorDomain()).size();
  // channels last format means C dimension is at axis kNumberOfDims-1 at x
  size_t c_axis = channels_last ? kNumberOfDims - 1 : 1;

  std::vector<int> reduction_axes;
  std::vector<bool> broadcast_mask(kNumberOfDims, false);
  Val* num_features = IrBuilder::create<Double>(x->container(), 1);

  for (const auto axis : c10::irange(kNumberOfDims)) {
    if (axis != c_axis) {
      reduction_axes.push_back(axis);
      broadcast_mask[axis] = true;
      num_features = mul(num_features, x->domain()->domain()[axis]->extent());
    }
  }

  TensorView* y = nullptr;
  TensorView* mean = nullptr;
  TensorView* invstd = nullptr;
  if (kTraining || running_mean == nullptr) {
    // Algorithm
    auto welford_out = Welford(x, reduction_axes);

    // updating running mean and running var
    if (running_mean != nullptr && running_var != nullptr) {
      // Note: kTraining is true here!
      TORCH_INTERNAL_ASSERT(
          kTraining,
          "When running stats are provided, batch stats should only be computed during training");

      auto rev_momentum =
          sub(IrBuilder::create<Double>(x->container(), 1.0), momentum);
      auto current_mean_hat = mul(welford_out.avg, momentum);
      auto mean_hat = mul(running_mean, rev_momentum);
      auto new_mean_hat = add(mean_hat, current_mean_hat);

      auto num_feature_decrement = sub(num_features, x->container()->oneVal());
      auto unbiased_var =
          mul(welford_out.var_sum, reciprocal(num_feature_decrement));
      auto current_var_hat = mul(unbiased_var, momentum);
      auto var_hat = mul(running_var, rev_momentum);
      auto new_var_hat = add(var_hat, current_var_hat);

      // when inputs have been cast by parser. We want to alias the output to
      // the pre-cast input, so we can still update running stats
      auto cast_to_input_dtype = [fusion](
                                     Val* cast_input, Val* aliased_output) {
        auto unary_op = cast_input->definition();
        TORCH_INTERNAL_ASSERT(
            unary_op->isA<UnaryOp>() &&
                unary_op->as<UnaryOp>()->getUnaryOpType() == UnaryOpType::Cast,
            "check for cast op");
        auto input_to_cast = unary_op->input(0);
        TORCH_INTERNAL_ASSERT(
            input_to_cast->isFusionInput(),
            "IO_tensor batch_norm::running_stats can only updating input tensor to fusion");
        auto rm_dtype = input_to_cast->getDataType();
        TORCH_INTERNAL_ASSERT(
            rm_dtype.has_value(),
            "Input running stats must have dtype defined");
        auto cast_output = castOp(*rm_dtype, aliased_output);

        fusion->aliasOutputToInput(cast_output, input_to_cast);
      };

      if (running_mean->isFusionInput()) {
        fusion->aliasOutputToInput(new_mean_hat, running_mean);
      } else {
        cast_to_input_dtype(running_mean, new_mean_hat);
      }

      if (running_var->isFusionInput()) {
        fusion->aliasOutputToInput(new_var_hat, running_var);
      } else {
        cast_to_input_dtype(running_var, new_var_hat);
      }
    }

    mean = welford_out.avg;
    auto mean_bcast = broadcast(mean, broadcast_mask);
    auto x_sub_mean = sub(x, mean_bcast);

    auto var = mul(welford_out.var_sum, reciprocal(num_features));
    auto var_eps = add(var, eps);
    invstd = rsqrt(var_eps);
    auto invstd_bcast = broadcast(invstd, broadcast_mask);

    y = mul(x_sub_mean, invstd_bcast);
  } else {
    // This is inference mode with running stats
    auto r_mean_bcasted = broadcast(running_mean, broadcast_mask);
    auto x_sub_mean = sub(x, r_mean_bcasted);

    auto var_eps = add(running_var, eps);
    auto unbiased_invstd = rsqrt(var_eps);
    auto invstd_bcast = broadcast(unbiased_invstd, broadcast_mask);

    // During inference, mean/invstd output are empty tensors
    // on CPU, but not on CUDA. We need to make sure we have the same
    // behavior as with eager mode on CUDA.
    mean = set(running_mean); // use set to avoid "trivial input forwarding NOT
                              // IMPLEMENTED" error
    invstd = unbiased_invstd;
    y = mul(x_sub_mean, invstd_bcast);
  }

  // Optional: norm * weight
  if (weight) {
    auto weight_bcast = broadcast(weight, broadcast_mask);
    y = mul(y, weight_bcast);
  }

  // Optional: norm * weight + bias
  if (bias) {
    auto bias_bcast = broadcast(bias, broadcast_mask);
    y = add(y, bias_bcast);
  }
  return {y, mean, invstd};
}

BackwardNormResult batch_norm_backward(
    TensorView* input,
    TensorView* grad_output,
    TensorView* weight,
    TensorView* running_mean,
    TensorView* running_var,
    TensorView* save_mean,
    TensorView* save_invstd,
    const bool kTraining,
    Val* eps,
    const std::vector<bool>& output_mask,
    bool channels_last) {
  TORCH_INTERNAL_ASSERT(input != nullptr, "Input is invalid.");
  TORCH_INTERNAL_ASSERT(grad_output != nullptr, "Grad Output is invalid.");
  TORCH_INTERNAL_ASSERT(
      eps != nullptr && eps->getDataType().has_value() &&
          eps->getDataType().value() == DataType::Double,
      "Epsilon (eps) is not a valid Double.");

  // (B, C, H, W, D) tensor
  // M = outer = channels
  // N = reduction = B * H * W * D
  // weight = bias = (C) tensor
  const size_t kNumberOfDims =
      TensorDomain::noReductions(input->getMaybeRFactorDomain()).size();
  // channels last format means C dimension is at axis kNumberOfDims-1 at x /
  // grad_out
  size_t c_axis = channels_last ? kNumberOfDims - 1 : 1;

  std::vector<int> reduction_axes;
  std::vector<bool> broadcast_mask(kNumberOfDims, false);
  Val* num_features = nullptr;
  for (const auto axis : c10::irange(kNumberOfDims)) {
    if (axis != c_axis) {
      reduction_axes.push_back(axis);
      broadcast_mask[axis] = true;
      if (num_features == nullptr) {
        num_features =
            castOp(DataType::Double, input->domain()->domain()[axis]->extent());
      } else {
        num_features =
            mul(num_features, input->domain()->domain()[axis]->extent());
      }
    }
  }

  auto mean = save_mean;
  auto invstd = save_invstd;
  if (kTraining) {
    TORCH_INTERNAL_ASSERT(
        save_mean != nullptr && save_invstd != nullptr,
        "When training=True, save_mean and save_invstd are required.");
  } else {
    mean = running_mean;
    invstd = rsqrt(add(running_var, eps));
  }

  mean = broadcast(mean, broadcast_mask);

  auto norm = reciprocal(num_features);

  auto grad_output_sum = sum(grad_output, reduction_axes);
  auto dot_p = sum(mul(grad_output, sub(input, mean)), reduction_axes);

  auto grad_mean = broadcast(mul(grad_output_sum, norm), broadcast_mask);
  auto proj_scale =
      broadcast(mul(mul(dot_p, norm), mul(invstd, invstd)), broadcast_mask);
  TensorView* grad_scale = nullptr;

  if (weight == nullptr) {
    grad_scale =
        mul(broadcast(invstd, broadcast_mask),
            IrBuilder::create<Double>(input->container(), 1));
  } else {
    grad_scale = mul(
        broadcast(invstd, broadcast_mask), broadcast(weight, broadcast_mask));
  }

  TensorView* grad_input = nullptr;
  if (kTraining) {
    auto proj = mul(sub(input, mean), proj_scale);
    grad_input = mul(sub(sub(grad_output, proj), grad_mean), grad_scale);
  } else {
    grad_input = mul(grad_output, grad_scale);
  }

  TensorView* grad_weight = nullptr;
  if (output_mask[1]) {
    grad_weight = mul(dot_p, invstd);
  }

  TensorView* grad_bias = nullptr;
  if (output_mask[2]) {
    grad_bias = grad_output_sum;
  }

  return {grad_input, grad_weight, grad_bias};
}

ForwardNormResult instance_norm(
    TensorView* x,
    TensorView* weight,
    TensorView* bias,
    TensorView* running_mean,
    TensorView* running_var,
    const bool kUseInputStats,
    Val* momentum,
    Val* eps,
    bool channels_last) {
  auto fusion = FusionGuard::getCurFusion();

  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");

  TORCH_INTERNAL_ASSERT(
      !((running_var == nullptr) ^ (running_mean == nullptr)),
      "running stats should comes in pairs");

  TORCH_INTERNAL_ASSERT(
      momentum != nullptr && momentum->getDataType().has_value() &&
          momentum->getDataType().value() == DataType::Double,
      "Momentum is not a valid Double.");

  TORCH_INTERNAL_ASSERT(
      eps != nullptr && eps->getDataType().has_value() &&
          eps->getDataType().value() == DataType::Double,
      "Epsilon (eps) is not a valid Double.");

  // (B, C, H, W, D) tensor
  // M = outer = B * C
  // N = reduction = H * W * D
  // weight = bias = C tensor
  const size_t kBatchDim = 0;
  const size_t kNumberOfDims =
      TensorDomain::noReductions(x->getMaybeRFactorDomain()).size();
  const size_t kChannelsDim = channels_last ? kNumberOfDims - 1 : 1;

  std::vector<int> x_reduction_axes;
  std::vector<bool> x_broadcast_mask(kNumberOfDims, false);
  Val* N = IrBuilder::create<Double>(x->container(), 1);
  for (const auto axis : c10::irange(kNumberOfDims)) {
    if (axis != kBatchDim && axis != kChannelsDim) {
      x_reduction_axes.push_back(axis);
      x_broadcast_mask[axis] = true;
      N = mul(N, x->domain()->domain()[axis]->extent());
    }
  }
  Val* B = IrBuilder::create<Double>(x->container(), 1);
  B = mul(B, x->domain()->domain()[kBatchDim]->extent());

  std::vector<bool> channels_only_broadcast_mask(kNumberOfDims, false);
  for (const auto axis : c10::irange(kNumberOfDims)) {
    if (axis != kChannelsDim) {
      channels_only_broadcast_mask[axis] = true;
    }
  }

  TensorView* y = nullptr;
  TensorView* mean = nullptr;
  TensorView* invstd = nullptr;
  if (kUseInputStats || running_mean == nullptr) {
    // Algorithm
    auto welford_out = Welford(x, x_reduction_axes);

    // updating running mean and running var
    if (running_mean != nullptr && running_var != nullptr) {
      auto _running_mean = running_mean;
      auto _running_var = running_var;
      if (_running_mean->getDataType().value() == DataType::Half ||
          _running_mean->getDataType().value() == DataType::BFloat16) {
        _running_mean = castOp(DataType::Float, _running_mean);
      }
      if (_running_var->getDataType().value() == DataType::Half ||
          _running_var->getDataType().value() == DataType::BFloat16) {
        _running_var = castOp(DataType::Float, running_var);
      }
      auto rev_momentum =
          sub(IrBuilder::create<Double>(x->container(), 1.0), momentum);
      auto current_mean_hat = mul(welford_out.avg, momentum);
      auto mean_hat = mul(_running_mean, rev_momentum);
      auto new_mean_hat = add(mean_hat, current_mean_hat);

      // NS: static_cast to workaround VC++ error, see
      // https://godbolt.org/z/6Prd77xYs
      auto new_mean_sum = sum(new_mean_hat, {static_cast<int>(kBatchDim)});
      auto new_mean_channels_only = mul(new_mean_sum, reciprocal(B));
      if (running_mean->getDataType().value() == DataType::Half ||
          running_mean->getDataType().value() == DataType::BFloat16) {
        new_mean_channels_only =
            castOp(running_mean->getDataType().value(), new_mean_channels_only);
      }
      // fusion->addOutput(new_mean_channels_only);
      fusion->aliasOutputToInput(new_mean_channels_only, running_mean);

      auto num_feature_decrement = sub(N, x->container()->oneVal());
      auto unbiased_var =
          mul(welford_out.var_sum, reciprocal(num_feature_decrement));
      auto current_var_hat = mul(unbiased_var, momentum);
      auto var_hat = mul(_running_var, rev_momentum);
      auto new_var_hat = add(var_hat, current_var_hat);

      // NS: static_cast to workaround VC++ error, see
      // https://godbolt.org/z/6Prd77xYs
      auto new_var_sum = sum(new_var_hat, {static_cast<int>(kBatchDim)});
      auto new_var_channels_only = mul(new_var_sum, reciprocal(B));
      if (running_var->getDataType().value() == DataType::Half ||
          running_var->getDataType().value() == DataType::BFloat16) {
        new_var_channels_only =
            castOp(running_var->getDataType().value(), new_var_channels_only);
      }
      // fusion->addOutput(new_var_channels_only);
      fusion->aliasOutputToInput(new_var_channels_only, running_var);
    }

    mean = welford_out.avg;
    auto mean_bcast = broadcast(mean, x_broadcast_mask);
    auto x_sub_mean = sub(x, mean_bcast);

    auto var = mul(welford_out.var_sum, reciprocal(N));
    auto var_eps = add(var, eps);
    invstd = rsqrt(var_eps);
    auto invstd_bcast = broadcast(invstd, x_broadcast_mask);

    y = mul(x_sub_mean, invstd_bcast);
  } else {
    // This is inference mode with running stats
    auto r_mean_bcasted = broadcast(running_mean, channels_only_broadcast_mask);
    auto x_sub_mean = sub(x, r_mean_bcasted);

    auto var_eps = add(running_var, eps);
    auto unbiased_invstd = rsqrt(var_eps);
    auto invstd_bcast =
        broadcast(unbiased_invstd, channels_only_broadcast_mask);

    // During inference, mean/invstd output are empty tensors
    mean = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
    invstd = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
    y = mul(x_sub_mean, invstd_bcast);
  }

  // Optional: norm * weight
  if (weight) {
    auto weight_bcast = broadcast(weight, channels_only_broadcast_mask);
    y = mul(y, weight_bcast);
  }

  // Optional: norm * weight + bias
  if (bias) {
    auto bias_bcast = broadcast(bias, channels_only_broadcast_mask);
    y = add(y, bias_bcast);
  }
  return {y, mean, invstd};
}

BackwardNormResult instance_norm_backward(
    TensorView* input,
    TensorView* grad_output,
    TensorView* weight,
    TensorView* running_mean,
    TensorView* running_var,
    TensorView* save_mean,
    TensorView* save_invstd,
    const bool kTraining,
    Val* eps,
    const std::vector<bool>& output_mask,
    bool channels_last) {
  TORCH_INTERNAL_ASSERT(input != nullptr, "Input is invalid.");
  TORCH_INTERNAL_ASSERT(grad_output != nullptr, "Grad Output is invalid.");
  TORCH_INTERNAL_ASSERT(
      eps != nullptr && eps->getDataType().has_value() &&
          eps->getDataType().value() == DataType::Double,
      "Epsilon (eps) is not a valid Double.");

  // (B, C, H, W, D) tensor
  // M = outer = channels
  // N = reduction = B * H * W * D
  // weight = bias = (C) tensor
  const size_t kNumberOfDims =
      TensorDomain::noReductions(input->getMaybeRFactorDomain()).size();
  // channels last format means C dimension is at axis kNumberOfDims-1 at x /
  // grad_out
  const size_t b_axis = 0; // for clarity
  const size_t c_axis = channels_last ? kNumberOfDims - 1 : 1;

  std::vector<int> reduction_axes;
  std::vector<bool> broadcast_mask(kNumberOfDims, false);
  // weight has its own broadcast mask as it is broadcast for the batch unlike
  // mean/var
  std::vector<bool> weight_broadcast_mask(kNumberOfDims, false);
  Val* num_features = nullptr;
  for (const auto axis : c10::irange(kNumberOfDims)) {
    if (axis != c_axis) {
      weight_broadcast_mask[axis] = true;
      if (axis != b_axis) {
        reduction_axes.push_back(axis);
        broadcast_mask[axis] = true;
        if (num_features == nullptr) {
          num_features = castOp(
              DataType::Double, input->domain()->domain()[axis]->extent());
        } else {
          num_features =
              mul(num_features, input->domain()->domain()[axis]->extent());
        }
      }
    }
  }

  auto mean = save_mean;
  auto invstd = save_invstd;
  if (kTraining) {
    TORCH_INTERNAL_ASSERT(
        save_mean != nullptr && save_invstd != nullptr,
        "When training=True, save_mean and save_invstd are required.");
  } else {
    mean = running_mean;
    invstd = rsqrt(add(running_var, eps));
  }
  mean = broadcast(mean, broadcast_mask);

  auto norm = reciprocal(num_features);

  auto grad_output_sum = sum(grad_output, reduction_axes);
  auto dot_p = sum(mul(grad_output, sub(input, mean)), reduction_axes);

  auto grad_mean = broadcast(mul(grad_output_sum, norm), broadcast_mask);

  auto proj_scale =
      broadcast(mul(mul(dot_p, norm), mul(invstd, invstd)), broadcast_mask);

  TensorView* grad_scale = nullptr;

  if (weight == nullptr) {
    grad_scale =
        mul(broadcast(invstd, broadcast_mask),
            IrBuilder::create<Double>(input->container(), 1));
  } else {
    grad_scale =
        mul(broadcast(invstd, broadcast_mask),
            broadcast(weight, weight_broadcast_mask));
  }

  TensorView* grad_input = nullptr;
  if (kTraining) {
    auto proj = mul(sub(input, mean), proj_scale);
    grad_input = mul(sub(sub(grad_output, proj), grad_mean), grad_scale);
  } else {
    grad_input = mul(grad_output, grad_scale);
  }

  TensorView* grad_weight = nullptr;
  TensorView* grad_weight_reduced = nullptr;
  if (output_mask[1]) {
    grad_weight = mul(dot_p, invstd);
    // TODO: grad weight needs to be reduced across batch-dim but is this the
    // most efficient place or can reduction happen earlier?
    grad_weight_reduced = sum(grad_weight, {0});
  }

  TensorView* grad_bias = nullptr;
  TensorView* grad_bias_reduced = nullptr;
  if (output_mask[2]) {
    grad_bias = grad_output_sum;
    // TODO: same as above for grad weight
    grad_bias_reduced = sum(grad_bias, {0});
  }

  return {grad_input, grad_weight_reduced, grad_bias_reduced};
}

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch