import numpy as np
import matplotlib.pyplot as plt
import cv2


class Mello_Costa:

    def __init__(self, image):
        """
        Initialization of the Mello_Costa class, that thresholds using the method described by Mello and
        Costa (2005).
        :param image: the image to be thresholded.
        """
        self.image = image.astype(np.int64)
        # Transition stuff only works if black is foreground in the original image
        self.image = np.abs(self.image - 255)

    def binarize(self):
        """
        Implementation of thresholding as explained in Mello and Costa (2005).
        :return: the binarized image.
        """

        # Calculates the most frequent intensity value in the image t
        hist, bins = np.histogram(self.image.ravel(), 256, [0, 256])
        prob = np.divide(hist, (self.image.shape[0] * self.image.shape[1]))
        t = np.argmax(hist)
        epsilon = 0.000000001

        # Calculates the entropies H, H_b, and H_w
        H_b = -1 * np.sum(
            np.multiply(prob[:t + 1], np.emath.logn(self.image.shape[0] * self.image.shape[1], prob[:t + 1] + epsilon)))
        H_w = -1 * np.sum(
            np.multiply(prob[t + 1:256],
                        np.emath.logn(self.image.shape[0] * self.image.shape[1], prob[t + 1:256] + epsilon)))
        H = H_b + H_w

        # Specifies hyperparameters for threshold using the entropy
        if H <= 0.25:
            mw = 2
            mb = 3
        elif 0.25 < H < 0.3:
            mw = 1
            mb = 2.6
        else:
            mw = 1
            mb = 1

        # Calculating the threshold
        thresh = int((mw * H_w + mb * H_b) * 256)
        orig_thresh = int((mw * H_w + mb * H_b) * 256)

        # Adjusting the threshold based on the ROC curve
        transition = 0
        iterations = 0
        while 0.89 > transition or transition > 0.91:
            iterations += 1
            # This condition prevents infinite loops from happening. Infinite loops can happen in two cases.
            # One, the threshold infinitely gets increased or decreased or, two, the threshold bounces between values
            # because it's smaller than 0.89 once and larger than 0.91 another time. These are both considered below.
            if iterations >= 255:
                if thresh <= 0 or thresh >= 255:
                    return np.where(self.image > orig_thresh, 255, 0)
                else:
                    return np.where(self.image > thresh, 255, 0)

            M = self.image == thresh
            n_1 = np.sum(M)
            n_0 = np.sum(np.logical_not(M))

            pd = []
            pfa = []
            for t in range(256):
                pd.append(round(np.sum(np.logical_and(self.image > t, M)) / (n_1 + 0.00001)))
                pfa.append(np.sum(np.logical_and(self.image > t, np.logical_not(M))) / (n_0 + 0.00001))
            for i in range(len(pd) - 1):
                if pd[i] == 1.0 and pd[i + 1] == 0.0:
                    transition = pfa[i + 1]

            # This first == 0 condition was added because if there are no values less than the threshold
            # in the image, then the threshold naturally needs to be increased, not decreased. Otherwise, it should
            # be decreased as usual.
            if transition == 0:
                below_threshold = np.where(np.unique(self.image) < thresh)
                # Here we check if nothing is below the threshold, or if the only element below the threshold is 0
                if np.sum(below_threshold) == 0 or \
                   np.array_equal(np.take(np.unique(self.image), list(below_threshold[0])), [0]):
                    thresh += 1
                else:
                    thresh -= 1
            elif transition > 0.91:
                thresh += 1
            elif transition < 0.89:
                thresh -= 1

        return np.where(self.image > thresh, 255, 0)
