
import copy
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

import sys # added!
sys.path.append("..")

from complexity.computational_complexity import ConvergenceConfig, TestCase, _compute_complexity, load_results_from_file
from complexity.infrastructures.unidirectional_corridor import UnidirectionalCorridor
from complexity.models.social_force import SocialForceModel


output_file = Path(r'C:\Users\martijnsparnaa\surfdrive\PhD\Complexity, Scalability & Validity\data\test_01.npz')

def get_run_times_per_exp():
    if output_file.is_file():
        return load_results_from_file(output_file)

    ped_counts = [1,10,50,100]
    infras = [UnidirectionalCorridor.get_instance(step) for step in range(4)]
    ddfs = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

    test_cases = []
    for ped_count in ped_counts:
        for infra in infras:
            for ddf in ddfs:
                test_cases.append(TestCase(infra, ped_count, ddf))

    convergence_config = ConvergenceConfig(sig_level=0.05, allowable_err_perc=0.1)

    run_times_per_exp = _compute_complexity([SocialForceModel], test_cases, convergence_config, output_file=output_file)
    return run_times_per_exp


run_times_per_exp = get_run_times_per_exp()
infras = sorted({exp.infra_tag for exp in run_times_per_exp.keys()})
ped_counts = sorted({exp.ped_count for exp in run_times_per_exp.keys()})
ddfs = sorted({exp.density_distr_factor for exp in run_times_per_exp.keys()})

from matplotlib import cm

line_styles = ['-', '--', ':', '-.']
infra_to_line_style = {infras[ii]:line_styles[ii] for ii in range(len(infras))}
colors = cm.get_cmap('YlOrRd')(np.linspace(0.2, 1.0, len(ddfs)))
ddf_to_color = {ddfs[ii]:colors[ii,:] for ii in range(len(colors))}
_, axes = plt.subplots(1,2, sharex=True, sharey=True)

mean_lines: dict[tuple[str, float], list] = {}
median_lines: dict[tuple[str, float], list] = {}
x = ped_counts
for exp, run_times in run_times_per_exp.items():
    tag = (exp.infra_tag, exp.density_distr_factor)
    if tag not in mean_lines:
        line_props = {'linestyle':infra_to_line_style[exp.infra_tag], 'color':ddf_to_color[exp.density_distr_factor], 'label':f'{exp.infra_tag[-3:]} - {exp.density_distr_factor:.2f}'}
        mean_lines[tag] = [{}, line_props]
        median_lines[tag] = [{}, line_props]

    mean_lines[tag][0][exp.ped_count] = np.mean(run_times)
    median_lines[tag][0][exp.ped_count] = np.median(run_times)

for tag in mean_lines:
    mean_lines[tag][0] = [mean_lines[tag][0][key] for key in sorted(mean_lines[tag][0].keys())]
    median_lines[tag][0] = [median_lines[tag][0][key] for key in sorted(median_lines[tag][0].keys())]

sorted_tags = sorted(mean_lines.keys())

for tag in sorted_tags:
    y, line_prop = mean_lines[tag]
    axes[0].plot(x,y,**line_prop)

for tag in sorted_tags:
    y, line_prop = median_lines[tag]
    axes[1].plot(x,y,**line_prop)

plt.legend()
axes[0].set_title('Mean')
axes[1].set_title('Median')
axes[0].grid()
axes[1].grid()
plt.show()


# def plot_lines(ax, data, infras, ped_counts, ddfs):
#     x = range(len(base_set))
#     for entry_1 in set_1:
#         for entry_2 in set_2:
#             y = []
#             for entry in base_set: 
#                 y.append()
#             ax.plot(x,y,**line_props)


# labels = []
# positions = []
# means = []

# experiments = sorted(run_times_per_exp.keys(), key=lambda exp: exp.mean_dens)

# for ii, experiment in enumerate(experiments):
#     #bins = list(np.arange(0,max(run_times) + 0.01, 0.01))
#     box_plot_data.append(run_times_per_exp[experiment])
#     positions.append(ii)
#     means.append(np.mean(run_times_per_exp[experiment]))
#     labels.append(f'{experiment.mean_dens}')

# plt.boxplot(box_plot_data, positions=positions, notch=True, labels=labels)#, meanline=True)
# plt.plot(positions, means)
# plt.grid()
# plt.show()

