1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
|
#include <torch/script.h>
int main(int argc, char* argv[]) {
if (argc != 3) {
std::cerr << "Usage: " << argv[0] << " <JIT_OBJECT_DIR> <INPUT_AUDIO_FILE>"
<< std::endl;
return -1;
}
torch::jit::script::Module loader, encoder, decoder;
std::cout << "Loading module from: " << argv[1] << std::endl;
try {
loader = torch::jit::load(std::string(argv[1]) + "/loader.zip");
} catch (const c10::Error& error) {
std::cerr << "Failed to load the module:" << error.what() << std::endl;
return -1;
}
try {
encoder = torch::jit::load(std::string(argv[1]) + "/encoder.zip");
} catch (const c10::Error& error) {
std::cerr << "Failed to load the module:" << error.what() << std::endl;
return -1;
}
try {
decoder = torch::jit::load(std::string(argv[1]) + "/decoder.zip");
} catch (const c10::Error& error) {
std::cerr << "Failed to load the module:" << error.what() << std::endl;
return -1;
}
std::cout << "Loading the audio" << std::endl;
auto waveform = loader.forward({c10::IValue(argv[2])});
std::cout << "Running inference" << std::endl;
auto emission = encoder.forward({waveform});
std::cout << "Generating the transcription" << std::endl;
auto result = decoder.forward({emission});
std::cout << result.toStringRef() << std::endl;
std::cout << "Done." << std::endl;
}
|