# -*- coding: utf-8 -*-
"""
Created on Fri Mar 18 10:38:01 2022

@author: LokinLR
"""
import os, re, sys
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt


class Dude:
    
    # Names of all possible outputfiles of normal runs
    general = 'out_general1.txt'
    outConfig = 'out_config.txt'
    
    bottomfile = 'out_bottom1.txt'
    zetafile = 'out_zeta1.txt'
    
    velufile = 'out_velu.txt'
    velwfile = 'out_velw.txt'
    velubfile = 'out_velub.txt'
    
    
    bssfile = 'out_bss1.txt'
    dhdxfile = 'out_dhdx1.txt'
    fluxfile = 'out_flux1.txt'
    fszfile = 'out_fsz1.txt'
    intfile = 'out_int1.txt'
    Sroundfile = 'out_Sround1.txt'
    
    
    
    
    def __init__(self, runfolder):
    """
    Read the basic info from the model output,
    input: pathname to the location with the model output files.
    returns: self.time wich are the times at which output is written.
    """
        self.run = runfolder
        modelinput = os.path.join(runfolder,'out_config.txt')
        if os.path.exists(modelinput):
            # inputsettings = pd.read_csv(modelinput, header ='none')
            inputfile = open(modelinput, 'r')
            for line in inputfile:
                if 'ENV' in line:
                    line = line.replace('ENV',' ')
                if 'Npx' in line:
                    self.Npx = int(line.strip().split(' ')[2])
                elif 'Npz' in line:
                    self.Npz = int(line.strip().split(' ')[2])
                elif 'dtr' in line:
                    self.dt = int(line.strip().split(' ')[2])
                elif 'dt_write' in line:
                    self.dt_write = int(line.strip().split(' ')[2])
                elif 'readfw' in line:
                    l = line.strip().split(' ')[2].split("\'")[1]
                    if len(l) == 0:
                        self.floodwave = False
                    else:
                        self.floodwave = l
          
        
        elif os.path.exists(os.path.join(runfolder,'out_general1.txt')):
            inputsettings = pd.read_csv(os.path.join(runfolder,'out_general1.txt'),header = None)
            self.Npx = int(inputsettings.iloc[2].values[0])
            self.Npz = int(inputsettings.iloc[4].values[0])
            self.dt  = float(inputsettings.iloc[6].values[0])
            self.dt_write = float(inputsettings.iloc[7].values[0])
            
            if len(inputsettings) == 18:
                self.floodwave = inputsettings.iloc[17].values[0]
            else:
                self.floodwave = False
        else:
            sys.exit('Error: configuration output data in this folder, run DUDE')

        file = np.loadtxt(os.path.join(self.run,'out_zeta1.txt'))
        times = file[:,0]
        self.time = times
        
    def config(self):
    """
    Read the config_out.txt file to obtain the input variables from the run.
    returns: self.time wich are the times at which output is written.
    """
        modelinput = os.path.join(self.run,'out_config.txt')
        if os.path.exists(modelinput):
            # inputsettings = pd.read_csv(modelinput, header ='none')
            inputfile = open(modelinput, 'r')
            variables = {}
            for line in inputfile:
                if not '=' in line:
                    continue
                if 'ENV' in line:
                    line = line.replace('ENV',' ')
                       
                if '[' in line:
                    line_var,unit = line.strip().split("[")
                    if unit == ']':
                        unit = None
                    else:
                        unit = unit.split("]")[0]
                else:
                    unit = None
                    line_var = line   
                
                variable_name, value = map(str.strip, line_var.split('='))
                try: 
                    value = float(value)
                except:
                    value = value
                
                if value == 'true' or value =='t' or value =='True':
                    value = True
                elif value == 'false' or value =='f' or value == 'False':
                    value = False
                    
                
                if (unit == None or unit == "s" or unit == "sec" or unit == "m" or
                    unit == "-" or unit == "m/m" or unit == "m2/s" or unit == "m/s2" or
                    unit == "kg/m3"):
                    variables[variable_name] = value
                elif unit == 'hr':
                    variables[variable_name] = value*3600
                elif unit == 'min':
                    variables[variable_name] = value*60
                elif unit == 'day':
                    variables[variable_name] = value*(60*60*24)
                elif unit == 'mm':
                    variables[variable_name] = value/1000
                elif unit == 'deg':
                    variables[variable_name] = value*np.pi/180  
        return(variables)
                 
        
    def zeta(self):
        """
        Water level variations
        """
        return(np.loadtxt(os.path.join(self.run,'out_zeta1.txt'))[:,1:])        
    
    def qinput(self, t_eval = []):
        """
        To determine the input discharge at a specific time
        """
        if self.floodwave == False:
            #print('input discharge is constant')
            qin = pd.read_csv(os.path.join(self.run,
                                           'out_config.txt'),
                              header = None).iloc[31].values[0]
            qin = float(qin.split(' ')[-2])
            #qin = [[self.time[0], qin],
            #       [self.time[-1], qin]]
            
        else:
            #print('input discharge is a floodwave')
            qin = np.loadtxt(os.path.join(self.run,self.floodwave),
                             skiprows = 1)
            if np.size(t_eval) == 0:
                t_eval = qin[:,0]
        
            qin = np.interp(t_eval,qin[:,0],qin[:,1])
        return(qin)
    
    def q(self): 
        """
        Specific discharge from the bottom file
        """
        return(np.loadtxt(os.path.join(self.run,'out_bottom1.txt'))[::2,3])
        
    def H(self):
        """
        Calculated water depth from the bottom file
        """
        return(np.loadtxt(os.path.join(self.run,'out_bottom1.txt'))[::2,4])
    
    def L(self):
        """
        Calculated domain length from the bottom file
        """
        return(np.loadtxt(os.path.join(self.run,'out_bottom1.txt'))[::2,5])
    
    def X(self):
        """
        X location in m
        """
        L = self.L()
        dx = L/self.Npx
        x = np.zeros((len(L),self.Npx))
        for i in np.arange(len(L)):
            x_temp = np.arange(0,L[i],dx[i])
            x[i,:] = x_temp[:self.Npx]
        return(x)
    
    def Z(self):
        """
        Z location in m, based on water depth only,
        NOT including bed and wl variations
        """
        H = self.H()
        dz = H/self.Npz
        z = np.zeros((len(H),self.Npz))
        for i in np.arange(len(H)):
            z_temp = np.arange(0,H[i],dz[i])
            z[i,:] = z_temp[:self.Npz]
        return(z)

    def bottom(self):
        """
        Bed elevation
        """
        file = np.loadtxt(os.path.join(self.run,'out_bottom1.txt'))
        return(file[::2,6:])   

    def flowsep(self):
        """
        Flow separation line
        """
        file = np.loadtxt(os.path.join(self.run,'out_bottom1.txt'))
        return(file[1::2,6:])
    
    
    def velu(self):
        """
        velocity field in ov the velocity component in x direction
        """
        u = np.loadtxt(os.path.join(self.run,'out_velu.txt'))[:,1:]
        return(u.reshape(u.shape[0],self.Npx,self.Npz))
    
    def velw(self):
        """
        velocity field in ov the velocity component in x direction
        """
        w = np.loadtxt(os.path.join(self.run,'out_velw.txt'))[:,1:]
        return(w.reshape(w.shape[0],self.Npx,self.Npz))
      
    
    def ubed(self):
        return(np.loadtxt(os.path.join(self.run,'out_velub.txt'))[:,1:])
    
    def shearstress(self):
       return(np.loadtxt(os.path.join(self.run,'out_bss1.txt'))[::2,1:])
   
    def dhdx(self):
       return(np.loadtxt(os.path.join(self.run,'out_dhdx1.txt'))[:,1:])
   
    def qSflux(self):
       return(np.loadtxt(os.path.join(self.run,'out_flux1.txt'))[:,1:])
   
    
    
    
class linstab:
     
    def __init__(self, runfolder):
        """
        Reads all linstab files in the output folder
        """
        self.run = runfolder
         
        stabfiles = [x for x in os.listdir(runfolder) if 'out_stab' in x]
        
        test   = np.loadtxt(os.path.join(runfolder,stabfiles[0]),skiprows = 5)
        stabstepscount = len(test[:,0])
        
        stabtimes = np.zeros(len(stabfiles))
        stabLengths = np.zeros((stabstepscount,len(stabfiles)))
        stabDepths = np.zeros(len(stabfiles))
        stabGrowtrates = np.zeros((stabstepscount,len(stabfiles)))
        stabMigrates = np.zeros((stabstepscount,len(stabfiles)))
        
        Lmax = np.zeros(len(stabfiles))
        Grmax = np.zeros(len(stabfiles)) 
        Migmax = np.zeros(len(stabfiles)) 
        
        Av = np.zeros(len(stabfiles)) 
        S = np.zeros(len(stabfiles)) 
        
        for i,_ in enumerate(stabfiles):
            Lstabfile = os.path.join(runfolder,f"out_stab{i+1}.out")
            
            # read the preambles
            file = open(Lstabfile, 'r')
            for l,line in enumerate(file):
                if l == 0:
                    if 'WARNING' in line:
                        warn = 'nope'
                        # print(f'{line}\nin file out_stab{i+1}.out\nConsider increasing Lmax')
                elif l == 1:
                    stabtimes[i] = float(re.split(' |=',line)[2])
                    row = int(re.split(' |=',line)[4])-1
                    if len(re.split(' |=',line))>10:
                        Av[i] = float(re.split(' |=',line)[10])
                        S[i] = float(re.split(' |=',line)[12])
            file.close()
            
            stab   = np.loadtxt(Lstabfile,skiprows = 5)
            stabLengths[:,i]  = stab[:,0]
            stabDepths[i]  = stab[0,1]
            stabGrowtrates[:,i] = stab[:,2]
            stabMigrates[:,i] = stab[:,3]
            
            Lmax[i]   = stab[row,0]      
            Grmax[i]  = stab[row,2] 
            Migmax[i] = stab[row,3]
            
            
        self.times  = stabtimes
        self.Lmax   = Lmax
        self.Grmax  = Grmax
        self.Migmax = Migmax
        
        self.H      = stabDepths
        self.L      = stabLengths
        self.Gr     = stabGrowtrates
        self.Mig    = stabMigrates
        
        self.Av     = Av
        self.S      = S
        
                     
     
    
    
    def linstab(self):
        return('Linstab reader')
        