from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats

import sys # added!
sys.path.append("..")

from complexity.computational_complexity import load_results_from_file

filename = Path(r'C:\Users\martijnsparnaa\surfdrive\PhD\Complexity, Scalability & Validity\data\test_01.npz')

def get_run_times():
    if filename.is_file():
        return load_results_from_file(filename)

        return np.load()
 
start = 30
step = 30

def get_convergence_cts(run_times):
    M = 100
    steps = int((len(run_times) - start)/step)
    print(steps)
    means_per_size = defaultdict(list)
    medians_per_size = defaultdict(list)
    for ii in range(steps):
        N = start+ii*step
        for _ in range(M):
            sample = np.random.choice(run_times, size=N, replace=False)
            if has_mean_convergence(sample):
                means_per_size[len(sample)].append(np.mean(sample))
                medians_per_size[len(sample)].append(np.median(sample))
        
    return means_per_size, medians_per_size

def has_ks_convergence(sample):
    for ii in range(step):
        comp = sample[:-(1+ii)]
        _, p_val = scipy.stats.ks_2samp(sample, comp)
        if p_val < 0.99:
            return False

    return True

def has_mean_convergence(sample):
    sample_mean = np.mean(sample)
    for ii in range(step):
        comp = sample[:-(1+ii)]
        rel_diff = abs((sample_mean - np.mean(comp))/sample_mean)
        if rel_diff >= 0.005:
            return False

    return True

run_times_per_exp = get_run_times()
_, axes = plt.subplots(2,3)

exp_1 = list(filter(lambda exp: exp.infra_tag == "unidirectional_corridor_028" and exp.density_distr_factor == 0.0 and exp.ped_count == 1,  run_times_per_exp.keys()))[0]
means_per_size, medians_per_size = get_convergence_cts(run_times_per_exp[exp_1])
x = sorted(means_per_size.keys())
box_data_mean = [means_per_size[size] for size in x]
axes[0][0].boxplot(box_data_mean, labels=x)
axes[0][0].axhline(np.mean(run_times_per_exp[exp_1]), linewidth=2, label='Mean')
axes[0][0].set_title(f'#peds = 1 - Mean')

x = sorted(medians_per_size.keys())
box_data_median = [medians_per_size[size] for size in x]
axes[0][1].boxplot(box_data_median, labels=x)
axes[0][1].axhline(np.median(run_times_per_exp[exp_1]), linewidth=2, label='Median')
axes[0][1].set_title(f'#peds = 1 - Median')

exp_100 = list(filter(lambda exp: exp.infra_tag == "unidirectional_corridor_028" and exp.density_distr_factor == 0.0 and exp.ped_count == 100,  run_times_per_exp.keys()))[0]
means_per_size, medians_per_size = get_convergence_cts(run_times_per_exp[exp_100])
x = sorted(means_per_size.keys())
box_data_mean = [means_per_size[size] for size in x]
axes[1][0].boxplot(box_data_mean, labels=x)
axes[1][0].axhline(np.mean(run_times_per_exp[exp_100]), linewidth=2, label='Mean')
axes[1][0].set_title(f'#peds = 100 - Mean')

x = sorted(medians_per_size.keys())
box_data_median = [medians_per_size[size] for size in x]
axes[1][1].boxplot(box_data_median, labels=x)
axes[1][1].axhline(np.median(run_times_per_exp[exp_100]), linewidth=2, label='Median')
axes[1][1].set_title(f'#peds = 100 - Median')


def get_box_plot_data(run_times):
    box_plot_data = []
    labels = []

    steps = int((len(run_times) - start)/step)
    for ii in range(steps):
        N = start+ii*step
        box_plot_data.append(run_times[:N])
        labels.append(N)

    return box_plot_data, labels

box_plot_data, labels = get_box_plot_data(run_times_per_exp[exp_1])
axes[0][2].boxplot(box_plot_data, labels=labels)
axes[0][2].set_title(f'Run times - #peds = 1')

box_plot_data, labels = get_box_plot_data(run_times_per_exp[exp_100])
axes[1][2].boxplot(box_plot_data, labels=labels)
axes[1][2].set_title(f'Run times - #peds = 100')



for ax in axes.flatten():
    ax.grid()

plt.show()
