# %%
import pandas as pd
import os
import statsmodels.formula.api as smf
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy.stats as stats
from statsmodels.stats.outliers_influence import variance_inflation_factor
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

# Option to disable logging for simplified output
# Uncomment the next line to disable all logging
# logging.disable(logging.CRITICAL)

# Info on linear mixed effects model: https://www.statsmodels.org/stable/mixed_linear.html
# Info on function (smf.mixedlm): https://www.statsmodels.org/devel/generated/statsmodels.regression.mixed_linear_model.MixedLM.html

# Specify the product to analyse
product = 'Acetate'  # Change this to 'Product1', 'Product2', etc., as needed

# Load Excel file
cwd = os.getcwd()
file_path = 'DoE_Product_Concentrations_HAc.xlsx'
df = pd.read_excel(file_path, sheet_name='Sheet2')

# Ensure that the specified product exists in the dataframe
if product not in df.columns:
    raise ValueError(f"Product '{product}' not found in the dataframe columns.")

# Compute Variance Inflation Factor (VIF)
X = df[['pH', 'CO2_rate', 'Added_HAc']]
vif_data = pd.DataFrame()
vif_data["feature"] = X.columns
vif_data["VIF"] = [variance_inflation_factor(X.values, i) for i in range(len(X.columns))]
logging.info(vif_data)

# Base model with main effects
formula = f'{product} ~ pH + CO2_rate + Added_HAc'

# Fit the mixed-effects model
try:
    # Define and initialize a linear mixed-effects model
    # - 'formula' specifies the dependent and independent variables
    # - 'data=df' indicates the DataFrame containing the data
    # - 'groups="Reactor"' sets the grouping factor for random effects (each reactor gets its own intercept)
    # - 're_formula="1"' specifies that only a random intercept is included (no random slopes)
    model = smf.mixedlm(formula, data=df, groups='Reactor', re_formula="1")
    # Fit the model using:
    # - 'method="lbfgs"' -> L-BFGS optimisation algorithm for maximum likelihood estimation
    # - 'maxiter=1000' -> Allows up to 1000 iterations for optimisation
    # - 'full_output=True' -> Returns detailed output of the optimisation process
    # - 'reml=False' -> Uses full maximum likelihood (ML) estimation instead of restricted maximum likelihood (REML)
    model_fit = model.fit(method='lbfgs', maxiter=1000, full_output=True, reml=False)
    logging.info(model_fit.summary())
except RuntimeError as e:
    logging.error(f"RuntimeError during model fitting: {e}")
    raise e

# Construct the regression equation
coefficients = model_fit.params
intercept_original = coefficients['Intercept']
coefficients_original = {predictor: coefficients[predictor] for predictor in ['pH', 'CO2_rate', 'Added_HAc']}

equation = f"{product} = {intercept_original:.2f}"
for predictor, coef in coefficients_original.items():
    if coef > 0:
        equation += f" + {coef:.2f} * {predictor}"
    else:
        equation += f" - {-coef:.2f} * {predictor}"

# Print the regression equation
logging.info("\nRegression Equation in Original Scale:")
logging.info(equation)

# Compute the "fixed-only" linear predictor
#     - model_fit.model.exog  is the fixed‐effect design matrix (nxp)
#     - model_fit.fe_params   is the p×1 vector of fixed‐effect coefficients
linpred_fixed = np.dot(model_fit.model.exog, model_fit.fe_params)
var_fixed    = np.var(linpred_fixed, ddof=0)     # population‐style variance

# Extract random‐intercept variance (assuming just a random intercept)
var_random   = float(model_fit.cov_re.iloc[0, 0])

# Extract residual variance from the mixed model
var_resid    = model_fit.scale

# Compute Marginal and Conditional R²
marginal_r2   = var_fixed / (var_fixed + var_random + var_resid)
conditional_r2 = (var_fixed + var_random) / (var_fixed + var_random + var_resid)

logging.info(f"Marginal R² = {marginal_r2:.3f}")
logging.info(f"Conditional R² = {conditional_r2:.3f}")

# Compute predicted values and residuals from the mixed-effects model
df['predicted'] = model_fit.fittedvalues
df['residuals'] = df[product] - df['predicted']

# 1. Predicted vs. Observed Values Plot
plt.figure(figsize=(8, 6))
sns.scatterplot(x='predicted', y=product, data=df)
plt.plot([df['predicted'].min(), df['predicted'].max()],
         [df['predicted'].min(), df['predicted'].max()],
         color='red', lw=2)
plt.xlabel(f"Predicted {product}")
plt.ylabel(f"Observed {product}")
plt.title('Predicted vs. Observed Values')
plt.show()

# 2. Residual Plot
plt.figure(figsize=(8, 6))
sns.scatterplot(x='predicted', y='residuals', data=df)
plt.axhline(0, color='red', lw=2)
plt.xlabel(f"Predicted {product}")
plt.ylabel('Residuals')
plt.title('Residual Plot')
plt.show()

# Q-Q plot using scipy
plt.figure(figsize=(8, 6))
stats.probplot(df['residuals'], dist="norm", plot=plt)
plt.title('Q-Q Plot of Residuals')
plt.xlabel('Theoretical Quantiles')
plt.ylabel('Sample Quantiles')
plt.show()
# %%%
