Skip to content
Snippets Groups Projects
Commit 228b55bc authored by Walt Mankowski's avatar Walt Mankowski
Browse files

run multiple replicates with random means

parent 10613373
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -7,6 +7,8 @@ import java.awt.image.*;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Random;
import static org.bytedeco.javacpp.opencv_core.*;
import static org.bytedeco.javacpp.opencv_imgproc.*;
......@@ -17,6 +19,18 @@ import org.bytedeco.javacpp.opencv_ml.*;
public class Demo_Pixel_Replication implements PlugInFilter {
int k;
Random rnd;
// foreground points in ip
ArrayList<Point2f> UniqPts;
// pixel-replicated points in ip
ArrayList<Point2f> PrPts;
public Demo_Pixel_Replication() {
UniqPts = new ArrayList<>();
PrPts = new ArrayList<>();
rnd = new Random();
}
public int setup(String arg, ImagePlus imp) {
if (IJ.versionLessThan("1.37j"))
......@@ -46,11 +60,6 @@ public class Demo_Pixel_Replication implements PlugInFilter {
IplImage ipl = ip2ipl(ip);
Rectangle r = ip.getRoi();
// foreground points in ip
ArrayList<Point2f> UniqPts = new ArrayList<>();
// pixel-replicated points in ip
ArrayList<Point2f> PrPts = new ArrayList<>();
// use distance transform to populate UniqPts and PrPts
CvMat distMat = DistTransform(ipl);
cvReleaseImage(ipl);
......@@ -70,7 +79,7 @@ public class Demo_Pixel_Replication implements PlugInFilter {
}
// compute the gmm
EM em = computeGMM(PrPts, k);
EM em = computeGMM(k);
// create a new image showing the partition
ImageProcessor ipOutPr = ip.duplicate();
......@@ -150,15 +159,42 @@ public class Demo_Pixel_Replication implements PlugInFilter {
return emMat;
}
private EM computeGMM(ArrayList<Point2f> PrPts, int k) {
private EM computeGMM(int k) {
EM bestEm = new EM();
double bestScore = -1e300;
for (int i = 1; i <= 5; i++) {
IJ.log(String.format("replicate %d", i));
EM em = new EM();
Mat logLikelihoods = new Mat(PrPts.size(), 1, CV_64FC1);
Mat labels = new Mat(PrPts.size(), 1, CV_32FC1);
em.set("nclusters", k);
em.set("covMatType", EM.COV_MAT_GENERIC);
em.set("maxIters", 1000);
// em.set("epsilon", 1e-6);
Mat emMat = PrPts2Mat(PrPts);
em.train(emMat);
Mat initMeans = randomInitMeans(k);
// em.trainE(emMat, initMeans, noArray());
em.trainE(emMat, initMeans, new Mat(), new Mat(), logLikelihoods, labels, new Mat());
double score = addUp(logLikelihoods.row(0));
IJ.log(String.format("score = %f", score));
if (score > bestScore) {
bestEm = em;
bestScore = score;
}
}
return em;
return bestEm;
}
private double addUp(Mat m) {
DoubleIndexer idx = m.createIndexer();
double total = 0;
for (int c = 0; c < m.cols(); c++)
total += idx.get(0, c);
return total;
}
private EllipseRoi makeEllipseRoi(Mat center, Mat unitVec, double A, double B) {
......@@ -174,4 +210,26 @@ public class Demo_Pixel_Replication implements PlugInFilter {
return new EllipseRoi(end1Idx.get(0,0), end1Idx.get(0,1), end2Idx.get(0,0), end2Idx.get(0,1), aspectRatio);
}
// Initialize the means by picking 2 distinct points at random from UniqPts.
// We're using a simple naive algorithm since we assume k << n.
private Mat randomInitMeans(int k) {
HashSet<Integer> used = new HashSet<Integer>();
Mat means = new Mat(k, 2, CV_32FC1);
FloatIndexer meansIdx = means.createIndexer();
int i = 0;
while (i < k) {
int j = rnd.nextInt(UniqPts.size());
if (!used.contains(j)) {
Point2f p = UniqPts.get(j);
meansIdx.put(i, 0, p.x());
meansIdx.put(i, 1, p.y());
used.add(j);
IJ.log(String.format("random pt %d: (%.0f,%.0f)", i, p.x(), p.y()));
i++;
}
}
return means;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment