using Gen
using JSON
using Distributions 

"""
Get the departure times from the final plan file
"""
function get_departure_times_from_plan(plan)
    departure_times = []
    actions = plan["actions"]

    return get_departure_times_from_actions(actions)
end


"""
Get the departure times from the plan actions
"""
function get_departure_times_from_actions(actions)
    departure_times = []

    for action in actions
        if get(action["taskType"], "predefined", -1) == "Exit"
            departure_times = push!(departure_times, (parse(Int, action["endTime"]), parse(Int, action["shuntingUnit"]["id"])))
        end
    end

    return departure_times, actions
end


"""
Get the arrival times from the plan actions
"""
function get_arrival_times_from_actions(actions)
    arrival_times = []

    for action in actions
        if get(action["taskType"], "predefined", -1) == "Arrive"
            arrival_times = push!(arrival_times, (parse(Int, action["endTime"]), parse(Int, action["shuntingUnit"]["id"])))
        end
    end

    return arrival_times, actions
end


"""
Get departures from the final plan if made, otherwise take the best intermediate plan
"""
function get_departures_from_best_plan()
    final_plan_path = "/workspace/algorithm-files/plans/plan.json"
    temp_plans_TS_path = "/workspace/algorithm-files/plans_temp_TS"
    temp_plans_SA_path = "/workspace/algorithm-files/plans_temp_SA"

    # Final plan has been generated, use that
    if isfile(final_plan_path)
        println("***Departures from final plan")
        plan = JSON.parsefile(final_plan_path)
        departure_times, actions = get_departure_times_from_plan(plan)

        schedule = "Final plan:\n"

        open(replace(final_plan_path, r"\.json$" => ".txt")) do f
            schedule *= read(f, String)
            schedule *= "\n" 
        end

        return departure_times, actions, schedule

    # Planner crashed, no final plan, use the best temp plan
    elseif isdir(temp_plans_TS_path)
        println("***Departures from temp plan")
        temp_plans = readdir(temp_plans_TS_path, join=true)
        if !isempty(temp_plans)
            temp_plans = filter(x -> occursin(".json", x), temp_plans)
            last_plan_path = last(sort(temp_plans, by = x -> parse(Int, match(r"temp_plan_(\d+)", x).captures[1])))
            plan = JSON.parsefile(last_plan_path)
            departure_times, actions = get_departure_times_from_plan(plan)

            schedule = "Temp plan:\n"

            open(replace(last_plan_path, r"\.json$" => ".txt")) do f   
                schedule *= read(f, String)
                schedule *= "\n" 
            end

            return departure_times, actions, schedule
        end
    
    else
        return [], [], ""
    end

end


"""
Removes the final plan if exists, removes the temp plans directories
"""
function remove_generated_plans()
    final_plan_path = "/workspace/algorithm-files/plans/plan.json"
    temp_plans_TS_path = "/workspace/algorithm-files/plans_temp_TS"
    temp_plans_SA_path = "/workspace/algorithm-files/plans_temp_SA"

    if isfile(final_plan_path)
        rm(final_plan_path)
    end

    if isdir(temp_plans_TS_path)
        rm(temp_plans_TS_path, recursive=true)
    end
    
    if isdir(temp_plans_SA_path)
        rm(temp_plans_SA_path, recursive=true)
    end
end


"""
Runs the planner with the given location and scenario paths
Output plan will be saved to the specified output path
"""
function run_solver(solve_config_path)
    println("***Run solver")
    remove_generated_plans()

    departure_times = []
    output_schedule = "\n"

    try
        cd(raw"robust-rail-solver/ServiceSiteScheduling")
        command = `dotnet run -- --config=$(solve_config_path)`
        solver_output = "/workspace/algorithm-files/solver_output.txt"
        run(pipeline(command, solver_output))
    finally
        cd("/workspace")
        departure_times, actions, output_schedule = get_departures_from_best_plan()
        return departure_times, actions, output_schedule
    end
    println("***Solver run completed")
end


"""
Generate a scenario
input: configuration for custom scenario, scenario name
Output: Scenario file is generated in the scenario directory
"""
function run_generator(gen_config_path, scenario_name)
    cd(raw"robust-rail-generator/src/scenario_generator")
    println("***Generating scenario")
    try 
        run(`python create_scenario.py --config "$(gen_config_path)" --scenario-file "$(scenario_name)"`)
        println("***Scenario generated")
    finally
        cd("/workspace")
    end

end

"""
Test correctness of scenario and location files, OR evaluates plan.
Is not used in the research.

Input: Location, Scenario, Plan (can be empty if test correctness), format (HIP or TORS) 
Output: Tells if location/scenario combination is correct, OR if plan is valid
"""
function run_evaluator(path_location, path_scenario, path_plan, format)
    println("***Evaluating plan")
    try
        cd(raw"robust-rail-evaluator")
        run(pipeline(`./build/TORS --mode "EVAL" --path_location "$(path_location)" --path_scenario "$(path_scenario)" --path_plan "$(path_plan)" --plan_type "$(format)"`, "/workspace/algorithm-files/evaluation_output.txt"))
        println("***Plan evaluated")

    finally
        cd("/workspace")
    end
end


"""
Generates a scenario, solves it to a plan

Input: path to generation config, scenario name, path to solver config
Output: best departure times, actions of the plan, output schedule
"""
function run_full_pipeline(gen_config_path, scenario_name, solve_config_path)
    run_generator(gen_config_path, scenario_name)

    # Run multiple times and get best plan to reduce the effect of the randomness in solver
    AMOUNT_OF_TRIES = 2

    best_departure_times = []
    best_actions = []
    output_schedule = ""

    tries = 0
    # Try to get a working plan, if this fails than the plan is unsolvable for the solver (Error: "No feasible matchin possible. Unmatched arrivals = ....)
    while (length(best_departure_times) == 0 && tries < 5)
        tries += 1
        best_departure_times, best_actions, output_schedule = run_solver(solve_config_path);
    end

    for i in 1:(AMOUNT_OF_TRIES-1)
        departure_times, actions, best_output_schedule = run_solver(solve_config_path);

        if length(departure_times) > 0 && (sum(map(x -> x[1], departure_times)) < sum(map(x -> x[1], best_departure_times)))
            best_departure_times = departure_times
            best_actions = actions
            output_schedule = best_output_schedule
        end
    end

    return best_departure_times, best_actions, output_schedule;
end


"""
Get arrival times from the generation config
"""
function get_arrival_times_from_config(config_dict)
    arrival_times = []
``
    custom_trains = config_dict["custom_trains"]

    for train in custom_trains
        arrival_time = get(train, "arrival_time", -1)
        if arrival_time != -1
            push!(arrival_times, arrival_time)
        end
    end

    return arrival_times
end


"""
Get departure times from the generation config

input: config dictionary
return array with tuples of (departure_time, id)
"""
function get_departure_times_from_config(config_dict)
    departure_times = []

    custom_trains = config_dict["custom_trains"]

    for train in custom_trains
        departure_time = get(train, "departure_time", -1)
        id = get(train, "id", -1)
        
        if departure_time != -1
            push!(departure_times, (departure_time, id))
        end
    end

    return departure_times
end


"""
Get cleaning times from the generation config
return array with cleaning_times
"""
function get_cleaning_times_from_config(config_dict)
    cleaning_times = []

    custom_services = config_dict["custom_servicing_tasks"]

    for service in custom_services
        if get(service, "type", "") == "Reinigingsperron"
            cleaning_time = get(service, "duration", -1)
            if cleaning_time != -1
                push!(cleaning_times, cleaning_time)
            end
        end
    end

    return cleaning_times
end


"""
Creates new config file with random arrival times

Input: config dictionary, the generated random arrival times, path to config file
Output: config dictionary with the random arrival times
"""
function create_random_arrival_time_config_dict(config_dict, random_arrivals, c_path, new_config_file_name)
    random_config_dict = deepcopy(config_dict)

    custom_trains = random_config_dict["custom_trains"]
    for (i, train) in enumerate(custom_trains)
        if get(train, "arrival_time", -1) != -1
            custom_trains[i]["arrival_time"] = random_arrivals[i]
        end
    end

    open(new_config_file_name, "w") do file
        JSON.print(file, random_config_dict)
    end

    return random_config_dict
end


"""
Creates new config file with random cleaning times

Input: config dictionary, the generated random cleaning, path to config file
Output: config dictionary with the random cleaning
"""
function create_random_cleaning_time_config_dict(config_dict, random_cleaning, c_path, new_config_file_name)
    random_config_dict = deepcopy(config_dict)

    custom_tasks = random_config_dict["custom_servicing_tasks"]
    for (i, task) in enumerate(custom_tasks)
        if get(task, "duration", -1) != -1
            custom_tasks[i]["duration"] = random_cleaning[i]
        end
    end

    open(new_config_file_name, "w") do file
        JSON.print(file, random_config_dict)
    end

    return random_config_dict
end


"""
Generates random values from a normal distribution centered around the original values

input: array of values
output: array of randomized values
"""
@gen function get_random_values_from_normal(values)
    STD = repeat([50], length(values))

    random_values = []
    for (i, value) in enumerate(values)
        random_value = ({(:random_value, i)} ~ normal(value, STD[i]))
        push!(random_values, random_value)
    end

    random_values = map(x -> round(Int64, x), random_values)

    return random_values
end


"""
calculate the delays
input: scheduled departure times, actual departure times
output: array of delays
"""
@gen function calculate_delays(scheduled_departures, departure_times)
    delays = []

    for (i, (scheduled, actual)) in enumerate(zip(scheduled_departures, departure_times))
        delay_val = abs(actual[1] - scheduled[1])
        # Sample the delays from a normal distribution with a very small variance so that they can be traced by Gen
        delay = ({(:delays, scheduled[2])} ~ normal(delay_val, 1e-1))
        push!(delays, delay)
    end


    return delays
end


"""
Generate a scenario for the planner
"""
@gen function generate_scenario_and_run_planner(scenario_config_path)
    # Get arrival times from the generation config
    config_dict = JSON.parsefile(scenario_config_path)
    scheduled_departures = get_departure_times_from_config(config_dict)


    # Choose an uncertainty #################################################

    # Create new config file with random arrival times
    arrivals = get_arrival_times_from_config(config_dict)
    random_arrivals = @trace(get_random_values_from_normal(arrivals))
    scenario_random_config_path = "$(scenario_config_path[1:end-5])_uncertain.json"
    random_config_dict = create_random_arrival_time_config_dict(config_dict, random_arrivals, scenario_config_path, scenario_random_config_path)

    # Uncomment to create new config file with random cleaning times
    # cleaning_times = get_cleaning_times_from_config(config_dict)
    # random_cleaning = @trace(get_random_values_from_normal(cleaning_times))
    # scenario_random_config_path = "$(scenario_config_path[1:end-5])_uncertain.json"
    # random_config_dict = create_random_cleaning_time_config_dict(config_dict, random_cleaning, scenario_config_path, scenario_random_config_path)

    #########################################################################

    
    solve_config_path = "/workspace/algorithm-files/solver-config.yaml"
    departure_times, actions, output_schedule = run_full_pipeline(scenario_random_config_path, "scenario_name", solve_config_path)

    # Assume the order of departure times matches the order of scheduled departures in the arrays, the IDs are different.
    delays = @trace(calculate_delays(scheduled_departures, departure_times))

    return (scheduled_departures, departure_times, actions, output_schedule)
end


"""
Perform importance sampling
input: path to scenario config file
output: traces, log norm weights, and LML estimate
"""
@gen function perform_inference(scenario_config_path, iterations)

    observations = Gen.choicemap()

    config_dict = JSON.parsefile(scenario_config_path)
    arrivals = get_arrival_times_from_config(config_dict)
    scheduled_departures = get_departure_times_from_config(config_dict)
    DELAY = 0
    for (departure_time, id) in scheduled_departures
        address = (:delays, id)
        observations[address] = DELAY
    end

    (traces, log_norm_weights, lml_est) = Gen.importance_sampling(generate_scenario_and_run_planner, (scenario_config_path,), observations, iterations)

    return traces, log_norm_weights, lml_est;

end
    
"""
main function
"""
function main()
    # If you change scenario, also change the Location path in the solver config file (/workspace/algorithm-files/solver-config.yaml)
    # scenario_config_path = "/workspace/algorithm-files/generation-config/three-three-configuration.json"
    # scenario_config_path = "/workspace/algorithm-files/generation-config/simple-configuration.json"
    scenario_config_path = "/workspace/algorithm-files/generation-config/cleaning_configuration.json"

    ITERATIONS = 15

    

    # Run solver on the original scenario without uncertainty
    output_schedule_original = ""
    departure_times_original, _, output_schedule_original = run_full_pipeline(scenario_config_path, "scenario_name", "/workspace/algorithm-files/solver-config.yaml")

    # Perform inference with importance sampling
    traces, log_norm_weights, lml_est = perform_inference(scenario_config_path, ITERATIONS)

    scheduled_arrival_times = get_arrival_times_from_config(JSON.parsefile(scenario_config_path))
    scheduled_departure_times = map(x -> x[1], get_departure_times_from_config(JSON.parsefile(scenario_config_path)))

    arrival_times_per_iteration = []
    departure_times_per_iteration = []
    delays_per_iteration = []
    actions_per_iteration = []
    output_schedule_per_iteration = []

    summed_delays = []

    unsolved_plans = []
    
    for (i,t) in enumerate(traces)
        (_, departure_times, actions, output_schedule) = Gen.get_retval(t)

        actions_per_iteration = push!(actions_per_iteration, actions)
        output_schedule_per_iteration = push!(output_schedule_per_iteration, output_schedule)

        arrivals, _ = get_arrival_times_from_actions(actions)
        arrivals = map(x -> x[1], arrivals)

        departure_times = map(x -> x[1], departure_times)

        # Skip if a plan has failed (Error: "No feasible matchin possible. Unmatched arrivals = ....)
        if length(departure_times) == 0 || length(arrivals) == 0 || output_schedule == ""
            unsolved_plans = push!(unsolved_plans, i)
            continue
        end

        delays = departure_times - scheduled_departure_times
        summed_delays = push!(summed_delays, sum(delays))

        arrival_times_per_iteration = push!(arrival_times_per_iteration, arrivals)
        departure_times_per_iteration = push!(departure_times_per_iteration, departure_times)
        delays_per_iteration = push!(delays_per_iteration, delays)
    end

    # Remove unsolved plans from log weights array
    log_norm_weights = log_norm_weights[setdiff(1:length(log_norm_weights), unsolved_plans)]

    open("/workspace/algorithm-files/inference_samples_output.txt", "w") do file
        write(file, "Scheduled arrival times: $(scheduled_arrival_times)\n")
        write(file, "Scheduled departure times: $(scheduled_departure_times)\n")
        write(file, "Original scenario schedule: \n$(output_schedule_original)\n")
        write(file, "Departure times from original scenario: $(departure_times_original)\n\n\n\n")

        write(file, "Summed delays: $summed_delays\n")
        write(file, "All log norm weights: $(log_norm_weights)\n")
        write(file, "LML estimate: $(lml_est)\n\n")

        for i in range(1, length(arrival_times_per_iteration))
            write(file, "Iteration $i\n")
            write(file, "Arrival times: $(arrival_times_per_iteration[i])\n")
            write(file, "Departure times: $(departure_times_per_iteration[i])\n")
            write(file, "Delays: $(delays_per_iteration[i])\n")
            write(file, "Log norm weights: $(log_norm_weights[i])\n")
            write(file, "Output schedule: \n$(output_schedule_per_iteration[i])\n\n")
            # write(file, "Actions: $(actions_per_iteration)\n\n\n")
        end
    end

end


main()