""" This module contains functions to calculate the quasiparticle tunneling 
currents passing through an SIS junction. 

Description:

    Given the voltages applied across an SIS junction, the quasiparticle 
    tunneling currents can be calculated using multi-tone spectral domain 
    analysis (MTSDA; see references in online docs). 

Note: 

    This code is largely based on P. Kittara's 2002 DPhil thesis (see 
    references in online docs). I include some inline comments to refer to
    specific equations.

    Also, all of the values in this module are normalized, i.e., voltages are
    normalized to the gap voltage, frequencies are normalized to the gap 
    frequency, etc.

"""

from timeit import default_timer as timer

#import numba as nb
import numpy as np
from scipy.special import jv as bessel


# round frequency values to this number of decimal places
# required when comparing frequency to frequency_list
ROUND_FREQ = 4  


# Determine the dc/ac tunneling currents -------------------------------------

def qtcurrent(vj, cct, resp, freq_list, num_b=15, verbose=True, resp_matrix=None):
    """Calculate the quasiparticle tunneling current.

    This function uses multi-tone spectral domain analysis (MTSDA; see 
    references in online docs). The current is calculated based on the
    voltage applied across the junction (vj).

    Note:

        This function will return the tunneling current for all of the 
        frequencies listed in freq_list (normalized to the gap frequency).
        E.g., to solve for the dc tunneling current and the ac tunneling
        current at 230 GHz, the ``freq_list`` would  be ``[0, 230e9 / fgap]``
        where ``fgap`` is the gap frequency.

        Maximum of 4 non-harmonically related tones.

    Args:
        vj (ndarray): Voltage across the SIS junction
        cct (qmix.circuit.EmbeddingCircuit): Embedding circuit
        resp (qmix.respfn.RespFn): Response function
        freq_list: Calculate the tunneling currents for these frequencies
            (normalized to the gap frequency)
        num_b (float/tuple, optional): Summation limits for phase factor
            coefficients, default is 15
        verbose (bool, optional): Print info to the terminal if true, default
            is True
        resp_matrix (ndarray, optional): The interpolated response function
            matrix, generated by interpolate_respfn(), default is None

    Returns:
        ndarray: Quasiparticle tunneling current

    """

    # Load, prepare and check input data -------------------------------------

    num_f = cct.num_f   # number of frequencies
    num_p = cct.num_p   # number of harmonics
    npts = cct.vb_npts  # number of bias voltages

    assert cct.freq[1:].min() > 0., "All freq must be > 0!"

    # TODO: there must be a better way...
    try:
        freq_list = list(freq_list)
        freq_is_list = True
    except TypeError:
        freq_list = [float(freq_list)]
        freq_is_list = False

    # TODO: there must be a better way...
    for i, freq_val in enumerate(freq_list):
        freq_list[i] = round(freq_val, ROUND_FREQ)
    freq_npts = len(freq_list)

    freq = cct.freq

    nb_list = _unpack_num_b(num_b, num_f)

    if verbose:
        print("Calculating tunneling current...")
        print(" - {0} tone(s)".format(cct.num_f))
        print(" - {0} harmonic(s)".format(cct.num_p))
        start_time = timer()

    # Convolution coefficients ------------------------------------------------

    ccc = calculate_phase_factor_coeff(vj, freq, num_f, num_p, num_b)

    # Interpolate response function ------------------------------------------

    if resp_matrix is None:
        resp_matrix = interpolate_respfn(cct, resp, num_b)
    else:
        assert resp_matrix.ndim == num_f + 1
        assert resp_matrix.shape[-1] == npts

    # Call the correct function depending on the number of tones---------------

    current_out = np.zeros((freq_npts, cct.vb_npts), dtype=complex)

    if num_f == 1:
        for i in range(freq_npts):
            current_out[i] = _current_1_tone(freq_list[i], ccc, freq, resp_matrix, num_p, npts, *nb_list)
    elif num_f == 2:
        for i in range(freq_npts):
            current_out[i] = _current_2_tones(freq_list[i], ccc, freq, resp_matrix, num_p, npts, *nb_list)
    elif num_f == 3:
        for i in range(freq_npts):
            current_out[i] = _current_3_tones(freq_list[i], ccc, freq, resp_matrix, num_p, npts, *nb_list)
    elif num_f == 4:
        for i in range(freq_npts):
            current_out[i] = _current_4_tones(freq_list[i], ccc, freq, resp_matrix, num_p, npts, *nb_list)

    # Done --------------------------------------------------------------------

    if verbose:
        print("Done.")
        print("Time: {0:.4f} s\n".format(timer() - start_time))

    if freq_is_list:
        return current_out
    else:
        if freq_list[0] == 0.:
            return current_out[0].real
        else:
            return current_out[0]


# Response function matrices --------------------------------------------------
# The response function (the dc I-V curve and it's KK transform) needs to be
# repeatedly interpolated in this module. The functions below do all of the
# necessary interpolations all at once to save time.
#
# Note: Interpolating the response function is one of the most time consuming 
# operations within this module.
#
# Two different methods are used below to generate the interpolation voltages:
#    - one using loops
#    - one without
# I've spent some time optimizing each and using the correct method for each
# number of tones.
#
# Runs once per qtcurrent function call

# TODO: write better tests for this function
def interpolate_respfn(cct, resp, num_b):
    """Interpolate the response function at all necessary voltages.

    I have included this as a stand-alone function because if you are going
    to be running ``qtcurrent`` over and over again with the same input 
    signal frequencies, it can save time by pre-interpolating the response 
    function.

    Args:
        cct (qmix.circuit.EmbeddingCircuit): Embedding circuit
        resp (qmix.respfn.RespFn): Response function
        num_b (int/tuple): Summation limits for phase factor coefficients

    Returns:
        ndarray: The interpolated response function as a matrix.

    """

    nb_list = _unpack_num_b(num_b, cct.num_f)

    if cct.num_f == 1:
        resp_matrix = _interpolate_respfn_1_tone(resp, cct.vb, cct.freq, *nb_list)
    elif cct.num_f == 2:
        resp_matrix = _interpolate_respfn_2_tone(resp, cct.vb, cct.freq, *nb_list)
    elif cct.num_f == 3:
        resp_matrix = _interpolate_respfn_3_tone(resp, cct.vb, cct.freq, *nb_list)
    elif cct.num_f == 4:
        resp_matrix = _interpolate_respfn_4_tone(resp, cct.vb, cct.freq, *nb_list)
    else:
        print("num_f must be 1, 2, 3 or 4!")
        raise ValueError

    return resp_matrix


def _interpolate_respfn_1_tone(resp, vb, freq, num_b1):
    """Interpolate the response function (1 tone).

    Args:
        resp (qmix.respfn.RespFn): Response function
        vb (ndarray): Bias voltages, normalized
        freq (ndarray): Frequencies, normalized
        num_b1 (int): Summation limits for phase factor coefficients

    Returns:
        ndarray: Response function, interpolated

    """

    npts = len(vb)
    k_npts = num_b1 * 2 + 1
    vb_tmp = vb[None, :] * np.ones(k_npts)[:, None]
    ind = np.r_[np.arange(0, num_b1+1), np.arange(-num_b1, 0)]
    k_array = ind[:, None] * np.ones(npts, dtype=int)[None, :]
    resp_out = resp(vb_tmp + k_array * freq[1])

    # # DEBUG
    # print k_array[:,0]
    # print len(k_array[0,:])
    # print " {} -> {}".format(-num_b1, k_array[-num_b1][0])
    # print " 0 -> {}".format(k_array[0][0])
    # print " {} -> {}".format(num_b1, k_array[num_b1][0])

    return resp_out


def _interpolate_respfn_2_tone(resp, vb, freq, num_b1, num_b2):
    """Interpolate the response function (2 tones).

    Args:
        resp (qmix.respfn.RespFn): Response function
        vb (ndarray): Bias voltages, normalized
        freq (ndarray): Frequencies, normalized
        num_b1 (int): Summation limits for phase factor coefficients (tone 1)
        num_b2 (int): Summation limits for phase factor coefficients (tone 2)

    Returns:
        ndarray: Response function, interpolated

    """

    npts = len(vb)
    k_npts = num_b1 * 2 + 1
    l_npts = num_b2 * 2 + 1

    ind = np.r_[np.arange(0, num_b1 + 1), np.arange(-num_b1, 0)]
    k_array = ind[:, None, None] * np.ones((l_npts, npts), dtype=int)[None, :, :] 

    ind = np.r_[np.arange(0, num_b2 + 1), np.arange(-num_b2, 0)]
    l_array = ind[None, :, None] * np.ones((k_npts, npts), dtype=int)[:, None, :] 

    vb_tmp = vb[None, None, :] * np.ones((k_npts, l_npts))[:, :, None]
    resp_out = resp(vb_tmp + k_array * freq[1] + l_array * freq[2])

    # # DEBUG
    # print k_array[:,0,0]
    # print " {} -> {}".format(-num_b1, k_array[-num_b1][0,0])
    # print " {} -> {}".format(0, k_array[0][0,0])
    # print " {} -> {}".format(num_b1, k_array[num_b1][0,0])
    # print l_array[0,:,0]
    # print " {} -> {}".format(-num_b2, l_array[:,-num_b2][0,0])
    # print " {} -> {}".format(0, l_array[:,0][0,0])
    # print " {} -> {}".format(num_b2, l_array[:,num_b2][0,0])

    return resp_out


def _interpolate_respfn_3_tone(resp, vb, freq, num_b1, num_b2, num_b3):
    """Interpolate the response function (3 tones).

    Args:
        resp (qmix.respfn.RespFn): Response function
        vb (ndarray): Bias voltages, normalized
        freq (ndarray): Frequencies, normalized
        num_b1 (int): Summation limits for phase factor coefficients (tone 1)
        num_b2 (int): Summation limits for phase factor coefficients (tone 2)
        num_b3 (int): Summation limits for phase factor coefficients (tone 3)

    Returns:
        ndarray: Response function, interpolated

    """

    npts = len(vb)
    voltage = np.zeros((num_b1 * 2 + 1, num_b2 * 2 + 1, num_b3 * 2 + 1, npts))
    for k in range(-num_b1, num_b1 + 1):
        for l in range(-num_b2, num_b2 + 1):
            for m in range(-num_b3, num_b3 + 1):
                voltage[k, l, m] = vb + k * freq[1] + l * freq[2] + m * freq[3]
    resp_out = resp(voltage)

    return resp_out


def _interpolate_respfn_4_tone(resp, vb, freq, num_b1, num_b2, num_b3, num_b4):
    """Interpolate the response function (3 tones).

    Args:
        resp (qmix.respfn.RespFn): Response function
        vb (ndarray): Bias voltages, normalized
        freq (ndarray): Frequencies, normalized
        num_b1 (int): Summation limits for phase factor coefficients (tone 1)
        num_b2 (int): Summation limits for phase factor coefficients (tone 2)
        num_b3 (int): Summation limits for phase factor coefficients (tone 3)
        num_b4 (int): Summation limits for phase factor coefficients (tone 4)

    Returns:
        ndarray: Response function, interpolated

    """

    npts = len(vb)
    voltage = np.zeros((num_b1 * 2 + 1, num_b2 * 2 + 1, num_b3 * 2 + 1, num_b4 * 2 + 1, npts))
    for k in range(-num_b1, num_b1 + 1):
        for l in range(-num_b2, num_b2 + 1):
            for m in range(-num_b3, num_b3 + 1):
                for n in range(-num_b4, num_b4 + 1):
                    voltage[k, l, m, n, :] = vb + k * freq[1] + l * freq[2] + m * freq[3] + n * freq[4]
    resp_out = resp(voltage)

    return resp_out


# Calculate the overall phase factor spectrum coefficients -------------------

def calculate_phase_factor_coeff(vj, freq, num_f, num_p, num_b):
    """Calculate the overall phase factor spectrum coefficients.

    Runs once per qtcurrent function call.

    Eqns. 5.7 and 5.12 in Kittara's thesis.

    Args:
        vj (ndarray): Voltage across the SIS junction
        freq (ndarray): Frequencies
        num_f (int): Number of non-harmonically related frequencies
        num_p (int): Number of harmonics
        num_b (int): Summation limits for phase factor coefficients

    Returns:
        ndarray: Phase factor spectrum coefficients (C_k(H) in Kittara)

    """

    # Number of bias voltage points
    npts = len(vj[0, 0, :])

    # Summation limits for phase factor coefficients
    if isinstance(num_b, int):
        num_b = tuple([num_b] * num_f)

    # Junction drive level:
    # alpha[f, p, i] in R^(num_f+1)(num_p+1)(npts)
    # Eqn. 5.5 in Kittara's thesis
    alpha = np.zeros_like(vj, dtype=float)
    for f in range(1, num_f + 1):
        for p in range(1, num_p + 1):
            alpha[f, p, :] = np.abs(vj[f, p, :]) / (p * freq[f])

    # Junction voltage phase:
    # phi[f, p, i] in R^(num_f+1)(num_p+1)(npts)
    phi = np.angle(vj)  # in radians

    # Complex coefficients from the Jacobi-Anger equality:
    # jac[f, p, n, i] in C^(num_f+1)(num_p+1)(num_b*2+1)(npts)
    # Equation 5.7 in Kittara's thesis
    # Note: This chunk of code dominates the computation time of this function
    # I tried using the recurrence relation, but ran into numerical errors
    jac = np.zeros((num_f + 1, num_p + 1, max(num_b) * 2 + 1, npts), dtype=complex)
    for f in range(1, num_f + 1):
        for p in range(1, num_p + 1):
            jac[f, p,  0] = bessel(0, alpha[f, p])
            for n in range(1, num_b[f - 1] + 1):
                # using Bessel function identity
                jn = bessel(n, alpha[f, p])
                jac[f, p,  n] = jn * np.exp(-1j * n * phi[f, p])
                jac[f, p, -n] = (-1)**n * np.conj(jac[f, p,  n])

    # Overall phase factor coefficients:
    # ckh[f, k, i] in C^(num_f+1)(num_b*2+1)(npts)
    ckh = _convolve_coefficients(jac)

    return ckh


#@nb.njit("c16[:,:,:](c16[:,:,:,:])")
def _convolve_coefficients(jac):  # pragma: no cover
    """Convolve spectrum coefficients (recursively).

    See Withington and Kollberg, 1989.

    This function is only used if there are higher-order harmonics
    (num_p > 1).

    Calculation time is proportional to num_p.

    Eqn. 5.12 in Kittara's thesis.

    Args:
        jac (ndarray): Complex coefficients from the Jacobi-Anger equality
            (Eqn. 5.7 in Kittara's thesis)

    Returns:
        ndarray: Overall phase factor spectrum coefficients

    """

    _, num_p, num_b, _ = jac.shape
    num_p -= 1                # number of harmonics
    num_b = (num_b - 1) // 2  # number of bessel functions

    ckh_last = jac[:, 1, :, :]
    if num_p == 1:
        return ckh_last

    for p in range(2, num_p + 1):
        ckh_next = np.zeros_like(ckh_last)
        for k in range(-num_b, num_b + 1):
            # Don't exceed indices
            l_min = max(-num_b, int((k - num_b) / p))
            l_max = min( num_b, int((k + num_b) / p))
            for l in range(l_min, l_max + 1):
                ckh_next[1:, k] += ckh_last[1:, k - p * l] * jac[1:, p, l]
        ckh_last = ckh_next

    return ckh_last


# Tunneling current functions ------------------------------------------------
# These are the functions that actually calculate the tunneling current.
# Different functions are provided for different numbers of tones. They are
# all built the same way except that every additional tone will add another
# layer of coefficients and for-loops.
# TODO: optimize further, vectorize for loops (?)
# TODO: write general function, for any number of tones

def _current_1_tone(freq_out, ccc, freq, resp_matrix, num_p, npts, num_b1):
    """Calculate the tunneling current at a specific frequency.

    One tone.

    Frequency is normalized to the gap frequency.

    Args:
        freq_out (float): frequencies of output values
        ccc (ndarray): convolution coefficients
        freq (ndarray): frequencies
        resp_matrix (ndarray): Response function matrix, generated by 
            interpolate_respfn
        num_p (int): Number of harmonics
        npts (int): Number of bias voltage points
        num_b1 (int): Summation limits for phase factor coefficients (tone 1)

    Returns:
        ndarray: Tunneling current at specified frequency

    """

    freq_out = round(freq_out, ROUND_FREQ)
    current_out = np.zeros(npts, dtype=complex)
    
    for a in range(num_p, -(num_p + 1), -1):

        freq_a = round(a * freq[1], ROUND_FREQ)

        if freq_a == freq_out:

            current_out += _current_coeff_1_tone(a, ccc, resp_matrix, num_b1, npts)

    return current_out


#@nb.njit("c16[:](i4, c16[:,:,:], c16[:,:], i4, i4)")
def _current_coeff_1_tone(a, ccc, resp_matrix, num_b1, npts):  # pragma: no cover
    """Calculate the tunneling current coefficient (for 1 tone).

    Calculate (I(a)) for a one tone system.
    
    Equations 5.25 and 5.26 in Kittara's thesis.

    Args:
        a (int): Index a in Eqn. 5.25
        ccc (ndarray): Convolution coefficient
        resp_matrix (ndarray): Response function matrix, generated by 
            interpolate_respfn
        num_b1 (int): Summation limits for phase factor coefficients (tone 1)
        npts (int): Number of bias voltage points

    Returns:
        ndarray: Tunneling current coefficient

    """

    # Equation 5.17
    rs_p = np.zeros(npts, dtype=np.complex128)  # positive coefficients
    rs_m = np.zeros(npts, dtype=np.complex128)  # negative coefficients
    ccc_conj = np.conj(ccc[1])
    for k in range(-num_b1, num_b1 + 1):

        if -num_b1 <= k + a <= num_b1:
            rs_p += ccc[1, k, :] * ccc_conj[k + a, :] * resp_matrix[k]

        if -num_b1 <= k - a <= num_b1:
            rs_m += ccc[1, k, :] * ccc_conj[k - a, :] * resp_matrix[k]

    # Calculate current coefficient: equation 5.26
    if a == 0:
        return rs_p.imag + 1j * 0
    else:
        return (rs_p.imag + rs_m.imag) - 1j * (rs_p.real - rs_m.real)


def _current_2_tones(freq_out, ccc, freq, resp_matrix, num_p, npts, num_b1, num_b2):
    """Calculate the tunneling current at a specific frequency.

    Two tones.

    Frequency is normalized to the gap frequency.

    Args:
        freq_out (float): frequency to solve for
        ccc (ndarray): convolution coefficients
        freq (ndarray): frequencies
        resp_matrix (ndarray): Response function matrix, generated by 
            interpolate_respfn
        num_p (int): Number of harmonics
        npts (int): Number of bias voltage points
        num_b1 (int): Summation limits for phase factor coefficients (tone 1)
        num_b2 (int): Summation limits for phase factor coefficients (tone 2)

    Returns:
        ndarray: Tunneling current at specified frequency
        
    """

    freq_out = round(freq_out, ROUND_FREQ)
    current_out = np.zeros(npts, dtype=complex)
    
    for a in range(num_p, -(num_p + 1), -1):
        for b in range(num_p, -(num_p + 1), -1):

            freq_ab = round(a * freq[1] + b * freq[2], ROUND_FREQ)

            if freq_ab == freq_out:

                current_out += _current_coeff_2_tones(a, b, ccc, resp_matrix, num_b1, num_b2, npts)

    return current_out


#@nb.njit("c16[:](i4, i4, c16[:,:,:], c16[:,:,:], i4, i4, i4)")
def _current_coeff_2_tones(a, b, ccc, resp_matrix, num_b1, num_b2, npts):  # pragma: no cover
    """Calculate the tunneling current coefficient (for 2 tones).

    Calculate (I(a,b)) for a two tone system (i.e., for an (a,b) 
    combination versus calculating the entire matrix for every (a,b) pair). 
    
    Equations 5.25 and 5.26 in Kittara's thesis.

    Args:
        a (int): Index a in Eqn. 5.25
        b (int): Index b in Eqn. 5.25
        ccc (ndarray): Convolution coefficient
        resp_matrix (ndarray): Response function matrix, generated by 
            interpolate_respfn
        num_b1 (int): Summation limits for phase factor coefficients (tone 1)
        num_b2 (int): Summation limits for phase factor coefficients (tone 2)
        npts (int): Number of bias voltage points

    Returns:
        ndarray: Tunneling current coefficient

    """

    # Equation 5.25
    rs_p = np.zeros(npts, dtype=np.complex128)
    rs_m = np.zeros(npts, dtype=np.complex128)
    ccc_conj = np.conj(ccc)
    for k in range(-num_b1, num_b1 + 1):
        for l in range(-num_b2, num_b2 + 1):

            if -num_b1 <= k + a <= num_b1 and \
               -num_b2 <= l + b <= num_b2:
                rs_p += ccc[1, k, :] * ccc_conj[1, k + a, :] * \
                        ccc[2, l, :] * ccc_conj[2, l + b, :] * \
                        resp_matrix[k, l]

            if -num_b1 <= k - a <= num_b1 and \
               -num_b2 <= l - b <= num_b2:
                rs_m += ccc[1, k, :] * ccc_conj[1, k - a, :] * \
                        ccc[2, l, :] * ccc_conj[2, l - b, :] * \
                        resp_matrix[k, l]

    # Calculate current coefficient: equation 5.26
    if a == 0 and b == 0:
        return rs_p.imag + 1j * 0.
    else:
        return (rs_p.imag + rs_m.imag) - 1j * (rs_p.real - rs_m.real)


def _current_3_tones(freq_out, ccc, freq, resp_matrix, num_p, npts, num_b1, num_b2, num_b3):
    """Calculate the tunneling current at a specific frequency.

    Three tones.

    Frequency is normalized to the gap frequency.

    Args:
        freq_out (float): frequency to solve for
        ccc (ndarray): convolution coefficients
        freq (ndarray): frequencies
        resp_matrix (ndarray): Response function matrix, generated by 
            interpolate_respfn
        num_p (int): Number of harmonics
        npts (int): Number of bias voltage points
        num_b1 (int): Summation limits for phase factor coefficients (tone 1)
        num_b2 (int): Summation limits for phase factor coefficients (tone 2)
        num_b3 (int): Summation limits for phase factor coefficients (tone 3)

    Returns:
        ndarray: Tunneling current at specified frequency
        
    """

    freq_out = round(freq_out, ROUND_FREQ)
    current_out = np.zeros(npts, dtype=complex)

    for a in range(num_p, -(num_p + 1), -1):
        for b in range(num_p, -(num_p + 1), -1):
            for c in range(num_p, -(num_p + 1), -1):

                freq_abc = round(a * freq[1] + b * freq[2] + c * freq[3], ROUND_FREQ)

                # # Debug
                # match = freq_abc == freq_out
                # if not match:
                #     match = ''
                # else:
                #     match = 'match!'
                # print "{:+d} {:+d} {:+d} {:+7.4f} {}".format(a, b, c, freq_abc, match)

                if freq_abc == freq_out:

                    # # Debug
                    # msg = "\t -> {:+d}*{:.4f} {:+d}*{:.4f} {:+d}*{:10.4f} = {:.4f}"
                    # print msg.format(a, freq[1], 
                    #                  b, freq[2],
                    #                  c, freq[3], freq_out)

                    current_out += _current_coeff_3_tones(a, b, c, ccc, resp_matrix,
                                                          num_b1, num_b2, num_b3)

    return current_out


#@nb.njit("c16[:](i4, i4, i4, c16[:,:,:], c16[:,:,:,:], i4, i4, i4)")
def _current_coeff_3_tones(a, b, c, ccc, resp_matrix, num_b1, num_b2, num_b3):  # pragma: no cover
    """Calculate the tunneling current coefficient (for 3 tones).

    Calculate (I(a,b,c)) for a three tone system (i.e., for an (a,b,c) 
    combination versus calculating the entire matrix for every (a,b,c) pair). 
    
    Equations 5.25 and 5.26 in Kittara's thesis.

    Args:
        a (int): Index a in Eqn. 5.25
        b (int): Index b in Eqn. 5.25
        c (int): Index c in Eqn. 5.25
        ccc (ndarray): Convolution coefficient
        resp_matrix (ndarray): Response function matrix, generated by 
            interpolate_respfn
        num_b1 (int): Summation limits for phase factor coefficients (tone 1)
        num_b2 (int): Summation limits for phase factor coefficients (tone 2)
        num_b3 (int): Summation limits for phase factor coefficients (tone 3)

    Returns:
        ndarray: Tunneling current coefficient

    """

    # Recast coefficients
    ccc1 = ccc[1]
    ccc2 = ccc[2]
    ccc3 = ccc[3]

    # Equation 5.25
    rs_p = np.zeros_like(ccc1[0, :], dtype=np.complex128)
    rs_m = np.zeros_like(ccc1[0, :], dtype=np.complex128)
    for k in range(-num_b1, num_b1 + 1):
        for l in range(-num_b2, num_b2 + 1):
            for m in range(-num_b3, num_b3 + 1):

                c0 = ccc1[k] * ccc2[l] * ccc3[m]
                resp_current = resp_matrix[k, l, m]

                if -num_b1 <= k + a <= num_b1 and \
                   -num_b2 <= l + b <= num_b2 and \
                   -num_b3 <= m + c <= num_b3:

                    cp = np.conj(ccc1[k + a, :] *
                                 ccc2[l + b, :] *
                                 ccc3[m + c, :]) * c0

                    rs_p += cp * resp_current

                if -num_b1 <= k - a <= num_b1 and \
                   -num_b2 <= l - b <= num_b2 and \
                   -num_b3 <= m - c <= num_b3:

                    cm = np.conj(ccc1[k - a, :] *
                                 ccc2[l - b, :] *
                                 ccc3[m - c, :]) * c0

                    rs_m += cm * resp_current

    # Calculate current coefficient: equation 5.26
    if a == 0 and b == 0 and c == 0:
        return rs_p.imag + 1j * 0.
    else:
        return (rs_p.imag + rs_m.imag) - 1j * (rs_p.real - rs_m.real)


# @nb.njit("c16[:](f4, c16[:,:,:], f8[:], c16[:,:,:,:,:], i4, i4, i4, i4, i4, i4)")
def _current_4_tones(freq_out, ccc, freq, resp_matrix, num_p, npts, num_b1, num_b2, num_b3, num_b4):  # pragma: no cover
    """Calculate the tunneling current at a specific frequency.

    Four tones.

    Frequency is normalized to the gap frequency.

    Args:
        freq_out (float): frequency to solve for
        ccc (ndarray): convolution coefficients
        freq (ndarray): frequencies
        resp_matrix (ndarray): Response function matrix, generated by 
            interpolate_respfn
        num_p (int): Number of harmonics
        npts (int): Number of bias voltage points
        num_b1 (int): Summation limits for phase factor coefficients (tone 1)
        num_b2 (int): Summation limits for phase factor coefficients (tone 2)
        num_b3 (int): Summation limits for phase factor coefficients (tone 3)
        num_b4 (int): Summation limits for phase factor coefficients (tone 4)

    Returns:
        ndarray: Tunneling current at specified frequency
        
    """

    freq_out = round(freq_out, ROUND_FREQ)
    current_out = np.zeros(npts, dtype=np.complex128)

    for a in range(num_p, -(num_p + 1), -1):
        for b in range(num_p, -(num_p + 1), -1):
            for c in range(num_p, -(num_p + 1), -1):
                for d in range(num_p, -(num_p + 1), -1):

                    freq_abcd = round(a * freq[1] + b * freq[2] + c * freq[3] + d * freq[4], ROUND_FREQ)

                    if freq_abcd == freq_out:

                        current_out += _current_coeff_4_tones(a, b, c, d, ccc, resp_matrix, 
                                                              num_b1, num_b2, num_b3, num_b4, npts)

    return current_out


#@nb.njit("c16[:](i4, i4, i4, i4, c16[:,:,:], c16[:,:,:,:,:], i4, i4, i4, i4, i4)")
def _current_coeff_4_tones(a, b, c, d, ccc, resp_matrix, num_b1, num_b2, num_b3, num_b4, npts):  # pragma: no cover
    """Calculate the tunneling current coefficient (for 4 tones).

    Calculate (I(a,b,c,d)) for a four tone system (i.e., for an (a,b,c,d) 
    combination versus calculating the entire matrix for every (a,b,c) pair). 
    
    Equations 5.25 and 5.26 in Kittara's thesis.

    Args:
        a (int): Index a in Eqn. 5.25
        b (int): Index b in Eqn. 5.25
        c (int): Index c in Eqn. 5.25
        d (int): Index d in Eqn. 5.25
        ccc (ndarray): Convolution coefficient
        resp_matrix (ndarray): Response function matrix, generated by 
            interpolate_respfn
        num_b1 (int): Summation limits for phase factor coefficients (tone 1)
        num_b2 (int): Summation limits for phase factor coefficients (tone 2)
        num_b3 (int): Summation limits for phase factor coefficients (tone 3)
        num_b4 (int): Summation limits for phase factor coefficients (tone 4)
        npts (int): Number of bias voltage points

    Returns:
        ndarray: Tunneling current coefficient

    """

    # Recast coefficients (saves a bit of time)
    ccc1 = ccc[1]
    ccc2 = ccc[2]
    ccc3 = ccc[3]
    ccc4 = ccc[4]

    # Calculate Rabcd+j*Sabcd: quation 5.25
    rs_p = np.zeros(npts, dtype=np.complex128)  # positive abcd indices
    rs_m = np.zeros(npts, dtype=np.complex128)  # negative abcd indices
    for k in range(-num_b1, num_b1 + 1):
        for l in range(-num_b2, num_b2 + 1):
            for m in range(-num_b3, num_b3 + 1):
                for n in range(-num_b4, num_b4 + 1):

                    c0 = ccc1[k] * ccc2[l] * ccc3[m] * ccc4[n]

                    # Response function
                    resp_current = resp_matrix[k, l, m, n]

                    # Positive abcd indices
                    if -num_b1 <= k + a <= num_b1 and \
                       -num_b2 <= l + b <= num_b2 and \
                       -num_b3 <= m + c <= num_b3 and \
                       -num_b4 <= n + d <= num_b4:

                        cp = np.conj(ccc1[k + a] *
                                     ccc2[l + b] *
                                     ccc3[m + c] *
                                     ccc4[n + d]) * c0

                        rs_p += cp * resp_current

                    # Negative abcd indices
                    if -num_b1 <= k - a <= num_b1 and \
                       -num_b2 <= l - b <= num_b2 and \
                       -num_b3 <= m - c <= num_b3 and \
                       -num_b4 <= n - d <= num_b4:
                        
                        cm = np.conj(ccc1[k - a] *
                                     ccc2[l - b] *
                                     ccc3[m - c] *
                                     ccc4[n - d]) * c0
                        
                        rs_m += cm * resp_current

    # Calculate current coefficient: equation 5.26
    if a == 0 and b == 0 and c == 0 and d == 0:
        return rs_p.imag + 1j * 0.
    else:
        return (rs_p.imag + rs_m.imag) - 1j * (rs_p.real - rs_m.real)


# Helper functions -----------------------------------------------------------

def _unpack_num_b(num_b, num_f):
    """Unpack num_b (summation limits for phase factor coefficients).

    Args:
        num_b: Summation limits for phase factor coefficients
        num_f: Number of frequencies

    Returns:
        tuple: Summation limits for phase factor coefficients in tuple form, 
        with one value for each tone

    """

    # Note: num_b is 0-indexed if it is a tuple
    # I.e.: num_b[0] is for the fundamental frequency
    if isinstance(num_b, tuple):
        assert len(num_b) >= num_f, \
            "There must be one value of num_b for each fundamental frequency."
        num_b = tuple(num_b[:num_f])
        return num_b

    return tuple([num_b] * num_f)
