#!/opt/software/python/2.7.5/bin/python
# Filters/collapses peaks, then assesses and re-defines MTLs
# Greg Donahue, 06-26-2017

import sys, string, os, subprocess
from chipseq import *
from sampledata import *

meta = SampleData()
fdr_cutoff = 0.001
min_patients_per_peak = 1
within_cohort_join_distance = 150
mtl_join_distance = 200
mtl_margin = 50

def main(args):
    
    # Process all MACS2 peak XLS files to get BED files
    for filename in os.listdir("Peaks"):
        if filename[-3:] == "xls": extractMacs2Peaks("Peaks/"+filename, fdr_cutoff)
        
    # Get associations between each patient peak file and the pooled peak file
    for cohort in meta.cohorts:
        for patient in meta.cohort_map[cohort]:
            command = [ "bedtools", "intersect", "-wa", "-wb",
                        "-a", "Peaks/H4K16ac."+cohort+"_peaks.FDR_"+str(fdr_cutoff)+".bed",
                        "-b", "Peaks/H4K16ac."+patient+"_peaks.FDR_"+str(fdr_cutoff)+".bed" ]
            f = open("Associations/"+cohort+"."+patient+".txt", 'w')
            f.write(subprocess.check_output(command))
            f.close()

    # Load cohort peaks
    for cohort in meta.cohorts:
        cohort_peaks = list()
        f = open("Peaks/H4K16ac."+cohort+"_peaks.FDR_"+str(fdr_cutoff)+".bed")
        lines = f.readlines(); f.close()
        for line in lines:
            t = line[:-1].split("\t")
            cohort_peaks.append((t[0],int(t[1]),int(t[2]),t[3],float(t[4]),t[5]))
        f = open("Peaks/H4K16ac."+cohort+".Multipatient.bed", 'w')
        for CP in cohort_peaks: f.write(string.join([ str(x) for x in CP ], "\t")+"\n")
        f.close()
        print "Found", len(cohort_peaks), "peaks satisfying patient threshold among", cohort, "patients"

    # Collapse overlapping peaks in each cohort and construct MTLs
    processed_peaks = dict()
    for cohort in meta.cohorts:
        collapsePeaks(cohort)
        processed_peaks[cohort] = loadLocusDictionary("Peaks/H4K16ac."+cohort+".Joined.bed")
    cohereMTLs(processed_peaks)
    del processed_peaks

    # Associate pooled peaks with MTLs and redefine MTLs using peak dimensions
    for cohort in meta.cohorts:
        command = [ "bedtools", "intersect", "-wa", "-wb",
                    "-a", "Peaks/H4K16ac."+cohort+".Joined.bed",
                    "-b", "MTLs/MTLs.bed" ]
        f = open("Associations/"+cohort+".MTLs.txt", 'w')
        f.write(subprocess.check_output(command))
        f.close()
    mtls = loadMTLs("MTLs/MTLs.bed")
    refineMTLs(mtls)
    collapseRedundantMTLs()

# Collapses redundant MTLs (same locus, different membership) by merging membership
def collapseRedundantMTLs():
    f = open("MTLs/MTLs.Expanded.bed"); lines = f.readlines(); f.close()
    mtls = dict()
    for line in lines:
        t = line[:-1].split("\t")
        try: mtls[(t[0],t[1],t[2])].extend(t[3].split("."))
        except Exception, e: mtls[(t[0],t[1],t[2])] = t[3].split(".")
    f = open("MTLs/MTLs.Expanded.bed", 'w')
    for M in mtls.keys():
        membership = sorted(list(set(mtls[M])))
        f.write(M[0]+"\t"+M[1]+"\t"+M[2]+"\t"+string.join(membership, ".")+"\t0\t+\n")
    f.close()

# Redefine MTLs using overlapping pooled peaks as a bounding box
def refineMTLs(mtls):
    associated = dict()
    for cohort in meta.cohorts: associated[cohort] = loadAssociations("Associations/"+cohort+".MTLs.txt", reverse=True)
    f = open("MTLs/MTLs.Expanded.bed", 'w')
    for M in mtls.keys():
        total = list()
        for cohort in meta.cohorts:
            try: total.extend(associated[cohort][M])
            except Exception, e: pass
        if len(total) == 0: raise Exception("Cannot find peaks for MTL "+M[0]+":"+str(M[1])+"-"+str(M[2]))
        start, stop = min([ x[1] for x in total ]), max([ x[2] for x in total ])
        f.write(M[0]+"\t"+str(start)+"\t"+str(stop)+"\t"+mtls[M]+"\t0\t+\n")
    f.close()

# Loads all MTLs into a dictionary, giving the membership
def loadMTLs(filename):
    ret = dict()
    f = open(filename); lines = f.readlines(); f.close()
    for line in lines:
        t = line[:-1].split("\t")
        ret[(t[0],int(t[1]),int(t[2]),t[3],float(t[4]),t[5])] = t[3]
    return ret

# Create MTLs from pooled-patient peaks in different cohorts
def cohereMTLs(peaks):
    pan = getPanLocusDictionary(peaks)
    f = open("MTLs/MTLs.bed", 'w')
    for chromosome in pan.keys():
        pan[chromosome] = sorted(pan[chromosome])
        mtls, previous, current = list(), pan[chromosome][0], [ pan[chromosome][0] ]
        for locus in pan[chromosome][1:]:
            if locus[0]-previous[0] < mtl_join_distance: current.append(locus)
            else:
                members = sorted(list(set([ x[1] for x in current ])))
                coordinates = (current[0][0]-mtl_margin,current[-1][0]+mtl_margin)
                mtls.append((coordinates, members))
                current = [ locus ]
            previous = locus
        members = sorted(list(set([ x[1] for x in current ])))
        coordinates = (current[0][0]-mtl_margin,current[-1][0]+mtl_margin)
        mtls.append((coordinates, members))
        for M in mtls: f.write(chromosome+"\t"+str(M[0][0])+"\t"+str(M[0][1])+"\t"+string.join(M[1], ".")+"\t0\t+\n")
    f.close()

# Get a merged locus dictionary, retaining peak centers + cohort info
def getPanLocusDictionary(loci):
    ret = dict()
    for cohort in loci.keys():
        for chromosome in loci[cohort].keys():
            if not chromosome in ret.keys(): ret[chromosome] = list()
            for locus in loci[cohort][chromosome]:
                C = int(locus[1]+(locus[2]-locus[1])/2)
                ret[chromosome].append((C, cohort))
    return ret
    
# Filters out under-represented (singleton) peaks
def filterPeaksPerPatient(counts):
    ret = list()
    for peak in counts.keys():
        if counts[peak] >= min_patients_per_peak: ret.append(peak)
    return list(set(ret))

# Collapse peaks, creating a new joined peak file
def collapsePeaks(cohort):
    peaks = loadLocusDictionary("Peaks/H4K16ac."+cohort+".Multipatient.bed")
    f = open("Peaks/H4K16ac."+cohort+".Joined.bed", 'w')
    for chromosome in peaks.keys():
        if len(peaks[chromosome]) == 1:
            f.write(string.join([ str(x) for x in peaks[chromosome][0][:3]+[ ".\t0\t+" ] ], "\t")+"\n")
            continue
        loci = sorted(peaks[chromosome])
        previous, current = loci[0], [ loci[0] ]
        for next in loci[1:]:
            if next[1]-previous[2] <= within_cohort_join_distance: current.append(next)
            else:
                start, stop = min([ x[1] for x in current ]), max([ x[2] for x in current ])
                f.write(chromosome+"\t"+str(start)+"\t"+str(stop)+"\t.\t0\t+\n")
                current = [ next ]
            previous = next
        start, stop = min([ x[1] for x in current ]), max([ x[2] for x in current ])
        f.write(chromosome+"\t"+str(start)+"\t"+str(stop)+"\t.\t0\t+\n")
    f.close()

# For each cohort peak, records the number of patients with overlapping peaks
# Cohort peaks without overlaps do not appear in the reported dictionary
def loadAssociationCounts():
    ret = dict()
    for cohort in meta.cohorts:
        ret[cohort] = dict()
        for patient in meta.cohort_map[cohort]:
            associations = loadAssociations("Associations/"+cohort+"."+patient+".txt")
            for cohort_peak in associations.keys():
                try: ret[cohort][cohort_peak] += 1
                except Exception, e: ret[cohort][cohort_peak] = 1
            del associations
    return ret

# Loads a map between two BED6 files (bedtools intersect output)
def loadAssociations(filename, reverse=False):
    ret = dict()
    f = open(filename); lines = f.readlines(); f.close()
    for line in lines:
        t = line[:-1].split("\t")
        cohort_peak = (t[0],int(t[1]),int(t[2]),t[3],float(t[4]),t[5])
        patient_peak = (t[6],int(t[7]),int(t[8]),t[9],float(t[10]),t[11])
        if reverse:
            try: ret[patient_peak].append(cohort_peak)
            except Exception, e: ret[patient_peak] = [ cohort_peak ]
        else:
            try: ret[cohort_peak].append(patient_peak)
            except Exception, e: ret[cohort_peak] = [ patient_peak ]
    return ret

if __name__ == "__main__": main(sys.argv)
