File: transform_view.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (776 lines) | stat: -rw-r--r-- 29,855 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
#include <torch/csrc/jit/codegen/cuda/transform_view.h>

#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ir_internal_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

//! There's three domains associated with performing a view operation:
//! 1) Original Domain:
//!   This view is the original input to the view operation. It has no
//!   transforms on it, it is however passed in without its reduction domains
//!   (as is expected since we're trying to generate the output of the
//!   operations).
//!
//! Trivially reduced domain:
//!   Predicting which operations are trivial reduced are not trivial. If a
//!   broadcast is between two iter domains in the original domain that must be
//!   merged for the view transform:
//!     - If the broadcast domain lines up with a broadcast domain in the final
//!       tensor domain keep it.
//!     - If the domain is size-1 but not marked as a broadcast domain (runtime
//!       size==1)
//!       Note: This isn't something we generally support consistently
//!     - If the broadcast domain is marked as a compile time broadcast domain,
//!       and doesn't line up with a broadcast domain in the final result.
//!       Trivially reduce it.
//!   The index for these transformations is marked as the index of the original
//!   domain, as that's the input for the trivial reduction. This produces the
//!   trivially reduced domain.
//!
//! Post-view Domain:
//!   This domain is the original domain after the trivial reductions and all
//!   transformations. This domain holds the rfactor domains determined by
//!   merge/split operations of the find transformations pass. It is the final
//!   domain without all the broadcast operations (can have some that were
//!   preserved through the transformations).
//!       For example: {1, 2, 1, 4} -> {1, 2, 1, 2, 2} doesn't have any
//!         conflicts of the view transformation and the broadcast dimensions,
//!         so they won't be trivial reduced, they will simply be propagated
//!         through the view.
//!         {1, 2, 1, 4} -> {1, 8, 1} does have the second 1 dimension in
//!         between the 2 and 8 that have to be merged. The first broadcast axis
//!         will be propagated through the domains unafected, yet the second
//!         braodcast axis will be trivially reduced, then rebroadcasted.
//!  The transformation index marked for the splits/merges to produce this
//!  domain are done based on an "in progress" tensor view (called transform
//!  view index in the find transformation pass). This allows us to simply apply
//!  these transformations serially to produce this domain.
//!
//! Post-broadcast Domain:
//!    This domain finally matches the output of the view operation fully and
//!    can be used in further computations.
//!
//! View process at compute time:
//!   1) View takes in the input TensorView x, original runtime
//!      std::vector<int64_t>, and viewed runtime std::vector<int64_t>.
//!   2) AnalyzeView is called Which will figure out what series of
//!      transformations is required from the input tensor to the output tensor.
//!      These transformations are recorded.
//!   3) Sum operation is called on the trivial reduction axes from the
//!      analysis.
//!   4) applyViewTransforms will generate the output domain of the view
//!      operation.
//!        Calls TensorDomain::view(view_analysis) which returns the rfactored
//!        domain.
//!        Gets forwarded to transformView(TensorDomain, view_analysis)
//!        Gets forwarded to createViewDomain(TensorDomain, view_analysis)
//!        createViewDomain creates the new root domain, and calls
//!        createRfactorDomain on view_analysis.transforms().
//!   5) brooadcast will be called with view_analysis.broadcast_axes
//!
//! TODO: Caching assumes that all size-1 inputs are correctly marked as a
//! broadcast dimension. We should probably remove the runtime size-1 merge
//! support in find transformation.
//!
//! Simple abstract class to record transformation and the indices required to
//! apply it.
class Transform : public PolymorphicBase {
 public:
  virtual std::string toString() const = 0;

  int64_t index() const {
    return index_;
  }

 protected:
  // Relevant location information for the transformation. Stored information is
  // related to when we have to apply that transformation (see long comment at
  // top of this file).
  Transform(int64_t index) : index_(index) {}

  const int64_t index_ = 0;
};

class ViewTransform : public Transform {
 public:
  // Function to apply the transformation. Transformation is applied on
  // current_transformed_domain. root_domain is required here to replace
  // IterDomains so we can flip the rfactor flag on the root domain if it's
  // involved in merge/split trasnforms to produce the rfactor domain.
  virtual void createRfactorDomain(
      std::vector<IterDomain*>& root_domain,
      std::vector<IterDomain*>& current_transformed_domain) = 0;

  // Convenience function to replace id in root_domain with an id that has
  // expand expanded, and rfactor flag turned on.
  static IterDomain* replaceRootIdWithRFactor(
      std::vector<IterDomain*>& root_domain,
      IterDomain* id) {
    auto root_domain_it = std::find(root_domain.begin(), root_domain.end(), id);

    TORCH_INTERNAL_ASSERT(
        root_domain_it != root_domain.end(),
        "Wanted to replace ",
        id->toString(),
        " in root with an rfactor dimension, but IterDomain was not found in root.");

    auto root_domain_pos = std::distance(root_domain.begin(), root_domain_it);

    bool is_expanded_dim = id->hasExpandedExtent();

    auto extent = is_expanded_dim ? id->expandedExtent() : id->extent();

    auto cloned_id =
        IterDomainBuilder(id)
            .iter_type(
                is_expanded_dim ? IterType::Iteration : id->getIterType())
            .extent(extent)
            .expanded_extent(nullptr)
            .is_rfactor_domain(true)
            .build();

    root_domain.erase(root_domain.begin() + root_domain_pos);
    root_domain.insert(root_domain.begin() + root_domain_pos, cloned_id);
    return cloned_id;
  }

  // Debugging utility to convert the transformation into a string.
  virtual std::string toString() const = 0;

 protected:
  ViewTransform(const int64_t& index) : Transform(index) {}
};

namespace {
//! The merge tranformation either combines two root iterDomains together OR
//! the last rfactor iterDomain with a root iterDomain. Unlike the general
//! TensorView merge there's no merging across axes not placed in consecutive
//! positions for View.
class MergeTransform final : public ViewTransform {
 public:
  MergeTransform(int64_t index) : ViewTransform(index) {}

  virtual std::string toString() const override {
    std::stringstream ss;
    ss << "Merge at index: " << index_;
    return ss.str();
  }

  void createRfactorDomain(
      std::vector<IterDomain*>& root_domain,
      std::vector<IterDomain*>& current_transformed_domain) override {
    TORCH_INTERNAL_ASSERT(
        (index_ + 1) < current_transformed_domain.size(),
        "Tried to apply: ",
        toString(),
        "\t To domain: \t",
        current_transformed_domain);

    // Assumed to never merge over non-contiguous dimensions.
    IterDomain* outer_id = current_transformed_domain[index_];
    if (!outer_id->isRFactorProduct()) {
      outer_id = replaceRootIdWithRFactor(root_domain, outer_id);
    }

    IterDomain* inner_id = current_transformed_domain[index_ + 1];
    if (!inner_id->isRFactorProduct()) {
      inner_id = replaceRootIdWithRFactor(root_domain, inner_id);
    }

    TORCH_INTERNAL_ASSERT(
        outer_id->start()->isZeroInt() && inner_id->start()->isZeroInt(),
        "Didn't expect to apply view transformations on an iter domain",
        " starting at a non-zero position.");

    auto merged_extent = mul(outer_id->extent(), inner_id->extent());

    auto new_merged_id =
        IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), merged_extent)
            .is_rfactor_domain(true)
            .build();

    IrBuilder::create<Merge>(new_merged_id, outer_id, inner_id);

    current_transformed_domain.erase(
        current_transformed_domain.begin() + index_);
    current_transformed_domain.erase(
        current_transformed_domain.begin() + index_);
    current_transformed_domain.insert(
        current_transformed_domain.begin() + index_, new_merged_id);
  }
};

//! The split tranformation creates two new iterDomains via an outer split.
class SplitTransform final : public ViewTransform {
 public:
  SplitTransform(const int64_t index, int64_t split_factor)
      : ViewTransform(index), split_factor_(split_factor) {
    TORCH_INTERNAL_ASSERT(
        split_factor > 0,
        "Split factors must be greater than 0, but found ",
        split_factor,
        " during view transformation.");
  }

  virtual std::string toString() const override {
    std::stringstream ss;
    ss << "Split Index at: " << index_ << " by: " << split_factor_ << std::endl;
    return ss.str();
  }

  void createRfactorDomain(
      std::vector<IterDomain*>& root_domain,
      std::vector<IterDomain*>& current_transformed_domain) override {
    TORCH_INTERNAL_ASSERT(
        index_ < current_transformed_domain.size(),
        "Index: \t",
        index_,
        "\t Domain Size:\t",
        current_transformed_domain.size());

    auto factor = IrBuilder::create<Int>(split_factor_);

    IterDomain* id = current_transformed_domain[index_];
    if (!id->isRFactorProduct()) {
      id = replaceRootIdWithRFactor(root_domain, id);
    }

    TORCH_INTERNAL_ASSERT(
        id->start()->isZeroInt(),
        "Didn't expect to apply view transformations on an iter domain",
        " starting at a non-zero position.");

    Val* remainder = ceilDiv(id->extent(), factor);

    // outer loop IterDomain
    IterDomain* factor_id =
        IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), factor)
            .parallel_type(id->getParallelType())
            .iter_type(id->getIterType())
            .is_rfactor_domain(true)
            .build();

    // inner loop IterDomain
    IterDomain* remainder_id =
        IterDomainBuilder(
            FusionGuard::getCurFusion()->zeroVal(), remainder->as<Int>())
            .is_rfactor_domain(true)
            .build();

    IrBuilder::create<Split>(factor_id, remainder_id, id, factor, false);

    current_transformed_domain.erase(
        current_transformed_domain.begin() + index_);
    current_transformed_domain.insert(
        current_transformed_domain.begin() + index_, remainder_id);
    current_transformed_domain.insert(
        current_transformed_domain.begin() + index_, factor_id);
  }

  int64_t split_factor() const {
    return split_factor_;
  }

 private:
  const int64_t split_factor_ = 0;
};

//! For any singleton dimensions in the new view, we create an implicit
//! broadcast dimension. We apply these transforms after the trivial reduction
//! and view transformation steps.
class BroadcastTransform final : public Transform {
 public:
  BroadcastTransform(int64_t index) : Transform(index) {}

  virtual std::string toString() const override {
    std::stringstream ss;
    ss << "Broadcast at: " << index_ << std::endl;
    return ss.str();
  }
};

//! For any implicit broadcast dimensions in the original view, we remove
//! them using a trivial reduction.
class TrivialReductionTransform final : public Transform {
 public:
  TrivialReductionTransform(int64_t index) : Transform(index) {}

  virtual std::string toString() const override {
    std::stringstream ss;
    ss << "Trivial reduction at: " << index_ << std::endl;
    return ss.str();
  }
};

//! The primary class that generates the transformations to go from
//! the original view to the new view.
class AnalyzeViewTransformation {
 public:
  AnalyzeViewTransformation(
      const std::vector<int64_t>& original_view,
      const std::vector<int64_t>& new_view,
      std::vector<IterDomain*> root_domain = {})
      : root_domain_not_provided_(root_domain.empty()),
        root_domain_(root_domain),
        root_is_transformed_(original_view.size(), false),
        original_view_(original_view),
        new_view_(new_view) {
    TORCH_INTERNAL_ASSERT(
        root_domain.empty() || original_view.size() == root_domain.size(),
        "Incoming domain must match the original view sizes for view.");
    // Check that the product of original and new view std::vector<int64_t> are
    // equal.
    const int64_t kOriginalNumElements = std::accumulate(
        original_view_.begin(), original_view_.end(), 1, std::multiplies<>());
    const int64_t kNewNumElements = std::accumulate(
        new_view_.begin(), new_view.end(), 1, std::multiplies<>());
    TORCH_INTERNAL_ASSERT(
        kOriginalNumElements == kNewNumElements,
        "Total element counts across view operation must match.");
  }

  AnalyzeViewConstraint constraint() {
    findTransformation();

    AnalyzeViewConstraint constraint;
    constraint.original_constraint =
        std::vector<int64_t>(original_view_.begin(), original_view_.end());
    for (auto i : c10::irange(constraint.original_constraint.size())) {
      if (constraint.original_constraint[i] != 1) {
        constraint.original_constraint[i] = 0;
      }
    }

    constraint.new_constraint =
        std::vector<int64_t>(new_view_.begin(), new_view_.end());
    for (auto i : c10::irange(constraint.new_constraint.size())) {
      if (constraint.new_constraint[i] != 1) {
        constraint.new_constraint[i] = 0;
      }
    }

    for (auto trivial_reduce : trivial_reduction_transforms_) {
      constraint.trivial_reduction_string.push_back(trivial_reduce->index());
    }

    for (auto broadcast : broadcast_transforms_) {
      constraint.broadcast_string.push_back(broadcast->index());
    }

    // Dilimeter for split/merge transforms is -2
    for (auto split_merge : view_transforms_) {
      if (split_merge->isA<SplitTransform>()) {
        constraint.split_merge_string.push_back(split_merge->index());
        constraint.split_merge_string.push_back(
            split_merge->as<SplitTransform>()->split_factor());
        constraint.split_merge_string.push_back(-2);
      } else {
        TORCH_INTERNAL_ASSERT(
            split_merge->isA<MergeTransform>(),
            "Unrecognized transformation found.");
        constraint.split_merge_string.push_back(split_merge->index());
        constraint.split_merge_string.push_back(-2);
      }
    }

    return constraint;
  }

  // Fill out all the information needed in AnalyzeViewResult, this should
  // contain all the information of what's required to perform the view
  // operation.
  AnalyzeViewResult run() {
    // Find all the transformations to go from the original tensor domain to the
    // final output of the view operations.
    findTransformation();

    auto trivial_reduction_axes = generateTrivialReductionAxes();
    auto broadcast_axes = generateBroadcastAxes();

    // Move data to AnalyzeViewResult and return it.
    return {broadcast_axes, trivial_reduction_axes, view_transforms_};
  }

 private:
  // Returns the bool flags that should be used to broadcast the output view
  // tensor
  std::vector<bool> generateBroadcastAxes() {
    std::vector<bool> broadcast_axes(new_view_.size(), false);
    for (auto& bcast : broadcast_transforms_) {
      broadcast_axes.at(bcast->index()) = true;
    }
    return broadcast_axes;
  }

  // Returns the positions for the trivial reductions to be performed before the
  // view operation
  std::vector<int> generateTrivialReductionAxes() {
    std::vector<int> reduction_axes;
    for (auto& tred : trivial_reduction_transforms_) {
      reduction_axes.push_back(tred->index());
    }
    return reduction_axes;
  }

  std::string toString() {
    std::stringstream output;
    output << "===============================" << std::endl;
    output << "old:";
    for (auto s : original_view_) {
      output << " " << s;
    }
    output << std::endl;

    output << "===============================" << std::endl;
    output << "new:";
    for (auto s : new_view_) {
      output << " " << s;
    }
    output << std::endl;

    output << "===============================" << std::endl;
    for (auto& trivial_reduction : trivial_reduction_transforms_) {
      output << trivial_reduction->toString() << "\n";
    }
    for (auto& split_or_merge : view_transforms_) {
      output << split_or_merge->toString() << "\n";
    }
    for (auto& broadcast : broadcast_transforms_) {
      output << broadcast->toString() << "\n";
    }
    output << "===============================" << std::endl;
    return output.str();
  }

  // Validation check after transformations are all found

  bool isImplicitBroadcast(int64_t original_view_index) const {
    if (root_domain_not_provided_) {
      return original_view_[original_view_index] == 1;
    } else {
      TORCH_INTERNAL_ASSERT(original_view_index < root_domain_.size());
      return root_domain_[original_view_index]->isImplicitBroadcast() &&
          !root_domain_[original_view_index]->hasExpandedExtent();
    }
  }

  //! Find the broadcast, merge and split operations necessary
  //! to transform the original view into the new view
  void findTransformation() {
    // There are three particularly important state indices we're working with.
    // There is:
    //   1) original_view_index which is indexing into the original tensor
    //      domain after all reductions are removed. This lines up with the last
    //      domain in original view that we added to current_size.
    //   2) transform_view_index which is the index of the transformations as
    //      we're virtually "developing" the output tensor domain (split/merge
    //      transformations post trivial reductions).
    //   3) The new_view_index which is directly associated with the new_view
    //      and the dimension in new_view we're currently trying to create.

    int64_t original_view_index = 0;
    int64_t transform_view_index = 0;
    int64_t new_view_index = 0;
    int64_t current_size = original_view_[0];

    // Safety counters to make sure we don't end up in an infinite loop.
    int64_t prev_original_view_index = std::numeric_limits<int64_t>::max();
    int64_t prev_new_view_index = std::numeric_limits<int64_t>::max();

    TORCH_INTERNAL_ASSERT(
        view_transforms_.empty(),
        "Already ran find transformation pass for View op, cannot run a second time.");

    // Iterate until original view is completely consumed and new view is
    // completely generated.
    while (original_view_index < original_view_.size() ||
           new_view_index < new_view_.size()) {
      TORCH_INTERNAL_ASSERT(
          !(prev_new_view_index == new_view_index &&
            prev_original_view_index == original_view_index),
          "Infinite loop detected in AnalyzeViewTransformation::findTransformation(). Bailing.");

      prev_new_view_index = new_view_index;
      prev_original_view_index = original_view_index;

      if (new_view_index >= new_view_.size()) {
        TORCH_INTERNAL_ASSERT(
            current_size == 1,
            "View is complete, but there's still some elements to distribute.");
      }

      if ((new_view_index == new_view_.size() ||
           (new_view_[new_view_index + 1] != 1)) &&
          original_view_index + 1 < original_view_.size() &&
          original_view_[original_view_index + 1] == 1 &&
          !isImplicitBroadcast(original_view_index + 1)) {
        // Next index in original_view is runtime size 1 and next new view is
        // not, merge the size 1 into the current view before moving on. Even if
        // the current size and new view size match we could have a trailing
        // size 1 dimension on the input that needs to be merged in.
        view_transforms_.push_back(
            std::make_shared<MergeTransform>(transform_view_index));
        ++original_view_index;
        continue;
      }

      if (new_view_index < new_view_.size() &&
          // Still new dimensions to resolve and current size does resolve it.
          current_size == new_view_[new_view_index]) {
        // Keep this dimension, it's good to go, we hit a boundary where there's
        // a multiple of original dims, that matches a multiple of view dims.
        // Increment state and keep going.

        ++transform_view_index;
        ++new_view_index;
        ++original_view_index;

        // Update current_size with the next size in original view
        if (original_view_index < original_view_.size()) {
          current_size = original_view_[original_view_index];
        } else {
          current_size = 0;
        }
        continue;
      }

      // Compile time broadcast in new view, but not a matching one in original
      // view. Insert broadcast and increment new_view. Size 1 dimensions in
      // new_view that don't match up with runtime size 1's in original view are
      // assumed to be broadcast (not a split from a runtime domain).
      if (new_view_index < new_view_.size() && new_view_[new_view_index] == 1) {
        broadcast_transforms_.push_back(
            std::make_shared<BroadcastTransform>(new_view_index));
        ++new_view_index;
        continue;
      }

      // If we run out of original_view dimensions we could still have broadcast
      // dimensions for new_view, but that should be hit before this point.
      TORCH_INTERNAL_ASSERT(
          current_size != 0,
          "View analysis failed, should never process an empty size unless we ",
          "simply need to add broadcasts to the post-view domain.");

      if (current_size == 1 && isImplicitBroadcast(original_view_index)) {
        // Original view has a compile time size 1 dimension, and it's not found
        // in the new_view_ (otherwise would have been caught in a branch
        // above). Do a trivial reduction.
        trivial_reduction_transforms_.push_back(
            std::make_shared<TrivialReductionTransform>(original_view_index));
        ++original_view_index;

        // Update original position and current size.
        if (original_view_index < original_view_.size()) {
          current_size = original_view_[original_view_index];
        } else {
          current_size = 0;
        }

        continue;
      }

      if (original_view_index + 1 < original_view_.size() &&
          isImplicitBroadcast(original_view_index + 1)) {
        // Original view has a compile time size 1 dimension, and it's
        // interfering with necessary transformations. Do a trivial reduction.
        ++original_view_index;
        trivial_reduction_transforms_.push_back(
            std::make_shared<TrivialReductionTransform>(original_view_index));

        continue;
      }

      // We're only left with performing transformations to match a new_view
      // dimension, there must be an activew new_view.
      TORCH_INTERNAL_ASSERT(
          new_view_index < new_view_.size(),
          "Expecting to still have new dimensions to work on in view, but none left.");

      if (new_view_index < new_view_.size() &&
          current_size % new_view_[new_view_index] == 0) {
        // Insert split to generate the next new_view domain.
        view_transforms_.push_back(std::make_shared<SplitTransform>(
            transform_view_index, new_view_[new_view_index]));
        current_size /= new_view_[new_view_index];
        TORCH_INTERNAL_ASSERT(current_size > 1, "This should be unreachable.");
        // Update transform and new since a split doesn't increment from the
        // original domain we're working on.
        ++transform_view_index;
        ++new_view_index;
        continue;
      }

      // Need more of the original_view dimension to resolve the new_view
      // dimension, merge the next dimension in.
      TORCH_INTERNAL_ASSERT(
          original_view_index + 1 < original_view_.size(),
          "Expecting to still have original dimensions to work on in view, but none left.");

      view_transforms_.push_back(
          std::make_shared<MergeTransform>(transform_view_index));
      current_size *= original_view_[++original_view_index];
    }
  }

 private:
  std::vector<std::shared_ptr<ViewTransform>> view_transforms_;
  std::vector<std::shared_ptr<BroadcastTransform>> broadcast_transforms_;
  std::vector<std::shared_ptr<TrivialReductionTransform>>
      trivial_reduction_transforms_;

  // If root domain isn't provided always assume size-1 dimensions are
  // compile-time dimensions. TODO: Remove runtime size-1 dimension support.
  // This should be cached higher in the stack.
  const bool root_domain_not_provided_ = true;

  const std::vector<IterDomain*> root_domain_;
  // Track if the root ID was transformed or kept ()
  std::vector<bool> root_is_transformed_;
  const std::vector<int64_t>& original_view_;
  const std::vector<int64_t>& new_view_;
};

//! Create new TensorDomain with a new root domain and modified rfactor domains
//! using the specified view transformations. Original domain should already be
//! without reduction axes.
TensorDomain* createViewDomain(
    TensorDomain* original_domain,
    const AnalyzeViewResult& view_analysis) {
  FUSER_PERF_SCOPE("createViewDomain");
  TORCH_INTERNAL_ASSERT(!view_analysis.transforms.empty());

  std::vector<IterDomain*> new_root_domain;
  auto orig_root_domain = original_domain->getMaybeRFactorDomain();

  // Apply trivial reductions.
  for (auto id_i : c10::irange(orig_root_domain.size())) {
    auto id = orig_root_domain[id_i];
    if (id->isReduction()) {
      continue;
    }
    if (std::find(
            view_analysis.trivial_reduction_axes.begin(),
            view_analysis.trivial_reduction_axes.end(),
            (int)id_i) != view_analysis.trivial_reduction_axes.end()) {
      continue;
    }

    new_root_domain.push_back(id->cloneWithoutRFactor());
  }

  std::vector<IterDomain*> new_rfactor_domain(
      new_root_domain.begin(), new_root_domain.end());

  // Apply rfactor transformations.
  for (auto& t : view_analysis.transforms) {
    t->createRfactorDomain(new_root_domain, new_rfactor_domain);
  }

  return IrBuilder::create<TensorDomain>(
      new_root_domain,
      new_rfactor_domain,
      new_rfactor_domain,
      std::vector<bool>(new_rfactor_domain.size(), true));
}

} // namespace

std::pair<std::vector<int64_t>, std::vector<int64_t>> inferViewShapes(
    const std::vector<int64_t>& original_sizes,
    const std::vector<int64_t>& new_sizes) {
  bool valid_original_sizes = std::all_of(
      original_sizes.begin(), original_sizes.end(), [](int64_t dim) {
        return dim > 0;
      });
  TORCH_INTERNAL_ASSERT(valid_original_sizes);

  std::vector<int64_t> original_view(
      original_sizes.begin(), original_sizes.end());
  std::vector<int64_t> new_view(new_sizes.size());

  // TODO: refactor
  int64_t dynamic_index = -1;
  int64_t new_size_num_elements = 1;
  for (int64_t idx = 0; idx < new_sizes.size(); ++idx) {
    if (new_sizes[idx] == -1) {
      TORCH_INTERNAL_ASSERT(
          dynamic_index == -1, "Only one dimension can by inferred.")
      dynamic_index = idx;
    } else {
      TORCH_INTERNAL_ASSERT(new_sizes[idx] > 0);
      new_size_num_elements *= new_sizes[idx];
      new_view[idx] = new_sizes[idx];
    }
  }

  const int64_t kNumElements = std::accumulate(
      original_view.begin(), original_view.end(), 1, std::multiplies<>());
  if (dynamic_index != -1) {
    new_view[dynamic_index] = kNumElements / new_size_num_elements;
  }

  return {original_view, new_view};
}

//! Generates the transformations necessary to convert
//! from the original view into the new view.
AnalyzeViewResult analyzeView(
    const TensorView* original_view_tv,
    const std::vector<int64_t>& original_sizes,
    const std::vector<int64_t>& new_sizes) {
  FUSER_PERF_SCOPE("analyzeView");
  TORCH_INTERNAL_ASSERT(
      original_sizes.size() > 0,
      "Empty original size not supported for view operatioon.");

  TORCH_INTERNAL_ASSERT(
      TensorDomain::noReductions(original_view_tv->getMaybeRFactorDomain())
          .size() == original_sizes.size());

  // Fill -1 dimension in new_std::vector<int64_t> with size infered from all
  // other values
  auto sizes = inferViewShapes(original_sizes, new_sizes);

  // Analysize the transformations required to go from original_sizes to
  // new_sizes
  AnalyzeViewTransformation analyzer(
      sizes.first /* original_view */,
      sizes.second /* new_view */,
      TensorDomain::noReductions(original_view_tv->getMaybeRFactorDomain()));
  return analyzer.run();
}

AnalyzeViewConstraint analyzeViewConstraint(
    const std::vector<int64_t>& original_sizes,
    const std::vector<int64_t>& new_sizes) {
  FUSER_PERF_SCOPE("analyzeViewConstraint");
  auto sizes = inferViewShapes(original_sizes, new_sizes);
  AnalyzeViewTransformation analyzer(
      sizes.first /* original_view */, sizes.second /* new_view */);
  return analyzer.constraint();
}

//! Create new TensorDomain with a modified rfactor domain using the specified
//! view transformations
TensorDomain* transformView(
    TensorDomain* original_domain,
    const AnalyzeViewResult& view_analysis) {
  FUSER_PERF_SCOPE("transformView");
  return createViewDomain(original_domain, view_analysis);
}

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch