
import unittest

import numpy as np

import pbcore.data
from pbcore.io.align import BamReader
from pbcore.io.align.PacBioBamIndex import PacBioBamIndex, StreamingBamIndex

class TestPbIndex(unittest.TestCase):
    BAM_FILE_NAME = pbcore.data.getUnalignedBam()

    @classmethod
    def setUpClass(cls):
        cls._bam = BamReader(cls.BAM_FILE_NAME)
        cls._pbi = PacBioBamIndex(cls.BAM_FILE_NAME + ".pbi")

    def test_pbindex_bam_consistency(self):
        self.assertEqual(len(self._pbi), 117)
        for i_rec, rec in enumerate(self._bam):
            self.assertEqual(rec.qId, self._pbi.qId[i_rec])
            self.assertEqual(rec.HoleNumber, self._pbi.holeNumber[i_rec])
            self.assertEqual(rec.qStart, self._pbi.qStart[i_rec])
            self.assertEqual(rec.qEnd, self._pbi.qEnd[i_rec])
        self.assertEqual(i_rec, 116)

    def test_pbindex_streaming(self):
        streamed = StreamingBamIndex(self.BAM_FILE_NAME + ".pbi", 20)
        self.assertEqual(streamed.nchunks, 6)
        chunks = [chunk for chunk in streamed]
        for attr in ["qId", "holeNumber", "qStart", "qEnd"]:
            combined = np.concatenate([getattr(c, attr) for c in chunks])
            self.assertEqual(len(combined), len(self._pbi))
            self.assertTrue(all(combined == getattr(self._pbi, attr)))
        chunk = streamed.get_chunk(1)
        for attr in ["qId", "holeNumber", "qStart", "qEnd"]:
            self.assertTrue(all(getattr(chunk, attr) == getattr(chunks[1], attr)))

    # with the default chunk size there should be just one chunk identical
    # to the whole index
    def test_pbindex_streaming_entire(self):
        streamed = StreamingBamIndex(self.BAM_FILE_NAME + ".pbi")
        self.assertEqual(streamed.nchunks, 1)
        chunked = [chunk for chunk in streamed][0]
        self.assertEqual(len(chunked), len(self._pbi))
        for attr in ["qId", "holeNumber", "qStart", "qEnd"]:
            self.assertTrue(all(getattr(chunked, attr) == getattr(self._pbi, attr)))

    def test_pbindex_with_zmw_index(self):
        """
        Test that the built in sub-index of ZMW start positions is correct.
        """
        streamed = StreamingBamIndex(self.BAM_FILE_NAME + ".pbi", 20)
        unique_zmws = set()
        n_indexed_zmws = 0
        for chunk, zmw_idx in streamed.iter_with_zmw_index():
            for k, zmw_start in enumerate(zmw_idx):
                if k < len(zmw_idx) - 1:
                    idx_max = zmw_idx[k+1]
                else:
                    idx_max = len(chunk)
                zmws = chunk.holeNumber[zmw_start:idx_max]
                self.assertTrue(len(set(zmws)), 1)
                unique_zmws.add(zmws[0])
                n_indexed_zmws += 1
        self.assertEqual(len(unique_zmws), n_indexed_zmws)
