import time
import csv
import multiprocessing as mp

import numpy as np
from GPyOpt.methods import BayesianOptimization
from serial.serialutil import SerialTimeoutException
from serial.serialutil import SerialException

import robot_finger
import signal_utility
import cue_integrator
import vision


class Controller:

    # error weights
    PCE_WEIGHT = 1.5
    PRESSED_WEIGHT = 1.0
    FORCE_WEIGHT = 2.5
    OVERSHOOT_WEIGHT = 0.5
    DURATION_WEIGHT = 0.7

    TOUCHED_WEIGHT = 0.0  # since finger start on top of the button
    JERK_WEIGHT = 0.0     # appears to only make movements slower

    MAX_FORCE = 25
    MAX_IMPULSE = 8500
    MIN_IMPULSE = 900
    MAX_JERK = 1e-11
    MIN_JERK = 1e-14

    DEFAULT_OFFSET = 50

    def __init__(self, time_step, antagonist=False, trails=1, robot=None,
                 n_init=20, init_x=None):

        np.set_printoptions(precision=5, suppress=True)
        self.antagonist = antagonist

        # vision related init settings
        self.input_q = mp.Queue()
        self.output_q = mp.Queue()
        self.vision = vision.Vision(self.input_q, self.output_q)
        self.vision.start()
        time.sleep(5)

        # other settings
        self.time_step = time_step  # ms
        self.trails = trails
        self.accuracies = [['iteration', 'trail', 'pc_exp', 'pc_obs',
                            'pc_true', 'diff_exp', 'diff_true',
                            'objective_value']]
        self.iteration = 1          # this is used for logging information
        self.total_duration = 1000  # ms
        self.integrator = cue_integrator.CueIntegrator()
        if robot == None:
            print("connecting...")
            self.robot = robot_finger.RobotFinger()
            print("connected")
        else:
            self.robot = robot

        offset_range = tuple(np.arange(0, 301))
        duration_range = tuple(np.arange(0, 500))
        pc_range = tuple(np.arange(0, 1001))
        domain = [{'name': 'offset', 'type': 'discrete', 'domain':
                   offset_range},
                  {'name': 'amplitude_ago', 'type': 'continuous', 'domain':
                   (0, 1)},
                  {'name': 'duration_ago', 'type': 'discrete', 'domain':
                   duration_range},
                  {'name': 'pc_exp', 'type': 'discrete', 'domain':
                   pc_range}]
        if antagonist:
            domain.append({'name': 'amplitude_ant', 'type': 'continuous',
                           'domain': (0, 1)})
            domain.append({'name': 'duration_ant', 'type': 'discrete',
                           'domain': duration_range})

        self.optimizer = BayesianOptimization(f=self._f,
                                              domain=domain,
                                              X=init_x,
                                              model_type='GP',
                                              maximize=True,
                                              initial_design_numdata=n_init,
                                              initial_design_type='latin',
                                              # acquisition_type='EI')
                                              acquisition_type='LCB')
        print("initialization complete\n")

    def _f(self, data, get_data=False, accurate=False):
        '''
        format of arg data:
            0 - offset between two signals
            1 - ampilitude of agonist
            2 - duration of agonist
            3 - expected perceptual center
            4 - amplitude of antagonist
            5 - duration of antagonist
        '''
        offset= int(data[:, 0])
        amplitude_ago = float(data[:, 1])
        duration_ago = int(data[:, 2])
        pc_exp = int(data[:, 3])

        if self.antagonist:
            amplitude_ant = float(data[:, 4])
            duration_ant = int(data[:, 5])
        else:
            amplitude_ant = 0.0
            duration_ant = 0

        print('Offset: %d, d_ago: %d, a_ago: %.2f, d_ant: %d,'
              ' a_ant: %.2f, pc_exp: %d' % (
            offset, duration_ago, amplitude_ago, duration_ant, amplitude_ant,
            pc_exp))

        signal = signal_utility.make_signal(
                self.time_step, self.total_duration, self.DEFAULT_OFFSET,
                offset, amplitude_ant, duration_ant, amplitude_ago,
                duration_ago, get_data)
        values = []
        outputs = []

        # recalibration
        # (uncomment the code below if you are training for long time and would
        # like to calibrate in between training session)
        # if self.iteration % 40 == 0:
        #     print("\nRobot entering setup mode for calibration")
        #     print(self.robot.setup())
        #     input('\nDone?')
        #     print()

        for i in range(self.trails):
            while True:
                try:
                    self.input_q.put('start')
                    time.sleep(0.2)  # this is system specific delay to match
                                     # camera process with fingers motion 
                    output, press_n, impulse = self.robot.send_signal(
                            self.time_step, signal)
                    break
                except (robot_finger.RobotError, SerialTimeoutException,
                        SerialException) as e:
                    print(e)
                    print("reboot please, then press enter")
                    input()
                    self.robot = robot_finger.RobotFinger()

            while self.output_q.empty():
                # print("waiting for vision...")
                time.sleep(0.3)
            vision_output = self.output_q.get()

            pc_obs = self.integrator.cue_integrate(
                    output, vision_output, show=get_data)

            if get_data:
                print("Button was pressed at %d ms" % press_n)
                print("Impulse: %.2f" % impulse)

            pressed = press_n > 0
            touched = self._was_touched(output)
            # normalize, scale impulse
            impulse = (impulse - self.MIN_IMPULSE) / (
                    self.MAX_IMPULSE - self.MIN_IMPULSE)
            # normalize, scale and clip jerk
            jerk_ms = min(1, (Controller._get_jerk(output) - self.MIN_JERK) / (
                    self.MAX_JERK - self.MIN_JERK))
            overshoot = self._get_overshoot(output)
            duration = self._get_duration(duration_ago, duration_ant)

            value = self._get_objective_value(
                    pc_exp, pc_obs, pressed, touched, impulse, jerk_ms,
                    overshoot, duration)
            self.accuracies.append([
                self.iteration, i, pc_exp, pc_obs, press_n, pc_exp-pc_obs,
                pc_exp-press_n, value])
            values.append(value)
            outputs.append(output)
        self.iteration += 1

        # return everything if its requested
        if get_data:
            return values, outputs

        # return average value over all trials with these parameters
        return sum(values)/len(values)

    @staticmethod
    def _get_jerk(output):
        m, n = output.shape
        t = output[:, 0]
        x = output[:, 1]
        v = np.zeros(m)
        a = np.zeros(m)
        j = np.zeros(m)
        for i in range(1, m):
            dt = t[i] - t[i-1]
            v[i] = (x[i] - x[i-1]) / dt
            a[i] = (v[i] - v[i-1]) / dt
            j[i] = (a[i] - a[i-1]) / dt
        jerk_mean_square = np.mean(j**2)
        return jerk_mean_square

    def _get_objective_value(self, pc_exp, pc_obs, pressed, touched,
                             peak_force, jerk, overshoot, duration):
        # Optimizer will maximize this value
        # below is a list of penalties that will be summed up

        # pc difference divided by 1000 to convert from ms to s
        e_pcs_dif = - self.PCE_WEIGHT * abs(pc_exp-pc_obs) / 1000

        e_press = - self.PRESSED_WEIGHT * float(not pressed)
        e_touch = - self.TOUCHED_WEIGHT * float(not touched)
        e_force = - self.FORCE_WEIGHT * peak_force
        e_jerk = - self.JERK_WEIGHT * jerk
        e_overshoot = self.OVERSHOOT_WEIGHT * overshoot
        e_duration = self.DURATION_WEIGHT * duration
        value = e_pcs_dif + e_press + e_touch + e_force + e_jerk + e_overshoot
        print("* pce: {:.3g}, press: {:.3g}, touch {:.3g}, force {:.3g},\n"
              "* jerk {:.3g}, overshoot {:.3g}, duration {:.3g},"
              " [SUM {:.3g}]\n".format(
                  e_pcs_dif, e_press, e_touch, e_force, e_jerk, e_overshoot,
                  e_duration, value))
        return value

    def _was_pressed(self, data):
        return np.max(data[:, 3]) > 0.1

    def _was_touched(self, data):
        return np.max(data[:, 2])-np.min(data[:, 2]) > 0.02

    def _get_peak_force(self, data):
        return np.max(abs(data[:, 4]-data[:, 5]))/self.MAX_FORCE

    def _get_overshoot(self, data):
        # implemented by summing all the points of displacement that are above
        # above the top of the keycap by 10 mm
        clipped = np.clip(data[:, 1], -np.inf, -0.001)
        overshoot = np.sum(clipped)
        return overshoot

    def _get_duration(self, duration_ago, duration_ant):
        return - (duration_ago + duration_ant) / 600

    def run(self, n=10, trails=3, max_time=np.inf):
        # print("params:\n", n, trails, max_time)
        self.trails = trails
        self.optimizer.run_optimization(max_iter=n, max_time=max_time,
                                        eps=-1, verbosity=True)

    def save_results(self, path, name):
        self.optimizer.save_evaluations(path+name+'_e.txt')
        self.optimizer.save_models(path+name+'_m.txt')
        self.optimizer.save_report(path+name+'_r.txt')
        with open(path + name+'_a.txt', 'w') as f:
            writer = csv.writer(f)
            writer.writerows(self.accuracies)

    def get_result(self):
        self.results = self.optimizer.X[np.argmin(self.optimizer.Y)]
        return self.results, self.accuracies

    def get_accuracies(self):
        return self.accuracies

    def run_result(self):
        data = np.reshape(self.results, (1, -1))
        evaluations, outputs = self._f(data, get_data=True)
        return evaluations, outputs

    def close(self):
        self.robot.close()
        self.input_q.put('end')
