import numpy as np
from keras import backend as K
import tensorflow as tf


def cross_entropy(y_true, y_pred):
    return K.categorical_crossentropy(y_true, y_pred)


def total_loss(alpha, y_est):
    #y_est = tf.argmax(y_est, axis=1)
    #y_est /= K.sum(y_est, axis=-1, keepdims=True)
    #y_est = K.clip(y_est, K.epsilon(), 1.0 - K.epsilon())
    #y_est = K.one_hot(K.argmax(y_est, 1), num_classes=10)
    def loss(y_true, y_pred):
        # 
        #y_pred = tf.argmax(y_pred, axis=1)
        #y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
        #y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
        #pred_labels = K.one_hot(K.argmax(y_pred, 1), num_classes=K.shape(y_true)[1])
        #return -K.sum((alpha * y_true + (1 - alpha) * y_est) * K.log(y_pred), axis = -1)
        return alpha * tf.reduce_mean(-tf.reduce_sum(y_true * tf.log(y_pred), axis = -1)) + (1 - alpha) * tf.reduce_mean(-tf.reduce_sum(y_est * tf.log(y_pred), axis = -1)) 
    return loss

def custom_loss(alpha):
    def loss(y_true, y_pred):
        # 
        #y_pred = tf.argmax(y_pred, axis=1)
        #y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
        #y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
        #pred_labels = K.one_hot(K.argmax(y_pred, 1), num_classes=K.shape(y_true)[1])
        #return -K.sum((alpha * y_true + (1 - alpha) * y_est) * K.log(y_pred), axis = -1)
        y_est = y_true[:, 10:]
        #y_est = np.reshape(y_est, (-1,10))
        y_true = y_true[:, 0:10]
        #y_true = np.reshape(y_true, (-1,10))
        L = alpha * tf.reduce_mean(-tf.reduce_sum(y_true * tf.log(y_pred), axis = -1))
        L_expert = (1 - alpha) * tf.reduce_mean(-tf.reduce_sum(y_est * tf.log(y_pred), axis = -1)) 
        
        return L + L_expert
    return loss 
 
def custom_sparse_categorical_accuracy(y_true, y_pred):
    flatten_y_true = K.cast( K.reshape(y_true,(-1,1) ), K.floatx())
    flatten_y_pred = K.cast(K.reshape(y_pred, (-1, y_pred.shape[-1])), K.floatx())
    y_pred_labels = K.cast(K.argmax(flatten_y_pred, axis=-1), K.floatx())
    return K.cast(K.equal(flatten_y_true,y_pred_labels), K.floatx())    