#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Nov 29 09:38:10 2023

Plotting heatwaves as single event
Also for all hotspot plots

@author: sutanto
"""


import numpy as np
import netCDF4
#from mpl_toolkits.basemap import Basemap
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib import pyplot as plt
#from matplotlib import colors
from matplotlib.colors import BoundaryNorm
from matplotlib.ticker import FormatStrFormatter


#########################################################
# Open heatwave data (past)
# gfdl
file2read = netCDF4.Dataset('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Merve/Heatwaves/hw_summary_gfdl.nc','r')  
lon = file2read.variables['x'][:]
lat = file2read.variables['y'][:]  
pgfdl_numberx = file2read.variables['hweve'][:]  
pgfdl_totdurx = file2read.variables['hwdur'][:]  
pgfdl_frequex = file2read.variables['hwint'][:]  

# Total
pgfdl_numberxx = np.nansum(pgfdl_numberx,3)
pgfdl_number = np.nansum(pgfdl_numberxx,2)
pgfdl_totdurxx = np.nansum(pgfdl_totdurx,3)
pgfdl_totdur = np.nansum(pgfdl_totdurxx,2)
pgfdl_frequexx = np.nanmean(pgfdl_frequex,3)
pgfdl_freque = np.nanmean(pgfdl_frequexx,2)

pgfdl_avedur = pgfdl_totdur/pgfdl_number

# ipsl
file2read = netCDF4.Dataset('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Merve/Heatwaves/hw_summary_ipsl.nc','r')  
pipsl_numberx = file2read.variables['hweve'][:]  
pipsl_totdurx = file2read.variables['hwdur'][:]  
pipsl_frequex = file2read.variables['hwint'][:]  

# Total
pipsl_numberxx = np.nansum(pipsl_numberx,3)
pipsl_number = np.nansum(pipsl_numberxx,2)
pipsl_totdurxx = np.nansum(pipsl_totdurx,3)
pipsl_totdur = np.nansum(pipsl_totdurxx,2)
pipsl_frequexx = np.nanmean(pipsl_frequex,3)
pipsl_freque = np.nanmean(pipsl_frequexx,2)

pipsl_avedur = pipsl_totdur/pipsl_number

# mpi
file2read = netCDF4.Dataset('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Merve/Heatwaves/hw_summary_mpi.nc','r')  
pmpi_numberx = file2read.variables['hweve'][:]  
pmpi_totdurx = file2read.variables['hwdur'][:]  
pmpi_frequex = file2read.variables['hwint'][:]  

# Total
pmpi_numberxx = np.nansum(pmpi_numberx,3)
pmpi_number = np.nansum(pmpi_numberxx,2)
pmpi_totdurxx = np.nansum(pmpi_totdurx,3)
pmpi_totdur = np.nansum(pmpi_totdurxx,2)
pmpi_frequexx = np.nanmean(pmpi_frequex,3)
pmpi_freque = np.nanmean(pmpi_frequexx,2)

pmpi_avedur = pmpi_totdur/pmpi_number

# mri
file2read = netCDF4.Dataset('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Merve/Heatwaves/hw_summary_mri.nc','r')  
pmri_numberx = file2read.variables['hweve'][:]  
pmri_totdurx = file2read.variables['hwdur'][:]  
pmri_frequex = file2read.variables['hwint'][:]  

# Total
pmri_numberxx = np.nansum(pmri_numberx,3)
pmri_number = np.nansum(pmri_numberxx,2)
pmri_totdurxx = np.nansum(pmri_totdurx,3)
pmri_totdur = np.nansum(pmri_totdurxx,2)
pmri_frequexx = np.nanmean(pmri_frequex,3)
pmri_freque = np.nanmean(pmri_frequexx,2)

pmri_avedur = pmri_totdur/pmri_number

# ukesm
file2read = netCDF4.Dataset('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Merve/Heatwaves/hw_summary_ukesm.nc','r')  
pukesm_numberx = file2read.variables['hweve'][:]  
pukesm_totdurx = file2read.variables['hwdur'][:]  
pukesm_frequex = file2read.variables['hwint'][:]  

# Total
pukesm_numberxx = np.nansum(pukesm_numberx,3)
pukesm_number = np.nansum(pukesm_numberxx,2)
pukesm_totdurxx = np.nansum(pukesm_totdurx,3)
pukesm_totdur = np.nansum(pukesm_totdurxx,2)
pukesm_frequexx = np.nanmean(pukesm_frequex,3)
pukesm_freque = np.nanmean(pukesm_frequexx,2)

pukesm_avedur = pukesm_totdur/pukesm_number

#### Ensemble median
ptotdur = np.zeros((len(lon),len(lat),5))
ptotdur[:,:,0] =  pgfdl_totdur
ptotdur[:,:,1] =  pipsl_totdur
ptotdur[:,:,2] =  pmpi_totdur
ptotdur[:,:,3] =  pmri_totdur
ptotdur[:,:,4] =  pukesm_totdur
ptotdurmed = np.median(ptotdur, axis=2)

pavgdur = np.zeros((len(lon),len(lat),5))
pavgdur[:,:,0] =  pgfdl_avedur
pavgdur[:,:,1] =  pipsl_avedur
pavgdur[:,:,2] =  pmpi_avedur
pavgdur[:,:,3] =  pmri_avedur
pavgdur[:,:,4] =  pukesm_avedur
pavgdurmed = np.median(pavgdur, axis=2)

pnumbeve = np.zeros((len(lon),len(lat),5))
pnumbeve[:,:,0] =  pgfdl_number
pnumbeve[:,:,1] =  pipsl_number
pnumbeve[:,:,2] =  pmpi_number
pnumbeve[:,:,3] =  pmri_number
pnumbeve[:,:,4] =  pukesm_number
pnumbevemed = np.median(pnumbeve, axis=2)

pfreque = np.zeros((len(lon),len(lat),5))
pfreque[:,:,0] =  pgfdl_freque
pfreque[:,:,1] =  pipsl_freque
pfreque[:,:,2] =  pmpi_freque
pfreque[:,:,3] =  pmri_freque
pfreque[:,:,4] =  pukesm_freque
pfrequemed = np.median(pfreque, axis=2)

# filtering ocean again
nanloc = np.where(pgfdl_totdur==0)
ptotdurmed[nanloc] = float("nan")
pavgdurmed[nanloc] = float("nan")
pnumbevemed[nanloc] = float("nan")
pfrequemed[nanloc] = float("nan")


#########################################################
# Open heatwave data (ssp126)
# gfdl
file2read = netCDF4.Dataset('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Merve/Heatwaves/hw_summary_gfdl_2.6.nc','r')  
gfdl_number_ssp126x = file2read.variables['hweve'][:]  
gfdl_totdur_ssp126x = file2read.variables['hwdur'][:]  
gfdl_freque_ssp126x = file2read.variables['hwint'][:]  

# Total
gfdl_number_ssp126xx = np.nansum(gfdl_number_ssp126x,3)
gfdl_number_ssp126 = np.nansum(gfdl_number_ssp126xx,2)
gfdl_totdur_ssp126xx = np.nansum(gfdl_totdur_ssp126x,3)
gfdl_totdur_ssp126 = np.nansum(gfdl_totdur_ssp126xx,2)
gfdl_freque_ssp126xx = np.nanmean(gfdl_freque_ssp126x,3)
gfdl_freque_ssp126 = np.nanmean(gfdl_freque_ssp126xx,2)

gfdl_avedur_ssp126 = gfdl_totdur_ssp126/gfdl_number_ssp126


# ipsl
file2read = netCDF4.Dataset('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Merve/Heatwaves/hw_summary_ipsl_2.6.nc','r')  
ipsl_number_ssp126x = file2read.variables['hweve'][:]  
ipsl_totdur_ssp126x = file2read.variables['hwdur'][:]  
ipsl_freque_ssp126x = file2read.variables['hwint'][:]  

# Total
ipsl_number_ssp126xx = np.nansum(ipsl_number_ssp126x,3)
ipsl_number_ssp126 = np.nansum(ipsl_number_ssp126xx,2)
ipsl_totdur_ssp126xx = np.nansum(ipsl_totdur_ssp126x,3)
ipsl_totdur_ssp126 = np.nansum(ipsl_totdur_ssp126xx,2)
ipsl_freque_ssp126xx = np.nanmean(ipsl_freque_ssp126x,3)
ipsl_freque_ssp126 = np.nanmean(ipsl_freque_ssp126xx,2)

ipsl_avedur_ssp126 = ipsl_totdur_ssp126/ipsl_number_ssp126

# mpi
file2read = netCDF4.Dataset('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Merve/Heatwaves/hw_summary_mpi_2.6.nc','r')  
mpi_number_ssp126x = file2read.variables['hweve'][:]  
mpi_totdur_ssp126x = file2read.variables['hwdur'][:]  
mpi_freque_ssp126x = file2read.variables['hwint'][:]  

# Total
mpi_number_ssp126xx = np.nansum(mpi_number_ssp126x,3)
mpi_number_ssp126 = np.nansum(mpi_number_ssp126xx,2)
mpi_totdur_ssp126xx = np.nansum(mpi_totdur_ssp126x,3)
mpi_totdur_ssp126 = np.nansum(mpi_totdur_ssp126xx,2)
mpi_freque_ssp126xx = np.nanmean(mpi_freque_ssp126x,3)
mpi_freque_ssp126 = np.nanmean(mpi_freque_ssp126xx,2)

mpi_avedur_ssp126 = mpi_totdur_ssp126/mpi_number_ssp126

# mri
file2read = netCDF4.Dataset('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Merve/Heatwaves/hw_summary_mri_2.6.nc','r')  
mri_number_ssp126x = file2read.variables['hweve'][:]  
mri_totdur_ssp126x = file2read.variables['hwdur'][:]  
mri_freque_ssp126x = file2read.variables['hwint'][:]  

# Total
mri_number_ssp126xx = np.nansum(mri_number_ssp126x,3)
mri_number_ssp126 = np.nansum(mri_number_ssp126xx,2)
mri_totdur_ssp126xx = np.nansum(mri_totdur_ssp126x,3)
mri_totdur_ssp126 = np.nansum(mri_totdur_ssp126xx,2)
mri_freque_ssp126xx = np.nanmean(mri_freque_ssp126x,3)
mri_freque_ssp126 = np.nanmean(mri_freque_ssp126xx,2)

mri_avedur_ssp126 = mri_totdur_ssp126/mri_number_ssp126

# ukesm
file2read = netCDF4.Dataset('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Merve/Heatwaves/hw_summary_ukesm_2.6.nc','r')  
ukesm_number_ssp126x = file2read.variables['hweve'][:]  
ukesm_totdur_ssp126x = file2read.variables['hwdur'][:]  
ukesm_freque_ssp126x = file2read.variables['hwint'][:]  

# Total
ukesm_number_ssp126xx = np.nansum(ukesm_number_ssp126x,3)
ukesm_number_ssp126 = np.nansum(ukesm_number_ssp126xx,2)
ukesm_totdur_ssp126xx = np.nansum(ukesm_totdur_ssp126x,3)
ukesm_totdur_ssp126 = np.nansum(ukesm_totdur_ssp126xx,2)
ukesm_freque_ssp126xx = np.nanmean(ukesm_freque_ssp126x,3)
ukesm_freque_ssp126 = np.nanmean(ukesm_freque_ssp126xx,2)

ukesm_avedur_ssp126 = ukesm_totdur_ssp126/ukesm_number_ssp126

#### Ensemble median
f1totdur = np.zeros((len(lon),len(lat),5))
f1totdur[:,:,0] =  gfdl_totdur_ssp126
f1totdur[:,:,1] =  ipsl_totdur_ssp126
f1totdur[:,:,2] =  mpi_totdur_ssp126
f1totdur[:,:,3] =  mri_totdur_ssp126
f1totdur[:,:,4] =  ukesm_totdur_ssp126
f1totdurmed = np.median(f1totdur, axis=2)

f1avgdur = np.zeros((len(lon),len(lat),5))
f1avgdur[:,:,0] =  gfdl_avedur_ssp126
f1avgdur[:,:,1] =  ipsl_avedur_ssp126
f1avgdur[:,:,2] =  mpi_avedur_ssp126
f1avgdur[:,:,3] =  mri_avedur_ssp126
f1avgdur[:,:,4] =  ukesm_avedur_ssp126
f1avgdurmed = np.median(f1avgdur, axis=2)

f1numbeve = np.zeros((len(lon),len(lat),5))
f1numbeve[:,:,0] =  gfdl_number_ssp126
f1numbeve[:,:,1] =  ipsl_number_ssp126
f1numbeve[:,:,2] =  mpi_number_ssp126
f1numbeve[:,:,3] =  mri_number_ssp126
f1numbeve[:,:,4] =  ukesm_number_ssp126
f1numbevemed = np.median(f1numbeve, axis=2)

f1freque = np.zeros((len(lon),len(lat),5))
f1freque[:,:,0] =  gfdl_freque_ssp126
f1freque[:,:,1] =  ipsl_freque_ssp126
f1freque[:,:,2] =  mpi_freque_ssp126
f1freque[:,:,3] =  mri_freque_ssp126
f1freque[:,:,4] =  ukesm_freque_ssp126
f1frequemed = np.median(f1freque, axis=2)

# filtering ocean again
f1totdurmed[nanloc] = float("nan")
f1avgdurmed[nanloc] = float("nan")
f1numbevemed[nanloc] = float("nan")
f1frequemed[nanloc] = float("nan")

#########################################################
# Open heatwave data (ssp585)
# gfdl
file2read = netCDF4.Dataset('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Merve/Heatwaves/hw_summary_gfdl_8.5.nc','r')  
gfdl_number_ssp585x = file2read.variables['hweve'][:]  
gfdl_totdur_ssp585x = file2read.variables['hwdur'][:]  
gfdl_freque_ssp585x = file2read.variables['hwint'][:]  

# Total
gfdl_number_ssp585xx = np.nansum(gfdl_number_ssp585x,3)
gfdl_number_ssp585 = np.nansum(gfdl_number_ssp585xx,2)
gfdl_totdur_ssp585xx = np.nansum(gfdl_totdur_ssp585x,3)
gfdl_totdur_ssp585 = np.nansum(gfdl_totdur_ssp585xx,2)
gfdl_freque_ssp585xx = np.nanmean(gfdl_freque_ssp585x,3)
gfdl_freque_ssp585 = np.nanmean(gfdl_freque_ssp585xx,2)

gfdl_avedur_ssp585 = gfdl_totdur_ssp585/gfdl_number_ssp585

# ipsl
file2read = netCDF4.Dataset('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Merve/Heatwaves/hw_summary_ipsl_8.5.nc','r')  
ipsl_number_ssp585x = file2read.variables['hweve'][:]  
ipsl_totdur_ssp585x = file2read.variables['hwdur'][:]  
ipsl_freque_ssp585x = file2read.variables['hwint'][:]  

# Total
ipsl_number_ssp585xx = np.nansum(ipsl_number_ssp585x,3)
ipsl_number_ssp585 = np.nansum(ipsl_number_ssp585xx,2)
ipsl_totdur_ssp585xx = np.nansum(ipsl_totdur_ssp585x,3)
ipsl_totdur_ssp585 = np.nansum(ipsl_totdur_ssp585xx,2)
ipsl_freque_ssp585xx = np.nanmean(ipsl_freque_ssp585x,3)
ipsl_freque_ssp585 = np.nanmean(ipsl_freque_ssp585xx,2)

ipsl_avedur_ssp585 = ipsl_totdur_ssp585/ipsl_number_ssp585

# mpi
file2read = netCDF4.Dataset('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Merve/Heatwaves/hw_summary_mpi_8.5.nc','r')  
mpi_number_ssp585x = file2read.variables['hweve'][:]  
mpi_totdur_ssp585x= file2read.variables['hwdur'][:]  
mpi_freque_ssp585x = file2read.variables['hwint'][:]  

# Total
mpi_number_ssp585xx = np.nansum(mpi_number_ssp585x,3)
mpi_number_ssp585 = np.nansum(mpi_number_ssp585xx,2)
mpi_totdur_ssp585xx = np.nansum(mpi_totdur_ssp585x,3)
mpi_totdur_ssp585 = np.nansum(mpi_totdur_ssp585xx,2)
mpi_freque_ssp585xx = np.nanmean(mpi_freque_ssp585x,3)
mpi_freque_ssp585 = np.nanmean(mpi_freque_ssp585xx,2)

mpi_avedur_ssp585 = mpi_totdur_ssp585/mpi_number_ssp585

# mri
file2read = netCDF4.Dataset('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Merve/Heatwaves/hw_summary_mri_8.5.nc','r')  
mri_number_ssp585x = file2read.variables['hweve'][:]  
mri_totdur_ssp585x = file2read.variables['hwdur'][:]  
mri_freque_ssp585x = file2read.variables['hwint'][:]  

# Total
mri_number_ssp585xx = np.nansum(mri_number_ssp585x,3)
mri_number_ssp585 = np.nansum(mri_number_ssp585xx,2)
mri_totdur_ssp585xx = np.nansum(mri_totdur_ssp585x,3)
mri_totdur_ssp585 = np.nansum(mri_totdur_ssp585xx,2)
mri_freque_ssp585xx = np.nanmean(mri_freque_ssp585x,3)
mri_freque_ssp585 = np.nanmean(mri_freque_ssp585xx,2)

mri_avedur_ssp585 = mri_totdur_ssp585/mri_number_ssp585

# ukesm
file2read = netCDF4.Dataset('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Merve/Heatwaves/hw_summary_ukesm_8.5.nc','r')  
ukesm_number_ssp585x = file2read.variables['hweve'][:]  
ukesm_totdur_ssp585x = file2read.variables['hwdur'][:]  
ukesm_freque_ssp585x = file2read.variables['hwint'][:]  

# Total
ukesm_number_ssp585xx = np.nansum(ukesm_number_ssp585x,3)
ukesm_number_ssp585 = np.nansum(ukesm_number_ssp585xx,2)
ukesm_totdur_ssp585xx = np.nansum(ukesm_totdur_ssp585x,3)
ukesm_totdur_ssp585 = np.nansum(ukesm_totdur_ssp585xx,2)
ukesm_freque_ssp585xx = np.nanmean(ukesm_freque_ssp585x,3)
ukesm_freque_ssp585 = np.nanmean(ukesm_freque_ssp585xx,2)

ukesm_avedur_ssp585 = ukesm_totdur_ssp585/ukesm_number_ssp585

#### Ensemble median
f2totdur = np.zeros((len(lon),len(lat),5))
f2totdur[:,:,0] =  gfdl_totdur_ssp585
f2totdur[:,:,1] =  ipsl_totdur_ssp585
f2totdur[:,:,2] =  mpi_totdur_ssp585
f2totdur[:,:,3] =  mri_totdur_ssp585
f2totdur[:,:,4] =  ukesm_totdur_ssp585
f2totdurmed = np.median(f2totdur, axis=2)

f2avgdur = np.zeros((len(lon),len(lat),5))
f2avgdur[:,:,0] =  gfdl_avedur_ssp585
f2avgdur[:,:,1] =  ipsl_avedur_ssp585
f2avgdur[:,:,2] =  mpi_avedur_ssp585
f2avgdur[:,:,3] =  mri_avedur_ssp585
f2avgdur[:,:,4] =  ukesm_avedur_ssp585
f2avgdurmed = np.median(f2avgdur, axis=2)

f2numbeve = np.zeros((len(lon),len(lat),5))
f2numbeve[:,:,0] =  gfdl_number_ssp585
f2numbeve[:,:,1] =  ipsl_number_ssp585
f2numbeve[:,:,2] =  mpi_number_ssp585
f2numbeve[:,:,3] =  mri_number_ssp585
f2numbeve[:,:,4] =  ukesm_number_ssp585
f2numbevemed = np.median(f2numbeve, axis=2)

f2freque = np.zeros((len(lon),len(lat),5))
f2freque[:,:,0] =  gfdl_freque_ssp585
f2freque[:,:,1] =  ipsl_freque_ssp585
f2freque[:,:,2] =  mpi_freque_ssp585
f2freque[:,:,3] =  mri_freque_ssp585
f2freque[:,:,4] =  ukesm_freque_ssp585
f2frequemed = np.median(f2freque, axis=2)

# filtering ocean again
f2totdurmed[nanloc] = float("nan")
f2avgdurmed[nanloc] = float("nan")
f2numbevemed[nanloc] = float("nan")
f2frequemed[nanloc] = float("nan")

#########################################################
# Future changes (ssp126)
Diff_totdur1 = f1totdurmed - ptotdurmed
Diff_avgdur1 = f1avgdurmed - pavgdurmed
Diff_numbeve1 = f1numbevemed - pnumbevemed
Diff_freque1 = f1frequemed - pfrequemed

# Future changes (ssp585)
Diff_totdur2 = f2totdurmed - ptotdurmed
Diff_avgdur2 = f2avgdurmed - pavgdurmed
Diff_numbeve2 = f2numbevemed - pnumbevemed
Diff_freque2 = f2frequemed - pfrequemed




"""
Plotting
"""

# Because our lon and lat variables are 1D,
late, lone = np.meshgrid(lat, lon)


# Create a figure and a GeoAxes with a PlateCarree projection
fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': ccrs.PlateCarree()})

# add features (coastlines, borders, etc.)
ax.coastlines(resolution='50m')   # 110m, 50m, 10m
ax.add_feature(cfeature.BORDERS, linestyle=':')
gl = ax.gridlines(draw_labels=True)
gl.top_labels = False
gl.right_labels = False
gl.xlabel_style = {'size': 12}
gl.ylabel_style = {'size': 12}

# Add Grid Ticks (optional custom ticks)
lat_ticks = np.arange(np.amin(lat), np.amax(lat), 10.)
lon_ticks = np.arange(np.amin(lon), np.amax(lon), 10.)

# Define the colormap and boundaries (even intervals)
boundaries = np.linspace(0, 600, num=7)  # Equal intervals from -0.5 to 2
norm = BoundaryNorm(boundaries, ncolors=256)

# Plot the data using pcolormesh
mesh = ax.pcolormesh(lone, late, f2numbevemed,cmap='coolwarm', norm=norm,transform=ccrs.PlateCarree())

# Add a colorbar
cbar = plt.colorbar(mesh, ax=ax, shrink=0.4, pad=0.05)
cbar.set_label("Number of event",size=12)
cbar.set_ticks(boundaries)
cbar.ax.yaxis.set_major_formatter(FormatStrFormatter('%.0f'))  # Format ticks to 1 decimal place

# Show the plot
plt.title('',fontsize= 10) 
plt.tight_layout()

# Limiting the plot
plt.xlim(-23,45)
plt.ylim(35,75)

#plt.savefig('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Results/Changes_Average_Duration_Median_ssp126_Heatwaves.jpg', dpi=300)
plt.savefig('/Users/Sutanto/Documents/postdoc_WSG/D3C2_2023/Data/Results/Number_Median_ssp585_Heatwaves_HS2.jpg', dpi=300)
plt.show()

