using CounterfactualExplanations
using CounterfactualExplanations.Benchmark
using CounterfactualExplanations.Generators
using CounterfactualExplanations.Data
using CounterfactualExplanations.Models
using DataFrames
using Flux
using LinearAlgebra
using MLUtils
using Random
init_perturbation = 2.0

# NOTE:
# This is probably the most important/useful test script, because it runs through the whole process of: 
# - loading artifacts
# - setting up counterfactual search for various models and generators
# - running counterfactual search

# LOOP:
for (key, generator_) ∈ generators
    name = uppercasefirst(string(key))
    @testset "$name" begin

        # Generator:
        generator = deepcopy(generator_())

        @testset "Models for synthetic data" begin

            for (key, value) ∈ synthetic

                name = string(key)
                @testset "$name" begin
                    counterfactual_data = value[:data]
                    X = counterfactual_data.X
                    ys_cold = vec(counterfactual_data.y)

                    for (likelihood, model) ∈ value[:models]
                        name = string(likelihood)
                        @testset "$name" begin
                            M = model[:model]
                            # Randomly selected factual:
                            Random.seed!(123)
                            x = select_factual(counterfactual_data, rand(1:size(X,2)))
                            multiple_x =
                                select_factual(counterfactual_data, rand(1:size(X,2), 5))
                            # Choose target:
                            y = predict_label(M, counterfactual_data, x)
                            target = get_target(counterfactual_data, y[1])
                            # Single sample:
                            counterfactual = generate_counterfactual(
                                x,
                                target,
                                counterfactual_data,
                                M,
                                generator,
                            )
                            # Multiple samples:
                            counterfactuals = generate_counterfactual(
                                multiple_x,
                                target,
                                counterfactual_data,
                                M,
                                generator,
                            )

                            @testset "Predetermined outputs" begin
                                if typeof(generator) <:
                                   Generators.AbstractLatentSpaceGenerator
                                    @test counterfactual.latent_space
                                end
                                @test counterfactual.target == target
                                @test counterfactual.x == x &&
                                      CounterfactualExplanations.factual(counterfactual) ==
                                      x
                                @test CounterfactualExplanations.factual_label(
                                    counterfactual,
                                ) == y
                                @test CounterfactualExplanations.factual_probability(
                                    counterfactual,
                                ) == probs(M, x)
                            end

                            @testset "Benchmark" begin
                                @test isa(benchmark(counterfactual), DataFrame)
                                @test isa(
                                    benchmark(counterfactuals; to_dataframe = false),
                                    Dict,
                                )
                            end

                            @testset "Convergence" begin

                                @testset "Non-trivial case" begin
                                    counterfactual_data.generative_model = nothing
                                    # Threshold reached if converged:
                                    γ = 0.9
                                    generator.decision_threshold = γ
                                    T = 1000
                                    counterfactual = generate_counterfactual(
                                        x,
                                        target,
                                        counterfactual_data,
                                        M,
                                        generator;
                                        T = T,
                                    )
                                    using CounterfactualExplanations:
                                        counterfactual_probability
                                    @test !converged(counterfactual) ||
                                          target_probs(counterfactual)[1] >= γ # either not converged or threshold reached
                                    @test !converged(counterfactual) ||
                                          length(path(counterfactual)) <= T
                                end

                                @testset "Trivial case (already in target class)" begin
                                    counterfactual_data.generative_model = nothing
                                    # Already in target and exceeding threshold probability:
                                    y = predict_label(M, counterfactual_data, x)
                                    target = y[1]
                                    generator.decision_threshold = minimum([1/length(counterfactual_data.y_levels), 0.5])
                                    counterfactual = generate_counterfactual(
                                        x,
                                        target,
                                        counterfactual_data,
                                        M,
                                        generator;
                                    )
                                    @test length(path(counterfactual)) == 1
                                    @test maximum(
                                        abs.(
                                            counterfactual.x .-
                                            CounterfactualExplanations.decode_state(
                                                counterfactual,
                                            )
                                        ),
                                    ) < init_perturbation
                                    @test converged(counterfactual)
                                    @test CounterfactualExplanations.terminated(
                                        counterfactual,
                                    )
                                    @test CounterfactualExplanations.total_steps(
                                        counterfactual,
                                    ) == 0

                                end
                            end
                        end
                    end
                end
            end
        end
    end
end
