# Code and data from the CPAIOR 2021 paper "Learning Variable Activity Initialisation for Lazy Clause Generation Solvers"
#
# Ronald van Driel et al.
# Delft University of Technology, NL
# April 2021
# copyright (c) 2020-2021 by R. van Driel et al.
# GNU General Public Licence 3
#
# contact: n.yorke-smith@tudelft.nl


import subprocess
import glob
import ntpath
import sys
import os
import errno
import re
from scipy.sparse import csr_matrix
import numpy as np
import pickle
import random


basePath = "./"
gcnPath = "gcn/gcn/"

trainFlatZinc = glob.glob(basePath + "FlatZincUnsat/*.fzn.gz")
trainFlatZinc = list(map(lambda x: x.replace('\\', '/'), trainFlatZinc))

predictFlatZinc = glob.glob(basePath + "FlatZincPaths/*/*.fzn.gz")
predictFlatZinc = list(map(lambda x: x.replace('\\', '/'), predictFlatZinc))

trainFiles = glob.glob(basePath + "TrainFiles/*.txt")
trainFiles = list(map(lambda x: x.replace('\\', '/'), trainFiles))

predictFiles = glob.glob(basePath + "PredictFiles/*.txt")
predictFiles = list(map(lambda x: x.replace('\\', '/'), predictFiles))


## parses a gzipped fzn file
def parseFlatzinc(fznFile, textfile):
    print(fznFile + " " + textfile)
    varArrays = {}
    valArrays = {}
    varTypes = {}
    constraints = {}
    allVars = set([])
    if os.path.exists(fznFile):
        with open(textfile) as txt:
            Ids = set([])
            lines = txt.readlines()
            for line in lines:
                Ids.update(set([int(n) for n in line.split()]))

            try:
                # gunzip fznFile to fznFileWithout
                fznFileWithout = fznFile.rsplit( ".", 1 )[ 0 ]  # remove .gz extension
                if not os.path.exists(fznFileWithout):
                    print(fznFileWithout + " doesn't exist. Attempting to gunzip " + fznFile)
                    r_gunzip = subprocess.call("gunzip -k -f " + fznFile, shell=True)

                with open(fznFileWithout) as fzn:
                    print(fznFileWithout)
                    count = -1
                    lines2 = fzn.readlines()
                    for line2 in lines2:
                        if line2.startswith("array"):
                            firstpart = line2.split('=')[0]
                            secondpart = line2.split('=')[1]
                            secondpart = secondpart.split('::')[0]
                            arrayName = re.search('(?<=: )[^\\s]*(?=( .*|::.* )=)', line2)
                            if not arrayName:
                                print(line2)
                            arrayName = arrayName.group(0)
                            allVars.update({arrayName})
                            if "var " in line2:
                                variables = list(re.findall('X_INTRODUCED_[0-9]*_', secondpart))
                                total = secondpart.count(",") + 1
                                if "var float" in firstpart:
                                    varArrays.update({arrayName: (
                                                total, "float", -3.4028234663852886e+38, 3.4028234663852886e+38, variables)})
                                elif "var int" in firstpart:
                                    varArrays.update({arrayName: (total, "int", -2147483648, 2147483647, variables)})
                                elif "var bool" in firstpart:
                                    varArrays.update({arrayName: (total, "bool", 0, 1, variables)})
                                elif re.search('(?<=var )-?[0-9]+\\.\\.-?[0-9]+', firstpart):
                                    minDomain = re.search('-?[0-9]+(?=\\.\\.)', line2).group(0)
                                    maxDomain = re.search('(?<=\\.\\.)-?[0-9]+', line2).group(0)
                                    varArrays.update({arrayName: (total, "int", int(minDomain), int(maxDomain), variables)})
                                elif re.search('(?<=var )-?[0-9]+\\.[0-9]+\\.\\.-?[0-9]+\\.[0-9]+', firstpart):
                                    minDomain = re.search('-?[0-9]+\\.[0-9]+(?=\\.\\.)', line2).group(0)
                                    maxDomain = re.search('(?<=\\.\\.)-?[0-9]+\\.[0-9]+', line2).group(0)
                                    varArrays.update({arrayName: (total, "float", float(minDomain), float(maxDomain), variables)})
                                else:
                                    print("Weird Array: ")
                                    print(line2)
                                    continue
                            else:
                                total = secondpart.count(",") + 1
                                valArrays.update({arrayName: total})
                        elif line2.startswith("var"):
                            varName = re.search('(?<=: )[^;:= ]*', line2).group(0)
                            allVars.update({varName})
                            if "var int" in line2:
                                varTypes.update({varName: ("int", -2147483648, 2147483647)})
                            elif "var bool" in line2:
                                varTypes.update({varName: ("bool", 0, 1)})
                            elif "var float" in line2:
                                varTypes.update({varName: ("float", -3.4028234663852886e+38, 3.4028234663852886e+38)})
                            elif re.search('(?<=var )-?[0-9]+\\.\\.-?[0-9]+', line2):
                                minDomain = re.search('-?[0-9]+(?=\\.\\.)', line2).group(0)
                                maxDomain = re.search('(?<=\\.\\.)-?[0-9]+', line2).group(0)
                                varTypes.update({varName: ("int", int(minDomain), int(maxDomain))})
                            elif re.search('(?<=var )-?[0-9]+\\.[0-9]+\\.\\.-?[0-9]+\\.[0-9]+', line2):
                                minDomain = re.search('-?[0-9]+\\.[0-9]+(?=\\.\\.)', line2).group(0)
                                maxDomain = re.search('(?<=\\.\\.)-?[0-9]+\\.[0-9]+', line2).group(0)
                                varTypes.update({varName: ("float", float(minDomain), float(maxDomain))})
                            elif "var {" in line2:
                                domain = re.search('\\{.*\\}:', line2).group(0)
                                numbers = [int(n) for n in list(re.findall('[0-9]+', domain))]
                                minDomain = min(numbers)
                                maxDomain = max(numbers)
                                varTypes.update({varName: ("set", minDomain, maxDomain)})
                            else:
                                print("Weird line: ")
                                print(line2)
                                continue
                        elif line2.startswith("constraint"):
                            count += 1
                            constraints.update({count: (count in Ids, [var for var in allVars if var in line2])})
                            continue
                        else:
                            continue

                print(fznFileWithout + " done")
            finally:
                # delete fznFileWithout
                if os.path.exists(fznFileWithout):
                    print(fznFileWithout + " exists. Attempting to delete it (might fail if another thread does it first)")
                    r_delete = subprocess.call("rm " + fznFileWithout, shell=True)

    else:
        return False
    return valArrays, varArrays, varTypes, constraints, allVars


def graphData(baseID, valArrays, varArrays, varTypes, constraints, allVars):
    features = []
    labels = []
    graph = {}
    keyList = list(varTypes.keys())
    varIDName = {}
    for varName in varTypes:
        id = baseID + keyList.index(varName)
        varIDName.update({id: varName})
        type, minV, maxV = varTypes.get(varName)
        # [bool, int, float, min, max, range]
        boolL = 1.0 if type == "bool" else 0.0
        intT = 1.0 if type == "int" else 0.0
        floatT = 1.0 if type == "float" else 0.0
        setT = 1.0 if type == "set" else 0.0
        featureVector = [boolL, intT, floatT, setT, float(minV), float(maxV), float(maxV - minV)]
        neighbours = set([])
        label = [1, 0]
        for constraint in constraints:
            muc, vars = constraints.get(constraint)
            if varName in vars:
                if muc:
                    label = [0, 1]
                for var in vars:
                    if var in varTypes and not var == varName:
                        neighbours.add(keyList.index(var) + baseID)
                    elif var in varArrays:
                        _, _, _, _, arrayVars = varArrays.get(var)
                        for arrayVar in arrayVars:
                            if not arrayVar == varName:
                                neighbours.add(keyList.index(arrayVar) + baseID)
        graph.update({id: neighbours})
        labels.append(label)
        features.append(featureVector)
    return features, labels, graph, len(varTypes), varIDName


def parseFlatzincPredict(fznFile):
    print(fznFile)
    fznFileWithout = fznFile.rsplit( ".", 1 )[ 0 ]  # remove .gz extension
    fileName = fznFileWithout[fznFileWithout.rfind('/') + 1:][:(fznFileWithout[fznFileWithout.rfind('/') + 1:]).rfind('.')]

    varArrays = {}
    valArrays = {}
    varTypes = {}
    constraints = {}
    allVars = set([])

    try:
        # gunzip fznFile to fznFileWithout
        if not os.path.exists(fznFileWithout):
            print(fznFileWithout + " doesn't exist. Attempting to gunzip " + fznFile)
            r_gunzip = subprocess.call("gunzip -k -f " + fznFile, shell=True)

        with open(fznFileWithout) as fzn:
            print(fznFileWithout)
            count = -1
            lines2 = fzn.readlines()
            for line2 in lines2:
                if line2.startswith("array"):
                    firstpart = line2.split('=')[0]
                    secondpart = line2.split('=')[1]
                    secondpart = secondpart.split('::')[0]
                    arrayName = re.search('(?<=: )[^\\s]*(?=( .*|::.* )=)', line2)
                    if not arrayName:
                        print(line2)
                    arrayName = arrayName.group(0)
                    allVars.update({arrayName})
                    if "var " in line2:
                        variables = list(re.findall('X_INTRODUCED_[0-9]*_', secondpart))
                        total = secondpart.count(",") + 1
                        if "var float" in firstpart:
                            varArrays.update({arrayName: (
                            total, "float", -3.4028234663852886e+38, 3.4028234663852886e+38, variables)})
                        elif "var int" in firstpart:
                            varArrays.update({arrayName: (total, "int", -2147483648, 2147483647, variables)})
                        elif "var bool" in firstpart:
                            varArrays.update({arrayName: (total, "bool", 0, 1, variables)})
                        elif re.search('(?<=var )-?[0-9]+\\.\\.-?[0-9]+', firstpart):
                            minDomain = re.search('-?[0-9]+(?=\\.\\.)', line2).group(0)
                            maxDomain = re.search('(?<=\\.\\.)-?[0-9]+', line2).group(0)
                            varArrays.update({arrayName: (total, "int", int(minDomain), int(maxDomain), variables)})
                        elif re.search('(?<=var )-?[0-9]+\\.[0-9]+\\.\\.-?[0-9]+\\.[0-9]+', firstpart):
                            minDomain = re.search('-?[0-9]+\\.[0-9]+(?=\\.\\.)', line2).group(0)
                            maxDomain = re.search('(?<=\\.\\.)-?[0-9]+\\.[0-9]+', line2).group(0)
                            varArrays.update({arrayName: (total, "float", float(minDomain), float(maxDomain), variables)})
                        else:
                            print("Weird Array: ")
                            print(line2)
                        continue
                    else:
                        total = secondpart.count(",") + 1
                        valArrays.update({arrayName: total})
                elif line2.startswith("var"):
                    varName = re.search('(?<=: )[^;:= ]*', line2).group(0)
                    allVars.update({varName})
                    if "var int" in line2:
                        varTypes.update({varName: ("int", -2147483648, 2147483647)})
                    elif "var bool" in line2:
                        varTypes.update({varName: ("bool", 0, 1)})
                    elif "var float" in line2:
                        varTypes.update({varName: ("float", -3.4028234663852886e+38, 3.4028234663852886e+38)})
                    elif re.search('(?<=var )-?[0-9]+\\.\\.[0-9]+', line2):
                        minDomain = re.search('-?[0-9]+(?=\\.\\.)', line2).group(0)
                        maxDomain = re.search('(?<=\\.\\.)-?[0-9]+', line2).group(0)
                        varTypes.update({varName: ("int", int(minDomain), int(maxDomain))})
                    elif re.search('(?<=var )-?[0-9]+\\.[0-9]+\\.\\.-?[0-9]+\\.[0-9]+', line2):
                        minDomain = re.search('-?[0-9]+\\.[0-9]+(?=\\.\\.)', line2).group(0)
                        maxDomain = re.search('(?<=\\.\\.)-?[0-9]+\\.[0-9]+', line2).group(0)
                        varTypes.update({varName: ("float", float(minDomain), float(maxDomain))})
                    elif "var {" in line2:
                        domain = re.search('\\{.*\\}:', line2).group(0)
                        numbers = [int(n) for n in list(re.findall('[0-9]+', domain))]
                        minDomain = min(numbers)
                        maxDomain = max(numbers)
                        varTypes.update({varName: ("set", minDomain, maxDomain)})
                    else:
                        print("Weird line: ")
                        print(line2)
                    continue
                elif line2.startswith("constraint"):
                    count += 1
                    constraints.update({count: (True, [var for var in allVars if var in line2])})
                    continue
                else:
                    continue

        print(fznFileWithout + " done")
    finally:
        # delete fznFileWithout
        if os.path.exists(fznFileWithout):
            print(fznFileWithout + " exists. Attempting to delete it (might fail if another thread does it first)")
            r_delete = subprocess.call("rm " + fznFileWithout, shell=True)

    return valArrays, varArrays, varTypes, constraints, allVars


## main

if __name__ == '__main__':
    numberOfFiles = len(trainFiles)
    numberOfPFiles = len(predictFiles)
    fileCount = 0

    trainFeatures = []
    trainLabels = []
    testFeatures = []
    testLabels = []
    graph = {}
    testID = []
    random.shuffle(trainFiles)
    for i, textfile in enumerate(trainFiles):
        name = re.search("(?<=Files/).*(?=_MUC)", textfile).group(0)
        print(name)
        fznFile = [value for value in trainFlatZinc if name + ".fzn.gz" in value][0]
        pickleFile = basePath + gcnPath + "pickles/" + name + ".pkl"
        pickleFile2 = basePath + gcnPath + "pickles2/" + name + ".pkl"

        if not os.path.exists(pickleFile):
            print(pickleFile + " doesn't exist")
            valArrays, varArrays, varTypes, constraints, allVars = parseFlatzinc(fznFile, textfile)
            pickle.dump([valArrays, varArrays, varTypes, constraints, allVars], open(pickleFile, "wb"))
        else:
            #print(pickleFile + " exists")
            valArrays, varArrays, varTypes, constraints, allVars = pickle.load(open(pickleFile, "rb"))

        features, labels, graph2, shift, _ = graphData(fileCount, valArrays, varArrays, varTypes, constraints, allVars)

        trainFeatures.extend(features)
        trainLabels.extend(labels)
        graph.update(graph2)
        fileCount += shift
        print(str(i), "training files parsed out of " + str(numberOfFiles))
        continue

    trainsize = len(trainLabels)
    print("---")
    print("training size " + str(trainsize))
    print("---")
    for i, textfile in enumerate(predictFiles):
        name = re.search("(?<=Files/).*(?=_MUC)", textfile).group(0)
        print(name)
        fznFile = [value for value in predictFlatZinc if name + ".fzn.gz" in value][0]
        pickleFile = basePath + gcnPath + "pickles/" + name + ".pkl"
        valArrays, varArrays, varTypes, constraints, allVars = parseFlatzincPredict(fznFile)
        features, labels, graph2, shift, varIDName = graphData(fileCount, valArrays, varArrays, varTypes, constraints, allVars)
        testFeatures.extend(features)
        testLabels.extend(labels)
        graph.update(graph2)
        testID.extend(list(range(fileCount, fileCount + shift)))
        fileCount += shift
        classifierFile = open(basePath + "PredictFiles/Predictions/" + name + "_P.txt", "w+")
        for id in varIDName:
            classifierFile.write(str(id) + "," + varIDName.get(id) + "\n")
        print(str(i), "prediction files parsed out of " + str(numberOfPFiles))
        continue

    print(len(testLabels))
    testsize = len(testLabels)

    x = csr_matrix(trainFeatures)
    tx = csr_matrix(testFeatures)
    allx = x
    y = np.array(trainLabels)
    ty = np.array(testLabels)
    ally = y
    pickle.dump(x, open(basePath + gcnPath + "data/ind.flatzinc.x", "wb"))
    pickle.dump(tx, open(basePath + gcnPath + "data/ind.flatzinc.tx", "wb"))
    pickle.dump(allx, open(basePath + gcnPath + "data/ind.flatzinc.allx", "wb"))
    pickle.dump(y, open(basePath + gcnPath + "data/ind.flatzinc.y", "wb"))
    pickle.dump(ty, open(basePath + gcnPath + "data/ind.flatzinc.ty", "wb"))
    pickle.dump(ally, open(basePath + gcnPath + "data/ind.flatzinc.ally", "wb"))
    pickle.dump(graph, open(basePath + gcnPath + "data/ind.flatzinc.graph", "wb"))
    indexFile = open(basePath + gcnPath + "data/ind.flatzinc.test.index", "w+")
    indexFile.write(str(trainsize) + "\n")
    indexFile.write(str(testsize))
    indexFile.close()

    print("done")
