Skip to content

Commit bda9700

Browse files
author
Frederik Mellbye
committed
Reflect GCL API changes
Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, christiana Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, christiana Subscribers: danielb Maniphest Tasks: T68415 Differential Revision: https://phabricator.sourcevertex.net/D74214
1 parent f359135 commit bda9700

File tree

5 files changed

+63
-49
lines changed

5 files changed

+63
-49
lines changed

tensorflow/compiler/plugin/poplar/driver/ops/custom_ops/popops/within_replicas.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ class AllGatherWithinReplicaOp : public PoplarOpDef {
121121
gcl::Chunks chunks;
122122

123123
std::vector<poplar::Tensor> shards;
124+
std::vector<gcl::Chunk> chunks_vec;
124125
for (auto i = 0u; i < inst->operand_count(); ++i) {
125126
TF_ASSIGN_OR_RETURN(
126127
poplar::Tensor input,
@@ -131,14 +132,16 @@ class AllGatherWithinReplicaOp : public PoplarOpDef {
131132
// to the subtensor after indexing the outermost dimension). Having
132133
// index=i means that the chunks in our gathered tensor will be in the
133134
// same order as the inputs.
134-
chunks.chunks.push_back({input, /*index*/ i, /*offset*/ 0});
135135

136+
chunks_vec.emplace_back(input, i, 0);
136137
shards.push_back(input);
137138
}
138139

140+
chunks.setChunks(chunks_vec);
141+
139142
auto original_input = poplar::concat(shards);
140-
chunks.originalInput =
141-
original_input.expand({0}).broadcast(inst->operand_count(), 0);
143+
chunks.setOriginalInput(
144+
original_input.expand({0}).broadcast(inst->operand_count(), 0));
142145
return chunks;
143146
}
144147
};
@@ -173,15 +176,15 @@ class ReduceScatterWithinReplicaOp : public PoplarOpDef {
173176
{debug_info, "ReduceScatterWithinReplica"},
174177
GetReplicatedCollectiveOptions(res));
175178

176-
CHECK_EQ(ipu_count, chunks.chunks.size())
179+
CHECK_EQ(ipu_count, chunks.getChunks().size())
177180
<< "Expecting to have a chunk for each IPU.";
178-
TF_CHECK_OK(SetOutputs(chunks.chunks, inst, res, tensor_map));
181+
TF_CHECK_OK(SetOutputs(chunks.getChunks(), inst, res, tensor_map));
179182

180183
return seq;
181184
}
182185

183186
private:
184-
Status SetOutputs(std::vector<gcl::Chunk>& output_chunks,
187+
Status SetOutputs(const std::vector<gcl::Chunk>& output_chunks,
185188
const HloInstruction* inst, CompilerResources& res,
186189
TensorMap& tensor_map) {
187190
const auto output_tensor_shape =
@@ -191,7 +194,7 @@ class ReduceScatterWithinReplicaOp : public PoplarOpDef {
191194
const auto output_tensor_size = output_tensor_shape.dimensions(0);
192195

193196
for (auto i = 0; i < output_chunks.size(); ++i) {
194-
auto tensor = output_chunks[i].tensor;
197+
auto tensor = output_chunks[i].getTensor();
195198
CHECK_EQ(tensor.rank(), 1);
196199

197200
// Pad everything to a consistent shape. We don't know how

tensorflow/compiler/plugin/poplar/driver/poplar_executable.cc

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,19 +72,24 @@ void PoplarExecutableCore::PopulateCollectiveBalanceReorderHostRerrangements() {
7272
for (auto& param_host_rearrangement_entry :
7373
GetRemoteParameterHostRearrangements()) {
7474
int64_t id = param_host_rearrangement_entry.first;
75-
auto& param_host_rearrangement = param_host_rearrangement_entry.second;
75+
const auto& param_host_rearrangement =
76+
param_host_rearrangement_entry.second;
7677
gcl::CollectiveBalancedHostRearrangement host_rearrangement;
77-
host_rearrangement.replicationFactor =
78-
param_host_rearrangement.replication_factor;
79-
host_rearrangement.totalElementsPerReplica =
80-
param_host_rearrangement.total_elements_per_replica;
81-
host_rearrangement.gatheredToRefSlices.reserve(
78+
host_rearrangement.setReplicationFactor(
79+
param_host_rearrangement.replication_factor);
80+
host_rearrangement.setTotalElementsPerReplica(
81+
param_host_rearrangement.total_elements_per_replica);
82+
83+
std::vector<poplar::Interval> gathered_to_ref_slices(
84+
host_rearrangement.getGatheredToRefSlices());
85+
gathered_to_ref_slices.reserve(
8286
param_host_rearrangement.gathered_to_ref_slice.size());
83-
for (auto& slice : param_host_rearrangement.gathered_to_ref_slice) {
84-
host_rearrangement.gatheredToRefSlices.emplace_back(slice.first,
85-
slice.second);
87+
88+
for (const auto& slice : param_host_rearrangement.gathered_to_ref_slice) {
89+
gathered_to_ref_slices.emplace_back(slice.first, slice.second);
8690
}
87-
host_rearrangement.elementMap = param_host_rearrangement.element_map;
91+
host_rearrangement.setGatheredToRefSlices(gathered_to_ref_slices);
92+
host_rearrangement.setElementMap(param_host_rearrangement.element_map);
8893
cbr_host_rearrangements_[id] = std::move(host_rearrangement);
8994
}
9095
}

tensorflow/compiler/plugin/poplar/driver/poplar_executor.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,8 @@ PoplarExecutor::TensorControl::TensorControl(size_t size_) {
307307
std::size_t PoplarExecutor::TensorControl::GetRemoteBufferSize() const {
308308
if (host_rearrangement) {
309309
return ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(
310-
element_type, {host_rearrangement->replicationFactor,
311-
host_rearrangement->totalElementsPerReplica}));
310+
element_type, {host_rearrangement->getReplicationFactor(),
311+
host_rearrangement->getTotalElementsPerReplica()}));
312312
}
313313
return size;
314314
}
@@ -2848,11 +2848,12 @@ Status PoplarExecutor::MoveDeviceToHost() {
28482848
// code.
28492849
std::vector<char> temp(buffer.size());
28502850
tc->host_rearrangement->undoRearrangeForCollective(
2851-
buffer.data(), temp.data(), bytes_per_element);
2851+
buffer, temp, bytes_per_element);
28522852
memcpy(tc->data, temp.data(), tc->size);
28532853
} else {
28542854
tc->host_rearrangement->undoRearrangeForCollective(
2855-
buffer.data(), tc->data, bytes_per_element);
2855+
buffer.data(), buffer.size(), tc->data, tc->size,
2856+
bytes_per_element);
28562857
}
28572858
}
28582859
} else {
@@ -2975,10 +2976,11 @@ Status PoplarExecutor::MoveHostToDevice() {
29752976
std::vector<char> temp(buffer_size);
29762977
memcpy(temp.data(), tc->data, tc->size);
29772978
tc->host_rearrangement->rearrangeForCollective(
2978-
temp.data(), buffer.data(), bytes_per_element);
2979+
temp, buffer, bytes_per_element);
29792980
} else {
29802981
tc->host_rearrangement->rearrangeForCollective(
2981-
tc->data, buffer.data(), bytes_per_element);
2982+
tc->data, tc->size, buffer.data(), buffer.size(),
2983+
bytes_per_element);
29822984
}
29832985
} else {
29842986
buffer.resize(bytes_per_replica);

tensorflow/compiler/plugin/poplar/driver/visitors/partitioned_elementwise_cluster_visitor.cc

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ Status PartitionedElementwiseClusterVisitor::ValidateShape(
138138
const gcl::CollectiveBalancedHostRearrangement& host_rearrangement =
139139
cbr->getHostRearrangement();
140140
const int64_t replicated_element_count =
141-
host_rearrangement.totalElementsPerReplica;
141+
host_rearrangement.getTotalElementsPerReplica();
142142
const int64_t non_replicated_element_count =
143-
replicated_element_count * host_rearrangement.replicationFactor;
143+
replicated_element_count * host_rearrangement.getReplicationFactor();
144144
const int64_t xla_element_count = ShapeUtil::ElementsIn(shape);
145145
VLOG(3) << "CBR slice element count: " << replicated_element_count
146146
<< ", total collectives elements: " << non_replicated_element_count
@@ -200,7 +200,7 @@ PartitionedElementwiseClusterVisitor::MakeParameterAllocationFunction(
200200
TF_ASSIGN_OR_RETURN(poplar::Type type, PoplarDataType(shape));
201201
auto element_count = ShapeUtil::ElementsIn(shape);
202202
auto& host_rearrangement = cbr->getHostRearrangement();
203-
if (host_rearrangement.totalElementsPerReplica == element_count) {
203+
if (host_rearrangement.getTotalElementsPerReplica() == element_count) {
204204
tensor_like = DriverTensor(cbr->createReplicaSlice(type));
205205
}
206206
}
@@ -291,13 +291,16 @@ PartitionedElementwiseClusterVisitor::UpdateRemoteBufferInformation(
291291
cbr_info->host_rearrangement_id) ==
292292
remote_parameter_host_rearrangements.end()) {
293293
RemoteParameterHostRearrangement host_rearrangement;
294-
host_rearrangement.replication_factor = src.replicationFactor;
295-
host_rearrangement.total_elements_per_replica = src.totalElementsPerReplica;
296-
for (auto& interval : src.gatheredToRefSlices) {
294+
host_rearrangement.replication_factor = src.getReplicationFactor();
295+
host_rearrangement.total_elements_per_replica =
296+
src.getTotalElementsPerReplica();
297+
for (const auto& interval : src.getGatheredToRefSlices()) {
297298
host_rearrangement.gathered_to_ref_slice.emplace_back(interval.begin(),
298299
interval.end());
299300
}
300-
host_rearrangement.element_map = src.elementMap;
301+
302+
host_rearrangement.element_map = src.getElementMap();
303+
301304
remote_parameter_host_rearrangements[cbr_info->host_rearrangement_id] =
302305
std::move(host_rearrangement);
303306
}
@@ -314,7 +317,7 @@ PartitionedElementwiseClusterVisitor::UpdateRemoteBufferInformation(
314317
for (auto param_idx : merged_params) {
315318
TF_RETURN_IF_ERROR(SetRemoteBufferHostRearrangementId(
316319
graph, entry_comp, param_idx, cbr_info->host_rearrangement_id,
317-
src.totalElementsPerReplica));
320+
src.getTotalElementsPerReplica()));
318321
}
319322

320323
return true;

tensorflow/compiler/plugin/poplar/tests/replicated_resource_update_elementwise_clustering_hw_test.cc

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -231,26 +231,29 @@ class ReplicatedResourceUpdateElementwiseClusteringHwTest
231231
info.host_rearrangement_id);
232232
CHECK(hr_it !=
233233
annotations.remote_parameter_host_rearrangements.end());
234-
auto& host_rearrangement = hr_it->second;
234+
const auto& host_rearrangement = hr_it->second;
235235
auto& gcl = host_rearrangements[index];
236-
gcl.replicationFactor = host_rearrangement.replication_factor;
237-
gcl.totalElementsPerReplica =
238-
host_rearrangement.total_elements_per_replica;
239-
for (auto& slice : host_rearrangement.gathered_to_ref_slice) {
240-
gcl.gatheredToRefSlices.emplace_back(slice.first, slice.second);
236+
gcl.setReplicationFactor(host_rearrangement.replication_factor);
237+
gcl.setTotalElementsPerReplica(
238+
host_rearrangement.total_elements_per_replica);
239+
std::vector<poplar::Interval> gathered_to_ref_slices =
240+
gcl.getGatheredToRefSlices();
241+
242+
for (const auto& slice : host_rearrangement.gathered_to_ref_slice) {
243+
gathered_to_ref_slices.emplace_back(slice.first, slice.second);
241244
}
242-
gcl.elementMap = host_rearrangement.element_map;
243245

244-
CHECK_EQ(gcl.replicationFactor, param.replication_factor);
245-
per_replica_size = gcl.totalElementsPerReplica;
246-
aligned_size = per_replica_size * gcl.replicationFactor;
246+
gcl.setGatheredToRefSlices(gathered_to_ref_slices);
247+
gcl.setElementMap(host_rearrangement.element_map);
248+
249+
CHECK_EQ(gcl.getReplicationFactor(), param.replication_factor);
250+
per_replica_size = gcl.getTotalElementsPerReplica();
251+
aligned_size = per_replica_size * gcl.getReplicationFactor();
247252
buffer.resize(aligned_size);
248253

249254
std::vector<float> tmp(aligned_size);
250255
VLOG(1) << "Rearranging for collective...";
251-
gcl.rearrangeForCollective(
252-
reinterpret_cast<const char*>(buffer.data()),
253-
reinterpret_cast<char*>(tmp.data()), 4);
256+
gcl.rearrangeForCollective(buffer, tmp);
254257
buffer = std::move(tmp);
255258
}
256259

@@ -320,9 +323,9 @@ class ReplicatedResourceUpdateElementwiseClusteringHwTest
320323
auto host_rearrangement_it = host_rearrangements.find(index);
321324
if (host_rearrangement_it != host_rearrangements.end()) {
322325
auto& host_rearrangement = host_rearrangement_it->second;
323-
per_replica_size = host_rearrangement.totalElementsPerReplica;
326+
per_replica_size = host_rearrangement.getTotalElementsPerReplica();
324327
aligned_size =
325-
per_replica_size * host_rearrangement.replicationFactor;
328+
per_replica_size * host_rearrangement.getReplicationFactor();
326329
}
327330

328331
VLOG(1) << "Downloading data from " << info.buffer_name
@@ -341,9 +344,7 @@ class ReplicatedResourceUpdateElementwiseClusteringHwTest
341344
EXPECT_TRUE(cluster);
342345
std::vector<float> tmp(buffer.size());
343346
VLOG(1) << "Undo rearrangement for collective...";
344-
host_rearrangement.undoRearrangeForCollective(
345-
reinterpret_cast<const char*>(buffer.data()),
346-
reinterpret_cast<char*>(tmp.data()), 4);
347+
host_rearrangement.undoRearrangeForCollective(buffer, tmp);
347348
buffer = std::move(tmp);
348349
}
349350
buffer.resize(size);

0 commit comments

Comments
 (0)