import time
import numpy as np
from scipy.sparse.linalg import svds
from sklearn.neighbors import NearestNeighbors
from ._cut_cluster_classify import iterative_cut_cluster, block_knn_classify


def _column1_subims(n_rows, n_cols, patch_size):
    """Returns the subscripts of all the patches in the first column.

     Parameters
     ----------
     n_rows : int
         number of rows of the image
     n_cols : int
         number of columns of the image
     patch_size : int
         dimension of the square patches

     Returns
     -------
     column1_subs : NumPy array
         (number of patches)*(patch_size)^2 by 2 matrix containing the
         subscripts of all the patches in the first column.
    """

    # Counting how many subimages in the first column
    rs = n_rows - patch_size + 1

    # Tile top left corner subimages of first column-subscript form
    col1_inds = range(0,rs)
    rand_sub1, rand_sub2 = np.unravel_index(col1_inds, (rs,n_cols), order='F')
    rand_subs = np.append(rand_sub1[:,np.newaxis], rand_sub2[:,np.newaxis], axis=1)
    ninds = np.tile(rand_subs, (1, patch_size**2))
    ninds = np.reshape(ninds,((patch_size**2)*rs,2)).T

    # Tile top-most and left-most corner subimage - subscript form
    temp1,temp2 = np.meshgrid(range(0,patch_size), range(0,patch_size), sparse=False)
    temp = np.append(temp1.flatten()[np.newaxis,:], temp2.flatten()[np.newaxis,:], axis=0).T
    base_ind = np.tile(temp, (rs, 1)).T

    # Use random permutation and tile to extract indices and calculate subimages
    column1_subs = ninds + base_ind

    return column1_subs


def _rand_indices(flat_imgs, n_rows, n_cols, patch_size, n_samples, sample_indices=0,
                  indices_shape=(0,0), stacking='side'):
    """Random collection of patches and their corresponding indices.

    Parameters
    ----------
    flat_imgs : NumPy array
        Matrix of shape = (n_rows*n_cols) by  n (i.e. number of sheets).
        Each column contains the vectorized version of a n_rows by n_cols
        image sheets. There are n (if RGB image, then 3 sheets) of these.
    n_rows : int
        Number of rows in the un-vectorized version of the image.
    n_cols : int
        Number of columns in the un-vectorized version of the image.
    patch_size : int
        Dimension of the square patches.
    n_samples : int
        Number of patches to be picked from each image in flat_imgs
    sample_indices : int or NumPy array, optional
    indices_shape : 2-tuple, optional
    stacking : str, optional
        String indicating how to do the stacking.
        If 'side', the sample patches will be size n_samples-by-n*(patch_size)^2.
        If 'up', sample patches will be of size n_samples-by-(patch_size)^2-by-n.

    Returns
    -------
    sample_patches : NumPy array
        Matrix containing the collection of patches.
    sample_inds : NumPy array
        Matrix of size = (n_samples, (patch_size)^2) containing the indices
        of the sample patches.
    """

    # Set up
    n_sheets = flat_imgs.shape[1]
    if (stacking !='up') & (stacking != 'side'):
        raise ValueError('Stacking must be either "side" or "up".')

    if type(sample_indices)==int:
        # Counting number of subimages to be collected
        rs = int(n_rows-patch_size +1)
        cs = int(n_cols-patch_size +1)
        # Tile random permutation for top left corner subimages -
        # subscript form
        np.random.seed(1)
        rand_inds = np.random.permutation(rs*cs)
        rand_inds = rand_inds[0:n_samples]
    else:
        rand_inds = sample_indices
        rs, cs = indices_shape

    rand_sub1,rand_sub2 = np.unravel_index(rand_inds,(rs,cs),order='F')
    rand_subs = np.append(rand_sub1[:,np.newaxis],rand_sub2[:,np.newaxis],
                          axis=1)
    ninds = np.tile(rand_subs, (1,patch_size**2))
    ninds = np.reshape(ninds,((patch_size**2)*n_samples,2)).T

    # Tile top-most and left-most corner subimage - subscript form
    temp1,temp2 = np.meshgrid(range(0,patch_size),range(0,patch_size), sparse=False)
    temp = np.append( temp1.flatten()[np.newaxis,:],
                      temp2.flatten()[np.newaxis,:], axis=0).T
    base_ind = np.tile(temp, (n_samples,1)).T

    # Use random permutation and tile to extract indices and calculate subimages
    sample_inds = ninds + base_ind
    sample_inds = np.ravel_multi_index(sample_inds, (n_rows,n_cols), order='F')
    sample_inds = np.reshape( sample_inds, (n_samples,patch_size**2) )

    if n_sheets == 1:
        sample_patches = np.squeeze( flat_imgs[sample_inds], axis=2 )
    else:
        sample_patches = np.zeros( (n_samples, patch_size**2, n_sheets) )
        for s in range(0,n_sheets):
            S = flat_imgs[:,s]
            sample_patches[:,:,s] = S[sample_inds]

    if (n_sheets > 1) & (stacking == 'side'):
        sample_patches = np.reshape( sample_patches.T,
                                     (n_sheets*(patch_size**2), n_samples) ).T

    return sample_patches, sample_inds


def segment(image, n_regions, patch_size, n_pca=4, n_samples=3000,
            n_near_neighbor_density=100, sample_indices=0,
            timeit=True, return_all_info=False):
    """Uses the Cut-Cluster-Classify algorithm to segment an image.

    This function segments in image by manifold learning in the space of
    patches extracted from the image using Cut-Cluster-Classify algorithm.

    For this, it first extracts `n_samples` number of random patches from
    the image (or uses the ones specified by `random_indices` array),
    and it clusters them into `n_regions` number of clusters (by forming
    a similarity graph and identifying its components). This requires
    estimating probability density of patches in patch space.
    Once the clusters/components are identified, they are used to classify
    the rest of the patches from the image, and all patches are given region
    labels. The label of a pixel is determined by voting from the patches
    containing the pixel. The pixel labels are returned as a `segmentation'
    array of integer labels, the same size as the original image.

    Parameters
    ----------
    image : NumPy array
        The image to be segmented. It can be a grayscale or RGB image,
        given in the form of a (m,n) or (m,n,3) shaped NumPy array.
    n_regions : int
        The number of regions or phases to segment the image into.
    patch_size : int
        Size of square patches of pixels, e.g patch_size=7 means 7x7 patches.
    n_pca : int, optional
        Number of Principal Components to take into account (from Principal
        Component Analysis (PCA)) in reducing dimensions of patches.
    n_samples : int, optional
        How many sample patches to collect for the clustering stage.
    n_near_neighbor_density : int, optional
        The number of nearest neighbors used for sample density estimation.
    sample_indices : integer or array of integers, optional
        Integer number if no good staring sample is known.
        Alternatively an array of (n_samples,) indices if a good starting
        sample is known (or desired).
    timeit : bool, optional
        Flag indicating whether or not to time the method.
    return_all_info : bool, optional
        Return additional information about the statistical estimation
        process (see below).

    Returns
    -------
    segmentation : NumPy array
        An array of integer region labels indicating which region the
        corresponding pixel belongs to.
    detailed_info : dict, optional
        If the input parameter `return_all_info` is set to True, then detailed
        information on the intermediate calculations of the algorithm is
        returned in a dictionary containing the following keys and data
        structures:
        'tally', a 3d array of the same shape as the image with n_regions
        channels at each pixel, recording the total times a vote was casted
        for each region by the voting algorithm, this can be used as a measure
        of how confident the algorithm is on the pixel label,
        'patch projection', an array of size (# patches, patch_size**2),
        containing projections of the patches onto the principal components,
        'patch label', an array containing the labels of the patches,
        'sample density', an array containing the sample density estimations
        for all patches,
        'sample projection', an array of size (# patches, n_pca) containing
        the projected sample patches onto the principal components,
        'sample label', an array containing containing the labels for all samples,
        'total time', time in seconds for the segmentation to be computed
        if `timeit` is True, otherwise it is 0.0.
    """

    #### Set up
    if image.ndim == 2:
        length, width = image.shape
        depth = 1
        flat_imgs = image.T.flatten()[:,np.newaxis]
    elif image.ndim == 3:
        length, width, depth = image.shape
        flat_imgs = np.reshape( image.T, (depth,length*width) ).T
    else:
        raise ValueError("The image should be 2 or 3-dimensional!")

    if timeit: # if you need to time method
        start_time = time.time()

    nrows = int( length - patch_size + 1 )
    ncols = int( width  - patch_size + 1 )

    #### The first set of subimages (normalized)
    T,IND = _rand_indices( flat_imgs, length, width, patch_size, n_samples,
                           sample_indices, (length-79,width-79), 'side' )
    mean_T = np.mean( T, axis=0 )
    sample = T - mean_T
    stds = np.std( sample )
    sample = sample / stds
    del T,IND

    #### Base projections
    U1,S1,V1 = svds( sample, n_pca ) ## Note: S1 comes in ascending order :()
    U1 = U1.dot( np.diag(S1) )
    V1 = V1.T

    del S1

    #### Calculating the density of sample set
    nbrs = NearestNeighbors( n_neighbors=n_near_neighbor_density,
                             algorithm='ball_tree' ).fit( U1 )
    dis,_ = nbrs.kneighbors(U1)
    dis = dis[:, n_near_neighbor_density-1 ]
    sample_density = 1 / dis[:,np.newaxis]

    del dis, nbrs, n_near_neighbor_density

    #### Cutting and Clustering of sample set
    clustNum1, _,_, sample_projection, rind = \
        iterative_cut_cluster( sample_density, n_samples/2, U1[:,0:n_pca],
                               n_regions, n_regions+1, 30 )
    sample = sample[rind,:]
    sample_density = sample_density[rind]

    #### Consistent Labeling: Sorting cluster label by average pixel mean
    mpixval = np.zeros((n_regions,))
    for i in range(0,n_regions):
        mpixval[i] = np.mean( sample[ np.where( clustNum1 == i+1 )[0], :] )

    tind = np.argsort( mpixval )
    sample_label = np.zeros( clustNum1.shape )

    for i in range(1,n_regions):
        sample_label[ np.where( clustNum1 == tind[i]+1 ) ] = i

    del sample, mpixval, tind, clustNum1, U1

    #### Block-wise: the rest of the subimages
    tally = np.zeros( (n_regions, length, width) )
    patch_label = np.zeros( (ncols*nrows,1) ).astype(int)
    patch_projection = np.zeros( (ncols*nrows, n_pca) )
    SUBS = _column1_subims( length, width, patch_size )
    block = np.append( np.zeros( (1,nrows*(patch_size**2)) ).astype(int),
                       np.ones(  (1,nrows*(patch_size**2)) ).astype(int), axis=0 )
    count = 0

    for j in range(0,ncols):
        Rsubs = SUBS + j*block
        if depth == 1:
            T = np.reshape( image[Rsubs[0,:],Rsubs[1,:]], (nrows, patch_size**2) )
        else:
            T = np.zeros( (nrows, depth*(patch_size**2)) )
            dcount = 0
            while dcount < depth:
                T[:, dcount*(patch_size**2):(dcount+1)*(patch_size**2) ] = \
                    np.reshape( image[ Rsubs[0,:], Rsubs[1,:], dcount ],
                                (nrows, patch_size**2) )
                dcount = dcount + 1;

        subimcor = T - mean_T
        subimcor = subimcor / stds

        # The PCA projection
        patch_projection[ count:count+nrows, :] = subimcor.dot(V1)

        # Classifying and plotting
        patch_label[ count:count+nrows ] = \
            block_knn_classify( sample_projection, sample_label,
                                patch_projection[count:count+nrows,:] )

        lbl = np.tile( patch_label[count:count+nrows],
                       (1,patch_size**2) ).flatten()

        # Update tally of votes for the pixels
        tally[ lbl, Rsubs[0,:], Rsubs[1,:] ] = tally[ lbl, Rsubs[0,:], Rsubs[1,:] ] + 1

        count = count + nrows
        del T, subimcor, lbl, Rsubs

    segmentation = np.argmax( tally, axis=0 )

    if timeit: # if you need to time method
        total_time = time.time() - start_time
    else:
        total_time = 0.0


    if not return_all_info:
        return segmentation

    else: # return_all_info
        detailed_info = {'tally': tally,
                         'total time': total_time,
                         'patch projection': patch_projection,
                         'patch label': patch_label,
                         'sample density': sample_density,
                         'sample projection': sample_projection,
                         'sample label': sample_label }

        return segmentation, detailed_info
