#!/usr/bin/env python

import rospy
from controllers.PID.PID import PID
import numpy as np
from scipy.spatial.transform import Rotation as R
from controllers.PID.helpers import *


class DynamicPush:
    def __init__(self, pushGoal_):

        self.object_pose_x = None
        self.robot_pose_x = None
        self.object_pose_y = None
        self.robot_pose_y = None
        self.robot_angle = None
        
        self.pushGoal_ = pushGoal_ 

        #init PID parameters
        self.P_alfa = 0.99
        self.I_alfa = 0.0
        self.D_alfa = 0.0

        self.I_alfa_max = 0.8

        self.PID_alfa = None

        self.ang_vel_max = 0.1
        self.linear_max = 0.2
        self.linear_min = 0.1

        #init pars activation functions
        self.mi_gamma = 0.0
        self.sigma_gamma = 0.0
        self.mi_alpha = np.pi 
        self.sigma_alpha = np.pi/4
        self.mi_posterior = np.pi 
        self.sigma_posterior = np.pi/4
        self.mi_prior = np.pi 
        self.sigma_prior = np.pi/4
        self.mi_posterior_pred = np.pi 
        self.sigma_posterior_pred = np.pi
        self.count_dr = 100
        self.sum_alpha = self.count_dr * self.mi_alpha
        self.count_all = 1

        self.alfa_list = []
        self.gamma_list = []

        self.prev_object_x = 0.8
        self.prev_object_y = 0.0

        self.count_all = 1

        self.gamma = 0.0
        self.old_gamma = 0.0

        self.aPORp = None
        self.aPOR = None
        self.filt_com = None

        self.psi_push = None
        self.psi_rel = None


    def receiveObjectStatePush(self, object_pose_x, object_pose_y):
        self.object_pose_x = object_pose_x
        self.object_pose_y = object_pose_y

    def receiveRobotStatePush(self, robot_pose_x, robot_pose_y, robot_angle):
        self.robot_pose_x = robot_pose_x
        self.robot_pose_y = robot_pose_y
        self.robot_angle = robot_angle

    def initPIDs(self):
        self.PID_alfa = PID(self.P_alfa, self.I_alfa, self.D_alfa, setpoint=0)
        self.PID_alfa.output_limits = (-self.I_alfa_max, self.I_alfa_max)

    def init(self):
        self.aPOR = angle3Points(self.pushGoal_[0], self.pushGoal_[1], self.object_pose_x, self.object_pose_y, self.robot_pose_x, self.robot_pose_y)
        self.aPORp = self.aPOR

    def update(self):
        if self.object_pose_x is None or  self.object_pose_y is None or self.robot_pose_x is None or self.robot_pose_y is None:
            print("No object/robot pose received yet")
        else: 

            dR2O = distPoints(self.robot_pose_x, self.robot_pose_y, self.object_pose_x, self.object_pose_y)
            self.aPOR = angle3Points(self.pushGoal_[0], self.pushGoal_[1], self.object_pose_x, self.object_pose_y, self.robot_pose_x, self.robot_pose_y)

            if self.aPOR<0:
                self.aPOR = 2*np.pi + self.aPOR

            self.alfa_list.append(self.aPOR)

            # if there was some movement
            if distPoints(self.object_pose_x, self.object_pose_y, self.prev_object_x, self.prev_object_y) > 0.0001:
                expected_dir = getVectorAngle(self.pushGoal_[0] - self.prev_object_x, self.pushGoal_[1] - self.prev_object_y) 
                executed_dir = getVectorAngle(self.object_pose_x - self.prev_object_x, self.object_pose_y - self.prev_object_y)

                self.prev_object_y = self.object_pose_y
                self.prev_object_x = self.object_pose_x

                self.count_all+=1

                tempA = self.mi_gamma
                self.gamma = expected_dir - executed_dir
                if self.gamma > 2*np.pi:
                    self.gamma = self.gamma - 2*np.pi 
                if self.gamma < 2*np.pi:
                    self.gamma = self.gamma + 2* np.pi

                self.mi_gamma = self.mi_gamma + (self.gamma - self.mi_gamma) / self.count_all
                self.sigma_gamma = ((self.count_all - 1) * self.sigma_gamma + (self.gamma - tempA) * (self.gamma - self.mi_gamma)) / self.count_all

                if (np.abs(self.gamma) - np.abs(self.old_gamma)) < -0.01:
                    self.count_dr +=1
                    tempT = self.mi_alpha
                    self.mi_alpha = self.mi_alpha + (self.aPORp - self.mi_alpha) / self.count_dr
                    self.sigma_alpha = ((self.count_dr-1) * self.sigma_alpha + (self.aPORp - tempT) * (self.aPORp - self.mi_alpha)) / self.count_dr
                    self.sum_alpha = self.sum_alpha + self.aPORp
                    self.mi_prior = self.mi_posterior
                    self.sigma_prior = self.sigma_posterior
                    self.mi_posterior = (self.sigma_alpha * self.mi_prior / self.count_dr + self.sigma_prior * self.sum_alpha / self.count_dr) / (self.sigma_alpha / self.count_dr + self.sigma_prior)
                    self.sigma_posterior = 1 / (self.count_dr / self.sigma_alpha + 1 / self.sigma_prior)
                    self.mi_posterior_pred = self.mi_posterior
                    self.sigma_posterior_pred = self.sigma_posterior +  self.sigma_alpha

                self.gamma_list.append(self.gamma)

                self.old_gamma = self.gamma

            self.aPORp = self.aPOR

            self.psi_push = getGaussianVal(self.aPOR, self.sigma_posterior_pred, self.mi_posterior_pred)
            
            if self.psi_push > np.cos(self.aPOR):
                self.psi_push = np.abs(np.cos(self.aPOR))

            self.filt_com = 1.0

            if dR2O > 1.0:
                self.psi_rel = 0.0
                self.filt_com = 0.0
                self.psi_push = 1.0
                print("big difference between object and robot pose, ", dR2O)
            elif (np.cos(self.aPOR)>0.0):
                self.psi_push = -self.psi_push
                print("push dyn alpha")
            elif (np.abs(self.aPOR - np.pi)>0.4) and np.cos(self.aPOR)<0:
                self.psi_push=0.0
            else:
                pass

            self.psi_rel = np.sqrt(1 - self.psi_push**2)
            print(self.psi_push, "psi push")
            print(self.psi_rel, "psi rel")


    def getPushCmd(self):
        if self.object_pose_x is None or  self.object_pose_y is None or self.robot_pose_x is None or self.robot_pose_y is None:
            print("No object/robot pose received yet")
        else:

            cmd = getNullTwist()

            aO2P = getVectorAngle(self.pushGoal_[0] - self.object_pose_x, self.pushGoal_[1] - self.object_pose_y)
            aR2O = getVectorAngle(self.object_pose_x - self.robot_pose_x, self.object_pose_y - self.robot_pose_y)
           
            theta_v = self.aPOR

            vx_push = self.psi_push * np.sign(np.cos(theta_v)) * np.cos(theta_v)
            vy_push = self.psi_push * np.sign(np.cos(theta_v)) * np.sin(theta_v)

            vx_relocate = - self.psi_rel * np.sign(np.sin(self.mi_alpha - theta_v)) * np.sin(theta_v)
            vy_reloacte = self.psi_rel * np.sign(np.sin(self.mi_alpha - theta_v)) * np.cos(theta_v)

            vx = vx_push + vx_relocate
            vy = vy_push + vy_reloacte

            theta_v = getVectorAngle(vx, vy)
            theta_v = theta_v * self.filt_com* 0.05 * self.mi_gamma - self.filt_com * 0.1 * self.gamma

            rot_diff = rotationDifference(aO2P, self.robot_angle)
            v_res_x, v_res_y = rotatePoint(vx, vy, rot_diff)

            V = self.linear_max / (1+ np.abs(self.mi_gamma))
            if V< self.linear_min:
                V = self.linear_min

            norm = np.sqrt(v_res_x**2 + v_res_y**2)
            rx = V * v_res_x / norm
            ry = V * v_res_y / norm

            orient_error = rotationDifference(aR2O, self.robot_angle)
            if orient_error > 0.3:
                cmd.linear.x = 0
                cmd.linear.y = 0
                rx=0
                ry=0

            cmd.angular.z = self.PID_alfa(rot_diff)
            cmd.linear.x = rx # vx
            cmd.linear.y = ry # vy

        return rx, ry, self.PID_alfa(rot_diff)


# if __name__ == "__main__":
#     rospy.init_node('DynamicPush', anonymous=True)
#     push = DynamicPush()
#     while push.object_pose_ is None or push.robot_pose_ is None:
#         push.rate_.sleep()
#     push.initPIDs()
#     push.init()
    
#     while not rospy.is_shutdown():
#         push.update()
#         push.getPushCmd()

#     rospy.spin()

