Skip to content

Commit 67b7c3b

Browse files
voldymanmeta-codesync[bot]
authored andcommitted
Sync tensordict
Summary: Imported commits from public tensordict repo using ``` python3 pytorch/import.py --project_name tensordict --no_submit ``` Manual Change: * adding a new dependency: pyvers * imported a new dependency pyvers in the stack * some import failures due to checked in code missing new lines * hypothesis: code formatter removed the new lines * manually added the new lines back and re-ran the script * added pyi files to library source to fix linter errors * disabled tests that use ray because we can't initialize a ray cluster on sandcastle hosts Differential Revision: D85157386
1 parent a4ca26f commit 67b7c3b

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

torchrec/distributed/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1469,7 +1469,7 @@ def input_dist(
14691469
) -> Awaitable[Awaitable[KJTList]]:
14701470
need_permute: bool = True
14711471
if isinstance(features, TensorDict):
1472-
feature_keys = list(features.keys()) # pyre-ignore[6]
1472+
feature_keys = list(features.keys())
14731473
if self._features_order:
14741474
feature_keys = [feature_keys[i] for i in self._features_order]
14751475
need_permute = False

torchrec/distributed/embeddingbag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1513,7 +1513,7 @@ def input_dist(
15131513
in advance
15141514
"""
15151515
if isinstance(features, TensorDict):
1516-
feature_keys = list(features.keys()) # pyre-ignore[6]
1516+
feature_keys = list(features.keys())
15171517
if len(self._features_order) > 0:
15181518
feature_keys = [feature_keys[i] for i in self._features_order]
15191519
self._has_features_permute = False # feature_keys are in order

torchrec/distributed/test_utils/test_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def _validate_pooling_factor(
298298
idlist_features, global_idlist_indices, global_idlist_lengths
299299
)
300300
}
301-
global_idlist_input = TensorDict(source=dict_of_nt)
301+
global_idlist_input = TensorDict(source=dict_of_nt) # pyre-ignore[6]
302302

303303
assert (
304304
len(idscore_features) == 0
@@ -409,7 +409,7 @@ def _validate_pooling_factor(
409409
local_idlist_lengths,
410410
)
411411
}
412-
local_idlist_input = TensorDict(source=dict_of_nt)
412+
local_idlist_input = TensorDict(source=dict_of_nt) # pyre-ignore[6]
413413
assert (
414414
len(idscore_features) == 0
415415
), "TensorDict does not support weighted features"

0 commit comments

Comments
 (0)