"""
Main script for Sea state estimation using an Adaptive Kalman Filter (AKF) 
based on ship motion responses.

Author: Ryane Bourkaib  
Date: 2025-05-26  
Description:
    Implements an AKF to estimate directional sea state parameters 
    (Hs, Tp, mean direction) simulated ship motion signals (e.g., heave, pitch).
    The AKF simultaneously estimates complex wave elevation components and the corresponding 
    wave spectrum in the time domain.

References:
    - Kim, H. S., Park, J. Y., Jin, C., Kim, M. H., & Lee, D. Y. (2023).
      "Real-time inverse estimation of multi-directional random waves from
      vessel-motion sensors using Kalman filter". Ocean Engineering, 280.
      https://doi.org/10.1016/j.oceaneng.2023.114501
   

"""

# ─────────────────────────────────────────────────────────────
# IMPORTS
# ─────────────────────────────────────────────────────────────
import numpy as np
import matplotlib.pyplot as plt
import os

from src.tools import general_func
import src.analysis.adaptiveKalmanFilter as akf_func
from results.plots import ploting
# ─────────────────────────────────────────────────────────────
# MAIN EXECUTION BLOCK
# ─────────────────────────────────────────────────────────────
def main():
    print("Starting Sea State Estimation using Adaptive Kalman Filter...")
  
    sim_num = 1 
    type_sim = "doubleSum"  # Options: "equalEnergy" or "doubleSum"
    file_path = f"results\\data\\shipMotion_{type_sim}\\sea_Hs_3.0_Tp_9.0_beta_180.0_U_9_sim_{sim_num}.npy"
    # Check if the file exists before trying to load it
    if not os.path.exists(file_path):
        print(f"File {file_path} does not exist. Please check the path.")
        return
    try:
        data = np.load(file_path,allow_pickle=True,
        )
        # Check if data is a numpy array containing a dictionary
        if data.shape == () and isinstance(data.item(), dict):
            data_dic = data.item()
            print(data_dic.keys())
        else:
            print("Motion data is not a numpy array containing a dictionary.")
    except FileNotFoundError:
        print("File not found. Please check the file path.")
    except Exception as e:
        print(f"An error occurred: {e}")

    # Now it's safe to use data_dic
    if data_dic is not None:
        result = general_func.extract_sim_ship_data_doubelSum(data_dic)
    else:
        print("Simulation data could not be loaded. Skipping extraction.")

    
    # Extract the ship motion data
    (
    # Sea State
    S1D, omega_wave, omega, S2D, betas_wave, om_enc,
    Hs_2d, Tp_2d, theta0_2d, beta0_2d, theta0,

    # Ship Info
    U, psi,

    # Time and Sampling
    time, fs, Num_Ts,

    # Wave Components
    wave_amplitude, epsilon, epsilon_seed,

    # Clean Motions
    wavet, surget, swayt, heavet, rollt, pitcht, yawt,

    # Noisy Motions
    surget_noisy, swayt_noisy, heavet_noisy, rollt_noisy,
    pitcht_noisy, yawt_noisy,

    # added noise
    snr
    ) = general_func.extract_sim_ship_data_doubelSum(data_dic)

    # Selected ship motion responses:
    #NOTE: seed is the index of the simulation time series 
    # max is len(Num_Ts)
    print(time.shape)
    print(wavet.shape)
    print(surget.shape)
    seed = 0  # choose one from 0 to len(Num_Ts)
    time = time[seed,0, :] # time[seed,0,:] is motion based on double sum method
    wavet = wavet[seed,:]

    # Select the motion responses:
    selected_ship_resp = ["surge", "sway", "heave", "roll", "pitch"]
    dof_number = len(selected_ship_resp)
    
    # Selected motion data:
    simulated_motion_data = np.array([surget, swayt, heavet, rollt, pitcht]) 
    noisy_simulated_motion_data = np.array(
        [surget_noisy, swayt_noisy, heavet_noisy, 
        rollt_noisy,  pitcht_noisy]
    ) 
    
    simulated_motion_data = simulated_motion_data[:, seed, :]
    noisy_simulated_motion_data = noisy_simulated_motion_data[:, seed, :]

    # Import TF motion model (TF amplitude and phase [(m,deg)/m and deg]):
    # Choose the RAO motion file corresponding to specific speeds:
    # NOTE: The TFs are confidentioal and should not be shared publicly.
    # we set random rao motion model for the sake of this example.

    omega_TF = np.arange(0.1, 2.0, 0.01)  # Example TF frequencies [rad/s]
    TF_beta = np.arange(0, 360, 10)  # Example relative wave directions [degrees]   
    DOF = 6  # Number of degrees of freedom (DOFs) 
    TF_amps = np.abs(np.random.rand(DOF, len(omega_TF), len(TF_beta)))  # Random TF amplitudes
    TF_phases = np.random.rand(DOF, len(omega_TF), len(TF_beta)) * 360  # Random TF phases [degrees]
    U = 9  # Ship forward speed [m/s]
    g = 9.81  # Gravitational acceleration [m/s^2]
    TF_enc_om = omega_TF.reshape(-1,1) - omega_TF.reshape(-1,1) **2 * U  / g * np.cos(np.deg2rad(TF_beta).reshape(1,-1))  # Encounter frequencies [rad/s]
    print("🚢 RAO Motion Model Parameters shape:")
    print(f"TF frequencies: {omega_TF.shape}")
    print(f"TF relative wave directions: {TF_beta.shape}")
    print(f"TF amplitudes: {TF_amps.shape}")
    print(f"TF phases: {TF_phases.shape}")
    print(f"TF encounter frequencies: {TF_enc_om.shape}")
    # NOTE: Since all the DOFs have the same encounter frequencies,
    
    # Ensures periodic wrapping of angular data
    TF_amps_new = []
    TF_phases_new = []
    TF_enc_om_new = []
    TF_beta = general_func.wrap_angle_range(TF_beta)
    if TF_beta[0,] % 360 != TF_beta[-1,] % 360:
        TF_beta_new = general_func.extend_with_wrap(TF_beta, axis=0, offset=360)
        for i in range(TF_amps.shape[0]):
            TF_amps_new.append(general_func.extend_with_wrap(TF_amps[i], axis=1))
            TF_phases_new.append(general_func.extend_with_wrap(TF_phases[i], axis=1))
            TF_enc_om_new = general_func.extend_with_wrap(TF_enc_om, axis=1)
        
        TF_amps = np.array(TF_amps_new)
        TF_phases = np.array(TF_phases_new)
        TF_enc_om = np.array(TF_enc_om_new)
        TF_beta = TF_beta_new
        TF_phases = TF_phases_new

    # Convert the angles to rad
    TF_beta = np.deg2rad(TF_beta)
    TF_phases = np.deg2rad(TF_phases)  # convert to rad rotatinal phases
    TF_amps[3:] = np.deg2rad(TF_amps[3:])  # convert to rad rotatinal amplitudes
    print("🚢 RAO Motion Model size:")
    print(f"Number of DOFs: {TF_amps.shape[0]}")
    print(f"Number of TF frequencies: {len(omega_TF):.3f}")
    print(f"Number of TF encounter frequencies: {len(TF_enc_om):.3f}")
    print(f"Number of TF relative wave direction: {len(TF_beta):.3f}")
    

    # AKF parameters:
    # Interpolate the transfer function data to the Kalman 
    #NOTE: smaller domega adds too many unknowns and reduces filter performance.
    # choose a domega that is suitable for the filter performance
    domega_kf = 0.04  # Frequency step for the Kalman filter [rad/s]    
    # filter frequency range:
    ( 
        TF_omega_inter,
        TF_amp_inter,
        TF_phase_inter,
        TF_om_enc_inter,
    ) = general_func.rao_interp_new_domega(
         TF_amps, TF_phases,  TF_enc_om, omega_TF, domega_kf
    )
   
    # Set the range of directions and feraquencies for the Kalman filter:
    beta_index_first = np.where(TF_beta == np.deg2rad(90))[0][0]
    beta_index_last = np.where(TF_beta == np.deg2rad(280))[0][0]
    omega_indx_last = np.where(np.isclose(TF_omega_inter, 2, atol=1e-1))[0][0]
    # Change the spacing direction to 30 degrees:
    beta_spacing = 3
    TF_beta = TF_beta[beta_index_first:beta_index_last:beta_spacing]

    # Update the RAO motion model with the new parameters:
    TF_omega_inter = TF_omega_inter[:omega_indx_last]
    TF_amp_inter = TF_amp_inter[:, :omega_indx_last, 
                                beta_index_first:beta_index_last:beta_spacing]
    TF_phase_inter = TF_phase_inter[:, :omega_indx_last, 
                                    beta_index_first:beta_index_last:beta_spacing]
    TF_om_enc_inter = TF_om_enc_inter[:omega_indx_last, 
                                      beta_index_first:beta_index_last:beta_spacing]

    # Select the Rao model for the selected ship responses:
     # RAO data for the selected motions:
    TF_amp_inter, TF_phase_inter = general_func.rao_data_slected_dof(
        selected_ship_resp, TF_amp_inter, TF_phase_inter
    )
    
    print("🚢 RAO Motion Model Interpolation size:")
    print(f"Number of DOFs: {TF_amp_inter.shape[0]}")
    print(f"Number of TF frequencies: {TF_omega_inter.shape[0]:.3f}")
    print(f"Number of TF encounter frequencies: {TF_om_enc_inter.shape[0]:.3f}")
    print(f"Number of TF relative wave direction: {TF_beta.shape[0]:.3f}")
 
    # Calcualte output matrix H over time :
    output_matrix_over_time = akf_func.output_matrix_KF(
        TF_om_enc_inter,
        1/fs, # time step [s]
        TF_amp_inter,
        TF_phase_inter,
        selected_ship_resp,
        time,  ### time not len
    )
    print("Output matrix H over time shape:", output_matrix_over_time.shape)

    # Set the Kalman filter parameters:
    Xk, Phi, R, Pk, Q= akf_func.init_kf_param(
        selected_ship_resp,
        len(TF_omega_inter),
        len(TF_beta),
        TF_omega_inter,
        low_freq_threshold=0,
        high_freq_threshold=1.15, 
        Pk_low=1,  
        Q_low=0.01, 
        cov_noise_sensor_tran_motion=0.01,
        cov_noise_sensor_rot_motion=0.001,
    )


    # SSE using the Adaptive Kalman Filter:
    # Kalman filter algorithm:
    (   predicted_motion_model,
        est_S2D_mean,
        est_Hs,
        est_Tp,
        est_beta,
        est_S1D,
        est_theta,
        est_S2Dt,
        est_S2D_enct,
        est_wave_elv,
        std_wave_elv,
        H_s_over_time,
        T_p_over_time,
        beta_over_time,
        S1D_over_time,
        theta_m_over_time
    ) = akf_func.sse_using_Akf(
        time,
        fs,
        Pk,
        Q,
        R,
        Xk,
        Phi,
        output_matrix_over_time,
        dof_number,
        noisy_simulated_motion_data,
        TF_omega_inter,
        TF_beta,
        TF_om_enc_inter,
        psi=np.deg2rad(psi),
    )


    #Visualize of the results:
    ## Evaluation of the Kalman Filter:
    # Plot the estimated and simulated motion data:
    for i in range(dof_number):
        ploting.plot_sim_and_est_motion_data(
            time,
            simulated_motion_data[i,:],
            predicted_motion_model[i,:],
            dof_index=i,
        )
    
    # plot the estimated wave spectrum
    ploting.plt_polar_S2D(
        TF_omega_inter,
        TF_beta,
        est_S2D_mean,
        ship_heading_rad=np.deg2rad(psi),
    )
    ploting.plot_estimated_sea_state_over_time(
        time,
        estimated_ss=H_s_over_time,
        chosen_ss="Significant wave height [m]",
        simulated_ss=Hs_2d,
        case="multidirectional case",
    )
    ploting.plot_estimated_sea_state_over_time(
        time,
        estimated_ss=T_p_over_time,
        chosen_ss="Peak period [s]",
        simulated_ss=Tp_2d,
        case="multidirectional case",
    )
    ploting.plot_estimated_sea_state_over_time(
        time,
        estimated_ss=np.rad2deg(beta_over_time),
        chosen_ss="Relative Mean direction [deg]",
        simulated_ss= np.rad2deg(beta0_2d),
        case="multidirectional case",
    )
    # plot estimated and simulated S1D:
    ploting.plot_simu_and_est_S1D(
        TF_omega_inter,
        est_S1D,
        omega_wave,
        S1D,
    )


    # calculate correlation between estimated and simulated wave elevation:
    wave_corr = np.corrcoef(wavet, est_wave_elv)
    print("wave_corr", wave_corr)

    # plot the estimated and simulated wave elevation:
    plt.plot(time[-1000:], est_wave_elv[-1000:], linewidth=2, label="Estimated")
    plt.plot(time[-1000:], wavet[-1000:], linewidth=2, label="Simulated")
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.xlabel("Time [s]", fontsize=16)
    plt.ylabel("Amplitude [m]", fontsize=16)
    #plt.title("Simulated and estimated wave elevation", fontsize=16, fontweight="bold")
    plt.legend(fontsize=14)
    plt.grid(True)
    plt.show()

if __name__ == "__main__":
    main()
    