#!/usr/bin/python

###############################################################################
# Generate plots for the paper "Resonant Drag Instabilities in Polydisperse
# Dust I. The Acoustic Resonant Drag Instability", by Paardekooper & Aly.
###############################################################################

import numpy as np
import scipy as sp
import matplotlib.pyplot as plt

# By default, save png's in current directory
save_direc = './'
# Location of data files
data_direc = './'

# Use same style as in paper
plt.style.use('./paper.mplstyle')

prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
linestyles = ['solid', 'dotted', 'dashed', 'dashdot']


def fig_adi_kx():
    '''ARDI growth rates as a function of wave number for 3 different size distributions'''
    delta = np.asarray([0.001,0.01,0.1])
    ardi_param_dict = {
        'dust_to_gas_ratio': 0.01,
        'stokes_range' : [1-delta[0], 1+delta[0]],
        'single_size_flag' : False,
    }

    k = np.logspace(-3,4,100)

    # Analytic results
    k_lim = 0.25*ardi_param_dict['dust_to_gas_ratio']/delta**2
    k_mono = 0.5*np.sqrt(ardi_param_dict['dust_to_gas_ratio'])*np.sqrt(k)

    npzfile = np.load(data_direc + 'fig_adi_kx.npz')

    plt.xscale('log')
    plt.yscale('log')
    plt.ylim([0.0002,2])
    plt.xlabel(r'$K$')
    plt.ylabel(r'$\Im(\tilde\omega)$')

    plt.plot(k, k_mono, color=colors[3], linestyle='dotted')
    plt.plot([k_lim[0], k_lim[0]], [0.0001,3], color=colors[0], linestyle='dotted')
    plt.plot([k_lim[1], k_lim[1]], [0.0001,3], color=colors[1], linestyle='dotted')
    plt.plot([k_lim[2], k_lim[2]], [0.0001,3], color=colors[2], linestyle='dotted')

    plt.plot(k, np.imag(npzfile['growth_3']), label=r'$\delta=10^{-3}$')
    plt.plot(k, np.imag(npzfile['growth_2']), label=r'$\delta=10^{-2}$')
    plt.plot(k, np.imag(npzfile['growth_1']), label=r'$\delta=10^{-1}$')

    plt.legend()
    plt.tight_layout()

    plt.savefig(save_direc + 'fig_ADI_Kmu.png')

    #plt.show()

def fig_adi_backward():
    '''ADI backward propagating sound wave'''
    K = np.logspace(-3, 1, 100)
    smin = 0.5
    smax = 1.2

    ardi_param_dict = {
        'dust_to_gas_ratio': 0.001,
        'stokes_range' : [smin, smax],
        'single_size_flag' : False,
    }

    npzfile = np.load(data_direc + 'fig_adi_backward.npz')

    res = np.zeros(len(K), dtype=complex)
    for i in range(0,len(K)):
        res[i] = sp.integrate.quad(lambda x:(K[i]*(1+x)**2-1j)/((1+x)*(1+1j*x*K[i]*(1+x))), smin, smax, complex_func=True)[0]

    res = 0.5*(res/(smax - smin) + 1j)

    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel(r'$K$')
    plt.ylabel(r'$\Im(\tilde\omega)$')

    plt.plot(K, np.imag(npzfile['growth'])*K, label=r'exact')
    plt.plot(K, ardi_param_dict['dust_to_gas_ratio']*np.imag(res), label=r'low $\mu$')

    limKlow = 0.5*ardi_param_dict['dust_to_gas_ratio']*(1 - np.log((1 + smax)/(1+smin))/(smax - smin))
    plt.plot([K[0], K[-1]], [limKlow, limKlow], linestyle='dotted', color=colors[0], label=r'$K\ll1$')

    plt.legend()
    plt.tight_layout()
    plt.savefig(save_direc + 'fig_ADI_backward.png')

    #plt.show()

def fig_adi_mu():
    '''ARDI growth rates as a function of mu'''

    mu = np.logspace(-4,-1, 100)
    k = 1.0

    # Analytic results
    res_mono = 0.5*np.sqrt(mu)*np.sqrt(k)

    npzfile = np.load(data_direc + 'fig_adi_mu.npz')

    plt.xscale('log')
    plt.yscale('log')
    plt.ylim([0.0002,0.2])
    plt.xlabel(r'$\mu$')
    plt.ylabel(r'$\Im(\tilde\omega)$')

    plt.plot(mu, res_mono, color=colors[2], linestyle='dotted')
    plt.plot(mu, npzfile['res_poly'], color=colors[1], linestyle='dotted')

    plt.plot(mu, np.imag(npzfile['res_num']), color=colors[0])

    plt.tight_layout()

    plt.savefig(save_direc + 'fig_ADI_mu.png')

    #plt.show()

def fig_adi_asym():
    '''ARDI growth rates as a function of wave number for 3 different mu'''

    k = np.logspace(-3,1,100)

    npzfile = np.load(data_direc + 'fig_adi_asym.npz')

    plt.xscale('log')
    plt.yscale('log')
    plt.ylim([2e-8,2e-3])
    plt.xlabel(r'$K$')
    plt.ylabel(r'$\Im(\tilde\omega)$')

    plt.plot(k, np.imag(npzfile['growth_5']), label=r'$\mu=10^{-5}$')
    plt.plot(k, np.imag(npzfile['growth_4']), label=r'$\mu=10^{-4}$')
    plt.plot(k, np.imag(npzfile['growth_3']), label=r'$\mu=10^{-3}$')

    DeltaL = 0.5
    DeltaR = 0.2
    for j, mu in enumerate([1.0e-5, 1.0e-4, 1.0e-3]):
        growth_pert = 0.5*mu*(np.log(DeltaL/DeltaR)/(DeltaL+DeltaR) - 1)/k
        for i in range(0, len(k)):
            res = sp.integrate.quad(lambda x: 1/(1-1j*x*k[i]*(1-x)), 1-DeltaL, 1+DeltaR, complex_func=True)[0]
            growth_pert[i] -= 0.5*mu*np.imag(res)/(DeltaL+DeltaR)
        growth_pert = growth_pert*k

        plt.plot(k, growth_pert, color=colors[j], linestyle='dotted')

    plt.legend()
    plt.tight_layout()

    plt.savefig(save_direc + 'fig_ADI_asym.png')

    #plt.show()

def growth_adi_visc(K, mu, st, alpha):
    A = np.matrix([[0,K,0,0],
                   [K-1j*mu, -1j*mu/st - 4j*alpha*K*K/3, 1j*mu, 1j*mu/st],
                   [1j*alpha*K*K/(1+mu),0,st*K-1j*alpha*K*K/(1+mu),K],
                   [0,1j/st,0,st*K-1j/st]])
    ret = np.linalg.eigvals(A)
    return ret[np.argmax(np.imag(ret))]

def fig_adi_visc():
    K = np.logspace(-3, 4, 100)

    mu = 0.01
    st = 1.0
    alpha = [1.0e-7, 1.0e-5, 5.0e-3]
    labels = [r'$\alpha_{\rm g}=\alpha_{\rm d}=10^{-7}$',
              r'$\alpha_{\rm g}=\alpha_{\rm d}=10^{-5}$',
              r'$\alpha_{\rm g}=\alpha_{\rm d}=0.005$']
    growth = np.zeros(np.shape(K), dtype=complex)

    for (a,l) in zip(alpha,labels):
        for i in range(0,len(K)):
            growth[i] = growth_adi_visc(K[i], mu, st, a)

        plt.plot(K, np.imag(growth), label=l)

    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel(r'$K$')
    plt.ylabel(r'$\Im(\tilde\omega)$')
    plt.legend()

    plt.tight_layout()

    plt.savefig(save_direc + 'fig_ADI_visc.png')

    #plt.show()

def fig_adi_conv():
    npzfile = np.load(data_direc + 'fig_adi_conv.npz')

    exact_growth_lomu, exact_growth_back_lomu, exact_growth_himu, exact_growth_back_himu = npzfile['exact_growth']
    N = npzfile['N']

    fig,axs = plt.subplots(nrows=2, ncols=1, figsize=(4.5,7), layout='constrained')

    axs[0].set_xscale('log')
    axs[0].set_yscale('log')
    axs[0].set_xlabel(r'$N$')
    axs[0].set_ylabel(r'$\Im(\tilde\omega)$')
    axs[0].set_title(r'$\mu=0.0004$')

    axs[0].plot(N, np.imag(npzfile['res1_himu']), label=r'$\tilde\omega_0=K$')
    axs[0].plot(N, np.imag(npzfile['res2_himu']), label=r'$\tilde\omega_0=-K$')
    axs[0].plot([N[0], N[-1]], [exact_growth_himu, exact_growth_himu], color=colors[0], linestyle='dotted')
    axs[0].plot([N[0], N[-1]], [exact_growth_back_himu, exact_growth_back_himu], color=colors[1], linestyle='dotted')

    axs[0].legend()

    axs[1].set_xscale('log')
    axs[1].set_yscale('log')
    axs[1].set_xlabel(r'$N$')
    axs[1].set_ylabel(r'$\Im(\tilde\omega)$')
    axs[1].set_title(r'$\mu=0.0001$')

    axs[1].plot(N, np.imag(npzfile['res1_lomu']), label=r'$\tilde\omega_0=K$')
    axs[1].plot(N, np.imag(npzfile['res2_lomu']), label=r'$\tilde\omega_0=-K$')
    axs[1].plot([N[0], N[-1]], [exact_growth_lomu, exact_growth_lomu], color=colors[0], linestyle='dotted')
    axs[1].plot([N[0], N[-1]], [exact_growth_back_lomu, exact_growth_back_lomu], color=colors[1], linestyle='dotted')

    axs[1].legend()

    plt.tight_layout()

    plt.savefig(save_direc + 'fig_ADI_conv.png')

    #plt.show()

# Generate all paper figures (except figure 1, which is a cartoon).
fig_adi_kx()
fig_adi_backward()
fig_adi_mu()
fig_adi_asym()
fig_adi_visc()
fig_adi_conv()
