1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
|
import argparse
import inspect
import os
import sys
import time
from datetime import timedelta
from datasets import load_dataset, load_metric
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import torch._dynamo
from torch.utils.data import DataLoader
torch.backends.cuda.matmul.allow_tf32 = True
# You will download around 84G dataset if you run this end to end training/evaluation example.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def data_processing(num_samples, batch_size):
dataset = load_dataset("yelp_review_full")
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
small_train_dataset = tokenized_datasets["train"].select(range(num_samples))
small_eval_dataset = tokenized_datasets["test"].select(range(num_samples))
train_dataloader = DataLoader(small_train_dataset, batch_size=batch_size)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=batch_size)
return train_dataloader, eval_dataloader
def training_iter_fn(batch, model, optimizer):
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
def model_training_evaluation(
backend, train_dataloader, eval_dataloader, model, optimizer, num_epochs, evaluation
):
model.to(device)
model.train()
loss_history = []
if not backend:
# Run with native Pytorch
opt_training_iter_fn = training_iter_fn
else:
# Support backends: eager, aot_eager, aot_nvfuser and inductor
opt_training_iter_fn = torch._dynamo.optimize(backend)(training_iter_fn)
for epoch in range(num_epochs):
running_loss = 0.0
for i, batch in enumerate(train_dataloader, 0):
batch = {k: v.to(device) for k, v in batch.items()}
loss = opt_training_iter_fn(batch, model, optimizer)
running_loss += loss.item()
if i % 100 == 99:
loss_history.append(running_loss / 100)
running_loss = 0.0
if evaluation:
metric = load_metric("accuracy")
model.eval()
if not backend:
opt_model = model
else:
opt_model = torch._dynamo.optimize(backend)(model)
for batch in eval_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = opt_model(**batch)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
metric.add_batch(predictions=predictions, references=batch["labels"])
return loss_history, metric.compute()
else:
return loss_history, None
def check_loss(ref_loss, res_loss):
assert len(ref_loss) == len(res_loss)
length = len(ref_loss)
x = min(length, 10)
return sum(res_loss[-x:]) / 10 <= sum(ref_loss[-x:]) / 10 + 0.1
def parse_args():
parser = argparse.ArgumentParser(
description="TorchDynamo end to end training/evaluation benchmark"
)
parser.add_argument(
"--epochs", type=int, default=10, help="number of epochs to train (default: 10)"
)
parser.add_argument(
"--num-samples",
type=int,
default=1000,
help="number of samples to train/eval (default: 1000)",
)
parser.add_argument(
"--batch-size",
type=int,
default=8,
help="input batch size for training (default: 8)",
)
parser.add_argument(
"--lr", type=float, default=5e-5, help="learning rate (default: 5e-5)"
)
parser.add_argument(
"--backend",
choices=torch._dynamo.list_backends(exclude_tags=None),
default="inductor",
help="train/evaluate model with a given backend (default: inductor)",
)
parser.add_argument(
"--optimizer",
default="Adam",
help="train model using a given optimizer (default: Adam)",
)
parser.add_argument(
"--evaluation",
action="store_true",
help="running evaluation after model training",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
train_dataloader, eval_dataloader = data_processing(
args.num_samples, args.batch_size
)
model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-cased", num_labels=5
)
optimizer_cls = getattr(sys.modules["torch.optim"], args.optimizer)
if "capturable" in inspect.signature(optimizer_cls).parameters.keys():
optimizer = optimizer_cls(model.parameters(), lr=args.lr, capturable=True)
else:
optimizer = optimizer_cls(model.parameters(), lr=args.lr)
native_start = time.time()
ref_loss, accuracy = model_training_evaluation(
None,
train_dataloader,
eval_dataloader,
model,
optimizer,
args.epochs,
args.evaluation,
)
native_end = time.time()
res_loss, accuracy = model_training_evaluation(
args.backend,
train_dataloader,
eval_dataloader,
model,
optimizer,
args.epochs,
args.evaluation,
)
dynamo_end = time.time()
if check_loss(ref_loss, res_loss):
print(
"[PASSED] TorchDynamo end to end training loss is less than or equal to native PyTorch"
)
else:
print(
"[FAILED] TorchDynamo end to end training loss is greater than native Pytorch"
)
if args.evaluation:
print(f"Model accuracy: {accuracy}")
native_elapsed = native_end - native_start
dynamo_elapsed = dynamo_end - native_end
print(
f"Train model on {args.epochs} epochs with backend {args.backend} and optimizer {args.optimizer}:"
)
print(f"PyTorch spent {timedelta(seconds=native_elapsed/args.epochs)} per epoch")
print(
f"TorchDynamo spent {timedelta(seconds=dynamo_elapsed/args.epochs)} per epoch"
)
if __name__ == "__main__":
main()
|