Skip to content

Commit 906379d

Browse files
committed
remove diffs and use pad_slot_id as var in tests
1 parent 69ebab8 commit 906379d

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

csrc/mamba/causal_conv1d/causal_conv1d.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ void set_conv_params_fwd(ConvParamsBase &params,
5858
int64_t pad_slot_id,
5959
const c10::optional<at::Tensor>& query_start_loc = std::nullopt,
6060
const c10::optional<at::Tensor>& cache_indices = std::nullopt,
61-
const c10::optional<at::Tensor>& has_initial_state = std::nullopt
62-
) {
61+
const c10::optional<at::Tensor>& has_initial_state = std::nullopt) {
62+
6363
// Reset the parameters
6464
memset(&params, 0, sizeof(params));
6565

tests/kernels/test_causal_conv1d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
307307
weight,
308308
bias,
309309
activation=activation,
310-
conv_state_indices=padded_state_indices)
310+
conv_state_indices=padded_state_indices,
311+
pad_slot_id=PAD_SLOT_ID)
311312
out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
312313
conv_state_ref,
313314
weight,
@@ -397,7 +398,7 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
397398

398399
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
399400
padded_state_indices, has_initial_states,
400-
final_states, activation)
401+
final_states, activation, PAD_SLOT_ID)
401402
out_ref = []
402403
out_ref_b = []
403404

tests/kernels/test_mamba_ssm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,8 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
594594
z=z,
595595
dt_bias=dt_bias,
596596
dt_softplus=True,
597-
state_batch_indices=padded_state_indices)
597+
state_batch_indices=padded_state_indices,
598+
pad_slot_id=PAD_SLOT_ID)
598599
out_ref = selective_state_update_ref(state_ref,
599600
x[:batch_size],
600601
dt[:batch_size],
@@ -694,7 +695,8 @@ def test_selective_state_update_with_heads_with_batch_indices(
694695
z=z,
695696
dt_bias=dt_bias,
696697
dt_softplus=True,
697-
state_batch_indices=state_indices)
698+
state_batch_indices=state_indices,
699+
pad_slot_id=PAD_SLOT_ID)
698700
out_ref = selective_state_update_ref(state_ref,
699701
x,
700702
dt,

0 commit comments

Comments
 (0)