File: onednn.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (652 lines) | stat: -rw-r--r-- 19,252 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
# mypy: allow-untyped-defs
import itertools
import operator

import torch
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.quantized.reference as nnqr
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.quantization.fuser_method_mappings import _sequential_wrapper2
from torch.ao.quantization.utils import MatchAllNode

from ._common_operator_config_utils import (
    _get_binary_op_configs,
    _get_bn_configs,
    _get_cat_config,
    _get_conv_configs,
    _get_default_op_configs,
    _get_embedding_op_configs,
    _get_fixed_qparams_op_configs,
    _get_linear_configs,
    _get_ln_configs,
    _get_rnn_op_configs,
    _get_share_qparams_op_configs,
)
from .backend_config import (
    BackendConfig,
    BackendPatternConfig,
    DTypeConfig,
    ObservationType,
)


# ===================
# |  DTYPE CONFIGS  |
# ===================

onednn_weighted_op_int8_dtype_config = DTypeConfig(
    input_dtype=torch.quint8,
    output_dtype=torch.quint8,
    weight_dtype=torch.qint8,
    bias_dtype=torch.float,
)

onednn_op_quint8_dtype_config = DTypeConfig(
    input_dtype=torch.quint8,
    output_dtype=torch.quint8,
)

onednn_dynamic_int8_dtype_config = DTypeConfig(
    input_dtype=torch.quint8,
    output_dtype=torch.float,
    weight_dtype=torch.qint8,
    bias_dtype=torch.float,
    is_dynamic=True,
)

onednn_weight_only_qint8_dtype_config = DTypeConfig(
    input_dtype=torch.float,
    output_dtype=torch.float,
    weight_dtype=torch.qint8,
)

onednn_input_output_only_quint8_dtype_config = DTypeConfig(
    input_dtype=torch.quint8,
    output_dtype=torch.quint8,
    weight_dtype=torch.float,
    bias_dtype=torch.float,
)

# ===================
# |  FUSER METHODS  |
# ===================


def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu):
    r"""Given the linear, bn and leaky_relu modules, fuses them and returns the fused module
    Args:
        is_qat: a flag for whether we are using quantization aware training fusion
                or post training quantization fusion
        linear: Module instance of type Linear
        bn: BatchNorm1d instance that needs to be fused with the linear layer
        leaky_relu: LeakyReLU instance that needs to be fused with the linear layer
    Examples::
        >>> # xdoctest: +SKIP(failing)
        >>> m1 = nn.Linear(20, 10)
        >>> b1 = nn.BatchNorm1d(10)
        >>> lr = nn.LeakyReLU(0.01)
        >>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr)
    """
    assert (
        linear.training == bn.training and bn.training == leaky_relu.training
    ), "Linear, BN and LeakyReLU all must be in the same mode (train or eval)."

    if is_qat:
        raise NotImplementedError(
            f"Cannot fuse train modules: {(linear, bn, leaky_relu)}"
        )
    else:
        map_to_fused_module_eval = {
            nn.Linear: nni.LinearLeakyReLU,
        }
        fused_module = map_to_fused_module_eval.get(type(linear), None)
        if fused_module is not None:
            fused_linear = nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
            fm = fused_module(fused_linear, leaky_relu)
            return fm
        else:
            raise NotImplementedError(
                f"Cannot fuse eval modules: {(linear, bn, leaky_relu)}"
            )


# ======================
# |  CONFIGS FOR CONV  |
# ======================
observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT

conv_dtype_configs = [onednn_weighted_op_int8_dtype_config]
conv_configs = _get_conv_configs(conv_dtype_configs)

# (1) Conv2d + Add

# conv2d   Y
#   \   /
#    add

# include:
# conv2d conv2d
#   \   /
#    add


def _fuse_conv_add_left(is_qat, add, conv, _):
    return nni.ConvAdd2d(conv, add)


def _conv_add_root_node_getter_left(pattern):
    _, conv, _ = pattern
    return conv


def _conv_add_extra_inputs_getter_left(pattern):
    """get inputs pattern for extra inputs, inputs for root node
    are assumed to be copied over from root node to the fused node
    """
    _, conv, extra_input = pattern
    return [extra_input]


# conv2d
#  \
#  bn   Y
#   \   /
#    add


def _fuse_conv_bn_add_left(is_qat, add, bn_conv, _):
    bn, conv = bn_conv
    if is_qat:
        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}")
    else:
        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
        return nni.ConvAdd2d(fused_conv, add)


def _conv_bn_add_root_node_getter_left(add_pattern):
    _, bn_conv, _ = add_pattern
    bn, conv = bn_conv
    return conv


def _conv_bn_add_extra_inputs_getter_left(add_pattern):
    """get inputs pattern for extra inputs, inputs for root node
    are assumed to be copied over from root node to the fused node
    """
    _, bn_conv, extra_input = add_pattern
    bn, conv = bn_conv
    return [extra_input]


conv_add_left_optioins = itertools.product(
    [True, False],  # with_bn
    [torch.add, operator.add],  # add_op
)

for with_bn, add_op in conv_add_left_optioins:
    if with_bn:
        conv_configs.append(
            BackendPatternConfig()
            ._set_pattern_complex_format(
                (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)
            )  # noqa: E131
            .set_observation_type(observation_type)
            .set_dtype_configs(conv_dtype_configs)
            .set_fuser_method(_fuse_conv_bn_add_left)
            ._set_root_node_getter(_conv_bn_add_root_node_getter_left)
            ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_left)
            .set_fused_module(nni.ConvAdd2d)
        )
    else:
        conv_configs.append(
            BackendPatternConfig()
            ._set_pattern_complex_format(
                (add_op, nn.Conv2d, MatchAllNode)
            )  # noqa: E131
            .set_observation_type(observation_type)
            .set_dtype_configs(conv_dtype_configs)
            .set_fuser_method(_fuse_conv_add_left)
            ._set_root_node_getter(_conv_add_root_node_getter_left)
            ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_left)
            .set_fused_module(nni.ConvAdd2d)
        )

#  Y   conv2d
#   \   /
#    add


def _fuse_conv_add_right(is_qat, add, _, conv):
    return nni.ConvAdd2d(conv, add)


def _conv_add_root_node_getter_right(pattern):
    add, _, conv = pattern
    return conv


def _conv_add_extra_inputs_getter_right(pattern):
    """get inputs pattern for extra inputs, inputs for root node
    are assumed to be copied over from root node to the fused node
    """
    _, extra_input, conv = pattern
    return [extra_input]


#      conv2d
#        /
#  Y    bn
#   \   /
#    add


def _fuse_conv_bn_add_right(is_qat, add, _, bn_conv):
    bn, conv = bn_conv
    if is_qat:
        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}")
    else:
        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
        return nni.ConvAdd2d(fused_conv, add)


def _conv_bn_add_root_node_getter_right(pattern):
    add, _, bn_conv = pattern
    bn, conv = bn_conv
    return conv


def _conv_bn_add_extra_inputs_getter_right(pattern):
    """get inputs pattern for extra inputs, inputs for root node
    are assumed to be copied over from root node to the fused node
    """
    _, extra_input, bn_conv = pattern
    bn, conv = bn_conv
    return [extra_input]


conv_add_optioins = itertools.product(
    [True, False],  # with_bn
    [torch.add, operator.add],  # add_op
)

for with_bn, add_op in conv_add_optioins:
    if with_bn:
        conv_configs.append(
            BackendPatternConfig()
            ._set_pattern_complex_format(
                (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))
            )  # noqa: E131
            .set_observation_type(observation_type)
            .set_dtype_configs(conv_dtype_configs)
            .set_fuser_method(_fuse_conv_bn_add_right)
            ._set_root_node_getter(_conv_bn_add_root_node_getter_right)
            ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_right)
            .set_fused_module(nni.ConvAdd2d)
        )
    else:
        conv_configs.append(
            BackendPatternConfig()
            ._set_pattern_complex_format(
                (add_op, MatchAllNode, nn.Conv2d)
            )  # noqa: E131
            .set_observation_type(observation_type)
            .set_dtype_configs(conv_dtype_configs)
            .set_fuser_method(_fuse_conv_add_right)
            ._set_root_node_getter(_conv_add_root_node_getter_right)
            ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_right)
            .set_fused_module(nni.ConvAdd2d)
        )

conv_configs.append(
    BackendPatternConfig(nni.ConvAdd2d)
    .set_observation_type(observation_type)  # noqa: E131
    .set_dtype_configs(conv_dtype_configs)
    .set_root_module(nn.Conv2d)
    .set_reference_quantized_module(nnqr.Conv2d)
)

# (2) Conv2d + Add + Relu

# conv2d Y
#   \   /
#    add
#     \
#     relu


def _fuse_conv_add_relu_left(is_qat, relu, add_pattern):
    add, conv, _ = add_pattern
    return nni.ConvAddReLU2d(conv, add, relu)


def _conv_add_relu_root_node_getter_left(pattern):
    relu, add_pattern = pattern
    _, conv, _ = add_pattern
    return conv


def _conv_add_relu_extra_inputs_getter_left(pattern):
    """get inputs pattern for extra inputs, inputs for root node
    are assumed to be copied over from root node to the fused node
    """
    relu, add_pattern = pattern
    _, conv, extra_input = add_pattern
    return [extra_input]


# conv2d
#  \
#  bn   Y
#   \   /
#    add
#     \
#     relu


def _fuse_conv_bn_add_relu_left(is_qat, relu, add_pattern):
    add, bn_conv, _ = add_pattern
    bn, conv = bn_conv
    if is_qat:
        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}")
    else:
        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
        return nni.ConvAddReLU2d(fused_conv, add, relu)


def _conv_bn_add_relu_root_node_getter_left(pattern):
    relu, add_pattern = pattern
    _, bn_conv, _ = add_pattern
    bn, conv = bn_conv
    return conv


def _conv_bn_add_relu_extra_inputs_getter_left(pattern):
    """get inputs pattern for extra inputs, inputs for root node
    are assumed to be copied over from root node to the fused node
    """
    relu, add_pattern = pattern
    _, bn_conv, extra_input = add_pattern
    bn, conv = bn_conv
    return [extra_input]


conv_add_relu_left_optioins = itertools.product(
    [True, False],  # with_bn
    [torch.add, operator.add],  # add_op
)

for with_bn, add_op in conv_add_relu_left_optioins:
    if with_bn:
        conv_configs.append(
            BackendPatternConfig()
            ._set_pattern_complex_format(
                (nn.ReLU, (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))
            )  # noqa: E131
            .set_observation_type(observation_type)
            .set_dtype_configs(conv_dtype_configs)
            .set_fuser_method(_fuse_conv_bn_add_relu_left)
            ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_left)
            ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_left)
            .set_fused_module(nni.ConvAddReLU2d)
        )
    else:
        conv_configs.append(
            BackendPatternConfig()
            ._set_pattern_complex_format(
                (nn.ReLU, (add_op, nn.Conv2d, MatchAllNode))
            )  # noqa: E131
            .set_observation_type(observation_type)
            .set_dtype_configs(conv_dtype_configs)
            .set_fuser_method(_fuse_conv_add_relu_left)
            ._set_root_node_getter(_conv_add_relu_root_node_getter_left)
            ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_left)
            .set_fused_module(nni.ConvAddReLU2d)
        )

#  Y   conv2d
#   \   /
#    add
#     \
#     relu


def _fuse_conv_add_relu_right(is_qat, relu, add_pattern):
    add, _, conv = add_pattern
    return nni.ConvAddReLU2d(conv, add, relu)


def _conv_add_relu_root_node_getter_right(pattern):
    relu, add_pattern = pattern
    _, _, conv = add_pattern
    return conv


def _conv_add_relu_extra_inputs_getter_right(pattern):
    """get inputs pattern for extra inputs, inputs for root node
    are assumed to be copied over from root node to the fused node
    """
    relu, add_pattern = pattern
    _, extra_input, conv = add_pattern
    return [extra_input]


#      conv2d
#        /
#  Y    bn
#   \   /
#    add
#     \
#     relu


def _fuse_conv_bn_add_relu_right(is_qat, relu, add_pattern):
    add, _, bn_conv = add_pattern
    bn, conv = bn_conv
    if is_qat:
        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}")
    else:
        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
        return nni.ConvAddReLU2d(fused_conv, add, relu)


def _conv_bn_add_relu_root_node_getter_right(pattern):
    relu, add_pattern = pattern
    _, _, bn_conv = add_pattern
    bn, conv = bn_conv
    return conv


def _conv_bn_add_relu_extra_inputs_getter_right(pattern):
    """get inputs pattern for extra inputs, inputs for root node
    are assumed to be copied over from root node to the fused node
    """
    relu, add_pattern = pattern
    _, extra_input, bn_conv = add_pattern
    bn, conv = bn_conv
    return [extra_input]


conv_add_relu_optioins = itertools.product(
    [True, False],  # with_bn
    [torch.add, operator.add],  # add_op
)

for with_bn, add_op in conv_add_relu_optioins:
    if with_bn:
        conv_configs.append(
            BackendPatternConfig()
            ._set_pattern_complex_format(
                (nn.ReLU, (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))
            )  # noqa: E131
            .set_observation_type(observation_type)
            .set_dtype_configs(conv_dtype_configs)
            .set_fuser_method(_fuse_conv_bn_add_relu_right)
            ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_right)
            ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_right)
            .set_fused_module(nni.ConvAddReLU2d)
        )
    else:
        conv_configs.append(
            BackendPatternConfig()
            ._set_pattern_complex_format(
                (nn.ReLU, (add_op, MatchAllNode, nn.Conv2d))
            )  # noqa: E131
            .set_observation_type(observation_type)
            .set_dtype_configs(conv_dtype_configs)
            .set_fuser_method(_fuse_conv_add_relu_right)
            ._set_root_node_getter(_conv_add_relu_root_node_getter_right)
            ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_right)
            .set_fused_module(nni.ConvAddReLU2d)
        )

conv_configs.append(
    BackendPatternConfig(nni.ConvAddReLU2d)
    .set_observation_type(observation_type)  # noqa: E131
    .set_dtype_configs(conv_dtype_configs)
    .set_root_module(nn.Conv2d)
    .set_reference_quantized_module(nnqr.Conv2d)
)

# ========================
# |  CONFIGS FOR LINEAR  |
# ========================

linear_dtype_configs = [
    onednn_weighted_op_int8_dtype_config,
    onednn_dynamic_int8_dtype_config,
]
linear_configs = _get_linear_configs(linear_dtype_configs)


def _add_eltwise_fusion_configs(
    configs,
    root_module,
    root_op,
    post_module,
    post_op,
    dtype_configs,
    fuser_method,
    fused_module,
    observation_type,
    ref_quant_module,
):
    # 1 base module + op module fusion config
    configs.append(
        BackendPatternConfig((root_module, post_module))
        .set_dtype_configs(dtype_configs)  # noqa: E131
        .set_fuser_method(fuser_method)
        .set_fused_module(fused_module)
    )
    # base module + functional post op
    configs.append(
        BackendPatternConfig((root_module, post_op))
        .set_dtype_configs(dtype_configs)  # noqa: E131
        .set_fuser_method(fuser_method)
        .set_fused_module(fused_module)
    )

    # 2 fused module configs
    configs.append(
        BackendPatternConfig(fused_module)
        .set_observation_type(observation_type)  # noqa: E131
        .set_dtype_configs(dtype_configs)
        .set_root_module(root_module)
        .set_reference_quantized_module(ref_quant_module)
    )

    # 3 functional base op + post op configs
    configs.append(
        BackendPatternConfig((root_op, post_module))
        .set_observation_type(observation_type)  # noqa: E131
        .set_dtype_configs(dtype_configs)
    )
    configs.append(
        BackendPatternConfig((root_op, post_op))
        .set_observation_type(observation_type)  # noqa: E131
        .set_dtype_configs(dtype_configs)
    )


# Configs for linear + leaky_relu fusion
_add_eltwise_fusion_configs(
    linear_configs,
    nn.Linear,
    F.linear,
    nn.LeakyReLU,
    F.leaky_relu,
    linear_dtype_configs,
    _sequential_wrapper2(nni.LinearLeakyReLU),
    nni.LinearLeakyReLU,
    observation_type,
    nnqr.Linear,
)

# Configs for linear module + batchnorm + leaky_relu
linear_configs.append(
    BackendPatternConfig((nn.Linear, nn.BatchNorm1d, nn.LeakyReLU))
    .set_dtype_configs(linear_dtype_configs)  # noqa: E131
    .set_fuser_method(_fuse_linear_bn_leaky_relu)
    .set_fused_module(nni.LinearLeakyReLU)
)

# Configs for linear + tanh fusion
_add_eltwise_fusion_configs(
    linear_configs,
    nn.Linear,
    F.linear,
    nn.Tanh,
    torch.tanh,
    linear_dtype_configs,
    _sequential_wrapper2(nni.LinearTanh),
    nni.LinearTanh,
    observation_type,
    nnqr.Linear,
)

# ===========================
# |  CONFIGS FOR OTHER OPS  |
# ===========================

binary_op_dtype_configs = [onednn_op_quint8_dtype_config]
default_op_dtype_configs = [onednn_op_quint8_dtype_config]
fixed_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config]
share_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config]
rnn_op_dtype_configs = [onednn_dynamic_int8_dtype_config]
embedding_op_dtype_configs = [onednn_weight_only_qint8_dtype_config]
layer_norm_op_dtype_configs = [onednn_input_output_only_quint8_dtype_config]

# =====================
# |  BACKEND CONFIGS  |
# =====================


def get_onednn_backend_config() -> BackendConfig:
    """
    Return the `BackendConfig` for PyTorch's native ONEDNN backend.
    """
    return (
        BackendConfig("onednn")
        .set_backend_pattern_configs(conv_configs)
        .set_backend_pattern_configs(linear_configs)
        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs))
        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs))
        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs))
        .set_backend_pattern_configs(
            _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)
        )
        .set_backend_pattern_configs(
            _get_share_qparams_op_configs(share_qparams_op_dtype_configs)
        )
        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs))
        .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs))
        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs))
        .set_backend_pattern_configs(
            _get_embedding_op_configs(embedding_op_dtype_configs)
        )
    )


__all__ = [
    "get_onednn_backend_config",
]