@@ -327,20 +327,45 @@ def __init__(
327327 else :
328328 shard_offsets_for_kv_zch = None
329329
330- self ._emb_module : IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz (
331- embedding_specs = embedding_specs ,
332- device = device ,
333- pooling_mode = self ._pooling ,
334- feature_table_map = self ._feature_table_map ,
335- row_alignment = self ._tbe_row_alignment ,
336- uvm_host_mapped = True , # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
337- bounds_check_mode = (
330+ # Determine embedding cache mode for KV embedding tables
331+ embedding_cache_mode = False # Default: False = randomized initialization
332+ if tbe_clazz == KVEmbeddingInference :
333+ # For KV embedding tables, set cache mode based on embedding table configuration
334+ # Check if any table has NoEvictionPolicy - use zero init for those
335+ for table in config .embedding_tables :
336+ if (
337+ table .virtual_table_eviction_policy is not None
338+ and type (table .virtual_table_eviction_policy ).__name__
339+ == "NoEvictionPolicy"
340+ ):
341+ embedding_cache_mode = True # True = zero initialization
342+ break
343+
344+ # Build kwargs for module construction
345+ module_kwargs : Dict [str , Any ] = {
346+ "embedding_specs" : embedding_specs ,
347+ "device" : device ,
348+ "pooling_mode" : self ._pooling ,
349+ "feature_table_map" : self ._feature_table_map ,
350+ "row_alignment" : self ._tbe_row_alignment ,
351+ "uvm_host_mapped" : True , # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
352+ "bounds_check_mode" : (
338353 bounds_check_mode if bounds_check_mode else BoundsCheckMode .WARNING
339354 ),
340- feature_names_per_table = [
355+ " feature_names_per_table" : [
341356 table .feature_names for table in config .embedding_tables
342357 ],
343- ** (tbe_fused_params (fused_params ) or {}),
358+ }
359+
360+ # Add KV-specific parameters
361+ if tbe_clazz == KVEmbeddingInference :
362+ module_kwargs ["embedding_cache_mode" ] = embedding_cache_mode
363+
364+ # Add fused params
365+ module_kwargs .update (** (tbe_fused_params (fused_params ) or {}))
366+
367+ self ._emb_module : IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz (
368+ ** module_kwargs
344369 )
345370 if device is not None :
346371 self ._emb_module .initialize_weights ()
@@ -495,6 +520,7 @@ def __init__(
495520
496521 managed : List [EmbeddingLocation ] = []
497522 is_virtual_table = False
523+ embedding_cache_mode = False
498524 for table in config .embedding_tables :
499525 if device is not None and device .type == "cuda" :
500526 managed .append (
@@ -504,6 +530,8 @@ def __init__(
504530 managed .append (EmbeddingLocation .HOST )
505531 if table .use_virtual_table :
506532 is_virtual_table = True
533+ if table .enable_embedding_update :
534+ embedding_cache_mode = True
507535 self ._config : GroupedEmbeddingConfig = config
508536 self ._emb_module_registered : bool = is_fused_param_register_tbe (fused_params )
509537 self ._quant_state_dict_split_scale_bias : bool = (
@@ -529,8 +557,9 @@ def __init__(
529557 else :
530558 shard_offsets_for_kv_zch = None
531559
532- self ._emb_module : IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz (
533- embedding_specs = [
560+ # Build kwargs for module construction
561+ module_kwargs : Dict [str , Any ] = {
562+ "embedding_specs" : [
534563 (
535564 table .name ,
536565 local_rows ,
@@ -549,15 +578,25 @@ def __init__(
549578 managed ,
550579 )
551580 ],
552- device = device ,
553- pooling_mode = PoolingMode .NONE ,
554- feature_table_map = self ._feature_table_map ,
555- row_alignment = self ._tbe_row_alignment ,
556- uvm_host_mapped = True , # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
557- feature_names_per_table = [
581+ " device" : device ,
582+ " pooling_mode" : PoolingMode .NONE ,
583+ " feature_table_map" : self ._feature_table_map ,
584+ " row_alignment" : self ._tbe_row_alignment ,
585+ " uvm_host_mapped" : True , # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
586+ " feature_names_per_table" : [
558587 table .feature_names for table in config .embedding_tables
559588 ],
560- ** (tbe_fused_params (fused_params ) or {}),
589+ }
590+
591+ # Add KV-specific parameters
592+ if embedding_clazz == KVEmbeddingInference :
593+ module_kwargs ["embedding_cache_mode" ] = embedding_cache_mode
594+
595+ # Add fused params
596+ module_kwargs .update (** (tbe_fused_params (fused_params ) or {}))
597+
598+ self ._emb_module : IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz (
599+ ** module_kwargs
561600 )
562601 if device is not None :
563602 self ._emb_module .initialize_weights ()
0 commit comments