Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
)

# populate init min and max
if config.is_using_virtual_table:
if config.is_using_virtual_table():
_generate_init_range_for_virtual_tables(ssd_tbe_params, config)

if (
Expand Down Expand Up @@ -242,7 +242,7 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
)
ssd_tbe_params["cache_sets"] = int(max_cache_sets)

if "kvzch_eviction_tbe_config" in fused_params and config.is_using_virtual_table:
if "kvzch_eviction_tbe_config" in fused_params and config.is_using_virtual_table():
ssd_tbe_params["kvzch_eviction_tbe_config"] = fused_params.get(
"kvzch_eviction_tbe_config"
)
Expand Down Expand Up @@ -1969,7 +1969,7 @@ def __init__(
len({table.embedding_dim for table in config.embedding_tables}) == 1
), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
assert (
config.is_using_virtual_table
config.is_using_virtual_table()
), "Try to create ZeroCollisionKeyValueEmbedding for non virtual tables"
assert embedding_cache_mode == config.enable_embedding_update, (
f"Embedding_cache kernel is {embedding_cache_mode} "
Expand Down Expand Up @@ -2883,7 +2883,7 @@ def __init__(
len({table.embedding_dim for table in config.embedding_tables}) == 1
), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
assert (
config.is_using_virtual_table
config.is_using_virtual_table()
), "Try to create ZeroCollisionKeyValueEmbeddingBag for non virtual tables"

for table in config.embedding_tables:
Expand Down
Loading