import random
import argparse
import os
import sys
import time


def find_out_length(infiles,amount,file_type):
	infile1 = open(infiles[0])
	lengths1 = []
	lengths2 = []
	i = 0
	if file_type == "fastq":
		name = infile1.readline()
	
		while i <= amount+1 and name != "":
			lengths1.append(len(infile1.readline())-1)
			infile1.readline()
			infile1.readline()
			name = infile1.readline()
			i += 1
	if file_type == "fasta":
		while i <= amount+1 and name != "":
			lengths1.append(len(infile1.readline())-1)
			name = infile1.readline()
			i += 1
	if not SingleEnd:
		infile2 = open(infiles[1])
		i = 0
		name = infile2.readline()
		if file_type == "fastq":
			while i <= amount+1 and name != "":
				lengths2.append(len(infile2.readline())-1)
				infile2.readline()
				infile2.readline()
				name = infile2.readline()
				i += 1
		if file_type == "fasta":
			while i <= amount+1 and name != "":
				lengths2.append(len(infile2.readline())-1)
				name = infile2.readline()
				i += 1
		infile2.close()
	else:
		lengths2 = [0]*len(lengths1)
	infile1.close()
	j = 0
	while len(lengths1) <= amount:
		lengths1.append(lengths1[j])
		j += 1
	j = 0 
	while len(lengths2) <= amount:
		lengths2.append(lengths2[j])
		j += 1
	return lengths1,lengths2
	

def checkread(read,verbose,allowed):
	#allowed = set("ACGTacgt")
	if set(read) <= allowed:
		return True
	else:
		if verbose is True:
			print("invalid read:")
			print(read)
		return False


def complement(s):
	basecomplement = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A', 'a': 'T', 'c': 'G', 'g': 'C', 't': 'A'}
	letters = list(s)
	letters = [basecomplement[base] for base in letters]
	return ''.join(letters)


def revcom(s):
	return complement(s[::-1])


def create_random_readpairs(infile, readlength1, readlength2, gapsize, pos, verbose, allowed):
	#try:
	infile.seek(pos)
	full_seq = infile.read(readlength1+gapsize+readlength2).replace('\n','').replace('\r','')
	read1 = full_seq[0:readlength1]		#todo: fit readlengths, they are 1 too long?
	read2 = full_seq[-readlength2:]
	#except:
	#	return False
	if checkread(full_seq,verbose,allowed):
		return read1, read2


def createrandomunassignedreads(infile, outfile1, outfile2, name, gapsize, verbose, chunk, amount, real_readfiles):
	if chunk > amount:
		chunk = amount
	rest = chunk
	revcomp = 0
	filesize = os.path.getsize(infile)
	infile = open(infile)
	list1 = []
	list2 = []
	i = 0
	sys.stdout.write("Simulating reads of "+name+" for evaluation:\n")
	lengths = find_out_length(real_readfiles,amount,"fastq")
	lengths1 = lengths[0]
	lengths2 = lengths[1]
	while rest > 0:
		poslist = []
		for j in range(rest+1):
			poslist.append(random.randint(0, filesize-(lengths1[j]+lengths2[j]+gapsize)))
		poslist = sorted(poslist)
		for k in range(len(poslist)-1):
			readlength1 = lengths1[i]
			if SingleEnd:
				readlength2 = 0
			else:	
				readlength2 = lengths2[i]
			reads = create_random_readpairs(infile,readlength1,readlength2,gapsize,poslist[k],verbose,allowed=set("ACGTacgt"))
			###for fasta-reads
			if reads:
				i += 1
				revcomp += 1
				if revcomp % 2 == 0:
					list1.append("".join([">",name , str(i), "\n" , reads[0], "\n"]))
					list2.append("".join([">",name , str(i), "\n" , reads[1], "\n"]))
				elif revcomp % 2 == 1:
					list1.append("".join([">",name , str(i), "\n" , revcom(reads[0]), "\n"]))
					list2.append("".join([">",name , str(i), "\n" , revcom(reads[1]), "\n"]))
				if i%1000 == 0:
					sys.stdout.write(str(float(i)/amount*100)+"%\r")
					sys.stdout.flush()
			else:
				continue
			###for fastq-reads
			"""if reads:
				i += 1
				revcomp += 1
				if revcomp % 2 == 0:
					list1.append("".join(["@",name , str(i) , "\n" , reads[0] , "\n" , "+\n" , "#"*len(reads[0]) , "\n"]))
					list2.append("".join(["@",name , str(i) , "\n" , reads[1] , "\n" , "+\n" , "#"*len(reads[1]) , "\n"]))

				elif revcomp % 2 == 1:
					list1.append("".join(["@",name , str(i) , "\n" , revcom(reads[0]) , "\n" , "+\n" , "#"*len(reads[0]) , "\n"]))
					list2.append("".join(["@",name , str(i) , "\n" , revcom(reads[1]) , "\n" , "+\n" , "#"*len(reads[1]) , "\n"]))
				if i%1000 == 0:
					sys.stdout.write(str(float(i)/amount*100)+"%\r")
					sys.stdout.flush()
			else:
				continue"""
		for elem in range(len(list1)):
			outfile1.write(list1[elem])
			if not SingleEnd:
				outfile2.write(list2[elem])
		list1 = []
		list2 = []
		if rest < amount - i:
			continue
		else:
			rest = amount - i
	print("\n")
	infile.close()


def createrandomassignedreads(infile, outfile1, outfile2, name, gapsize, verbose, chunk, amount, real_readfiles):
	if chunk > amount:
		chunk = amount
	rest = chunk
	revcomp = 0
	filesize = os.path.getsize(infile)
	infile = open(infile)
	list1 = []
	list2 = []
	i = 0
	sys.stdout.write("Simulating reads of "+name+" for training:\n")
	lengths = find_out_length(real_readfiles,amount,"fastq")
	while rest > 0:
		poslist = []
		for j in range(rest+1):
			poslist.append(random.randint(0, filesize-(lengths[0][j]+lengths[1][j]+gapsize)))
		poslist = sorted(poslist)
		for k in range(len(poslist)-1):
			readlength1 = lengths[0][i]
			if SingleEnd:
				readlength2 = 0
			else:	
				readlength2 = lengths[1][i]
			reads = create_random_readpairs(infile,readlength1,readlength2,gapsize,poslist[k],verbose,allowed=set("ACGTacgt"))
			if reads:
				i += 1
				revcomp += 1
				if revcomp % 2 == 0:
					list1.append("".join([">",name , str(i) , "\n" , reads[0], "\n"]))
					list2.append("".join([">",name , str(i) , "\n" , reads[1], "\n"]))
				elif revcomp % 2 == 1:
					list1.append("".join([">",name , str(i) , "\n" , revcom(reads[0]), "\n"]))
					list2.append("".join([">",name , str(i) , "\n" , revcom(reads[1]), "\n"]))
				if i%1000 == 0:
					sys.stdout.write(str(float(i)/amount*100)+"%\r")
					sys.stdout.flush()
			else:
				continue

		for elem in range(len(list1)):
			outfile1.write(list1[elem])
			if not SingleEnd:
				outfile2.write(list2[elem])
		list1 = []
		list2 = []
		if rest < amount - i:
			continue
		else:
			rest = amount - i
	print("\n")
	infile.close()


def main(infile1,infile2,name1,name2,temppath,real_readfiles,SE,amount,gapsize=1,verbose=False,chunk=100000):
	global SingleEnd
	SingleEnd=SE

	if not SingleEnd:
		outfile1 = os.path.join(temppath,"simulated_reads_"+name1+"_"+str(amount)+"_1.fasta")
		outfile2 = os.path.join(temppath,"simulated_reads_"+name1+"_"+str(amount)+"_2.fasta")
		outfile3 = os.path.join(temppath,"simulated_reads_"+name2+"_"+str(amount)+"_1.fasta")
		outfile4 = os.path.join(temppath,"simulated_reads_"+name2+"_"+str(amount)+"_2.fasta")
		outfile5 = os.path.join(temppath,"simulated_reads_"+name1+"_and_"+name2+"_"+str(amount)+"_1.fasta")
		outfile6 = os.path.join(temppath,"simulated_reads_"+name1+"_and_"+name2+"_"+str(amount)+"_2.fasta")
		outf1 = open(outfile1,"w")
		outf2 = open(outfile2,"w")
		outf3 = open(outfile3,"w")
		outf4 = open(outfile4,"w")
	else:
		outfile1 = os.path.join(temppath,"simulated_reads_"+name1+"_"+str(amount)+".fasta")
		outfile3 = os.path.join(temppath,"simulated_reads_"+name2+"_"+str(amount)+".fasta")
		outfile5 = os.path.join(temppath,"simulated_reads_"+name1+"_and_"+name2+"_"+str(amount)+".fasta")
		outf1 = open(outfile1,"w")
		outf2 = None
		outf3 = open(outfile3,"w")
		outf4 = None

	before = time.time()
	createrandomassignedreads(infile1, outf1, outf2, name1, gapsize, verbose, chunk, amount, real_readfiles)
	print(time.time()-before)
	before = time.time()
	createrandomassignedreads(infile2, outf3, outf4, name2, gapsize, verbose, chunk, amount, real_readfiles)
	print(time.time()-before)
	if not SingleEnd:
		outf1.close()
		outf2.close()
		outf3.close()
		outf4.close()
		outf5, outf6 = open(outfile5,"w"), open(outfile6,"w")

	else:
		outf1.close()
		outf3.close()
		outf5 = open(outfile5,"w")
		outf6 = None
	#outf5, outf6 = open(outfile5,"w"), open(outfile6,"w")
	before = time.time()
	createrandomunassignedreads(infile1, outf5, outf6, name1, gapsize, verbose, chunk, amount, real_readfiles)
	print(time.time()-before)
	before = time.time()

	createrandomunassignedreads(infile2, outf5, outf6, name2, gapsize, verbose, chunk,  amount, real_readfiles)
	print(time.time()-before)
	if not SingleEnd:
		outf5.close()
		outf6.close()
	else:
		outf5.close()

if __name__ == "__main__":
	parser = argparse.ArgumentParser()
	parser.add_argument('-i', '--infile', help='Input file of species 1 and species 2, given in fasta-format', required=True, action='append')
	parser.add_argument('-r', '--real_readfiles', help='One File containing a subset of your reads in fastq-format', required=True, action='append')
	parser.add_argument('-n', '--name', help='Name of species 1 and species 2', required=True, action='append')
	parser.add_argument('-o', '--outpath', help='Path to write resulting files', required=True)
	parser.add_argument('-g', '--gapsize', help='Approximate lenght of gap between reads', required=False, default=0)
	parser.add_argument('-v', '--verbose', help='print invalid reads and additional information',required=False,default=False,action='store_true')
	parser.add_argument('-a', '--amount', help='number of reads to be simulated for all kinds', required=False,default=500000)
	parser.add_argument('-c', '--chunk', help='number of reads to be generated before writing. Choose according to your RAM.', required=False,default=100000)
	parser.add_argument('-s', '--SE', help='Single End mode', action='store_true', default=False)
	args = parser.parse_args()
	verbose = args.verbose
	outpath = str(args.outpath)
	amount = int(args.amount)
	chunk = int(args.chunk)
	SE = args.SE
	if SE:
		gapsize = 0
	else:
		gapsize = int(args.gapsize)

	infile1 = args.infile[0]
	infile2 = args.infile[1]
	name1 = str(args.name[0])
	name2 = str(args.name[1])
	real_readfiles = args.real_readfiles
	main(infile1,infile2,name1,name2,outpath,real_readfiles,SE,amount,gapsize,verbose,chunk)
