"""
Tests for the sqt.io.fasta module
"""
from io import StringIO
from nose.tools import raises

from sqt.io.fasta import (FastaReader, FastaWriter, Sequence, FastqWriter,
	SequenceReader, fastq_header)
import os.path


def dpath(path):
	return os.path.join(os.path.dirname(__file__), path)


def test_fastqwriter():
	tmp = dpath("tmp.fastq")
	with FastqWriter(tmp) as fq:
		fq.write("name", "CCATA", "!#!#!")
		fq.write("name2", "HELLO", "&&&!&&")
	assert fq._file.closed
	with open(tmp) as t:
		assert t.read() == '@name\nCCATA\n+\n!#!#!\n@name2\nHELLO\n+\n&&&!&&\n'
	os.remove(tmp)

def test_fastqwriter_twoheaders():
	tmp = dpath("tmp.fastq")
	with FastqWriter(tmp, twoheaders=True) as fq:
		fq.write("name", "CCATA", "!#!#!")
		fq.write("name2", "HELLO", "&&&!&&")
	assert fq._file.closed
	with open(tmp) as t:
		assert t.read() == '@name\nCCATA\n+name\n!#!#!\n@name2\nHELLO\n+name2\n&&&!&&\n'
	os.remove(tmp)


def test_fastawriter():
	tmp = dpath("tmp.fasta")
	with FastaWriter(tmp) as fw:
		fw.write("name", "CCATA")
		fw.write("name2", "HELLO")
	assert fw._file.closed
	with open(tmp) as t:
		assert t.read() == '>name\nCCATA\n>name2\nHELLO\n'
	os.remove(tmp)


def test_fastawriter_linelength():
	tmp = dpath("tmp.fasta")
	with FastaWriter(tmp, line_length=3) as fw:
		fw.write("name", "CCAT")
		fw.write("name2", "TACCAG")
	assert fw._file.closed
	with open(tmp) as t:
		d = t.read()
		assert d == '>name\nCCA\nT\n>name2\nTAC\nCAG\n'
	os.remove(tmp)


def test_fastawriter_sequence():
	tmp = dpath("tmp.fasta")
	with FastaWriter(tmp) as fw:
		fw.write(Sequence("name", "CCATA"))
		fw.write(Sequence("name2", "HELLO"))
	assert fw._file.closed
	with open(tmp) as t:
		assert t.read() == '>name\nCCATA\n>name2\nHELLO\n'
	os.remove(tmp)


@raises(ValueError)
def test_fastawriter_contextmanager():
	tmp = dpath("tmp.fasta")
	fr = FastaWriter(tmp)
	os.remove(tmp)
	with fr as frw:
		pass
	with fr as frw:
		pass


def test_fastareader():
	with FastaReader(dpath("seq.fa"), case='keep') as fr:
		seqs = list(fr)
	assert fr._file.closed
	assert len(seqs) == 3
	assert seqs[0].qualities is None
	assert seqs[0].name == 'Chr1'
	assert seqs[1].name == 'Chr2 CHROMOSOME dumped from ADB: Jun/20/09 14:54; last updated: 2009-02-02'
	assert len(seqs[0].sequence) == 1235
	assert seqs[0].sequence.startswith('CCCTAAACCCTAAACCCTAAACCCTAAACCTCTGAATCCTTAATC')
	assert seqs[1].sequence.startswith('ctcgaccaggacgatgaatgggc')
	assert seqs[2].sequence.endswith('AATCTTGCAAGTTCCAACTAATT')


def test_fastareader_upper():
	with FastaReader(dpath("seq.fa")) as fr:
		seqs = list(fr)
	assert seqs[0].name == 'Chr1'
	assert len(seqs[0].sequence) == 1235
	assert seqs[0].sequence.startswith('CCCTAAACCCTAAACCCTAAACCCTAAACCTCTGAATCCTTAATC')
	assert seqs[1].sequence.startswith('CTCGACCAGGACGATGAATGGGC')


def test_fastareader_lower():
	with FastaReader(dpath("seq.fa"), case='lower') as fr:
		seqs = list(fr)
	assert seqs[0].name == 'Chr1'
	assert len(seqs[0].sequence) == 1235
	assert seqs[0].sequence.startswith('ccctaaaccctaaaccctaaaccctaaacctctgaatccttaatc')
	assert seqs[1].sequence.startswith('ctcgaccaggacgatgaatgggc')


def test_fastareader_binary():
	for wholefile in False, True:
		print('wholefile:', wholefile)
		with FastaReader(dpath("seq.fa"), binary=True, wholefile=wholefile, case='keep') as fr:
			seqs1 = list(fr)

		with FastaReader(dpath("seq.fa"), mode='rb', wholefile=wholefile, case='keep') as fr:
			seqs2 = list(fr)

		for seqs in seqs1, seqs2:
			assert fr._file.closed
			assert len(seqs) == 3
			assert seqs[0].qualities is None
			assert seqs[0].name == 'Chr1'
			assert seqs[2].name == 'Chr3 CHROMOSOME dumped from ADB: Jun/20/09 14:54; last updated: 2009-02-02'
			assert len(seqs[0].sequence) == 1235
			assert seqs[0].sequence.startswith(b'CCCTAAACCCTAAACCCTAAACCCTAAACCTCTGAATCCTTAATC')
			assert seqs[1].sequence.startswith(b'ctcgaccaggacgatgaatgggc')
			assert seqs[2].sequence.endswith(b'AATCTTGCAAGTTCCAACTAATT')


def test_sequence_reader():
	# should auto-detect FASTA vs FASTQ
	with SequenceReader(dpath("seq.fa")) as sr:
		assert sr.format == 'fasta'
	with SequenceReader(dpath("seq.fastq")) as sr:
		assert sr.format == 'fastq'


@raises(ValueError)
def test_fastareader_contextmanager():
	fr = FastaReader(dpath("seq.fa"))
	with fr as frw:
		pass
	with fr as frw:
		pass


def test_fastq_header():
	h = fastq_header(StringIO('@HWI-ST344:204:D14G8ACXX:8:1101:1638:2116 1:N:0:CGATGT'))
	assert h.instrument == 'HWI-ST344'
	assert h.run == 204
	assert h.flowcell == 'D14G8ACXX'
	assert h.lane == 8
	assert h.barcode == 'CGATGT'

	h = fastq_header(StringIO('@MISEQ:56:000000000-A4YM7:1:1101:15071:2257 1:N:0:CTTGTA'))
	assert h.instrument == 'MISEQ'
	assert h.run == 56
	assert h.flowcell == '000000000-A4YM7'
	assert h.lane == 1
	assert h.barcode == 'CTTGTA'

	h = fastq_header(StringIO('@HWI-ST552_0:4:1101:1179:1939#0/1'))
	print(h)
	assert h.instrument == 'HWI-ST552_0'
	assert h.run is None
	assert h.flowcell is None
	assert h.lane == 4
	assert h.barcode is None

	h = fastq_header(StringIO('@HWI_ST139:8:1:1202:1874#GATCAG/1'))
	assert h.instrument == 'HWI_ST139'
	assert h.run is None
	assert h.flowcell is None
	assert h.lane == 8
	assert h.barcode == 'GATCAG'

	#h = fastq_header(StringIO('@FCD20MKACXX:8:1101:1215:2155#TCGTAAGC/1'))
	#assert h.instrument is None
	#assert h.run is None
	#assert h.flowcell == 'FCD20MKACXX'
	#assert h.lane == 8
	#assert h.barcode == 'TCGTAAGC'
