File: generate.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (649 lines) | stat: -rw-r--r-- 18,991 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
import dataclasses
import itertools
import platform
import time
from typing import Optional, Tuple

import torchao
from mixtral_moe_model import ConditionalFeedForward, Transformer as MixtralMoE
from mixtral_moe_quantize import (
    ConditionalFeedForwardInt8,
    WeightOnlyInt8QuantHandler as MixtralMoEWeightOnlyInt8QuantHandler,
)
from model import Transformer as LLaMA
from quantize import WeightOnlyInt8QuantHandler as LLaMAWeightOnlyInt8QuantHandler

import torch
import torch._inductor.config


torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True  # Experimental feature to reduce compilation times, will be on by default in future
torch._inductor.config.assert_indirect_indexing = False

compiled = False


@dataclasses.dataclass
class GPTModelConfig:
    name: str
    module: type
    mode: Optional[str]
    quantizer: type
    token_per_sec: float
    memory_bandwidth: float
    compilation_time: float
    batch_size: Optional[int] = None


def device_sync(device):
    if "cuda" in device:
        torch.cuda.synchronize(device)
    elif "cpu" in device:
        pass
    else:
        print(f"device={device} is not yet suppported")


def get_arch_name() -> str:
    if torch.cuda.is_available():
        return torch.cuda.get_device_name()
    else:
        # This returns x86_64 or arm64 (for aarch64)
        return platform.machine()


def multinomial_sample_one_no_sync(
    probs_sort,
):  # Does multinomial sampling without a cuda synchronization
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)


def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs


def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    probs = logits_to_probs(logits[0, -1], temperature, top_k)
    idx_next = multinomial_sample_one_no_sync(probs)
    return idx_next, probs


def prefill(
    model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
) -> torch.Tensor:
    # input_pos: [B, S]
    logits = model(x, input_pos)
    return sample(logits, **sampling_kwargs)[0]


def decode_one_token(
    model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
    # input_pos: [B, 1]
    assert input_pos.shape[-1] == 1
    logits = model(x, input_pos)
    return sample(logits, **sampling_kwargs)


def decode_n_tokens(
    model: torch.nn.Module,
    cur_token: torch.Tensor,
    input_pos: torch.Tensor,
    num_new_tokens: int,
    **sampling_kwargs,
):
    new_tokens, new_probs = [], []
    for i in range(num_new_tokens):
        with torch.nn.attention.sdpa_kernel(
            torch.nn.attention.SDPBackend.MATH
        ):  # Actually better for Inductor to codegen attention here
            next_token, next_prob = decode_one_token(
                model, cur_token, input_pos, **sampling_kwargs
            )
            input_pos += 1
            new_tokens.append(next_token.clone())
            new_probs.append(next_prob.clone())
            cur_token = next_token.view(1, -1)

    return new_tokens, new_probs


@torch.no_grad()
def generate(
    model: torch.nn.Module, prompt: torch.Tensor, max_new_tokens: int, **sampling_kwargs
) -> torch.Tensor:
    device, dtype = prompt.device, prompt.dtype
    T = prompt.size(0)
    T_new = T + max_new_tokens
    max_seq_length = min(T_new, model.config.block_size)

    with torch.device(device):
        model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

    # create an empty tensor of the expected final shape and fill in the current tokens
    empty = torch.empty(T_new, dtype=dtype, device=device)
    empty[:T] = prompt
    seq = empty
    input_pos = torch.arange(0, T, device=device)

    next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
    seq[T] = next_token

    input_pos = torch.tensor([T], device=device, dtype=torch.int)

    generated_tokens, _ = decode_n_tokens(
        model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs
    )
    seq[T + 1 :] = torch.cat(generated_tokens)
    return seq


def _load_model(x: GPTModelConfig, device="cuda", precision=torch.bfloat16):
    with torch.device("meta"):
        model = x.module.from_name(x.name)
    model = model.to(dtype=precision)

    if x.mode == "int8":
        print("Using int8 weight-only quantization!")
        model = x.quantizer(model).convert_for_runtime()

    state_dict = model.state_dict()
    for k, v in state_dict.items():
        state_dict[k] = torch.nn.Parameter(
            torch.randn(v.shape, device=device).to(dtype=v.dtype),
            requires_grad=v.requires_grad,
        )
    model.load_state_dict(state_dict, assign=True)
    return model.eval()


# Only count activated parameters and buffers.
def _get_model_size(model):
    model_size = 0
    for name, child in model.named_children():
        if not isinstance(child, torch.nn.Embedding):
            model_size += sum(
                p.numel() * p.dtype.itemsize
                for p in itertools.chain(child.parameters(), child.buffers())
            )

    # Remove the inactivated experts from the model size if this is mixture of experts
    # architecture, since only activated experts are loaded.
    if hasattr(model.config, "num_experts"):
        config = model.config
        for submodule in model.modules():
            if isinstance(
                submodule, (ConditionalFeedForward, ConditionalFeedForwardInt8)
            ):
                model_size -= (
                    sum(
                        p.numel() * p.dtype.itemsize
                        for p in itertools.chain(
                            submodule.parameters(), child.buffers()
                        )
                    )
                    * (config.num_experts - config.num_activated_experts)
                    / config.num_experts
                )

    return model_size


def run_experiment(
    x: GPTModelConfig,
    num_samples: int = 5,
    max_new_tokens: int = 200,
    top_k: int = 200,
    temperature: float = 0.8,
    device: str = "cuda",
) -> None:
    print(f"Loading model {x.name}")
    t0 = time.time()
    model = _load_model(x, device=device)
    device_sync(device=device)  # MKG
    print(f"Time to load model: {time.time() - t0:.02f} seconds")

    prompt = torch.tensor(
        [1, 15043, 29892, 590, 1024, 338], device=device, dtype=torch.int32
    )
    prompt_length = prompt.size(0)

    torch.manual_seed(1234)
    model_size = _get_model_size(model)

    aggregate_metrics = {"tokens_per_sec": [], "memory_bandwidth": []}
    start = -1
    compilation_time = None

    if x.mode == "autoquant":
        print("Using autoquant")
        model = torchao.autoquant(model, manual=True, error_on_unseen=False)
        generate(model, prompt, max_new_tokens, temperature=temperature, top_k=top_k)
        model.finalize_autoquant()

    if x.mode == "autoquant_v2":
        print("Using autoquant_v2")
        from torchao.prototype.quantization.autoquant_v2 import autoquant_v2

        p = prompt.view(1, -1)
        T = prompt.size(0)
        T_new = T + max_new_tokens
        max_seq_length = min(T_new, model.config.block_size)
        input_pos = torch.arange(0, T, device=device)
        example_input = (p, input_pos)

        with torch.device(device):
            model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
        model = autoquant_v2(
            model,
            manual=True,
            error_on_unseen=False,
            example_input=example_input,
            batch_size=x.batch_size,
        )
        torch.compiler.cudagraph_mark_step_begin()
        generate(model, prompt, max_new_tokens, temperature=temperature, top_k=top_k)
        model.finalize_autoquant()

    global decode_one_token, prefill, compiled
    if not compiled:
        compiled = True
        decode_one_token = torch.compile(
            decode_one_token, mode="reduce-overhead", fullgraph=True
        )
        prefill = torch.compile(prefill, fullgraph=True)

    for i in range(start, num_samples):
        device_sync(device=device)  # MKG

        torch.compiler.cudagraph_mark_step_begin()
        t0 = time.perf_counter()
        y = generate(
            model, prompt, max_new_tokens, temperature=temperature, top_k=top_k
        )

        if i == -1:
            compilation_time = time.perf_counter() - t0
            print(f"Compilation time: {compilation_time:.2f} seconds")
            continue

        device_sync(device=device)  # MKG
        t = time.perf_counter() - t0
        tokens_generated = y.size(0) - prompt_length
        tokens_sec = tokens_generated / t
        aggregate_metrics["tokens_per_sec"].append(tokens_sec)
        aggregate_metrics["memory_bandwidth"].append(model_size * tokens_sec / 1e9)

    token_per_sec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item()
    memory_bandwidth = torch.mean(
        torch.tensor(aggregate_metrics["memory_bandwidth"])
    ).item()
    print(f"Average tokens/sec: {token_per_sec:.2f} tokens/sec")
    print(f"Average bandwidth achieved: {memory_bandwidth:.02f} GB/s")
    print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
    return token_per_sec, memory_bandwidth, compilation_time


# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_llama2_7b_bf16(device: str = "cuda"):
    from benchmark import Experiment

    model = GPTModelConfig(
        "Llama-2-7b-chat-hf",
        LLaMA,
        "bfloat16",
        LLaMAWeightOnlyInt8QuantHandler,
        94,
        1253,
        133,
    )
    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
        model, device=device
    )
    return [
        Experiment(
            model.name,
            "token_per_sec",
            model.token_per_sec,
            f"{token_per_sec:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "memory_bandwidth(GB/s)",
            model.memory_bandwidth,
            f"{memory_bandwidth:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "compilation_time(s)",
            model.compilation_time,
            f"{compilation_time:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
    ]


# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_llama2_7b_int8(device: str = "cuda"):
    from benchmark import Experiment

    model = GPTModelConfig(
        "Llama-2-7b-chat-hf",
        LLaMA,
        "int8",
        LLaMAWeightOnlyInt8QuantHandler,
        144,
        957,
        136,
    )
    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
        model, device=device
    )
    return [
        Experiment(
            model.name,
            "token_per_sec",
            model.token_per_sec,
            f"{token_per_sec:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "memory_bandwidth(GB/s)",
            model.memory_bandwidth,
            f"{memory_bandwidth:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "compilation_time(s)",
            model.compilation_time,
            f"{compilation_time:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
    ]


# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_mixtral_8x7b_int8(device: str = "cuda"):
    from benchmark import Experiment

    # We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
    model = GPTModelConfig(
        "Mixtral-8x7B-v0.1",
        MixtralMoE,
        "int8",
        MixtralMoEWeightOnlyInt8QuantHandler,
        175,
        1130,
        133,
    )
    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
        model, device=device
    )
    return [
        Experiment(
            model.name,
            "token_per_sec",
            model.token_per_sec,
            f"{token_per_sec:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "memory_bandwidth(GB/s)",
            model.memory_bandwidth,
            f"{memory_bandwidth:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "compilation_time(s)",
            model.compilation_time,
            f"{compilation_time:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
    ]


# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_llama2_7b_autoquant(device: str = "cuda"):
    from benchmark import Experiment

    model = GPTModelConfig(
        "Llama-2-7b-chat-hf",
        LLaMA,
        "autoquant",
        None,
        144,
        957,
        136,
    )
    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
        model, device=device
    )
    return [
        Experiment(
            model.name,
            "token_per_sec",
            model.token_per_sec,
            f"{token_per_sec:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "memory_bandwidth(GB/s)",
            model.memory_bandwidth,
            f"{memory_bandwidth:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "compilation_time(s)",
            model.compilation_time,
            f"{compilation_time:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
    ]


# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_mixtral_8x7b_autoquant(device: str = "cuda"):
    from benchmark import Experiment

    # We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
    model = GPTModelConfig(
        "Mixtral-8x7B-v0.1",
        MixtralMoE,
        "autoquant",
        None,
        175,
        1130,
        133,
    )
    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
        model, device=device
    )
    return [
        Experiment(
            model.name,
            "token_per_sec",
            model.token_per_sec,
            f"{token_per_sec:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "memory_bandwidth(GB/s)",
            model.memory_bandwidth,
            f"{memory_bandwidth:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "compilation_time(s)",
            model.compilation_time,
            f"{compilation_time:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
    ]


# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_llama2_7b_autoquant_v2(device: str = "cuda"):
    from benchmark import Experiment

    model = GPTModelConfig(
        "Llama-2-7b-chat-hf",
        LLaMA,
        "autoquant_v2",
        None,
        144,
        957,
        136,
        6,  # batch_size
    )
    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
        model, device=device
    )
    return [
        Experiment(
            model.name,
            "token_per_sec",
            model.token_per_sec,
            f"{token_per_sec:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "memory_bandwidth(GB/s)",
            model.memory_bandwidth,
            f"{memory_bandwidth:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "compilation_time(s)",
            model.compilation_time,
            f"{compilation_time:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
    ]


# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_mixtral_8x7b_autoquant_v2(device: str = "cuda"):
    from benchmark import Experiment

    # We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
    model = GPTModelConfig(
        "Mixtral-8x7B-v0.1",
        MixtralMoE,
        "autoquant_v2",
        None,
        175,
        1130,
        133,
        6,  # batch_size
    )
    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
        model, device=device
    )
    return [
        Experiment(
            model.name,
            "token_per_sec",
            model.token_per_sec,
            f"{token_per_sec:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "memory_bandwidth(GB/s)",
            model.memory_bandwidth,
            f"{memory_bandwidth:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
        Experiment(
            model.name,
            "compilation_time(s)",
            model.compilation_time,
            f"{compilation_time:.02f}",
            model.mode,
            device,
            get_arch_name(),
            True,
        ),
    ]