1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126
|
import unittest
from functools import partial
from typing import List
import numpy as np
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_device_type import tol, toleranceOverride
from torch.testing._internal.common_dtype import (
all_types_and,
all_types_and_complex_and,
complex_types,
floating_and_complex_types_and,
floating_types_and,
integral_types,
)
from torch.testing._internal.opinfo.core import (
DecorateInfo,
gradcheck_wrapper_masked_operation,
gradcheck_wrapper_masked_pointwise_operation,
M,
OpInfo,
ReductionOpInfo,
S,
sample_inputs_reduction,
SampleInput,
)
from torch.testing._internal.opinfo.utils import reference_reduction_numpy
# Used for log_softmax, softmax, softmin
def sample_inputs_softmax_variant(
op_info, device, dtype, requires_grad, with_dtype=False, **kwargs
):
make_arg = partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
cases = [
((S,), (0,)),
((S, S), (0,)),
((S, S), (1,)),
((S, S), (-1,)),
((S, M, S), (2,)),
]
kwargs = dict(dtype=torch.float64) if with_dtype else None
# PyTorch on XLA throws an error when passed with dim argument for 0d tensor.
# See https://github.com/pytorch/xla/issues/3061 for more details.
if torch.device(device).type != "xla":
cases.append(((), (0,)))
return [
SampleInput(make_arg(shape), args=dim, kwargs=kwargs) for shape, dim in cases
]
def _generate_masked_op_mask(input_shape, device, **kwargs):
yield None
yield make_tensor(input_shape, dtype=torch.bool, device=device, requires_grad=False)
if len(input_shape) > 2:
# broadcast last mask dimension:
yield make_tensor(
input_shape[:-1] + (1,),
dtype=torch.bool,
device=device,
requires_grad=False,
)
# broadcast middle mask dimension:
yield make_tensor(
input_shape[:1] + (1,) + input_shape[2:],
dtype=torch.bool,
device=device,
requires_grad=False,
)
# broadcast first mask dimension:
yield make_tensor(
(1,) + input_shape[1:], dtype=torch.bool, device=device, requires_grad=False
)
# mask.ndim < input.ndim
yield make_tensor(
input_shape[1:], dtype=torch.bool, device=device, requires_grad=False
)
# mask.ndim == 1
yield make_tensor(
input_shape[-1:], dtype=torch.bool, device=device, requires_grad=False
)
# masks that require broadcasting of inputs (mask.ndim >
# input.ndim) will not be supported, however, we may
# reconsider this if there will be demand on this kind of
# degenerate cases.
def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for masked reduction operators.
Masked reduction operator is a reduction operator with trailing
mask optional argument. A mask is a bool tensor with the same
shape as input or a shape that is broadcastable to input shape.
"""
kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims
for sample_input in sample_inputs_reduction(
op_info, device, dtype, requires_grad, **kwargs
):
for mask in _generate_masked_op_mask(
sample_input.input.shape, device, **kwargs
):
sample_input_args, sample_input_kwargs = sample_input.args, dict(
mask=mask, **sample_input.kwargs
)
yield SampleInput(
sample_input.input.detach().requires_grad_(requires_grad),
args=sample_input_args,
kwargs=sample_input_kwargs,
)
if (
not requires_grad
and dtype.is_floating_point
and sample_input.input.ndim == 2
and mask is not None
and mask.shape == sample_input.input.shape
):
for v in [torch.inf, -torch.inf, torch.nan]:
t = sample_input.input.detach()
t.diagonal(0, -2, -1).fill_(v)
yield SampleInput(
t.requires_grad_(requires_grad),
args=sample_input_args,
kwargs=sample_input_kwargs,
)
def sample_inputs_sparse_coo_masked_reduction(
op_info, device, dtype, requires_grad, **kwargs
):
"""Sample inputs for masked reduction operators that support inputs
with sparse coo layouts.
"""
if op_info.supports_sparse:
op_name = op_info.name.replace("masked.", "")
for sample_input in sample_inputs_masked_reduction(
op_info, device, dtype, requires_grad, **kwargs
):
mask = sample_input.kwargs.get("mask")
if mask is not None:
sample_input_kwargs = sample_input.kwargs.copy()
sample_input_kwargs.update(mask=mask.to_sparse())
yield SampleInput(
sample_input.input.to_sparse(),
args=sample_input.args,
kwargs=sample_input_kwargs,
)
else:
if op_name in {"prod", "amax", "amin"}:
# FIXME: for now reductions with non-zero reduction identity and
# unspecified mask are not supported for sparse COO
# tensors, see torch.masked.prod implementation
# for details.
continue
yield SampleInput(
sample_input.input.to_sparse(),
args=sample_input.args,
kwargs=sample_input.kwargs,
)
def sample_inputs_sparse_csr_masked_reduction(
op_info, device, dtype, requires_grad, **kwargs
):
"""Sample inputs for masked reduction operators that support inputs
with sparse csr layouts.
"""
if op_info.supports_sparse_csr:
op_name = op_info.name.replace("masked.", "")
for sample_input in sample_inputs_masked_reduction(
op_info, device, dtype, requires_grad, **kwargs
):
if not (
sample_input.input.ndim == 2 and sample_input.kwargs.get("keepdim")
):
# - sparse CSR tensors are always 2-D tensors
# - masked reduction on CSR tensors are defined only if keepdim is True.
continue
mask = sample_input.kwargs.get("mask")
if mask is not None:
sample_input_kwargs = sample_input.kwargs.copy()
sample_input_kwargs.update(mask=mask.to_sparse_csr())
new_sample = SampleInput(
sample_input.input.to_sparse_csr(),
args=sample_input.args,
kwargs=sample_input_kwargs,
)
else:
if op_name in ["prod", "amax", "amin", "mean"]:
# reductions with non-zero reduction identity and
# unspecified mask is not supported for sparse CSR
# tensors, see torch.masked.prod implementation
# for details.
continue
new_sample = SampleInput(
sample_input.input.to_sparse_csr(),
args=sample_input.args,
kwargs=sample_input.kwargs,
)
yield new_sample
if sample_input.kwargs["dim"] == 0:
# Reductions of CSR tensors use different implementations for
# inner and/or outer dimensions. So, as a minimum of testing CSR
# implementations the following kwargs must be generated:
# dict(dim=0, keepdim=True)
# dict(dim=1, keepdim=True)
# dict(dim=(0, 1), keepdim=True)
# Here we generate the dim=1 case from the dim=0 case.
sample_input_kwargs = new_sample.kwargs.copy()
sample_input_kwargs.update(dim=1)
yield SampleInput(
new_sample.input.clone(),
args=sample_input.args,
kwargs=sample_input_kwargs,
)
def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for masked norm."""
for ord in [2.0, 1, float("inf"), float("-inf"), 0]:
for sample_input in sample_inputs_masked_reduction(
op_info, device, dtype, requires_grad, **kwargs
):
sample_input_args, sample_input_kwargs = (
ord,
) + sample_input.args, sample_input.kwargs.copy()
yield SampleInput(
sample_input.input.clone().requires_grad_(requires_grad),
args=sample_input_args,
kwargs=sample_input_kwargs,
)
def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for masked std/var."""
for unbiased in [False, True]:
for sample_input in sample_inputs_masked_reduction(
op_info, device, dtype, requires_grad, **kwargs
):
if sample_input.args:
dim = sample_input.args[0]
sample_input_args = (
sample_input.args[:1] + (unbiased,) + sample_input.args[1:]
)
sample_input_kwargs = sample_input.kwargs.copy()
else:
dim = sample_input.kwargs.get("dim")
sample_input_args = sample_input.args
sample_input_kwargs = dict(sample_input.kwargs, unbiased=unbiased)
if requires_grad:
if sample_input_kwargs.get("mask") is None:
orig_count = torch.masked.sum(
torch.ones(sample_input.input.shape, dtype=torch.int64),
dim,
keepdim=True,
)
else:
inmask = torch.masked._input_mask(
sample_input.input, *sample_input_args, **sample_input_kwargs
)
orig_count = torch.masked.sum(
inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
dim,
keepdim=True,
mask=inmask,
)
if orig_count.min() <= int(unbiased) + 1:
# Skip samples that lead to singularities in var
# computation resulting nan values both in var and
# autograd output that test_grad_fn cannot handle
# correctly. Also, skip samples when the autograd output
# for std could not be handled correctly due to torch.sqrt
continue
yield SampleInput(
sample_input.input.detach().requires_grad_(requires_grad),
args=sample_input_args,
kwargs=sample_input_kwargs,
)
def sample_inputs_masked_softmax(
op_info, device, dtype, requires_grad, with_dtype=False, **kwargs
):
"""Sample inputs for masked softmax, log_softmax, and softmin.
Masked normalization operator is a reduction operator with
trailing mask optional argument. A mask is a bool tensor with the
same shape as input or a shape that is broadcastable to input
shape.
"""
inputs: List[SampleInput] = []
for sample_input in sample_inputs_softmax_variant(
op_info, device, dtype, requires_grad, with_dtype=with_dtype, **kwargs
):
for mask in _generate_masked_op_mask(
sample_input.input.shape, device, **kwargs
):
sample_input_args, sample_input_kwargs = sample_input.args, dict(
mask=mask, **sample_input.kwargs
)
inputs.append(
SampleInput(
sample_input.input.clone().requires_grad_(requires_grad),
args=sample_input_args,
kwargs=sample_input_kwargs,
)
)
return inputs
def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for masked cumsum and cumprod."""
inputs: List[SampleInput] = []
for sample_input in sample_inputs_softmax_variant(
op_info, device, dtype, requires_grad, **kwargs
):
for mask in _generate_masked_op_mask(
sample_input.input.shape, device, **kwargs
):
if type(mask) != torch.Tensor:
continue
sample_input_args, sample_input_kwargs = sample_input.args, dict(
mask=mask, **sample_input.kwargs
)
if "keepdim" in sample_input_kwargs:
sample_input_kwargs.pop("keepdim")
# dimension is required
if sample_input_args:
dim = sample_input.args[0]
else:
if "dim" not in sample_input_kwargs:
continue
dim = sample_input_kwargs.pop("dim")
sample_input_args = (dim,)
inputs.append(
SampleInput(
sample_input.input.clone().requires_grad_(requires_grad),
args=sample_input_args,
kwargs=sample_input_kwargs,
)
)
return inputs
def sample_inputs_masked_logaddexp(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for masked logaddexp."""
inputs: List[SampleInput] = []
shapes = [(S,), (S, S), (S, M, S)]
input_mask_lists = [
list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes
]
other_mask_lists = [
list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes
]
for shape, input_masks, other_masks in zip(
shapes, input_mask_lists, other_mask_lists
):
for input_mask, other_mask in zip(input_masks, other_masks):
input = make_tensor(
shape, dtype=dtype, device=device, requires_grad=requires_grad
)
other = make_tensor(
shape, dtype=dtype, device=device, requires_grad=requires_grad
)
inputs.append(
SampleInput(
input.clone().requires_grad_(requires_grad),
args=(other.clone().requires_grad_(requires_grad),),
kwargs=dict(input_mask=input_mask, other_mask=other_mask),
)
)
return inputs
def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for masked normalize."""
inputs: List[SampleInput] = []
for ord in [2.0, 1, float("inf"), float("-inf"), 0]:
for sample_input in sample_inputs_softmax_variant(
op_info, device, dtype, requires_grad, **kwargs
):
sample_input_args, sample_input_kwargs = (
ord,
) + sample_input.args, sample_input.kwargs.copy()
inputs.append(
SampleInput(
sample_input.input.clone().requires_grad_(requires_grad),
args=sample_input_args,
kwargs=sample_input_kwargs,
)
)
return inputs
op_db: List[OpInfo] = [
ReductionOpInfo(
"masked.sum",
ref=reference_reduction_numpy(np.sum),
method_variant=None,
identity=0,
nan_policy="propagate",
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_sparse=True,
supports_sparse_csr=True,
promotes_int_to_int64=True,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
skips=(
DecorateInfo(
unittest.skip("Failing on some jobs"),
"TestReductions",
"test_reference_masked",
dtypes=(torch.bool, torch.int8, torch.int16, torch.int32),
),
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# FIXME: sum reduces all dimensions when dim=[]
DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
DecorateInfo(
unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
),
# RuntimeError: undefined value tensor
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
decorators=[
DecorateInfo(
toleranceOverride({torch.bfloat16: tol(atol=1e-03, rtol=1e-03)}),
"TestReductions",
"test_reference_masked",
),
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
"TestReductions",
"test_reference_masked",
),
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}),
"TestReductions",
"test_ref_small_input",
),
],
sample_inputs_func=sample_inputs_masked_reduction,
sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
),
ReductionOpInfo(
"masked.prod",
ref=reference_reduction_numpy(np.prod),
method_variant=None,
identity=1,
nan_policy="propagate",
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_sparse=True,
supports_sparse_csr=True,
promotes_int_to_int64=True,
# FIXME: "prod_cpu" not implemented for 'Half' or 'BFloat16'
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(
torch.bool, torch.float16, torch.bfloat16
),
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
DecorateInfo(
unittest.skip("Failing on some jobs"),
"TestReductions",
"test_reference_masked",
dtypes=(torch.bool, torch.int8, torch.int16, torch.int32),
),
# integer overflow
DecorateInfo(
unittest.skip("Skipped!"),
"TestReductions",
"test_ref_small_input",
dtypes=(torch.int8, torch.int16, torch.int32),
),
# FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs)
DecorateInfo(
unittest.skip("Skipped!"),
"TestMasked",
"test_mask_layout",
device_type="cuda",
dtypes=(torch.bool, *integral_types(), *complex_types()),
),
),
decorators=[
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-02)}),
"TestReductions",
"test_reference_masked",
),
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
"TestReductions",
"test_ref_duplicate_values",
),
],
sample_inputs_func=sample_inputs_masked_reduction,
sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
),
OpInfo(
"masked.cumsum",
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
method_variant=None,
# Runs very slowly on slow gradcheck - alternatively reduce input sizes
gradcheck_fast_mode=True,
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
DecorateInfo(
unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
),
),
# Can reuse the same inputs; dim is required in both
sample_inputs_func=sample_inputs_masked_cumops,
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
),
OpInfo(
"masked.cumprod",
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
method_variant=None,
# Runs very slowly on slow gradcheck - alternatively reduce input sizes
gradcheck_fast_mode=True,
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
# NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
DecorateInfo(
unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
),
# RuntimeError: "prod_cpu" not implemented for 'BFloat16'
DecorateInfo(
unittest.expectedFailure,
"TestDecomp",
"test_comprehensive",
dtypes=(torch.bfloat16,),
device_type="cpu",
),
),
# Can reuse the same inputs; dim is required in both
sample_inputs_func=sample_inputs_masked_cumops,
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
),
ReductionOpInfo(
"masked.amax",
nan_policy="propagate",
supports_out=False,
dtypes=all_types_and(torch.float16, torch.bfloat16),
supports_sparse=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_sparse_csr=True,
ref=reference_reduction_numpy(np.amax),
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# FIXME: amax reduces all dimensions when dim=[]
DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
DecorateInfo(
unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
),
# RuntimeError: Unknown builtin op: aten::iinfo
DecorateInfo(
unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
),
# FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs)
# FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs)
DecorateInfo(
unittest.skip("Skipped!"),
"TestMasked",
"test_mask_layout",
dtypes=(torch.bool, *integral_types(), *complex_types()),
),
),
sample_inputs_func=sample_inputs_masked_reduction,
sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
),
ReductionOpInfo(
"masked.amin",
nan_policy="propagate",
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
dtypes=all_types_and(torch.float16, torch.bfloat16),
supports_sparse=True,
supports_sparse_csr=True,
ref=reference_reduction_numpy(np.amin),
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# FIXME: amax reduces all dimensions when dim=[]
DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
DecorateInfo(
unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
),
# RuntimeError: Unknown builtin op: aten::iinfo
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
# FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs)
# FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs)
DecorateInfo(
unittest.skip("Skipped!"),
"TestMasked",
"test_mask_layout",
dtypes=(torch.bool, *integral_types(), *complex_types()),
),
),
sample_inputs_func=sample_inputs_masked_reduction,
sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
),
ReductionOpInfo(
"masked.argmax",
supports_out=False,
supports_multiple_dims=False,
supports_autograd=False,
dtypes=all_types_and(torch.float16, torch.bfloat16),
ref=reference_reduction_numpy(np.argmax, supports_keepdims=False),
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# initial is not a keyword for argmax
DecorateInfo(
unittest.expectedFailure, "TestReductions", "test_reference_masked"
),
# NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
sample_inputs_func=sample_inputs_masked_reduction,
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
),
ReductionOpInfo(
"masked.argmin",
supports_out=False,
supports_multiple_dims=False,
supports_autograd=False,
dtypes=all_types_and(torch.float16, torch.bfloat16),
ref=reference_reduction_numpy(np.argmin, supports_keepdims=False),
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# initial is not a keyword for argmin
DecorateInfo(
unittest.expectedFailure, "TestReductions", "test_reference_masked"
),
# NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
sample_inputs_func=sample_inputs_masked_reduction,
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
),
ReductionOpInfo(
"masked.mean",
ref=reference_reduction_numpy(np.mean)
if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
else None,
method_variant=None,
nan_policy="propagate",
supports_out=False,
supports_sparse_csr=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
promotes_int_to_float=True,
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestReductions",
"test_ref_duplicate_values",
dtypes=(torch.bool,),
),
DecorateInfo(
unittest.expectedFailure,
"TestReductions",
"test_reference_masked",
dtypes=(torch.bool,),
),
DecorateInfo(
unittest.expectedFailure,
"TestReductions",
"test_ref_small_input",
dtypes=(torch.bool,),
),
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# FIXME: sum reduces all dimensions when dim=[]
DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
DecorateInfo(
unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
),
# RuntimeError: undefined value tensor
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
# FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs)
DecorateInfo(
unittest.skip("Skipped!"),
"TestMasked",
"test_mask_layout",
dtypes=(torch.bool, *integral_types(), *complex_types()),
),
),
decorators=[
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
"TestReductions",
"test_reference_masked",
),
],
sample_inputs_func=sample_inputs_masked_reduction,
sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
),
OpInfo(
"masked.median",
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16),
method_variant=None,
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
DecorateInfo(
unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
),
),
sample_inputs_func=sample_inputs_masked_softmax,
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
),
ReductionOpInfo(
"masked.norm",
identity=0,
method_variant=None,
nan_policy="propagate",
supports_out=False,
promotes_int_to_float=True,
dtypes=floating_types_and(torch.float16, torch.bfloat16),
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# FIXME: sum reduces all dimensions when dim=[]
DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
DecorateInfo(
unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
),
# torch.jit.frontend.NotSupportedError: Compiled functions
# can't take variable number of arguments or use
# keyword-only arguments with defaults
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_masked_norm,
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
),
ReductionOpInfo(
"masked.var",
ref=reference_reduction_numpy(np.var)
if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
else None,
method_variant=None,
nan_policy="propagate",
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
promotes_int_to_float=True,
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
"TestSchemaCheckModeOpInfo",
"test_schema_correctness",
dtypes=(torch.complex64, torch.complex128),
),
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# FIXME: sum reduces all dimensions when dim=[]
DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
DecorateInfo(
unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
),
# RuntimeError: undefined value tensor
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
decorators=[
DecorateInfo(
toleranceOverride(
{
torch.float16: tol(atol=1e-02, rtol=1e-02),
torch.bfloat16: tol(atol=1e-03, rtol=1e-03),
}
),
"TestReductions",
"test_reference_masked",
),
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
"TestReductions",
"test_ref_small_input",
),
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
"TestMasked",
"test_reference_masked",
),
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
"TestCudaFuserOpInfo",
"test_nvfuser_correctness",
),
],
sample_inputs_func=sample_inputs_masked_std_var,
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
check_batched_grad=True,
),
ReductionOpInfo(
"masked.std",
ref=reference_reduction_numpy(np.std)
if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
else None,
method_variant=None,
nan_policy="propagate",
# Runs very slowly on slow gradcheck - alternatively reduce input sizes
gradcheck_fast_mode=True,
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
promotes_int_to_float=True,
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
"TestSchemaCheckModeOpInfo",
"test_schema_correctness",
dtypes=(torch.complex64, torch.complex128),
),
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# FIXME: sum reduces all dimensions when dim=[]
DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
DecorateInfo(
unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
),
# RuntimeError: undefined value tensor
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
DecorateInfo(
unittest.skip("Skipped!"),
"TestCudaFuserOpInfo",
"test_nvfuser_correctness",
dtypes=(torch.float16,),
),
),
decorators=[
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
"TestReductions",
"test_reference_masked",
),
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
"TestReductions",
"test_ref_small_input",
),
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
"TestMasked",
"test_reference_masked",
),
],
sample_inputs_func=sample_inputs_masked_std_var,
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
check_batched_grad=True,
),
OpInfo(
"masked.softmax",
method_variant=None,
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_masked_softmax,
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False,
),
OpInfo(
"masked.log_softmax",
method_variant=None,
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_masked_softmax,
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
decorators=[
DecorateInfo(
toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1e-02)}),
"TestMasked",
"test_reference_masked",
),
],
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False,
),
OpInfo(
"masked.softmin",
method_variant=None,
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_masked_softmax,
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False,
),
OpInfo(
"masked.normalize",
method_variant=None,
dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_masked_normalize,
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
# RuntimeError: "clamp_min_cpu" not implemented for 'Half'
DecorateInfo(
unittest.expectedFailure,
"TestMasked",
"test_reference_masked",
device_type="cpu",
dtypes=[torch.half],
),
),
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
# Runs very slowly on slow gradcheck - alternatively reduce input sizes
gradcheck_fast_mode=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False,
),
OpInfo(
"masked.logaddexp",
dtypes=floating_types_and(torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
check_batched_forward_grad=False,
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
DecorateInfo(
unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
),
DecorateInfo(
unittest.skip("Skipped!"), "TestGradients", "test_fn_gradgrad"
),
),
sample_inputs_func=sample_inputs_masked_logaddexp,
gradcheck_wrapper=gradcheck_wrapper_masked_pointwise_operation,
),
ReductionOpInfo(
"masked.logsumexp",
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
method_variant=None,
nan_policy="propagate",
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
DecorateInfo(
unittest.skip("Skipped!"),
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# FIXME: reduces all dimensions when dim=[]
DecorateInfo(unittest.skip("Skipped!"), "TestReductions", "test_dim_empty"),
DecorateInfo(
unittest.skip("Skipped!"), "TestReductions", "test_dim_empty_keepdim"
),
# Identity can't be -torch.inf without overflow
DecorateInfo(
unittest.skip("Skipped!"),
"TestReductions",
"test_empty_tensor_empty_slice",
),
# NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
DecorateInfo(
unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
),
# all the values are the same except for -inf vs nan
DecorateInfo(unittest.skip("Skipped!"), "TestDecomp", "test_comprehensive"),
),
sample_inputs_func=sample_inputs_masked_reduction,
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
),
]
|