#!/usr/bin/python
# Creates a heatmap of input data tracks
# Greg Donahue, 10-26-2011
# ------------------------------------------------------------------------------
# NOTES ON INPUT
# 1. Takes any number of text files
# 2. Each text file should be a two-column spreadsheet, with a primary key in
#    the first column and a numerical data value in the second column
# 3. Primary keys must agree between data files, and only the set of all primary
#    keys in common between all files is used to generate the final image
# 4. Input files should also have a unique identifier on the first line
# 5. The unique identifiers have three tokens: description, color, and space;
#    see example files for details of how to construct this line
# 6. We always list the tracks vertically in the order they have on the command-
#    line, and primary keys are always sorted by their values in the first track
# ------------------------------------------------------------------------------
import sys, os, datetime, time, string, math, Image, ImageDraw, ImageFont
# ------------------------------------------------------------------------------
# GLOBALS
# The dimensions of each track in pixels
dimensions = (500,25)

# The amount of vertical whitespace between subsequent tracks in pixels
whitespace = 2

# The amount of border whitespace around the entire image in pixels
border = 20

# Should we label this heatmap?
use_labels = True

# The font used for labeling
font_name = "arial.ttf"
font_size = 14

# Are we testing for maxima?
verbose = True

# ------------------------------------------------------------------------------
# FUNCTIONS
# Main function
def main(args):
    
    # Check arguments
    if len(args) < 2: sys.exit("USAGE: python heatmap.py FILE_1 FILE_2 ..."+ \
                                   " FILE_N")
    for filename in args[1:]:
        if not os.path.isfile(filename):
            sys.exit("Could not find "+filename)

    # Load data tracks from each input file
    data, meta, order = dict(), dict(), list()
    for filename in args[1:]:
        f = open(filename); lines = f.readlines(); f.close()
        description = lines[0][:-1].split("\"")[1]
        order.append(description)
        t = lines[0][:-1].split("\"")[2].split(" ")
        data[description] = dict()
        if len(t) > 3: hc = int(t[3])
        else: hc = None
        meta[description] = (t[1],int(t[2]),hc)
        for line in lines[1:]:
            t = line[:-1].split("\t")
            data[description][t[0]] = float(t[1])
    print "Loaded", len(args)-1, "input files"

    # Get the common set of unique primary keys in rank-order
    keys = set(data[data.keys()[0]].keys())
    for k in data.keys(): keys = keys.intersection(set(data[k].keys()))
    key_tuples = [ (k,data[order[0]][k]) for k in list(keys) ]
    key_tuples = sorted(key_tuples, key=lambda k: k[1])
    keys = [ k[0] for k in key_tuples ]
    print "Found", len(keys), "primary keys in common"

    # Figure out the dimensions of the heatmap and any text labels
    width, height, text_x, text_y = getDimensions(meta)
    
    # Create the image
    image = Image.new("RGB", (width,height), "White")
    pen = ImageDraw.Draw(image)
    
    # Get the position of the first track
    x, y = border, border
    if use_labels: x += (text_x+text_x/4)
    
    # Get the vertical increment size
    increment = whitespace+dimensions[1]
    if use_labels and text_y > increment: increment = whitespace+text_y
    
    # Draw the heatmap tracks
    for D in order:
        y += meta[D][1]
        drawTrack(pen, keys, data[D], (x,y), meta[D][0], hard_color=meta[D][2])
        y += increment

    # If we are including labels, include labels
    if use_labels:
        F = ImageFont.truetype(font_name, font_size)
        x, y = border-border/4, border+dimensions[1]/4
        for D in order:
            y += meta[D][1]
            pen.text((x,y), D, font=F, fill=(0,0,0))
            y += increment
    
    # Save the image
    stamp = string.join([ str(s) for s in time.localtime()[:6] ], "")
    image.save("Heatmap_"+stamp+".png", "PNG")

# Get the dimensions of the heatmap image
def getDimensions(meta):
    
    # Set the width and height to include the tracks and the border
    width, height = dimensions[0]+border*2, dimensions[1]*len(meta)+border*2
    
    # Get the width and height of the biggest text label
    longest = ""
    for k in meta.keys():
        if len(k) > len(longest): longest = k
    pen = ImageDraw.Draw(Image.new("RGB", (300,300), "White"))
    x, y = pen.textsize(longest, font=ImageFont.truetype(font_name, font_size))
    
    # If we are using labels, add the width of the biggest label + buffer
    if use_labels: width += (x+x/4)
    
    # If we are using labels and the height of the biggest label is bigger than
    # the vertical whitespace, use the height as the whitespace quantity
    increment = whitespace
    if use_labels and y > increment+dimensions[1]: increment = y
    for k in meta.keys()[1:]: height += increment
    
    # Add any other vertical spacing specified by individual tracks
    for k in meta.keys(): height += meta[k][1]
    
    # Return the width and height
    return width, height, x, y

# Draw a heatmap track
def drawTrack(pen, keys, data, origin, color_code, hard_color=None):
    x, y = origin
    histogram, bin_size = dict(), len(keys)/dimensions[0]
    for i in range(dimensions[0]):
        values = [ data[k] for k in keys[i*bin_size:(i+1)*bin_size] ]
        histogram[i] = float(sum(values))/len(values)
    if hard_color == None:
        minimum, maximum = min(histogram.values()), max(histogram.values())
    else:
        minimum, maximum = min(histogram.values()), hard_color
    if verbose: print max(histogram.values())
    for i in range(len(histogram.keys())):
        v = histogram[histogram.keys()[i]]
        color = getColor(color_code, v, minimum, maximum)
        pen.line((x+i,y,x+i,y+dimensions[1]), fill=color)

# Get the color for a given pixel in a heatmap track
def getColor(color_code, value, minimum, maximum):
    R, G, B = color_code.split(",")
    R = int(255*value/(maximum-minimum)) if R == "X" else int(R)
    G = int(255*value/(maximum-minimum)) if G == "X" else int(G)
    B = int(255*value/(maximum-minimum)) if B == "X" else int(B)
    return (R,G,B)

# ------------------------------------------------------------------------------
# The following code is executed upon command-line invocation
if __name__ == "__main__": main(sys.argv)

# ------------------------------------------------------------------------------
# EOF
