import time
from collections import namedtuple
from typing import Tuple

import torch
import torch.distributed as dist
from utils import dist_utils, metrics

_LG = dist_utils.getLogger(__name__)

Metric = namedtuple("SNR", ["si_snri", "sdri"])
Metric.__str__ = lambda self: f"SI-SNRi: {self.si_snri:10.3e}, SDRi: {self.sdri:10.3e}"


def si_sdr_improvement(
    estimate: torch.Tensor, reference: torch.Tensor, mix: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute the improvement of scale-invariant SDR. (SI-SNRi) and bare SDR (SDRi).

    Args:
        estimate (torch.Tensor): Estimated source signals.
            Shape: [batch, speakers, time frame]
        reference (torch.Tensor): Reference (original) source signals.
            Shape: [batch, speakers, time frame]
        mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated.
            Shape: [batch, speakers == 1, time frame]
        mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1).
            Shape: [batch, 1, time frame]


    Returns:
        torch.Tensor: Improved SI-SDR. Shape: [batch, ]
        torch.Tensor: Absolute SI-SDR. Shape: [batch, ]

    References:
        - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
          Luo, Yi and Mesgarani, Nima
          https://arxiv.org/abs/1809.07454
    """
    with torch.no_grad():
        sdri = metrics.sdri(estimate, reference, mix, mask=mask)

    estimate = estimate - estimate.mean(axis=2, keepdim=True)
    reference = reference - reference.mean(axis=2, keepdim=True)
    mix = mix - mix.mean(axis=2, keepdim=True)

    si_sdri = metrics.sdri(estimate, reference, mix, mask=mask)
    return si_sdri, sdri


class OccasionalLogger:
    """Simple helper class to log once in a while or when progress is quick enough"""

    def __init__(self, time_interval=180, progress_interval=0.1):
        self.time_interval = time_interval
        self.progress_interval = progress_interval

        self.last_time = 0.0
        self.last_progress = 0.0

    def log(self, metric, progress, force=False):
        now = time.monotonic()
        if force or now > self.last_time + self.time_interval or progress > self.last_progress + self.progress_interval:
            self.last_time = now
            self.last_progress = progress
            _LG.info_on_master("train: %s [%3d%%]", metric, 100 * progress)


class Trainer:
    def __init__(
        self,
        model,
        optimizer,
        train_loader,
        valid_loader,
        eval_loader,
        grad_clip,
        device,
        *,
        debug,
    ):
        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.eval_loader = eval_loader
        self.grad_clip = grad_clip
        self.device = device
        self.debug = debug

    def train_one_epoch(self):
        self.model.train()
        logger = OccasionalLogger()

        num_batches = len(self.train_loader)
        for i, batch in enumerate(self.train_loader, start=1):
            mix = batch.mix.to(self.device)
            src = batch.src.to(self.device)
            mask = batch.mask.to(self.device)

            estimate = self.model(mix)

            si_snri, sdri = si_sdr_improvement(estimate, src, mix, mask)
            si_snri = si_snri.mean()
            sdri = sdri.mean()

            loss = -si_snri
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip, norm_type=2.0)
            self.optimizer.step()

            metric = Metric(si_snri.item(), sdri.item())
            logger.log(metric, progress=i / num_batches, force=i == num_batches)

            if self.debug:
                break

    def evaluate(self):
        with torch.no_grad():
            return self._test(self.eval_loader)

    def validate(self):
        with torch.no_grad():
            return self._test(self.valid_loader)

    def _test(self, loader):
        self.model.eval()

        total_si_snri = torch.zeros(1, dtype=torch.float32, device=self.device)
        total_sdri = torch.zeros(1, dtype=torch.float32, device=self.device)

        for batch in loader:
            mix = batch.mix.to(self.device)
            src = batch.src.to(self.device)
            mask = batch.mask.to(self.device)

            estimate = self.model(mix)

            si_snri, sdri = si_sdr_improvement(estimate, src, mix, mask)

            total_si_snri += si_snri.sum()
            total_sdri += sdri.sum()

            if self.debug:
                break

        dist.all_reduce(total_si_snri, dist.ReduceOp.SUM)
        dist.all_reduce(total_sdri, dist.ReduceOp.SUM)

        num_samples = len(loader.dataset)
        metric = Metric(total_si_snri.item() / num_samples, total_sdri.item() / num_samples)
        return metric
