#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 31 18:41:11 2023

@author: kailiu
"""
# This code is to plot the relative error of approximation in Pub_1
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error
import pandas

def read_file(filename):
    list_x = []
    list_y = []
    with open(filename) as f:
        for line in f:
            x, y = line.split()
            list_x.append(float(x))
            list_y.append(float(y))
    f.close()
    return list_x, list_y

AcLirate, Prediction = read_file('/home/kailiu/Documents/MATLAB/main/AccurancyOfSimplifiedEquation/Li.txt')
mae_all   = mean_absolute_error(AcLirate, Prediction)

Acc = np.array(AcLirate)
Pred = np.array(Prediction)
List = np.vstack((Acc, Pred))
List = np.transpose(List)
Listnew = pandas.DataFrame(List,columns=['Analytical','Approximation'])
low  = -1.5
high = 1.5
lims = [low,high]
intv = 0.5
Metal = 'Li'
g = sns.JointGrid(Listnew, x="Analytical", y="Approximation", space=0, height=3)
g.ax_joint.plot(lims, lims, '--k', linewidth=0.5)  
g = g.plot_joint(sns.kdeplot, color="C0", shade=True, shade_lowest=False, n_levels=20)
g = g.plot_marginals(sns.distplot, color="C0")

g.ax_marg_x.set_xlim(lims[0], lims[1])
g.ax_marg_y.set_ylim(lims[0], lims[1])

g.ax_joint.plot(lims, lims, '--k', linewidth=0.5) 

g.ax_joint.set_xticks(np.arange(low, high+intv, intv).astype(float))
g.ax_joint.set_yticks(np.arange(low, high+intv, intv).astype(float))

g.ax_joint.text(0.05, 0.95, "%s \nMAE=%.2f"%(Metal, mae_all), horizontalalignment='left', verticalalignment='top', transform=g.ax_joint.transAxes)
plt.show()
plt.savefig('/home/kailiu/Documents/MATLAB/main/AccurancyOfSimplifiedEquation/Li.svg', format='svg', dpi=600)