@@ -173,14 +173,20 @@ def __init__(self,
173173 if self .enforce_eager is None :
174174 self .enforce_eager = False
175175
176- if (not self .disable_sliding_window
177- and self .hf_text_config .model_type == "gemma2"
178- and self .hf_text_config .sliding_window is not None ):
176+ sliding_window = getattr (self .hf_text_config , "sliding_window" , None )
177+ has_interleaved_attention = (sliding_window is not None ) and (
178+ isinstance (sliding_window , list ) or
179+ (self .hf_text_config .model_type in ["gemma2" ]))
180+
181+ if (not self .disable_sliding_window and has_interleaved_attention ):
182+ sliding_window_len_min = get_min_sliding_window (
183+ self .hf_text_config .sliding_window )
184+
179185 print_warning_once (
180- "Gemma 2 uses sliding window attention for every odd layer , "
186+ f" { self . hf_text_config . model_type } has interleaved attention, "
181187 "which is currently not supported by vLLM. Disabling sliding "
182188 "window and capping the max length to the sliding window size "
183- f"({ self . hf_text_config . sliding_window } )." )
189+ f"({ sliding_window_len_min } )." )
184190 self .disable_sliding_window = True
185191
186192 self .max_model_len = _get_and_verify_max_len (
@@ -431,7 +437,8 @@ def verify_with_parallel_config(
431437 "pipeline parallelism currently. Disabling it." )
432438 self .use_async_output_proc = False
433439
434- def get_hf_config_sliding_window (self ) -> Optional [int ]:
440+ def get_hf_config_sliding_window (
441+ self ) -> Union [Optional [int ], List [Optional [int ]]]:
435442 """Get the sliding window size, or None if disabled."""
436443
437444 # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
@@ -442,8 +449,9 @@ def get_hf_config_sliding_window(self) -> Optional[int]:
442449 return None
443450 return getattr (self .hf_text_config , "sliding_window" , None )
444451
445- def get_sliding_window (self ) -> Optional [int ]:
446- """Get the sliding window size, or None if disabled."""
452+ def get_sliding_window (self ) -> Optional [Union [int , List [Optional [int ]]]]:
453+ """Get the sliding window size, or None if disabled.
454+ """
447455 # If user disables sliding window, return None.
448456 if self .disable_sliding_window :
449457 return None
@@ -1717,7 +1725,7 @@ def _get_and_verify_max_len(
17171725 hf_config : PretrainedConfig ,
17181726 max_model_len : Optional [int ],
17191727 disable_sliding_window : bool ,
1720- sliding_window_len : Optional [int ],
1728+ sliding_window_len : Optional [Union [ int , List [ Optional [ int ]]] ],
17211729 spec_target_max_model_len : Optional [int ] = None ,
17221730) -> int :
17231731 """Get and verify the model's maximum length."""
@@ -1750,10 +1758,12 @@ def _get_and_verify_max_len(
17501758 # If sliding window is manually disabled, max_length should be less
17511759 # than the sliding window length in the model config.
17521760 if disable_sliding_window and sliding_window_len is not None :
1753- max_len_key = ("sliding_window"
1754- if sliding_window_len < derived_max_model_len else
1755- max_len_key )
1756- derived_max_model_len = min (derived_max_model_len , sliding_window_len )
1761+
1762+ sliding_window_len_min = get_min_sliding_window (sliding_window_len )
1763+ max_len_key = "sliding_window" \
1764+ if sliding_window_len_min < derived_max_model_len else max_len_key
1765+ derived_max_model_len = min (derived_max_model_len ,
1766+ sliding_window_len_min )
17571767
17581768 # If none of the keys were found in the config, use a default and
17591769 # log a warning.
@@ -1836,6 +1846,14 @@ def _get_and_verify_max_len(
18361846 return int (max_model_len )
18371847
18381848
1849+ def get_min_sliding_window (
1850+ sliding_window : Union [int , List [Optional [int ]]]) -> int :
1851+ if isinstance (sliding_window , list ):
1852+ return min (s for s in sliding_window if s is not None )
1853+
1854+ return sliding_window
1855+
1856+
18391857def get_served_model_name (model : str ,
18401858 served_model_name : Optional [Union [str , List [str ]]]):
18411859 """
0 commit comments