import numpy as np
from global_methods.Kavallieratou import Kavallieratou


class Rao_et_al:

    def __init__(self, image):
        """
        Initialization of the Roa_et_al class, that thresholds using the method described by Roa et al. (2009).
        :param image: the image to be thresholded.
        """
        self.image = image.astype(np.int64)

    def binarize(self, window_size=31, k=0.01):
        """
        Implementation of thresholding as explained in Roa et al. (2009).
        Note: This implementation uses Kavallieratou's method as part of it.
        :param window_size: the size of the window applied to the local thresholding. The image is split into regions of
        the window size during local thresholding.
        :param k: a constant that determines the sensitivity of the local thresholding. The higher the value of k, the lower
        the sensitivity.
        :return: the binarized image.
        """
        # First applies the global iterative threshold
        global_thresh = Kavallieratou(self.image.astype(np.int64))
        bin_image, iterations = global_thresh.binarize()

        # Finds the number of black pixels in each window of the image.
        # Calculates the mean and standard deviation with it.
        num_black_pixels = []
        for i in range(0, self.image.shape[0], window_size):
            for j in range(0, self.image.shape[1], window_size):
                bin_window = bin_image[i: i + window_size, j: j + window_size]
                num_black_pixels.append(len(np.where(bin_window == 0)[0]))
        num_black_pixels = np.array(num_black_pixels)
        mean = np.mean(num_black_pixels)
        std_div = np.std(num_black_pixels)

        # Applies the local thresholding step, where it goes through the windows and checks if a window needs to be
        # locally thresholded. If so, it reapplies the global iterative thresholding to the window, and adds it to the
        # binary image.
        local_image = np.zeros(self.image.shape)
        for i in range(0, self.image.shape[0], window_size):
            for j in range(0, self.image.shape[1], window_size):
                bin_window = bin_image[i: i + window_size, j: j + window_size]
                if len(np.where(bin_window == 0)[0]) > mean + std_div * k:
                    thresh = Kavallieratou(self.image[i: i + window_size, j: j + window_size])
                    new_window, _ = thresh.binarize(iterations)
                    local_image[i: i + window_size, j: j + window_size] = new_window
                else:
                    local_image[i: i + window_size, j: j + window_size] = bin_window

        return local_image
