diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index d8554edea..9ddb6e91c 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -1469,7 +1469,7 @@ def input_dist( ) -> Awaitable[Awaitable[KJTList]]: need_permute: bool = True if isinstance(features, TensorDict): - feature_keys = list(features.keys()) # pyre-ignore[6] + feature_keys = list(features.keys()) if self._features_order: feature_keys = [feature_keys[i] for i in self._features_order] need_permute = False diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index fd6117884..4375baa4f 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1513,7 +1513,7 @@ def input_dist( in advance """ if isinstance(features, TensorDict): - feature_keys = list(features.keys()) # pyre-ignore[6] + feature_keys = list(features.keys()) if len(self._features_order) > 0: feature_keys = [feature_keys[i] for i in self._features_order] self._has_features_permute = False # feature_keys are in order diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index cb7004670..d0d0c7a8d 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -298,7 +298,7 @@ def _validate_pooling_factor( idlist_features, global_idlist_indices, global_idlist_lengths ) } - global_idlist_input = TensorDict(source=dict_of_nt) + global_idlist_input = TensorDict(source=dict_of_nt) # pyre-ignore[6] assert ( len(idscore_features) == 0 @@ -409,7 +409,7 @@ def _validate_pooling_factor( local_idlist_lengths, ) } - local_idlist_input = TensorDict(source=dict_of_nt) + local_idlist_input = TensorDict(source=dict_of_nt) # pyre-ignore[6] assert ( len(idscore_features) == 0 ), "TensorDict does not support weighted features"