diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 754b9e6fa..2d43db376 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -55,6 +55,7 @@ FUSED_PARAM_IS_SSD_TABLE, FUSED_PARAM_SSD_TABLE_LIST, ) +from torchrec.distributed.logger import _torchrec_method_logger from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding from torchrec.distributed.sharding.dynamic_sharding import ( @@ -466,6 +467,7 @@ class ShardedEmbeddingBagCollection( This is part of the public API to allow for manual data dist pipelining. """ + @_torchrec_method_logger() def __init__( self, module: EmbeddingBagCollectionInterface, @@ -2021,6 +2023,7 @@ class ShardedEmbeddingBag( This is part of the public API to allow for manual data dist pipelining. """ + @_torchrec_method_logger() def __init__( self, module: nn.EmbeddingBag, diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index 84a21bd12..214015844 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -18,6 +18,7 @@ from torch import nn from torchrec.distributed.collective_utils import invoke_on_rank_and_broadcast_result from torchrec.distributed.comm import get_local_size +from torchrec.distributed.logger import _torchrec_method_logger from torchrec.distributed.planner.constants import BATCH_SIZE, MAX_SIZE from torchrec.distributed.planner.enumerators import EmbeddingEnumerator from torchrec.distributed.planner.partitioners import ( @@ -498,6 +499,7 @@ def collective_plan( sharders, ) + @_torchrec_method_logger() def plan( self, module: nn.Module, diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index e9fe723e8..3650324fc 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -16,6 +16,7 @@ import torchrec.optim as trec_optim from torch import nn from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.logger import _torchrec_method_logger from torchrec.distributed.planner.constants import ( BATCHED_COPY_PERF_FACTOR, BIGINT_DTYPE, @@ -955,6 +956,7 @@ class EmbeddingStorageEstimator(ShardEstimator): is_inference (bool): If the model is inference model. Default to False. """ + @_torchrec_method_logger() def __init__( self, topology: Topology, diff --git a/torchrec/distributed/shard.py b/torchrec/distributed/shard.py index 0a27711a7..bc5a811ae 100644 --- a/torchrec/distributed/shard.py +++ b/torchrec/distributed/shard.py @@ -15,6 +15,7 @@ from torch.distributed._composable.contract import contract from torchrec.distributed.comm import get_local_size from torchrec.distributed.global_settings import get_propogate_device +from torchrec.distributed.logger import _torchrec_method_logger from torchrec.distributed.model_parallel import get_default_sharders from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology from torchrec.distributed.sharding_plan import ( @@ -146,6 +147,7 @@ def _shard( # pyre-ignore @contract() +@_torchrec_method_logger() def shard_modules( module: nn.Module, env: Optional[ShardingEnv] = None, @@ -194,6 +196,7 @@ def init_weights(m): return _shard_modules(module, env, device, plan, sharders, init_params) +@_torchrec_method_logger() def _shard_modules( # noqa: C901 module: nn.Module, # TODO: Consolidate to using Dict[str, ShardingEnv] diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 61c430bd8..4977477f3 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -32,6 +32,7 @@ import torch from torch.autograd.profiler import record_function from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable +from torchrec.distributed.logger import _torchrec_method_logger from torchrec.distributed.model_parallel import ShardedModule from torchrec.distributed.train_pipeline.pipeline_context import ( EmbeddingTrainPipelineContext, @@ -106,6 +107,8 @@ class TrainPipeline(abc.ABC, Generic[In, Out]): def progress(self, dataloader_iter: Iterator[In]) -> Out: pass + # pyre-ignore [56] + @_torchrec_method_logger() def __init__(self) -> None: # pipeline state such as in foward, in backward etc, used in training recover scenarios self._state: PipelineState = PipelineState.IDLE diff --git a/torchrec/modules/mc_embedding_modules.py b/torchrec/modules/mc_embedding_modules.py index 129da1e69..04d4e2489 100644 --- a/torchrec/modules/mc_embedding_modules.py +++ b/torchrec/modules/mc_embedding_modules.py @@ -12,6 +12,7 @@ import torch import torch.nn as nn +from torchrec.distributed.logger import _torchrec_method_logger from torchrec.modules.embedding_modules import ( EmbeddingBagCollection, @@ -125,6 +126,7 @@ class ManagedCollisionEmbeddingCollection(BaseManagedCollisionEmbeddingCollectio """ + @_torchrec_method_logger() def __init__( self, embedding_collection: EmbeddingCollection, @@ -164,6 +166,7 @@ class ManagedCollisionEmbeddingBagCollection(BaseManagedCollisionEmbeddingCollec """ + @_torchrec_method_logger() def __init__( self, embedding_bag_collection: EmbeddingBagCollection,