import numpy as np
from local_methods.Kamel_Zhao import Kamel_Zhao
from local_methods.Gatos import Gatos

import cv2


class Lantink:

    def __init__(self, image):
        """
        Initialization of the Lantink class, that thresholds using a method created by me, Anna Lantink.
        :param image: the image to be thresholded.
        """
        self.image = image.astype(np.int64)

    def binarize(self):
        T_value = -0.03

        # First calculate the low detail version of Kamel_Zhao with opening (i.e. closing) applied
        high_detail = Kamel_Zhao(self.image).binarize(T=T_value)
        high_detail_morph = cv2.morphologyEx(high_detail.astype('uint8'), cv2.MORPH_OPEN,
                                      cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)))

        # Calculate the fraction of foreground (black) pixels. Based on this calculation alter the T_value
        # for the binarization to get a good distribution of pixels in low and high detail. High detail is
        # also calculated to get a more detailed version of the binarization.
        fraction_foreground = len(np.where(high_detail_morph == 0)[0]) / (self.image.shape[0] * self.image.shape[1])
        # print(fraction_foreground)
        if fraction_foreground > 0.2:
            high_detail = Kamel_Zhao(self.image).binarize(T=T_value - 0.01)
            low_detail = Kamel_Zhao(self.image).binarize(T=T_value - 0.015)
        elif fraction_foreground < 0.05:
            low_detail = high_detail
            high_detail = Kamel_Zhao(self.image).binarize(T=T_value + 0.005)
        else:
            low_detail = Kamel_Zhao(self.image).binarize(T=T_value - 0.005)


        # Apply closing to the low detail to remove extra noise
        high_detail = cv2.morphologyEx(high_detail.astype('uint8'), cv2.MORPH_OPEN,
                                      cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)))
        low_detail = cv2.morphologyEx(low_detail.astype('uint8'), cv2.MORPH_CLOSE,
                                       cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)))


        # Recalculate fraction of the foreground and determine the window radius based on the level of detail.
        # Low detail gets a higher radius.
        fraction_foreground = len(np.where(low_detail == 0)[0]) / (self.image.shape[0] * self.image.shape[1])
        # print(fraction_foreground)
        if fraction_foreground < 0.02:
            window_radius = 15
        elif fraction_foreground < 0.1:
            window_radius = 10
        else:
            window_radius = 5

        # Low detail gets a relatively sparse reading of the watermark. This is valuable because it *should* include
        # very little of the background. However, it also doesn't contain a very strong foreground. To make sure that
        # enough foreground is contained to make it not totally useless, I take all the foreground pixels in the low
        # detail version and then take the pixels in the high detail image that are within a window of the low detail
        # foreground pixel.
        foreground_indices = np.where(low_detail == 0)
        total_contrast = np.full(self.image.shape, 255)
        for x, y in zip(foreground_indices[0], foreground_indices[1]):
            window = high_detail[max(0, x - window_radius): min(self.image.shape[0] - 1, x + window_radius) + 1,
                                 max(0, y - window_radius): min(self.image.shape[1] - 1, y + window_radius) + 1]

            total_contrast[max(0, x - window_radius): min(self.image.shape[0] - 1, x + window_radius) + 1,
                           max(0, y - window_radius): min(self.image.shape[1] - 1, y + window_radius) + 1] = window

        # Based on the fraction of foreground pixels, different morphological operations are applied.
        # Since black is foreground, the erosion actually works as a dilation.
        fraction_foreground = len(np.where(total_contrast == 0)[0]) / (self.image.shape[0] * self.image.shape[1])
        # print(fraction_foreground)
        if fraction_foreground > 0.05:
            total_contrast = cv2.morphologyEx(total_contrast.astype('uint8'), cv2.MORPH_OPEN,
                                              cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)))
        else:
            total_contrast = cv2.erode(total_contrast.astype(np.uint8),
                                       cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)))

        return total_contrast
