#= 
This julia script finds the ground state energy of the TFIM for a range of 
external field strengths. It then converts the MPS representation of the ground
state wavefunction into a vector in the computational basis (inefficient) such
that the magic can be computed in the same fashion as the results from the other
methods.

Note: in the loop arounds line 130, the conversion from MPS to vector is hardcoded
for a specfic number of qubits, this is because I am bad at Julia. 

Depending on the number of qubits that this is run for, you will need to create
new directories for the output.

Special thanks to https://gist.github.com/hershsingh/e6bd6c1fab997530c1c0e31e0a1c87d5
=#

using ITensors, ITensorMPS, DelimitedFiles, Printf
include("utils.jl")

sweeps = Sweeps(100, [
    "maxdim" "mindim" "cutoff" "noise"
    10       10       1e-12    1E-7
    10       10       1e-12    1E-7
    20       10       1e-12    1E-7
    50       10       1e-12    1E-7
    100      20       1e-12    1E-8
    200      20       1e-12    1E-10
    200      20       1e-12    0
])


mutable struct DemoObserver <: AbstractObserver
    """
    Tells the DMRG algorithm when to stop early.
    """
    energy_tol::Float64
    last_energy::Float64
    num_sweeps::Int64
    energy_tol_current::Float64

    DemoObserver(energy_tol=0.0) = new(energy_tol,1000.0, 0, 0.0)
end

function ITensors.checkdone!(o::DemoObserver;kwargs...)
    sw = kwargs[:sweep]
    energy = kwargs[:energy]

    o.energy_tol_current = abs(energy-o.last_energy)/abs(energy)
    o.last_energy = energy
    o.num_sweeps =  sw

    if  o.energy_tol_current < o.energy_tol
        println("Stopping DMRG after sweep $sw")
        return true
    end
    return false
end

function dmrg_ising(N::Int64, h::Float64, sites)
    println("="^80)
    println("DMRG for Ising model")
    println("N = $N, h = $h")
    println("-"^20)

    ampo = OpSum()
    for i in 1:N-1
        ampo += -1.0, "Z", i, "Z", i+1
        ampo += -h, "X", i
    end
    ampo += -h, "X", N
    ampo += -1.0, "Z", N, "Z", 1
    H = MPO(ampo, sites)

    psi0 = randomMPS(sites; linkdims=10)

    # Set the energy tolerance to 1e-12
    dmrg_observer = DemoObserver(1E-12)

    energy, psi = dmrg(H, psi0, sweeps; observer=dmrg_observer);

    return energy, dmrg_observer.num_sweeps, dmrg_observer.energy_tol_current, psi
end

function from_strings_to_MPO(N::Int64, all_Pauli_strings::Vector{Any}, sites)
    all_Pauli_MPOs = similar(all_Pauli_strings)
    for i in eachindex(all_Pauli_strings)
        mpo = OpSum()
        for (qubit, gate) in enumerate(all_Pauli_strings[i])
            mpo += gate*"", qubit
        end
        all_Pauli_MPOs[i] = MPO(mpo, sites)
    end
    return all_Pauli_MPOs
end

h_arr = collect(0.0:0.25:3.0)
N = 8
sites = siteinds("S=1/2", N)

eigenvalues = Any[]
eigenvectors = Any[]


for h in h_arr
    energy, num_sweeps, energy_tol, psi = dmrg_ising(N,h,sites)
    push!(eigenvalues, [h, energy])
    push!(eigenvectors, [h, psi])
end
writedlm("./outputs/$(N)_qubits/energies.dat", eigenvalues)

# Convert the MPS into a tensor
wavefunctions_as_tensors = Any[]
for (h, eigenvector) in eigenvectors
    running_tensor = 1
    for site in eigenvector
        running_tensor *= site
    end
    push!(wavefunctions_as_tensors, [running_tensor])
end
@show wavefunctions_as_tensors


# Evaluate the tensor for each configuration
all_configs = generate_all_sequences(N, ["1","2"])
configs_as_ints = from_string_to_integers(N, all_configs)
n_configs, _ = size(configs_as_ints)

for (idx, wavefunction) in enumerate(wavefunctions_as_tensors)
    wavefunction_as_vector_of_amplitudes = Any[]
    for config_idx in 1:n_configs
        # THIS IS WHERE IT'S HARD CODED TO 8 QUBITS
        amplitude = wavefunction[1][configs_as_ints[config_idx,1],configs_as_ints[config_idx,2],configs_as_ints[config_idx,3],configs_as_ints[config_idx,4],configs_as_ints[config_idx,5],configs_as_ints[config_idx,6],configs_as_ints[config_idx,7],configs_as_ints[config_idx,8]]

        ## THIS IS WHERE IT'S HARD CODED TO 12 QUBITS
        #amplitude = wavefunction[1][configs_as_ints[config_idx,1],configs_as_ints[config_idx,2],configs_as_ints[config_idx,3],configs_as_ints[config_idx,4],configs_as_ints[config_idx,5],configs_as_ints[config_idx,6],configs_as_ints[config_idx,7],configs_as_ints[config_idx,8],configs_as_ints[config_idx,9],configs_as_ints[config_idx,10],configs_as_ints[config_idx,11],configs_as_ints[config_idx,12]]
        push!(wavefunction_as_vector_of_amplitudes, [amplitude])
    end
    formatted_h = @sprintf("%.*f", 3, eigenvectors[idx,1][1])
    while length(formatted_h) < 3
        formatted_h *= "0"
    end
    filename = "./outputs/$(N)_qubits/wavefunctions/wavefunction_for_mu-$(formatted_h).dat"
    writedlm(filename, wavefunction_as_vector_of_amplitudes)
end
