Skip to content

Commit 77669c3

Browse files
YulunWmeta-codesync[bot]
authored andcommitted
Ensure it is safe to discard unenqueued events
Summary: Similar to D85124897, we did not invalidate the handle when we override the pendingEnqueueColl_. Although this is unlikely, we should at least make it an error rather than segfault. Change the code to invalidate the handle for the old event. Reviewed By: SuhitK Differential Revision: D85124893 fbshipit-source-id: 7ae0696983e2fae241d46c14c1d89ef326190a86
1 parent 514dd90 commit 77669c3

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

comms/utils/colltrace/CollTrace.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "comms/utils/colltrace/CollTrace.h"
44

55
#include <fmt/core.h>
6+
#include <folly/json.h>
67
#include <folly/logging/xlog.h>
78
#include <folly/stop_watch.h>
89

@@ -82,6 +83,32 @@ CollTrace::~CollTrace() {
8283
CommsMaybe<std::shared_ptr<CollTraceHandle>> CollTrace::recordCollective(
8384
std::unique_ptr<ICollMetadata> metadata,
8485
std::unique_ptr<ICollWaitEvent> waitEvent) noexcept {
86+
if (metadata == nullptr) {
87+
return folly::makeUnexpected(CommsError(
88+
"Received nullptr for metadata during recordCollective",
89+
commInternalError));
90+
}
91+
if (waitEvent == nullptr) {
92+
return folly::makeUnexpected(CommsError(
93+
"Received nullptr for waitEvent during recordCollective",
94+
commInternalError));
95+
}
96+
if (pendingEnqueueColl_ != nullptr) {
97+
XLOG_FIRST_N(
98+
ERR,
99+
1,
100+
fmt::format(
101+
"{}: Got another collective enqueued when a previous one haven't finished, colltrace result would be inaccurate. Previous: {}, Next:{}",
102+
logPrefix_,
103+
folly::toJson(pendingEnqueueColl_->collRecord->toDynamic()),
104+
folly::toJson(metadata->toDynamic())));
105+
auto handlePtr = eventToHandleMap_.find(pendingEnqueueColl_.get());
106+
if (handlePtr != eventToHandleMap_.end()) {
107+
handlePtr->second->invalidate();
108+
eventToHandleMap_.erase(pendingEnqueueColl_.get());
109+
}
110+
}
111+
85112
pendingEnqueueColl_ = std::make_unique<CollTraceEvent>(
86113
std::make_shared<CollRecord>(collId_.fetch_add(1), std::move(metadata)),
87114
std::move(waitEvent));

comms/utils/colltrace/tests/CollTraceUT.cc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,3 +411,49 @@ TEST_F(CollTraceTest, CheckHandleValidityWhenPendingQueueFull) {
411411
}
412412
}
413413
}
414+
415+
// If multiple enqueue happened at the same time, colltrace would not be able
416+
// to handle them as there is only one slot for pending enqueue collectives.
417+
// We will change the behavior later to make sure we can enqueue from multiple
418+
// places, but for now let's at least make sure it won't cause segfault.
419+
TEST_F(CollTraceTest, CheckHandleValidityOverMultipleEnqueues) {
420+
std::vector<std::shared_ptr<ICollTraceHandle>> handles;
421+
for (int i = 0; i < 10; ++i) {
422+
// Create metadata and wait event. We will only call recordCollective, so
423+
// we don't expect getting any calls to them
424+
auto metadata = std::make_unique<NiceMock<MockCollMetadata>>();
425+
auto waitEvent = std::make_unique<NiceMock<MockCollWaitEvent>>();
426+
427+
// Set up expectations for the wait event
428+
ON_CALL(*waitEvent, beforeCollKernelScheduled())
429+
.WillByDefault(Return(folly::unit));
430+
431+
// Set up expectations for the wait event
432+
ON_CALL(*metadata, toDynamic())
433+
.WillByDefault(
434+
Return(static_cast<folly::dynamic>(folly::dynamic::object())));
435+
436+
// Record a collective
437+
auto handleMaybe =
438+
collTrace->recordCollective(std::move(metadata), std::move(waitEvent));
439+
440+
// Verify that the handle was created successfully
441+
EXPECT_VALUE(handleMaybe);
442+
EXPECT_NE(handleMaybe.value().get(), nullptr);
443+
444+
// Trigger the enqueue
445+
handleMaybe.value()->trigger(
446+
CollTraceHandleTriggerState::BeforeEnqueueKernel);
447+
448+
handles.emplace_back(std::move(handleMaybe.value()));
449+
}
450+
451+
for (auto& handle : handles) {
452+
// Make sure we can get the coll record without encountering segmentation
453+
// fault. Getting invalid record is expected.
454+
auto res = handle->getCollRecord();
455+
if (res.hasValue()) {
456+
EXPECT_NE(res.value(), nullptr);
457+
}
458+
}
459+
}

0 commit comments

Comments
 (0)