diff --git a/src/ImageJ/Demo_Pixel_Replication.class b/src/ImageJ/Demo_Pixel_Replication.class
index aae38ecec2287eee0c651123e85a142620694308..4fb78d61790c294acb0c5eec90bc6d1b5026c554 100644
Binary files a/src/ImageJ/Demo_Pixel_Replication.class and b/src/ImageJ/Demo_Pixel_Replication.class differ
diff --git a/src/ImageJ/Demo_Pixel_Replication.java b/src/ImageJ/Demo_Pixel_Replication.java
index 85f5fdb2895ff1605339231fb9eba5458af83733..345dd02aadc4390a181ae2aaf0bbac12da4a316d 100755
--- a/src/ImageJ/Demo_Pixel_Replication.java
+++ b/src/ImageJ/Demo_Pixel_Replication.java
@@ -21,6 +21,8 @@ public class Demo_Pixel_Replication implements PlugInFilter {
 
     int k;
     Random rnd;
+    int Replicates = 5;
+    int MaxIter = 100;
 
     // foreground points in ip
     ArrayList<Point2f> UniqPts;
@@ -44,12 +46,17 @@ public class Demo_Pixel_Replication implements PlugInFilter {
         String[] kChoices = {"2", "3", "4", "5", "6"};
 
         GenericDialog gd = new GenericDialog("Number of ellipses");
-        gd.addChoice("Number of ellipses:", kChoices, kChoices[0]);
+        gd.addChoice("Number of ellipses:", kChoices, kChoices[2]);
+        gd.addNumericField("Number of replicates", Replicates, 0);
+        gd.addNumericField("Maximum number of iterations", MaxIter, 0);
+
         gd.showDialog();
         if (gd.wasCanceled()) {
             return false;
         } else {
             k = gd.getNextChoiceIndex() + 2;
+            Replicates = (int) gd.getNextNumber();
+            MaxIter = (int) gd.getNextNumber();
             return true;
         }
     }
@@ -72,8 +79,8 @@ public class Demo_Pixel_Replication implements PlugInFilter {
                 if (n > 0) {
                     Point2f p = new Point2f(x, y);
                     UniqPts.add(p);
-                    // long rep = Math.round((double)n/10.0);
-                    long rep = 1;
+                    long rep = Math.round((double)n/1.0);
+                    // long rep = 1;
                     if (rep == 0) rep = 1;
                     for (int i = 0; i < rep; i++) {
                         PrPts.add(p);
@@ -176,42 +183,83 @@ public class Demo_Pixel_Replication implements PlugInFilter {
     private EM computeGMM(int k) {
         EM bestEm = new EM();
         double bestScore = -1e300;
+        int bestI = 0;
+        // Mat initMeans = randomInitMeans(k);
 
-        for (int i = 1; i <= 10; i++) {
+        for (int i = 1; i <= Replicates; i++) {
             IJ.log(String.format("replicate %d", i));
             EM em = new EM();
             Mat logLikelihoods = new Mat(PrPts.size(), 1, CV_64FC1);
             Mat labels = new Mat(PrPts.size(), 1, CV_32FC1);
+            Mat covs0 = makeInitCovs(k);
+            Mat weights0 = makeInitWeights(k);
             em.set("nclusters", k);
             em.set("covMatType", EM.COV_MAT_GENERIC);
-            em.set("maxIters", 1000);
+            em.set("maxIters", MaxIter);
             em.set("epsilon", 1e-6);
 
             // Mat emMat = PrPts2Mat(PrPts, 0.10);
             Mat emMat = PrPts2Mat(PrPts);
             Mat initMeans = randomInitMeans(k);
             // em.trainE(emMat, initMeans, noArray());
-            em.trainE(emMat, initMeans, new Mat(), new Mat(), logLikelihoods, labels, new Mat());
-            double score = addUp(logLikelihoods.row(0));
+            // if (!em.trainE(emMat, initMeans, new Mat(), new Mat(), logLikelihoods, labels, new Mat()))
+            // try {
+                IJ.log("calling trainE");
+                if (!em.trainE(emMat, initMeans, covs0, weights0, logLikelihoods, labels, new Mat()))
+                // if (!em.train(emMat, logLikelihoods, new Mat(), new Mat()))
+                    IJ.log("trainE() returned false!");
+            // } catch (Exception e) {
+            //     IJ.log(String.format("exception calling trainE: %s", e.getMessage()));
+            // }
+            double score = addUp(logLikelihoods.col(0));
             IJ.log(String.format("score = %f", score));
             if (score > bestScore) {
                 bestEm = em;
                 bestScore = score;
+                bestI = i;
             }
         }
 
+        IJ.log(String.format("best replicate was %d", bestI));
         return bestEm;
     }
 
     private double addUp(Mat m) {
         DoubleIndexer idx = m.createIndexer();
         double total = 0;
-        for (int c = 0; c < m.cols(); c++)
-            total += idx.get(0, c);
+        for (int r = 0; r < m.rows(); r++) {
+            total += idx.get(r, 0);
+        }
 
         return total;
     }
 
+    private Mat makeInitCovs(int k) {
+        // Mat covs0 = new Mat(1, 4*k, CV_64FC1); 
+        // DoubleIndexer idx = covs0.createIndexer();
+        // for (int i = 0; i < 4*k; i += 4) {
+        //     idx.put(0, i,   1);
+        //     idx.put(0, i+1, 0);
+        //     idx.put(0, i+2, 0);
+        //     idx.put(0, i+3, 1);
+        // }
+
+        // return covs0;
+        return new Mat();
+    }
+
+    private Mat makeInitWeights(int k) {
+        // Mat weights0 = new Mat(k, 1, CV_64FC1);
+        // DoubleIndexer idx = weights0.createIndexer();
+        // double weight = 1/k;
+
+        // for (int r = 0; r < k; r++)
+        //     idx.put(r, 0, weight);
+
+        // return weights0;
+        return new Mat();
+    }
+
     private EllipseRoi makeEllipseRoi(Mat center, Mat unitVec, double A, double B) {
         Mat end1 = new Mat();
         Mat end2 = new Mat();
@@ -230,8 +278,8 @@ public class Demo_Pixel_Replication implements PlugInFilter {
     // We're using a simple naive algorithm since we assume k << n.
     private Mat randomInitMeans(int k) {
         HashSet<Integer> used = new HashSet<Integer>();
-        Mat means = new Mat(k, 2, CV_32FC1);
-        FloatIndexer meansIdx = means.createIndexer();
+        Mat means = new Mat(k, 2, CV_64FC1);
+        DoubleIndexer meansIdx = means.createIndexer();
 
         int i = 0;
         while (i < k) {
@@ -245,6 +293,13 @@ public class Demo_Pixel_Replication implements PlugInFilter {
                 i++;
             }
         }
+
+        for (int r = 0; r < means.rows(); r++) {
+            for (int c = 0; c < means.cols(); c++) {
+                IJ.log(String.format("means(%d,%d) = %f", r, c, meansIdx.get(r, c)));
+            }
+        }
+
         return means;
     }
 }