# PIXEL_REP_LIB houses all PixelReplication functions including GMM fitting
# and cluster drawing

# -------------------------------------------------------------------------------
# Copyright (c) 2016, Drexel University
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of PixelRep nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# -------------------------------------------------------------------------------


# PixelReplicate() accepts [a] logical image containing only one connected component.
# This function uses the pixel replication algorithm of replicating according the the
# respective depth of the pixel in the Euclidean distance transform.
def PixelReplicate(bwim):
    import numpy as np
    import cv2

    # Computer Euclidean Distance Transform
    print("Computing Distance Transform...")
    bwd = cv2.distanceTransform(bwim, distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_5)
    bwd = np.asmatrix(bwd, dtype=float)
    roundBwd = np.round(bwd)
    nPtsRep = np.sum(roundBwd, dtype=int) # total number of replicated points is the sum of the
                                          # rounded pixel depths in the distance transform
    ptsRep = np.zeros((nPtsRep.astype(int), 2)) # initialize replicated points array

    # Extract the points of the ellipse from the binary image
    r, c = np.where(bwim == 1)
    pts = np.asarray([r, c], dtype=int)
    pts = np.transpose(pts)
    nInsert = int(0)
    sizePts = pts.shape
    lengthPts = int(sizePts[0])

    # POINT REPLICATION STEP
    nrepTotal = []
    print("Replicating Pixels...")
    for i in np.arange(0, lengthPts, 1, dtype=int):
        idxPts = np.asmatrix([pts[i,0], pts[i,1]])
        nrep = np.round(bwd[idxPts[0, 0], idxPts[0, 1]])
        nrep = int(nrep)
        nrepTotal = np.append(nrepTotal,nrep)
        rp = np.matlib.repmat(idxPts, nrep, 1)
        if (nInsert + nrep) > nPtsRep:
            ptsRep = np.append(ptsRep, rp, axis=0)
        else:
            ptsRep[int(nInsert):int(nInsert + nrep), 0:] = rp
        nInsert = nInsert + nrep
    return ptsRep # returns replicated points(array)


# fitGMM() accepts [a] points replicated (array) and [b] k-number of clusters (int).
# It uses an Expectation Maximum algorithm designed to fit Gaussian mixture models (GMM).
# The GMM fitting initializes 10 times, with search iterations capped at 1000 max iterations
# or min tolerance of 1e-25. The covariance matrix type is "GENERIC" (symmetric and
# positive-definite matrix)
def fitGMM(ptsRep, k):
    import numpy as np
    import cv2
    print("Fitting GMM...")
    bestGMM = []
    bestLL = float("-inf") # initialize lowest possible log likelihood value (-inf)

    # Get unique points from replicated points
    coorPtsRep = [tuple(x) for x in ptsRep]
    uniqPts = sorted(set(coorPtsRep), key=lambda x: coorPtsRep.index(x))
    uniqPts = np.array(uniqPts)

    # GAUSSIAN MIXTURE MODEL FITTING
    for i in np.arange(0, 10, 1): # Initialization Loop
        # Random generation of indices into non-replicated points array
        rndIdx = np.random.randint(low=0, high=uniqPts.shape[0], size=(1, 1, k), dtype=int)
        rndMeans = np.matrix(uniqPts[np.array(rndIdx), 0:])

        # Initialize GMM model
        gmm0 = cv2.ml.EM_create()
        gmm0.setCovarianceMatrixType(2)
        gmm0.setClustersNumber(k)
        termCrit = (1000, 1000, 1e-6)
        gmm0.setTermCriteria(termCrit)

        # Train the GMM
        newGMM = gmm0.trainE(samples=ptsRep, means0=rndMeans)

        # Calculate Log Likelihood of fit
        LL = sum(newGMM[1]) # sum of log likelihoods per sample
        print('Likelihood Logarithm of GMM fit (Rep. ', i, ') = ', LL)

        # Select better fit using log likelihood comparison
        if LL > bestLL: # closer to zero Negative-LL is better
            bestGMM = gmm0
            bestLL = LL
    return bestGMM # return best gmm model (openCV EM class)


# drawClusters() accepts [a] gmm model, [b] the binary image, [c] the true edges of each
# ellipse. This function extracts ellipse points, clusters points using gmm model, and plots
# the clusters with the true boundaries drawn. This returns a bwLabel matrix which is the same
# size as the input image. Each pixel is labeled according which cluster it belongs to.
def drawClusters(gmm0, bwim, ptsEdge):
    import numpy as np
    import matplotlib.pyplot as plt

    # Get ellipse points
    r, c = np.where(bwim == 1)
    pts = np.asarray([r, c], dtype=int)
    pts = np.transpose(pts)

    # Use the em::predict2() method to cluster ellipse points
    print("Clustering...")
    clusters = -1*np.ones((pts.shape[0], 1))
    for i in np.arange(0, pts.shape[0], 1):
        prob = gmm0.predict2(sample=np.matrix(pts[i, 0:]))
        clusters[i] = prob[0][1] # Extract most probable cluster index
    clusters = clusters.astype(int)
    print("Clusters Found: ", np.unique(clusters)+1)

    # Create bwLabel matrix
    bwLabel = np.zeros((bwim.shape[0], bwim.shape[1]))
    print("Plotting...")
    for h in np.arange(0, np.max(clusters)+1, 1):
        clustIdx = np.where(clusters == h)
        bwLabel[np.array([pts[clustIdx[0], 0]]), np.array([pts[clustIdx[0], 1]])] = h + 1
    bwLabel = np.uint8(bwLabel)

    # Plot Result
    plt.figure()
    cmap = plt.cm.hsv
    cmap.set_under(color="black")# set zero valued pixels to black
    plt.imshow(bwLabel, cmap=cmap, vmin=0.00000001)
    plt.hold(True)

    # Plots boundaries if not all zero
    if sum(np.unique(ptsEdge)) != 0:
        plt.scatter(ptsEdge[0:, 1], ptsEdge[0:, 0], s=45, c='w', lw=0.5, alpha=1, edgecolors='w')

    # Configure figure display parameters
    plt.axis('off')
    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
    plt.margins(0)
    plt.autoscale(enable=True, axis='both', tight='True')
    plt.show()
    plt.hold(False)
    return bwLabel # returns bwLabel matrix (uint8)