#!/usr/bin/python

import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import matplotlib as mpl

save_direc = './'
data_direc = './'

plt.style.use('./paper.mplstyle')

prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
linestyles = ['solid', 'dotted', 'dashed', 'dashdot']

def figure_si_monoresonance():
    '''Growth rate and logarithmic mu dependence mSI'''
    npzfile = np.load(data_direc + 'fig_si_monoresonance.npz')

    res = np.log(np.imag(npzfile['growth_h'])/(np.imag(npzfile['growth_l'])+1.0e-10))/np.log(npzfile['mu'][1]/npzfile['mu'][0])

    fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(4.5,7), layout='constrained')

    axs[0].set_xscale('log')
    axs[0].set_yscale('log')
    axs[0].set_xlabel(r'$K_x$')
    axs[0].set_ylabel(r'$K_z$')

    axs[0].set_xlim([npzfile['kx'][0], npzfile['kx'][-1]])
    axs[0].set_ylim([npzfile['kz'][0], npzfile['kz'][-1]])

    f = np.transpose(np.log10(np.imag(npzfile['growth_h'])))
    pcm = axs[0].pcolormesh(npzfile['kx'], npzfile['kz'], f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    cbar = fig.colorbar(pcm, ax=axs[0], ticks=[-7,-6,-5,-4,-3,-2,-1,0])
    cbar.ax.set_yticklabels([r'$10^{-7}$', r'$10^{-6}$', r'$10^{-5}$', r'$10^{-4}$', r'$10^{-3}$', r'$10^{-2}$', r'$10^{-1}$', r'$10^{0}$'])

    axs[1].set_xscale('log')
    axs[1].set_yscale('log')
    axs[1].set_xlabel(r'$K_x$')
    axs[1].set_ylabel(r'$K_z$')

    axs[1].set_xlim([npzfile['kx'][0], npzfile['kx'][-1]])
    axs[1].set_ylim([npzfile['kz'][0], npzfile['kz'][-1]])

    pcm = axs[1].pcolormesh(npzfile['kx'], npzfile['kz'], np.transpose(res), vmin=0, vmax=1.01, cmap='RdBu_r',shading='gouraud',rasterized=True)
    fig.colorbar(pcm, ax=axs[1])

    plt.savefig(save_direc + 'fig_SI_monoresonance.png')

    #plt.show()
    plt.close()

def figure_si_resonance():
    npzfile = np.load(data_direc + 'fig_si_resonance.npz')

    res = np.log(np.imag(npzfile['growth_h'])/(np.imag(npzfile['growth_l'])+1.0e-10))/np.log(npzfile['mu'][1]/npzfile['mu'][0])

    fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(4.5,7), layout='constrained')

    axs[0].set_xscale('log')
    axs[0].set_yscale('log')
    axs[0].set_xlabel(r'$K_x$')
    axs[0].set_ylabel(r'$K_z$')

    axs[0].set_xlim([npzfile['kx'][0], npzfile['kx'][-1]])
    axs[0].set_ylim([npzfile['kz'][0], npzfile['kz'][-1]])

    f = np.transpose(np.log10(np.imag(npzfile['growth_h'])))
    pcm = axs[0].pcolormesh(npzfile['kx'], npzfile['kz'], f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    cbar = fig.colorbar(pcm, ax=axs[0], ticks=[-7,-6,-5,-4,-3,-2,-1,0])
    cbar.ax.set_yticklabels([r'$10^{-7}$', r'$10^{-6}$', r'$10^{-5}$', r'$10^{-4}$', r'$10^{-3}$', r'$10^{-2}$', r'$10^{-1}$', r'$10^{0}$'])

    axs[1].set_xscale('log')
    axs[1].set_yscale('log')
    axs[1].set_xlabel(r'$K_x$')
    axs[1].set_ylabel(r'$K_z$')

    axs[1].set_xlim([npzfile['kx'][0], npzfile['kx'][-1]])
    axs[1].set_ylim([npzfile['kz'][0], npzfile['kz'][-1]])

    pcm = axs[1].pcolormesh(npzfile['kx'], npzfile['kz'], np.transpose(res), vmin=0, vmax=1.01, cmap='RdBu_r',shading='gouraud',rasterized=True)
    fig.colorbar(pcm, ax=axs[1])

    plt.savefig(save_direc + 'fig_SI_resonance.png')

    #plt.show()
    plt.close()

def figure_si_cont():
    kx = np.linspace(1.5, 4.5, 100)
    npzfile = np.load(data_direc + 'fig_si_cont.npz')

    plt.yscale('log')

    plt.xlabel(r'$K_x$')
    plt.ylabel(r'$\Im(\omega)$')

    tot, = plt.plot(kx, np.imag(npzfile['growth_mono']))
    dens, = plt.plot(kx, np.imag(npzfile['growth_mono_dens']))
    drift, = plt.plot(kx, -np.imag(npzfile['growth_mono_dv']))

    poly, = plt.plot(kx, np.imag(npzfile['growth_poly']), color=colors[0], linestyle='--')
    plt.plot(kx, np.imag(npzfile['growth_poly_dens']), color=colors[1], linestyle='--')
    plt.plot(kx, -np.imag(npzfile['growth_poly_dv']), color=colors[2], linestyle='--')

    leg1 = plt.gca().legend([tot, dens, drift], ['total', 'density', 'drift'], loc='upper right')
    plt.gca().add_artist(leg1)
    plt.legend([tot, poly], [r'$\Delta=0$', r'$\Delta=0.3$'], loc='lower right')

    #plt.tight_layout()

    plt.savefig(save_direc + 'fig_SI_contribution.png')

    #plt.show()
    plt.close()

def figure_si_delta():
    delta_arr = np.logspace(-3, np.log10(0.5), 100)
    npzfile = np.load(data_direc + 'fig_si_delta.npz')

    plt.xscale('log')
    plt.yscale('log')

    plt.xlabel(r'$\Delta$')
    plt.ylabel(r'$\Im(\omega)$')

    plt.plot(delta_arr, np.imag(npzfile['growth_4']), label=r'$\mu=10^{-4}$')
    plt.plot(delta_arr, np.imag(npzfile['growth_3']), label=r'$\mu=10^{-3}$')
    plt.plot(delta_arr, np.imag(npzfile['growth_2']), label=r'$\mu=10^{-2}$')
    plt.plot(delta_arr, np.imag(npzfile['growth_1']), label=r'$\mu=10^{-1}$')

    plt.legend()

    #plt.tight_layout()

    plt.savefig(save_direc + 'fig_SI_delta.png')

    #plt.show()
    plt.close()

def figure_dsi_monoresonance():
    npzfile = np.load(data_direc + 'fig_dsi_monoresonance.npz')

    res = np.log(np.imag(npzfile['growth_h'])/(np.imag(npzfile['growth_l'])+1.0e-10))/np.log(npzfile['mu'][1]/npzfile['mu'][0])

    fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(4.5,7), layout='constrained')

    axs[0].set_xscale('log')
    axs[0].set_yscale('log')
    axs[0].set_xlabel(r'$K_x$')
    axs[0].set_ylabel(r'$K_z$')

    axs[0].set_xlim([npzfile['kx'][0], npzfile['kx'][-1]])
    axs[0].set_ylim([npzfile['kz'][0], npzfile['kz'][-1]])

    f = np.transpose(np.log10(np.imag(npzfile['growth_h'])))
    pcm = axs[0].pcolormesh(npzfile['kx'], npzfile['kz'], f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    cbar = fig.colorbar(pcm, ax=axs[0], ticks=[-7,-6,-5,-4,-3,-2,-1,0])
    cbar.ax.set_yticklabels([r'$10^{-7}$', r'$10^{-6}$', r'$10^{-5}$', r'$10^{-4}$', r'$10^{-3}$', r'$10^{-2}$', r'$10^{-1}$', r'$10^{0}$'])

    axs[1].set_xscale('log')
    axs[1].set_yscale('log')
    axs[1].set_xlabel(r'$K_x$')
    axs[1].set_ylabel(r'$K_z$')

    axs[1].set_xlim([npzfile['kx'][0], npzfile['kx'][-1]])
    axs[1].set_ylim([npzfile['kz'][0], npzfile['kz'][-1]])

    pcm = axs[1].pcolormesh(npzfile['kx'], npzfile['kz'], np.transpose(res), vmin=0, vmax=1.01, cmap='RdBu_r',shading='gouraud',rasterized=True)
    fig.colorbar(pcm, ax=axs[1])

    plt.savefig(save_direc + 'fig_dsi_monoresonance.png')

    #plt.show()
    plt.close()

def figure_dsi_contour_growth():
    npzfile = np.load(data_direc + 'fig_dsi_contour_growth.npz')

    fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(4.5,7), layout='constrained')

    axs[0].set_xscale('log')
    axs[0].set_yscale('log')
    axs[0].set_xlabel(r'$K_x$')
    axs[0].set_ylabel(r'$K_z$')

    axs[0].set_xlim([npzfile['kx'][0], npzfile['kx'][-1]])
    axs[0].set_ylim([npzfile['kz'][0], npzfile['kz'][-1]])

    f = np.transpose(np.log10(np.imag(npzfile['growth_m'])))
    pcm = axs[0].pcolormesh(npzfile['kx'], npzfile['kz'], f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    cbar = fig.colorbar(pcm, ax=axs[0], ticks=[-7,-6,-5,-4,-3,-2,-1,0])
    cbar.ax.set_yticklabels([r'$10^{-7}$', r'$10^{-6}$', r'$10^{-5}$', r'$10^{-4}$', r'$10^{-3}$', r'$10^{-2}$', r'$10^{-1}$', r'$10^{0}$'])

    axs[1].set_xscale('log')
    axs[1].set_yscale('log')
    axs[1].set_xlabel(r'$K_x$')
    axs[1].set_ylabel(r'$K_z$')

    axs[1].set_xlim([npzfile['kx'][0], npzfile['kx'][-1]])
    axs[1].set_ylim([npzfile['kz'][0], npzfile['kz'][-1]])

    f = np.transpose(np.log10(np.imag(npzfile['growth_p'])))
    pcm = axs[1].pcolormesh(npzfile['kx'], npzfile['kz'], f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    cbar = fig.colorbar(pcm, ax=axs[1], ticks=[-7,-6,-5,-4,-3,-2,-1,0])
    cbar.ax.set_yticklabels([r'$10^{-7}$', r'$10^{-6}$', r'$10^{-5}$', r'$10^{-4}$', r'$10^{-3}$', r'$10^{-2}$', r'$10^{-1}$', r'$10^{0}$'])

    plt.savefig(save_direc + 'fig_dsi_contour_growth.png')

    #plt.show()
    plt.close()

def figure_dsi_poly_kx():
    npzfile = np.load(data_direc + 'fig_dsi_poly_kx.npz')

    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel(r'$K_x$')
    plt.ylabel(r'$\Im(\omega)$')

    mono_l, = plt.plot(npzfile['kx'], np.imag(npzfile['growth_0_l']), color=colors[0], label=r'$\Delta=0$')
    mono_h, = plt.plot(npzfile['kx'], np.imag(npzfile['growth_0_h']), color=colors[0], linestyle='dotted')
    poly_1, = plt.plot(npzfile['kx'], np.imag(npzfile['growth_1_l']), color=colors[1], label=r'$\Delta=0.1$')
    plt.plot(npzfile['kx'], np.imag(npzfile['growth_1_h']), color=colors[1], linestyle='dotted')
    poly_5, = plt.plot(npzfile['kx'], np.imag(npzfile['growth_5_l']), color=colors[2], label=r'$\Delta=0.5$')
    plt.plot(npzfile['kx'], np.imag(npzfile['growth_5_h']), color=colors[2], linestyle='dotted')

    leg1 = plt.gca().legend([mono_l, mono_h], [r'$\mu=10^{-4}$', r'$\mu=10^{-3}$'], loc='lower right')
    plt.gca().add_artist(leg1)
    plt.legend([mono_l, poly_1, poly_5], [r'$\Delta=0$', r'$\Delta=0.1$',r'$\Delta=0.5$'], loc='upper left')

    #plt.tight_layout()

    plt.savefig(save_direc + 'fig_dsi_poly_kx.png')

    #plt.show()
    plt.close()

def figure_dsi_lim_delta():
    stokes_list = [0.01, 0.1]
    mu_list =  [1.0e-2, 1.0e-4]

    plt_list = []
    for i, mu in np.ndenumerate(mu_list):
        for j, st0 in np.ndenumerate(stokes_list):
            w0 = lambda x: 0.5*(2*st0*mu*x)**(1/3)
            f = lambda x: 1/(1/3 + 4*x*st0**3/(1+st0**2)/w0(x))

            k_lim = (0.5*np.pi - np.atan(0.5/np.sqrt(0.001)))**3/mu/st0
            Kx = np.logspace(np.log10(k_lim),4,100)

            plt_item, = plt.plot(Kx, f(Kx), color=colors[j[0]], linestyle=linestyles[i[0]])
            plt_list.append(plt_item)

    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel(r'$K_x$')
    plt.ylabel(r'$\delta_{\rm lim}$')

    leg1 = plt.gca().legend([plt_list[0], plt_list[1]], [r'${\rm St}=0.01$', r'${\rm St}=0.1$'], loc='center left')
    plt.gca().add_artist(leg1)
    plt.legend([plt_list[0], plt_list[2]], [r'$\mu=10^{-2}$', r'$\mu=10^{-4}$'], loc='lower left')

    #plt.tight_layout()

    plt.savefig(save_direc + 'fig_dsi_lim_delta.png')

    #plt.show()
    plt.close()

def figure_dsi_double_growth():
    npzfile = np.load(data_direc + 'fig_dsi_double_growth.npz')

    plt.yscale('log')
    plt.ylim([1.0e-4,2.0])
    plt.xlabel(r'$K_x$')
    plt.ylabel(r'$\Im(\omega)$')

    plt_mono, = plt.plot(npzfile['kx_mono'], np.imag(npzfile['growth_0_l']), color=colors[0])
    plt_mono2, = plt.plot(npzfile['kx_mono'], np.imag(npzfile['growth_0_h']), color=colors[0], linestyle='dashed')
    plt_poly1, = plt.plot(npzfile['kx_poly_1'], np.imag(npzfile['growth_1_l']), color=colors[1])
    plt.plot(npzfile['kx_poly_1'], np.imag(npzfile['growth_1_h']), color=colors[1], linestyle='dashed')
    plt_poly7, = plt.plot(npzfile['kx_poly_7'], np.imag(npzfile['growth_7_l']), color=colors[2])
    plt.plot(npzfile['kx_poly_7'], np.imag(npzfile['growth_7_h']), color=colors[2], linestyle='dashed')

    leg1 = plt.gca().legend([plt_mono, plt_poly1, plt_poly7], [r'$\Delta=0$', r'$\Delta=0.01$', r'$\Delta=0.7$'], loc='upper left')
    plt.gca().add_artist(leg1)
    plt.legend([plt_mono, plt_mono2], [r'$\mu=10^{-3}$', r'$\mu=10^{-2}$'], loc='upper right')

    #plt.tight_layout()

    plt.savefig(save_direc + 'fig_dsi_double_growth.png')

    #plt.show()
    plt.close()

def figure_dsi_visc():
    npzfile = np.load(data_direc + 'fig_dsi_visc.npz')

    kx = npzfile['kx']
    kz = npzfile['kz']

    fig = plt.figure(figsize=(18,6))
    gs = fig.add_gridspec(2, 5, hspace=0, wspace=0)
    axs = gs.subplots(sharex='col', sharey='row')

    for ax in axs.flatten():
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.set_xlim([kx[0]*1.01, kx[-1]*0.99])
        ax.set_ylim([kz[0]*1.01, kz[-1]*0.99])

    axs[0,0].set_title(r'$\alpha=0$')
    axs[0,1].set_title(r'$\alpha=10^{-10}$')
    axs[0,2].set_title(r'$\alpha=10^{-8}$')
    axs[0,3].set_title(r'$\alpha=10^{-6}$')
    axs[0,4].set_title(r'$\alpha=10^{-5}$')

    axs[0,4].text(10,100,r'$\Delta=0$')
    axs[1,4].text(10,100,r'$\Delta=0.99$')

    axs[0,0].set_ylabel(r'$K_z$')
    axs[1,0].set_ylabel(r'$K_z$')
    axs[1,0].set_xlabel(r'$K_x$')
    axs[1,1].set_xlabel(r'$K_x$')
    axs[1,2].set_xlabel(r'$K_x$')
    axs[1,3].set_xlabel(r'$K_x$')
    axs[1,4].set_xlabel(r'$K_x$')

    f = np.transpose(np.log10(np.imag(npzfile['growth_m_0'])))
    pcm = axs[0,0].pcolormesh(kx, kz, f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    f = np.transpose(np.log10(np.imag(npzfile['growth_m_10'])))
    pcm = axs[0,1].pcolormesh(kx, kz, f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    f = np.transpose(np.log10(np.imag(npzfile['growth_m_8'])))
    pcm = axs[0,2].pcolormesh(kx, kz, f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    f = np.transpose(np.log10(np.imag(npzfile['growth_m_6'])))
    pcm = axs[0,3].pcolormesh(kx, kz, f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    f = np.transpose(np.log10(np.imag(npzfile['growth_m_5'])))
    pcm = axs[0,4].pcolormesh(kx, kz, f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)

    f = np.transpose(np.log10(np.imag(npzfile['growth_p_0'])))
    pcm = axs[1,0].pcolormesh(kx, kz, f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    f = np.transpose(np.log10(np.imag(npzfile['growth_p_10'])))
    pcm = axs[1,1].pcolormesh(kx, kz, f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    f = np.transpose(np.log10(np.imag(npzfile['growth_p_8'])))
    pcm = axs[1,2].pcolormesh(kx, kz, f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    f = np.transpose(np.log10(np.imag(npzfile['growth_p_6'])))
    pcm = axs[1,3].pcolormesh(kx, kz, f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    f = np.transpose(np.log10(np.imag(npzfile['growth_p_5'])))
    pcm = axs[1,4].pcolormesh(kx, kz, f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)

    cbar = fig.colorbar(pcm, ax=axs, ticks=[-7,-6,-5,-4,-3,-2,-1,0], use_gridspec=True)

    plt.savefig(save_direc + 'fig_dsi_visc.png')

    #plt.show()
    plt.close()

def figure_dsi_convergence():
    npzfile = np.load(data_direc + 'fig_dsi_convergence.npz')

    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(8,0.8*4.5), layout='constrained')

    axs[0].set_xscale('log')
    axs[0].set_yscale('log')
    axs[0].set_xlabel(r'$K_x$')
    axs[0].set_ylabel(r'$K_z$')

    axs[0].set_xlim([npzfile['kx'][0], npzfile['kx'][-1]])
    axs[0].set_ylim([npzfile['kz'][0], npzfile['kz'][-1]])

    f = np.transpose(np.log10(np.imag(npzfile['growth_poly'])))
    pcm = axs[0].pcolormesh(npzfile['kx'], npzfile['kz'], f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    cbar = fig.colorbar(pcm, ax=axs[0], ticks=[-7,-6,-5,-4,-3,-2,-1,0])
    cbar.ax.set_yticklabels([r'$10^{-7}$', r'$10^{-6}$', r'$10^{-5}$', r'$10^{-4}$', r'$10^{-3}$', r'$10^{-2}$', r'$10^{-1}$', r'$10^{0}$'])

    axs[1].set_xscale('log')
    axs[1].set_yscale('log')
    axs[1].set_xlabel(r'$N$')
    axs[1].set_ylabel(r'$\Im(\omega)$')

    axs[1].plot(npzfile['N'], np.imag(npzfile['growth_1']), label=r'$\Delta=0.1$')
    axs[1].plot(npzfile['N'], np.imag(npzfile['growth_7']), label=r'$\Delta=0.99$')

    axs[1].legend()

    plt.savefig(save_direc + 'fig_dsi_convergence.png')

    #plt.show()
    plt.close()

def figure_dsi_convergence_max():
    npzfile = np.load(data_direc + 'fig_dsi_convergence_max.npz')

    fig, axs = plt.subplots(nrows=2, ncols=1, figsize=[4,7], layout='constrained')

    axs[0].set_xscale('log')
    axs[0].set_yscale('log')
    axs[0].set_xlabel(r'$K_x$')
    axs[0].set_ylabel(r'$K_z$')

    axs[0].set_xlim([npzfile['kx'][0], npzfile['kx'][-1]])
    axs[0].set_ylim([npzfile['kz'][0], npzfile['kz'][-1]])

    f = np.transpose(np.log10(np.imag(npzfile['growth_direct'])))
    pcm = axs[0].pcolormesh(npzfile['kx'], npzfile['kz'], f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    cbar = fig.colorbar(pcm, ax=axs[0], ticks=[-7,-6,-5,-4,-3,-2,-1,0])
    cbar.ax.set_yticklabels([r'$10^{-7}$', r'$10^{-6}$', r'$10^{-5}$', r'$10^{-4}$', r'$10^{-3}$', r'$10^{-2}$', r'$10^{-1}$', r'$10^{0}$'])

    axs[1].set_xscale('log')
    axs[1].set_yscale('log')
    axs[1].set_xlabel(r'$K_x$')
    axs[1].set_ylabel(r'$K_z$')

    f = np.transpose(np.log10(np.imag(npzfile['growth_exact'])))
    pcm = axs[1].pcolormesh(npzfile['kx'], npzfile['kz'], f, vmin=-7, vmax=0, cmap='RdBu_r',shading='gouraud',rasterized=True)
    cbar = fig.colorbar(pcm, ax=axs[1], ticks=[-7,-6,-5,-4,-3,-2,-1,0])
    cbar.ax.set_yticklabels([r'$10^{-7}$', r'$10^{-6}$', r'$10^{-5}$', r'$10^{-4}$', r'$10^{-3}$', r'$10^{-2}$', r'$10^{-1}$', r'$10^{0}$'])

    plt.savefig(save_direc + 'fig_dsi_convergence_max.png')

    #plt.show()
    plt.close()


figure_si_monoresonance()
figure_si_resonance()
figure_si_cont()
figure_si_delta()

figure_dsi_monoresonance()
figure_dsi_contour_growth()
figure_dsi_poly_kx()
figure_dsi_lim_delta()
figure_dsi_double_growth()
figure_dsi_visc()

figure_dsi_convergence()
figure_dsi_convergence_max()
