1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
|
Train the Model on the Device
==============================
Once the training artifacts are generated, the model can be trained on the device using the onnxruntime training python API.
The expected training artifacts are:
1. The training onnx model
2. The checkpoint state
3. The optimizer onnx model
4. The eval onnx model (optional)
Sample usage:
.. code-block:: python
from onnxruntime.training.api import CheckpointState, Module, Optimizer
# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)
# Create the module
module = Module(path_to_the_training_model,
state,
path_to_the_eval_model,
device="cpu")
optimizer = Optimizer(path_to_the_optimizer_model, module)
# Training loop
for ...:
module.train()
training_loss = module(...)
optimizer.step()
module.lazy_reset_grad()
# Eval
module.eval()
eval_loss = module(...)
# Save the checkpoint
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)
.. autoclass:: onnxruntime.training.api.checkpoint_state.Parameter
:members:
:show-inheritance:
:member-order: bysource
:inherited-members:
:special-members: __repr__
.. autoclass:: onnxruntime.training.api.checkpoint_state.Parameters
:members:
:show-inheritance:
:member-order: bysource
:inherited-members:
:special-members: __getitem__, __setitem__, __contains__, __iter__, __repr__, __len__
.. autoclass:: onnxruntime.training.api.checkpoint_state.Properties
:members:
:show-inheritance:
:member-order: bysource
:inherited-members:
:special-members: __getitem__, __setitem__, __contains__, __iter__, __repr__, __len__
.. autoclass:: onnxruntime.training.api.CheckpointState
:members:
:show-inheritance:
:member-order: bysource
:inherited-members:
.. autoclass:: onnxruntime.training.api.Module
:members:
:show-inheritance:
:member-order: bysource
:inherited-members:
:special-members: __call__
.. autoclass:: onnxruntime.training.api.Optimizer
:members:
:show-inheritance:
:member-order: bysource
:inherited-members:
.. autoclass:: onnxruntime.training.api.LinearLRScheduler
:members:
:show-inheritance:
:member-order: bysource
:inherited-members:
|