"""
Sea state parameter calculation module.

Author: Ryane Bourkaib  
Date: 2025-05-21  
Description:
    This module contains functions to calculate sea state parameters from
    1D (S1D) or 2D (S2D) wave spectra. These parameters are commonly used
    in oceanography and ship motion simulations.

    The calculated sea state parameters include:
    - Significant wave height, \( H_s \) [m]
    - Peak period, \( T_p \) [s]
    - Mean wave direction [rad]

Example:
    from src.analysis.seaState_func import ...
    Hs = compute_Hs(S_1D, df)
    Tp = compute_Tp(S_1D, f)

References:
    - R. Bourkaib, M. Kok, H.C. Seyffert,
    Unidirectional and multi-directional wave estimation from ship 
    motions using an Adaptive Kalman Filter with the inclusion of varying forward speed,
    Probabilistic Engineering Mechanics,
    2025,
    103773,
    ISSN 0266-8920,
    https://doi.org/10.1016/j.probengmech.2025.103773.
"""

# ─────────────────────────────────────────────────────────────
# IMPORTS
# ─────────────────────────────────────────────────────────────
import numpy as np
from scipy import integrate
from scipy.integrate import trapezoid

from src.tools import general_func
# ─────────────────────────────────────────────────────────────
# FUNCTIONS: Sea state parameters from 1D spectrum 
# ─────────────────────────────────────────────────────────────
def spec1d_to_sea_state(S1D, omega):
    """Calculates the sea state parameters from a 1D wave spectrum.

    Parameters
    ----------
    S1D : array_like of shape (Nfreq,)
        1D wave spectrum [m^2.s/rad].
    omega : array_like of shape (Nfreq,)
        Vector of angular frequencies [rad/s].
    Returns
    -------
    m : array_like of shape (6,)
        Spectral moments 
    Hs : float
        Significant wave height [m].
    Tm01 : float, or array_like
        Mean wave period [s].
    Tm02 : float, or array_like
        Zero up-crossing period [s].
    Tp : float
        Peak period [s].
    beta_mean : float
        Mean wave direction [rad].
    Tz : float
        Zero-upcrossing period [s].

    References
    ----------
    - NetSSE Mounet, R. E. G. (Creator) & Nielsen, U. D. (Supervisor),
      Technical University of Denmark, 27 Jul 2023 DOI: 10.11583/DTU.26379811, 
      https://gitlab.gbar.dtu.dk/regmo/NetSSE 
    """
    # Constants:
    g = 9.81  # [m/s^2]
    m = np.zeros((6,))

    # Calculate the spectral moments:
    m[0] = integrate.simpson(omega ** (-1) * S1D, omega)  # m_{-1}
    for i in range(1, 6):
        m[i,] = integrate.simpson(omega ** (i - 1) * S1D, omega)  # m_i

    Hs = 4 * np.sqrt(m[1])       # Significant wave height [m]
    Tm02 = np.sqrt(m[1] / m[3])  # Zero up-crossing period [s]
    Tm01 = m[1] / m[2,]          # Mean wave period [s]
    Tm24 = np.sqrt(m[3] / m[5])  # Mean crest period [s]
    TE = m[0] / m[1]             # Mean energy period [s]
    Sm02 = 2 * np.pi / g * Hs / Tm02**2  # Significant wave steepness [-]
    epsilon = np.sqrt(1 - m[3] ** 2 / (m[1] * m[5]))  # Spectral bandwidth [-]

    # Calculate the peak period:
    Tp = 2 * np.pi / omega[np.argmax(S1D)]  # Peak period [s]
    # check if the total variance in the spectrum is equal to the variance of the wave height
  
    expected_m0 = Hs**2 / 16
    m0 = m[1]  # m0 is the zeroth moment
    if not np.isclose(m0, expected_m0, rtol=1e-2):
        print("Warning: Total variance in the spectrum does not match wave height variance.")
        #print(f"m0 (from spectrum): {m0:.4f}, Hs²/16: {expected_m0:.4f}")

    return Hs, Tp, Tm01, Tm02, Tm24, TE, Sm02, epsilon


# ─────────────────────────────────────────────────────────────
# FUNCTIONS: Sea state parameters from 2D spectrum
# ─────────────────────────────────────────────────────────────
def spec2d_to_sea_state(S2D, omega, theta, psi, conv="from"):
    """Calculates the sea state parameters from a 2D wave spectrum.

    Parameters
    ----------
    S2D : array_like of shape (Nfreq, Ntheta)
        2D wave spectrum [m^2.s/rad].
    omega : array_like of shape (Nfreq,)
        Vector of angular frequencies [rad/s].
    theta : array_like of shape (Ntheta,) [rad]
    psi :   float [rad]
        Ship heading [rad]
    Returns
    -------
    m : array_like of shape (6)
        Spectral moments :math: 
    Hs : float
        Significant wave height [m].
    Tm01 : float, or array_like
        Mean wave period [s].
    Tm02 : float, or array_like
        Zero up-crossing period [s].
    Tp : float
        Peak period [s].
    beta_mean : float
        Mean wave direction [rad].
    Tz : float
        Zero-upcrossing period [s].

    """
    # Constants:
    g = 9.81  # [m/s^2]
    m = np.zeros((6,))
    # Calculate the 1D wave spectrum:
    S1D = integrate.simpson(S2D, theta, axis=1)

    # Directional spreading function:
    D = S2D / np.expand_dims(S1D, 1)
    D[np.isnan(D)] = 0

    # Calculate the spectral moments:
    m[0] = integrate.simpson(omega ** (-1) * S1D, omega) 
    for i in range(1, 6):
        m[i,] = integrate.simpson(omega ** (i - 1) * S1D, omega)  # m_i

    Hs = 4 * np.sqrt(m[1])          # Significant wave height [m]
    Tm02 = np.sqrt(m[1] / m[3])     # Zero up-crossing period [s]
    Tm01 = m[1] / m[2,]             # Mean wave period [s]
    Tm24 = np.sqrt(m[3] / m[5])     # Mean crest period [s]
    TE = m[0] / m[1]                # Mean energy period [s]
  
    # Calculate the peak period:
    Tp = 2 * np.pi / omega[np.argmax(S1D)]  # Peak period [s]
    
    # Calculate the overall mean wave direction:
    d = trapezoid(trapezoid(S2D * np.sin(theta), omega, axis=0), theta, axis=0)
    c = trapezoid(trapezoid(S2D * np.cos(theta), omega, axis=0), theta, axis=0)
    theta_m = np.arctan2(d, c)  # Mean wave direction [rad]
    if theta_m < 0:
        theta_m += 2 * np.pi

    # Calculate relative mean wave direction:
    beta = theta_m - psi # Relative mean wave direction [rad]
    if conv == "from":
        beta = general_func.wrap_angle_range(np.rad2deg(np.pi + beta), start =0)
      
    elif conv == "to":
        beta = general_func.re_range(np.rad2deg(beta), start =0)
      
    else:
        raise NameError('Invalid direction convention. Only accepts "from" or "to".')

    return Hs, Tp, theta_m, S1D, np.deg2rad(beta)


# ─────────────────────────────────────────────────────────────
# FUNCTIONS: Relative wave direction calculation
# ─────────────────────────────────────────────────────────────
def calculate_relative_dirc(theta_deg, psi_deg, conv="from"):
    
    """Calculates the relative wave direction.
    Parameters
    ----------
    theta_deg : array_like
        Incoming wave directions [deg]
    psi_deg : float
        Ship heading [deg]
    conv : str, optional
        Direction convention, either "from" or "to". Default is "from".
    Returns
    -------
    beta : array_like
        Relative wave direction [deg]
    Ndir : int
        Number of directions
    Example
    -------
    >>> beta, Ndir = calculate_relative_dirc(theta_deg, psi_deg, conv="from")
    """
    # Check if the first and last directions are equivalent, considering floating-point precision
    # if np.isclose(theta_deg[-1] - theta_deg[0], 360, atol=1e-1):
    #     # The directions are actually wrapped around, so that theta[0,] = theta[-1,]
    #     print("The directions are actually wrapped around, so that theta[0,] = theta[-1,]")
    #     beta = theta_deg[:-1,] - psi_deg
    #     Ndir = len(theta_deg) - 1
    # else:
    beta = theta_deg - psi_deg
    Ndir = len(theta_deg)
    
    # Apply the specified direction convention
    if conv == "from":
        beta = general_func.wrap_angle_range(180 + beta)
    elif conv == "to":
        beta = general_func.wrap_angle_range(beta)
    else:
        raise NameError('Invalid direction convention. Only accepts "from" or "to".')

    return beta, Ndir


# ─────────────────────────────────────────────────────────────
# FUNCTIONS: Calculate 2D encounter spectrum from estimated complex
#wave compoenents
# ─────────────────────────────────────────────────────────────
def wave_comp2dir_spec_enc(estimated_states, om_enc, beta):
   
    """Calculates the 2D encounter spectrum from estimated 
    complex wave components.
    Parameters
    ----------
    estimated_states : array_like of shape (2*Nfreq*Ndir)
        Estimated complex wave components [m^2.s/rad].
    om_enc : array_like of shape (Nfreq,Ndir)
        Encounter wave frequencies [rad/s].
    beta : array_like of shape (Ndir,)
        Relative Wave directions [rad].  
    Returns
    -------
    S2D_enct : array_like of shape (Nfreq,Ndir)
        2D encounter spectrum [m^2.s/rad].
    Example
    -------
    >>> S2D_enct = wave_comp2dir_spec_enc(estimated_states,
      enc_frequency, beta)
    """
    S2D_enct = np.zeros((len(estimated_states), len(om_enc), len(beta)))
    dbeta = np.diff(beta)
    dbeta = np.append(dbeta, dbeta[-1])
    dom_enc = np.vstack(
        (np.diff(om_enc, axis=0), np.abs(om_enc[-1,] - om_enc[-2,]))
    )
    for k in range(len(estimated_states)):
        for j in range(len(om_enc)):
            for m in range(len(beta)):
                term = (
                    estimated_states[k, 2 * j * len(beta) + 2 * m] ** 2
                    + estimated_states[k, 2 * j * len(beta) + 2 * m + 1] ** 2
                )
                S2D_enct[k, j, m] = (
                    0.5
                    * 1
                    / dbeta[m]
                    * 1
                    / dom_enc[j, m]
                    * term
                )
    return S2D_enct


# ─────────────────────────────────────────────────────────────
# FUNCTIONS: Calculate 2D absolutespectrum from estimated complex
# wave compoenents
# ─────────────────────────────────────────────────────────────
def wave_comp2dir_spec_over_time(estimated_states, omega, beta):
    """Calculates the 2D absolute spectrum from estimated
    complex wave components.
    Parameters
    ----------
    estimated_states : array_like of shape (2*Nfreq*Ndir)

        Estimated complex wave components [m^2.s/rad].
    omega : array_like of shape (Nfreq,)
        Wave angular frequencies [rad/s].
    beta : array_like of shape (Ndir,)
        Relative Wave directions [rad].
    Returns
    -------
    S2Dt : array_like of shape (Nfreq,Ndir)
        2D absolute spectrum [m^2.s/rad].
    Example
    -------
    >>> S2Dt = wave_comp2dir_spec_over_time(estimated_states, omega, beta)
    """

    S2Dt = np.zeros((len(estimated_states), len(omega), len(beta)))
    dbeta = np.diff(beta)
    dbeta = np.append(dbeta, dbeta[-1])
    domega = np.diff(omega)
    domega = np.append(domega ,  domega[-1])
   
    for k in range(len(estimated_states)):
        for j in range(len(omega)):
            for m in range(len(beta)):
                term = (
                    estimated_states[k, 2 * j * len(beta) + 2 * m] ** 2
                    + estimated_states[k, 2 * j * len(beta) + 2 * m + 1] ** 2
                )
                S2Dt[k, j, m] = (
                    0.5
                    * 1
                    / dbeta[m]
                    * 1
                    / domega[j]
                    * term
                )
    return S2Dt


# ─────────────────────────────────────────────────────────────
# FUNCTIONS: Calculate  the wave elevation from estimated complex
# wave compoenents
# ─────────────────────────────────────────────────────────────
def wave_comp2_wave_elev(
    estimated_states, frequency, direction, delta_time, time_steps
):
    """Calculates the wave elevation from estimated complex wave components.
    Parameters
    ----------
    estimated_states : array_like of shape (time_steps, 2*Nfreq*Ndir)
        Estimated complex wave components [m^2.s/rad].
    frequency : array_like of shape (Nfreq, Ndir)
        Wave angular frequencies [rad/s].
    direction : array_like of shape (Ndir,)
        Relative wave directions [rad].     
    delta_time : float
        Time step size [s].
    time_steps : int
        Number of time steps.
    Returns
    -------
    estimated_wave_elevation : array_like of shape (time_steps,)
        Estimated wave elevation [m].
    st_dev_elevation : float
        Significant wave height from wave elevation [m].
    Example
    -------
    >>> estimated_wave_elevation, st_dev_elevation = wave_comp2_wave_elev(
        estimated_states, frequency, direction, delta_time, time_steps
    )
    """
    estimated_wave_elevation = np.zeros((time_steps))

    for k in range(time_steps):
        t = k * delta_time
        for j in range(len(frequency)):
            for m in range(len(direction)):
                real_part = np.cos(frequency[j,m] * t)
                imag_part = np.sin(frequency[j, m] * t)

                estimated_wave_elevation[k] += (
                    estimated_states[k, 2 * (j * len(direction) + m)] * real_part
                    + estimated_states[k, 2 * (j * len(direction) + m) + 1] * imag_part
                )

    st_dev_elevation = 4 * np.std(estimated_wave_elevation)
    print(f" Significant wave from wave elevation:", st_dev_elevation)
    return estimated_wave_elevation, st_dev_elevation
