Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions modules/cudaarithm/src/cuda/threshold.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ __global__ void otsu_sums(uint *histogram, uint *threshold_sums, unsigned long l
}

__global__ void
otsu_variance(float2 *variance, uint *histogram, uint *threshold_sums, unsigned long long *sums)
otsu_variance(float2 *variance, uint *histogram, uint *threshold_sums, unsigned long long *sums, uint n_samples)
{
const uint n_bins = 256;

Expand All @@ -137,7 +137,6 @@ otsu_variance(float2 *variance, uint *histogram, uint *threshold_sums, unsigned
int bin_idx = threadIdx.x;
int threshold = blockIdx.x;

uint n_samples = threshold_sums[0];
uint n_samples_above = threshold_sums[threshold];
uint n_samples_below = n_samples - n_samples_above;

Expand All @@ -149,15 +148,21 @@ otsu_variance(float2 *variance, uint *histogram, uint *threshold_sums, unsigned
float threshold_variance_below_f32 = 0;
if (bin_idx > threshold)
{
float mean = (float) sum_above / n_samples_above;
float sigma = bin_idx - mean;
threshold_variance_above_f32 = sigma * sigma;
if (n_samples_above > 0)
{
float mean = (float) sum_above / n_samples_above;
float sigma = bin_idx - mean;
threshold_variance_above_f32 = sigma * sigma;
}
}
else
{
float mean = (float) sum_below / n_samples_below;
float sigma = bin_idx - mean;
threshold_variance_below_f32 = sigma * sigma;
if (n_samples_below > 0)
{
float mean = (float) sum_below / n_samples_below;
float sigma = bin_idx - mean;
threshold_variance_below_f32 = sigma * sigma;
}
}

uint bin_count = histogram[bin_idx];
Expand Down Expand Up @@ -198,15 +203,14 @@ __device__ bool has_lowest_score(
}

__global__ void
otsu_score(uint *otsu_threshold, uint *threshold_sums, float2 *variance)
otsu_score(uint *otsu_threshold, uint *threshold_sums, float2 *variance, uint n_samples)
{
const uint n_thresholds = 256;

__shared__ float shared_memory[n_thresholds];

int threshold = threadIdx.x;

uint n_samples = threshold_sums[0];
uint n_samples_above = threshold_sums[threshold];
uint n_samples_below = n_samples - n_samples_above;

Expand Down Expand Up @@ -241,7 +245,7 @@ otsu_score(uint *otsu_threshold, uint *threshold_sums, float2 *variance)
}
}

void compute_otsu(uint *histogram, uint *otsu_threshold, Stream &stream)
void compute_otsu(uint *histogram, uint *otsu_threshold, uint n_samples, Stream &stream)
{
const uint n_bins = 256;
const uint n_thresholds = 256;
Expand All @@ -261,12 +265,12 @@ void compute_otsu(uint *histogram, uint *otsu_threshold, Stream &stream)
otsu_sums<<<grid_all, block_all, 0, cuda_stream>>>(
histogram, gpu_threshold_sums.ptr<uint>(), gpu_sums.ptr<unsigned long long>());
otsu_variance<<<grid_all, block_all, 0, cuda_stream>>>(
gpu_variances.ptr<float2>(), histogram, gpu_threshold_sums.ptr<uint>(), gpu_sums.ptr<unsigned long long>());
gpu_variances.ptr<float2>(), histogram, gpu_threshold_sums.ptr<uint>(), gpu_sums.ptr<unsigned long long>(), n_samples);
otsu_score<<<grid_score, block_score, 0, cuda_stream>>>(
otsu_threshold, gpu_threshold_sums.ptr<uint>(), gpu_variances.ptr<float2>());
otsu_threshold, gpu_threshold_sums.ptr<uint>(), gpu_variances.ptr<float2>(), n_samples);
}

// TODO: Replace this is cv::cuda::calcHist
// TODO: Replace this with cv::cuda::calcHist
template <uint n_bins>
__global__ void histogram_kernel(
uint *histogram, const uint8_t *image, uint width,
Expand Down Expand Up @@ -334,7 +338,7 @@ double cv::cuda::threshold(InputArray _src, OutputArray _dst, double thresh, dou
calcHist(src, gpu_histogram, stream);

GpuMat gpu_otsu_threshold(1, 1, CV_32SC1, pool.getAllocator());
compute_otsu(gpu_histogram.ptr<uint>(), gpu_otsu_threshold.ptr<uint>(), stream);
compute_otsu(gpu_histogram.ptr<uint>(), gpu_otsu_threshold.ptr<uint>(), src.rows * src.cols, stream);

cv::Mat mat_otsu_threshold;
gpu_otsu_threshold.download(mat_otsu_threshold, stream);
Expand Down