From 1caff107ba0176e4b46bbcb83093a2caa1731c6c Mon Sep 17 00:00:00 2001 From: Eddy Li Date: Tue, 4 Nov 2025 13:14:51 -0800 Subject: [PATCH] Fix is_using_virtual_table bug in kernel Summary: Ease fix for a misused func calling. Differential Revision: D86233647 --- torchrec/distributed/batched_embedding_kernel.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 6e5daaaef..4684dc472 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -194,7 +194,7 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: ) # populate init min and max - if config.is_using_virtual_table: + if config.is_using_virtual_table(): _generate_init_range_for_virtual_tables(ssd_tbe_params, config) if ( @@ -242,7 +242,7 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: ) ssd_tbe_params["cache_sets"] = int(max_cache_sets) - if "kvzch_eviction_tbe_config" in fused_params and config.is_using_virtual_table: + if "kvzch_eviction_tbe_config" in fused_params and config.is_using_virtual_table(): ssd_tbe_params["kvzch_eviction_tbe_config"] = fused_params.get( "kvzch_eviction_tbe_config" ) @@ -1969,7 +1969,7 @@ def __init__( len({table.embedding_dim for table in config.embedding_tables}) == 1 ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." assert ( - config.is_using_virtual_table + config.is_using_virtual_table() ), "Try to create ZeroCollisionKeyValueEmbedding for non virtual tables" assert embedding_cache_mode == config.enable_embedding_update, ( f"Embedding_cache kernel is {embedding_cache_mode} " @@ -2883,7 +2883,7 @@ def __init__( len({table.embedding_dim for table in config.embedding_tables}) == 1 ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." assert ( - config.is_using_virtual_table + config.is_using_virtual_table() ), "Try to create ZeroCollisionKeyValueEmbeddingBag for non virtual tables" for table in config.embedding_tables: