"""
Main script for ship motion generation.

Author: Ryane Bourkaib
Date: 2025-05-21
Description: This script simulates ship motions in a stochastic sea environment using an equal enrgy approach.
It generates time series for wave elevation and six degrees of freedom (DOF) ship motions (surge, sway, heave, roll, pitch, ywave_amp)
based on a specified wave spectrum, ship speed, and transfer functions (RAOs).
Key Steps:
    - Defines wave spectrum parameters (significant wave height, peak period).
    - Computes 1D and 2D wave spectra and extracts sea state parameters.
    - Loads and interpolates ship motion transfer functions (RAOs) for the selected ship speed and wave directions.
    - Generates random phases for the wave components.
    - Calculates wave elevation
    - Simulates time series for wave elevation and ship motions, with and without added noise.
    - Stores all relevant simulation data in a dictionary and saves it to a .npy file.
Inputs:
    None (parameters are set within the script).
Outputs:
    - Saves a dictionary containing simulation results and parameters to a .npy file in the results directory.
        python main_sim_EqualEnergy.py
    
Example:
    Run this script directly:
        python main_sim_EqualEnergy.py
"""

# ─────────────────────────────────────────────────────────────
# IMPORTS
# ─────────────────────────────────────────────────────────────
import os
import math
import numpy as np
import matplotlib.pyplot as plt
from numpy.random import default_rng
from scipy.interpolate import interp1d
from scipy.interpolate import RegularGridInterpolator


import src.tools.env_cond as env_cond
import src.analysis.seaState_func as seaState_func
import results.plots.ploting as ploting
import src.tools.general_func as general_func
from src.simulations.shipResp_func import  ship_motion_equal_energy
# ─────────────────────────────────────────────────────────────
# MAIN FUNCTION
# ─────────────────────────────────────────────────────────────

def main ():

    # Wave spectrum 1D:
    g = 9.81  # gravity [m/s^2]
    Hs = 3.0  # significant wave height [m]
    Tp = 9.0  # peak period [s]
    domega_wave = 0.001  # rad/s
    omega_wave_end = 2  # rad/s
    omega_wave = np.arange(domega_wave, omega_wave_end, domega_wave)
    freq_wave = omega_wave / (2 * np.pi)  # wave frequencies [Hz]
    Nomega_wave = len(omega_wave)
    print(f"Number of wave frequencies: {Nomega_wave}")
    S1D = env_cond.JONSWAP_DNV(Tp, Hs, omega_wave, gamma="standard")
    #ploting.plt_S1D(omega_wave, S1D, "JONSWAP", Hs=Hs, Tp=Tp)
   
    # Simulated sea state parameters from S1D:
    Hs_1d, Tp_1d, Tm01, Tm02, Tm24, TE, Sm02, epsilon = (
        seaState_func.spec1d_to_sea_state(S1D, omega_wave)
    )
    print("🌊 Sea State Parameters from 1D Spectrum")
    print(f"Significant wave height Hs        : {Hs_1d:.3f} m")
    print(f"Peak period Tp                    : {Tp_1d:.3f} s")
    print(f"Mean wave period Tm01             : {Tm01:.3f} s")
    print(f"Zero up-crossing period Tm02      : {Tm02:.3f} s")
    print(f"Crest period Tm24                 : {Tm24:.3f} s")
    print(f"Wave energy period TE             : {TE:.3f} s")
    print(f"Wave steepness Sm02               : {Sm02:.3f} [-]")
    print(f"Spectral bandwidth ε              : {epsilon:.3f} [-]")
  
    # Calculate the 2D spectrum:
    theta = np.deg2rad(np.linspace(-180, 180, 1000))  # wave directions [rad]
    # NOTE: Assuming ship heading psi is aligned with the wave direction theta0
    theta0 = np.deg2rad(0)  # mean wave direction [rad]
    psi = np.deg2rad(0)  # ship heading [rad]
    S2D = env_cond.wave_spec_1dto2d(S1D, theta, theta0, s=2, max_theta=np.pi/2)
    #ploting.plt_polar_S2D(omega_wave, theta, S2D, ship_heading_rad=psi)
    
    # Simulated sea state parameters from S2D:
    Hs_2d, Tp_2d, theta0_2d, s2d_to_s1d, beta_s2d = seaState_func.spec2d_to_sea_state(
        S2D, omega_wave, theta, psi
    )
    print("🌊 Sea State Parameters from 2D Spectrum")
    print(f"Significant wave height Hs        : {Hs_2d:.3f} m")
    print(f"Peak period Tp                    : {Tp_2d:.3f} s")
    print(f"Mean wave direction θ0            : {np.rad2deg(theta0_2d):.3f} °")

    
    # Descritision of wave directions to be used in equal energy method:
    # NOTE: The number of direction theta_num should be equal to the 
    # number of wave frequencies.
    theta_num = 500  # number of wave directions
    theta_inter, P_theta, D, energy_steps = env_cond.equal_energy_theta(
        theta, theta0, theta_num, s=2)
    #ploting.plt_cumulative_energy_distribution(theta, P_theta, theta_inter, 
    #                                           D, energy_steps)
   
    # Calculate the relative wave direction:
    # NOTE: The relative wave direction is the angle between the ship heading 
    # and the wave direction.
    beta_wave, Nbeta_wave = seaState_func.calculate_relative_dirc(
        np.rad2deg(theta_inter), np.rad2deg(psi), "from"
    )
    beta_wave = np.deg2rad(beta_wave)  # convert to rad

    # Import TF motion model (TF amplitude and phase [(m,deg)/m and deg]):
     # Import TF motion model (TF amplitude and phase [(m,deg)/m and deg]):
    # NOTE: The TFs are confidentioal and should not be shared publicly.
    # we set random rao motion model for the sake of this example.
    
    omega_TF = np.arange(0.1, 2.0, 0.01)  # Example TF frequencies [rad/s]
    TF_beta = np.arange(0, 360, 10)  # Example relative wave directions [degrees]   
    DOF = 6  # Number of degrees of freedom (DOFs) 
    TF_amps = np.abs(np.random.rand(DOF, len(omega_TF), len(TF_beta)))  # Random TF amplitudes
    TF_phases = np.random.rand(DOF, len(omega_TF), len(TF_beta)) * 360  # Random TF phases [degrees]
    U = 9  # Ship forward speed [m/s]
    TF_enc_om = omega_TF.reshape(-1,1) - omega_TF.reshape(-1,1) **2 * U  / g * np.cos(np.deg2rad(TF_beta).reshape(1,-1))  # Encounter frequencies [rad/s]
    print("🚢 RAO Motion Model Parameters shape:")
    print(f"TF frequencies: {omega_TF.shape}")
    print(f"TF relative wave directions: {TF_beta.shape}")
    print(f"TF amplitudes: {TF_amps.shape}")
    print(f"TF phases: {TF_phases.shape}")
    print(f"TF encounter frequencies: {TF_enc_om.shape}")
    # NOTE: Since all the DOFs have the same encounter frequencies,
    
    # Ensures periodic wrapping of angular data
    TF_amps_new = []
    TF_phases_new = []
    TF_enc_om_new = []
    TF_beta = general_func.wrap_angle_range(TF_beta)
    if TF_beta[0,] % 360 != TF_beta[-1,] % 360:
        TF_beta_new = general_func.extend_with_wrap(TF_beta, axis=0, offset=360)
        for i in range(TF_amps.shape[0]):
            TF_amps_new.append(general_func.extend_with_wrap(TF_amps[i], axis=1))
            TF_phases_new.append(general_func.extend_with_wrap(TF_phases[i], axis=1))
            TF_enc_om_new = general_func.extend_with_wrap(TF_enc_om, axis=1)
        
        TF_amps = np.array(TF_amps_new)
        TF_phases = np.array(TF_phases_new)
        TF_enc_om = np.array(TF_enc_om_new)
        TF_beta = TF_beta_new
        TF_phases = TF_phases_new

    # Convert the angles to rad
    TF_beta = np.deg2rad(TF_beta)
    TF_phases = np.deg2rad(TF_phases)      # convert to rad rotatinal phases
    TF_amps[3:] = np.deg2rad(TF_amps[3:])  # convert to rad rotatinal amplitudes
    print("🚢 RAO Motion Model size:")
    print(f"Number of DOFs: {TF_amps.shape[0]}")
    print(f"Number of TF frequencies: {len(omega_TF):.3f}")
    print(f"Number of TF encounter frequencies: {len(TF_enc_om):.3f}")
    print(f"Number of TF relative wave direction: {len(TF_beta):.3f}")
    
    ###--------Ship Motion Generation--------###
    # Non-equidistant ang frequency spacing for  wave and motion time series:
    rng = default_rng()
    Nomega = 500 # chose a large number of frequencies
    ang_low = np.max((omega_wave[0,], omega_TF[0,]))
    ang_high = np.min((omega_wave[-1,], omega_TF[-1]))
    omega = np.linspace(ang_low, ang_high, Nomega)
    domega = np.hstack(
        (np.diff(omega), np.abs(np.array(omega[-1] - omega[-2])))
    )
    print(f"Number of angular frequencies: {len(omega)}")

    # Interpolate the uni/diretcional wave spectra based on the new frequencies:
    S1D_interpolator = interp1d(
        omega_wave, S1D, kind='linear', bounds_error=False)
    S1D_inter = S1D_interpolator(omega)  # Interpolated 1D wave spectrum [m^2.s/r] 

    # omega_grid, beta_grid = np.meshgrid(omega, theta_inter, indexing='ij')
    # query_points = np.column_stack((omega_grid.ravel(), beta_grid.ravel()))  
    # S2D_interptor = RegularGridInterpolator((omega_wave, theta) ,np.array(S2D),
    #                                       method='linear',bounds_error=False)
    # S2D_inter = S2D_interptor(query_points).reshape(omega_grid.shape)  # Directional wave spectrum [m^2.s/rad]
    # ploting.plt_S2D_polar(omega, theta_inter, S2D_inter)
   
    # print the shape of the interpolated 1 D and 2D wave spectra:
    print("🌊 Interpolated Wave Spectra:")
    print(f"Interpolated 1D wave spectrum shape: {S1D_inter.shape}")
    #print(f"Interpolated 2D wave spectrum shape: {S2D_inter.shape}")

    
    # Extract TRFs amplitude and interpolate them at new frequencies and new beta waves:
    TF_amp_surge = TF_amps[0]
    TF_amp_sway = TF_amps[1]
    TF_amp_heave = TF_amps[2]
    TF_amp_roll = TF_amps[3]
    TF_amp_pitch = TF_amps[4]
    TF_amp_ywave_amp = TF_amps[5]
    
   
    query_points = np.column_stack((omega, beta_wave))  # Both should be shape (N,)

    f_surge_amp = RegularGridInterpolator((omega_TF, TF_beta) ,TF_amp_surge,
                                          method='linear',bounds_error=True)
    f_sway_amp = RegularGridInterpolator((omega_TF, TF_beta) ,TF_amp_sway,
                                          method='linear',bounds_error=True)
    f_heave_amp = RegularGridInterpolator((omega_TF, TF_beta) ,TF_amp_heave,
                                          method='linear',bounds_error=True)
    f_roll_amp = RegularGridInterpolator((omega_TF, TF_beta) ,TF_amp_roll,
                                            method='linear',bounds_error=True)
    f_pitch_amp = RegularGridInterpolator((omega_TF, TF_beta) ,TF_amp_pitch,
                                            method='linear',bounds_error=True)
    f_ywave_amp_amp = RegularGridInterpolator((omega_TF, TF_beta) ,TF_amp_ywave_amp,
                                            method='linear',bounds_error=True)
    
    TF_amp_surge_new = np.expand_dims(f_surge_amp(query_points),axis=1)
    TF_amp_sway_new = np.expand_dims(f_sway_amp(query_points),axis=1)
    TF_amp_heave_new = np.expand_dims(f_heave_amp(query_points),axis=1)
    TF_amp_roll_new = np.expand_dims(f_roll_amp(query_points),axis=1)
    TF_amp_pitch_new = np.expand_dims(f_pitch_amp(query_points),axis=1)
    TF_amp_ywave_amp_new = np.expand_dims(f_ywave_amp_amp(query_points),axis=1)

    # Extract TFs amplitude and interpolate them at new frequencies:
    TF_phase_surge = TF_phases[0]
    TF_phase_sway = TF_phases[1]
    TF_phase_heave = TF_phases[2]
    TF_phase_roll = TF_phases[3]
    TF_phase_pitch = TF_phases[4]
    TF_phase_ywave_amp = TF_phases[5]

    f_surge_phase = RegularGridInterpolator((omega_TF, TF_beta) ,TF_phase_surge,
                                          method='linear',bounds_error=True)
    f_sway_phase = RegularGridInterpolator((omega_TF, TF_beta) ,TF_phase_sway,
                                          method='linear',bounds_error=True)
    f_heave_phase = RegularGridInterpolator((omega_TF, TF_beta) ,TF_phase_heave,
                                            method='linear',bounds_error=True)
    f_roll_phase = RegularGridInterpolator((omega_TF, TF_beta) ,TF_phase_roll,
                                            method='linear',bounds_error=True)
    f_pitch_phase = RegularGridInterpolator((omega_TF, TF_beta) ,TF_phase_pitch,
                                            method='linear',bounds_error=True)
    f_ywave_amp_phase = RegularGridInterpolator((omega_TF, TF_beta) ,TF_phase_ywave_amp,
                                            method='linear',bounds_error=True)

    TF_phase_surge_new =  np.expand_dims(f_surge_phase(query_points), axis=1)
    TF_phase_sway_new =  np.expand_dims(f_sway_phase(query_points), axis=1)
    TF_phase_heave_new = np.expand_dims(f_heave_phase(query_points), axis=1)
    TF_phase_roll_new = np.expand_dims(f_roll_phase(query_points), axis=1)
    TF_phase_pitch_new = np.expand_dims(f_pitch_phase(query_points), axis=1)
    TF_phase_ywave_amp_new = np.expand_dims(f_ywave_amp_phase(query_points), axis=1)
    
    f_omc_enc = RegularGridInterpolator((omega_TF, TF_beta), TF_enc_om,
                                        method='linear',bounds_error=True)
    TF_omc_enc_new = np.expand_dims(f_omc_enc(query_points), axis=1)

    print("🚢 Interpolated Transfer Functions:")
    print(f"Shape of TF amplitude of jth DOF: {TF_amp_surge_new.shape}")
    print(f"Shape of TF phase of jth DOF: {TF_phase_surge_new.shape}")
    print(f"Shape of TF encounter frequency of jth DOF: {TF_omc_enc_new.shape}")

   
    # Calculate wave amplitude : 
    wave_amp = np.sqrt(2 * S1D_inter * domega)  # Wave amplitude
   
    # Reshaping the arrays of omgea/beta and wave amplitude to 2D:
    kappa = omega**2 / g  # Wave number [rad/s^2]
    om_enc = omega - kappa * U* np.cos(beta_wave) # Encounter wave frequency [rad/s]

    wave_amp = wave_amp.reshape((-1, 1))     # Reshape to (Nfreq, 1)
    om_enc = om_enc.reshape((-1, 1))         # Reshape to (Nfreq, 1)
    print("🌊 Wave Parameters reshaped (Nfreq, 1):")
    print(f"Encounter frequency shape:      {om_enc.shape}")
    print(f"Wave frequency shape:           {omega.shape}")
    print(f"Wave direction shape:           {beta_wave.shape}")
    print(f"Wave amplitude shape:           {wave_amp.shape}")

    # Input varibles for the ship responses:
    T = 1000   # Simulation time [s] 
    fs = 10    # Sampling frequency [Hz]
    n_timeSeries = 1  # Number of time series to simulate
    time = np.arange(0, T, 1 / fs).reshape((1, -1))  # Time vector [s]
    
    # check if the signal is stocastic or not
    a = 2 * np.pi / (omega[1] - omega[0])
    print("max T shoud be chosen", a)
    if T < a:
        print("The signal is stochastic")
    else:
        print("The signal is not stochastic")

    snr = 20  # Signal to noise ratio in decibels [dB]#---------------------------------------------
    Nt = np.shape(time)[1]  # Number of time steps

    wavet = np.zeros((n_timeSeries, Nt))  # Wave elevation [m] time series
    heavet, rollt, pitcht = (
        np.zeros((n_timeSeries, Nt)),
        np.zeros((n_timeSeries, Nt)),
        np.zeros((n_timeSeries, Nt)),
    )
    swayt, yawt, surget = (
        np.zeros((n_timeSeries, Nt)),
        np.zeros((n_timeSeries, Nt)),
        np.zeros((n_timeSeries, Nt)),
    )
    heavet_noisy, rollt_noisy, pitcht_noisy = (
        np.zeros((n_timeSeries, Nt)),
        np.zeros((n_timeSeries, Nt)),
        np.zeros((n_timeSeries, Nt)),
    )
    swayt_noisy, yawt_noisy, surget_noisy = (
        np.zeros((n_timeSeries, Nt)),
        np.zeros((n_timeSeries, Nt)),
        np.zeros((n_timeSeries, Nt)),
    )
  
    # Generate the wave elevation and ship motions using the equal energy method
    # for each seed:
    for seed in range(n_timeSeries):
        # Generate the random phase for the seed:
        # Standard normal distributed variables with mean = 0:
        epsilon = rng.uniform(0, 2 * np.pi, size=(Nomega, n_timeSeries))
        epsilon_seed = np.expand_dims(epsilon[:, seed], axis=1)
        
        # Calculate the wave elevation:
        phase =  np.outer(om_enc, time) + epsilon_seed
        wavet[seed, :] = np.sum(wave_amp* np.cos(phase), axis=0) 

        # Calculate the ship motions:
        surget[seed, :] = ship_motion_equal_energy(
            wave_amp,
            epsilon_seed,
            om_enc,
            time,
            seed,
            TF_amp_surge_new,
            TF_phase_surge_new,
            AddNoise=False,
            snr=20,
        )
       
        surget_noisy[seed, :] = ship_motion_equal_energy(
            wave_amp,
            epsilon_seed,
            om_enc,
            time,
            seed,
            TF_amp_surge_new,
            TF_phase_surge_new,
            AddNoise=True,
            snr=20,
        )
        swayt[seed, :] = ship_motion_equal_energy(
            wave_amp,
            epsilon_seed,
            om_enc,
            time,
            seed,
            TF_amp_sway_new,
            TF_phase_sway_new,
            AddNoise=False,
            snr=20,
        )
        swayt_noisy[seed, :] = ship_motion_equal_energy(
            wave_amp,
            epsilon_seed,
            om_enc,
            time,
            seed,
            TF_amp_sway_new,
            TF_phase_sway_new,
            AddNoise=True,
            snr=20,
        )
        heavet[seed, :] = ship_motion_equal_energy(
            wave_amp,
            epsilon_seed,
            om_enc,
            time,
            seed,
            TF_amp_heave_new,
            TF_phase_heave_new,
            AddNoise=False,
            snr=20,
        )
        heavet_noisy[seed, :] = ship_motion_equal_energy(
            wave_amp,
            epsilon_seed,
            om_enc,
            time,
            seed,
            TF_amp_heave_new,
            TF_phase_heave_new,
            AddNoise=True,
            snr=20,
        )
        rollt[seed, :] = ship_motion_equal_energy(
            wave_amp,
            epsilon_seed,
            om_enc,
            time,
            seed,
            TF_amp_roll_new,
            TF_phase_roll_new,
            AddNoise=False,
            snr=20,
        )
        rollt_noisy[seed, :] = ship_motion_equal_energy(
            wave_amp,
            epsilon_seed,
            om_enc,
            time,
            seed,
            TF_amp_roll_new,
            TF_phase_roll_new,
            AddNoise=True,
            snr=20,
        )
        pitcht[seed, :] = ship_motion_equal_energy(
            wave_amp,
            epsilon_seed,
            om_enc,
            time,
            seed,
            TF_amp_pitch_new,
            TF_phase_pitch_new,
            AddNoise=False,
            snr=20,
        )
        pitcht_noisy[seed, :] = ship_motion_equal_energy(
            wave_amp,
            epsilon_seed,
            om_enc,
            time,
            seed,
            TF_amp_pitch_new,
            TF_phase_pitch_new,
            AddNoise=True,
            snr=20,
        )
        yawt[seed, :] = ship_motion_equal_energy(
            wave_amp,
            epsilon_seed,
            om_enc,
            time,
            seed,
            TF_amp_ywave_amp_new,
            TF_phase_ywave_amp_new,
            AddNoise=False,
            snr=20,
        )
        yawt_noisy[seed, :] = ship_motion_equal_energy(
            wave_amp,
            epsilon_seed,
            om_enc,
            time,
            seed,
            TF_amp_ywave_amp_new,
            TF_phase_ywave_amp_new,
            AddNoise=True,
            snr=20,
        )
  

    # Store relevant information regarding the simulated data in a dictionary:
    data_dict = {
    # --------------------
    # Sea State Information
    # --------------------
    "S1D": S1D,                     # 1D wave spectrum [m²·s/rad]
    "omega_wave": omega_wave,      # Wave frequencies [rad/s]
    "theta": theta,                # Discretized wave directions [rad]
    "theta0": theta0,              # Mean wave direction [rad]
    "S2D": S2D,                    # 2D wave spectrum [m²·s/rad]
    "omega": omega,                # Angular frequencies [rad/s]
    "betas_wave": beta_wave,      # Wave directions [rad]
    "om_enc": om_enc,              # Encounter frequencies [rad/s]
    "Hs_2d": Hs_2d,                # Significant wave height from 2D spectrum [m]
    "Tp_2d": Tp_2d,                # Peak period from 2D spectrum [s]
    "theta0_2d": theta0_2d,        # Mean wave direction from 2D spectrum [rad]
    "beta_s2d": beta_s2d,          # Mean wave direction from 2D spectrum [rad]
    
    # --------------------
    # Ship & Environment
    # --------------------
    "U": U,  # Ship forward speed [m/s]
    "psi": psi,                             # Ship heading angle [rad]

    # --------------------
    # Time & Sampling
    # --------------------
    "T": T,                       # Total simulation time [s]
    "fs": fs,                     # Sampling frequency [Hz]
    "time": time,                 # Time vector
    "n_timeSeries": n_timeSeries, # Number of time series
    "snr": snr,                   # Signal-to-noise ratio [dB]

    # --------------------
    # Wave Components
    # --------------------
    "wave_amplitude": wave_amp,   # Wave component amplitudes [m]
    "epsilon": epsilon,           # Random phase vector [rad]

    # --------------------
    # Ship Motions (Clean)
    # --------------------
    "wavet": wavet,               # Wave elevation [m]
    "surget": surget,             # Surge [m]
    "swayt": swayt,               # Sway [m]
    "heavet": heavet,             # Heave [m]
    "rollt": rollt,               # Roll [rad]
    "pitcht": pitcht,             # Pitch [rad]
    "yawt": yawt,                 # Yaw [rad]

    # --------------------
    # Ship Motions (Noisy)
    # --------------------
    "surget_noisy": surget_noisy,
    "swayt_noisy": swayt_noisy,
    "heavet_noisy": heavet_noisy,
    "rollt_noisy": rollt_noisy,
    "pitcht_noisy": pitcht_noisy,
    "yawt_noisy": yawt_noisy,

    }


    # Save to a binary .npy file (efficient for numpy arrays)
    simulation_number = 1

    np.save(
        f"results\data\shipMotion_EqualEnergy\sea_Hs_{np.round(Hs,2)}_Tp_{np.round(Tp,2)}_beta_{np.round(np.rad2deg(beta_s2d),2)}_U_{(np.round(U,2))}_sim_{simulation_number}.npy",
        data_dict,
    )
    
        
    
        
    
if __name__ == "__main__":
    main()
    