From 310ed03ec19316a47fa67f5b2873e3774bf4c43a Mon Sep 17 00:00:00 2001 From: Lhongpei <1453244320@qq.com> Date: Mon, 17 Nov 2025 02:58:01 +0000 Subject: [PATCH 1/2] Fused kernel for optimizing efficiency on sparse cases --- internal/internal_types.h | 11 ++ internal/utils.h | 4 + src/solver.cu | 219 ++++++++++++++++++++++++++++++++++++++ src/utils.cu | 24 +++++ 4 files changed, 258 insertions(+) diff --git a/internal/internal_types.h b/internal/internal_types.h index e66600d..20fe8b5 100644 --- a/internal/internal_types.h +++ b/internal/internal_types.h @@ -31,6 +31,12 @@ typedef struct double *val; } cu_sparse_matrix_csr_t; +typedef enum +{ + CUSPARSE_UPDATE, + FUSED_UPDATE, +} pdhg_update_algorithm_t; + typedef struct { int num_variables; @@ -122,6 +128,9 @@ typedef struct double *ones_primal_d; double *ones_dual_d; + + pdhg_update_algorithm_t primal_update_algorithm; + pdhg_update_algorithm_t dual_update_algorithm; } pdhg_solver_state_t; typedef struct @@ -133,3 +142,5 @@ typedef struct double obj_vec_rescale; double rescaling_time_sec; } rescale_info_t; + + diff --git a/internal/utils.h b/internal/utils.h index 62da072..f84466c 100644 --- a/internal/utils.h +++ b/internal/utils.h @@ -125,6 +125,10 @@ extern "C" int coo_to_csr(const matrix_desc_t *desc, int **row_ptr, int **col_ind, double **vals, int *nnz_out); + int calculate_max_nnz_row(int m, const int* matA_row_ptr); + + int calculate_max_nnz_col(int n, const int* matAt_row_ptr); + #ifdef __cplusplus } diff --git a/src/solver.cu b/src/solver.cu index 2880e9a..6fd1fa3 100644 --- a/src/solver.cu +++ b/src/solver.cu @@ -60,7 +60,10 @@ __global__ void compute_delta_solution_kernel( const double *initial_primal, const double *pdhg_primal, double *delta_primal, const double *initial_dual, const double *pdhg_dual, double *delta_dual, int n_vars, int n_cons); + +static void fused_compute_next_pdhg_primal_solution(pdhg_solver_state_t *state); static void compute_next_pdhg_primal_solution(pdhg_solver_state_t *state); +static void fused_compute_next_pdhg_dual_solution(pdhg_solver_state_t *state); static void compute_next_pdhg_dual_solution(pdhg_solver_state_t *state); static void halpern_update(pdhg_solver_state_t *state, double reflection_coefficient); @@ -78,6 +81,8 @@ static void compute_fixed_point_error(pdhg_solver_state_t *state); void lp_problem_free(lp_problem_t *prob); void pdhg_solver_state_free(pdhg_solver_state_t *state); void rescale_info_free(rescale_info_t *info); +static void decide_fused_update_usage(pdhg_solver_state_t *state, + const pdhg_parameters_t *params); cupdlpx_result_t *optimize(const pdhg_parameters_t *params, const lp_problem_t *original_problem) @@ -87,6 +92,7 @@ cupdlpx_result_t *optimize(const pdhg_parameters_t *params, pdhg_solver_state_t *state = initialize_solver_state(original_problem, rescale_info); + decide_fused_update_usage(state, params); rescale_info_free(rescale_info); initialize_step_size_and_primal_weight(state, params); clock_t start_time = clock(); @@ -575,6 +581,11 @@ __global__ void compute_delta_solution_kernel( static void compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) { + if (state->dual_update_algorithm == FUSED_UPDATE) + { + fused_compute_next_pdhg_primal_solution(state); + return; + } CUSPARSE_CHECK(cusparseDnVecSetValues(state->vec_dual_sol, state->current_dual_solution)); CUSPARSE_CHECK( @@ -612,6 +623,11 @@ static void compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) static void compute_next_pdhg_dual_solution(pdhg_solver_state_t *state) { + if (state->primal_update_algorithm == FUSED_UPDATE) + { + fused_compute_next_pdhg_dual_solution(state); + return; + } CUSPARSE_CHECK(cusparseDnVecSetValues(state->vec_primal_sol, state->reflected_primal_solution)); CUSPARSE_CHECK( @@ -951,4 +967,207 @@ void set_default_parameters(pdhg_parameters_t *params) params->restart_params.k_i = 0.01; params->restart_params.k_d = 0.0; params->restart_params.i_smooth = 0.3; +} + +__global__ void fused_compute_next_pdhg_primal_solution_kernel( + const int *matAt_row_ptr, const int *matAt_col_ind, const double *matAt_val, + double *dual_solution, double *dual_product, + double *current_primal, double *reflected_primal, + const double *objective, const double *var_lb, const double *var_ub, double step_size, + int n_vars) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n_vars) + { + //Compute dual product + double sum = 0.0; + int row_start = matAt_row_ptr[i]; + int row_end = matAt_row_ptr[i + 1]; + for (int j = row_start; j < row_end; ++j) + { + int col = matAt_col_ind[j]; + double val = matAt_val[j]; + sum += val * dual_solution[col]; + } + double dual_prod = sum; + //Compute PDHG primal solution + double temp = current_primal[i] - step_size * (objective[i] - dual_prod); + double temp_proj = fmax(var_lb[i], fmin(temp, var_ub[i])); + reflected_primal[i] = 2.0 * temp_proj - current_primal[i]; + dual_product[i] = dual_prod; + } + return; +} + +__global__ void fused_compute_next_pdhg_primal_solution_major_kernel( + const int *matAt_row_ptr, const int *matAt_col_ind, const double *matAt_val, + double *dual_solution, double *dual_product, + double *current_primal, double *reflected_primal, double *pdhg_primal, double *dual_slack, + const double *objective, const double *var_lb, const double *var_ub, double step_size, + int n_vars) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n_vars) + { + //Compute dual product + double sum = 0.0; + int row_start = matAt_row_ptr[i]; + int row_end = matAt_row_ptr[i + 1]; + for (int j = row_start; j < row_end; ++j) + { + int col = matAt_col_ind[j]; + double val = matAt_val[j]; + sum += val * dual_solution[col]; + } + double dual_prod = sum; + //Compute PDHG primal solution + double temp = current_primal[i] - step_size * (objective[i] - dual_prod); + double temp_proj = fmax(var_lb[i], fmin(temp, var_ub[i])); + reflected_primal[i] = 2.0 * temp_proj - current_primal[i]; + dual_slack[i] = (temp_proj - temp) / step_size; + pdhg_primal[i] = temp_proj; + dual_product[i] = dual_prod; + } + return; +} + +__global__ void fused_compute_next_pdhg_dual_solution_kernel( + const int *matA_row_ptr, const int *matA_col_ind, const double *matA_val, + double *primal_solution, double *primal_product, + double *current_dual, double *reflected_dual, + const double *const_lb, const double *const_ub, double step_size, + int n_cons) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n_cons) + { + //Compute primal product + double sum = 0.0; + int row_start = matA_row_ptr[i]; + int row_end = matA_row_ptr[i + 1]; + for (int j = row_start; j < row_end; ++j) + { + int col = matA_col_ind[j]; + double val = matA_val[j]; + sum += val * primal_solution[col]; + } + double primal_prod = sum; + //Compute PDHG dual solution + double temp = current_dual[i] / step_size - primal_prod; + double temp_proj = fmax(-const_ub[i], fmin(temp, -const_lb[i])); + reflected_dual[i] = 2.0 * (temp - temp_proj) * step_size - current_dual[i]; + primal_product[i] = primal_prod; + } + return; +} + +__global__ void fused_compute_next_pdhg_dual_major_solution( + const int *matA_row_ptr, const int *matA_col_ind, const double *matA_val, + double *primal_solution, double *primal_product, + double *current_dual, double *reflected_dual, double *pdhg_dual, + const double *const_lb, const double *const_ub, double step_size, + int n_cons) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n_cons) + { + //Compute primal product + double sum = 0.0; + int row_start = matA_row_ptr[i]; + int row_end = matA_row_ptr[i + 1]; + for (int j = row_start; j < row_end; ++j) + { + int col = matA_col_ind[j]; + double val = matA_val[j]; + sum += val * primal_solution[col]; + } + double primal_prod = sum; + //Compute PDHG dual solution + double temp = current_dual[i] / step_size - primal_prod; + double temp_proj = fmax(-const_ub[i], fmin(temp, -const_lb[i])); + pdhg_dual[i] = (temp - temp_proj) * step_size; + reflected_dual[i] = 2.0 * pdhg_dual[i] - current_dual[i]; + primal_product[i] = primal_prod; + } + return; +} + +static void fused_compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) +{ + double step_size = state->step_size / state->primal_weight; + + if (state->is_this_major_iteration || ((state->total_count + 2) % get_print_frequency(state->total_count + 2)) == 0) + { + fused_compute_next_pdhg_primal_solution_major_kernel<<num_blocks_primal, THREADS_PER_BLOCK>>>( + state->constraint_matrix_t->row_ptr, state->constraint_matrix_t->col_ind, state->constraint_matrix_t->val, + state->current_dual_solution, state->dual_product, + state->current_primal_solution, state->reflected_primal_solution, + state->pdhg_primal_solution, state->dual_slack, + state->objective_vector, state->variable_lower_bound, + state->variable_upper_bound, step_size, + state->num_variables); + } + else + { + fused_compute_next_pdhg_primal_solution_kernel<<num_blocks_primal, THREADS_PER_BLOCK>>>( + state->constraint_matrix_t->row_ptr, state->constraint_matrix_t->col_ind, state->constraint_matrix_t->val, + state->current_dual_solution, state->dual_product, + state->current_primal_solution, state->reflected_primal_solution, + state->objective_vector, state->variable_lower_bound, + state->variable_upper_bound, step_size, + state->num_variables); + } +} + +static void fused_compute_next_pdhg_dual_solution(pdhg_solver_state_t *state) +{ + double step_size = state->step_size * state->primal_weight; + + if (state->is_this_major_iteration || ((state->total_count + 2) % get_print_frequency(state->total_count + 2)) == 0) + { + fused_compute_next_pdhg_dual_major_solution<<num_blocks_dual, THREADS_PER_BLOCK>>>( + state->constraint_matrix->row_ptr, state->constraint_matrix->col_ind, state->constraint_matrix->val, + state->reflected_primal_solution, state->primal_product, + state->current_dual_solution, state->reflected_dual_solution, state->pdhg_dual_solution, + state->constraint_lower_bound, state->constraint_upper_bound, + step_size, + state->num_constraints); + } + else + { + fused_compute_next_pdhg_dual_solution_kernel<<num_blocks_dual, THREADS_PER_BLOCK>>>( + state->constraint_matrix->row_ptr, state->constraint_matrix->col_ind, state->constraint_matrix->val, + state->reflected_primal_solution, state->primal_product, + state->current_dual_solution, state->reflected_dual_solution, + state->constraint_lower_bound, state->constraint_upper_bound, + step_size, + state->num_constraints); + } +} + +static void decide_fused_update_usage(pdhg_solver_state_t *state, + const pdhg_parameters_t *params) +{ + // Heuristic to decide whether to use fused kernels or not + // Currently, we use fused kernels when the number of non-zeros is less than a threshold + int n_cons = state->num_constraints; + int n_vars = state->num_variables; + int max_nnz_A_row = calculate_max_nnz_row(n_cons, state->constraint_matrix->row_ptr); + int max_nnz_At_row = calculate_max_nnz_row(n_vars, state->constraint_matrix_t->row_ptr); + int fusion_nnz_threshold = 100; + if (max_nnz_A_row > fusion_nnz_threshold) state->dual_update_algorithm = CUSPARSE_UPDATE; + else state->dual_update_algorithm = FUSED_UPDATE; + if (max_nnz_At_row > fusion_nnz_threshold) state->primal_update_algorithm = CUSPARSE_UPDATE; + else state->primal_update_algorithm = FUSED_UPDATE; + if (params->verbose) + { + if (state->primal_update_algorithm == FUSED_UPDATE) + printf("Using fused primal update kernel.\n"); + else + printf("Using cuSPARSE primal update kernel.\n"); + if (state->dual_update_algorithm == FUSED_UPDATE) + printf("Using fused dual update kernel.\n"); + else + printf("Using cuSPARSE dual update kernel.\n"); + } } \ No newline at end of file diff --git a/src/utils.cu b/src/utils.cu index 83ba194..a22929c 100644 --- a/src/utils.cu +++ b/src/utils.cu @@ -961,3 +961,27 @@ int coo_to_csr(const matrix_desc_t *desc, int **row_ptr, int **col_ind, *nnz_out = nnz; return 0; } + +int calculate_max_nnz_row(int m, const int* matA_row_ptr) { + int max_nnz = 0; + int* host_ptr = (int*)safe_malloc((size_t)(m + 1) * sizeof(int)); + CUDA_CHECK(cudaMemcpy(host_ptr, matA_row_ptr, (size_t)(m + 1) * sizeof(int), cudaMemcpyDeviceToHost)); + for (int i = 0; i < m; ++i) { + int row_nnz = host_ptr[i + 1] - host_ptr[i]; + if (row_nnz > max_nnz) max_nnz = row_nnz; + } + free(host_ptr); + return max_nnz; +} + +int calculate_max_nnz_col(int n, const int* matAt_row_ptr) { + int max_nnz = 0; + int* host_ptr = (int*)safe_malloc((size_t)(n + 1) * sizeof(int)); + CUDA_CHECK(cudaMemcpy(host_ptr, matAt_row_ptr, (size_t)(n + 1) * sizeof(int), cudaMemcpyDeviceToHost)); + for (int j = 0; j < n; ++j) { + int col_nnz = host_ptr[j + 1] - host_ptr[j]; + if (col_nnz > max_nnz) max_nnz = col_nnz; + } + free(host_ptr); + return max_nnz; +} \ No newline at end of file From 0e9b04498be423021d0566c3cf265f0e0b168808 Mon Sep 17 00:00:00 2001 From: Lhongpei <1453244320@qq.com> Date: Mon, 17 Nov 2025 14:43:01 +0000 Subject: [PATCH 2/2] Improve kernel efficiency --- src/solver.cu | 119 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 81 insertions(+), 38 deletions(-) diff --git a/src/solver.cu b/src/solver.cu index 6fd1fa3..2454ef0 100644 --- a/src/solver.cu +++ b/src/solver.cu @@ -581,7 +581,7 @@ __global__ void compute_delta_solution_kernel( static void compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) { - if (state->dual_update_algorithm == FUSED_UPDATE) + if (state->primal_update_algorithm == FUSED_UPDATE) { fused_compute_next_pdhg_primal_solution(state); return; @@ -623,7 +623,7 @@ static void compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) static void compute_next_pdhg_dual_solution(pdhg_solver_state_t *state) { - if (state->primal_update_algorithm == FUSED_UPDATE) + if (state->dual_update_algorithm == FUSED_UPDATE) { fused_compute_next_pdhg_dual_solution(state); return; @@ -970,10 +970,18 @@ void set_default_parameters(pdhg_parameters_t *params) } __global__ void fused_compute_next_pdhg_primal_solution_kernel( - const int *matAt_row_ptr, const int *matAt_col_ind, const double *matAt_val, - double *dual_solution, double *dual_product, - double *current_primal, double *reflected_primal, - const double *objective, const double *var_lb, const double *var_ub, double step_size, + const int * __restrict__ matAt_row_ptr, + const int * __restrict__ matAt_col_ind, + const double * __restrict__ matAt_val, + const double * __restrict__ dual_solution, + double * __restrict__ dual_product, + const double * __restrict__ current_primal, + double * __restrict__ reflected_primal, + const double * __restrict__ objective, + const double * __restrict__ var_lb, + const double * __restrict__ var_ub, + double step_size, + double inv_step_size, int n_vars) { int i = blockIdx.x * blockDim.x + threadIdx.x; @@ -987,9 +995,10 @@ __global__ void fused_compute_next_pdhg_primal_solution_kernel( { int col = matAt_col_ind[j]; double val = matAt_val[j]; - sum += val * dual_solution[col]; + sum += val * __ldg(&dual_solution[col]); } double dual_prod = sum; + //Compute PDHG primal solution double temp = current_primal[i] - step_size * (objective[i] - dual_prod); double temp_proj = fmax(var_lb[i], fmin(temp, var_ub[i])); @@ -1000,10 +1009,20 @@ __global__ void fused_compute_next_pdhg_primal_solution_kernel( } __global__ void fused_compute_next_pdhg_primal_solution_major_kernel( - const int *matAt_row_ptr, const int *matAt_col_ind, const double *matAt_val, - double *dual_solution, double *dual_product, - double *current_primal, double *reflected_primal, double *pdhg_primal, double *dual_slack, - const double *objective, const double *var_lb, const double *var_ub, double step_size, + const int * __restrict__ matAt_row_ptr, + const int * __restrict__ matAt_col_ind, + const double * __restrict__ matAt_val, + const double * __restrict__ dual_solution, + double * __restrict__ dual_product, + const double * __restrict__ current_primal, + double * __restrict__ reflected_primal, + double * __restrict__ pdhg_primal, + double * __restrict__ dual_slack, + const double * __restrict__ objective, + const double * __restrict__ var_lb, + const double * __restrict__ var_ub, + double step_size, + double inv_step_size, int n_vars) { int i = blockIdx.x * blockDim.x + threadIdx.x; @@ -1017,25 +1036,36 @@ __global__ void fused_compute_next_pdhg_primal_solution_major_kernel( { int col = matAt_col_ind[j]; double val = matAt_val[j]; - sum += val * dual_solution[col]; + sum += val * __ldg(&dual_solution[col]); } double dual_prod = sum; + //Compute PDHG primal solution double temp = current_primal[i] - step_size * (objective[i] - dual_prod); double temp_proj = fmax(var_lb[i], fmin(temp, var_ub[i])); reflected_primal[i] = 2.0 * temp_proj - current_primal[i]; - dual_slack[i] = (temp_proj - temp) / step_size; + + dual_slack[i] = (temp_proj - temp) * inv_step_size; + pdhg_primal[i] = temp_proj; dual_product[i] = dual_prod; } return; } -__global__ void fused_compute_next_pdhg_dual_solution_kernel( - const int *matA_row_ptr, const int *matA_col_ind, const double *matA_val, - double *primal_solution, double *primal_product, - double *current_dual, double *reflected_dual, - const double *const_lb, const double *const_ub, double step_size, +__global__ void fused_compute_next_pdhg_dual_solution_major_kernel( + const int * __restrict__ matA_row_ptr, + const int * __restrict__ matA_col_ind, + const double * __restrict__ matA_val, + const double * __restrict__ primal_solution, + double * __restrict__ primal_product, + const double * __restrict__ current_dual, + double * __restrict__ reflected_dual, + double * __restrict__ pdhg_dual, + const double * __restrict__ const_lb, + const double * __restrict__ const_ub, + double step_size, + double inv_step_size, int n_cons) { int i = blockIdx.x * blockDim.x + threadIdx.x; @@ -1049,23 +1079,31 @@ __global__ void fused_compute_next_pdhg_dual_solution_kernel( { int col = matA_col_ind[j]; double val = matA_val[j]; - sum += val * primal_solution[col]; + sum += val * __ldg(&primal_solution[col]); } double primal_prod = sum; //Compute PDHG dual solution - double temp = current_dual[i] / step_size - primal_prod; + double temp = current_dual[i]* inv_step_size - primal_prod; double temp_proj = fmax(-const_ub[i], fmin(temp, -const_lb[i])); - reflected_dual[i] = 2.0 * (temp - temp_proj) * step_size - current_dual[i]; + pdhg_dual[i] = (temp - temp_proj) * step_size; + reflected_dual[i] = 2.0 * pdhg_dual[i] - current_dual[i]; primal_product[i] = primal_prod; } return; } -__global__ void fused_compute_next_pdhg_dual_major_solution( - const int *matA_row_ptr, const int *matA_col_ind, const double *matA_val, - double *primal_solution, double *primal_product, - double *current_dual, double *reflected_dual, double *pdhg_dual, - const double *const_lb, const double *const_ub, double step_size, +__global__ void fused_compute_next_pdhg_dual_solution_kernel( + const int * __restrict__ matA_row_ptr, + const int * __restrict__ matA_col_ind, + const double * __restrict__ matA_val, + const double * __restrict__ primal_solution, + double * __restrict__ primal_product, + const double * __restrict__ current_dual, + double * __restrict__ reflected_dual, + const double * __restrict__ const_lb, + const double * __restrict__ const_ub, + double step_size, + double inv_step_size, int n_cons) { int i = blockIdx.x * blockDim.x + threadIdx.x; @@ -1079,14 +1117,13 @@ __global__ void fused_compute_next_pdhg_dual_major_solution( { int col = matA_col_ind[j]; double val = matA_val[j]; - sum += val * primal_solution[col]; + sum += val * __ldg(&primal_solution[col]); } double primal_prod = sum; //Compute PDHG dual solution - double temp = current_dual[i] / step_size - primal_prod; + double temp = current_dual[i] * inv_step_size - primal_prod; double temp_proj = fmax(-const_ub[i], fmin(temp, -const_lb[i])); - pdhg_dual[i] = (temp - temp_proj) * step_size; - reflected_dual[i] = 2.0 * pdhg_dual[i] - current_dual[i]; + reflected_dual[i] = 2.0 * (temp - temp_proj) * step_size - current_dual[i]; primal_product[i] = primal_prod; } return; @@ -1095,7 +1132,7 @@ __global__ void fused_compute_next_pdhg_dual_major_solution( static void fused_compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) { double step_size = state->step_size / state->primal_weight; - + double inv_step_size = 1.0 / step_size; if (state->is_this_major_iteration || ((state->total_count + 2) % get_print_frequency(state->total_count + 2)) == 0) { fused_compute_next_pdhg_primal_solution_major_kernel<<num_blocks_primal, THREADS_PER_BLOCK>>>( @@ -1104,7 +1141,7 @@ static void fused_compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) state->current_primal_solution, state->reflected_primal_solution, state->pdhg_primal_solution, state->dual_slack, state->objective_vector, state->variable_lower_bound, - state->variable_upper_bound, step_size, + state->variable_upper_bound, step_size, inv_step_size, state->num_variables); } else @@ -1114,7 +1151,7 @@ static void fused_compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) state->current_dual_solution, state->dual_product, state->current_primal_solution, state->reflected_primal_solution, state->objective_vector, state->variable_lower_bound, - state->variable_upper_bound, step_size, + state->variable_upper_bound, step_size, inv_step_size, state->num_variables); } } @@ -1122,15 +1159,16 @@ static void fused_compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) static void fused_compute_next_pdhg_dual_solution(pdhg_solver_state_t *state) { double step_size = state->step_size * state->primal_weight; + double inv_step_size = 1.0 / step_size; if (state->is_this_major_iteration || ((state->total_count + 2) % get_print_frequency(state->total_count + 2)) == 0) { - fused_compute_next_pdhg_dual_major_solution<<num_blocks_dual, THREADS_PER_BLOCK>>>( + fused_compute_next_pdhg_dual_solution_major_kernel<<num_blocks_dual, THREADS_PER_BLOCK>>>( state->constraint_matrix->row_ptr, state->constraint_matrix->col_ind, state->constraint_matrix->val, state->reflected_primal_solution, state->primal_product, state->current_dual_solution, state->reflected_dual_solution, state->pdhg_dual_solution, state->constraint_lower_bound, state->constraint_upper_bound, - step_size, + step_size, inv_step_size, state->num_constraints); } else @@ -1140,7 +1178,7 @@ static void fused_compute_next_pdhg_dual_solution(pdhg_solver_state_t *state) state->reflected_primal_solution, state->primal_product, state->current_dual_solution, state->reflected_dual_solution, state->constraint_lower_bound, state->constraint_upper_bound, - step_size, + step_size, inv_step_size, state->num_constraints); } } @@ -1155,9 +1193,14 @@ static void decide_fused_update_usage(pdhg_solver_state_t *state, int max_nnz_A_row = calculate_max_nnz_row(n_cons, state->constraint_matrix->row_ptr); int max_nnz_At_row = calculate_max_nnz_row(n_vars, state->constraint_matrix_t->row_ptr); int fusion_nnz_threshold = 100; - if (max_nnz_A_row > fusion_nnz_threshold) state->dual_update_algorithm = CUSPARSE_UPDATE; + double fusion_density_threshold = 0.01; + int primal_threshold = fmin(fusion_nnz_threshold, + (int)(fusion_density_threshold * n_cons)); + int dual_threshold = fmin(fusion_nnz_threshold, + (int)(fusion_density_threshold * n_vars)); + if (max_nnz_A_row > dual_threshold) state->dual_update_algorithm = CUSPARSE_UPDATE; else state->dual_update_algorithm = FUSED_UPDATE; - if (max_nnz_At_row > fusion_nnz_threshold) state->primal_update_algorithm = CUSPARSE_UPDATE; + if (max_nnz_At_row > primal_threshold) state->primal_update_algorithm = CUSPARSE_UPDATE; else state->primal_update_algorithm = FUSED_UPDATE; if (params->verbose) {