import numpy as np
from matplotlib import pyplot as plt
from scipy import integrate

import utils.utils as utils
from dynamics.cylinder_rob_push_dynamics import cylinder_pushing_dynamics
from utils.integrator import my_RK4

class MPPI():

    def __init__(self, horizon, N):
        """
        initial_state: the starting position of the robot
        initial_action: an inital guess at the controls for the first horizon
        horizon: how many timesteps in the future
        lam: mppi parameter
        sig: sampling distribution varience
        N: number of rollouts
        """
        # self.lam, sig_omega_acc, sig_acc can be tuned to get more or less noise for the action sampling
        # horizon and N can also be tuned
        self.lam = 0.5
        self.sig_omega_acc = 0.5
        self.sig_acc = 0.1 

        self.horizon = horizon
        self.dt = 0.1 # timestep

        self.a = None 
        self.a0 = None
        self.cur_state = None 

        self.fin_path = None 
        self.fin_control = []
        self.fin_time = [0] 

        self.N = N
        self.beta = 0.7 # smoothing parameter, can be tuned

        self.goal = np.array([1.0, 1.0])

        self.max_acc = 1
        self.max_omega_acc = 1

        self.t_cur = 0 
       
        self.past_action = np.array([0.0, 0.0])

        self.rob_size = 0.15          # m (radius of the cylinder)
        self.obj_size = [0.5, 0.5]    # m
        self.obj_mass = 3             # kg
        self.obj_I_or = 1/12 * self.obj_mass * \
                        (self.obj_size[0]**2 + self.obj_size[1]**2) # moment of inertia (rotation center is 
                                                                    # the center of the object) (for rectangle)
        self.mu_g = 0.3                     # friction coefficient of the ground-object contact surface
        self.grav_acc = 9.8                 # gravity acceleration

        # self.pub_robot_mpc_x_plan = rospy.Publisher('/husky/mpc_x_plan', Float64MultiArray, queue_size=1)
        # self.pub_robot_info = rospy.Publisher('/husky/robot_info', Float64MultiArray, queue_size=1)



    def init_matrices(self, x0):
        initial_state = x0[0:12]    # [x_r, y_r, theta_r, vel_r, omega_r, x_o, y_o, theta_o, d_x_o, d_y_o, omega_o]
        initial_action = np.ones([2, self.horizon]) * np.array([[0.0, 0.0]]).T

        self.a = initial_action # 2xN
        self.a0 = initial_action[:,0] # action to append to a after robot has been issued a control
        self.cur_state = initial_state # 1x11 or just (11,)

        self.fin_path = [initial_state] # final path taken by the robot

    def update_real_state(self, x):
        self.cur_state = x


    def l(self, x, a, eps):
        """
        stage cost function
        x: the state of the robot-obstacle system
        a: the control taken
        eps: the sampled control pertubations
        """
        """
        cost function
        All inputs should correspond to the same time stamp
        x: the state of the robot
        a: the control taken
        eps: the sampled control pertubations
        """
        output = np.zeros([self.N, 1])

        a = a.reshape(1,2)
       
        for n in range(self.N):

            px_error = self.goal[0] - x[n,5]
            py_error = self.goal[1] - x[n,6]
            disToGoal = np.sqrt(px_error ** 2 + py_error ** 2)
            disToGoal = max(disToGoal, 0.2)
            cost_pos_navi = 1 * px_error ** 2 / disToGoal + 1 * py_error ** 2 / disToGoal

            px_error = x[n,0] - x[n,5]
            py_error = x[n,1] - x[n,6]
            disToGoal = np.sqrt(px_error ** 2 + py_error ** 2)
            disToGoal = max(disToGoal, 0.2)
            cost_pos_dist = 10 * px_error ** 2 / disToGoal + 10 * py_error ** 2 / disToGoal

            cost_husky_input = 10 * (a[0][0]) ** 2 + 10 * (a[0][1]) ** 2

            cost_husky_vel = 1 * (x[n,3]) ** 2 + 1 * (x[n,4]) ** 2

            output[n,:] = cost_husky_input + cost_pos_dist +  cost_husky_vel

        return output

   
    def m(self, x):
        """
        terminal cost function
        x: the state of the robot at the end of the horizon
        """
        output = np.zeros([self.N, 1])

        for n in range(self.N):

            px_error = self.goal[0] - x[n,5]
            py_error = self.goal[1] - x[n,6]
            disToGoal = np.sqrt(px_error ** 2 + py_error ** 2)
            disToGoal = max(disToGoal, 0.2)
            cost_pos_navi = 200 * px_error ** 2 / disToGoal + 200 * py_error ** 2 / disToGoal

            output[n,:] =  cost_pos_navi 
           
        return output


    def set_goal(self, x, y):
        """
        Function used to update the goal to drive towards
        """
        self.goal = np.array([x, y])

    def plot_paths(self):
        all_x_coord = []
        all_y_coord = []

        for i in range(0, self.mpc_path.shape[1]):
            x_coord = self.mpc_path[:, i, 0]  # a list of x in one episode
            y_coord = self.mpc_path[:, i, 1]  # a list of y in one episode
            all_x_coord.append(x_coord)
            all_y_coord.append(y_coord)

        plt.figure(0)
        for j in range(len(all_x_coord)):
            plt.plot(all_x_coord[j], all_y_coord[j], color="blue")
            # plt.plot(self.cur_state[0], self.cur_state[1], "ro", markersize = 4)
            # plt.plot(self.goal[0], self.goal[1], "x", color="black", markersize =10, label="Goal")
        plt.savefig("fig.png")



    def get_action(self):
        # perform mppi optimization and get the current action 

        J = [] # cost list
        eps = [] # samples list

        all_states = np.zeros((self.horizon, self.N, self.cur_state.shape[0]))
        temp_state = np.tile(self.cur_state, (self.N,1))

        eps_array = np.zeros((self.horizon, self.N, 2))
        acc_eps = np.random.normal(0, self.sig_acc, size=(self.horizon, self.N, 1))
        omega_acc_eps = np.random.normal(0, self.sig_omega_acc, size=(self.horizon, self.N, 1))
        eps_array[:, :, 0] = acc_eps.reshape(self.horizon, self.N)
        eps_array[:, :, 1] = omega_acc_eps.reshape(self.horizon, self.N)

        eps_smoothed = np.zeros_like(eps_array)
        for i in range(self.horizon):
            if (i==0):
                eps_smoothed[i, :, :] = self.beta * (self.a[:, i] + eps_array[i, :, :]) + (1 - self.beta) * self.past_action
            else:
                eps_smoothed[i, :, :] = self.beta * (self.a[:, i] + eps_array[i, :, :]) + (1 - self.beta) * eps_smoothed[i-1, :, :]

        eps_smoothed = np.clip(eps_smoothed, -1, 1)
        for t in range(self.horizon):

            matrix = np.concatenate((np.random.normal(0, self.sig_acc, size=(self.N, 1)), np.random.normal(0, self.sig_omega_acc, size=(self.N, 1))), axis = 1)
            eps.append(matrix)

            J.append(self.l(temp_state, self.a[:,t], eps_smoothed[t, :, :]))

            # new_temp_state = self.step_RK4(temp_state,  eps_smoothed[t, :, :], self.dynamics)
            new_temp_state = np.array([])

            ##########
            for i in range(np.size(temp_state,0)):
                new_temp_state_i = my_RK4(temp_state[i],\
                                    eps_smoothed[t, i, :],\
                                    cylinder_pushing_dynamics, self.dt, [])

                # new_temp_state_i = my_RK4(temp_state[i],  eps_smoothed[t, i, :], self.dynamics)
                new_temp_state = np.append(new_temp_state, new_temp_state_i)
            ##########
            new_temp_state = new_temp_state.reshape(self.N, np.size(new_temp_state_i))

            temp_state = new_temp_state
            all_states[t, :, :] = temp_state

        self.mpc_path = all_states
        J.append(self.m(temp_state))

        J = np.flip(np.cumsum(np.flip(J, 0), axis=0), 0)

        for t in range(self.horizon):

            J[t] -= np.amin(J[t]) # log sum exp trick

            w = np.exp(-J[t]/self.lam) + 1e-8
            w /= np.sum(w)

            self.a[:,t] = self.a[:,t] + np.dot(w.T, (eps_smoothed[t, :, :]-self.a[:,t]))
                        

        # make sure action is between bounds
        action_acc = np.clip(self.a[0,0], 0, self.max_acc)
        action_omega_acc = np.clip(self.a[1,0], -self.max_omega_acc, self.max_omega_acc)

        action = np.array([action_acc, action_omega_acc])
        self.past_action = action

        return action, self.mpc_path

    def init_next_action(self):
        # advance control matrix for next step
        self.a = np.concatenate([self.a[:, 1:], np.array(self.a0).reshape(2,1)], axis=1)


def main():
    pass 

if __name__ == "__main__":
    main()