import time
import numpy as np
from scipy import stats
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA, KernelPCA
from .CkNN_clustering import persistence_graph_cluster


def halton(base_p, n_quasirand):
    """Find a sequence of n quasi-random numbers of base p

    `halton` is a routine to find a sequence of n quasi-random numbers of
    base p. Its description is given in Numerical Analysis book by Tim Sauer:
    Program 9.1 Quasi-random number generator - Halton sequence in base p

    Parameters
    ----------
    base_p : integer
        Integer prime number which indicates base.
    n_quasirand : integer
        Integer number of random numbers to be generated.

    Returns
    -------
    quasirand_seq : NumPy array
        (n,)-shaped array of quasi-random numbers in the interval [0,1].

    Examples
    ---------
    >>> rand_seq = halton(2,100)
    """
    log = np.log
    eps = np.finfo(float).eps
    # largest number of digits
    b = np.zeros((int(np.ceil(log(n_quasirand+1)/log(base_p))),))
    quasirand_seq = np.zeros(n_quasirand,)

    for j in range(0,n_quasirand):
        i = 0
        b[0] = b[0] + 1 # add one to the current integer
        while b[i] > base_p - 1 + eps: # this loop does carrying in base p
            b[i] = 0
            i = i+1
            b[i] = b[i]+1

        quasirand_seq[j] = 0
        for k in range(0,len(b.flatten())): # add up reversed digits
            quasirand_seq[j] = quasirand_seq[j] + b[k] * base_p**(-(k+1))

    return quasirand_seq


def uniform_rejection_sample(given_density, n_samples):
    """Rejection sampling for uniform sampling in given density.

    `uniform_rejection_sample` is an implementation of the rejection sampling
    algorithm to ensure uniform sampling of data given a density.

    Parameters
    ----------
    given_density : array_like
        An array shaped (# of points,1) containing the density.
    n_samples : integer
        Integer number of samples to collect.

    Returns
    -------
    uniform_sample_indices : NumPy array
        An array of indices that make up the random uniform sampled collection.
    """
    ### Set up
    lq1 = len( given_density )
    p = np.random.permutation( lq1 )
    uniform_sample_indices = np.zeros( lq1, )
    count = 0

    ## Coefficient
    ub = np.min( given_density )
    ub = 1.0 / ub

    niter = 0
    while (count < n_samples) & (niter <= 10):
        k = 0
        while (k <= lq1-1) & (count < n_samples):
            R = np.random.rand(1)
            if (uniform_sample_indices[k] == 0) & (R < 1/(ub*given_density[p[k]])):
                uniform_sample_indices[count] = p[k]
                count = count + 1
            k = k + 1
        niter = niter + 1

    uniform_sample_indices = uniform_sample_indices[0:count].astype(int)

    return uniform_sample_indices


def iterative_cut_cluster(density, n_samples, data, n_clusters, max_clusters,
                          n_quasirand, metric='euclidean'):
    """Iterative thresholding given density to discover the clusters.

    `iteretive_cut_cluster` is the iterative routine that (1) thresholds
    according to the density, (2) clusters, and then (3) checks if the cluster
    is meaningful in the sense that it is not one point in one cluster and the
    rest in the remaining ones

    Parameters
    ----------
    density : array_like
        An array of length equal to the number of points, containing the
        density estimation.
    n_samples : integer
        Number of sample points to be picked from each collection of points.
    data : array_like
        Either N by d array of data points, or N by N array of pairwise distances.
    n_clusters : integer
        Number of clusters to be found.
    max_clusters : integer
        Upper bound for number of clusters to be found.
    n_quasirand : integer
        Number of quasirandom steps to take.
    metric : str, optional
        The metric to use when calculating distance between data points.
        Its default value is 'euclidean'. If metric is set to 'precomputed',
        data is assumed to be a distance matrix, and must be a square array.

    Returns
    -------
    clustNum : NumPy array
        Array of cluster ids for the points.
    sck : NumPy array
        A float-values array of three rows, each of which has entries as many
        as the number of iterations. Row 1 contains the threshold percentages.
        Row 2 contains the fraction of points in the  smallest cluster given
        the thresholds in Row 1. Row 3 contains the corresponding persistence
        of clusters = n_clusters.
    threshold : float
        The value of threshold in the final iteration.
    sample : NumPy array
        An array containing the subimages of the final iteration.
    sample_inds : NumPy array
        Array containing the indices of the samples.
    """


    ##### Set up
    sorted_density = np.sort( density, axis=0 )
    lq = len(density)
    ite = 0
    mit = 0
    pnum = n_samples

    N = data.shape[0]

    ##### Storage of peaks on # of points in smaller cluster + enough persistence
    b = 1 - (1/100.) * np.ceil(100*n_samples/float(N))
    ml = n_quasirand + 2
    prs = np.zeros((ml,))
    prs[0:2] = [0.01, b]
    prs[2:n_quasirand+2] = halton(2,n_quasirand)
    ck = np.zeros((3,ml))
    cut_summary = np.zeros((3,ml))
    tol = 10
    delt = 0.09

    ##### Bracket
    ### Compute # in smallest cluster + persistence
    while (mit < 30) & (tol > 0.10):
        threshold = sorted_density[ int(np.ceil(prs[ite]*lq)) ]
        q1 = np.where( np.squeeze( density >= threshold, axis=1 ) )[0]
        q2 = density[q1]

        # Getting a random subcollection via rejection sampling
        lq1 = len(q2);
        if lq1 < pnum:
            n_samples = lq1
            sample_inds = np.array(range(0,lq1))
        else:
            sample_inds = uniform_rejection_sample(q2,n_samples)

        sample_inds = q1[ sample_inds.astype(int) ]
        ll = int( np.ceil(0.25*n_samples) )

        if metric=='precomputed':
            ### Chosing the right distances
            sample = data[sample_inds,:][:,sample_inds]
            inds = np.argsort(sample,axis=1)
            grid = np.meshgrid(range(0,n_samples),range(0,n_samples))[1]
            Xsd = sample[ grid, inds ]
        else:
            sample = data[sample_inds,:]

            # Getting the distances to do the clustering
            nbrs = NearestNeighbors( n_neighbors=ll, metric = metric,
                                     algorithm='ball_tree' ).fit(sample)
            Xsd,inds = nbrs.kneighbors(sample)

        # clustering + persistance + # in smallest cluster
        numClusters,sample_labels,_,tpts,_ = \
            persistence_graph_cluster( Xsd, inds, max_clusters, makeplots=0 )

        # Persistence
        ck[0,mit] = prs[ite] # storing threshold value
        # Storing persistence
        ck[2,mit] = (tpts[n_clusters-2] - tpts[n_clusters-1]) / (n_samples*(n_samples-1)//2)

        # If persistence of one cluster number is one point
        tt = numClusters==n_clusters
        if np.sum(tt)>1:
            indxt = np.array( range(len(tt==1)) )[ tt==1 ]
            temp2 = sample_labels[:,indxt[0]]
        else:
            temp2 = sample_labels[:,numClusters==n_clusters]

        # Calculating # of points in smallest cluster
        s = np.zeros((n_clusters,))
        for j in range(0,n_clusters):
            s[j] = np.sum(temp2==(j+1))

        s = np.sort(s)
        # storing percentage of data in smallest cluster
        ck[1,mit] = s[0] / float(n_samples)

        ### Checking which threshold resulted in a pass/fail
        sind = np.argsort(ck[0,0:mit+1])
        cut_summary[0:3,0:mit+1] = ck[0:3,sind]
        dpp = cut_summary[1,0:mit] > 0.01 # Pass/fail

        ### The bracketing
        if (np.sum(dpp>0)>0): #If there at least one that passes
            indx = np.array( range(len(dpp==1)) )[dpp==1]
            idp = np.array( [ cut_summary[0,indx[0]-1], cut_summary[1,indx[0]] ] )
            delt = np.diff(idp)
            prs = delt * halton(2,n_quasirand) + idp[0]
            ite = 0

        # Check if any pass # of small
        if (ck[1,mit] > 0.01) :
            tol = delt

        ### If none passed, go to the next quasi-random number
        ite = ite +1
        mit = mit + 1;

    sample_labels = sample_labels[:, numClusters==n_clusters ]
    cut_summary = cut_summary[:,0:mit]

    return sample_labels, cut_summary, threshold, sample, sample_inds


def knn_classifier(reference_pts, reference_labels, query_pts, k, metric='euclidean'):
    """K-nearest neighbors classifier to classify the query points.

    `knn_classifier` uses the k-nearest neighbors algorithm to classify the
    query points based on the given reference points and their reference labels
    indicating which classes they belong to.

    Parameters
    ----------
    reference_pts : NumPy array
        An array containing the coordinates of all the reference points, or
        if metric is set to 'precomputed', then an array of distances with
        shape (number of querry points) by (number of reference points).
    reference_labels : NumPy array
        An array containing the corresponding class labels for the reference
        points
    query_pts : NumPy array
        An array containing the coordinates of all query points. If metric
        is set to 'precomputed', then `query_pts` is empty.
    k : integer
        Number of nearest neighbors to be considered.
    metric : str, optional
        The metric to use when calculating distance between data points.
        Its default value is 'euclidean'. If metric is set to 'precomputed',
        then reference_pts must be an array of distances with shape
        (number of querry points) by (number of reference points).

    Returns
    -------
    query_labels : NumPy array
        An array containing the predicted labels for evaluation points.
    """

    # If there are no reference points
    if np.min( reference_pts.shape ) == 0:
        # If there are no querry points
        if (type(query_pts) == int) | (type(query_pts) == float):
            query_labels = np.zeros((0,1))
        else:
            query_labels = np.zeros((query_pts.shape[0],0))
    else:
        # Finding the k nearest neighbor reference points
        if metric == 'precomputed':
            inds = np.argsort(reference_pts,axis=1)[:,0:k]
        else:
            nbrs = NearestNeighbors( n_neighbors=k, metric=metric,
                                     algorithm='ball_tree' ).fit( reference_pts )
            _, inds = nbrs.kneighbors( query_pts )

        # Finding to which cluster most of the k nearest neighbors belong to
        query_labels,_ = stats.mode( reference_labels[inds], axis=1 )
        query_labels = np.squeeze( query_labels, axis=1 )

    return query_labels


def block_knn_classify(reference_pts, reference_labels, query_pts, metric='euclidean'):
    """More efficient blockwise executed 4-nearest neighbor classification.

    `block_knn_classify` implements a 4-nearest neighbor classification
    algorithm in a block-wise manner for better efficiency.

    Parameters
    ----------
    reference_pts : NumPy array
        An array containing the coordinates of all the reference points, or
        if metric is set to 'precomputed', then an array of distances with
        shape (number of querry points) by (number of reference points).
    reference_labels : NumPy array
        An array containing the corresponding class labels for the reference
        points
    query_pts : NumPy array
        An array containing the coordinates of all query points. If metric
        is set to 'precomputed', then `query_pts` is empty.
    k : integer
        Number of nearest neighbors to be considered.
    metric : str, optional
        The metric to use when calculating distance between data points.
        Its default value is 'euclidean'. If metric is set to 'precomputed',
        then reference_pts must be an array of distances with shape
        (number of querry points) by (number of reference points).

    Returns
    -------
    query_labels : NumPy array
        An array containing the predicted labels for evaluation points.
    """
    if metric == 'precomputed':
        tsnum = reference_pts.shape[0]
    else:
        tsnum = query_pts.shape[0]

    query_labels = np.zeros((tsnum,1))
    numC = int( np.max(reference_labels) - np.min(reference_labels) + 1 )

    N = 5000 #how big the blocks are going to be

    if tsnum >= N:
        tsnum = int(np.floor(tsnum/N))
        for i in range(0,tsnum):
            query_labels[ N*i:N*(i+1) ] = \
                          knn_classifier( reference_pts, reference_labels,
                                          query_pts[N*i:N*(i+1),:], 5, metric=metric )

        query_labels[ N*(i+1):query_pts.shape[0] ] = \
                      knn_classifier( reference_pts, reference_labels,
                                      query_pts[ N*(i+1):query_pts.shape[0], : ],
                                      5, metric=metric )
    else:
        mn = 5;
        for i in range (1,numC):
            mn = int( np.min( [mn, np.sum(reference_labels==i)] ) )

        query_labels = knn_classifier( reference_pts, reference_labels,
                                       query_pts, mn, metric=metric )

    return query_labels


## def cut_cluster_classify(data, n_clusters, n_samples, knn_for_density=30,
##                          dim_reduction_alg='PCA', n_components=2,
##                          metric='euclidean', metric_params=None, algorithm='auto',
##                          leaf_size=30, p=None, n_jobs=1,timeit=1):
##     """CCC takes an image and performs the Cut-Cluster-Classify algorithm to
##     segment the image

##     Parameters
##     ----------
##     data :
##         number of points by dimension matrix containing
##                          the data or squared matrix of distances
##     n_clusters      - integer dictating how many clusters to look for
##     n_samples       - integer dictating how many samples to collect
##     knn_for_density - integer number of nearest neighbors for sample denstiy estimation
##     doPCA           - 1 if Principal Component Analysis (PCA) will be
##                          used to project down data (this is only
##                          recommended for high dimensional data);0 otherwise
##     n_pca           - number of principal components to use from PCA
##                          (only needed for doPCA = 1).
##     metric      - The metric to use when calculating distance between
##                      data points. If metric is "precomputed", data is assumed
##                      to be a distance matrix & must be square.
##     metric_params- Optional additional keyword arguments for the metric
##                      function.
##     algorithm   - string of which nearest neighbor distance finding
##                      algorithm to use. The ones avaialble are: ['auto',
##                      'ball_tree', 'kd_tree', 'brute']
##     leaf_size   - integer determining leaf size passed to BallTree or
##                      cKDTree. This can affect the speed of the construction and
##                      query, as well as the memory required to store the tree.
##                      The optimal value depends on the nature of the problem.
##     p           - float giving the power of the Minkowski metric to be used
##                      to calculate distance between points.
##     n_jobs      - optional integer giving number of parallel jobs to run. If
##                      -1, then the number of jobs is set to the number of CPU
##                      cores.
##     timeit      - binary number (0 or 1) dictating whether or not to
##                      time the method

##     Returns
##     -------
##     all_labels        - number of points by 1 array containing the
##                                calculated classification of the data
##     Tally             - number of points by number of clusters matrix
##                                with the total times a vote was casted for each
##                                data point
##     all_projection    - number of points by n_pca (if doPCA = 0, n_pca
##                                = data dimension) matrix containing the PCA
##                                coordinates (if doPCA=0, these are the original
##                                coordinates)
##     sample_density    - number of samples by 1 array containing the
##                                sample density estimation
##     sample_projection - number of samples by n_pca (if doPCA = 0, n_pca
##                                = data dimension) matrix containing the
##                                PCA coordinates (if doPCA=0, these are the
##                                original coordinates) of the sample points
##     sample_labels     - number of points by 1 cluster id vector
##     cut_summary       - 3 by number of iterations array containing:
##                                row 1 - threhold, row 2 - percent of data in
##                                smallest cluster, row 3 - persistence of
##                                n_clustersT = data[IND,:]
##     """

##     ######## Set up
##     lng,wd = data.shape

##     if timeit == 1: # if you need to time method
##         a = time.time()

##     ######## The first set of points (normalized)
##     IND = np.random.permutation(lng)[0:n_samples]

##     if metric == 'precomputed':
##         ### Note that since we're dealing with distances, there is no need for PCA projection
##         sample = data[IND,:][:,IND]
##         U1 = sample
##         n_pca=n_samples
##         sdis = np.sort(U1,axis=1)
##         sample_density= 1/(sdis[:,knn_for_density-1][:,np.newaxis])
##         del sdis
##     else:
##         T = data[IND,:]

##         mt = np.mean(T,axis=0)
##         sample = T- mt
##         stds = np.std(sample)
##         sample = sample/stds
##         del T,IND

##         ######## Base projections
##         if dim_reduction_alg == 'none':
##             U1 = sample
##             n_pca=wd
##         elif dim_reduction_alg == 'KernelPCA':
##             dim_algorithm = KernelPCA(kernel="rbf",n_components=n_components)
##             U1 = dim_algorithm.fit_transform(sample)
##         elif dim_reduction_alg=='PCA':
##             dim_algorithm = PCA(n_components)
##             U1 = dim_algorithm.fit_transform(sample)
##         else:
##             print('Currently only doing none, PCA, or KernelPCA')

##         ######## Calculating the density of sample set
##         neighbors_model = NearestNeighbors(n_neighbors=knn_for_density,
##                                            algorithm=algorithm,leaf_size=leaf_size,
##                                            metric=metric,metric_params=metric_params,
##                                            p=p,n_jobs=n_jobs)
##         nbrs = neighbors_model.fit(U1)
##         dis,_ = nbrs.kneighbors(U1)
##         dis = dis[:,knn_for_density-1];
##         sample_density = 1/dis[:,np.newaxis]

##         del dis, nbrs, knn_for_density

##     ######## Cutting and Clustering of sample set
##     clustNum1,cut_summary,thresh,sample_projection,rind = iterCutClus(sample_density,n_samples/2,U1[:,0:n_components],n_clusters,n_clusters+1,30,metric=metric)
##     sample = sample[rind,:]

##     ######## Consistent Labeling: Sorting cluster label by average pixel mean
##     mpixval = np.zeros((n_clusters,))
##     for i in range(0,n_clusters):
##         mpixval[i] = np.mean(sample[np.where(clustNum1==i+1)[0],:])

##     tind = np.argsort(mpixval)
##     sample_labels = np.zeros(clustNum1.shape)
##     for i in range(1,n_clusters):
##         sample_labels[np.where(clustNum1==tind[i]+1)] = i

##     del sample, mpixval, tind, clustNum1, U1

##     ######## Block-wise: the rest of the subimages
##     Tally = np.zeros((lng,n_clusters)).astype(int)
##     all_labels = np.zeros((lng,1)).astype(int)
##     all_projection = np.zeros((lng,n_components))

##     if wd>= 450:
##         block_size = 1250
##     else:
##         block_size = 4000

##     cs = lng/block_size

##     if metric == 'precomputed':
##         del sample_projection

##         sample_dis = data[IND[rind],:].T
##         for j in range(0,cs):
##             subimcor = sample_dis[block_size*j:block_size*(j+1),:]

##             ######## Note that since we're dealing with distances, there is no need for projections
##             ######## Classifying and plotting
##             all_labels[block_size*j:block_size*(j+1)] = block_knn_classify(subimcor,sample_labels,0,metric=metric)

##             ######## Tally
##             Tally[block_size*j:block_size*(j+1),all_labels[block_size*j:block_size*(j+1)]] = Tally[block_size*j:block_size*(j+1),all_labels[block_size*j:block_size*(j+1)]] + 1

##             del subimcor


##         subimcor = sample_dis[cs*block_size:lng,:]

##         ######## Note that since we're dealing with distances, there is no need for projections
##         ######## Classifying and plotting
##         all_labels[cs*block_size:lng] = block_knn_classify(subimcor,sample_labels,0,metric=metric)

##         ######## Tally
##         Tally[cs*block_size:lng,all_labels[cs*block_size:lng]] = Tally[cs*block_size:lng,all_labels[cs*block_size:lng]] + 1

##         del subimcor
##         sample_projection = np.zeros((0,n_samples))
##         all_projection = np.zeros((0,data.shape[0]))
##     else:
##         for j in range(0,cs):

##             T = data[block_size*j:block_size*(j+1),:]
##             subimcor = T-mt
##             subimcor = subimcor/stds

##             ### The PCA projection
##             if dim_reduction_alg == 'none':
##                 all_projection[block_size*j:block_size*(j+1),:] = subimcor
##             else:
##                 all_projection[block_size*j:block_size*(j+1),:] = dim_algorithm.transform(subimcor)

##             ######## Classifying and plotting
##             all_labels[block_size*j:block_size*(j+1)] = block_knn_classify(sample_projection,sample_labels,all_projection[block_size*j:block_size*(j+1),:])

##             ######## Tally
##             Tally[block_size*j:block_size*(j+1),all_labels[block_size*j:block_size*(j+1)]] = Tally[block_size*j:block_size*(j+1),all_labels[block_size*j:block_size*(j+1)]] + 1

##             del T, subimcor


##         T = data[cs*block_size:lng,:]
##         subimcor = T-mt
##         subimcor = subimcor/stds

##         ### The PCA projection
##         if dim_reduction_alg == 'none':
##             all_projection[cs*block_size:lng,:] = subimcor
##         else:
##             all_projection[cs*block_size:lng,:] = dim_algorithm.transform(subimcor)

##         ######## Classifying and plotting
##         all_labels[cs*block_size:lng] = block_knn_classify(sample_projection,sample_labels,all_projection[cs*block_size:lng,:])

##         ######## Tally
##         Tally[cs*block_size:lng,all_labels[cs*block_size:lng]] = Tally[cs*block_size:lng,all_labels[cs*block_size:lng]] + 1

##         del T, subimcor

##     if timeit==1: # if you need to time method
##         total_time = time.time()-a
##     else:
##         total_time = np.zeros((1,0))

##     sample_density = sample_density[rind]

##     return all_labels,Tally,all_projection,sample_density,sample_projection,sample_labels,cut_summary,thresh,total_time
