from sqt.align import (edit_distance as ed, GlobalAlignment as GA, consensus,
	hamming_distance)
from random import choice, seed, randint
import pytest

STRING_PAIRS = [
	('', ''),
	('', 'A'),
	('A', 'A'),
	('AB', ''),
	('AB', 'ABC'),
	('TGAATCCC', 'CCTGAATC'),
	('ANANAS', 'BANANA'),
	('SISSI', 'MISSISSIPPI'),
	('GGAATCCC', 'TGAGGGATAAATATTTAGAATTTAGTAGTAGTGTT'),
	('TCTGTTCCCTCCCTGTCTCA', 'TTTTAGGAAATACGCC'),
	('TGAGACACGCAACATGGGAAAGGCAAGGCACACAGGGGATAGG', 'AATTTATTTTATTGTGATTTTTTGGAGGTTTGGAAGCCACTAAGCTATACTGAGACACGCAACAGGGGAAAGGCAAGGCACA'),
	('TCCATCTCATCCCTGCGTGTCCCATCTGTTCCCTCCCTGTCTCA', 'TTTTAGGAAATACGCCTGGTGGGGTTTGGAGTATAGTGAAAGATAGGTGAGTTGGTCGGGTG'),
	('A', 'TCTGCTCCTGGCCCATGATCGTATAACTTTCAAATTT'),
	('GCGCGGACT', 'TAAATCCTGG'),
	]


seed(10)

def randstring():
	return ''.join(choice('AC') for _ in range(randint(0, 10)))

STRING_PAIRS.extend((randstring(), randstring()) for _ in range(100000))


def test_edit_distance():
	assert ed('', '') == 0
	assert ed('', 'A') == 1
	assert ed('A', 'B') == 1
	assert ed('A', 'A') == 0
	assert ed('A', 'AB') == 1
	assert ed('BA', 'AB') == 2
	for s, t in STRING_PAIRS:
		assert ed(s, '') == len(s)
		assert ed('', s) == len(s)
		assert ed(s, t) == ed(t, s)


def test_edit_distance_bytes():
	assert ed(b'', b'') == 0
	assert ed(b'', b'A') == 1
	assert ed(b'A', b'B') == 1
	assert ed(b'A', b'A') == 0
	assert ed(b'A', b'AB') == 1
	assert ed(b'BA', b'AB') == 2
	for s, t in STRING_PAIRS:
		s = s.encode('ascii')
		t = t.encode('ascii')
		assert ed(s, b'') == len(s)
		assert ed(b'', s) == len(s)
		assert ed(s, t) == ed(t, s)


def assert_banded(s, t, maxdiff):
	banded_dist = ed(s, t, maxdiff=maxdiff)
	true_dist = ed(s, t)
	if true_dist > maxdiff:
		assert banded_dist > maxdiff
	else:
		assert banded_dist == true_dist


def test_edit_distance_banded():
	for maxdiff in range(5):
		assert_banded('ABC', '', maxdiff)

		for s, t in STRING_PAIRS:
			assert_banded(s, '', maxdiff)
			assert_banded('', s, maxdiff)
			assert_banded(s, t, maxdiff)
			assert_banded(t, s, maxdiff)


def nongap_characters(row):
	"""
	Return the non-gap characters (not '\0') of an alignment row.
	"""
	try:
		return row.replace(b'\0', b'')
	except TypeError:
		return row.replace('\0', '')


def count_gaps(row):
	try:
		return row.count(b'\0')
	except TypeError:
		return row.count('\0')


def count_mismatches(row1, row2):
	if type(row1) is str:
		gap = '\0'
	else:
		gap = 0
	return sum(1 for (c1, c2) in zip(row1, row2) if c1 != c2 and c1 != gap and c2 != gap)


def test_global_alignment():
	for s, t in STRING_PAIRS:
		distance = ed(s, t)
		ga = GA(s, t)
		assert len(ga.row1) == len(ga.row2)
		assert ga.errors == distance
		assert nongap_characters(ga.row1) == s
		assert nongap_characters(ga.row2) == t
		assert ga.errors == count_gaps(ga.row1) + count_gaps(ga.row2) + count_mismatches(ga.row1, ga.row2)


def test_consensus():
	d = dict(a='AAA', b='ACA', c='AAG', d='TAA', e='AAA')
	assert consensus(d) == 'AAA'
	assert consensus(d.values()) == 'AAA'

def test_hamming_distance():
	assert hamming_distance('', '') == 0
	assert hamming_distance('A', 'A') == 0
	assert hamming_distance('HELLO', 'HELLO') == 0
	assert hamming_distance('ABC', 'DEF') == 3


def test_hamming_distance_incorrect_length():
	with pytest.raises(IndexError):
		hamming_distance('A', 'BC')
