diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 410624e27..62bcdfa80 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -262,10 +262,8 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: ) ssd_tbe_params["cache_sets"] = int(max_cache_sets) - if "kvzch_eviction_tbe_config" in fused_params and config.is_using_virtual_table(): - ssd_tbe_params["kvzch_eviction_tbe_config"] = fused_params.get( - "kvzch_eviction_tbe_config" - ) + if "kvzch_tbe_config" in fused_params and config.is_using_virtual_table(): + ssd_tbe_params["kvzch_tbe_config"] = fused_params.get("kvzch_tbe_config") ssd_tbe_params["table_names"] = [table.name for table in config.embedding_tables] @@ -359,10 +357,10 @@ def _populate_zero_collision_tbe_params( l2_cache_size = tbe_params["l2_cache_size"] assert ( - "kvzch_eviction_tbe_config" in tbe_params - ), "kvzch_eviction_tbe_config should be in tbe_params" - eviction_tbe_config = tbe_params["kvzch_eviction_tbe_config"] - tbe_params.pop("kvzch_eviction_tbe_config") + "kvzch_tbe_config" in tbe_params + ), "kvzch_tbe_config should be in tbe_params" + eviction_tbe_config = tbe_params["kvzch_tbe_config"] + tbe_params.pop("kvzch_tbe_config") eviction_trigger_mode = eviction_tbe_config.kvzch_eviction_trigger_mode eviction_free_mem_threshold_gb = ( eviction_tbe_config.eviction_free_mem_threshold_gb diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 0fb74665e..9e8c0cc83 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -33,7 +33,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( BoundsCheckMode, CacheAlgorithm, - KVZCHEvictionTBEConfig, + KVZCHTBEConfig, MultiPassPrefetchConfig, ) @@ -668,7 +668,7 @@ class KeyValueParams: lazy_bulk_init_enabled: bool: whether to enable lazy(async) bulk init for SSD TBE enable_raw_embedding_streaming: Optional[bool]: enable raw embedding streaming for SSD TBE res_store_shards: Optional[int] = None: the number of shards to store the raw embeddings - kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig]: KVZCH eviction config for TBE + kvzch_tbe_config: Optional[KVZCHTBEConfig]: KVZCH config for TBE # Parameter Server (PS) Attributes ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses @@ -694,7 +694,7 @@ class KeyValueParams: None # enable raw embedding streaming for SSD TBE ) res_store_shards: Optional[int] = None # shards to store the raw embeddings - kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig] = None + kvzch_tbe_config: Optional[KVZCHTBEConfig] = None # Parameter Server (PS) Attributes ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None @@ -723,7 +723,7 @@ def __hash__(self) -> int: self.lazy_bulk_init_enabled, self.enable_raw_embedding_streaming, self.res_store_shards, - self.kvzch_eviction_tbe_config, + self.kvzch_tbe_config, ) )