import matplotlib.pyplot as plt
import colors
import model
plt.ion()
#import os 
from scatterplotimage import ScatterPlotter
#import clustering
#import model
import utilities
#import colors
#import optimizers
#import plottingintermdata
#import sys
import numpy as np
#from scipy import stats
#import itertools


def tic():
    #Homemade version of matlab tic and toc functions
    import time
    global startTime_for_tictoc
    startTime_for_tictoc = time.time()

def toc():
    import time
    if 'startTime_for_tictoc' in globals():
        print "Elapsed time is " + str(time.time() - startTime_for_tictoc) + " seconds."
    else:
        print "Toc: start time not set"

def MahalanobisDist(x, y):
    covariance_xy = np.cov(x,y, rowvar=0)
    inv_covariance_xy = np.linalg.inv(covariance_xy)
    xy_mean = np.mean(x),np.mean(y)
    x_diff = np.array([x_i - xy_mean[0] for x_i in x])
    y_diff = np.array([y_i - xy_mean[1] for y_i in y])
    diff_xy = np.transpose([x_diff, y_diff])
    
    md = []
    for i in range(len(diff_xy)):
        md.append(np.sqrt(np.dot(np.dot(np.transpose(diff_xy[i]),inv_covariance_xy),diff_xy[i])))

    return md

def MD_removeOutliers(x, y, Threshold):
    MD = MahalanobisDist(x, y)
    threshold = np.mean(MD) * Threshold # adjust accordingly 
    nx, ny, outliers = [], [], []
    for i in range(len(MD)):
        if MD[i] > threshold:
            outliers.append(i) # position of removed pair
    return (np.array(outliers))


def MahalanobisDistance(datapoints):
    Covariance = np.cov(datapoints, rowvar=0)
    InvCovariance = np.linalg.inv(Covariance)
    CenterEstimation = np.mean(datapoints, axis = 0)
    Diff = datapoints - CenterEstimation

    NumElems = len(datapoints)
    MahalanobisDistance = np.empty(NumElems)
    for i in xrange(NumElems):
        MahalanobisDistance[i] = np.sqrt(np.dot(np.dot(np.transpose(Diff[i]),InvCovariance),Diff[i]))

    return MahalanobisDistance


def MahalanobisOutlierDetection(datapoints, Threshold):
    MD = MahalanobisDistance(datapoints)
    Cutoff = np.mean(MD) * Threshold # adjust accordingly 
    Outliers = np.where(MD > Cutoff)[0]
    return Outliers


# Get a scatter plotter
SP = ScatterPlotter(53)

dataFilePath = "../data/studyData/corr/randNorm_10000pnts_1clusters_6.csv"
datapoints, datapointsClusters, datapointsClusters_woOutliers, \
outliers_foreachcluster, corr_perCluster = \
        utilities.getDataPointsWithClustersFromCSVFile(dataFilePath, hasHeading=True, shuffle=False)
 
DataBoundingBox = utilities.getBBox(datapoints)
marker_size = 20
marker_opacity = 100
image_width = 1000
image_aspect_ratio = 1

fig, scatter_plot = plt.subplots();
fig.tight_layout()
background = fig.patch
background.set_facecolor((1,1,1,0))
scatter_plot.axis('off')
scatter_plot.grid(False)


def Filter(NewThreshold):

    scatter_plot.clear()

    Colors = np.zeros((len(datapoints),3))
    tic()
    Outliers = MahalanobisOutlierDetection(datapoints, NewThreshold)
    toc()
    #creating the vector of color for each data point
    if len(Outliers) > 0:
        Colors[Outliers] += [1,0,0]

    #tic()
    #O = MD_removeOutliers(datapoints[..., 0], datapoints[..., 1], NewThreshold)
    #if len(O) > 0:
    #    Colors[O] += [0,0,1]
    #toc()

    SP.FillPlot(fig, scatter_plot, datapoints,
                DataBoundingBox,
                marker_size,
                marker_opacity,
                image_width,
                image_aspect_ratio,
                Colors)



Filter(1.5)

#if __name__ == "__main__":

