"""
Main file for Adaptive Kalman Filter (AKF) functions
Author: Ryane Bourkaib
Date: 2025-05-21
Description: This module contains functions to compute the 
the Adaptive Kalman Filter (AKF) function including the output matrix H, 
initialization parameter function, and other related functionalities.

References:
    -R. Bourkaib, M. Kok, H.C. Seyffert,
    Unidirectional and multi-directional wave estimation from ship motions 
    using an Adaptive Kalman Filter with the inclusion of varying forward speed,
    Probabilistic Engineering Mechanics,
    2025,
    103773,
    ISSN 0266-8920,
    https://doi.org/10.1016/j.probengmech.2025.103773.

    -Kim, H. S., Park, J. Y., Jin, C., Kim, M. H., & Lee, D. Y. (2023).
      "Real-time inverse estimation of multi-directional random waves from
      vessel-motion sensors using Kalman filter". Ocean Engineering, 280.
      https://doi.org/10.1016/j.oceaneng.2023.114501
    
"""
# ─────────────────────────────────────────────────────────────
# IMPORTS
# ─────────────────────────────────────────────────────────────
import os
import math
import numpy as np
from scipy.interpolate import interp1d
from scipy.interpolate import RegularGridInterpolator
import matplotlib.pyplot as plt


from src.analysis import seaState_func
# ─────────────────────────────────────────────────────────────
# FUNCTION: Output matrix H for Adaptive Kalman Filter (AKF)    
# ─────────────────────────────────────────────────────────────

def output_matrix_KF(
    TF_om_enc,
    dtime,
    TF_amp,
    TF_phase,
    chosen_motions,
    time_motion,
):
    """Calculates the output matrix H of the AKF over time
      for a ship motion model.
    Parameters
    ----------
    TF_om_enc : array_like of shape (Nfreq,)
        Vector of encounter wave frequencies [rad/s].
    dtime : float
        Time step [s].
    TF_amp : array_like of shape (NDOF, Nfreq, Nbeta)
        Amplitude of the transfer function for each DOF, frequency, and direction.
    TF_phase : array_like of shape (NDOF, Nfreq, Nbeta)
        Phase of the transfer function for each DOF, frequency, and direction [rad].
    chosen_motions : list of str
        List of chosen motion names (e.g., ['Surge', 'Sway', 'Heave']).
    time_motion : array_like of shape (Ntime,)
        Time vector for the motion [s].
    Returns
    -------
    H_tot : array_like of shape (Ntime, NDOF, 2 * Nfreq * Nbeta)
        Output matrix H over time for the chosen motions.   
    Notes
    -----
    This function computes the output matrix H for a ship motion model
    using the transfer function amplitudes and phases. The output matrix
    is computed for each time step based on the encounter frequencies,  
    amplitudes, and phases of the transfer functions for the chosen motions.
    The output matrix H is structured such that each row corresponds to a
    time step, and each column corresponds to a specific DOF and 
    frequency-direction pair.  

    Example
    -------
    >>> H_tot = output_matrix_KF(TF_om_enc, dtime, TF_amp, 
    TF_phase, chosen_motions, time_motion)
    """
    # Initialisation :
    dof_number = len(chosen_motions)
    H = np.zeros(
        (TF_amp.shape[0], TF_amp.shape[1], TF_amp.shape[2], 2)
    )
    H_tot = np.zeros((len(time_motion),dof_number,
                      2 * TF_amp.shape[1] * TF_amp.shape[2],))

    for k in range(len(time_motion)):
        t = k * dtime
        for dof in range(TF_amp.shape[0]):
            for i in range(TF_amp.shape[1]):
                for j in range(TF_amp.shape[2]):
                    x1 = TF_amp[dof, i, j] * np.cos(TF_om_enc[i, j] * t + TF_phase[dof, i, j])
                    x2 = TF_amp[dof, i, j] * np.sin(TF_om_enc[i, j] * t + TF_phase[dof, i, j])
                    H[dof, i, j, 0] = (x1)  # Store the value of x1 for the current time step
                    
                    H[dof, i, j, 1] = (x2)  # Store the value of x2 for the current time step
                
        # Reshape H to match the second dimension of H_dof
        H_dof = np.reshape(
            H, (dof_number, 2 * TF_amp.shape[1] * TF_amp.shape[2])
        )
        H_tot[k] = H_dof
    return H_tot

# ─────────────────────────────────────────────────────────────
# FUNCTION: Initialize weighted error covariance matrices Pk and Q
# ─────────────────────────────────────────────────────────────
def init_weighted_covariances(
    Nomega,
    Ndirc,
    omega_KF,
    low_freq_threshold,
    high_freq_threshold,
    Pk_low,
    Q_low,
):
    """
    Initialize weighted error covariance matrices Pk and 
    Q based on frequency thresholds.

    Parameters:
    ----------
    - Nomega: int
        Number of frequency bins.
    - Ndirc: int
        Number of direction bins.
    - omega_KF: ndarray, shape (Nomega,)
        Array of angular frequencies for the Kalman filter.
    - low_freq_threshold: float
        Frequency below which components are considered 'low'.
    - high_freq_threshold: float
        Frequency above which components are considered 'high'. 
    - Pk_low: float
        Scaling factor for low/high frequency components in Pk.
    - Q_low: float
        Scaling factor for low/high frequency components in Q.

    Returns:
    -------
    - Pk: ndarray, shape (2*N*M, 2*N*M)
        Weighted error covariance matrix.
    - Q: ndarray, shape (2*N*M, 2*N*M)
        Weighted process noise covariance matrix.
    """
    # Initialize the number of frequency and direction bins
    Pk_weights = np.ones(2 * Nomega * Ndirc)
    Q_weights = np.ones(2 * Nomega *    Ndirc)
    for j in range(Ndirc):
        for i in range(Nomega):
            freq = omega_KF[i]
            index_re = 2 * (i * Ndirc + j)
            index_im = index_re + 1
            if  freq > high_freq_threshold:
                Pk_weights[index_re] = Pk_low
                Pk_weights[index_im] = Pk_low
                Q_weights[index_re] = Q_low
                Q_weights[index_im] = Q_low
            else:
                Pk_weights[index_re] = 10  
                Pk_weights[index_im] = 10  
                Q_weights[index_re] = 0.1 
                Q_weights[index_im] = 0.1
   
    # Create diagonal covariance matrices
    Pk = np.diag(Pk_weights)
    Q = np.diag(Q_weights)

    return Pk, Q

# ─────────────────────────────────────────────────────────────
# FUNCTION: Initialize the Kalman Filter parameters xk, R and phi
# ─────────────────────────────────────────────────────────────
def init_kf_param(
    selected_ship_resp,
    Nomega,
    Ndirc,
    omega_KF,
    low_freq_threshold,
    high_freq_threshold,
    Pk_low,
    Q_low,
    cov_noise_sensor_tran_motion,
    cov_noise_sensor_rot_motion,

):
    """
    Initialize the Kalman Filter parameters for the adaptive Kalman filter.
    Parameters:
    ----------
    - selected_ship_resp: list of str
        List of selected ship responses (e.g., ['surge', 'sway', 'heave']).
    - Nomega: int
        Number of frequency bins.
    - Ndirc: int
        Number of direction bins.   
    - omega_KF: ndarray, shape (Nomega,)
        Array of angular frequencies for the Kalman filter.
    - low_freq_threshold: float
        Frequency below which components are considered 'low'.
    - high_freq_threshold: float
        Frequency above which components are considered 'high'.
    - Pk_low: float
        Scaling factor for low/high frequency components in Pk.
    - Q_low: float
        Scaling factor for low/high frequency components in Q.
    - cov_noise_sensor_tran_motion: float
        Covariance noise for translational motion sensors.
    - cov_noise_sensor_rot_motion: float
        Covariance noise for rotational motion sensors.
    Returns:
    ------- 
    - Xk: ndarray, shape (2*Nomega*Ndirc,)
        Initial state vector for the Kalman filter.
    - Phi: ndarray, shape (2*Nomega*Ndirc, 2*Nomega*Ndirc)
        State transition matrix for the Kalman filter.
    - R: ndarray, shape (len(selected_ship_resp), len(selected_ship_resp))
        Measurement noise covariance matrix for the Kalman filter.
    - Pk: ndarray, shape (2*Nomega*Ndirc, 2*Nomega*Ndirc)
        Weighted error covariance matrix for the Kalman filter.
    - Q: ndarray, shape (2*Nomega*Ndirc, 2*Nomega*Ndirc)
        Weighted process noise covariance matrix for the Kalman filter.
    Notes:
    This function initializes the Kalman Filter parameters including the state
    transition matrix, state vector, measurement noise covariance matrix, and
    weighted error covariance matrices. The parameters are set based on the
    selected ship responses and the frequency and direction bins defined for
    the Kalman filter. The covariance matrices are initialized with specific
    values based on the translational and rotational motion sensors.
    Example:
    >>> Xk, Phi, R, Pk, Q = init_kf_param(
        selected_ship_resp,
        Nomega,
        Ndirc,
        omega_KF,
        low_freq_threshold=0.08,
        high_freq_threshold=1.2, 
        Pk_low=1,  
        Q_low=0.01, 
        cov_noise_sensor_tran_motion=0.01,
        cov_noise_sensor_rot_motion=0.001,
    )
    """

    # State transition matrix: complex wave component
    Phi = np.eye(2 * Nomega* Ndirc)

    # State vector: complex wave component
    Xk = np.zeros(2 * Nomega * Ndirc)

    # Error measurement covariance: square standard values of motion data
    R_values = np.zeros(len(selected_ship_resp))
    for i in range(len(selected_ship_resp)):
        if (
            selected_ship_resp[i] == "surge"
            or selected_ship_resp[i] == "sway"
            or selected_ship_resp[i] == "heave"
        ):
            R_values[i] = cov_noise_sensor_tran_motion
        else:
            R_values[i] = cov_noise_sensor_rot_motion

    R = np.diag(R_values)

    # Initialize the covariance matrices Pk and Q
    Pk, Q = init_weighted_covariances(
        Nomega,
        Ndirc,
        omega_KF,
        low_freq_threshold,
        high_freq_threshold,
        Pk_low,
        Q_low
    )
    return Xk, Phi, R, Pk, Q


# ─────────────────────────────────────────────────────────────
# FUNCTION: Adaptive Kalman Filter (AKF) implementation
#for complex wave component extimation 
# ─────────────────────────────────────────────────────────────
def kalman_filter_algo(
    motion_data, Pk, Q, R, Xk, Phi, H_dof, 
    time_steps, dof_number, alpha=0.3
):
    """
    Implements the Kalman filter algorithm for estimating complex wave components
    from ship motion data.
    Parameters:
    ----------
    - motion_data: ndarray, shape (Ntime, NDOF)
        Ship motion data for each degree of freedom (DOF) over time.
    - Pk: ndarray, shape (2*Nomega*Ndirc, 2*Nomega*Ndirc)
        Weighted error covariance matrix for the Kalman filter.
    - Q: ndarray, shape (2*Nomega*Ndirc, 2*Nomega*Ndirc)
        Weighted process noise covariance matrix for the Kalman filter.
    - R: ndarray, shape (len(selected_ship_resp), len(selected_ship_resp))
        Measurement noise covariance matrix for the Kalman filter.
    - Xk: ndarray, shape (2*Nomega*Ndirc,)
        Initial state vector for the Kalman filter. 
    - Phi: ndarray, shape (2*Nomega*Ndirc, 2*Nomega*Ndirc)
        State transition matrix for the Kalman filter.
    - H_dof: ndarray, shape (Ntime, NDOF, 2*Nomega*Ndirc)
        Output matrix H for the Kalman filter over time.
    - time_steps: int
        Number of time steps in the motion data.
    - dof_number: int
        Number of degrees of freedom (DOF) in the ship motion data.
    - alpha: float, optional
        Smoothing factor for the measurement noise covariance update (default is 0.3).
    Returns:
    -------
    - Xk: ndarray, shape (2*Nomega*Ndirc,)
        Final estimated state vector after applying the Kalman filter.
    - Xk_tot: ndarray, shape (Ntime, 2*Nomega*Ndirc)
        Estimated state vector over time after applying the Kalman filter.
    - innovation: ndarray, shape (NDOF, Ntime)
        Innovation (measurement residual) for each DOF over time.
    - motion_residus: ndarray, shape (NDOF, Ntime)
        Residuals of the motion data after applying the Kalman filter.
    - predicted_motion_model: ndarray, shape (NDOF, Ntime)
        Predicted motion model based on the Kalman filter estimates.
    - innovation_tot: ndarray, shape (Ntime, NDOF, NDOF)
        Total innovation covariance over time for each DOF.

    Example:
    -------
    >>> Xk, Xk_tot, innovation, motion_residus, predicted_motion_model,
        innovation_tot = kalman_filter_algo(motion_data, Pk, Q, R, Xk, 
        Phi, H_dof, time_steps, dof_number, dtime)
    )
    """
    # Inisialisation:
    Xk_tot = np.zeros((time_steps, H_dof.shape[2]))
    innovation = np.zeros_like((motion_data))
    motion_residus = np.zeros_like((motion_data))
    predicted_motion_model = np.zeros_like((motion_data))
    innovation_tot = np.zeros((time_steps, dof_number, dof_number))

    for k in range(time_steps):
        #1- Prediction step:
        Xk_pre = Xk  # Predict the next state
        Pk_pre =  Pk  + Q # Predict the next error covariance

        #2- Update step:
        # Innovation pre-fit residual
        predicted_motion_model[:, k] = H_dof[k, :, :] @ Xk_pre
    
        # Calculate the measurement residual
        innovation[:, k] = motion_data[:, k] - predicted_motion_model[:, k]
        # Calculate the measurement residual
    
        R = alpha * R + (1 - alpha) * (
            innovation[:, k] @ innovation[:, k].T
            + H_dof[k] @ Pk_pre @ H_dof[k].T
        )
        #R = np.diag(np.diag(R))

        # Calculate the covariance of the inovation 
        S = H_dof[k, :, :] @ Pk_pre @ H_dof[k, :, :].T + R

        # calculate KF gain
        K = Pk_pre @ H_dof[k, :, :].T @ np.linalg.inv(S)
        # update the estimate with measurements
        Xk = Xk_pre + K @ innovation[:, k]

        # update the error covariance
        Pk = Pk_pre - K @ H_dof[k, :, :] @ Pk_pre
       
        # Store the results:
        Xk_tot[k, :] = Xk
     
    return (
        Xk,
        Xk_tot,
        innovation,
        motion_residus,
        predicted_motion_model,
        innovation_tot,
    )

# ─────────────────────────────────────────────────────────────
# FUNCTION: Sea state estimation using the Adaptive Kalman Filter
# ─────────────────────────────────────────────────────────────
def sse_using_Akf(
    time,
    fs,
    Pk,
    Q,
    R,
    Xk,
    Phi,
    output_matrix_over_time,
    dof_number,
    simulated_motion_data,
    TF_omega_inter,
    TF_beta,
    TF_om_enc_inter,
    psi,
):

    # apply the kalman filter:
    (
        Xk,
        Xk_tot,
        innovation,
        motion_residus,
        predicted_motion_model,
        S_tot,
    ) = kalman_filter_algo(
        simulated_motion_data,
        Pk,
        Q,
        R,
        Xk,
        Phi,
        output_matrix_over_time,
        time.shape[0],
        dof_number,
        alpha=0.3,
    )

    
    ## Calculate the estimated wave spectra over time:
    # Calculate the estimated encounter wave spectrum over time:
    est_S2D_enct = seaState_func.wave_comp2dir_spec_enc(
        Xk_tot,
        TF_om_enc_inter,
        TF_beta,
    )

    # calculate S2D over time:
    est_S2Dt = seaState_func.wave_comp2dir_spec_over_time(
        Xk_tot, TF_omega_inter, TF_beta
    )

    # Calculate the average S2D of last estimated states:
    last_states_number = 5000 
    est_S2D_mean = np.mean(est_S2Dt[-last_states_number:], axis=0)
    est_S2D_enc_mean = np.mean(est_S2D_enct[-last_states_number:], axis=0)
    
    # Print the estimated sea state parameters:
    est_Hs, est_Tp, est_beta, est_S1D, est_theta = seaState_func.spec2d_to_sea_state(
        est_S2D_mean, TF_omega_inter, TF_beta, psi
    )
    # print the estimated sea state parameters:
    print("Estimated sea state parameters:")
    print("Significant wave height [m]:", est_Hs)
    print("Peak period [s]:", est_Tp)
    print("Relative Mean direction [deg]:", np.rad2deg(est_beta))
    print("Mean wave direction [deg]:", np.rad2deg(est_theta))
    
    # Calculate sea state paramters over time:
    H_s_over_time, T_p_over_time = np.zeros(len(time)), np.zeros(len(time))
    theta_m_over_time, S1D_over_time = np.zeros(len(time)), np.zeros(
        (len(time), len(TF_omega_inter))
    )
    beta_over_time = np.zeros(len(time))

    for i in range(len(time)):
        (
            H_s_over_time[i],
            T_p_over_time[i],
            beta_over_time[i],
            S1D_over_time[i],
            theta_m_over_time[i],
        ) = seaState_func.spec2d_to_sea_state(
            est_S2Dt[i], TF_omega_inter, TF_beta, psi
        )


    # Calculate the estimated wave elevation:
    est_wave_elv, std_wave_elv = seaState_func.wave_comp2_wave_elev(
        Xk_tot, TF_om_enc_inter, TF_beta, 1 / fs, time.shape[0]
    )

    
    return (
        predicted_motion_model,
        est_S2D_mean,
        est_Hs,
        est_Tp,
        est_beta,
        est_S1D,
        est_theta,
        est_S2Dt,
        est_S2D_enct,
        est_wave_elv,
        std_wave_elv,
        H_s_over_time,
        T_p_over_time,
        beta_over_time,
        S1D_over_time,
        theta_m_over_time
    )