'''
This file contains functions for plotting necessary graphs.

'''

# Importing necessary libraries
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrow
from matplotlib.patches import Polygon
from matplotlib.transforms import Affine2D
# ─────────────────────────────────────────────────────────────
# FUNCTION: Plotting 1D wave spectrum
# ─────────────────────────────────────────────────────────────
def plt_S1D(ang_freq, wave_spectrum, wave_spectra_type, Hs=None, Tp=None):
    plt.plot(ang_freq, wave_spectrum, label=f"(Hs={Hs}m, Tp={Tp})")
    plt.title(f"Wave Spectrum '{wave_spectra_type}' ", fontweight="bold", fontsize=14)
    plt.xlabel(r"Angular frequency  $\omega$ [rad/s]", fontsize=12)
    plt.ylabel(r"Energy density $S(\omega)$ [m$^2\cdot$s/rad]", fontsize=12)
    plt.legend()
    plt.grid(True)
    plt.show()
    return

# ─────────────────────────────────────────────────────────────
# FUNCTION: Plotting 2D wave spectrum
# ─────────────────────────────────────────────────────────────
def plt_S2D_polar(ang_freq, directions, spectrum, tick_step=0.5):
    """
    Plot polar plot of multidirectional wave spectrum.

    Parameters:
    - ang_freq: Array of wave angular frequencies.
    - directions: Array of wave directions in radians.
    - spectrum: 2D array representing the wave spectrum (frequencies x directions).
    - tick_step: Spacing between frequency ticks (default: 0.5 rad/s).
    """
    fig, ax = plt.subplots(subplot_kw={"projection": "polar"})
    
    # Create a meshgrid for polar plot
    theta, r = np.meshgrid(directions, ang_freq)

    # Plot the polar plot
    cax = ax.contourf(theta, r, spectrum, cmap='viridis') 

    # Set custom frequency tick positions and labels based on the specified tick_step
    max_freq = np.max(ang_freq)
    frequency_ticks = np.arange(0, max_freq + tick_step, tick_step)  # Tick positions with specified spacing
    ax.set_yticks(frequency_ticks)  # Set y-tick positions
    ax.set_yticklabels([f"{freq:.2f}" for freq in frequency_ticks])  # Set corresponding labels

    #ax.set_title(f"Multidirectional Wave Spectrum", fontweight='bold', fontsize=14)
    fig.colorbar(cax, label="Wave Energy [m^2.s/rad]")

    plt.show()
    plt.close(fig)

# ────────────────────────────────────────────────────────────
# FUNCTION: Plotting 2D wave spectrum in polar coordinates
# ────────────────────────────────────────────────────────────
def plt_polar_S2D(omega_wave, theta, S2D, ship_heading_rad=0.0, n_ticks=6):
    """
    Plots a polar wave spectrum and overlays a small rotated ship hull at the center.

    Parameters:
        omega_wave (np.ndarray): Wave frequencies [rad/s], radial axis.
        theta (np.ndarray): Wave directions [rad], angular axis.
        S2D (np.ndarray): 2D spectrum of shape (len(omega_wave), len(theta)).
        ship_heading_rad (float): Ship heading angle [rad], 0 = North.
        n_ticks (int): Number of radial ticks to display.
    """
    if S2D.shape != (len(omega_wave), len(theta)):
        raise ValueError("S2D shape must be (len(omega_wave), len(theta))")

    plt.close('all')
    fig, ax = plt.subplots(subplot_kw={'projection': 'polar'}, figsize=(8, 6))

    # Plot the spectrum
    c = ax.contourf(theta, omega_wave, S2D, cmap='viridis')

    # Set radial ticks
    min_freq, max_freq = np.min(omega_wave), np.max(omega_wave)
    frequency_ticks = np.linspace(min_freq, max_freq, n_ticks)
    ax.set_yticks(frequency_ticks)
    ax.set_yticklabels([f"{f:.2f}" for f in frequency_ticks])

    # Polar settings
    ax.set_theta_zero_location('N')
    ax.set_theta_direction(-1)

    # === Add a ship-shaped polygon (small hull) at center ===
    # Define ship in polar coordinates (radius, angle)
    hull_length = max_freq * 0.81
    hull_width = hull_length * 0.3

    # Define ship shape in Cartesian coordinates (x, y)
    ship_shape = np.array([
        [0.0, hull_length],                    # bow
        [-hull_width / 2, hull_length * 0.4],  # left front
        [-hull_width / 2, 0.0],                # left back
        [ hull_width / 2, 0.0],                # right back
        [ hull_width / 2, hull_length * 0.4]   # right front
    ])

    # Apply rotation
    rotation = Affine2D().rotate(ship_heading_rad - np.pi/2)  # Rotate by ship heading
    ship_shape_rotated = rotation.transform(ship_shape)

    # Convert Cartesian to polar: r = sqrt(x² + y²), θ = atan2(y, x)
    ship_r = np.hypot(ship_shape_rotated[:, 0], ship_shape_rotated[:, 1])
    ship_theta = np.arctan2(ship_shape_rotated[:, 1], ship_shape_rotated[:, 0])

    # Plot as a polygon in polar coordinates
    ax.plot(np.append(ship_theta, ship_theta[0]),
            np.append(ship_r, ship_r[0]),
            color='red', linewidth=2, label='Ship')

    # Add colorbar and title
    fig.colorbar(c, ax=ax, label='Wave Energy')
    ax.set_title("2D Wave Spectrum with Ship Orientation", va='bottom')
    ax.legend(loc='upper right')
    plt.tight_layout()
    plt.show()

# ────────────────────────────────────────────────────────────
# FUNCTION: Plotting the trasfer function amplitude or phase of the
# ith DOF and jth relative diretcion 
# ────────────────────────────────────────────────────────────
def plt_TF_inter(omega_TF, omega, TF, TF_new, beta_inedx, type="amplitude"):
    if type == "amplitude":
        amplitude = "Amplitude"
    else:
        amplitude = "Phase"
    plt.figure()
    plt.plot(omega_TF, TF[:, beta_inedx], label="Original")
    plt.plot(omega, TF_new[:, beta_inedx], label="Interpolated")
    plt.title(f" Motion TF {amplitude}", fontsize=14)
    plt.xlabel("Angular frequency [rad/s]", fontsize=12)
    plt.ylabel(f"{amplitude}", fontsize=12)
    plt.grid()
    plt.legend()
    plt.show()

def plt_TF_inter_Eq_Method(omega_TF, omega, TF, TF_new, chosen_beta_deg, Tf_beta, beta_wave, type="amplitude"):
    if type == "amplitude":
        amplitude = "Amplitude"
    else:
        amplitude = "Phase"
    # find the index of the relative wave direction 
    beta_index_TF = np.where(Tf_beta == np.deg2rad(180))[0] 
    beta_index_TFnew = np.argmin(np.abs(beta_wave - np.deg2rad(180)))
    plt.figure()
    plt.plot(omega_TF, TF[:,  beta_index_TF], label="Original")
    plt.plot(omega, TF_new[:, beta_index_TFnew], label="Interpolated")
    plt.title(f" Motion TF {amplitude}", fontsize=14)
    plt.xlabel("Angular frequency [rad/s]", fontsize=12)
    plt.ylabel(f"{amplitude}", fontsize=12)
    plt.grid()
    plt.legend()
    plt.show()


# ────────────────────────────────────────────────────────────
# FUNCTION: Ploting comulative energy distribution function P(θ)
# ────────────────────────────────────────────────────────────
def plt_cumulative_energy_distribution(theta, P_theta, theta_inter, D, energy_steps):
    """
    Plots the cumulative energy distribution function P(θ).

    Parameters:
    -----------
    - theta: array-like, wave directions in radians.
    - P_theta: array-like, cumulative energy distribution values corresponding to theta.
    - theta_inter: array-like, wave directions (in radians) where notable points.
    - D: directional distribution function.
    """
    plt.plot(np.rad2deg(theta), P_theta, label="P(θ)")
    plt.plot(np.rad2deg(theta_inter), energy_steps, "ro", label="θ intersections")

    plt.plot(np.rad2deg(theta), D[0], "g--", label="Directional Distribution")
    plt.title("Cumulative Energy Distribution Function")
    plt.xlabel("Wave direction θ [°]")
    plt.ylabel("Cumulative Energy Distribution Function P(θ)")
    plt.grid()
    plt.legend()
    plt.show()


# ─────────────────────────────────────────────────────────────
# FUNCTION -AKF: plotining the simulayted and estimated motion
#  data 
# ─────────────────────────────────────────────────────────────
#NOte : and y must have same first dimension, but have shapes (10000,) and (1, 10000
def plot_sim_and_est_motion_data(
    motion_time, motion_data, predicted_motion_data, dof_index
):
    plt.plot(motion_time, motion_data, label="motion data")
    plt.plot(motion_time, predicted_motion_data, label="predicted motion data")
    plt.title(f" Theorical and estimated motion data of the {dof_index+1} ", fontsize=14)
    plt.xlabel("Time (s)", fontsize=12)
    plt.ylabel("Displacement(m,rad)", fontsize=12)
    plt.grid()
    plt.legend()
    plt.show()

# ─────────────────────────────────────────────────────────────
# FUNCTION -AKF: plotining the simulayted and estimated sea state 
#  data over time
# ─────────────────────────────────────────────────────────────
def plot_estimated_sea_state_over_time(motion_time, estimated_ss, chosen_ss, simulated_ss, case):
     simulated_ss_over_time = np.zeros(len(estimated_ss))
     for i in range(len(estimated_ss)):
         simulated_ss_over_time[i] = simulated_ss
     plt.plot(motion_time, estimated_ss, linestyle='--', linewidth=2, label='Estimated')
     plt.plot(motion_time, simulated_ss_over_time, linewidth=2.5, label='Simulated')
     #plt.title(f"Simulated and estimated {chosen_ss} wave spectrum over time", fontweight='bold', fontsize=16)
     plt.xticks(fontsize=14)
     plt.yticks(fontsize=14)
     plt.xlabel("Time [s]", fontsize=16)
     plt.ylabel(f"{chosen_ss}", fontsize=16)
     plt.grid(True)
     plt.legend(fontsize=14)
     plt.show()

# ─────────────────────────────────────────────────────────────
# FUNCTION -AKF: plotining the simulated and estimated 1D wave 
# spectrum
# ─────────────────────────────────────────────────────────────
def plot_simu_and_est_S1D(omega_inter, S1D, original_omega, original_S1D):
        plt.plot(omega_inter, S1D, linewidth=2, label="Estimated")
        plt.plot(original_omega, original_S1D,  linewidth=2, label="Simulated")
        plt.xlabel("Frequency [rad/s]", fontsize=16)
        plt.ylabel("Spectrum [m^2.s/rad]", fontsize=16)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)
        # plt.title(
        #     "Simulated and estimated 1D wave spectrum",
        #     fontweight="bold",
        #     fontsize=16,
        # )
        plt.legend(fontsize=14)
        plt.grid(True)
        plt.show()
