@@ -418,6 +418,31 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
418418 typename Ktraits::BlockStoreT (smem_store).Store (out, out_vals_store, seqlen - chunk * kChunkSize );
419419 }
420420 out += kChunkSize ;
421+
422+ int final_state_position = ((seqlen - (kWidth - 1 )) - (n_chunks - 1 ) * kChunkSize );
423+ // in case the final state is separated between the last "smem_exchange" and
424+ // and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
425+ // (which occurs when `final_state_position` is a non-positivie index)
426+ // we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
427+ if (final_state_position < 0 && seqlen > kWidth ){
428+ input_t vals_load[kNElts ] = {0 };
429+ if ((chunk == n_chunks - 2 ) && (tidx == kNThreads - 1 )){
430+ // chunk = n_chunks - 2, a segment of the final state sits in the last index
431+ reinterpret_cast <vec_t *>(vals_load)[0 ] = smem_exchange[kNThreads - 1 ];
432+ #pragma unroll
433+ for (int w = 0 ; w < -final_state_position; ++w){
434+ conv_states[w] = vals_load[kNElts + final_state_position + w];
435+ }
436+ }
437+ if ((chunk == n_chunks - 1 ) && tidx == 0 ){
438+ // chunk = n_chunks - 1, the second segment of the final state first positions
439+ reinterpret_cast <vec_t *>(vals_load)[0 ] = smem_exchange[0 ];
440+ for (int w = -final_state_position; w < kWidth - 1 ; ++w){
441+ conv_states[w] = vals_load[w + final_state_position];
442+ }
443+ return ;
444+ }
445+ }
421446 }
422447 // Final state is stored in the smem_exchange last token slot,
423448 // in case seqlen < kWidth, we would need to take the final state from the
@@ -446,9 +471,14 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
446471 }
447472 else {
448473 // in case the final state is in between the threads data
449- reinterpret_cast <vec_t *>(x_vals_load)[1 ] = smem_exchange[last_thread + 1 ];
450- reinterpret_cast <vec_t *>(x_vals_load)[0 ] = smem_exchange[last_thread];
451474 const int offset = ((seqlen - (kWidth - 1 )) % (kNElts ));
475+ if ((offset + kWidth - 2 ) >= kNElts && (last_thread + 1 < kNThreads )){
476+ // In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
477+ // illegal access error on H100.
478+ // Therefore, we access last_thread + 1, only if the final state data sits there
479+ reinterpret_cast <vec_t *>(x_vals_load)[1 ] = smem_exchange[last_thread + 1 ];
480+ }
481+ reinterpret_cast <vec_t *>(x_vals_load)[0 ] = smem_exchange[last_thread];
452482 #pragma unroll
453483 for (int w = 0 ; w < kWidth - 1 ; ++w){
454484 conv_states[w] = x_vals_load[offset + w ];
0 commit comments