#pragma once
#include "CudaImageContainer.cuh"
#include "CudaDeviceImages.cuh"
#include "CudaUtilities.h"
#include "CudaDeviceInfo.h"
#include "Kernel.cuh"
#include "KernelIterator.cuh"
#include "ImageDimensions.cuh"
#include "ImageChunk.h"
#include "Defines.h"
#include "Vec.h"

#include <cuda_runtime.h>
#include <limits>
#include <omp.h>


// this is an approximate nl means. uses fisher discriminant as a distance, with mean, variance of patches computed by previous cuda filter
template <class PixelTypeIn, class PixelTypeOut>
// params needed: a,h, search window size, comparison nhood
__global__ void cudaNLMeans_mv(CudaImageContainer<PixelTypeIn> imageIn, CudaImageContainer<PixelTypeOut> imageMean, CudaImageContainer<PixelTypeOut> imageVariance,
	CudaImageContainer<PixelTypeOut> imageOut, double h, double searchWindowRadius, int nhoodRadius, PixelTypeOut minValue, PixelTypeOut maxValue)
{
	Vec<std::size_t> threadCoordinate;
	GetThreadBlockCoordinate(threadCoordinate);

	if (threadCoordinate < imageIn.getDims())
	{
		int xySearch = (int)std::floor(searchWindowRadius);
		int zSearch = (int)((searchWindowRadius - std::floor(searchWindowRadius)) * (double)xySearch);
		if ( (searchWindowRadius - std::floor(searchWindowRadius)) < 1e-6 )
			zSearch = xySearch;
		Vec<int> nSearchWindow = Vec<int>(xySearch, xySearch, zSearch);
		Vec<int> searchVec = Vec<int>::min(Vec<int>(imageIn.getDims()) - 1, nSearchWindow);

		// 
		float inputVal = (float)imageIn(threadCoordinate);
		float inputMeanVal = (float)imageMean(threadCoordinate);
		float inputVarVal = (float)imageVariance(threadCoordinate);
		// 
		double wMax = 0.; 
		double wAccumulator = 0.;
		double outputAccumulator = 0.;
		KernelIterator kIt(threadCoordinate, imageIn.getDims(), searchVec * 2 + 1);
		for (; !kIt.end(); ++kIt)
		{
			Vec<float> kernelPos = kIt.getImageCoordinate();

			float kernelMeanVal = (float)imageMean(kernelPos);
			float kernelVarVal = (float)imageVariance(kernelPos);
			float kernelVal = (float)imageIn(kernelPos);

			double w = SQR(inputMeanVal - kernelMeanVal) / (inputVarVal + kernelVarVal + 1e-9);
			w= exp(-w / SQR(h));
			if (w > wMax)
				wMax = w;
			wAccumulator += w;
			outputAccumulator += w*kernelVal;
		}
		// add in the value at threadCoordinate, weighted by wMax
		outputAccumulator += wMax * inputVal;
		wAccumulator += wMax;

		// now normalize
		double outVal = outputAccumulator / wAccumulator;

		imageOut(threadCoordinate) = (PixelTypeOut)CLAMP(outVal, minValue, maxValue);
	}
} // cudaNLMeans_mv

// this part is the templated function that gets called by the front end.
// here be cpu
// todo - if we chunk, make sure search window doesn't go off chunk
template <class PixelTypeIn, class PixelTypeOut>
void cNLMeans(ImageView<PixelTypeIn> imageIn, ImageView<PixelTypeOut> imageOut, double h, double searchWindowRadius, int nhoodRadius, int device = -1)
{
	const PixelTypeOut MIN_VAL = std::numeric_limits<PixelTypeOut>::lowest();
	const PixelTypeOut MAX_VAL = std::numeric_limits<PixelTypeOut>::max();
	const int NUM_BUFF_NEEDED = 4;

	CudaDevices cudaDevs(cudaNLMeans_mv<PixelTypeIn, PixelTypeOut>, device);

	std::size_t maxTypeSize = MAX(sizeof(PixelTypeIn), sizeof(PixelTypeOut));
	std::vector<ImageChunk> chunks = calculateBuffers(imageIn.getDims(), NUM_BUFF_NEEDED, cudaDevs, maxTypeSize, Vec<std::size_t>(2*nhoodRadius+1));

	Vec<std::size_t> maxDeviceDims;
	setMaxDeviceDims(chunks, maxDeviceDims);

	Vec<int> kernelDims = Vec<int>(1+2*nhoodRadius);
	float* kernelMem = new float[kernelDims.product()];
	for (int i = 0; i < kernelDims.product(); i++)
		kernelMem[i] = 1.0;
	ImageView<float> kernel(kernelMem, kernelDims);

	omp_set_num_threads(MIN(chunks.size(), cudaDevs.getNumDevices()));
	#pragma omp parallel default(shared)
	{
		const int CUDA_IDX = omp_get_thread_num();
		const int N_THREADS = omp_get_num_threads();
		const int CUR_DEVICE = cudaDevs.getDeviceIdx(CUDA_IDX);

		CudaDeviceImages<PixelTypeOut> deviceImages(NUM_BUFF_NEEDED, maxDeviceDims, CUR_DEVICE);
		Kernel constKernelMem(kernel, CUR_DEVICE);

		for (int i = CUDA_IDX; i < chunks.size(); i += N_THREADS)
		{
			if (!chunks[i].sendROI(imageIn, deviceImages.getCurBuffer()))
				std::runtime_error("Error sending ROI to device!");

			deviceImages.setAllDims(chunks[i].getFullChunkSize());
			
			cudaMeanAndVariance << <chunks[i].blocks, chunks[i].threads >> > (*(deviceImages.getCurBuffer()), *(deviceImages.getBuffer(1)), *(deviceImages.getBuffer(2)), constKernelMem, MIN_VAL, MAX_VAL);

			cudaNLMeans_mv <<<chunks[i].blocks, chunks[i].threads>>>(*(deviceImages.getCurBuffer()), *(deviceImages.getBuffer(1)), *(deviceImages.getBuffer(2)), *(deviceImages.getBuffer(3)),
				h, searchWindowRadius, nhoodRadius, MIN_VAL, MAX_VAL);
						
			chunks[i].retriveROI(imageOut, deviceImages.getBuffer(3));
		}
	}
}
