#!/usr/bin/env python3
"""Train the HuBERTPretrainModel by using labels generated by KMeans clustering.
Example:
python train.py --dataset-path ./exp/data/mfcc/ --feature-type mfcc --num-classes 100
"""

import logging
import pathlib
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawDescriptionHelpFormatter
from typing import Tuple

from lightning.pytorch import seed_everything, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint

from lightning_modules import HuBERTPreTrainModule


logger = logging.getLogger(__name__)


class _Formatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
    # To use ArgumentDefaultsHelpFormatter as the formatter_class and
    # RawDescriptionHelpFormatter to add custom formatting to description or epilog.
    # Check: https://stackoverflow.com/a/18462760
    pass


def run_train(args):
    seed_everything(1337)
    checkpoint_dir = args.exp_dir / f"checkpoints_{args.dataset}_{args.model_name}"
    checkpoint = ModelCheckpoint(
        checkpoint_dir,
        monitor="val_loss",
        mode="min",
        save_top_k=5,
        save_weights_only=False,
        verbose=True,
    )
    train_checkpoint = ModelCheckpoint(
        checkpoint_dir,
        monitor="train_loss",
        mode="min",
        save_top_k=5,
        save_weights_only=False,
        verbose=True,
    )
    callbacks = [
        checkpoint,
        train_checkpoint,
    ]
    trainer = Trainer(
        default_root_dir=args.exp_dir,
        max_steps=args.max_updates,
        num_nodes=args.num_nodes,
        devices=args.gpus,
        accelerator="gpu",
        strategy="ddp_find_unused_parameters_true",
        use_distributed_sampler=False,
        callbacks=callbacks,
        reload_dataloaders_every_n_epochs=1,
    )

    model = HuBERTPreTrainModule(
        model_name=args.model_name,
        feature_grad_mult=args.feature_grad_mult,
        num_classes=args.num_classes,
        dataset=args.dataset,
        dataset_path=args.dataset_path,
        feature_type=args.feature_type,
        seconds_per_batch=args.seconds_per_batch,
        learning_rate=args.learning_rate,
        betas=args.betas,
        eps=args.eps,
        weight_decay=args.weight_decay,
        clip_norm=args.clip_norm,
        warmup_updates=args.warmup_updates,
        max_updates=args.max_updates,
    )
    trainer.fit(model, ckpt_path=args.resume_checkpoint)


def _parse_args():
    parser = ArgumentParser(
        description=__doc__,
        formatter_class=_Formatter,
    )
    parser.add_argument(
        "--dataset-path",
        type=pathlib.Path,
        required=True,
        help="Path to the feature and label directories.",
    )
    parser.add_argument(
        "--resume-checkpoint",
        type=pathlib.Path,
        default=None,
        help="Path to the feature and label directories. (Default: None)",
    )
    parser.add_argument(
        "--feature-type",
        choices=["mfcc", "hubert"],
        type=str,
        required=True,
    )
    parser.add_argument(
        "--feature-grad-mult",
        default=0.1,
        type=float,
        help="The scaling factor to multiply the feature extractor gradient. (Default: 0.1)",
    )
    parser.add_argument(
        "--num-classes",
        choices=[100, 500],
        type=int,
        required=True,
        help="The ``num_class`` when building the hubert_pretrain_base model.",
    )
    parser.add_argument(
        "--model-name",
        default="hubert_pretrain_base",
        choices=["hubert_pretrain_base", "hubert_pretrain_large", "hubert_pretrain_xlarge"],
        type=str,
        help="The HuBERT model to train. (Default: 'hubert_pretrain_base')",
    )
    parser.add_argument(
        "--exp-dir",
        default=pathlib.Path("./exp"),
        type=pathlib.Path,
        help="Directory to save checkpoints and logs to. (Default: './exp')",
    )
    parser.add_argument(
        "--dataset",
        default="librispeech",
        choices=["librispeech", "librilight"],
        type=str,
        help="The dataset for training. (Default: 'librispeech')",
    )
    parser.add_argument(
        "--learning-rate",
        default=0.0005,
        type=float,
        help="The peak learning rate. (Default: 0.0005)",
    )
    parser.add_argument(
        "--betas",
        default=(0.9, 0.98),
        type=Tuple,
        help="The coefficients for computing running averages of gradient and its square (Default: (0.9, 0.98))",
    )
    parser.add_argument(
        "--eps",
        default=1e-6,
        type=float,
        help="Epsilon value in Adam optimizer. (Default: 1e-6)",
    )
    parser.add_argument(
        "--weight-decay",
        default=0.01,
        type=float,
        help="Weight decay (L2 penalty) (default: 0.01)",
    )
    parser.add_argument(
        "--clip-norm",
        default=10.0,
        type=float,
        help="The gradient norm value to clip. (Default: 10.0)",
    )
    parser.add_argument(
        "--num-nodes",
        default=4,
        type=int,
        help="Number of nodes to use for training. (Default: 4)",
    )
    parser.add_argument(
        "--gpus",
        default=8,
        type=int,
        help="Number of GPUs per node to use for training. (Default: 8)",
    )
    parser.add_argument(
        "--warmup-updates",
        default=32000,
        type=int,
        help="Number of steps for warm up the learning rate. (Default: 32000)",
    )
    parser.add_argument(
        "--max-updates",
        default=250000,
        type=int,
        help="Total number of training steps. (Default: 250000)",
    )
    parser.add_argument(
        "--seconds-per-batch",
        default=87.5,
        type=float,
        help="Number of seconds of audio in a mini-batch. (Default: 87.5)",
    )
    parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
    return parser.parse_args()


def _init_logger(debug):
    fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
    level = logging.DEBUG if debug else logging.INFO
    logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")


def cli_main():
    args = _parse_args()
    _init_logger(args.debug)
    run_train(args)


if __name__ == "__main__":
    cli_main()
