import h5py
import os
from pathlib import Path
import tempfile
from scipy.stats import poisson
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import pycorrelate as pyc
import numpy as np
import pandas as pd
import datetime
import subprocess as sp
import socket
import matplotlib
import seaborn as sns
import re



hostname = socket.gethostname()
if hostname == 'TUD207901':
    matplotlib.use('Agg')

mylightblue = '#a6cee3'
mydarkblue = '#1f78b4'
mylightgreen = '#b2df8a'
mydarkgreen = '#33a02c'
myred = '#e31a1c'
myyellow = 'ffff33ff'
mypurple = '984ea3ff'
myorange = 'ff7f00ff'
mybrown = 'a65628ff'
mygrey = '999999ff'
mypink = '#fb9a99'

TUD_cyan = '#00a6d6ff'
TUD_lightblue = '#6ebbd5ff'
TUD_darkblue = '#1d1c73ff'
TUD_orange = '#e64616ff'
TUD_red = '#e21a1aff'
TUD_yellow = '#ffd100ff'
TUD_green = '#008891ff'
TUD_lightgreen = '#a5ca1aff'
TUD_purple = '#6d177fff'
TUD_grey = '#6b8689ff'
TUD_brown = '#a65628ff'
TUD_pink = '#fb9a99'


def cmtoinch(*tupel):
    inchpercen = 2.54
    if isinstance(tupel[0], tuple):
        return tuple(j/inchpercen for j in tupel[0])
    else:
        return tuple(j/inchpercen for j in tupel)

def get_photonhdf_loc(results_hdf_loc, absolute=False):
    with h5py.File(results_hdf_loc, "r") as results_hdf:
        absolutepath = Path(results_hdf_loc).parent.parent / 'Photon_HDF5'/ Path(results_hdf['photon_hdf5_loc'][()].decode('ascii')).name
        if absolute:
            return absolutepath
        else:
            return absolutepath.relative_to(absolutepath.parent.parent.parent.parent.parent)
        

def get_results_hdf_loc(photonhdf, absolute=False):
    resultspath = Path(photonhdf).parent.parent / 'Results'
    absolutepath = Path(str(resultspath) + '/' + Path(photonhdf).name[:-5] + '_results.hdf5')
    if absolute:
        return  absolutepath
    else:
        return absolutepath.relative_to(absolutepath.parent.parent.parent.parent.parent)



def logger(logstr, logtype = 'c', init=False):
    hostname = socket.gethostname()
    if hostname == 'TUD207901':
        lognm = "logfile_TUD.log"
    else:
        lognm = "logfile.log"
    if init:
        log = open(lognm, "w")
    else:
        log = open(lognm, "a")

    tempstr = str(datetime.datetime.now())
    if logtype == 'w':
        tempstr += ': Warning: '
    elif logtype == 'e':
        tempstr += ': Error: '
    elif logtype == 'c':
        tempstr += ': Comment: '
    else:
        tempstr += ': No type: '
    
    tempstr += logstr

    print(tempstr)
    log.write(tempstr + '\n')
    log.close()


def check_hdf_and_group_exist(hdfloc, grp=None, latest_run=True):
    if not os.path.exists(hdfloc):
        logger('There is no file %s'%hdfloc, 'e')
        return False
    else:
        if grp is not None:
            with h5py.File(hdfloc, "a") as hdf5:
                if latest_run & ('amnt_runs' in hdf5):
                    runstr = 'run_%s/'%hdf5['amnt_runs'][...]
                    testgrp = runstr + grp
                else:
                    testgrp = grp
                if testgrp in hdf5:
                    return True
                else:
                    logger('%s is missing %s'%(hdfloc, testgrp))
                    return False
        else:
            return True

def init_results_hdf5(analydir, photonhdf5, overwrite_last_run = False,
                      skip_existing=False, this_run_description = '', copyfrom = 'None'):
    '''
    initialises a hdf5 file, makes it new or uses the old one and either generates a 
    new group with this_run_description (overwrite_last_run = False) for the current evaluation.
    
    It copies the data from a specific run given as copyfrom, starts a new run (copyfrom = 'None')
    or copies from the last run (copyfrom = 'latest_run').
    
    if skip_existing, it will skip any already existing hdf file

    returns the location of the results hdf5
    '''
    logger('starting on %s with init'%photonhdf5, 'c')
    full_photonhdf5 = analydir + photonhdf5
    if not os.path.exists(full_photonhdf5):
        logger('File %s not found. Rerun the converter on this file.'%full_photonhdf5, logtype='e')
        return 'None'
    exp_path = Path(full_photonhdf5).parent.parent

    resultspath = exp_path / 'Results'
    resultspath.mkdir(exist_ok=True)
    
    results_hdf5_str = (str(resultspath) + '/' + 
                        Path(full_photonhdf5).name[:-5] + '_results.hdf5') # path can handle overlaps
    if skip_existing and (os.path.exists(results_hdf5_str)):
        return 'Skip'
    try:
        with h5py.File(results_hdf5_str, "a") as f:
            # add photonhdf5 loc
            if not 'photon_hdf5_loc' in f:
                f['photon_hdf5_loc'] = photonhdf5
            # make the runs group

            if 'amnt_runs' in f:
                amnt_runs = f['amnt_runs'][...] 
            else:
                amnt_runs = 0

            if (overwrite_last_run & (amnt_runs > 0) & (copyfrom != 'run_%s' % amnt_runs)):
                del f['run_%s' % amnt_runs]
            else:
                amnt_runs += 1

            if 'amnt_runs' in f:
                f['amnt_runs'][...] = amnt_runs
            else:
                f['amnt_runs'] = amnt_runs
            if (copyfrom == 'None') or (amnt_runs == 1):
                logger('Creating new group run_%s'%amnt_runs)
                grp = f.create_group('run_%s'%amnt_runs)
                grp.attrs['generation time'] = str(datetime.datetime.now())
                grp.attrs['comment'] = this_run_description
                grp['analydir'] = analydir
            elif copyfrom == 'latest_run':
                logger('copying from %s to run_%s'%(amnt_runs - 1, amnt_runs), 'c')
                f.copy('run_%s'%(amnt_runs - 1), 'run_%s'%amnt_runs)
                grp = f['run_%s'%amnt_runs]
                grp.attrs.create('generation time', str(datetime.datetime.now()))
                grp.attrs.create('comment', this_run_description) 
                grp['analydir'][...] = analydir
                if 'Plotting' in grp:
                    del grp['Plotting']
            else:
                logger('copying from %s to run_%s'%(copyfrom, amnt_runs), 'c')
                f.copy(copyfrom, 'run_%s'%amnt_runs)
                grp = f['run_%s'%amnt_runs]
                grp.attrs.create('generation time', str(datetime.datetime.now()))
                grp.attrs.create('comment', this_run_description) 
                grp['analydir'][...] = analydir
                if 'Plotting' in grp:
                    del grp['Plotting']


    except OSError:
        logger('File %s is probably opened and could not be created.'%results_hdf5_str, 'e')
        return 'None'
    logger('File %s generated.'%results_hdf5_str, logtype='c')
    return results_hdf5_str

def get_cps(results_hdf_loc, changepointloc, chnl, 
            splitpts = 'None', getfromrun = 'None',
            params = (0.01, 0.9, 2), test = False, amnt_cores = 1):

    '''
    finds changepoints in the ptu timetrace at fileloc.
    If you want to speed up computation split the timetrace in parts of splitspts photons, that are handed
    over to the changepoint detection.

    Returns a dictionary with chnl0 and chnl1, that contains the changeptout and the timestamps.
    Additionally saves the results of the changepoint finding in the hdf5 where it has to be retrieved from.
    '''
    results_hdf = h5py.File(results_hdf_loc, "a")
    runstr = 'run_%s'%results_hdf['amnt_runs'][...]
    rungrp = results_hdf['run_%s'%results_hdf['amnt_runs'][...]]
    
    photon_hdf = h5py.File(get_photonhdf_loc(results_hdf_loc, absolute=True), "r")
    logger('\n --- \n\n\n starting changepoint now \n\n\n ---', 'c')
    tfil = tempfile.TemporaryFile()
    chnlgrp = rungrp.require_group(chnl)
    spkgrp = chnlgrp.require_group('spikes')
    retrieve = False
    if (getfromrun is not None) and (getfromrun != 'None'):
        retrieve = int(getfromrun[-1]) < int(runstr[-1])
    
    if retrieve:
        #if it has to be retrieved from a specific run
        if results_hdf[getfromrun][chnl]['spikes']['exist_unique'][()]:
            for name in ['cpt_params', 'cps', 'em2', 'states_unique', 'exist_unique']:
                spkgrp[name] = results_hdf[getfromrun][chnl]['spikes'][name]
        else:
            spkgrp['exist_unique'] = False
    else:
        detectors = photon_hdf['photon_data']['detectors'][...]
        times = photon_hdf['photon_data']['timestamps'][()]
        times = times[detectors == int(chnl[-1])]
        nanotimes = photon_hdf['photon_data']['timestamps'][()]
        nanotimes = nanotimes[detectors == int(chnl[-1])]
        timeres = photon_hdf['photon_data']['timestamps_specs']['timestamps_unit'][()]
        nanotimeres = photon_hdf['photon_data']['nanotimes_specs']['tcspc_unit'][()]
        # calculate the photon rate, so that according to this the split chunks can be made. 
        # The photonrate is an indicator of the amount of events that should be observed.

        photonrate = len(times) / (times[-1] * timeres)
        logger('photonrate of %s is %s'%(chnl, photonrate), 'c')

        if (splitpts is None) or (splitpts == 'None'):
            splitpts = 20000
        else:
            splitpts = int(splitpts)

        #this might introduce a 0.035 / 3125 error
        timeresratio = int(np.round(timeres / nanotimeres))

        split = int((len(times) / splitpts)) + 1


        times_accu = times * timeresratio + nanotimes
        timediffs = times_accu[1:] - times_accu[:-1]
        if test:
            split = 1
        cptr = str(np.round(nanotimeres, decimals=13))
        i = 0

        while i < split:
            #generate the timedifferences
            logger('At %i of %i splitsteps for finding cps.'%(i, split), 'c')
            returns = {}
            # use split - 1 -i to avoid a very small trace in the end
            if split == 1:
                amnt_threads = 1
            else:
                amnt_threads = min(amnt_cores, split - i - 1)
            nr_list = range(amnt_threads)
            for nr in nr_list:
                # use split -2 to avoid short trace at the end
                if i + nr == split - 2:
                    np.savetxt('no_synchronise_timediffs_%i.txt'%(nr), timediffs[(i+nr) * splitpts:], fmt='%i')
                elif i + nr < split - 2:
                    np.savetxt('no_synchronise_timediffs_%i.txt'%(nr), timediffs[(i+nr) * splitpts:((i+nr) + 1) * splitpts], fmt='%i')


            # run changepoint
            proc_list = [sp.Popen([changepointloc, 'no_synchronise_timediffs_%i.txt'%(nr), 
                                  cptr, str(params[0]), str(params[1]), str(params[2])], stdout=tfil) for nr in nr_list]

            for nr in nr_list:
                proc = proc_list[nr]
                returns[nr] = proc.wait()
            logger('Done with changepoint for step %s'%i)
            for nr in nr_list:
                if returns[nr] != 0:
                    logger('Changepoint program exited with nonzero exit code. ' + 
                           'The number where it failed is %s.'%nr, 'e')
                    spkgrp['exist_unique'] = False
                    results_hdf.close()
                    photon_hdf.close()
                    return
                try:
                    cps = np.loadtxt('no_synchronise_timediffs_%i.txt.cp'%(nr)) + ((i + nr) * splitpts)
                except OSError:
                    logger('no changepoint output at %s'%nr, 'e')
                    raise Exception('no changepoint output at %s'%nr)
                em2 = np.loadtxt('no_synchronise_timediffs_%i.txt.em.2'%(nr))

                try:
                    os.remove('no_synchronise_timediffs_%i.txt'%(nr))
                    os.remove('no_synchronise_timediffs_%i.txt.cp'%(nr))
                    for n in np.arange(1, params[-1]) + 1:
                        os.remove('no_synchronise_timediffs_%i.txt.em.%i'%(nr,n))
                    os.remove('no_synchronise_timediffs_%i.txt.ah'%(nr))
                except:
                    logger('something went wrong when deleting basic files', 'w')
                logger('deleted basic files', 'c')
                try:
                    os.remove('no_synchronise_timediffs_%i.txt.asc'%(nr))
                except:
                    logger('no .asc file', 'c')
                try:
                    os.remove('no_synchronise_timediffs_%i.txt.bic'%(nr))
                except:
                    logger('no .bic file', 'c')
                try:
                    os.remove('no_synchronise_timediffs_%i.txt.cp0'%(nr))
                except:
                    logger('no .cp0 file', 'c')

                minrate = np.min(em2[:,1])
                maxrate = np.max(em2[:,1])

                if minrate == maxrate:
                    em2[:,0] = 1
                else:
                    em2[:,0][em2[:,1] == minrate] = 1
                    em2[:,0][em2[:,1] == maxrate] = 2

                #init the arrays
                if (i + nr) == 0:
                    cpfull = cps
                    em2full = em2
                else:
                    if np.shape(cpfull)[0] == 0:
                        cpfull = cps
                        em2full = em2
                    else:
                        if np.shape(cps)[0] != 0:
                            cpfull = np.vstack([cpfull, cps])
                            em2full = np.vstack([em2full, em2[1:]])
            i += amnt_cores
            logger('now next step')
            # check if full arrays are in nice shape

            if np.shape(cpfull) == (3,):
                cpfull = cpfull[:,np.newaxis]
            elif np.shape(cpfull)[0] == 0:
                logger('no changepoints found', 'w')
                spkgrp['exist_unique'] = False
                results_hdf.close()
                photon_hdf.close()
                return

            if np.shape(em2full) == (3,):
                em2full = em2full[:,np.newaxis]
            elif np.shape(em2full)[0] == 0:
                logger('no em2 found', 'w')
                spkgrp['exist_unique'] = False
                results_hdf.close()
                photon_hdf.close()
                return


        #remove duplicates where no state change is from cps and em2
        statechanges = em2full[:-1,0] - em2full[1:,0]
        cps_unique = cpfull[:,0][statechanges != 0]
        states_unique_tp = em2full[:-1,0][statechanges != 0]
        states_unique_tp = np.append(states_unique_tp, em2full[-1,0])
        cps_ext = np.concatenate([np.array([0]), cps_unique, np.array([len(times) - 1])])

        high_state_mask = states_unique_tp > 1
        states_unique = np.array([cps_ext[:-1][high_state_mask], cps_ext[1:][high_state_mask],
                                  states_unique_tp[high_state_mask]]).T

        # check if the data has strange entries, this once occurred:
        # states_unique = states_unique[(states_unique[:,0] < len(times)) & (states_unique[:,1] < len(times))]


        #store tha data

        #remove all the datasets, to later make them new to save and retrieve the data
        if ('cpt_params' in spkgrp):
            del spkgrp['cpt_params']
            del spkgrp['cps']
            del spkgrp['em2']
            del spkgrp['states_unique']
            del spkgrp['exist_unique']

        spkgrp['cpt_params'] = ('splitpts %s and changepointparams: %s'%(splitpts, params))
        spkgrp['cps'] = cpfull
        spkgrp['em2'] = em2full
        spkgrp['states_unique'] = states_unique
        spkgrp['exist_unique'] = True

    
    photon_hdf.close()
    results_hdf.close()
    logger('\n----------------------\n\n\n done with changepoint\n\n\n----------------------', 'c')


def sample_info(hdf_loc, chnl, is_photonhdf=False, fullname=False):
    '''
    returns a dictionary with the info about the sample that is expected in the specific channel chnl.
    '''
    if is_photonhdf:
        photon_hdf = h5py.File(hdf_loc)
    else:
        photon_hdf = h5py.File(str(get_photonhdf_loc(hdf_loc, absolute=True)), "r")

    spchnnr = int(chnl[-1])
    # filtnr = photon_hdf['photon_data']['measurement_specs']['detectors_specs']['spectral_ch%s'%spchnnr]
    thiscol = float(photon_hdf['setup']['detection_wavelengths'][spchnnr][()])

    greens = ('Yoyo1', 'Atto 488', 'eGFP', 'Alexa 488', '6-FAM', 'FITC')
    reds = ('Atto 647N', 'Alexa 647', 'Cy5')
    
    #has dyes and focus_size in μm
    col_to_prop = {
        6.9e-7 : (reds, 0.35),
        5.5e-7 : (greens, 0.25),
        7.0e-7 : (reds, 0.35),
        5.25e-7 : (greens, 0.25)
    }

    d = {}
    for i in range(photon_hdf['user']['meas_spec']['nranalytes'][()]):
        for dye_nr in range(2):
            if photon_hdf['user']['meas_spec']['analyte_%s_dye_%s'%(i, dye_nr)][()].decode() in col_to_prop[thiscol][0]:
                d.update({'color' : photon_hdf['user']['meas_spec']['analyte_%s_pltclr_%s'%(i, dye_nr)][()].decode(),
                          'concentration' : photon_hdf['user']['meas_spec']['analyte_%s_conc'%i][()].decode(),
                          'dye' : photon_hdf['user']['meas_spec']['analyte_%s_dye_%s'%(i, dye_nr)][()].decode(),
                          'labeff' : photon_hdf['user']['meas_spec']['analyte_%s_dye_%s_labeff'%(i, dye_nr)][()].decode(),
                          'time_threshold' : photon_hdf['user']['meas_spec']['analyte_%s_time_threshold'%i][()].decode(),
                          'rgyr' : photon_hdf['user']['meas_spec']['analyte_%s_rgyr'%i][()].decode(),
                          'diff_coeff' : photon_hdf['user']['meas_spec']['analyte_%s_diff_coeff'%i][()].decode(),
                          'diff_coeff_err' : photon_hdf['user']['meas_spec']['analyte_%s_diff_coeff_err'%i][()].decode(),
                          'analyte' : photon_hdf['user']['meas_spec']['analyte_%s'%i][()].decode(),
                          'focus_size' :  col_to_prop[thiscol][1]                   
                          }
                        )
    if 'color' not in d.keys():
        d.update({'color' : TUD_brown,
                  'concentration' : 'None',
                  'dye' : 'None',
                  'labeff' : 'None',
                  'time_threshold' : 'None',
                  'diff_coeff' : 'None',
                  'diff_coeff_err' : 'None',
                  'analyte' : 'None',
                  'rgyr':'None',
                  'focus_size':col_to_prop[thiscol][1]}
                )
    d.update({'voltage' : photon_hdf['user']['meas_spec']['voltage'][()].decode(),
             'pore_diam':photon_hdf['user']['meas_spec']['pore_diam'][()].decode(),
             'pore_area':photon_hdf['user']['meas_spec']['pore_area'][()].decode(),
             'coating':photon_hdf['user']['meas_spec']['coating'][()].decode(),
             'membrane_buffer':photon_hdf['user']['meas_spec']['mem_buf'][()].decode(),
             'reservoir_buffer':photon_hdf['user']['meas_spec']['res_buf'][()].decode(),
             })
    if (len(d['analyte']) > 20) & (not fullname):
        d['analyte'] = 'Complex sample'
    # if results_hdf_in == 'None':
    #     results_hdf.close()
    photon_hdf.close()
    return d

def gen_pltcps(states, phystimes, bounds = 'None'):
    if bounds == 'None':
        lower = 0
        upper = phystimes[-1]
    else:
        lower = bounds[0]
        upper = bounds[1]
    pltcps = [[lower,0]]
    for state in states:
        starttime = phystimes[state[0]]
        endtime = phystimes[state[1]]
        betstate = state[2]
        pltcps.append([starttime, 0])
        pltcps.append([starttime, betstate])
        pltcps.append([endtime, betstate])
        pltcps.append([endtime, 0])
    #choose upper here, it might not be sorted, but after cutting out the wanted range in the end, it should become the last entry.
    pltcps.append([upper, 0])
    
    pltcps = np.array(pltcps).T
    pltcps = pltcps[:,(pltcps[0] >= lower) & (pltcps[0] <= upper)]
    

    if pltcps.shape[1] <= 2:
        pltcps = 'None'
    
    return pltcps


def plot_cp_trace(results_hdf_loc, tt_dict):
    '''
    plots the timetrace of the chnl in ax with a certain binningtime.
    Additionally adds the changepoint states.
    
    tt_dict needs to contain:
    * tt_binning binningtime in s
    * chnl
    * ax
    optional:
    * tt_start & tt_end in seconds (if not given the full timetrace is plotted)
    * stages (if not given final is chosen)
    * plt_cp (True/False)
    * lw
    * color
    * lab, a label for the legend
    '''
    
    # read the tt_dict
    logger('Now plotting cp trace', 'c')
    chnl = tt_dict['chnl']
    ax = tt_dict['ax']
    if 'stages' not in tt_dict:
        tt_dict['stages'] = ['final']
    stages = tt_dict['stages']

    if 'plt_cp' not in tt_dict:
        tt_dict['plt_cp'] = True
    plt_cp = tt_dict['plt_cp']

    if 'lw' not in tt_dict:
        tt_dict['lw'] = 1.5
    lw = tt_dict['lw']

    if 'lab' not in tt_dict:
        tt_dict['lab'] = ''
    lab = tt_dict['lab']

    if 'axylab' not in tt_dict:
        tt_dict['axylab'] = 'Intensity [kcps]'

    if 'secaxylab' not in tt_dict:
        tt_dict['secaxylab'] = 'Nr of photons'

    
    sample_d = sample_info(results_hdf_loc, chnl)   
    
    if 'color' not in tt_dict:
        tt_dict['color'] = sample_d['color']
    color = tt_dict['color'] 
        
    results_hdf = h5py.File(results_hdf_loc, "a")
    runstr = 'run_%s'%results_hdf['amnt_runs'][...]
    rungrp = results_hdf['run_%s'%results_hdf['amnt_runs'][...]]
    
    photon_hdf = h5py.File(get_photonhdf_loc(results_hdf_loc, absolute=True), "r")
    
    detectors = photon_hdf['photon_data']['detectors'][...]
    times = photon_hdf['photon_data']['timestamps'][()]
    times = times[detectors == int(chnl[-1])]
    timeres = photon_hdf['photon_data']['timestamps_specs']['timestamps_unit'][()]
    
    logger('loaded times from photonhdf5 for %s'%chnl, 'c')
    
    phystimes = times * timeres
    if 'tt_start' not in tt_dict:
        tt_dict['tt_start'] = phystimes[0]
    if 'tt_end' not in tt_dict:
        tt_dict['tt_end'] = phystimes[-1]
    if ('tt_start' in tt_dict) and ('tt_end' in tt_dict):
        tt_mask = (phystimes > tt_dict['tt_start']) & (phystimes < tt_dict['tt_end'])
        phystimes = phystimes[tt_mask]

    
    tt_binning = tt_dict['tt_binning']
    nrbins = int((phystimes[-1] - phystimes[0]) / tt_binning)
    binned_tt = np.histogram(phystimes, bins=nrbins)
    
    # make binned_tt plotable, because hist gives edges of the bins
    ts = (binned_tt[1][:-1] + binned_tt[1][1:] ) / 2
    cts = binned_tt[0]
    binned_tt = np.array((ts, cts))

    
    ax.plot(binned_tt[0], binned_tt[1] * 10 ** (-3) / tt_binning,
            color = color, linewidth=lw, label=lab)

    for stage in stages:
        if plt_cp:
            if rungrp[chnl]['spikes']['exist_' + stage][()]:
                states = rungrp[chnl]['spikes']['states_' + stage][()].astype(int)
                pltcps = gen_pltcps(states, times * timeres, bounds=(tt_dict['tt_start'],tt_dict['tt_end']))
                if pltcps != 'None':
                    if stage == stages[-1]:
                        ax.plot(pltcps[0], pltcps[1] * np.max(binned_tt[1]) / 5 * (10 ** (-3) / tt_binning),
                                color='k', alpha=1, linestyle='dotted', linewidth=lw)
                    else:
                        ax.plot(pltcps[0], pltcps[1]  * np.max(binned_tt[1]) / 5 * (10 ** (-3) / tt_binning),
                                color=TUD_grey, alpha=0.5, linewidth=lw)
    ax.set_ylabel(tt_dict['axylab'])
    ax.set_xlabel('Time [s]')
    if sample_d['analyte'] != 'None':
        if (sample_d['analyte'] == sample_d['dye']) or ('IBB-GFP' == sample_d['analyte']) :
            ax.set_title('%s at %s, binning %s s'%(sample_d['analyte'],
                                                   sample_d['concentration'],
                                                   tt_binning))
        else:
            ax.set_title('%s with %s at %s, binning %s s'%(sample_d['analyte'],
                                                           sample_d['dye'],
                                                           sample_d['concentration'],
                                                           tt_binning))
    else:
        ax.set_title('No analyte expected, binning %s s'%tt_binning)
    ax.set_xlim(binned_tt[0][0], binned_tt[0][-1])

    # for secondary yaxis
    def kcps_to_phots(kcps):
        return kcps * tt_binning * 1000

    def phots_to_kcps(phots):
        return phots / tt_binning * 1000

    secaxy = ax.secondary_yaxis('right', functions=(kcps_to_phots, phots_to_kcps))
    secaxy.set_ylabel(tt_dict['secaxylab'])
    ax.tick_params(top=True, which='both', direction='in')

    results_hdf.close()
    photon_hdf.close()
    return True


def gen_fcs(results_hdf_loc, chnls = ['chnl0', 'chnl1'], overwrite=True):
    '''
    chnls is a list of channels for which it should calculate the fcs curves
    if overwrite is True, it will overwrite any existing fcs groups in the run
    '''


    logger('Now calculating FCS for %s'%chnls, 'c')
    results_hdf = h5py.File(results_hdf_loc, "a")
    runstr = 'run_%s'%results_hdf['amnt_runs'][()]
    rungrp = results_hdf['run_%s'%results_hdf['amnt_runs'][()]]

    # generate the fcs group and check if there is already data in there

    if ('fcs' in rungrp):
        if overwrite:
            del rungrp['fcs']
            logger('Deleted old fcs group', 'c')
        else:
            logger('FCS aleady calculated, skipping', 'c')
            return True

    fcsgrp = rungrp.require_group('fcs')

    photon_hdf = h5py.File(get_photonhdf_loc(results_hdf_loc, absolute=True), "r")
    detectors = photon_hdf['photon_data']['detectors'][...]



    zs_unit = photon_hdf['photon_data']['timestamps_specs']['timestamps_unit'][()]
    bins = pyc.make_loglags(1, 8, 15)
    mean_lags = np.mean([bins[1:], bins[:-1]], 0) * zs_unit



    if len(chnls) == 1:

        zs = photon_hdf['photon_data']['timestamps'][()]
        zs = zs[detectors == int(chnls[0][-1])]

    elif len(chnls) == 2:
        zs = photon_hdf['photon_data']['timestamps'][()]
        zs = zs[detectors == int(chnls[0][-1])]
        os = photon_hdf['photon_data']['timestamps'][()]
        os = os[detectors == int(chnls[1][-1])]

    else:
        logger('Chnls were wrongly given, please check')
        results_hdf.close()
        photon_hdf.close()
        return False

    # calculate the correlations and save them in the HDF
    exist = []

    if len(zs) > 1:
        Gz = pyc.pcorrelate(zs, zs, bins, normalize=True)
        exist.append(chnls[0])
        fcsgrp.require_group(chnls[0])
        fcsgrp[chnls[0]]['Gs'] = Gz
    if (len(chnls) == 2):
        if len(os) > 1:
            Go = pyc.pcorrelate(os, os, bins, normalize=True)
            exist.append(chnls[1])
            fcsgrp.require_group(chnls[1])
            fcsgrp[chnls[1]]['Gs'] = Go
            if len(zs) > 1:
                Gcross = pyc.pcorrelate(zs, os, bins, normalize=True)
                exist.append('cross')
                fcsgrp.require_group('cross')
                fcsgrp['cross']['Gs'] = Gcross

    if len(exist) > 0:
        # meaning sth. was generated
        fcsgrp['mean_lags'] = mean_lags
        
    # dt needed to store a list of str in hdf5
    dt = h5py.special_dtype(vlen=str)
    
    fcsgrp['exist'] = np.array(exist, dtype=dt)


    results_hdf.close()
    photon_hdf.close()
    logger('Calculated FCS curves, closed files', 'c')

    return True  

def diffusion_3d(t, A0, taud, kappa):
    return 1 + A0 * (1 + (t / taud)) ** (-1) * (1 + (t / (kappa * taud))) ** (- 1 / 2)

def fit_fcs(results_hdf_loc, print_out=False, fit_to = 0.00001, fix_kappa=None):

    logger('Now fitting FCS curves', 'c')
    with h5py.File(results_hdf_loc, "a") as results_hdf:
        runstr = 'run_%s'%results_hdf['amnt_runs'][()]
        rungrp = results_hdf['run_%s'%results_hdf['amnt_runs'][()]]
        fcsgrp = rungrp['fcs']

        for chnl in ['chnl0', 'chnl1', 'cross']:

            if chnl.encode('ascii') in fcsgrp['exist'][()]:
                Gs = fcsgrp[chnl]['Gs'][...]
                mean_lags = fcsgrp['mean_lags'][...]

                y = Gs[mean_lags > fit_to]
                x = mean_lags[mean_lags > fit_to]

                try:
                    if fix_kappa is None:
                        popt, pcov = curve_fit(diffusion_3d, x, y, p0 = (y[0] - 1, 0.0001, 5), 
                                               bounds=((0,0.0000001,1), (10000,100,10)))
                    else:
                        popt, pcov = curve_fit(diffusion_3d, x, y, p0 = (y[0] - 1, 0.0001, fix_kappa), 
                                               bounds=((0,0.0000001,fix_kappa), (10000,100,fix_kappa)))
                except ValueError:
                    logger('Fitting failed with a ValueError at %s of %s'%(chnl, results_hdf_loc), 'w')
                    return False
                except RuntimeError:
                    logger('Fitting failed with a runtime error %s of %s'%(chnl, results_hdf_loc), 'w')
                    return False

                if print_out:
                    print(chnl)
                    print(popt)
                    print(pcov)
                if 'popt' in fcsgrp[chnl]:
                    del fcsgrp[chnl]['popt']
                    del fcsgrp[chnl]['pcov']
                fcsgrp[chnl]['popt']= popt
                fcsgrp[chnl]['pcov']= pcov

    logger('Calculated FCS curves, closed files', 'c')
    return True

def plot_fcs(results_hdf_loc, ax, fcs_dict):
    '''
    Plots the FCS curve to ax
    fcs_dict might contain:
    * normalize: if False 1/G(0) = N, if True G(t[normalize_to]) = 1
    * normalize_to: if not given normalize_to=0.  t[0] = 2.71e-7s, t[9] = 1.08e-6s, t[24] = 1.08e-5s. if 'fit' is True, normalized by the fit.
    * fit: Bool, whether the fitted result should be plotted
    * fit_cc: Bool, whether the fitted result for the crosscorrelation should be plotted
    * chnls: then chnls needs to be a list of 'chnl0', 'chnl1' or 'cross'
    * colors: color can be set individually as a list of colors
    * labels: The legend labels can be set as a list of labels
    * fcs_marker: if not given x is used
    '''
    
    logger('Now plotting fcs', 'c')
    if 'normalize' not in fcs_dict:
        fcs_dict['normalize'] = False

    if 'fit' not in fcs_dict:
        fcs_dict['fit'] = False

    if 'fit_cc' not in fcs_dict:
        fcs_dict['fit_cc'] = False

    if 'normalize_to' not in fcs_dict:
        fcs_dict['normalize_to'] = 0
        
    if 'chnls' not in fcs_dict:
        fcs_dict['chnls'] = ['chnl0', 'chnl1', 'cross']
        
    if 'colors' not in fcs_dict:
        fcs_dict['colors'] = []
        for chnl in fcs_dict['chnls']:
            if chnl == 'cross':
                fcs_dict['colors'].append('k')
            else:
                sample_d = sample_info(results_hdf_loc, chnl)
                fcs_dict['colors'].append(sample_d['color'])
                
    if 'labels' not in fcs_dict:
        fcs_dict['labels'] = []
        for chnl in fcs_dict['chnls']:
            if chnl == 'cross':
                fcs_dict['labels'].append('Crosscorrelation')
            else:
                sample_d = sample_info(results_hdf_loc, chnl)
                fcs_dict['labels'].append('%s at %s'%(sample_d['analyte'], sample_d['concentration']))
    
    if 'fcs_marker' not in fcs_dict:
        fcs_dict['fcs_marker'] = 'x'
    mark = fcs_dict['fcs_marker']
    
    with h5py.File(results_hdf_loc, "a") as results_hdf:
        runstr = 'run_%s'%results_hdf['amnt_runs'][()]
        rungrp = results_hdf['run_%s'%results_hdf['amnt_runs'][()]]
        try:
            fcsgrp = rungrp['fcs']
        except KeyError:
            logger('No FCS group found, please run gen_fcs before', 'e')
            return False

        for i in range(len(fcs_dict['chnls'])):
            chnl = fcs_dict['chnls'][i]
            if (chnl.encode('ascii') in fcsgrp['exist'][()]):

                col = fcs_dict['colors'][i]
                lab = fcs_dict['labels'][i]
                
                G = fcsgrp[chnl]['Gs'][...]
                mean_lags = fcsgrp['mean_lags'][...]
                
                if fcs_dict['normalize']:
                    G -= 1
                    normby = G[fcs_dict['normalize_to']]
                else:
                    normby = 1
                if fcs_dict['fit'] and ('popt' in fcsgrp[chnl]) and (chnl != 'cross' or fcs_dict['fit_cc']):
                    popt = fcsgrp[chnl]['popt']
                    pcov = fcsgrp[chnl]['pcov']
                    fitys = diffusion_3d(mean_lags, *popt)
                    if sample_info(results_hdf_loc, chnl)['focus_size'] != 'None':
                        lab += ', τD=%d μs, D~%i μm²/s'%(popt[1] * 10**6, sample_info(results_hdf_loc, chnl)['focus_size'] ** 2 / (4 * popt[1]))
                    else:
                        lab += ', τD=%d μs'%(popt[1] * 10**6)
                    if fcs_dict['normalize']:
                        normby = popt[0]
                        fitys -= 1
                    ax.semilogx(mean_lags, fitys / normby,  '-', color=col)
                
                ax.semilogx(mean_lags, G / normby, mark, color=col, label=lab)

        ax.set_xlabel('Timedifference τ [s]')
        if fcs_dict['normalize']:
            ax.set_ylabel('Normalized G(τ)')
            ax.set_ylim(-0.1, 2)
        else:
            ax.set_ylabel('G(τ)')
        ax.tick_params(top=True, right=True, which='both', direction='in')
        ax.legend()
    return True

def sort_fct(x, bckgrd_rate):
    '''
    This was an empirically determined function that was designed such that the background counts fall below it
    '''
    retval = (x * bckgrd_rate * 1 + 10 - (10 - 4) * np.exp(- x * bckgrd_rate ) +
              20 - (20) * np.exp(- x * bckgrd_rate  / 10) + 50 - (50) * np.exp(- x * bckgrd_rate / 100))

    return retval


def poiss_prob(eventphots, eventdurations, bckgrndrate):
    '''
    Gives the probability of finding less than the amount of photons found
    in the eventduration. If this probability is high
    then it is likely an event.
    It performs this on the arrays eventphots and eventdurations and compares them
    with the poisson process of the with bckgrndrate.
    Returns an array of the probabilities.
    '''
    bckgrd_phots = bckgrndrate * eventdurations
    good_events = (sort_fct(eventdurations, bckgrndrate) < eventphots)
    probs = np.zeros(np.shape(eventphots))
    probs[good_events] = poisson.cdf(eventphots[good_events], bckgrd_phots[good_events]) # prob to find less or equal the amount of photons
    return probs

def generate_event_stats(results_hdf_loc, stage='unique', chnllst = ['chnl0', 'chnl1']):
    results_hdf = h5py.File(results_hdf_loc, "a")
    runstr = 'run_%s'%results_hdf['amnt_runs'][()]
    rungrp = results_hdf['run_%s'%results_hdf['amnt_runs'][()]]

    photon_hdf = h5py.File(get_photonhdf_loc(results_hdf_loc, absolute=True), "r")
    detectors = photon_hdf['photon_data']['detectors'][...]
    phot_hdf_fold_str = str(Path(results_hdf_loc).parent.parent / 'Photon_HDF5') 
    
    for chnl in chnllst:
        logger('generating event stats for %s'%chnl)
        if rungrp[chnl]['spikes']['exist_' + stage][()]:
            times = photon_hdf['photon_data']['timestamps'][()]
            times = times[detectors == int(chnl[-1])]
            timeres = photon_hdf['photon_data']['timestamps_specs']['timestamps_unit'][()]
            phystimes = times*timeres
            spkgrp = rungrp[chnl]['spikes']
            states = spkgrp['states_' + stage][()].astype(int)
            
            starttimes = phystimes[states[:,0]]
            endtimes = phystimes[states[:,1]]

            intereventtimes = starttimes[1:] - endtimes[:-1]
            eventdurations = endtimes - starttimes
            
            eventphots = states[:,1] - states[:,0]

            if photon_hdf['user/meas_spec/bgrd_file'][()].decode() != 'None':
                bgrd_file = (phot_hdf_fold_str + '/' + 
                             photon_hdf['user/meas_spec/bgrd_file'][()].decode()[:-3] + 'hdf5')
                logger('Background rate of file %s used.'%bgrd_file, 'c')
                with h5py.File(bgrd_file, "r") as bgrd_hdf:
                    bgdets = bgrd_hdf['photon_data']['detectors'][...]
                    bgtimes = bgrd_hdf['photon_data']['timestamps'][()]
                    bgtimes = bgtimes[bgdets == int(chnl[-1])]
                    bgtimeres = bgrd_hdf['photon_data']['timestamps_specs']['timestamps_unit'][()]
                    bgphystimes = bgtimes*bgtimeres
                    bgrd_rate = len(bgphystimes)/bgphystimes[-1]
            else:
                logger('No background file given, so the photon rate of this file used', 'c')
                bgrd_rate = len(phystimes) / phystimes[-1]
                
            probs = poiss_prob(eventphots, eventdurations, bgrd_rate)
            
            if 'background_rate' not in rungrp[chnl]:
                rungrp[chnl]['background_rate'] = bgrd_rate

            # clear the old entries if it is just rerun on the same group
            if 'intereventdurations_%s'%stage in spkgrp:
                del spkgrp['intereventdurations_%s'%stage]
                del spkgrp['eventdurations_%s'%stage]
                del spkgrp['eventphots_%s'%stage]
                del spkgrp['eventprobs_%s'%stage]
                del spkgrp['eventctr_%s'%stage]
                del spkgrp['eventrate_%s'%stage]


            spkgrp['intereventdurations_%s'%stage] = intereventtimes
            spkgrp['eventdurations_%s'%stage] = eventdurations
            spkgrp['eventphots_%s'%stage] = eventphots
            spkgrp['eventprobs_%s'%stage] = probs
            spkgrp['eventctr_%s'%stage] = np.shape(states)[0]
            spkgrp['eventrate_%s'%stage] = np.shape(states)[0] / (phystimes[-1] - phystimes[0])
        else:
            logger('No events in %s so no stats generated'%stage, 'w')

    results_hdf.close()
    photon_hdf.close()

def sort_hsnr_states(results_hdf_loc, chnllst=['chnl0','chnl1'], alpha = 0.99):
    
    results_hdf = h5py.File(results_hdf_loc, "a")
    rungrp = results_hdf['run_%s'%results_hdf['amnt_runs'][()]]
    
    for chnl in chnllst:
        logger('sorting for HSNR events now for %s'%chnl, 'c')
        if rungrp[chnl]['spikes']['exist_unique'][()]:
            states_unique = rungrp[chnl]['spikes']['states_unique'][()].astype(int)
            probs = rungrp[chnl]['spikes']['eventprobs_unique'][()]
            states_hsnr = states_unique[probs > alpha]
            if states_hsnr.shape[0] == 0:
                rungrp[chnl]['spikes']['exist_hsnr'] = False
                continue
        else:
            rungrp[chnl]['spikes']['exist_hsnr'] = False
            continue
        rungrp[chnl]['spikes']['states_hsnr'] = states_hsnr
        rungrp[chnl]['spikes']['exist_hsnr'] = True
    results_hdf.close()


def plot_interevents(results_hdf_loc,  plt_int_dict):
    '''
    plt_int_dict should contain:
    * chnl
    * ax (tuple of 2 axes)
    it can contain
    * threshold ('protein' or a time in ms)
    * stage
    * ylim
    '''
    
    ((ax0, ax1),(ax2, ax3)) = plt_int_dict['ax']
    chnl = plt_int_dict['chnl']
    if 'stage' not in plt_int_dict:
        plt_int_dict['stage'] = 'hsnr'
    stage = plt_int_dict['stage']    
    
    logger('plotting interevents for %s in %s'%(chnl, stage), 'c')

    results_hdf = h5py.File(results_hdf_loc, "a")
    rungrp = results_hdf['run_%s'%results_hdf['amnt_runs'][()]]
    
    if not rungrp[chnl]['spikes']['exist_' + stage][()]:
        logger('no states for %s found so interevents skipped.'%stage, 'w')
        results_hdf.close()
        return False
    
    intereventtimes = rungrp[chnl]['spikes']['intereventdurations_' + stage][()]

    if len(intereventtimes) < 1:
        logger('There are no intereventtimes, skipping to plot it for %s in %s'%(chnl, stage), 'c')
        results_hdf.close()
        return False

    if 'ylim' in plt_int_dict:
        intereventtimes = intereventtimes[intereventtimes<plt_int_dict['ylim']]
    
    if len(intereventtimes) < 1:
        logger('There are no intereventtimes left after applying the ylim,' + 
               ' skipping to plot it for %s in %s'%(chnl, stage), 'c')
        results_hdf.close()
        return False
    
    sample_d = sample_info(results_hdf_loc, chnl)

    clusters = np.ones(intereventtimes.shape)
    if 'threshold' in plt_int_dict:
        if plt_int_dict['threshold'] == 'protein':
            plt_int_dict['threshold'] = sample_d['time_threshold']
    else:
        plt_int_dict['threshold'] = 'None'
    thrshold = plt_int_dict['threshold']

    if thrshold != 'None':
        thrshold = float(thrshold)
        clusters[intereventtimes < thrshold] = 0
    

    for ax in [ax0, ax2]:
        if thrshold != 'None':
            ax.scatter(np.random.rand(len(intereventtimes[clusters==0])), intereventtimes[clusters==0],
                       facecolors='none', alpha=0.7, edgecolors=TUD_pink)
            ax.hlines(thrshold, -0.1, 1.1, color='k')
        ax.scatter(np.random.rand(len(intereventtimes[clusters==1])), intereventtimes[clusters==1],
                   facecolors='none', alpha=0.7, edgecolors=sample_d['color'])
        
        ax.set_xlim(-0.1,1.1)
        ax.set_ylim(np.min(intereventtimes), np.max(intereventtimes))
        ax.set(xticklabels=[])
        ax.set_xticks([], [])
        ax.set_ylabel('Interevent duration [s]')
        ax.tick_params(right=True, which='both', direction='in')

    ax1.hist(intereventtimes, bins=30, orientation="horizontal", color=sample_d['color'])
    ax3.hist(intereventtimes, bins=np.logspace(np.log10(np.min(intereventtimes)),np.log10(np.max(intereventtimes)), 30),
             orientation="horizontal", log=False, color=sample_d['color'])
    if thrshold != 'None':
        ax0.set_title('Threshold at %2.1e s'%(thrshold))
    ax3.set_xlabel('Amount')
    ax1.tick_params(top=True, which='both', direction='in')
    ax3.tick_params(top=True, which='both', direction='in')
    ax2.set_yscale('log')
    results_hdf.close()
    return True

def combine_reentries(results_hdf_loc, chnllst=['chnl0', 'chnl1'], stage='hsnr', threshold = 'None'):
    '''
    combines events, that have a gap of less than threshold between end and start. 
    If no threshold is given, 1ms is used by default
    '''
    results_hdf = h5py.File(results_hdf_loc, "a")
    runstr = 'run_%s'%results_hdf['amnt_runs'][()]
    rungrp = results_hdf['run_%s'%results_hdf['amnt_runs'][()]]
    photon_hdf = h5py.File(get_photonhdf_loc(results_hdf_loc, absolute=True), "r")
    detectors = photon_hdf['photon_data']['detectors'][...]
    
    for chnl in chnllst:
        logger('combining reentries for %s'%chnl)
        spkgrp = rungrp[chnl]['spikes']
        if spkgrp['exist_' + stage][()]:
            if threshold == 'None':
                threshold = sample_info(results_hdf_loc,chnl)['time_threshold']
                if threshold == 'None':
                    threshold = 0.001
                    logger('No threshold for combine reentries given, so just 1ms used', 'w')
                else:
                    threshold = float(threshold)
                
            times = photon_hdf['photon_data']['timestamps'][()]
            times = times[detectors == int(chnl[-1])]
            timeres = photon_hdf['photon_data']['timestamps_specs']['timestamps_unit'][()]
            phystimes = times*timeres
            
            states = spkgrp['states_' + stage][()].astype(int)
            states_combined = []
            
            # combine the states, thus start with one state and look if the starttime
            # of the next one is too close to the endtime of this one. If it is, then use this as the new endtime
            i = 0
            j = 1
            thisstate = states[i].copy()
            while i + j < len(states[:,0]):
                timediff = phystimes[states[i + j, 0]] - phystimes[thisstate[1]]
                if (timediff < threshold) and (thisstate[2] == states[i + j, 2]):
                    # should be combined
                    thisstate[1] = states[i + j, 1]
                    j += 1
                else:
                    states_combined.append(thisstate)
                    thisstate = states[i + j].copy()
                    i = i + j
                    j = 1
            states_combined.append(thisstate)
            states_combined = np.array(states_combined, dtype = np.int)
            if states_combined.shape[0] != 0:
                if 'exist_combined' in spkgrp:
                    del spkgrp['exist_combined']
                    del spkgrp['states_combined']
                # if combine is rerun, also the statistics should be deleted
                if 'intereventdurations_combined' in spkgrp:
                    del spkgrp['intereventdurations_combined']
                    del spkgrp['eventdurations_combined']
                    del spkgrp['eventphots_combined']
                    del spkgrp['eventprobs_combined']
                    del spkgrp['eventctr_combined']
                    del spkgrp['eventrate_combined']
                spkgrp['exist_combined'] = True
                spkgrp['states_combined'] = states_combined
                spkgrp['states_combined'].attrs['Threshold'] = str(threshold)
            else:
                # this shouldn't happen, there should always be one state left.
                spkgrp['exist_combined'] = False
        else:
            logger('No events in %s, so no combined events generated.'%stage)
            spkgrp['exist_combined'] = False
            

    results_hdf.close()
    photon_hdf.close()

def plot_event_stats(results_hdf_loc, psd_dict):
    '''
    Plots the lengths of the events of state state in 2 scatter plots and 2 histograms:
    in the psd_dict there need to be:
    fig = fig
    ax = (ax0, ax1, ax2, ax3) or (ax0, ax1)
    what = 'physical' --> plots the physical length of the state in s,
           'photon' --> plots the amount of photons of the particular state
           'photon_rate' --> plots the photon rate of the particular state
           'probability' --> plots the probability of a state to be not background from poisson distribution
    
    optional are:
    stages is a list of stages of the state processing it should do, e.g. 'unique', 'combined', hsnr
    state
    threshold (bool, whether the sorting for hsnr is plotted in photon length plot)
    color
    '''
    plotted_sth = False


    fig = psd_dict['fig']
    axs = psd_dict['ax']
    chnl = psd_dict['chnl']
    what = psd_dict['what']
    if 'stages' not in psd_dict:
        psd_dict['stages'] = ['combined']
    stages = psd_dict['stages']
    if 'state' not in psd_dict:
            psd_dict['state'] = 2
    state = psd_dict['state']
    if 'threshold' not in psd_dict:
        psd_dict['threshold'] = False
    thrshold = psd_dict['threshold']



    
    logger('Now plotting event stats for %s'%stages)

    if state != 2:
        logger('watch out, the nr of events could be wrong here as the event counter' + 
               ' in generate_state_stats counted all transitions from state 1 to any state!', 'w')
    
    if axs.shape == (2,2):
        ((ax0, ax1), (ax2, ax3)) = axs
    elif axs.shape == (2,):
        (ax0, ax1) = axs
   
    results_hdf = h5py.File(results_hdf_loc, "r")
    rungrp = results_hdf['run_%s'%results_hdf['amnt_runs'][()]]
    spkgrp = rungrp[chnl]['spikes']
    

    sample_d = sample_info(results_hdf_loc, chnl)

    for stage in stages:
        if stage == stages[-1]:
            if 'color' not in psd_dict:
                psd_dict['color'] = sample_d['color']
            clr = psd_dict['color']
        else:
            clr = TUD_grey
        
        if spkgrp['exist_' + stage][()]:
            plotted_sth = True
            states = spkgrp['states_' + stage][()].astype(int)
            if states.shape[0] <= 1:
                continue

            if 'setalpha' not in psd_dict:
                if states.shape[0] > 300:
                    psd_dict['setalpha'] = 210./states.shape[0]
                else:
                    psd_dict['setalpha']  = 0.7
            setalpha = psd_dict['setalpha']

            if what == 'physical':
                plt_lengths = spkgrp['eventdurations_' + stage][states[:,2] == state]
                scatter_xs = np.random.rand(len(plt_lengths))
            elif what == 'photon':
                scatter_xs = spkgrp['eventdurations_' + stage][states[:,2] == state] * 1000
                plt_lengths = spkgrp['eventphots_' + stage][states[:,2] == state]
            elif what == 'photon_rate':
                scatter_xs = spkgrp['eventdurations_' + stage][states[:,2] == state] * 1000
                plt_lengths = spkgrp['eventphots_' + stage][states[:,2] == state] / scatter_xs
            elif what == 'probability':
                scatter_xs = spkgrp['eventdurations_' + stage][states[:,2] == state] * 1000
                plt_lengths = spkgrp['eventprobs_' + stage][states[:,2] == state]

            

            # plot the scatter plots
            ax0.scatter(scatter_xs, plt_lengths, facecolors='none', alpha=setalpha, edgecolors=clr)
            
            ax1.scatter(scatter_xs, plt_lengths, facecolors='none', alpha=setalpha, edgecolors=clr)
            ax1.set_yscale('log')
            if what != 'physical':
                ax1.set_xscale('log')

            # plot the histograms
            if axs.shape == (2,2):
                outlier_thrshld = np.median(plt_lengths) + np.std(plt_lengths)
                if what == 'physical':
                    ax2.hist(plt_lengths, bins=50, color = clr)
                    ax3.hist(plt_lengths[plt_lengths < outlier_thrshld] * 1000, bins=50, color = clr)
                elif what == 'photon':
                    ax2.hist(plt_lengths, bins=20, color = clr)
                    ax3.hist(plt_lengths[plt_lengths < outlier_thrshld], bins=20, color = clr)
                elif what == 'photon_rate':
                    ax2.hist(plt_lengths, bins=50, color = clr)
                    ax3.hist(plt_lengths[plt_lengths < outlier_thrshld], bins=50, color = clr)        
                elif what == 'probability':
                    ax2.hist(plt_lengths, bins=50, color = clr)
                    ax3.hist(plt_lengths[plt_lengths < outlier_thrshld], bins=50, color = clr)        

            # plot thrsholds for sorting
            if what == 'photon' and thrshold:
                x = np.logspace(-10, 10, num=10000)
                for ax in [ax0, ax1]:
                    ax.plot(x * 1000., sort_fct(x,rungrp[chnl]['background_rate']) , 'k-')


            # fix the axes labelling and margins
            for ax in [ax0, ax1]:
                ax.set_ylim(np.min(plt_lengths), np.max(plt_lengths))
                ax.set_xlim(np.min(scatter_xs), np.max(scatter_xs))
                if what == 'physical':
                    ax.set_xlim(-0.1,1.1)
                    ax.set(xticklabels=[])
                    ax.set_xticks([], [])
                    ax.tick_params(right=True, which='both', direction='in')
                else:
                    ax.set_xlabel('Event duration (ms)')
                    ax.tick_params(top=True, right=True, which='both', direction='in')
                if what == 'physical':
                    ax.set_ylabel('Event duration (s)')
                elif what == 'photon':
                    ax.set_ylabel('Number of photons')
                elif what == 'photon_rate':
                    ax.set_ylabel('Photon rate (1/ms)')
                elif what == 'probability':
                    ax.set_ylabel('P(<N photons)')
            if axs.shape == (2,2):
                if what == 'physical':    
                    ax2.set_xlabel('Duration (s)')
                    ax3.set_xlabel('Duration (ms)')
                elif what == 'photon':
                    ax2.set_xlabel('Number of photons')
                    ax3.set_xlabel('Number of photons')
                elif what == 'photon_rate':
                    ax2.set_xlabel('Photon rate (1/ms)')
                    ax3.set_xlabel('Photon rate (1/ms)')
                elif what == 'probability':
                    ax2.set_xlabel('P(<N photons)')
                    ax3.set_xlabel('P(<N photons)')

                ax2.set_ylabel('# Events')
                ax2.tick_params(top=True, right=True, which='both', direction='in')
                ax3.tick_params(top=True, right=True, which='both', direction='in')
                ax3.set_title('Outliers removed')
                ax3.set_ylabel('# Events')
            if sample_d['analyte'] != 'None':
                fig.suptitle('%i %s events at %1.1f events/s'%(len(plt_lengths), sample_d['analyte'], 
                                                                spkgrp['eventrate_' + stage][()]))
            else:
                fig.suptitle('%i background events at %1.1f events/s'%(len(plt_lengths), 
                                                                       spkgrp['eventrate_' + stage][()]))
    results_hdf.close()
    return plotted_sth
    


def savefig_with_params(results_hdf_loc, fig, figdict, kind, changes='', mod='std', saveassvg='None'):
    '''
    kind is one of the following: cptt, correl, int_ev, evstat, evov, special
    '''
    plotfold = Path(results_hdf_loc).parent.parent / 'Plotting'
    plotfold.mkdir(exist_ok=True)
    plotfold = str(plotfold)
    
    handle = Path(results_hdf_loc).name[:-13]
    results_hdf = h5py.File(results_hdf_loc, "a")
    runstr = 'run_%s'%results_hdf['amnt_runs'][()]
    
    t = datetime.datetime.now()
    pid = t.strftime('%y%m%d%H%M%S%f')
    logger('plotting plot: %s'%pid)

    if kind == 'cptt':
        stagestr = ''
        for stage in figdict['stages']:
            stagestr += stage
        figname = '%s/%s_%s_%s_%s_%s_%s.png'%(plotfold, handle, kind, mod, stagestr, runstr, pid)
    elif kind == 'correl':
        figname = '%s/%s_%s_%s_%s_%s.png'%(plotfold, handle, kind, mod, runstr, pid)
    elif kind == 'evstat':
        stagestr = ''
        for stage in figdict['stages']:
            stagestr += stage
        figname = '%s/%s_%s_%s_%s_%s_%s_%s.png'%(plotfold, handle, figdict['chnl'], kind, figdict['what'], stagestr, runstr, pid)
    elif (kind == 'evov') or (kind == 'int_ev'):
        figname = '%s/%s_%s_%s_%s_%s_%s.png'%(plotfold, handle, figdict['chnl'], kind,  mod, runstr, pid)
    else:
        figname = '%s/%s_%s_%s_%s_%s.png'%(plotfold, handle, kind, mod, runstr, pid)
    fig.savefig(figname, dpi=fig.dpi)
    if saveassvg != 'None':
        fig.savefig(saveassvg)

    # now save in resultshdf5
    pidgrp = results_hdf.require_group(runstr + '/Plotting/' + pid)
    # dt needed to store a list of str in hdf5
    dt = h5py.special_dtype(vlen=str)
    this_d = figdict.copy()
    this_d['changes'] = np.array(changes, dtype=dt)
    this_d['kind'] = kind
    if kind == 'cptt':
        this_d['chnl'] = np.array(['chnl0', 'chnl1'], dtype=dt)
        this_d['stages'] = np.array(this_d['stages'], dtype=dt)
    elif kind == 'correl':
        this_d['chnls'] = np.array(this_d['chnls'], dtype=dt)
        this_d['colors'] = np.array(this_d['colors'], dtype=dt)
        this_d['labels'] = np.array(this_d['labels'], dtype=dt)
        this_d['fcs_marker'] = np.array(this_d['fcs_marker'], dtype=dt)
    elif kind == 'evstat':
        this_d['stages'] = np.array(this_d['stages'], dtype=dt)
    for key in this_d:
        if (key != 'ax') and (key != 'fig'):
            pidgrp[key] = this_d[key]
    if socket.gethostname() == 'TUD207901':
        plt.close(fig)
    results_hdf.close()
    logger('saved %s fig'%kind, 'c')


def plot_event_overlay(results_hdf_loc, ovdict):
    '''
    ovdict needs: 
    * chnl
    * ax
    optional:
    * amnt
    * t_before
    * t_after
    * stage
    * binning
    * setalpha
    * thisevent (plots out one event if not 'None')
    * xmax (the upper xlim)
    * ymax (the upper ylim)
    * setcolor
    * avcolor
    '''

    chnl = ovdict['chnl']
    ax = ovdict['ax']

    if 't_before' not in ovdict:
        ovdict['t_before'] = 0.001
    t_before = ovdict['t_before']

    if 't_after' not in ovdict:
        ovdict['t_after'] = 0.001
    t_after = ovdict['t_after']

    if 'xmax' not in ovdict:
        ovdict['xmax'] = 'None'
    xmax = ovdict['xmax']

    if 'ymax' not in ovdict:
        ovdict['ymax'] = 'None'
    ymax = ovdict['ymax']

    
    if 'stage' not in ovdict:
        ovdict['stage'] = 'combined'
    stage = ovdict['stage']
    
    if 'binning' not in ovdict:
        ovdict['binning'] = 0.0001
    tt_binning = ovdict['binning']
    
    if 'thisevent' not in ovdict:
        ovdict['thisevent'] = 'None'
    thisevent = ovdict['thisevent']

    logger('plotting event overlay of %s'%chnl)

    results_hdf = h5py.File(results_hdf_loc, "a")
    spkgrp = results_hdf['run_%s'%results_hdf['amnt_runs'][()]][chnl]['spikes']
    runstr = 'run_%s'%results_hdf['amnt_runs'][()]
    photon_hdf = h5py.File(get_photonhdf_loc(results_hdf_loc, absolute=True), "r")

    detectors = photon_hdf['photon_data']['detectors'][...]
    times = photon_hdf['photon_data']['timestamps'][()]
    times = times[detectors == int(chnl[-1])]
    timeres = photon_hdf['photon_data']['timestamps_specs']['timestamps_unit'][()]
    phystimes = times*timeres

    if spkgrp['exist_' + stage][()]:
        states = spkgrp['states_' + stage][()].astype(int)
        
        if thisevent != 'None':
            startr = thisevent
            ovdict['amnt'] = thisevent + 1
            plt_states = states
        else:
            startr = 0
            if 'amnt' not in ovdict:
                ovdict['amnt'] = states.shape[0]
            if ovdict['amnt'] < states.shape[0]:
                plt_states = states[np.random.choice(states.shape[0], ovdict['amnt'], replace=False), :]
            else:
                plt_states = states
                ovdict['amnt'] = states.shape[0]
        amnt = ovdict['amnt']

        if 'setalpha' not in ovdict:
            if (amnt - startr) > 10:
                ovdict['setalpha'] = 10./amnt
            elif (amnt - startr) == 1:
                ovdict['setalpha'] = 1
            else:
                ovdict['setalpha']  = 0.5
        setalpha = ovdict['setalpha']
        
        sample_d = sample_info(results_hdf_loc, chnl)

        if 'setcolor' not in ovdict:
            ovdict['setcolor'] = sample_d['color']
        setcolor = ovdict['setcolor']

        if 'avcol' not in ovdict:
            ovdict['avcol'] = 'k'
        avcol = ovdict['avcol']


        firstplt = True
        maxdur = np.max(phystimes[plt_states[startr:amnt,1]] - phystimes[plt_states[startr:amnt,0]]) + t_after

        for i in range(startr, amnt):
            event = plt_states[i]
            pltstart = phystimes[event[0]] - t_before
            pltend = phystimes[event[0]] + maxdur
            if (pltstart < 0) or (pltend) > phystimes[-1]:
                ovdict['amnt'] -= 1
                amnt -= 1
                continue

            plt_phystimes = phystimes[(phystimes > pltstart) & (phystimes < pltend)] - phystimes[event[0]]

            nrbins = int((pltend - pltstart) / tt_binning)
            binned_tt = np.histogram(plt_phystimes, bins=nrbins)

            # make binned_tt plottable, because hist gives edges of the bins
            ts = (binned_tt[1][:-1] + binned_tt[1][1:] ) / 2
            cts = binned_tt[0]
            binned_tt = np.array((ts, cts))


            if firstplt:
                ax.plot(binned_tt[0]* 1000, binned_tt[1] * 10 ** (-3) / tt_binning,
                        color = setcolor, alpha=setalpha, label='Individual timetrace')
                alltimes = plt_phystimes
                firstplt = False
            else:
                ax.plot(binned_tt[0]* 1000, binned_tt[1] * 10 ** (-3) / tt_binning,
                        color = setcolor, alpha=setalpha)
                alltimes = np.concatenate((alltimes, plt_phystimes))

        if (ovdict['thisevent'] == 'None') and (amnt - startr > 1):
            nrbins = int((np.max(alltimes) - np.min(alltimes)) / tt_binning)
            binned_tt = np.histogram(alltimes, bins=nrbins)

            # make binned_tt plottable, because hist gives edges of the bins
            ts = (binned_tt[1][:-1] + binned_tt[1][1:] ) / 2
            cts = binned_tt[0]
            binned_tt = np.array((ts, cts / float(amnt)))

            ax.plot(binned_tt[0] * 1000, binned_tt[1] * 10 ** (-3) / tt_binning,
                        color = avcol,label='Average of %i events'%amnt, linewidth=2)
            ax.set_title('%s events at %ss binning'%(sample_d['analyte'], tt_binning))
            ax.legend(loc='upper right')
            ax.tick_params(top=True, right=True, which='both', direction='in')
        elif (amnt - startr) < 1:
            results_hdf.close()
            photon_hdf.close()
            return False
        else:
            ax.axvspan(0, (pltend - t_after - t_before - pltstart) * 1000, alpha=0.5, color=TUD_grey)
            ax.set_title('One %s event at %ss binning'%(sample_d['analyte'], tt_binning))

        ax.set_xlim(binned_tt[0,0] * 1000, binned_tt[0,-1] * 1000)           
        ax.set_ylabel('Intensity [kcps]')
        ax.set_xlabel('Δt [ms]')
        if xmax != 'None':
            ax.set_xlim(right=xmax)
        if ymax != 'None':
            ax.set_ylim(top=ymax)
        

    else:
        results_hdf.close()
        photon_hdf.close()
        return False

    results_hdf.close()
    photon_hdf.close()

    return True


def get_eventrate_etc(results_hdf_loc , resdict, stage='hsnr', chnllst=['chnl0', 'chnl1']):
    results_hdf =  h5py.File(results_hdf_loc, "r")
    
    runstr = 'run_%s'%results_hdf['amnt_runs'][...]
    rungrp = results_hdf['run_%s'%results_hdf['amnt_runs'][...]]
    resdict['fname'].append(os.path.basename(results_hdf_loc))
    resdict['results_hdf_loc'].append(results_hdf_loc)
    first = True
    for chnl in chnllst:
        # get measurement info

        with h5py.File(get_photonhdf_loc(results_hdf_loc, absolute=True), "r") as photon_hdf:
            detectors = photon_hdf['photon_data']['detectors'][...]
            times = photon_hdf['photon_data']['timestamps'][()]
            times = times[detectors == int(chnl[-1])]
            timeres = photon_hdf['photon_data']['timestamps_specs']['timestamps_unit'][()]

            phystimes = times * timeres

            resdict['%s photons'%chnl].append(len(phystimes))
            if first:
                resdict['meas_dur'].append(phystimes[-1])
        
        sampled = sample_info(results_hdf_loc, chnl)
        resdict['%s analyte'%chnl].append(sampled['analyte'])
        if sampled['labeff'] != 'None':
            resdict['%s labeff'%chnl].append(float(sampled['labeff']))
        else:
            resdict['%s labeff'%chnl].append(1.0)
        resdict['%s color'%chnl].append(sampled['color'])
        resdict['%s concentration'%chnl].append(sampled['concentration'])
        resdict['%s diff_coeff'%chnl].append(sampled['diff_coeff'])
        resdict['%s diff_coeff_err'%chnl].append(sampled['diff_coeff_err'])
        resdict['%s rgyr'%chnl].append(sampled['rgyr'])
        if first:
            if sampled['voltage'] != 'None':
                resdict['voltage'].append(float(sampled['voltage']))
            else:
                resdict['voltage'].append(0)
            resdict['pore_area'].append(sampled['pore_area'])
            resdict['pore_diam'].append(sampled['pore_diam'])
            if 'x' in sampled['pore_diam']:
                shortax, longax = sampled['pore_diam'].split('x')
                resdict['pore_diam_plot'].append((float(shortax) + float(longax)) / 2)
            elif sampled['pore_diam'] == 'None':
                resdict['pore_diam_plot'].append(np.NaN)
            else:
                resdict['pore_diam_plot'].append(float(sampled['pore_diam']))

            resdict['coating'].append(sampled['coating'])
            resdict['membrane_buffer'].append(sampled['membrane_buffer'])
            resdict['reservoir_buffer'].append(sampled['reservoir_buffer'])
            first = False
        
        
        # get the event rate
        resdict['%s event_rate'%chnl].append(0)
        resdict['%s event_ctr'%chnl].append(0)
        resdict['%s eventphotons'%chnl].append(0)
        resdict['%s eventphotons_err'%chnl].append(0)
        resdict['%s eventphotrate'%chnl].append(0)
        resdict['%s eventphotrate_err'%chnl].append(0)

        get_the_data = False
        if rungrp[chnl]['spikes']['exist_unique'][()]:
            if stage == 'unique':
                get_the_data = True

            elif rungrp[chnl]['spikes']['exist_hsnr'][()]:
                if stage == 'hsnr':
                    get_the_data = True
                elif stage == 'combined':
                    get_the_data = True

        if get_the_data:
                resdict['%s event_rate'%chnl][-1] = rungrp[chnl]['spikes']['eventrate_%s'%stage][()]
                resdict['%s event_ctr'%chnl][-1] = len(rungrp[chnl]['spikes']['states_%s'%stage][()])
                tmpvar = rungrp[chnl]['spikes']['eventphots_%s'%stage][()]

                resdict['%s eventphotons'%chnl][-1] = np.mean(tmpvar)
                resdict['%s eventphotons_err'%chnl][-1] = np.std(tmpvar) / np.sqrt(len(tmpvar))

                tmpvar = tmpvar / rungrp[chnl]['spikes']['eventdurations_%s'%stage][()]
                resdict['%s eventphotrate'%chnl][-1] = np.mean(tmpvar)
                resdict['%s eventphotrate_err'%chnl][-1] = np.std(tmpvar) / np.sqrt(len(tmpvar))
        else:
            resdict['%s event_rate'%chnl][-1] = None
            resdict['%s event_ctr'%chnl][-1] = None
            resdict['%s eventphotons'%chnl][-1] = None
            resdict['%s eventphotons_err'%chnl][-1] = None
            resdict['%s eventphotrate'%chnl][-1] = None
            resdict['%s eventphotrate_err'%chnl][-1] = None

        
        # get FCS results
    addedchnllst = chnllst + ['cross']
    for chnl in addedchnllst:
        if 'popt' in rungrp['fcs'][chnl]:
            resdict['%s A0'%chnl].append(rungrp['fcs'][chnl]['popt'][0])
            resdict['%s A0_err'%chnl].append(rungrp['fcs'][chnl]['pcov'][0,0] ** 0.5)
            resdict['%s taud'%chnl].append(rungrp['fcs'][chnl]['popt'][1])
            resdict['%s taud_err'%chnl].append(rungrp['fcs'][chnl]['pcov'][1,1] ** 0.5)
            resdict['%s kappa'%chnl].append(rungrp['fcs'][chnl]['popt'][2])
            resdict['%s kappa_err'%chnl].append(rungrp['fcs'][chnl]['pcov'][2,2] ** 0.5)
        else:
            resdict['%s A0'%chnl].append(0)
            resdict['%s A0_err'%chnl].append(0)
            resdict['%s taud'%chnl].append(0)
            resdict['%s taud_err'%chnl].append(0)
            resdict['%s kappa'%chnl].append(0)
            resdict['%s kappa_err'%chnl].append(0)
    results_hdf.close()      
    return resdict


def gen_respd(reslist, stage, analytype='protein', labefs = {'chnl0':1, 'chnl1':1}):
    '''
    generates a dataframe from a list of result.hdf5 . Uses the event_rate etc. of stage.
    The datframe is indexed and ordered by the fname
    '''
    respd = pd.DataFrame()
    resdict = {'chnl0 event_rate':[], 'chnl1 event_rate':[],
               'chnl0 analyte':[], 'chnl1 analyte':[], 
               'chnl0 labeff':[], 'chnl1 labeff':[], 
               'chnl0 color':[], 'chnl1 color':[],
               'chnl0 diff_coeff':[], 'chnl1 diff_coeff':[],
               'chnl0 diff_coeff_err':[], 'chnl1 diff_coeff_err':[],
               'chnl0 concentration':[], 'chnl1 concentration':[],
               'chnl0 event_ctr':[], 'chnl1 event_ctr':[], 
               'chnl0 rgyr':[], 'chnl1 rgyr':[],
               'chnl0 taud' : [], 'chnl1 taud':[], 'cross taud':[],
               'chnl0 taud_err' : [], 'chnl1 taud_err':[], 'cross taud_err':[],
               'chnl0 A0' : [], 'chnl1 A0':[], 'cross A0':[],
               'chnl0 A0_err' : [], 'chnl1 A0_err':[], 'cross A0_err':[],
               'chnl0 kappa' : [], 'chnl1 kappa':[], 'cross kappa':[],
               'chnl0 kappa_err' : [], 'chnl1 kappa_err':[], 'cross kappa_err':[],
               'chnl0 photons' : [], 'chnl1 photons' : [],
               'chnl0 eventphotrate' : [], 'chnl1 eventphotrate' : [],
               'chnl0 eventphotrate_err' : [], 'chnl1 eventphotrate_err' : [],
               'chnl0 eventphotons' : [], 'chnl1 eventphotons' : [],
               'chnl0 eventphotons_err' : [], 'chnl1 eventphotons_err' : [],
               'fname':[], 'voltage':[], 'pore_diam':[], 'pore_diam_plot':[], 'pore_area':[], 'coating':[],
               'reservoir_buffer':[], 'membrane_buffer':[], 'results_hdf_loc':[], 'meas_dur':[],
               }

    for reshdf in reslist: 
        resdict = get_eventrate_etc(reshdf, resdict, stage=stage)

    respd = pd.DataFrame(resdict)
    respd = respd.set_index('fname')


    if analytype == 'protein':
        for chnl in ['chnl0', 'chnl1']:
            lst = []
            for index, row in respd.iterrows():
                if row[chnl + ' concentration'] == 'None':
                    lst.append(0)
                else:
                    try:
                        lst.append(float(row[chnl + ' concentration'][:-2]))
                    except ValueError:
                        lst.append(np.NaN)

            respd[chnl + ' rate_err'] = np.sqrt(respd[chnl + ' event_ctr']) / (respd[chnl + ' event_ctr'] / respd[chnl + ' event_rate'])

            respd[chnl + ' concentration unit'] = lst
            respd[chnl + ' event_rate_norm'] = respd[chnl + ' event_rate'] / (respd[chnl + ' concentration unit'] * respd[chnl + ' labeff'])
            respd[chnl + ' event_rate_norm'] = respd[chnl + ' event_rate_norm'].replace(np.inf, np.nan)

            respd[chnl + ' rate_err_norm'] = respd[chnl + ' rate_err'] / (respd[chnl + ' concentration unit'] * respd[chnl + ' labeff'])
            respd[chnl + ' rate_err_norm'] = respd[chnl + ' rate_err_norm'].replace(np.inf, np.nan)

    elif analytype == 'DNA':
        for chnl in ['chnl0', 'chnl1']:
            lst = []
            for index, row in respd.iterrows():
                if row[chnl + ' concentration'] == 'None':
                    lst.append(0)
                else:
                    try:
                        lst.append(float(row[chnl + ' concentration'][:-5]))
                    except ValueError:
                        lst.append(np.NaN)

            respd[chnl + ' rate_err'] = np.sqrt(respd[chnl + ' event_ctr']) / (respd[chnl + ' event_ctr'] / respd[chnl + ' event_rate'])

            respd[chnl + ' concentration unit'] = lst
    #respd = respd.fillna(0)

    respd = respd.sort_values(by='fname')
    respd['order'] = np.arange(len(respd))
    
    return respd


def subcategorybar(ax, X, vals, collst, alphalst, labellst, width=0.8):
    '''
    X is a list of the labels and is the amount of conditions tested
    collst has length of the different analytes
    '''
    
    n = len(collst)
    _X = np.arange(len(X))
    for i in range(n):
        ax.bar(_X - width/2. + i/float(n)*width, vals[i], 
                width=width/float(n), align="edge", color=collst[i], alpha=alphalst[i], label = labellst[i])
    
    ax.set_xticks(_X)
    ax.set_xticklabels(X, rotation='vertical')
    ax.legend()

def diff_theory_smol(c,r,D):
    '''
    r in nm
    D in μm²/s
    c in nM
    
    returns events/s
    '''
    # convert to SI
    #r_bar = r * 10**(-9)
    #D_bar = D * 10 ** (-12)
    #c_bar = c * 6.02214076e23 * 10**3 * 10**(-9)
    conv = 6.02214076e-4
    J = 2 * np.pi * c * r * D * conv
    return J

def diff_theory_fick(c,A,D,d):
    '''
    c in nM
    A in μm²
    D in μm²/s
    d in nm (thickness)
    '''
    #D_bar = D * 10 ** (-12)
    #c_bar = c * 6.02214076e23 * 10**3 * 10**(-9)
    #A_bar = A * 10 ** (-12)
    #d_bar = d * 10 * (-9)

    conv = 6.02214076e2 #10^26 * 10^-9 * 10^9 * 10^-12 * 10^-12
    J = c / d * D
    J_part = J * A
    return J_part * conv

def D_constrained(Dopen, rgyr, rpore):
    '''
    rgyr in nm
    rpore in nm
    Dopen is same unit as D_constrained
    '''

    frac = rgyr / rpore
    corr = (1 + 9/8 * frac * np.log(frac) - 1.56034 * frac + 0.528155 * frac ** 2 + 1.91521 * frac ** 3 -
           2.81903 * frac ** 4 + 0.270788 * frac ** 5 + 1.10115 * frac ** 6 - 0.435933 * frac ** 7)
    
    return Dopen * corr


def plot_transloc_overview(axs, respd, pltdict={}):
    typlst = ['X-', 'o-', '*-', 'd-', 'P-']

    if 'theory' not in pltdict:
        pltdict['theory'] = True
    theory = pltdict['theory']

    (ax0, ax1, ax2) = axs

    if 'collst' not in pltdict:
        pltdict['collst'] = [respd['chnl0 color'][-1],respd['chnl1 color'][-1]]
    collst = pltdict['collst']

    if 'labellst' not in pltdict:
        pltdict['labellst'] = [respd['chnl0 analyte'][-1],respd['chnl1 analyte'][-1]]
    labellst = pltdict['labellst']

    i = 0
    for conc in respd['chnl1 concentration unit'].unique():
        smallpd = respd[respd['chnl1 concentration unit'] == conc]
        if len(smallpd) < 2:
            continue

        smallpd = smallpd.sort_values(by='chnl0 concentration unit')
        if smallpd['chnl1 analyte'][0] == 'None':
            labelstr = 'No other analyte'
        else:    
            labelstr = smallpd['chnl1 analyte'][0] + ' at ' + smallpd['chnl1 concentration'][0]

        ax0.errorbar(smallpd['chnl0 concentration unit'], smallpd['chnl0 event_rate'], smallpd['chnl0 rate_err'],
                     fmt=typlst[i], color = collst[0], label=labelstr, ms=15, mec='k')
        
        
        if (i == 0):
            ax0.set_xlabel(labellst[0] + ' concentration [nM]')
            ax0.set_ylabel('Eventrate [1/s]')
            if theory:
                if 'x' in smallpd['pore_diam'][-1]:
                    rads = np.array(smallpd['pore_diam'][-1].split('x')).astype(np.float)
                    prad = np.average(rads) / 2.
                else:
                    prad = float(smallpd['pore_diam'][-1]) / 2.
                D_con = D_constrained(float(smallpd['chnl0 diff_coeff'][-1]), float(smallpd['chnl0 rgyr'][-1]), prad)
                if respd['pore_area'][0] == 'None':
                    pa = np.pi * prad ** 2 * 10 ** (-6)
                else:
                    pa = float(respd['pore_area'][0][:-3])
                ax0.plot(smallpd['chnl0 concentration unit'], 
                         diff_theory_fick(smallpd['chnl0 concentration unit'], pa, D_con, 100),
                         color=TUD_grey, label="Fick's law")
                logger('Membrane thickness of 100nm used, watch out if it is changed at some point', 'w')
        i += 1
    
    ax0.legend(prop={'size':14})
    ax0.tick_params(top=True, right=True, which='both', direction='in')
    
    i = 0
    for conc in respd['chnl0 concentration unit'].unique():
        smallpd = respd[respd['chnl0 concentration unit'] == conc]
        if len(smallpd) < 2:
            continue

        smallpd = smallpd.sort_values(by='chnl1 concentration unit')
        if smallpd['chnl0 analyte'][0] == 'None':
            labelstr = 'No other analyte'
        else:    
            labelstr = smallpd['chnl0 analyte'][0] + ' at ' + smallpd['chnl0 concentration'][0]

        ax1.errorbar(smallpd['chnl1 concentration unit'], smallpd['chnl1 event_rate'], smallpd['chnl1 rate_err'],
                     fmt=typlst[i], color = collst[1], label=labelstr, ms=15, mec='k')
        


        if (i == 0) :
            ax1.set_xlabel(labellst[1] + ' concentration [nM]')
            ax1.set_ylabel('Eventrate [1/s]')

            if theory:
                if 'x' in smallpd['pore_diam'][-1]:
                    rads = np.array(smallpd['pore_diam'][-1].split('x')).astype(np.float)
                    prad = np.average(rads) / 2.
                else:
                    prad = float(smallpd['pore_diam'][-1]) / 2.
                D_con = D_constrained(float(smallpd['chnl1 diff_coeff'][-1]), float(smallpd['chnl1 rgyr'][-1]), prad)
                
                if respd['pore_area'][0] == 'None':
                    pa = np.pi * prad ** 2 * 10 ** (-6)
                else:
                    pa = float(respd['pore_area'][0][:-3])
                ax1.plot(smallpd['chnl1 concentration unit'], 
                         diff_theory_fick(smallpd['chnl1 concentration unit'], pa, D_con, 100),
                         color=TUD_grey, label="Fick's law")
                logger('Membrane thickness of 100nm used, watch out if it is changed at some point', 'w')
        i += 1
    ax1.legend(prop={'size':14} )
    ax1.tick_params(top=True, right=True, which='both', direction='in')
    strlst = []
    for index, line in respd.iterrows():
        if (line['chnl0 analyte'] == 'None') & (line['chnl1 analyte'] == 'None'):
            strlst.append('No analyte')
        elif (line['chnl0 analyte'] != 'None') & (line['chnl1 analyte'] != 'None'):
            strlst.append('%s at %s, %s at %s'%(line['chnl0 analyte'], line['chnl0 concentration'],
                                                line['chnl1 analyte'], line['chnl1 concentration']))
        elif (line['chnl0 analyte'] == 'None') & (line['chnl1 analyte'] != 'None'):
            strlst.append('%s at %s'%(line['chnl1 analyte'], line['chnl1 concentration']))
        elif (line['chnl0 analyte'] != 'None') & (line['chnl1 analyte'] == 'None'):
            strlst.append('%s at %s'%(line['chnl0 analyte'], line['chnl0 concentration']))

    alphalst = [1, 1]

    subcategorybar(ax2, strlst, [respd['chnl0 event_rate'], respd['chnl1 event_rate']], collst, alphalst, labellst)
    ax2.set_ylabel('Eventrate [1/s]')
    ax2.tick_params(axis='x', which='both', bottom=False, top=False)

    ax0.set_ylim(bottom = - 0.05 * np.max(respd['chnl0 event_rate']), top = np.max(respd['chnl0 event_rate']) * 1.1)
    ax1.set_ylim(bottom = - 0.05 * np.max(respd['chnl1 event_rate']), top = np.max(respd['chnl1 event_rate']) * 1.1)

def plot_photon_swarm(results_hdf_loc, ax, tt_dict):
    
    chnl = tt_dict['chnl']
    
    sample_d = sample_info(results_hdf_loc, chnl)

    results_hdf = h5py.File(results_hdf_loc, "a")
    runstr = 'run_%s'%results_hdf['amnt_runs'][...]
    rungrp = results_hdf['run_%s'%results_hdf['amnt_runs'][...]]

    photon_hdf = h5py.File(get_photonhdf_loc(results_hdf_loc, absolute=True), "r")

    detectors = photon_hdf['photon_data']['detectors'][...]
    times = photon_hdf['photon_data']['timestamps'][()]
    times = times[detectors == int(chnl[-1])]
    timeres = photon_hdf['photon_data']['timestamps_specs']['timestamps_unit'][()]

    phystimes = times * timeres
    if 'tt_start' not in tt_dict:
        tt_dict['tt_start'] = phystimes[0]
        tt_dict['tt_end'] = phystimes[-1]
    else:
        tt_mask = (phystimes > tt_dict['tt_start']) & (phystimes < tt_dict['tt_end'])
        phystimes = phystimes[tt_mask]


    sns.swarmplot(phystimes, ax=ax, color = sample_d['color'])
    
    ax.set_xlim(left=tt_dict['tt_start'], right=tt_dict['tt_end'])
    ax.tick_params(top=True, right=True, which='both', direction='in')


def plot_pore_size_overview(axs, respd, pltdict={}):

    '''
    Fits the slope of eventrates for different concentrations of other analyte and poresizes, thus reduces a full range to one point
    Then plots the change of eventrate slope (compared to the average) vs. other concentration
    And plots the eventrate slope vs. pore size
    pltdict may contain:
    excludelst, a list of strings of patterns that should be excluded from the analysis, e.g. r'B[0-9]X'
    labellst
    collst
    '''

    conctyplst = ['X-', 'o-', '*-', 'd-', 'P-', 'H-', 's-', '<-', '^-', '8-', 'v-']
    sizetyplst = ['X', 'o', '*', 'd', 'P', 'H', 's', '<', '^', '8', 'v']

    (ax0, ax1, ax2, ax3) = axs


    if 'excludelst' not in pltdict:
        pltdict['excludelst'] = []
    excludelst = pltdict['excludelst']




    remlist = []
    for index, line in respd.iterrows():
        for exclude in excludelst:
            if re.search(exclude, index):
                remlist.append(index)

    usepd = respd.drop(index=remlist)

    if 'labellst' not in pltdict:
        pltdict['labellst'] = [usepd['chnl0 analyte'][-1],usepd['chnl1 analyte'][-1]]
    labellst = pltdict['labellst']

    if 'collst' not in pltdict:
        pltdict['collst'] = [usepd['chnl0 color'][-1],usepd['chnl1 color'][-1]]
    collst = pltdict['collst']

    slopedict = {'chnl0':[], 'chnl1':[]}
    for chnl in slopedict:
        if chnl == 'chnl0':
            otherchnl = 'chnl1'
            concax = ax0
            sizeax = ax1
            thiscol = collst[1]
            concaxxlab = labellst[0] + ' concentration [nM]'
            sizeaxleg = labellst[0] + ' at '
            ylabpart = labellst[1]
        else:
            otherchnl = 'chnl0'
            concax = ax2
            sizeax = ax3
            thiscol = collst[0]
            concaxxlab = labellst[1] + ' concentration [nM]'
            sizeaxleg = labellst[1] + ' at '
            ylabpart = labellst[0]

        for poresize in usepd['pore_diam'].unique():
            porepd = usepd[usepd['pore_diam'] == poresize]

            for conc in porepd['%s concentration unit'%chnl].unique():
                smallpd = porepd[porepd['%s concentration unit'%chnl] == conc]
                if len(smallpd) < 2:
                    continue

                if smallpd['%s analyte'%chnl][0] == 'None':
                    labelstr = 'No other analyte'
                else:    
                    labelstr = smallpd['%s analyte'%chnl][0] + ' at ' + smallpd['%s concentration'%chnl][0]

                if len(smallpd)> 2:
                    linear_model=np.polyfit(smallpd['%s concentration unit'%otherchnl],
                                            smallpd['%s event_rate'%otherchnl],1, cov=True)
                    (slope, x), cov = linear_model
                    slope_err = np.sqrt(cov[0,0])
                else:
                    linear_model=np.polyfit(smallpd['%s concentration unit'%otherchnl],
                                            smallpd['%s event_rate'%otherchnl],1, cov=False)
                    slope, x = linear_model
                    slope_err = 0

                slopedict[chnl].append([conc, slope, slope_err, poresize])
                
        dataarr = np.array(slopedict[chnl])
        psizes = np.unique(dataarr[:,3])
        for i in range(len(psizes)):
            poresize = psizes[i]
            psarr = dataarr[dataarr[:,3] == poresize]
            concax.errorbar(x=psarr[:,0].astype(float), y=psarr[:,1].astype(float) - np.average(psarr[:,1].astype(float)), yerr=psarr[:,2].astype(float),
                            label=poresize, fmt=conctyplst[i], ms=15, mec='k', color=thiscol)

        concax.legend(prop={'size':14})
        concax.set_xlabel(concaxxlab)
        concax.set_ylabel('change in normalised ' + ylabpart + ' eventrate [1/(nM s)]')
        #   concax.set_ylim(bottom=0)
        
        conces = np.sort(np.unique(dataarr[:,0]).astype(np.float)).astype(np.str)
        for i in range(len(conces)):
            oneconc = conces[i]
            psarr = dataarr[dataarr[:,0] == oneconc]
            poresizes = []
            for ps in psarr[:,3]:
                if 'x' in ps:
                    lowerps, upperps = ps.split('x')
                    psmean = np.average([float(lowerps), float(upperps)])
                    poresizes.append(psmean)
                else:
                    poresizes.append(float(ps))
            
            sizeax.errorbar(x=poresizes, y=psarr[:,1].astype(float), yerr=psarr[:,2].astype(float),
                            label=sizeaxleg + oneconc + ' nM', fmt=sizetyplst[i], ms=15, mec='k', color=thiscol)

        sizeax.legend(prop={'size':14})
        sizeax.set_xlabel('Pore diameter [nm]')
        sizeax.set_ylabel('normalised ' + ylabpart +  ' eventrate [1/(nM s)]')
        sizeax.set_ylim(bottom=0)



def plot_fcs_fitres(respd, ax, ax2, ax3, pltdict={}):

    if 'excludelst' not in pltdict:
        pltdict['excludelst'] = []
    excludelst = pltdict['excludelst']

    remlist = []
    for index, line in respd.iterrows():
        for exclude in excludelst:
            if re.search(exclude, index):
                remlist.append(index)
    usepd = respd.drop(index=remlist)

    sizetyplst = ['X', 'o', '*', 'd', 'P', 'H', 's', '<', '^', '8', 'v']
    i = 0
    for poresize in usepd['pore_diam'].unique():
        porepd = usepd[usepd['pore_diam'] == poresize]
        i += 1

        chnl0pd = porepd[porepd['chnl0 concentration unit'] > 0]
        chnl1pd = porepd[porepd['chnl1 concentration unit'] > 0]

        if 'collst' not in pltdict:
            pltdict['collst'] = [chnl0pd['chnl0 color'][-1], chnl1pd['chnl1 color'][-1]]
        collst = pltdict['collst']



        ax.errorbar(chnl0pd['order'], chnl0pd['chnl0 taud'] * 1000000, yerr=chnl0pd['chnl0 taud_err'],
                    label=poresize, fmt=sizetyplst[i], ms=10, mec='k', color=collst[0])

        ax.errorbar(chnl1pd['order'], chnl1pd['chnl1 taud'] * 1000000, yerr=chnl1pd['chnl1 taud_err'],
                    label=poresize, fmt=sizetyplst[i], ms=10, mec='k', color=collst[1])

        ax.set_ylabel('Diffusion time $τ_D$[μs]')
        ax.set_xlabel('Order')
        ax.set_ylim(0, max([max(chnl0pd['chnl0 taud']), max(chnl1pd['chnl1 taud'])]) * 1000000)

        ax2.errorbar(chnl0pd['order'], chnl0pd['chnl0 A0'], yerr=chnl0pd['chnl0 A0_err'],
                    label=poresize, fmt=sizetyplst[i], ms=10, mec='k', color=collst[0])

        ax2.errorbar(chnl1pd['order'], chnl1pd['chnl1 A0'], yerr=chnl1pd['chnl1 A0_err'],
                    label=poresize, fmt=sizetyplst[i], ms=10, mec='k', color=collst[1])

        ax2.set_ylabel('$A_0$')
        ax2.set_xlabel('Order')
        ax2.set_ylim(0, max([max(chnl0pd['chnl0 A0']), max(chnl1pd['chnl1 A0'])]))

        ax3.errorbar(chnl0pd['order'], chnl0pd['chnl0 kappa'], yerr=chnl0pd['chnl0 kappa_err'],
                    label=poresize, fmt=sizetyplst[i], ms=10, mec='k', color=collst[0])

        ax3.errorbar(chnl1pd['order'], chnl1pd['chnl1 kappa'], yerr=chnl1pd['chnl1 kappa_err'],
                    label=poresize, fmt=sizetyplst[i], ms=10, mec='k', color=collst[1])
        ax3.set_ylim(0,10)
        ax3.set_ylabel('κ')
        ax3.set_xlabel('Order')


def average_full_exps(respdlist):
    thisexpdict = {'pore':[], 'kap_rate':[], 'kap_rate_err':[], 'bsa_rate':[], 'bsa_rate_err':[], 
               'ratio':[], 'ratio_err':[]}

    for loopdf in respdlist:
        for i in range(1,9):
            poreidf = loopdf[loopdf.index.str.contains(r'B%i[%s]'%(i, 'ABCDEFGHI'))]
            print(i, len(poreidf))
            if len(poreidf) > 0:
                thisexpdict['pore'].append(poreidf['pore_diam_plot'].mean())
                norm_evrate = (poreidf[poreidf['chnl0 concentration unit'] > 0]['chnl0 event_rate_norm'])
                thisexpdict['kap_rate'].append(norm_evrate.mean())
                thisexpdict['kap_rate_err'].append(norm_evrate.sem())

                norm_evrate = (poreidf[poreidf['chnl1 concentration unit'] > 0]['chnl1 event_rate_norm'])
                thisexpdict['bsa_rate'].append(norm_evrate.mean())
                thisexpdict['bsa_rate_err'].append(norm_evrate.sem())

                thisexpdict['ratio'].append(thisexpdict['kap_rate'][-1] / thisexpdict['bsa_rate'][-1])
                thisexpdict['ratio_err'].append(np.sqrt((thisexpdict['kap_rate_err'][-1] / thisexpdict['bsa_rate'][-1]) ** 2 + 
                                                (thisexpdict['kap_rate'][-1] / thisexpdict['bsa_rate'][-1] ** 2 * 
                                                 thisexpdict['bsa_rate_err'][-1]) ** 2))
                
    for key in thisexpdict:
        thisexpdict[key]  = np.array(thisexpdict[key])
                
    return thisexpdict