1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287
|
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import inspect
import warnings
from typing import Any, Callable, cast, Optional, Sequence, Tuple
import torch
import torch.distributed.tensor._dispatch as op_dispatch
import torch.distributed.tensor._random as random
import torch.nn as nn
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.tensor._collective_utils import check_tensor_meta, mesh_broadcast
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._random import (
is_rng_supported_mesh,
OffsetBasedRNGTracker,
)
from torch.distributed.tensor._redistribute import (
Redistribute,
redistribute_local_tensor,
)
from torch.distributed.tensor._utils import (
compute_global_tensor_info,
compute_local_shape_and_global_offset,
normalize_to_torch_size,
)
from torch.distributed.tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
__all__ = [
"DTensor",
"distribute_tensor",
"distribute_module",
"ones",
"empty",
"full",
"rand",
"randn",
"zeros",
]
aten = torch.ops.aten
# NOTE [Autograd interaction between torch.Tensor]
#
# The autograd functions defined below are being used by the public
# facing APIs (i.e. from_local, to_local) to ensure DTensor to work
# together with torch.Tensor within the autograd engine. This
# allows DTensor to only exist on part of the module hierarchy.
#
# As an example, we have the a module that consists of submodules
# A, B, and C, the execution flow would be like:
# input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor)
#
# Suppose I only want to make Module B be a sharded module with
# DTensor params, the following forward/backward should work:
#
# input(torch.Tensor) -> Module A
# -> DTensor input (from_local) -> Sharded Module B -> DTensor output
# -> torch.Tensor output (to_local) -> Module C
#
# So from_local/to_local must be Autograd functions.
#
class _ToTorchTensor(torch.autograd.Function):
@staticmethod
def forward( # type: ignore[override]
ctx,
input: "DTensor",
grad_placements: Optional[Sequence[Placement]],
):
ctx.dtensor_spec = input._spec
ctx.grad_placements = grad_placements
local_tensor = input._local_tensor
# We need to return a fresh Tensor object there as autograd metadata
# will be inplaced into it. So we don't want to pollute the Tensor
# object stored in the _local_tensor of this DTensor.
return local_tensor.view_as(local_tensor)
@staticmethod
def backward(ctx, grad_output: torch.Tensor): # type: ignore[override]
dtensor_spec = ctx.dtensor_spec
mesh = dtensor_spec.mesh
grad_placements = ctx.grad_placements
dtensor_meta = dtensor_spec.tensor_meta
_, tensor_stride = compute_global_tensor_info(
grad_output, mesh, dtensor_spec.placements
)
tensor_stride = tuple(tensor_stride)
grad_placements = grad_placements or dtensor_spec.placements
grad_spec = DTensorSpec(
mesh,
grad_placements,
tensor_meta=TensorMeta(
shape=dtensor_meta.shape,
stride=tensor_stride,
dtype=dtensor_meta.dtype,
),
)
return (
DTensor(
grad_output,
grad_spec,
requires_grad=grad_output.requires_grad,
),
None,
)
class _FromTorchTensor(torch.autograd.Function):
@staticmethod
def forward( # type: ignore[override]
ctx, # pyre-ignore[2]: Parameter must be annotated.
input: torch.Tensor,
device_mesh: DeviceMesh,
placements: Tuple[Placement, ...],
run_check: bool,
shape: Optional[torch.Size] = None,
stride: Optional[Tuple[int, ...]] = None,
) -> "DTensor":
ctx.previous_placement = placements
ctx.previous_device_mesh = device_mesh
if shape and stride:
tensor_shape, tensor_stride = shape, stride
elif not shape and not stride:
# if it's not by default run_check, we assume user is certain that each
# rank has the same tensor shape, and we just use that to calculate the
# global shape
global_shape, global_stride = compute_global_tensor_info(
input, device_mesh, placements
)
tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride)
else:
raise RuntimeError(
f"Found shape:{shape}, stride:{stride}.",
"Please pass both shape and stride at the same time.",
)
if device_mesh.get_coordinate() is None:
# if the global rank is not participating in the device mesh, we
# simply set the local tensor to an empty tensor
input = input.new_empty(0, requires_grad=input.requires_grad)
elif run_check:
# TODO: support uneven sharding when global shape/stride not passed, by
# building the global TensorMeta during check_tensor_meta
check_shape_stride = not shape and not stride
check_tensor_meta(input, check_shape_stride=check_shape_stride)
# TODO: See if we need to make this run_check logic
# have a corresponding backward.
for idx, placement in enumerate(placements):
if placement.is_replicate():
# broadcast rank 0 tensor to all ranks
# only broadcast if run_check is True
input = input.contiguous()
mesh_broadcast(input, device_mesh, mesh_dim=idx)
dist_spec = DTensorSpec(
device_mesh,
placements,
tensor_meta=TensorMeta(
tensor_shape,
tensor_stride,
input.dtype,
),
)
# We want a fresh Tensor object that shares memory with the input tensor
dist_tensor = DTensor(
input.view_as(input),
dist_spec,
# requires_grad of the dist tensor depends on if input
# requires_grad or not
requires_grad=input.requires_grad,
)
return dist_tensor
@staticmethod
def backward(ctx, grad_output: "DTensor"): # type: ignore[override]
previous_placement = ctx.previous_placement
previous_device_mesh = ctx.previous_device_mesh
# reshard to the placement when creating DistributedTensor
# so that the gradient layout matches, and we could return
# local gradients directly
if grad_output.placements != previous_placement:
current_spec = grad_output._spec
target_spec = DTensorSpec(
previous_device_mesh,
previous_placement,
tensor_meta=grad_output._spec.tensor_meta,
)
local_tensor = grad_output._local_tensor
output = redistribute_local_tensor(
local_tensor, current_spec, target_spec, is_backward=True
)
# TODO: return the redistributed local tensor directly without
# differentiable backward. see if this make sense for all cases.
return output, None, None, None, None, None
# TODO: backward is also differentiable now, add a test
# to test higher level gradients.
return grad_output.to_local(), None, None, None, None, None
class DTensor(torch.Tensor):
"""
``DTensor`` (Distributed Tensor) is a subclass of ``torch.Tensor`` that provides single-device like
abstraction to program with multi-device ``torch.Tensor``. It describes the distributed tensor sharding
layout (DTensor Layout) through the :class:`DeviceMesh` and following types of :class:`Placement`:
* :class:`Shard`: Tensor sharded on the tensor dimension ``dim`` on the devices of the ``DeviceMesh`` dimension
* :class:`Replicate`: Tensor replicated on the devices of the ``DeviceMesh`` dimension
* :class:`Partial`: Tensor is pending reduction on the devices of the ``DeviceMesh`` dimension
When calling PyTorch operators, ``DTensor`` overrides the PyTorch operators to perform sharded computation and issue
communications whenever necessary. Along with the operator computation, ``DTensor`` will transform or propagate the
placements (DTensor Layout) properly (based on the operator semantic itself) and generate new ``DTensor`` outputs.
To ensure numerical correctness of the ``DTensor`` sharded computation when calling PyTorch operators, ``DTensor``
requires every Tensor argument of the operator be DTensor.
.. note:: Directly using the Tensor subclass constructor here is not the recommended way to create a ``DTensor``
(i.e. it does not handle autograd correctly hence is not the public API). Please refer to the `create_dtensor`_
section to see how to create a ``DTensor``.
"""
_local_tensor: torch.Tensor
_spec: DTensorSpec
__slots__ = ["_local_tensor", "_spec"]
# _op_dispatcher instance as a class attribute to handle runtime dispatching logic
_op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher()
@staticmethod
@torch._disable_dynamo
def __new__(
cls,
local_tensor: torch.Tensor,
spec: DTensorSpec,
*,
requires_grad: bool,
) -> "DTensor":
"""
Construct a DTensor from a local tensor, device mesh, and placement and
other tensor properties (i.e. shape, requires_grad, strides, etc).
.. note:: This is not a public API and it's only supposed to be used by the
operator implementations and internals. If you want to construct a
DTensor from a local tensor, consider using ``DTensor.from_local``, if
you want to construct a DTensor from a "global" tensor (where you
already have tensor initialized and want to shard this tensor),
consider using ``distribute_tensor``.
"""
if local_tensor.requires_grad and not requires_grad:
warnings.warn(
"To construct DTensor from torch.Tensor, it's recommended to "
"use local_tensor.detach() and make requires_grad consistent."
)
# new method instruct wrapper tensor from local_tensor and add
# placement spec, it does not do actual distribution
assert spec.tensor_meta is not None, "TensorMeta should not be None!"
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
spec.tensor_meta.shape,
strides=spec.tensor_meta.stride,
dtype=local_tensor.dtype,
device=local_tensor.device,
layout=local_tensor.layout,
requires_grad=requires_grad,
)
r._spec = spec
r._local_tensor = local_tensor
return r
# pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently.
# pyre-fixme[3]: Return type must be annotated.
def __repr__(self): # type: ignore[override]
# TODO: consider all_gather the local tensors for better debugging
return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})"
def __tensor_flatten__(self):
"""
protocol to inform how to flatten a DTensor to local tensor
for PT2 tracing
"""
return ["_local_tensor"], (self._spec, self.requires_grad)
@staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
assert (
flatten_spec is not None
), "Expecting spec to be not None from `__tensor_flatten__` return value!"
local_tensor = inner_tensors["_local_tensor"]
spec, requires_grad = flatten_spec
unflatten_tensor_meta = TensorMeta(
shape=outer_size,
stride=outer_stride,
dtype=spec.tensor_meta.dtype,
)
unflatten_spec = DTensorSpec(
spec.mesh,
spec.placements,
tensor_meta=unflatten_tensor_meta,
)
return DTensor(
local_tensor,
unflatten_spec,
requires_grad=requires_grad,
)
def __coerce_tangent_metadata__(self):
if not any(isinstance(p, Partial) for p in self.placements):
return self
placements = [
Replicate() if isinstance(p, Partial) else p for p in self.placements
]
return self.redistribute(device_mesh=self.device_mesh, placements=placements)
def __coerce_same_metadata_as_tangent__(self, flatten_spec, expected_type=None):
if expected_type is not None:
return None
(spec, _) = flatten_spec # Result of tensor_flatten()
return self.redistribute(
device_mesh=self.device_mesh,
placements=spec.placements,
)
@classmethod
@torch._disable_dynamo
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return DTensor._op_dispatcher.dispatch(
func,
args,
kwargs or {},
)
@staticmethod
def from_local(
local_tensor: torch.Tensor,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
*,
run_check: bool = False,
shape: Optional[torch.Size] = None,
stride: Optional[Tuple[int, ...]] = None,
) -> "DTensor":
"""
Create a :class:`DTensor` from a local torch.Tensor on each rank
according to the ``device_mesh`` and ``placements`` specified.
Args:
local_tensor (torch.Tensor): local torch.Tensor on each rank.
device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
tensor, if not specified, must be called under a DeviceMesh
context manager, default: None
placements (List[:class:`Placement`], optional): the placements that
describes how to place the local torch.Tensor on DeviceMesh, must
have the same number of elements as ``device_mesh.ndim``.
Keyword args:
run_check (bool, optional): at a cost of extra communications, perform
sanity check across ranks to check each local tensor's meta information
to ensure correctness. If have :class:`Replicate` in ``placements``, the
data on first rank of the device mesh dimension will be broadcasted
to other ranks. default: False
shape (torch.Size, optional): A List of int which specifies the size of
DTensor which build on top of `local_tensor`. Note this needs to be
provided if the shape of ``local_tensor`` are different across the ranks.
If not provided, ``shape`` will be computed assuming the given distributed
tensor is evenly sharded across ranks. default: None
stride (tuple, optional): A List of int which specifies the stride of DTensor.
If not provided, ``stride`` will be computed assuming the given distributed
tensor is evenly sharded across ranks. default: None
Returns:
A :class:`DTensor` object
.. note:: When ``run_check=False``, it is the user's responsibility to ensure the
local tensor passed in is correct across ranks (i.e. the tensor is sharded for
the ``Shard(dim)`` placement or replicated for the ``Replicate()`` placement).
If not, the behavior of the created DTensor is undefined.
.. note:: ``from_local`` is differentiable, the `requires_grad` of the created
`DTensor` object will depend on if `local_tensor` requires_grad or not.
"""
# if same shape/dtype, no need to run_check, if not, must allgather
# the metadatas to check the size/dtype across ranks
# There should be no data communication unless there's replication
# strategy, where we broadcast the replication from the first rank
# in the mesh dimension
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
device_type = device_mesh.device_type
# convert the local tensor to desired device base on device mesh's device_type
if device_type != local_tensor.device.type and not local_tensor.is_meta:
local_tensor = local_tensor.to(device_type)
# set default placements to replicated if not specified
if placements is None:
placements = [Replicate() for _ in range(device_mesh.ndim)]
else:
placements = list(placements)
for idx, placement in enumerate(placements):
# normalize shard dim to be positive
if placement.is_shard():
placement = cast(Shard, placement)
if placement.dim < 0:
placements[idx] = Shard(placement.dim + local_tensor.ndim)
# `from_local` is differentiable, and the gradient of the dist tensor this function
# created should flow back the gradients to the local_tensor, so we call an autograd
# function to construct the dist tensor instead.
return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func
local_tensor,
device_mesh,
tuple(placements),
run_check,
shape,
stride,
)
def to_local(
self, *, grad_placements: Optional[Sequence[Placement]] = None
) -> torch.Tensor:
"""
Get the local tensor of this DTensor on its current rank. For sharding it returns
a local shard of the logical tensor view, for replication it returns the replica on
its current rank.
Keyword args:
grad_placements (List[:class:`Placement`], optional): the placements describes
the future layout of any gradient layout of the Tensor returned from this
function.
`to_local` converts DTensor to local tensor and the returned local tensor
might not be used as the original DTensor layout later in the code. This
argument is the hint that user can give to autograd in case the gradient
layout of the returned tensor does not match the original DTensor layout.
If not specified, we will assume the gradient layout remains the same
as the original DTensor and use that for gradient computation.
Returns:
A :class:`torch.Tensor` or ``AsyncCollectiveTensor`` object. it represents the
local tensor on its current rank. When an ``AsyncCollectiveTensor`` object is returned,
it means the local tensor is not ready yet (i.e. communication is not finished). In this
case, user needs to call ``wait`` to wait the local tensor to be ready.
.. note:: ``to_local`` is differentiable, the ``requires_grad`` of the local tensor returned
will depend on if the `DTensor` requires_grad or not.
"""
if not torch.is_grad_enabled():
return self._local_tensor
if grad_placements is not None and not isinstance(grad_placements, tuple):
grad_placements = tuple(grad_placements)
return _ToTorchTensor.apply(
self, grad_placements
) # pyre-ignore[16]: autograd func
def redistribute(
self,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
*,
async_op: bool = False,
) -> "DTensor":
"""
``redistribute`` performs necessary collective operations that redistribute the current
DTensor from its current placements to a new placements, or from is current DeviceMesh
to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by
specifying a Replicate placement for each dimension of the DeviceMesh.
When redistributing from current to the new placements on one device mesh dimension, we
will perform the following operations including communication collective or local operation:
1. ``Shard(dim)`` -> ``Replicate()``: ``all_gather``
2. ``Shard(src_dim)`` -> ``Shard(dst_dim)``: ``all_to_all``
3. ``Replicate()`` -> ``Shard(dim)``: local chunking (i.e. ``torch.chunk``)
4. ``Partial()`` -> ``Replicate()``: ``all_reduce``
5. ``Partial()`` -> ``Shard(dim)``: ``reduce_scatter``
``redistribute`` would correctly figure out the necessary redistribute steps for DTensors
that are created either on 1-D or N-D DeviceMesh.
Args:
device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
DTensor. If not specified, it would use the current DTensor's DeviceMesh.
default: None
placements (List[:class:`Placement`], optional): the new placements that
describes how to place the DTensor into the DeviceMesh, must
have the same number of elements as ``device_mesh.ndim``.
default: replicate on all mesh dimensions
Keyword args:
async_op (bool, optional): whether to perform the DTensor redistribute operation
asynchronously or not. Default: False
Returns:
A :class:`DTensor` object
.. note:: ``redistribute`` is differentiable, which means user do not need to worry about
the backward formula of the redistribute operation.
.. note:: ``redistribute`` currently only supports redistributing DTensor on the same DeviceMesh,
Please file an issue if you need to redistribute DTensor to different DeviceMesh.
"""
# NOTE: This redistribute API currently only supports out
# of place redistribution, i.e. it always create a new
# DTensor object and leave the original one unchanged.
# if device_mesh is not specified, use the current device_mesh
device_mesh = device_mesh or self.device_mesh
# raise error if new placements not specified
if placements is None:
raise RuntimeError("placements is needed for redistribute!")
placements = list(placements)
for i, placement in enumerate(placements):
if placement.is_partial():
raise RuntimeError(
"Can not redistribute to Partial, redistributing to Partial is for internal use only!"
)
elif isinstance(placement, Shard) and placement.dim < 0:
# normalize shard dim to be positive
placements[i] = Shard(placement.dim + self.ndim)
placements = tuple(placements)
# pyre-fixme[16]: `Redistribute` has no attribute `apply`.
return Redistribute.apply(self, device_mesh, placements, async_op)
def full_tensor(
self, *, grad_placements: Optional[Sequence[Placement]] = None
) -> torch.Tensor:
"""
Return the full tensor of this DTensor. It will perform necessary collectives
to gather the local tensors from other ranks in its DeviceMesh and concatenate
them together. It's a syntatic sugar of the following code:
``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()``
Keyword args:
grad_placements (List[:class:`Placement`], optional): the placements describes
the future layout of any gradient layout of the full Tensor returned from this
function.
`full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor
might not be used as the original replicated DTensor layout later in the code. This
argument is the hint that user can give to autograd in case the gradient
layout of the returned tensor does not match the original replicated DTensor layout.
If not specified, we will assume the gradient layout of the full tensor be replicated.
Returns:
A :class:`torch.Tensor` object that represents the full tensor of this DTensor.
.. note:: ``full_tensor`` is differentiable.
"""
redist_res = self.redistribute(
placements=[Replicate()] * self.device_mesh.ndim, async_op=False
)
return _ToTorchTensor.apply(redist_res, grad_placements)
@property
def device_mesh(self) -> DeviceMesh:
"""
The :class:`DeviceMesh` attribute that associates with this DTensor object.
.. note:: ``device_mesh`` is a read-only property, it can not be set.
"""
return self._spec.mesh
@property
def placements(self) -> Tuple[Placement, ...]:
"""
The placements attribute of this DTensor that describes the layout of this
DTensor on the its DeviceMesh.
.. note:: ``placements`` is a read-only property, it can not be set.
"""
return self._spec.placements
def __create_write_items__(self, fqn: str, object: Any):
from torch.distributed.checkpoint.planner_helpers import (
_create_write_items_for_dtensor,
)
if hasattr(self._local_tensor, "__create_write_items__"):
return self._local_tensor.__create_write_items__(fqn, object) # type: ignore[attr-defined]
elif isinstance(self._local_tensor, torch.Tensor):
return [_create_write_items_for_dtensor(fqn, object)]
else:
raise RuntimeError("Unsupported tensor type!")
def __create_chunk_list__(self):
from torch.distributed.checkpoint.planner_helpers import (
_create_chunk_from_dtensor,
)
if hasattr(self._local_tensor, "__create_chunk_list__"):
return self._local_tensor.__create_chunk_list__() # type: ignore[attr-defined]
elif isinstance(self._local_tensor, torch.Tensor):
return [_create_chunk_from_dtensor(self)]
else:
raise RuntimeError("Unsupported tensor type!")
def __get_tensor_shard__(self, index):
if hasattr(self._local_tensor, "__get_tensor_shard__"):
return self._local_tensor.__get_tensor_shard__(index) # type: ignore[attr-defined]
elif isinstance(self._local_tensor, torch.Tensor):
return self.to_local()
else:
raise RuntimeError("Unsupported tensor type!")
def distribute_tensor(
tensor: torch.Tensor,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according
to the ``placements`` specified. The rank of ``device_mesh`` and ``placements`` must be the
same. The ``tensor`` to distribute is the logical or "global" tensor, and the API would use
the ``tensor`` from first rank of the DeviceMesh dimension as the source of truth to preserve
the single-device semantic. If you want to construct a DTensor in the middle of the Autograd
computation, please use :meth:`DTensor.from_local` instead.
Args:
tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you
want to shard a tensor on a dimension that is not evenly divisible by
the number of devices in that mesh dimension, we use ``torch.chunk``
semantic to shard the tensor and scatter the shards. The uneven sharding
behavior is experimental and subject to change.
device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the
tensor, if not specified, must be called under a DeviceMesh context
manager, default: None
placements (List[:class:`Placement`], optional): the placements that
describes how to place the tensor on DeviceMesh, must have the same
number of elements as ``device_mesh.ndim``. If not specified, we will
by default replicate the tensor across the ``device_mesh`` from the
first rank of each dimension of the `device_mesh`.
Returns:
A :class:`DTensor` or ``XLAShardedTensor`` object.
.. note::
When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_tensor``
return `XLAShardedTensor` instead. see `this issue <https://github.com/pytorch/pytorch/issues/92909>`__
for more details. The XLA integration is experimental and subject to change.
"""
torch._C._log_api_usage_once("torch.dtensor.distribute_tensor")
# get default device mesh if there's nothing specified
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
device_type = device_mesh.device_type
if device_type == "xla":
try:
# call PyTorch/XLA SPMD for `xla` backend type device mesh.
# This returns XLAShardedTensor
from torch_xla.distributed.spmd import ( # type:ignore[import]
xla_distribute_tensor,
)
return xla_distribute_tensor(
tensor, device_mesh, placements
) # type:ignore[return-value]
except ImportError as e:
msg = "To use DTensor API with xla, you must install the torch_xla package!"
raise ImportError(msg) from e
# instantiate a RNG tracker if haven't. By default DTensor uses an
# OffsetBasedRNGTracker to perform random operators.
# TODO: the value assignment to global variable is not the ideal solution
# we can replace it in future.
if not random._rng_tracker and is_rng_supported_mesh(device_mesh):
random._rng_tracker = OffsetBasedRNGTracker(device_type)
if not tensor.is_leaf:
raise RuntimeError(
"`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!"
)
# convert tensor to the corresponding device type if it's not in that device type
if device_type != tensor.device.type and not tensor.is_meta:
tensor = tensor.to(device_type)
# set default placements to replicated if not specified
if placements is None:
placements = [Replicate() for _ in range(device_mesh.ndim)]
if len(placements) != device_mesh.ndim:
raise ValueError(
f"`placements` must have the same length as `device_mesh.ndim`! "
f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}."
)
if isinstance(tensor, DTensor):
# if the tensor is already a DTensor, we need to check:
# 1. if the we can further shard this DTensor if the two device mesh belong to
# the same parenet mesh and further sharding is possible.
# 2. check if device mesh and placements are the same
if tensor.device_mesh != device_mesh:
raise ValueError(
f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} "
f"to a different device mesh {device_mesh}."
)
if tensor.placements != tuple(placements):
raise ValueError(
f"Cannot distribute a DTensor with placements {tensor.placements} "
f"to a different placements {placements}. do you want to call "
f"`redistribute` instead?"
)
return tensor
local_tensor = tensor.detach()
# TODO(xilun): address sharding order
# distribute the tensor according to the placements.
placements = list(placements)
for idx, placement in enumerate(placements):
if placement.is_shard():
placement = cast(Shard, placement)
if placement.dim < 0:
# normalize shard placement dim
placement = Shard(placement.dim + tensor.ndim)
placements[idx] = placement
local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx)
elif placement.is_replicate():
placement = cast(Replicate, placement)
local_tensor = placement._replicate_tensor(local_tensor, device_mesh, idx)
else:
raise RuntimeError(
f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!"
)
placements = tuple(placements)
assert local_tensor is not None, "distributing a tensor should not be None"
# detach the local tensor passed to DTensor since after the construction
# of DTensor, autograd would work on top of DTensor instead of local tensor
spec = DTensorSpec(
mesh=device_mesh,
placements=placements,
tensor_meta=TensorMeta(
shape=tensor.size(),
stride=tensor.stride(),
dtype=tensor.dtype,
),
)
return DTensor(
local_tensor.requires_grad_(tensor.requires_grad),
spec,
requires_grad=tensor.requires_grad,
)
def _shard_tensor(
full_tensor: torch.Tensor,
placements: Sequence[Shard],
device_mesh: Optional[DeviceMesh] = None,
) -> "DTensor":
"""
Locally shards a full tensor based on indicated sharding arrangement, and
returns a DTensor containing the local shard.
.. warning:: This is a private API that is subject to change. It skips the
communication otherwise required by `distribute_tensor`. It is only
applicable to cases where all ranks have the same `full_tensor`. For
example, in distributed inference all ranks load from the same
checkpoint. This API will not check for data equality between ranks, it
is thus user's responsibility to ensure the `full_tensor` is the same
across ranks.
Args:
full_tensor (torch.Tensor): the full tensor to be sharded.
placements (Sequence[:class:`Shard`]): the placements that
describes how to place the local tensor on DeviceMesh.
device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
DTensor. Must have same dimension as the number of placements.
If not specified, would be retrieve from current context.
Returns:
A :class:`DTensor` object with the shard as its local tensor.
Examples:
>>> # xdoctest: +SKIP("need world_size and rank")
>>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
>>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
>>> dtensor = _shard_tensor(full_tensor, [Shard(1)], device_mesh)
"""
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
shape, offset = compute_local_shape_and_global_offset(
full_tensor.shape, device_mesh, placements
)
slices = [
slice(cur_offset, cur_offset + cur_shape)
for cur_shape, cur_offset in zip(shape, offset)
]
local_tensor = full_tensor[slices]
return DTensor.from_local(local_tensor, device_mesh, placements)
def distribute_module(
module: nn.Module,
device_mesh: Optional[DeviceMesh] = None,
partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None,
input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
) -> nn.Module:
"""
This function expose three functions to control the parameters/inputs/outputs of the module:
1. To perform sharding on the module before runtime execution by specifying the
``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor`
parameters according to the `partition_fn` specified).
2. To control the inputs or outputs of the module during runtime execution by
specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to
:class:`DTensor`, convert the output back to ``torch.Tensor``)
Args:
module (:class:`nn.Module`): user module to be partitioned.
device_mesh (:class:`DeviceMesh`): the device mesh to place the module.
partition_fn (Callable): the function to partition parameters (i.e. shard certain
parameters across the ``device_mesh``). If ``partition_fn`` is not specified,
by default we replicate all module parameters of ``module`` across the mesh.
input_fn (Callable): specify the input distribution, i.e. could control how the
input of the module is sharded. ``input_fn`` will be installed as a module
``forward_pre_hook`` (pre forward hook).
output_fn (Callable): specify the output distribution, i.e. could control how the
output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be
installed as a module ``forward_hook`` (post forward hook).
Returns:
A module that contains parameters/buffers that are all ``DTensor`` s.
.. note::
When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module``
return nn.Module with PyTorch/XLA SPMD annotated parameters. See
`this issue <https://github.com/pytorch/pytorch/issues/92909>`__
for more details. The XLA integration is experimental and subject to change.
"""
torch._C._log_api_usage_once("torch.dtensor.distribute_module")
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
device_type = device_mesh.device_type
if device_type == "xla":
try:
# This function annotates all module parameters for auto-partitioning with
# PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters
# according to the `partition_fn` specified.
from torch_xla.distributed.spmd import ( # type:ignore[import]
xla_distribute_module,
)
return xla_distribute_module(
module, device_mesh, partition_fn, input_fn, output_fn
) # type:ignore[return-value]
except ImportError as e:
msg = "To use DTensor API with xla, you must install the torch_xla package!"
raise ImportError(msg) from e
def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None:
# This function loop over the immediate module parameters and
# buffers, replicate all non DTensor params/buffers to DTensor
# parameters/buffers, if they have not been partitioned in the
# partition_fn, we can't easily use `module._apply` here
# because we don't know what happened inside partition_fn as
# user could do anything, i.e. install hooks, and we want to
# preserve those.
full_replicate = [Replicate()] * mesh.ndim
for key, param in m._parameters.items():
if param is not None and not isinstance(param, DTensor):
m.register_parameter(
key,
nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)),
)
for key, buffer in m._buffers.items():
if buffer is not None and not isinstance(buffer, DTensor):
m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate)
if partition_fn is None:
# if partition_fn not specified, we by default replicate
# all module params/buffers
for name, submod in module.named_modules():
replicate_module_params_buffers(submod, device_mesh)
else:
# apply partition_fun to submodules
for name, submod in module.named_modules():
partition_fn(name, submod, device_mesh)
replicate_module_params_buffers(submod, device_mesh)
# register input_fn as module forward pre hook
if input_fn is not None:
# check the input_fn signature
num_args = len(inspect.signature(input_fn).parameters)
if num_args == 2:
# input_fn only takes in inputs and device mesh
warnings.warn(
"Deprecating input_fn that takes two arguments (inputs, device_mesh), "
"please use input_fn that takes in (module, inputs, device_mesh) instead!",
FutureWarning,
stacklevel=2,
)
module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg]
elif num_args == 3:
# input_fn takes in module, inputs, device mesh
module.register_forward_pre_hook(
lambda mod, inputs: input_fn(mod, inputs, device_mesh)
)
else:
raise ValueError(
f"input_fn should take in 3 arguments, but got {num_args} arguments!"
)
# register output_fn as module forward hook
if output_fn is not None:
num_args = len(inspect.signature(output_fn).parameters)
if num_args == 2:
# output_fn only takes in outputs and device mesh
warnings.warn(
"Deprecating output_fn that takes two arguments (inputs, device_mesh), "
"please use output_fn that takes in (module, inputs, device_mesh) instead!",
FutureWarning,
stacklevel=2,
)
module.register_forward_hook(
lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg]
)
elif num_args == 3:
module.register_forward_hook(
lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh)
)
else:
raise ValueError(
f"output_fn should take in 3 arguments, but got {num_args} arguments!"
)
return module
# Below are tensor factory function APIs, which are used to create a DTensor directly. We need
# to make separate factory function APIs because tensor subclass could not override the tensor
# factory methods, and we need user to call the factory functions with user intended device_mesh
# and placements to create a proper DTensor.
def _dtensor_init_helper( # type: ignore[no-untyped-def]
init_op,
size: torch.Size,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
**kwargs,
) -> DTensor:
# from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
# if device_mesh is None, use the one from mesh resources
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
kwargs["device"] = device_mesh.device_type
# set default placements to replicated if not specified
placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim))
# check device_mesh againts placements
assert device_mesh.ndim == len(
placements
), "mesh dimension does not match the length of placements"
assert kwargs["layout"] == torch.strided, "layout value not supported!"
torch_stride = torch._prims_common.make_contiguous_strides_for(size)
# get local tensor shape
local_shape, _ = compute_local_shape_and_global_offset(
size, device_mesh, placements
)
# initialize the local tensor
if init_op == torch.full:
fill_value = kwargs.pop("fill_value", 0)
local_tensor = init_op(local_shape, fill_value, **kwargs)
elif init_op == torch.rand or init_op == torch.randn:
# this tensor meta is not used except `shape`
dtype = kwargs.get("dtype", torch.get_default_dtype())
tensor_meta = TensorMeta(size, (0,), dtype)
spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta)
if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker:
random._rng_tracker = random.OffsetBasedRNGTracker()
assert random._rng_tracker is not None
with random._rng_tracker._distribute_region(spec):
local_tensor = init_op(local_shape, **kwargs)
else:
local_tensor = init_op(local_shape, **kwargs)
spec = DTensorSpec(
device_mesh,
tuple(placements),
tensor_meta=TensorMeta(
size,
torch_stride,
local_tensor.dtype,
),
)
return DTensor(
local_tensor,
spec,
requires_grad=kwargs["requires_grad"],
)
def ones( # type: ignore[no-untyped-def]
*size,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined
by the variable argument ``size``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.ones,
torch_size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
def empty( # type: ignore[no-untyped-def]
*size,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor`
is defined by the variable argument ``size``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\
layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.empty,
torch_size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
def full( # type: ignore[no-untyped-def]
size,
fill_value,
*,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and
``placements``, with the shape defined by the argument ``size``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
fill_value(Scalar): the value to fill the output tensor with.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.full,
torch_size,
fill_value=fill_value,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
def rand( # type: ignore[no-untyped-def]
*size,
requires_grad: bool = False,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with random numbers from a uniform distribution
on the interval ``[0, 1)``. The shape of the tensor is defined by the variable
argument ``size``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.rand,
torch_size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
def randn( # type: ignore[no-untyped-def]
*size,
requires_grad: bool = False,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with random numbers from a normal distribution
with mean 0 and variance 1. The shape of the tensor is defined by the variable
argument ``size``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.randn,
torch_size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
def zeros( # type: ignore[no-untyped-def]
*size,
requires_grad: bool = False,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with the scalar value 0.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))
Keyword args:
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`.
Default: ``torch.strided``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.zeros,
torch_size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
|