Skip to content

Commit 0d04ad3

Browse files
authored
Fix antialias downsample on CUDA EP (microsoft#25265)
### Description <!-- Describe your changes. --> This PR addresses 3 issues: - Compilation errors when DISABLE_CONTRIB_OPS flag is on - Solve a CUDA compute kernel setup issue on Resize op with cubic filter and antialiasing - Solve cubic_coeff_a parameter being ignored in CUDA kernel of Resize op ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> microsoft#25264
1 parent cb4d4af commit 0d04ad3

File tree

6 files changed

+131
-13
lines changed

6 files changed

+131
-13
lines changed

onnxruntime/core/optimizer/graph_transformer_utils.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
442442
// PR #6351 implemented similar fusion-pattern for CUDA only, and can only fuse conv-add-relu,
443443
// while we can fuse more activation.
444444
transformers.emplace_back(std::make_unique<ConvAddActivationFusion>(cpu_ep));
445+
#else
446+
ORT_UNUSED_PARAMETER(logger);
445447
#endif
446448

447449
} break;
@@ -533,6 +535,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
533535
}
534536
#else
535537
ORT_UNUSED_PARAMETER(cpu_execution_provider);
538+
ORT_UNUSED_PARAMETER(logger);
536539
#endif
537540
}
538541
} break;

onnxruntime/core/providers/cuda/cuda_execution_provider.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,9 @@ std::optional<bool> CUDAExecutionProvider::ShouldConvertDataLayoutForOp([[maybe_
348348
(node_domain == kMSDomain && node_op_type == "GridSample");
349349

350350
#else // defined(ENABLE_CUDA_NHWC_OPS)
351+
ORT_UNUSED_PARAMETER(node_domain);
352+
ORT_UNUSED_PARAMETER(node_op_type);
353+
ORT_UNUSED_PARAMETER(target_data_layout);
351354
return std::nullopt;
352355
#endif
353356
}

onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,7 @@ void ResizeBicubicUpsample(cudaStream_t stream,
961961
int rank,
962962
const UpsampleMode /*upsample_mode*/,
963963
ResizeCoordinateTransformationMode coordinate_transform_mode,
964+
const float cubic_coeff_a,
964965
gsl::span<const int64_t> /*input_shape*/,
965966
gsl::span<const int64_t> /*output_shape*/,
966967
int64_t batch_size, int64_t num_channels,
@@ -982,19 +983,22 @@ void ResizeBicubicUpsample(cudaStream_t stream,
982983
const bool use_extrapolation = extrapolation.has_value();
983984
const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f;
984985

985-
int blocksPerGrid = narrow<int>(CeilDiv(N, GridDim::maxThreadsPerBlock));
986-
const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4]
987-
: fast_divmod(gsl::narrow_cast<int>(N));
988-
const fast_divmod& div_output_width = output_div_pitches[rank - 2];
989-
990-
constexpr float support_value = antialias_constants::kBiCubicSupportSize;
991-
992986
int64_t input_depth, input_height, input_width;
993987
std::tie(input_depth, input_height, input_width) = inferred_input_dims;
994988

995989
int64_t output_depth, output_height, output_width;
996990
std::tie(output_depth, output_height, output_width) = inferred_output_dims;
997991

992+
const auto temp_buf_size = SafeInt<int64_t>(batch_size) * num_channels * input_height * output_width;
993+
994+
int blocksPerGridL2 = narrow<int>(CeilDiv(N, GridDim::maxThreadsPerBlock));
995+
int blocksPerGridL1 = narrow<int>(CeilDiv(temp_buf_size, GridDim::maxThreadsPerBlock));
996+
const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4]
997+
: fast_divmod(gsl::narrow_cast<int>(N));
998+
const fast_divmod& div_output_width = output_div_pitches[rank - 2];
999+
1000+
constexpr float support_value = antialias_constants::kBiCubicSupportSize;
1001+
9981002
int blocksPerDimsMappingGrid =
9991003
narrow<int>(CeilDiv((output_depth + output_height + output_width), 32));
10001004

@@ -1027,7 +1031,6 @@ void ResizeBicubicUpsample(cudaStream_t stream,
10271031
AccumType* y_weighted_buffer = GetTyped<AccumType>(weighted_buffer_ptr);
10281032
AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size;
10291033

1030-
const auto temp_buf_size = SafeInt<int64_t>(batch_size) * num_channels * input_height * output_width;
10311034
auto image_temp_buffer = AllocateTyped<T>(allocate_temp_space, narrow<size_t>(temp_buf_size));
10321035

10331036
// clang-format off
@@ -1042,15 +1045,15 @@ void ResizeBicubicUpsample(cudaStream_t stream,
10421045
std::make_tuple(roi_vals[rank - 2 + rank], roi_vals[rank - 1 + rank]), // roi ends h, w
10431046
std::make_tuple(h_scaled_support, w_scaled_support),
10441047
std::make_tuple(h_window_size, w_window_size),
1045-
onnxruntime::antialias_constants::kCubicCoeffA, exclude_outside,
1048+
cubic_coeff_a, exclude_outside,
10461049
GetTyped<int64_t>(bounds_buffer_ptr),
10471050
GetTyped<int64_t>(out_of_bounds_buffer_ptr),
10481051
std::make_tuple(y_weighted_buffer, w_weighted_buffer));
10491052
});
10501053
// clang-format on
10511054
const fast_divmod div_step_image(narrow<int>(num_channels * input_height * output_width));
10521055
// clang-format off
1053-
_ComputeInterpolationAtLevel1<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
1056+
_ComputeInterpolationAtLevel1<T><<<blocksPerGridL1, GridDim::maxThreadsPerBlock, 0, stream>>>(
10541057
num_channels, input_height, input_width, input_height, output_width,
10551058
div_output_width,
10561059
div_step_image,
@@ -1064,7 +1067,7 @@ void ResizeBicubicUpsample(cudaStream_t stream,
10641067

10651068
const fast_divmod div_output_height{narrow<int>(output_height * output_width)};
10661069
// clang-format off
1067-
_ComputeInterpolationAtLevel2<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
1070+
_ComputeInterpolationAtLevel2<T><<<blocksPerGridL2, GridDim::maxThreadsPerBlock, 0, stream>>>(
10681071
num_channels, input_height, output_width, output_height, output_width,
10691072
div_output_height,
10701073
div_output_width,
@@ -1085,6 +1088,7 @@ void ResizeAntiAliasImpl(
10851088
int rank,
10861089
const UpsampleMode upsample_mode,
10871090
ResizeCoordinateTransformationMode coordinate_transform_mode,
1091+
const float cubic_coeff_a,
10881092
gsl::span<const int64_t> input_shape,
10891093
gsl::span<const int64_t> output_shape,
10901094
int64_t batch_size, int64_t num_channels,
@@ -1132,7 +1136,7 @@ void ResizeAntiAliasImpl(
11321136
} break;
11331137
case CUBIC: {
11341138
if (is_2D) {
1135-
ResizeBicubicUpsample<T>(stream, rank, upsample_mode, coordinate_transform_mode,
1139+
ResizeBicubicUpsample<T>(stream, rank, upsample_mode, coordinate_transform_mode, cubic_coeff_a,
11361140
input_shape, output_shape, batch_size, num_channels,
11371141
inferred_input_dims, inferred_output_dims, inferred_dim_rscales,
11381142
output_div_pitches, roi_vals, extrapolation, exclude_outside,
@@ -1153,6 +1157,7 @@ void ResizeAntiAliasImpl(
11531157
int rank, \
11541158
const UpsampleMode upsample_mode, \
11551159
ResizeCoordinateTransformationMode coordinate_transform_mode, \
1160+
float cubic_coeff_a, \
11561161
gsl::span<const int64_t> input_shape, \
11571162
gsl::span<const int64_t> output_shape, \
11581163
int64_t batch_size, int64_t num_channels, \

onnxruntime/core/providers/cuda/tensor/resize_impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ void ResizeAntiAliasImpl(
9898
int rank,
9999
const UpsampleMode upsample_mode,
100100
ResizeCoordinateTransformationMode coordinate_transform_mode,
101+
float cubic_coeff_a,
101102
gsl::span<const int64_t> input_shape,
102103
gsl::span<const int64_t> output_shape,
103104
int64_t batch_size, int64_t num_channels,

onnxruntime/core/providers/cuda/tensor/upsample.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context,
159159
rank,
160160
mode_,
161161
coordinate_transform_mode_,
162+
cubic_coeff_a_,
162163
X_dims, output_dims,
163164
batch_size, num_channels,
164165
std::make_tuple(0, input_height, input_width),
@@ -201,6 +202,7 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context,
201202
rank,
202203
mode_,
203204
coordinate_transform_mode_,
205+
cubic_coeff_a_,
204206
X_dims, output_dims,
205207
batch_size, num_channels,
206208
std::make_tuple(input_depth, input_height, input_width),
@@ -246,7 +248,7 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context,
246248
const float height_scale = is_2D ? scales[0] : scales[2];
247249
const float width_scale = is_2D ? scales[1] : scales[3];
248250

249-
ResizeAntiAliasImpl(Stream(context), rank, mode_, coordinate_transform_mode_,
251+
ResizeAntiAliasImpl(Stream(context), rank, mode_, coordinate_transform_mode_, cubic_coeff_a_,
250252
X_dims, output_dims,
251253
batch_size, num_channels,
252254
std::make_tuple(0, input_height, input_width),

0 commit comments

Comments
 (0)