import os 
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from scatterplotimage import ScatterPlotter
#import clustering
import model
import utilities
import colors
import optimizers
import plottingintermdata
import sys
import numpy as np
from scipy import stats
import itertools

from skimage import color
from skimage.feature import canny

def GeneratePlotImageFileFast(SP,Table, outImageFilePath,
                          data_bbox, MarkerSize, MarkerOpacity,
                          ImageWidth, ImageAspectRatio, bSaveWithAxes,  PlotColor = (0,0,0) ):

    fig, scatter_plot = SP.GeneratePlot(Table,data_bbox,
                                          MarkerSize, MarkerOpacity,
                                          ImageWidth, ImageAspectRatio, PlotColor)

    #fig.savefig("out.png", dpi=fig.dpi, bbox_inches='tight')

    scatter_plot.axis('off')
    RGBABuffer = SP.fig2data(fig)
    

    #Save with exactly the same dpi setting as we analyze the figure
    if outImageFilePath:
        utilities.ensureDirectory(outImageFilePath)
        if bSaveWithAxes:
            scatter_plot.axis('on')
        fig.savefig(outImageFilePath, dpi=fig.dpi)

    #Close the figure!
    plt.close(fig)

    return RGBABuffer


# Gets the perceived edges of points detected in the scatterplot image
# Input: scatterplot_image_rgba - numpy array with the RGBA values of the pixels of the scatterplot image
def getPerceivedPointsCannyParam(image_rgba, cannysigma = 4):
    
    # change to grayscale
    image_gray = color.rgb2gray(image_rgba)
    
    # use canny edge detector to detect edges in the grayscale image 
    # ... more about parameter setting here: http://wintopo.com/help/html/canny-opt.htm
    
    edges = canny(image_gray, sigma=cannysigma)
    
    # convert boolean edges to points
    data_edges = model.convertBooleanPixelMapToDataPoints(edges)
    if (len(data_edges)==0):
        return (False, None)
        
    return (True, data_edges)

def PlotChart(design, plotFilePath, datapointsClusters, datapoints, ClustersColors, SP, DataBoundingBox, plotAxes):
    
    
    #creating the vector of color for each data point
    total_colors = np.zeros((len(datapoints),3))

    start_off = 0
    for cluster_idx in range(0,len(datapointsClusters)):
        total_colors[start_off:start_off + len(datapointsClusters[cluster_idx])] = ClustersColors[cluster_idx]
        start_off = start_off + len(datapointsClusters[cluster_idx])
    
    GeneratePlotImageFileFast(SP,datapoints, plotFilePath, DataBoundingBox, design[0], design[1], design[2], design[3], plotAxes,  total_colors)



#dataFilesDir = "../data/figuresData/"
#dataFiles = ["randNorm_1000pnts_2clusters_aspect_ratioex.csv"]

dataFilesDir = "../data/studyData/clusters/"
dataFiles = ["randNorm_10000pnts_3clusters_3.csv"]

optimizedplotFilesDir = "../cannyEdge/"

if not os.path.exists(optimizedplotFilesDir):
    os.makedirs(optimizedplotFilesDir)


design = [15.5, 67.5, 1000, 0.5]
#design = [15.5, 67.5, 1000, 1]

cannySigmas = np.linspace(0, 20, 21)
#cannySigmas = [4]

SP = ScatterPlotter(design[0])

colorDeltas = colors.getDeltaEBetweenCategoricalColorBrewerColors()

for dataFileName in dataFiles:

    dataFilePath = dataFilesDir + dataFileName

    originalPlotFileName = optimizedplotFilesDir + dataFileName.replace(".csv","_original.png")

    datapoints, datapointsClusters, datapointsClusters_woOutliers, \
    outliers_foreachcluster, corr_perCluster = \
            utilities.getDataPointsWithClustersFromCSVFile(dataFilePath, hasHeading=True, shuffle=False)
    
    DataBoundingBox = utilities.getBBox(datapoints)
    #datapointsClusters = {0:datapoints}
    #datapointsClusters = clustering.clusterDataPoints(datapoints, maxK=maxClusters)  # use this to detect the number of clusters in the data points
    TriangleIndices = np.stack(np.triu_indices(len(datapointsClusters), 1), axis=-1)
     
     
    # Compute covariance ellipse and its properties for each actual data points cluster
    clustersMeasures = {}
    actualCovEllipses = {}
    #actualCorrelations = {}
    actualSDySDxRatios = {}
    for d in range(0,len(datapointsClusters)):
        datapointsCluster = datapointsClusters[d]
        actual_covellipse, actual_minorOnMajorAxis, actual_angle = model.getDataCovarianceEllipse(datapointsCluster)
        clustersMeasures[d] = (actual_minorOnMajorAxis, actual_angle) 
        actualCovEllipses[d] = actual_covellipse
        #actualCorrelations[d] = stats.pearsonr(datapointsCluster[:,0], datapointsCluster[:,1])[0]
        actualSDySDxRatios[d] = model.getSDySDxRatioOfPointCloud(datapointsCluster)
        
    # Compute the cluster overlap measure for each pair of clusters
    actualPairwiseClusterOverlapMeasures, actualCovEllipsesOverlapAreas = model.getPairwiseClusterOverlapMeasures(actualCovEllipses, datapointsClusters, TriangleIndices)
   
    # Compute rendering order of clusters
    datapoints, \
    datapointsClusters, \
    outliers_foreachcluster, \
    actualPairwiseClusterOverlapMeasures_relabelled, \
    clustersMeasures_relabelled, \
    actualCovEllipses_relabelled, \
    actualSDySDxRatios_relabelled = model.getRenderingOrderOfClusters(datapoints, \
                                                                      datapointsClusters, \
                                                                      outliers_foreachcluster, \
                                                                      actualPairwiseClusterOverlapMeasures, \
                                                                      clustersMeasures, \
                                                                      actualCovEllipses, \
                                                                      actualSDySDxRatios, \
                                                                      TriangleIndices)
    
    # Get distinguishable RGB colors for clusters
    # Each color is a tuple of the form (r,g,b) where each of r, g and b values is in [0,1]
    color_foreachcluster = colors.getRGBColorsForClusters(actualPairwiseClusterOverlapMeasures_relabelled, colorDeltas, TriangleIndices)          

    PlotChart(design, originalPlotFileName, datapointsClusters, datapoints, color_foreachcluster, SP, DataBoundingBox, True)
   
    for sigma in cannySigmas:

        TotalPerceivedPoints = np.array([])
        PerceivedPoints = {}

        sigmaPlotFileName = originalPlotFileName.replace("_original.png","_s"+str(sigma)+".png")

        # Compute objective score of each cluster
        for d in range(0,len(datapointsClusters)):
         
            # Data points in the dth index cluster
            datapointsCluster = datapointsClusters[d]
                    
            ImageBuffer = GeneratePlotImageFileFast(SP,datapointsCluster,
                                    None,
                                    DataBoundingBox,
                                    design[0],
                                    design[1],
                                    design[2],
                                    design[3],
                                    False)


            perceived, perceivedPoints = getPerceivedPointsCannyParam(ImageBuffer, cannysigma = sigma);

            if perceived:
                PerceivedPoints[d] = perceivedPoints;

                if TotalPerceivedPoints.size == 0:
                    TotalPerceivedPoints = perceivedPoints
                else:
                    TotalPerceivedPoints =  np.append(TotalPerceivedPoints, perceivedPoints, axis = 0)
        
        if TotalPerceivedPoints.size > 0:     
            PerceivedBoundingBox = utilities.getBBox(TotalPerceivedPoints)
            PlotChart(design, sigmaPlotFileName, PerceivedPoints, TotalPerceivedPoints, color_foreachcluster, SP, PerceivedBoundingBox, True)
   