@@ -961,6 +961,7 @@ void ResizeBicubicUpsample(cudaStream_t stream,
961961 int rank,
962962 const UpsampleMode /* upsample_mode*/ ,
963963 ResizeCoordinateTransformationMode coordinate_transform_mode,
964+ const float cubic_coeff_a,
964965 gsl::span<const int64_t > /* input_shape*/ ,
965966 gsl::span<const int64_t > /* output_shape*/ ,
966967 int64_t batch_size, int64_t num_channels,
@@ -982,19 +983,22 @@ void ResizeBicubicUpsample(cudaStream_t stream,
982983 const bool use_extrapolation = extrapolation.has_value ();
983984 const float extrapolation_value = use_extrapolation ? *extrapolation : 0 .f ;
984985
985- int blocksPerGrid = narrow<int >(CeilDiv (N, GridDim::maxThreadsPerBlock));
986- const fast_divmod div_output_image = (rank > 2 ) ? output_div_pitches[rank - 4 ]
987- : fast_divmod (gsl::narrow_cast<int >(N));
988- const fast_divmod& div_output_width = output_div_pitches[rank - 2 ];
989-
990- constexpr float support_value = antialias_constants::kBiCubicSupportSize ;
991-
992986 int64_t input_depth, input_height, input_width;
993987 std::tie (input_depth, input_height, input_width) = inferred_input_dims;
994988
995989 int64_t output_depth, output_height, output_width;
996990 std::tie (output_depth, output_height, output_width) = inferred_output_dims;
997991
992+ const auto temp_buf_size = SafeInt<int64_t >(batch_size) * num_channels * input_height * output_width;
993+
994+ int blocksPerGridL2 = narrow<int >(CeilDiv (N, GridDim::maxThreadsPerBlock));
995+ int blocksPerGridL1 = narrow<int >(CeilDiv (temp_buf_size, GridDim::maxThreadsPerBlock));
996+ const fast_divmod div_output_image = (rank > 2 ) ? output_div_pitches[rank - 4 ]
997+ : fast_divmod (gsl::narrow_cast<int >(N));
998+ const fast_divmod& div_output_width = output_div_pitches[rank - 2 ];
999+
1000+ constexpr float support_value = antialias_constants::kBiCubicSupportSize ;
1001+
9981002 int blocksPerDimsMappingGrid =
9991003 narrow<int >(CeilDiv ((output_depth + output_height + output_width), 32 ));
10001004
@@ -1027,7 +1031,6 @@ void ResizeBicubicUpsample(cudaStream_t stream,
10271031 AccumType* y_weighted_buffer = GetTyped<AccumType>(weighted_buffer_ptr);
10281032 AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size;
10291033
1030- const auto temp_buf_size = SafeInt<int64_t >(batch_size) * num_channels * input_height * output_width;
10311034 auto image_temp_buffer = AllocateTyped<T>(allocate_temp_space, narrow<size_t >(temp_buf_size));
10321035
10331036 // clang-format off
@@ -1042,15 +1045,15 @@ void ResizeBicubicUpsample(cudaStream_t stream,
10421045 std::make_tuple (roi_vals[rank - 2 + rank], roi_vals[rank - 1 + rank]), // roi ends h, w
10431046 std::make_tuple (h_scaled_support, w_scaled_support),
10441047 std::make_tuple (h_window_size, w_window_size),
1045- onnxruntime::antialias_constants:: kCubicCoeffA , exclude_outside,
1048+ cubic_coeff_a , exclude_outside,
10461049 GetTyped<int64_t >(bounds_buffer_ptr),
10471050 GetTyped<int64_t >(out_of_bounds_buffer_ptr),
10481051 std::make_tuple (y_weighted_buffer, w_weighted_buffer));
10491052 });
10501053 // clang-format on
10511054 const fast_divmod div_step_image (narrow<int >(num_channels * input_height * output_width));
10521055 // clang-format off
1053- _ComputeInterpolationAtLevel1<T><<<blocksPerGrid , GridDim::maxThreadsPerBlock, 0 , stream>>> (
1056+ _ComputeInterpolationAtLevel1<T><<<blocksPerGridL1 , GridDim::maxThreadsPerBlock, 0 , stream>>> (
10541057 num_channels, input_height, input_width, input_height, output_width,
10551058 div_output_width,
10561059 div_step_image,
@@ -1064,7 +1067,7 @@ void ResizeBicubicUpsample(cudaStream_t stream,
10641067
10651068 const fast_divmod div_output_height{narrow<int >(output_height * output_width)};
10661069 // clang-format off
1067- _ComputeInterpolationAtLevel2<T><<<blocksPerGrid , GridDim::maxThreadsPerBlock, 0 , stream>>> (
1070+ _ComputeInterpolationAtLevel2<T><<<blocksPerGridL2 , GridDim::maxThreadsPerBlock, 0 , stream>>> (
10681071 num_channels, input_height, output_width, output_height, output_width,
10691072 div_output_height,
10701073 div_output_width,
@@ -1085,6 +1088,7 @@ void ResizeAntiAliasImpl(
10851088 int rank,
10861089 const UpsampleMode upsample_mode,
10871090 ResizeCoordinateTransformationMode coordinate_transform_mode,
1091+ const float cubic_coeff_a,
10881092 gsl::span<const int64_t > input_shape,
10891093 gsl::span<const int64_t > output_shape,
10901094 int64_t batch_size, int64_t num_channels,
@@ -1132,7 +1136,7 @@ void ResizeAntiAliasImpl(
11321136 } break ;
11331137 case CUBIC: {
11341138 if (is_2D) {
1135- ResizeBicubicUpsample<T>(stream, rank, upsample_mode, coordinate_transform_mode,
1139+ ResizeBicubicUpsample<T>(stream, rank, upsample_mode, coordinate_transform_mode, cubic_coeff_a,
11361140 input_shape, output_shape, batch_size, num_channels,
11371141 inferred_input_dims, inferred_output_dims, inferred_dim_rscales,
11381142 output_div_pitches, roi_vals, extrapolation, exclude_outside,
@@ -1153,6 +1157,7 @@ void ResizeAntiAliasImpl(
11531157 int rank, \
11541158 const UpsampleMode upsample_mode, \
11551159 ResizeCoordinateTransformationMode coordinate_transform_mode, \
1160+ float cubic_coeff_a, \
11561161 gsl::span<const int64_t > input_shape, \
11571162 gsl::span<const int64_t > output_shape, \
11581163 int64_t batch_size, int64_t num_channels, \
0 commit comments