"""
Ship motion simulation using the double-sum method or equal energy method.

Author: Ryane Bourkaib
Date: 2025-05-21
Description:
    Generate synthetic ship motion in 6DOF (heave, pitch, etc.) using the 
    double-sum method or equal energy method based on the wave energy 
    spectrum and RAOs.

References:
    - R. Bourkaib, M. Kok, H.C. Seyffert,
    Unidirectional and multi-directional wave estimation from ship
    motions using an Adaptive Kalman Filter with the inclusion of varying forward speed,
    Probabilistic Engineering Mechanics,
    2025,
    103773,
    ISSN 0266-8920,
    https://doi.org/10.1016/j.probengmech.2025.103773.

    Kim, H. S., Park, J. Y., Jin, C., Kim, M. H., & Lee, D. Y. (2023).
      "Real-time inverse estimation of multi-directional random waves from
      vessel-motion sensors using Kalman filter". Ocean Engineering, 280.
      https://doi.org/10.1016/j.oceaneng.2023.114501
    - Duarte, T, Gueydon, S, Jonkman, J, & Sarmento, A. "Computation of Wave Loads
      Under Multidirectional Sea States for Floating Offshore Wind Turbines.
      " Proceedings of the ASME 2014 33rd International Conference on Ocean, 
      Offshore and Arctic Engineering. Volume 9B: Ocean Renewable Energy. 
      San Francisco, California, USA. June 8–13, 2014. V09BT09A023. ASME.
     https://doi.org/10.1115/OMAE2014-24148

Usage:
------

    from src.simulation.shipResp_func import ship_motion_double_sum
    motion_surge = ship_motion_double_sum(
        wave_amp, epsilon_seed, om_enc, time, seed,
        TRF_amps=TRF_amps, TRF_phases=TRF_phases,
        RespStr="Motions", AddNoise=True, snr=20
    )
"""

# ─────────────────────────────────────────────────────────────
# IMPORTS
# ─────────────────────────────────────────────────────────────
import math
import numpy as np
import matplotlib.pyplot as plt
from numpy.random import default_rng

# ─────────────────────────────────────────────────────────────
# FUNCTION: Generate ship motion using double sum method
# ─────────────────────────────────────────────────────────────
def ship_motion_double_sum(
    wave_amp,
    epsilon_seed,
    om_enc,
    time,
    seed,
    TRF_amps=0,
    TRF_phases=0,
    RespStr="Waves",
    AddNoise=False,
    snr=20,
):
    """Sum wave components to generate the wave elevation time series
      or motion responses.

    Parameters
    ----------
    wave_amp : array_like of shape (Nfreq,Nbeta,1)
        Amplitude of the wave components.
    epsilon_seed : array_like of shape (Nfreq,Nbeta,1)
        Random phases [rad] of the wave components. Those are uniformly distributed
        variables between 0 and :math:`2\pi`.
    om_enc : array_like of shape (Nfreq,1,1)
        Vector of encounter wave frequencies [rad/s].
    time : array_like of shape (1,1,Nt)
        Vector of time [s].
    seed : int
        Index of the seed number.
    result : array_like of shape (Nseed,Nt)
        Anterior value of the vector of wave/response sequence.

        .. note::
            Consult the documentation of the :func:`netsse.simul.ship_resp.simul_ship_resp`
            function for information on the other parameters.

    Returns
    -------
    result : array_like of shape (Nseed,Nt)
        Updated value of the vector of wave/response sequence.
    Example
    -------
    >>> motion_heave = ship_motion_double_sum(
        wave_amp, epsilon_seed, om_enc, time, seed,
        TRF_amps=TRF_amps, TRF_phases=TRF_phases,
        RespStr="Motions", AddNoise=True, snr=20
    )

    """
    # Sum wave components
    if RespStr == "Waves":
        #print("Simulation started for SEED - Waves" (seed + 1))
        result = np.sum(wave_amp * (np.cos(om_enc * time + epsilon_seed)), axis=(0, 1))

    # Sum motion components
    if RespStr == "Motions":
        #print("Simulation started for SEED - Motions" (seed + 1))

        result = np.sum(
            wave_amp * TRF_amps * (np.cos(om_enc * time + epsilon_seed + TRF_phases)),
            axis=(0, 1),
        )
        #check if wave_amp, epsilon_seed, om_enc, and time are nan
        if np.isnan(wave_amp).any() or np.isnan(epsilon_seed).any() or \
           np.isnan(om_enc).any() or np.isnan(time).any():
            print("Warning: NaN values detected in input parameters. Check wave_amp, epsilon_seed, om_enc, and time.")
            # Identify the source of NaN values
            if np.isnan(wave_amp).any():
                print("wave_amp contains NaN values.")
            if np.isnan(epsilon_seed).any():
                print("epsilon_seed contains NaN values.")
            if np.isnan(om_enc).any():
                print("om_enc contains NaN values.")
            if np.isnan(time).any():
                print("time contains NaN values.")
            return np.zeros_like(result)

        

        if AddNoise:
            rng_noise = default_rng()
            Nt = np.shape(time)[2]
            sigma_motiont = np.std(result)
            print("sigma_motiont", sigma_motiont)
            sigma_noise = sigma_motiont / np.sqrt(snr)
            print("sigma_noise", sigma_noise)
            noise = rng_noise.normal(0, sigma_noise, size=(Nt,))
            result += noise
    return result


# ─────────────────────────────────────────────────────────────
# FUNCTION: Generate ship motion using Equal Energy Method
# ─────────────────────────────────────────────────────────────

def ship_motion_equal_energy(
    wave_amp,
    epsilon_seed,
    om_enc,
    time,
    seed,
    TRF_amps=0,
    TRF_phases=0,
    AddNoise=False,
    snr=20,
):
    """ Equal energy method to generate the wave elevation time series or
    motion responses.

    Parameters
    ----------
    wave_amp : array_like of shape (Nfreq,1)
        Amplitude of the wave components.
    epsilon_seed : array_like of shape (Nfreq,1)
        Random phases [rad] of the wave components. Those are uniformly 
        distributed
        variables between 0 and :math:`2\pi`.
    om_enc : array_like of shape (Nfreq,1)
        Vector of encounter wave frequencies [rad/s].
    time : array_like of shape (1,Nt)
        Vector of time [s].
    seed : int
        Index of the seed number.
    result : array_like of shape (Nseed,Nt)
        Anterior value of the vector of wave/response sequence.
   
      
    Returns
    -------
    result : array_like of shape (Nseed,Nt)
        Updated value of the vector of wave/response sequence.

    Example
    -------
    >>> motion_heave = ship_motion_equal_energy(
        wave_amp, epsilon_seed, om_enc, time, seed,
        TRF_amps=TRF_amps, TRF_phases=TRF_phases,
        AddNoise=True, snr=20
    )
    Notes
    -----
    See also the ship_motion_double_sum function for the double-sum method.
    """
   
    # print('Simulation started for SEED %s - Motions'%(seed+1))
    result = np.sum(
        wave_amp * TRF_amps * (np.cos(om_enc * time + epsilon_seed + TRF_phases)),
        axis=0,
    )
    
    if AddNoise:
        rng_noise = default_rng()
        Nt = np.shape(time)[1]
        sigma_motiont = np.std(result)
        print("sigma_motiont", sigma_motiont)
        sigma_noise = sigma_motiont / np.sqrt(snr)
        print("sigma_noise", sigma_noise)
        noise = rng_noise.normal(0, sigma_noise, size=(Nt,))
        result += noise
    return result


