import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import seaborn as sns
import shutil
from collections import defaultdict

from ema_workbench import Model, RealParameter, Constant, Policy, ScalarOutcome, ema_logging
from ema_workbench.em_framework.optimization import GenerationalBorg
from ema_workbench.em_framework.evaluators import MultiprocessingEvaluator, SequentialEvaluator
from ema_workbench.em_framework.optimization import ArchiveLogger, OperatorProbabilities

from DPS.problem_formulation import get_problem_formulation_rbfs
from network import manhattan_graph
from DPS.sort_and_filter import sort_and_filter_pol_fug_city as sort_and_filter_nodes
from fugitive_interception_model_grid import fugitive_interception_model
from DPS.rbf import rbf_gaussian, rbf_cubic, rbf_linear

ema_logging.log_to_stderr(ema_logging.INFO)


def run_dps(nfe=5e4, graph_type='manhattan', manhattan_diameter=10, n_realizations=10, num_units=1, num_sensors=1,
            num_rbfs=1, rbf=None, instance=0, seed=0):
    t_max = int(5 + (0.5 * manhattan_diameter))
    mode = 'optimization'

    try:
        shutil.rmtree('./results/tmp')
    except OSError as e:
        pass

    if graph_type == 'manhattan':
        graph, labels, labels_inv, pos = manhattan_graph(manhattan_diameter)

    # graph = nx.path_graph(7)
    # police_start = 1
    # fugitive_routes = np.ones((n_realizations, t_max)) * 6  # stay at node 6
    # graph = pd.read_pickle(
    #     f"../data/{graph_type}/graph.pkl")
    police_start = pd.read_pickle(
        f"../../data/{graph_type}/units_start_N{manhattan_diameter}_T{t_max}_R{n_realizations}_U{num_units}_numsensors{num_sensors}_instance{instance}.pkl")
    fugitive_start = pd.read_pickle(
        f"../../data/{graph_type}/fugitive_start_N{manhattan_diameter}_T{t_max}_R{n_realizations}_U{num_units}_numsensors{num_sensors}_instance{instance}.pkl")
    fugitive_routes = pd.read_pickle(
        f"../../data/{graph_type}/fugitive_routes_N{manhattan_diameter}_T{t_max}_R{n_realizations}_U{num_units}_numsensors{num_sensors}_instance{instance}.pkl")
    sensor_locations = pd.read_pickle(
        f"../../data/{graph_type}/sensors_N{manhattan_diameter}_T{t_max}_R{n_realizations}_U{num_units}_numsensors{num_sensors}_instance{instance}.pkl")

    labels_perunit_sorted, labels_perunit_inv_sorted, labels_full_sorted, labels_full_sorted_inv = sort_and_filter_nodes(
        graph=graph,
        fugitive_start=fugitive_start,
        fugitive_routes=fugitive_routes,
        police_start=police_start,
        t_max=t_max)

    fugitive_routes_labeled = defaultdict(dict)
    for realization in range(n_realizations):
        for t in fugitive_routes[realization].keys():
            fugitive_routes_labeled[realization][t] = labels[fugitive_routes[realization][t]]

    sensor_locations = [labels[loc] for loc in sensor_locations]
    police_start = [labels[police_start[u]] for u in range(len(police_start))]

    sensor_detections = {i: np.zeros((n_realizations, t_max)) for i, _ in enumerate(sensor_locations)}

    # sensors stay flipped
    for t_interval in range(t_max):
        nodes_interval = [[v for k, v in fugitive_routes_labeled[realization].items() if k <= t_interval] for
                          realization in
                          range(n_realizations)]

        for i, sensor_loc in enumerate(sensor_locations):
            for realization in range(n_realizations):
                sensor_detections[i][realization, t_interval] = sensor_loc in nodes_interval[realization]

    model = Model('fugitiveinterception', function=fugitive_interception_model)

    levers, constants, outcomes = get_problem_formulation_rbfs(mode, graph, police_start, n_realizations, t_max,
                                                               sensor_locations,
                                                               fugitive_routes_labeled,
                                                               rbf, num_rbfs,
                                                               labels, labels_perunit_sorted,
                                                               labels_perunit_inv_sorted,
                                                               labels_inv,
                                                               sensor_detections, )
    model.levers = levers
    model.constants = constants
    model.outcomes = outcomes

    convergence_metrics = [
        ArchiveLogger(
            f"./results/",
            [l.name for l in model.levers],
            [o.name for o in model.outcomes if o.kind != o.INFO],
            base_filename=f"archives_N{manhattan_diameter}_T{t_max}_R{n_realizations}_U{num_units}_S{num_sensors}_{rbf.__name__}_numrbf{num_rbfs}_instance{instance}_seed{seed}.tar.gz"
        ),

        # OperatorProbabilities("SBX", 0),
        # OperatorProbabilities("PCX", 1),
        # OperatorProbabilities("DE", 2),
        # OperatorProbabilities("UNDX", 3),
        # OperatorProbabilities("SPX", 4),
        # OperatorProbabilities("UM", 5),
    ]
    # with SequentialEvaluator(model) as evaluator:
    with MultiprocessingEvaluator(model, n_processes=12) as evaluator:
        results = evaluator.optimize(algorithm=GenerationalBorg,
                                     nfe=nfe, searchover='levers',
                                     epsilons=[1 / n_realizations, ] * len(model.outcomes),
                                     convergence=convergence_metrics, convergence_freq=10)

    convergence = ArchiveLogger.load_archives(
        f"./results/archives_N{manhattan_diameter}_T{t_max}_R{n_realizations}_U{num_units}_S{num_sensors}_{rbf.__name__}_numrbf{num_rbfs}_instance{instance}_seed{seed}.tar.gz")
    print(results)

    convergence_df = pd.DataFrame()
    for nfe, archive in convergence.items():
        archive['nfe'] = nfe
        convergence_df = pd.concat([convergence_df, archive])
    convergence_df.to_csv(
        f'./results/convergence_N{manhattan_diameter}_T{t_max}_R{n_realizations}_U{num_units}_S{num_sensors}_{rbf.__name__}_numrbf{num_rbfs}_instance{instance}_seed{seed}.csv')


if __name__ == '__main__':
    for manhattan_diameter in [30, 10]:
        for n_realizations in [100, 30]:
            for num_sensors in [3, 10]:
                for num_units in [3, 10]:
                    for num_rbfs in [2, 6]:
                        for rbf in [rbf_linear, rbf_cubic, rbf_gaussian]:
                            for instance in range(3):
                                for seed in range(3):
                                    if manhattan_diameter == 10:
                                        nfe = 5e3
                                    elif manhattan_diameter == 30:
                                        nfe = 5e4

                                    run_dps(nfe=nfe,
                                            graph_type='manhattan',
                                            manhattan_diameter=manhattan_diameter,
                                            n_realizations=n_realizations,
                                            num_units=num_units,
                                            num_sensors=num_sensors,
                                            num_rbfs=num_rbfs,
                                            rbf=rbf,
                                            instance=instance,
                                            seed=seed
                                            )

