#!/usr/bin/python
# Contains classes for easy access to sample data
# Greg Donahue, 06-26-2017

import sys, string
from chipseq import *

class SampleData:

    # patients is a list of patients
    # cohorts is a list of cohorts (study groups)
    # patient_map gives the cohort for each patient
    # cohort_map gives the patient list for each cohort
    # cohort_codes is a map of single letter cohort codes
    
    def __init__(self):
        self.patient_map = dict()
        self.cohort_map = dict()
        self.cohort_codes = dict()
        f = open("SampleSheet.csv"); lines = f.readlines(); f.close()
        for line in lines:
            t = line[:-1].split(",")
            self.patient_map[t[0]] = t[1]
            try: self.cohort_map[t[1]].append(t[0])
            except Exception, e: self.cohort_map[t[1]] = [ t[0] ]
            self.cohort_codes[t[1]] = t[2]
        self.patients = list(set(self.patient_map.keys()))
        self.cohorts = list(set(self.cohort_map.keys()))

class CovariateData:

    # patients is a list of patients
    # neuronal_fraction maps patients to their neuron / total tissue proportion
    # batch maps patient samples to their sequencing batch
    
    def __init__(self):
        self.neuronal_fraction = dict()
        self.batch = dict()
        f = open("SampleCovariates.csv"); lines = f.readlines(); f.close()
        for line in lines:
            t = line[:-1].split(",")
            self.neuronal_fraction[t[0]] = float(t[1])
            self.batch[t[0]] = t[2]
        self.patients = self.batch.keys()

class RNAseqData:
    
    # cohorts is a list of cohorts in this analysis
    # comparisons is the list of comparisons for which p-values were calculated
    # expression is a dictionary mapping cohort average values for each transcript
    # significance is a dictionary mapping p-values for each transcript in each comparison
    
    def __init__(self):
        f = open("DESeq Table.txt"); lines = f.readlines(); f.close()
        headers, self.cohorts, self.comparisons = lines[0][:-1].split("\t")[1:], list(), list()
        self.expression, self.significance = dict(), dict()
        for H in headers:
            if "p(" in H: self.comparisons.append(tuple(H[2:-1].split(":")))
            else: self.cohorts.append(H)
        for line in lines[1:]:
            t = line[:-1].split("\t")
            self.expression[t[0]], self.significance[t[0]] = dict(), dict()
            for i in range(len(t[1:])):
                try: value = float(t[1:][i])
                except Exception, e: value = "N/A"
                if "p(" in headers[i]: self.significance[t[0]][tuple(headers[i][2:-1].split(":"))] = value
                else: self.expression[t[0]][headers[i]] = value

if __name__ == "__main__": main(sys.argv)
