# Copyright (c) 2006 Carnegie Mellon University
#
# You may copy and modify this freely under the same terms as
# Sphinx-III
"""Compute MFCC coefficients.

This module provides functions for computing MFCC (mel-frequency
cepstral coefficients) as used in the Sphinx speech recognition
system.
"""

__author__ = "David Huggins-Daines <dhdaines@gmail.com>"
__version__ = "$Revision$"

import numpy
import math
import numpy.fft


def mel(f):
    return 2595. * numpy.log10(1. + f / 700.)


def melinv(m):
    return 700. * (numpy.power(10., m / 2595.) - 1.)


def s2dctmat(nfilt, ncep, freqstep):
    """Return the 'legacy' not-quite-DCT matrix used by Sphinx"""
    melcos = numpy.empty((ncep, nfilt), 'double')
    for i in range(0, ncep):
        freq = numpy.pi * float(i) / nfilt
        melcos[i] = numpy.cos(freq *
                              numpy.arange(0.5,
                                           float(nfilt) + 0.5, 1.0, 'double'))
    melcos[:, 0] = melcos[:, 0] * 0.5
    return melcos


def logspec2s2mfc(logspec, ncep=13):
    """Convert log-power-spectrum bins to MFCC using the 'legacy'
    Sphinx transform"""
    nframes, nfilt = logspec.shape
    melcos = s2dctmat(nfilt, ncep, 1. / nfilt)
    return numpy.dot(logspec, melcos.T) / nfilt


def dctmat(N, K, freqstep, orthogonalize=True):
    """Return the orthogonal DCT-II/DCT-III matrix of size NxK.
    For computing or inverting MFCCs, N is the number of
    log-power-spectrum bins while K is the number of cepstra."""
    cosmat = numpy.zeros((N, K), 'double')
    for n in range(0, N):
        for k in range(0, K):
            cosmat[n, k] = numpy.cos(freqstep * (n + 0.5) * k)
    if orthogonalize:
        cosmat[:, 0] = cosmat[:, 0] * 1. / numpy.sqrt(2)
    return cosmat


def dct(input, K=13, htk=False):
    """Convert log-power-spectrum to MFCC using the orthogonal DCT-II"""
    nframes, N = input.shape
    freqstep = numpy.pi / N
    cosmat = dctmat(N, K, freqstep)
    if htk:
        return numpy.dot(input, cosmat) * numpy.sqrt(2.0 / N)
    else:
        return numpy.dot(input, cosmat) * numpy.sqrt(1.0 / N)


def dct2(input, K=13):
    """Convert log-power-spectrum to MFCC using the normalized DCT-II"""
    nframes, N = input.shape
    freqstep = numpy.pi / N
    cosmat = dctmat(N, K, freqstep, False)
    return numpy.dot(input, cosmat) * (2.0 / N)


def idct(input, K=40):
    """Convert MFCC to log-power-spectrum using the orthogonal DCT-III"""
    nframes, N = input.shape
    freqstep = numpy.pi / K
    cosmat = dctmat(K, N, freqstep).T
    return numpy.dot(input, cosmat) * numpy.sqrt(2.0 / K)


def dct3(input, K=40):
    """Convert MFCC to log-power-spectrum using the unnormalized DCT-III"""
    nframes, N = input.shape
    freqstep = numpy.pi / K
    cosmat = dctmat(K, N, freqstep, False)
    cosmat[:, 0] = cosmat[:, 0] * 0.5
    return numpy.dot(input, cosmat.T)


class MFCC(object):
    def __init__(self,
                 nfilt=40,
                 ncep=13,
                 lowerf=133.3333,
                 upperf=6855.4976,
                 alpha=0.97,
                 samprate=16000,
                 frate=100,
                 wlen=0.0256,
                 nfft=512,
                 transform="legacy",
                 lifter=0):
        # Store parameters
        self.lowerf = lowerf
        self.upperf = upperf
        self.nfft = nfft
        self.ncep = ncep
        self.nfilt = nfilt
        self.frate = frate
        self.fshift = float(samprate) / frate

        # Build Hamming window
        self.wlen = int(wlen * samprate)
        self.win = numpy.hamming(self.wlen)

        # Prior sample for pre-emphasis
        self.prior = 0
        self.alpha = alpha

        # Build mel filter matrix
        self.filters = numpy.zeros((nfft // 2 + 1, nfilt), 'd')
        dfreq = float(samprate) / nfft
        if upperf > samprate / 2:
            raise Exception
        melmax = mel(upperf)
        melmin = mel(lowerf)
        dmelbw = (melmax - melmin) / (nfilt + 1)
        # Filter edges, in Hz
        filt_edge = melinv(melmin +
                           dmelbw * numpy.arange(nfilt + 2, dtype='d'))

        for whichfilt in range(0, nfilt):
            # Filter triangles, in DFT points
            leftfr = round(filt_edge[whichfilt] / dfreq)
            centerfr = round(filt_edge[whichfilt + 1] / dfreq)
            rightfr = round(filt_edge[whichfilt + 2] / dfreq)
            # For some reason this is calculated in Hz, though I think
            # it doesn't really matter
            fwidth = (rightfr - leftfr) * dfreq
            height = 2. / fwidth

            if centerfr != leftfr:
                leftslope = height / (centerfr - leftfr)
            else:
                leftslope = 0
            freq = leftfr + 1
            while freq < centerfr:
                self.filters[freq, whichfilt] = (freq - leftfr) * leftslope
                freq = freq + 1
            if freq == centerfr:  # This is always true
                self.filters[freq, whichfilt] = height
                freq = freq + 1
            if centerfr != rightfr:
                rightslope = height / (centerfr - rightfr)
            while freq < rightfr:
                self.filters[freq, whichfilt] = (freq - rightfr) * rightslope
                freq = freq + 1

            # print("Filter %d: left %d=%f center %d=%f right %d=%f width %d" %
            #       (whichfilt,
            #        leftfr, leftfr*dfreq,
            #        centerfr, centerfr*dfreq,
            #        rightfr, rightfr*dfreq,
            #        freq - leftfr))
            # print(self.filters[leftfr:rightfr, whichfilt])

        # Build DCT matrix
        self.s2dct = s2dctmat(nfilt, ncep, 1. / nfilt)
        self.dct = dctmat(nfilt, ncep, numpy.pi / nfilt)

        # Choice of transform
        self.transform = transform
        
        # Liftering weights
        if lifter:
            self.lifter = (1 + lifter / 2
                           * numpy.sin(numpy.arange(ncep) * numpy.pi / lifter))
        else:
            self.lifter = numpy.ones(ncep)

    def sig2s2mfc(self, sig):
        nfr = int(len(sig) / self.fshift + 1)
        mfcc = numpy.zeros((nfr, self.ncep), 'd')
        fr = 0
        while fr < nfr:
            start = round(fr * self.fshift)
            end = min(len(sig), start + self.wlen)
            frame = sig[start:end]
            if len(frame) < self.wlen:
                frame = numpy.resize(frame, self.wlen)
                frame[self.wlen:] = 0
            mfcc[fr] = self.frame2s2mfc(frame)
            fr = fr + 1
        return mfcc

    def sig2logspec(self, sig):
        nfr = int(len(sig) / self.fshift + 1)
        mfcc = numpy.zeros((nfr, self.nfilt), 'd')
        fr = 0
        while fr < nfr:
            start = round(fr * self.fshift)
            end = min(len(sig), start + self.wlen)
            frame = sig[start:end]
            if len(frame) < self.wlen:
                frame = numpy.resize(frame, self.wlen)
                frame[self.wlen:] = 0
            mfcc[fr] = self.frame2logspec(frame)
            fr = fr + 1
        return mfcc

    def pre_emphasis(self, frame):
        # FIXME: Do this with matrix multiplication
        outfr = numpy.empty(len(frame), 'd')
        outfr[0] = frame[0] - self.alpha * self.prior
        for i in range(1, len(frame)):
            outfr[i] = frame[i] - self.alpha * frame[i - 1]
        self.prior = frame[-1]
        return outfr

    def frame2logspec(self, frame):
        frame = self.pre_emphasis(frame) * self.win
        fft = numpy.fft.rfft(frame, self.nfft)
        # Square of absolute value
        power = fft.real * fft.real + fft.imag * fft.imag
        return numpy.log(numpy.dot(power, self.filters).clip(1e-5, numpy.inf))

    def frame2s2mfc(self, frame):
        logspec = self.frame2logspec(frame)
        if self.transform == "legacy":
            mfcc = numpy.dot(logspec, self.s2dct.T) / self.nfilt
        elif self.transform == "dct":
            mfcc = numpy.dot(logspec, self.dct) * numpy.sqrt(2.0 / self.nfilt)
        else:
            raise RuntimeError("Unknown transform %s" % self.transform)
        return mfcc * self.lifter
