using SpinShuttling
import StatsBase

include("../data_io.jl")
data_dir="./data/error_data/"
import Statistics: mean, std, var

# function fidelity2phase(f::Real)::Complex
#     a=2*f-1
#     b=sqrt(1 - a^2)
#     return a+sign(rand()-1/2)*b*im
# end

function objfunc(model::ShuttlingModel, randseq::Vector{<:Real}; isarray::Bool=false)::Number
    # model.R || error("covariance matrix is not initialized")
    N = model.N
    dt = model.T / N
    A = model.R(randseq)
    if model.n == 1
        Z = A
    elseif model.n == 2
        # only valid for two-spin EPR pair, ψ=1/√2(|↑↓⟩-|↓↑⟩)
        Z = A[1:N] - A[N+1:end]
    else
        Z = missing
    end
    ϕ = sum(Z) * dt
    return exp.(im * ϕ)
    # return cos.(ϕ)
end

function samplingerror(M_arr::Vector{Int}, sample::Vector{<:Number}, f_exp::Number)::DataFrame
    println("scanning parameter: M= ", minimum(M_arr), " ~ ", maximum(M_arr))
    m=length(M_arr)
    data=Vector{NamedTuple}(undef,m)
    # convery fidelity sampling to phase sampling
    # sample=fidelity2phase.(sample)
    f_exp=2*f_exp.-1

    @showprogress for i in 1:m
        M=M_arr[i]
        counting = @timed begin 
            sample= StatsBase.sample(sample, M)
            (abs(mean(sample)-f_exp)/f_exp, sqrt(var(sample)/M)/f_exp, mean((sample .- f_exp).^2) )
        end
        data[i]=(M=M, 
        epsilon=counting.value[1], sigma=counting.value[2], delta=counting.value[3], 
        cpu_time=counting.time, ram_bytes=counting.bytes)
    end
    return DataFrame(data)
end


begin "running!"
    σ =2; κₜ=1; κₓ=1;
    B=OrnsteinUhlenbeckField(0,[κₜ,κₓ],σ)
    M_max=10^6; M_min=10^3; m=1000;
    # M_max=10^7; M_min=10^4; m=1000;
    M_arr = unique(round.(Int,  (1 ./(range(1/sqrt(M_max), 1/sqrt(M_min), m))).^2))
    ##
    let T=2, L=10,  N=501;
        meta_info=(description="Sampling Error of One Spin One-way Shuttling", sigma=σ, T=T, L=L, N=N, corr_t=κₜ, corr_x=κₓ)
        save(meta_info, "SmpError_S1_OW_OU_MC.json", dir=data_dir)
        println(meta_info)
        
        model=OneSpinModel(T, L, N, B);
        # sample=zeros(ComplexF64, M_max)
        # @showprogress for i in 1:M_max
        #     sample[i]=objfunc(model, randn(N))
        # end
        # save(DataFrame(:vals => sample), "SmpVals_S1_OW_OU_MC.csv", dir=data_dir)
        # println("sample values saved!")
        sample=load("Ben02_20240712_SmpVals_S1_OW_OU_MC.csv", dir=data_dir)[!, :vals]
        sample=parse.(ComplexF64, sample)
        f_exp=(1+W(T, L, B))/2;
        println("prediction: W=", f_exp)
        ##
        save(samplingerror(M_arr,sample,f_exp), "SmpError_S1_OW_OU_MC.csv", dir=data_dir)
        println("sampling error data saved!")
    end 

    ##
    let T=2, L=5,  N=501;
        meta_info=(description="Sampling Error of One Spin Forth-back Shuttling", sigma=σ, T=T, L=L, N=N, corr_t=κₜ, corr_x=κₓ)
        save(meta_info, "SmpError_S1_FB_OU_MC.json", dir=data_dir)
        println(meta_info)

        model=OneSpinForthBackModel(T, L, N, B);
        # sample=zeros(ComplexF64, M_max)
        # @showprogress for i in 1:M_max
        #     sample[i]=objfunc(model, randn(N))
        # end
        # save(DataFrame(:vals => sample), "SmpVals_S1_FB_OU_MC.csv", dir=data_dir)
        # println("sample values saved!")
        sample=load("Ben02_20240712_SmpVals_S1_FB_OU_MC.csv", dir=data_dir)[!, :vals]
        sample=parse.(ComplexF64, sample)
        f_exp=(1+W(T, L, B; path=:forthback))/2;
        println("prediction: W=", f_exp)
        ##
        save(samplingerror(M_arr,sample,f_exp), "SmpError_S1_FB_OU_MC.csv", dir=data_dir)
        println("sampling error data saved!")
    end 

    ##
    let T0=1, T1=1.5, L=10, N=501;
        meta_info=(description="Sampling Error of Two Spin Sequential Shuttling", 
        sigma=σ, T0=T0, T1=T1, L=L, N=N, corr_t=κₜ, corr_x=κₓ)
        save(meta_info, "SmpError_S2_SQ_OU_MC.json", dir=data_dir)
        println(meta_info)

        model=TwoSpinSequentialModel(T0, T1, L, N, B);
        sample=zeros(ComplexF64, M_max)
        @showprogress for i in 1:M_max
            sample[i]=objfunc(model, randn(2*N))
        end
        save(DataFrame(:vals => sample), "SmpVals_S2_SQ_OU_MC.csv", dir=data_dir)
        println("sample values saved!")
        # sample=load("Ben02_20240707_SmpVals_S2_SQ_OU_MC.csv", dir=data_dir)[!, :vals]
        # sample=parse.(ComplexF64, sample)
        f_exp=(1+W(T0,T1,L,B))/2;
        println("prediction: W=", f_exp)

        save(samplingerror(M_arr,sample, f_exp), "SmpError_S2_SQ_OU_MC.csv", dir=data_dir)
        println("sampling error data saved!")
    end
end