2525from ..attention_backend import get_sparse_attn_kv_cache_manager
2626from ..model_config import ModelConfig
2727from ..speculative import get_num_extra_kv_tokens , get_spec_decoder
28- from .config import PyTorchConfig
2928from .config_utils import is_mla , is_nemotron_hybrid , is_qwen3_next
3029from .guided_decoder import GuidedDecoder
3130from .kv_cache_connector import KvCacheConnectorManager
@@ -73,7 +72,7 @@ def __init__(
7372 max_seq_len : int ,
7473 max_batch_size : int ,
7574 kv_cache_config : KvCacheConfig ,
76- pytorch_backend_config : PyTorchConfig ,
75+ llm_args : TorchLlmArgs ,
7776 speculative_config : SpeculativeConfig ,
7877 sparse_attention_config : SparseAttentionConfig ,
7978 profiling_stage_data : Optional [dict ],
@@ -86,7 +85,7 @@ def __init__(
8685 self ._max_num_tokens = max_num_tokens
8786 self ._max_beam_width = max_beam_width
8887 self ._kv_connector_manager = kv_connector_manager
89- self ._pytorch_backend_config = pytorch_backend_config
88+ self ._llm_args = llm_args
9089 self ._speculative_config = speculative_config
9190 self ._sparse_attention_config = sparse_attention_config
9291 self ._tokens_per_block = tokens_per_block
@@ -248,9 +247,8 @@ def _get_token_num_for_estimation(self) -> int:
248247 # estimate_max_kv_cache_tokens submits self._dummy_reqs
249248 num_cache_blocks = 0
250249 num_extra_tokens_per_seq = 1 # account for generated tokens
251- pytorch_backend_config = self ._pytorch_backend_config
252250 spec_cfg = self ._speculative_config
253- if not pytorch_backend_config .disable_overlap_scheduler :
251+ if not self . _llm_args .disable_overlap_scheduler :
254252 num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
255253 if spec_cfg is not None :
256254 num_extra_tokens_per_seq += spec_cfg .max_total_draft_tokens
@@ -653,7 +651,7 @@ def create_py_executor_instance(
653651 dist ,
654652 resources ,
655653 mapping ,
656- pytorch_backend_config ,
654+ llm_args ,
657655 ctx_chunk_config ,
658656 model_engine ,
659657 start_worker ,
@@ -680,7 +678,7 @@ def create_py_executor_instance(
680678 f"max_seq_len={ max_seq_len } , max_num_requests={ max_batch_size } , max_num_tokens={ max_num_tokens } , max_batch_size={ max_batch_size } "
681679 )
682680
683- for key , value in pytorch_backend_config .extra_resource_managers .items ():
681+ for key , value in llm_args .extra_resource_managers .items ():
684682 if key in resources :
685683 raise ValueError (
686684 f"Cannot overwrite existing resource manager { key } ." )
@@ -805,8 +803,7 @@ def create_py_executor_instance(
805803 drafter = drafter ,
806804 dist = dist ,
807805 max_num_sequences = max_num_sequences ,
808- disable_overlap_scheduler = pytorch_backend_config .
809- disable_overlap_scheduler ,
806+ disable_overlap_scheduler = llm_args .disable_overlap_scheduler ,
810807 max_batch_size = max_batch_size ,
811808 max_beam_width = max_beam_width ,
812809 max_draft_len = spec_config .max_draft_len
@@ -842,13 +839,11 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
842839 )
843840
844841
845- def instantiate_sampler (engine : PyTorchModelEngine ,
846- pytorch_backend_config : PyTorchConfig , mapping : Mapping ,
847- max_batch_size : int , max_beam_width : int ,
848- max_seq_len : int , mm_encoder_only : bool ,
849- speculative_config : SpeculativeConfig ,
850- decoding_config : trtllm .DecodingConfig ,
851- kv_cache_config : KvCacheConfig ):
842+ def instantiate_sampler (
843+ engine : PyTorchModelEngine , llm_args : TorchLlmArgs , mapping : Mapping ,
844+ max_batch_size : int , max_beam_width : int , max_seq_len : int ,
845+ mm_encoder_only : bool , speculative_config : SpeculativeConfig ,
846+ decoding_config : trtllm .DecodingConfig , kv_cache_config : KvCacheConfig ):
852847 sampler_args = create_torch_sampler_args (
853848 mapping ,
854849 max_seq_len = engine .max_seq_len ,
@@ -858,7 +853,7 @@ def instantiate_sampler(engine: PyTorchModelEngine,
858853 decoding_mode = get_decoding_mode (decoding_config = decoding_config ,
859854 max_beam_width = max_beam_width )
860855 if mapping .cp_config .get ('cp_type' ) == CpType .STAR :
861- assert pytorch_backend_config .attn_backend == "FLASHINFER_STAR_ATTENTION" , "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
856+ assert llm_args .attn_backend == "FLASHINFER_STAR_ATTENTION" , "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
862857 return TorchSampler (sampler_args )
863858 if engine .spec_config is not None and engine .spec_config .spec_dec_mode .has_spec_decoder (
864859 ):
@@ -867,15 +862,15 @@ def instantiate_sampler(engine: PyTorchModelEngine,
867862 if mm_encoder_only :
868863 # NOTE: handle model outputs specially for mm encoder executor/engine
869864 return EarlyStopWithMMResult ()
870- if pytorch_backend_config .sampler_type == SamplerType .TRTLLMSampler or (
871- pytorch_backend_config .sampler_type == SamplerType .auto
865+ if llm_args .sampler_type == SamplerType .TRTLLMSampler or (
866+ llm_args .sampler_type == SamplerType .auto
872867 and decoding_mode .isBeamSearch ()):
873868 logger .debug (f"DecodingMode: { decoding_mode .name } " )
874869 return TRTLLMSampler (engine .model ,
875870 engine .dtype ,
876871 mapping ,
877872 decoding_mode ,
878- pytorch_backend_config .disable_overlap_scheduler ,
873+ llm_args .disable_overlap_scheduler ,
879874 max_seq_len = max_seq_len ,
880875 max_batch_size = max_batch_size ,
881876 max_beam_width = max_beam_width ,
@@ -937,7 +932,12 @@ def _try_infer_num_experts(model_config: ModelConfig) -> int:
937932 return num_experts
938933
939934
940- def _adjust_torch_mem_fraction (pytorch_backend_config : PyTorchConfig ):
935+ def _adjust_torch_mem_fraction ():
936+ # If true, adjust PyTorch CUDA memory fraction to correspond to the
937+ # total GPU memory minus the statically allocated engine memory.
938+ # If false, set the PyTorch CUDA memory fraction to 1.0.
939+ _limit_torch_cuda_mem_fraction : bool = True
940+
941941 # FIXME: PyTorch only uses the garbage_collection_threshold setting
942942 # if a memory fraction is set, cf.
943943 # https://github.com/pytorch/pytorch/blob/cd995bfb2aac8891465809be3ce29543bd524287/c10/cuda/CUDACachingAllocator.cpp#L1357
@@ -966,7 +966,7 @@ def _adjust_torch_mem_fraction(pytorch_backend_config: PyTorchConfig):
966966 # lead PyTorch to release all unused memory before hitting the set fraction. This
967967 # still mitigates OOM, although at a higher performance impact, because it
968968 # effectively resets the allocator cache.
969- if not pytorch_backend_config . _limit_torch_cuda_mem_fraction :
969+ if not _limit_torch_cuda_mem_fraction :
970970 return
971971 mem_reserved = torch .cuda .memory_reserved ()
972972 mem_free , mem_total = torch .cuda .mem_get_info ()
0 commit comments