import sys
import time
import numpy as np
import pandas as pd  
import os
import threading
import sys
import select
import tty
import termios
import serial
from Thermal_image import generate_thermal_image, get_temp_result
from Controller import SMP_Controller
from mlx.mlx90640 import Mlx9064x
from unittest.mock import MagicMock


def save_to_csv(frame_count, f, amb, label):

    f_matrix = np.array(f).reshape(24, 32)

    df = pd.DataFrame(f_matrix)
       
    filename = f"/home/zxh/Desktop/mlx9064x-driver-py-master/mlx/examples/data/{label}_temp_matrix_frame_{frame_count}.csv"
    
    with open(filename, 'w') as file:
        file.write(f" {amb}\n")  
    
    df.to_csv(filename, mode='a', header=False, index=False)  

def initialize_all_connections(serial_port="/dev/ttyACM0", baudrate=9600, timeout=1, i2c_port="I2C-1", frame_rate=2):
    
    # Dictionary to hold the connections
    connections = {}

    # 1. Initialize Serial connection (Raspberry Pi to Arduino)
    try:
        ser = serial.Serial(serial_port, baudrate, timeout=timeout)
        print(f"Connected to Arduino on {serial_port} at {baudrate} baud.")
        connections['serial'] = ser
    except serial.SerialException as e:
        print(f"Error connecting to Arduino: {e}")
        sys.exit(1)

    # 2. Initialize MLX90640 I2C connection
    try:
        dev = Mlx9064x(i2c_port, frame_rate=frame_rate)
        dev.init()  # Initialize the device
        dev.emissivity = 0.95  # Set emissivity
        print(f"MLX90640 initialized on {i2c_port} with frame rate {frame_rate} Hz.")
        connections['mlx'] = dev
    except Exception as e:
        print(f"Error initializing MLX90640: {e}")
        sys.exit(1)

    # print("All connections initialized successfully.")
    return connections

# Function to send wake command to Arduino
def wake_arduino(ser):
    try:
        # 发送wake指令
        ser.write(b'wake\n')  
        print("Sent wake command to Arduino.")
        
        # 等待Arduino响应
        time.sleep(2)  
        if ser.in_waiting > 0:
            response = ser.read(ser.in_waiting).decode('utf-8')
            print(f"Arduino responded: {response}")
        else:
            print("No response from Arduino.")
    except Exception as e:
        print(f"Error communicating with Arduino: {e}")


def set_nonblocking_input():
    fd = sys.stdin.fileno()
    old_settings = termios.tcgetattr(fd)
    tty.setcbreak(fd) 
    return old_settings

def restore_input_mode(old_settings):
    fd = sys.stdin.fileno()
    termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)

# Energency stop，keyboard input 's'
def emergency_stop(emergency_stop_condition):
    print("Listening for emergency stop signal ('s')...")
    old_settings = set_nonblocking_input()  
    try:
        while not emergency_stop_condition.is_set():
            # 检查输入是否可读
            if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
                input_str = sys.stdin.read(1).strip()  
                print(f"Detected input: {input_str}")  
                if input_str.lower() == 's':  
                    emergency_stop_condition.set()
                    print("Emergency stop triggered!")
            time.sleep(0.1)  
    finally:
        restore_input_mode(old_settings)  


# Function for standby mode: SMP Preheating
def standby_mode(max_vol, frame_rate, dev, ser, smp_con, standby_condition):
    standby_condition.wait()
    global frame_id
    label = "Standby"
    print("Standby mode started.")
    target_temp = 50
    smp_con.set_target(target_temp)
    while True:
        standby_condition.wait()
        smp_con.set_target(target_temp)
        
        frame = None
        f = []
        amb = 0
        num_target = 1
        try:
            frame = dev.read_frame()
        except Exception as e:
            print("ERROR:", e)
            dev.clear_error(frame_rate)
            pass
        if frame is not None:

            f = dev.do_compensation(frame, add_ambient_temperature=True)
            f = dev.do_handle_bad_pixels(f)
            amb = f.pop()
            # save_to_csv(frame_id, f, amb, label)
            f_matrix = np.array(f).reshape(24, 32)
            result = generate_thermal_image(temp_matrix=f_matrix, amb=amb, scale_factor=3, num_target=num_target)
            mask, temp_matrix = result
            cal_result = get_temp_result(num_target=num_target, mask=mask, temp_matrix=temp_matrix, amb=amb)
            # print(cal_result)
            # print("Standby ING!!!")
            
            smp_con.sampling_update(cal_result)
            action_smp = smp_con.pid_controller()
            pwm_value_smp = round(action_smp/max_vol*100, 2)
            command = (f"1:0:0,2:0:0,3:0:0,4:0:0, \
                        5:{pwm_value_smp}:0,6:{pwm_value_smp}:0,7:{pwm_value_smp}:0, 8:{pwm_value_smp}:0\n"
            )

            try:
                ser.write(command.encode())  # Send command to Arduino
                # print(f"Sent command to Arduino: {command.strip()}")
            except Exception as e:
                print(f"Error sending command to Arduino: {e}")

            frame_id += 1

# Function for pre-heat mode: SMA Preheating
def preheat_mode(direction, action_sma_pre, frame_rate, dev, ser, emergency_stop_condition):
    print("Pre-heating Mode started!!!")
    global frame_id
    sensor_count = 0
    while sensor_count<action_sma_pre:
        if emergency_stop_condition.is_set():
            command = ("1:0:0,2:0:0,3:0:0,4:0:0,5:0:0,6:0:0,7:0:0, 8:0:0\n")
            print("Emergency stop detected during Mode 3!")
            break 
        frame = None
        f = []
        amb = 0
        num_target = 2
        try:
            frame = dev.read_frame()
        except Exception as e:
            print("ERROR:", e)
            dev.clear_error(frame_rate)
            pass
        if frame is not None:

            f = dev.do_compensation(frame, add_ambient_temperature=True)
            f = dev.do_handle_bad_pixels(f)
            amb = f.pop()
            # save_to_csv(frame_id, f, amb, label)
            f_matrix = np.array(f).reshape(24, 32)
            result = generate_thermal_image(temp_matrix=f_matrix, amb=amb, scale_factor=3, num_target=num_target)
            mask, temp_matrix = result
            cal_result = get_temp_result(num_target=num_target, mask=mask, temp_matrix=temp_matrix, amb=amb)
            print(cal_result)
            pwm_value_sma = 99
            if direction == 1:
                command = (
                    f"1:{pwm_value_sma}:0,2:0:0,3:0:0,"
                    f"4:0:0,5:0:0,6:0:0,7:0:0, 8:0:0\n"
                )
            elif direction == 2:
                    command = (
                    f"1:0:0,2:{pwm_value_sma}:0,3:0:0,"
                    f"4:0:0,5:0:0,6:0:0,7:0:0, 8:0:0\n"
                )
            elif direction == 3:
                    command = (
                    f"1:0:0,2:0:0,3:{pwm_value_sma}:0,"
                    f"4:0:0,5:0:0,6:0:0,7:0:0, 8:0:0\n"
                )
            elif direction == 4:
                    command = (
                    f"1:{pwm_value_sma}:0,2:{pwm_value_sma}:0,3:0:0,"
                    f"4:0:0,5:0:0,6:0:0,7:0:0, 8:0:0\n"
                )
            elif direction == 5:
                    command = (
                    f"1:{pwm_value_sma}:0,2:0:0,3:{pwm_value_sma}:0,"
                    f"4:0:0,5:0:0,6:0:0,7:0:0, 8:0:0\n"
                )
            elif direction == 6:
                    command = (
                    f"1:0:0,2:{pwm_value_sma}:0,3:{pwm_value_sma}:0,"
                    f"4:0:0,5:0:0,6:0:0,7:0:0, 8:0:0\n"
                )
            else:
                    command = (
                    f"1:{pwm_value_sma}:0,2:{pwm_value_sma}:0,3:{pwm_value_sma}:0,"
                    f"4:0:0,5:0:0,6:0:0,7:0:0, 8:0:0\n"
                )

            try:
                ser.write(command.encode())  # Send command to Arduino
                print(f"Sent command to Arduino: {command.strip()}")
            except Exception as e:
                print(f"Error sending command to Arduino: {e}")
            sensor_count += 1
            frame_id += 1
    print(f"Pre-heat Mode completed. Returning to standby mode.")

# Function for Mode 1 & 2: Small & Large steps
def execute_mode1_and_2(user_input, direction, steps, action_sma,action_sma_pre, action_sma_interval, max_vol, frame_rate, dev, ser, smp_con, emergency_stop_condition):
    global frame_id
    condition = "Small" if user_input == '1' else "Large"
    label = f"{condition}_stepping"
    print(f"Executing Mode {user_input}: {condition} steps in direction {direction} for {steps} steps.")
    target_temp = 110
    smp_con.set_target(target_temp)
    start_sma = False
    sma_action_count = 0
    length_circle = action_sma + action_sma_interval
    while sma_action_count<steps * length_circle + action_sma_pre:
        if emergency_stop_condition.is_set():
            command = ("1:0:0,2:0:0,3:0:0,4:0:0,5:0:0,6:0:0,7:0:0, 8:0:0\n")
            print("Emergency stop detected during Mode 3!")
            break 
        frame = None
        f = []
        amb = 0
        num_target = 2
        try:
            frame = dev.read_frame()
        except Exception as e:
            print("ERROR:", e)
            dev.clear_error(frame_rate)
            pass
        if frame is not None:

            f = dev.do_compensation(frame, add_ambient_temperature=True)
            f = dev.do_handle_bad_pixels(f)
            amb = f.pop()
            # save_to_csv(frame_id, f, amb, label)
            f_matrix = np.array(f).reshape(24, 32)
            result = generate_thermal_image(temp_matrix=f_matrix, amb=amb, scale_factor=3, num_target=num_target)
            mask, temp_matrix = result
            cal_result = get_temp_result(num_target=num_target, mask=mask, temp_matrix=temp_matrix, amb=amb)
            print(cal_result)
            if abs(cal_result-target_temp)<=15:
                start_sma = True
            smp_con.sampling_update(cal_result)
            action_smp = smp_con.pid_controller()
            pwm_value_smp = round(action_smp/max_vol*100, 2)
            pwm_value_sma = 0
            pwm_port1 = pwm_port2 = pwm_port3 = 0
            index = -2      # index means the current working condition
            if start_sma and sma_action_count<steps * length_circle + action_sma_pre:
                pwm_value_sma = 99
                if sma_action_count < action_sma_pre:
                    index = -1      # during the preparation
                else:
                    index_judge = (sma_action_count-action_sma_pre) % length_circle       
                    if index_judge < action_sma:
                        index = 0       # during actions
                    else:
                        index = 1       # during intervals
                sma_action_count+=1
            if index == -1 or index == -2:
                pwm_port1 = pwm_value_sma
                pwm_port2 = pwm_value_sma
                pwm_port3 = pwm_value_sma

            elif index == 0:
                if direction == 1:
                    pwm_port1 = pwm_value_sma
                elif direction == 2:
                    pwm_port2 = pwm_value_sma
                elif direction == 3:
                    pwm_port3 = pwm_value_sma
                elif direction == 4:
                    pwm_port1 = pwm_value_sma
                    pwm_port2 = pwm_value_sma
                elif direction == 5:
                    pwm_port3 = pwm_value_sma
                    pwm_port1 = pwm_value_sma
                elif direction == 6:
                    pwm_port2 = pwm_value_sma
                    pwm_port3 = pwm_value_sma
                elif direction == 7:
                    pwm_port1 = pwm_value_sma
                    pwm_port2 = pwm_value_sma
                    pwm_port3 = pwm_value_sma
            elif index == 1:
                pwm_port1 = 0
                pwm_port2 = 0
                pwm_port3 = 0
            else:
                print("Index Error!!!")

            command = (
                f"1:{pwm_port1}:0,2:{pwm_port2}:0,3:{pwm_port3}:0,"
                f"4:0:0,5:{pwm_value_smp}:0,6:{pwm_value_smp}:0,7:{pwm_value_smp}:0, 8:{pwm_value_smp}:0\n"
            )

            try:
                ser.write(command.encode())  # Send command to Arduino
                print(f"Sent command to Arduino: {command.strip()}")
            except Exception as e:
                print(f"Error sending command to Arduino: {e}")

            frame_id += 1

    print(f"Mode {user_input} completed. Returning to standby mode.")

# Function for Mode 3: Forming
def execute_mode3(cooling_vol, max_vol, frame_rate, dev, ser, emergency_stop_condition):
    print("Executing Mode 3: Forming started.")
    global frame_id
    label = "Forming"
    cooling = True
    pwm_value_cooling = cooling_vol/max_vol*100
    
    while cooling:
        if emergency_stop_condition.is_set():
            command = ("1:0:0,2:0:0,3:0:0,4:0:0,5:0:0,6:0:0,7:0:0, 8:0:0\n")
            print("Emergency stop detected during Mode 3!")
            break 

        command = (f"1:0:0,2:0:0,3:0:0,4:{pwm_value_cooling}:10,5:0:0,6:0:0,7:0:0, 8:0:0\n")
        try:
            ser.write(command.encode())  # Send command to Arduino
            print(f"Sent command to Arduino: {command.strip()}")
        except Exception as e:
            print(f"Error sending command to Arduino: {e}")

        frame = None
        f = []
        amb = 0
        num_target = 1
        try:
            frame = dev.read_frame()
        except Exception as e:
            print("ERROR:", e)
            dev.clear_error(frame_rate)
            pass
        if frame is not None:

            f = dev.do_compensation(frame, add_ambient_temperature=True)
            f = dev.do_handle_bad_pixels(f)
            amb = f.pop()
            # save_to_csv(frame_id, f, amb, label)
            f_matrix = np.array(f).reshape(24, 32)
            result = generate_thermal_image(temp_matrix=f_matrix, amb=amb, scale_factor=3, num_target=num_target)
            mask, temp_matrix = result
            cal_result = get_temp_result(num_target=num_target, mask=mask, temp_matrix=temp_matrix, amb=amb)
            print(cal_result)
            if cal_result<40:
                cooling = False
            
            frame_id += 1
    print("Mode 3 completed. Returning to standby mode.")
    # return frame_id

# Function for Mode 4: Recovery
def execute_mode4(action_sma, max_vol, frame_rate, dev, ser, smp_con, emergency_stop_condition):
    print("Executing Mode 4: Recovery started.")
    global frame_id
    label = "Recovery"
    target_temp = 110
    smp_con.set_target(target_temp)
    recovery = False
    start_sma = False
    sma_action_count = 0
    sensor_count = 0
    input_recovery = 5 # seconds
    while recovery is False:
        if emergency_stop_condition.is_set():
            command = ("1:0:0,2:0:0,3:0:0,4:0:0,5:0:0,6:0:0,7:0:0, 8:0:0\n")
            print("Emergency stop detected during Mode 4!")
            break 
        frame = None
        f = []
        amb = 0
        num_target = 2
        try:
            frame = dev.read_frame()
        except Exception as e:
            print("ERROR:", e)
            dev.clear_error(frame_rate)
            pass
        if frame is not None:

            f = dev.do_compensation(frame, add_ambient_temperature=True)
            f = dev.do_handle_bad_pixels(f)
            amb = f.pop()
            # save_to_csv(frame_id, f, amb, label)
            f_matrix = np.array(f).reshape(24, 32)
            result = generate_thermal_image(temp_matrix=f_matrix, amb=amb, scale_factor=3, num_target=num_target)
            mask, temp_matrix = result
            cal_result = get_temp_result(num_target=num_target, mask=mask, temp_matrix=temp_matrix, amb=amb)
            print(cal_result)
            if abs(cal_result-target_temp)<=15:
                start_sma = True
            smp_con.sampling_update(cal_result)
            action_smp = smp_con.pid_controller()
            pwm_value_smp = round(action_smp/max_vol*100, 2)
            pwm_value_sma = 0
            if start_sma and sma_action_count<action_sma:
                pwm_value_sma = 99
                sma_action_count = sma_action_count + 1
            command = (
                f"1:{pwm_value_sma}:0,2:{pwm_value_sma}:0,3:{pwm_value_sma}:0,"
                f"4:0:0,5:{pwm_value_smp}:0,6:{pwm_value_smp}:0,7:{pwm_value_smp}:0, 8:{pwm_value_smp}:0\n"
            )

            try:
                ser.write(command.encode())  # Send command to Arduino
                print(f"Sent command to Arduino: {command.strip()}")
            except Exception as e:
                print(f"Error sending command to Arduino: {e}")
            sensor_count += 1
            if sensor_count >= input_recovery * frame_rate:
                sensor_count = 0
                if start_sma:
                    recovery_condition = input("Whether the shape has recovered? (y/n): ")
                    if recovery_condition == "y":
                        recovery = True

            frame_id += 1
            print(f"The ID is :{sma_action_count}")
    print(f"Mode 4 completed. Returning to standby mode.")
    
# Main loop: Wait for user input, handle commands
def main_loop(dev, ser, frame_rate = 2):

    action_sma_pre = 20                         # 10s for preparation
    action_sma_small = 4                        # 2.0s for small step
    action_sma_large = 12                       # 6.0s for large step
    action_sma_interval = 4                     # 2.0s interval between steps
    action_sma_recovery = 10
    # recovery_sma = [1.5, 1.5, 1.5]
    cooling_vol = 12
    max_vol = 18
    kp = 0.5
    kd = 0.4
    ki = 0
    smp_con = SMP_Controller(kp=kp, ki=ki, kd=kd, max_vol=max_vol)

    # Initialize a threading.Event for communication
    standby_condition = threading.Event()
    # Start a thread for temperature monitoring and preheating
    monitor_thread = threading.Thread(target=standby_mode, args=(max_vol, frame_rate, dev, ser, smp_con, standby_condition))
    monitor_thread.daemon = True  # Set daemon so it closes when the main thread closes
    monitor_thread.start()
    standby_condition.set()



    while True:
        user_input = input("Enter mode (1-4, q to quit): ")
        # Initialize a threading.Event for emergency stop
        emergency_stop_condition = threading.Event()
        # Define the emergency stop thread
        emergency_stop_thread = threading.Thread(target=emergency_stop, args=(emergency_stop_condition, ))
        emergency_stop_thread.daemon = True
       
        # emergency_stop_condition.clear()
        if user_input == 'q':

            print("Exiting program")
            ser.write(b'q\n')  # Notify Arduino to stop all actuators
            break

        # Mode 1: Small steps
        elif user_input == '1':
            try:
                direction = int(input("Enter direction (1-7): "))
                if direction not in range(1, 8):
                    raise ValueError("Invalid direction")
                steps = int(input("Enter steps (1-10): "))
                if steps not in range(1, 11):
                    raise ValueError("Invalid steps")
                print("Mode Small Stepping started")

                # Turn off standby condition for temperature monitoring during Mode 1
                standby_condition.clear()
                emergency_stop_thread.start()
                execute_mode1_and_2(user_input, direction, steps, action_sma_small,action_sma_pre, action_sma_interval, max_vol, frame_rate, dev, ser, smp_con, emergency_stop_condition)
                standby_condition.set()
            except ValueError as e:
                print(f"Input error: {e}")

        # Mode 2: Large steps
        elif user_input == '2':
            try:
                direction = int(input("Enter direction (1-7): "))
                if direction not in range(1, 8):
                    raise ValueError("Invalid direction")
                steps = int(input("Enter steps (1-5): "))
                if steps not in range(1, 6):
                    raise ValueError("Invalid steps")
                print("Mode 2 started")

                # emergency_stop_thread.start()

                # Turn off standby condition for temperature monitoring during Mode 2
                standby_condition.clear()
                emergency_stop_thread.start()
                execute_mode1_and_2(user_input, direction, steps, action_sma_large,action_sma_pre, action_sma_interval, max_vol, frame_rate, dev, ser, smp_con, emergency_stop_condition)
                standby_condition.set()

            except ValueError as e:
                print(f"Input error: {e}")

        # Mode 3: Forming
        elif user_input == '3':
            print("Mode 3 started")
            # Turn off standby condition for temperature monitoring during Mode 3
            standby_condition.clear()
            emergency_stop_thread.start()
            execute_mode3(cooling_vol, max_vol, frame_rate, dev, ser, emergency_stop_condition)
            standby_condition.set()

        # Mode 4: Recovery
        elif user_input == '4':
            print("Mode 4 started")
            # Turn off standby condition for temperature monitoring during Mode 4
            standby_condition.clear()
            emergency_stop_thread.start()
            execute_mode4(action_sma_recovery, max_vol, frame_rate, dev, ser, smp_con, emergency_stop_condition)
            standby_condition.set()

        elif user_input == 'p':
            print("Pre-heating mode started")
            direction = int(input("Enter direction (1-7): "))
            if direction not in range(1, 8):
                raise ValueError("Invalid direction")
            # Turn off standby condition for temperature monitoring during Mode p
            standby_condition.clear()
            emergency_stop_thread.start()
            preheat_mode(direction, action_sma_pre, frame_rate, dev, ser, emergency_stop_condition)
            standby_condition.set()

        else:
            print("Invalid mode, please enter 1-4 or q to quit.")

    ser.close()  # Close serial connection
    print("Program ended")


if __name__ == "__main__":
    # MagicMock vis port test
    # serial.Serial = MagicMock()
    # simulate Serial's actions
    # serial_mock = serial.Serial()
    # serial_mock.write = MagicMock(return_value=None)
    # serial_mock.read = MagicMock(return_value=b'OK')

    # Initialize all connections (serial and MLX90640)
    connections = initialize_all_connections()

    # Wake up Arduino
    wake_arduino(connections['serial'])

    frame_id = 0

    # Enter standby mode, waiting for user input
    main_loop(connections['mlx'], connections['serial'])
