import numpy as np
import cv2
from global_methods.Otsu import Otsu
import matplotlib.pyplot as plt


class Bolan_Su:

    def __init__(self, image):
        """
        Initialization of the Bolan_Su class, that thresholds using the method described by Bolan Su et al. (2013).
        :param image: the image to be thresholded.
        """
        self.image = image
        self.image = image.astype(np.int64)

    def binarize(self, contrast_window=7, gamma=0):
        """
        Implementation of thresholding as explained in Bolan Su et al. (2013).
        :param contrast_window: the size of the window applied to get the contrast image.
        :param gamma: a constant that determines the degree of edgeness in the contrast image
        :return: the binarized image.
        """
        c_width = contrast_window // 2
        epsilon = 0.00001
        alpha = (np.std(self.image) / 128) ** gamma

        # Creates the contrast image.
        contrast_image = np.zeros(self.image.shape)
        for i in range(self.image.shape[0]):
            for j in range(self.image.shape[1]):
                windowed_image = self.image[max(0, i - c_width): min(self.image.shape[0] - 1, i + c_width) + 1,
                                 max(0, j - c_width): min(self.image.shape[1] - 1, j + c_width) + 1].astype(np.int64)
                I_max = np.max(windowed_image)
                I_min = np.min(windowed_image)
                C_ij = ((I_max - I_min) / (I_max + I_min + epsilon))

                contrast_image[i, j] = alpha * C_ij + (1 - alpha) * (I_max - I_min)

        contrast_image = np.clip(contrast_image * 255, 0, 255)

        # Creates the text stroke pixel map by combining the otsu binarization and canny edge map of the contrast image.
        contrast_map = Otsu(contrast_image).binarize()
        edge_map = cv2.Canny(np.uint8(contrast_image), 0, 20).astype(np.int64)
        stroke_map = np.where(np.logical_or(contrast_map, np.abs(edge_map - 255)), 255, 0)

        # Calculates the window size for the thresholding by calculating the stroke width
        distances = []
        for row in range(stroke_map.shape[0]):
            all_pixels = []
            for col in range(stroke_map.shape[1] - 1):
                if stroke_map[row][col] == 0 and stroke_map[row][col + 1] == 255:
                    all_pixels.append(col)
            pixels = []
            for i in range(0, len(all_pixels) - 1):
                if self.image[row][all_pixels[i]] >= self.image[row][all_pixels[i + 1]]:
                    pixels.append(all_pixels[i])
            for i in range(0, len(pixels) - 1, 2):
                distances.append(pixels[i+1] - pixels[i])
        hist, bins = np.histogram(distances, 50, [0, 50])
        threshold_window = (np.argmax(hist) + 1)
        t_width = threshold_window // 2
        # plt.bar(bins[1:], hist)
        # plt.show()

        # Calculates the binary image using a window applied to the text stroke pixel map
        binary_image = np.zeros(self.image.shape).astype(np.int64)
        for i in range(self.image.shape[0]):
            for j in range(self.image.shape[1]):
                windowed_contrast = stroke_map[max(0, i - t_width): min(self.image.shape[0] - 1, i + t_width) + 1,
                                    max(0, j - t_width): min(self.image.shape[1] - 1, j + t_width) + 1].astype(np.int64)
                if stroke_map[i, j] <= (np.mean(windowed_contrast) + (np.std(windowed_contrast) / 2)):
                    binary_image[i, j] = 255
                else:
                    binary_image[i, j] = 0
        binary_image = np.abs(binary_image - 255)

        # Beginning of post-processing
        # Removes pixels that are unconnected
        ret, labels = cv2.connectedComponents(np.uint8(stroke_map))
        for label in np.unique(labels):
            label_indices = np.where(labels == label)
            if len(label_indices[0]) == 1:
                stroke_map[label_indices[0][0], label_indices[1][0]] = 0
                binary_image[label_indices[0][0], label_indices[1][0]] = 0

        # Processes the image such that neighborhood pixel pairs that are on symmetric sides of an edge pixel
        # belong to different classes
        ones = np.where(stroke_map != 255)
        for indices in zip(ones[0], ones[1]):
            if indices[0] != 0 and indices[0] != self.image.shape[0] - 1:
                left_pixel = stroke_map[indices[0] - 1, indices[1]]
                right_pixel = stroke_map[indices[0] + 1, indices[1]]
                if (left_pixel == 0 and right_pixel == 0) or (left_pixel != 0 and right_pixel != 0):
                    if self.image[indices[0] - 1, indices[1]] < self.image[indices[0] + 1, indices[1]]:
                        binary_image[indices[0] - 1, indices[1]] = 0
                        binary_image[indices[0] + 1, indices[1]] = 255
                    else:
                        binary_image[indices[0] - 1, indices[1]] = 255
                        binary_image[indices[0] + 1, indices[1]] = 0
            if indices[1] != 0 and indices[1] != self.image.shape[1] - 1:
                left_pixel = stroke_map[indices[0], indices[1] - 1]
                right_pixel = stroke_map[indices[0], indices[1] + 1]
                if (left_pixel == 0 and right_pixel == 0) or (left_pixel != 0 and right_pixel != 0):
                    if self.image[indices[0], indices[1] - 1] < self.image[indices[0], indices[1] + 1]:
                        binary_image[indices[0], indices[1] - 1] = 0
                        binary_image[indices[0], indices[1] + 1] = 255
                    else:
                        binary_image[indices[0], indices[1] - 1] = 255
                        binary_image[indices[0], indices[1] + 1] = 0

        # Removes single pixel distortions - specifically, convexities, concavities, and holes - by applying
        # a window to these sections and checking if the distortion patterns appear. If so, the pixel discrepancy
        # gets removed

        # pixel_convexity = np.array([[0,0,255,0,0], [255,255,255,255,255], [255,255,255,255,255]])
        # pixel_concavity = np.array([[255,255,0,255,255], [255,255,255,255,255], [255,255,255,255,255]])
        # pixel_hole = np.array([[255,255,255,255,255], [255,255,0,255,255], [255,255,255,255,255]])
        # window_height = pixel_whole.shape[0] // 2
        # window_width = pixel_whole.shape[1] // 2
        #
        # for i in range(window_height, binary_image.shape[0] - window_height):
        #     for j in range(window_width, binary_image.shape[1] - window_width):
        #         window = binary_image[i - window_height: i + window_height + 1, j - window_width: j + window_width + 1]
        #         if np.array_equal(window, pixel_convexity):
        #             binary_image[i - 1, j] = 0
        #         if np.array_equal(window, np.flip(pixel_convexity, 0)):
        #             binary_image[i + 1, j] = 0
        #         if np.array_equal(window, pixel_concavity):
        #             binary_image[i - 1, j] = 255
        #         if np.array_equal(window, np.flip(pixel_concavity, 0)):
        #             binary_image[i + 1, j] = 255
        #         if np.array_equal(window, pixel_hole):
        #             binary_image[i, j] = 255
        #
        # for i in range(window_width, binary_image.shape[0] - window_width):
        #     for j in range(window_height, binary_image.shape[1] - window_height):
        #         window = binary_image[i - window_width: i + window_width + 1, j - window_height: j + window_height + 1]
        #         if np.array_equal(window, np.rot90(pixel_convexity, k=1)):
        #             binary_image[i, j - 1] = 0
        #         if np.array_equal(window, np.rot90(pixel_convexity, k=3)):
        #             binary_image[i, j + 1] = 0
        #         if np.array_equal(window, np.rot90(pixel_concavity, k=1)):
        #             binary_image[i, j - 1] = 255
        #         if np.array_equal(window, np.rot90(pixel_concavity, k=3)):
        #             binary_image[i, j + 1] = 255
        #         if np.array_equal(window, np.rot90(pixel_hole, k=1)):
        #             binary_image[i, j] = 255

        # Applies morphological closing in place of the single pixel process
        binary_image = cv2.morphologyEx(binary_image.astype('uint8'), cv2.MORPH_CLOSE,
                                        cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)),
                                        iterations=1)
        # binary_image = cv2.medianBlur(binary_image.astype('uint8'), 3)
        # plt.imshow(np.abs(binary_image - 255), cmap='gray')
        # plt.show()
        binary_image = binary_image.astype(np.int64)
        return np.abs(binary_image - 255)