File: graph_passes.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (950 lines) | stat: -rw-r--r-- 40,639 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
import torch
from torch.fx import GraphModule, map_arg
from torch.fx.graph import Graph, Node
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix

from .utils import (
    get_node_first_input_and_output_type,
    getattr_from_fqn,
    NodeInputOrOutputType,
    return_first_non_observer_node,
    get_number_of_non_param_args,
    get_target_type_str,
    get_arg_indices_of_inputs_to_log,
    get_node_input_qparams,
    op_type_supports_shadowing,
    get_normalized_nth_input,
)

from .ns_types import (
    NSSingleResultValuesType,
    NSSubgraph,
    NSNodeTargetType,
)
from torch.ao.ns.fx.mappings import (
    get_node_type_to_io_type_map,
)
from torch.ao.quantization.quantize import is_activation_post_process

from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set

def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
    fqn = None
    if hasattr(gm, '_node_name_to_scope'):
        # fqn on observers is not present, because they do not
        # exist when the fqns are created during tracing. If this is
        # an observer, get the fqn of the node being observed.
        node_to_use_for_fqn = node
        if node.op == 'call_module':
            assert isinstance(node.target, str)
            module = getattr_from_fqn(gm, node.target)
            if is_activation_post_process(module):
                node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
        fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0]  # type: ignore[index]
    return fqn  # type: ignore[return-value]

def _insert_logger_after_node(
    node: Node,
    gm: GraphModule,
    logger_cls: Callable,
    logger_node_name_suffix: str,
    ref_node_name: str,
    model_name: str,
    ref_name: str,
    ref_node_target_type: str,
    results_type: str,
    index_within_arg: int,
    index_of_arg: int,
    fqn: Optional[str],
) -> Node:
    """
    Given a starting graph of

    prev_node -> node -> next_node

    This function creates a new logger_cls obj and adds it
    after node, resulting in

    prev_node -> node -> logger_obj -> next_node
    """
    # create new name
    logger_node_name = \
        get_new_attr_name_with_prefix(node.name + logger_node_name_suffix)(gm)
    target_type = get_target_type_str(node, gm)
    # create the logger object
    logger_obj = logger_cls(
        ref_node_name, node.name, model_name, ref_name, target_type,
        ref_node_target_type,
        results_type, index_within_arg, index_of_arg, fqn)
    # attach the logger object to the parent module
    setattr(gm, logger_node_name, logger_obj)
    logger_node = node.graph.create_node(
        'call_module', logger_node_name, (node,), {})
    return logger_node

def add_loggers_to_model(
    gm: GraphModule,
    node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
    node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
    logger_cls: Callable,
    model_name: str,
) -> GraphModule:
    """
    Takes the graph of gm, adds loggers to the output
    of each node in nodes_to_instrument. Returns a GraphModule with the new
    graph.
    """

    new_graph = Graph()
    env: Dict[str, Any] = {}
    modules = dict(gm.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    for node in gm.graph.nodes:
        if node.op == 'output':
            new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg))
            continue

        if (
            (node in node_to_instrument_inputs_to_ref_node_name) or
            (node in node_to_instrument_outputs_to_ref_node_name)
        ):
            fqn = _maybe_get_fqn(node, gm)

            if node in node_to_instrument_inputs_to_ref_node_name:
                ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[node]
                # Ops such add and mul are special because either
                # one or two of the first two arguments can be tensors,
                # and if one argument is a tensor it can be first or
                # second (x + 1 versus 1 + x).
                arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
                for node_arg_idx in arg_indices_to_log:
                    node_arg = get_normalized_nth_input(node, gm, node_arg_idx)
                    if type(node_arg) == Node:
                        # create a single input logger
                        prev_node = env[node_arg.name]
                        env[node_arg.name] = _insert_logger_after_node(
                            prev_node, gm, logger_cls, '_ns_logger_', node.name,
                            model_name, ref_name, ref_node_type,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0, index_of_arg=node_arg_idx,
                            fqn=fqn)
                    elif type(node_arg) == torch.fx.immutable_collections.immutable_list:
                        # create N input loggers, one for each node
                        for arg_idx, arg in enumerate(node_arg):  # type: ignore[var-annotated, arg-type]
                            prev_node = env[arg.name]
                            env[prev_node.name] = _insert_logger_after_node(
                                prev_node, gm, logger_cls, '_ns_logger_', node.name,
                                model_name, ref_name, ref_node_type,
                                NSSingleResultValuesType.NODE_INPUT.value,
                                index_within_arg=arg_idx, index_of_arg=node_arg_idx,
                                fqn=fqn)
                    else:
                        pass

            # ensure env is populated with base node
            # Note: runs for both inputs and outputs
            env[node.name] = new_graph.node_copy(node, load_arg)

            if node in node_to_instrument_outputs_to_ref_node_name:
                ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[node]
                # add the logger after the base node
                env[node.name] = _insert_logger_after_node(
                    env[node.name], gm, logger_cls, '_ns_logger_', node.name,
                    model_name, ref_name, ref_node_type,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0, index_of_arg=0, fqn=fqn)

        else:
            env[node.name] = new_graph.node_copy(node, load_arg)

    new_gm = GraphModule(gm, new_graph)
    return new_gm

def _insert_quantize_per_tensor_node(
    prev_node_c: Node,
    node_a: Node,
    gm_b: GraphModule,
    graph_c: Graph,
    scale: Union[torch.Tensor, float],
    zero_point: Union[torch.Tensor, int],
    dtype_cast_name: str,
) -> Node:
    # copy scale
    scale_node_name = \
        get_new_attr_name_with_prefix(
            node_a.name + '_input_scale_')(gm_b)
    setattr(gm_b, scale_node_name, scale)
    scale_node = graph_c.create_node(
        'get_attr', scale_node_name, (), {}, scale_node_name)
    # copy zero_point
    zero_point_node_name = \
        get_new_attr_name_with_prefix(
            node_a.name + '_input_zero_point_')(gm_b)
    setattr(gm_b, zero_point_node_name, zero_point)
    zero_point_node = graph_c.create_node(
        'get_attr', zero_point_node_name, (), {}, zero_point_node_name)
    # create the quantize_per_tensor call
    return graph_c.create_node(
        'call_function', torch.quantize_per_tensor,
        (prev_node_c, scale_node, zero_point_node, torch.quint8), {},
        dtype_cast_name)

def _insert_dtype_cast_after_node(
    node_a: Node,
    node_c: Node,
    prev_node_c: Union[Node, List[Node]],
    gm_a: GraphModule,
    gm_b: GraphModule,
    graph_c: Graph,
    node_name_prefix: str,
    logger_cls: Callable,
    node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
) -> Union[Node, List[Node]]:
    """
    Given a starting graph C (derived from graph B) of

    ... -> prev_node_c -> node_c -> ...

    And a corresponding related node_a, inserts the correct dtype
    cast node after prev_node_c to cast into the dtype expected
    by node_a, resulting in:

                          dtype_cast
                        /
    ... -> prev_node_c -> node_c -> ...

    For example, if node_c is an int8 op and node_a is an fp32 op, this function
    will insert a dequant.
    """
    dtype_cast_op = None
    dtype_cast_mod_cls = None
    dtype_cast_method = None
    dtype_cast_method_dtype = None
    dtype_cast_scale = None
    dtype_cast_zero_point = None
    node_input_type_a, _node_output_type_a = \
        get_node_first_input_and_output_type(
            node_a, gm_a, logger_cls, node_type_to_io_type_map)
    node_input_type_c, _node_output_type_c = \
        get_node_first_input_and_output_type(
            node_c, gm_b, logger_cls, node_type_to_io_type_map)

    if (
        (node_input_type_a == NodeInputOrOutputType.FP32 and
         node_input_type_c == NodeInputOrOutputType.INT8) or
        (node_input_type_a == NodeInputOrOutputType.FP32 and
         node_input_type_c == NodeInputOrOutputType.FP16) or
        # TODO(future PR): determine the actual dtype of node_c,
        # the current code only works because dequantize works with
        # multiple input dtypes.
        (node_input_type_a == NodeInputOrOutputType.FP32 and
         node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8)
    ):
        dtype_cast_op = torch.dequantize
    elif (
        node_input_type_a == node_input_type_c and
        node_input_type_a != NodeInputOrOutputType.UNKNOWN
    ):
        dtype_cast_mod_cls = torch.nn.Identity
    elif (
        node_input_type_a == NodeInputOrOutputType.INT8 and
        node_input_type_c == NodeInputOrOutputType.FP32
    ):
        # int8 shadows fp32, the dtype cast needs to quantize to int8
        # with the right qparams.
        node_a_input_qparams = get_node_input_qparams(
            node_a, gm_a, node_type_to_io_type_map)
        if node_a_input_qparams is not None:
            dtype_cast_op = torch.quantize_per_tensor  # type: ignore[assignment]
            dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams
    elif (
        node_input_type_a == NodeInputOrOutputType.FP16 and
        node_input_type_c == NodeInputOrOutputType.FP32
    ):
        dtype_cast_method = 'to'
        dtype_cast_method_dtype = torch.float16
    else:
        raise AssertionError(
            f"dtype cast from {node_input_type_c} {node_c.format_node()} to " +
            f"{node_input_type_a} {node_a.format_node()} needs to be implemented")

    if isinstance(prev_node_c, Node):
        new_dtype_cast_name = \
            get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
        if dtype_cast_op:
            if dtype_cast_scale is not None and dtype_cast_zero_point is not None:
                return _insert_quantize_per_tensor_node(
                    prev_node_c, node_a, gm_b, graph_c, dtype_cast_scale,
                    dtype_cast_zero_point, new_dtype_cast_name)
            else:
                return graph_c.create_node(
                    'call_function', dtype_cast_op, (prev_node_c,), {},
                    new_dtype_cast_name)
        elif dtype_cast_method:
            return graph_c.create_node(
                'call_method', dtype_cast_method,
                (prev_node_c, dtype_cast_method_dtype), {}, new_dtype_cast_name)
        else:
            assert dtype_cast_mod_cls
            dtype_cast_mod = dtype_cast_mod_cls()
            setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
            return graph_c.create_node(
                'call_module', new_dtype_cast_name, (prev_node_c,), {},
                new_dtype_cast_name)
    elif isinstance(prev_node_c, list):
        results = []
        for prev_node_c_inner in prev_node_c:
            new_dtype_cast_name = \
                get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
            if dtype_cast_op:
                # TODO(future PR): add handling for quantize_per_tensor
                new_dtype_cast_node = graph_c.create_node(
                    'call_function', dtype_cast_op, (prev_node_c_inner,), {},
                    new_dtype_cast_name)
                results.append(new_dtype_cast_node)
            else:
                assert dtype_cast_mod_cls
                dtype_cast_mod = dtype_cast_mod_cls()
                setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
                new_dtype_cast_node = graph_c.create_node(
                    'call_module', new_dtype_cast_name, (prev_node_c_inner,), {},
                    new_dtype_cast_name)
                results.append(new_dtype_cast_node)
        return results
    else:
        raise AssertionError(f"type f{type(prev_node_c)} is not handled")

# TODO(future PR): look into using copy_node API instead
def _copy_node_from_a_to_c(
    node_a: Node,
    gm_a: GraphModule,
    gm_b: GraphModule,
    graph_c: Graph,
) -> Node:
    """
    Simple copy of node_a to graph_c.
    """
    if node_a.op == 'get_attr':
        node_a_copy_name = \
            get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
        node_a_obj = getattr_from_fqn(gm_a, node_a.target)  # type: ignore[arg-type]
        if torch.is_tensor(node_a_obj):
            node_a_obj = node_a_obj.detach()
        setattr(gm_b, node_a_copy_name, node_a_obj)
        node_a_copy = graph_c.create_node(
            node_a.op, node_a_copy_name, (), {}, node_a_copy_name)
        return node_a_copy
    elif node_a.op == 'call_method':
        assert node_a.target in ('dequantize', 'to'), \
            f"target {node_a.target} is not implemented"
        if node_a.target == 'dequantize':
            arg_copy = _copy_node_from_a_to_c(
                get_normalized_nth_input(node_a, gm_a, 0),
                gm_a, gm_b, graph_c)  # type: ignore[arg-type]
            node_a_copy_name = \
                get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
            node_a_copy = graph_c.create_node(
                node_a.op, node_a.target, (arg_copy,), {}, node_a_copy_name)
            return node_a_copy
        else:  # to
            arg_copy = _copy_node_from_a_to_c(
                get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c)  # type: ignore[arg-type]
            node_a_copy_name = \
                get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
            node_a_copy = graph_c.create_node(
                node_a.op, node_a.target,
                (arg_copy, get_normalized_nth_input(node_a, gm_a, 1)),
                {}, node_a_copy_name)
            return node_a_copy

    else:
        raise AssertionError(
            f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented")

def _can_insert_copy_of_subgraph_a(
    subgraph_a: NSSubgraph,
    gm_a: GraphModule,
    num_non_param_args_node_a: int,
) -> bool:
    """
    This function returns `False` if the input subgraph cannot be copied by
    `_insert_copy_of_subgraph_a_after_input_node_c`. This usually means
    that there is a corner case logic for which copy is not yet implemented.
    """
    # populate the list of nodes we need to check
    nodes = []
    cur_node = subgraph_a.end_node
    while cur_node != subgraph_a.start_node:
        nodes.append(cur_node)
        cur_node = get_normalized_nth_input(cur_node, gm_a, 0)  # type: ignore[assignment]
    nodes.append(cur_node)
    nodes.reverse()

    def _can_insert(node_a_arg, gm_a):
        if isinstance(node_a_arg, Node):
            arg_a = return_first_non_observer_node(node_a_arg, gm_a)
            if arg_a.op == 'call_method':
                return arg_a.target in ('dequantize', 'to')
            elif arg_a.op == 'get_attr':
                return True
            else:
                return False
        elif isinstance(node_a_arg, (list, tuple)):
            for el in node_a_arg:
                if not isinstance(el, Node):
                    return False
        return True

    # For each node, check if we handle the copy behavior. This follows the
    # logic in `_insert_copy_of_subgraph_a_after_input_node_c`.
    for node_a in nodes:

        local_num_non_param_args_node_a = num_non_param_args_node_a \
            if node_a is nodes[0] else 1

        norm_args_kwargs = node_a.normalized_arguments(
            gm_a, normalize_to_only_use_kwargs=True)
        if norm_args_kwargs is not None:
            norm_args, norm_kwargs = norm_args_kwargs
        else:
            norm_args, norm_kwargs = node_a.args, node_a.kwargs

        cur_idx = 0

        while cur_idx < len(norm_args):
            if cur_idx == 0:
                pass
            elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
                pass
            else:
                if not _can_insert(norm_args[cur_idx], gm_a):
                    return False
            cur_idx += 1

        for kwarg_name, kwarg_val in norm_kwargs.items():
            # stitch the inputs from base graph
            if cur_idx == 0:
                pass
            elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
                pass
            else:
                if not _can_insert(kwarg_val, gm_a):
                    return False
            cur_idx += 1

    return True

def _insert_copy_of_subgraph_a_after_input_node_c(
    input_node_c: Union[Node, List[Node]],
    input_node_c_2: Optional[Union[Node, List[Node]]],
    subgraph_a: NSSubgraph,
    gm_a: GraphModule,
    gm_b: GraphModule,
    node_name_prefix: str,
) -> Node:
    """
    TODO(before land): real docblock
    """
    if isinstance(input_node_c, Node):
        graph_c = input_node_c.graph
    else:
        assert isinstance(input_node_c, list)
        graph_c = input_node_c[0].graph

    # create a sequential list of the subgraphs' nodes from start to end,
    # because we need to add the nodes to graph C in non-reverse order
    nodes_of_a = [subgraph_a.end_node]
    cur_node = subgraph_a.end_node
    while cur_node != subgraph_a.start_node:
        cur_node = get_normalized_nth_input(cur_node, gm_a, 0)  # type: ignore[assignment]
        nodes_of_a.insert(0, cur_node)

    # go through nodes of a in order, and insert them into the graph of c
    # sequentially
    cur_node_a = nodes_of_a[0]
    cur_node_c = _insert_copy_of_node_a_after_input_node_c(
        input_node_c,
        input_node_c_2,
        cur_node_a,
        gm_a,
        gm_b,
        node_name_prefix)
    for cur_idx_a in range(1, len(nodes_of_a)):
        cur_node_a = nodes_of_a[cur_idx_a]
        prev_node_c = cur_node_c  # previous added node is the input to next node
        cur_node_c = _insert_copy_of_node_a_after_input_node_c(
            prev_node_c,
            # TODO(future PR): enable multiple inputs for nodes which are not at start of subgraph
            None,
            cur_node_a,
            gm_a,
            gm_b,
            node_name_prefix)
    # return the last inserted node
    return cur_node_c


def _insert_copy_of_node_a_after_input_node_c(
    input_node_c: Union[Node, List[Node]],
    input_node_c_2: Optional[Union[Node, List[Node]]],
    node_a: Node,
    gm_a: GraphModule,
    gm_b: GraphModule,
    node_name_prefix: str,
) -> Node:
    """
    Assume that node_a from graph_a has
      args (input, (input2)?, arg1, ...), and
      kwargs {kw0: kwarg0, ...}

    Note: input2 is optional. If it equals to None, we assume that the op
    has a single non-param input.  If it is specified, we assume that the op
    has two non-param inputs.

    Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b,
    and creates the corresponding nodes in graph_c. Note: observers are ignored,
    so if an arg is an observer we navigate up until we find a non-observer parent.

    If node_a is a call_module, points the module pointed to by node_a to gm_b.

    Creates the copy of node_a in graph_c, with input as the first arg,
    and all other args and kwargs pointing to the copies of the objects
    in gm_b created above.

    An example in pictures:

    graph A:
    ========

    input -------------> node_a
                         / / /
    (input_2)?----------/ / /
                         / /
    weight -> weight_obs  /
                         /
    bias ----------------

    graph C (derived from B):
    =========================

    input_node_c --> node_a_copy
                     / / /
    (input_node_c_2)? / /
                     / /
    weight_copy ----/ /
                     /
    bias_copy ------/
    """
    if isinstance(input_node_c, Node):
        graph_c = input_node_c.graph
    else:
        assert isinstance(input_node_c, list)
        graph_c = input_node_c[0].graph

    norm_args_kwargs = node_a.normalized_arguments(
        gm_a, normalize_to_only_use_kwargs=True)
    if norm_args_kwargs is not None:
        norm_args, norm_kwargs = norm_args_kwargs
    else:
        norm_args, norm_kwargs = node_a.args, node_a.kwargs

    new_args = []
    new_kwargs = {}

    def _copy_arg(arg):
        # copy the other inputs from the other graph
        if isinstance(arg, Node):
            arg = return_first_non_observer_node(arg, gm_a)
            arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c)
            return arg
        elif isinstance(arg, (int, float, torch.dtype)):
            return arg
        elif isinstance(kwarg_val, (list, tuple)):
            for el in kwarg_val:
                assert not isinstance(el, Node), \
                    "handling of Node inside list is not implemented"
            return arg
        else:
            raise AssertionError(
                f"handling for kwarg of type {type(kwarg_val)} is not implemented")

    cur_idx = 0

    while cur_idx < len(norm_args):
        if cur_idx == 0:
            new_arg = input_node_c
        elif cur_idx == 1 and input_node_c_2 is not None:
            new_arg = input_node_c_2
        else:
            new_arg = _copy_arg(norm_args[cur_idx])
        new_args.append(new_arg)
        cur_idx += 1

    for kwarg_name, kwarg_val in norm_kwargs.items():
        # stitch the inputs from base graph
        if cur_idx == 0:
            new_kwargs[kwarg_name] = input_node_c
        elif cur_idx == 1 and input_node_c_2 is not None:
            new_kwargs[kwarg_name] = input_node_c_2
        else:
            new_kwargs[kwarg_name] = _copy_arg(kwarg_val)
        cur_idx += 1

    new_args = tuple(new_args)  # type: ignore[assignment]

    node_a_shadows_c_name = \
        get_new_attr_name_with_prefix(node_name_prefix)(gm_b)

    if node_a.op == 'call_module':
        # if target is a module, we point to the module from gm_b
        new_mod_copy_name = \
            get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
        # fetch the corresponding module from gm_a
        assert isinstance(node_a.target, str)
        mod_a = getattr_from_fqn(gm_a, node_a.target)
        setattr(gm_b, new_mod_copy_name, mod_a)
        node_a_shadows_c = graph_c.create_node(
            node_a.op, new_mod_copy_name, new_args,
            new_kwargs, node_a_shadows_c_name)
        return node_a_shadows_c
    else:
        assert node_a.op in ('call_function', 'call_method')
        node_a_shadows_c = graph_c.create_node(
            node_a.op, node_a.target, new_args,
            new_kwargs, node_a_shadows_c_name)
        return node_a_shadows_c

def create_a_shadows_b(
    name_a: str,
    gm_a: GraphModule,
    name_b: str,
    gm_b: GraphModule,
    matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
    logger_cls: Callable,
    should_log_inputs: bool,
    node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> GraphModule:
    """
    Creates a new GraphModule consisting of the graph of C, with the meaningful
    nodes of A shadowing the corresponding nodes of B.  For example,

    Graph A:
    a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2

    Graph B:
    b0 -> op0_int8 -> b1 -> op1_int8 -> b2

    matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}

    Graph C (A shadows B):

        / dequant0 -> op0_fp32 -> logger_a_0  / dequant_1 -> op1_fp32 -> logger_a_1
       /                                     /
    b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1

    In a nutshell, this function does the following for each node pair:
    * copies the necessary attributes and modules from gm_a to gm_b,
      keeping names unique
    * adds a dtype cast op (dequant, quant, etc)
    * adds a copy of node_a in gm_b's graph
    * adds loggers to the outputs of node_a and node_b
    """

    if node_type_to_io_type_map is None:
        node_type_to_io_type_map = get_node_type_to_io_type_map()

    # graph_c is the graph created from copying the nodes of graph_b and inserting
    # the shadows with the nodes copied from graph_a
    graph_c = Graph()
    env_c: Dict[str, Any] = {}
    modules = dict(gm_b.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env_c[node.name])

    start_node_b_to_matched_subgraph_a_and_name = {}
    end_node_b_to_matched_subgraph_a_and_name = {}
    for match_name, match in matched_subgraph_pairs.items():
        subgraph_a, subgraph_b = match
        ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
        ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
        start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = \
            (subgraph_a, match_name, ref_node_type_a, ref_node_type_b)
        end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = \
            (subgraph_a, match_name, ref_node_type_a, ref_node_type_b)

    for node_b in gm_b.graph.nodes:
        if node_b.op == 'output':
            graph_c.output(map_arg(node_b.args[0], load_arg))
            continue

        # calculate the flags to determine what to do with this node
        node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name
        node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name

        if (node_b_is_start_node or node_b_is_end_node):

            if node_b_is_start_node:
                subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
                    start_node_b_to_matched_subgraph_a_and_name[node_b]
            else:
                assert node_b_is_end_node
                subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
                    end_node_b_to_matched_subgraph_a_and_name[node_b]

            all_op_types_support_shadowing = (
                op_type_supports_shadowing(subgraph_a.start_node) and
                op_type_supports_shadowing(node_b)
            )
            if not all_op_types_support_shadowing:
                print(
                    f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
                    f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
                    ', unsupported')
                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                continue

            # For both start_node and end_node verify that we know how to do
            # the dtype cast. If we do not, skip.
            node_input_type_a, node_output_type_a = \
                get_node_first_input_and_output_type(
                    subgraph_a.start_node, gm_a, logger_cls,
                    node_type_to_io_type_map)
            node_input_type_b, node_output_type_b = \
                get_node_first_input_and_output_type(
                    node_b, gm_b, logger_cls,
                    node_type_to_io_type_map)
            node_io_types_known_a_and_b = (
                node_input_type_a != NodeInputOrOutputType.UNKNOWN and
                node_output_type_a != NodeInputOrOutputType.UNKNOWN and
                node_input_type_b != NodeInputOrOutputType.UNKNOWN and
                node_output_type_b != NodeInputOrOutputType.UNKNOWN
            )
            if not node_io_types_known_a_and_b:
                print(
                    f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
                    f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
                    ', unknown dtype cast')
                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                continue

            # If we are shadowing from fp32 to int8, we need to insert
            # quantize_per_tensor call with qparams from the previous node.
            # Only do this if we are able to infer these qparams from the graph.
            if (
                node_input_type_a == NodeInputOrOutputType.INT8 and
                node_input_type_b == NodeInputOrOutputType.FP32
            ):
                node_a_input_qparams = get_node_input_qparams(
                    subgraph_a.start_node, gm_a, node_type_to_io_type_map)
                if not node_a_input_qparams:
                    print(
                        f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
                        f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
                        ', unknown input qparams')
                    env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                    continue

            num_non_param_args_node_a = \
                get_number_of_non_param_args(subgraph_a.start_node, gm_a)
            if not _can_insert_copy_of_subgraph_a(subgraph_a, gm_a, num_non_param_args_node_a):
                print(
                    f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
                    f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
                    ', unhandled logic in subgraph copy')
                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                continue

            fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
            fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b)

            if node_b_is_start_node:

                # if necessary, log the input of node_c
                if should_log_inputs:
                    prev_node_b = get_normalized_nth_input(node_b, gm_b, 0)
                    if isinstance(prev_node_b, Node):
                        prev_node_c = env_c[prev_node_b.name]
                        env_c[prev_node_c.name] = _insert_logger_after_node(
                            prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
                            node_b.name, name_b, ref_name, ref_node_type_b,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0, index_of_arg=0,
                            fqn=fqn_base_b)
                    elif isinstance(prev_node_b, list):
                        # first, save the prev_node instances, because they
                        # will be overwritten in the env after the first logger
                        # is added
                        prev_node_c_list = [env_c[arg.name] for arg in prev_node_b]

                        for arg_idx, arg in enumerate(prev_node_b):
                            prev_node_c = prev_node_c_list[arg_idx]
                            env_c[prev_node_c.name] = _insert_logger_after_node(
                                prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
                                node_b.name, name_b, ref_name, ref_node_type_b,
                                NSSingleResultValuesType.NODE_INPUT.value,
                                index_within_arg=arg_idx, index_of_arg=0,
                                fqn=fqn_base_b)
                    else:
                        # logging of inputs which are not lists is not supported yet
                        raise AssertionError(f"type {type(prev_node_b)} is not handled yet")
                # subgraph so far:
                #
                # (prev_node_c)+ -> (logger_c_input)?

            # Note: this if statement is always True, spelling it out to clarify code
            # intent.
            if node_b_is_start_node or node_b_is_end_node:
                # ensure env_c is populated with base node
                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                node_c = env_c[node_b.name]

                # after this point,
                #
                # node_a is the original node from graph_a, with parent module gm_a
                # node_b is the original node from graph_b, with parent module gm_b
                # node_c is the copy of node_b in graph_c
                #
                # subgraph so far:
                #
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

            if node_b_is_start_node:

                # cast dtype from the dtype of node_c's input to the dtype of
                # node_a's input (dequant, etc)
                # prev_node_c = node_c.args[0]
                prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
                if should_log_inputs:
                    # skip the input logger when inserting a dtype cast
                    if isinstance(prev_node_c, Node):
                        prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
                    elif isinstance(prev_node_c, list):
                        prev_node_c = [get_normalized_nth_input(arg, gm_b, 0) for arg in prev_node_c]
                dtype_cast_node = _insert_dtype_cast_after_node(
                    subgraph_a.start_node, node_c, prev_node_c, gm_a, gm_b, graph_c,
                    node_b.name + '_dtype_cast_', logger_cls,
                    node_type_to_io_type_map)
                # note: not inserting to env_c because all nodes which use the dtype
                #   casts are copied from graph_a
                #
                # subgraph so far:
                #
                #           (dtype_cast_node)+
                #                  /
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

                # if input logging is enabled, log the input to the subgraph
                if should_log_inputs:
                    # TODO: explain this
                    ref_node_name = ''
                    if isinstance(dtype_cast_node, Node):
                        dtype_cast_node = _insert_logger_after_node(
                            dtype_cast_node, gm_b, logger_cls, '_ns_logger_a_inp_',
                            ref_node_name, name_a, ref_name, ref_node_type_a,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0, index_of_arg=0,
                            fqn=fqn_base_a)
                        input_logger: Union[Node, List[Node]] = dtype_cast_node
                    else:
                        assert isinstance(dtype_cast_node, list)
                        new_loggers = []
                        for dtype_cast_idx, dtype_cast_node_inner in enumerate(dtype_cast_node):
                            dtype_cast_logger = _insert_logger_after_node(
                                dtype_cast_node_inner, gm_b, logger_cls, '_ns_logger_a_inp_',
                                ref_node_name, name_a, ref_name, ref_node_type_a,
                                NSSingleResultValuesType.NODE_INPUT.value,
                                index_within_arg=dtype_cast_idx,
                                index_of_arg=0,
                                fqn=fqn_base_a)
                            new_loggers.append(dtype_cast_logger)
                        dtype_cast_node = new_loggers
                        input_logger = dtype_cast_node
                    # subgraph so far:
                    #
                    #       (dtype_cast_node)+ -> (logger_a_input)?
                    #                  /
                    # prev_node_c -> (logger_c_input)? -> node_start_c

                # hook up the new mod_a copy to be in the graph, receiving the
                # same inputs as mod_b does, with dtype cast to match a
                # Some ops, such as LSTMs, have two non-param inputs. If we have
                # such an op, pass the second param as well. Note: dtype casting
                # for the second param is not implemented yet, it can be added
                # later if there is a use case.
                node_c_second_non_param_arg = None
                num_non_param_args_node_a = get_number_of_non_param_args(subgraph_a.start_node, gm_a)
                if num_non_param_args_node_a == 2:
                    # node_c_second_non_param_arg = node_c.args[1]
                    node_c_second_non_param_arg = get_normalized_nth_input(node_c, gm_b, 1)
                node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
                    dtype_cast_node, node_c_second_non_param_arg,
                    subgraph_a, gm_a, gm_b, node_c.name + '_shadow_copy_')
                env_c[node_a_shadows_c.name] = node_a_shadows_c
                # subgraph so far:
                #
                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
                #                  /
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

                if should_log_inputs:
                    # When we created the input logger, we left the ref_node_name
                    # as an empty string, because the subgraph copy did not exist
                    # yet. Now that the subgraph copy exists, we modify this name
                    # to its true value.
                    # Note: the alternative to this is to create the input logger
                    # after creating the subgraph, which is slightly more
                    # complicated. This is the lesser of two evils.
                    # input_logger = env_c[dtype_cast_node.name]
                    # Find the first node in the subgraph
                    cur_node = node_a_shadows_c
                    while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger:
                        cur_node = get_normalized_nth_input(cur_node, gm_b, 0)  # type: ignore[assignment]
                    if isinstance(input_logger, Node):
                        input_logger_mod = getattr(gm_b, input_logger.name)
                        input_logger_mod.ref_node_name = cur_node.name
                    else:
                        assert isinstance(input_logger, list)
                        for input_logger_inner in input_logger:
                            input_logger_mod = getattr(gm_b, input_logger_inner.name)
                            input_logger_mod.ref_node_name = cur_node.name

                # hook up a logger to the mod_a copy
                env_c[node_a_shadows_c.name] = _insert_logger_after_node(
                    env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_',
                    node_a_shadows_c.name, name_a, ref_name, ref_node_type_a,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0, index_of_arg=0,
                    fqn=fqn_base_a)
                # subgraph so far:
                #
                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
                #                  /
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

            if node_b_is_end_node:

                # hook up a logger to the mod_b copy
                env_c[node_b.name] = _insert_logger_after_node(
                    env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_',
                    node_b.name, name_b, ref_name, ref_node_type_b,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0, index_of_arg=0,
                    fqn=fqn_base_b)
                # subgraph so far:
                #
                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
                #                  /
                # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c
                #
                # Note: node_start_c may be the same node as node_end_c, or they
                # may have nodes inbetween.

        else:
            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)

    gm_c = GraphModule(gm_b, graph_c)
    return gm_c