From c9f4f6dc682af1fbf7a2d6978927741a42b2d73b Mon Sep 17 00:00:00 2001
From: actb <andrew.r.cohen@drexel.edu>
Date: Fri, 24 Apr 2020 11:25:50 -0400
Subject: [PATCH] fix NL means double dipping output buffer

---
 src/c/Cuda/CudaDeviceImages.cuh |  8 ++++++++
 src/c/Cuda/CudaNLMeans.cuh      | 21 ++++++++++-----------
 2 files changed, 18 insertions(+), 11 deletions(-)

diff --git a/src/c/Cuda/CudaDeviceImages.cuh b/src/c/Cuda/CudaDeviceImages.cuh
index 88f5ae6e..062ca5b3 100644
--- a/src/c/Cuda/CudaDeviceImages.cuh
+++ b/src/c/Cuda/CudaDeviceImages.cuh
@@ -72,6 +72,14 @@ public:
 		return deviceImages[trd];
 	}
 
+	CudaImageContainer<PixelType>* getBuffer(int idx)
+	{
+		if (numBuffers <= idx)
+			return NULL;
+
+		return deviceImages[idx];
+	}
+
 	void incrementBuffer()
 	{
 		DEBUG_KERNEL_CHECK();
diff --git a/src/c/Cuda/CudaNLMeans.cuh b/src/c/Cuda/CudaNLMeans.cuh
index cd3e124c..96d1de1d 100644
--- a/src/c/Cuda/CudaNLMeans.cuh
+++ b/src/c/Cuda/CudaNLMeans.cuh
@@ -18,8 +18,8 @@
 // this is an approximate nl means. uses fisher discriminant as a distance, with mean, variance of patches computed by previous cuda filter
 template <class PixelTypeIn, class PixelTypeOut>
 // params needed: a,h, search window size, comparison nhood
-__global__ void cudaNLMeans_mv(CudaImageContainer<PixelTypeIn> imageIn, CudaImageContainer<PixelTypeOut> imageOutMean, CudaImageContainer<PixelTypeOut> imageVariance,
-	double h, int searchWindowRadius, int nhoodRadius, PixelTypeOut minValue, PixelTypeOut maxValue)
+__global__ void cudaNLMeans_mv(CudaImageContainer<PixelTypeIn> imageIn, CudaImageContainer<PixelTypeOut> imageMean, CudaImageContainer<PixelTypeOut> imageVariance,
+	CudaImageContainer<PixelTypeOut> imageOut, double h, int searchWindowRadius, int nhoodRadius, PixelTypeOut minValue, PixelTypeOut maxValue)
 {
 	Vec<std::size_t> threadCoordinate;
 	GetThreadBlockCoordinate(threadCoordinate);
@@ -30,7 +30,7 @@ __global__ void cudaNLMeans_mv(CudaImageContainer<PixelTypeIn> imageIn, CudaImag
 
 		// 
 		float inputVal = (float)imageIn(threadCoordinate);
-		float inputMeanVal = (float)imageOutMean(threadCoordinate);
+		float inputMeanVal = (float)imageMean(threadCoordinate);
 		float inputVarVal = (float)imageVariance(threadCoordinate);
 		// 
 		double wMax = 0.;
@@ -43,7 +43,7 @@ __global__ void cudaNLMeans_mv(CudaImageContainer<PixelTypeIn> imageIn, CudaImag
 			if (kernelPos == threadCoordinate)
 				continue;
 
-			float kernelMeanVal = (float)imageOutMean(kernelPos);
+			float kernelMeanVal = (float)imageMean(kernelPos);
 			float kernelVarVal = (float)imageVariance(kernelPos);
 			float kernelVal = (float)imageIn(kernelPos);
 
@@ -61,7 +61,7 @@ __global__ void cudaNLMeans_mv(CudaImageContainer<PixelTypeIn> imageIn, CudaImag
 		// now normalize
 		double outVal = outputAccumulator / wAccumulator;
 
-		imageOutMean(threadCoordinate) = (PixelTypeOut)CLAMP(outVal, minValue, maxValue);
+		imageOut(threadCoordinate) = (PixelTypeOut)CLAMP(outVal, minValue, maxValue);
 	}
 } // cudaNLMeans_mv
 
@@ -146,7 +146,7 @@ void cNLMeans(ImageView<PixelTypeIn> imageIn, ImageView<PixelTypeOut> imageOut,
 {
 	const PixelTypeOut MIN_VAL = std::numeric_limits<PixelTypeOut>::lowest();
 	const PixelTypeOut MAX_VAL = std::numeric_limits<PixelTypeOut>::max();
-	const int NUM_BUFF_NEEDED = 3;
+	const int NUM_BUFF_NEEDED = 4;
 
 	CudaDevices cudaDevs(cudaNLMeans_mv<PixelTypeIn, PixelTypeOut>, device);
 
@@ -179,13 +179,12 @@ void cNLMeans(ImageView<PixelTypeIn> imageIn, ImageView<PixelTypeOut> imageOut,
 
 			deviceImages.setAllDims(chunks[i].getFullChunkSize());
 			
-			cudaMeanAndVariance << <chunks[i].blocks, chunks[i].threads >> > (*(deviceImages.getCurBuffer()), *(deviceImages.getNextBuffer()), *(deviceImages.getThirdBuffer()), constKernelMem, MIN_VAL, MAX_VAL);
+			cudaMeanAndVariance << <chunks[i].blocks, chunks[i].threads >> > (*(deviceImages.getCurBuffer()), *(deviceImages.getBuffer(1)), *(deviceImages.getBuffer(2)), constKernelMem, MIN_VAL, MAX_VAL);
 
-			cudaNLMeans_mv <<<chunks[i].blocks, chunks[i].threads>>>(*(deviceImages.getCurBuffer()), *(deviceImages.getNextBuffer()), *(deviceImages.getThirdBuffer()),
+			cudaNLMeans_mv <<<chunks[i].blocks, chunks[i].threads>>>(*(deviceImages.getCurBuffer()), *(deviceImages.getBuffer(1)), *(deviceImages.getBuffer(2)), *(deviceImages.getBuffer(3)),
 				h, searchWindowRadius, nhoodRadius, MIN_VAL, MAX_VAL);
-			
-			deviceImages.incrementBuffer();
-			chunks[i].retriveROI(imageOut, deviceImages.getCurBuffer());
+						
+			chunks[i].retriveROI(imageOut, deviceImages.getBuffer(3));
 		}
 	}
 }
-- 
GitLab