@@ -454,43 +454,6 @@ def test_forward_decode_only_swa_seq_len_mismatch(
454454
455455 assert output .shape == (10 , 8 * 64 )
456456
457- @patch ('vllm_ascend.attention.attention_v1.get_forward_context' )
458- @patch ('vllm_ascend.utils.get_ascend_device_type' ,
459- return_value = AscendDeviceType ._910_93 )
460- @patch ('torch_npu._npu_reshape_and_cache' )
461- @patch ('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill' )
462- def test_forward_head_size_192 (self , mock_vanilla_prefill ,
463- mock_npu_reshape_and_cache ,
464- mock_soc_version , mock_get_forward_context ):
465- """Test forward pass when head_size is 192"""
466-
467- self .impl .head_size = 192
468- query = torch .randn (10 , 8 * 192 )
469- key = torch .randn (10 , 8 * 192 )
470- value = torch .randn (10 , 8 * 192 )
471- kv_cache = torch .empty (2 , 5 , 128 , 8 , 192 )
472- output = torch .empty_like (query )
473-
474- mock_get_forward_context .return_value = MagicMock (capturing = False )
475-
476- metadata = self .attn_metadata
477- metadata .attn_mask = torch .randn (1 , 1 , 10 , 10 )
478- metadata .query_lens = torch .tensor ([10 ])
479- metadata .seq_lens = torch .tensor ([10 ])
480- metadata .block_tables = torch .zeros (1 , 5 , dtype = torch .long )
481- metadata .num_actual_tokens = 10
482- metadata .slot_mapping = torch .zeros (10 , dtype = torch .long )
483- metadata .num_decodes = 10
484- metadata .num_prefills = 0
485- layer = self .layer_no_quant
486- mock_vanilla_prefill .return_value = MagicMock ()
487-
488- output = self .impl_192 .forward (layer , query , key , value , kv_cache ,
489- metadata , output )
490-
491- mock_vanilla_prefill .assert_called_once ()
492- assert output .shape == (10 , 8 * 192 )
493-
494457 @patch ('vllm_ascend.attention.attention_v1.get_forward_context' )
495458 @patch ('torch_npu.npu_fused_infer_attention_score' )
496459 @patch ('torch_npu._npu_reshape_and_cache' )
0 commit comments