#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan  9 15:49:09 2024

Plotting prediction results

@author: sutanto
"""

import numpy as np
from matplotlib import pyplot as plt # import libraries

#################################################
# Models past=1953-2014

result = np.loadtxt('/Users/Sutanto/Documents/Postdoc_WSG/D3C2_2023/ML/Heatwave/Historical/ERA5_monthly_hw_history.csv', skiprows=1, 
                    delimiter=',') 
year = result[:,0]
month = result[:,1]
model_impact = result[:,2]

# Removing 1 month impact
#model_impact_nzer = np.zeros((len(model_impact)))

pp = np.where(model_impact != 0)
ppp = np.append([0],pp)  # add 0 for starting number
ppdif = np.diff(ppp)   
loc = np.where(ppdif != 1)
 
number_dr_impact = len(loc[0])    
totdur_dr_impact = np.sum(model_impact)
avgdur_dr_impact = totdur_dr_impact/number_dr_impact
freq_dr_impact = totdur_dr_impact/62

# total impact predicted every year
yearr = np.arange(1953,2015,1)
dr_impact = np.zeros((len(yearr)))

for j in range(len(yearr)):   
    pp = np.where(year==yearr[j])
    dr_impact[j] = np.sum(model_impact[pp])

# Duration only when impact occurred
xxx = np.where(dr_impact != 0)     
dur_his = dr_impact[xxx]

#################################################
# Models Future=2039-2100
times = 'Future_26'   # Future_26, Future_85
rcp = '26'    # 26, 85

model = 'gfdl'
resultf = np.loadtxt('/Users/Sutanto/Documents/Postdoc_WSG/D3C2_2023/ML/Heatwave/'+times+'/'+model+'_monthly_hw_'+rcp+'.csv', skiprows=1, 
                    delimiter=',') 
fyear = resultf[:,0]
fmonth = resultf[:,1]
fmodel_impact = np.zeros((len(resultf[:,0]),5))
fmodel_impact[:,0] = resultf[:,2]

model = 'ipsl'
resultf = np.loadtxt('/Users/Sutanto/Documents/Postdoc_WSG/D3C2_2023/ML/Heatwave/'+times+'/'+model+'_monthly_hw_'+rcp+'.csv', skiprows=1, 
                    delimiter=',') 
fmodel_impact[:,1] = resultf[:,2]

model = 'mpi'
resultf = np.loadtxt('/Users/Sutanto/Documents/Postdoc_WSG/D3C2_2023/ML/Heatwave/'+times+'/'+model+'_monthly_hw_'+rcp+'.csv', skiprows=1, 
                    delimiter=',') 
fmodel_impact[:,2] = resultf[:,2]

model = 'mri'
resultf = np.loadtxt('/Users/Sutanto/Documents/Postdoc_WSG/D3C2_2023/ML/Heatwave/'+times+'/'+model+'_monthly_hw_'+rcp+'.csv', skiprows=1, 
                    delimiter=',') 
fmodel_impact[:,3] = resultf[:,2]

model = 'ukesm'
resultf = np.loadtxt('/Users/Sutanto/Documents/Postdoc_WSG/D3C2_2023/ML/Heatwave/'+times+'/'+model+'_monthly_hw_'+rcp+'.csv', skiprows=1, 
                    delimiter=',') 
fmodel_impact[:,4] = resultf[:,2]


number_dr_impact_f26 = np.zeros((5))


for i in range(5):
    dummy = fmodel_impact[:,i]
    pp = np.where(dummy != 0)
    ppp = np.append([0],pp)  # add 0 for starting number
    ppdif = np.diff(ppp)   
    loc = np.where(ppdif != 1)
 
    number_dr_impact_f26[i] = len(loc[0])    

totdur_dr_impact_f26 = np.sum(fmodel_impact, axis=0)

mednumber_dr_impact_f26 = np.median(number_dr_impact_f26)
medtotdur_dr_impact_f26 = np.median(totdur_dr_impact_f26)
medavgdur_dr_impact_f26 = medtotdur_dr_impact_f26/mednumber_dr_impact_f26
medfreq_dr_impact_f26 = medtotdur_dr_impact_f26/62


# total impact predicted every year
fyearr = np.arange(2039,2100,1)
fdr_impact = np.zeros((len(fyearr),5))

for i in range(5):
    for j in range(len(fyearr)):   
        pp = np.where(fyear==fyearr[j])
        fdr_impact[j,i] = np.sum(fmodel_impact[pp,i])

fmedian_impact = np.median(fdr_impact, axis=1)

# Duration only when impact occurred
xxx = np.where(fdr_impact != 0)     # all models
dur_his26 = fdr_impact[xxx]

# Models Future=2039-2100
times = 'Future_85'   # Future_26, Future_85
rcp = '85'    # 26, 85

model = 'gfdl'
resultf = np.loadtxt('/Users/Sutanto/Documents/Postdoc_WSG/D3C2_2023/ML/Heatwave/'+times+'/'+model+'_monthly_hw_'+rcp+'.csv', skiprows=1, 
                    delimiter=',') 
fyear = resultf[:,0]
fmonth = resultf[:,1]
fmodel_impact = np.zeros((len(resultf[:,0]),5))
fmodel_impact[:,0] = resultf[:,2]

model = 'ipsl'
resultf = np.loadtxt('/Users/Sutanto/Documents/Postdoc_WSG/D3C2_2023/ML/Heatwave/'+times+'/'+model+'_monthly_hw_'+rcp+'.csv', skiprows=1, 
                    delimiter=',') 
fmodel_impact[:,1] = resultf[:,2]

model = 'mpi'
resultf = np.loadtxt('/Users/Sutanto/Documents/Postdoc_WSG/D3C2_2023/ML/Heatwave/'+times+'/'+model+'_monthly_hw_'+rcp+'.csv', skiprows=1, 
                    delimiter=',') 
fmodel_impact[:,2] = resultf[:,2]

model = 'mri'
resultf = np.loadtxt('/Users/Sutanto/Documents/Postdoc_WSG/D3C2_2023/ML/Heatwave/'+times+'/'+model+'_monthly_hw_'+rcp+'.csv', skiprows=1, 
                    delimiter=',') 
fmodel_impact[:,3] = resultf[:,2]

model = 'ukesm'
resultf = np.loadtxt('/Users/Sutanto/Documents/Postdoc_WSG/D3C2_2023/ML/Heatwave/'+times+'/'+model+'_monthly_hw_'+rcp+'.csv', skiprows=1, 
                    delimiter=',') 
fmodel_impact[:,4] = resultf[:,2]

number_dr_impact_f85 = np.zeros((5))


for i in range(5):
    dummy = fmodel_impact[:,i]
    pp = np.where(dummy != 0)
    ppp = np.append([0],pp)  # add 0 for starting number
    ppdif = np.diff(ppp)   
    loc = np.where(ppdif != 1)
 
    number_dr_impact_f85[i] = len(loc[0])    

totdur_dr_impact_f85 = np.sum(fmodel_impact, axis=0)

mednumber_dr_impact_f85 = np.median(number_dr_impact_f85)
medtotdur_dr_impact_f85 = np.median(totdur_dr_impact_f85)
medavgdur_dr_impact_f85 = medtotdur_dr_impact_f85/mednumber_dr_impact_f85
medfreq_dr_impact_f85 = medtotdur_dr_impact_f85/62



# total impact predicted every year
fyearr = np.arange(2039,2100,1)
fdr_impact85 = np.zeros((len(fyearr),5))

for i in range(5):
    for j in range(len(fyearr)):   
        pp = np.where(fyear==fyearr[j])
        fdr_impact85[j,i] = np.sum(fmodel_impact[pp,i])

fmedian_impact85 = np.median(fdr_impact85, axis=1)

# Duration only when impact occurred
xxx = np.where(fdr_impact85 != 0)     
dur_his85 = fdr_impact85[xxx]

#################################################
# Box plot for past and future events
fdr_impact_bp = np.zeros((100,3))

fdr_impact_bp[:,0] = fdr_impact[0:20,:].flatten()
fdr_impact_bp[:,1] = fdr_impact[20:40,:].flatten()
fdr_impact_bp[:,2] = fdr_impact[40:60,:].flatten()

labels = ['2039-2059','2060-2080','2081-2100']

# Box plot past
labels2 = ['1953-2014']


# Combine past and future
fdr_impact_bp2 = np.full((len(dur_his26),3),np.nan)
fdr_impact_bp2xx = np.full((len(dur_his85),3),np.nan)

#fdr_impact_bp2[:,0] = dr_impact.flatten()
fdr_impact_bp2[:,1] = dur_his26.flatten()
fdr_impact_bp2xx[:,2] = dur_his85.flatten()

fdr_impact_bp2x = np.full((len(dur_his),3),np.nan)

fdr_impact_bp2x[:,0] = dur_his.flatten()


labels3 = ['Reference','SSP1-2.6','SSP5-8.5']



#################################################
# heatwave reported impact

resulti = np.loadtxt('/Users/Sutanto/Documents/Postdoc_WSG/D3C2_2023/ML/Heatwave/Heatwave_Germany_impact.csv', skiprows=1, 
                    delimiter=',') 
yeari = resulti[:,0]
impact = resulti[:,2]

# total impact predicted every year
iyearr = np.arange(1953,2015,1)
hw_impact = np.zeros((len(iyearr)))


for j in range(len(iyearr)):   
        pp = np.where(yeari==iyearr[j])
        hw_impact[j] = np.sum(impact[pp])


    
"""
Plotting
"""
### Plotting historical
fig, ax = plt.subplots(figsize=(10,5))
bar_width = 0.4
opacity = 1

plt.bar(yearr+(-0.5*bar_width), dr_impact, bar_width, alpha=opacity, color='blue', label='ML model')
plt.bar(yearr+(0.5*bar_width), hw_impact, bar_width, alpha=opacity, color='red', label='Impact')

plt.title('Impact prediction for future heatwave')
plt.xlabel('Year',fontsize=12)
#plt.xticks(yearr+bar_width, fontsize=12)
plt.yticks(np.arange(0, 4, 1))
#plt.yticks(fontsize=12)
plt.ylabel('Total impact',fontsize=12)
plt.legend(fontsize=10, loc='upper left')
plt.ylim(0,4)
plt.xlim(1950,2020)

#plt.savefig('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/ML/Figure/Heatwave_Impact_His.jpg', dpi=300)



### Plotting historical and future
plt.figure(figsize=(7, 5), dpi=300)
plt.subplot(1,1,1)

plt.boxplot(fdr_impact_bp2x, labels=labels3, patch_artist=True, boxprops = dict(facecolor = "yellowgreen"),medianprops = dict(color = "darkblue", linewidth = 4))
plt.boxplot(fdr_impact_bp2, labels=labels3, patch_artist=True, boxprops = dict(facecolor = "yellow"),medianprops = dict(color = "darkblue", linewidth = 4))
plt.boxplot(fdr_impact_bp2xx, labels=labels3, patch_artist=True, boxprops = dict(facecolor = "orangered"),medianprops = dict(color = "darkblue", linewidth = 4))
plt.title('')
plt.xlabel('Scenario')
plt.ylabel('Number of Impact')

plt.savefig('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/ML/Figure/Heatwave_Impact_combined.jpg', dpi=300)


plt.show()    
    