File: main.py

package info (click to toggle)
pytorch-audio 0.7.2-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 5,512 kB
  • sloc: python: 15,606; cpp: 1,352; sh: 257; makefile: 21
file content (71 lines) | stat: -rwxr-xr-x 1,984 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#!/usr/bin/env python3
"""Generate torchscript object of specific torhcaudio version.

This requires that the corresponding torchaudio (and torch) is installed.
"""
import os
import sys
import argparse


_BASE_OBJ_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets')


def _parse_args():
    parser = argparse.ArgumentParser(
        description=__doc__
    )
    parser.add_argument(
        '--mode', choices=['generate', 'validate'], required=True,
        help=(
            '"generate" generates Torchscript objects of the specific torchaudio '
            'in the given directory. '
            '"validate" validates if the objects in the givcen directory are compatible '
            'with the current torhcaudio.'
        )
    )
    parser.add_argument(
        '--version', choices=['0.6.0'], required=True,
        help='torchaudio version.'
    )
    parser.add_argument(
        '--base-obj-dir', default=_BASE_OBJ_DIR,
        help='Directory where objects are saved/loaded.'
    )
    return parser.parse_args()


def _generate(version, output_dir):
    if version == '0.6.0':
        import ver_060
        ver_060.generate(output_dir)
    else:
        raise ValueError(f'Unexpected torchaudio version: {version}')


def _validate(version, input_dir):
    if version == '0.6.0':
        import ver_060
        ver_060.validate(input_dir)
    else:
        raise ValueError(f'Unexpected torchaudio version: {version}')


def _get_obj_dir(base_dir, version):
    py_version = f'{sys.version_info.major}.{sys.version_info.minor}'
    return os.path.join(base_dir, f'{version}-py{py_version}')


def _main():
    args = _parse_args()
    obj_dir = _get_obj_dir(args.base_obj_dir, args.version)
    if args.mode == 'generate':
        _generate(args.version, obj_dir)
    elif args.mode == 'validate':
        _validate(args.version, obj_dir)
    else:
        raise ValueError(f'Unexpected mode: {args.mode}')


if __name__ == '__main__':
    _main()