1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
|
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
"""
This module implements Paged Attention on top of flex_attention.
This module is experimental and subject to change.
"""
from typing import Optional, Union
import torch
from torch.nn.attention.flex_attention import (
_identity,
_mask_mod_signature,
_score_mod_signature,
BlockMask,
noop_mask,
)
__all__ = ["PagedAttention"]
def _cdiv(
x: Union[int, float, torch.Tensor], multiple: Union[int, float, torch.Tensor]
):
return (x + multiple - 1) // multiple
class PagedAttention:
"""
PagedAttention supports flex attention inference with a large batch size.
With PagedAttention, a batch of key/value tensors with varying kv length
is splitted into tensor blocks of fixed length and cached in a compact way.
Thus we can avoid redundant memory consumption due to varying kv length and
support a larger batch size.
"""
def __init__(
self,
n_pages: int,
page_size: int,
max_batch_size: int,
device: str = "cuda",
):
# number of pages
self.n_pages = n_pages
# number of tokens per page
self.page_size = page_size
# page table: [batch, logical_block_idx] -> physical_page_idx
self.page_table = -torch.ones(
(max_batch_size, self.n_pages), dtype=torch.int64, device=device
)
# capacity: batch_idx -> allocated sequence length
self.capacity = torch.zeros(max_batch_size, dtype=torch.int64, device=device)
# index of empty pages that is available for allocation
self.empty_pages = list(range(n_pages - 1, -1, -1))
# mapping from physical page index to logical page index
self.physical_to_logical = -torch.ones(
(max_batch_size, n_pages), dtype=torch.int64, device=device
)
def reserve(self, batch_idx: torch.Tensor, seq_len: torch.Tensor) -> None:
"""
Requests the capacity of a given batch to be at least enough to
hold `seq_len` elements.
Args:
batch_idx (Tensor): batch index to be reserved; shape :math:`(1)`.
seq_len (Tensor): minimum capacity for the given batch; shape :math:`(1)`.
"""
if seq_len <= self.capacity[batch_idx]:
return
num_pages_to_allocate = _cdiv(
seq_len - self.capacity[batch_idx], self.page_size
)
assert len(self.empty_pages) >= num_pages_to_allocate, (
f"requested {num_pages_to_allocate.item()} pages "
f"but there are only {len(self.empty_pages)} empty pages"
)
start_page_idx = self.capacity[batch_idx] // self.page_size
end_page_idx = start_page_idx + num_pages_to_allocate
# find empty physical pages
allocated_pages = torch.tensor(
self.empty_pages[-num_pages_to_allocate:],
device=num_pages_to_allocate.device,
)
self.empty_pages = self.empty_pages[:-num_pages_to_allocate]
# update page table
self.page_table[
batch_idx,
start_page_idx:end_page_idx,
] = allocated_pages
# update metadata
self.physical_to_logical[batch_idx, allocated_pages] = torch.arange(
start_page_idx.item(),
end_page_idx.item(),
device=num_pages_to_allocate.device,
)
self.capacity[batch_idx] += num_pages_to_allocate * self.page_size
def erase(self, batch_idx: torch.Tensor) -> None:
"""
Removes a single batch from paged attention.
Args:
batch_idx (Tensor): batch index to be removed; shape :math:`(1)`.
"""
# find allocated pages
allocated_page_idx = self.page_table[batch_idx] != -1
allocated_pages = self.page_table[batch_idx][allocated_page_idx]
# clean metadata
self.capacity[batch_idx] = 0
self.empty_pages += allocated_pages.tolist()
self.physical_to_logical[batch_idx][:, allocated_pages] = -1
self.page_table[batch_idx] = -1
def assign(
self,
batch_idx: torch.Tensor,
input_pos: torch.Tensor,
k_val: torch.Tensor,
v_val: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
) -> None:
"""
Assigns new contents `val` to the storage `cache` at the location
`batch_idx` and `input_pos`.
Args:
batch_idx (Tensor): batch index; shape :math:`(B)`.
input_pos (Tensor): input positions to be assigned for the given batch; shape :math:`(S)`.
val (Tensor): value to be assigned; shape :math:`(B, H, S, D)`
cache (Tensor): the cache to store the values; shape:`(1, H, MAX_S, D)`
"""
if k_val.requires_grad:
raise RuntimeError("val must not require gradient")
B, H, S, K_D = k_val.shape
V_D = v_val.shape[3]
if B != batch_idx.shape[0]:
raise RuntimeError(
f"Expect val and batch_idx have the same batch size "
f"but got B={B} and B={batch_idx.shape[0]}."
)
if H != k_cache.shape[1]:
raise RuntimeError(
f"Expect val and cache has the same number of heads "
f"but got H={H} and H={k_cache.shape[1]}."
)
if S != input_pos.shape[0]:
raise RuntimeError(
f"Expect val and input_pos has the same length "
f"but got S={S} and S={input_pos.shape[0]}."
)
if K_D != k_cache.shape[3]:
raise RuntimeError(
f"Expect k_val and k_cache has the same hidden dim "
f"but got D={K_D} and D={k_cache.shape[3]}."
)
if V_D != v_cache.shape[3]:
raise RuntimeError(
f"Expect v_val and v_cache has the same hidden dim "
f"but got D={V_D} and D={v_cache.shape[3]}."
)
# find address
logical_block_idx = input_pos // self.page_size # [S]
logical_block_offset = input_pos % self.page_size # [S]
physical_block_idx = self.page_table[batch_idx][:, logical_block_idx] # [B, S]
addr = (
physical_block_idx * self.page_size + logical_block_offset[None, :]
).view(
-1
) # [B*S]
k_val = k_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, K_D)
v_val = v_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, V_D)
k_cache[:, :, addr, :] = k_val
v_cache[:, :, addr, :] = v_val
def convert_logical_block_mask(
self,
block_mask: BlockMask,
batch_idx: Optional[torch.Tensor] = None,
) -> BlockMask:
"""
Converts a logical block mask by mapping its logical kv indices to the corresponding
physical kv indices.
Args:
block_mask (BlockMask): logical block mask;
kv_indices shape :math:`(B, H, ROWS, MAX_BLOCKS_IN_COL)`.
batch_idx (Tensor): batch index corresponding to the block_mask
batch dimension. This provides flexibility to convert a
block mask with smaller batch size than the page table;
shape :math:`(1)`.
"""
B, H, ROWS, MAX_BLOCKS_IN_COL = block_mask.kv_indices.shape
if block_mask.BLOCK_SIZE[1] != self.page_size:
raise RuntimeError(
f"Expect block_mask has the same column block size as page_size"
f"but got size={block_mask.BLOCK_SIZE[1]} and size={self.page_size}"
)
# Increase the num columns of converted block mask from logical block mask's
# num columns to n_pages, since a) the converted block mask
# may have larger indices values; and b) `_ordered_to_dense` realizes
# a dense tensor with these converted indices. There would be an IndexError
# if using the logical block mask's num columns.
device = block_mask.kv_num_blocks.device
if batch_idx is None:
batch_idx = torch.arange(B, device=device)
page_table = self.page_table[batch_idx]
new_kv_num_blocks = block_mask.kv_num_blocks.clone()
new_kv_indices = torch.zeros(
(B, H, ROWS, self.n_pages), dtype=torch.int32, device=device
)
new_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = (
torch.gather(
page_table, 1, block_mask.kv_indices.view(B, -1).to(torch.int64)
)
.view(block_mask.kv_indices.shape)
.to(torch.int32)
)
new_full_kv_indices, new_full_kv_num_blocks = None, None
if block_mask.full_kv_num_blocks is not None:
assert block_mask.full_kv_indices is not None
new_full_kv_num_blocks = block_mask.full_kv_num_blocks.clone()
new_full_kv_indices = torch.zeros(
(B, H, ROWS, self.n_pages), dtype=torch.int32, device=device
)
new_full_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = (
torch.gather(
page_table,
1,
block_mask.full_kv_indices.view(B, -1).to(torch.int64),
)
.view(block_mask.full_kv_indices.shape)
.to(torch.int32)
)
new_mask_mod = self.get_mask_mod(block_mask.mask_mod)
seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size)
return BlockMask.from_kv_blocks(
new_kv_num_blocks,
new_kv_indices,
new_full_kv_num_blocks,
new_full_kv_indices,
block_mask.BLOCK_SIZE,
new_mask_mod,
seq_lengths=seq_lengths,
)
def get_mask_mod(
self, mask_mod: Optional[_mask_mod_signature]
) -> _mask_mod_signature:
"""
Converts a mask_mod based on mapping from the physical block index to the logical
block index.
Args:
mask_mod (_mask_mod_signature): mask_mod based on the logical block index.
"""
if mask_mod is None:
mask_mod = noop_mask
def new_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
):
physical_kv_block = physical_kv_idx // self.page_size
physical_kv_offset = physical_kv_idx % self.page_size
logical_block_idx = self.physical_to_logical[b, physical_kv_block]
logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset
return torch.where(
logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False
)
return new_mask_mod
def get_score_mod(
self, score_mod: Optional[_score_mod_signature]
) -> _score_mod_signature:
"""
Converts a score_mod based on mapping from the physical block index to the logical
block index.
Args:
score_mod (_score_mod_signature): score_mod based on the logical block index.
"""
if score_mod is None:
score_mod = _identity
def new_score_mod(
score: torch.Tensor,
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
):
physical_kv_block = physical_kv_idx // self.page_size
physical_kv_offset = physical_kv_idx % self.page_size
logical_block_idx = self.physical_to_logical[b, physical_kv_block]
logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset
return torch.where(
logical_block_idx >= 0,
score_mod(score, b, h, q_idx, logical_kv_idx),
float("-inf"),
)
return new_score_mod
|