"""
Image evaluation.

Author: Tino Weinkauf, December 2015
"""

from __future__ import division
import math 
import matplotlib.pyplot as plt
import numpy as np
import skimage
from PIL import Image
from skimage.measure import structural_similarity as ssim

def MeasurePaintAndPixels(RGBABuffer):
    """Measures the amount of paint in the given image and the number of painted pixels.
    Returns them as a tuple.

    Call this function like this:
    MeasuredPaint, MeasuredPaint2, NumPaintedPixels = MeasurePaintAndPixels(...)

    Arguments:
    RGBABuffer -- the scatter plot image without the axes as a numpy array

    Output:
    MeasuredPaint
    MeasuredPaint2
    NumPaintedPixels
    """

    # Measure: These are numpy routines that are significantly faster than the code below.
    AlphaBuffer = RGBABuffer.flatten()[3::4]
    NumPaintedPixels = np.count_nonzero(AlphaBuffer)
    MeasuredPaint = AlphaBuffer.sum() / 255.0
    MeasuredPaint2 = np.square(AlphaBuffer).sum() / 65025.0 # 255^2

    ## Run over the image and compute our metrics
    #MeasuredPaint = 0
    #MeasuredPaint2 = 0
    #NumPaintedPixels = 0
    #for PixelAlpha in AlphaBuffer:
    #    if (PixelAlpha > 0):
    #        # Count the number of pixels that got at least a bit of paint
    #        NumPaintedPixels += 1
    #        # Summing up the alpha to see how much paint is in the image
    #        ThisAlpha = PixelAlpha / 255.0
    #        MeasuredPaint += ThisAlpha
    #        MeasuredPaint2 += ThisAlpha * ThisAlpha

    return MeasuredPaint, MeasuredPaint2, NumPaintedPixels


def GetImageQualityFactors(RGBABuffer, MarkerSize, MarkerOpacity, NumOfDataPoints, SingleMarkerPixelCoverage,
                           DesiredMean = 0.5, DesiredContrast = 0.1):
    """Evaluates different quality metrics for the scatter plot.
    Returns them as a tuple.

    Call this function like this:
    OverlapFactor, OverplottingFactor, LightnessFactor, MeanFactor, ContrastFactor = GetImageQualityFactors(...)

    Arguments:
    RGBABuffer -- the scatter plot image without the axes as a numpy array
    MarkerSize -- radius of the marker, assuming a circle as a marker
    MarkerOpacity -- opacity of the marker, assumed constant for all markers
    NumOfDataPoints -- number of data points that we are visualizing
    SingleMarkerPixelCoverage -- the amount of paint and number of pixels
                                 covered by a single marker (as measured by MeasurePaintAndPixels)

    Output:
    OverlapFactor
    OverplottingFactor
    LightnessFactor
    MeanFactor
    ContrastFactor
    """
    # What a single marker covers in terms of pixels
    SingleMeasuredPaint, SingleNumPaintedPixels = SingleMarkerPixelCoverage

    # Run over the image and compute our metrics
    MeasuredPaint, MeasuredPaint2, NumPaintedPixels = MeasurePaintAndPixels(RGBABuffer)

    # The overlap factor tells us whether markers have been placed on top of each other.
    # It is 0, when no markers overlapped.
    # It is close to 1, when all markers overlap.
    OverlapFactor = 0
    if (NumOfDataPoints * SingleNumPaintedPixels > 0):
        OverlapFactor = 1.0 - float(NumPaintedPixels) / float(NumOfDataPoints * SingleNumPaintedPixels)

    # The overplotting factor is 0, if we find just as much paint in the image as expected by the data.
    # It goes up to 1, if there is less paint in the image due to exceeding the 255-opacity limit of some pixels.
    AvailablePaint = NumOfDataPoints * SingleMeasuredPaint
    OverplottingFactor = 0
    if (AvailablePaint != 0):
        OverplottingFactor = 1.0 - float(MeasuredPaint) / float(AvailablePaint)

    # Average amount of paint
    Mean = 0
    if NumPaintedPixels > 0:
        Mean = float(MeasuredPaint) / float(NumPaintedPixels)

    # The lightness factor is based on the average amount of paint in the image.
    # If closer to 1, then the image is very whitish.
    # If 0, then all painted pixels are fully black.
    LightnessFactor = 1.0 - Mean

    # The mean factor is based on the average amount of paint in the image.
    # If 0, then the average amount of paint is to our liking.
    # If larger than 0, the image is either too light or too dark.
    MeanFactor = math.fabs(DesiredMean - Mean)

    # The contrast factor is always positive and should be around 0 and 1, but can be larger than 1.
    # If 0, the contrast is to our liking.
    # If larger than 0, the contrast is either too low or too high.
    Sigma = 0
    if NumPaintedPixels > 1:
        # I was almost in shock when I saw this formula. How can this be correct?
        # But I did implement it correctly. It is a version that is useful when you have streaming data.
        # And then some... but it works out.
        # See for example:
        # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
        Sigma = math.sqrt(math.fabs(MeasuredPaint2 - NumPaintedPixels * Mean * Mean) / (NumPaintedPixels - 1))
    ContrastFactor = math.fabs(DesiredContrast - Sigma)

    # Clamp all factors to [0, 1]
    AllFactors = tuple(np.clip((OverlapFactor, OverplottingFactor, LightnessFactor, MeanFactor, ContrastFactor), 0, 1))

    return AllFactors


def CompositeImages(Images, Alphas, SortedKeys, Schusterjunge = "Tino"):
    CompositeImage = np.full_like(Images.itervalues().next(), 1.0)
    for i in SortedKeys:
        if i is not Schusterjunge:
            CompositeImage = (1.0 - Alphas[i][..., None]) * CompositeImage  +  Alphas[i][..., None] * Images[i]

    return CompositeImage


def CompositeAlphas(Alphas, SortedKeys):
    CompositeAlpha = np.full_like(Alphas.itervalues().next(), 0.0)
    for i in SortedKeys:
        CompositeAlpha = (1.0 - Alphas[i]) * CompositeAlpha  +  Alphas[i]

    return CompositeAlpha


def MeasureClusterPerceivability(ClusterImages, ClusterColors):
    """Checks how perceivable the clusters are within the final image.

    Starting from the individually plotted cluster,
    the final image is composed using alpha blending.
    Additional images are created that contain all but one cluster.
    These are compared to the final image.

    A composite score reports the perceivability of the clusters as a whole.
    """

    # Convert to floating point
    Images = {}
    Alphas = {}
    for c in ClusterImages:
        Images[c] = ClusterImages[c][..., :3].astype(np.float64) / 255.
        Alphas[c] = ClusterImages[c][...,  3].astype(np.float64) / 255.

        Images[c] = 1.0 - (1.0 - ClusterColors[c]) * (1.0 - Images[c])
        #Image.fromarray(skimage.img_as_ubyte(Images[c])).save('outfile_single.png')

    # Make sure we always render in the correct order. After all, a dict is unordered.
    # The rest of the code assumes an order, but that is actually a bug
    SortedKeys = sorted(Images.keys())

    # Render the final image using back-to-front alpha compositing
    FinalImage = CompositeImages(Images, Alphas, SortedKeys)

    # Create a mask for where to assess structural similarity
    FinalAlpha = CompositeAlphas(Alphas, SortedKeys)
    AlphaMask = PrepareAlphaMask(FinalAlpha)

    #np.clip(FinalImage, 0.0, 1.0, FinalImage)
    #Image.fromarray(skimage.img_as_ubyte(FinalImage)).save('outfile.png')
    #Image.fromarray(skimage.img_as_ubyte(FinalAlpha)).save('outfilealpha.png')
    #Image.fromarray(skimage.img_as_ubyte(AlphaMask)).save('outfilealphamask.png')

    # For each cluster, composite all images except that one and compare the result to the full image using Structural Similarity.
    MeanStructuralSimilarity = np.empty(len(SortedKeys))
    for index, i in enumerate(SortedKeys):
        CompositeImageWithoutCurrentCluster = CompositeImages(Images, Alphas, SortedKeys, i)
        #np.clip(CompositeImageWithoutCurrentCluster, 0.0, 1.0, CompositeImageWithoutCurrentCluster)
        #Image.fromarray(skimage.img_as_ubyte(CompositeImageWithoutCurrentCluster)).save('outfile%d.png' % i)

        # Structural Similarity
        MeanStructuralSimilarity[index] = \
        compare_ssim(CompositeImageWithoutCurrentCluster, FinalImage, AlphaMask,
                        multichannel = True,
                        gaussian_weights = True, sigma = 1.5, use_sample_covariance = False,
                        full = False) #True for debug

        #np.clip(SSIMArray, 0.0, 1.0, SSIMArray)
        #Image.fromarray(skimage.img_as_ubyte(SSIMArray)).save('outfile%d_ssim.png' % i)

    #return (1.0 - MeanStructuralSimilarity.min())
    return (FinalImage, AlphaMask, MeanStructuralSimilarity.max())


def GetMaximumOutliersPerc(TotalOutliers, TotalOutliersColors, AllDataPoints, AllDataPointsColors, design, data_bbox, SP):

    if len(TotalOutliers) == 0:
        return 1

    fig, scatter = SP.GeneratePlot(TotalOutliers, data_bbox, design[0], design[1],design[2],design[3], TotalOutliersColors)
    scatter.axis('off')
    OutliersImage = SP.fig2data(fig)
    plt.close(fig)

    
    
    total_fig, total_scatter = SP.GeneratePlot(AllDataPoints, data_bbox, design[0], design[1],design[2],design[3], AllDataPointsColors)
    total_scatter.axis('off')

    TotalImage = SP.fig2data(total_fig)
    plt.close(total_fig)

    OutliersImageAlphas = np.divide(OutliersImage[...,3].astype(np.float64), 255.)
    OutliersImage = np.divide(OutliersImage[...,:3].astype(np.float64), 255.)
    OutliersImage = (1.0 - OutliersImageAlphas[..., None]) + OutliersImageAlphas[..., None] * OutliersImage

    #TESTING
    #Image.fromarray(skimage.img_as_ubyte(OutliersImage)).save('out_outliers.png')

    TotalImageAlphas = np.divide(TotalImage[...,  3].astype(np.float64), 255.)
    TotalImage = np.divide(TotalImage[...,:3].astype(np.float64), 255.)
    TotalImage = (1.0 - TotalImageAlphas[..., None]) + TotalImageAlphas[..., None] * TotalImage

    FinalAlphaMask = PrepareAlphaMask(TotalImageAlphas)

    #TESTING
    #Image.fromarray(skimage.img_as_ubyte(TotalImage)).save('out_total.png')
    #Image.fromarray(skimage.img_as_ubyte(FinalAlphaMask)).save('out_mask.png')

    WhiteFig = np.empty(OutliersImage.shape)
    WhiteFig.fill(1)

    StructuralSimilarityOutliersVsWhite = compare_ssim(WhiteFig, OutliersImage, FinalAlphaMask,
                        multichannel = True,
                        gaussian_weights = True, sigma = 1.5, use_sample_covariance = False,
                        full = False)

    return StructuralSimilarityOutliersVsWhite


def MeasureOutliersPerceivability(FinalImage, FinalAlphaMask, TotalPointsWithoutOutliers, TotalPointsWithoutOutliersColors, \
                                  MaxOutliersPerc, design, data_bbox, scatter_plotter):
    """Checks how perceivable the outliers are within the final image.

    Starting from the individually plotted cluster,
    the final image is composed using alpha blending.
    The image with the ouliers is added to the final image.
    An additional image is created to contain all but the outliers.
    This is compared to the final image.

    A composite score reports the perceivability of the outliers.
    """
    
    if MaxOutliersPerc == 1:
        return 1

    fig_no_out, scatter_no_out = scatter_plotter.GeneratePlot(TotalPointsWithoutOutliers, data_bbox, design[0], design[1],design[2],design[3], TotalPointsWithoutOutliersColors)
    scatter_no_out.axis('off')

    ImageWithoutOutliers = scatter_plotter.fig2data(fig_no_out)
    plt.close(fig_no_out)

    ImageWithoutOutliersAlphas = np.divide(ImageWithoutOutliers[...,3].astype(np.float64), 255.)
    ImageWithoutOutliers = np.divide(ImageWithoutOutliers[...,:3].astype(np.float64), 255.)
    ImageWithoutOutliers = (ImageWithoutOutliersAlphas[..., None] * ImageWithoutOutliers) + (1 - ImageWithoutOutliersAlphas[..., None])
    

    #ONLY FOR TESTING
    #Image.fromarray(skimage.img_as_ubyte(ImageWithoutOutliers)).save('out_no_outliers.png')
    #Image.fromarray(skimage.img_as_ubyte(FinalImage)).save('out_with_outliers.png')

    MeanStructuralSimilarity = compare_ssim(ImageWithoutOutliers, FinalImage, FinalAlphaMask,
                        multichannel = True,
                        gaussian_weights = True, sigma = 1.5, use_sample_covariance = False,
                        full = False) #True for debug

    MeanStructuralSimilarity = max(0,(MeanStructuralSimilarity - MaxOutliersPerc)/(1 - MaxOutliersPerc))

    return MeanStructuralSimilarity


def PrepareAlphaMask(Alpha):
    filter_func = gaussian_filter
    filter_args = {'sigma': 1.5}
    win_size = 11  # 11 to match Wang et. al. 2004
    AlphaMask = filter_func(Alpha, **filter_args)
    AlphaMask[AlphaMask > 0] = 1
    return AlphaMask


# Structural Similarity
# https://en.wikipedia.org/wiki/Structural_similarity
#
# SSIM is used for measuring the similarity between two images.
# Main use in broadcasting: Compare original with compressed image to assess image degradation due to compression.
# Perceptually motivated. Cited more than 10.000 times. Authors received Emmy Award in 2015.
# SSIM is designed to improve on traditional methods such as peak signal-to-noise ratio (PSNR)
# and mean squared error (MSE), which have proven to be inconsistent with human visual perception.
#
# Main Paper:
# Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004).
# Image quality assessment: From error visibility to structural similarity.
# IEEE Transactions on Image Processing, 13, 600-612.
# https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf
#
# This code is from the upcoming Scikit-Image 0.12
# It supports color and much more.
# compare_ssim
# http://scikit-image.org/docs/dev/api/skimage.measure.html#compare-ssim
#
# Other alternatives:
# https://github.com/helderc/src/blob/master/SSIM_Index.py
# http://isit.u-clermont1.fr/~anvacava/code.html
# https://ece.uwaterloo.ca/~z70wang/research/ssim/
#


from scipy.ndimage import uniform_filter, gaussian_filter
from skimage.util.dtype import dtype_range
from skimage.util.arraypad import crop


def compare_ssim(X, Y, AlphaMask, win_size=None, gradient=False,
                 dynamic_range=None, multichannel=False,
                 gaussian_weights=False, full=False, **kwargs):
    """Compute the mean structural similarity index between two images.
    Parameters
    ----------
    X, Y : ndarray
        Image.  Any dimensionality.
    win_size : int or None
        The side-length of the sliding window used in comparison.  Must be an
        odd value.  If `gaussian_weights` is True, this is ignored and the
        window size will depend on `sigma`.
    gradient : bool
        If True, also return the gradient.
    dynamic_range : int
        The dynamic range of the input image (distance between minimum and
        maximum possible values).  By default, this is estimated from the image
        data-type.
    multichannel : int or None
        If True, treat the last dimension of the array as channels. Similarity
        calculations are done independently for each channel then averaged.
    gaussian_weights : bool
        If True, each patch has its mean and variance spatially weighted by a
        normalized Gaussian kernel of width sigma=1.5.
    full : bool
        If True, return the full structural similarity image instead of the
        mean value
    Other Parameters
    ----------------
    use_sample_covariance : bool
        if True, normalize covariances by N-1 rather than, N where N is the
        number of pixels within the sliding window.
    K1 : float
        algorithm parameter, K1 (small constant, see [1]_)
    K2 : float
        algorithm parameter, K2 (small constant, see [1]_)
    sigma : float
        sigma for the Gaussian when `gaussian_weights` is True.
    Returns
    -------
    mssim : float or ndarray
        The mean structural similarity over the image.
    grad : ndarray
        The gradient of the structural similarity index between X and Y [2]_.
        This is only returned if `gradient` is set to True.
    S : ndarray
        The full SSIM image.  This is only returned if `full` is set to True.
    Notes
    -----
    To match the implementation of Wang et. al. [1]_, set `gaussian_weights`
    to True, `sigma` to 1.5, and `use_sample_covariance` to False.
    References
    ----------
    .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P.
       (2004). Image quality assessment: From error visibility to
       structural similarity. IEEE Transactions on Image Processing,
       13, 600-612.
       https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf
    .. [2] Avanaki, A. N. (2009). Exact global histogram specification
       optimized for structural similarity. Optical Review, 16, 613-621.
       http://arxiv.org/abs/0901.0065
    """
    if not X.dtype == Y.dtype:
        raise ValueError('Input images must have the same dtype.')

    if not X.shape == Y.shape:
        raise ValueError('Input images must have the same dimensions.')

    if multichannel:
        # loop over channels
        args = dict(win_size=win_size,
                    gradient=gradient,
                    dynamic_range=dynamic_range,
                    multichannel=False,
                    gaussian_weights=gaussian_weights,
                    full=full)
        args.update(kwargs)
        nch = X.shape[-1]
        mssim = np.empty(nch)
        if gradient:
            G = np.empty(X.shape)
        if full:
            S = np.empty(X.shape)
        for ch in range(nch):
            ch_result = structural_similarity(X[..., ch], Y[..., ch], AlphaMask, **args)
            if gradient and full:
                mssim[..., ch], G[..., ch], S[..., ch] = ch_result
            elif gradient:
                mssim[..., ch], G[..., ch] = ch_result
            elif full:
                mssim[..., ch], S[..., ch] = ch_result
            else:
                mssim[..., ch] = ch_result
        mssim = mssim.mean()
        if gradient and full:
            return mssim, G, S
        elif gradient:
            return mssim, G
        elif full:
            return mssim, S
        else:
            return mssim

    K1 = kwargs.pop('K1', 0.01)
    K2 = kwargs.pop('K2', 0.03)
    sigma = kwargs.pop('sigma', 1.5)
    if K1 < 0:
        raise ValueError("K1 must be positive")
    if K2 < 0:
        raise ValueError("K2 must be positive")
    if sigma < 0:
        raise ValueError("sigma must be positive")
    use_sample_covariance = kwargs.pop('use_sample_covariance', True)

    if win_size is None:
        if gaussian_weights:
            win_size = 11  # 11 to match Wang et. al. 2004
        else:
            win_size = 7   # backwards compatibility

    if np.any((np.asarray(X.shape) - win_size) < 0):
        raise ValueError("win_size exceeds image extent")

    if not (win_size % 2 == 1):
        raise ValueError('Window size must be odd.')

    if dynamic_range is None:
        dmin, dmax = dtype_range[X.dtype.type]
        dynamic_range = dmax - dmin

    ndim = X.ndim

    if gaussian_weights:
        # sigma = 1.5 to approximately match filter in Wang et. al. 2004
        # this ends up giving a 13-tap rather than 11-tap Gaussian
        filter_func = gaussian_filter
        filter_args = {'sigma': sigma}

    else:
        filter_func = uniform_filter
        filter_args = {'size': win_size}

    # ndimage filters need floating point data
    X = X.astype(np.float64)
    Y = Y.astype(np.float64)

    NP = win_size ** ndim

    # filter has already normalized by NP
    if use_sample_covariance:
        cov_norm = NP / (NP - 1)  # sample covariance
    else:
        cov_norm = 1.0  # population covariance to match Wang et. al. 2004

    # compute (weighted) means
    ux = filter_func(X, **filter_args)
    uy = filter_func(Y, **filter_args)

    # compute (weighted) variances and covariances
    uxx = filter_func(X * X, **filter_args)
    uyy = filter_func(Y * Y, **filter_args)
    uxy = filter_func(X * Y, **filter_args)
    vx = cov_norm * (uxx - ux * ux)
    vy = cov_norm * (uyy - uy * uy)
    vxy = cov_norm * (uxy - ux * uy)

    R = dynamic_range
    C1 = (K1 * R) ** 2
    C2 = (K2 * R) ** 2

    A1, A2, B1, B2 = ((2 * ux * uy + C1,
                       2 * vxy + C2,
                       ux ** 2 + uy ** 2 + C1,
                       vx + vy + C2))
    D = B1 * B2
    S = (A1 * A2) / D

    # to avoid edge effects will ignore filter radius strip around edges
    pad = (win_size - 1) // 2

    # compute (weighted) mean of ssim
    mssim = crop(S, pad).mean() #Original; works well for natural images.

    # Avg only non-white pixels; our images are not natural!
    ACropped = crop(AlphaMask, pad)
    SCropped = crop(S, pad) * ACropped
    mssim = SCropped.sum() / np.count_nonzero(ACropped)

    if gradient:
        # The following is Eqs. 7-8 of Avanaki 2009.
        grad = filter_func(A1 / D, **filter_args) * X
        grad += filter_func(-S / B2, **filter_args) * Y
        grad += filter_func((ux * (A2 - A1) - uy * (B2 - B1) * S) / D,
                            **filter_args)
        grad *= (2 / X.size)

        if full:
            return mssim, grad, S
        else:
            return mssim, grad
    else:
        if full:
            return mssim, S
        else:
            return mssim


def structural_similarity(X, Y, AlphaMask, win_size=None, gradient=False,
                          dynamic_range=None, multichannel=False,
                          gaussian_weights=False, full=False, **kwargs):
    return compare_ssim(X, Y, AlphaMask, win_size=win_size, gradient=gradient,
                        dynamic_range=dynamic_range, 
                        multichannel=multichannel, 
                        gaussian_weights=gaussian_weights, full=full, **kwargs)
