diff --git a/internal/internal_types.h b/internal/internal_types.h index 1f42582..9d4b0ec 100644 --- a/internal/internal_types.h +++ b/internal/internal_types.h @@ -31,6 +31,12 @@ typedef struct double *val; } cu_sparse_matrix_csr_t; +typedef enum +{ + CUSPARSE_UPDATE, + FUSED_UPDATE, +} pdhg_update_algorithm_t; + typedef struct { int num_variables; @@ -123,6 +129,9 @@ typedef struct double *ones_primal_d; double *ones_dual_d; + pdhg_update_algorithm_t primal_update_algorithm; + pdhg_update_algorithm_t dual_update_algorithm; + double feasibility_polishing_time; int feasibility_iteration; } pdhg_solver_state_t; @@ -136,3 +145,5 @@ typedef struct double obj_vec_rescale; double rescaling_time_sec; } rescale_info_t; + + diff --git a/internal/utils.h b/internal/utils.h index 4998801..34fc559 100644 --- a/internal/utils.h +++ b/internal/utils.h @@ -125,6 +125,10 @@ extern "C" int coo_to_csr(const matrix_desc_t *desc, int **row_ptr, int **col_ind, double **vals, int *nnz_out); + int calculate_max_nnz_row(int m, const int* matA_row_ptr); + + int calculate_max_nnz_col(int n, const int* matAt_row_ptr); + void check_feas_polishing_termination_criteria( pdhg_solver_state_t *solver_state, const termination_criteria_t *criteria, diff --git a/src/solver.cu b/src/solver.cu index 4129501..89b4302 100644 --- a/src/solver.cu +++ b/src/solver.cu @@ -60,7 +60,10 @@ __global__ void compute_delta_solution_kernel( const double *initial_primal, const double *pdhg_primal, double *delta_primal, const double *initial_dual, const double *pdhg_dual, double *delta_dual, int n_vars, int n_cons); + +static void fused_compute_next_pdhg_primal_solution(pdhg_solver_state_t *state); static void compute_next_pdhg_primal_solution(pdhg_solver_state_t *state); +static void fused_compute_next_pdhg_dual_solution(pdhg_solver_state_t *state); static void compute_next_pdhg_dual_solution(pdhg_solver_state_t *state); static void halpern_update(pdhg_solver_state_t *state, double reflection_coefficient); @@ -77,6 +80,8 @@ initialize_solver_state(const lp_problem_t *original_problem, static void compute_fixed_point_error(pdhg_solver_state_t *state); void pdhg_solver_state_free(pdhg_solver_state_t *state); void rescale_info_free(rescale_info_t *info); +static void decide_fused_update_usage(pdhg_solver_state_t *state, + const pdhg_parameters_t *params); static void perform_primal_restart(pdhg_solver_state_t *state); static void perform_dual_restart(pdhg_solver_state_t *state); @@ -101,6 +106,7 @@ cupdlpx_result_t *optimize(const pdhg_parameters_t *params, pdhg_solver_state_t *state = initialize_solver_state(original_problem, rescale_info); + decide_fused_update_usage(state, params); rescale_info_free(rescale_info); initialize_step_size_and_primal_weight(state, params); clock_t start_time = clock(); @@ -597,6 +603,11 @@ __global__ void compute_delta_solution_kernel( static void compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) { + if (state->primal_update_algorithm == FUSED_UPDATE) + { + fused_compute_next_pdhg_primal_solution(state); + return; + } CUSPARSE_CHECK(cusparseDnVecSetValues(state->vec_dual_sol, state->current_dual_solution)); CUSPARSE_CHECK( @@ -634,6 +645,11 @@ static void compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) static void compute_next_pdhg_dual_solution(pdhg_solver_state_t *state) { + if (state->dual_update_algorithm == FUSED_UPDATE) + { + fused_compute_next_pdhg_dual_solution(state); + return; + } CUSPARSE_CHECK(cusparseDnVecSetValues(state->vec_primal_sol, state->reflected_primal_solution)); CUSPARSE_CHECK( @@ -978,6 +994,125 @@ void set_default_parameters(pdhg_parameters_t *params) params->restart_params.i_smooth = 0.3; } +__global__ void fused_compute_next_pdhg_primal_solution_kernel( + const int * __restrict__ matAt_row_ptr, + const int * __restrict__ matAt_col_ind, + const double * __restrict__ matAt_val, + const double * __restrict__ dual_solution, + double * __restrict__ dual_product, + const double * __restrict__ current_primal, + double * __restrict__ reflected_primal, + const double * __restrict__ objective, + const double * __restrict__ var_lb, + const double * __restrict__ var_ub, + double step_size, + double inv_step_size, + int n_vars) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n_vars) + { + //Compute dual product + double sum = 0.0; + int row_start = matAt_row_ptr[i]; + int row_end = matAt_row_ptr[i + 1]; + for (int j = row_start; j < row_end; ++j) + { + int col = matAt_col_ind[j]; + double val = matAt_val[j]; + sum += val * __ldg(&dual_solution[col]); + } + double dual_prod = sum; + + //Compute PDHG primal solution + double temp = current_primal[i] - step_size * (objective[i] - dual_prod); + double temp_proj = fmax(var_lb[i], fmin(temp, var_ub[i])); + reflected_primal[i] = 2.0 * temp_proj - current_primal[i]; + dual_product[i] = dual_prod; + } + return; +} + +__global__ void fused_compute_next_pdhg_primal_solution_major_kernel( + const int * __restrict__ matAt_row_ptr, + const int * __restrict__ matAt_col_ind, + const double * __restrict__ matAt_val, + const double * __restrict__ dual_solution, + double * __restrict__ dual_product, + const double * __restrict__ current_primal, + double * __restrict__ reflected_primal, + double * __restrict__ pdhg_primal, + double * __restrict__ dual_slack, + const double * __restrict__ objective, + const double * __restrict__ var_lb, + const double * __restrict__ var_ub, + double step_size, + double inv_step_size, + int n_vars) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n_vars) + { + //Compute dual product + double sum = 0.0; + int row_start = matAt_row_ptr[i]; + int row_end = matAt_row_ptr[i + 1]; + for (int j = row_start; j < row_end; ++j) + { + int col = matAt_col_ind[j]; + double val = matAt_val[j]; + sum += val * __ldg(&dual_solution[col]); + } + double dual_prod = sum; + + //Compute PDHG primal solution + double temp = current_primal[i] - step_size * (objective[i] - dual_prod); + double temp_proj = fmax(var_lb[i], fmin(temp, var_ub[i])); + reflected_primal[i] = 2.0 * temp_proj - current_primal[i]; + + dual_slack[i] = (temp_proj - temp) * inv_step_size; + + pdhg_primal[i] = temp_proj; + dual_product[i] = dual_prod; + } + return; +} + +__global__ void fused_compute_next_pdhg_dual_solution_major_kernel( + const int * __restrict__ matA_row_ptr, + const int * __restrict__ matA_col_ind, + const double * __restrict__ matA_val, + const double * __restrict__ primal_solution, + double * __restrict__ primal_product, + const double * __restrict__ current_dual, + double * __restrict__ reflected_dual, + double * __restrict__ pdhg_dual, + const double * __restrict__ const_lb, + const double * __restrict__ const_ub, + double step_size, + double inv_step_size, + int n_cons) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n_cons) + { + //Compute primal product + double sum = 0.0; + int row_start = matA_row_ptr[i]; + int row_end = matA_row_ptr[i + 1]; + for (int j = row_start; j < row_end; ++j) + { + int col = matA_col_ind[j]; + double val = matA_val[j]; + sum += val * __ldg(&primal_solution[col]); + } + double primal_prod = sum; + //Compute PDHG dual solution + double temp = current_dual[i]* inv_step_size - primal_prod; + double temp_proj = fmax(-const_ub[i], fmin(temp, -const_lb[i])); + pdhg_dual[i] = (temp - temp_proj) * step_size; + reflected_dual[i] = 2.0 * pdhg_dual[i] - current_dual[i]; + primal_product[i] = primal_prod; //Feasibility Polishing void feasibility_polish(const pdhg_parameters_t *params, pdhg_solver_state_t *state) { @@ -1091,6 +1226,39 @@ void primal_feasibility_polish(const pdhg_parameters_t *params, pdhg_solver_stat return; } +__global__ void fused_compute_next_pdhg_dual_solution_kernel( + const int * __restrict__ matA_row_ptr, + const int * __restrict__ matA_col_ind, + const double * __restrict__ matA_val, + const double * __restrict__ primal_solution, + double * __restrict__ primal_product, + const double * __restrict__ current_dual, + double * __restrict__ reflected_dual, + const double * __restrict__ const_lb, + const double * __restrict__ const_ub, + double step_size, + double inv_step_size, + int n_cons) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n_cons) + { + //Compute primal product + double sum = 0.0; + int row_start = matA_row_ptr[i]; + int row_end = matA_row_ptr[i + 1]; + for (int j = row_start; j < row_end; ++j) + { + int col = matA_col_ind[j]; + double val = matA_val[j]; + sum += val * __ldg(&primal_solution[col]); + } + double primal_prod = sum; + //Compute PDHG dual solution + double temp = current_dual[i] * inv_step_size - primal_prod; + double temp_proj = fmax(-const_ub[i], fmin(temp, -const_lb[i])); + reflected_dual[i] = 2.0 * (temp - temp_proj) * step_size - current_dual[i]; + primal_product[i] = primal_prod; void dual_feasibility_polish(const pdhg_parameters_t *params, pdhg_solver_state_t *state, const pdhg_solver_state_t *ori_state) { print_initial_feas_polish_info(false, params); @@ -1137,6 +1305,91 @@ void dual_feasibility_polish(const pdhg_parameters_t *params, pdhg_solver_state_ return; } +static void fused_compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) +{ + double step_size = state->step_size / state->primal_weight; + double inv_step_size = 1.0 / step_size; + if (state->is_this_major_iteration || ((state->total_count + 2) % get_print_frequency(state->total_count + 2)) == 0) + { + fused_compute_next_pdhg_primal_solution_major_kernel<<num_blocks_primal, THREADS_PER_BLOCK>>>( + state->constraint_matrix_t->row_ptr, state->constraint_matrix_t->col_ind, state->constraint_matrix_t->val, + state->current_dual_solution, state->dual_product, + state->current_primal_solution, state->reflected_primal_solution, + state->pdhg_primal_solution, state->dual_slack, + state->objective_vector, state->variable_lower_bound, + state->variable_upper_bound, step_size, inv_step_size, + state->num_variables); + } + else + { + fused_compute_next_pdhg_primal_solution_kernel<<num_blocks_primal, THREADS_PER_BLOCK>>>( + state->constraint_matrix_t->row_ptr, state->constraint_matrix_t->col_ind, state->constraint_matrix_t->val, + state->current_dual_solution, state->dual_product, + state->current_primal_solution, state->reflected_primal_solution, + state->objective_vector, state->variable_lower_bound, + state->variable_upper_bound, step_size, inv_step_size, + state->num_variables); + } +} + +static void fused_compute_next_pdhg_dual_solution(pdhg_solver_state_t *state) +{ + double step_size = state->step_size * state->primal_weight; + double inv_step_size = 1.0 / step_size; + + if (state->is_this_major_iteration || ((state->total_count + 2) % get_print_frequency(state->total_count + 2)) == 0) + { + fused_compute_next_pdhg_dual_solution_major_kernel<<num_blocks_dual, THREADS_PER_BLOCK>>>( + state->constraint_matrix->row_ptr, state->constraint_matrix->col_ind, state->constraint_matrix->val, + state->reflected_primal_solution, state->primal_product, + state->current_dual_solution, state->reflected_dual_solution, state->pdhg_dual_solution, + state->constraint_lower_bound, state->constraint_upper_bound, + step_size, inv_step_size, + state->num_constraints); + } + else + { + fused_compute_next_pdhg_dual_solution_kernel<<num_blocks_dual, THREADS_PER_BLOCK>>>( + state->constraint_matrix->row_ptr, state->constraint_matrix->col_ind, state->constraint_matrix->val, + state->reflected_primal_solution, state->primal_product, + state->current_dual_solution, state->reflected_dual_solution, + state->constraint_lower_bound, state->constraint_upper_bound, + step_size, inv_step_size, + state->num_constraints); + } +} + +static void decide_fused_update_usage(pdhg_solver_state_t *state, + const pdhg_parameters_t *params) +{ + // Heuristic to decide whether to use fused kernels or not + // Currently, we use fused kernels when the number of non-zeros is less than a threshold + int n_cons = state->num_constraints; + int n_vars = state->num_variables; + int max_nnz_A_row = calculate_max_nnz_row(n_cons, state->constraint_matrix->row_ptr); + int max_nnz_At_row = calculate_max_nnz_row(n_vars, state->constraint_matrix_t->row_ptr); + int fusion_nnz_threshold = 100; + double fusion_density_threshold = 0.01; + int primal_threshold = fmin(fusion_nnz_threshold, + (int)(fusion_density_threshold * n_cons)); + int dual_threshold = fmin(fusion_nnz_threshold, + (int)(fusion_density_threshold * n_vars)); + if (max_nnz_A_row > dual_threshold) state->dual_update_algorithm = CUSPARSE_UPDATE; + else state->dual_update_algorithm = FUSED_UPDATE; + if (max_nnz_At_row > primal_threshold) state->primal_update_algorithm = CUSPARSE_UPDATE; + else state->primal_update_algorithm = FUSED_UPDATE; + if (params->verbose) + { + if (state->primal_update_algorithm == FUSED_UPDATE) + printf("Using fused primal update kernel.\n"); + else + printf("Using cuSPARSE primal update kernel.\n"); + if (state->dual_update_algorithm == FUSED_UPDATE) + printf("Using fused dual update kernel.\n"); + else + printf("Using cuSPARSE dual update kernel.\n"); + } +} static pdhg_solver_state_t *initialize_primal_feas_polish_state( const pdhg_solver_state_t *original_state) { diff --git a/src/utils.cu b/src/utils.cu index 96ae8d2..62f6bac 100644 --- a/src/utils.cu +++ b/src/utils.cu @@ -964,6 +964,29 @@ int coo_to_csr(const matrix_desc_t *desc, int **row_ptr, int **col_ind, return 0; } +int calculate_max_nnz_row(int m, const int* matA_row_ptr) { + int max_nnz = 0; + int* host_ptr = (int*)safe_malloc((size_t)(m + 1) * sizeof(int)); + CUDA_CHECK(cudaMemcpy(host_ptr, matA_row_ptr, (size_t)(m + 1) * sizeof(int), cudaMemcpyDeviceToHost)); + for (int i = 0; i < m; ++i) { + int row_nnz = host_ptr[i + 1] - host_ptr[i]; + if (row_nnz > max_nnz) max_nnz = row_nnz; + } + free(host_ptr); + return max_nnz; +} + +int calculate_max_nnz_col(int n, const int* matAt_row_ptr) { + int max_nnz = 0; + int* host_ptr = (int*)safe_malloc((size_t)(n + 1) * sizeof(int)); + CUDA_CHECK(cudaMemcpy(host_ptr, matAt_row_ptr, (size_t)(n + 1) * sizeof(int), cudaMemcpyDeviceToHost)); + for (int j = 0; j < n; ++j) { + int col_nnz = host_ptr[j + 1] - host_ptr[j]; + if (col_nnz > max_nnz) max_nnz = col_nnz; + } + free(host_ptr); + return max_nnz; +} void check_feas_polishing_termination_criteria( pdhg_solver_state_t *solver_state, const termination_criteria_t *criteria,