1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
|
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import unittest
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor.experimental._attention import (
_AttentionContextParallel,
_CausalBehavior,
_cp_options,
_is_causal_behavior,
_RotateMethod,
context_parallel,
context_parallel_unshard,
set_rotate_method,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import parallelize_module
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_FUSED_ATTENTION,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
decorateIf,
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfRocm,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
ModelArgs,
Transformer,
with_comms,
)
c10d_functional = torch.ops.c10d_functional
backends = []
if PLATFORM_SUPPORTS_FLASH_ATTENTION:
backends.append(SDPBackend.FLASH_ATTENTION)
if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
backends.append(SDPBackend.EFFICIENT_ATTENTION)
rotater_enum_to_str = {
_RotateMethod.ALL_GATHER: "allgather",
_RotateMethod.ALL_TO_ALL: "alltoall",
} # mapping from _RotateMethod enum to string
class RingAttentionTest(DTensorTestBase):
@property
def world_size(self) -> int:
return torch.cuda.device_count()
@skip_if_lt_x_gpu(2)
@skipIfRocm # Missing _c10d_functional_autograd::all_to_all_single
@unittest.skipIf(
not PLATFORM_SUPPORTS_FUSED_ATTENTION,
"Does not support flash nor efficient attention",
)
@with_comms
@decorateIf(
unittest.skip, lambda params: params["load_balance"] and not params["is_causal"]
)
@parametrize("is_causal", [True, False])
@parametrize("compiled", [True, False])
@parametrize("backend", backends)
@parametrize("load_balance", [True, False])
@parametrize("rotater", [_RotateMethod.ALL_TO_ALL, _RotateMethod.ALL_GATHER])
def test_ring_attention_sdpa(
self,
is_causal: bool,
compiled: bool,
backend: SDPBackend,
load_balance: bool,
rotater: _RotateMethod,
) -> None:
set_rotate_method(rotater_enum_to_str[rotater])
self.assertEqual(_cp_options.rotate_method, rotater)
device_mesh = DeviceMesh(self.device_type, torch.arange(0, self.world_size))
dtype = torch.bfloat16
bs = 8
query_tokens = 64
context_tokens = 64
dim = 32
nheads = 8
torch.manual_seed(10)
dtype = (
torch.bfloat16 if backend == SDPBackend.FLASH_ATTENTION else torch.float32
)
_cp_options.enable_load_balance = load_balance
q = torch.rand(
(bs, nheads, self.world_size * query_tokens, dim),
device=self.device_type,
dtype=dtype,
requires_grad=True,
)
k = torch.rand(
(bs, nheads, self.world_size * context_tokens, dim),
device=self.device_type,
dtype=dtype,
requires_grad=True,
)
v = torch.rand(
(bs, nheads, self.world_size * context_tokens, dim),
device=self.device_type,
dtype=dtype,
requires_grad=True,
)
# Ensure all ranks have the same initialization data.
with torch.no_grad():
dist.broadcast(q, src=0)
dist.broadcast(k, src=0)
dist.broadcast(v, src=0)
with sdpa_kernel(backend):
out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
out.sum().backward()
cp_q = q.detach().clone()
cp_k = k.detach().clone()
cp_v = v.detach().clone()
# Theoretically, context_parallel() should not be used to shard
# parameters because when require_grad is True, resize_ is not
# allowed. But requires_grad of cp_q, cp_k, and cp_v are False
# now. So we can just use context_parallel() to shard q, k, v.
# In reality, context_paralle() should be used to shard the input.
with context_parallel(
device_mesh, buffers=(cp_q, cp_k, cp_v), buffer_seq_dims=(2, 2, 2)
):
cp_q.requires_grad = True
cp_k.requires_grad = True
cp_v.requires_grad = True
with CommDebugMode() as comm_mode:
with sdpa_kernel(backend):
if compiled:
fn = torch.compile(
F.scaled_dot_product_attention,
fullgraph=True,
backend="aot_eager",
)
else:
fn = F.scaled_dot_product_attention
cp_out = fn(cp_q, cp_k, cp_v, is_causal=is_causal)
cp_out.sum().backward()
if not compiled and rotater == _RotateMethod.ALL_TO_ALL:
# Compiler and CommDebugMode do not work well together.
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_to_all_single: self.world_size * 3
- 2
},
)
# Due to numerical error, we need to choose different atol for different
# attention kernels
cp_out, cp_dq, cp_dk, cp_dv = context_parallel_unshard(
device_mesh,
[cp_out, cp_q.grad, cp_k.grad, cp_v.grad],
[2, 2, 2, 2],
)
atol = (
1e-08
if backend == SDPBackend.EFFICIENT_ATTENTION
else 1e-3 * self.world_size
)
self.assertTrue(torch.allclose(out, cp_out, atol=atol))
atol = (
2e-06
if backend == SDPBackend.EFFICIENT_ATTENTION
else 8e-3 * self.world_size
)
self.assertTrue(torch.allclose(q.grad, cp_dq, atol=atol))
self.assertTrue(torch.allclose(k.grad, cp_dk, atol=atol))
self.assertTrue(torch.allclose(v.grad, cp_dv, atol=atol))
cp_q.grad = None
cp_k.grad = None
cp_v.grad = None
cp_q.requires_grad = False
cp_k.requires_grad = False
cp_v.requires_grad = False
def test_is_causal_behavior(self) -> None:
_cp_options.enable_load_balance = False
self.assertEqual(
_is_causal_behavior(rank=0, world_size=4, i=0, is_causal=False),
_CausalBehavior.NOT_IS_CAUSAL,
)
ranks = [
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.SKIP],
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL],
]
for rank, iters in enumerate(ranks):
for i, behavior in enumerate(iters):
self.assertEqual(
_is_causal_behavior(rank=rank, world_size=2, i=i, is_causal=True),
behavior,
)
_cp_options.enable_load_balance = True
ranks = [
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL],
[_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL],
]
for rank, iters in enumerate(ranks):
for i, behavior in enumerate(iters):
self.assertEqual(
_is_causal_behavior(rank=rank, world_size=2, i=i, is_causal=True),
behavior,
)
@skip_if_lt_x_gpu(2)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
)
@with_comms
@sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION])
@parametrize("is_causal", [True, False])
@parametrize("rotater", [_RotateMethod.ALL_GATHER, _RotateMethod.ALL_TO_ALL])
def test_ring_attention_native_transformer(
self, is_causal: bool, rotater: _RotateMethod
) -> None:
_cp_options.enable_load_balance = is_causal
set_rotate_method(rotater_enum_to_str[rotater])
self.assertEqual(_cp_options.rotate_method, rotater)
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, self.world_size),
)
dtype = torch.bfloat16
bs = 8
ntokens = 8
dim = 32
nheads = 8
num_layers = 2
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim,
nhead=nheads,
dim_feedforward=dim,
batch_first=True,
).to(dtype)
encoder_layer = parallelize_module(
module=encoder_layer,
device_mesh=device_mesh,
parallelize_plan={
"self_attn": _AttentionContextParallel(),
},
)
model = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
model = model.to(self.device_type).to(dtype)
mask = (
nn.Transformer.generate_square_subsequent_mask(
ntokens, device=self.device_type, dtype=dtype
)
if is_causal
else None
)
seq = torch.rand((bs, ntokens, dim), device=self.device_type, dtype=dtype)
with CommDebugMode() as comm_mode:
out = model(seq, mask=mask, is_causal=is_causal)
if rotater == _RotateMethod.ALL_TO_ALL:
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_to_all_single: (self.world_size - 1)
* num_layers,
},
)
else:
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_gather_into_tensor: num_layers,
},
)
with CommDebugMode() as comm_mode:
out.sum().backward()
if rotater == _RotateMethod.ALL_TO_ALL:
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_to_all_single: (self.world_size * 2 - 1)
* num_layers,
},
)
else:
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_gather_into_tensor: num_layers,
c10d_functional.all_to_all_single: self.world_size * num_layers,
},
)
@skip_if_lt_x_gpu(2)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
)
@with_comms
@sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION])
@parametrize("rotater", [_RotateMethod.ALL_GATHER, _RotateMethod.ALL_TO_ALL])
def test_ring_attention_custom_transformer(self, rotater: _RotateMethod) -> None:
set_rotate_method(rotater_enum_to_str[rotater])
self.assertEqual(_cp_options.rotate_method, rotater)
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, self.world_size),
)
dtype = torch.bfloat16
bs = 2
args = ModelArgs()
model = Transformer(args).to(dtype).to(self.device_type)
model = parallelize_module(
module=model,
device_mesh=device_mesh,
parallelize_plan={
f"layers.{i}.attention": _AttentionContextParallel()
for i in range(args.n_layers)
},
)
seq = torch.randint(
args.vocab_size, (bs, args.max_seq_len), device=self.device_type
)
with CommDebugMode() as comm_mode:
out = model(seq)
if rotater == _RotateMethod.ALL_TO_ALL:
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_to_all_single: (self.world_size - 1)
* args.n_layers,
},
)
else:
self.assertDictEqual(
comm_mode.get_comm_counts(),
{c10d_functional.all_gather_into_tensor: args.n_layers},
)
with CommDebugMode() as comm_mode:
out.sum().backward()
if rotater == _RotateMethod.ALL_TO_ALL:
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_to_all_single: (self.world_size * 2 - 1)
* args.n_layers,
},
)
else:
self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_gather_into_tensor: args.n_layers,
c10d_functional.all_to_all_single: self.world_size * args.n_layers,
},
)
if backends:
instantiate_parametrized_tests(RingAttentionTest)
if __name__ == "__main__":
run_tests()
|