1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
|
"""
PyTorch version: https://github.com/pytorch/examples/blob/master/mnist/main.py
TensorFlow version: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/mnist.py
"""
# pip install thinc ml_datasets typer
from thinc.api import Model, chain, Relu, Softmax, Adam
import ml_datasets
from wasabi import msg
from tqdm import tqdm
import typer
def main(
n_hidden: int = 256, dropout: float = 0.2, n_iter: int = 10, batch_size: int = 128
):
# Define the model
model: Model = chain(
Relu(nO=n_hidden, dropout=dropout),
Relu(nO=n_hidden, dropout=dropout),
Softmax(),
)
# Load the data
(train_X, train_Y), (dev_X, dev_Y) = ml_datasets.mnist()
# Set any missing shapes for the model.
model.initialize(X=train_X[:5], Y=train_Y[:5])
train_data = model.ops.multibatch(batch_size, train_X, train_Y, shuffle=True)
dev_data = model.ops.multibatch(batch_size, dev_X, dev_Y)
# Create the optimizer.
optimizer = Adam(0.001)
for i in range(n_iter):
for X, Y in tqdm(train_data, leave=False):
Yh, backprop = model.begin_update(X)
backprop(Yh - Y)
model.finish_update(optimizer)
# Evaluate and print progress
correct = 0
total = 0
for X, Y in dev_data:
Yh = model.predict(X)
correct += (Yh.argmax(axis=1) == Y.argmax(axis=1)).sum()
total += Yh.shape[0]
score = correct / total
msg.row((i, f"{score:.3f}"), widths=(3, 5))
if __name__ == "__main__":
typer.run(main)
|