Skip to content

Commit af25076

Browse files
emlinmeta-codesync[bot]
authored andcommitted
disable random init in inference operator for embedding cache (#3466)
Summary: X-link: pytorch/FBGEMM#5026 Pull Request resolved: #3466 X-link: https://github.com/facebookresearch/FBGEMM/pull/2040 For embedding cache mode, we do not expect random value if there is cache missing. This diff passed the embedding cache mode to inference operator, and use that to disable the backend random initialization. Differential Revision: D84367061 fbshipit-source-id: 83687bcb7c097f60b583c00bf80956efcdcd3a9d
1 parent dbabd25 commit af25076

File tree

2 files changed

+92
-34
lines changed

2 files changed

+92
-34
lines changed

torchrec/distributed/quant_embedding_kernel.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -327,20 +327,45 @@ def __init__(
327327
else:
328328
shard_offsets_for_kv_zch = None
329329

330-
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz(
331-
embedding_specs=embedding_specs,
332-
device=device,
333-
pooling_mode=self._pooling,
334-
feature_table_map=self._feature_table_map,
335-
row_alignment=self._tbe_row_alignment,
336-
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
337-
bounds_check_mode=(
330+
# Determine embedding cache mode for KV embedding tables
331+
embedding_cache_mode = False # Default: False = randomized initialization
332+
if tbe_clazz == KVEmbeddingInference:
333+
# For KV embedding tables, set cache mode based on embedding table configuration
334+
# Check if any table has NoEvictionPolicy - use zero init for those
335+
for table in config.embedding_tables:
336+
if (
337+
table.virtual_table_eviction_policy is not None
338+
and type(table.virtual_table_eviction_policy).__name__
339+
== "NoEvictionPolicy"
340+
):
341+
embedding_cache_mode = True # True = zero initialization
342+
break
343+
344+
# Build kwargs for module construction
345+
module_kwargs: Dict[str, Any] = {
346+
"embedding_specs": embedding_specs,
347+
"device": device,
348+
"pooling_mode": self._pooling,
349+
"feature_table_map": self._feature_table_map,
350+
"row_alignment": self._tbe_row_alignment,
351+
"uvm_host_mapped": True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
352+
"bounds_check_mode": (
338353
bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING
339354
),
340-
feature_names_per_table=[
355+
"feature_names_per_table": [
341356
table.feature_names for table in config.embedding_tables
342357
],
343-
**(tbe_fused_params(fused_params) or {}),
358+
}
359+
360+
# Add KV-specific parameters
361+
if tbe_clazz == KVEmbeddingInference:
362+
module_kwargs["embedding_cache_mode"] = embedding_cache_mode
363+
364+
# Add fused params
365+
module_kwargs.update(**(tbe_fused_params(fused_params) or {}))
366+
367+
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz(
368+
**module_kwargs
344369
)
345370
if device is not None:
346371
self._emb_module.initialize_weights()
@@ -495,6 +520,7 @@ def __init__(
495520

496521
managed: List[EmbeddingLocation] = []
497522
is_virtual_table = False
523+
embedding_cache_mode = False
498524
for table in config.embedding_tables:
499525
if device is not None and device.type == "cuda":
500526
managed.append(
@@ -504,6 +530,8 @@ def __init__(
504530
managed.append(EmbeddingLocation.HOST)
505531
if table.use_virtual_table:
506532
is_virtual_table = True
533+
if table.enable_embedding_update:
534+
embedding_cache_mode = True
507535
self._config: GroupedEmbeddingConfig = config
508536
self._emb_module_registered: bool = is_fused_param_register_tbe(fused_params)
509537
self._quant_state_dict_split_scale_bias: bool = (
@@ -529,8 +557,9 @@ def __init__(
529557
else:
530558
shard_offsets_for_kv_zch = None
531559

532-
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz(
533-
embedding_specs=[
560+
# Build kwargs for module construction
561+
module_kwargs: Dict[str, Any] = {
562+
"embedding_specs": [
534563
(
535564
table.name,
536565
local_rows,
@@ -549,15 +578,25 @@ def __init__(
549578
managed,
550579
)
551580
],
552-
device=device,
553-
pooling_mode=PoolingMode.NONE,
554-
feature_table_map=self._feature_table_map,
555-
row_alignment=self._tbe_row_alignment,
556-
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
557-
feature_names_per_table=[
581+
"device": device,
582+
"pooling_mode": PoolingMode.NONE,
583+
"feature_table_map": self._feature_table_map,
584+
"row_alignment": self._tbe_row_alignment,
585+
"uvm_host_mapped": True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
586+
"feature_names_per_table": [
558587
table.feature_names for table in config.embedding_tables
559588
],
560-
**(tbe_fused_params(fused_params) or {}),
589+
}
590+
591+
# Add KV-specific parameters
592+
if embedding_clazz == KVEmbeddingInference:
593+
module_kwargs["embedding_cache_mode"] = embedding_cache_mode
594+
595+
# Add fused params
596+
module_kwargs.update(**(tbe_fused_params(fused_params) or {}))
597+
598+
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz(
599+
**module_kwargs
561600
)
562601
if device is not None:
563602
self._emb_module.initialize_weights()

torchrec/quant/embedding_modules.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -764,9 +764,9 @@ def __init__( # noqa C901
764764
self._output_dtype = output_dtype
765765
self._device = device
766766
self.row_alignment = row_alignment
767-
self._key_to_tables: Dict[Tuple[DataType, bool], List[EmbeddingConfig]] = (
768-
defaultdict(list)
769-
)
767+
self._key_to_tables: Dict[
768+
Tuple[DataType, bool, bool], List[EmbeddingConfig]
769+
] = defaultdict(list)
770770
self._feature_names: List[str] = []
771771
self._features_order: Optional[List[int]] = None
772772

@@ -789,12 +789,24 @@ def __init__( # noqa C901
789789
+ f" {self._embedding_dim}"
790790
)
791791
if hasattr(table, "use_virtual_table"):
792-
key = (table.data_type, table.use_virtual_table)
792+
key = (table.data_type, table.use_virtual_table, False)
793+
if hasattr(table, "use_virtual_table") and hasattr(
794+
table, "enable_embedding_update"
795+
):
796+
key = (
797+
table.data_type,
798+
table.use_virtual_table,
799+
table.enable_embedding_update,
800+
)
793801
else:
794-
key = (table.data_type, False)
802+
key = (table.data_type, False, False)
795803
self._key_to_tables[key].append(table)
796804
self._feature_splits: List[int] = []
797-
for (data_type, use_virtual_table), emb_configs in self._key_to_tables.items():
805+
for (
806+
data_type,
807+
use_virtual_table,
808+
enable_embedding_update,
809+
), emb_configs in self._key_to_tables.items():
798810
embedding_specs = []
799811
weight_lists: Optional[
800812
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
@@ -825,15 +837,20 @@ def __init__( # noqa C901
825837
if use_virtual_table
826838
else IntNBitTableBatchedEmbeddingBagsCodegen
827839
)
828-
emb_module = embedding_clazz(
829-
embedding_specs=embedding_specs,
830-
pooling_mode=PoolingMode.NONE,
831-
weight_lists=weight_lists,
832-
device=device,
833-
output_dtype=data_type_to_sparse_type(dtype_to_data_type(output_dtype)),
834-
row_alignment=row_alignment,
835-
feature_table_map=feature_table_map,
836-
)
840+
kwargs: Dict[str, Any] = {
841+
"embedding_specs": embedding_specs,
842+
"pooling_mode": PoolingMode.NONE,
843+
"weight_lists": weight_lists,
844+
"device": device,
845+
"output_dtype": data_type_to_sparse_type(
846+
dtype_to_data_type(output_dtype)
847+
),
848+
"row_alignment": row_alignment,
849+
"feature_table_map": feature_table_map,
850+
}
851+
if embedding_clazz == KVEmbeddingInference:
852+
kwargs["embedding_cache_mode"] = enable_embedding_update
853+
emb_module = embedding_clazz(**kwargs)
837854
if weight_lists is None:
838855
emb_module.initialize_weights()
839856
self._emb_modules.append(emb_module)
@@ -869,6 +886,7 @@ def __init__( # noqa C901
869886
"weight_qbias", qbias
870887
)
871888

889+
# pyre-ignore [8]
872890
self._embedding_names_by_batched_tables: Dict[
873891
Tuple[DataType, bool], List[str]
874892
] = {
@@ -934,6 +952,7 @@ def forward(
934952
f = kjts_per_key[i]
935953
lengths = _get_feature_length(f)
936954
indices, offsets = _fx_trec_unwrap_kjt(f)
955+
# pyre-ignore [6]
937956
embedding_names = self._embedding_names_by_batched_tables[key]
938957
lookup = (
939958
emb_module(indices=indices, offsets=offsets)

0 commit comments

Comments
 (0)