
"""
Genegeral functions for various purposes.
Author: Ryane Bourkaib  
Date: 2025-05-21  
Description:
   This module contains functions for various purposes, including:
    - Wrapping angles to a specified range


Example:
    from src.tools.general_func import ...

"""

# ─────────────────────────────────────────────────────────────
# IMPORTS
# ─────────────────────────────────────────────────────────────
import numpy as np
from scipy import integrate
from scipy.interpolate import CubicSpline


#from src.model.RAOs_USCG.USCG_Motion_Mode_Parser import USCGMotionModeParser

# ─────────────────────────────────────────────────────────────
# FUNCTION: Rearrange angles to a new range
# ─────────────────────────────────────────────────────────────
def wrap_angle_range(angle, start=0):
    """Wraps angles to a specified range of 360 degrees.

    Parameters
    ----------
    angle : array_like
        Array of angles to be wrapped.
    start : float, default 0
        Start of the target range. Output angles will lie within
        ``[start, start+360)`` if in degrees
    unit : {'deg'}, optional
        Unit of ``angle`` and ``start``. Must be either ``'deg'`` (degrees)
     

    Returns
    -------
    wrapped_angle : array_like
        Array of angles wrapped into the specified range.
    References
    ----------
    - NetSSE Mounet, R. E. G. (Creator) & Nielsen, U. D. (Supervisor),
      Technical University of Denmark, 27 Jul 2023 DOI: 10.11583/DTU.26379811, 
      https://gitlab.gbar.dtu.dk/regmo/NetSSE 

    Example
    -------
    >>> wrapped = wrap_angle_range(angle, start=0)
    """

    size = np.size(angle)
    angle_flat = np.reshape(np.array(angle), (-1,))
    wrapped_angle = []


    for a in angle_flat:
        while not (start <= a < start + 360):
            a = (360 + a) * (a < start) + \
                a * (start <= a < start + 360) + \
                (a - 360) * (a >= start + 360)
        wrapped_angle.append(a)

    if size < 2:
        return wrapped_angle[0]
    else:
        return np.array(wrapped_angle).reshape(np.shape(angle))

# ─────────────────────────────────────────────────────────────
# FUNCTION: Wrap angles to a new range
# ─────────────────────────────────────────────────────────────
def extend_with_wrap(arr, axis=1, offset=0):
    """Extends an array by wrapping its first row/column to the end.

    Parameters
    ----------
    arr : array_like of shape (n, m) or (n,)
        Input array to be wrapped.
    axis : {1, 0}, optional
        If ``arr`` is 2D, specifies the axis along which to wrap:
        - 1: wrap the first column
        - 0: wrap the first row
        Default is 1.
    offset : float or array_like, default 0
        Value added to the repeated part of ``arr`` before concatenation.

    Returns
    -------
    wrapped_arr : array_like
        The array extended by repeating the first element (1D) or
        first row/column (2D) at the end, with optional offset.
    
    References
    ----------
    - NetSSE Mounet, R. E. G. (Creator) & Nielsen, U. D. (Supervisor),
      Technical University of Denmark, 27 Jul 2023 DOI: 10.11583/DTU.26379811, 
      https://gitlab.gbar.dtu.dk/regmo/NetSSE 

    Example
    -------
    >>> wrapped = extend_with_wrap(arr, axis=1, offset=0)
    """

    shape = np.shape(arr)
    if len(shape) == 1:
        wrapped_arr = np.concatenate([arr, np.reshape(arr[0] + offset, (1,))])
    else:
        n, m = shape
        if axis == 1:
            wrapped_arr = np.concatenate(
                [arr, np.reshape(arr[:, 0] + offset, (n, 1))],
                axis=axis
            )
        elif axis == 0:
            wrapped_arr = np.concatenate(
                [arr, np.reshape(arr[0, :] + offset, (1, m))],
                axis=axis
            )
    return wrapped_arr

# ─────────────────────────────────────────────────────────────
# FUNCTION: RAO data for the selected DOF
# ─────────────────────────────────────────────────────────────
def rao_data_slected_dof(
    chosen_motion_vector, rao_amp, rao_phase
):
    """
    Get the RAO data for the chosen motion vector.
    Parameters:
    ----------
  
    chosen_motion_vector (list): The chosen motion vector.
    rao_amplitude (ndarray): The RAO amplitude.
    rao_phase (ndarray): The RAO phase.
    Returns:
    -------
    new_rao_amplitude (ndarray): The new RAO amplitude.
    new_rao_phase (ndarray): The new RAO phase.
    Example:
    -------
    >>> new_rao_amplitude, new_rao_phase = rao_data_slected_dof(
        chosen_motion_vector, rao_amp, rao_phase
    )   
    """
    new_rao_amplitude = np.zeros(
        (len(chosen_motion_vector), rao_amp.shape[1], rao_amp.shape[2])
    )
    new_rao_phase = np.zeros(
        (len(chosen_motion_vector), rao_phase.shape[1], rao_amp.shape[2])
    )
  
   
    for i in range(len(chosen_motion_vector)):
        if chosen_motion_vector[i] == "surge":
            new_rao_amplitude[i] = rao_amp[0]
            new_rao_phase[i] = rao_phase[0]
       
        elif chosen_motion_vector[i] == "sway":
            new_rao_amplitude[i] = rao_amp[1]
            new_rao_phase[i] = rao_phase[1]
         
        elif chosen_motion_vector[i] == "heave":
            new_rao_amplitude[i] = rao_amp[2]
            new_rao_phase[i] = rao_phase[2]
            
        elif chosen_motion_vector[i] == "roll":
            new_rao_amplitude[i] = rao_amp[3]
            new_rao_phase[i] = rao_phase[3]
    
        elif chosen_motion_vector[i] == "pitch":
            new_rao_amplitude[i] = rao_amp[4]
            new_rao_phase[i] = rao_phase[4]
          
        elif chosen_motion_vector[i] == "yaw":
            new_rao_amplitude[i] = rao_amp[5]
            new_rao_phase[i] = rao_phase[5]
    return new_rao_amplitude, new_rao_phase

# ─────────────────────────────────────────────────────────────
# FUNCTION: Interpolate the RAO data
# ─────────────────────────────────────────────────────────────
def rao_interp_new_domega(
    rao_amp, rao_phase, om_enc, omega, new_domega
):
    """	
    Interpolate the RAO model to a new frequency resolution.
    Parameters
    ----------
    rao_amp : array_like of shape (dof, Nomega, Nbeta)
        RAO amplitude [m/m] or [deg/m].
    rao_phase : array_like of shape (dof, Nomega, Nbeta)
        RAO phase [deg].
    om_enc : array_like of shape (Nomega, Nbeta)
        Encounter frequencies [rad/s].
    omega : array_like of shape (Nomega,)
        Angular frequencies [rad/s].
    new_domega : float
        New desired frequency resolution [rad/s].
    Returns
    -------
    interpolated_frequency : array_like of shape (num_points,)
        Interpolated frequency array with the desired resolution.
    interpolated_rao_amplitude : array_like of shape (dof, num_points, Nbeta)
        Interpolated RAO amplitude [m/m] or [deg/m].
    interpolated_rao_phase : array_like of shape (dof, num_points, Nbeta)
        Interpolated RAO phase [deg].
    interpolated_encounter_frequency : array_like of shape (num_points, Nbeta)
        Interpolated encounter frequencies [rad/s].
    Example
    -------
    >>> omega_inter, rao_amp_inter, rao_phase_inter, rao_om_enc_inter = 
    rao_interp_new_domega(
    ...     rao_amp, rao_phase, om_enc, omega, new_domega
    ... )
    """
    # Interpolate the RAO model
    num_points = int((omega[-1] - omega[0]) / new_domega) 

    # Create a new frequency array with the desired delta frequency
    omega_inter = np.linspace(omega[0], omega[-1], num_points)

    # Create the interpolated RAO arrays
    rao_amp_inter = np.zeros(
        (len(rao_amp), num_points, rao_amp.shape[2])
    )
    rao_phase_inter = np.zeros_like(rao_amp_inter)
    rao_om_enc_inter =np.zeros((num_points, rao_amp.shape[2]))

    # Interpolate the RAO model
    for dof_index in range(rao_amp.shape[0]):
        cs_amplitude = CubicSpline(omega, rao_amp[dof_index, :, :])
        rao_amp_inter[dof_index, :, :] = cs_amplitude(
            omega_inter
        )

        cs_phase = CubicSpline(omega, rao_phase[dof_index, :, :])
        rao_phase_inter[dof_index, :, :] = cs_phase(omega_inter)

        cs_encounter_frequency = CubicSpline(
            omega, om_enc
        )
        rao_om_enc_inter = cs_encounter_frequency(
            omega_inter
        )

    return ( 
        omega_inter,
        rao_amp_inter,
        rao_phase_inter,
        rao_om_enc_inter,
    )


# ─────────────────────────────────────────────────────────────
# FUNCTION: Transform encounter frequency to absolute frequency      
# ─────────────────────────────────────────────────────────────
def encounter_frequency_to_absolute_frequency(
    encounter_frequency, direction_rad, forward_speed
):
    # Loop through each frequency component

    g = 9.81
    omega_total = []

    #     # Loop through each frequency component
    for m in range(len(direction_rad)):
        w = []
        j = 0
        A = forward_speed / g * np.cos(direction_rad[m])
        encounter_frequency_m = encounter_frequency[:, m]
        for i in range(len(encounter_frequency)):
            omega_ei = encounter_frequency_m[i]

            if direction_rad[m] >= np.deg2rad(90) and direction_rad[m] <= np.deg2rad(270):
                if direction_rad[m] == np.deg2rad(90):
                    j = j + 1
                    # print('1 to 1 mapping')
                    omega_4 = omega_ei
                    w.append(omega_4)
                else:
                    # print('1 to 1 mapping')
                    j = j + 1
                    omega_4 = (1 - np.sqrt(1 - 4 * omega_ei * A)) / (2 * A)
                    # print('omega_4', omega_4)
                    w.append(omega_4)

            else:
                print("3 to 1 mapping")
                omega_1 = (1 - np.sqrt(1 - 4 * omega_ei * A)) / (2 * A)
                omega_2 = (1 + np.sqrt(1 - 4 * omega_ei * A)) / (2 * A)
                omega_3 = (1 + np.sqrt(1 + 4 * omega_ei * A)) / (2 * A)

                # Region III
                j = j + 1
                w.append(omega_3)

                if omega_ei < 1 / (4 * A):
                    # Region I
                    j = j + 1
                    w.append(omega_1)

                    # Region II
                    j = j + 1
                    w.append(omega_2)

        # Sort absolute frequencies from low to high and reorder S_0 accordingly
        sorted_indices = np.argsort(w)
        omega_sorted = np.array(w)[sorted_indices]
        omega_total.append(omega_sorted)

    omega_total = np.array(omega_total)
    print("omega_sorted_total", np.shape(omega_total))

    return omega_total


# ─────────────────────────────────────────────────────────────
# FUNCTION: Extract the simulated ship motion data based on 
# the equalEnergy method
# ─────────────────────────────────────────────────────────────
def extract_sim_ship_data_equalEnergy(data):
    #Extract the data from main_sim_equalEnergy.py file:

    # Sea state parameters:
    S1D = data['S1D']   # 1D wave spectrum [m^2.s/rad]
    omega_wave = data["omega_wave"] # wave frequencies
    beta_wave = data["betas_wave"] # wave directions
    S2D = data["S2D"] # 2D wave spectrum [m^2.s/rad]
    omega= data["omega"] # angular frequencies
    om_enc = data["om_enc"] # encounter frequencies
    Hs_2d = data["Hs_2d"] # Significant wave height from 2D spectrum
    Tp_2d = data["Tp_2d"] # Peak period from 2D spectrum
    theta0_2d = data["theta0_2d"] # Mean wave direction from 2D spectrum
    beta0_2d = data["beta_s2d"] # Mean wave direction from 1D spectrum
    U = data["U"] # ship speed [m/s]
    theta0 = data["theta0"]  # mean wave direction
    psi = data["psi"]  # ship heading [rad]
   
    # Ship motion time series:
    Num_Ts = data["n_timeSeries"]  # Number of time series
    time  = data["time"]           # Time vector
    wavet = data["wavet"]      # Wave elevation
    surget = data["surget"]       # Surge motion
    swayt = data["swayt" ]     # Sway motion
    heavet = data["heavet"]     # Heave motion
    rollt = data["rollt"]        # Roll motion
    pitcht = data["pitcht" ]      # Pitch motion
    yawt = data["yawt" ]          # Yaw motion
    surget_noisy = data["surget_noisy"]   # Noisy surge motion
    swayt_noisy = data["swayt_noisy"]     # Noisy sway motion
    heavet_noisy = data["heavet_noisy"]   # Noisy heave motion
    rollt_noisy = data["rollt_noisy"] # Noisy roll motion
    pitcht_noisy = data["pitcht_noisy"]   # Noisy pitch motion
    yawt_noisy = data["yawt_noisy" ]  # Noisy yaw motion
    srn = data["snr" ]               # Signal to noise ratio
    fs = data["fs"  ]                # Sampling motion frequency
    epsilon = data["epsilon" ]      # random phases
    epsilon_seed = data["epsilon"]  # random seed
    wave_amplitude = data["wave_amplitude"]  # wave amplitude
   

    # print usful information
    print(f"Number of time series: {Num_Ts} realizations")
    print(f"Simulation time: {time} s")
    print(f"Sampling frequency of time serie: {fs} Hz")
    print(f"Ship speed: {U}m/s")
    print(f"sea state parameters:, {Hs_2d}m {Tp_2d}s,{np.rad2deg(beta0_2d)}deg, {np.rad2deg(theta0_2d)}deg")


    return (# Sea State
    S1D, omega_wave, omega, S2D, beta_wave, om_enc,
    Hs_2d, Tp_2d, theta0_2d, beta0_2d, theta0,

    # Ship Info
    U, psi,

    # Time and Sampling
    time, fs, Num_Ts,

    # Wave Components
    wave_amplitude, epsilon, epsilon_seed,

    # Clean Motions
    wavet, surget, swayt, heavet, rollt, pitcht, yawt,

    # Noisy Motions
    surget_noisy, swayt_noisy, heavet_noisy, rollt_noisy,
    pitcht_noisy, yawt_noisy,

    # Other
    srn
    )

# ─────────────────────────────────────────────────────────────
# FUNCTION: Extract the simulated ship motion data based on 
# the equalEnergy method
# ─────────────────────────────────────────────────────────────
def extract_sim_ship_data_doubelSum(data):
    #Extract the data from main_sim_doubelSum.py file:

    # Sea state parameters:
    S1D = data['S1D']   # 1D wave spectrum [m^2.s/rad]
    omega_wave = data["omega_wave"] # wave frequencies
    beta_wave = data["betas_wave"] # wave directions
    S2D = data["S2D"] # 2D wave spectrum [m^2.s/rad]
    omega= data["omega"] # angular frequencies
    om_enc = data["om_enc"] # encounter frequencies
    Hs_2d = data["Hs_2d"] # Significant wave height from 2D spectrum
    Tp_2d = data["Tp_2d"] # Peak period from 2D spectrum
    theta0_2d = data["theta0_2d"] # Mean wave direction from 2D spectrum
    beta0_2d = data["beta_s2d"] # Mean wave direction from 1D spectrum
    U = data["U"] # ship speed [m/s]
    theta0 = data["theta0"]  # mean wave direction
    psi = data["psi"]  # ship heading [rad]
   
    # Ship motion time series:
    Num_Ts = data["n_timeSeries"]  # Number of time series
    time  = data["time"]           # Time vector
    wavet = data["wavet"]      # Wave elevation
    surget = data["surget"]       # Surge motion
    swayt = data["swayt" ]     # Sway motion
    heavet = data["heavet"]     # Heave motion
    rollt = data["rollt"]        # Roll motion
    pitcht = data["pitcht" ]      # Pitch motion
    yawt = data["yawt" ]          # Yaw motion
    surget_noisy = data["surget_noisy"]   # Noisy surge motion
    swayt_noisy = data["swayt_noisy"]     # Noisy sway motion
    heavet_noisy = data["heavet_noisy"]   # Noisy heave motion
    rollt_noisy = data["rollt_noisy"] # Noisy roll motion
    pitcht_noisy = data["pitcht_noisy"]   # Noisy pitch motion
    yawt_noisy = data["yawt_noisy" ]  # Noisy yaw motion
    srn = data["snr" ]               # Signal to noise ratio
    fs = data["fs"  ]                # Sampling motion frequency
    epsilon = data["epsilon" ]      # random phases
    epsilon_seed = data["epsilon"]  # random seed
    wave_amplitude = data["wave_amplitude"]  # wave amplitude
   

    # print usful information
    print(f"Number of time series: {Num_Ts} realizations")
    print(f"Simulation time: {time} s")
    print(f"Sampling frequency of time serie: {fs} Hz")
    print(f"Ship speed: {U}m/s")
    print(f"sea state parameters:, {Hs_2d}m {Tp_2d}s,{np.rad2deg(beta0_2d)}deg, {np.rad2deg(theta0_2d)}deg")


    return (# Sea State
    S1D, omega_wave, omega, S2D, beta_wave, om_enc,
    Hs_2d, Tp_2d, theta0_2d, beta0_2d, theta0,

    # Ship Info
    U, psi,

    # Time and Sampling
    time, fs, Num_Ts,

    # Wave Components
    wave_amplitude, epsilon, epsilon_seed,

    # Clean Motions
    wavet, surget, swayt, heavet, rollt, pitcht, yawt,

    # Noisy Motions
    surget_noisy, swayt_noisy, heavet_noisy, rollt_noisy,
    pitcht_noisy, yawt_noisy,

    # Other
    srn
    )