#!/usr/bin/python3

import re
import gzip
import struct
import sys
import time
import getopt

_dbenable=False


#####
#
#
# Generic file function
#
#
#####

def universalOpen(file):
    if isinstance(file,str):
        if file[-3:] == '.gz':
            rep = gzip.open(file)
        else:
            rep = open(file)
    else:
        rep = file
    return rep

def universalTell(file):
    if isinstance(file, gzip.GzipFile):
        file=file.myfileobj
    return file.tell()

def fileSize(file):
    if isinstance(file, gzip.GzipFile):
        file=file.myfileobj
    pos = file.tell()
    file.seek(0,2)
    length = file.tell()
    file.seek(pos,0)
    return length

def progressBar(pos,max,reset=False,delta=[]):
    if reset:
        del delta[:]
    if not delta:
        delta.append(time.time())
        delta.append(time.time())

    delta[1]=time.time()
    elapsed = delta[1]-delta[0]
    percent = float(pos)/max * 100
    remain = time.strftime('%H:%M:%S',time.gmtime(elapsed / percent * (100-percent)))
    bar = '#' * int(percent/2)
    bar+= '|/-\\-'[pos % 5]
    bar+= ' ' * (50 - int(percent/2))
    sys.stderr.write('\r%5.1f %% |%s] remain : %s' %(percent,bar,remain))

#####
#
#
# NCBI Dump Taxonomy reader
#
#
#####

def endLessIterator(endedlist):
    for x in endedlist:
        yield x
    while(1):
        yield endedlist[-1]
   
class ColumnFile(object):
    
    def __init__(self,stream,sep=None,strip=True,types=None):
        if isinstance(stream,str):
            self._stream = open(stream)
        elif hasattr(stream,'next'):
            self._stream = stream
        else:
            raise ValueError('stream must be string or an iterator')
        self._delimiter=sep
        self._strip=strip
        if types:
            self._types=[x for x in types]
            for i in range(len(self._types)):
                if self._types[i] is bool:
                    self._types[i]=ColumnFile.str2bool
        else:
            self._types=None
            
    def str2bool(x):
        return bool(eval(x.strip()[0].upper(),{'T':True,'V':True,'F':False}))
                    
    str2bool = staticmethod(str2bool)
            
        
    def __iter__(self):
        return self
    
    def __next__(self):
        ligne = next(self._stream)
        data = ligne.split(self._delimiter)
        if self._strip or self._types:
            data = [x.strip() for x in data]
        if self._types:
            it = endLessIterator(self._types)
            data = [x[1](x[0]) for x in ((y,next(it)) for y in data)]
        return data
    
def taxonCmp(t1,t2):
    if t1[0] < t2[0]:
        return -1
    elif t1[0] > t2[0]:
        return +1
    return 0

def bsearchTaxon(taxonomy,taxid):
    taxCount = len(taxonomy)
    begin = 0
    end   = taxCount 
    oldcheck=taxCount
    check = begin + end / 2
    while check != oldcheck and taxonomy[check][0]!=taxid :
        if taxonomy[check][0] < taxid:
            begin=check
        else:
            end=check
        oldcheck=check
        check = (begin + end) / 2
        
        
    if taxonomy[check][0]==taxid:
        return check
    else:
        return None
        
    
    
def readNodeTable(file):

    file = universalOpen(file)
    
    nodes = ColumnFile(file, 
                       sep='|', 
                       types=(int,int,str,
                              str,str,bool,
                              int,bool,int,
                              bool,bool,bool,str))
    print("Reading taxonomy dump file...", file=sys.stderr)
    taxonomy=[[n[0],n[2],n[1]] for n in nodes]
    print("List all taxonomy rank...", file=sys.stderr)    
    ranks =list(set(x[1] for x in taxonomy))
    ranks.sort()
    ranks = dict(map(None,ranks,range(len(ranks))))
    
    print("Sorting taxons...", file=sys.stderr)
    taxonomy.sort(taxonCmp)

    print("Indexing taxonomy...", file=sys.stderr)
    index = {}
    for t in taxonomy:
        index[t[0]]=bsearchTaxon(taxonomy, t[0])
    
    print("Indexing parent and rank...", file=sys.stderr)
    for t in taxonomy:
        t[1]=ranks[t[1]]
        t[2]=index[t[2]]
        
        
    return taxonomy,ranks,index

def nameIterator(file):
    file = universalOpen(file)
    names = ColumnFile(file, 
                       sep='|', 
                       types=(int,str,
                              str,str))
    for taxid,name,unique,classname,white in names:
        yield taxid,name,classname
        
def mergedNodeIterator(file):
    file = universalOpen(file)
    merged = ColumnFile(file, 
                       sep='|', 
                       types=(int,int,str))
    for taxid,current,white in merged:
            yield taxid,current
    
def deletedNodeIterator(file):
    file = universalOpen(file)
    deleted = ColumnFile(file, 
                       sep='|', 
                       types=(int,str))
    for taxid,white in deleted:
            yield taxid
    
def readTaxonomyDump(taxdir):
    taxonomy,ranks,index = readNodeTable('%s/nodes.dmp' % taxdir)
    
    print("Adding scientific name...", file=sys.stderr)

    alternativeName=[]
    for taxid,name,classname in nameIterator('%s/names.dmp' % taxdir):
        alternativeName.append((name,classname,index[taxid]))
        if classname == 'scientific name':
            taxonomy[index[taxid]].append(name)
        
    print("Adding taxid alias...", file=sys.stderr)
    for taxid,current in mergedNodeIterator('%s/merged.dmp' % taxdir):
        index[taxid]=index[current]
    
    print("Adding deleted taxid...", file=sys.stderr)
    for taxid in deletedNodeIterator('%s/delnodes.dmp' % taxdir):
        index[taxid]=None
    
    return taxonomy,ranks,alternativeName,index


#####
#
#
#  Genbank/EMBL sequence reader
#
#
#####

def entryIterator(file):
    file = universalOpen(file)
    rep =[]
    for ligne in file:
        rep.append(ligne)
        if ligne == '//\n':
            rep = ''.join(rep)
            yield rep
            rep = []
            
def fastaEntryIterator(file):
    file = universalOpen(file)
    rep =[]
    for ligne in file:
        if ligne[0] == '>' and rep:
            rep = ''.join(rep)
            yield rep
            rep = []
        rep.append(ligne)
    if rep:
        rep = ''.join(rep)
        yield rep
    
_cleanSeq = re.compile('[ \n0-9]+')
            
def cleanSeq(seq):
    return _cleanSeq.sub('',seq)
    
    
_gbParseID = re.compile('(?<=^LOCUS {7})[^ ]+(?= )',re.MULTILINE)   
_gbParseDE = re.compile('(?<=^DEFINITION {2}).+?\. *$(?=[^ ])',re.MULTILINE+re.DOTALL)   
_gbParseSQ = re.compile('(?<=^ORIGIN).+?(?=^//$)',re.MULTILINE+re.DOTALL)  
_gbParseTX = re.compile('(?<= /db_xref="taxon:)[0-9]+(?=")')
  
def genbankEntryParser(entry):
    Id = _gbParseID.findall(entry)[0]
    De = ' '.join(_gbParseDE.findall(entry)[0].split())
    Sq = cleanSeq(_gbParseSQ.findall(entry)[0].upper())
    try:
        Tx = int(_gbParseTX.findall(entry)[0])
    except IndexError:
        Tx = None
    return {'id':Id,'taxid':Tx,'definition':De,'sequence':Sq}

######################

_cleanDef = re.compile('[\nDE]')

def cleanDef(definition):
    return _cleanDef.sub('',definition)

_emblParseID = re.compile('(?<=^ID {3})[^ ]+(?=;)',re.MULTILINE)   
_emblParseDE = re.compile('(?<=^DE {3}).+?\. *$(?=[^ ])',re.MULTILINE+re.DOTALL)   
_emblParseSQ = re.compile('(?<=^  ).+?(?=^//$)',re.MULTILINE+re.DOTALL)  
_emblParseTX = re.compile('(?<= /db_xref="taxon:)[0-9]+(?=")')

def emblEntryParser(entry):
    Id = _emblParseID.findall(entry)[0]
    De = ' '.join(cleanDef(_emblParseDE.findall(entry)[0]).split())
    Sq = cleanSeq(_emblParseSQ.findall(entry)[0].upper())
    try:
        Tx = int(_emblParseTX.findall(entry)[0])
    except IndexError:
        Tx = None
    return {'id':Id,'taxid':Tx,'definition':De,'sequence':Sq}


######################

_fastaSplit=re.compile(';\W*')

def parseFasta(seq):
    seq=seq.split('\n')
    title = seq[0].strip()[1:].split(None,1)
    id=title[0]
    if len(title) == 2:
        field = _fastaSplit.split(title[1])
    else:
        field=[]
    info = dict(x.split('=',1) for x in field if '=' in x)
    definition = ' '.join([x for x in field if '=' not in x])
    seq=(''.join([x.strip() for x in seq[1:]])).upper()   
    return id,seq,definition,info

  
def fastaEntryParser(entry):
    id,seq,definition,info = parseFasta(entry)
    Tx = info.get('taxid',None)   
    if Tx is not None:
        Tx=int(Tx)
    return {'id':id,'taxid':Tx,'definition':definition,'sequence':seq}
    

def sequenceIteratorFactory(entryParser,entryIterator):
    def sequenceIterator(file):
        for entry in entryIterator(file):
            yield entryParser(entry)
    return sequenceIterator


def taxonomyInfo(entry,connection):
    taxid = entry['taxid']
    curseur = connection.cursor()
    curseur.execute("""
                        select taxid,species,genus,family,
                               taxonomy.scientificName(taxid) as sn,
                               taxonomy.scientificName(species) as species_sn,
                               taxonomy.scientificName(genus) as genus_sn,
                               taxonomy.scientificName(family) as family_sn
                        from
                            (   
                             select alias                      as taxid,
                               taxonomy.getSpecies(alias) as species,
                               taxonomy.getGenus(alias)   as genus,
                               taxonomy.getFamily(alias)  as family
                                from taxonomy.aliases
                               where id=%d ) as tax
                    """ % taxid)
    rep = curseur.fetchone()
    entry['current_taxid']=rep[0]
    entry['species']=rep[1]
    entry['genus']=rep[2]
    entry['family']=rep[3]
    entry['scientific_name']=rep[4]
    entry['species_sn']=rep[5]
    entry['genus_sn']=rep[6]
    entry['family_sn']=rep[7]
    return entry
    
#####
#
#
# Binary writer
#
#
#####
    
def ecoSeqPacker(sq):
    
    compactseq = gzip.zlib.compress(sq['sequence'],9)
    cptseqlength  = len(compactseq)
    delength   = len(sq['definition'])
    
    totalSize = 4 + 20 + 4 + 4 + 4 + cptseqlength + delength
    
    packed = struct.pack('> I I 20s I I I %ds %ds' % (delength,cptseqlength),
                         totalSize,
                         sq['taxid'],
                         sq['id'],
                         delength,
                         len(sq['sequence']),
                         cptseqlength,
                         sq['definition'],
                         compactseq)
    
    assert len(packed) == totalSize+4, "error in sequence packing"
    
    return packed

def ecoTaxPacker(tx):
    
    namelength = len(tx[3])
    
    totalSize = 4 + 4 + 4 + 4 + namelength
    
    packed = struct.pack('> I I I I I %ds' % namelength, 
                         totalSize, 
                         tx[0],
                         tx[1],
                         tx[2], 
                         namelength,
                         tx[3])
    
    return packed

def ecoRankPacker(rank):
    
    namelength = len(rank)
    
    packed = struct.pack('> I %ds' % namelength,
                         namelength,
                         rank)
    
    return packed
                
def ecoNamePacker(name):
    
    namelength = len(name[0])
    classlength= len(name[1])
    totalSize =  namelength + classlength + 4 + 4 + 4 + 4
    
    packed = struct.pack('> I I I I I %ds %ds' % (namelength,classlength),
                         totalSize,
                         int(name[1]=='scientific name'),
                         namelength,
                         classlength,
                         name[2],
                         name[0],
                         name[1])
    
    return packed
    
def ecoSeqWriter(file,input,taxindex,parser):
    output = open(file,'wb')
    input  = universalOpen(input)
    inputsize = fileSize(input)
    entries = parser(input)
    seqcount=0
    skipped = []

    output.write(struct.pack('> I',seqcount))
    
    progressBar(1, inputsize,reset=True)
    for entry in entries:
        if entry['taxid'] is not None:
            try:
                entry['taxid']=taxindex[entry['taxid']]
            except KeyError:
                entry['taxid']=None
            if entry['taxid'] is not None:
                seqcount+=1
                output.write(ecoSeqPacker(entry))
            else:
                skipped.append(entry['id'])
            where = universalTell(input)
            progressBar(where, inputsize)
            print(" Readed sequences : %d     " % seqcount, end=' ', file=sys.stderr)
        else:
            skipped.append(entry['id'])
        
    print(file=sys.stderr)
    output.seek(0,0)
    output.write(struct.pack('> I',seqcount))
    
    output.close()
    return skipped
        

def ecoTaxWriter(file,taxonomy):
    output = open(file,'wb')
    output.write(struct.pack('> I',len(taxonomy)))
    
    for tx in taxonomy:
        output.write(ecoTaxPacker(tx))

    output.close()
    
def ecoRankWriter(file,ranks):
    output = open(file,'wb')
    output.write(struct.pack('> I',len(ranks)))

    rankNames = list(ranks.keys())
    rankNames.sort()
    
    for rank in rankNames:
        output.write(ecoRankPacker(rank))

    output.close()

def nameCmp(n1,n2):
    name1=n1[0].upper()
    name2=n2[0].upper()
    if name1 < name2:
        return -1
    elif name1 > name2:
        return 1
    return 0


def ecoNameWriter(file,names):
    output = open(file,'wb')
    output.write(struct.pack('> I',len(names)))

    names.sort(nameCmp)
    
    for name in names:
        output.write(ecoNamePacker(name))

    output.close()
    
def ecoDBWriter(prefix,taxonomy,seqFileNames,parser):
    
    ecoRankWriter('%s.rdx' % prefix, taxonomy[1])
    ecoTaxWriter('%s.tdx' % prefix, taxonomy[0])
    ecoNameWriter('%s.ndx' % prefix, taxonomy[2])
  
    filecount = 0
    for filename in seqFileNames:
        filecount+=1
        sk=ecoSeqWriter('%s_%03d.sdx' % (prefix,filecount), 
                     filename, 
                     taxonomy[3], 
                     parser)
        if sk:
            print("Skipped entry :", file=sys.stderr)
            print(sk, file=sys.stderr)
        
def ecoParseOptions(arguments):
    opt = {
            'prefix' : 'ecodb',
            'taxdir' : 'taxdump',
            'parser' : sequenceIteratorFactory(genbankEntryParser,
                                                  entryIterator)
           }
    
    o,filenames = getopt.getopt(arguments,
                                'ht:n:gfe',
                                ['help',
                                 'taxonomy=',
                                 'name=',
                                 'genbank',
                                 'fasta',
                                 'embl'])
    
    for name,value in o:
        if name in ('-h','--help'):
            printHelp()
            exit()
        elif name in ('-t','--taxonomy'):
            opt['taxmod']='dump'
            opt['taxdir']=value
        elif name in ('-n','--name'):
            opt['prefix']=value
        elif name in ('-g','--genbank'):
            opt['parser']=sequenceIteratorFactory(genbankEntryParser,
                                                  entryIterator)

        elif name in ('-f','--fasta'):
            opt['parser']=sequenceIteratorFactory(fastaEntryParser,
                                                  fastaEntryIterator)
            
        elif name in ('-e','--embl'):
            opt['parser']=sequenceIteratorFactory(emblEntryParser,
                                                  entryIterator)
        else:
            raise ValueError('Unknown option %s' % name)

    return opt,filenames


def printHelp():
    print("-----------------------------------")
    print(" ecoPCRFormat.py")
    print("-----------------------------------")
    print("ecoPCRFormat.py [option] <argument>")
    print("-----------------------------------")
    print("-e    --embl        :[E]mbl format")
    print("-f    --fasta       :[F]asta format")
    print("-g    --genbank     :[G]enbank format")
    print("-h    --help        :[H]elp - print this help")
    print("-n    --name        :[N]ame of the new database created")
    print("-t    --taxonomy    :[T]axonomy - path to the taxonomy database")
    print("                    :bcp-like dump from GenBank taxonomy database.")
    print("-----------------------------------")

if __name__ == '__main__':
    
    opt,filenames = ecoParseOptions(sys.argv[1:])
    
    taxonomy = readTaxonomyDump(opt['taxdir'])
    
    ecoDBWriter(opt['prefix'], taxonomy, filenames, opt['parser'])
    

