File tree Expand file tree Collapse file tree 2 files changed +2
-2
lines changed Expand file tree Collapse file tree 2 files changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -136,7 +136,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
136136 : reinterpret_cast <scan_t *>(params.x_ptr ) + (batch_id * params.dim + dim_id) * (params.n_chunks ) * params.dstate ;
137137 float dD_val = 0 ;
138138 float ddelta_bias_val = 0 ;
139- long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr ) + batch_id * params.u_batch_stride
139+ long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr ) + batch_id * params.u_batch_stride ;
140140
141141 constexpr int kChunkSize = kNThreads * kNItems ;
142142 u += (params.n_chunks - 1 ) * kChunkSize ;
Original file line number Diff line number Diff line change @@ -107,7 +107,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
107107 weight_t *C = reinterpret_cast <weight_t *>(params.C_ptr ) + dim_id * kNRows * params.C_d_stride ;
108108 input_t *Cvar = reinterpret_cast <input_t *>(params.C_ptr ) + batch_id * params.C_batch_stride + group_id * params.C_group_stride ;
109109 scan_t *x = reinterpret_cast <scan_t *>(params.x_ptr ) + (batch_id * params.dim + dim_id * kNRows ) * params.n_chunks * params.dstate ;
110- long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr ) + batch_id * params.u_batch_stride
110+ long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr ) + batch_id * params.u_batch_stride ;
111111
112112 float D_val[kNRows ] = {0 };
113113 if (params.D_ptr != nullptr ) {
You can’t perform that action at this time.
0 commit comments