import gym
from gym import spaces

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import Arc, RegularPolygon
from scipy import integrate
from scipy.interpolate import interp1d
import math

from utils.integrator import my_RK4
from general_problem_setup import Problem, Index
import utils.utils as utils
from dynamics.cylinder_rob_push_dynamics import cylinder_pushing_dynamics

class cylinder_robot_pushing_recBox(gym.Env):
    """Custom Environment that follows gym interface"""
    metadata = {'render.modes': ['human'], "render_fps": 30}

    def __init__(self):
        super(cylinder_robot_pushing_recBox, self).__init__()

        self.pr = Problem()
        self.index = Index()

        self.diata_t = 0.1

        # system setup
        self.r_r = self.pr.rob_size    # m (radius of the cylinder)
        [self.W_o, self.L_o] = self.pr.obj_size    # m
        self.obj_mass = self.pr.obj_mass    # kg
        self.obj_I_or = self.pr.obj_I_or    # moment of inertia (rotation center is 
                                            # the center of the object) (for rectangle)
        self.mu_g = self.pr.mu_g            # friction coefficient of the ground-object contact surface
        self.grav_acc = self.pr.grav_acc    # gravity acceleration

        # state initialization
        # state = [x_r, y_r, theta_r, v_r, d_theta_r, phi, d, d_fai, d_d]
        # q = [x_r, y_r, theta_r, phi, d]
        # self.state = np.random.uniform([-2.4, -2.4,        0.0, -self.pr.robot_maxVel, -self.pr.robot_maxOmega,       0.0, -self.L_o/2.0, -self.pr.robot_maxOmega, -self.pr.robot_maxOmega], \
        #                                [0.0,  2.4,   np.pi/2.0,  self.pr.robot_maxVel,  self.pr.robot_maxOmega, np.pi/2.0,  self.L_o/2.0,  self.pr.robot_maxOmega,  self.pr.robot_maxOmega])
        self.state = np.random.uniform([-2.4, -2.4,        0.0, -self.pr.robot_maxVel*0.1, -self.pr.robot_maxOmega*0.01, -0.2, -self.L_o/2.0, 0.0,  0.0], \
                                       [0.0,  2.4,   np.pi/2.0,  self.pr.robot_maxVel*0.1,  self.pr.robot_maxOmega*0.01,  0.2,  self.L_o/2.0, 0.0,  0.0])

        self.state[2] = np.arctan2(-self.state[1], -self.state[0]) # robot angle is pointing to the origin
        self.state[5] = np.random.uniform(-0.2,0.2)
        x_r = self.state[0]
        y_r = self.state[1]
        theta_r = self.state[2]
        v_r = self.state[3]
        d_theta_r = self.state[4]
        phi = self.state[5]
        d = self.state[6]
        d_fai = self.state[7]
        d_d = self.state[8]   

        d_x_r = np.cos(theta_r)*v_r
        d_y_r = np.sin(theta_r)*v_r

        # translate q to x
        # position of the object
        W_p_o =  np.array([x_r + d*np.sin(phi + theta_r) + np.cos(phi + theta_r)*(self.W_o/2 + self.r_r),\
                        y_r - d*np.cos(phi + theta_r) + np.sin(phi + theta_r)*(self.W_o/2 + self.r_r)])
        # velocity of the object 
        d_W_p_o = np.array([d_x_r + d_d*np.sin(phi + theta_r) + d*np.cos(phi + theta_r)*(d_fai + d_theta_r) - np.sin(phi + theta_r)*(d_fai + d_theta_r)*(self.W_o/2 + self.r_r),\
                            d_y_r - d_d*np.cos(phi + theta_r) + d*np.sin(phi + theta_r)*(d_fai + d_theta_r) + np.cos(phi + theta_r)*(d_fai + d_theta_r)*(self.W_o/2 + self.r_r) ])
        theta_o = theta_r + phi
        d_theta_o = d_theta_r + d_fai

        self.obj_state = np.array([W_p_o[0], W_p_o[1], theta_o, d_W_p_o[0], d_W_p_o[1], d_theta_o])
        self.rob_state = np.array([x_r, y_r, theta_r, v_r, d_theta_r])

        # initialize object history path
        self.obj_his_path = np.array(self.obj_state[0:3]).reshape((-1, 1))

        # visualization initialization, range of the x-y axises
        self.ws_x = self.pr.ws_x
        self.ws_y = self.pr.ws_y

        # prepare for visualization
        self.prepare_visualization()

        # Define action and observation space
        # They must be gym.spaces objects
        self.action_space_low = np.zeros(2)    # robot and object state
        self.action_space_high = np.array([0.5, 0.5])   # robot vel, angular vel
        self.action_space = spaces.Box(self.action_space_low, self.action_space_high, dtype=np.float32)
        
        self.observation_space_low = 5*np.ones(11)     # robot and object state
        self.observation_space_high = 5*np.ones(11)    # robot and object state
        self.observation_space = spaces.Box(low=self.observation_space_low, high=self.observation_space_high, dtype=np.float32)

        self.done = False


    def set_state(self, cur_rob_state, cur_obj_state):

        self.obj_state = cur_obj_state
        self.rob_state = cur_rob_state

        # initialize object history path
        self.obj_his_path = np.array(self.obj_state[0:3]).reshape((-1, 1))

        self.done = False


    def render(self, mode='human', close=False):
        # Render the environment to the screen
        # Update robot current pos 
        self.fig_robot_pos.set_center([self.rob_state[0], self.rob_state[1]])

        # Update object current pos 
        pos_corner_o = np.array([self.W_o, self.L_o]).T / -2.0   # position of the corner in the object frame
        pos_corner_w = self.obj_state[0:2].T + np.dot(utils.rot_M(self.obj_state[2]), pos_corner_o)      # position of the corner in the world frame
        self.fig_object_pos.set_xy([pos_corner_w[0], pos_corner_w[1]])     # the left bottom coner position 
        self.fig_object_pos.angle = np.rad2deg(self.obj_state[2])
        self.fig_object_pos.set_width(self.W_o)
        self.fig_object_pos.set_height(self.L_o)

        # update robot current orientation
        point_rd = np.dot(utils.rot_M(self.rob_state[2]), np.array([1.0, 0]).T) + np.array([self.rob_state[0], self.rob_state[1]])
        self.line_robot_direction[0].set_data([self.rob_state[0], point_rd[0]], [self.rob_state[1], point_rd[1]])

        # Update object history path
        self.fig_object_path[0].set_data(np.concatenate((self.obj_his_path[0, :], self.obj_his_path[1, :])).reshape((2, -1)))

        # update contact surface
        point1_o = np.array([-self.W_o/2.0,  self.L_o])
        point2_o = np.array([-self.W_o/2.0, -self.L_o])
        point1_w = np.dot(utils.rot_M(self.obj_state[2]), point1_o.T) + np.array([self.obj_state[0], self.obj_state[1]])
        point2_w = np.dot(utils.rot_M(self.obj_state[2]), point2_o.T) + np.array([self.obj_state[0], self.obj_state[1]])
        self.fig_contact_surface[0].set_data([point1_w[0], point2_w[0]], [point1_w[1], point2_w[1]])
        # Update canvas
        self.fig_main.canvas.draw()
        self.fig_main.canvas.flush_events()
        
    def prepare_visualization(self):
        # ================== prepare for visualization ====================
        # Prepare a figure for visualization 
        plt.ion()
        self.fig_main, self.ax_main = plt.subplots()
        self.ax_main.grid(visible=True, ls='-.')
        self.ax_main.set_aspect('equal')
        self.ax_main.set_xlim(self.ws_x)
        self.ax_main.set_ylim(self.ws_y)
        self.ax_main.set_xlabel('x [m]')
        self.ax_main.set_ylabel('y [m]')

        # plot objects
        # robot current pos
        robot_pos_cir = mpatches.Circle(np.array([self.rob_state[0],self.rob_state[1]]), self.r_r, fc=(0, 0.4, 1, 0.8), ec=(0, 0, 1, 0.8), alpha = 0.4)
        self.fig_robot_pos = self.ax_main.add_artist(robot_pos_cir)

        # robot current orientation
        point_rd = np.dot(utils.rot_M(self.rob_state[2]), np.array([2.0, 0]).T) + np.array([self.rob_state[0], self.rob_state[1]])
        self.line_robot_direction = self.ax_main.plot([self.rob_state[0], point_rd[0]], [self.rob_state[1], point_rd[1]], linestyle = 'dotted', linewidth=2)

        # object current pos
        pos_corner_o = np.array([self.W_o, self.L_o]).T / -2.0   # position of the corners in the object frame
        object_pos_rec = mpatches.Rectangle([pos_corner_o[0], pos_corner_o[1]], self.W_o, self.L_o, 0.0, fc=(0, 1, 1, 0.1), ec=(0, 0, 1, 0.8))
        self.fig_object_pos = self.ax_main.add_artist(object_pos_rec)

        # object history path 
        self.fig_object_path = self.ax_main.plot(0.0, 0.0, c='g', ls='-', lw=2.0)
        plt.draw() 

        # contact surface
        point1_o = np.array([-self.W_o/2.0,  self.L_o*2])
        point2_o = np.array([-self.W_o/2.0, -self.L_o*2])
        point1_w = np.dot(utils.rot_M(self.obj_state[2]), point1_o.T) + np.array([self.obj_state[0], self.obj_state[1]])
        point2_w = np.dot(utils.rot_M(self.obj_state[2]), point2_o.T) + np.array([self.obj_state[0], self.obj_state[1]])
        self.fig_contact_surface = self.ax_main.plot([point1_w[0], point2_w[0]], [point1_w[1], point2_w[1]], linestyle = 'dotted')
        # ========================================================================

    def close(self):
        plt.close(self.fig_main)

    def init_path_candidates(self, N):
        self.obj_mpc_planned_path = ['path_i']*N
        for i in range(N):
            self.obj_mpc_planned_path[i] = self.ax_main.plot(0.0, 0.0, c="red", ls='-', lw=2.0)

    def show_paths_candidates(self, mpc_path):
        all_x_coord = []
        all_y_coord = []

        for i in range(0, mpc_path.shape[1]):
            x_coord = mpc_path[:, i, 0]  # a list of x in one episode
            y_coord = mpc_path[:, i, 1]  # a list of y in one episode
            all_x_coord.append(x_coord)
            all_y_coord.append(y_coord)

        # # first remove all paths
        # for paths in self.obj_mpc_planned_path:
        #     paths.remove()
        # now draw new planned paths
        for j in range(len(self.obj_mpc_planned_path)):
            # self.obj_mpc_planned_path.append(self.ax_main.plot(all_x_coord[j], all_y_coord[j], c="red", ls='-', lw=2.0))
            self.obj_mpc_planned_path[j][0].set_data(np.concatenate((all_x_coord[j], all_y_coord[j])).reshape((2, -1)))

        # Update canvas
        self.fig_main.canvas.draw()
        self.fig_main.canvas.flush_events()

    def init_goal_pos(self, goal_pos):
        goal_cir = mpatches.Circle(np.array([goal_pos[0],goal_pos[1]]), 0.2, fc=(1, 0.1, 1, 0.5), ec=(1, 0, 0, 0.8), alpha = 0.4)
        self.fig_goal_pos = self.ax_main.add_artist(goal_cir)

    def set_goal_pos(self, goal_pos):
        self.fig_goal_pos.set_center([goal_pos[0], goal_pos[1]])

    def init_path_candidates_both(self, N):
        self.obj_mpc_planned_path = ['path_i']*N
        self.obj_mpc_planned_path_object = ['path_i']*N
        for i in range(N):
            self.obj_mpc_planned_path[i] = self.ax_main.plot(0.0, 0.0, c="red", ls='-', lw=2.0)
            self.obj_mpc_planned_path_object[i] = self.ax_main.plot(0.0, 0.0, c="magenta", ls='-', lw=2.0)

    def show_paths_candidates_robot(self, mpc_path):
        all_x_coord = []
        all_y_coord = []
        all_x_coord_object = []
        all_y_coord_object = []

        for i in range(0, mpc_path.shape[1]):
            x_coord = mpc_path[:, i, 0]  # a list of x in one episode
            y_coord = mpc_path[:, i, 1]  # a list of y in one episode
            all_x_coord.append(x_coord)
            all_y_coord.append(y_coord)

        # # first remove all paths
        # for paths in self.obj_mpc_planned_path:
        #     paths.remove()
        # now draw new planned paths
        for j in range(len(self.obj_mpc_planned_path)):
            # self.obj_mpc_planned_path.append(self.ax_main.plot(all_x_coord[j], all_y_coord[j], c="red", ls='-', lw=2.0))
            self.obj_mpc_planned_path[j][0].set_data(np.concatenate((all_x_coord[j], all_y_coord[j])).reshape((2, -1)))

        # Update canvas
        self.fig_main.canvas.draw()
        self.fig_main.canvas.flush_events()

    def show_paths_candidates_both(self, mpc_path):
        all_x_coord = []
        all_y_coord = []
        all_x_coord_object = []
        all_y_coord_object = []

        for i in range(0, mpc_path.shape[1]):
            x_coord = mpc_path[:, i, 0]  # a list of x in one episode
            y_coord = mpc_path[:, i, 1]  # a list of y in one episode
            all_x_coord.append(x_coord)
            all_y_coord.append(y_coord)

            x_coord_object = mpc_path[:, i, 5]  # a list of x in one episode
            y_coord_object = mpc_path[:, i, 6]  # a list of y in one episode
            all_x_coord_object.append(x_coord_object)
            all_y_coord_object.append(y_coord_object)

        # # first remove all paths
        # for paths in self.obj_mpc_planned_path:
        #     paths.remove()
        # now draw new planned paths
        for j in range(len(self.obj_mpc_planned_path)):
            # self.obj_mpc_planned_path.append(self.ax_main.plot(all_x_coord[j], all_y_coord[j], c="red", ls='-', lw=2.0))
            self.obj_mpc_planned_path[j][0].set_data(np.concatenate((all_x_coord[j], all_y_coord[j])).reshape((2, -1)))
            self.obj_mpc_planned_path_object[j][0].set_data(np.concatenate((all_x_coord_object[j], all_y_coord_object[j])).reshape((2, -1)))

        # Update canvas
        self.fig_main.canvas.draw()
        self.fig_main.canvas.flush_events()

    def save_trajectory(self):
        # action [acc_r, acc_omega]
        # robot state: self.rob_state [x, y, theta, vr, omega]
        # object state: self.obj_state [x, y, theta, vx, vy, omega]

        full_state = np.append(np.append(np.append(self.prev_rob_state, self.prev_obj_state), self.action), self.time)
        full_state = full_state.reshape(1, -1)
        self.full_trajectory = np.concatenate((self.full_trajectory, full_state), axis=0)

    def save_trajectory_data(self, i_episode, save_directory):
        # save_directory = '/home/susan/Documents/_isaac/sim_yu/trajectories/trajectories_collect/push_trajectory'
        np.save(save_directory + str(i_episode) + '.npy', self.full_trajectory)
