package lib.behave.proto;

import java.lang.reflect.*;
import java.util.*;

import lib.asal.*;
import lib.behave.*;
import lib.blocks.common.ConnectionPt;
import lib.blocks.models.*;
import lib.utils.Dir;

/**
 * This machine contains all declarations of variables, flows, functions, and states.
 * The state hierarchy is validated.
 */
public class ProtoStateMachine implements ASALContextDecls {
	public final Set<Block> targetBlocks;
	public final Class<? extends StateMachine> clz;
	public final Map<String, Variable> inPortVars;
	public final Map<String, Variable> outPortVars;
	public final Map<String, Variable> stateMachineVars;
	public final Map<String, Function> functions;
	public final ProtoVertex rootVertex;
	public final Map<Class<?>, ProtoVertex> vertices;
	public final Set<ProtoVertex> initialVertices;
	public final Set<ProtoTransition> transitions;
	
	public ProtoStateMachine(Class<? extends StateMachine> clz, Block targetBlock) {
		this(clz, Collections.singleton(targetBlock));
	}
	
	public ProtoStateMachine(Class<? extends StateMachine> clz, Collection<Block> targetBlockCandidates) {
		this.clz = clz;
		
		targetBlocks = new HashSet<Block>(targetBlockCandidates);
		
		//Flows that enter and leave the block(s):
		inPortVars = new HashMap<String, Variable>();
		outPortVars = new HashMap<String, Variable>();
		
		try {
			for (Block block : targetBlocks) {
				for (ConnectionPt cp : block.getOwnedPts()) {
					if (cp instanceof ConnectionPt.PrimitivePt) {
						ConnectionPt.PrimitivePt pp = (ConnectionPt.PrimitivePt)cp;
						
						switch (pp.getDir()) {
							case IN:
								Variable rv = inPortVars.get(pp.field.getName());
								
								if (rv != null) {
									if (rv.getType() != pp.getType() || pp.getDir() != Dir.IN) {
										throw new Error("Blocks define different types/directions for port named " + pp.field.getName() + "!");
									}
								} else {
									inPortVars.put(pp.field.getName(), new Variable(pp.getType()));
								}
								break;
							case OUT:
								Variable wv = outPortVars.get(pp.field.getName());
								
								if (wv != null) {
									if (wv.getType() != pp.getType() || pp.getDir() != Dir.OUT) {
										throw new Error("Blocks define different types/directions for port named " + pp.field.getName() + "!");
									}
								} else {
									outPortVars.put(pp.field.getName(), new Variable(pp.getType()));
								}
								break;
						}
					}
				}
			}
		} catch (Exception e) {
			throw new Error(e);
		}
		
		//Variables and functions:
		stateMachineVars = new HashMap<String, Variable>();
		functions = new HashMap<String, Function>();
		
		for (Field field : clz.getFields()) {
			int mod = field.getModifiers();
			
			if (Modifier.isFinal(mod) && Modifier.isPublic(mod) && Modifier.isStatic(mod)) {
				try {
					Object fieldValue = field.get(null);
					
					if (Variable.class.isAssignableFrom(fieldValue.getClass())) {
						if (outPortVars.containsKey(field.getName())) {
							throw new Error("Block already has an outgoing flow named " + field.getName() + "!");
						}
						
						if (inPortVars.containsKey(field.getName())) {
							throw new Error("Block already has an incoming flow named " + field.getName() + "!");
						}
						
						stateMachineVars.put(field.getName(), (Variable)fieldValue);
					} else {
						if (fieldValue.getClass() == Function.class) {
							functions.put(field.getName(), (Function)fieldValue);
						} else {
							throw new Error("Could not interpret value of field " + field.getName() + "!");
						}
					}
					
				} catch (IllegalArgumentException | IllegalAccessException e) {
					throw new Error(e);
				}
			} else {
				throw new Error("Could not categorize field " + field.getName() + "!");
			}
		}
		
		//Vertices:
		vertices = new HashMap<Class<?>, ProtoVertex>();
		
		rootVertex = new ProtoVertex(null, clz);
		addVertex(rootVertex); //This recursively adds child vertices, too!
		initialVertices = new HashSet<ProtoVertex>(rootVertex.initialVertices);
		
		//Entry/exit/do behaviour:
		transitions = new HashSet<ProtoTransition>();
		
		for (ProtoVertex s : vertices.values()) {
			try {
				Constructor<?> cstr = s.clz.getConstructors()[0];
				Vertex state = (Vertex)cstr.newInstance();
				
				//onEntry
				if (State.class.isAssignableFrom(s.clz)) {
					Method m = State.class.getMethod("onEntry");
					LocalTransition onEntry = (LocalTransition)m.invoke(state);
					
					if (onEntry != null) {
						s.onEntry = new ProtoTransition(null, s, onEntry.getCode(), true);
					}
				}
				
				//onDo
				if (State.class.isAssignableFrom(s.clz)) {
					Method m = State.class.getMethod("onDo");
					LocalTransition[] onDo = (LocalTransition[])m.invoke(state);
					
					if (onDo != null) {
						for (LocalTransition o : onDo) {
							s.onDo.add(new ProtoTransition(s, s, o.getCode(), true));
						}
					}
				}
				
				//onExit
				if (State.class.isAssignableFrom(s.clz)) {
					Method m = State.class.getMethod("onExit");
					LocalTransition onExit = (LocalTransition)m.invoke(state);
					
					if (onExit != null) {
						s.onExit = new ProtoTransition(s, null, onExit.getCode(), true);
					}
				}
				
				//Incoming transitions:
				if (Vertex.class.isAssignableFrom(s.clz)) {
					Method m = Vertex.class.getMethod("getIncoming");
					Incoming[] incoming = (Incoming[])m.invoke(state);
					
					if (incoming != null) {
						for (Incoming i : incoming) {
							ProtoVertex otherState = vertices.get(i.getSourceState());
							
							if (otherState == null) {
								throw new Error("Unknown state (" + i.getSourceState().getCanonicalName() + ")!");
							}
							
							transitions.add(new ProtoTransition(otherState, s, i.getCode(), false));
						}
					}
				}
				
				//Outgoing transitions:
				if (Vertex.class.isAssignableFrom(s.clz)) {
					Method m = Vertex.class.getMethod("getOutgoing");
					Outgoing[] outgoing = (Outgoing[])m.invoke(state);
					
					if (outgoing != null) {
						for (Outgoing o : outgoing) {
							ProtoVertex otherState = vertices.get(o.getTargetState());
							
							if (otherState == null) {
								throw new Error("Unknown state (" + o.getTargetState().getCanonicalName() + ")!");
							}
							
							transitions.add(new ProtoTransition(s, otherState, o.getCode(), false));
						}
					}
				}
			} catch (NoSuchMethodException | SecurityException e) {
				throw new Error(e);
			} catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
				throw new Error(e);
			}
		}
		
		checkReachability();
		checkLiveness();
	}
	
	private void addVertex(ProtoVertex parentVertex) {
		vertices.put(parentVertex.clz, parentVertex);
		
		if (CompositeState.class.isAssignableFrom(parentVertex.clz)) {
			for (Class<?> childClass : parentVertex.clz.getDeclaredClasses()) {
				addChildClass(parentVertex, childClass);
			}
			
			if (parentVertex.initialVertices.isEmpty()) {
				throw new Error("Composite states and descendants should have " + InitialVertex.class.getCanonicalName() + "!");
			}
		} else {
			if (parentVertex.clz.getDeclaredClasses().length > 0) {
				throw new Error("Only state-machines and composite states can define internal states!");
			}
		}
	}
	
	private void addChildClass(ProtoVertex parentVertex, Class<?> childClass) {
		int mod = childClass.getModifiers();
		
		if (Modifier.isPublic(mod) && !Modifier.isInterface(mod) && Modifier.isStatic(mod)) {
			if (childClass.getSuperclass() == Vertex.class) {
				throw new Error("Class should NOT be an immediate subclass of " + Vertex.class.getCanonicalName() + "!");
			}
			
			if (!Vertex.class.isAssignableFrom(childClass)) {
				throw new Error("Class should be a subclass of " + Vertex.class.getCanonicalName() + "!");
			}
			
			if (StateMachine.class.isAssignableFrom(childClass)) {
				throw new Error("Internal state cannot be a " + StateMachine.class.getCanonicalName() + "!");
			}
			
			ProtoVertex childVertex = new ProtoVertex(parentVertex, childClass);
			
			if (InitialVertex.class.isAssignableFrom(childClass)) {
				parentVertex.initialVertices.add(childVertex);
			}
			
			addVertex(childVertex);
		} else {
			throw new Error("Could not categorize class " + parentVertex.clz.getSimpleName() + "!");
		}
	}
	
	@Override
	public Variable getWritableVariableDecl(String name) {
		Variable result = outPortVars.get(name);
		
		if (result != null) {
			return result;
		}
		
		return stateMachineVars.get(name);
	}
	
	@Override
	public Variable getVariableDecl(String name) {
		Variable v = getWritableVariableDecl(name);
		
		if (v != null) {
			return v;
		}
		
		return inPortVars.get(name);
	}
	
	@Override
	public Function getFunctionDecl(String name) {
		return functions.get(name);
	}
	
	@Override
	public String getScopeSuggestions() {
		List<String> suggestions = new ArrayList<String>();
		
		for (Map.Entry<String, Function> entry : functions.entrySet()) {
			suggestions.add("FCT " + entry.getKey() + "(): " + entry.getValue().getReturnType().name);
		}
		
		for (Map.Entry<String, Variable> entry : stateMachineVars.entrySet()) {
			suggestions.add("VAR " + entry.getKey() + ": " + entry.getValue().getType().name);
		}
		
		for (Map.Entry<String, Variable> entry : inPortVars.entrySet()) {
			if (!outPortVars.containsKey(entry.getKey())) {
				suggestions.add("IN " + entry.getKey() + ": " + entry.getValue().getType().name);
			}
		}
		
		for (Map.Entry<String, Variable> entry : outPortVars.entrySet()) {
			suggestions.add("OUT " + entry.getKey() + ": " + entry.getValue().getType().name);
		}
		
		switch (suggestions.size()) {
			case 0:
				return "!";
			case 1:
				return ", perhaps you meant " + suggestions.get(0) + "?";
			case 2:
				return ", perhaps you meant " + suggestions.get(0) + " or " + suggestions.get(1) + "?";
			default:
				String result = ", perhaps you meant one of the following:";
				
				for (String suggestion : suggestions) {
					result += "\n\t" + suggestion;
				}
				
				return result;
		}
	}
	
	public void checkReachability() {
		Set<ProtoVertex> beenHere = new HashSet<ProtoVertex>();
		beenHere.add(vertices.get(clz));
		
		Set<ProtoVertex> fringe = new HashSet<ProtoVertex>();
		Set<ProtoVertex> newFringe = new HashSet<ProtoVertex>();
		fringe.add(vertices.get(clz));
		
		Map<ProtoVertex, Set<ProtoVertex>> targetVerticesPerVertex = new HashMap<ProtoVertex, Set<ProtoVertex>>();
		
		for (ProtoVertex vertex : vertices.values()) {
			targetVerticesPerVertex.put(vertex, new HashSet<ProtoVertex>());
		}
		
		for (ProtoTransition transition : transitions) {
			targetVerticesPerVertex.get(transition.sourceState).add(transition.targetState);
		}
		
		while (fringe.size() > 0) {
			newFringe.clear();
			
			for (ProtoVertex f : fringe) {
				for (ProtoVertex targetVertex : targetVerticesPerVertex.get(f)) {
					if (beenHere.add(targetVertex)) {
						newFringe.add(targetVertex);
					}
				}
				
				for (ProtoVertex targetVertex : f.initialVertices) {
					if (beenHere.add(targetVertex)) {
						newFringe.add(targetVertex);
					}
				}
			}
			
			fringe.clear();
			fringe.addAll(newFringe);
		}
		
		Set<ProtoVertex> unreachedVertices = new HashSet<ProtoVertex>(vertices.values());
		unreachedVertices.removeAll(beenHere);
		
		if (unreachedVertices.size() > 0) {
			String s = "";
			
			for (ProtoVertex v : unreachedVertices) {
				s += "\n\t" + v.clz.getCanonicalName();
			}
			
			throw new Error("Could not reach the following vertices/states:" + s);
		}
	}
	
	public void checkLiveness() {
		Set<ProtoVertex> verticesWithOutgoingTransitions = new HashSet<ProtoVertex>();
		
		for (ProtoTransition transition : transitions) {
			verticesWithOutgoingTransitions.add(transition.sourceState);
		}
		
		Set<ProtoVertex> stateVertices = new HashSet<ProtoVertex>();
		
		for (ProtoVertex vertex : vertices.values()) {
			if (State.class.isAssignableFrom(vertex.clz)) {
				if (vertex.initialVertices.size() == 0) {
					stateVertices.add(vertex);
				}
			}
		}
		
		Set<ProtoVertex> livelessStateVertices = new HashSet<ProtoVertex>(stateVertices);
		livelessStateVertices.removeAll(verticesWithOutgoingTransitions);
		
		if (livelessStateVertices.size() > 0) {
			String s = "";
			
			for (ProtoVertex v : livelessStateVertices) {
				s += "\n\t" + v.clz.getCanonicalName();
			}
			
			System.err.println("Warning! Could not leave the following states:" + s);
			//throw new Error("Could not leave the following states:" + s);
		}
	}
}
