import numpy as np
import casadi as cd
import time
import timeit

import sys
import os
from pathlib import Path

FILE_THIS = Path(__file__).resolve()
PARENT = FILE_THIS.parent
GPARENT = FILE_THIS.parents[1]
GGPARENT = FILE_THIS.parents[2]
sys.path.append(str(PARENT))
sys.path.append(str(GPARENT))
sys.path.append(str(GGPARENT))

from general_problem_setup import Problem, Index, Mpc_model
from dynamics.cylinder_rob_push_dynamics_casadi import cylinder_pushing_dynamics_casadi
from dynamics.cylinder_rob_push_dynamics import cylinder_pushing_dynamics
from utils.integrator import my_RK4, my_RK4_withoutParam

class Free_Push_mpc_controller:
    def __init__(self):
        # problem setup
        self.pr = Problem()
        # index
        self.index = Index()

        # MPC model
        self.model = Mpc_model()
        # MPC
        self.model.dt = self.pr.dt  # sampling time
        self.model.N = self.pr.N  # horizon length
        self.model.nx = self.pr.nx  # number of state 
        self.model.nu = self.pr.nu  # number of control inputs
        # model.ns = 0  # number of slack variables
        self.model.np = 12  # number of parameters on each stage
        # robot_start, goal (4) ...
        # + robot size (2) + mpc_weights (4) ...
        # + obs_size (2) + obs_scale (1) + obs_pos(2) = 15
        # bound
        self.model.xl = [self.pr.ws_x[0], self.pr.ws_y[0], -1.5 * np.pi, self.pr.robot_minVel, -self.pr.robot_maxOmega, \
                                                    -self.pr.phi_max, -self.pr.d_max, -self.pr.phi_maxVel, -self.pr.d_maxVel]
        self.model.xu = [self.pr.ws_x[1], self.pr.ws_y[1], 1.5 * np.pi, self.pr.robot_maxVel, self.pr.robot_maxOmega, \
                                                    self.pr.phi_max,  self.pr.d_max,  self.pr.phi_maxVel, self. pr.d_maxVel]
        self.model.ul = [self.pr.robot_minAcc, self.pr.robot_minOmegaAcc]
        self.model.uu = [self.pr.robot_maxAcc, self.pr.robot_maxOmegaAcc]
        self.model.sl = [0.0]
        self.model.su = [10.0]

        # system info
        self.state_current = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]   # [x_r, y_r, theta_r, vel_r, omega_r, phi, d, d_phi, d_d]
        self.system_state_current = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] 
        # [x_r, y_r, theta_r, vel_r, omega_r, x_o, y_o, theta_o, vel_ox, vel_oy, omega_o]
        self.robot_control_current = [0.0, 0.0]                     # [vel, omega]
        self.robot_acc_omega_acc_current = [0.0, 0.0]               # [acc, omega_acc]

        self.state_next = self.state_current
        self.obj_pos_goal = self.pr.obj_pos_goal
        self.obj_theta_goal = self.pr.obj_theta_goal

        # MPC info
        self.mpc_x_plan = np.tile(np.array(self.state_current).reshape((-1, 1)), (1, self.pr.N))
        self.mpc_u_plan = np.zeros((self.pr.nu, self.pr.N))
        self.mpc_feasible = False
        self.mpc_step_count = 0
        self.mpc_failure_count = 0
        self.mpc_failure_rate = 0.0
        self.mpc_solve_time = 0.0
        self.mpc_solve_time_all = 0.0

        self.system_state_plan = np.zeros((self.pr.N, 1, np.array(self.system_state_current).shape[0]))

        # stop condition
        self.dist_obj_goal = 100000
        self.dist_obj_goal_last = 100000
        self.stop = np.array([])

    def casadi_nlp_form(self, model, index, pr):
        # Declare variables 
        # State
        x = cd.SX.sym('x', model.nx)         # x = q = [x_r, y_r, theta_r, v_r, d_theta_r, fai, d, d_fai, d_d]
        x_r = x[index.x_r]
        y_r = x[index.y_r]
        theta_r = x[index.theta_r]
        v_r = x[index.v_r]
        d_theta_r = x[index.d_theta_r]
        phi = x[index.phi]
        d = x[index.d]
        d_phi = x[index.d_phi]
        d_d = x[index.d_d] 

        # Control
        u = cd.SX.sym('u', model.nu)          # u = [acc_r, omega_acc_r]
        acc_r = u[index.u_acc_r]
        omega_acc_r = u[index.u_omega_acc_r]

        # params
        p = cd.SX.sym('p', model.np)  # p = [goal_pos_x, goal_pos_y, goal_theta, initial_state[0:9]]
        goal_pos_x = p[index.p_goal_pos_x]
        goal_pos_y = p[index.p_goal_pos_y]
        goal_theta = p[index.p_goal_theta]
        initial_state = p[index.p_initial_state]

        # position of the object
        [W_o, L_o] = pr.obj_size
        r_r = pr.rob_size  
        x_o = x_r + d*np.sin(phi + theta_r) + np.cos(phi + theta_r)*(W_o/2 + r_r)
        y_o = y_r - d*np.cos(phi + theta_r) + np.sin(phi + theta_r)*(W_o/2 + r_r)
        theta_o = theta_r + phi
        # Objective function
        # L = x_o*x_o + y_o*y_o + 0.1*d_phi*d_phi + 0.1*d_d*d_d + acc_r*acc_r + omega_acc_r*omega_acc_r
        # L = (x_o)*x_o + (y_o+0.)*(y_o+0.) + (theta_o - atan(y_o/x_o))*(theta_o - atan(y_o/x_o))
        # L = (x_o)*x_o + (y_o+0.)*(y_o+0.) + 0.0001*(theta_o - atan(y_o/x_o))*(theta_o - atan(y_o/x_o))
        # L = (x_o)*(x_o) + (y_o)*(y_o) #+ 0.0001*phi*phi
        L_p = 10.*phi*phi + 3.*d*d + 1*d_d*d_d + 1*d_phi*d_phi
        L = (x_o-goal_pos_x)*(x_o-goal_pos_x) + (y_o-goal_pos_y)*(y_o-goal_pos_y) #+ 0.1*theta_o*theta_o #+ theta_r*theta_r
        # L = theta_o*theta_o
        # L = (theta_o+np.pi/2)*(theta_o+np.pi/2)
        l = cd.Function('l', [x, p], [L])
        l_p = cd.Function('l_p', [x, p], [L_p])

        # ODE right hand side
        xdot = cylinder_pushing_dynamics_casadi(x, u)

        # Formulate discrete time dynamics
        if False:
            # CVODES from the SUNDIALS suite
            dae = {'x':x, 'p':u, 'ode':xdot, 'quad':L}
            opts = {'tf':T/N}
            F = integrator('F', 'cvodes', dae, opts)
        else:
            # Fixed step Runge-Kutta 4 integrator
            DT = 0.1
            f = cd.Function('f', [x, u], [xdot])
            X0 = cd.MX.sym('X0', pr.nx)
            U = cd.MX.sym('U', pr.nu)
            X = X0
            Q = 0
            k1 = f(X, U)
            k2 = f(X + DT/2 * k1, U)
            k3 = f(X + DT/2 * k2, U)
            k4 = f(X + DT * k3, U)
            X=X+DT/6*(k1 +2*k2 +2*k3 +k4)
            F = cd.Function('F', [X0, U], [X],['x0','p'],['xf'])

        ################################################################
        # Start with an empty NLP
        w=[]        # decision variables
        w0 = []  # initial guess
        lbw = [] # lower bound on decision variables
        ubw = [] # upper bound on decision variables
        J = 0    # cost accumulator
        g=[]     # inequality constraint equations
        lbg = [] # lower bound for g
        ubg = [] # upper bound for g
        nlp_p = []  # real time parameters

        # Parameters
        p = cd.MX.sym('P', model.np)  # parameter of this stage
        nlp_p += [p]

        # Formulate the NLP    
        Xk = p[index.p_initial_state]
        for k in range(model.N):

            # New NLP variable for the control
            Uk = cd.MX.sym('U_' + str(k), model.nu)
            w   += [Uk]
            lbw += model.ul
            ubw += model.uu
            w0  += [0] * model.nu

            # Integrate till the end of the interval
            Fk = F(x0=Xk, p=Uk)
            Xk_end = Fk['xf']
            if k == model.N-1:
                J=J+l(Xk, p)
            else: 
                J=J+l_p(Xk, p)*0/20

            # New NLP variable for state at end of interval
            Xk = cd.MX.sym('X_' + str(k+1), model.nx)
            w   += [Xk]
            lbw += model.xl
            ubw += model.xu
            w0 += [0] * model.nx

            # Add equality constraint
            g   += [Xk_end-Xk]
            lbg += [0] * model.nx
            ubg += [0] * model.nx

        # Create an NLP solver
        prob = {'f': J, 'x': cd.vertcat(*w), 'g': cd.vertcat(*g), 'p': cd.vertcat(*nlp_p)}
        opts = {'ipopt.print_level':2, 'print_time':0}
        # opts = {'print_time': 0, 'ipopt.print_level': 0, 'ipopt.max_iter': 100}
        solver = cd.nlpsol('solver', 'ipopt', prob, opts)

        return solver, lbw, ubw, lbg, ubg

    def generate_solver(self, recompile, solver_build_name):
        # recompile = True
        # solver_build_name = 'free_push_mpc_controller_casadi'

        # NLP formulation
        [nlp_solver, nlp_lbx, nlp_ubx, nlp_lbg, nlp_ubg] = self.casadi_nlp_form(self.model, self.index, self.pr)
        # compiling
        if recompile:
            solver_build_c = solver_build_name + '.c'
            solver_build_o = solver_build_name + '.so'
            nlp_solver.generate_dependencies(solver_build_c)
            print('Compiling...')
            os.system('gcc -fPIC -shared ' + solver_build_c + ' -o ' + solver_build_o)
            print('Done Compiling!')
        self.solver_comp = cd.nlpsol('solver', 'ipopt', \
                                './' + solver_build_o, {'print_time': 0, 'ipopt.print_level': 0})

    def set_system_state_current(self, state_current):
        self.state_current = state_current

    def set_obj_pos_theta_goal(self, obj_pos_goal, obj_theta_goal):
        self.obj_pos_goal = obj_pos_goal
        self.obj_theta_goal = obj_theta_goal

    def run_mpc(self):
        # Run the mpc once
        self.mpc_step_count += 1

        # prepare real-time parameters for MPC
        parameters = np.zeros(self.model.np)  # all parameters
        parameters[self.index.p_goal_pos_x] = self.pr.obj_pos_goal[0]
        parameters[self.index.p_goal_pos_y] = self.pr.obj_pos_goal[1]
        parameters[self.index.p_goal_theta] = self.pr.obj_theta_goal[0]
        parameters[self.index.p_initial_state] = self.state_current
            # parameters_all_stage[self.index.p_mpc_weights, iStage] = np.array([0.0, self.pr.w_inputs, self.pr.w_coll, self.pr.w_slack])
        # set parameters
        nlp_p = np.transpose(parameters).reshape(-1)

        # Set initial guess
        if self.mpc_feasible is True:  # MPC feasible
            x_traj_init = np.concatenate((self.mpc_x_plan[:, 1:], self.mpc_x_plan[:, -1:]), axis=1)
            u_traj_init = np.concatenate((self.mpc_u_plan[:, 1:], self.mpc_u_plan[:, -1:]), axis=1)
            wo = np.concatenate((self.w_opt[11:], self.w_opt[-11:]), axis=0).reshape(-1)
        else:  # MPC infeasible
            x_traj_init = self.state_current
            u_traj_init = np.zeros(self.pr.nu)
            wo_k_init = np.concatenate((u_traj_init, x_traj_init), axis=0).reshape(-1)
            wo = np.tile(np.array(wo_k_init), self.pr.N).reshape(-1)

        lbw = [self.pr.robot_minAcc, self.pr.robot_minOmegaAcc]
        ubw = [self.pr.robot_maxAcc, self.pr.robot_maxOmegaAcc]
        lbw += [self.pr.ws_x[0], self.pr.ws_y[0], -1.5 * np.pi, self.pr.robot_minVel, -self.pr.robot_maxOmega, \
                                                -self.pr.phi_max, -self.pr.d_max, -self.pr.phi_maxVel, -self.pr.d_maxVel]
        ubw += [self.pr.ws_x[1], self.pr.ws_y[1], 1.5 * np.pi, self.pr.robot_maxVel, self.pr.robot_maxOmega, \
                                                self.pr.phi_max,  self.pr.d_max,  self.pr.phi_maxVel,  self.pr.d_maxVel]
        nlp_lbx = np.tile(np.array(lbw), self.pr.N)
        nlp_ubx = np.tile(np.array(ubw), self.pr.N)

        nlp_lbg = np.tile(np.array([0] * self.model.nx), self.pr.N)
        nlp_ubg = nlp_lbg

        # call the solver
        tic = time.time()
        sol = self.solver_comp(x0=wo, p=nlp_p, lbx=nlp_lbx.reshape(-1), ubx=nlp_ubx.reshape(-1),\
                               lbg=nlp_lbg.reshape(-1), ubg=nlp_ubg.reshape(-1))
        self.w_opt = sol['x']
        self.mpc_feasible = self.solver_comp.stats()["success"]
        # print("sol['f']", sol['f'])
        self.mpc_solve_time = time.time() - tic
        # self.mpc_solve_time_all = np.append(self.mpc_solve_time_all, self.mpc_solve_time)

        u1_opt = self.w_opt[0::11]
        u2_opt = self.w_opt[1::11]
        x_r_opt = self.w_opt[self.index.x_r + 2 :: 11]
        y_r_opt = self.w_opt[self.index.y_r + 2 :: 11]
        theta_r_opt = self.w_opt[self.index.theta_r + 2 :: 11]
        v_r_opt = self.w_opt[self.index.v_r + 2 :: 11]
        d_theta_r_opt = self.w_opt[self.index.d_theta_r + 2 :: 11]
        phi_opt = self.w_opt[self.index.phi + 2 :: 11]
        d_opt = self.w_opt[self.index.d + 2 :: 11]
        d_phi_opt = self.w_opt[self.index.d_phi + 2 :: 11]
        d_d_opt = self.w_opt[self.index.d_d + 2 :: 11] 

        # get the planned trajectories
        self.system_state_plan = np.zeros((self.model.N+1, 1, self.pr.nx))
        # initial state
        x_r = self.state_current[self.index.x_r]
        y_r = self.state_current[self.index.y_r]
        theta_r = self.state_current[self.index.theta_r]
        v_r = self.state_current[self.index.v_r]
        d_theta_r = self.state_current[self.index.d_theta_r]
        phi = self.state_current[self.index.phi]
        d = self.state_current[self.index.d]
        d_phi = self.state_current[self.index.d_phi]
        d_d = self.state_current[self.index.d_d] 
        # initial position of the object
        x_o = x_r + d*np.sin(phi + theta_r) + np.cos(phi + theta_r)*(self.pr.obj_size[0]/2 + self.pr.rob_size)
        y_o = y_r - d*np.cos(phi + theta_r) + np.sin(phi + theta_r)*(self.pr.obj_size[0]/2 + self.pr.rob_size)
        self.system_state_plan[0, 0, 0] = x_r
        self.system_state_plan[0, 0, 1] = y_r
        self.system_state_plan[0, 0, 5] = x_o
        self.system_state_plan[0, 0, 6] = y_o

        for iStage in range(self.model.N):
            x_r = x_r_opt[iStage]
            y_r = y_r_opt[iStage]
            theta_r = theta_r_opt[iStage]
            v_r = v_r_opt[iStage]
            d_theta_r = d_theta_r_opt[iStage]
            phi = phi_opt[iStage]
            d = d_opt[iStage]
            d_phi = d_phi_opt[iStage]
            d_d = d_d_opt[iStage] 

            # position of the object
            x_o = x_r + d*np.sin(phi + theta_r) + np.cos(phi + theta_r)*(self.pr.obj_size[0]/2 + self.pr.rob_size)
            y_o = y_r - d*np.cos(phi + theta_r) + np.sin(phi + theta_r)*(self.pr.obj_size[0]/2 + self.pr.rob_size)

            self.system_state_plan[iStage+1, 0, 0] = x_r   # start from the 1
            self.system_state_plan[iStage+1, 0, 1] = y_r
            self.system_state_plan[iStage+1, 0, 5] = x_o
            self.system_state_plan[iStage+1, 0, 6] = y_o

        self.robot_acc_omega_acc_current = np.array([u1_opt[0],
                                            u2_opt[0]]).reshape(-1)
        # self.robot_control_current = [self.mpc_x_plan[self.index.v_r, 1], self.mpc_x_plan[self.index.d_theta_r, 1]]
        self.robot_control_current = np.array([v_r_opt[0], d_theta_r_opt[0]]).reshape(-1)

        # stop consition
        self.dist_obj_goal_last = self.dist_obj_goal
        self.dist_obj_goal = np.linalg.norm(np.array([x_r, y_r])- np.array([x_o, y_o]))
        if self.dist_obj_goal < 0.01:
            print(np.size(self.stop))
            if self.dist_obj_goal > self.dist_obj_goal_last:
                self.stop = np.append(self.stop, 1)
        

        # # Control command
        # if status != 0 and status != 2:  # infeasible
        #     # MPC weights
        #     # self.w_navi_error = 0.
        #     # self.w_d_error = 10.
        #     # self.w_phi_error = 10.
        #     self.mpc_failure_count += 1
        #     self.mpc_feasible = False
        #     self.system_state_next = None
        #     self.robot_acc_omega_acc_current = [0.0, 0.0]
        #     self.robot_control_current = [0.0, 0.0]
        # else:  # feasible
        #     # # MPC weights
        #     # self.w_navi_error = 10.
        #     # self.w_d_error = 0.
        #     # self.w_phi_error = 0.
        #     self.mpc_feasible = True
        #     self.system_state_next = self.mpc_x_plan[:, 1]
        #     self.robot_acc_omega_acc_current = [u1_opt[0],
        #                                         u2_opt[0]]
        #     # self.robot_control_current = [self.mpc_x_plan[self.index.v_r, 1], self.mpc_x_plan[self.index.d_theta_r, 1]]
        #     self.robot_control_current = [v_r_opt[0], d_theta_r_opt[0]]

        # self.mpc_failure_rate = 1.0 * self.mpc_failure_count / self.mpc_step_count
        # print('MPC infeasible rate: {}'.format(self.mpc_failure_rate))



if __name__ == '__main__':
    mpc_controller = Free_Push_mpc_controller()
    recompile = True
    solver_build_name = 'free_push_mpc_controller_casadi'    
    mpc_controller.generate_solver(recompile, solver_build_name)