"""
Scatter plot creation and evaluation using matplotlib.

Authors: Gregorio Palmas and Tino Weinkauf, January 2016
"""

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import ImageMetrics
import utilities
import csv


class ScatterPlotter():
    """Creates scatter plots and evaluates their qualities.

    MaxMarkerSize: maximum MarkerSize to be tested in this particular particular run of the optimizer 
    (e.g., if in the main script of the optimizer, marker_sizes=[1,2,3,4,5] then 
    MaxMarkerSize=5 even though if for instance now we are considering a design with MarkerSize=1)

    Many aspects of matplotlib are given in the unit 'points'.
    A point is 1/72 inch. For a dpi=72, we then have 1 pixel = 1 point.
    A dpi setting of 72 is the default for monitors on Windows.
    """

    def __init__(self, MaxMarkerSize, DPI=72):
        self.PixelCoverageCache = dict()
        self.MaxMarkerSize = MaxMarkerSize
        self.DPI = DPI

    def is_number(self, s):
        try:
            float(s)
            return True
        except ValueError:
            return False


    def get_axes_size(self, axes, fig):
        bbox = axes.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
        width, height = bbox.width, bbox.height
        width *= fig.dpi
        height *= fig.dpi
        return width, height


    def fig2data(self, fig):
        """
        @brief Convert a Matplotlib figure to a 3D numpy array with RGBA channels and return it
        @param fig a matplotlib figure
        @return a numpy 3D array of RGBA values
        """
        # draw the renderer
        fig.canvas.draw()
        
        # grab the pixel buffer and dump it into a numpy array
        # -> this does not work on mac, but very nice on Windows
        #rgba = np.array(fig.canvas.renderer._renderer)
        ##print rgba.strides
        #return rgba

        #Alternatives:
        l, b, w, h = fig.bbox.bounds
        buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8)
        buf.shape = (h, w, 4)
        # canvas.tostring_argb give pixmap in ARGB mode.  Roll the ALPHA
        # channel to have it in RGBA mode
        buf = np.roll(buf, 3, axis = 2)

        ## grab the pixel buffer and dump it into a numpy array
        #buf = fig.canvas.buffer_rgba()
        #l, b, w, h = fig.bbox.bounds
        ## The array needs to be copied, because the underlying buffer
        ## may be reallocated when the window is resized.
        #X = np.frombuffer(buf, np.uint8).copy()
        #X.shape = h,w,4
        
        return buf 


    def GetSingleMarkerPixelCoverage(self, MarkerSize, MarkerOpacity, ImageWidth):

        #Create a tuple from the params
        Key = MarkerSize, MarkerOpacity, ImageWidth

        #Already computed before?
        if Key not in self.PixelCoverageCache:
            #Nope, have to compute!
            #Generate Data. Single point at the origin
            Table = np.zeros((1,2))
            #Actual plot
            fig, scatter = self.GeneratePlot(Table, [], MarkerSize, MarkerOpacity, ImageWidth, 1.0)
            scatter.axis('off')
            #Measure
            MeasuredPaint, ignored, NumPaintedPixels = ImageMetrics.MeasurePaintAndPixels(self.fig2data(fig))
            #Finalize
            #fig.savefig(OutFileName + "_%g.png" % MarkerSize, dpi=fig.dpi)
            plt.close(fig)
            #Save to cache
            self.PixelCoverageCache[Key] = MeasuredPaint, NumPaintedPixels

        return self.PixelCoverageCache[Key]


    #Table: a ndarray with shape (n, 2)
    #outImageFilePath: the file path where the generated scatterplot image file should be saved
    #data_bbox: the bounding box of the original data, neede to reneder the clusters
    #MarkerSize: the size of the markers
    #MarkerOpacity: the opacity of the markers
    #ImageWidth: the width of the scatter plot
    #ImageAspectRatio: the aspect ratio of the scatter plot
    #PlotColor = the color used for the scatterplot, black by default OTHERWISE SUPPOSED TO BE AN ARRAY OF n 3D TUPLES (R,G,B) 
    def GeneratePlotImageFile(self,Table, outImageFilePath,
                              data_bbox, MarkerSize, MarkerOpacity,
                              ImageWidth, ImageAspectRatio, bSaveWithAxes,  PlotColor = (0,0,0) ):
            
        SingleMarkerPixelCoverage = self.GetSingleMarkerPixelCoverage(MarkerSize, MarkerOpacity, ImageWidth)
    
        fig, scatter_plot = self.GeneratePlot(Table,data_bbox,
                                              MarkerSize, MarkerOpacity,
                                              ImageWidth, ImageAspectRatio, PlotColor)
    
        #fig.savefig("out.png", dpi=fig.dpi, bbox_inches='tight')

        scatter_plot.axis('off')
        RGBABuffer = self.fig2data(fig)
        factors = ImageMetrics.GetImageQualityFactors(RGBABuffer, MarkerSize, MarkerOpacity, Table.shape[0], SingleMarkerPixelCoverage)
    
        

        #Save with exactly the same dpi setting as we analyze the figure
        if outImageFilePath:
            utilities.ensureDirectory(outImageFilePath)
            if bSaveWithAxes:
                scatter_plot.axis('on')
            fig.savefig(outImageFilePath, dpi=fig.dpi)
    
        #Close the figure!
        plt.close(fig)

        return factors + (RGBABuffer,)
    

    def GeneratePlot(self, 
                     Table,
                     data_bbox, MarkerSize, MarkerOpacity,
                     ImageWidth, ImageAspectRatio, PointColors = (0,0,0)):
    
        fig, scatter_plot = plt.subplots();
        fig.tight_layout()
        background = fig.patch
        background.set_facecolor((1,1,1,0))
        scatter_plot.axis('off')
        scatter_plot.grid(False)

        self.FillPlot(fig, scatter_plot, Table, data_bbox, MarkerSize, MarkerOpacity, ImageWidth, ImageAspectRatio, PointColors)

        return fig, scatter_plot


    def FillPlot(self, fig, scatter_plot,
                 Table,
                 data_bbox, MarkerSize, MarkerOpacity,
                 ImageWidth, ImageAspectRatio, PointColors = (0,0,0)):

        #Many aspects of matplotlib are given in the unit 'points'.
        #A point is 1/72 inch. For a dpi=72, we then have 1 pixel = 1 point.
        #A dpi setting of 72 is the default for monitors on Windows.
        fig.set_dpi(self.DPI)
        fig.set_size_inches(ImageWidth / self.DPI, (ImageWidth * ImageAspectRatio) / self.DPI, forward=True)
   
        ColTable = Table.T
    
       

        scatter_plot.scatter(ColTable[0], ColTable[1],
                             s = MarkerSize * MarkerSize,
                             marker = "o",
                             c = [color for color in PointColors] , alpha = float(MarkerOpacity)/255.0, edgecolor='none')

        axes = fig.gca()
        axes.set_autoscale_on(False)

        if (len(data_bbox) >= 4):
           scatter_plot.set_xlim([data_bbox[0],data_bbox[1]])
           scatter_plot.set_ylim([data_bbox[2],data_bbox[3]])

        #Add some space to the axes such that the largest Marker will fit
        axis_x_range = scatter_plot.get_xlim()
        axis_y_range = scatter_plot.get_ylim()
        axes_width, axes_height = self.get_axes_size(scatter_plot, fig)
    
        perc_max_marker_size_width = self.MaxMarkerSize / axes_width
        bottom_axis_range = axis_x_range[1] - axis_x_range[0];
        offset_to_add_bottom = 1.5 * bottom_axis_range * perc_max_marker_size_width;
        scatter_plot.set_xlim([axis_x_range[0] - offset_to_add_bottom, axis_x_range[1] + offset_to_add_bottom])
    
        perc_max_marker_size_height = self.MaxMarkerSize / axes_height
        left_axis_range = axis_y_range[1] - axis_y_range[0]
        offset_to_add_left = 1.5 * left_axis_range * perc_max_marker_size_height ;
        scatter_plot.set_ylim([axis_y_range[0] - offset_to_add_left, axis_y_range[1] + offset_to_add_left])





# Generate a scatterplot image with the given design for the given data
# Markers are drawn with the given color
# Axes can be hidden; Background can be transparent or white
def GeneratePlotWithDesignForData (dataFilePath, outputPlotFilePath, \
                                   max_marker_size, marker_size, marker_opacity, \
                                   image_width, image_aspect_ratio, \
                                   marker_rgb_color=(0,0,0), showAxes=False, transparentBackground=True):


    SP = ScatterPlotter(max_marker_size)

    datapoints, datapointsClusters, datapointsClusters_woOutliers, \
    outliers_foreachcluster, corr_perCluster = \
            utilities.getDataPointsWithClustersFromCSVFile(dataFilePath, hasHeading=True, shuffle=False)
    
    DataBoundingBox = utilities.getBBox(datapoints)

    color = np.array(marker_rgb_color)
    
    for d in range(len(datapointsClusters)):
                            
        fig, scatter_plot = plt.subplots();
        fig.tight_layout()
        scatter_plot.axis('on' if showAxes else 'off')
        scatter_plot.grid(False)

        SP.FillPlot(fig, scatter_plot, datapointsClusters[d], 
                            DataBoundingBox,
                            marker_size,
                            marker_opacity,
                            image_width,
                            image_aspect_ratio,
                            color)
        
        fig.savefig(outputPlotFilePath, dpi=fig.dpi, transparent=transparentBackground)                   
    

