1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
|
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import os
import unittest
from functools import wraps
from typing import Any, Callable, Dict, Tuple
import numpy as np
import torch
from torch import nn
from torch.distributed._tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,
Replicate,
Shard,
)
from torch.testing._internal.common_utils import run_tests, TestCase
# wrapper to check xla test requirements
def with_xla(func: Callable) -> Callable:
assert func is not None
@wraps(func) # pyre-ignore[6]
def wrapper(
self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc]
) -> None:
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
os.environ["XLA_USE_SPMD"] = "1"
try:
import torch_xla # type:ignore[import] # noqa: F401
except ImportError as exc:
raise unittest.SkipTest("torch_xla is not installed.") from exc
self.device_type = "xla"
func(self, *args, **kwargs) # type: ignore[misc]
os.environ["XLA_USE_SPMD"] = "0"
return wrapper
class DTensorXLAIntegrationTest(TestCase):
class SimpleLinear(nn.Module):
def __init__(self) -> None:
super(DTensorXLAIntegrationTest.SimpleLinear, self).__init__()
self.fc1 = nn.Linear(128, 64)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, 1)
def forward(self, x):
y = self.relu(self.fc1(x))
z = self.fc2(y)
return z
@with_xla
def test_xla_distribute_tensor_1d_shard(self):
import torch_xla.runtime as xr # type:ignore[import]
device_count = xr.global_runtime_device_count()
if device_count > 1:
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Shard(0)]
for requires_grad in [True, False]:
tensor_to_shard = torch.randn(
3 * device_count, 3, requires_grad=requires_grad
)
dist_tensor = distribute_tensor(
tensor_to_shard, device_mesh, shard_spec
)
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
assert type(dist_tensor).__name__ == "XLAShardedTensor"
global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined]
self.assertEqual(
global_tensor.size(), torch.Size([3 * device_count, 3])
)
local_tensor = dist_tensor.local_shards[0].data
self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
if requires_grad:
self.assertTrue(dist_tensor.global_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)
@with_xla
def test_xla_distribute_tensor_1d_replicate(self):
import torch_xla.runtime as xr # type:ignore[import]
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Replicate()]
for requires_grad in [True, False]:
tensor_to_shard = torch.randn(
3 * device_count, 3, requires_grad=requires_grad
)
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
assert type(dist_tensor).__name__ == "XLAShardedTensor"
global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined]
self.assertEqual(global_tensor.size(), torch.Size([3 * device_count, 3]))
local_tensor = dist_tensor.local_shards[0].data
self.assertEqual(local_tensor.size(), torch.Size([3 * device_count, 3]))
if requires_grad:
self.assertTrue(dist_tensor.global_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)
@with_xla
def test_xla_distribute_tensor_2d(self):
import torch_xla.runtime as xr # type:ignore[import]
device_count = xr.global_runtime_device_count()
if device_count > 1:
device_mesh = DeviceMesh(
"xla", np.array(range(device_count)).reshape(2, device_count // 2)
)
shard_spec = [Replicate(), Shard(0)]
for requires_grad in [True, False]:
tensor_to_shard = torch.randn(
3 * device_count // 2, 3, requires_grad=requires_grad
)
dist_tensor = distribute_tensor(
tensor_to_shard, device_mesh, shard_spec
)
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
assert type(dist_tensor).__name__ == "XLAShardedTensor"
global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined]
self.assertEqual(
global_tensor.size(), torch.Size([3 * device_count // 2, 3])
)
local_tensor = dist_tensor.local_shards[0].data
self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
if requires_grad:
self.assertTrue(dist_tensor.global_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)
@with_xla
def text_xla_distribute_module(self):
import torch_xla # type:ignore[import]
import torch_xla.core.xla_model as xm # type:ignore[import]
import torch_xla.runtime as xr # type:ignore[import]
model = self.SimpleLinear().to(xm.xla_device())
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
def shard_params(mod_name, mod, mesh):
shard_spec = [Shard(0)]
# annoate fc1 and fc2
if isinstance(mod, nn.Linear):
for name, param in mod.named_parameters():
# annotate the parameter tensors directly
distribute_tensor(param, mesh, shard_spec)
sharded_model = distribute_module(model, device_mesh, shard_params)
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc1.weight) != ""
)
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc2.weight) != ""
)
if __name__ == "__main__":
run_tests()
|