1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
|
import copy
import click
import torch
class Serializer(torch.nn.Module):
def __init__(self, data):
super().__init__()
for key in data:
setattr(self, key, data[key])
@click.command()
@click.option(
"--input-path",
type=str,
default="",
required=True,
help="path to the ExportedProgram",
)
@click.option(
"--output-path",
type=str,
default="",
required=True,
)
def main(
input_path: str = "",
output_path: str = "",
) -> None:
data = {}
ep = torch.export.load(input_path)
with torch.no_grad():
example_inputs = ep.example_inputs[0]
# Get scripted original module.
module = torch.jit.trace(copy.deepcopy(ep.module()), example_inputs)
# Get aot compiled module.
so_path = torch._inductor.aot_compile(ep.module(), example_inputs)
runner = torch.fx.Interpreter(ep.module())
output = runner.run(example_inputs)
if isinstance(output, (list, tuple)):
output = list(output)
else:
output = [output]
data.update(
{
"script_module": module,
"model_so_path": so_path,
"inputs": list(example_inputs),
"outputs": output,
}
)
torch.jit.script(Serializer(data)).save(output_path)
if __name__ == "__main__":
main()
|