import ij.*;
import ij.plugin.filter.PlugInFilter;
import ij.process.*;
import ij.gui.*;
import java.awt.*;
import java.awt.image.*;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Random;
import java.util.Collections;

import static org.bytedeco.javacpp.opencv_core.*;
import static org.bytedeco.javacpp.opencv_imgproc.*;
import org.bytedeco.javacpp.indexer.*;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.opencv_ml.*;

public class Demo_Pixel_Replication implements PlugInFilter {

    int k;
    Random rnd;
    int Replicates = 5;
    int MaxIter = 100;

    // foreground points in ip
    ArrayList<Point2f> UniqPts;
    // pixel-replicated points in ip
    ArrayList<Point2f> PrPts;

    public Demo_Pixel_Replication() {
        UniqPts = new ArrayList<>();
        PrPts = new ArrayList<>();
        rnd = new Random();
    }

    public int setup(String arg, ImagePlus imp) {
        if (IJ.versionLessThan("1.37j"))
            return DONE;
        else    
            return DOES_ALL+DOES_STACKS+SUPPORTS_MASKING;
    }

    public boolean showDialog() {
        String[] kChoices = {"2", "3", "4", "5", "6"};

        GenericDialog gd = new GenericDialog("Number of ellipses");
        gd.addChoice("Number of ellipses:", kChoices, kChoices[2]);
        gd.addNumericField("Number of replicates", Replicates, 0);
        gd.addNumericField("Maximum number of iterations", MaxIter, 0);

        gd.showDialog();
        if (gd.wasCanceled()) {
            return false;
        } else {
            k = gd.getNextChoiceIndex() + 2;
            Replicates = (int) gd.getNextNumber();
            MaxIter = (int) gd.getNextNumber();
            return true;
        }
    }

    public void run(ImageProcessor ip) {
        if (!showDialog())
            return;

        IplImage ipl = ip2ipl(ip);
        Rectangle r = ip.getRoi();

        // use distance transform to populate UniqPts and PrPts
        CvMat distMat = DistTransform(ipl);
        cvReleaseImage(ipl);
        FloatIndexer distMatIdx = distMat.createIndexer();

        for (int y=r.y; y<(r.y+r.height); y++) {
            for (int x=r.x; x<(r.x+r.width); x++) {
                long n = (int) Math.round(distMatIdx.get(y,x));
                if (n > 0) {
                    Point2f p = new Point2f(x, y);
                    UniqPts.add(p);
                    long rep = Math.round((double)n/1.0);
                    // long rep = 1;
                    if (rep == 0) rep = 1;
                    for (int i = 0; i < rep; i++) {
                        PrPts.add(p);
                    }
                }
            }
        }

        IJ.log(String.format("%d unique points, %d pixel rep points", UniqPts.size(), PrPts.size()));

        // compute the gmm
        EM em = computeGMM(k);

        // create a new image showing the partition
        ImageProcessor ipOutPr = ip.duplicate();
        Mat PrMat = new Mat(1, 2, CV_32FC1);
        FloatIndexer PrMatIndexer = PrMat.createIndexer();
        int clrOffset = 256 / k;
        for (Point2f p : UniqPts) {
            PrMatIndexer.put(0, 0, p.x());
            PrMatIndexer.put(0, 1, p.y());
            Point2d pr = em.predict(PrMat);
            int val = ((int) pr.y() + 1) * clrOffset - 1;
            ipOutPr.set((int) p.x(), (int) p.y(), val);
        }

        ImagePlus segIP = new ImagePlus("Segmentation", ipOutPr);
        segIP.show();

        // draw ellipses around the regions found by the gmms
        Mat means = em.getMat("means");
        MatVector covs = em.getMatVector("covs");
        Overlay ellipses = new Overlay();
        for (int i = 0; i < k; i++) {
            // compute eigenvalues and eigenvectors of the covariance matrix
            Mat eigVal = new Mat();
            Mat eigVec = new Mat();
            int lo, hi;
            eigen(covs.get(i), eigVal, eigVec);
            DoubleIndexer valIdx = eigVal.createIndexer();
            if (valIdx.get(0,0) > valIdx.get(1,0)) {
                hi = 0;
                lo = 1;
            } else {
                hi = 1;
                lo = 0;
            }

            double A = Math.sqrt(valIdx.get(hi,0) * 20 / 3);
            double B = Math.sqrt(valIdx.get(lo,0) * 20 / 3);
            // double aspectRatio = B / A;

            EllipseRoi elRoi = makeEllipseRoi(means.row(i), eigVec.row(hi), A, B);
            // Mat end1 = new Mat();
            // Mat end2 = new Mat();
            // scaleAdd(eigVec.row(hi), A, means.row(i), end1);
            // scaleAdd(eigVec.row(hi), -A, means.row(i), end2);
            // DoubleIndexer end1Idx = end1.createIndexer();
            // DoubleIndexer end2Idx = end2.createIndexer();
            // EllipseRoi elRoi = new EllipseRoi(end1Idx.get(0,0), end1Idx.get(0,1), end2Idx.get(0,0), end2Idx.get(0,1), aspectRatio);
            elRoi.setStrokeWidth(2);
            elRoi.setStrokeColor(Color.red);
            ellipses.add(elRoi);
        }
        segIP.setOverlay(ellipses);
    }

    private IplImage ip2ipl(ImageProcessor src) {

        BufferedImage bi = src.getBufferedImage();
        return IplImage.createFrom(bi);
    }

    private CvMat DistTransform(IplImage iplIn) {
        IplImage iplOut = cvCreateImage(cvGetSize(iplIn), IPL_DEPTH_32F, 1);
        cvDistTransform(iplIn, iplOut);
        return iplOut.asCvMat();
    }

    private Mat PrPts2Mat(ArrayList<Point2f> PrPts, double pct) {
        ArrayList<Point2f> ShufPts = new ArrayList<Point2f>(PrPts);
        Collections.shuffle(ShufPts);
        // ArrayList<Point2f> SampPts = new ArrayList<Point2f>(ShufPts.subList(0, (int) (ShufPts.size() * pct)));
        // return PrPts2Mat(SampPts);
        return PrPts2Mat(new ArrayList<Point2f>(ShufPts.subList(0, (int) (ShufPts.size() * pct))));
    }

    private Mat PrPts2Mat(ArrayList<Point2f> PrPts) {
        Mat emMat = new Mat(PrPts.size(), 2, CV_32FC1);
        FloatIndexer emMatIndexer = emMat.createIndexer();
        for (int i = 0; i < PrPts.size(); i++) {
            Point2f p = PrPts.get(i);
            emMatIndexer.put(i, 0, p.x());
            emMatIndexer.put(i, 1, p.y());
        }

        return emMat;
    }

    private EM computeGMM(int k) {
        EM bestEm = new EM();
        double bestScore = -1e300;
        int bestI = 0;
        // Mat initMeans = randomInitMeans(k);

        for (int i = 1; i <= Replicates; i++) {
            IJ.log(String.format("replicate %d", i));
            EM em = new EM();
            Mat logLikelihoods = new Mat(PrPts.size(), 1, CV_64FC1);
            Mat labels = new Mat(PrPts.size(), 1, CV_32FC1);
            Mat covs0 = makeInitCovs(k);
            Mat weights0 = makeInitWeights(k);
            em.set("nclusters", k);
            em.set("covMatType", EM.COV_MAT_GENERIC);
            em.set("maxIters", MaxIter);
            em.set("epsilon", 1e-6);

            // Mat emMat = PrPts2Mat(PrPts, 0.10);
            Mat emMat = PrPts2Mat(PrPts);
            Mat initMeans = randomInitMeans(k);
            // em.trainE(emMat, initMeans, noArray());
            // if (!em.trainE(emMat, initMeans, new Mat(), new Mat(), logLikelihoods, labels, new Mat()))
            // try {
                IJ.log("calling trainE");
                if (!em.trainE(emMat, initMeans, covs0, weights0, logLikelihoods, labels, new Mat()))
                // if (!em.train(emMat, logLikelihoods, new Mat(), new Mat()))
                    IJ.log("trainE() returned false!");
            // } catch (Exception e) {
            //     IJ.log(String.format("exception calling trainE: %s", e.getMessage()));
            // }
            double score = addUp(logLikelihoods.col(0));
            IJ.log(String.format("score = %f", score));
            if (score > bestScore) {
                bestEm = em;
                bestScore = score;
                bestI = i;
            }
        }

        IJ.log(String.format("best replicate was %d", bestI));
        return bestEm;
    }

    private double addUp(Mat m) {
        DoubleIndexer idx = m.createIndexer();
        double total = 0;
        for (int r = 0; r < m.rows(); r++) {
            total += idx.get(r, 0);
        }

        return total;
    }

    private Mat makeInitCovs(int k) {
        // Mat covs0 = new Mat(1, 4*k, CV_64FC1); 
        // DoubleIndexer idx = covs0.createIndexer();
        // for (int i = 0; i < 4*k; i += 4) {
        //     idx.put(0, i,   1);
        //     idx.put(0, i+1, 0);
        //     idx.put(0, i+2, 0);
        //     idx.put(0, i+3, 1);
        // }

        // return covs0;
        return new Mat();
    }

    private Mat makeInitWeights(int k) {
        // Mat weights0 = new Mat(k, 1, CV_64FC1);
        // DoubleIndexer idx = weights0.createIndexer();
        // double weight = 1/k;

        // for (int r = 0; r < k; r++)
        //     idx.put(r, 0, weight);

        // return weights0;
        return new Mat();
    }

    private EllipseRoi makeEllipseRoi(Mat center, Mat unitVec, double A, double B) {
        Mat end1 = new Mat();
        Mat end2 = new Mat();

        scaleAdd(unitVec, A, center, end1);
        scaleAdd(unitVec, -A, center, end2);
        double aspectRatio = B / A;

        DoubleIndexer end1Idx = end1.createIndexer();
        DoubleIndexer end2Idx = end2.createIndexer();

        return new EllipseRoi(end1Idx.get(0,0), end1Idx.get(0,1), end2Idx.get(0,0), end2Idx.get(0,1), aspectRatio);
    }

    // Initialize the means by picking 2 distinct points at random from UniqPts.
    // We're using a simple naive algorithm since we assume k << n.
    private Mat randomInitMeans(int k) {
        HashSet<Integer> used = new HashSet<Integer>();
        Mat means = new Mat(k, 2, CV_64FC1);
        DoubleIndexer meansIdx = means.createIndexer();

        int i = 0;
        while (i < k) {
            int j = rnd.nextInt(UniqPts.size());
            if (!used.contains(j)) {
                Point2f p = UniqPts.get(j);
                meansIdx.put(i, 0, p.x());
                meansIdx.put(i, 1, p.y());
                used.add(j);
                IJ.log(String.format("random pt %d: (%.0f,%.0f)", i, p.x(), p.y()));
                i++;
            }
        }

        for (int r = 0; r < means.rows(); r++) {
            for (int c = 0; c < means.cols(); c++) {
                IJ.log(String.format("means(%d,%d) = %f", r, c, meansIdx.get(r, c)));
            }
        }

        return means;
    }
}
