# ################################################
#   
#   connectivityOperator.py - operators that support network connectivity calculations 
#
#    --------
#
#    Copyright (c) 2016 - N.J. Hardebol - N.J.Hardebol@tudelft.nl  - (a)
#     (a) Delft University of Technology, Department of Geotechnology,  Stevinweg 1, 2628 CN Delft, the Netherlands
#
#   ---------
# 
#   This program is free software: you can redistribute it and/or modify
#   it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see http://www.gnu.org/licenses/.
#
# #############################################

"""The connectivity Operations module provide classe(s) and (member) functions for window cursor operations:
    place window within digitizing surface, walk in given direction (rows and columns) with given step size, perform and collect (spatial/stistical calculations
"""

import os , sys
import math
import collections
import itertools

import logging

#import numpy
#import shapely
#from shapely import geos
import qgis.core as QgsCore

logger = logging.getLogger("DigiFractLogger.connectivityOperator")

# #####################################################################
#   The WindowCursor Operator for cursor window calculations                                                                                     #
# #####################################################################
class baseConnectivityOperator (object):
    """base class for connectivity operations, i.e. the operator"""
    def __init__ (self , _pLineamentLyr=None, injector=None,  producer=None, _mode='c' ):
        
        ##self.fracCurrIdx = 0
        #self.clusters = collections.OrderedDict()
        self.clusters = list()
        
        self.setSpatialIndex()
        
        if _pLineamentLyr:
            success = self.set_pLineaments(_pLineamentLyr)
    
    def setSpatialIndex(self):
        self.spatialIndex = QgsCore.QgsSpatialIndex()
    
    def set_pLineaments(self,  _pLineamentLyr):
        """Set a 'pointer' to the lineament Layer (QgsVectorLayer) which is held outside (and not owned by) the connectivity operator"""
        #-- the pointer should refer to a lineamentList that has at least one lineament.
        #--   to which more can be added later
        self.pLineaments = _pLineamentLyr
        
        return True
    
    def add_feat2spatialindex(self,  _feat):
        """add feature to spatial index in case of QgsSpatialIndex"""
        self.spatialIndex.insertFeature( _feat )
        return True     
    
    ##def add_feat2spatialindex(self,  _feat):
    ##    """add feature to spatial index in case of use of rtree module"""
    ##    self.spatialIndex.add( _feat )
    ##    return True     
    
    def set_precision(self,  _precision):
        self.precision = _precision
        
    def checkSet4clusters(self,  fracSet=None ):
        
        if not fracSet:
            fracSet = self.pLineaments.getFeatures()
        elif isinstance(fracSet,QgsCore.QgsVectorLayer ):
            fracSet = fracSet.getFeatures()
        
        for fracIdx2Check,  frac in enumerate(fracSet):
            self.checkFrac2clusters( fracIdx2Check )
        
        return True

    def checkFrac2clusters_fast(self,  fracIdx2Check ):
        """check for one fracture with index fracIdx2Check in self.pLineaments to whcih cluster it belongs (or 'cluster' of its own)"""
        #-- The new fracture at index fracIdx2Check in the self.pLineaments is checked
        _featlist = list(self.pLineaments.getFeatures())
        ##logger.debug("in checkFrac2clusters_fast() with self.pLineaments.featureCount(): %s and _featlist: %s" % ( self.pLineaments.featureCount(),  list(_featlist)) )
        
        self.clusters.append([fracIdx2Check])   #-- start by assuming that the fracture is a cluster of its own by assigning its own new cluster containing fracIdx2Check
        _currClusterIdx = len(self.clusters)-1
        geom2check = _featlist[fracIdx2Check].geometry().buffer(self.precision, 4)
        
        _clustersList = self.clusters
        intersectFracIdxs = self.spatialIndex.intersects(geom2check.boundingBox ())
        """
        use the SpatialIndex (R-Tree) to find possibly intersecting elements based on boundary box. It tgives boudnary box intersects not actual line.buffer interesection Only these need to be tested for intersection. The advantage is that you can use a separate R-Tree for each polyline and thus avoid detection of selfintersections, if needed.
        """
        intersectFracIdxs = [ idx for idx in intersectFracIdxs if geom2check.intersects( _featlist[idx].geometry()) ]
        ##logger.debug("in checkFrac2clusters_fast() with intersectFracIdxs: %s" % intersectFracIdxs )
        selectedClusters2Join = [icluster for icluster, fractIdxs in enumerate(_clustersList) if any(set(fractIdxs).intersection(intersectFracIdxs))]
        #logger.debug("in checkFrac2clusters_fast() with selectedClusters2Join: %s" % selectedClusters2Join )
        newClustersList = list(itertools.chain.from_iterable(_clustersList[icluster] for icluster in selectedClusters2Join))
        ##logger.debug("in checkFrac2clusters_fast() with newClustersList: %s" % newClustersList )
        if len(newClustersList): newClustersList = [newClustersList]
        else: newClustersList = list()
        
        other_clusters = set(range(_currClusterIdx+1)).difference( set(selectedClusters2Join) )
        newClustersList.extend([ _clustersList[icluster] for icluster in other_clusters])
        
        ##logger.debug("in checkFrac2clusters_fast() with newClustersList: %s" % newClustersList )
        self.clusters = newClustersList
        return True
    
    def checkFrac2clusters(self,  fracIdx2Check ):
        """check for one fracture with index fracIdx2Check in self.pLineaments to whcih cluster it belongs (or 'cluster' of its own)"""
        #-- The new fracture at index fracIdx2Check in the self.pLineaments is checked
        _list = list(self.pLineaments.getFeatures())
        
        _currClusterIdx = len(self.clusters)
        self.clusters.append([fracIdx2Check])   #-- start by assuming that the fracture is a cluster of its own by assigning its own new cluster containing fracIdx2Check

        geom2check = _list[fracIdx2Check].geometry().buffer(self.precision, 5)
        
        ##self.clusters[_currClusterIdx] = [fracIdx2Check]  
        Logger.info(" # connected clusters: %s %S#" % ( len(self.clusters), self.clusters) )
        if len(self.clusters) > 1:
            #for i_cluster in range(len(self.clusters)-2, -1,  -1 ): #-- walk backward through the clusters
            for i_cluster in range(len(self.clusters)-1, -1,  -1 ): #-- walk backward through the clusters
                ##logger.debug("in check4clusters with len(self.clusters): %s and i_cluster: %s" % (  len(self.clusters),  i_cluster ) )
                for i_frac in self.clusters[i_cluster]:
                    ##logger.debug("in check4clusters with fracIdx2Check: %s, i_frac: %s self.pLineaments.getFeatures(): %s" % ( fracIdx2Check, i_frac, list(self.pLineaments.getFeatures() ) ) ) 
                    ##logger.debug("in check4clusters with geom2check: %s" % geom2check.exportToWkt())
                    othergeom = _list[i_frac].geometry()
                    ##logger.debug("in check4clusters with othergeom: %s" % othergeom.exportToWkt())
                    if geom2check.intersects( othergeom ):
                        ##logger.debug("len(self.clusters): %s | i_cluster: %s | _currClusterIdx: %s" % ( len(self.clusters), i_cluster, _currClusterIdx ) )
                        self.clusters[i_cluster].extend( self.clusters[_currClusterIdx] )
                        del self.clusters[_currClusterIdx] 
                        #logger.debug("in check4clusters found an intersection with _currClusterIdx: %s and i_cluster: %s between fractures: %s x %s " % ( _currClusterIdx,  i_cluster, i_frac,  fracIdx2Check ))
                        #logger.debug("in check4clusters found an intersection with with resulting current clusters: %s" % ( [ "%s: %s |" % (i,  cluster) for i,  cluster in enumerate ( self.clusters ) ] ) )
                        _currClusterIdx  = i_cluster
                        break
                        #-- NJH 26-07
        logger.debug("established all clusters for current state of network with n_cluster: %s for n_fracture: %s" % ( len(self.clusters),  self.pLineaments.featureCount() ) )
        
        return True
        
    def assign_clusterIdx(self,  fieldIdx =2):
        self.pLineaments.startEditing()
        for icluster,  cluster in enumerate( self.clusters ):
            for fracIdx in cluster:
                self.pLineaments.changeAttributeValue( fracIdx,  fieldIdx,  icluster)
        self.pLineaments.commitChanges()
        return True
        
    def check2wells(self,  injector=None,  producer=None):
        conn2injector = set()
        conn2producer = set()
        
        if not injector: injector = self.injector
        if not producer: producer = self.producer
        #logger.debug("in check2wells with self.clusters: %s %s" % ( self.clusters,  len(self.clusters) ))
        #for i_cluster in range(len(self.clusters), -1,  -1 ): #-- walk backward through the clusters
        for i_cluster in range(len(self.clusters)-1, -1,  -1 ): #-- walk backward through the clusters
            bl_conn2injector   = False
            bl_conn2producer = False
            for i_frac in self.clusters[i_cluster]:
                _list = list( self.pLineaments.getFeatures() )
                geom2check = _list[i_frac].geometry()
                if not bl_conn2injector:
                    if geom2check.intersects( injector.buffer(self.precision, 5) ):
                        conn2injector.add(i_cluster)
                        bl_conn2injector = True
                if not bl_conn2producer:
                    if geom2check.intersects(producer.buffer(self.precision,  5)):
                        conn2producer.add(i_cluster)
                        bl_conn2producer = True
                        
                if bl_conn2injector and bl_conn2producer:
                    break
        
        wellsBridged = False
        _matched = conn2injector.intersection(conn2producer)
        if len(_matched) >= 1: wellsBridged = True
        return wellsBridged,  conn2injector,  conn2producer


class connectivityOperatorSingleLyr (baseConnectivityOperator):
    """base class for connectivity operations, i.e. the operator"""
    def __init__ (self , _pLineamentLyr=None, _mode='c',  injector=None,  producer=None ):
        
        self.set_mode(_mode,  _pLineamentLyr)
    
    def set_mode(self, _mode,  _pLineamentLyr):
        self.mode = _mode
        baseConnectivityOperator.__init__( self,  _pLineamentLyr )
    
    def set_injector(self,  _geom):
        self.injector = _geom
        return True
    
    def set_producer(self,  _geom):
        self.producer = _geom
        return True
        

