diff --git a/src/c/Cuda/CudaDeviceImages.cuh b/src/c/Cuda/CudaDeviceImages.cuh index 88f5ae6ec6a25194c1bcfd48dbd07230352d85ac..062ca5b37e9e13484c4fdd45fd856c890782f577 100644 --- a/src/c/Cuda/CudaDeviceImages.cuh +++ b/src/c/Cuda/CudaDeviceImages.cuh @@ -72,6 +72,14 @@ public: return deviceImages[trd]; } + CudaImageContainer<PixelType>* getBuffer(int idx) + { + if (numBuffers <= idx) + return NULL; + + return deviceImages[idx]; + } + void incrementBuffer() { DEBUG_KERNEL_CHECK(); diff --git a/src/c/Cuda/CudaNLMeans.cuh b/src/c/Cuda/CudaNLMeans.cuh index cd3e124c95f856c9e939475997cc3d5bc93c57fb..96d1de1dacd818562dbd04603bab048cb43e75e9 100644 --- a/src/c/Cuda/CudaNLMeans.cuh +++ b/src/c/Cuda/CudaNLMeans.cuh @@ -18,8 +18,8 @@ // this is an approximate nl means. uses fisher discriminant as a distance, with mean, variance of patches computed by previous cuda filter template <class PixelTypeIn, class PixelTypeOut> // params needed: a,h, search window size, comparison nhood -__global__ void cudaNLMeans_mv(CudaImageContainer<PixelTypeIn> imageIn, CudaImageContainer<PixelTypeOut> imageOutMean, CudaImageContainer<PixelTypeOut> imageVariance, - double h, int searchWindowRadius, int nhoodRadius, PixelTypeOut minValue, PixelTypeOut maxValue) +__global__ void cudaNLMeans_mv(CudaImageContainer<PixelTypeIn> imageIn, CudaImageContainer<PixelTypeOut> imageMean, CudaImageContainer<PixelTypeOut> imageVariance, + CudaImageContainer<PixelTypeOut> imageOut, double h, int searchWindowRadius, int nhoodRadius, PixelTypeOut minValue, PixelTypeOut maxValue) { Vec<std::size_t> threadCoordinate; GetThreadBlockCoordinate(threadCoordinate); @@ -30,7 +30,7 @@ __global__ void cudaNLMeans_mv(CudaImageContainer<PixelTypeIn> imageIn, CudaImag // float inputVal = (float)imageIn(threadCoordinate); - float inputMeanVal = (float)imageOutMean(threadCoordinate); + float inputMeanVal = (float)imageMean(threadCoordinate); float inputVarVal = (float)imageVariance(threadCoordinate); // double wMax = 0.; @@ -43,7 +43,7 @@ __global__ void cudaNLMeans_mv(CudaImageContainer<PixelTypeIn> imageIn, CudaImag if (kernelPos == threadCoordinate) continue; - float kernelMeanVal = (float)imageOutMean(kernelPos); + float kernelMeanVal = (float)imageMean(kernelPos); float kernelVarVal = (float)imageVariance(kernelPos); float kernelVal = (float)imageIn(kernelPos); @@ -61,7 +61,7 @@ __global__ void cudaNLMeans_mv(CudaImageContainer<PixelTypeIn> imageIn, CudaImag // now normalize double outVal = outputAccumulator / wAccumulator; - imageOutMean(threadCoordinate) = (PixelTypeOut)CLAMP(outVal, minValue, maxValue); + imageOut(threadCoordinate) = (PixelTypeOut)CLAMP(outVal, minValue, maxValue); } } // cudaNLMeans_mv @@ -146,7 +146,7 @@ void cNLMeans(ImageView<PixelTypeIn> imageIn, ImageView<PixelTypeOut> imageOut, { const PixelTypeOut MIN_VAL = std::numeric_limits<PixelTypeOut>::lowest(); const PixelTypeOut MAX_VAL = std::numeric_limits<PixelTypeOut>::max(); - const int NUM_BUFF_NEEDED = 3; + const int NUM_BUFF_NEEDED = 4; CudaDevices cudaDevs(cudaNLMeans_mv<PixelTypeIn, PixelTypeOut>, device); @@ -179,13 +179,12 @@ void cNLMeans(ImageView<PixelTypeIn> imageIn, ImageView<PixelTypeOut> imageOut, deviceImages.setAllDims(chunks[i].getFullChunkSize()); - cudaMeanAndVariance << <chunks[i].blocks, chunks[i].threads >> > (*(deviceImages.getCurBuffer()), *(deviceImages.getNextBuffer()), *(deviceImages.getThirdBuffer()), constKernelMem, MIN_VAL, MAX_VAL); + cudaMeanAndVariance << <chunks[i].blocks, chunks[i].threads >> > (*(deviceImages.getCurBuffer()), *(deviceImages.getBuffer(1)), *(deviceImages.getBuffer(2)), constKernelMem, MIN_VAL, MAX_VAL); - cudaNLMeans_mv <<<chunks[i].blocks, chunks[i].threads>>>(*(deviceImages.getCurBuffer()), *(deviceImages.getNextBuffer()), *(deviceImages.getThirdBuffer()), + cudaNLMeans_mv <<<chunks[i].blocks, chunks[i].threads>>>(*(deviceImages.getCurBuffer()), *(deviceImages.getBuffer(1)), *(deviceImages.getBuffer(2)), *(deviceImages.getBuffer(3)), h, searchWindowRadius, nhoodRadius, MIN_VAL, MAX_VAL); - - deviceImages.incrementBuffer(); - chunks[i].retriveROI(imageOut, deviceImages.getCurBuffer()); + + chunks[i].retriveROI(imageOut, deviceImages.getBuffer(3)); } } }