@@ -337,13 +337,19 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
337337 "target_attn_1" : mock .MagicMock (),
338338 "target_attn_2" : mock .MagicMock ()
339339 }
340+ target_indx_layers : dict [str , mock .MagicMock ] = {}
340341 # Draft model has one extra attention layer compared to target model
341342 all_attn_layers = {
342343 ** target_attn_layers , "draft_extra_attn" : mock .MagicMock ()
343344 }
344345
346+ all_indx_layers : dict [str , mock .MagicMock ] = {}
347+
345348 # Make mock_get_layers return different values for each call
346- mock_get_layers .side_effect = [target_attn_layers , all_attn_layers ]
349+ mock_get_layers .side_effect = [
350+ target_attn_layers , target_indx_layers , all_attn_layers ,
351+ all_indx_layers
352+ ]
347353
348354 # Setup mock for pp group to return the appropriate value for world size
349355 mock_pp_group = mock .MagicMock ()
@@ -658,6 +664,9 @@ def create_deterministic_logits(token_ids, k: int):
658664 # Mock runner for attention metadata building.
659665 proposer .runner = mock .MagicMock ()
660666 proposer .runner .attn_groups .append ([mock .MagicMock ()])
667+ proposer .runner .attn_groups [0 ][0 ].metadata_builders = [
668+ attn_metadata_builder
669+ ]
661670 proposer .runner .attn_groups [0 ][0 ].get_metadata_builder .return_value = \
662671 attn_metadata_builder
663672 proposer ._get_attention_metadata_builder = mock .MagicMock (
0 commit comments