Skip to content
Snippets Groups Projects
Commit 513d8c20 authored by Walt Mankowski's avatar Walt Mankowski
Browse files

run gmm REPLICATES times and pick the best scoring result

Note that this makes the code run significantly longer
parent 84ea4d19
No related branches found
No related tags found
No related merge requests found
No preview for this file type
...@@ -7,6 +7,8 @@ import java.awt.image.*; ...@@ -7,6 +7,8 @@ import java.awt.image.*;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.FloatBuffer; import java.nio.FloatBuffer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet;
import java.util.Random;
import static org.bytedeco.javacpp.opencv_core.*; import static org.bytedeco.javacpp.opencv_core.*;
import static org.bytedeco.javacpp.opencv_imgproc.*; import static org.bytedeco.javacpp.opencv_imgproc.*;
...@@ -17,6 +19,18 @@ import org.bytedeco.javacpp.opencv_ml.*; ...@@ -17,6 +19,18 @@ import org.bytedeco.javacpp.opencv_ml.*;
public class Demo_Pixel_Replication implements PlugInFilter { public class Demo_Pixel_Replication implements PlugInFilter {
int k; int k;
Random rnd;
// foreground points in ip
ArrayList<Point2f> UniqPts;
// pixel-replicated points in ip
ArrayList<Point2f> PrPts;
public Demo_Pixel_Replication() {
UniqPts = new ArrayList<>();
PrPts = new ArrayList<>();
rnd = new Random();
}
public int setup(String arg, ImagePlus imp) { public int setup(String arg, ImagePlus imp) {
if (IJ.versionLessThan("1.37j")) if (IJ.versionLessThan("1.37j"))
...@@ -46,11 +60,6 @@ public class Demo_Pixel_Replication implements PlugInFilter { ...@@ -46,11 +60,6 @@ public class Demo_Pixel_Replication implements PlugInFilter {
IplImage ipl = ip2ipl(ip); IplImage ipl = ip2ipl(ip);
Rectangle r = ip.getRoi(); Rectangle r = ip.getRoi();
// foreground points in ip
ArrayList<Point2f> UniqPts = new ArrayList<>();
// pixel-replicated points in ip
ArrayList<Point2f> PrPts = new ArrayList<>();
// use distance transform to populate UniqPts and PrPts // use distance transform to populate UniqPts and PrPts
CvMat distMat = DistTransform(ipl); CvMat distMat = DistTransform(ipl);
cvReleaseImage(ipl); cvReleaseImage(ipl);
...@@ -70,7 +79,7 @@ public class Demo_Pixel_Replication implements PlugInFilter { ...@@ -70,7 +79,7 @@ public class Demo_Pixel_Replication implements PlugInFilter {
} }
// compute the gmm // compute the gmm
EM em = computeGMM(PrPts, k); EM em = computeGMM(k);
// create a new image showing the partition // create a new image showing the partition
ImageProcessor ipOutPr = ip.duplicate(); ImageProcessor ipOutPr = ip.duplicate();
...@@ -150,15 +159,42 @@ public class Demo_Pixel_Replication implements PlugInFilter { ...@@ -150,15 +159,42 @@ public class Demo_Pixel_Replication implements PlugInFilter {
return emMat; return emMat;
} }
private EM computeGMM(ArrayList<Point2f> PrPts, int k) { private EM computeGMM(int k) {
EM bestEm = new EM();
double bestScore = -1e300;
for (int i = 1; i <= 10; i++) {
IJ.log(String.format("replicate %d", i));
EM em = new EM(); EM em = new EM();
Mat logLikelihoods = new Mat(PrPts.size(), 1, CV_64FC1);
Mat labels = new Mat(PrPts.size(), 1, CV_32FC1);
em.set("nclusters", k); em.set("nclusters", k);
em.set("covMatType", EM.COV_MAT_GENERIC); em.set("covMatType", EM.COV_MAT_GENERIC);
em.set("maxIters", 1000);
// em.set("epsilon", 1e-6);
Mat emMat = PrPts2Mat(PrPts); Mat emMat = PrPts2Mat(PrPts);
em.train(emMat); Mat initMeans = randomInitMeans(k);
// em.trainE(emMat, initMeans, noArray());
em.trainE(emMat, initMeans, new Mat(), new Mat(), logLikelihoods, labels, new Mat());
double score = addUp(logLikelihoods.row(0));
IJ.log(String.format("score = %f", score));
if (score > bestScore) {
bestEm = em;
bestScore = score;
}
}
return em; return bestEm;
}
private double addUp(Mat m) {
DoubleIndexer idx = m.createIndexer();
double total = 0;
for (int c = 0; c < m.cols(); c++)
total += idx.get(0, c);
return total;
} }
private EllipseRoi makeEllipseRoi(Mat center, Mat unitVec, double A, double B) { private EllipseRoi makeEllipseRoi(Mat center, Mat unitVec, double A, double B) {
...@@ -174,4 +210,26 @@ public class Demo_Pixel_Replication implements PlugInFilter { ...@@ -174,4 +210,26 @@ public class Demo_Pixel_Replication implements PlugInFilter {
return new EllipseRoi(end1Idx.get(0,0), end1Idx.get(0,1), end2Idx.get(0,0), end2Idx.get(0,1), aspectRatio); return new EllipseRoi(end1Idx.get(0,0), end1Idx.get(0,1), end2Idx.get(0,0), end2Idx.get(0,1), aspectRatio);
} }
// Initialize the means by picking 2 distinct points at random from UniqPts.
// We're using a simple naive algorithm since we assume k << n.
private Mat randomInitMeans(int k) {
HashSet<Integer> used = new HashSet<Integer>();
Mat means = new Mat(k, 2, CV_32FC1);
FloatIndexer meansIdx = means.createIndexer();
int i = 0;
while (i < k) {
int j = rnd.nextInt(UniqPts.size());
if (!used.contains(j)) {
Point2f p = UniqPts.get(j);
meansIdx.put(i, 0, p.x());
meansIdx.put(i, 1, p.y());
used.add(j);
IJ.log(String.format("random pt %d: (%.0f,%.0f)", i, p.x(), p.y()));
i++;
}
}
return means;
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment