import argparse
import logging
import os
from collections import defaultdict
from datetime import datetime
from time import time
from typing import List

import torch
import torchaudio
from datasets import collate_factory, split_process_dataset
from losses import LongCrossEntropyLoss, MoLLoss
from processing import NormalizeDB
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchaudio.models.wavernn import WaveRNN
from utils import count_parameters, MetricLogger, save_checkpoint


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--workers",
        default=4,
        type=int,
        metavar="N",
        help="number of data loading workers",
    )
    parser.add_argument(
        "--checkpoint",
        default="",
        type=str,
        metavar="PATH",
        help="path to latest checkpoint",
    )
    parser.add_argument(
        "--epochs",
        default=8000,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="manual epoch number")
    parser.add_argument(
        "--print-freq",
        default=10,
        type=int,
        metavar="N",
        help="print frequency in epochs",
    )
    parser.add_argument(
        "--dataset",
        default="ljspeech",
        choices=["ljspeech", "libritts"],
        type=str,
        help="select dataset to train with",
    )
    parser.add_argument("--batch-size", default=256, type=int, metavar="N", help="mini-batch size")
    parser.add_argument(
        "--learning-rate",
        default=1e-4,
        type=float,
        metavar="LR",
        help="learning rate",
    )
    parser.add_argument("--clip-grad", metavar="NORM", type=float, default=4.0)
    parser.add_argument(
        "--mulaw",
        default=True,
        action="store_true",
        help="if used, waveform is mulaw encoded",
    )
    parser.add_argument("--jit", default=False, action="store_true", help="if used, model is jitted")
    parser.add_argument(
        "--upsample-scales",
        default=[5, 5, 11],
        type=List[int],
        help="the list of upsample scales",
    )
    parser.add_argument(
        "--n-bits",
        default=8,
        type=int,
        help="the bits of output waveform",
    )
    parser.add_argument(
        "--sample-rate",
        default=22050,
        type=int,
        help="the rate of audio dimensions (samples per second)",
    )
    parser.add_argument(
        "--hop-length",
        default=275,
        type=int,
        help="the number of samples between the starts of consecutive frames",
    )
    parser.add_argument(
        "--win-length",
        default=1100,
        type=int,
        help="the length of the STFT window",
    )
    parser.add_argument(
        "--f-min",
        default=40.0,
        type=float,
        help="the minimum frequency",
    )
    parser.add_argument(
        "--min-level-db",
        default=-100,
        type=float,
        help="the minimum db value for spectrogam normalization",
    )
    parser.add_argument(
        "--n-res-block",
        default=10,
        type=int,
        help="the number of ResBlock in stack",
    )
    parser.add_argument(
        "--n-rnn",
        default=512,
        type=int,
        help="the dimension of RNN layer",
    )
    parser.add_argument(
        "--n-fc",
        default=512,
        type=int,
        help="the dimension of fully connected layer",
    )
    parser.add_argument(
        "--kernel-size",
        default=5,
        type=int,
        help="the number of kernel size in the first Conv1d layer",
    )
    parser.add_argument(
        "--n-freq",
        default=80,
        type=int,
        help="the number of spectrogram bins to use",
    )
    parser.add_argument(
        "--n-hidden-melresnet",
        default=128,
        type=int,
        help="the number of hidden dimensions of resblock in melresnet",
    )
    parser.add_argument(
        "--n-output-melresnet",
        default=128,
        type=int,
        help="the output dimension of melresnet",
    )
    parser.add_argument(
        "--n-fft",
        default=2048,
        type=int,
        help="the number of Fourier bins",
    )
    parser.add_argument(
        "--loss",
        default="crossentropy",
        choices=["crossentropy", "mol"],
        type=str,
        help="the type of loss",
    )
    parser.add_argument(
        "--seq-len-factor",
        default=5,
        type=int,
        help="the length of each waveform to process per batch = hop_length * seq_len_factor",
    )
    parser.add_argument(
        "--val-ratio",
        default=0.1,
        type=float,
        help="the ratio of waveforms for validation",
    )
    parser.add_argument(
        "--file-path",
        default="",
        type=str,
        help="the path of audio files",
    )
    parser.add_argument(
        "--normalization",
        default=True,
        action="store_true",
        help="if True, spectrogram is normalized",
    )

    args = parser.parse_args()
    return args


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch):

    model.train()

    sums = defaultdict(lambda: 0.0)
    start1 = time()

    metric = MetricLogger("train_iteration")
    metric["epoch"] = epoch

    for waveform, specgram, target in data_loader:

        start2 = time()

        waveform = waveform.to(device)
        specgram = specgram.to(device)
        target = target.to(device)

        output = model(waveform, specgram)
        output, target = output.squeeze(1), target.squeeze(1)

        loss = criterion(output, target)
        loss_item = loss.item()
        sums["loss"] += loss_item
        metric["loss"] = loss_item

        optimizer.zero_grad()
        loss.backward()

        if args.clip_grad > 0:
            gradient = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
            sums["gradient"] += gradient.item()
            metric["gradient"] = gradient.item()

        optimizer.step()

        metric["iteration"] = sums["iteration"]
        metric["time"] = time() - start2
        metric()
        sums["iteration"] += 1

    avg_loss = sums["loss"] / len(data_loader)

    metric = MetricLogger("train_epoch")
    metric["epoch"] = epoch
    metric["loss"] = sums["loss"] / len(data_loader)
    metric["gradient"] = avg_loss
    metric["time"] = time() - start1
    metric()


def validate(model, criterion, data_loader, device, epoch):

    with torch.no_grad():

        model.eval()
        sums = defaultdict(lambda: 0.0)
        start = time()

        for waveform, specgram, target in data_loader:

            waveform = waveform.to(device)
            specgram = specgram.to(device)
            target = target.to(device)

            output = model(waveform, specgram)
            output, target = output.squeeze(1), target.squeeze(1)

            loss = criterion(output, target)
            sums["loss"] += loss.item()

        avg_loss = sums["loss"] / len(data_loader)

        metric = MetricLogger("validation")
        metric["epoch"] = epoch
        metric["loss"] = avg_loss
        metric["time"] = time() - start
        metric()

        return avg_loss


def main(args):

    devices = ["cuda" if torch.cuda.is_available() else "cpu"]

    logging.info("Start time: {}".format(str(datetime.now())))

    melkwargs = {
        "n_fft": args.n_fft,
        "power": 1,
        "hop_length": args.hop_length,
        "win_length": args.win_length,
    }

    transforms = torch.nn.Sequential(
        torchaudio.transforms.MelSpectrogram(
            sample_rate=args.sample_rate,
            n_mels=args.n_freq,
            f_min=args.f_min,
            mel_scale="slaney",
            norm="slaney",
            **melkwargs,
        ),
        NormalizeDB(min_level_db=args.min_level_db, normalization=args.normalization),
    )

    train_dataset, val_dataset = split_process_dataset(args, transforms)

    loader_training_params = {
        "num_workers": args.workers,
        "pin_memory": False,
        "shuffle": True,
        "drop_last": False,
    }
    loader_validation_params = loader_training_params.copy()
    loader_validation_params["shuffle"] = False

    collate_fn = collate_factory(args)

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        collate_fn=collate_fn,
        **loader_training_params,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        collate_fn=collate_fn,
        **loader_validation_params,
    )

    n_classes = 2**args.n_bits if args.loss == "crossentropy" else 30

    model = WaveRNN(
        upsample_scales=args.upsample_scales,
        n_classes=n_classes,
        hop_length=args.hop_length,
        n_res_block=args.n_res_block,
        n_rnn=args.n_rnn,
        n_fc=args.n_fc,
        kernel_size=args.kernel_size,
        n_freq=args.n_freq,
        n_hidden=args.n_hidden_melresnet,
        n_output=args.n_output_melresnet,
    )

    if args.jit:
        model = torch.jit.script(model)

    model = torch.nn.DataParallel(model)
    model = model.to(devices[0], non_blocking=True)

    n = count_parameters(model)
    logging.info(f"Number of parameters: {n}")

    # Optimizer
    optimizer_params = {
        "lr": args.learning_rate,
    }

    optimizer = Adam(model.parameters(), **optimizer_params)

    criterion = LongCrossEntropyLoss() if args.loss == "crossentropy" else MoLLoss()

    best_loss = 10.0

    if args.checkpoint and os.path.isfile(args.checkpoint):
        logging.info(f"Checkpoint: loading '{args.checkpoint}'")
        checkpoint = torch.load(args.checkpoint)

        args.start_epoch = checkpoint["epoch"]
        best_loss = checkpoint["best_loss"]

        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])

        logging.info(f"Checkpoint: loaded '{args.checkpoint}' at epoch {checkpoint['epoch']}")
    else:
        logging.info("Checkpoint: not found")

        save_checkpoint(
            {
                "epoch": args.start_epoch,
                "state_dict": model.state_dict(),
                "best_loss": best_loss,
                "optimizer": optimizer.state_dict(),
            },
            False,
            args.checkpoint,
        )

    for epoch in range(args.start_epoch, args.epochs):

        train_one_epoch(
            model,
            criterion,
            optimizer,
            train_loader,
            devices[0],
            epoch,
        )

        if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1:

            sum_loss = validate(model, criterion, val_loader, devices[0], epoch)

            is_best = sum_loss < best_loss
            best_loss = min(sum_loss, best_loss)
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": model.state_dict(),
                    "best_loss": best_loss,
                    "optimizer": optimizer.state_dict(),
                },
                is_best,
                args.checkpoint,
            )

    logging.info(f"End time: {datetime.now()}")


if __name__ == "__main__":

    logging.basicConfig(level=logging.INFO)
    args = parse_args()
    main(args)
