#!/usr/bin/python
# Gets a vector heatmap of sequencing data
# Changes max brightness algorithm
# Greg Donahue, 02-05-2014
# ------------------------------------------------------------------------------

import sys, os, datetime, time, string, math
from PIL import Image, ImageDraw, ImageFont, ImageOps

# ------------------------------------------------------------------------------
# GLOBALS

# Font used for labeling
font_file = "arial.ttf"
font_size = 14

# The desired plot height
plot_height = 600

# The margin in pixels
margin = 1

# ------------------------------------------------------------------------------
# FUNCTIONS
# The main function
def main(args):
    
    # Check arguments
    if len(args) == 1: sys.exit("USAGE: python heatvector.py FILE_1 FILE_2 "+ \
                                    "... FILE_N --sort=SORT_FILE")
    
    # Load data and meta-data
    data, meta, filenames, orderer = list(), list(), list(), dict()
    print "Loading..."
    for filename in args[1:]:
        if "--sort=" in filename:
            print "\tsorting on", filename.split("=")[1]
            orderer = loadProfiles(filename.split("=")[1])
            continue
        print "\t"+filename
        filenames.append(filename)
        f = open(filename); lines = f.readlines(); f.close()
        label = lines[0][:-1].split("\"")[1]
        color_code = lines[0][:-1].split("\"")[2].split()[0].split(",")
        separator = int(lines[0][:-1].split("\"")[2].split()[1])
        if len(lines[0][:-1].split("\"")[2].split()) > 2:
            local_maximum = float(lines[0][:-1].split("\"")[2].split()[2])
        else: local_maximum = 0.0
        meta.append((label,color_code,separator,local_maximum))
        data.append(loadProfiles(filename))
        
    # Sort data as needed
    sorted_data = list()
    if len(orderer.keys()) == 0:
        tuples = [ (sum(data[0][k]),k) for k in data[0].keys() ]
    else:
        tuples = [ (sum(orderer[k]),k) for k in orderer.keys() ]
    keys = [ k[1] for k in sorted(tuples, reverse=True) ]
    print "Sorting..."
    for i in range(len(data)):
        print "\t", filenames[i]
        sorted_data.append([ data[i][k] for k in keys ])
    del data, tuples, keys
    
    # Collapse into plot vectors
    plot_data = list()
    print "Collapsing..."
    for i in range(len(sorted_data)):
        print "\t", filenames[i]
        plot_data.append(collapse(sorted_data[i]))
    del sorted_data

    # Get the max value for each track
    max_values = list()
    for i in range(len(plot_data)):
        if meta[i][3] != 0.0: max_values.append(meta[i][3])
        else:
            values = list()
            for V in plot_data[i]:
                for v in V:
                    if v > 0: values.append(v)
            max_values.append(max(values))
    print "Choosing local maximum values..."
    for i in range(len(plot_data)):
        print "\t"+meta[i][0]+"\t"+str(max_values[i])
    
    # Determine dimensions
    dimensions = getDimensions(len(plot_data), len(plot_data[0][0])*30, meta)
    
    # Create the image
    image = Image.new("RGB", dimensions["T"], "White")
    pen = ImageDraw.Draw(image)

    # Add labels (rotate 90 degrees to save space)
    F = ImageFont.truetype(font_file, font_size)
    x, y = margin+len(plot_data[0][0])/2-font_size/2, margin
    for i in range(len(meta)):
        x += meta[i][2]
        itemp = Image.new("RGB", dimensions["F"], "White")
        pentemp = ImageDraw.Draw(itemp)
        pentemp.text((0,0), meta[i][0], font=F, fill=(0,0,0))
        rotated = itemp.rotate(90, expand=0)
        image.paste(rotated, (x,y))
        x += len(plot_data[i][0])+margin+30

    # Draw plots
    x, y = margin, y+2*margin+dimensions["F"][0]
    for i in range(len(plot_data)):
        x += meta[i][2]
        plotTrack(plot_data[i], meta[i][1], x, y, max_values[i], pen)
        x += len(plot_data[0][0])+margin+30
            
    # Add legend...nah, maybe not
    # It's meaningless, just a normalized tag density AUC

    # Save image
    stamp = string.join([ str(s) for s in time.localtime()[:6] ], "")
    image.save("Densities "+stamp+".png", "PNG")

# Plots a track
def plotTrack(data, C, x, y, M, pen):
    dx, dy = x, y
    for D in data:
        for i in range(len(D)):
            value = min(255,int(256*float(D[i])/M))
            rgb = (value if C[0] == "X" else 0,
                   value if C[1] == "X" else 0,
                   value if C[2] == "X" else 0)
            pen.line((dx+i,dy,dx+i+30,dy), fill=rgb)
        dy += 1

# Get the dimensions of the plot, legend, labels, etc
def getDimensions(num_tracks, track_length, meta):

    # Get the label height
    i = Image.new("RGB", (plot_height,plot_height), "White")
    pen = ImageDraw.Draw(i)
    longest = ""
    for M in meta:
        if len(M[0]) > len(longest): longest = M[0]
    label = pen.textsize(longest, font=ImageFont.truetype(font_file, font_size))
    
    # Get the plot dimensions
    plot = (track_length,plot_height)

    # Get the legend dimensions
    # Format top-to-bottom: band->M->labels
    # Band is the same width as the plot and the same height as the margin
    offsets = sum([ M[2] for M in meta ])
    legend = (track_length*num_tracks+margin*(num_tracks-1)+offsets,
              2*margin+label[1])

    # Get the dimensions of the entire image
    # Format top-to-bottom: M->labels->M->plot->M->legend->M
    total = (2*margin+legend[0],4*margin+label[0]+plot[1]+legend[1])

    # Return
    return { "F":label, "L":legend, "P":plot, "T":total }

# Collapse adjacent vectors into vector averages to shrink to fit plot height
def collapse(data):
    
    # Base case - if we don't need to collapse, just return the vector as-is
    if len(data) < plot_height: return data
    
    # Otherwise, enumerate through the vector by the scaling factor (bin_size)
    ret, bin_size, i = list(), float(len(data))/plot_height, 0
    while i < len(data):
        j, vectors = 0, dict()
        
        # If i is not an integer, add the fractional remainder of the previous
        if i != 0:
            vectors[j] = (int(i+1)-i,data[int(i)])
            j += 1

        # Once we've done that, add everything else as whole units
        while i+j < int(i+bin_size) and i+j < len(data):
            vectors[j] = (1.0,data[int(i+j)])
            j += 1

        # Add the fractional remainder of the next
        if i+j < len(data):
            vectors[j] = (i+bin_size-int(i+bin_size),data[int(i+j)])
        
        # Get average vector, add to RET
        avg = [ 0.0 for z in range(len(vectors[0][1])) ]
        for k in vectors.keys():
            F = vectors[k][0]
            V = vectors[k][1]
            for z in range(len(V)): avg[z] += V[z]*F
        for z in range(len(avg)): avg[z] /= len(vectors.keys())
        ret.append(avg)

        # Update
        i += bin_size

    return ret

# Gives a profile dictionary of the form (chr,start,stop)->[ VECTOR ]
def loadProfiles(filename):
    ret = dict()
    f = open(filename); lines = f.readlines(); f.close()
    for line in lines[1:]:
        t = line[:-1].split("\t")
        ret[t[0]] = [ float(x) for x in t[1:] ]
    return ret

if __name__ == "__main__": main(sys.argv)
