Skip to content

Commit ee26ed4

Browse files
emlinmeta-codesync[bot]
authored andcommitted
optimization: move set_metadata out of main stream (#5082)
Summary: Pull Request resolved: #5082 X-link: https://github.com/facebookresearch/FBGEMM/pull/2090 with feature score eviction, tbe will call backend to update feature score metadata separately in forward pass. this process is designed for asynchronous update without blocking forward/backward pass, however the cpu blocking operation blocked the main stream, so after get_cuda, all2all cannot be started immediately. from dummy profile, we can see this trace: {F1983224804} the set metadata operation becomes a blocker in critical path, which took 217ms With this change, we can see the trace is updated to: {F1983224830} where overall prefetch is reduced to less than 70ms, also the get_cuda is followed by all2all immediately without other waiting and stream sync https://www.internalfb.com/ai_infra/zoomer/profiling-run/overview?profilingRunID=1913270729575721 Reviewed By: steven1327, kathyxuyy Differential Revision: D86013406 fbshipit-source-id: 2fad88bd17d8e83104706540cfcd3311545af613
1 parent 5c0d969 commit ee26ed4

File tree

1 file changed

+65
-17
lines changed

1 file changed

+65
-17
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -762,8 +762,10 @@ def __init__(
762762
(low_priority, high_priority) = torch.cuda.Stream.priority_range()
763763
# GPU stream for SSD cache eviction
764764
self.ssd_eviction_stream = torch.cuda.Stream(priority=low_priority)
765-
# GPU stream for SSD memory copy
765+
# GPU stream for SSD memory copy (also reused for feature score D2H)
766766
self.ssd_memcpy_stream = torch.cuda.Stream(priority=low_priority)
767+
# GPU stream for async metadata operation
768+
self.feature_score_stream = torch.cuda.Stream(priority=low_priority)
767769

768770
# SSD get completion event
769771
self.ssd_event_get = torch.cuda.Event()
@@ -1675,6 +1677,56 @@ def _update_cache_counter_and_pointers(
16751677
unique_indices_length_curr=curr_data.actions_count_gpu,
16761678
)
16771679

1680+
def _update_feature_score_metadata(
1681+
self,
1682+
linear_cache_indices: Tensor,
1683+
weights: Tensor,
1684+
d2h_stream: torch.cuda.Stream,
1685+
write_stream: torch.cuda.Stream,
1686+
pre_event_for_write: torch.cuda.Event,
1687+
post_event: Optional[torch.cuda.Event] = None,
1688+
) -> None:
1689+
"""
1690+
Write feature score metadata to DRAM
1691+
1692+
This method performs D2H copy on d2h_stream, then writes to DRAM on write_stream.
1693+
The caller is responsible for ensuring d2h_stream doesn't compete with other D2H operations.
1694+
1695+
Args:
1696+
linear_cache_indices: GPU tensor containing cache indices
1697+
weights: GPU tensor containing feature scores
1698+
d2h_stream: Stream for D2H copy operation (should already be synchronized appropriately)
1699+
write_stream: Stream for metadata write operation
1700+
pre_event_for_write: Event to wait on before writing metadata (e.g., wait for eviction)
1701+
post_event: Event to record when the operation is done
1702+
"""
1703+
# Start D2H copy on d2h_stream
1704+
with torch.cuda.stream(d2h_stream):
1705+
# Record streams to prevent premature deallocation
1706+
linear_cache_indices.record_stream(d2h_stream)
1707+
weights.record_stream(d2h_stream)
1708+
# Do the D2H copy
1709+
linear_cache_indices_cpu = self.to_pinned_cpu(linear_cache_indices)
1710+
score_weights_cpu = self.to_pinned_cpu(weights)
1711+
1712+
# Write feature score metadata to DRAM
1713+
with record_function("## ssd_write_feature_score_metadata ##"):
1714+
with torch.cuda.stream(write_stream):
1715+
write_stream.wait_event(pre_event_for_write)
1716+
write_stream.wait_stream(d2h_stream)
1717+
self.record_function_via_dummy_profile(
1718+
"## ssd_write_feature_score_metadata ##",
1719+
self.ssd_db.set_feature_score_metadata_cuda,
1720+
linear_cache_indices_cpu,
1721+
torch.tensor(
1722+
[score_weights_cpu.shape[0]], device="cpu", dtype=torch.long
1723+
),
1724+
score_weights_cpu,
1725+
)
1726+
1727+
if post_event is not None:
1728+
write_stream.record_event(post_event)
1729+
16781730
def prefetch(
16791731
self,
16801732
indices: Tensor,
@@ -1747,12 +1799,6 @@ def _prefetch( # noqa C901
17471799

17481800
self.timestep += 1
17491801
self.timesteps_prefetched.append(self.timestep)
1750-
if self.backend_type == BackendType.DRAM and weights is not None:
1751-
# DRAM backend supports feature score eviction, if there is weights available
1752-
# in the prefetch call, we will set metadata for feature score eviction asynchronously
1753-
cloned_linear_cache_indices = linear_cache_indices.clone()
1754-
else:
1755-
cloned_linear_cache_indices = None
17561802

17571803
# Lookup and virtually insert indices into L1. After this operator,
17581804
# we know:
@@ -2114,16 +2160,18 @@ def _prefetch( # noqa C901
21142160
name="cache",
21152161
is_bwd=False,
21162162
)
2117-
if self.backend_type == BackendType.DRAM and weights is not None:
2118-
# Write feature score metadata to DRAM
2119-
self.record_function_via_dummy_profile(
2120-
"## ssd_write_feature_score_metadata ##",
2121-
self.ssd_db.set_feature_score_metadata_cuda,
2122-
cloned_linear_cache_indices.cpu(),
2123-
torch.tensor(
2124-
[weights.shape[0]], device="cpu", dtype=torch.long
2125-
),
2126-
weights.cpu(),
2163+
if (
2164+
self.backend_type == BackendType.DRAM
2165+
and weights is not None
2166+
and linear_cache_indices.numel() > 0
2167+
):
2168+
# Reuse ssd_memcpy_stream for feature score D2H since critical D2H is done
2169+
self._update_feature_score_metadata(
2170+
linear_cache_indices=linear_cache_indices,
2171+
weights=weights,
2172+
d2h_stream=self.ssd_memcpy_stream,
2173+
write_stream=self.feature_score_stream,
2174+
pre_event_for_write=self.ssd_event_cache_evict,
21272175
)
21282176

21292177
# Generate row addresses (pointing to either L1 or the current

0 commit comments

Comments
 (0)