package lib.blocks.printing;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.Map.Entry;

import lib.asal.ASALContext;
import lib.asal.ASALDataType;
import lib.asal.ASALFunction;
import lib.asal.ASALVariable;
import lib.asal.parsing.ASALException;
import lib.asal.parsing.api.*;
import lib.behave.*;
import lib.behave.proto.*;
import lib.blocks.models.UnifyingBlock;
import lib.blocks.models.UnifyingBlock.*;

public class MCRL2Printer extends AbstractMCRL2Printer {
	
	public MCRL2Printer(UnifyingBlock target) throws ASALException {
		super(target);
	}
	
	private void printStdStruct(String name, Collection<?> items) {
		if (items.size() > 0) {
			Iterator<?> q = items.iterator();
			
			println("sort");
			println("\t" + name + " = struct " + q.next());
			
			while (q.hasNext()) {
				println("\t\t| " + q.next());
			}
			
			println("\t;");
		} else {
			println("sort");
			println("\t" + name + ";");
		}
	}
	
	private void printStdMapping(String name, String keyType, String itemType, Map<? extends Object, ? extends Object> items) {
		if (items.size() > 0) {
			println("map");
			println("\t" + name + ": " + keyType + " -> " + itemType + ";");
			println("eqn");
			
			for (Map.Entry<?, ?> entry : items.entrySet()) {
				println("\t" + name + "(" + String.valueOf(entry.getKey()) + ") = " + String.valueOf(entry.getValue()) + ";");
			}
		} else {
			println("map " + name + ": " + keyType + " -> " + itemType + ";");
		}
	}
	
//	private void printStdList(String name, String itemType, Collection<?> items) {
//		if (items.size() > 0) {
//			println("map");
//			println("\t" + name + ": List(" + itemType + ");");
//			println("eqn");
//			
//			int lastIndex = items.size() - 1;
//			List<?> list = new ArrayList<Object>(items);
//			println("\t" + name + " = [");
//			
//			for (int index = 0; index < lastIndex; index++) {
//				println("\t\t\t" + list.get(index) + ",");
//			}
//			
//			println("\t\t\t" + list.get(lastIndex));
//			println("\t\t];");
//		} else {
//			println("map " + name + ": List(" + itemType + ");");
//			println("eqn " + name + " = [];");
//		}
//	}
	
	protected void printHeader(String header, String... optionalLines) {
		println("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%");
		println("%% " + header);
		
		for (String optionalLine : optionalLines) {
			println("%% " + optionalLine);
		}
		
		println("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%");
	}
	
	private static String instructionsToStr(List<String> instructions) {
		if (instructions.size() > 0) {
			String result = instructions.get(0);
			
			for (int index = 1; index < instructions.size(); index++) {
				result += ", " + instructions.get(index);
			}
			
			return "[" + result + "]";
		}
		
		return "[]";
	}
	
	private String exprToStr(ASALContext sm, ASALExpr expr, String fallback) {
		if (expr == null) {
			return fallback;
		}
		
		ASALA2mCRL2Visitor v = new ASALA2mCRL2Visitor(this, sm);
		
		try {
			return instructionsToStr(v.visitExpr(expr));
		} catch (ASALException e) {
			throw new Error(e);
		}
	}
	
	private String literalToString(ASALLiteral lit) {
		switch (lit.getType()) {
		case BOOLEAN:
			return "Value_Bool(" + lit.getText().toLowerCase() + ")";
		case NUMBER:
			return "Value_Int(" + lit.getText() + ")";
		case STRING:
			return "Value_String(" + getName(lit) + ")";
		default:
			throw new Error("Should not happen!");
		}
	}
	
	private String statToStr(ASALContext sm, TritoTransition t) {
		if (t != null) {
			return statToStr(sm, t.getStatement());
		}
		
		return "[]";
	}
	
	private String statToStr(ASALContext sm, ASALStatement stat) {
		if (stat == null) {
			return "";
		}
		
		ASALA2mCRL2Visitor v = new ASALA2mCRL2Visitor(this, sm);
		
		try {
			return instructionsToStr(v.visitStat(stat));
		} catch (ASALException e) {
			throw new Error(e);
		}
	}
	
	//Helper class to store a component name and a port name
		private static class ComponentPortPair {
			final String component;
			final ReprPort port;
			
			ComponentPortPair(String comp, ReprPort port) {
				this.component = comp;
				this.port = port;
			}
		}
		
		//Helper class to store communication labels
		private static class PortChannel {
			public final ComponentPortPair sender;
			//We account for the possibility of multiple receivers by using a list
			public final List<ComponentPortPair> receivers;
			
			public PortChannel(ComponentPortPair sender, List<ComponentPortPair> receivers) {
				this.sender = sender;
				this.receivers = receivers;
			}
		}
		
		//Helper function to compute the communication labels
		private List<PortChannel> getPortChannels() {
			List<PortChannel> portChannels = new ArrayList<PortChannel>();

			for (ReprCommEvt evt : target.reprCommEvts) {
				String srcBeqName = getName(evt.source.owner);
				ComponentPortPair sender = new ComponentPortPair(srcBeqName, evt.source);
				List<ComponentPortPair> receivers = new ArrayList<ComponentPortPair>();
				
				for (ReprPort p2 : evt.targets) {
					String tgtBeqName = getName(p2.owner);
					ComponentPortPair recAction = new ComponentPortPair(tgtBeqName, p2);
					receivers.add(recAction);
				}
				
				portChannels.add(new PortChannel(sender, receivers));
			}
			
			for (ReprPort rp : target.reprPorts) {
				if (target.isPortToEnvironment(rp)) {
					switch (rp.getDir()) {
						case IN: //case communication from environment
							{
								String srcBeqName = "Environment";
								String tgtBeqName = getName(rp.owner);
								ComponentPortPair sender = new ComponentPortPair(srcBeqName, rp);
								ComponentPortPair receiver = new ComponentPortPair(tgtBeqName, rp);
								List<ComponentPortPair> receivers = new ArrayList<ComponentPortPair>();
								receivers.add(receiver);
								portChannels.add(new PortChannel(sender, receivers));
							}
							break;
						case OUT: //case communication to environment
							{
								String srcBeqName = getName(rp.owner);
								String tgtBeqName = "Environment";
								ComponentPortPair sender = new ComponentPortPair(srcBeqName, rp);
								ComponentPortPair receiver = new ComponentPortPair(tgtBeqName, rp);
								List<ComponentPortPair> receivers = new ArrayList<ComponentPortPair>();
								receivers.add(receiver);
								portChannels.add(new PortChannel(sender, receivers));
							}
							break;
					}
				}
			}
			
			return portChannels;
		}
		
	private String translateEvent(ASALEvent event, TritoStateMachine sm) {
		if (event == null) {
			return "none";
		}
		
		if(event instanceof ASALTimeout) {
			return "TimeoutEvent(" + exprToStr(sm,((ASALTimeout) event).getDuration(),"") + ")";
		} else if(event instanceof ASALCall) {
			return "CallEvent(" + ((ASALCall) event).getMethodName() + ")";
		} else if(event instanceof ASALTrigger) {
			ASALExpr expr = ((ASALTrigger) event).getExpr();
			return "TriggerEvent(" + exprToStr(sm,expr,"") + ")";
		} else {
			return "UnknownEventType";
		}
	}
	
	private String transformTomCRL2List(List<String> list) {
		if(list.size() == 0) {
			return "[]";
		} else {
			String returnString = "[" + list.get(0);
			for(int i = 1; i < list.size(); i++) {
				returnString = returnString.concat("," + list.get(i));
			}
			return returnString + "]";
		}
	}
	
	private void printlines(String ...lines) {
		for (String s: lines) {
			println(s);
		}
	}
	
	@Override
	protected void print(String mode, Object object) {
		String staticContent = "";
		try {
			InputStream in = getClass().getResourceAsStream("static.mcrl2");
			byte[] b = in.readAllBytes();
			staticContent = new String(b, StandardCharsets.US_ASCII);
		} catch (IOException e) {
			e.printStackTrace();
		}
		println(staticContent);
		
		printHeader("Struct containing all concrete strings:");
		printStdStruct("String", stringValues);
		
		printHeader("Struct containing all state names. The names are prefixed with the name of the statemachine (TODO is this required? or can we optimize?):");
		stateNames.add("root");
		printStdStruct("StateName", stateNames);
		
		beqNames.add("Environment");
		printHeader("Struct containing the names of all components/statemachines:");
		printStdStruct("CompName", beqNames);
		
		printHeader("Struct containing all port/variable names:");
		printStdStruct("VarName", varNames);
		
		
		
//		printStdList("local_vars", "var_name", new ArrayList<String>());
//		printStdList("instant_ports", "var_name", new ArrayList<String>());
//		printStdList("pseudo_states", "state", new ArrayList<String>());
//		printStdList("pulse_ports", "var_name", new ArrayList<String>());
		
		printHeader("Functions:");
		println("% All function names:");
		printStdStruct("FunctionName", funcNames);
		
		Map<String, String> paramsPerFunc = new HashMap<String, String>();
		Map<String, String> bodyPerFunc = new HashMap<String, String>();
		
		for (UnifyingBlock.ReprStateMachine rsm : target.reprStateMachines) {
			for (Map.Entry<String, ASALFunction> entry : rsm.representedStateMachine.functions.entrySet()) {
				String fctName = getName(entry.getValue());
				paramsPerFunc.put(fctName, "[]");
				bodyPerFunc.put(fctName, statToStr(entry.getValue().createContext(rsm.representedStateMachine), entry.getValue().getBody()));
			}
		}
		
		printStdMapping("getFunctionParams", "FunctionName", "List(VarName)", paramsPerFunc);
		printStdMapping("getFunctionBodies", "FunctionName", "Instructions", bodyPerFunc);
		
		
		
		Map<String, String> entryActionPerState = new HashMap<String, String>();
		Map<String, String> exitActionPerState = new HashMap<String, String>();
		Map<String, String> transitionsPerComp = new HashMap<String, String>();
		Map<String, String> portNamesPerComp = new HashMap<String, String>();
		Map<String, String> pulsePortNamesPerComp = new HashMap<String, String>();
		Map<String, String> finalVerticesPerComp = new HashMap<String, String>();
		Map<String, String> initialVerticesPerComp = new HashMap<String, String>();
		Map<String, String> junctionVerticesPerComp = new HashMap<String, String>();
		Map<String, String> compositeStatesPerComp = new HashMap<String, String>();
		Map<String, String> forkVerticesPerComp = new HashMap<String, String>();
		Map<String, String> joinVerticesPerComp = new HashMap<String, String>();
		Map<String, String> choiceVerticesPerComp = new HashMap<String, String>();
		
		for (ReprBlock rb : target.reprBlocks) {
			String component = getName(rb);
			List<String> transitions = new ArrayList<String>();
			List<String> portNames = new ArrayList<String>();
			List<String> pulsePortNames = new ArrayList<String>();
			List<String> joinVertices = new ArrayList<String>();
			List<String> forkVertices = new ArrayList<String>();
			List<String> junctionVertices = new ArrayList<String>();
			List<String> compositeStates = new ArrayList<String>();
			List<String> finalVertices = new ArrayList<String>();
			List<String> initialVertices = new ArrayList<String>();
			List<String> choiceVertices = new ArrayList<String>();
			
			if (rb.getOwnedStateMachine() != null) {
				TritoStateMachine sm = rb.getOwnedStateMachine().representedStateMachine;
				
				for (TritoVertex v : sm.vertices) {
					if (CompositeState.class.isAssignableFrom(v.getClz())) {
						compositeStates.add(getName(v));
					}
					if (InitialVertex.class.isAssignableFrom(v.getClz())) {
						initialVertices.add(getName(v));
					}
					if (ForkVertex.class.isAssignableFrom(v.getClz())) {
						forkVertices.add(getName(v));
					}
					if (JoinVertex.class.isAssignableFrom(v.getClz())) {
						joinVertices.add(getName(v));
					}
					if (JunctionVertex.class.isAssignableFrom(v.getClz())) {
						junctionVertices.add(getName(v));
					}
					if (FinalVertex.class.isAssignableFrom(v.getClz())) {
						finalVertices.add(getName(v));
					}
					if (ChoiceVertex.class.isAssignableFrom(v.getClz())) {
						choiceVertices.add(getName(v));
					}
					if (State.class.isAssignableFrom(v.getClz())) {
						String vertexName = getName(v);
						String key = component + ", " + vertexName;
						entryActionPerState.put(key, statToStr(sm, v.getOnEntry()));
						exitActionPerState.put(key, statToStr(sm, v.getOnExit()));

						for (TritoTransition t : v.getOnDo()) {
							transitions.add(
								"ProtoTransition(" +
								"\n\t\t\t\t" + getName(v) + ", % source internal" +
								"\n\t\t\t\t" + translateEvent(t.getEvent(),sm) + ", % trigger" +
								"\n\t\t\t\t" + exprToStr(sm, t.getGuard(), "[ASALA_PushValue({{{bool}}}(true))]") + ", % guard" +
								"\n\t\t\t\t" + statToStr(sm, t.getStatement()) + ", % effect" +
								"\n\t\t\t\t" + getName(v) + ", % target" +
								"\n\t\t\t\t true % target" +
								"\n\t\t\t)"
							);
						}
					} else {
						String key = component + ", " + getName(v);
						entryActionPerState.put(key, "[]");
						exitActionPerState.put(key, "[]");
					}
				}
				
				for (TritoTransition t : sm.transitions) {
					transitions.add(
						"ProtoTransition(" +
						"\n\t\t\t\t" + getName(t.getSourceVertex()) + ", % source" +
						"\n\t\t\t\t" + translateEvent(t.getEvent(),sm) + ", % trigger" +
						"\n\t\t\t\t" + exprToStr(sm, t.getGuard(), "[ASALA_PushValue({{{bool}}}(true))]") + ", % guard" +
						"\n\t\t\t\t" + statToStr(sm, t.getStatement()) + ", % effect" +
						"\n\t\t\t\t" + getName(t.getTargetVertex()) + ", % target" +
						"\n\t\t\t\t false % target" +
						"\n\t\t\t)"
					);
				}
			}
			
			finalVerticesPerComp.put(getName(rb), transformTomCRL2List(finalVertices));
			initialVerticesPerComp.put(getName(rb), transformTomCRL2List(initialVertices));
			junctionVerticesPerComp.put(getName(rb), transformTomCRL2List(junctionVertices));
			compositeStatesPerComp.put(getName(rb), transformTomCRL2List(compositeStates));
			forkVerticesPerComp.put(getName(rb), transformTomCRL2List(forkVertices));
			joinVerticesPerComp.put(getName(rb), transformTomCRL2List(joinVertices));
			choiceVerticesPerComp.put(getName(rb), transformTomCRL2List(choiceVertices));
			transitionsPerComp.put(getName(rb), transformTomCRL2List(transitions));
			
			
			for(ReprPort port : rb.ownedPorts) {
				if(!portNames.contains(getName(port))) {
					portNames.add(getName(port));
				}
				if(!pulsePortNames.contains(getName(port)) && port.getType() == ASALDataType.PULSE) {
					pulsePortNames.add(getName(port));
				}
			}
			portNamesPerComp.put(component, transformTomCRL2List(portNames));
			pulsePortNamesPerComp.put(component, transformTomCRL2List(pulsePortNames));
		}
		
		println("% The following mappings are never updated:");
		printStdMapping("entryAction", "CompName#StateName", "Instructions", entryActionPerState);
		printStdMapping("exitAction", "CompName#StateName", "Instructions", exitActionPerState);
		printStdMapping("protoTransitions", "CompName", "List(ProtoTransition)", transitionsPerComp);
		printStdMapping("ports", "CompName", "List(VarName)", portNamesPerComp);
		printStdMapping("pulsePorts", "CompName", "List(VarName)", pulsePortNamesPerComp);
		printStdMapping("initialStates", "CompName", "List(StateName)", initialVerticesPerComp);
		printStdMapping("compositeStates", "CompName", "List(StateName)", compositeStatesPerComp);
		printStdMapping("forkVertices", "CompName", "List(StateName)", forkVerticesPerComp);
		printStdMapping("joinVertices", "CompName", "List(StateName)", joinVerticesPerComp);
		printStdMapping("junctionVertices", "CompName", "List(StateName)", junctionVerticesPerComp);
		printStdMapping("finalStates", "CompName", "List(StateName)", finalVerticesPerComp);
		printStdMapping("choiceVertices", "CompName", "List(StateName)", choiceVerticesPerComp);
		
		println("%Parent relation");
		println("map");
		println("parent: CompName#StateName -> StateName;");
		println("eqn");
		for (ReprBlock rb : target.reprBlocks) {
			if (rb.getOwnedStateMachine() != null) {
				TritoStateMachine sm = rb.getOwnedStateMachine().representedStateMachine;
				for(TritoVertex v : sm.vertices)  {
					if(v.getParentVertex() == null) {
						println__("parent("+getName(rb)+","+getName(v)+") = root;");
					} else {
						println__("parent("+getName(rb)+","+getName(v)+") = "+getName(v.getParentVertex())+";");
					}
				}
			}
		}

		
		Map<String, String> initialStateConfigPerComp = new HashMap<String, String>();
		Map<String, String> initialMonitorsPerComp = new HashMap<String, String>();
		String initialValuation = "";

		for(ReprBlock block : target.reprBlocks) {
			List<String> monitors = new ArrayList<String>();
			if (block.getOwnedStateMachine() != null) {
				ReprStateMachine m = block.getOwnedStateMachine();
				List<String> initialVertices = new ArrayList<String>();
				for(TritoVertex v: m.representedStateMachine.rootVertex.getInitialVertices()) {
					initialVertices.add("StateConfig(" + getName(v) + ", [])");
				}
				String initialStateConfig = "StateConfig(" + getName(m.representedStateMachine.rootVertex) 
					+ "," + transformTomCRL2List(initialVertices) + ")";
				initialStateConfigPerComp.put(getName(block), initialStateConfig);
				for(TritoTransition t : m.representedStateMachine.transitions) {
					if(t.getEvent() instanceof ASALTrigger) {
						monitors.add("Monitor(" + exprToStr(m.representedStateMachine,((ASALTrigger) t.getEvent()).getExpr(),"") + ", false)");
					}
				}

				Set<Entry<String, ASALLiteral>> initializedVariables = m.representedStateMachine.initialValuation.entrySet();
				Set<String> allInitialized = new HashSet<String>();
				
				for(Map.Entry<String,ASALLiteral> entry : initializedVariables) {
					initialValuation += "initialValuation(" + getName(block) + ")(" + entry.getKey() + ") = ";
					initialValuation += literalToString(entry.getValue()) + ";\n";
					allInitialized.add(entry.getKey());
				}
				
				for(Map.Entry<ReprPort,ASALLiteral> entry: target.getInitValuesEnvPorts().entrySet()) {
					if(block == entry.getKey().owner) {
						initialValuation += "initialValuation(" + getName(block) + ")(" + getName(entry.getKey()) + ") = ";
						initialValuation += literalToString(entry.getValue()) + ";\n";
						allInitialized.add(getName(entry.getKey()));
					}
				}
				
				Set<ASALVariable<?>> allVariables = new HashSet<ASALVariable<?>>();
				allVariables.addAll(m.representedStateMachine.stateMachineVars.values());
				allVariables.addAll(m.representedStateMachine.inPortVars.values());
				allVariables.addAll(m.representedStateMachine.outPortVars.values());
				for(ASALVariable<?> var : allVariables) {
					if(!allInitialized.contains(var.getName())) {
						String defaultValue = "";
						switch (var.getType()) {
							case VOID: 
								defaultValue = "Value_None"; 
								break;
							case PULSE: 
								defaultValue = "Value_Bool(false)";
								break;
							case BOOLEAN: 
								defaultValue = "Value_Bool(false)";
								break;
							case NUMBER: 
								defaultValue = "Value_Int(0)";
								break;
							case STRING: 
								defaultValue = "Value_String(STR_)";
								break;
							default: 
								defaultValue = "Value_None";
						}
						initialValuation += "initialValuation(" + getName(block) + ")(" + var.getName() + ") = ";
						initialValuation += defaultValue + ";\n";
					}
				}
			}
			initialMonitorsPerComp.put(getName(block), transformTomCRL2List(monitors));
		}
		
		println("% The following mappings is never updated:");
		printStdMapping("initialStateConfig", "CompName", "StateConfig", initialStateConfigPerComp);
		
		println("% The following mapping is never updated:");
		printStdMapping("initialMonitors", "CompName", "List(Monitor)", initialMonitorsPerComp);
		
		printlines(
				"% The following mapping is never updated:",
				"map",
				"	initialValuation: CompName -> VarName -> Value;",
				"eqn");
		println(initialValuation);
		
		println("proc");
		println__("MessagingIntermediary = delta");
		List<PortChannel> portChannels = this.getPortChannels();
		for (PortChannel pc : portChannels) {
			String line = "+ (sum v:Value. ";
			line = line + "send_intermediate(" + pc.sender.component + "," + getName(pc.sender.port) + ",v)" ;
			for (ComponentPortPair receiver : pc.receivers) {
				line = line + "|receive_intermediate(" + receiver.component + "," + getName(receiver.port) + ",v)";
			}
			line = line + ".MessagingIntermediary)";
			println____(line);
		}
		println__(";");
		
		println__("Environment = delta");
		for (PortChannel pc : portChannels) {
			if(pc.sender.component == "Environment") {
				String valueRestriction;
				String extraValueRestriction;
				if(target.getEnvRestrictions().containsKey(pc.receivers.get(0).port)) {
					extraValueRestriction = "&& v in ";
					List<String> valueOptions = new ArrayList<String>();
					for(ASALLiteral l : target.getEnvRestrictions().get(pc.receivers.get(0).port)) {
						valueOptions.add(this.literalToString(l));
					}
					extraValueRestriction += this.transformTomCRL2List(valueOptions);
				} else {
					extraValueRestriction = "";
				}
				switch (pc.sender.port.getType()) {
					case VOID: 
						valueRestriction = "sum v:Value. (v == Value_None) ->"; 
						break;
					case PULSE: 
						valueRestriction = "sum v:Value. (v == Value_Bool(true)" 
								+ extraValueRestriction + ") ->";
						break;
					case BOOLEAN: 
						valueRestriction = "sum v:Value. (v == Value_Bool(true) || v == Value_Bool(false)" 
								+ extraValueRestriction + ") ->";
						break;
					case NUMBER: 
						valueRestriction = "sum v:Value, i:Int. (v == Value_Int(i) && i == 1 " 
								+ extraValueRestriction + ") ->";
						break;
					case STRING: 
						valueRestriction = "sum v:Value, s:String. (v == Value_String(s) " 
								+ extraValueRestriction + ") ->";
						break;
					default: 
						valueRestriction = "";
				}
				println____("+ (" + valueRestriction + "sendP(Environment,"+getName(pc.sender.port)+",v).Environment)");
			}
			for(ComponentPortPair rec: pc.receivers) {
				if(rec.component == "Environment") {
					println____("+ (sum v:Value. receiveP(Environment,"+getName(rec.port)+",v).Environment)");
				}
			}
		}
		println__(";");
		
		printlines("% Initialization)",
		"init",
		"	allow({discard_event,in_state,buffer_contains,fire_event,execute_code,start_transition,continue_transition,",
		"		send|receive,send|receive|receive,send|receive|receive|receive,send|receive|receive|receive|receive",
		"	},",
		"		comm({",
		"			sendP|send_intermediate -> send,", 
		"			receive_intermediate|receiveP -> receive",
		"		},",
		"			MessagingIntermediary||Environment");
		for (ReprBlock rb : target.reprBlocks) {
			String component = getName(rb);
			println______("||StateMachine(" + component + ",initialStateConfig("+component+"),[],initialMonitors("+component+"),"
					+ "initialValuation("+component+"),{},[], EmptyFctFrame, [])");
		}
		println("	));");
	}
}





