import pandas as pd
import numpy as np
import sys
import os
import cv2
from sklearn.cluster import KMeans, DBSCAN
# from skimage.filters import threshold_otsu

def normalize(temp_matrix_flatted):
    min_temp = np.min(temp_matrix_flatted)
    max_temp = np.max(temp_matrix_flatted)
    gap = max_temp-min_temp
    normalized_temp = (temp_matrix_flatted-min_temp)/gap
    return normalized_temp

def increase_resolution(image, scale_factor, interpolation=cv2.INTER_LANCZOS4):
   
    height, width = image.shape[:2]
    new_dimensions = (int(width * scale_factor), int(height * scale_factor))
    high_res_image = cv2.resize(image, new_dimensions, interpolation=interpolation)
    return high_res_image

def detail_enhancement(image):  # DDE
    kernel = np.array([[0, -1, 0],
                       [-1, 5, -1],
                       [0, -1, 0]])
    

    enhanced_image = cv2.filter2D(image, -1, kernel)
    return enhanced_image

def denoise_image(image, strength=20):  #DNR
   
    denoised_image = cv2.fastNlMeansDenoisingColored(image, None, strength, strength, 7, 21)
    return denoised_image

def mask_seq(temp_matrix, amb, mask, mode=0):  
    """ mode=0: sort the clusters, from background to higher temp
        mode=1: pick the cluster with highest temp and make it red (index: 70)
        mode=2: pick the cluster with the second highest temp and make it light red (index: 80)
    """
    unique_clusters = np.unique(mask)
    cluster_means = []
    for cluster in unique_clusters:
        mean = np.mean(temp_matrix[mask==cluster])
        cluster_means.append((cluster, mean))

    cluster_means.sort(key=lambda x: (x[1]-amb))
    
    if mode==1:
        re_mask = mask
        if len(cluster_means) >= 2:
            max_diff_cluster = cluster_means[-1][0]  
            re_mask[mask == max_diff_cluster] = 70  
            
    elif mode==2:
        re_mask = mask
        if len(cluster_means) >= 3:
            max_diff_cluster = cluster_means[-1][0]  
            re_mask[mask == max_diff_cluster] = 70  
            second_max_diff_cluster = cluster_means[-2][0]  
            re_mask[mask == second_max_diff_cluster] = 80   
           
    else:
        re_mask = np.zeros_like(mask)
        for i, (cluster, _) in enumerate(cluster_means):
            re_mask[mask == cluster]=i

    return re_mask.astype(np.uint8)

def kmeans_mask(temp_matrix, amb, target_num, n_clusters=4, static_threashold = 0):
    """ static threashold: static lower limit of interesting range
    """
    h, w = temp_matrix.shape
    temp_matrix_flatted = temp_matrix.reshape(-1, 1)
    normalized_temp = normalize(temp_matrix_flatted)

    kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init=1)
    kmeans.fit(normalized_temp)

    mask = kmeans.labels_.reshape(h, w)
    mask = mask_seq(temp_matrix=temp_matrix, amb=amb, mask=mask, mode=target_num)
    mask_static = ((temp_matrix>static_threashold)*100).astype(np.uint8)
    return np.minimum(mask, mask_static)

def dbscan_remask(mask, temp_matrix, amb,num_target, eps=5, min_samples=5):   # velue=60 means the noise points
    h, w = mask.shape
    re_mask = np.full((h, w), 60, dtype=np.uint8)
    next_label = 0
    unique_clusters_mask = np.unique(mask)
    for cluster_label in unique_clusters_mask:
        cluster_points = np.column_stack(np.where(mask==cluster_label))

        if len(cluster_points)<min_samples:
            continue

        dbscan = DBSCAN(eps=eps, min_samples=min_samples)
        dbscan_labels = dbscan.fit_predict(cluster_points)
        for i, label in enumerate(dbscan_labels):
            if label==-1:
                re_mask[cluster_points[i][0], cluster_points[i][1]] = 60
            else:
                re_mask[cluster_points[i][0], cluster_points[i][1]] = next_label+label
        next_label+=max(dbscan_labels)+1
    re_mask = mask_seq(temp_matrix, amb, re_mask, mode=1)
    if num_target==2:
        re_mask = mask_seq(temp_matrix, amb, re_mask, mode=2)
    return re_mask

def draw_b_w_mask(mask):    # B, G, R
    new_image = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    new_image[mask == 60] = [255, 255, 0]     # Cyan
    new_image[mask == 70] = [0, 0, 255]     # Red
    new_image[mask == 80] = [100, 100, 255]   # Light Red

    new_image[mask == 0] = [255, 255, 255]    # White
    new_image[mask == 1] = [0, 0, 0]          # Black
    new_image[mask == 2] = [255, 165, 0]      # Orange
    new_image[mask == 3] = [128, 0, 128]      # Purple
    new_image[mask == 4] = [0, 255, 255]      # Yellow
    new_image[mask == 5] = [173, 216, 230]    # Light Blue
    new_image[mask == 6] = [0, 128, 128]      # Teal
    new_image[mask == 7] = [255, 192, 203]    # Pink)
    new_image[mask == 8] = [255, 140, 0]      # Dark Orange
    new_image[mask == 9] = [0, 0, 139]        # Dark Blue
    
    
    new_image[(mask != 0) & (mask != 1) & (mask != 2) & (mask != 3) & 
              (mask != 4) & (mask != 5) & (mask != 6) & (mask != 7) & 
              (mask != 8) & (mask != 9) & (mask != 10) & (mask != 60) & 
              (mask != 70) & (mask != 80)] = [200, 200, 200]  # Light Grey
    return new_image

def edge_detection(image, lower_limit=100, higher_limit = 200):

    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    
    # Canny edge detection
    edges = cv2.Canny(gray_image, lower_limit, higher_limit)
    
    # draw edges on image
    edges_3ch = cv2.merge([edges, edges, edges]) 
    edge_detected_image = cv2.addWeighted(image, 0.2, edges_3ch, 0.8, 0)
    
    return edge_detected_image

def add_temperature_labels(colorbar, color_min, color_max, num_labels=7):
   
    # ceate the free area for labels
    label_width = 130
    height, width, _ = colorbar.shape
    colorbar_with_labels = np.zeros((height, width + label_width, 3), dtype=np.uint8)
    
    colorbar_with_labels[:, :width] = colorbar

    temperature_values = np.linspace(color_min, color_max, num_labels)
    
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 1
    thickness = 2
    
    for i, temp in enumerate(temperature_values):
        y_position = int(height - (i / (num_labels - 1)) * height)
        text = f'{temp:.1f}C'
        
        # place text lebels
        text_size = cv2.getTextSize(text, font, font_scale, thickness)[0]
        text_x = width + 20  # text's location
        text_y = min(max(y_position + text_size[1] // 2, text_size[1]), height - 5)
        
        cv2.putText(colorbar_with_labels, text, (text_x, text_y), font, font_scale, (255, 255, 255), thickness)
    
    return colorbar_with_labels

def generate_thermal_image(temp_matrix, amb, scale_factor = 2, num_target=1):
    """ color_min & colormax: to create jet fig
        scale_factor: to increase fig resolution
        apply_dde: whether to use DDE
        apply_dnr: whether to use DNR
        apply_mask: whether to highlight hotspot
        num_target: 1 or 2 interesting targets
        save_fig: when save_fig==1, the thermal figs are saved as PNG
    """

    # increase resolution
    temp_matrix_high_res = increase_resolution(temp_matrix, scale_factor=scale_factor)

    mask = kmeans_mask(temp_matrix=temp_matrix_high_res, amb=amb, target_num=num_target, n_clusters=5)
    # mask = dbscan_remask(mask, temp_matrix_high_res, amb,num_target=num_target, eps=2, min_samples=5)

    return mask, temp_matrix_high_res

    
def get_temp_result(num_target, mask, temp_matrix, amb):
    avg_temp_70=amb
    avg_temp_80=amb
    if 70 in mask:
        avg_temp_70 = round(np.mean(temp_matrix[mask == 70]), 2)
    
        if num_target == 2:
            if 80 in mask:
                avg_temp_80 = round(np.mean(temp_matrix[mask == 80]), 2)
            # return avg_temp_70, avg_temp_80
            return avg_temp_80
        else:
            return avg_temp_70
    else:
        print('No hot target!!!')
        return amb
