Skip to content

Commit

Permalink
[BE] split seq_id to collective_seq_id and p2p_seq_id
Browse files Browse the repository at this point in the history
Summary:
Split out seq_id into collective_seq_id and p2p_seq_id. The main idea
here is that collectives that go to all machines should have identical
collective_seq_id and therefore it makes it easier to spot if one of
machines isn't handling a collective operation.
Next, we can attempt to match up p2p operations to ensure that the
sender(s)/receivers(s) are in sync.

Resolves issue: #125173

ghstack-source-id: c31b3164d2e51efeab210e6a949cd4c8d1ecd3d7
Pull Request resolved: #125727
  • Loading branch information
c-p-i-o committed May 13, 2024
1 parent 1becdbe commit 10d3cef
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 52 deletions.
4 changes: 2 additions & 2 deletions test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
const std::vector<at::Tensor>& outputs = {},
bool record = false) override {
return c10::make_intrusive<WorkNCCLSimulateErrors>(
device, simulateError_, rank, opType, seq_);
device, simulateError_, rank, opType, seqCollective_);
}

size_t getNCCLCommCacheSize() {
Expand Down Expand Up @@ -131,7 +131,7 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
const std::vector<at::Tensor>& outputs = {},
bool record = false) override {
return c10::make_intrusive<WorkNCCLTimedoutErrors>(
device, setTimedoutError_, rank, opType, seq_);
device, setTimedoutError_, rank, opType, seqCollective_);
}

void setTimedoutError() {
Expand Down
32 changes: 18 additions & 14 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3402,7 +3402,7 @@ def test_short(self, timing_enabled):

t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
ver = t["version"]
self.assertEqual(ver, "1.5")
self.assertEqual(ver, "2.5")
pg_config = t["pg_config"]
self.assertEqual(len(pg_config), 1)
default_pg_info = pg_config["0"]
Expand All @@ -3426,7 +3426,7 @@ def test_short(self, timing_enabled):
self.assertIn("test_c10d_nccl.py", str(last["frames"]))
self.assertEqual(last["input_sizes"], ((3, 4),))
self.assertEqual(last["output_sizes"], ((3, 4),))
self.assertEqual(last["seq_id"], 2)
self.assertEqual(last["collective_seq_id"], 2)
now = datetime.now()
event_created_time = datetime.fromtimestamp(
last["time_created_ns"] / 1000000000
Expand Down Expand Up @@ -3507,7 +3507,7 @@ def test_long(self):
self.assertIn("test_c10d_nccl.py", str(last["frames"]))
self.assertEqual(last["input_sizes"], ((3, 4),))
self.assertEqual(last["output_sizes"], ((3, 4),))
self.assertEqual(last["seq_id"] - first["seq_id"], 9)
self.assertEqual(last["collective_seq_id"] - first["collective_seq_id"], 9)

@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
Expand Down Expand Up @@ -3537,10 +3537,10 @@ def test_trace_while_active(self, timing_enabled):
t = t["entries"]
self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce")
if self.rank == 0:
self.assertEqual(t[-1]["seq_id"], 1)
self.assertEqual(t[-1]["collective_seq_id"], 1)
self.assertEqual(t[-1]["state"], "completed")
else:
self.assertEqual(t[-1]["seq_id"], 2)
self.assertEqual(t[-1]["collective_seq_id"], 2)
self.assertEqual(
t[-1]["state"], self.started_or_scheduled(timing_enabled)
)
Expand Down Expand Up @@ -3582,10 +3582,10 @@ def gather_trace():
t = t["entries"]
self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce")
if self.rank == 0:
self.assertEqual(t[-1]["seq_id"], 1)
self.assertEqual(t[-1]["collective_seq_id"], 1)
self.assertEqual(t[-1]["state"], "completed")
else:
self.assertEqual(t[-1]["seq_id"], 2)
self.assertEqual(t[-1]["collective_seq_id"], 2)
self.assertEqual(
t[-1]["state"], self.started_or_scheduled(timing_enabled)
)
Expand Down Expand Up @@ -3677,7 +3677,9 @@ def test_batched_send_recv(self, op_sizes_per_coalesce, timing_enabled):
self.assertEqual(
t["entries"][p2p_op_idx]["profiling_name"], profiling_name
)
self.assertEqual(t["entries"][p2p_op_idx]["seq_id"], expected_seq)
self.assertEqual(
t["entries"][p2p_op_idx]["collective_seq_id"], expected_seq
)
self.assertEqual(t["entries"][p2p_op_idx]["op_id"], expected_op_id)
expected_op_id += 1
self.assertEqual(t["entries"][p2p_op_idx]["input_sizes"], [input_sizes])
Expand All @@ -3697,7 +3699,9 @@ def test_batched_send_recv(self, op_sizes_per_coalesce, timing_enabled):
self.assertEqual(
t["entries"][coalesced_op]["profiling_name"], "nccl:coalesced"
)
self.assertEqual(t["entries"][coalesced_op]["seq_id"], expected_seq)
self.assertEqual(
t["entries"][coalesced_op]["collective_seq_id"], expected_seq
)
expected_seq += 1
self.assertEqual(t["entries"][coalesced_op]["state"], "completed")
self.assertEqual(t["entries"][coalesced_op]["input_sizes"], [])
Expand Down Expand Up @@ -3753,7 +3757,7 @@ def test_individual_send_recv(self, op_sizes, timing_enabled):
input_sizes = op_sizes[seq % ops_per_repeat]
profiling_name = "nccl:recv 0<-1" if self.rank == 0 else "nccl:send 1->0"
self.assertEqual(t["entries"][seq]["profiling_name"], profiling_name)
self.assertEqual(t["entries"][seq]["seq_id"], expected_seq)
self.assertEqual(t["entries"][seq]["p2p_seq_id"], expected_seq)
expected_seq += 1
self.assertEqual(t["entries"][seq]["op_id"], expected_op_id)
expected_op_id += 1
Expand Down Expand Up @@ -3813,7 +3817,7 @@ def test_coalescing_manager_collective(self, timing_enabled):
self.assertEqual(
t["entries"][0]["profiling_name"], "nccl:reduce_scatter_tensor_coalesced"
)
self.assertEqual(t["entries"][0]["seq_id"], 1)
self.assertEqual(t["entries"][0]["collective_seq_id"], 1)
self.assertEqual(t["entries"][0]["input_sizes"], [[2, 2], [2, 2]])
self.assertEqual(
t["entries"][0]["output_sizes"],
Expand Down Expand Up @@ -3881,9 +3885,9 @@ def test_timeout_dumps(self, timing_enabled):
t = pickle.load(f)
t = t["entries"]
self.assertEqual(len(t), 2)
self.assertEqual(t[0]["seq_id"], 1)
self.assertEqual(t[0]["collective_seq_id"], 1)
self.assertEqual(t[0]["state"], "completed")
self.assertEqual(t[1]["seq_id"], 2)
self.assertEqual(t[1]["collective_seq_id"], 2)
self.assertEqual(
t[1]["state"], self.started_or_scheduled(timing_enabled)
)
Expand Down Expand Up @@ -3944,7 +3948,7 @@ def test_timeout_dumps_on_stuck_ranks(self):
t = pickle.load(f)
t = t["entries"]
self.assertEqual(len(t), 1)
self.assertEqual(t[0]["seq_id"], 1)
self.assertEqual(t[0]["collective_seq_id"], 1)
self.assertEqual(t[0]["state"], "completed")
return

Expand Down
62 changes: 42 additions & 20 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

#include <torch/csrc/distributed/c10d/Work.hpp>
#ifdef USE_C10D_NCCL

#include <exception>
Expand Down Expand Up @@ -926,7 +927,7 @@ void ProcessGroupNCCL::setSequenceNumberForGroup() {
} // NCCL just starts sequence numbers at 0.

uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() {
return seq_;
return seqCollective_;
}

void ProcessGroupNCCL::registerOnCompletionHook(
Expand Down Expand Up @@ -2239,14 +2240,15 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
device,
rank,
opType,
seq_,
seqCollective_,
profilingTitle,
profilingTitle != nullptr ? c10::optional<std::vector<at::Tensor>>(inputs)
: c10::nullopt,
desyncDebug_,
enableTiming_.load(),
dist_debug_level_);
if (record) {
bool isP2P = isP2POp(opType);
// Ideally record every work that we enqueue, rather than every work we
// create.
// - at the time of this PR we do not currently enqueue every created work
Expand All @@ -2263,13 +2265,15 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
r->trace_id_ = NCCLTraceBuffer::get()->record(
uid_,
std::make_tuple(pg_name_, pg_desc_),
seq_,
seqCollective_,
seqP2P_,
op_id_,
profilingTitle ? profilingTitle : "",
inputs,
outputs,
r->ncclStartEvent_.get(),
r->ncclEndEvent_.get());
r->ncclEndEvent_.get(),
isP2P);
}
return r;
}
Expand Down Expand Up @@ -2321,10 +2325,6 @@ ProcessGroupNCCL::Options::Options(bool is_high_priority_stream)
static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04;

void ProcessGroupNCCL::startCoalescing() {
coalescedDevice_.set_index(-1);
coalescedComm_ = nullptr;
coalescing_state_ |= CoalActive;
groupStart();
// Other collective ops bump seq_ before creating a work. Thus, if coalesced
// ops bump seq_ only after initing a work they will collide with (reuse) the
// seq_ of the last non-coalesced collective. Previously, seq_ was bumped
Expand All @@ -2333,10 +2333,19 @@ void ProcessGroupNCCL::startCoalescing() {
// same seq_ for those ops and its 'endCoalescing' op. Hence we bump during
// start, which has one minor downside- we burn a seq_ if someone ever does a
// 'start' and 'end' coalescing region without doing an operation inbetween.
seq_++;

// Don't bump op_id_ here, becuase startCoalescing isn't a logical operation.
// Don't bump op_id_ here, because startCoalescing isn't a logical operation.
// Bump it for each logical op inside the coalescing group.
if (coalescing_state_ & CoalP2P) {
seqP2P_++;
} else {
seqCollective_++;
}

coalescedDevice_.set_index(-1);
coalescedComm_ = nullptr;
coalescing_state_ |= CoalActive;
groupStart();
}

// `optype` is for specifying a composite optype, such as ALLGATHER and
Expand Down Expand Up @@ -2431,7 +2440,11 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
errorIfCapturingNonCapturableNCCL(capture_status);

// Bump collective counter
seq_++;
if (isP2POp(opType)) {
seqP2P_++;
} else {
seqCollective_++;
}
op_id_++;

auto device = getDevice(input);
Expand Down Expand Up @@ -2586,9 +2599,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
errorIfCapturingNonCapturableNCCL(capture_status);

// Bump collective counter
seq_++;
if (isP2POp(opType)) {
seqP2P_++;
} else {
seqCollective_++;
}
// For coalescingManager collectives, there is no individual c++ call per
// collective so there is no flight record and we increment seq_ and op_id_
// collective so there is no flight record and we increment seq*_ and op_id_
// together. Compare this to startCoalesing/endCoalescing flow where we
// increment seq_ once per group and increment op_id_ once per indvidual
// operation within the group
Expand Down Expand Up @@ -2814,8 +2831,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(

if (!coalescing_state_) {
// Bump sequence number. Don't do so if it's a batch P2P, it will be
// bumped in `endCoalescing`.
seq_++;
// bumped in `startCoalescing`.
seqP2P_++;
}
}

Expand Down Expand Up @@ -2845,6 +2862,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
// First let NCCL streams wait for input tensors allocation streams
syncStream(device, ncclEvents_[key], ncclStream);

bool isP2P = isP2POp(opType);
// Work itself will create the CUDA events on all GPUs of tensors
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> work;
if (coalescing_state_) {
Expand All @@ -2856,13 +2874,15 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
auto trace_id = NCCLTraceBuffer::get()->record(
uid_,
std::make_tuple(pg_name_, pg_desc_),
seq_,
seqCollective_,
seqP2P_,
op_id_,
profilingTitle,
{tensor},
{tensor},
nullptr,
nullptr);
nullptr,
isP2P);
// TODO(whc) if we want to make the per-p2p-op flightrecorder entries get
// their timings/states updated by proxy when the Work obj representing the
// coalesce group gets its update, we could accumulate these trace_ids
Expand All @@ -2881,19 +2901,21 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
// output, not sure what
work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
work->outputs_->push_back(tensor);
// TODO(whc) becuase we don't pass output {tensor} to initWork, we tell
// TODO(whc) because we don't pass output {tensor} to initWork, we tell
// initWork to not record, and then we manually call record passing all the
// information it wants.
work->trace_id_ = NCCLTraceBuffer::get()->record(
uid_,
std::make_tuple(pg_name_, pg_desc_),
seq_,
seqCollective_,
seqP2P_,
op_id_,
profilingTitle,
{tensor},
{tensor},
work->ncclStartEvent_.get(),
work->ncclEndEvent_.get());
work->ncclEndEvent_.get(),
isP2P);
}

// is gpuGuard needed for the if block below, or can i swap them
Expand Down
11 changes: 7 additions & 4 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1050,24 +1050,27 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Counting for the sequential number of NCCL collective call.
// (specifically, how many actual kernels we launched, which differs from
// op_id_ when coalescing is enabled)
uint64_t seq_{0};
uint64_t seqCollective_{0};

// Counting for the sequential number of NCCL P2P calls.
uint64_t seqP2P_{0};

// Incrementing counter for logical operations (collective or p2p) issued on
// the ProcessGroup
uint64_t op_id_{0};

// the sequential number of the last colletive enqueued into workMetaList_
// the sequential number of the last collective enqueued into workMetaList_
// This is useful for indentifying a rank that has not join a collective
// initialized to be -1 to indicate no collective has been enqueued
int64_t lastEnqueuedSeq_{-1};

// the name of the last collective enqueued into workMetaList_
std::string lastEnqueuedWorkName_;

// the sequential number of the last colletive started as the kernal
// the sequential number of the last collective started as the kernel
int64_t lastStartedSeq_{-1};

// the name of the last collective started as the kernal
// the name of the last collective started as the kernel
std::string lastStartedWorkName_;

// the sequential number of the last colletive completed marked by
Expand Down

0 comments on commit 10d3cef

Please sign in to comment.