"""
Wave environment generation module

Author: Ryane Bourkaib
Date: 2025-05-21
Description:
    Generates sea state conditions for simulation, including:
    - JONSWAP wave spectrum
    - Directional spreading function
    - Unidirectional wave spectrum
    - Multidirectional wave spectrum
    Used as input to ship motion simulation.

Example:
    >>> from src.env_conditions import ...
    >>> S = JONSWAP_DNV(Tp, Hs, omega, gamma='standard') 
    >>> D = directional_distribution(directions, theta0=0, spreading=10)
    >>> S_2D = np.outer(S, D)
"""

# ─────────────────────────────────────────────────────────────
# IMPORTS
# ─────────────────────────────────────────────────────────────
import math
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import gamma
from scipy.interpolate import interp1d
from numpy.random import default_rng

# ─────────────────────────────────────────────────────────────
# FUNCTION: JONSWAP Spectrum
# ─────────────────────────────────────────────────────────────
def JONSWAP_DNV(Tp, Hs, omega, gamma='standard'):
    """Computes the JONSWAP spectrum corresponding to the input sea state 
    parameters.

    The JONSWAP spectrum is formulated as a modification of
    a Pierson-Moskowitz spectrum for a developing sea state in a fetch
    limited situation.

    Parameters
    ----------    
    Tp : float
        Peak period [s].
    Hs : float
        Significant wave height [m].
    omega : array_like of shape (Nfreq,)
        Vector of angular frequencies [rad/s].
    gamma : {'standard','DNV',float}, optional
        Peak shape parameter [-]. The value can be user-provided as a float. 
        Alternatively, if ``'standard'`` is input, then ``gamma`` will take the 
        standard value of 3.3., while a value ``'DNV'`` as input leads to 
        following the procedure 3.5.5.5 described in DNV-RP-C205.
        .. tip::
            Use ``gamma = 1`` to output a standard Pierson-Moskowitz spectrum.

    Returns
    -------
    S : array_like of shape (Nfreq,)
        Standard wave spectrum [m^2.s/rad].

    Example
    -------
    >>> S = JONSWAP_DNV(Tp, Hs, omega, gamma='standard') 

    References
    ----------
    - Fossen T. (2011). Handbook of Marine Craft Hydrodynamics and Motion
        Control. John Wiley & Sons.

    """
    # Constants
    g = 9.81 # m/s^2
    omega_p = 2*np.pi/Tp  # angular spectral peak frequency [rad/s]

    # Pierson-Moskowitz spectrum:
    S_PM = (5/16*Hs**2*omega_p**4)*omega**(-5)*np.exp(-5/4*(omega_p/omega)**(4))
    
    # JONSWAP spectrum:
    crit = Tp/np.sqrt(Hs) # Criterion defined in DNV to select an appropriate gamma.
    
    if gamma == 'standard':
        gamma = 3.3    
    elif gamma == 'DNV':
        if crit<=3.6:
            gamma = 5 # steep waves → high gamma = 5.
            print('JONSWAP spectrum should be used with caution for the given (Tp,Hs)')
        elif crit>=5:
            gamma = 1 # broad waves → gamma = 1
            print('JONSWAP spectrum should be used with caution for the given (Tp,Hs)')
        else:
            gamma = np.exp(5.75-1.15*crit)
    
    print('JONSWAP spectrum with gamma =',gamma)
    
    A_gamma = 1-0.287*np.log(gamma);
    sigma_a = 0.07; sigma_b = 0.09
    sigma = sigma_a*np.ones(np.shape(omega))  #spectral width parameter [n.d.]
    sigma[np.where(omega>omega_p)] = sigma_b
    S = A_gamma*S_PM*gamma**(np.exp(-0.5*((omega-omega_p)/(sigma*omega_p))**2))
    
    return S

# ─────────────────────────────────────────────────────────────
# FUNCTION: Spreading function
# ─────────────────────────────────────────────────────────────
def spreading_fun(theta, theta0, s):
    """Computes the spreading function for the directional wave spectrum
    Parameters
    ----------
    theta : array_like
        Wave directions [rad].
    theta0 : float
        Mean wave direction [rad].
    s : int
        Spreading parameter [-].
    Returns
    -------     
    D_theta : array_like
        Spreading function [-].
    
    Example
    -------
    >>> D= spreading_fun(theta, theta0=np.pi/4, s=2)
    .
    Parameters
    ----------
    theta : array_like
        Wave directions [rad].
    theta0 : float
        Mean wave direction [rad].
    s : int
        Spreading parameter [-].
    Returns
    -------
    D : array_like
        Spreading function [-].
    D_norm : array_like
        Normalized spreading function [-].
    Example
    -------
    >>> D = spreading_fun(theta, theta0=np.pi/4, s=2)
    """
    
    Cons = 2**(2*s-1)/np.pi*gamma(s+1)**2/gamma(2*s+1) 
    D_theta = np.reshape(Cons*((np.cos((theta-theta0)/2))**2)**s,(1,-1))
    # check if the integral of spreading function is 1
    # integral = np.trapz(D_theta, theta)
    # print(f"Integral of spreading function: {integral}")
    return D_theta

def limited_cosine_spreading(theta, s, theta0, theta_max):
    """
    Limited-range cosine spreading function for directional wave spectrum.
    This function computes the spreading function for a given set of
    incoming wave directions, a spreading parameter, a mean wave direction,
    and a maximum allowed deviation from the mean direction.

    Parameters:
    ----------
    theta : np.ndarray
        Incoming wave directions [rad]
    s : float
        Spreading parameter [-]
    theta0 : float
        Mean wave direction [rad]
    theta_max : float
        Max allowed deviation from mean direction [rad]

    Returns:
    --------
    D : np.ndarray
        Normalized spreading function values [unitless]
    """
    D = np.zeros_like(theta)
    mask = np.abs(theta - theta0) <= theta_max
    C = (np.sqrt(np.pi) * gamma(s + 1)) / (2 * theta_max * gamma(s + 0.5))
    D[mask] = C * np.abs(np.cos(np.pi * (theta[mask] - theta0) / (2 * theta_max))) ** (2 * s)

    return D

# ─────────────────────────────────────────────────────────────
# FUNCTION: 1D to 2D wave spectrum conversion
# ─────────────────────────────────────────────────────────────
def wave_spec_1dto2d(S1D, theta, theta0, s=2, max_theta=np.pi/2):
    """Converts a 1D wave spectrum to a 2D directional wave spectrum.
    Parameters
    ----------
    S1D : array_like
        1D wave spectrum [m^2.s/rad].
        theta : array_like
        Wave directions [rad].
        theta0 : float
        Mean wave direction [rad].
        s : int
        Spreading parameter [-].
    Returns
    -------
        S2D : array_like
        2D wave spectrum [m^2.s/rad].
    Example
    -------
    >>> S2D = wave_spec_1dto2d(S1D, theta, theta0, s=2, max_theta=np.pi/2)
    """
   
    #D = limited_cosine_spreading(theta, s, theta0, max_theta)
    D = spreading_fun(theta,theta0, s)
    S2D = np.reshape(S1D,(-1,1))*np.ones((len(S1D),len(theta)))
    S2D = D*S2D # 2D wave spectrum [m^2.s/rad]
    return S2D


# ─────────────────────────────────────────────────────────────
# FUNCTION: Calculate the energy step vector for wave direction descretisation
# ─────────────────────────────────────────────────────────────

def energy_step_vector(direction_number):
    """Calculates the energy step vector for wave direction discretization.
    Parameters
    ----------
    direction_number : int
        Number of directions [-].
    Returns
    -------
    E : array_like
        Energy step vector [rad].
    Example
    -------
    >>> E = energy_step_vector(36)
    """

    # Calculate the range of angles
    start = 1 / (2 * direction_number)
    end = 1 - 1 / (2 * direction_number)

    # Calculate the energy step vector
    E = np.linspace(start, end, num=direction_number, endpoint=True)
   
    # E = np.arange(start, end, step)
    return E

# ─────────────────────────────────────────────────────────────
# FUNCTION: Generate random phases for wave directions
# ─────────────────────────────────────────────────────────────

def equal_energy_theta(theta, theta0, theta_num=10, s=2):
    """
    Computes discrete wave directions based on the equal energy method.

    Parameters:
    ----------
    - theta: array-like, wave directions in radians (high-resolution, uniformly spaced)
    - theta0: float, mean or peak wave direction in radians
    - theta_num: int, number of discrete wave directions to compute (default = 10)
    - s: float, spreading parameter for the directional spreading function (default = 2)

    Returns:
    -------
    - theta_inter: array of interpolated wave directions (in radians) corresponding to equal energy partitions
    - P_theta: cumulative energy distribution function values
    - D: directional spreading function values over theta
    - energy_steps: array of energy step values (in radians) for the specified number of directions

    Example:
    -------
    >>> theta_inter, P_theta, D = equal_energy_theta(theta, theta0, theta_num=10, s=2)

    References:
    ----------
  
    """
    # 1. Directional spreading function:
    D = spreading_fun(theta, theta0, s=s)

    # 2. Cumulative energy distribution function:
    dtheta = np.abs(theta[1] - theta[0])
    P_theta = np.cumsum(D) * dtheta

    # 3. Energy step vector:
    energy_steps = energy_step_vector(theta_num)

    # 4. Interpolate to find discrete directions:
    interpolator = interp1d(P_theta, theta, kind="linear")
    theta_inter = interpolator(energy_steps)

    return theta_inter, P_theta, D, energy_steps


