#! /usr/bin/env python

import rospy
import os
import numpy as np
import matplotlib.pyplot as plt
import math 
from sys import exit

from scipy import interpolate 
from std_msgs.msg import String,  Header
from sensor_msgs.msg import ChannelFloat32, JointState

# for logged data mode
seq_index = 0
final_array = None

import time


# gets string from raw__sensor data topic
def raw_data_callback(data): 


  """ 
  data string pre processing

  """
  # remove unnecessary chars
  t0 = time.time()
  global seq_index
  global final_array

  if (seq_index == 600):
    seq_index = 0
    return

  data.data = data.data.replace("[","")
  data.data = data.data.replace("]","")
  data.data = data.data.replace("\n", "")
  
  # remove empty strings
  data_list = [x for x in data.data.split(" ") if x ]
  
  # convert to floats
  in_array = np.array(data_list, dtype=float)
 
  res = gripper.feed_forward(in_array)

  # reorder graph
  res = -res

  ang_bone_data, x_bone_data, y_bone_data = convert_to_anim_data(res[::2] , res[1::2])

  # process data for sending to AR VIZ per bone 3D pos and 3D Rot
  unity_data = [ 
                                                              
                 x_bone_data[0] , -y_bone_data[0] , 0,  0 ,0 , ang_bone_data[0],
                 x_bone_data[1], -y_bone_data[1]  , 0 ,  0 ,0 , ang_bone_data[1], 
                 x_bone_data[2] , -y_bone_data[2]  , 0,  0 ,0 ,ang_bone_data[2], 
                 x_bone_data[3] , -y_bone_data[3] , 0,  0 ,0 , ang_bone_data[3],                                                             
                 res[10]- 0.0825 , -res[11] , 0,  0 ,0 ,-180 # root of system (fixed) 

               ]

  array_np = np.array(unity_data)
  shaped_arr = array_np.reshape((1,30))
  # the data to be sent, initialise the array     
  bone_config_array = JointState()  
  bone_config_array.header = Header()
  bone_config_array.position =  unity_data 
  # assign the array with the value you want to send
  pub.publish(bone_config_array)
  print(seq_index)
  # using Sequence Index sim mode
  seq_index +=1


# processor function converting 
# NN position 2d points into format for real time viz
def convert_to_anim_data(x, y,): 
  # get cubic spline interpolation 
  tck, u = interpolate.splprep([x,y])
  # new points in which to evaluate tangents for bones
  ti = np.array([ 0.1 ,  0.3,  0.5,  0.7])

  # evaluate interpolated curve in points we care about  (tangent for each of the 5 points)
  # tangent is [dx ,dy], we can get normal by simply rotating it by 90 degrees  
  dx_dt, dy_dt = interpolate.splev(ti, tck, der=1)
  x_bones, y_bones = interpolate.splev(ti, tck, der = 0)
  # normal is given by rotating tangent [ -dy, dx]
  normals = np.array([-dy_dt, dx_dt], dtype=float)
  normals =  normals.transpose() # so shape is a normal per row

  # normalize them 
  row_sum = np.sqrt(np.sum(normals**2, axis=1)) # magnitude of each norm pair
  norm_normals = normals  / row_sum[:, np.newaxis]
  cosines = norm_normals[:, 1]

  return(np.degrees(np.arccos(np.clip(cosines, -1, 1))), x_bones, y_bones)

# Feed forward neural network in numpy (3 layers not including input)
# takes input size, output size 
class GripperFFN:
    def __init__(self, x, y):
        # input layer
        self.input      = np.zeros(x)

        # output layer            
        self.output     = np.zeros(y)

        ## Internal NN data 

        # input layer (16 nodes) data
        self.x1_step_xoffset = np.array([881,1640,2031,4740,977,1697,1857,4633,966,1592,1836,4467,665,1149,1347,3265])
        self.x1_step_gain = np.array([0.00335570469798658,0.00123685837971552,0.000802246289610911,0.000418848167539267,0.00350262697022767,0.00156985871271586,0.00119260584376863,0.000557724484104852,0.00355239786856128,0.0016220600162206,0.00118623962040332,0.000576036866359447,0.00391389432485323,0.00158227848101266,0.00118203309692671,0.000567054153671676])
        self.x1_step_ymin = -1

        # Layer 1 - (10 nodes)
        self.b1 = np.array([-0.059536523785815927134,1.3354269377371938976,0.098426896827479626317,0.33473418083876443907,-0.272454257342443662,0.4099605984914136414,1.2363367969248282296,1.3092362370854537801,0.55764586943457084001,-1.9337380614892656361])
        # layer 1 lin weights (10 x 16)
        self.LW_1 = np.array([[-0.25229928,  0.6664376 , -0.46633049, -0.49462393,  0.42432173,
                 1.11331037,  0.28567787,  0.33060954, -0.31926306,  0.53938021,
                -0.25233539, -0.33329282, -0.43221728, -0.70176504, -1.57622398,
                 0.31911451],
               [ 0.09273581,  0.22313012, -0.5335523 , -0.05852288, -0.28381566,
                -0.15162598,  0.39820567, -0.3213374 , -0.06784085,  0.75206884,
                -0.52453785,  0.40060456,  0.04754084,  0.61236734,  0.11014705,
                 0.51892458],
               [-0.04724974,  0.12253153,  0.27473859,  0.01739523, -0.05895964,
                -0.53846905, -0.38036763,  0.09968862,  0.21513616,  0.34594306,
                -0.47393672,  0.12704656,  0.40706977,  0.08749001, -0.40190599,
                 0.0026874 ],
               [-0.84871683,  0.0754264 ,  0.09966625, -0.45799309, -1.13728949,
                 0.45360379,  0.86973895, -0.17119379, -0.55898635,  1.02759755,
                 1.07640434,  0.23614207,  0.41518243,  0.02915285,  0.99586442,
                 0.10473361],
               [-0.70652461,  0.75161681,  0.40877995,  0.1799286 , -0.44703896,
                -1.03122817, -0.55224701,  0.23726419,  0.30621326,  0.79059325,
                -0.32665447,  0.35182751,  0.89775069,  0.14239136, -0.22282346,
                -0.21171346],
               [ 0.58679948, -0.70412274, -0.65183651,  0.07011823,  0.5645195 ,
                -0.22789521,  0.93631625,  0.98605432, -0.41457122, -0.36016688,
                 0.01884947, -0.44717467,  0.39244815, -0.75622699, -0.48747603,
                -0.11898306],
               [-0.26147487,  0.31589482,  0.27165069, -0.83439741, -1.14369983,
                 1.09543974,  0.38411187,  0.68745413, -0.64147317,  0.12672911,
                 0.1898453 ,  1.12904051, -1.32056391,  0.31521238, -0.66573237,
                -0.03115702],
               [ 0.27761092, -0.60415127,  0.59491452, -0.14189804, -0.08866412,
                -0.91049911,  0.9252058 , -0.57308518,  0.08705993,  0.76372833,
                -0.60118745, -0.63545701, -0.13569712,  0.15623965,  0.966148  ,
                 0.71620653],
               [ 0.34356366,  0.04986506,  0.08543531, -0.18432167,  0.0110901 ,
                 0.6707782 , -0.30847934, -0.91040402,  0.17441217, -0.0125193 ,
                -0.83481403,  0.32171197, -0.67891645,  0.77410147, -0.33987968,
                 0.1003797 ],
               [-0.07261834,  0.32549167, -0.09630331,  0.17519286,  0.28211322,
                 0.15930322, -0.39876469, -0.19727929,  0.15716225, -0.2845294 ,
                -0.28957616,  0.19944746, -0.01276154, -0.26365315, -0.92518425,
                -0.111112  ]])

        # layer 2 nodes (12)
        self.b2 = np.array([2.7722624555636343757,-0.2787855966687131648,2.7942149228895627466,0.060700249977518591638,2.1915045996955222485,0.23106685653962399929,0.4764069872857832344,-0.2637457138919779287,0.33134654598836676964,-0.72834113383148024745,-0.3119019356328651682,-0.53418405859196860241])

        # layer 2 linear weights (12, 10)
        self.LW_2 = np.array([[ 0.10421112,  0.7680723 ,  2.48857688, -0.02756092, -1.67336203,
                -0.86435246,  0.25531288, -1.39786097, -1.39454712,  2.92378361],
               [ 0.26085107, -0.21577632,  2.23884045,  0.05098436, -0.90077304,
                -0.23357188,  0.76919904, -0.88437467, -0.53455092, -0.39313212],
               [ 0.17624528, -0.36610551,  0.75103388,  0.04897287, -0.60279706,
                -0.4501455 , -0.00580824, -0.8565602 , -0.53284198,  2.41487332],
               [ 0.30009895, -0.28990705,  1.464966  ,  0.066562  , -0.00386515,
                 0.34203837,  0.65689584, -0.69523664, -0.21622301,  0.01187001],
               [ 0.30555049, -1.22906124,  0.64789283,  0.17085183, -0.72689651,
                -0.73018578,  0.09994386, -0.93485159, -0.54770792,  0.96813376],
               [ 0.33745815, -0.26605432,  0.93212846,  0.07955089,  0.59275412,
                 0.71858013,  0.6527635 , -0.51147458,  0.01244265,  0.38757613],
               [ 0.74012787, -1.9049974 ,  0.65005161,  0.32380358, -0.28424004,
                -0.6770737 ,  0.56993606,  0.40269937, -0.40109996, -0.13974754],
               [ 0.47333275, -0.6057067 ,  1.36313384,  0.20532371,  1.09665097,
                 1.21929967,  0.71364668,  0.21703109, -0.19369792,  0.17349297],
               [ 0.5088848 , -0.67249342, -0.4999847 , -0.00654422,  0.2012985 ,
                -1.00307764,  2.19861031, -0.45356571, -0.69040485,  0.22539155],
               [ 0.56002926,  0.37100376,  1.03353725,  0.19127787,  1.02945384,
                -0.01745113,  1.71643435,  0.03530322, -0.21201619,  1.13117471],
               [ 0.2403271 , -0.28938399, -0.61746123, -0.11806287,  0.36793455,
                -0.63015703,  1.2878456 ,  0.19080754, -0.18084243,  0.34267491],
               [ 0.05603849, -0.20012301,  0.42387984,  0.0499089 ,  0.08569474,
                -0.26498457,  0.80934803, -0.22407051, -0.36329022, -0.22149996]])


        # output layer nodes (12 nodes) - even for x coords , odd idx for y coords
        self.y1_step1_ymin = -1
        self.y1_step_gain = np.array([0.0181533796742763,0.00920083414176453,0.0387741304649813,0.0135700453441686,0.102859798214089,0.0256436952089366,0.373501441513425,0.0874260716609571,1.10718777627207,0.842242666981568,1.00732112829588,1.03090146075654])
        self.y1_step_offset = np.array([843.098068793099,467.42140251033,918.994793781876,465.104741417797,994.290675625534,473.780380723937,1072.68014375021,486.613329463474,1145.67046076056,492.679327255819,1223.51600409659,492.757276484495])

    def feed_forward(self, x): # computes feed forward given X 
      # input formatting
      
      processed_x = self.map_min_max(x, self.x1_step_xoffset, self.x1_step_gain, self.x1_step_ymin)
      
      # layer 1 
      #print(np.matmul(self.LW_1 , processed_x).shape)
      #print(self.b1.shape)
      layer1 = self.sigmoid(self.b1 + np.matmul(self.LW_1 , processed_x))

      # layer 2 
      layer2 = self.b2 +  np.matmul(self.LW_2, layer1)
      # output
      output = self.map_min_max_reverse(layer2, self.y1_step_offset, self.y1_step_gain, self.y1_step1_ymin )

      return(output) 

    def sigmoid(self, n):
      return(2.0 / (1.0 + np.exp(-2.0*n)) -1.0)

    def map_min_max(self, x , offset, gain , ymin):
      y = x - offset
      y = y * gain 
      y = y + ymin
      return(y)

    def map_min_max_reverse(self, y , offset, gain , ymin):
      x = y - ymin
      x = x / gain 
      x = x + offset
      return(x)



#  subscriber to soft gripper sensor data
sub = rospy.Subscriber('/sg_sensor_data_raw', String, raw_data_callback)
pub = rospy.Publisher('/sg_sensor_bone_configs', JointState, queue_size=10)


if __name__ == '__main__':
  # for interactive plot
  plt.ion()
  gripper = GripperFFN(16, 12)
  rospy.init_node("soft_gripper_node")
  rospy.spin() # maintain the service open.
     
