#!/usr/bin/env python
"""
Script for filtering alignments with deficient properties (rqs,gaps, mismatches). The script generates a new .SAM file from the input SAM file, where only alignments are
contained that fulfill certain quality thresholds.
nm = numberOfMisMatches
rq = readqualityscore
MQ = mappingquality
"""



"""
Copyright (c) 2012, Sven H. Giese, gieseS@rki.de, 
Robert Koch-Institut, Germany,
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
    * Redistributions of source code must retain the above copyright
      notice, this list of conditions and the following disclaimer.
    * Redistributions in binary form must reproduce the above copyright
      notice, this list of conditions and the following disclaimer in the
      documentation and/or other materials provided with the distribution.
    * The name of the author may not be used to endorse or promote products
      derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL MARTIN S. LINDNER BE LIABLE FOR ANY
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""




import HTSeq
import re
import sys
import time
from optparse import OptionParser


from core import AnalyseMapping as AM


#global counter variables for good /counter1) and bad (counter2) alignments
counter1 = 0
counter2 = 0



def writeFilteredReads(iddic,fastqfile,fastqfileoutput):
    """
    
    Reads a .fastqfile and filters them for certain quality values. If an alignment is equally good or 
    better as the threshold it will be written to the filtered .SAM file.
    
    input: fastq file
    output: fastq dictionary key = readid; value = qualstr 


    @type  iddic: dictionary
    @param iddic: dictionary containing read ids.

    @type  fastqfile: fastq
    @param fastqfile: HTSeq readfile object

    @type  fastqfileoutput: fastq
    @param fastqfileoutput: HTSeq readfile object. Alignments that pass the filter are written to this file.
    
    """
    counter = 0
    for read in fastq:
        # does the read fulfill the conditions?
        if read.name in iddic:
            
            counter+=1
    
    #print fastqoutput
    print (str(len(iddic)) + " ids" )
    print (str(counter) + " reads to file written..." )
    fastqoutput.close()

def getMisGapRQ(alngt):
    """
    
    Calculates the number of gaps / mistmatches from an HTSeq alngt. object.



    @type  alngt: HTSeq alignment object
    @param alngt: Initial SAM file for filtering


    @type  gaps,mism: integer
    @param gaps,mism: Number of gaps / mistmatches.
    
    """
    cigar = alngt.cigar
    # gap is generated by insertions / deletions
    gaps  = sum([i.size for i in cigar if i.type=="I" or i.type=="D"])
    #mismatches cannot be distinguished in a cigar string. use combination of cigar and md tag
    # to calculate mismatches
    mism = sum([i.size for i in cigar if i.type=="M"]) - AM.getsum(re.findall("(\d+)", alngt.optional_field("MD")))
    return(gaps,mism)

def FilterSAM(samfile,MISM,GAPS,RQ,fsam):
    """
    
    Calculates the number of gaps / mistmatches from an HTSeq alngt. object.



    @type  alngt: HTSeq alignment object
    @param alngt: Initial SAM file for filtering


    @type  gaps,mism: integer
    @param gaps,mism: Number of gaps / mistmatches.
    
    """
    
    #counter1 = good alignments ( better than threshold)
    global counter1
    
    #counter2 = bad alignments ( worse than threshold)
    global counter2
    SAM = HTSeq.SAM_Reader(samfile)
    #test = open(teststr,"w")
    #test.write("RQS\tGAPS\tMM\r\n")
    
    for alngt in SAM:
    
        # get gaps,mismatches
        gaps,mismatches = getMisGapRQ(alngt)
        try:
            # in compressed sam files, the sequence as well as the quality for one read is
            rqs = float(AM.ComputeRQScore(alngt.read.qualstr))
            
            # only once contained. For avoiding trouble just remember every sequence and qualstr
            sequence = alngt.read.seq
            qualstr = alngt.read.qualstr
            
            # threshold fulfilled?
            if (int(mismatches) <= MISM and float(rqs) >= RQ and int(gaps) <= GAPS):
                fsam.write(alngt.get_sam_line()+"\r\n")
                counter1 +=1
                
                
            else:
                # if alignment worse than threshold --> throw away
                counter2 +=1
        except:
            # if a read is incomplete (missing sequence, quality) that it's immediate predecessor must have the same sequence
            rqs = rqs
            if int(mismatches) <= MISM and rqs >= RQ and int(gaps) <= GAPS:
                
                # construct cigar
                cigar  ="".join([str(i.size)+str(i.type) for i in alngt.cigar])
                qualstr= "*"
                
                #generate tagdic for beeing able to write additional information (NM, MD tags)
                tagdic =  dict(alngt.optional_fields)                
                # NO PE Support --> '*'
                tempstr = "\t".join([alngt.read.name,str(alngt.flag),alngt.iv.chrom,str(alngt.iv.start),str(alngt.aQual),cigar,'*','0','0',sequence,qualstr,"NM:i:"+str(tagdic["NM"]),"MD:Z:"+tagdic['MD']+"\r\n"])
                fsam.write(tempstr)
                counter1 +=1
            else:
                counter2 +=1
        
        #print ("%s\t%s\t%s\r\n" % (rqs,gaps,mismatches))
                
        #test.write("%s\t%s\t%s\r\n" % (rqs,gaps,mismatches))
                
    print ("PASSED\tREMOVED\tOVERALL")
    print counter1,
    print "\t",
    print counter2,
    print "\t",
    print counter1+counter2
    

def readHeader(samfile):
    """
    Function to retrieve the header from a SAM file.
    """
    
    fobj = open(samfile)
    headerstr= ""
    
    for line in fobj:
        if line.startswith("@"):
            headerstr+=line
        else:
            # no header lines are left...
            break
    
    return(headerstr)

def filter():
    """
    Function for filtering a given sam file. The alignments are evaluated individually in regard to RQS,mismatches and gaps. The input needs to be sorted.
    
    """
    
    usage = """usage: %prog [INPUT SAM] [OUTPUTFILE] [OPTIONS]...
    Required Arguments:
    
    OUTPUTFILE:\t Outputfilename
    INPUT: \t Single SAM file to filter.
    OPTIONS: \t Paramters for filtering.
    
    Script to filter and a SAM file based on the analysis with ARDEN. It is necessary to specify the desired cutoffs for RQS,GAPS and MM.
    Type: python filter -e 1 for printing some examples.
    """
    

    
    parser = OptionParser(usage=usage,version="%prog 1.0")
    parser.add_option("-r", "--rqs", type='int',dest="rqs",action="store",  default=0,help="Threshold for minimum read quality score.. [default: %default]")
    parser.add_option("-m", "--mismatches", type='int',dest="mm",action="store",  default=3,help="Threshold for maximum number of mismatches in an alignment [default: %default]")
    parser.add_option("-g", "--gaps", type='int',dest="gaps",action="store",  default=3,help="Threshold for maximum number of gaps in an alignment. [default: %default]")
    parser.add_option("-e", "--examples", type='int',dest="pexamples",action="store",  default=0,help="set to 1 if you want to print examples")
    parser.description=""
    (options, args) = parser.parse_args() 
    numArgs = len(args)


    # numArgs has to be equal to # required parameters
    if numArgs ==2:
        input = args[0]
        output = args[1]
        # adjust options, otherwise default is used
        rqs = options.rqs
        mismatches = options.mm
        gaps = options.gaps
  
        
    else:
        print("!!!Wrong number of parameters!!!")
        parser.print_help()
        sys.exit(1)


    print("###############Your Settings:######################")

    print("inputfile:\t"+ str(input))
    print("outputfile:\t"+ str(output))
    print("RQS >=:\t\t"+ str(rqs))
    print("MM <=:\t\t" +  str(mismatches))
    print("GAPS <=:\t" + str(gaps))
    print("###################################################")
    print ("Start...")
    a = time.time()    
    header = readHeader(input)
    
    # open samfile
    fsam = open( output, "w" )
    fsam.write(header)
    # actually filter
    FilterSAM(input,mismatches,gaps,rqs,fsam)
        
    # close current Readfile
    fsam.close()
    b = time.time()
    print ("DONE in %f seconds!" %(b-a))

        
if __name__ == "__main__":
    filter()


