1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211
|
import math
from typing import List, Optional, Union
import torch
import torch._prims_common as utils
from torch import Tensor
from torch._prims_common import (
check,
corresponding_complex_dtype,
corresponding_real_dtype,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
)
from torch._prims_common.wrappers import out_wrapper
from torch._refs import _broadcast_shapes
from torch.utils._pytree import tree_map
aten = torch.ops.aten
_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
meta_table = {}
def register_meta(op, register_dispatcher=True):
def wrapper(f):
def add_func(op):
meta_table[op] = f
if register_dispatcher:
name = (
op.__name__
if op._overloadname != "default"
else op.overloadpacket.__name__
)
_meta_lib_dont_use_me_use_register_meta.impl(name, f)
op.py_impl(torch._C.DispatchKey.Meta)(f)
tree_map(add_func, op)
return f
return wrapper
def toRealValueType(dtype):
from_complex = {
torch.complex32: torch.half,
torch.cfloat: torch.float,
torch.cdouble: torch.double,
}
return from_complex.get(dtype, dtype)
@register_meta(aten._fft_c2c.default)
def meta_fft_c2c(self, dim, normalization, forward):
assert self.dtype.is_complex
return self.new_empty(self.size())
@register_meta(aten._fft_r2c.default)
def meta_fft_r2c(self, dim, normalization, onesided):
assert self.dtype.is_floating_point
output_sizes = list(self.size())
if onesided:
last_dim = dim[-1]
last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
output_sizes[last_dim] = last_dim_halfsize
return self.new_empty(
output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
)
@register_meta(aten.randperm.generator_out)
def meta_randperm(n, *, generator=None, out):
assert out.ndim == 1 and out.size(0) == n
return out
@register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
@out_wrapper()
def meta_fft_c2r(self, dim, normalization, lastdim):
assert self.dtype.is_complex
output_sizes = list(self.size())
output_sizes[dim[-1]] = lastdim
return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
@register_meta(aten.copy_.default, register_dispatcher=False)
def meta_copy_(self, src, non_blocking=False):
return self
# Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
@register_meta(aten.index_select.default)
def meta_index_select(self, dim, index):
result_size = list(self.size())
if self.dim() > 0:
result_size[dim] = index.numel()
return self.new_empty(result_size)
@register_meta(aten.index_select.out)
def meta_index_select_out(self, dim, index, out):
torch._resize_output_(out, self.size(), self.device)
return out.copy_(torch.index_select(self, dim, index))
@register_meta([aten.max.default, aten.min.default])
def meta_max(self):
return self.new_empty(())
@register_meta(aten.angle.default)
def meta_angle(self):
if self.is_complex():
result_dtype = corresponding_real_dtype(self.dtype)
else:
_, result_dtype = elementwise_dtypes(
self, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
)
return self.new_empty(self.size(), dtype=result_dtype)
@register_meta(aten.angle.out)
def meta_angle_out(self, out):
torch._resize_output_(out, self.size(), self.device)
return out.copy_(torch.angle(self))
def squareCheckInputs(self, f_name):
assert (
self.dim() >= 2
), f"{f_name}: The input tensor must have at least 2 dimensions."
assert self.size(-1) == self.size(
-2
), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
def checkUplo(uplo: str):
uplo_uppercase = uplo.upper()
assert (
len(uplo) == 1 and uplo_uppercase == "U" or uplo_uppercase == "L"
), f"Expected UPLO argument to be 'L' or 'U', but got {uplo}"
# @register_meta(aten.linalg_eigh.default)
def meta_linalg_eigh(self, uplo="L"):
squareCheckInputs(self, "linalg_eigh")
checkUplo(uplo)
real_dtype = toRealValueType(self.dtype)
assert self.dim() >= 2
values = self.new_empty(self.shape, dtype=real_dtype)
values.transpose_(-2, -1)
vectors = self.new_empty(self.shape[:-1])
return (values, vectors)
@register_meta(aten.reflection_pad2d.default)
def meta_pad2d(self, padding):
valid_dims = self.size(1) != 0 and self.size(2) != 0
check(
(self.ndim == 3 and valid_dims)
or (self.ndim == 4 and valid_dims and self.size(3) != 0),
lambda: f"3D or 4D (batch mode) tensor expected for input, but got: {self}",
)
if self.ndim == 4:
nbatch, nplane, input_h, input_w = self.shape
else:
nbatch = 1
nplane, input_h, input_w = self.shape
pad_l, pad_r, pad_t, pad_b = padding
output_h = input_h + pad_t + pad_b
output_w = input_w + pad_l + pad_r
if self.ndim == 3:
return self.new_empty((nplane, output_h, output_w))
else:
return self.new_empty((nbatch, nplane, output_h, output_w))
def dot_check(self, other):
check(
self.dim() == 1 and other.dim() == 1,
lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
)
@register_meta(aten.dot.default)
def meta_dot(self, tensor):
dot_check(self, tensor)
return self.new_empty(())
@register_meta([aten.mm.default], register_dispatcher=False)
def meta_mm(a, b):
check(a.dim() == 2, lambda: "a must be 2D")
check(b.dim() == 2, lambda: "b must be 2D")
N, M1 = a.shape
M2, P = b.shape
check(M1 == M2, lambda: "a and b must have same reduction dim")
return a.new_empty(N, P)
def _compute_reduction_shape(self, dims, keepdim):
if keepdim:
return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))
return utils.compute_reduction_output_shape(self.shape, dims)
@register_meta(aten.bernoulli.out)
def meta_bernoulli(self, *, generator=None, out):
torch._resize_output_(out, self.size(), self.device)
return out
@register_meta(aten.convolution.default)
def meta_conv(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
):
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
Args:
ln: length of the dimension
p: padding in that dim
d: dilation in that dim
k: kernel size in that dim
s: stride in that dim
Returns:
The output length
"""
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
if transposed convolution is used.
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
Args:
ln: length of the dimension
p: padding in that dim
d: dilation in that dim
k: kernel size in that dim
s: stride in that dim
op: output padding in that dim
Returns:
The output length
"""
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
def calc_conv_nd_return_shape(
dims: torch.Size,
kernel_size: torch.Size,
stride: Union[List[int], int],
padding: Union[List[int], int],
dilation: Union[List[int], int],
output_padding: Optional[Union[List[int], int]] = None,
):
ret_shape = []
if isinstance(stride, int):
stride = [stride] * len(dims)
elif len(stride) == 1:
stride = [stride[0]] * len(dims)
if isinstance(padding, int):
padding = [padding] * len(dims)
elif len(padding) == 1:
padding = [padding[0]] * len(dims)
if isinstance(dilation, int):
dilation = [dilation] * len(dims)
elif len(dilation) == 1:
dilation = [dilation[0]] * len(dims)
output_padding_list: Optional[List[int]] = None
if output_padding:
if isinstance(output_padding, int):
output_padding_list = [output_padding] * len(dims)
elif len(output_padding) == 1:
output_padding_list = [output_padding[0]] * len(dims)
else:
output_padding_list = output_padding
for i in range(len(dims)):
# If output_padding is present, we are dealing with a transposed convolution
if output_padding_list:
ret_shape.append(
_formula_transposed(
dims[i],
padding[i],
dilation[i],
kernel_size[i],
stride[i],
output_padding_list[i],
)
)
else:
ret_shape.append(
_formula(
dims[i], padding[i], dilation[i], kernel_size[i], stride[i]
)
)
return ret_shape
def is_channels_last(ten):
return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
def pick_memory_format(device_hint):
if device_hint == "cuda":
if is_channels_last(input_tensor) or is_channels_last(weight):
return torch.channels_last
else:
if is_channels_last(input_tensor):
return torch.channels_last
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
return torch.contiguous_format
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
return torch.preserve_format
kernel_size = weight.shape[2:]
dims = input_tensor.shape[2:]
if is_transposed:
out_channels = groups * weight.shape[1]
shape_out = calc_conv_nd_return_shape(
dims,
kernel_size,
stride,
padding,
dilation,
output_padding,
)
else:
out_channels = weight.shape[0]
if weight.shape[1] * groups != input_tensor.shape[1]:
raise RuntimeError("Invalid channel dimensions")
shape_out = calc_conv_nd_return_shape(
dims, kernel_size, stride, padding, dilation
)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
from torch._subclasses.fake_tensor import FakeTensor
if isinstance(input_tensor, FakeTensor):
device_hint = input_tensor.fake_device.type
else:
device_hint = "cuda" # default to cuda
mem_fmt = pick_memory_format(device_hint)
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out
# from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
def check_dim_size(tensor, dim, dim_size, size):
check(
tensor.dim() == dim and tensor.shape[dim_size] == size,
lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
+ f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
)
@register_meta(aten.avg_pool2d.default, register_dispatcher=False)
def meta_avg_pool2d(
input,
kernel_size,
stride=(),
padding=(0,),
ceil_mode=False,
count_include_pad=True,
divisor_override=None,
):
def unpack(name, val):
check(
len(val) in [1, 2],
lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
)
H = val[0]
W = H if len(val) == 1 else val[1]
return H, W
kH, kW = unpack("kernel_size", kernel_size)
check(
len(stride) in [0, 1, 2],
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
)
if len(stride) == 0:
dH, dW = kH, kW
elif len(stride) == 1:
dH, dW = stride[0], stride[0]
else:
dH, dW = unpack("stride", stride)
padH, padW = unpack("padding", padding)
check(
divisor_override is None or divisor_override != 0,
lambda: "divisor must be not zero",
)
nbatch = input.size(-4) if input.dim() == 4 else 1
nInputPlane = input.size(-3)
inputHeight = input.size(-2)
inputWidth = input.size(-1)
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
memory_format = utils.suggest_memory_format(input)
pool2d_shape_check(
input,
kH,
kW,
dH,
dW,
padH,
padW,
1,
1,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
memory_format,
)
if input.dim() == 3:
size = [nInputPlane, outputHeight, outputWidth]
else:
size = [nbatch, nInputPlane, outputHeight, outputWidth]
return torch.empty(
size, dtype=input.dtype, device=input.device, memory_format=memory_format
)
# from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h.
def avg_pool2d_backward_shape_check(
input,
gradOutput,
nbatch,
kH,
kW,
dH,
dW,
padH,
padW,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
mem_format,
):
pool2d_shape_check(
input,
kH,
kW,
dH,
dW,
padH,
padW,
1,
1,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
mem_format,
)
ndim = input.dim()
nOutputPlane = nInputPlane
check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane)
check_dim_size(gradOutput, ndim, ndim - 2, outputHeight)
check_dim_size(gradOutput, ndim, ndim - 1, outputWidth)
@register_meta(aten._adaptive_avg_pool2d.default)
def meta_adaptive_avg_pool2d(self, output_size):
check(
self.ndim == 3 or self.ndim == 4,
lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
)
return self.new_empty(self.shape[:-2] + tuple(output_size))
@register_meta(aten._adaptive_avg_pool3d.default)
def meta_adaptive_avg_pool3d(self, output_size):
check(
self.ndim == 4 or self.ndim == 5,
lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
)
return self.new_empty(self.shape[:-3] + tuple(output_size))
@register_meta(aten.repeat_interleave.Tensor)
def meta_repeat_interleave_Tensor(repeats, output_size=None):
if output_size is None:
raise RuntimeError("cannot repeat_interleave a meta tensor without output_size")
return repeats.new_empty(output_size)
@register_meta([aten.complex.default, aten.complex.out])
@out_wrapper()
def meta_complex(real, imag):
assert real.dtype.is_floating_point
assert imag.dtype.is_floating_point
out_shape = _broadcast_shapes(real.shape, imag.shape)
return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))
@register_meta(aten.vdot.default)
def vdot(self, other):
if not self.is_complex:
return torch.dot(self, other)
if self.is_conj():
if other.is_conj():
return torch.vdot(other.conj(), self.conj())
else:
return torch.dot(self.conj(), other)
elif other.is_conj():
return torch.dot(self, other.conj()).conj()
dot_check(self, other)
return self.new_empty(())
# Leaving this function around because a python implementation
# of indexing shape inference is useful,
# but not registering it to the dispatcher because we already
# get shape inference through structured kernels
@register_meta(aten.index.Tensor, register_dispatcher=False)
def meta_index_Tensor(self, indices):
check(indices, lambda: "at least one index must be provided")
# aten::index is the internal advanced indexing implementation
# checkIndexTensorTypes and expandTensors
result: List[Optional[Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
check(
index.dtype in [torch.long, torch.int8, torch.bool],
lambda: "tensors used as indices must be long, byte or bool tensors",
)
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
check(
k + index.ndim <= self.ndim,
lambda: f"too many indices for tensor of dimension {self.ndim}",
IndexError,
)
for j in range(index.ndim):
check(
index.shape[j] == self.shape[k + j],
lambda: f"The shape of the mask {index.shape} at index {i} "
f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
IndexError,
)
result.append(nonzero.select(1, j))
else:
result.append(index)
else:
result.append(index)
indices = result
check(
len(indices) <= self.ndim,
lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
)
# expand_outplace
import torch._refs as refs # avoid import cycle in mypy
indices = list(refs._maybe_broadcast(*indices))
# add missing null tensors
while len(indices) < self.ndim:
indices.append(None)
# hasContiguousSubspace
# true if all non-null tensors are adjacent
# See:
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
state = 0
has_contiguous_subspace = False
for index in indices:
if state == 0:
if index is not None:
state = 1
elif state == 1:
if index is None:
state = 2
else:
if index is not None:
break
else:
has_contiguous_subspace = True
# transposeToFront
# This is the logic that causes the newly inserted dimensions to show up
# at the beginning of the tensor, if they're not contiguous
if not has_contiguous_subspace:
dims = []
transposed_indices = []
for i, index in enumerate(indices):
if index is not None:
dims.append(i)
transposed_indices.append(index)
for i, index in enumerate(indices):
if index is None:
dims.append(i)
transposed_indices.append(index)
self = self.permute(dims)
indices = transposed_indices
# AdvancedIndex::AdvancedIndex
# Now we can assume the indices have contiguous subspace
# This is simplified from AdvancedIndex which goes to more effort
# to put the input and indices in a form so that TensorIterator can
# take them. If we write a ref for this, probably that logic should
# get implemented
before_shape: List[int] = []
after_shape: List[int] = []
replacement_shape: List[int] = []
for dim, index in enumerate(indices):
if index is None:
if replacement_shape:
after_shape.append(self.shape[dim])
else:
before_shape.append(self.shape[dim])
else:
replacement_shape = list(index.shape)
return self.new_empty(before_shape + replacement_shape + after_shape)
@register_meta([aten.addbmm.default, aten.addbmm.out])
@out_wrapper()
def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
dim1 = batch1.size(1)
dim2 = batch2.size(2)
self = self.expand((dim1, dim2))
check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
check(
batch1.size(0) == batch2.size(0),
lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
)
check(
batch1.size(2) == batch2.size(1),
lambda: (
f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
f"and {batch2.size(1)}x{batch2.size(2)})"
),
)
check(
self.size(0) == dim1 and self.size(1) == dim2,
lambda: "self tensor does not match matmul output shape",
)
return self.new_empty(self.size())
@register_meta(aten._cdist_forward.default)
def meta_cdist_forward(x1, x2, p, compute_mode):
check(
x1.dim() >= 2,
lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
)
check(
x2.dim() >= 2,
lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
)
check(
x1.size(-1) == x2.size(-1),
lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
)
check(
utils.is_float_dtype(x1.dtype),
lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
)
check(
utils.is_float_dtype(x2.dtype),
lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
)
check(p >= 0, lambda: "cdist only supports non-negative p values")
check(
compute_mode >= 0 and compute_mode <= 2,
lambda: f"possible modes: 0, 1, 2, but was: {compute_mode}",
)
r1 = x1.size(-2)
r2 = x2.size(-2)
batch_tensor1 = x1.shape[:-2]
batch_tensor2 = x2.shape[:-2]
output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
output_shape.extend([r1, r2])
return x1.new_empty(output_shape)
@register_meta(aten._embedding_bag.default)
def meta_embedding_bag(
weight,
indices,
offsets,
scale_grad_by_freq=False,
mode=0,
sparse=False,
per_sample_weights=None,
include_last_offset=False,
padding_idx=-1,
):
check(
indices.dtype in (torch.long, torch.int),
lambda: f"expected indices to be long or int, got {indices.dtype}",
)
check(
offsets.dtype in (torch.long, torch.int),
lambda: f"expected offsets to be long or int, got {offsets.dtype}",
)
check(
utils.is_float_dtype(weight.dtype),
lambda: f"expected weight to be floating point type, got {weight.dtype}",
)
num_bags = offsets.size(0)
if include_last_offset:
check(
num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1"
)
num_bags -= 1
output = weight.new_empty(num_bags, weight.size(1))
MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
if per_sample_weights is not None:
check(
mode == MODE_SUM,
lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
)
check(
per_sample_weights.dtype == weight.dtype,
lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
)
check(
per_sample_weights.ndim == 1,
lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
)
check(
per_sample_weights.numel() == indices.numel(),
lambda: (
f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
f"to be the same as indices.numel() ({indices.numel()})"
),
)
def is_fast_path_index_select_scale(src, scale, output, padding_idx):
return (
is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
)
def is_fast_path_index_select(src, output, padding_idx):
return (
(src.dtype == torch.float or src.dtype == torch.half)
and src.stride(1) == 1
and output.stride(1) == 1
and padding_idx < 0
)
def is_fast_path(src, scale, output, padding_idx):
if scale is not None:
return is_fast_path_index_select_scale(src, scale, output, padding_idx)
else:
return is_fast_path_index_select(src, output, padding_idx)
if offsets.device.type != "cpu":
offset2bag = indices.new_empty(indices.size(0))
bag_size = indices.new_empty(offsets.size())
if mode == MODE_MAX:
max_indices = indices.new_empty(num_bags, weight.size(1))
else:
max_indices = indices.new_empty(0)
else:
fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum:
offset2bag = offsets.new_empty(indices.size(0))
else:
offset2bag = offsets.new_empty(0)
bag_size = offsets.new_empty(num_bags)
max_indices = offsets.new_empty(bag_size.size())
return output, offset2bag, bag_size, max_indices
@register_meta([aten.diag.default, aten.diag.out])
@out_wrapper()
def meta_diag(self, dim=0):
check(self.dim() in (1, 2), lambda: "matrix or a vector expected")
if self.dim() == 1:
sz = self.size(0) + abs(dim)
return self.new_empty((sz, sz))
# case: dim is 2
if dim >= 0:
sz = min(self.size(0), self.size(1) - dim)
else:
sz = min(self.size(0) + dim, self.size(1))
return self.new_empty((sz,))
@register_meta(aten._embedding_bag_forward_only.default)
def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
output, offset2bag, bag_size, max_indices = meta_embedding_bag(
weight, indices, offsets, *args
)
if offsets.device.type == "cpu":
bag_size = offsets.new_empty(offsets.size())
return output, offset2bag, bag_size, max_indices
def _get_reduction_dtype(input, dtype, promote_int_to_long=True):
# if specified, dtype takes precedence
if dtype:
return dtype
if input.dtype.is_floating_point or input.dtype.is_complex:
return input.dtype
elif promote_int_to_long:
return torch.long
return input.dtype
@register_meta([aten.nansum.default, aten.nansum.out])
@out_wrapper()
def meta_nansum(input, dims=None, keepdim=False, *, dtype=None):
output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True)
dims = utils.reduction_dims(input.shape, dims)
output_shape = _compute_reduction_shape(input, dims, keepdim)
return input.new_empty(output_shape, dtype=output_dtype)
@register_meta(aten.nanmedian.default)
def meta_nanmedian(input):
output_shape = utils.compute_reduction_output_shape(
input.shape, tuple(range(input.dim()))
)
return input.new_empty(output_shape)
@register_meta([aten.nanmedian.dim, aten.nanmedian.dim_values])
@out_wrapper("values", "indices")
def meta_nanmedian_dim(input, dim=-1, keepdim=False):
dim = utils.reduction_dims(input.shape, (dim,))
output_shape = _compute_reduction_shape(input, dim, keepdim)
return (
input.new_empty(output_shape),
input.new_empty(output_shape, dtype=torch.long),
)
@register_meta(aten.logical_not_.default)
def meta_logical_not_(self):
return self
@register_meta(aten.repeat.default)
def meta_repeat(self, repeats):
check(
len(repeats) >= self.dim(),
lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
)
# Add new leading dimensions to the tensor if the
# number of target dimensions is larger than the
# number of source dimensions.
num_new_dimensions = len(repeats) - self.dim()
padded_size = (1,) * num_new_dimensions + tuple(self.shape)
target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))]
return self.new_empty(target_size)
@register_meta(aten.zero_.default, register_dispatcher=False)
def meta_zero_(self):
return self
@register_meta(
[aten.fill.Tensor, aten.fill.Scalar, aten.fill_.Tensor, aten.fill_.Scalar],
register_dispatcher=False,
)
def meta_fill_(self, val):
return self
@register_meta(aten.relu_.default, register_dispatcher=False)
def meta_relu_(self):
return self
@register_meta(aten.index_put.default, register_dispatcher=False)
def meta_index_put(self, indices, values, accumulate=False):
return self.new_empty(self.size())
@register_meta(aten.masked_fill_.Scalar, register_dispatcher=False)
def meta_masked_fill_(self, mask, value):
return self
@register_meta(aten.index_put_.default, register_dispatcher=False)
def meta_index_put_(self, indices, values, accumulate=False):
return self
@register_meta(aten.alias.default, register_dispatcher=False)
def meta_alias(self):
return self.view(self.shape)
def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
batch1_sizes = batch1.size()
batch2_sizes = batch2.size()
bs = batch1_sizes[0]
contraction_size = batch1_sizes[2]
res_rows = batch1_sizes[1]
res_cols = batch2_sizes[2]
output_size = (bs, res_rows, res_cols)
check(
batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
)
# TODO: handle out
output = batch2.new_empty(output_size)
if not is_bmm and self_baddbmm is not None:
check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
check(
self_baddbmm.size() == output_size,
lambda: "Expected an input tensor shape with shape {output_size} but got shape: {self.size()}",
)
return output
@register_meta(aten.bmm.default, register_dispatcher=False)
def meta_bmm(self, mat2):
return common_meta_baddbmm_bmm(self, mat2, True)
def div_rtn(x, y):
q = x // y
r = x % y
# WARNING: explicit bool conversion here is necessary;
# would be fixed by SymBool
if r != 0 and (bool(r < 0) != bool(y < 0)):
q -= 1
return q
def pooling_output_shape_pad_lr(
inputSize, kernelSize, pad_l, pad_r, stride, dilation, ceil_mode
):
outputSize = (
div_rtn(
inputSize
+ pad_l
+ pad_r
- dilation * (kernelSize - 1)
- 1
+ (stride - 1 if ceil_mode else 0),
stride,
)
+ 1
)
if ceil_mode:
if (outputSize - 1) * stride >= inputSize + pad_l:
outputSize -= 1
return outputSize
def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
check(stride != 0, lambda: "stride should not be zero")
check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
check(
pad <= kernelSize // 2,
lambda: f"pad should be at most half of kernel size, but got pad={pad} and kernel_size={kernelSize}",
)
return pooling_output_shape_pad_lr(
inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode
)
def pool2d_shape_check(
input,
kH,
kW,
dH,
dW,
padH,
padW,
dilationH,
dilationW,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
memory_format,
):
ndim = input.dim()
nOutputPlane = nInputPlane
check(
kW > 0 and kH > 0,
lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
)
check(
dW > 0 and dH > 0,
lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
)
check(
dilationH > 0 and dilationW > 0,
lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
)
valid_dims = input.size(1) != 0 and input.size(2) != 0
if memory_format == torch.channels_last:
check(
ndim == 4 and valid_dims and input.size(3) != 0,
lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
" with optional 0 dim batch size for input, but got: {input.size()}",
)
else:
check(
(ndim == 3 and input.size(0) != 0 and valid_dims)
or (ndim == 4 and valid_dims and input.size(3) != 0),
lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
)
check(
kW // 2 >= padW and kH // 2 >= padH,
lambda: "pad should be smaller than or equal to half of kernel size, but got "
f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
)
check(
outputWidth >= 1 and outputHeight >= 1,
lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
"Output size is too small",
)
@register_meta(aten.max_pool2d_with_indices.default, register_dispatcher=False)
def meta_max_pool2d_with_indices(
input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False
):
# Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp
def unpack(name, val):
check(
len(val) in [1, 2],
lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
)
H = val[0]
W = H if len(val) == 1 else val[1]
return H, W
kH, kW = unpack("kernel_size", kernel_size)
check(
len(stride) in [0, 1, 2],
lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
)
if len(stride) == 0:
dH, dW = kH, kW
else:
dH, dW = unpack("stride", stride)
padH, padW = unpack("padding", padding)
dilationH, dilationW = unpack("dilation", dilation)
memory_format = utils.suggest_memory_format(input)
if memory_format == torch.channels_last:
check(
input.dim() == 4,
lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
)
elif memory_format == torch.contiguous_format:
check(
input.dim() in [3, 4],
lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
)
else:
check(
False,
lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous",
)
nbatch = input.size(-4) if input.dim() == 4 else 1
nInputPlane = input.size(-3)
inputHeight = input.size(-2)
inputWidth = input.size(-1)
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
pool2d_shape_check(
input,
kH,
kW,
dH,
dW,
padH,
padW,
dilationH,
dilationW,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
memory_format,
)
if input.dim() == 3:
size = [nInputPlane, outputHeight, outputWidth]
else:
size = [nbatch, nInputPlane, outputHeight, outputWidth]
return (
torch.empty(
size, dtype=input.dtype, device=input.device, memory_format=memory_format
),
torch.empty(
size, dtype=torch.int64, device=input.device, memory_format=memory_format
),
)
@register_meta([aten.full.default])
def full(size, fill_value, *args, **kwargs):
return torch.empty(size, *args, **kwargs)
@register_meta(
[
aten.randint_like.default,
aten.randint_like.low_dtype,
aten.randn_like.default,
aten.rand_like.default,
aten.full_like.default,
aten.zeros_like.default,
aten.ones_like.default,
]
)
def meta_like(self, *args, **kwargs):
return aten.empty_like.default(self, **kwargs)
# hacky: Please remove after math.ceil works with arange
@register_meta(aten.arange.default)
def arange(end, **kwargs):
if isinstance(end, float):
end = math.ceil(end)
def is_integral(x):
return isinstance(x, int) or isinstance(x, bool)
set_to_integral_dtype = kwargs.get("dtype", None) is None and is_integral(end)
if set_to_integral_dtype:
kwargs["dtype"] = torch.int64
return aten.empty([end], **kwargs)
@register_meta(aten.arange.start)
def arange_start(start, end, **kwargs):
return aten.arange(end - start, **kwargs)
# We must also trigger meta registrations from PrimTorch ref
# decompositions
import torch._refs
import torch._refs.nn.functional
import torch._refs.special
|