import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from ipywidgets import SelectionSlider, interactive
from matplotlib.gridspec import GridSpec
from scipy.optimize import minimize_scalar
from scipy.signal import find_peaks, savgol_filter
from scipy.stats import zscore

# Resistance quantum RK = h/e^2 in Ohm
RK = 25812.80745  # Ohm


def plot_conductance_data_with_peaks(data, peaks):
    def _plot_traces(bval):
        g = data.sel(dict(field=bval))
        p = peaks.loc[peaks["field"] == bval]
        fig = plt.figure()
        g.plot(marker=".")
        for idx, row in p.iterrows():
            plt.scatter(row["pg"], row["height"], c="r")
        return fig

    b_slider = SelectionSlider(options=[("%g" % i, i) for i in data.field.values])

    widget = interactive(_plot_traces, bval=b_slider)
    return widget


def rotate_lockin_data(X, Y, method="fit", plot=True):
    """
    Given DataArrays containing X and Y quadratures of a lockin
    measurement, returns a Dataset with quadratures rotated so that
    all the signal is in the X quadrature. This is accomplished with
    a linear fit of the rotation phase.

    If plot=True, a plot with the results is shown.
    """
    Z = X.values + 1j * Y.values
    theta = determine_rotation_phase(Z, method=method)
    ZR = (X + 1j * Y) * np.exp(1j * theta)
    XR, YR = np.real(ZR), np.imag(ZR)

    if plot:
        fig, ax = plt.subplots(figsize=(5, 5))
        R = np.max(np.abs(ZR)).item()
        ax.scatter(
            X.values.flatten() / R,
            Y.values.flatten() / R,
            marker=".",
            color="black",
            label="raw",
        )
        ax.scatter(
            XR.values.flatten() / R,
            YR.values.flatten() / R,
            marker=".",
            color="green",
            label="rotated",
        )
        ax.set_xlim([-1, 1])
        ax.set_ylim([-1, 1])
        ax.grid()
        ax.set_xlabel("Normalized Raw Lockin X")
        ax.set_ylabel("Normalized Raw Lockin Y")
        circ = plt.Circle((0, 0), radius=1, edgecolor="gray", facecolor="None")
        ax.add_patch(circ)
        plt.legend(loc="upper right")
        plt.show()

    return XR, YR


def determine_rotation_phase(zs, method="fit"):
    """
    Given an array of complex numbers, determines the rotation
    phase that minimizes the norm of the imaginary part of the array.
    """
    if method == "fit":
        xs = np.real(zs)
        ys = np.imag(zs)
        R = np.max(np.sqrt(xs ** 2 + ys ** 2))
        xs, ys = xs / R, ys / R
        m, b = np.polyfit(xs.flatten(), ys.flatten(), deg=1)
        theta = np.arcsin(-m)
    elif method == "minimum":

        def imaginary_part_norm_after_rotation(alpha):
            zs_rotated = zs * np.exp(1j * alpha)
            return np.sum(np.imag(zs_rotated) ** 2)

        result = minimize_scalar(
            imaginary_part_norm_after_rotation, bracket=[0, 2 * np.pi]
        )
        theta = result.x
    return theta


def convert_to_conductance(raw, gain, dV, Rseries):
    dI = raw / gain
    return dI / (dV - Rseries * dI) * RK  # In units of e^2/h


def get_conductance_data(data, gain, dV, Rseries, dim_attrs, method, plot=True):
    """
    Given the raw lock-in data:
        1. Rotates the data so that the signal is all in the X quadratures.
        2. Converts the data to conductance given measurements parameters.
    The original data and the rotated ones are both returned in a single Dataset.
    """
    XR, YR = rotate_lockin_data(data["X"], data["Y"], method=method, plot=plot)
    gdata = convert_to_conductance(data["X"], gain, dV, Rseries).to_dataset(name="X")
    gdata["Y"] = convert_to_conductance(data["Y"], gain, dV, Rseries)
    gdata["XR"] = convert_to_conductance(XR, gain, dV, Rseries)
    gdata["YR"] = convert_to_conductance(YR, gain, dV, Rseries)

    for dim in gdata.dims:
        gdata[dim].attrs = dim_attrs[dim]

    for var in gdata.data_vars:
        gdata[var].attrs["units"] = r"$e^2/h$"
        gdata[var].attrs["long_name"] = var + " Conductance"

    # Copy other variables to gdata, if present
    other_vars = [var for var in data.data_vars if var not in ["X", "Y"]]
    for var in other_vars:
        gdata[var] = data[var]
        gdata[var].attrs = dim_attrs[var]

    if plot:
        figsize = (12, 8)
        fig = plt.figure(constrained_layout=True, figsize=figsize)
        gs = GridSpec(2, 2, figure=fig)
        gdata.X.plot(ax=fig.add_subplot(gs[0, 0]))
        gdata.XR.plot(ax=fig.add_subplot(gs[0, 1]))
        gdata.Y.plot(ax=fig.add_subplot(gs[1, 0]))
        gdata.YR.plot(ax=fig.add_subplot(gs[1, 1]))
        plt.show()
    return gdata


def plot_and_compare_conductance_data(data):
    def _plot_traces(bval, pval):

        g = data.sel(dict(field=bval, pg_coarse=pval))

        fig = plt.figure(figsize=(10, 5), constrained_layout=True)
        grid = fig.add_gridspec(2, 1)

        ax1 = fig.add_subplot(grid[0, 0])
        g["X"].plot(ax=ax1, label="raw")
        g["XR"].plot(ax=ax1, label="after rotation")
        ax1.set_title("")
        ax1.set_ylabel(r"X component [$e^2 / h$]")
        ax1.legend(fontsize=12)
        ax1.grid()

        ax2 = fig.add_subplot(grid[1, 0])
        g["Y"].plot(ax=ax2, label="raw")
        g["YR"].plot(ax=ax2, label="after rotation")
        ax2.axhline(0, c="k", lw=0.5)
        ax2.grid()
        ax2.set_title("")
        ax2.set_ylabel(r"Y component [$e^2 / h$]")

        return fig

    b_slider = SelectionSlider(options=[("%g" % i, i) for i in data.field.values])
    p_slider = SelectionSlider(options=[("%g" % i, i) for i in data.pg_coarse.values])

    widget = interactive(_plot_traces, bval=b_slider, pval=p_slider)
    return widget


def get_peaks(
    array, coord, filter_params, peak_params, height_threshold, zscore_threshold
):
    """
    Returns a table containing the location and the values
    of all peaks along a specified coordinate of a dataset.

    Parameters:
    ----------
    array: xarray DataArray
    coord: dimension of array to be reduced by the peakfinder
    filter_params: dict, parameters to be passed to savgol_filter to smoothen the data
    peak_params: dict, parameters to be passed to find_peaks

    Returns:
    --------
    peaks: a pandas dataframe containing position and values of all peaks found.
    """
    # get list of dimensions other than coord
    dims = [dim for dim in array.dims if dim != coord]
    # get list of their sizes
    sizes = tuple(array.sizes[dim] for dim in dims)
    peaks = []
    # multidimensional loop over all 1D traces of array
    for indices in np.ndindex(sizes):
        # Store location
        where = {dim: array[dim].values[index] for (dim, index) in zip(dims, indices)}
        # Get data
        y = array.sel(where).values
        # Replace nans with 0
        if np.isnan(y).any():
            y = np.nan_to_num(y, nan=0)
        y_av = np.average(y)
        # Smoothen data
        if filter_params:
            yf = savgol_filter(y, **filter_params)  # smoothen raw data
        else:
            yf = y
        # Find peaks
        peaks_tmp, ___ = find_peaks(yf, **peak_params)
        peak_heights = y[peaks_tmp]
        zvalues = zscore(peak_heights)
        where_round = {key: np.round(val, 4) for key, val in where.items()}
        peak_dict = {
            coord: array[coord].values[peaks_tmp],  # peak position along coord
            **where_round,  # peak positions along the other coords in data array
            "height": peak_heights,  # peak height
            "height_zscore": zvalues,
        }
        peak_df = pd.DataFrame(peak_dict)
        peak_df = peak_df.loc[peak_df["height"] > height_threshold * y_av]
        if zscore_threshold:
            peak_df = peak_df.loc[peak_df["height_zscore"] > zscore_threshold]
        peaks.append(peak_df)
    return pd.concat(peaks, ignore_index=True)


def get_aggregate_peak_info(peaks, by=["field", "pg_coarse"], dim="pg_fine"):
    grouped = peaks.groupby(by)
    avg_spacing = (
        grouped[dim].aggregate(lambda x: np.mean(np.diff(x))).rename("avg_spacing")
    )
    median_spacing = (
        grouped[dim].aggregate(lambda x: np.median(np.diff(x))).rename("median_spacing")
    )
    std_spacing = (
        grouped[dim].aggregate(lambda x: np.std(np.diff(x))).rename("std_spacing")
    )
    avg_conductance = (
        grouped["height"].aggregate(np.mean).rename("avg_peak_conductance")
    )
    std_conductance = grouped["height"].aggregate(np.std).rename("std_peak_conductance")
    number_of_peaks = grouped.size().rename("number_of_peaks")
    result = pd.concat(
        [
            avg_spacing,
            median_spacing,
            std_spacing,
            avg_conductance,
            std_conductance,
            number_of_peaks,
        ],
        axis=1,
    )
    result = result.to_xarray()
    return result


def return_spacing_histogram(peaks, coord, nbins, rng, dim="pg_fine"):
    grouped = peaks.groupby(coord)
    hist = np.zeros((len(grouped), nbins))
    for n, (_, group) in enumerate(grouped):
        spacings = np.diff(group[dim])
        spacings = spacings[spacings > 0]  # eliminates boundary cases
        zs = zscore(spacings)
        spacings = spacings[np.abs(zs) < 3]  # eliminates outliers
        counts, bin_edges = np.histogram(spacings, bins=nbins, range=rng)
        bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2
        hist[n] = counts
    data = xr.DataArray(
        data=hist,
        dims=[coord, "spacing"],
        coords={"spacing": bin_centers, coord: np.unique(peaks[coord].values)},
    )
    return data
