import numpy as np
from matplotlib import pyplot as plt


class PCenter:
    def __init__(self, noise, value=0, weight=0):
        self.noise = noise
        self.weight = weight
        self.value = value

    def is_valid(self):
        return abs(self.value) > 0.0001

    @property
    def weighted(self):
        return self.weight * self.value


class CueIntegrator:

    # values taken from [1]
    # NOISE_POS = 8e-7
    # NOISE_FORCE = 1e-8

    # Some fitting values
    NOISE_POS = 0.02
    NOISE_FORCE = 0.01
    NOISE_VISION = 0.03

    FORCE_NORM = 1      # typical force during a button press
    POS_NORM = 0.06     # typical displacement during a button press

    def __init__(self):
        noise_sum = self.NOISE_POS**-2 + self.NOISE_FORCE**-2 + \
                self.NOISE_VISION**-2
        self._w_pos = (self.NOISE_POS**-2) / noise_sum
        self._w_fsr = (self.NOISE_FORCE**-2) / noise_sum
        self._w_vision = (self.NOISE_VISION**-2) / noise_sum
        # print("wp: %f, wf: %f, wv: %f" % (self._w_pos, self._w_fsr,
        #       self._w_vision))

    def _normalize(self, v):
        '''Frobenius norm, currently not used'''
        norm = np.linalg.norm(v)
        if norm == 0:
           return v
        return v / norm

    def _get_weights(self, deviations):
        deviations = np.array(deviations)
        return deviations**-2 / np.sum(deviations**-2)

    def _set_weights(self, *pcenters):
        valid = []
        for pc in pcenters:
            if pc.is_valid():
                valid.append(pc)
        weights = self._get_weights([pc.noise for pc in valid])
        for i in range(len(valid)):
            valid[i].weight = weights[i]

    def _normalize_pos(self, v):
        return v / (self.POS_NORM)

    def _normalize_fsr(self, v):
        return v / self.FORCE_NORM

    def _get_vision_pc(self, signal):
        '''
        Start count from the time of first 0 after 1 i.e. after arduino gives
        syncronization signal. Return time of first activation.
        '''
        xi = 0
        start_time = 0
        pc_vision = 0
        for i in range(len(signal)):
            # falling edge of syncronization signal
            if signal[i, 1] == 0 and signal[xi, 1] == 1:
                start_time = signal[i, 0]
            # rising edge of button activation
            if start_time != 0 and signal[i, 1] == 1 and signal[xi, 1] == 0:
                pc_vision = signal[i, 0] - start_time
                break
            xi = i
        if pc_vision >= 950:
            pc_vision = 0
        # print('pc_vision %d' % pc_vision)
        return pc_vision


    def _get_pcs(self, data, vision):
        if vision is None:
            pc_vision = PCenter(self.NOISE_VISION)
        else:
            pc_vision = PCenter(self.NOISE_VISION)
            binarized_vision = np.around(vision)
            pc_vision_perfect = self._get_vision_pc(binarized_vision)
            if abs(pc_vision_perfect) < 0.0001:
                pc_vision.value = 0.0
            else:
                pc_vision.value = pc_vision_perfect + np.random.normal(
                    scale=self.NOISE_VISION)
        pos = self._add_noise(data[:, 1], self.NOISE_POS)
        fsr = self._add_noise(data[:, 2], self.NOISE_FORCE)

        # ignore signal if it is buried in noise
        if np.std(pos) < (self.NOISE_POS + 0.1*self.NOISE_POS):
            pc_pos = PCenter(self.NOISE_POS)
            pc_pos.value = 0
        else:
            pc_pos = PCenter(self.NOISE_POS)
            pc_pos.value = data[np.argmax(pos), 0]

        if np.std(fsr) < (self.NOISE_FORCE + 0.1*self.NOISE_FORCE):
            pc_force = PCenter(self.NOISE_FORCE)
            pc_force.value = 0
        else:
            pc_force = PCenter(self.NOISE_FORCE)
            pc_force.value = data[np.argmax(fsr), 0]
        return pc_pos, pc_force, pc_vision

    def _add_noise(self, data, deviation):
        noise = np.random.normal(scale=deviation, size=data.shape)
        return data+noise

    def _plot(self, data, pc_pos, pc_fsr, pc_vision, pc_obs):

        plt.plot(data[:, 0], self._add_noise(data[:, 1], self.NOISE_POS),
                 label='position')

        plt.plot(data[:, 0], self._add_noise(data[:, 2], self.NOISE_FORCE),
                 label='force')
        if pc_fsr:
            plt.axvline(x=pc_fsr, linestyle='--', color='y',
                        label='pc_fsr %.0f' % pc_fsr)
        if pc_pos:
            plt.axvline(x=pc_pos, linestyle='--', color='b',
                        label='pc_pos %.0f' % pc_pos)
        if pc_vision:
            plt.axvline(x=pc_vision, linestyle='--', color='g',
                        label='pc_vis %.0f' % pc_vision)
        plt.axvline(x=pc_obs, color='r', label='pc_obs %.0f' % pc_obs)
        plt.legend()
        plt.show()

    def cue_integrate(self, data, vision=None, show=False):
        ndata = np.copy(data)
        ndata[:, 1] = self._normalize_pos(data[:, 1])
        ndata[:, 2] = self._normalize_fsr(data[:, 2])
        # plt.plot(data[:, 0], ndata[:, 2]); plt.show()  # (show force)
        pc_pos, pc_force, pc_vision = self._get_pcs(ndata, vision)
        self._set_weights(pc_pos, pc_force, pc_vision)
        pc_obs = pc_pos.weighted + pc_force.weighted + pc_vision.weighted
        if show:
            print("pc_pos: %.2f, pc_fsr: %.2f, pc_vision: %.2f"
                  ", pc_obs: %.2f" % (pc_pos.value, pc_force.value,
                                      pc_vision.value, pc_obs))
            self._plot(ndata, pc_pos.value, pc_force.value, pc_vision.value,
                       pc_obs)
        return pc_obs

# REFERENCES
#   [1]   Oulasvirta, Antti, Sunjun Kim, and Byungjoo Lee. "Neuromechanics of a
#   Button Press." Proceedings of the 36th Annual ACM Conference on Human
#   Factors in Computing Systems (CHI18). ACM Press. DOI: http://dx. doi.
#   org/10.1145/3173574.3174082. 2018.
