#!/usr/bin/env python3
"""
Segment live speech from the default audio device.
"""

from pocketsphinx import Vad, Endpointer, set_loglevel
from contextlib import closing
import unittest
import subprocess
import wave
import sys
import os

DATADIR = os.path.join(os.path.dirname(__file__), "../../test/data/librivox")


class VadQ:
    def __init__(self, vad_frames=10, frame_length=0.03):
        self.frames = [None] * vad_frames
        self.is_speech = [0] * vad_frames
        self.n = self.pos = 0
        self.maxlen = vad_frames
        self.frame_length = frame_length
        self.start_time = 0.0

    def __len__(self):
        return self.n

    def empty(self):
        return self.n == 0

    def full(self):
        return self.n == self.maxlen

    def clear(self):
        self.n = 0

    def push(self, is_speech, pcm):
        i = (self.pos + self.n) % self.maxlen
        self.frames[i] = pcm
        self.is_speech[i] = is_speech
        if self.full():
            self.start_time += self.frame_length
            self.pos = (self.pos + 1) % self.maxlen
        else:
            self.n += 1

    def pop(self):
        if self.empty():
            raise IndexError("Queue is empty")
        self.start_time += self.frame_length
        rv = self.is_speech[self.pos], self.frames[self.pos]
        self.pos = (self.pos + 1) % self.maxlen
        self.n -= 1
        return rv

    def speech_count(self):
        if self.empty():
            return 0
        if self.full():
            return sum(self.is_speech)
        # Ideally we would let it equal self.maxlen
        end = (self.pos + self.n) % self.maxlen
        if end > self.pos:
            return sum(self.is_speech[self.pos: end])
        else:
            # Note second term is 0 if end is 0
            return sum(self.is_speech[self.pos:]) + sum(self.is_speech[:end])


class PyEndpointer(Vad):
    def __init__(
        self,
        window=0.3,
        ratio=0.9,
        vad_mode=Vad.LOOSE,
        sample_rate=Vad.DEFAULT_SAMPLE_RATE,
        frame_length=Vad.DEFAULT_FRAME_LENGTH,
    ):
        super(PyEndpointer, self).__init__(vad_mode, sample_rate, frame_length)
        maxlen = int(window / self.frame_length + 0.5)
        self.start_frames = int(ratio * maxlen)
        self.end_frames = int((1.0 - ratio) * maxlen + 0.5)
        print("Threshold %d%% of %.3fs window (>%d frames <%d frames of %d)" %
              (int(ratio * 100.0 + 0.5),
               maxlen * self.frame_length,
               self.start_frames, self.end_frames, maxlen))
        self.vadq = VadQ(maxlen, self.frame_length)
        self.timestamp = 0.0
        self.in_speech = False
        self.speech_start = self.speech_end = None

    def end_stream(self, frame):
        if len(frame) > self.frame_bytes:
            raise IndexError(
                "Last frame size must be %d bytes or less" % self.frame_bytes
            )
        speech_frames = []
        self.timestamp += len(frame) * 0.5 / self.sample_rate
        if not self.in_speech:
            return None
        self.in_speech = False
        self.speech_end = self.vadq.start_time
        while not self.vadq.empty():
            is_speech, pcm = self.vadq.pop()
            if is_speech:
                speech_frames.append(pcm)
                self.speech_end = self.vadq.start_time
            else:
                break
        # If we used all the VAD queue, add the trailing samples
        if self.vadq.empty() and self.speech_end == self.vadq.start_time:
            speech_frames.append(frame)
            self.speech_end = self.timestamp
        self.vadq.clear()
        return b"".join(speech_frames)

    def process(self, frame):
        if self.in_speech:
            assert not self.vadq.full(), "VAD queue overflow (should not happen)"
        if len(frame) != self.frame_bytes:
            raise IndexError("Frame size must be %d bytes" % self.frame_bytes)
        self.vadq.push(self.is_speech(frame), frame)
        self.timestamp += self.frame_length
        speech_count = self.vadq.speech_count()
        #print("%.2f %d %d %d" % (self.timestamp, speech_count, self.start_frames, self.end_frames))
        # Handle state transitions
        if self.in_speech:
            if speech_count < self.end_frames:
                # Return only the first frame.  Either way it's sort
                # of arbitrary, but this avoids having to drain the
                # queue to prevent overlapping segments.  It's also
                # closer to what human annotators will do.
                _, outframe = self.vadq.pop()
                self.speech_end = self.vadq.start_time
                self.in_speech = False
                return outframe
        else:
            if speech_count > self.start_frames:
                self.speech_start = self.vadq.start_time
                self.speech_end = None
                self.in_speech = True
        # Return a buffer if we are in a speech region
        if self.in_speech:
            _, outframe = self.vadq.pop()
            return outframe
        else:
            return None


def get_wavfile_length(path):
    with closing(wave.open(path)) as reader:
        nfr = reader.getnframes()
        frate = reader.getframerate()
        return nfr / frate


def get_labels(path, pos):
    with open(path, "rt") as infh:
        labels = [(pos, "silence")]
        for spam in infh:
            # The labels are a bit odd
            start, _, label = spam.strip().split()
            labels.append((pos + float(start), label))
    return labels


def make_single_track():
    labels = []
    infiles = []
    with open(os.path.join(DATADIR, "fileids"), "rt") as infh:
        pos = 0.0
        for spam in infh:
            fileid = spam.strip()
            path = os.path.join(DATADIR, fileid + ".wav")
            infiles.append(path)
            nsec = get_wavfile_length(path)
            path = os.path.join(DATADIR, fileid + ".lab")
            labels.extend(get_labels(path, pos))
            pos += nsec
    out_labels = []
    start_time, label = labels[0]
    for end_time, next_label in labels[1:]:
        if next_label != label:
            if label == "speech":
                out_labels.append((start_time, end_time, label))
            start_time = end_time
        label = next_label
    if label == "speech":
        out_labels.append((start_time, pos, label))
    return infiles, out_labels


class EndpointerTest(unittest.TestCase):
    def srtest(self, sample_rate):
        ep = Endpointer(vad_mode=3, sample_rate=sample_rate)
        pyep = PyEndpointer(vad_mode=3, sample_rate=sample_rate)
        self.assertEqual(ep.frame_bytes, pyep.frame_bytes)
        soxcmd = ["sox"]
        files, labels = make_single_track()
        soxcmd.extend(files)
        soxcmd.extend("-c 1 -b 16 -e signed-integer -D -G -r".split())
        soxcmd.append("%d" % ep.sample_rate)
        soxcmd.extend("-t raw -".split())
        with subprocess.Popen(soxcmd, stdout=subprocess.PIPE) as sox:
            idx = 0
            while True:
                frame = sox.stdout.read(ep.frame_bytes)
                if len(frame) == 0:
                    break
                elif len(frame) < ep.frame_bytes:
                    speech = ep.end_stream(frame)
                    pyspeech = pyep.end_stream(frame)
                    self.assertEqual(speech, pyspeech)
                else:
                    speech = ep.process(frame)
                    pyspeech = pyep.process(frame)
                    self.assertEqual(speech, pyspeech)
                if speech is not None:
                    self.assertEqual(ep.in_speech, pyep.in_speech)
                    if not ep.in_speech:
                        self.assertFalse(pyep.in_speech)
                        start_time, end_time, _ = labels[idx]
                        start_diff = abs(start_time - ep.speech_start)
                        end_diff = abs(end_time - ep.speech_end)
                        print(
                            "%.2f:%.2f (py: %.2f:%.2f) (truth: %.2f:%.2f) (diff:%.2f:%.2f)"
                            % (
                                ep.speech_start,
                                ep.speech_end,
                                pyep.speech_start,
                                pyep.speech_end,
                                start_time,
                                end_time,
                                start_diff,
                                end_diff,
                            )
                        )
                        self.assertAlmostEqual(ep.speech_start, pyep.speech_start, 3)
                        self.assertAlmostEqual(ep.speech_end, pyep.speech_end, 3)
                        self.assertLess(start_diff, 0.06)
                        self.assertLess(end_diff, 0.21)
                        idx += 1

    def testEndpointer(self):
        try:
            set_loglevel("INFO")
            # 8000, 44100, 48000 give slightly different results unfortunately
            for sample_rate in 11025, 16000, 22050, 32000:
                print(sample_rate)
                self.srtest(sample_rate)
        except OSError as err:
            self.skipTest("sox not installed: %s" % err)


if __name__ == "__main__":
    unittest.main()
