# -*- coding: utf-8 -*-
"""
Created on Mon Sep 26 10:04:56 2022

@author: vedal + stefan
"""
from obspy import UTCDateTime
from obspy.clients.fdsn import Client
#from obspy.clients.iris import Client
client = Client()
from obspy import read as obread
from obspy import read_inventory
from datetime import datetime, timezone, timedelta
import time
import schedule
import pickle
import numpy as np
import tensorflow as tf 
import numpy as np
import os
import sys
from keras.models import Model
#from keras.layers import Conv1D, BatchNormalization, Add, MaxPooling1D, Dropout, Dense, Flatten, CuDNNLSTM, ZeroPadding1D
from keras.layers import Conv1D, BatchNormalization, Add, MaxPooling1D, Dropout, Dense, Flatten, ZeroPadding1D
from keras.layers import Activation, Input, concatenate, GaussianNoise, GlobalMaxPooling1D, GlobalAveragePooling1D, Softmax, Permute, Multiply, Masking
#from keras.optimizers import Adam, RMSprop, SGD
from tensorflow.keras.optimizers import Adam,RMSprop, SGD
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras import backend as keras
from keras.regularizers import *  
import matplotlib
from matplotlib import pyplot as plt
matplotlib.use('Agg')
import pandas as pd
import random as rd
import copy


###########
# read / import mseed data for events:
data_base='old'

if data_base == 'old': ########### old events list (Veda):
    selectedevents = []
    #testlist = [12,  13, 14, 15, 29, 30, 39]
    # event 3 file incompleteu
    allevs=[0, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 
           20, 23, 26, 30, 32, 35, 36, 37, 39, 41, 43, 46, 50, 51, 
           52, 54,  57, 58, 59, 63, 65, 66, 69 , 81, 87]
    # testlist=[0, 4, 6, 7, 8, 9, 10, 11, 17]
    # testlist=print(rd.sample(allevs, ))
    testlist = copy.deepcopy(allevs)
    for k in[1]:del testlist[k]
    for i in testlist:
        selectedevents.append("./Original_data/seeddata32/TEST"+str(i)+".mseed")
elif data_base == 'new': ########### new events list (Stefan):
    w_names = [line.rstrip() for line in open('windows_list1.txt')]
    import copy
    selectedevents=copy.deepcopy(w_names)
    for k in [6,1,0]: # get rid of faulty data files
        del selectedevents[k]
else: print('plz select databez')
############

window_length=16384
s_s=900*40;m_l=1000*40; 
# 900 secods at 40 Hz and 1000 s at 40 Hz
# create windows of 16.6 minutes with overlap of 1.6 minutes
def make_macrowindows(X1, sample_stride = s_s, mawi_length=m_l):
    X2=[]
    for i in range(len(X1)-mawi_length):
        if i % sample_stride == 0:
            X2.append(X1[i:i+mawi_length])
    return(X2)


def Preprocessing_one(mawik):
    def make_windows(X1, sample_stride = 650):
        X2 = []
        for i in range(len(X1)-window_length):   
        #for i in range(len(X1)):        

            if i % sample_stride == 0:
                X2.append(X1[i:i+window_length])    
        return(X2)

    def normalise_old(X1): ## Here the 3 components are mean stripped and normalised together, results strange but working for the CNN:
        X2 = []
        for data in X1:
            values = np.zeros((len(data),3))
            values = data - np.mean(data)
            values = values / np.linalg.norm(values)
            X2.append(values)
        return(X2)
    #
        # ### Make windows of length "window_length" 
    # The windows overlap of 650 timesteps
   # normwin=[];
    #sub_windows=[]
    sub=make_windows(mawik, 650)
    #nowi=str(len(sub))
    sub_windows = normalise_old(sub)
    #for j in range(len(sub)):
        #jewi=str(j+1)
        #print('\r        normalising subwindow',jewi,'of ',nowi,' subwindows |||',end='', flush = True)
     #   subnor=normalise_old(sub[j])
      #  sub_windows.append(subnor)
    print('')
    return(sub_windows)

# In[4]

def encoder(input_shape, nb_classes):
        
    input_layer= Input(input_shape)
    random = Conv1D(3, 3, padding='same', kernel_initializer='random_normal')(input_layer)

    X = Add()([input_layer, random])
    

    conv3 = Conv1D(filters=32, kernel_size=7,padding='same', strides=1, kernel_initializer='random_normal')(X)
    conv3 = Conv1D(filters=32, kernel_size=7,padding='same', strides=1, kernel_initializer='random_normal')(conv3)
    conv3 = MaxPooling1D(3, strides=1)(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)

    conv4 = Conv1D(64, kernel_size=7,padding='same', strides=1, kernel_initializer='random_normal')(conv3)
    conv4 = Conv1D(64, kernel_size=7,padding='same', strides=1, kernel_initializer='random_normal')(conv4)        
    conv4 = MaxPooling1D(3, strides=2)(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)

    conv5 = Conv1D(filters=64, kernel_size=5, padding='same', dilation_rate=2, kernel_initializer='random_normal')(conv4)
    conv5 = MaxPooling1D(3, strides=1)(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)
        
    conv3 = Conv1D(filters=128, kernel_size=5,padding='same', dilation_rate=4 , kernel_initializer='random_normal')(conv5)
    conv3 = MaxPooling1D(3, strides=1)(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)

    conv3 = Conv1D(filters=128, kernel_size=5,padding='same', dilation_rate=8 , kernel_initializer='random_normal')(conv3)
    conv3 = MaxPooling1D(3, strides=1)(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)

    conv4 = Conv1D(256, kernel_size=3,padding='same',  dilation_rate=16 , kernel_initializer='random_normal')(conv3)
    conv4 = MaxPooling1D(3, strides=1)(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)

    conv5 = Conv1D(filters=256, kernel_size=3, padding='same', dilation_rate=32, kernel_initializer='random_normal')(conv4)
    conv5 = MaxPooling1D(3, strides=1)(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)
    conv5 = Dropout(0.02)(conv5)

    Y = GlobalAveragePooling1D()(conv5)
    Y = Dense(256, activation='relu', kernel_initializer='random_normal')(Y)
    Y = Dropout(0.02)(Y)

    output_layer = Dense(nb_classes, activation='softmax', kernel_initializer='random_normal')(Y)

    model = Model(inputs=input_layer, outputs=output_layer)

        
    return model

####################################

model=encoder((window_length, 3), 2)
model.compile(loss='binary_crossentropy', optimizer=RMSprop(0.0001), metrics=['accuracy'])
#model.load_weights("model_0.h5")
model.load_weights("bestweights.h5")

results_model = np.zeros((len(selectedevents), 39))
#selectedevents=selectedevents[0:2]
for i, w in enumerate(selectedevents):

    event_10h=obread(w)
    
    print('####################################################################')
    print('####################################################################')
    print('evaluating 10 hours of file',w,' (nr. '),i+1,' )'
    print(event_10h)
    stra = event_10h[0].data
    strb = event_10h[1].data
    strc = event_10h[2].data
    sta = np.expand_dims(stra[-1440000:], axis=-1)
    stb = np.expand_dims(strb[-1440000:], axis=-1)
    stc = np.expand_dims(strc[-1440000:], axis=-1)
    w_data=np.concatenate((sta,stb,stc),axis=1)
    # split the 10 hours in windows of 1000 sec=40k samples=16.66_ min 
    # overlap 1.66 minutes, taking a new window every 15 min=900 sec=32k samples
    mawi = make_macrowindows(w_data)   
    for k, mawik in enumerate(mawi):
        #print('    ',k+1,'th macro-window of 40k samples (1000s) out of ', len(mawi), 'macro-windows...')
        #print('    ','Now splitting macro-window in subwindows and normalising...')
        mawik_n=Preprocessing_one(mawik)
        #print('    ','40k macrowindows split in', len(mawik_n), 'subwindows of 16384 samples each and normalised.')
        #print('    ','Now evaluating model proba.....')
        #print('mawik_n has',len(mawik_n),'elements')
        estest=model(np.array(mawik_n))
        #print('estest has',len(estest),'elements')
        print('    ',k,estest)
        arg_maxes = np.argmax(estest, axis=1)
        estestest = len(np.where(arg_maxes==1)[0])/len(estest)
       # estestest=np.average(estest[:,1])
        #print('    *prediction* value =',estest)
        print('    ','*proba* value =',estestest)
        print('    ','########################################################')
        results_model[i,k] = estestest
    if data_base == 'new': 
        w_name=w.strip()
        final_time=str(w_name).split('/')[1].split('.')[0]
        foot='results_'+final_time+'.npy'
    elif data_base == 'old':
        evcha='_'.join([str(n) for n in testlist])
        foot='results_'+w.split('/')[3][:-6]+'.npy'
    np.save(foot,results_model[i])
########################################
