/**
 * class to merge given array of RB trees
 * @author Lukas Armborst, University of Twente
 * @year 2021
 */

final class TreeMerger {

    /*********
     Resources
     *********/

    /*@
    /// initial configuration of an iterator acting as a source (consumer)
    resource src_initial(ListListQueue it)
        = it != null ** it.consumer() 
          ** it.toBagC() == bag<Node>{}
          ** \unfolding it.consumer() \in it.readHead==0;
    
    /// initial configuration of an iterator acting as a destination (producer)
    resource dst_initial(ListListQueue it)
        = it != null ** it.producer() ** Perm(Integer.MIN_VALUE, read)
          ** it.toBagP() == bag<Node>{}
          ** \unfolding it.producer() \in !it.finalised
                                            ** it.maxKey == Integer.MIN_VALUE;
    @*/
    
    /*@
    /// assert that given iterator is finished producing
    ///     and corresponds to bag at index idx of nodeBags.
    /// Should be boolean function, but Viper had trouble so resource instead.
    ///     To be self-framed, producer also encapsuled in this resource.
    resource equal_bags(seq<bag<Node>> nodeBags, ListListQueue iterator, int idx)
        = iterator != null
          ** 0<=idx ** idx<|nodeBags|
          ** iterator.producer()
          ** (\unfolding iterator.producer()
                  \in iterator.finalised)
          ** iterator.toBagP() == nodeBags[idx];
    @*/


    /****************
     Helper Functions
     ****************/

    /*@
    /// convert array of trees into sequence of bags (of Nodes)
        requires trees != null;
        requires (\forall* int j; 0<=j && j<trees.length; Perm(trees[j], 1\2));
        requires (\forall int j; 0<=j && j<trees.length; trees[j] != null);
        requires (\forall* int j; 0<=j && j<trees.length; Tree.tree_perm(trees[j]));
        ensures |\result| == trees.length;
        ensures (\forall int j; 0<=j && j<trees.length; 
                    \result[j] == Util.toBag(Tree.toSeq(trees[j])));
    pure static inline seq<bag<Node>> toBags(Node[] trees)
        = toBags(trees, 0);
    
    /// iterate over tree array to convert array of trees into sequence of bags (of Nodes)
        requires idx >= 0;
        requires trees != null;
        requires (\forall* int j; 0<=j && j<trees.length; Perm(trees[j], 1\2));
        requires (\forall int j; 0<=j && j<trees.length; trees[j] != null);
        requires (\forall* int j; 0<=j && j<trees.length; Tree.tree_perm(trees[j]));
        ensures idx <= trees.length ==> |\result| == trees.length - idx;
        ensures (\forall int j; idx<=j && j<trees.length;  
                    \result[j-idx] == Util.toBag(Tree.toSeq(trees[j])));
    pure static seq<bag<Node>> toBags(Node[] trees, int idx)
        = idx >= trees.length ? seq<bag<Node>>{}
            : seq<bag<Node>>{Util.toBag(Tree.toSeq(trees[idx]))} + toBags(trees, idx+1);
    
    @*/
    
    /*@
    /// ascertain that nodeBags indeed represents a merger structure,
    ///     i.e. merged bag equals sum of constituent bags.
        requires 0<=idx && 2*idx+1<|nodeBags| && 0<=offset && idx+offset < |nodeBags|;
        ensures \result ==> (\forall Node n; (n \memberof nodeBags[2*idx]) > 0;
                             (n \memberof nodeBags[idx + offset]) > 0);
        ensures \result ==> (\forall Node n; (n \memberof nodeBags[2*idx + 1]) > 0;
                             (n \memberof nodeBags[idx + offset]) > 0);
        ensures \result ==> (|nodeBags[idx+offset]| == |nodeBags[2*idx]| + |nodeBags[2*idx+1]|);
    static pure boolean bagSum(seq<bag<Node>> nodeBags, int idx, int offset)
        = nodeBags[idx + offset] == nodeBags[2 * idx] + nodeBags[2 * idx + 1];


    /// lemma ensuring that if we extend nodeBags, the bagSum property of previous entries is
    ///     preserved
        requires newSeq == oldSeq + diff;
        requires maxI>0 && 2*maxI - 1 < |oldSeq| && 0 <= offset && maxI + offset <= |oldSeq|;
        requires (\forall int j; 0 <= j && j < maxI; bagSum(oldSeq, j, offset));
        ensures (\forall int j; 0 <= j && j < maxI; bagSum(newSeq, j, offset));
        ensures \result;
    static inline pure boolean bagSumExtensionLemma(seq<bag<Node>> oldSeq,
                                                           seq<bag<Node>> newSeq,
                                                           seq<bag<Node>> diff,
                                                           int maxI,
                                                           int offset)
        = bagSumExtensionLemma(oldSeq, newSeq, diff, maxI, offset, 0);

    /// iterating over bag sequence to prove bagSumExtensionLemma
        requires newSeq == oldSeq + diff;
        requires 0<=idx && idx<maxI;
        requires 2*maxI - 1 < |oldSeq| && 0 <= offset && maxI + offset <= |oldSeq|;
        requires (\forall int j; idx <= j && j < maxI; bagSum(oldSeq, j, offset));
        ensures (\forall int j; idx <= j && j < maxI; bagSum(newSeq, j, offset));
        ensures \result;
    static pure boolean bagSumExtensionLemma(seq<bag<Node>> oldSeq,
                                             seq<bag<Node>> newSeq,
                                             seq<bag<Node>> diff,
                                             int maxI,
                                             int offset,
                                             int idx)
        = (oldSeq[2*idx] == newSeq[2*idx] 
           && oldSeq[2*idx+1] == newSeq[2*idx+1] 
           && oldSeq[offset+idx] == newSeq[offset+idx] 
           && bagSum(oldSeq, idx, offset) 
           ==> bagSum(newSeq, idx, offset))
          && (idx+1 < maxI ==> bagSumExtensionLemma(oldSeq, newSeq, diff, maxI, offset, idx+1));


    /// transitive application of bagSum to prove that if node is in nodeBags[idx],
    ///     then it is also in the last bag 
        requires |nodeBags| == 2*treesLen - 1 && 0 <= idx && idx < |nodeBags|;
        requires (\forall int j; 0 <= j && j < treesLen - 1;
                    {: bagSum(nodeBags, j, treesLen) :});
        ensures \result;
        ensures 0 < (node \memberof nodeBags[idx])
                ==> 0 < (node \memberof nodeBags[2*treesLen - 2]);
    static pure boolean transitivityLemma(seq<bag<Node>> nodeBags, int treesLen,
                                            int idx, Node node)
        = idx < 2*treesLen - 2 ==> (bagSum(nodeBags, idx/2, treesLen)
                                     && (\forall Node n; {: (n \memberof nodeBags[idx]) :} > 0;
                                         (n \memberof nodeBags[idx/2 + treesLen]) > 0 )
                                     && transitivityLemma(nodeBags, treesLen, idx/2+treesLen, node));

    /// transitive application of bagSum to prove that any node in an early bag 
    ///     (i.e. in a given tree), is also in the last bag 
        requires |nodeBags| == 2*treesLen - 1;
        requires (\forall int j; 0 <= j && j < treesLen - 1;
                    {: bagSum(nodeBags, j, treesLen) :});
        ensures \result;
        ensures (\forall int j; 0<=j && j<treesLen;
                    (\forall Node n; {: (n \memberof nodeBags[j]) :} > 0;
                        0 < (n \memberof nodeBags[2*treesLen - 2])));
    public static pure boolean transClosureBagSumLemma(seq<bag<Node>> nodeBags, int treesLen)
        = (\forall int j; 0<=j && j<treesLen;
            (\forall Node n; {: (n \memberof nodeBags[j]) :} > 0;
                transitivityLemma(nodeBags, treesLen, j, n)));

    @*/

    /// Changing nodeBags means replacing it (as bags are immutable), meaning the equal_bags
    ///     resources point to the wrong (outdated) bags now. This method fixes that, by unfolding
    ///     the resources with the old bags and re-folding them with the new bags.
    /// The method takes the old version of nodeBags and the new version, as well as the array of
    ///     iterators as argument. The last argument is the loop index of joinMergerThreads, that
    ///     determines for which iterators we currently have the equal_bags resource.
    /*@
        context 0 <= idx && 2*idx+2 <= |oldBag| && |oldBag| <= |newBag|;
        context iterators != null && |newBag| <= iterators.length;
        context (\forall* int j; 0<=j && j<iterators.length; Perm(iterators[j], 1\2));
        context (\forall int j; 0<=j && j<|oldBag|; newBag[j] == oldBag[j]);
        requires (\forall* int j; 2*idx+2<=j && j<|oldBag|; equal_bags(oldBag, iterators[j], j));
        ensures (\forall* int j; 2*idx+2<=j && j<|oldBag|; equal_bags(newBag, iterators[j], j));
    ghost void refoldEqual_bags(seq<bag<Node>> oldBag, seq<bag<Node>> newBag,
                                ListListQueue[] iterators, int idx)
    {
        ghost for (int k=2*idx+2; k<|oldBag|; k++)
            loop_invariant 0<=idx && 2*idx+2<=k && k<=|oldBag| && |oldBag| <= |newBag|;
            loop_invariant iterators != null  && |newBag| <= iterators.length;
            loop_invariant (\forall* int j; 0<=j && j<iterators.length; Perm(iterators[j], 1\2));
            loop_invariant (\forall* int j; 2*idx+2<=j && j<k;
                                equal_bags(newBag, iterators[j], j));
            loop_invariant (\forall* int j; k<=j && j<|oldBag|;
                                equal_bags(oldBag, iterators[j], j));
        {
            unfold equal_bags(oldBag, iterators[k], k);
            fold equal_bags(newBag, iterators[k], k);
        }
    }
    @*/

    /*@
    /// unfold equal_bags and join producer and consumer of the finished iterator
        requires 0<=idx && idx<|nodeBags|;
        requires equal_bags(nodeBags, iterator, idx);
        requires iterator.consumer();
        requires iterator.done() == true;
        ensures Perm(iterator.allP, 1\2);
        ensures iterator.toBag() == nodeBags[idx];
        ensures iterator.toBag() == \old(iterator.toBagC());
    ghost void joinIteratorPC(ListListQueue iterator, seq<bag<Node>> nodeBags, int idx) {
        unfold equal_bags(nodeBags, iterator, idx);
        assert iterator.producer();
        assert \unfolding iterator.producer()
                \in iterator.toBag() == nodeBags[idx];
        ghost iterator.joinPC();
    }
    @*/


    /**************
     Init for merge
     **************/

    /// (concurrently) initialise the first few iterators to represent the given trees
    /// and the rest to be empty ListListQueues ready to start merging
    /*@
        given seq<bag<Node>> treeBags;
        context trees != null && iterators != null;
        context iterators.length == 2*trees.length-1 && |treeBags| == trees.length;
        context (\forall* int j; 0<=j && j<iterators.length; Perm(iterators[j], write));
        context (\forall* int j; 0<=j && j<trees.length; Perm(trees[j], 1\2));
        context (\forall int j; 0<=j && j<trees.length; trees[j] != null);
        requires (\forall* int j; 0<=j && j<trees.length; Tree.tree_perm(trees[j]));
        requires (\forall int j; 0<=j && j<trees.length; Tree.validTree( {: trees[j] :} ));
        requires (\forall int j; 0<=j && j<trees.length; 
                    treeBags[j] == Util.toBag(Tree.toSeq(trees[j])));
        ensures (\forall int j; 0<=j && j<iterators.length; iterators[j]!=null);
        ensures (\forall int j; 0<=j && j<iterators.length; 
                    (\forall int k; 0<=k && k<j ; 
                        iterators[j] != iterators[k]));
        ensures (\forall* int j; 0<=j && j<|treeBags|; equal_bags(treeBags, iterators[j], j));
        ensures (\forall* int j; 0<=j && j<iterators.length; src_initial(iterators[j]));
        ensures (\forall* int j; trees.length<=j && j<iterators.length; 
                    dst_initial(iterators[j]));
    @*/
    public void initIterators(Node[] trees, ListListQueue[] iterators) {
        TreeConverterThread[] threads = new TreeConverterThread[trees.length];
        
        int i;
        
        /// start parallel conversion of trees into iterators
        /*@
            loop_invariant threads.length == trees.length;
            loop_invariant 0<=i && i<=trees.length;
            loop_invariant (\forall* int j; 0<=j && j<threads.length; 
                                Perm(threads[j], write));
            loop_invariant (\forall* int j; 0<=j && j<trees.length; Perm(trees[j], 1\2));
            loop_invariant (\forall int j; 0<=j && j<trees.length; trees[j] != null);
            loop_invariant (\forall* int j; i<=j && j<trees.length; Tree.tree_perm(trees[j]));
            loop_invariant (\forall int j; i<=j && j<trees.length; Tree.validTree( {: trees[j] :} ));
            loop_invariant (\forall int j; 0<=j && j<i; threads[j] != null);
            loop_invariant (\forall* int j; 0<=j && j<i; threads[j].join_token(write));
            loop_invariant (\forall* int j; 0<=j && j<i;
                                Perm(threads[j].treeAsSeq, 1\2)
                                ** Perm(threads[j].keysAsSeq, 1\2));
            loop_invariant (\forall int j; 0<=j && j<i; 
                                threads[j].treeAsSeq == \old(Tree.toSeq(trees[j]))
                                && threads[j].keysAsSeq == \old(Tree.toSeqKeys(trees[j])));
            loop_invariant (\forall int j; i<=j && j<trees.length; 
                                Tree.toSeq({: trees[j] :}) == \old(Tree.toSeq(trees[j]))
                                && Tree.toSeqKeys({: trees[j] :}) == \old(Tree.toSeqKeys(trees[j])));
        @*/
        for (i=0; i<trees.length; i++) {
            threads[i] = new TreeConverterThread(trees[i]);
            /*@ assume (\forall int j; 0<=j && j<i; threads[j] != threads[i]); @*/
            threads[i].start();
        } 
        
        /*@ assume Perm(Integer.MIN_VALUE, read); @*/
        
        /// intialise empty iterators for mergers
        /*@
            loop_invariant Perm(Integer.MIN_VALUE, read);
            loop_invariant iterators.length == 2*trees.length-1;
            loop_invariant trees.length<=i && i<=iterators.length;
            loop_invariant (\forall* int j; 0<=j && j<iterators.length; Perm(iterators[j], write));
            loop_invariant (\forall int j; trees.length<=j && j<i; iterators[j]!=null);
            loop_invariant (\forall* int j; trees.length<=j && j<i; 
                                dst_initial(iterators[j]) ** src_initial(iterators[j]));
            loop_invariant (\forall int j; trees.length<=j && j<i; 
                            (\forall int k; trees.length<=k && k<j; 
                                iterators[j]!=iterators[k]));
        @*/
        for (i=trees.length; i<2*trees.length-1; i++) {
            iterators[i] = new ListListQueue();
            /*@ assume (\forall int j; trees.length<=j && j<iterators.length; 
                            iterators[j]!=iterators[i]); @*/
            /*@ fold src_initial(iterators[i]); @*/
            /*@ fold dst_initial(iterators[i]); @*/
        }
        
        /// join converter threads, now finished turning trees into iterators
        /*@
            loop_invariant Perm(Integer.MIN_VALUE, read);
            loop_invariant 0<=i && i<=threads.length && threads.length == trees.length;
            loop_invariant iterators.length == 2*trees.length-1;
            loop_invariant (\forall* int j; 0<=j && j<threads.length; Perm(threads[j], write));
            loop_invariant (\forall* int j; 0<=j && j<trees.length; Perm(iterators[j], write));
            loop_invariant (\forall* int j; trees.length<=j && j<iterators.length; 
                                Perm(iterators[j], 1\2));
            loop_invariant (\forall int j; 
                                (0<=j && j<i) || (trees.length<=j && j<iterators.length); 
                                iterators[j] != null);
            loop_invariant (\forall int j; 0<=j && j<threads.length; threads[j] != null);
            loop_invariant (\forall* int j; i<=j && j<trees.length; threads[j].join_token(write));
            loop_invariant (\forall* int j; i<=j && j<threads.length;
                                Perm(threads[j].treeAsSeq, 1\2)
                                ** Perm(threads[j].keysAsSeq, 1\2));
            loop_invariant (\forall int j; i<=j && j<threads.length; 
                                threads[j].treeAsSeq == \old(Tree.toSeq(trees[j]))
                                && threads[j].keysAsSeq == \old(Tree.toSeqKeys(trees[j])));
            loop_invariant (\forall int j; 
                                (0<=j && j<i) || (trees.length<=j && j<iterators.length); 
                                (\forall int k; 
                                    (0<=k && k<j && k<i) || (trees.length<=k && k<j); 
                                    iterators[j] != iterators[k]));
            loop_invariant (\forall* int j; 0<=j && j<i; src_initial(iterators[j]));
            loop_invariant (\forall* int j; 0<=j && j<i; equal_bags(treeBags, iterators[j], j));
        @*/
        for (i=0; i<trees.length; i++) {
            TreeConverterThread thread = threads[i];
            thread.join();
            /*@ unfold thread.post_join(write); @*/
            iterators[i] = thread.res;
            /*@ assume (\forall int j; 0<=j && j<iterators.length; iterators[j]!=iterators[i]); @*/
            /*@ fold src_initial(iterators[i]); @*/
            /*@ fold equal_bags(treeBags, iterators[i], i); @*/
        }
        
    }

    /*************************
     Thread handling for merge
     *************************/

    /// Initialise all merger threads in given array, and call "start" on them to kick off merging.
    /*@
        given seq<bag<Node>> treeBags;
        context iterators != null ** threads != null ** treesLen > 0;
        context iterators.length == 2*treesLen-1 && threads.length == treesLen-1;
        context |treeBags| == treesLen;
        context (\forall* int j; 0<=j && j<iterators.length; Perm(iterators[j], write));
        context (\forall int j; 0<=j && j<iterators.length; iterators[j] != null);
        context (\forall int j; 0<=j && j<iterators.length; 
                    (\forall int k; 0<=k && k<j ; 
                        iterators[j] != iterators[k]));
        context (\forall* int j; 0<=j && j<threads.length; Perm(threads[j], write));
        context (\forall* int j; 0<=j && j<|treeBags|; equal_bags(treeBags, iterators[j], j));
        /// all iterators can be used as source, while only the second half can be targets 
        ///     (first half is already set to be the given trees)
        requires (\forall* int j; 0<=j && j<iterators.length; 
                            src_initial(iterators[j]));
        requires (\forall* int j; treesLen<=j && j<iterators.length;
                            dst_initial(iterators[j]));
        /// only the last iterator (the final fully merged tree) can still be read, 
        ///     all others are assigned to some merger threads as source
        ensures src_initial(iterators[iterators.length-1]);
        ensures (\forall int j; 0<=j && j<threads.length; threads[j] != null);
        ensures (\forall* int j; 0<=j && j<threads.length; 
                    Perm(threads[j].src1, 1\2)
                    ** Perm(threads[j].src2, 1\2)
                    ** Perm(threads[j].dst, 1\2));
        ensures (\forall int j; 0<=j && j<threads.length; 
                            threads[j].src1 == iterators[2*j]
                            && threads[j].src2 == iterators[2*j+1]
                            && threads[j].dst == iterators[treesLen+j]);
        /// all threads are running and can be joined at some future point
        ensures (\forall* int j; 0<=j && j<threads.length; threads[j].join_token(write));
        ensures treeBags == \old(treeBags);
    @*/
    void startMergerThreads(ListListQueue[] iterators, ListMergerThread[] threads, int treesLen)
    {
        int i;
        /*@ assume  Perm(Integer.MIN_VALUE, read); @*/
        
        /*@
            loop_invariant 0<=i && i<=treesLen-1;
            loop_invariant Perm(Integer.MIN_VALUE, read);
            loop_invariant iterators.length == 2*treesLen-1 && threads.length == treesLen-1;
            loop_invariant |treeBags| == treesLen;
            loop_invariant (\forall* int j; 0<=j && j<iterators.length; Perm(iterators[j], write));
            loop_invariant (\forall int j; 0<=j && j<iterators.length; iterators[j] != null);
            loop_invariant (\forall int j; 0<=j && j<iterators.length; 
                                (\forall int k; 0<=k && k<j ; 
                                    iterators[j] != iterators[k]));
            loop_invariant (\forall* int j; 0<=j && j<threads.length; Perm(threads[j], write));
            loop_invariant (\forall* int j; 0<=j && j<|treeBags|;
                                equal_bags(treeBags, iterators[j], j));
            loop_invariant (\forall int j; 0<=j && j<i; threads[j] != null);
            loop_invariant (\forall* int j; 0<=j && j<i; 
                                Perm(threads[j].src1, 1\2)
                                ** Perm(threads[j].src2, 1\2)
                                ** Perm(threads[j].dst, 1\2));
            loop_invariant (\forall int j; 0<=j && j<i; 
                                threads[j].src1 == iterators[2*j]
                                && threads[j].src2 == iterators[2*j+1]
                                && threads[j].dst == iterators[treesLen+j]);
            loop_invariant (\forall* int j; 0<=j && j<i; threads[j].join_token(write));
            loop_invariant (\forall* int j; 2*i<=j && j<iterators.length; 
                                src_initial(iterators[j]));
            loop_invariant (\forall* int j; treesLen+i<=j && j<iterators.length;
                                dst_initial(iterators[j]));
            loop_invariant treeBags == \old(treeBags);
        @*/
        for (i=0; i<treesLen-1; i++) {
            ListListQueue d = iterators[treesLen+i];
            ListListQueue s1 = iterators[2*i];
            ListListQueue s2 = iterators[2*i+1];
            /*@ unfold dst_initial(d); @*/
            /*@ unfold src_initial(s1); @*/
            /*@ unfold src_initial(s2); @*/
            threads[i] = new ListMergerThread(s1, s2, d);
            /*@ assume (\forall int j; 0<=j && j<i; threads[j] != threads[i]); @*/
            threads[i].start();
        }
    }
    

    /// join all merger threads after they have finished merging.
    ///     takes the threads to be merged and the associated ListListQueues 
    ///     and the number of merged trees
    /*@
        /// sequence of bags representing the trees to merge
        given seq<bag<Node>> treeBags;
        given ListListQueue[] iterators;
        yields ListListQueue merged;
        context iterators != null ** threads != null ** treesLen > 0;
        context iterators.length == 2*treesLen-1 && threads.length == treesLen-1;
        context (\forall* int j; 0<=j && j<iterators.length; Perm(iterators[j], write));
        context (\forall int j; 0<=j && j<iterators.length; iterators[j] != null);
        context (\forall int j; 0<=j && j<iterators.length; 
                    (\forall int k; 0<=k && k<j ; 
                        iterators[j] != iterators[k]));
        context src_initial(iterators[2*treesLen-2]);
        context (\forall* int j; 0<=j && j<threads.length; Perm(threads[j], write));
        context (\forall int j; 0<=j && j<threads.length; threads[j] != null);
        context (\forall int j; 0<=j && j<threads.length; 
                    (\forall int k; 0<=k && k<j; 
                        threads[j] != threads[k]));
        context (\forall* int j; 0<=j && j<threads.length; 
                    Perm(threads[j].src1, 1\2)
                    ** Perm(threads[j].src2, 1\2)
                    ** Perm(threads[j].dst, 1\2));
        context (\forall int j; 0<=j && j<threads.length; 
                    threads[j].src1 == iterators[2*j]
                    && threads[j].src2 == iterators[2*j+1]
                    && threads[j].dst == iterators[treesLen+j]);
        requires |treeBags| == treesLen;
        requires (\forall* int j; 0<=j && j<threads.length; threads[j].join_token(write));
        requires (\forall* int j; 0<=j && j<treesLen;
                    equal_bags(treeBags, iterators[j], j) );
        ensures merged == iterators[2*treesLen-2];
        ensures merged.producer() 
                ** \unfolding merged.producer() 
                        \in merged.finalised;
        ensures (\let bag<Node> last_bag = merged.toBagP();
                    (\forall int j; 0<=j && j<treesLen;
                        (\forall Node n; 
                            {: (n \memberof treeBags[j]) :} > 0;
                            (n \memberof last_bag) > 0)));
    @*/
    void joinMergerThreads(ListMergerThread[] threads, int treesLen) {
        /*@ ghost seq<bag<Node>> nodeBags = treeBags; @*/
        int i;
        /*@
            loop_invariant 0<=i && i<=treesLen-1;
            loop_invariant iterators.length == 2*treesLen - 1 && threads.length == treesLen-1;
            loop_invariant (\forall* int j; 0<=j && j<iterators.length; Perm(iterators[j], write));
            loop_invariant (\forall int j; 0<=j && j<iterators.length; iterators[j] != null);
            loop_invariant (\forall int j; 0<=j && j<iterators.length; 
                                (\forall int k; 0<=k && k<j ; 
                                    iterators[j] != iterators[k]));
            loop_invariant src_initial(iterators[2*treesLen-2]);
            loop_invariant (\forall* int j; 0<=j && j<threads.length; Perm(threads[j], write));
            loop_invariant (\forall int j; 0<=j && j<threads.length; threads[j] != null);
            loop_invariant (\forall int j; 0<=j && j<threads.length; 
                                (\forall int k; 0<=k && k<j; 
                                    threads[j] != threads[k]));
            loop_invariant (\forall* int j; 0<=j && j<threads.length; 
                                Perm(threads[j].src1, 1\2)
                                ** Perm(threads[j].src2, 1\2)
                                ** Perm(threads[j].dst, 1\2));
            loop_invariant (\forall int j; 0<=j && j<threads.length; 
                                threads[j].src1 == iterators[2*j]
                                && threads[j].src2 == iterators[2*j+1]
                                && threads[j].dst == iterators[treesLen+j]);

            loop_invariant |nodeBags| == treesLen + i;
            loop_invariant (\forall* int j; i<=j && j<threads.length; threads[j].join_token(write));
            loop_invariant (\forall* int j; 2*i<=j && j<treesLen+i;
                                equal_bags(nodeBags, iterators[j], j) );

            loop_invariant (\forall int j; 0<=j && j<treesLen;
                                nodeBags[j] == treeBags[j]);
            loop_invariant (\forall* int j; 0<=j && j<2*i; Perm(iterators[j].allP, 1\2));
            loop_invariant (\forall int j; 0<=j && j<i; bagSum(nodeBags, j, treesLen));
        @*/
        for (i=0; i<treesLen-1; i++) {
            threads[i].join();
            /*@ 
            unfold threads[i].post_join(write);

            ghost joinIteratorPC(iterators[2*i], nodeBags, 2*i);
            ghost joinIteratorPC(iterators[2*i+1], nodeBags, 2*i+1);

            /// add the iterator, which this thread produced to, to nodeBags
            ghost seq<bag<Node>> oldBag = nodeBags;
            unfold iterators[i+treesLen].producer() ;
            ghost seq<bag<Node>> new_bag = seq<bag<Node>>{iterators[i+treesLen].toBag()};
            fold iterators[i+treesLen].producer() ;
            ghost nodeBags = nodeBags + new_bag;
            
            /// bagSum still holds for all bags in the now extended nodeBags
            assert i>0 ==> bagSumExtensionLemma(oldBag, nodeBags, new_bag, i, treesLen);
                            
            /// let equal_bags refer to the new extended nodeBags
            ghost refoldEqual_bags(oldBag, nodeBags, iterators, i);
    
            fold equal_bags(nodeBags, iterators[i+treesLen], i+treesLen);
            @*/
        }
        
        /*@ 
        ghost int last_it = 2*treesLen-2;
        ghost merged = iterators[2*treesLen-2];
        assert transClosureBagSumLemma(nodeBags, treesLen);
        assert (\forall int j; 0<=j && j<treesLen;
                    (\forall Node n; 
                        {: (n \memberof nodeBags[j]) :} > 0;
                        (n \memberof nodeBags[last_it]) > 0));

        unfold equal_bags(nodeBags, merged, last_it);
        assert (\forall int j; 0<=j && j<treesLen;
                    (\forall Node n; 
                        {: (n \memberof nodeBags[j]) :} > 0;
                        (n \memberof merged.toBagP()) > 0));
        @*/
    }
    
    /// merge given array of RB trees into a single RB tree
    /*@
        requires trees != null ** trees.length > 0;
        requires (\forall* int i; 0<=i && i<trees.length; Perm(trees[i], 1\2));
        requires (\forall int i; 0<=i && i<trees.length; trees[i] != null);
        requires (\forall* int i; 0<=i && i<trees.length; Tree.tree_perm(trees[i]));
        requires (\forall* int i; 0<=i && i<trees.length; Tree.validTree( {: trees[i] :} ));
        ensures Tree.tree_perm(\result);
        ensures (\forall int j; 0<=j && j<trees.length;
                    (\forall Node n;
                        {: (n \memberof \old(Util.toBag(Tree.toSeq(trees[j])))) :} > 0;
                        (n \memberof Util.toBag(Tree.toSeq(\result))) > 0));
        ensures Tree.validTree(\result);
    @*/
    public Node mergeTrees(Node[] trees) {
        if (trees.length == 1) {
            /// trivial case, nothing to merge
            return trees[0];
        }

        int treesLen = trees.length;
        ListListQueue[] iterators = new ListListQueue[2*treesLen-1];
        ListMergerThread[] threads = new ListMergerThread[treesLen-1];

        /// bag representations of the given trees, and later also the joined lists
        /*@ ghost seq<bag<Node>> treeBags = toBags(trees); @*/
        /*@ ghost ListListQueue mergedIterator; @*/

        /// first few iterators traverse given trees
        initIterators(trees, iterators) /*@ with {treeBags=treeBags;} @*/;

        /// start threads to parallel merge trees, using remaining iterators
        startMergerThreads(iterators, threads, treesLen) /*@ with {treeBags=treeBags;} @*/;
        /// join threads when everything is merged completely
        joinMergerThreads(threads, treesLen) /*@ with {iterators=iterators; treeBags=treeBags;}
                                                             then {mergedIterator=merged;} @*/;

        int idxLastIterator = 2*treesLen-2;
        /*@
        assert (\forall int j; 0<=j && j<trees.length;
                    (\forall Node n;
                        {: (n \memberof \old(Util.toBag(Tree.toSeq(trees[j])))) :} > 0;
                        (n \memberof mergedIterator.toBagP()) > 0));
        @*/

        /// last iterator is completely merged list, turn it into RB tree
        ListListQueue merged = iterators[idxLastIterator];
        /*@ assert merged == mergedIterator; @*/
        /*@ unfold src_initial(merged); @*/
        Node res = merged.toTree();
        /*@ ghost merged.joinPC(); @*/

        return res;
    }
}



/**************
 Thread classes
 **************/

/// a thread whose task is converting an RB tree into a NodeList
final class TreeConverterThread /// extends VerCorsThread
{
    /// tree to be converted
    Node tree;
    /// ListListQueue containing the single NodeList that was the tree
    ListListQueue res;
    /*@
    /// sequence representation of the NodeList and its keys
    ghost seq<Node> treeAsSeq;
    ghost seq<int> keysAsSeq;
    @*/

    /// constructor
    /*@
        requires tree != null ** Tree.tree_perm(tree) ** Tree.validTree(tree);
        ensures pre_fork(write);
        ensures Perm(this.tree, 1\2) ** Perm(treeAsSeq, 1\2) ** Perm(keysAsSeq, 1\2);
        ensures treeAsSeq == \old(Tree.toSeq(tree))
                && keysAsSeq == \old(Tree.toSeqKeys(tree));
        ensures this.tree == tree;
    @*/
    public TreeConverterThread(Node tree) {
        this.tree = tree;
        this.res = null;
        /*@ ghost treeAsSeq = Tree.toSeq(tree); @*/
        /*@ ghost keysAsSeq = Tree.toSeqKeys(tree); @*/
        /*@ fold pre_fork(write); @*/
    }

    /// this threads main task: turning tree into ListListQueue
    /*@
        requires pre_fork(write);
        ensures post_join(write);
    @*/
    public void run() {
        /*@ unfold pre_fork(write); @*/
        res = new ListListQueue(new NodeListIterator(NodeList.fromTree(tree)));
        /*@ fold post_join(write); @*/
    }

    /// copy from VerCorsThread, adapted

    /*@
    public resource join_token(frac p)
        = true;
    public resource pre_fork(frac p)
        = Perm(res, write) ** Perm(tree, 1\2) ** tree != null ** Tree.tree_perm(tree)
          ** Perm(treeAsSeq, 1\2) ** Perm(keysAsSeq, 1\2)
           ** Tree.validTree(tree)
           ** treeAsSeq == Tree.toSeq(tree) ** keysAsSeq == Tree.toSeqKeys(tree); 
    public resource post_join(frac p)
        = Perm(res, write) ** res != null ** res.consumer() ** res.producer()
          ** Perm(treeAsSeq, 1\2) ** Perm(keysAsSeq, 1\2) 
          ** (\unfolding res.producer() 
                \in res.finalised ** res.allP == treeAsSeq ** res.keysP == keysAsSeq)
          ** \unfolding res.consumer() \in res.readHead==0 ** res.finalisedC; 
    @*/

    
    /*@
        /// given frac p;
        requires pre_fork(write);
        ensures  join_token(write);
    @*/
    public final void start();
  
    /*@
        /// given frac p;
        requires join_token(write);
        ensures  post_join(write);
    @*/
    public final void join();
    
}


/// thread whose main task is joining two ListListQueues into one
final class ListMergerThread /// extends VerCorsThread
{
    
    ListListQueue src1;
    ListListQueue src2;
    ListListQueue dst;
    
    /// constructor initialising the three ListListQueues (two sources, one destination)
    /*@
        requires Perm(Integer.MIN_VALUE, read);
        requires src1 != null && src2 != null && dst != null;
        requires src1.consumer() ** src2.consumer() ** dst.producer();
        requires \unfolding src1.consumer() \in src1.readHead == 0;
        requires \unfolding src2.consumer() \in src2.readHead == 0;
        requires dst.toBagP() == bag<Node>{};
        requires \unfolding dst.producer() \in !dst.finalised
                                                ** dst.maxKey == Integer.MIN_VALUE;
        ensures pre_fork(write);
        ensures Perm(this.src1, 1\2) ** Perm(this.src2, 1\2) ** Perm(this.dst, 1\2);
        ensures this.src1 == src1 ** this.src2 == src2 ** this.dst == dst;

    @*/
    public ListMergerThread(ListListQueue src1, ListListQueue src2, ListListQueue dst) {
        this.src1 = src1;
        this.src2 = src2;
        this.dst = dst;
        /*@ fold pre_fork(write); @*/
    }
    
    /// main task of the thread: join the two sources into one destination
    /*@ 
        requires pre_fork(write);
        ensures post_join(write);
    @*/
    public void run()
    {
        /*@ unfold pre_fork(write); @*/
        Node cur1 = null;
        Node cur2 = null;
        if (src1.hasNext()) {
            cur1 = src1.getNext();
        }
        if (src2.hasNext()) {
            cur2 = src2.getNext();
        }
        
        /// as long as both sources have elements, join them according to order of keys
        /*@
            loop_invariant Perm(src1, 1\2) ** Perm(src2, 1\2) ** Perm(dst, 1\2);
            loop_invariant src1 != null ** src2 != null ** dst != null;
            loop_invariant src1.consumer() ** src2.consumer() ** dst.producer();
            loop_invariant \unfolding dst.producer() \in !dst.finalised;
            loop_invariant cur1 != null ==> Node.node_perm(cur1, write);
            loop_invariant cur2 != null ==> Node.node_perm(cur2, write);
            /// cur1 and cur2 are the current element of their respective source
            loop_invariant cur1 != null ==> \unfolding src1.consumer() 
                                                \in src1.readHead>0 
                                                    ** cur1.key==src1.keysC[src1.readHead-1];
            loop_invariant cur2 != null ==> \unfolding src2.consumer() 
                                                \in src2.readHead>0 
                                                    ** cur2.key==src2.keysC[src2.readHead-1];
            loop_invariant cur1 == null ==> src1.done();
            loop_invariant cur2 == null ==> src2.done();
            /// all elements in dst are smaller than current elements, i.e. order is preserved
            loop_invariant \unfolding dst.producer() 
                                \in (cur1 != null ==> dst.maxKey <= cur1.key) 
                                    && (cur2 != null ==> dst.maxKey <= cur2.key);
            /// dst contains all Nodes consumed from either source (except the current Nodes)
            loop_invariant dst.toBagP() 
                              + (cur1==null ? bag<Node>{} : bag<Node>{cur1}) 
                              + (cur2==null ? bag<Node>{} : bag<Node>{cur2})
                           == src1.toBagC()
                              + src2.toBagC();
        @*/
        while (cur1 != null && cur2 != null) {
            /*@ ghost bag<Node> bd = dst.toBagP(); @*/
            /*@ ghost bag<Node> bs1 = src1.toBagC(); @*/
            /*@ ghost bag<Node> bs2 = src2.toBagC(); @*/
            if(cur1.key <= cur2.key) {
                /*@ ghost Node oldCur1 = cur1; @*/
                dst.append(cur1);
                if (src1.hasNext()) {
                    cur1 = src1.getNext();
                } else {
                    cur1 = null;
                }
                /*@ 
                assert bd + bag<Node>{oldCur1}
                          + (cur1==null ? bag<Node>{} : bag<Node>{cur1}) 
                          + bag<Node>{cur2}
                        == bd + bag<Node>{oldCur1} + bag<Node>{cur2}
                          + (cur1==null ? bag<Node>{} : bag<Node>{cur1});
                assert bs1 + bs2 + (cur1==null ? bag<Node>{} : bag<Node>{cur1})
                        == bs1 + (cur1==null ? bag<Node>{} : bag<Node>{cur1}) + bs2;
                @*/
            } else {
                /*@ ghost Node oldCur2 = cur2; @*/
                dst.append(cur2);
                if (src2.hasNext()) {
                    cur2 = src2.getNext();
                } else {
                    cur2 = null;
                }
                /*@ assert bd + bag<Node>{oldCur2} + bag<Node>{cur1}
                              + (cur2==null ? bag<Node>{} : bag<Node>{cur2}) 
                            == bs1 + bs2 + (cur2==null ? bag<Node>{} : bag<Node>{cur2}); @*/
            }
        }
        
        /// now at least one of the sources is exhausted, only need to iterate remaining source
        
        /// iterate src1
        /*@
            loop_invariant Perm(src1, 1\2) ** Perm(src2, 1\2) ** Perm(dst, 1\2);
            loop_invariant src1 != null ** src2 != null ** dst != null;
            loop_invariant src1.consumer() ** src2.consumer() ** dst.producer();
            loop_invariant \unfolding dst.producer() \in !dst.finalised;
            loop_invariant cur1 != null ==> Node.node_perm(cur1, write);
            loop_invariant cur2 != null ==> Node.node_perm(cur2, 1\2);
            /// at least one src is empty
            loop_invariant cur1 != null ==> cur2 == null;
            loop_invariant cur1 != null ==> \unfolding src1.consumer() 
                                                \in src1.readHead>0 
                                                    ** cur1.key==src1.keysC[src1.readHead-1];
            loop_invariant cur2 != null ==> \unfolding src2.consumer() 
                                                \in src2.readHead>0 
                                                    ** cur2.key==src2.keysC[src2.readHead-1];
            loop_invariant cur1 == null ==> src1.done();
            loop_invariant cur2 == null ==> src2.done();
            loop_invariant \unfolding dst.producer() \in 
                            cur1 != null ==> dst.maxKey <= cur1.key;
            loop_invariant \unfolding dst.producer() \in 
                            cur2 != null ==> dst.maxKey <= cur2.key;
            loop_invariant dst.toBagP() 
                              + (cur1==null ? bag<Node>{} : bag<Node>{cur1}) 
                              + (cur2==null ? bag<Node>{} : bag<Node>{cur2})
                           == src1.toBagC()
                              + src2.toBagC();
        @*/
        while (cur1 != null) {
            /*@ ghost bag<Node> bd = dst.toBagP(); @*/
            /*@ ghost bag<Node> bs1 = src1.toBagC(); @*/
            /*@ ghost bag<Node> bs2 = src2.toBagC(); @*/
            /*@ ghost Node oldCur1 = cur1; @*/
            /*@ assert bd + bag<Node>{oldCur1} == bs1 + bs2; @*/
            dst.append(cur1);
            if (src1.hasNext()) {
                cur1 = src1.getNext();
            } else {
                cur1 = null;
            }
            /*@ 
                assert bd + bag<Node>{oldCur1}
                          + (cur1==null ? bag<Node>{} : bag<Node>{cur1}) 
                        == bs1 + (cur1==null ? bag<Node>{} : bag<Node>{cur1}) + bs2;
                @*/
        }
        
        /// now src1 is definitely empty, iterate src2 if necessary
        /*@
            loop_invariant Perm(src1, 1\2) ** Perm(src2, 1\2) ** Perm(dst, 1\2);
            loop_invariant src1 != null ** src2 != null ** dst != null;
            loop_invariant src1.consumer() ** src2.consumer() ** dst.producer();
            /// src1 is definitely done
            loop_invariant src1.done();
            loop_invariant \unfolding dst.producer() \in !dst.finalised;
            loop_invariant cur2 != null ==> Node.node_perm(cur2, write);
            loop_invariant cur2 != null ==> \unfolding src2.consumer() 
                                                \in src2.readHead>0 
                                                    ** cur2.key==src2.keysC[src2.readHead-1];
            loop_invariant cur2 == null ==> src2.done();
            loop_invariant \unfolding dst.producer() \in 
                            cur2 != null ==> dst.maxKey <= cur2.key;
            loop_invariant dst.toBagP() 
                              + (cur2==null ? bag<Node>{} : bag<Node>{cur2})
                           == src1.toBagC()
                              + src2.toBagC();
        @*/
        while (cur2 != null) {
            dst.append(cur2);
            if (src2.hasNext()) {
                cur2 = src2.getNext();
            } else {
                cur2 = null;
            }
        }
        
        /// mark destination as complete
        dst.finalise();
        
        /*@ fold post_join(write); @*/
    }
    
    
    /// copied from VerCorsThread, adapted
    
    /*@ 
        public resource join_token(frac p)
            = true; 
        public resource pre_fork(frac p)
            = Perm(src1, 1\2) ** Perm(src2, 1\2) ** Perm(dst, 1\2) ** Perm(Integer.MIN_VALUE, read)
              ** src1 != null ** src2 != null ** dst != null
              ** src1.consumer() ** src2.consumer() ** dst.producer()
              ** (\unfolding src1.consumer() \in src1.readHead == 0)
              ** (\unfolding src2.consumer() \in src2.readHead == 0)
              ** (\unfolding dst.producer() \in !dst.finalised ** dst.maxKey == Integer.MIN_VALUE)
              ** src1.toBagC() == bag<Node>{} ** src2.toBagC() == bag<Node>{} 
              ** dst.toBagP() == bag<Node>{}
              ; 
        public resource post_join(frac p)
            = Perm(src1, 1\2) ** Perm(src2, 1\2) ** Perm(dst, 1\2)
              ** src1 != null ** src2 != null ** dst != null
              ** src1.consumer() ** src2.consumer() ** dst.producer()
            ** src1.done()
            ** src2.done()
            ** (\unfolding dst.producer() \in dst.finalised)
            ** dst.toBagP() == src1.toBagC() + src2.toBagC()
            ; 
    @*/


    /*@
        /// given frac p;
        requires pre_fork(write);
        ensures  join_token(write);
    @*/
    public final void start();

    /*@
        /// given frac p;
        requires join_token(write);
        ensures  post_join(write);
    @*/
    public final void join();

}

/// super class representing a thread that VerCors can verify. 
/// Only provided as a reference, as Merger uses specific adaptations of this class.
final class VerCorsThread {
  /*@ public resource join_token(frac p)=true; @*/
  /*@ public resource pre_fork(frac p)=true; @*/
  /*@ public resource post_join(frac p)=true; @*/


    /*@
        requires pre_fork(write);
        ensures post_join(write);
    @*/
    public void run();

  /*@
    given frac p;
    requires pre_fork(write);
    ensures  join_token(write);
  @*/
  public final void start();
  
  /*@
    given frac p;
    requires join_token(write);
    ensures  post_join(write);
  @*/
  public final void join();

}
