1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <iterator>
18 #include <memory>
19 #include <tuple>
20 #include <utility>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/ADT/iterator_range.h"
30 #include "llvm/Support/Casting.h"
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/Builders.h" // from @llvm-project
33 #include "mlir/IR/Identifier.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/Operation.h" // from @llvm-project
36 #include "mlir/IR/Types.h" // from @llvm-project
37 #include "mlir/IR/Value.h" // from @llvm-project
38 #include "mlir/Pass/Pass.h" // from @llvm-project
39 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
46
47 namespace mlir {
48 namespace TFTPU {
49
50 namespace {
51
52 constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
53 constexpr char kDeviceAttr[] = "device";
54 constexpr char kNameAttr[] = "name";
55 constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica";
56 constexpr char kNumReplicasAttr[] = "num_replicas";
57 constexpr char kReplicatedInputIndicesAttr[] = "_replicated_input_indices";
58 constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
59
60 constexpr char kBadTPUReplicateAttrMsg[] =
61 "requires '_tpu_replicate' string attribute";
62
63 // Mapping for `_tpu_replicate` attribute to TPUReplicateMetadata attributes.
64 using MetadataMap = llvm::SmallDenseMap<llvm::StringRef, NamedAttrList, 8>;
65
66 // A set of operations in a cluster.
67 using ClusterOps = llvm::SmallSetVector<Operation*, 8>;
68
69 // Mapping for `_tpu_replicate` attribute to ops of a cluster.
70 using ClusterMap = llvm::SmallDenseMap<llvm::StringRef, ClusterOps, 8>;
71
72 struct TPUClusterFormationPass
73 : public TF::TPUClusterFormationPassBase<TPUClusterFormationPass> {
getDependentDialectsmlir::TFTPU::__anon82cb0a850111::TPUClusterFormationPass74 void getDependentDialects(DialectRegistry& registry) const override {
75 registry.insert<tf_device::TensorFlowDeviceDialect>();
76 }
77
78 void runOnOperation() override;
79 };
80
81 // Creates a mapping from the TPUReplicateMetadata ops `_tpu_replicate`
82 // attribute to its attributes and removes the ops. If multiple
83 // TPUReplicateMetadata ops have the same `_tpu_replicate` attribute, an error
84 // will be returned.
CollectMetadata(Block * block,MetadataMap * metadata_map)85 LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) {
86 // Just look at top-level operations in the block (not nested ones)
87 for (Operation& op : llvm::make_early_inc_range(*block)) {
88 auto metadata_op = dyn_cast<TF::TPUReplicateMetadataOp>(op);
89 if (!metadata_op) continue;
90
91 NamedAttrList attrs(metadata_op->getAttrDictionary());
92
93 // Missing or bad `_tpu_replicate` attribute.
94 auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr);
95 if (!tpu_replicate_attr)
96 return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
97
98 auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast<StringAttr>();
99 if (!tpu_replicate_attr_str || tpu_replicate_attr_str.getValue().empty())
100 return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
101
102 // Remove `name` attribute.
103 attrs.erase(Identifier::get(kNameAttr, metadata_op.getContext()));
104
105 auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(),
106 std::move(attrs));
107
108 // There are multiple TPUReplicateMetadata ops with the same
109 // `_tpu_replicate` attribute.
110 if (!it.second) {
111 return metadata_op.emitError()
112 << "multiple TPUReplicateMetadata ops with the same '"
113 << kTPUReplicateAttr << "' attribute '"
114 << tpu_replicate_attr_str.getValue() << "' found";
115 }
116 metadata_op.erase();
117 }
118 return success();
119 }
120
121 // Collects and clusters ops with the same `_tpu_replicate` attribute. This will
122 // return an error if a `_tpu_replicate` attribute of an op is empty.
CollectAndGroupClusterOps(Block * block,ClusterMap * clusters)123 LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) {
124 for (Operation& op : *block) {
125 if (auto attr = op.getAttrOfType<StringAttr>(kTPUReplicateAttr)) {
126 if (attr.getValue().empty())
127 return op.emitError()
128 << "attribute '" << kTPUReplicateAttr << "' is empty";
129
130 auto it = clusters->try_emplace(attr.getValue());
131 it.first->getSecond().insert(&op);
132 }
133 }
134
135 return success();
136 }
137
138 // Collects all resource ids from an op.
CollectResourceIdsFromOp(Operation & op,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis,llvm::SmallDenseSet<int64_t> & observed_resource_ids)139 void CollectResourceIdsFromOp(
140 Operation& op,
141 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis,
142 llvm::SmallDenseSet<int64_t>& observed_resource_ids) {
143 op.walk([&](Operation* inner_op) {
144 for (Value operand : TF::filter_resources(inner_op->getOperands())) {
145 if (resource_alias_analysis.IsUnknownResource(operand)) continue;
146 const auto& ids = resource_alias_analysis.GetResourceUniqueIds(operand);
147 observed_resource_ids.insert(ids.begin(), ids.end());
148 }
149 for (Value result : TF::filter_resources(inner_op->getResults())) {
150 if (resource_alias_analysis.IsUnknownResource(result)) continue;
151 const auto& ids = resource_alias_analysis.GetResourceUniqueIds(result);
152 observed_resource_ids.insert(ids.begin(), ids.end());
153 }
154 });
155 }
156
157 // Checks if an op should be moved after a cluster. There may be users of a
158 // cluster interleaved among the cluster ops.
ShouldMoveOpAfterCluster(Block * block,Operation * op,const ClusterOps & cluster_ops,const llvm::SmallSetVector<Operation *,8> & preceding_users,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis,const llvm::SmallDenseSet<int64_t> & observed_resource_ids)159 bool ShouldMoveOpAfterCluster(
160 Block* block, Operation* op, const ClusterOps& cluster_ops,
161 const llvm::SmallSetVector<Operation*, 8>& preceding_users,
162 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis,
163 const llvm::SmallDenseSet<int64_t>& observed_resource_ids) {
164 const bool is_replicate = llvm::isa<tf_device::ReplicateOp>(op);
165 auto result = op->walk([&](Operation* inner_op) {
166 for (Value operand : inner_op->getOperands()) {
167 Operation* def = operand.getDefiningOp();
168 // Operands may not have a defining op (BlockArgument) or is from a
169 // different block.
170 if (!def || def->getBlock() != block) continue;
171
172 if (cluster_ops.count(def) != 0 || preceding_users.count(def) != 0) {
173 // Op is a user of a cluster or another op that is a user of the
174 // cluster (transitively), but is before the cluster.
175 return WalkResult::interrupt();
176 }
177 }
178
179 // Don't visit replicate op inner op operands as new resource
180 // values/arguments may have been created but are not known in
181 // `resource_alias_analysis`.
182 if (is_replicate && inner_op != op) return WalkResult::advance();
183
184 // Check for uses of any resource in or after cluster.
185 for (Value operand : TF::filter_resources(inner_op->getOperands())) {
186 if (resource_alias_analysis.IsUnknownResource(operand)) continue;
187 auto ids = resource_alias_analysis.GetResourceUniqueIds(operand);
188 for (const auto& id : ids)
189 if (observed_resource_ids.contains(id)) return WalkResult::interrupt();
190 }
191 return WalkResult::advance();
192 });
193
194 return result.wasInterrupted();
195 }
196
197 // Collects ops that are before ops in the cluster but are users of other ops
198 // in the cluster. This may happen because users of individual ops in the
199 // cluster may be interleaved with other ops in the cluster. Resource id's are
200 // also captured, to keep track of resource usage before, in, or after the
201 // cluster.
202 // TODO(b/175701589): Extend this to handle all side effecting ops while
203 // handling transitive data dependencies.
CollectClusterPrecedingUsers(Block * block,const ClusterOps & cluster_ops,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)204 llvm::SmallSetVector<Operation*, 8> CollectClusterPrecedingUsers(
205 Block* block, const ClusterOps& cluster_ops,
206 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
207 llvm::SmallSetVector<Operation*, 8> preceding_users;
208 llvm::SmallDenseSet<int64_t> observed_resource_ids;
209
210 auto front = Block::iterator(cluster_ops.front());
211 auto back = Block::iterator(cluster_ops.back());
212 for (Operation& op : llvm::make_range(front, back)) {
213 if (cluster_ops.contains(&op)) {
214 CollectResourceIdsFromOp(op, resource_alias_analysis,
215 observed_resource_ids);
216 } else if (ShouldMoveOpAfterCluster(
217 block, &op, cluster_ops, preceding_users,
218 resource_alias_analysis, observed_resource_ids)) {
219 preceding_users.insert(&op);
220 CollectResourceIdsFromOp(op, resource_alias_analysis,
221 observed_resource_ids);
222 }
223 }
224
225 return preceding_users;
226 }
227
228 // Collects results and associated types of the cluster that are used outside of
229 // the cluster. These results and types are used to create the clusters
230 // `tf_device.cluster` and associated terminator. Results that have no uses
231 // outside of the cluster (i.e. results of ops in the cluster are only consumed
232 // by other ops in the cluster) are pruned.
CollectClusterResults(Block * block,const ClusterOps & cluster_ops)233 llvm::SmallVector<Value, 8> CollectClusterResults(
234 Block* block, const ClusterOps& cluster_ops) {
235 llvm::SmallVector<Value, 8> results;
236
237 for (Operation* op : cluster_ops) {
238 for (Value result : op->getResults()) {
239 for (Operation* user : result.getUsers()) {
240 // Check if user is not an op in the cluster.
241 if (cluster_ops.count(block->findAncestorOpInBlock(*user)) == 0) {
242 results.push_back(result);
243 break;
244 }
245 }
246 }
247 }
248
249 return results;
250 }
251
252 // Creates a `tf_device.cluster` to wrap cluster ops.
CreateClusterOp(Block * block,const ClusterOps & cluster_ops,llvm::ArrayRef<Value> results,llvm::ArrayRef<Operation * > preceding_users)253 tf_device::ClusterOp CreateClusterOp(
254 Block* block, const ClusterOps& cluster_ops, llvm::ArrayRef<Value> results,
255 llvm::ArrayRef<Operation*> preceding_users) {
256 // `tf_device.cluster` will be placed at where the last op of the cluster is.
257 Operation* last_cluster_op = cluster_ops.back();
258 OpBuilder builder(last_cluster_op);
259
260 llvm::SmallVector<Type, 8> result_types;
261 for (Value result : results) result_types.push_back(result.getType());
262 auto cluster = builder.create<tf_device::ClusterOp>(last_cluster_op->getLoc(),
263 result_types);
264
265 Block* body = new Block;
266 cluster.body().push_back(body);
267
268 // Move cluster ops to the cluster body. Also remove `_tpu_replicate` and
269 // `device` attribute from ops in the cluster as that information will be
270 // present in the `tf_device.cluster`. Do this for all ops including nested
271 // ops.
272 for (Operation* cluster_op : cluster_ops) {
273 cluster_op->moveBefore(body, body->end());
274 cluster_op->walk([&](Operation* inner_op) {
275 inner_op->removeAttr(kTPUReplicateAttr);
276 inner_op->removeAttr(kDeviceAttr);
277 });
278 }
279
280 // Add terminator.
281 builder.setInsertionPointToEnd(body);
282 builder.create<tf_device::ReturnOp>(last_cluster_op->getLoc(), results);
283
284 // Replaces uses of cluster ops results outside of cluster with the associated
285 // `tf_device.cluster` results.
286 for (auto ret_vals : llvm::zip(results, cluster.getResults())) {
287 Value old_ret = std::get<0>(ret_vals);
288 Value new_ret = std::get<1>(ret_vals);
289 for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) {
290 Operation* user = use.getOwner();
291 if (!body->findAncestorOpInBlock(*user)) use.set(new_ret);
292 }
293 }
294
295 // Move users of cluster that are before the cluster to after the cluster.
296 Operation* op_after_cluster = cluster.getOperation()->getNextNode();
297 for (Operation* user : preceding_users) user->moveBefore(op_after_cluster);
298 return cluster;
299 }
300
301 // Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index`
302 // of -1 are always after ops with a non negative `index`, and an arbitrary
303 // ordering is used as there are no dependencies on their relative ordering. If
304 // there are multiple `tf.TPUReplicatedInput` ops with the same non negative
305 // index or if indices are less than -1, an error will be returned.
SortTPUReplicatedInputsByIndex(llvm::ArrayRef<Operation * > inputs,llvm::SmallVectorImpl<Operation * > * sorted_inputs)306 LogicalResult SortTPUReplicatedInputsByIndex(
307 llvm::ArrayRef<Operation*> inputs,
308 llvm::SmallVectorImpl<Operation*>* sorted_inputs) {
309 llvm::SmallDenseSet<int64_t, 8> unique_indices;
310 for (Operation* input : inputs) {
311 int64_t index = llvm::cast<TF::TPUReplicatedInputOp>(input).index();
312 if (index < -1)
313 return input->emitOpError()
314 << "requires index to be at least -1, but got " << index;
315 if (index == -1) continue;
316 if (!unique_indices.insert(index).second)
317 return input->emitOpError()
318 << "requires indices to be unique, but found multiple '"
319 << input->getName() << "' ops with index " << index;
320 }
321
322 // Sort all TPUReplicatedInputs by `index` attribute to have
323 // TPUReplicatedInputs with indices be added to the `tf_device.replicate` op
324 // deterministically. If `index` attribute is -1, instead move them to the
325 // end.
326 sorted_inputs->assign(inputs.begin(), inputs.end());
327 std::stable_sort(
328 sorted_inputs->begin(), sorted_inputs->end(),
329 [](Operation* l, Operation* r) {
330 int64_t l_index = llvm::cast<TF::TPUReplicatedInputOp>(l).index();
331 int64_t r_index = llvm::cast<TF::TPUReplicatedInputOp>(r).index();
332 if (l_index == -1 && r_index != -1) return false;
333 if (r_index == -1 && l_index != -1) return true;
334 return l_index < r_index;
335 });
336
337 return success();
338 }
339
340 // Creates a `tf_device.replicate` to represent replication for the cluster, if
341 // necessary.
ReplicateCluster(tf_device::ClusterOp cluster,int num_replicas,int num_cores_per_replica)342 LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas,
343 int num_cores_per_replica) {
344 // No need to replicate.
345 if (num_replicas == 1) return success();
346
347 if (num_replicas < 1)
348 return cluster.emitError() << "requires '" << kNumReplicasAttr
349 << "' int attribute to be at least 1";
350
351 LogicalResult status = success();
352 // Collect all used TPUReplicatedInput ops and sort by `index`.
353 llvm::SmallSetVector<Operation*, 8> unique_replicated_input_ops;
354 mlir::visitUsedValuesDefinedAbove(
355 cluster.body(), cluster.body(), [&](mlir::OpOperand* operand) {
356 Operation* def = operand->get().getDefiningOp();
357 if (llvm::isa_and_nonnull<TF::TPUReplicatedInputOp>(def))
358 unique_replicated_input_ops.insert(def);
359 // When model parallelism is used in conjunction with data parallelism
360 // for resource inputs, we need to collect the per replica resource
361 // inputs from input to `tf.TPUPartitionedInput` ops.
362 if (auto pi = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(def)) {
363 if (pi->getNumOperands() != num_cores_per_replica)
364 status = pi.emitOpError()
365 << "requires " << num_cores_per_replica
366 << " operands but found " << pi->getNumOperands();
367 for (auto operand : pi.inputs()) {
368 if (llvm::isa_and_nonnull<TF::TPUReplicatedInputOp>(
369 operand.getDefiningOp()))
370 unique_replicated_input_ops.insert(operand.getDefiningOp());
371 }
372 }
373 });
374
375 if (failed(status)) return failure();
376 llvm::SmallVector<Operation*, 8> replicated_input_ops;
377 if (failed(SortTPUReplicatedInputsByIndex(
378 unique_replicated_input_ops.getArrayRef(), &replicated_input_ops)))
379 return failure();
380
381 // Index attribute value stored on TPUReplicatedInput op. These will be used
382 // later for dynamic padder.
383 llvm::SmallVector<int64_t, 8> replicated_input_indices;
384 llvm::SmallVector<int64_t, 8> packed_input_indices;
385 bool has_replicated_input_index = false;
386
387 // Indices of the replicate op's arguments that are mirrored variables.
388 llvm::SmallVector<int64_t, 8> mirrored_variable_indices;
389
390 // Check if number of operands of each used TPUReplicatedInput op matches
391 // `num_replicas` or 1. Collect all their operands and associated type for
392 // creating the replicate op.
393 llvm::SmallVector<std::pair<ValueRange, Type>, 8> replicated_inputs;
394 llvm::SmallVector<Value, 8> packed_inputs;
395 for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) {
396 auto input = pos_and_input.value();
397 bool is_packed = llvm::cast<TF::TPUReplicatedInputOp>(input).is_packed();
398 const int num_operands = input->getNumOperands();
399 int num_inputs = is_packed ? 1 : num_replicas;
400 if (num_operands != num_inputs)
401 return input->emitOpError() << "requires " << num_inputs << " operands";
402
403 auto tpu_replicated_input = llvm::cast<TF::TPUReplicatedInputOp>(input);
404 int64_t tpu_replicated_input_index = tpu_replicated_input.index();
405 if (is_packed) {
406 packed_inputs.push_back(input->getOperand(0));
407 packed_input_indices.push_back(tpu_replicated_input_index);
408 } else {
409 replicated_inputs.push_back(
410 {input->getOperands(), input->getOperand(0).getType()});
411 replicated_input_indices.push_back(tpu_replicated_input_index);
412 }
413 if (tpu_replicated_input_index != -1) has_replicated_input_index = true;
414
415 if (tpu_replicated_input.is_mirrored_variable())
416 mirrored_variable_indices.push_back(pos_and_input.index());
417 }
418
419 replicated_input_indices.append(packed_input_indices.begin(),
420 packed_input_indices.end());
421
422 // Create replicate op.
423 OpBuilder builder(cluster);
424 auto replicate_op = builder.create<tf_device::ReplicateOp>(
425 cluster.getLoc(), num_replicas,
426 llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
427 replicated_inputs, packed_inputs, cluster.getResultTypes());
428 if (has_replicated_input_index)
429 replicate_op->setAttr(kReplicatedInputIndicesAttr,
430 builder.getI64ArrayAttr(replicated_input_indices));
431
432 if (!mirrored_variable_indices.empty())
433 replicate_op->setAttr(kMirroredVariableIndicesAttr,
434 builder.getI64ArrayAttr(mirrored_variable_indices));
435
436 // Replace replicated cluster results with replicate op results.
437 for (auto result_and_idx : llvm::enumerate(cluster.getResults())) {
438 Value result = result_and_idx.value();
439 int idx = result_and_idx.index();
440 auto replicate_outputs = llvm::make_range(
441 std::next(replicate_op.result_begin(), idx * num_replicas),
442 std::next(replicate_op.result_begin(), (idx + 1) * num_replicas));
443
444 for (auto& use : llvm::make_early_inc_range(result.getUses())) {
445 Operation* def = use.getOwner();
446 if (!llvm::isa<TF::TPUReplicatedOutputOp>(def)) {
447 // If user is not a `tf.TPUReplicatedOutput`, simply forward the first
448 // replica output. Certain Graphs under V1 create `tf.Identity` users of
449 // replicated ops to pin the TPU computation for execution.
450 use.set(*replicate_outputs.begin());
451 continue;
452 }
453
454 const int def_num_results = def->getNumResults();
455 if (def_num_results != num_replicas)
456 return def->emitOpError() << "requires " << num_replicas << " results";
457
458 def->replaceAllUsesWith(replicate_outputs);
459 }
460 }
461
462 // Collect all `tf.TPUPartitionedInput` ops to be moved inside the
463 // `tf_device.replicate` later.
464 llvm::SmallSet<Operation*, 4> partitioned_inputs;
465 // Update replicated inputs with replicate op block arguments.
466 for (auto input_and_block_arg :
467 llvm::zip(replicated_input_ops, replicate_op.GetBody().getArguments())) {
468 Operation* input = std::get<0>(input_and_block_arg);
469 Value block_arg = std::get<1>(input_and_block_arg);
470 mlir::replaceAllUsesInRegionWith(input->getResult(0), block_arg,
471 cluster.body());
472 // Update replicated input use in tf.TPUPartitionedInput op.
473 for (auto& use : input->getUses()) {
474 auto pi = llvm::dyn_cast<TF::TPUPartitionedInputOp>(use.getOwner());
475 if (pi) {
476 pi.setOperand(use.getOperandNumber(), block_arg);
477 partitioned_inputs.insert(pi.getOperation());
478 }
479 }
480 }
481
482 // Create terminator for replicate op and move `tf_device.cluster` and
483 // `tf.TPUPartitionedInput`(s) into replicate body.
484 builder.setInsertionPointToEnd(&replicate_op.GetBody());
485 auto return_op = builder.create<tf_device::ReturnOp>(replicate_op.getLoc(),
486 cluster.getResults());
487 for (auto pi : partitioned_inputs) pi->moveBefore(return_op);
488
489 cluster.getOperation()->moveBefore(return_op);
490
491 return success();
492 }
493
494 // Forms clusters with ops of the same `_tpu_replicate` attribute under a block.
495 //
496 // For a given block, clusters are formed via grouping ops by `_tpu_replicate`
497 // attributes.
498 // For every cluster formed:
499 // 1. Find associated TPUReplicateMetadata attributes with the same
500 // `_tpu_replicate` attribute.
501 // 2. Find users not in cluster that are interleaved between cluster ops.
502 // 3. Find external uses of cluster ops.
503 // 4. Create `tf_device.cluster` with results consisting of the external uses
504 // of cluster ops determined at 3.
505 // 5. Move cluster ops to `tf_device.cluster` body.
506 // 6. Replace external uses of cluster ops uses with `tf_device.cluster`
507 // results.
508 // 7. Move users from 2 to after the `tf_device.cluster`.
509 // 8. Wrap cluster (`tf_device.cluster`) in a `tf_device.replicate` if
510 // attribute `num_replicas` is greater than 1.
511 // 9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`.
FormClustersInBlock(Block * block,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)512 LogicalResult FormClustersInBlock(
513 Block* block,
514 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
515 MetadataMap metadata_map;
516 LogicalResult result = CollectMetadata(block, &metadata_map);
517 if (failed(result)) return result;
518
519 // If there is no TPUReplicateMetadata op in this block, process blocks in
520 // regions attached to the op's in the block.
521 if (metadata_map.empty()) {
522 for (Operation& op : *block) {
523 for (Region& region : op.getRegions()) {
524 if (!llvm::hasSingleElement(region))
525 return op.emitOpError("Expected single block region");
526 if (failed(
527 FormClustersInBlock(®ion.front(), resource_alias_analysis)))
528 return failure();
529 }
530 }
531 return success();
532 }
533
534 ClusterMap clusters;
535 result = CollectAndGroupClusterOps(block, &clusters);
536 if (failed(result)) return result;
537
538 for (const auto& cluster_metadata_and_ops : clusters) {
539 const auto& cluster_ops = cluster_metadata_and_ops.getSecond();
540
541 auto cluster_metadata =
542 metadata_map.find(cluster_metadata_and_ops.getFirst());
543
544 // No TPUReplicateMetadata for a `_tpu_replicate` attribute.
545 if (cluster_metadata == metadata_map.end()) {
546 cluster_ops.front()->emitWarning()
547 << "TPUReplicateMetadata for associated '" << kTPUReplicateAttr
548 << "' attribute '" << cluster_metadata_and_ops.getFirst()
549 << "' is missing";
550 continue;
551 }
552
553 llvm::SmallSetVector<Operation*, 8> preceding_users =
554 CollectClusterPrecedingUsers(block, cluster_ops,
555 resource_alias_analysis);
556
557 llvm::SmallVector<Value, 8> results =
558 CollectClusterResults(block, cluster_ops);
559
560 tf_device::ClusterOp cluster = CreateClusterOp(
561 block, cluster_ops, results, preceding_users.getArrayRef());
562
563 auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr);
564 if (!num_replicas || !num_replicas.isa<mlir::IntegerAttr>())
565 return cluster.emitError()
566 << "requires '" << kNumReplicasAttr << "' int attribute";
567
568 int num_cores_per_replica = 1;
569 auto num_cores_per_replica_attr =
570 cluster_metadata->getSecond()
571 .get(kNumCoresPerReplicaAttr)
572 .dyn_cast_or_null<mlir::IntegerAttr>();
573 if (num_cores_per_replica_attr)
574 num_cores_per_replica = num_cores_per_replica_attr.getInt();
575
576 if (failed(ReplicateCluster(cluster,
577 num_replicas.cast<mlir::IntegerAttr>().getInt(),
578 num_cores_per_replica)))
579 return failure();
580
581 // Copy TPUReplicateMetadata attributes to `tf_device.cluster`.
582 cluster->setAttrs(
583 cluster_metadata->second.getDictionary(cluster.getContext()));
584 // Exclude `num_replicas` as cluster should be replicated if necessary.
585 cluster.removeAttr(kNumReplicasAttr);
586 }
587
588 return success();
589 }
590
FormClustersInFunction(FuncOp func,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)591 LogicalResult FormClustersInFunction(
592 FuncOp func,
593 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
594 if (!llvm::hasSingleElement(func))
595 return func.emitOpError("Expecting a single block function");
596
597 if (failed(FormClustersInBlock(&func.front(), resource_alias_analysis)))
598 return failure();
599
600 // Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
601 auto remove_result = func.walk([&](Operation* op) {
602 if (!llvm::isa<TF::TPUReplicatedInputOp, TF::TPUReplicatedOutputOp>(op))
603 return WalkResult::advance();
604
605 // Forward operand to result. When `num_replicas` attribute is 1, no
606 // `tf_device.replicate` is created and replicated (1) operands/results are
607 // untouched.
608 if (op->getNumOperands() == 1 && op->getNumResults() == 1)
609 op->getResult(0).replaceAllUsesWith(op->getOperand(0));
610
611 // Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of
612 // `num_replicas` to 1.
613 if (!op->use_empty()) {
614 op->emitOpError() << "is expected to have no uses, but it is operand#"
615 << op->use_begin()->getOperandNumber() << " of "
616 << *op->use_begin()->getOwner();
617 return WalkResult::interrupt();
618 }
619
620 op->erase();
621
622 return WalkResult::advance();
623 });
624
625 return failure(remove_result.wasInterrupted());
626 }
627
runOnOperation()628 void TPUClusterFormationPass::runOnOperation() {
629 auto& resource_alias_analysis = getAnalysis<TF::ResourceAliasAnalysis>();
630 for (auto func : getOperation().getOps<FuncOp>())
631 if (!func.isExternal() &&
632 failed(FormClustersInFunction(
633 func, resource_alias_analysis.GetAnalysisForFunc(func))))
634 return signalPassFailure();
635 }
636 } // anonymous namespace
637
CreateTPUClusterFormationPass()638 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass() {
639 return std::make_unique<TPUClusterFormationPass>();
640 }
641
642 } // namespace TFTPU
643 } // namespace mlir
644