import matplotlib.pyplot as plt
import numpy as np

class Visualisation:
    def __init__(self):
        plt.ion()  # Turn on interactive mode for dynamic plotting
        self.fig, self.ax = plt.subplots()
        # Initialize lists to store data points for plotting
        self.max_fitness_list = []
        self.avg_fitness_list = []
        self.min_fitness_list = []
        self.std_top_list = []
        self.std_low_list = []
        self.initialise_plot()

    def process_data(self, fitness_stats):
        """Process incoming fitness data to compute statistics."""
        mean = np.mean(fitness_stats)
        max_fitness = max(fitness_stats)
        min_fitness = min(fitness_stats)

        # Filtered lists for standard deviation calculation
        above_mean = [x for x in fitness_stats if x > mean]
        below_mean = [x for x in fitness_stats if x < mean]

        # Check for empty lists and calculate standard deviations accordingly
        std_top = mean + np.std(above_mean) if above_mean else mean
        std_low = mean - np.std(below_mean) if below_mean else mean
        return {'max': max_fitness, 'mean': mean, 'min': min_fitness, 'std_top': std_top, 'std_low': std_low}

    def initialise_plot(self):
        '''
        Initial setup for the plot.
        '''
        self.ax.set_xlim(0, 10)  # Initial x-axis limit
        self.ax.set_ylim(0, 110)  # Initial y-axis limit, adjust as necessary
        self.ax.set_xlabel('Generation')
        self.ax.set_ylabel('Fitness Score')
        # Setup plot lines for different statistics
        self.lines = {
            'max': self.ax.plot([], [], 'k-', label='Max Fitness')[0],
            'mean': self.ax.plot([], [], 'k-', label='Average Fitness')[0],
            'min': self.ax.plot([], [], 'k-', label='Min Fitness')[0],
            'std_top': self.ax.plot([], [], 'r-', label='Std Top')[0],
            'std_low': self.ax.plot([], [], 'r-', label='Std Low')[0],
        }
        # Initialize text labels with positions outside the plotting area
        self.text_labels = {
            'Max': self.ax.text(0, 0, 'max', va='center', ha='left'),
            'Avg': self.ax.text(0, 0, 'μ', va='center', ha='left'),
            'Min': self.ax.text(0, 0, 'min', va='center', ha='left'),
            'StdTop': self.ax.text(0, 0, 'μ+σ', va='center', ha='left'),
            'StdLow': self.ax.text(0, 0, 'μ-σ', va='center', ha='left')
        }
        # plt.legend()
        plt.title('Evolution of Fitness Over Generations')

    def update_plot(self, generation, fitness_stats):
        '''
        A function to update the plot with a new batch of data

        Parameters:
        - generation: The current generation
        - fitness_stats: A list of all the candidates' fitness values
        '''

        # Process the new data
        processed_data = self.process_data(fitness_stats)
        # Append new data to the lists
        self.max_fitness_list.append(processed_data['max'])
        self.avg_fitness_list.append(processed_data['mean'])
        self.min_fitness_list.append(processed_data['min'])
        self.std_top_list.append(processed_data['std_top'])
        self.std_low_list.append(processed_data['std_low'])
        # Correcting the generation range to start from 0 for x-axis
        x_values = range(len(self.std_low_list))  # Start from 0 up to the current generation - 1      
        # Fill between the standard deviation lines
        if hasattr(self, 'fill_between_reference'):
            self.fill_between_reference.remove()
        # Redraw the fill between with the latest data points
        self.fill_between_reference = self.ax.fill_between(x_values, self.std_low_list, self.std_top_list, color='red', alpha=0.5)
        
        # Update the plot lines with corrected x values
        self.lines['max'].set_data(x_values, self.max_fitness_list)
        self.lines['mean'].set_data(x_values, self.avg_fitness_list)
        self.lines['min'].set_data(x_values, self.min_fitness_list)
        self.lines['std_top'].set_data(x_values, self.std_top_list)
        self.lines['std_low'].set_data(x_values, self.std_low_list)
        # Adjust x and y axis limits dynamically
        # Add buffer for the x-axis to avoid setting identical min and
        # max limits of x-axis, which causes the plot to throw a warning.
        if generation <= 1:
            buffer = 1
            self.ax.set_xlim(0, generation + buffer)
        else:
            self.ax.set_xlim(0, generation)
        self.ax.set_ylim(0, max(self.max_fitness_list + self.std_top_list) * 1.1)
        # Update text labels' positions to outside the plotting area
        label_x_position = 1.01*generation # Adjust based on your plot's dimensions
        self.text_labels['Max'].set_position((label_x_position, self.max_fitness_list[-1]))
        self.text_labels['Avg'].set_position((label_x_position, self.avg_fitness_list[-1]))
        self.text_labels['Min'].set_position((label_x_position, self.min_fitness_list[-1]))
        self.text_labels['StdTop'].set_position((label_x_position, self.std_top_list[-1]))
        self.text_labels['StdLow'].set_position((label_x_position, self.std_low_list[-1]))

        plt.draw()
        plt.pause(0.1)  # Short pause to update the plot

    def finalise_plot(self):
        plt.ioff()  # Turn off interactive mode
        plt.show()

# def test(generations):
#     visualiser = Visualisation()
#     for generation in range(generations):
#         # Simulate generation data with random integers
#         step_data = np.random.randint(0, 100, 100)
#         visualiser.update_plot(generation, step_data)
#     visualiser.finalise_plot()

# if __name__ == '__main__':
#     test(100)