import pandas as pd
import matplotlib.pyplot as plt
from itertools import cycle
import numpy as np
import time
import datetime
from scipy.optimize import curve_fit


##################settings
colors={ 
        
        'blue'  : '#4477AA', # blue
        'cyan': '#66CCEE', # cyan
        'green': '#228833', # green
        'yellow': '#CCBB44', # yellow
        'orange': '#EE7733', # orange
        'red'   : '#EE6677', # red
        'purple': '#AA3377', # purple
        'grey': '#BBBBBB' #grey
        }
color_cycle = cycle(colors.values())

marker_styles = ['o', 's', 'D', '^', 'v', '<', '>', 'p', '*', 'h', 'H', '+', 'x', 'd', '_','|']
marker_cycle = cycle(marker_styles)

line_styles =[
     ('loosely dotted',        (0, (1, 10))),
     ('dotted',                (0, (1, 1))),
     ('densely dotted',        (0, (1, 1))),
     ('long dash with offset', (5, (10, 3))),
     ('loosely dashed',        (0, (5, 10))),
     #('dashed',                (0, (5, 5))),
     ('densely dashed',        (0, (5, 1))),

     ('loosely dashdotted',    (0, (3, 10, 1, 10))),
     ('dashdotted',            (0, (3, 5, 1, 5))),
     ('densely dashdotted',    (0, (3, 1, 1, 1))),

     ('dashdotdotted',         (0, (3, 5, 1, 5, 1, 5))),
     ('loosely dashdotdotted', (0, (3, 10, 1, 10, 1, 10))),
     ('densely dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))]
line_cycle = cycle(line_styles)
################################################################################
### read the data
kb=1.38e-23
T=298
df_flat = pd.read_excel("./fix_ch_flat20250503_202800.xlsx")
variable_name="Csalt"
geometry='flat'
ct = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') # ct stores current time
ct_str=str(ct)

shifts_flat = np.zeros(len(np.unique(df_flat[variable_name].values)))
coeffs_flat = np.zeros(len(np.unique(df_flat[variable_name].values)))

# Define the function we want to fit (e.g., a quadratic: y = a*x^2 + b*x + c)
def osb_function(x, const, const1):
    return const1 + 200*const*x**(0.5)

### for fitting line for osb
i=0
for csalt_value in np.unique(df_flat[variable_name].values[1:]):
    x_data = df_flat.loc[df_flat["Csalt"] == csalt_value]["p2_charge"][3:] # exclude case m=1
    y_data = df_flat.loc[df_flat["Csalt"] ==csalt_value]["delta1lin"][3:]

    # Use curve_fit to fit the quadratic function to the noisy data
    params, covariance = curve_fit(osb_function, x_data, y_data)
    # Extract the fitted parameters
    a_fit = params[0]
    b_fit = params[1]
    print(f"flat {csalt_value}  Fitted parameters: a = {a_fit}, b = {b_fit}")
    coeffs_flat[i]=a_fit
    shifts_flat[i]=b_fit
    i=i+1

######################################################################################################
##################################### read cylindrical part
df_cyl = pd.read_excel("./fix_ch_results_cylindrical20250503_202657.xlsx")
variable_name="Csalt"
geometry='cylindrical'
shifts_cyl = np.zeros(len(np.unique(df_cyl[variable_name].values)))
coeffs_cyl = np.zeros(len(np.unique(df_cyl[variable_name].values)))
### for fitting line for osb
i=0
for csalt_value in np.unique(df_cyl[variable_name].values[1:]):
    x_data = df_cyl.loc[df_cyl["Csalt"] == csalt_value]["p2_charge"][3:] # exclude case m=1
    y_data = df_cyl.loc[df_cyl["Csalt"] ==csalt_value]["delta1lin"][3:]

    # Use curve_fit to fit the quadratic function to the noisy data
    params, covariance = curve_fit(osb_function, x_data, y_data)
    # Extract the fitted parameters
    a_fit = params[0]
    b_fit = params[1]
    print(f"cyl {csalt_value}  Fitted parameters: a = {a_fit}, b = {b_fit}")
    coeffs_cyl[i]=a_fit
    shifts_cyl[i]=b_fit
    i=i+1

########################### in 1 graph
breit = 8.27
long = breit
x_datanew = df_cyl.loc[df_cyl["Csalt"] == 0.001]["p2_charge"]
fig, ((ax1, ax4), (ax2,ax5), (ax3, ax6)) = plt.subplots(3, 2, figsize=(breit,long))
y_fit_flat = osb_function(np.array(x_datanew), coeffs_flat[0], shifts_flat[0])
y_fit_cyl = osb_function(np.array(x_datanew), coeffs_cyl[0], shifts_cyl[0])

 # strain
for variable,farbe, mar in zip(np.unique(df_flat[variable_name].values), colors.values(), marker_styles):
    x = df_flat.loc[df_flat[variable_name] == variable]["p2_charge"] # exclude case m=1
    delta1lin = df_flat.loc[df_flat[variable_name] == variable]["delta1lin"].values
    delta0lin = df_flat.loc[df_flat[variable_name] == variable]["delta0lin"].values[0]
    ax1.plot(x, delta1lin, marker=mar, label=str(variable) + ' M', color=farbe)
    ax1.axhline(delta0lin, ls='dashed', c=farbe)
ax1.set_ylim(0,0.75)
ax1.set_ylabel(r"$\delta_{\mathrm{eq}}$",fontsize=16)
ax1.plot(x_datanew, y_fit_flat, label='osmotic brush', color='k', linewidth=2,  ls='dashed')
ax1.text(0.85,delta0lin+0.03, r'$\delta_{\mathrm{eq}}^{\mathrm{0}}$', fontsize=16)

#### delta p
for variable,farbe, mar in zip(np.unique(df_flat[variable_name].values),  colors.values(), marker_styles):
    x = df_flat.loc[df_flat[variable_name] == variable]["p2_charge"] # exclude case m=1
    deltap_pro = df_flat.loc[df_flat[variable_name] == variable]["deltap_pro3"].values/1000000
    ax2.plot(x,deltap_pro,marker=mar, label=f'$10^{{{np.log10(variable):.0f}}}$ M', color=farbe)
for variable,farbe, mar in zip(np.unique(df_flat[variable_name].values),  colors.values(), marker_styles):
    x = df_flat.loc[df_flat[variable_name] == variable]["p2_charge"] # exclude case m=1
    deltap_dep = df_flat.loc[df_flat[variable_name] == variable]["deltap_dep3"].values/1000000
    ax2.plot(x,deltap_dep,marker=mar, label=f'$10^{{{np.log10(variable):.0f}}}$ M',color=farbe)
ax2.axhline(0, ls='dashed', c='grey')
ax2.set_ylabel(r"$\Delta \Pi$ [MPa]", fontsize=16)

#### work
for variable,farbe, mar in zip(np.unique(df_flat[variable_name].values), colors.values(), marker_styles):
    x = df_flat.loc[df_flat[variable_name] == variable]["p2_charge"] # exclude case m=1
    w_chain = df_flat.loc[df_flat[variable_name] == variable]["int_bad"].values / (df_flat.loc[df_flat[variable_name] == variable]["M"].values)
    ax3.plot(x, w_chain/kb/T, marker=mar, label=str(variable) + ' M',color=farbe)
ax3.axhline(0, ls='dashed', c='grey')
legend = ax3.legend(title=r'$c_{\mathrm{s}}$:',bbox_to_anchor=(0, -0.6, 1, -0.6), loc="lower center",
                mode="expand", borderaxespad=0, ncol=2, fontsize = 11, title_fontsize=12)
ax3.set_ylabel(r"$W_{\mathrm{chain}}$ $[kT]$", fontsize=16)

ax3.set_xlabel(r"charge fraction", fontsize=14)
ax1.tick_params(labelbottom=False, direction="in")  # Hide x-axis labels on ax1
ax2.tick_params(labelbottom=False, direction="in") 
ax3.tick_params(direction="inout") 
# Adjust the space between plots to be zero
plt.subplots_adjust(hspace=0)

######################################
###################################################cyl ################
################################


##################settings
colors={ 
         'grey': '#BBBBBB', #grey
        'blue'  : '#4477AA', # blue
        'cyan': '#66CCEE', # cyan
        'green': '#228833', # green
        'yellow': '#CCBB44', # yellow
        'orange': '#EE7733', # orange
        'red'   : '#EE6677', # red
        'purple': '#AA3377' # purple
       
        }
color_cycle = cycle(colors.values())

marker_styles = ['|', 'o', 's', 'D', '^', 'v', '<', '>', 'p', '*', 'h', 'H', '+', 'x', 'd', '_',]
marker_cycle = cycle(marker_styles)

line_styles =[
     ('loosely dotted',        (0, (1, 10))),
     ('dotted',                (0, (1, 1))),
     ('densely dotted',        (0, (1, 1))),
     ('long dash with offset', (5, (10, 3))),
     ('loosely dashed',        (0, (5, 10))),
     #('dashed',                (0, (5, 5))),
     ('densely dashed',        (0, (5, 1))),

     ('loosely dashdotted',    (0, (3, 10, 1, 10))),
     ('dashdotted',            (0, (3, 5, 1, 5))),
     ('densely dashdotted',    (0, (3, 1, 1, 1))),

     ('dashdotdotted',         (0, (3, 5, 1, 5, 1, 5))),
     ('loosely dashdotdotted', (0, (3, 10, 1, 10, 1, 10))),
     ('densely dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))]
line_cycle = cycle(line_styles)


 # strain
for variable,farbe, mar in zip(np.unique(df_cyl[variable_name].values), colors.values(), marker_styles):
    x = df_cyl.loc[df_cyl[variable_name] == variable]["p2_charge"] # exclude case m=1
    delta1lin = df_cyl.loc[df_cyl[variable_name] == variable]["delta1lin"].values
    delta0lin = df_cyl.loc[df_cyl[variable_name] == variable]["delta0lin"].values[0]
    ax4.plot(x, delta1lin, marker=mar, label=str(variable) + ' M', fillstyle='none', color=farbe)
    ax4.axhline(delta0lin, ls='dashed', c=farbe)
ax4.set_ylim(0,0.75)
ax4.plot(x_datanew, y_fit_cyl, label='osmotic brush', color='k', linewidth=2,  ls='dashed')
ax4.text(0.85,delta0lin+0.035, r'$\delta_{\mathrm{eq}}^{\mathrm{0}}$', fontsize=16)

#### delta p
for variable,farbe, mar in zip(np.unique(df_cyl[variable_name].values),  colors.values(), marker_styles):
    x = df_cyl.loc[df_cyl[variable_name] == variable]["p2_charge"] # exclude case m=1
    deltap_pro = df_cyl.loc[df_cyl[variable_name] == variable]["deltap_pro3"].values/1000000
    ax5.plot(x,deltap_pro,marker=mar, fillstyle='none', color=farbe)
for variable,farbe, mar in zip(np.unique(df_cyl[variable_name].values),  colors.values(), marker_styles):
    x = df_cyl.loc[df_cyl[variable_name] == variable]["p2_charge"] # exclude case m=1
    deltap_dep = df_cyl.loc[df_cyl[variable_name] == variable]["deltap_dep3"].values/1000000
    ax5.plot(x,deltap_dep,marker=mar, fillstyle='none',color=farbe)
ax5.axhline(0, ls='dashed', c='grey')

#### work
for variable,farbe, mar in zip(np.unique(df_cyl[variable_name].values), colors.values(), marker_styles):
    x = df_cyl.loc[df_cyl[variable_name] == variable]["p2_charge"] # exclude case m=1
    w_chain = df_cyl.loc[df_cyl[variable_name] == variable]["int_bad"].values / (df_cyl.loc[df_cyl[variable_name] == variable]["M"].values)
    ax6.plot(x, w_chain/kb/T, marker=mar, label=str(variable) + ' M',fillstyle='none', color=farbe)
ax6.axhline(0, ls='dashed', c='grey')

ax6.set_xlabel(r"charge fraction", fontsize=14)
ax4.tick_params(labelbottom=False, direction="in")  # Hide x-axis labels on ax1
ax5.tick_params(labelbottom=False, direction="in") 
ax6.tick_params(direction="inout") 
legend2 = ax6.legend(title=r'$c_{\mathrm{s}}$:' ,bbox_to_anchor=(0, -0.7, 1, -0.7), loc="lower center",
                mode="expand", borderaxespad=0, ncol=2, fontsize = 11, title_fontsize=12)
# Adjust the space between plots to be zero
plt.subplots_adjust(hspace=0)

ax1.set_title("Flat", fontsize=14)
ax4.set_title("Cylindrical", fontsize=14)

ax1.tick_params(axis='both', labelsize=12)
ax2.tick_params(axis='both', labelsize=12)
ax3.tick_params(axis='both', labelsize=12)
ax4.tick_params(axis='both', labelsize=12)
ax5.tick_params(axis='both', labelsize=12)
ax6.tick_params(axis='both', labelsize=12)

plt.savefig("fix_charge_vary"+variable_name+".pdf", bbox_inches="tight")