from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats

from complexity.computational_complexity import ConvergenceConfig, TestCase, _compute_complexity
from complexity.infrastructures.unidirectional_corridor import UnidirectionalCorridor
from complexity.models.social_force import SocialForceModel

filename = Path(r'C:\Users\martijnsparnaa\surfdrive\PhD\Complexity, Scalability & Validity\data\test_01.npz')

def get_run_times():
    if filename.is_file():
        return np.load(filename)

    unidir_corridor = UnidirectionalCorridor.get_instance(0)
    test_cases = [
        #TestCase(unidir_corridor, 1, 0),
        #TestCase(unidir_corridor, 10, 0),
        TestCase(unidir_corridor, 50, 0),
        #TestCase(unidir_corridor, 100.area, 0)    
    ]

    convergence_config = ConvergenceConfig(init_rep_count=3000)
    run_times_per_exp = _compute_complexity([SocialForceModel], test_cases, convergence_config)

    save_data = {str(experiment):np.array(run_times) for experiment, run_times in run_times_per_exp.items()}

    np.savez(filename, **save_data)

    return save_data

start = 30
step = 5


def get_convergence_cts(run_times):
    M = 100
    steps = int((len(run_times) - start)/5)
    means_per_size = defaultdict(list)
    medians_per_size = defaultdict(list)
    for ii in range(steps):
        N = start+ii*step
        for _ in range(M):
            sample = np.random.choice(run_times, size=N, replace=False)
            if has_ks_convergence(sample):
                means_per_size[len(sample)].append(np.mean(sample))
                medians_per_size[len(sample)].append(np.median(sample))
        

def check_convergence(run_times):    
    M = 100
    steps = 50#int((len(run_times) - start)/5)
    for ii in range(steps):
        N = start+ii*step
        mean_conv_count = 0
        ks_conv_count = 0
        for _ in range(M):
            sample = np.random.choice(run_times, size=N, replace=False)
            mean_conv_count += int(has_mean_convergence(sample))
            ks_conv_count += int(has_ks_convergence(sample))
        print(f'N = {N:04d}, mean = {mean_conv_count:03d}/{M}, ks = {ks_conv_count:03d}/{M}')
            
def has_mean_convergence(sample):
    sample_mean = np.mean(sample)
    for ii in range(step):
        comp = sample[:-(1+ii)]
        rel_diff = abs((sample_mean - np.mean(comp))/sample_mean)
        if rel_diff >= 0.005:
            return False

    return True

def has_ks_convergence(sample):
    for ii in range(step):
        comp = sample[:-(1+ii)]
        _, p_val = scipy.stats.ks_2samp(sample, comp)
        if p_val < 0.99:
            return False

    return True

run_times_per_exp = get_run_times()
for name, run_times in run_times_per_exp.items():
    print(name)
    check_convergence(run_times)



