#! /usr/bin/env python3

import os
import subprocess
import sys
import tarfile
import tempfile

from six.moves.urllib.request import urlretrieve

from caffe2.python.models.download import downloadFromURLToFile, getURLFromName, deleteDirectory

class SomeClass:
    # largely copied from
    # https://github.com/onnx/onnx-caffe2/blob/master/tests/caffe2_ref_test.py
    def _download(self, model):
        model_dir = self._caffe2_model_dir(model)
        assert not os.path.exists(model_dir)
        os.makedirs(model_dir)
        for f in ['predict_net.pb', 'init_net.pb', 'value_info.json']:
            url = getURLFromName(model, f)
            dest = os.path.join(model_dir, f)
            try:
                try:
                    downloadFromURLToFile(url, dest,
                                          show_progress=False)
                except TypeError:
                    # show_progress not supported prior to
                    # Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1
                    # (Sep 17, 2017)
                    downloadFromURLToFile(url, dest)
            except Exception as e:
                print("Abort: {reason}".format(reason=e))
                print("Cleaning up...")
                deleteDirectory(model_dir)
                exit(1)

    def _caffe2_model_dir(self, model):
        caffe2_home = os.path.expanduser('~/.caffe2')
        models_dir = os.path.join(caffe2_home, 'models')
        return os.path.join(models_dir, model)

    def _onnx_model_dir(self, model):
        onnx_home = os.path.expanduser('~/.onnx')
        models_dir = os.path.join(onnx_home, 'models')
        model_dir = os.path.join(models_dir, model)
        return model_dir, os.path.dirname(model_dir)

    # largely copied from
    # https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py
    def _prepare_model_data(self, model):
        model_dir, models_dir = self._onnx_model_dir(model)
        if os.path.exists(model_dir):
            return
        os.makedirs(model_dir)
        url = 'https://s3.amazonaws.com/download.onnx/models/{}.tar.gz'.format(model)

        # On Windows, NamedTemporaryFile cannot be opened for a
        # second time
        download_file = tempfile.NamedTemporaryFile(delete=False)
        try:
            download_file.close()
            print('Start downloading model {} from {}'.format(model, url))
            urlretrieve(url, download_file.name)
            print('Done')
            with tarfile.open(download_file.name) as t:
                t.extractall(models_dir)
        except Exception as e:
            print('Failed to prepare data for model {}: {}'.format(model, e))
            raise
        finally:
            os.remove(download_file.name)

models = [
    'bvlc_alexnet',
    'densenet121',
    'inception_v1',
    'inception_v2',
    'resnet50',

    # TODO currently onnx can't translate squeezenet :(
    # 'squeezenet',

    'vgg16',

    # TODO currently vgg19 doesn't work in the CI environment,
    # possibly due to OOM
    # 'vgg19'
]

def download_models():
    sc = SomeClass()
    for model in models:
        print('update-caffe2-models.py:  downloading', model)
        caffe2_model_dir = sc._caffe2_model_dir(model)
        onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
        if not os.path.exists(caffe2_model_dir):
            sc._download(model)
        if not os.path.exists(onnx_model_dir):
            sc._prepare_model_data(model)

def generate_models():
    sc = SomeClass()
    for model in models:
        print('update-caffe2-models.py:  generating', model)
        caffe2_model_dir = sc._caffe2_model_dir(model)
        onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
        subprocess.check_call(['echo', model])
        with open(os.path.join(caffe2_model_dir, 'value_info.json'), 'r') as f:
            value_info = f.read()
        subprocess.check_call([
            'convert-caffe2-to-onnx',
            '--caffe2-net-name', model,
            '--caffe2-init-net', os.path.join(caffe2_model_dir, 'init_net.pb'),
            '--value-info', value_info,
            '-o', os.path.join(onnx_model_dir, 'model.pb'),
            os.path.join(caffe2_model_dir, 'predict_net.pb')
        ])
        subprocess.check_call([
            'tar',
            '-czf',
            model + '.tar.gz',
            model
        ], cwd=onnx_models_dir)

def upload_models():
    sc = SomeClass()
    for model in models:
        print('update-caffe2-models.py:  uploading', model)
        onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
        subprocess.check_call([
            'aws',
            's3',
            'cp',
            model + '.tar.gz',
            "s3://download.onnx/models/{}.tar.gz".format(model),
            '--acl', 'public-read'
        ], cwd=onnx_models_dir)

def cleanup():
    sc = SomeClass()
    for model in models:
        onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
        os.remove(os.path.join(os.path.dirname(onnx_model_dir), model + '.tar.gz'))

if __name__ == '__main__':
    try:
        subprocess.check_call(['aws', 'sts', 'get-caller-identity'])
    except:
        print('update-caffe2-models.py:  please run `aws configure` manually to set up credentials')
        sys.exit(1)
    if sys.argv[1] == 'download':
        download_models()
    if sys.argv[1] == 'generate':
        generate_models()
    elif sys.argv[1] == 'upload':
        upload_models()
    elif sys.argv[1] == 'cleanup':
        cleanup()
