%
% waterfall segmentation
% merge waterfall basins to end up with k components (cells)
%
function L = Waterfall(distIm, k)
    bValid = ~isinf(distIm);
    iterL = watershed(distIm);
    
    % Assign border (or dropped) pixels to their nearest component
    bBorder = (iterL==0) | (bValid & iterL==1);
    [tstDist,tstIdx] = bwdist(iterL > 1);
    iterL(bBorder) = iterL(tstIdx(bBorder));
    
    if ( k == 1 )
        L = (iterL > 1);
        return;
    end
    
    curLabels = double(unique(iterL(iterL > 1)));
    totalComp = length(curLabels);
    if ( k >= totalComp )
        if ( k > totalComp )
%             warning(['Unable to identify a ' num2str(k) ' component watershed threshold.']);
        end
        
        L = remapLabels(iterL);
        return;
    end
    
    d = ndims(distIm);
    boxElem = ones(repmat(3,1,d));
    se = strel(boxElem);
    
    edgeM = Inf(max(curLabels),max(curLabels));
    basinVal = -Inf(max(curLabels),1);
    for i=1:length(curLabels)
        bwComp = bValid & (iterL==curLabels(i));
        
        basinVal(curLabels(i)) = min(distIm(bwComp(:)));
        chkIdx = find(bValid & imdilate(bwComp, se) & ~bwComp);
        
        if ( isempty(chkIdx) )
            continue;
        end
        
        [edgeVals,srtIdx] = sort((distIm(chkIdx)));
        
        chkIdx = chkIdx(srtIdx);
        [adjIdx,ia] = unique(iterL(chkIdx),'stable');
        
        edgeM(curLabels(i),adjIdx) = edgeVals(ia);
    end
    
    while ( totalComp > k )
        A = abs(edgeM - repmat(basinVal,1,length(basinVal)));
        
        [mergeVal,bestIdx] = min(A(:));
        if ( isinf(mergeVal) )
            break;
        end
        
        [minR,minC] = ind2sub(size(A), bestIdx);
        
        % Drop higher basin value
        basinVal(minR) = min(basinVal(minR),basinVal(minC));
        basinVal(minC) = -Inf;
        
        % Merge graph edges
        edgeM(minR,:) = min([edgeM(minR,:);edgeM(minC,:)],[],1);
        edgeM(:,minR) = min([edgeM(:,minR) edgeM(:,minC)],[],2);
        
        edgeM(minC,:) = Inf;
        edgeM(:,minC) = Inf;
        edgeM(minR,minR) = Inf;
        
        iterL(iterL==minC) = minR;
        
        curLabels = double(unique(iterL(iterL > 1)));
        totalComp = nnz(curLabels > 0);
    end
    
    totalComp = nnz(curLabels > 0);
    if ( totalComp > k )
        error(['Unable to merge to ' num2str(k) ' components in watershed.']);
    end
    
    L = remapLabels(iterL);
end

function L = remapLabels(inL)
    allLabels = unique(inL);
    
    remapIdx = zeros(1,max(allLabels));
    remapIdx(allLabels) = [0 1:length(allLabels)-1];
    
    L = remapIdx(inL);
end
