import numpy as np
import networkx as nx
from sklearn.neighbors import NearestNeighbors


def persistence_graph_cluster(k_dist, k_inds, max_clusters, makeplots=0):
    """Clustering of graph nodes using persistence to decide clusters.

    This function creates a graph using the inter-node distances given
    in `k_dist`, and clusters the nodes based on these distances.
    It uses the persistence principle to decide the number of clusters.
    Namely, the persistent clusters, which remain through a larger range
    of thresholds/cut-offs of the point density, are considered to be
    the correct clusters.

    Parameters
    ----------
    k_dist : NumPy array
        An  N by k matrix of weighted distances (of N points to k nearest
        neighbors).
    k_inds : NumPy array
        An N by k matrix of indices of the k nearest neighbors of each of
        N points.
    max_clusters : integer
        Maximum number of clusters to consider. So the number of clusters
        discovered can be at most `max_clusters`.
    makeplots : integer, optional
        Non-negative figure number to plot persistence. If no persistence
        plot is needed, then it should be set to 0.

    Returns
    -------
    all_clusters : NumPy array
        Array storing number of clusters corresponding to all_midpoints.
    all_labels : NumPy array
        An array of shape (N, (max_clusters-1)) storing the cluster labels
        for each of the N points and for each number of clusters from 2
        to max_clusters.
    all_midpoints : NumPy array
        Array storing the values of global scaling giving each # of clusters.
    all_transitions : NumPy array
        Array of values of global scaling where # of clusters changes.
    sorted_dist ; NumPy array
        Sorted list of weights of unique edges corresponding to edge numbers
        in `all_midpoints` and `all_transitions`.
    """

    ### Finding all edges
    N,k = np.shape( k_dist )
    dii = np.argsort( (k_dist.T).flatten() )
    sorted_dist = np.sort( (k_dist.T).flatten() )

    _,isc = np.unravel_index( dii, [k, N] )

    jsc = k_inds.T.flatten()[dii]
    alledges = np.append( isc[:,np.newaxis], jsc[:,np.newaxis], axis=1 )
    alledges = np.sort( alledges, axis=1 )

    ### Sorting edges
    b = np.ascontiguousarray( alledges ).view( \
                np.dtype( (np.void, alledges.dtype.itemsize * alledges.shape[1]) ))
    _, ia = np.unique( b, return_index=True )

    ia = np.sort(ia)
    alledges = alledges[ia,:]  ### list of unique edges in order of weights

    sorted_dist = sorted_dist[ia]   ### corresponding weights

    ### Allocating memory
    numEdges = np.shape(alledges)[0]
    numClus = 1
    numClusSteps = list()
    numEdgesSteps = list()
    numClusSteps.append( numClus )
    numEdgesSteps.append( numEdges )
    # allnodes = len( np.unique(alledges) )
    allnodes = np.unique(alledges)

    ### Graph set up
    # g = ig.Graph()
    # g.add_vertices( allnodes )
    g = nx.Graph()
    g.add_nodes_from( allnodes )
    n0 = int( np.floor( numEdges/2.0 ))
    # g.add_edges( alledges[0:n0,0:2] )
    g.add_edges_from( alledges[0:n0,0:2] )

    ### Persistence Steps: Finding a bracket to optimize bisection
    while numClus < max_clusters:
        numClus = nx.number_connected_components( g ) #len( list(g.components()) )
        numClusSteps.append( numClus )
        numEdgesSteps.append( n0 )

        numEdges = int( np.floor(n0/2.0) )
        # g.delete_edges( tuple( map(tuple, alledges[numEdges:n0,0:2]) ) )
        g.remove_edges_from( tuple( map(tuple, alledges[numEdges:n0,0:2]) ) )
        n0 = numEdges

    ##***
    numClusSteps  = np.array( numClusSteps  ).astype(float)
    numEdgesSteps = np.array( numEdgesSteps ).astype(float)

    min_clusters = np.min( numClusSteps )

    ### Persistence Steps: Find all the transition points via bisection method
    all_transitions = np.zeros( [ int(max_clusters - min_clusters), 1] )

    gl = nx.Graph() # ig.Graph()
    gu = nx.Graph() # ig.Graph()
    for i in range(0, int(max_clusters - min_clusters) ):
        # Start with the brackets from previous step
        del gl,gu
        gl = nx.Graph() # ig.Graph()
        gu = nx.Graph() # ig.Graph()
        gl.add_nodes_from( allnodes ) # gl.add_vertices( allnodes )
        gl.add_nodes_from( allnodes ) # gu.add_vertices( allnodes )
        clusterNumber = i + min_clusters + 1
        lb = int( np.max( numEdgesSteps[ numClusSteps >= clusterNumber ] ) )
        ub = int( np.min( numEdgesSteps[ numClusSteps <  clusterNumber ] ) )
        all_transitions[i] = np.floor( (lb + ub)/2. )

        # Bisection set up
        # gl.add_edges(alledges[0:lb,0:2])
        # gu.add_edges(alledges[0:ub,0:2])
        gl.add_edges_from( alledges[0:lb,0:2] )
        gu.add_edges_from( alledges[0:ub,0:2] )

        ub = np.array( [ub], dtype='int' )
        lb = np.array( [lb], dtype='int' )

        # Bisection with "upper" and "lower" graphs: different from original
        while (ub - lb > 1):
            # Choice: remove half of the edges and check number of components
            numEdges = int( all_transitions[i] )
            # gu.delete_edges( tuple( map(tuple, alledges[numEdges:int(ub[0]),0:2]) ) )
            gu.remove_edges_from( tuple( map(tuple, alledges[numEdges:int(ub[0]),0:2]) ) )
            numClus = nx.number_connected_components( gu ) # len( list(gu.components()) )

            # Record results
            numClusSteps  = np.append( numClusSteps, numClus )
            numEdgesSteps = np.append( numEdgesSteps, numEdges )
            sinds = np.argsort( numEdgesSteps )
            numEdgesSteps = np.sort( numEdgesSteps )
            numClusSteps  = numClusSteps[sinds]

            # Checking what's the next step

            # went too far, we need to instead move the lower bound up
            if (numClus >= clusterNumber):
                gl.add_edges_from( alledges[int(lb[0]):numEdges,0:2] )
                gu.add_edges_from( alledges[numEdges:int(ub[0]),0:2] )
                lb = np.zeros((1,))
                lb[0] = all_transitions[i].astype(int)
            else: # the new point is the new upper bound
                ub = np.zeros((1,))
                ub[0] = all_transitions[i].astype(int)

            ### log scale bisection method
            all_transitions[i] = np.floor( (lb + ub)/2. )

            # Making persistence plot
            if (makeplots):
                import matplotlib.pyplot as plt
                plt.figure(makeplots)
                plt.clf()
                plt.semilogx( numEdgesSteps, numClusSteps )
                plt.semilogx( all_transitions[0:i+1], range(0,i+1) + min_clusters + 1,
                              marker='o', color='r', linestyle='none' )
                plt.xlim( np.min(numEdgesSteps), 1.5*np.max(all_transitions) )
                plt.show()

    ### Set up for finding mid points and producing the clusters from these pts
    steps = int( max_clusters - min_clusters - 1 )
    all_labels = np.zeros((N,steps))
    all_clusters = np.zeros((steps,))
    all_midpoints = np.zeros((steps,))

    if (min_clusters > 1):
        all_transitions = np.append( np.max(numEdgesSteps), all_transitions, axis=0 )
        steps = steps + 1

    ### Cluster at the midpoints of the persistence regions
    for i in range(0,steps):

        lb = np.zeros((1,))
        lb[0] = all_transitions[i]
        ub = np.zeros((1,))
        ub[0] = all_transitions[i+1]

        ## log scale midpoint
        all_midpoints[i] = np.floor( (lb + ub)/2. )
        del g
        g = nx.Graph() # ig.Graph()
        g.add_nodes_from( allnodes ) # g.add_vertices( allnodes )
        g.add_edges_from( alledges[0: all_midpoints[i].astype(int),:] )

        # The actual clusters
        # temp = list( g.components() )
        temp = list( nx.connected_components( g ) )
        for j in range(0,len(temp)):
            temp1 = list(temp[j]) # = temp[j]
            all_labels[temp1,i] = j+1

        all_clusters[i] = len(temp)

    all_clusters = all_clusters.astype(int)

    return all_clusters, all_labels, all_midpoints, all_transitions, sorted_dist


def persistent_clustering(data, is_neighbor=16, knn_dist=100, max_clusters=5,
                          required_nclusters=0, metric='euclidean', metric_params=None,
                          algorithm='auto', leaf_size=30, p=None, n_jobs=1, makeplots=0):
    """Finds the persistent clusters of the points in the given data..

    Given the points in `data`, `persistent_clustering` constructs an adjacency
    graph of the points, and uses the connected components of the graph to cluster
    the points. The number of clusters is decided using the persistence plot.

    Parameters
    ----------
    data : Numpy array or tuple
        Either an N by d matrix of data points, or N by N matrix of pairwise
        distances, or a tuple of sorted distances and corresponding indices.
    is_neighbor : integer, optional
        Integer indicating what a neighbor will be (4 <= is_neighbor <= 16).
    knn_dist : integer, optional
        Integer used for knn distance finding (is_neighbor <= knn_dist <= 0.1*N).
    max_clusters : integer, optional
        Maximum number of clusters to consider (should be greater or equal to 2).
    required_nclusters : integer, optional
        Non-negative integer indicating the required/desired number of clusters
        to look for. If 0, then return array `required_label` will be empty,
        but it will still return the suggested number of clusters.
    metric : str, optional
        The metric to use when calculating distance between data points.
        If metric is 'precomputed', data is assumed to be a distance matrix.
        If metric is 'knn', data needs to be a tuple with the first element
        being the N by k matrix of sorted distances of the first k nearest
        neighbors of all the points.
    metric_params : str, optional
        Optional additional keyword arguments for the metric function.
    algorithm : str, optional
        String indicating which nearest neighbor distance finding algorithm
        to use. The ones available are: 'auto', 'ball_tree', 'kd_tree', 'brute'.
    leaf_size : integer, optional
        Leaf size passed to BallTree or cKDTree. This can affect the speed of
        the construction and query, as well as the memory required to store the
        tree. The optimal value depends on the nature of the problem.
    p : float, optional
        The power of the Minkowski metric used to calculate distance between points.
    n_jobs : integer, optional
        Optional integer giving number of parallel jobs to run. If it is set
        to -1, then the number of jobs is set to the number of CPU cores.
    makeplots : integer, optional
        Non-negative integer indicating the figure number for the persistence
        plot. If 0, no plot.

    Returns
    -------
    required_label : Numpy array
        N-by-1 integer array with labels of all points according to the
        required number of clusters given. Empty if required_nclusters = 0.
    persistent_label : Numpy array
        N-by-1 integer array with labels of all points according to the
        suggested number of clusters from the persistence algorithm.
    persistent_nclusters : integer
        Suggested number of clusters inherent to the data, based on the
        persistence principle.
    all_clusters : Numpy array
        Integer array of the number of clusters found.
    all_labels : Numpy array
        N-by-(max_clusters-1) array, storing the cluster number for each
        point per each number of clusters from 2 to max_clusters.
    """

    ### Loading data and normalizing it
    if metric == 'knn':
        ds,di = data;
    elif metric == 'precomputed':
        di = np.argsort( data, axis=1 )
        ys = range(0, data.shape[0])
        _,ys = np.meshgrid(ys,ys)
        ds = data[ys,di]
    else:
        neighbors_model = NearestNeighbors( n_neighbors=knn_dist, algorithm=algorithm,
                                            leaf_size=leaf_size,   metric=metric,
                                            metric_params=metric_params, p=p,
                                            n_jobs=n_jobs )
        nbrs = neighbors_model.fit(data)
        ds, di = nbrs.kneighbors(data)

    # scale data
    dst = ds[:,is_neighbor]
    ds = ds/np.sqrt( dst[:,np.newaxis] * dst[di] )

    ### Clustering Algorithm
    all_clusters,all_labels,_,tpts,_ = \
           persistence_graph_cluster( ds, di, max_clusters, makeplots )

    ### Getting the right cluster
    persistent_nclusters = np.argmax( -np.diff(np.squeeze(tpts,1)) ) + 2
    persistent_label = np.argmax( all_clusters == persistent_nclusters )
    persistent_label = all_labels[:, persistent_label ].astype(int)
    print('Most persistence found with number of clusters = ' , persistent_nclus )

    if required_nclusters > 0:
        rc = np.argmax( all_clusters == required_nclusters )
        required_label = all_labels[:,rc].astype(int)
    else:
        required_label = np.array([])

    return required_label, persistent_label, persistent_nclusters, all_clusters, all_labels
