#!/usr/bin/env python
import numpy as np
import rospy
from timeit import default_timer as timer
import sys
from nav_msgs.msg import Odometry
from scipy.spatial.transform import Rotation as R

class State_Estimator():

    def __init__(self, mode):
        # init robot and object state
        self.robot_state =          [0.0, 0.0, 0.0, 0.0, 0.0]
        self.object_state =         [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
        self.full_state =           [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
        self.prev_full_state =      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

        if mode == "mocap":
            self.topic_name_robot = "/Bebop1/position_velocity_orientation_estimation"
            self.topic_name_object = "/Bebop2/position_velocity_orientation_estimation"

        if mode == "simulation":
            self.topic_name_robot = "/poseTwist_of_husky_wrt_world"
            self.topic_name_object = "/pose_of_box_wrt_world"

        self.sub_robot_state_gazebo = rospy.Subscriber(self.topic_name_robot,Odometry, self.receive_robot_state)
        self.sub_object_state = rospy.Subscriber(self.topic_name_object, Odometry, self.receive_object_state)

    def receive_robot_state(self, odom_robot):

        r = R.from_quat([odom_robot.pose.pose.orientation.x, 
                            odom_robot.pose.pose.orientation.y, 
                            odom_robot.pose.pose.orientation.z, 
                            odom_robot.pose.pose.orientation.w])

        ypr = r.as_euler('zyx', degrees=False)[0]

        vel = np.sqrt(odom_robot.twist.twist.linear.x**2 + odom_robot.twist.twist.linear.y**2)
        omega = odom_robot.twist.twist.angular.z
        self.robot_state = [odom_robot.pose.pose.position.x, 
                                   odom_robot.pose.pose.position.y, ypr, vel, omega]

    def receive_object_state(self, obj_pos_msg):

        r = R.from_quat([obj_pos_msg.pose.pose.orientation.x, 
                            obj_pos_msg.pose.pose.orientation.y, 
                            obj_pos_msg.pose.pose.orientation.z, 
                            obj_pos_msg.pose.pose.orientation.w])

        ypr = r.as_euler('zyx', degrees=False)[0]

        vel_x = obj_pos_msg.twist.twist.linear.x
        vel_y = obj_pos_msg.twist.twist.linear.y
        omega = obj_pos_msg.twist.twist.angular.z
        self.object_state = [obj_pos_msg.pose.pose.position.x, obj_pos_msg.pose.pose.position.y, ypr, vel_x, vel_y, omega]

    def check_for_anomalies(self, full_state):
        # check robot velocity
        full_state_new = full_state
        v_r = full_state[3]
        omega_r = full_state[4]

        if v_r > 2.5 or v_r < -2.5:
            full_state_new[3] = self.prev_full_state[3]
        if omega_r > 2.5 or omega_r < -2.5:
            full_state_new[4] = self.prev_full_state[4]

        # check object velocity
        v_xo = full_state[8]
        v_yo = full_state[9]
        omega_o = full_state[10]

        if v_xo > 2.5 or v_xo < -2.5:
            full_state_new[8] = self.prev_full_state[8]
        if v_yo > 2.5 or v_yo < -2.5:
            full_state_new[9] = self.prev_full_state[9]
        if omega_o > 2.5 or omega_o < -2.5:
            full_state_new[10] = self.prev_full_state[10]

        if full_state_new.any() != full_state.any():
            print("Weird velocities detected!")
        return full_state_new

    def get_state(self):
        # first set the previous state to remember it
        self.prev_full_state = self.full_state

        # update the current state
        full_state = self.robot_state + self.object_state
        full_state_arr = np.array(full_state)

        # check if there are some weird data values 
        full_state_arr = self.check_for_anomalies(full_state_arr)

        self.full_state = full_state_arr

        return full_state_arr


if __name__ == '__main__':
    pass