import json
import os
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import List
from argparse import ArgumentParser
import dacite
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import math
import numpy as np

@dataclass
class DataPoint:
    naive_makespan: float
    opt_makespan: float
    naive_succ_prob: float
    opt_succ_prob: float
    prog_size: int
    param_name: str  # Name of param being varied
    param_value: float  # Value of the varied param

@dataclass
class DataMeta:
    timestamp: str
    sim_duration: float
    hardware: str
    qia_sga: int
    prog_name: str 
    prog_sizes: List[int]
    num_iterations: List[int]
    num_clients: int
    linear: bool
    cc: float 
    t1: float 
    t2: float 
    single_gate_dur: float 
    two_gate_dur: float 
    all_gate_dur: float
    single_gate_fid: float
    two_gate_fid: float
    all_gate_fid: float
    qnos_instr_proc_time: float
    host_instr_time: float
    host_peer_latency: float
    client_num_qubits: float
    server_num_qubits: float
    # use_netschedule: bool
    # bin_length: int
    param_name: str  # The parameter being varied
    link_duration: int
    link_fid: float

@dataclass
class Data:
    meta: DataMeta
    data_points: List[DataPoint]

    def filter_data_points(self, value):
        # Remove data points with param_value > value 
        self.data_points = [dp for dp in self.data_points if dp.param_value <= value]

def relative_to_cwd(file: str) -> str:
    return os.path.join(os.path.dirname(__file__), file)

def create_pdf(filename: str):
    output_dir = relative_to_cwd("plots")
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    output_path = os.path.join(output_dir, f"{filename}.pdf")
    plt.savefig(output_path,format="pdf", transparent=True, dpi=1000)
    print(f"plot written to {output_path}")

def create_png(filename: str):
    output_dir = relative_to_cwd("plots")
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    output_path = os.path.join(output_dir, f"{filename}.png")
    plt.savefig(output_path, transparent=True, dpi=1000)
    print(f"plot written to {output_path}")

# Returns 5 dictionaries that map program size to a list of values
def get_vals(data: Data):
    naive_makespan_size_dp_map = dict()
    naive_succprob_size_dp_map = dict()
    opt_makespan_size_dp_map = dict()
    opt_succprob_size_dp_map = dict()
    x_val_size_dp_map = dict()

    meta = data.meta
    datapoints = data.data_points

    sizes = meta.prog_sizes
    for size in sizes:
        dps = [dp for dp in datapoints if dp.prog_size == size]
        naive_makespan_size_dp_map[size] = [dp.naive_makespan for dp in dps]
        naive_succprob_size_dp_map[size] = [dp.naive_succ_prob for dp in dps]

        opt_makespan_size_dp_map[size] = [dp.opt_makespan for dp in dps]
        opt_succprob_size_dp_map[size] = [dp.opt_succ_prob for dp in dps]

        x_val_size_dp_map[size] = [dp.param_value for dp in dps]

    return x_val_size_dp_map, naive_makespan_size_dp_map, naive_succprob_size_dp_map, opt_makespan_size_dp_map, opt_succprob_size_dp_map

# Scans all .json files in a folder and finds the 'worst' results in terms of makespan and success probability
def find_worst(path:str, param:str, hardware:str, program:str, savefile:bool=False, timestamp=None):
    # Get all .json files for the correct parameter and hardware
    files = [f for f in os.listdir(relative_to_cwd(path)) if f[-5:] == ".json" and param in f and hardware in f and program in f]

    # Load all of the data objects
    datas = [load_data(path+"/"+f) for f in files] 

    worst_makespan = math.inf
    worst_makespan_file = ""
    worst_succprob = math.inf 
    worst_succprob_file = ""
    succprob_diff = 0
    succsec_diff = 0
    # For each data object
    for i in range(0, len(datas)):
        data = datas[i]

        avg_makespan_diff = 0
        avg_succprob_diff = 0
        # Compute the average difference for makespan and success probability
        for dp in data.data_points:
            avg_makespan_diff += dp.naive_makespan-dp.opt_makespan
            avg_succprob_diff += dp.opt_succ_prob-dp.naive_succ_prob
        avg_makespan_diff = avg_makespan_diff / len(data.data_points) 
        avg_succprob_diff = avg_succprob_diff / len(data.data_points)

        if avg_makespan_diff < worst_makespan:
            worst_makespan = avg_makespan_diff
            worst_makespan_file = files[i]

        if avg_succprob_diff < worst_succprob:
            worst_succprob = avg_succprob_diff
            worst_succprob_file = files[i]
            
            succprob_diff = [[sum([dp.opt_succ_prob  for dp in data.data_points if (dp.prog_size == n)])/(len(data.data_points)/len(data.meta.prog_sizes)),sum([dp.naive_succ_prob for dp in data.data_points if (dp.prog_size==n)])/(len(data.data_points)/len(data.meta.prog_sizes))] for n in data.meta.prog_sizes]
    
            succsec_diff = [[sum([dp.opt_succ_prob /dp.opt_makespan  for dp in data.data_points if (dp.prog_size == n)])/(len(data.data_points)/len(data.meta.prog_sizes)),sum([dp.naive_succ_prob / dp.naive_makespan for dp in data.data_points if (dp.prog_size==n)])/((len(data.data_points)/len(data.meta.prog_sizes)))] for n in data.meta.prog_sizes]
    print("Avg Makespan diff: ", worst_makespan, worst_makespan_file)
    print("Avg Succprob diff: ", worst_succprob,worst_succprob_file)
    print(succprob_diff)
    print(succsec_diff)
    print("Succprob, ", [(ld[0] - ld[1])/ld[0]*100 for ld in succprob_diff])

    print("\n\nSuccsec",[(ld[0] - ld[1])/ld[0]*100 for ld in succsec_diff])
    worst_makespan_data = load_data(path+"/"+worst_makespan_file)
    worst_succprob_data = load_data(path+"/"+worst_succprob_file)

    create_plots(timestamp,worst_makespan_data,"makespan",saveFile)
    create_plots(timestamp,worst_succprob_data,"succprob",saveFile)
    create_plots(timestamp,worst_succprob_data,"succsec",saveFile)

def load_data(path: str) -> Data:
    with open(relative_to_cwd(path), "r") as f:
        all_data = json.load(f)

    return dacite.from_dict(Data, all_data)

def create_plots(timestamp, data: Data, plottype:str, save=True):
    meta = data.meta
    prog_sizes = meta.prog_sizes
    x_val_map, naive_makespan_map, naive_succprob_map, opt_makespan_map, opt_succprob_map = get_vals(data)
    label_fontsize = 14
    opt_markersize=8

    # plt.xscale('log')
    if plottype=="makespan" or plottype=="":
        for key in x_val_map.keys():
            plt.plot(
                x_val_map[key], [(val / meta.num_iterations[0]) / 1e6 for val in naive_makespan_map[key]] , label=f"Subopt $n$={key}", marker="s"
            )
            if meta.prog_name == "rotation":
                if key == 10:
                    plt.plot(x_val_map[key], [(val / meta.num_iterations[0]) / 1e6 for val in opt_makespan_map[key]], label=f"Opt", marker="d", markersize=opt_markersize, color="red")
            else:
                plt.plot(x_val_map[key], [(val / meta.num_iterations[0]) / 1e6 for val in opt_makespan_map[key]], label=f"Opt $n$={key}", marker="d", markersize=opt_markersize)

        plt.legend(loc="upper left", fontsize=11)
        if meta.param_name == "single_gate_fid":
            plt.xlabel("Single qubit gate fidelity",fontsize=label_fontsize)   
        elif meta.param_name == "distance":
            plt.xlabel("Distance (km)",fontsize=label_fontsize)   
        elif meta.param_name == "cc":
            plt.xlabel("Classical Communication Latency as a fraction of T2 time",fontsize=label_fontsize)   
        else:
            plt.xlabel(meta.param_name,fontsize=label_fontsize)   
        plt.ylabel("Makespan (ms)", fontsize=label_fontsize)
        
        if save:
            create_pdf(timestamp + "_" + meta.prog_name + "_"+ meta.param_name + "_makespan_n_"+ meta.hardware)
        else:
            plt.show()
        plt.cla()

    if plottype=="succprob" or plottype=="":
        plt.ylim(0.6, 1.01)
        for key in x_val_map.keys():
            plt.plot(
                x_val_map[key], naive_succprob_map[key], label=f"Unopt $n$={key}", marker="s"
            )
            if meta.prog_name == "rotation":
                if key == 10:
                    plt.plot(x_val_map[key], opt_succprob_map[key], label=f"Optimized", marker="d", markersize=opt_markersize, color="red")
            else:
                plt.plot(x_val_map[key], opt_succprob_map[key], label=f"Opt $n$={key}", marker="d", markersize=opt_markersize)

        plt.legend(loc="lower right", fontsize=11)
        plt.ylabel("Success Probability",fontsize=label_fontsize)
        if meta.param_name == "single_gate_fid":
            plt.xlabel("Single qubit gate fidelity",fontsize=label_fontsize)   
        elif meta.param_name == "distance":
            plt.xlabel("Distance (km)",fontsize=label_fontsize)   
        elif meta.param_name == "cc":
            plt.xlabel("Classical Communication Latency as a fraction of T2 time",fontsize=label_fontsize)
        else:
            plt.xlabel(meta.param_name,fontsize=label_fontsize)   
        
        if save:
            create_pdf(timestamp + "_" + meta.prog_name + "_"+ meta.param_name + "_succprob_n_" + meta.hardware)
        else:
            plt.show()
        plt.cla()


    if plottype=="succsec" or plottype=="": 
        for key in x_val_map.keys():
            plt.plot(
                x_val_map[key], [naive_succprob_map[key][i] / naive_makespan_map[key][i] * 1e9 for i in range(0,len(x_val_map[key]))], label=f"Subopt $n$={key}", marker="s"
            )
            if meta.prog_name == "rotation":
                if key == 10:
                    plt.plot(x_val_map[key],  [opt_succprob_map[key][i] / opt_makespan_map[key][i] *1e9 for i in range(0,len(x_val_map[key]))], label=f"Opt", marker="d", markersize=opt_markersize)
            else:
                plt.plot(x_val_map[key],  [opt_succprob_map[key][i] / opt_makespan_map[key][i] *1e9 for i in range(0,len(x_val_map[key]))], label=f"Opt $n$={key}", marker="d", markersize=opt_markersize)
        plt.legend(loc="upper right", fontsize=11)
        plt.ylabel("Successes / s", fontsize=label_fontsize)
        if meta.param_name == "single_gate_fid":
            plt.xlabel("Single qubit gate fidelity",fontsize=label_fontsize)   
        elif meta.param_name == "distance":
            plt.xlabel("Distance (km)",fontsize=label_fontsize)   
        elif meta.param_name == "cc":
            plt.xlabel("Classical Communication Latency as a fraction of T2 time",fontsize=label_fontsize)
        else:
            plt.xlabel(meta.param_name,fontsize=label_fontsize)

        if save:
            create_pdf(timestamp + "_" + meta.prog_name + "_" + meta.param_name + "_succsec_n_"+ meta.hardware)
        else:
            plt.show()
        plt.cla()

def average_and_plot(path: str, param: str, hardware: str, program: str, savefile: bool = False, timestamp=None):
    # Get all .json files for the correct parameter, hardware, and program
    files = [f for f in os.listdir(relative_to_cwd(path)) if f.endswith(".json") and param in f and hardware in f and program in f]

    # Load all data objects
    datas = [load_data(os.path.join(path, f)) for f in files]

    # Dictionary to store aggregated data points
    aggregated_data = {}

    # Aggregate data points by program size and parameter value
    for data in datas:
        for dp in data.data_points:
            key = (dp.prog_size, dp.param_value)
            if key not in aggregated_data:
                aggregated_data[key] = {
                    "naive_makespans": [],
                    "opt_makespans": [],
                    "naive_succ_probs": [],
                    "opt_succ_probs": [],
                }
            aggregated_data[key]["naive_makespans"].append(dp.naive_makespan)
            aggregated_data[key]["opt_makespans"].append(dp.opt_makespan)
            aggregated_data[key]["naive_succ_probs"].append(dp.naive_succ_prob)
            aggregated_data[key]["opt_succ_probs"].append(dp.opt_succ_prob)

    # Compute averages and standard deviations
    averaged_data = {}
    for key, values in aggregated_data.items():
        prog_size, param_value = key
        averaged_data[key] = {
            "param_value": param_value,
            "prog_size": prog_size,
            "naive_makespan_avg": np.mean(values["naive_makespans"]),
            "naive_makespan_std": np.std(values["naive_makespans"]),
            "opt_makespan_avg": np.mean(values["opt_makespans"]),
            "opt_makespan_std": np.std(values["opt_makespans"]),
            "naive_succ_prob_avg": np.mean(values["naive_succ_probs"]),
            "naive_succ_prob_std": np.std(values["naive_succ_probs"]),
            "opt_succ_prob_avg": np.mean(values["opt_succ_probs"]),
            "opt_succ_prob_std": np.std(values["opt_succ_probs"]),
        }

    # Organize data for plotting
    plot_data = {}
    for key, values in averaged_data.items():
        prog_size = values["prog_size"]
        if prog_size not in plot_data:
            plot_data[prog_size] = {
                "param_values": [],
                "naive_makespan_avgs": [],
                "naive_makespan_stds": [],
                "opt_makespan_avgs": [],
                "opt_makespan_stds": [],
                "naive_succ_prob_avgs": [],
                "naive_succ_prob_stds": [],
                "opt_succ_prob_avgs": [],
                "opt_succ_prob_stds": [],
            }
        plot_data[prog_size]["param_values"].append(values["param_value"])
        plot_data[prog_size]["naive_makespan_avgs"].append(values["naive_makespan_avg"])
        plot_data[prog_size]["naive_makespan_stds"].append(values["naive_makespan_std"])
        plot_data[prog_size]["opt_makespan_avgs"].append(values["opt_makespan_avg"])
        plot_data[prog_size]["opt_makespan_stds"].append(values["opt_makespan_std"])
        plot_data[prog_size]["naive_succ_prob_avgs"].append(values["naive_succ_prob_avg"])
        plot_data[prog_size]["naive_succ_prob_stds"].append(values["naive_succ_prob_std"])
        plot_data[prog_size]["opt_succ_prob_avgs"].append(values["opt_succ_prob_avg"])
        plot_data[prog_size]["opt_succ_prob_stds"].append(values["opt_succ_prob_std"])

    # Print percentage decrease in makespan and percentage increase in success probability
    for prog_size, data in plot_data.items():
        avg_makespan_decrease = np.mean([
            (naive - opt) / naive * 100 
            for naive, opt in zip(data["naive_makespan_avgs"], data["opt_makespan_avgs"])
        ])
        avg_succprob_increase = np.mean([
            (opt - naive) / naive * 100 
            for naive, opt in zip(data["naive_succ_prob_avgs"], data["opt_succ_prob_avgs"])
        ])
        print(f"Program Size {prog_size}:")
        print(f"  Average Makespan Decrease (Optimal vs Suboptimal): {avg_makespan_decrease:.2f}%")
        print(f"  Average Success Probability Increase (Optimal vs Suboptimal): {avg_succprob_increase:.2f}%")

    xlabel = ""
    if param == "cc":
        xlabel = "Classical communication latency as a fraction of T2 time"
    elif param == "single_gate_fid":
        xlabel = "Single qubit gate fidelity"
    elif param == "distance":
        xlabel = "Distance (km)"
    # Plot suboptimal for all program sizes
    for prog_size, data in plot_data.items():
        plt.errorbar(
            data["param_values"], 
            [val / 1e6 for val in data["naive_makespan_avgs"]], 
            yerr=[std / 1e6 for std in data["naive_makespan_stds"]], 
            label=f"Unopt $n$={prog_size}", marker="s", capsize=5
        )
    # Plot optimal only for program size 10
    # if 10 in plot_data:
    #     data = plot_data[10]
        plt.errorbar(
            data["param_values"], 
            [val / 1e6 for val in data["opt_makespan_avgs"]], 
            yerr=[std / 1e6 for std in data["opt_makespan_stds"]], 
            label=f"Opt $n$={prog_size}", marker="d", markersize=8, capsize=5,
        )
    
    plt.legend(loc="center left", fontsize=11)
    plt.xlabel(xlabel, fontsize=14)
    plt.ylabel("Makespan (ms)", fontsize=14)
    if savefile:
        create_pdf(f"{timestamp}_{program}_{param}_makespan_plot_{hardware}")
    else:
        plt.show()
    plt.cla()

    # Plot success probabilities
    for prog_size, data in plot_data.items():
        plt.errorbar(
            data["param_values"], 
            data["naive_succ_prob_avgs"], 
            yerr=data["naive_succ_prob_stds"], 
            label="$S_{unopt}$"+f" $n$={prog_size}", marker="s", capsize=5
        )
    # Plot optimal only for program size 10
    if 10 in plot_data:
        data = plot_data[10]
        plt.errorbar(
            data["param_values"], 
            data["opt_succ_prob_avgs"], 
            yerr=data["opt_succ_prob_stds"], 
            label="$S_{opt}$"+f" $n$={prog_size}", marker="d",markersize=8, capsize=5, 
            color="red",
        )
    plt.legend(loc="lower right", fontsize=11)
    plt.xlabel(xlabel, fontsize=14)
    plt.ylabel("Success Probability", fontsize=14)
    plt.ylim(0.6, 1.01)
    if savefile:
        create_pdf(f"{timestamp}_{program}_{param}_succprob_plot_{hardware}")
    else:
        plt.show()
    plt.cla()

if __name__ == "__main__":
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    parser = ArgumentParser()
    parser.add_argument("--folder", "-f", type=str, required=True)
    parser.add_argument("--save", "-s", action="store_true", default=False)
    parser.add_argument("--params", type=str, nargs="+", required=True)
    parser.add_argument("--hardware", type=str, nargs="+", required=True)
    parser.add_argument("--programs", type=str, nargs="+",required=True)

    args = parser.parse_args()
    folder = args.folder
    saveFile = args.save
    params = args.params
    hardware = args.hardware
    programs = args.programs
    
    for param in params:
        for hw in hardware:
            for program in programs:
                # find_worst(folder, param, hw, program, saveFile, timestamp)
                average_and_plot(folder, param, hw, program, saveFile, timestamp)
                

        
