#!/usr/bin/python
import sys, re, time, logging
from collections import defaultdict

from pysam import Samfile
"""
Realigns a bam generated by blasr so that indels between reads
are more likely to line up over the same reference space
"""

def realign(read):
    """
    realigns target, query so that every alignment should 
    have the comparable characteristics in the same sequence context
    regardless of differences due to the error rate and alignment 
    variation
    realignment happens inplace
    """
    def inner(query, target):
        q = "".join(query)
        t = "".join(target)
        pos = 0
        length = len(target)
        while pos < length:
            #can the query be moved over (deletion realign)
            #ATT-TCG
            #ATTTTCG
            #becomes
            #ATTT-CG
            #ATTTTCG
            d = False
            if query[pos] == '-':
                i = re.match('-+[ATCGatcg]?', q[pos:])
                if target[pos] == query[pos+i.end()-1]:
                    query[pos], query[pos+i.end()-1] = query[pos+i.end()-1], query[pos]
                    q = "".join(query)
                d = True
            #can the target be moved over (insertion realign)
            # ATTTCG
            # AT-TCG
            #becomes
            # ATTTCG
            # ATT-CG
            if target[pos] == '-':
                i = re.match('-+[ATCGatcg]?', t[pos:])
                if query[pos] == target[pos+i.end()-1]:
                    target[pos], target[pos+i.end()-1] = target[pos+i.end()-1], target[pos]
                    t = "".join(target)
                d = True
            #allow adjacent indels
            #dash before me and dash in target
            if not d and query[pos] != target[pos]:
                query.insert(pos, '-')
                target.insert(pos+1,'-')
                q = "".join(query)
                t = "".join(target)
                length += 1
            
            pos += 1
        return query, target
        
    query, target = expandAlign(read)
    query,target = inner(query, target)
    return query, target
    if read.is_reverse:
        query = query[::-1]
        target = target[::-1]
        query, target = inner(query, target)
        query = query[::-1]
        target = target[::-1]
    else:
        query,target = inner(query, target)
    return query, target

def expandAlign(alignment):
    """
    Takes a pysam Alignment and creates 
    (reference, query) alignments
    For example:
        query     =  ATCGC-GT
        reference =  AT-GCGGA
        Where C inserted, G deleted, A->T Sub
    """
    seq = alignment.query
    cigar = expandCigar(alignment.cigar)
    
    mdTag = None
    for i in alignment.tags:
        if i[0] == "MD":
            mdTag = expandMd(i[1])
    
    if mdTag is None:# and alignment.target:
        logging.error(("MD tag is absent. Run samtools calmd"))
        exit(1)
    qPos = 0
    tPos = 0
    tSeq = []
    qSeq = []
    p = 0
    for i in cigar:
        if i == 0:
            if mdTag[tPos] == '-':
                #mat
                tSeq.append(seq[qPos])
            else:
                #mis
                tSeq.append(mdTag[tPos])
            qSeq.append(seq[qPos])
            qPos += 1
            tPos += 1
        elif i == 1:
            #ins
            qSeq.append(seq[qPos])
            tSeq.append("-")
            qPos += 1
        elif i == 2:
            #del
            qSeq.append("-")
            tSeq.append(mdTag[tPos])
            tPos += 1
    #Expanding query seq and filling in target seq
    return (qSeq,tSeq)

def replace(read, query, target):
    """
    replaces read cigar and md
    """
    ret = []
    prev = None
    count = 0
    #cigar
    seqLen = 0
    for qSeq, tSeq in zip(query, target):
        #we're seeing a ...
        if tSeq == '-':
            if qSeq == '-':
                continue#skip it
            seqLen += 1
            cur = 1 # i
        elif qSeq == '-':
            cur = 2 # d
        else:
            seqLen += 1
            cur = 0 # m
        # this is our first one
        if prev is None:
            prev = cur
        #this is our first one of this type
        if prev != cur:
            #write down how many we've seen
            ret.append((prev, count))
            count = 1
            prev = cur
        else:
            count += 1
    
    #last little bit
    ret.append((prev, count))
    if read.cigar[0][0] == 4:
        ret.insert(0, read.cigar[0])
    if read.cigar[-1][0] == 4:
        ret.append(read.cigar[-1])
    
    read.cigar = ret
    
    #md -- I can't do this right now
    #   -- just delete it
    newTags = []
    for pos,i in enumerate( read.tags ):
        if i[0] != "MD":
            newTags.append(i)
            #jsut removing because I can't make MD
            #newTags.append((i[0], "".join(ret)))
        #else:
    read.tags = newTags

    
def expandCigar(cigar):
    """
    Turns the abbreviated cigar into the full array
    
    0 = M
    1 = I
    2 = D
    """
    ret = []
    for t,s in cigar:
        #remove tails...
        if t not in [0,1,2]:
            continue
        ret.extend([t]*s)
    return ret

def expandMd(md):
    """
    Turns abbreviated MD into a full array
    """
    ret = []
    for i in re.findall("\d+|\^?[ATCGN]+", md):
        if i.startswith('^'):
            ret.extend(list(i[1:]))
        elif i[0] in ["A","T","C","G","N"]:
            ret.extend(list(i))
        else:
            ret.extend(['-']*int(i))

    return ret

if __name__ == '__main__':
    f = Samfile(sys.argv[1])
    out = Samfile(sys.argv[1][:-4]+"_realign.bam",'wb', template=f)
    count = 0.0
    n = 0.05
    for read in f: 
        q,t = expandAlign(read)
        query, target = realign(read)
        replace(read, query, target)
        out.write(read)
        count += 1
        if (count / f.mapped) > n:
            n += 0.05
            print "[%s] -- parsed %d of %d reads (%.2f)" % (time.asctime(), int(count), f.mapped, count/f.mapped )
        
    out.close()
