File: function_extraction.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (1187 lines) | stat: -rw-r--r-- 38,328 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/onnx/function_extraction.h>
#include <torch/csrc/jit/passes/onnx/naming.h>

namespace torch {
namespace jit {
namespace onnx {

namespace {

using scope_list = std::vector<ScopePtr>;

// Annotated attributes retrieved from module by inspecting module annotations.
// These attributes are not used inside the subgraph of ONNX local function
// because they are not created by PyTorch JIT tracing, but they may be used by
// consumers to determine whether or not to replace the function with a
// particular fused kernel.
static std::unordered_map<ScopePtr, Node*> scope_attr_map_;
static std::shared_ptr<Graph> scope_attr_graph_ = std::make_shared<Graph>();

static bool HasSameAttribute(
    const Node* a,
    const Node* b,
    const c10::Symbol& attr);

struct FunctionExtractor {
 public:
  FunctionExtractor(
      std::shared_ptr<Graph>& graph,
      const std::unordered_set<std::string>& module_names,
      const std::vector<std::string>& param_names)
      : graph_(graph),
        module_names_(module_names.begin(), module_names.end()),
        param_names_(param_names.begin(), param_names.end()) {}
  NodeAttrNameMap run();

 private:
  struct ScopeContext {
    std::unordered_set<ScopePtr> children_;
    ScopePtr scope_;
    node_list nlist_;
    value_list inputs_;
    value_list outputs_;
    std::unordered_map<Value*, Value*> env_to_subgraph_;

    void PopulateInputsOutputs(
        const std::unordered_set<std::string>& param_names);
    bool IsIdenticalFuncion(const ScopeContext& other_ctx) const;
  };

  using ScopeCtxPtr = ScopeContext*;
  using scope_ctx_map = std::unordered_map<ScopePtr, ScopeCtxPtr>;

  struct FunctionContext {
    FunctionContext(
        ScopePtr key,
        const scope_list& scopes,
        scope_ctx_map& scope_ctxs);
    void DebugPrint() const;
    void SetAttrName(Node* ref_n, Symbol attr, const std::string& name);
    c10::optional<std::string> FindAttrName(Node* ref_n, Symbol attr);
    c10::optional<std::string> FindAttrName(Node* ref_const_n);

    ScopePtr scope_key_;
    scope_ctx_map scope_ctxs_;
    std::unordered_map<
        Node*,
        std::unordered_map<Symbol, std::unordered_set<Node*>>>
        attribute_map_;

    // Passed later to serialization.
    NodeAttrNameMap node_attr_to_name_;
  };

  using FunctionCtxPtr = FunctionContext*;
  using func_ctx_map = std::unordered_map<ScopePtr, FunctionCtxPtr>;

  static bool IsValidScope(ScopePtr s);
  static c10::optional<ScopePtr> InferScope(Node* n);
  static bool IsAncestor(ScopePtr parent, ScopePtr child);
  static c10::optional<ScopePtr> FindCommonAncestor(ScopePtr a, ScopePtr b);
  static c10::optional<ScopePtr> FindCommonAncestor(const scope_list& scopes);
  std::shared_ptr<Graph> ConstructFuncGraph(FunctionContext& ctx);

  void ConvertScopeToFunction(
      const ScopePtr& scope_key,
      const scope_list& scope_list,
      scope_ctx_map& scope_ctxs,
      const std::shared_ptr<Graph>& graph);

  static void HandleNoScopeNodes(scope_ctx_map&, node_list no_scope_nlist);
  std::tuple<scope_ctx_map, node_list> PartitionNodesByScope(Block* b);
  scope_ctx_map PartitionNodesByScope(const std::shared_ptr<Graph>& graph);
  static std::unordered_map<ScopePtr, scope_list> PartitionIdenticalScopes(
      scope_ctx_map& scope_ctxs);
  static scope_list SortScopesByMaxDepth(
      std::unordered_map<ScopePtr, scope_list>&);
  Node* CreateFunctionDefNode(
      FunctionContext& func_ctx,
      const std::shared_ptr<Graph>& graph,
      const std::string& domain_name,
      const std::string& func_name);
  Node* CreateFunctionNode(
      FunctionContext& func_ctx,
      ScopeContext& scope_ctx,
      const std::shared_ptr<Graph>& graph,
      const std::string& domain_name,
      const std::string& func_name);

  static void DebugPrintScopeContexts(const scope_ctx_map&);
  static void DebugPrintGraphWithFunction(const std::shared_ptr<Graph>& g);
  static void DebugPrintConstantDiff(const FunctionContext&);

  std::shared_ptr<Graph> graph_;
  std::unordered_set<std::string> module_names_;
  std::unordered_set<std::string> param_names_;
  // Track modules with same module name that are exported as different onnx
  // local functions.
  std::unordered_map<std::string, int> module_variant_count_;
  func_ctx_map func_ctxs_;
};

FunctionExtractor::FunctionContext::FunctionContext(
    ScopePtr key,
    const scope_list& scopes,
    scope_ctx_map& scope_ctxs)
    : scope_key_(std::move(key)) {
  GRAPH_UPDATE(
      "Process function context for scope ",
      scope_key_->name().toDisplayString());
  TORCH_INTERNAL_ASSERT(scopes.size() > 0);
  const auto& ref_ctx = scope_ctxs[scope_key_];
  // NOTE: Function scopes must have same number and order of nodes.
  GRAPH_DEBUG(
      "Initialized function context for scope ",
      scope_key_->name().toDisplayString());

  for (const auto& scope : scopes) {
    GRAPH_DEBUG(
        "Process function context for scope ", scope->name().toDisplayString());
    TORCH_INTERNAL_ASSERT(scope_ctxs.find(scope) != scope_ctxs.end());
    scope_ctxs_[scope] = scope_ctxs[scope];
    if (scope_key_ == scope) {
      continue;
    }
    auto& scope_ctx = scope_ctxs[scope];

    const auto& ns_a = ref_ctx->nlist_;
    const auto& ns_b = scope_ctx->nlist_;
    TORCH_INTERNAL_ASSERT(ns_a.size() == ns_b.size());

    GRAPH_DEBUG("Process nodes of scope ", scope->name().toDisplayString());
    for (const auto i : c10::irange(ns_a.size())) {
      TORCH_INTERNAL_ASSERT(ns_a[i]->kind() == ns_b[i]->kind());
      auto n_a = ns_a[i];
      auto n_b = ns_b[i];
      std::vector<c10::Symbol> diff_attrs;
      std::vector<c10::Symbol> same_attrs;
      auto n_a_attr_names = n_a->attributeNames();
      auto n_b_attr_names = n_b->attributeNames();
      std::sort(n_a_attr_names.begin(), n_a_attr_names.end());
      std::sort(n_b_attr_names.begin(), n_b_attr_names.end());
      std::set_difference(
          n_a_attr_names.begin(),
          n_a_attr_names.end(),
          n_b_attr_names.begin(),
          n_b_attr_names.end(),
          std::inserter(diff_attrs, diff_attrs.begin()));
      std::set_intersection(
          n_a_attr_names.begin(),
          n_a_attr_names.end(),
          n_b_attr_names.begin(),
          n_b_attr_names.end(),
          std::inserter(same_attrs, same_attrs.begin()));
      for (auto attr_name : diff_attrs) {
        attribute_map_[n_a][attr_name].insert(n_b);
      }

      for (auto attr_name : same_attrs) {
        if (!HasSameAttribute(n_a, n_b, attr_name)) {
          attribute_map_[n_a][attr_name].insert(n_b);
        }
      }
    }
    GRAPH_DEBUG("Process scope complete. ", scope->name().toDisplayString());
  }

  GRAPH_DEBUG(
      "Process function context complete. ",
      scope_key_->name().toDisplayString());
  DebugPrint();
}

void FunctionExtractor::FunctionContext::DebugPrint() const {
  GRAPH_DEBUG("Scope name: ", scope_key_->name().toDisplayString());

  for (const auto& it : attribute_map_) {
    for (const auto& attr_it : it.second) {
      GRAPH_DEBUG(
          "Attribute value difference for attribute ",
          attr_it.first.toDisplayString());
      GRAPH_DEBUG(*it.first);
      for (auto n : attr_it.second) {
        GRAPH_DEBUG(*n);
      }
    }
  }
}

void FunctionExtractor::FunctionContext::SetAttrName(
    Node* ref_n,
    Symbol attr,
    const std::string& name) {
  auto v_it =
      scope_ctxs_[scope_key_]->env_to_subgraph_.find(ref_n->outputs().at(0));
  TORCH_INTERNAL_ASSERT(
      v_it != scope_ctxs_[scope_key_]->env_to_subgraph_.end());
  auto* n_in_def = v_it->second->node();
  auto n_attr_it = node_attr_to_name_[n_in_def][attr.toUnqualString()] = name;
}

c10::optional<std::string> FunctionExtractor::FunctionContext::FindAttrName(
    Node* ref_n,
    Symbol attr) {
  auto v_it =
      scope_ctxs_[scope_key_]->env_to_subgraph_.find(ref_n->outputs().at(0));
  if (v_it == scope_ctxs_[scope_key_]->env_to_subgraph_.end()) {
    return c10::nullopt;
  }
  auto* n_in_def = v_it->second->node();
  auto n_attr_it = node_attr_to_name_.find(n_in_def);
  if (n_attr_it == node_attr_to_name_.end()) {
    return c10::nullopt;
  }
  auto name_it = n_attr_it->second.find(attr.toUnqualString());
  if (name_it == n_attr_it->second.end()) {
    return c10::nullopt;
  }
  return name_it->second;
}

void FunctionExtractor::DebugPrintScopeContexts(
    const scope_ctx_map& scope_ctxs) {
  for (auto& it : scope_ctxs) {
    GRAPH_UPDATE(
        "Scope name: ",
        it.first->namesFromRoot(),
        " ",
        it.first->name().toDisplayString());
    GRAPH_UPDATE("Children scopes: ", [&]() {
      std::stringstream ss;
      for (const auto& child_scope : it.second->children_) {
        ss << child_scope->name().toDisplayString() << " ";
      }
      return ss.str();
    }());
    GRAPH_UPDATE("Node types: \n", [&]() {
      std::stringstream ss;
      for (auto n : it.second->nlist_) {
        ss << "  " << *n;
      }
      return ss.str();
    }());
    GRAPH_UPDATE("Node count: ", it.second->nlist_.size());
  }
}

void FunctionExtractor::DebugPrintGraphWithFunction(
    const std::shared_ptr<Graph>& g) {
  GRAPH_UPDATE("Local function definitions:");
  for (auto* n : g->nodes()) {
    if (n->kind() == Symbol::onnx("LocalFunctionDef")) {
      GRAPH_UPDATE(
          n->s(attr::name),
          " graph: ",
          n->g(Symbol::attr("graph"))->toString());
    }
  }
  GRAPH_UPDATE("Main graph: ", g->toString());
}

bool FunctionExtractor::IsValidScope(ScopePtr s) {
  return !s->isRoot() && !s->isBlank();
}

bool FunctionExtractor::IsAncestor(ScopePtr parent, ScopePtr child) {
  if (!IsValidScope(parent) || !IsValidScope(child) ||
      parent->getDepth() >= child->getDepth()) {
    return false;
  }
  do {
    child = child->parent();
    if (parent == child) {
      return true;
    }
  } while (IsValidScope(child));
  return false;
}

c10::optional<ScopePtr> FunctionExtractor::FindCommonAncestor(
    ScopePtr a,
    ScopePtr b) {
  if (!IsValidScope(a) || !IsValidScope(b)) {
    return c10::nullopt;
  }

  auto diff =
      static_cast<int64_t>(a->getDepth()) - static_cast<int64_t>(b->getDepth());
  if (diff != 0) {
    auto deeper_scope = diff > 0 ? a : b;
    auto other_scope = diff > 0 ? b : a;
    while (diff > 0) {
      deeper_scope = deeper_scope->parent();
      diff--;
    }
    a = deeper_scope;
    b = other_scope;
  }

  while (IsValidScope(a) && IsValidScope(b)) {
    if (a == b) {
      return a;
    } else {
      a = a->parent();
      b = b->parent();
    }
  }

  return c10::nullopt;
}

c10::optional<ScopePtr> FunctionExtractor::FindCommonAncestor(
    const scope_list& scopes) {
  if (scopes.size() == 0) {
    return c10::nullopt;
  }

  c10::optional<ScopePtr> common_ancestor = scopes.at(0);
  for (const auto& scope : scopes) {
    common_ancestor = FindCommonAncestor(common_ancestor.value(), scope);
    if (!common_ancestor.has_value()) {
      return c10::nullopt;
    }
  }

  return common_ancestor;
}

c10::optional<ScopePtr> FunctionExtractor::InferScope(Node* n) {
  // The scope of node n is assigned based on the following rules.
  // 1. If all uses of outputs of n belongs to the same scope,
  //    assign that scope, otherwise
  // 2. If all nodes of inputs of n belongs to the same scope,
  //    assign that scope, otherwise
  // 3. Find common ancestor of the scopes of uses of outputs of n,
  //    and the scopes of nodes of inputs of n.
  scope_list input_scopes;
  scope_list output_scopes;
  for (auto input : n->inputs()) {
    input_scopes.emplace_back(input->node()->scope());
  }
  for (auto output : n->outputs()) {
    for (auto use : output->uses()) {
      if (!IsValidScope(use.user->scope())) {
        auto inferred_output_scope = InferScope(use.user);
        if (inferred_output_scope.has_value() &&
            IsValidScope(inferred_output_scope.value())) {
          use.user->setScope(inferred_output_scope.value());
        }
      }
      output_scopes.emplace_back(use.user->scope());
    }
  }
  if (output_scopes.size() > 0 &&
      std::all_of(
          output_scopes.begin(),
          output_scopes.end(),
          [&output_scopes](ScopePtr scope) -> bool {
            return IsValidScope(scope) && scope == output_scopes.at(0);
          })) {
    return output_scopes.at(0);
  } else if (
      input_scopes.size() > 0 &&
      std::all_of(
          input_scopes.begin(),
          input_scopes.end(),
          [&input_scopes](ScopePtr scope) -> bool {
            return IsValidScope(scope) && scope == input_scopes.at(0);
          })) {
    return input_scopes.at(0);
  } else {
    scope_list scopes;
    std::copy_if(
        input_scopes.begin(),
        input_scopes.end(),
        std::back_inserter(scopes),
        IsValidScope);
    std::copy_if(
        output_scopes.begin(),
        output_scopes.end(),
        std::back_inserter(scopes),
        IsValidScope);
    if (scopes.size() > 0) {
      auto common_ancestor = FindCommonAncestor(scopes);
      if (common_ancestor.has_value() &&
          IsValidScope(common_ancestor.value())) {
        return common_ancestor.value();
      }
    }
  }

  return c10::nullopt;
}

std::shared_ptr<Graph> FunctionExtractor::ConstructFuncGraph(
    FunctionContext& func_ctx) {
  auto& ctx = *func_ctx.scope_ctxs_[func_ctx.scope_key_];
  const auto& nlist = ctx.nlist_;
  const auto& scope = ctx.scope_;
  auto& env = ctx.env_to_subgraph_;

  auto g = std::make_shared<Graph>();
  GRAPH_DEBUG("Constructing graph for ", scope->namesFromRoot());

  // TODO: Update input names of function to match those in Module source code
  // signature.
  // This requires mapping between function node inputs and Module inputs.
  // Due to the lack of such mapping, currently debugName is used as input
  // names.
  ctx.PopulateInputsOutputs(param_names_);
  for (auto* v : ctx.inputs_) {
    env[v] = g->addInput()->copyMetadata(v);
    GRAPH_DEBUG(
        "Add input value ",
        env[v]->debugName(),
        " for outer scope value ",
        v->debugName(),
        " from ",
        *v->node());
  }

  for (auto* n : nlist) {
    auto clone_n = g->createClone(n, [&](Value* v) {
      TORCH_INTERNAL_ASSERT(env.find(v) != env.end());
      return env[v];
    });
    for (const auto i : c10::irange(clone_n->outputs().size())) {
      env[n->output(i)] = clone_n->output(i);
    }
    g->insertNode(clone_n);
  }

  // If values are used outside of this graph, set as graph output.
  for (auto* v : ctx.outputs_) {
    TORCH_INTERNAL_ASSERT(env.find(v) != env.end());
    g->registerOutput(env[v]);
  }

  GRAPH_DEBUG(g->toString());
  return g;
}

Node* FunctionExtractor::CreateFunctionDefNode(
    FunctionContext& func_ctx,
    const std::shared_ptr<Graph>& graph,
    const std::string& domain_name,
    const std::string& func_name) {
  const auto func_def_nk = Symbol::onnx("LocalFunctionDef");
  const auto func_g_attr = Symbol::attr("graph");
  const auto func_name_attr = attr::name;
  const auto func_domain_attr = Symbol::attr("domain");

  auto func_graph = ConstructFuncGraph(func_ctx);

  // create and insert local function definition node
  auto func_def_n = graph->create(func_def_nk, 0);
  func_def_n->g_(func_g_attr, func_graph);
  func_def_n->s_(func_name_attr, func_name);
  func_def_n->s_(func_domain_attr, domain_name);
  graph->prependNode(func_def_n);

  // set constants and attributes of different values as function attributes.
  std::unordered_map<std::string, int> base_attr_name_count;
  std::vector<std::string> final_attr_names;

  auto adjust_attr_name = [&](std::string attr_name) {
    if (base_attr_name_count.find(attr_name) != base_attr_name_count.end()) {
      attr_name =
          attr_name + "." + std::to_string(base_attr_name_count[attr_name]++);
    } else {
      base_attr_name_count[attr_name] = 1;
    }
    return attr_name;
  };

  for (const auto& n_it : func_ctx.attribute_map_) {
    auto* n = n_it.first;
    for (const auto& attr_it : n_it.second) {
      const auto& attr = attr_it.first;
      // Add prefix "inferred::" to name of inferred attribute.
      // This is to differentiate from annotated attributes picked up
      // from python module annotation.
      auto attr_name = "inferred::" + std::string(n->kind().toUnqualString()) +
          '_' + attr.toUnqualString();
      auto final_attr_name = adjust_attr_name(attr_name);
      final_attr_names.emplace_back(final_attr_name);
      func_ctx.SetAttrName(n, attr, final_attr_name);
    }
  }

  // Set annotated attributes
  std::unordered_set<Symbol> annotated_attr_names;
  bool first_iteration = true;
  for (const auto& it : func_ctx.scope_ctxs_) {
    auto scope = it.first;
    auto annotated_attr_node = scope_attr_map_.find(scope);
    if (annotated_attr_node != scope_attr_map_.end()) {
      auto names = annotated_attr_node->second->attributeNames();
      if (first_iteration) {
        std::copy(
            names.begin(),
            names.end(),
            std::inserter(annotated_attr_names, annotated_attr_names.end()));
        first_iteration = false;
      } else {
        auto unseen_attr_name = std::find_if(
            names.begin(),
            names.end(),
            [&annotated_attr_names](const Symbol& name) {
              return annotated_attr_names.find(name) ==
                  annotated_attr_names.end();
            });
        TORCH_CHECK(
            unseen_attr_name == names.end(),
            "Found outstanding annotated attribute ",
            *unseen_attr_name,
            " from module ",
            scope->name(),
            ". Please ensure module instances of the same class have the same set of annotated attributes.");
      }
    }
  }
  for (auto attr_name : annotated_attr_names) {
    final_attr_names.emplace_back(attr_name.toUnqualString());
  }

  func_def_n->ss_(Symbol::attr("attributes"), final_attr_names);

  return func_def_n;
}

Node* FunctionExtractor::CreateFunctionNode(
    FunctionContext& func_ctx,
    ScopeContext& scope_ctx,
    const std::shared_ptr<Graph>& graph,
    const std::string& domain_name,
    const std::string& func_name) {
  const auto& func_scope = func_ctx.scope_key_;
  GRAPH_DEBUG(
      "Create and insert local function for scope: ",
      func_scope->namesFromRoot());
  scope_ctx.PopulateInputsOutputs(param_names_);
  auto last_n = *scope_ctx.nlist_.rbegin();
  auto func_n = graph->create(
      Symbol::fromQualString(domain_name + "::" + func_name),
      scope_ctx.outputs_.size());
  func_n->copyMetadata(last_n);
  for (auto* v : scope_ctx.inputs_) {
    func_n->addInput(v);
  }
  for (const auto i : c10::irange(scope_ctx.outputs_.size())) {
    func_n->output(i)->setType(scope_ctx.outputs_[i]->type());
    scope_ctx.outputs_[i]->replaceAllUsesWith(func_n->output(i));
  }

  // set attributes of different values as function attributes.
  auto copy_attr =
      [](Node* a, Node* b, Symbol attr, const std::string& new_name) {
#define COPY_ATTR(kind)                                \
  case AttributeKind::kind: {                          \
    b->kind##_(Symbol::attr(new_name), a->kind(attr)); \
    break;                                             \
  }
        switch (a->kindOf(attr)) {
          COPY_ATTR(f)
          COPY_ATTR(fs)
          COPY_ATTR(i)
          COPY_ATTR(is)
          COPY_ATTR(s)
          COPY_ATTR(ss)
          COPY_ATTR(t)
          COPY_ATTR(ts)
#undef COPY_ATTR
          case AttributeKind::ival:
          case AttributeKind::g:
          case AttributeKind::gs:
          case AttributeKind::ty:
          case AttributeKind::tys:
          case AttributeKind::c:
          default:
            TORCH_INTERNAL_ASSERT(
                false,
                "Unexpected attribute type ",
                static_cast<int>(a->kindOf(attr)),
                " from node ",
                *a);
            break;
        }
      };

  for (const auto& it : func_ctx.attribute_map_) {
    auto* ref_n = it.first;
    for (const auto& attr_it : it.second) {
      const auto& attr = attr_it.first;
      auto attr_name = func_ctx.FindAttrName(ref_n, attr).value();
      copy_attr(ref_n, func_n, attr, attr_name);
      for (auto* n : scope_ctx.nlist_) {
        if (attr_it.second.find(n) != attr_it.second.end()) {
          copy_attr(n, func_n, attr, attr_name);
          break;
        }
      }
    }
  }

  // annotated attributes
  auto scope = scope_ctx.scope_;
  auto annotated_attr_node = scope_attr_map_.find(scope);
  if (annotated_attr_node != scope_attr_map_.end()) {
    auto node = annotated_attr_node->second;
    for (auto attr : node->attributeNames()) {
      copy_attr(node, func_n, attr, attr.toUnqualString());
    }
  }

  func_n->insertAfter(last_n);
  return func_n;
}

void FunctionExtractor::ConvertScopeToFunction(
    const ScopePtr& scope_key,
    const scope_list& scope_list,
    scope_ctx_map& scope_ctxs,
    const std::shared_ptr<Graph>& graph) {
  // This function needs to be called always on inner most scopes.
  // 1. Generate function context, this identifies different constants and
  // attributes.
  // 2. Create function definition node, and insert to main graph.
  // 3. Create function node for each call, and replace subgraph nodes in parent
  // functions.

  func_ctxs_.insert(std::make_pair(
      scope_key, new FunctionContext(scope_key, scope_list, scope_ctxs)));
  auto& func_ctx = *func_ctxs_[scope_key];

  const std::string module_class_name(
      ONNXScopeName::className(func_ctx.scope_key_));
  auto pos = module_class_name.rfind('.');
  TORCH_INTERNAL_ASSERT(pos != std::string::npos);

  auto construct_unique_module_name = [&](std::string module_name) {
    auto module_name_variant = module_variant_count_.find(module_name);
    if (module_name_variant != module_variant_count_.end()) {
      module_variant_count_[module_name]++;
      module_name += ("." + std::to_string(module_name_variant->second));
    } else {
      module_variant_count_[module_name] = 0;
    }
    return module_name;
  };

  const auto domain_name = module_class_name.substr(0, pos);
  const auto func_name =
      construct_unique_module_name(module_class_name.substr(pos + 1));

  CreateFunctionDefNode(func_ctx, graph, domain_name, func_name);

  // create and insert local function node to graph.
  for (const auto& it : func_ctx.scope_ctxs_) {
    auto scope = it.first;
    auto& scope_ctx = *it.second;
    auto func_n =
        CreateFunctionNode(func_ctx, scope_ctx, graph, domain_name, func_name);

    std::unordered_set<Node*> old_nodes(
        scope_ctx.nlist_.begin(), scope_ctx.nlist_.end());

    auto last_n = *scope_ctx.nlist_.rbegin();
    // replace function body nodes in parent scopes with local function node.
    for (auto& it : scope_ctxs) {
      const auto& parent_scope = it.first;
      auto& parent_ctx = *it.second;

      if (!IsAncestor(parent_scope, scope)) {
        continue;
      }

      auto& ctx_nlist = parent_ctx.nlist_;
      GRAPH_DEBUG(
          "Replace local function node in parent scope: ",
          it.first->namesFromRoot(),
          " nodes to remove: ",
          old_nodes.size(),
          " parent total nodes: ",
          ctx_nlist.size());

      // insert local function node
      auto last_n_it = std::find(ctx_nlist.begin(), ctx_nlist.end(), last_n);
      ctx_nlist.insert(last_n_it, func_n);

      // remove replaced nodes from list
      ctx_nlist.erase(
          std::remove_if(
              ctx_nlist.begin(),
              ctx_nlist.end(),
              [&old_nodes](Node* n) {
                return old_nodes.find(n) != old_nodes.end();
              }),
          ctx_nlist.end());

      GRAPH_DEBUG("Parent total nodes after remove: ", ctx_nlist.size());

      // refresh inputs/outputs.
      parent_ctx.PopulateInputsOutputs(param_names_);
    }
  }

  for (const auto& it : func_ctx.scope_ctxs_) {
    auto& scope_ctx = *it.second;
    // delete replaced nodes in graph.
    for (auto it = scope_ctx.nlist_.rbegin(); it != scope_ctx.nlist_.rend();) {
      auto* n = *it;
      it++;
      GRAPH_DEBUG("Destroying node ", *n);
      n->destroy();
    }
  }
}

bool FunctionExtractor::ScopeContext::IsIdenticalFuncion(
    const ScopeContext& other_ctx) const {
  // Differentiate same function under different inputs.
  // When constants are passed in place of inputs, it leads to different
  // input count and node count. Likewise, due to different uses, output
  // count can be different as well.
  // For now export them as different functions.
  // Covered by `test_local_function_overloads` in
  // `test/onnx/test_utility_funs.py`.
  if (&other_ctx == this) {
    return true;
  }
  if (ONNXScopeName::className(this->scope_) !=
      ONNXScopeName::className(other_ctx.scope_)) {
    return false;
  }
  if (this->inputs_.size() != other_ctx.inputs_.size() ||
      this->outputs_.size() != other_ctx.outputs_.size()) {
    return false;
  }
  const auto& ns_a = this->nlist_;
  const auto& ns_b = other_ctx.nlist_;
  if (ns_a.size() != ns_b.size()) {
    return false;
  }
  for (const auto i : c10::irange(ns_a.size())) {
    if (ns_a[i]->kind() != ns_b[i]->kind()) {
      return false;
    }
  }

  return true;
}

void FunctionExtractor::ScopeContext::PopulateInputsOutputs(
    const std::unordered_set<std::string>& param_names) {
  inputs_.clear();
  outputs_.clear();
  const auto& nlist = this->nlist_;
  std::unordered_set<Value*> v_set;
  std::unordered_set<Node*> n_set;

  value_list input_list;
  value_list initializer_list;

  // Add initializers after inputs.
  for (auto* n : nlist) {
    for (auto* v : n->inputs()) {
      if (v_set.find(v) == v_set.end()) {
        if (param_names.find(v->debugName()) != param_names.end()) {
          initializer_list.emplace_back(v);
        } else {
          input_list.emplace_back(v);
        }
        v_set.insert(v);
      }
    }
    for (auto* v : n->outputs()) {
      v_set.insert(v);
    }
    n_set.insert(n);
  }
  for (auto* v : input_list) {
    inputs_.emplace_back(v);
  }
  for (auto* v : initializer_list) {
    inputs_.emplace_back(v);
  }

  for (auto* n : nlist) {
    for (auto* v : n->outputs()) {
      bool used_outside = false;
      for (auto use : v->uses()) {
        used_outside |= (n_set.find(use.user) == n_set.end());
      }
      if (used_outside) {
        outputs_.emplace_back(v);
      }
    }
  }
}

void FunctionExtractor::HandleNoScopeNodes(
    scope_ctx_map& scope_ctxs,
    node_list no_scope_nlist) {
  GRAPH_UPDATE("No scope node count: ", no_scope_nlist.size());
  for (auto n : no_scope_nlist) {
    TORCH_WARN(
        "ONNX function extraction cannot determine the scope for node: ", *n);
  }
  TORCH_INTERNAL_ASSERT(
      no_scope_nlist.size() == 0,
      "ONNX function extraction cannot determine the scope for the above nodes.");
}

std::tuple<FunctionExtractor::scope_ctx_map, node_list> FunctionExtractor::
    PartitionNodesByScope(Block* b) {
  scope_ctx_map scope_ctxs = {};
  node_list no_scope_nlist;

  auto find_or_create_scope_ctx = [](scope_ctx_map& scope_ctxs,
                                     const ScopePtr& scope) {
    if (scope_ctxs.find(scope) == scope_ctxs.end()) {
      scope_ctxs.insert(std::make_pair(scope, new ScopeContext()));
    }
    return scope_ctxs[scope];
  };

  auto record_node_scope = [&scope_ctxs, &find_or_create_scope_ctx](Node* n) {
    const auto& scope = n->scope();
    find_or_create_scope_ctx(scope_ctxs, scope)->scope_ = scope;
    auto tmp_scope = scope;
    while (IsValidScope(tmp_scope)) {
      find_or_create_scope_ctx(scope_ctxs, tmp_scope)->nlist_.emplace_back(n);
      if (IsValidScope(tmp_scope->parent())) {
        find_or_create_scope_ctx(scope_ctxs, tmp_scope->parent())
            ->children_.insert(tmp_scope);
      }
      tmp_scope = tmp_scope->parent();
    }
  };

  for (auto* n : b->nodes()) {
    auto scope = n->scope();
    if (scope && IsValidScope(scope)) {
      record_node_scope(n);
    } else {
      auto inferred_scope = InferScope(n);

      if (inferred_scope.has_value() && IsValidScope(inferred_scope.value())) {
        n->setScope(inferred_scope.value());
        record_node_scope(n);
      } else {
        GRAPH_UPDATE("Cannot infer proper scope for node: ", *n);
        no_scope_nlist.emplace_back(n);
      }
    }

    for (auto* sub_b : n->blocks()) {
      scope_ctx_map subblock_scope_ctxs;
      node_list subblock_no_scope_nlist;
      std::tie(subblock_scope_ctxs, subblock_no_scope_nlist) =
          PartitionNodesByScope(sub_b);

      for (auto& it : subblock_scope_ctxs) {
        if (scope_ctxs.find(it.first) == scope_ctxs.end()) {
          scope_ctxs.insert(std::make_pair(it.first, it.second));
        } else {
          for (auto* s_n : it.second->nlist_) {
            scope_ctxs[it.first]->nlist_.emplace_back(s_n);
          }
          for (const auto& s_child_scope : it.second->children_) {
            scope_ctxs[it.first]->children_.insert(s_child_scope);
          }
        }
      }

      no_scope_nlist.insert(
          no_scope_nlist.end(),
          subblock_no_scope_nlist.begin(),
          subblock_no_scope_nlist.end());
    }
  }

  for (auto& it : scope_ctxs) {
    it.second->scope_ = it.first;
    it.second->PopulateInputsOutputs(param_names_);
  }

  return std::tie(scope_ctxs, no_scope_nlist);
}

FunctionExtractor::scope_ctx_map FunctionExtractor::PartitionNodesByScope(
    const std::shared_ptr<Graph>& graph) {
  scope_ctx_map scope_ctxs;
  node_list no_scope_nlist;
  std::tie(scope_ctxs, no_scope_nlist) = PartitionNodesByScope(graph->block());

  HandleNoScopeNodes(scope_ctxs, no_scope_nlist);

  return scope_ctxs;
}

std::unordered_map<ScopePtr, scope_list> FunctionExtractor::
    PartitionIdenticalScopes(FunctionExtractor::scope_ctx_map& scope_ctxs) {
  std::unordered_map<ScopePtr, scope_list> identical_scope_map;

  for (auto& it : scope_ctxs) {
    auto scope = it.first;
    const auto& scope_ctx = it.second;
    bool unique = true;
    for (auto& kv_it : identical_scope_map) {
      auto key_scope = kv_it.first;
      const auto& key_scope_ctx = scope_ctxs[key_scope];
      auto& key_scope_vec = kv_it.second;
      if (key_scope_ctx->IsIdenticalFuncion(*scope_ctx)) {
        key_scope_vec.emplace_back(scope);
        unique = false;
        break;
      }
    }
    if (unique) {
      identical_scope_map[scope].emplace_back(scope);
    }
  }

  return identical_scope_map;
}

static bool HasSameAttribute(
    const Node* a,
    const Node* b,
    const c10::Symbol& attr) {
  if (!a->hasAttribute(attr) && !b->hasAttribute(attr)) {
    return true;
  }
  if (!a->hasAttribute(attr) || !b->hasAttribute(attr)) {
    return false;
  }
  auto a_kind = a->kindOf(attr);
  auto b_kind = b->kindOf(attr);
  if (a_kind != b_kind) {
    return false;
  }

#define COMP_ATTR(kind)              \
  case AttributeKind::kind: {        \
    const auto& a_v = a->kind(attr); \
    const auto& b_v = b->kind(attr); \
    return a_v == b_v;               \
  }

  switch (a_kind) {
    COMP_ATTR(f)
    COMP_ATTR(fs)
    COMP_ATTR(i)
    COMP_ATTR(is)
    COMP_ATTR(s)
    COMP_ATTR(ss)
#undef COMP_ATTR
    case AttributeKind::t: {
      const auto& a_v = a->t(attr);
      const auto& b_v = b->t(attr);
      return a_v.equal(b_v);
    }
    case AttributeKind::ts: {
      const auto& a_v = a->ts(attr);
      const auto& b_v = b->ts(attr);
      return std::equal(
          a_v.begin(),
          a_v.end(),
          b_v.begin(),
          b_v.end(),
          [](const at::Tensor& a_t, const at::Tensor& b_t) {
            return a_t.equal(b_t);
          });
    }
    case AttributeKind::ival:
    case AttributeKind::g:
    case AttributeKind::gs:
    case AttributeKind::ty:
    case AttributeKind::tys:
    case AttributeKind::c:
    default:
      TORCH_INTERNAL_ASSERT(
          false,
          "Unexpected attribute type ",
          static_cast<int>(a_kind),
          " from node ",
          *a);
      break;
  }

  return true;
}

scope_list FunctionExtractor::SortScopesByMaxDepth(
    std::unordered_map<ScopePtr, scope_list>& identical_scope_map) {
  std::unordered_map<ScopePtr, size_t> scope_max_depth;
  for (const auto& it : identical_scope_map) {
    const auto& scopes = it.second;
    size_t max_depth = 0;
    for (const auto& scope : scopes) {
      if (scope->getDepth() > max_depth) {
        max_depth = scope->getDepth();
      }
    }
    scope_max_depth[it.first] = max_depth;
  }

  scope_list sorted_scopes;
  sorted_scopes.reserve(scope_max_depth.size());
  for (const auto& it : scope_max_depth) {
    sorted_scopes.emplace_back(it.first);
  }
  std::sort(
      sorted_scopes.begin(),
      sorted_scopes.end(),
      [&scope_max_depth](const ScopePtr& a, const ScopePtr& b) -> bool {
        return scope_max_depth[a] >= scope_max_depth[b];
      });
  return sorted_scopes;
}

NodeAttrNameMap FunctionExtractor::run() {
  auto scope_ctxs = PartitionNodesByScope(graph_);
  DebugPrintScopeContexts(scope_ctxs);
  auto identical_scope_map = PartitionIdenticalScopes(scope_ctxs);
  // Deepest scope comes first, guaranteeing no other scope can be its child.
  auto sorted_scope_keys = SortScopesByMaxDepth(identical_scope_map);
  for (const auto& scope_key : sorted_scope_keys) {
    if (module_names_.find(ONNXScopeName::className(scope_key)) !=
        module_names_.end()) {
      ConvertScopeToFunction(
          scope_key, identical_scope_map[scope_key], scope_ctxs, graph_);
    }
    GRAPH_DEBUG("Main graph afterwards: ", graph_->toString());
  }
  DebugPrintGraphWithFunction(graph_);

  // Construct return mappings
  NodeAttrNameMap node_attr_to_name;

  for (const auto& it : func_ctxs_) {
    auto func_ref_map = it.second->node_attr_to_name_;
    node_attr_to_name.insert(func_ref_map.begin(), func_ref_map.end());
  }

  // Clear
  for (auto& it : scope_ctxs) {
    delete it.second;
  }
  scope_ctxs.clear();
  for (auto& it : func_ctxs_) {
    delete it.second;
  }
  func_ctxs_.clear();

  return node_attr_to_name;
}

// Retrieves the node representing the most recent
// ScopePtr. This function should only be invoked from module forward hook. At
// this point, module forward call is completed, and the most recent ScopePtr
// is popped from TracingState.
// This function inspects the node, and its subblock, to find
// the node associated with the most recent ScopePtr.
Node* NodeOfMostRecentScope(Node* forward_node) {
  TORCH_INTERNAL_ASSERT(
      forward_node->kind() == prim::TracedModuleForward,
      "forward_node got kind: ",
      forward_node->kind().toDisplayString());
  auto* block = forward_node->blocks()[0];
  for (auto* node : block->nodes().reverse()) {
    if (node->kind() == prim::TracedModuleForward) {
      Node* target_node = NodeOfMostRecentScope(node);
      if (scope_attr_map_.find(node->scope()) == scope_attr_map_.end()) {
        return target_node;
      }
    }
  }
  return forward_node;
}

} // namespace

// FunctionExtractor runs in the following steps. Updates are made inplace to
// the graph argument.
//    1. Partition nodes into groups based on their scope information.
//    Each scope represents an individual nn.Module call. A ScopeContext object
//    is created for each group.
//    2. Compare and find groups with the same subgraph pattern from step 1.
//    3. Scopes are nested. Starting from the deepest scope, extract the
//    subgraph pattern, and define as local function node. Replace subgraph
//    pattern with a single node of the new local function node type. A
//    FunctionContext object is created for each function.
//    4. Construct NodeAttrNameMap tracking mapping from attribute name of
//    IR Node inside function subgraph, to function attribute name.
NodeAttrNameMap ONNXFunctionExtraction(
    std::shared_ptr<Graph>& graph,
    const std::unordered_set<std::string>& module_names,
    const std::vector<std::string>& param_names) {
  GRAPH_UPDATE(
      "Export these module forward calls as functions: ",
      std::vector<std::string>{module_names.begin(), module_names.end()});
  FunctionExtractor fe(graph, module_names, param_names);
  return fe.run();
}

Node* ONNXGetPreviousScope(std::shared_ptr<Graph>& graph) {
  auto* last_node = graph->nodes().back()->prev();
  auto* scope_node = NodeOfMostRecentScope(last_node);
  auto* attr_node = scope_attr_graph_->create(prim::TracedModuleForward);
  attr_node->setScope(scope_node->scope());
  TORCH_INTERNAL_ASSERT(
      scope_attr_map_.find(scope_node->scope()) == scope_attr_map_.end(),
      "Found duplicated scope. Scope ",
      scope_node->scope()->namesFromRoot(),
      " already processed.");
  scope_attr_map_[scope_node->scope()] = attr_node;
  return attr_node;
}

void ONNXClearScopeRecords() {
  scope_attr_map_.clear();
  scope_attr_graph_ = std::make_shared<Graph>();
}

void ONNXTrackScopeAttributes(
    std::shared_ptr<Graph>& graph,
    std::map<std::string, IValue>& attributes) {
  // Skip the "real" last node which is `return_node`.
  auto* last_node = graph->nodes().back()->prev();
  auto* scope_node = NodeOfMostRecentScope(last_node);
  auto* attr_node = scope_attr_graph_->create(prim::TracedModuleForward);
  attr_node->setScope(scope_node->scope());
  TORCH_INTERNAL_ASSERT(
      scope_attr_map_.find(scope_node->scope()) == scope_attr_map_.end());
  scope_attr_map_[scope_node->scope()] = attr_node;

  for (const auto& it : attributes) {
    auto k = Symbol::attr(it.first);
    auto v = it.second;
    if (v.isTensor()) {
      attr_node->t_(k, v.toTensor());
    } else if (v.isInt()) {
      attr_node->i_(k, v.toInt());
    } else if (v.isDouble()) {
      attr_node->f_(k, v.toDouble());
    } else if (v.isBool()) {
      attr_node->i_(k, v.toBool());
    } else if (v.isString()) {
      attr_node->s_(k, v.toStringRef());
    } else if (v.isIntList()) {
      attr_node->is_(k, v.toIntList().vec());
    } else if (v.isBoolList()) {
      auto bool_list = v.toBoolList();
      attr_node->is_(
          k, std::vector<int64_t>(bool_list.begin(), bool_list.end()));
    } else if (v.isDoubleList()) {
      attr_node->fs_(k, v.toDoubleList().vec());
    }
  }
}

} // namespace onnx
} // namespace jit
} // namespace torch