#!/usr/bin/env python3

import sys
import os
import re
import getopt
from graphviz import Digraph

import subprocess
from pathlib import Path

class Program:
    def __init__(self, name, file):
        self.name = name
        self.file = file

        self.lines = None
        with open(file) as f:
            self.lines = f.readlines()

class Parser:
    def __init__(self, program):
        self.program = program

        self.module_id = None
        self.target_triple = None

        self.functions = []
        self.current_function = None

        self.metamap = {}

        # Possible states:
        #   - search-function-start
        #   - in-basic-block
        self.state = "search-function-start"

        self.parse()

    def parse(self):
        for line in self.program.lines:
            line = line.rstrip()
            self.parse_line(line)

        print('ModuleID:', self.module_id)
        print('Target triple:', self.target_triple)

    def parse_line(self, line):
        # Skip empty lines
        match = re.match(r'^\s*$', line)
        if match:
            return

        if self.state == "search-function-start":
            self.parse_search_function_start(line)

        elif self.state == "in-basic-block":
            self.parse_in_basic_block(line)

        else:
            print('Unrecognised state:', self.state)

        return

    def parse_search_function_start(self, line):
        # Match ModuleID
        match = re.match(r"^;\sModuleID = '(.*)'", line)
        if match:
            self.module_id = match.group(1)
            return

        # Match target_triple
        match = re.match(r'^target triple = "(.*)"', line)
        if match:
            self.target_triple = match.group(1)
            return

        # Match string metadata definition
        match = re.match(r'(!\S+)\s=\s!{!"(.*)"}', line)
        if match:
            meta_key = match.group(1)
            meta_string = match.group(2)
            self.metamap[meta_key] = meta_string
            return

        # Match Function define
        match = re.match(r'^define .* \S* @(\S*)\((.*)\) #(\S*) (.*) {', line)
        if match:
            func_name = match.group(1)
            func_args = match.group(2)
            func_operand = match.group(3)
            func_metadata = match.group(4)
            self.start_function(line, func_name, func_operand, func_metadata)

            # The basic block operand depends on the number of arguments
            # We count the arguments and fill it in here.
            # If we run into a case where this does not work, we can also fill it in
            # When we process the BB lines if the BB.operand is None for example
            func_args_count = 0 # Function operand is always 0
            if func_args != '':
                a = func_args
                a = re.sub(r'\(.*?\)[^,]+', '', a) # Remove all brackets and their content, e.g., (i32, i32) as that's one arg
                func_args_count = a.count(',')+1 # Now we can count the commas to determine the number of arguments

            self.start_basic_block(line, str(func_args_count), "") # Function also starts the first basic block (one after the args)

            # Change the state of the FSM
            self.state = "in-basic-block"
            return

    def parse_in_basic_block(self, line):
        # New basic block
        match = re.match(r'^(.*):\s+;\s*(.*)', line)
        if match:
            bb_operand = match.group(1)
            bb_metadata = match.group(2)

            # Start a new basic block
            self.start_basic_block(line, bb_operand, bb_metadata)
            return

        # Match Function end
        match = re.match(r'^}', line)
        if match:
            self.stop_function()
            self.stop_basic_block()

            self.state = "search-function-start"
            return

        # Else it's an instruction
        line = line.lstrip()
        instr = Instruction(line)
        self.current_basic_block.instructions.append(instr)

    def start_function(self, line, name, operand, metadata):
        func = Function(line, name, operand, metadata)
        self.functions.append(func)
        self.current_function = func

    def stop_function(self):
        self.current_function = None

    def start_basic_block(self, line, operand, metadata):
        # Find the predecesors (preds= in the metadata)
        preds = []
        match = re.findall(r'%([^,\s]+)', metadata)
        for m in match:
            preds.append(m)

        basic_block = BasicBlock(line, operand, preds, metadata)
        self.current_basic_block = basic_block
        # Add the BB to the current function
        self.current_function.basic_blocks.append(basic_block)

    def stop_basic_block(self):
        self.current_basic_block = None


class ControlFlowGraph:
    def __init__(self, program):
        self.program = program

        self.parser = Parser(program)
        self.module_id = self.parser.module_id
        self.target_triple = self.parser.target_triple
        self.functions = self.parser.functions
        self.metamap = self.parser.metamap

        self.construct()

        # Construct the drawer to create .pdf files
        self.drawer = Drawer(self)

    # Construct the CFG using the parser information
    def construct(self):
        self.find_successors()
        self.convert_pred_succ_string_to_ref()

    def find_successors(self):
        # We have a predecessors for all the blocks, reverse it to find all the successors
        for f in self.functions:
            sucs_map = {}
            for bb in f.basic_blocks:
                for pred in bb.predecessors:
                    if pred in sucs_map:
                        sucs_map[pred].append(bb.operand)
                    else:
                        sucs_map[pred] = [bb.operand]
            #print('Successors map:', sucs_map)

            for bb in f.basic_blocks:
                if bb.operand in sucs_map:
                    bb.successors = sucs_map[bb.operand]
                else:
                    bb.successors = []
 
    # Currently predecessors and successors are in text form
    # we need to map them to their BB objects instead
    def convert_pred_succ_string_to_ref(self):
        for f in self.functions:
            # Collect all basic blocks in the function
            bb_map = {}
            for bb in f.basic_blocks:
                bb_map[bb.operand] = bb;

            # Convert all pred and succ in the function
            for bb in f.basic_blocks:
                bb.successors   = [bb_map[x] for x in bb.successors]
                bb.predecessors = [bb_map[x] for x in bb.predecessors]

class Instruction:
    def __init__(self, line):
        self.line = line
        self.predecessor = None
        self.successor = None
        self.is_debug = False

        if "call void @llvm.dbg" in line:
            self.is_debug = True

        # Create a "bare" instruction string
        bare = line
        bare = re.sub(',*\salign\s\S+', '', bare)
        bare = re.sub(',*\s!\S+\s\S+', '', bare)
        bare = re.sub(',*\s!\S+', '', bare)
        self.instruction_str = bare

class BasicBlock:
    def __init__(self, line, operand, predecessors, metadata):
        self.line = line
        self.operand = operand
        self.metadata = metadata

        self.instructions = []
        self.predecessors = predecessors
        self.successors = []

class Function:
    def __init__(self, line, name, operand, metadata):
        self.line = line
        self.name = name
        self.operand = operand 
        self.metadata = metadata

        self.basic_blocks = []

        self.instructions = []
        self.predecessors = []
        self.successors = []

        # Get the function signature
        self.signature = name # Fallback
        match = re.search(r'\S+\s@\S+\(.*\)', line)
        if match:
            self.signature = re.sub('@', '', match.group(0))

class Drawer:
    def __init__(self, cfg):
        self.cfg = cfg

        self.output_dir = None
        self.output_file = None

        self.function_files = None

    def draw(self, functions, output_dir, output_file, highlights, keep_intermediate=False):
        self.output_dir = output_dir
        self.output_file = output_file
        self.highlights = highlights
        self.function_files = []

        # Draw a function per pdf
        for f in self.cfg.functions:
            if len(functions) == 0 or f.name in functions:
                filename = self.function_filename(f)
                filepath = self.output_dir + '/' + filename
                self.draw_function(f, filepath)
                self.function_files.append(filepath + '.pdf')

        # Combine the pdf files using pdfunite
        try:
            #subprocess.run(['pdfunite'] + self.function_files + [self.output_file], cwd=self.output_dir)
            subprocess.run(['pdfunite'] + self.function_files + [self.output_dir + '/' + self.output_file])
        except Exception as e:
            print('Exception:', e)
            print('To output a merged .pdf of all the functions please install pdfunite')

        # Remove the function files
        if keep_intermediate == False:
            for f in self.function_files:
                try:
                    os.remove(f)
                except Exception as e:
                    print('Failed to remove function file:', f)

    def function_filename(self, function):
            return 'func-' +function.name

    def escape_string(self, string):
        return string

    def generate_label(self, basic_block):
        operand_str = basic_block.operand

        s = ""
        s += '<'
        s += "<table border='0' cellborder='0' cellspacing='0' cellpadding='1'>"
        s += "<tr>"
        s += "<td colspan='1' border='1' align='left' color='grey' bgcolor='grey'> <b>%" + operand_str + "</b></td>"

        # Add columns for the highlights
        for hl in self.highlights:
            s += "<td border='1' color='grey' bgcolor='grey' align='left'> <b>"+ hl.name +" </b></td>"

        s += '</tr>'

        # Build the table
        for i in basic_block.instructions:
            if i.is_debug == True:
                continue
            s += '<tr>'
            s += "<td bgcolor='white' align='left'> "+self.escape_string(i.instruction_str)+" </td>"
            for hl in self.highlights:
                s += "<td border='1' color='grey' bgcolor='white' align='left'> "+ hl.string(i.line) +" </td>"
            s += '</tr>'

        # End the table
        s = s + '</table> >'

        return s

    def draw_function(self, function, filename):
        print('Drawing:', function.name)
        parent_dot = Digraph()

        parent_dot.attr('graph', ranksep="1")
        parent_dot.attr('graph', splines="true")

        dot = Digraph(name='cluster_'+function.name)

        dot.attr(label=function.signature)
        dot.attr('node', shape='record')
        dot.attr('node', fontname='Source Code Pro')
        dot.attr('graph', fontname='Source Code Pro')
        dot.attr(concentrate='true')

        # Add the nodes
        for bb in function.basic_blocks:
            label = self.generate_label(bb)
            #print('label:', label)
            dot.node(bb.operand, label=label, style='filled', fillcolor='white')

        # Add the edges
        for bb in function.basic_blocks:
            for successor in bb.successors:
                dot.edge(bb.operand+'', successor.operand+'')

        # Render the pdf
        parent_dot.subgraph(dot)
        parent_dot.render(filename)
        os.remove(filename)


# Parse the arguments
class Arguments:
    def __init__(self, args):
        # Provided
        self.filepath = None
        self.output_file = None

        # The functions to process (None = ALL)
        self.functions = []

        # The metadata to highlight
        self.meta_highlights = []

        # Computed
        self.directory = None
        self.filename = None
        self.basename = None

        # Parse
        try:
            opts, args = getopt.getopt(args, 'ho:f:m:', ['output=', 'function=', 'meta='])
        except getopt.GetoptError:
            print('Error parsing arguments')
            self.usage()
            sys.exit(2)

        for opt, arg in opts:
            if opt == '-h':
                self.usage()
                sys.exit()
            elif opt in ('-o', '--output'):
                self.output_file = arg
            elif opt in ('-f', '--function'):
                self.functions.append(arg)
            elif opt in ('-m', '--meta'):
                self.meta_highlights.append(arg)

        if len(args) == 0:
            print('Input file required')
            self.usage()
            sys.exit(2)
        elif len(args) > 1:
            print('Only one input file is supported')
            self.usage()
            sys.exit(2)
        else:
            self.filepath = args[0]

        file_path = Path(self.filepath)
        if file_path.is_file() == False:
            print('Input file:', self.filepath, 'does not exits')
            sys.exit(2)

        # Get information from the filepath
        self.filename = str(file_path.name)
        self.directory = str(file_path.parent)
        self.basename = str(file_path.stem)
        
        if self.output_file == None:
            self.output_file = self.basename + '.pdf'

    def print(self):
        print('Input file:', self.filepath)
        print('filename:', self.filename)
        print('basename:', self.basename)
        print('directory:', self.directory)
        print('Output file:', self.output_file)
        print('Functions:', self.functions)
        print('Meta highlights:', self.meta_highlights)

    def usage(self):
        print('llvmir2dot.py -o <output.pdf> <input.ll>')

# namme,regex
class MetaHighlight:
    def __init__(self, name, metamap, title = None):
        self.name = name
        self.title = title
        self.metamap = metamap

    def title(self):
        if self.title != None:
            return self.title
        else:
            return self.name

    # Return the text to be displayed
    def string(self, line):
        s = ''
        regex = r'!'+ self.name + r'\s(!\S+),?'
        match = re.search(regex, line)
        if match:
            meta = match.group(1)
            meta = meta.replace(',', '')
            if meta in self.metamap:
                s = self.metamap[meta]
            else:
                s = meta
        return s

args = Arguments(sys.argv[1:])
args.print()


# Process the IR
p = Program(args.basename, args.filepath)
cfg = ControlFlowGraph(p)

meta_highlights = []
for h in args.meta_highlights:
    meta_highlights.append(MetaHighlight(h, cfg.metamap))

cfg.drawer.draw(args.functions, args.directory, args.output_file, meta_highlights)
