import numpy as np
import cv2
from local_methods.Sauvola import Sauvola
import scipy


class Gatos:

    def __init__(self, image):
        """
        Initialization of the Gatos class, that thresholds using the method described by Gatos et al. (2006).
        :param image: the image to be thresholded.
        """
        self.image = image.astype(np.int64)

    def binarize(self):
        # Pre-processing
        filtered_image = np.clip(scipy.signal.wiener(self.image, (3, 3)), 0, 255)

        # Rough estimation of foreground regions
        S = (Sauvola(filtered_image).binarize(k=0.005) / 255).astype(np.int64)
        S = np.abs(S - 1)
        image = filtered_image

        # Background surface estimation
        back_window = (25, 25)
        B = np.zeros(self.image.shape)
        for x in range(self.image.shape[0]):
            for y in range(self.image.shape[1]):
                if S[x, y] == 0:
                    B[x, y] = image[x, y]
                else:
                    I_window = image[max(0, x - back_window[0]): min(self.image.shape[0] - 1, x + back_window[0]) + 1,
                                     max(0, y - back_window[1]): min(self.image.shape[1] - 1, y + back_window[1]) + 1]
                    S_window = S[max(0, x - back_window[0]): min(self.image.shape[0] - 1, x + back_window[0]) + 1,
                                 max(0, y - back_window[1]): min(self.image.shape[1] - 1, y + back_window[1]) + 1]
                    B[x,y] = (np.sum(I_window * (1 - S_window))) / (np.sum(1 - S_window) + 0.000001)

        # Final thresholding
        delta = np.sum(B - image) / np.sum(S)
        b = np.sum(B * (1 - S)) / np.sum(1 - S)
        q, p1, p2 = 0.6, 0.5, 0.8

        T = np.zeros(self.image.shape)
        for x in range(T.shape[0]):
            for y in range(T.shape[1]):
                d_B = q * delta * (p2 + (1 - p2)/(1 + np.exp((-4 * B[x,y])/(b * (1-p1)) + (2 * (1 + p1))/(1 - p1))))
                if B[x,y] - image[x, y] > d_B:
                    T[x,y] = 1

        # Post-processing
        shrink_radius = 3
        k_sh = 7 * shrink_radius
        k_sw = 3 * shrink_radius
        k_sw1 = 7 * shrink_radius
        for x in range(shrink_radius, T.shape[0] - shrink_radius):
            for y in range(shrink_radius, T.shape[1] - shrink_radius):
                window = T[x - shrink_radius: x + shrink_radius, y - shrink_radius : y + shrink_radius].astype(np.uint8)
                P_sh = len(np.where(window == 1)[0])
                P_sw = len(np.where(window == 0)[0])

                # 1st step
                if T[x,y] == 0 and P_sh > k_sh:
                    T[x,y] = 1
                # 2nd step
                if T[x,y] == 1 and P_sw > k_sw:
                    x_a = np.average(np.where(window == 0)[0])
                    y_a = np.average(np.where(window == 0)[1])
                    if np.abs(x - x_a) < back_window[0] and np.abs(y - y_a) < back_window[1]:
                        T[x, y] = 0
                # 3rd step
                if T[x,y] == 1 and P_sw > k_sw1:
                    T[x,y] = 0

        return T * 255

    # Upsampling code. Not used
    # def I_u(self, x_prime, y_prime, M, image):
    #     y = int(y_prime / M)
    #     b = (y_prime / M) - y
    #     return (-1 * b * ((1 - b) ** 2) * self.F(x_prime, max(0, y - 1), M, image)) + \
    #            ((1 - 2 * (b ** 2) + (b ** 3)) * self.F(x_prime, y, M, image)) + \
    #            (b * (1 + b - (b ** 2)) * self.F(x_prime, min(image.shape[1] - 1, y+1), M, image)) - \
    #            ((b ** 2) * (1 - b) * self.F(x_prime, min(image.shape[1] - 1, y + 2), M, image))
    #
    #
    # def F(self, x_prime, m, M, image):
    #     x = int(x_prime / M)
    #     a = (x_prime / M) - x
    #     return (-1 * a * ((1 - a) ** 2) * image[max(0, x - 1), m]) + \
    #            ((1 - 2 * (a ** 2) + (a ** 3)) * image[x, m]) + \
    #            (a * (1 + a - (a ** 2)) * image[min(image.shape[0] - 1, x + 1), m]) - \
    #            ((a ** 2) * (1 - a) * image[min(image.shape[0] - 1, x + 2), m])
