1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <iterator>
18 #include <memory>
19 #include <tuple>
20 #include <utility>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/ADT/iterator_range.h"
30 #include "llvm/Support/Casting.h"
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/Builders.h" // from @llvm-project
33 #include "mlir/IR/Identifier.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/Operation.h" // from @llvm-project
36 #include "mlir/IR/Types.h" // from @llvm-project
37 #include "mlir/IR/Value.h" // from @llvm-project
38 #include "mlir/Pass/Pass.h" // from @llvm-project
39 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
47
48 namespace mlir {
49 namespace TFTPU {
50
51 namespace {
52
53 constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
54 constexpr char kDeviceAttr[] = "device";
55 constexpr char kNameAttr[] = "name";
56 constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica";
57 constexpr char kNumReplicasAttr[] = "num_replicas";
58 constexpr char kReplicatedInputIndicesAttr[] = "_replicated_input_indices";
59 constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
60
61 constexpr char kBadTPUReplicateAttrMsg[] =
62 "requires '_tpu_replicate' string attribute";
63
64 // Mapping for `_tpu_replicate` attribute to TPUReplicateMetadata attributes.
65 using MetadataMap = llvm::SmallDenseMap<llvm::StringRef, NamedAttrList, 8>;
66
67 // A set of operations in a cluster.
68 using ClusterOps = llvm::SmallSetVector<Operation*, 8>;
69
70 // Mapping for `_tpu_replicate` attribute to ops of a cluster.
71 using ClusterMap = llvm::SmallDenseMap<llvm::StringRef, ClusterOps, 8>;
72
73 struct TPUClusterFormationPass
74 : public TF::TPUClusterFormationPassBase<TPUClusterFormationPass> {
getDependentDialectsmlir::TFTPU::__anonb7006daa0111::TPUClusterFormationPass75 void getDependentDialects(DialectRegistry& registry) const override {
76 registry.insert<tf_device::TensorFlowDeviceDialect>();
77 }
78
79 void runOnOperation() override;
80 };
81
82 // Creates a mapping from the TPUReplicateMetadata ops `_tpu_replicate`
83 // attribute to its attributes and removes the ops. If multiple
84 // TPUReplicateMetadata ops have the same `_tpu_replicate` attribute, an error
85 // will be returned.
CollectMetadata(Block * block,MetadataMap * metadata_map)86 LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) {
87 // Just look at top-level operations in the block (not nested ones)
88 for (Operation& op : llvm::make_early_inc_range(*block)) {
89 auto metadata_op = dyn_cast<TF::TPUReplicateMetadataOp>(op);
90 if (!metadata_op) continue;
91
92 NamedAttrList attrs(metadata_op->getAttrDictionary());
93
94 // Missing or bad `_tpu_replicate` attribute.
95 auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr);
96 if (!tpu_replicate_attr)
97 return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
98
99 auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast<StringAttr>();
100 if (!tpu_replicate_attr_str || tpu_replicate_attr_str.getValue().empty())
101 return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
102
103 // Remove `name` attribute.
104 attrs.erase(Identifier::get(kNameAttr, metadata_op.getContext()));
105
106 auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(),
107 std::move(attrs));
108
109 // There are multiple TPUReplicateMetadata ops with the same
110 // `_tpu_replicate` attribute.
111 if (!it.second) {
112 return metadata_op.emitError()
113 << "multiple TPUReplicateMetadata ops with the same '"
114 << kTPUReplicateAttr << "' attribute '"
115 << tpu_replicate_attr_str.getValue() << "' found";
116 }
117 metadata_op.erase();
118 }
119 return success();
120 }
121
122 // Collects and clusters ops with the same `_tpu_replicate` attribute. This will
123 // return an error if a `_tpu_replicate` attribute of an op is empty.
CollectAndGroupClusterOps(Block * block,ClusterMap * clusters)124 LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) {
125 for (Operation& op : *block) {
126 if (auto attr = op.getAttrOfType<StringAttr>(kTPUReplicateAttr)) {
127 if (attr.getValue().empty())
128 return op.emitError()
129 << "attribute '" << kTPUReplicateAttr << "' is empty";
130
131 auto it = clusters->try_emplace(attr.getValue());
132 it.first->getSecond().insert(&op);
133 }
134 }
135
136 return success();
137 }
138
139 // Collects all resource ids from an op.
CollectResourceIdsFromOp(Operation & op,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis,llvm::SmallDenseSet<int64_t> & observed_resource_ids)140 void CollectResourceIdsFromOp(
141 Operation& op,
142 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis,
143 llvm::SmallDenseSet<int64_t>& observed_resource_ids) {
144 op.walk([&](Operation* inner_op) {
145 for (Value operand : TF::filter_resources(inner_op->getOperands())) {
146 if (resource_alias_analysis.IsUnknownResource(operand)) continue;
147 const auto& ids = resource_alias_analysis.GetResourceUniqueIds(operand);
148 observed_resource_ids.insert(ids.begin(), ids.end());
149 }
150 for (Value result : TF::filter_resources(inner_op->getResults())) {
151 if (resource_alias_analysis.IsUnknownResource(result)) continue;
152 const auto& ids = resource_alias_analysis.GetResourceUniqueIds(result);
153 observed_resource_ids.insert(ids.begin(), ids.end());
154 }
155 });
156 }
157
158 // Checks if an op should be moved after a cluster. There may be users of a
159 // cluster interleaved among the cluster ops.
ShouldMoveOpAfterCluster(Block * block,Operation * op,const ClusterOps & cluster_ops,const llvm::SmallSetVector<Operation *,8> & preceding_users,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis,const llvm::SmallDenseSet<int64_t> & observed_resource_ids)160 bool ShouldMoveOpAfterCluster(
161 Block* block, Operation* op, const ClusterOps& cluster_ops,
162 const llvm::SmallSetVector<Operation*, 8>& preceding_users,
163 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis,
164 const llvm::SmallDenseSet<int64_t>& observed_resource_ids) {
165 const bool is_replicate = llvm::isa<tf_device::ReplicateOp>(op);
166 auto result = op->walk([&](Operation* inner_op) {
167 for (Value operand : inner_op->getOperands()) {
168 Operation* def = operand.getDefiningOp();
169 // Operands may not have a defining op (BlockArgument) or is from a
170 // different block.
171 if (!def || def->getBlock() != block) continue;
172
173 if (cluster_ops.count(def) != 0 || preceding_users.count(def) != 0) {
174 // Op is a user of a cluster or another op that is a user of the
175 // cluster (transitively), but is before the cluster.
176 return WalkResult::interrupt();
177 }
178 }
179
180 // Don't visit replicate op inner op operands as new resource
181 // values/arguments may have been created but are not known in
182 // `resource_alias_analysis`.
183 if (is_replicate && inner_op != op) return WalkResult::advance();
184
185 // Check for uses of any resource in or after cluster.
186 for (Value operand : TF::filter_resources(inner_op->getOperands())) {
187 if (resource_alias_analysis.IsUnknownResource(operand)) continue;
188 auto ids = resource_alias_analysis.GetResourceUniqueIds(operand);
189 for (const auto& id : ids)
190 if (observed_resource_ids.contains(id)) return WalkResult::interrupt();
191 }
192 return WalkResult::advance();
193 });
194
195 return result.wasInterrupted();
196 }
197
198 // Collects ops that are before ops in the cluster but are users of other ops
199 // in the cluster. This may happen because users of individual ops in the
200 // cluster may be interleaved with other ops in the cluster. Resource id's are
201 // also captured, to keep track of resource usage before, in, or after the
202 // cluster.
203 // TODO(b/175701589): Extend this to handle all side effecting ops while
204 // handling transitive data dependencies.
CollectClusterPrecedingUsers(Block * block,const ClusterOps & cluster_ops,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)205 llvm::SmallSetVector<Operation*, 8> CollectClusterPrecedingUsers(
206 Block* block, const ClusterOps& cluster_ops,
207 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
208 llvm::SmallSetVector<Operation*, 8> preceding_users;
209 llvm::SmallDenseSet<int64_t> observed_resource_ids;
210
211 auto front = Block::iterator(cluster_ops.front());
212 auto back = Block::iterator(cluster_ops.back());
213 for (Operation& op : llvm::make_range(front, back)) {
214 if (cluster_ops.contains(&op)) {
215 CollectResourceIdsFromOp(op, resource_alias_analysis,
216 observed_resource_ids);
217 } else if (ShouldMoveOpAfterCluster(
218 block, &op, cluster_ops, preceding_users,
219 resource_alias_analysis, observed_resource_ids)) {
220 preceding_users.insert(&op);
221 CollectResourceIdsFromOp(op, resource_alias_analysis,
222 observed_resource_ids);
223 }
224 }
225
226 return preceding_users;
227 }
228
229 // Collects results and associated types of the cluster that are used outside of
230 // the cluster. These results and types are used to create the clusters
231 // `tf_device.cluster` and associated terminator. Results that have no uses
232 // outside of the cluster (i.e. results of ops in the cluster are only consumed
233 // by other ops in the cluster) are pruned.
CollectClusterResults(Block * block,const ClusterOps & cluster_ops)234 llvm::SmallVector<Value, 8> CollectClusterResults(
235 Block* block, const ClusterOps& cluster_ops) {
236 llvm::SmallVector<Value, 8> results;
237
238 for (Operation* op : cluster_ops) {
239 for (Value result : op->getResults()) {
240 for (Operation* user : result.getUsers()) {
241 // Check if user is not an op in the cluster.
242 if (cluster_ops.count(block->findAncestorOpInBlock(*user)) == 0) {
243 results.push_back(result);
244 break;
245 }
246 }
247 }
248 }
249
250 return results;
251 }
252
253 // Creates a `tf_device.cluster` to wrap cluster ops.
CreateClusterOp(Block * block,const ClusterOps & cluster_ops,llvm::ArrayRef<Value> results,llvm::ArrayRef<Operation * > preceding_users)254 tf_device::ClusterOp CreateClusterOp(
255 Block* block, const ClusterOps& cluster_ops, llvm::ArrayRef<Value> results,
256 llvm::ArrayRef<Operation*> preceding_users) {
257 // `tf_device.cluster` will be placed at where the last op of the cluster is.
258 Operation* last_cluster_op = cluster_ops.back();
259 OpBuilder builder(last_cluster_op);
260
261 llvm::SmallVector<Type, 8> result_types;
262 for (Value result : results) result_types.push_back(result.getType());
263 auto cluster = builder.create<tf_device::ClusterOp>(last_cluster_op->getLoc(),
264 result_types);
265
266 Block* body = new Block;
267 cluster.body().push_back(body);
268
269 // Move cluster ops to the cluster body. Also remove `_tpu_replicate` and
270 // `device` attribute from ops in the cluster when that information is
271 // redundant will the `tf_device.cluster`. Do this for all ops including
272 // nested ops.
273 for (Operation* cluster_op : cluster_ops) {
274 cluster_op->moveBefore(body, body->end());
275 cluster_op->walk([&](Operation* inner_op) {
276 inner_op->removeAttr(kTPUReplicateAttr);
277
278 if (auto attr = inner_op->getAttrOfType<StringAttr>(kDeviceAttr)) {
279 // Preserve device attribute if the op is placed on a replicated core
280 // device. Device attribute is used to infer the appropriate sharding
281 // within TPUs for this op.
282 // TODO(b/183598857): Use explicit sharding ops from the front-end.
283 // For example, dequeue ops generated by
284 // tensorflow/python/tpu/tpu_feed.py
285 if (!tensorflow::IsTPUReplicatedCore(attr.getValue())) {
286 inner_op->removeAttr(kDeviceAttr);
287 }
288 }
289 });
290 }
291
292 // Add terminator.
293 builder.setInsertionPointToEnd(body);
294 builder.create<tf_device::ReturnOp>(last_cluster_op->getLoc(), results);
295
296 // Replaces uses of cluster ops results outside of cluster with the associated
297 // `tf_device.cluster` results.
298 for (auto ret_vals : llvm::zip(results, cluster.getResults())) {
299 Value old_ret = std::get<0>(ret_vals);
300 Value new_ret = std::get<1>(ret_vals);
301 for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) {
302 Operation* user = use.getOwner();
303 if (!body->findAncestorOpInBlock(*user)) use.set(new_ret);
304 }
305 }
306
307 // Move users of cluster that are before the cluster to after the cluster.
308 Operation* op_after_cluster = cluster.getOperation()->getNextNode();
309 for (Operation* user : preceding_users) user->moveBefore(op_after_cluster);
310 return cluster;
311 }
312
313 // Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index`
314 // of -1 are always after ops with a non negative `index`, and an arbitrary
315 // ordering is used as there are no dependencies on their relative ordering. If
316 // there are multiple `tf.TPUReplicatedInput` ops with the same non negative
317 // index or if indices are less than -1, an error will be returned.
SortTPUReplicatedInputsByIndex(llvm::ArrayRef<Operation * > inputs,llvm::SmallVectorImpl<Operation * > * sorted_inputs)318 LogicalResult SortTPUReplicatedInputsByIndex(
319 llvm::ArrayRef<Operation*> inputs,
320 llvm::SmallVectorImpl<Operation*>* sorted_inputs) {
321 llvm::SmallDenseSet<int64_t, 8> unique_indices;
322 for (Operation* input : inputs) {
323 int64_t index = llvm::cast<TF::TPUReplicatedInputOp>(input).index();
324 if (index < -1)
325 return input->emitOpError()
326 << "requires index to be at least -1, but got " << index;
327 if (index == -1) continue;
328 if (!unique_indices.insert(index).second)
329 return input->emitOpError()
330 << "requires indices to be unique, but found multiple '"
331 << input->getName() << "' ops with index " << index;
332 }
333
334 // Sort all TPUReplicatedInputs by `index` attribute to have
335 // TPUReplicatedInputs with indices be added to the `tf_device.replicate` op
336 // deterministically. If `index` attribute is -1, instead move them to the
337 // end.
338 sorted_inputs->assign(inputs.begin(), inputs.end());
339 std::stable_sort(
340 sorted_inputs->begin(), sorted_inputs->end(),
341 [](Operation* l, Operation* r) {
342 int64_t l_index = llvm::cast<TF::TPUReplicatedInputOp>(l).index();
343 int64_t r_index = llvm::cast<TF::TPUReplicatedInputOp>(r).index();
344 if (l_index == -1 && r_index != -1) return false;
345 if (r_index == -1 && l_index != -1) return true;
346 return l_index < r_index;
347 });
348
349 return success();
350 }
351
352 // Creates a `tf_device.replicate` to represent replication for the cluster, if
353 // necessary.
ReplicateCluster(tf_device::ClusterOp cluster,int num_replicas,int num_cores_per_replica)354 LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas,
355 int num_cores_per_replica) {
356 // No need to replicate.
357 if (num_replicas == 1) return success();
358
359 if (num_replicas < 1)
360 return cluster.emitError() << "requires '" << kNumReplicasAttr
361 << "' int attribute to be at least 1";
362
363 LogicalResult status = success();
364 // Collect all used TPUReplicatedInput ops and sort by `index`.
365 llvm::SmallSetVector<Operation*, 8> unique_replicated_input_ops;
366 mlir::visitUsedValuesDefinedAbove(
367 cluster.body(), cluster.body(), [&](mlir::OpOperand* operand) {
368 Operation* def = operand->get().getDefiningOp();
369 if (llvm::isa_and_nonnull<TF::TPUReplicatedInputOp>(def))
370 unique_replicated_input_ops.insert(def);
371 // When model parallelism is used in conjunction with data parallelism
372 // for resource inputs, we need to collect the per replica resource
373 // inputs from input to `tf.TPUPartitionedInput` ops.
374 if (auto pi = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(def)) {
375 if (pi->getNumOperands() != num_cores_per_replica)
376 status = pi.emitOpError()
377 << "requires " << num_cores_per_replica
378 << " operands but found " << pi->getNumOperands();
379 for (auto operand : pi.inputs()) {
380 if (llvm::isa_and_nonnull<TF::TPUReplicatedInputOp>(
381 operand.getDefiningOp()))
382 unique_replicated_input_ops.insert(operand.getDefiningOp());
383 }
384 }
385 });
386
387 if (failed(status)) return failure();
388 llvm::SmallVector<Operation*, 8> replicated_input_ops;
389 if (failed(SortTPUReplicatedInputsByIndex(
390 unique_replicated_input_ops.getArrayRef(), &replicated_input_ops)))
391 return failure();
392
393 // Index attribute value stored on TPUReplicatedInput op. These will be used
394 // later for dynamic padder.
395 llvm::SmallVector<int64_t, 8> replicated_input_indices;
396 llvm::SmallVector<int64_t, 8> packed_input_indices;
397 bool has_replicated_input_index = false;
398
399 // Indices of the replicate op's arguments that are mirrored variables.
400 llvm::SmallVector<int64_t, 8> mirrored_variable_indices;
401
402 // Check if number of operands of each used TPUReplicatedInput op matches
403 // `num_replicas` or 1. Collect all their operands and associated type for
404 // creating the replicate op.
405 llvm::SmallVector<std::pair<ValueRange, Type>, 8> replicated_inputs;
406 llvm::SmallVector<Value, 8> packed_inputs;
407 llvm::SmallVector<Operation*, 8> replicated_ops;
408 llvm::SmallVector<Operation*, 8> packed_ops;
409 for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) {
410 auto input = pos_and_input.value();
411 bool is_packed = llvm::cast<TF::TPUReplicatedInputOp>(input).is_packed();
412 const int num_operands = input->getNumOperands();
413 int num_inputs = is_packed ? 1 : num_replicas;
414 if (num_operands != num_inputs)
415 return input->emitOpError() << "requires " << num_inputs << " operands";
416
417 auto tpu_replicated_input = llvm::cast<TF::TPUReplicatedInputOp>(input);
418 int64_t tpu_replicated_input_index = tpu_replicated_input.index();
419 if (is_packed) {
420 packed_inputs.push_back(input->getOperand(0));
421 packed_input_indices.push_back(tpu_replicated_input_index);
422 packed_ops.push_back(input);
423 } else {
424 replicated_inputs.push_back(
425 {input->getOperands(), input->getOperand(0).getType()});
426 replicated_input_indices.push_back(tpu_replicated_input_index);
427 replicated_ops.push_back(input);
428 }
429 if (tpu_replicated_input_index != -1) has_replicated_input_index = true;
430
431 if (tpu_replicated_input.is_mirrored_variable())
432 mirrored_variable_indices.push_back(pos_and_input.index());
433 }
434
435 replicated_input_indices.append(packed_input_indices.begin(),
436 packed_input_indices.end());
437
438 // Create replicate op.
439 OpBuilder builder(cluster);
440 auto replicate_op = builder.create<tf_device::ReplicateOp>(
441 cluster.getLoc(), num_replicas,
442 llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
443 replicated_inputs, packed_inputs, cluster.getResultTypes());
444 if (has_replicated_input_index)
445 replicate_op->setAttr(kReplicatedInputIndicesAttr,
446 builder.getI64ArrayAttr(replicated_input_indices));
447
448 if (!mirrored_variable_indices.empty())
449 replicate_op->setAttr(kMirroredVariableIndicesAttr,
450 builder.getI64ArrayAttr(mirrored_variable_indices));
451
452 // Replace replicated cluster results with replicate op results.
453 for (auto result_and_idx : llvm::enumerate(cluster.getResults())) {
454 Value result = result_and_idx.value();
455 int idx = result_and_idx.index();
456 auto replicate_outputs = llvm::make_range(
457 std::next(replicate_op.result_begin(), idx * num_replicas),
458 std::next(replicate_op.result_begin(), (idx + 1) * num_replicas));
459
460 for (auto& use : llvm::make_early_inc_range(result.getUses())) {
461 Operation* def = use.getOwner();
462 if (!llvm::isa<TF::TPUReplicatedOutputOp>(def)) {
463 // If user is not a `tf.TPUReplicatedOutput`, simply forward the first
464 // replica output. Certain Graphs under V1 create `tf.Identity` users of
465 // replicated ops to pin the TPU computation for execution.
466 use.set(*replicate_outputs.begin());
467 continue;
468 }
469
470 const int def_num_results = def->getNumResults();
471 if (def_num_results != num_replicas)
472 return def->emitOpError() << "requires " << num_replicas << " results";
473
474 def->replaceAllUsesWith(replicate_outputs);
475 }
476 }
477
478 // Collect all `tf.TPUPartitionedInput` ops to be moved inside the
479 // `tf_device.replicate` later.
480 llvm::SmallSet<Operation*, 4> partitioned_inputs;
481 // Update replicated inputs with replicate op block arguments.
482 auto ordered_tpu_replicate_inputs =
483 llvm::concat<Operation*>(replicated_ops, packed_ops);
484 for (auto input_and_block_arg :
485 llvm::zip(ordered_tpu_replicate_inputs,
486 replicate_op.GetBody().getArguments())) {
487 Operation* input = std::get<0>(input_and_block_arg);
488 Value block_arg = std::get<1>(input_and_block_arg);
489 mlir::replaceAllUsesInRegionWith(input->getResult(0), block_arg,
490 cluster.body());
491 // Update replicated input use in tf.TPUPartitionedInput op.
492 for (auto& use : input->getUses()) {
493 auto pi = llvm::dyn_cast<TF::TPUPartitionedInputOp>(use.getOwner());
494 if (pi) {
495 pi.setOperand(use.getOperandNumber(), block_arg);
496 partitioned_inputs.insert(pi.getOperation());
497 }
498 }
499 }
500
501 // Create terminator for replicate op and move `tf_device.cluster` and
502 // `tf.TPUPartitionedInput`(s) into replicate body.
503 builder.setInsertionPointToEnd(&replicate_op.GetBody());
504 auto return_op = builder.create<tf_device::ReturnOp>(replicate_op.getLoc(),
505 cluster.getResults());
506 for (auto pi : partitioned_inputs) pi->moveBefore(return_op);
507
508 cluster.getOperation()->moveBefore(return_op);
509
510 return success();
511 }
512
513 // Forms clusters with ops of the same `_tpu_replicate` attribute under a block.
514 //
515 // For a given block, clusters are formed via grouping ops by `_tpu_replicate`
516 // attributes.
517 // For every cluster formed:
518 // 1. Find associated TPUReplicateMetadata attributes with the same
519 // `_tpu_replicate` attribute.
520 // 2. Find users not in cluster that are interleaved between cluster ops.
521 // 3. Find external uses of cluster ops.
522 // 4. Create `tf_device.cluster` with results consisting of the external uses
523 // of cluster ops determined at 3.
524 // 5. Move cluster ops to `tf_device.cluster` body.
525 // 6. Replace external uses of cluster ops uses with `tf_device.cluster`
526 // results.
527 // 7. Move users from 2 to after the `tf_device.cluster`.
528 // 8. Wrap cluster (`tf_device.cluster`) in a `tf_device.replicate` if
529 // attribute `num_replicas` is greater than 1.
530 // 9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`.
FormClustersInBlock(Block * block,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)531 LogicalResult FormClustersInBlock(
532 Block* block,
533 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
534 MetadataMap metadata_map;
535 LogicalResult result = CollectMetadata(block, &metadata_map);
536 if (failed(result)) return result;
537
538 // If there is no TPUReplicateMetadata op in this block, process blocks in
539 // regions attached to the op's in the block.
540 if (metadata_map.empty()) {
541 for (Operation& op : *block) {
542 for (Region& region : op.getRegions()) {
543 if (!llvm::hasSingleElement(region))
544 return op.emitOpError("Expected single block region");
545 if (failed(
546 FormClustersInBlock(®ion.front(), resource_alias_analysis)))
547 return failure();
548 }
549 }
550 return success();
551 }
552
553 ClusterMap clusters;
554 result = CollectAndGroupClusterOps(block, &clusters);
555 if (failed(result)) return result;
556
557 for (const auto& cluster_metadata_and_ops : clusters) {
558 const auto& cluster_ops = cluster_metadata_and_ops.getSecond();
559
560 auto cluster_metadata =
561 metadata_map.find(cluster_metadata_and_ops.getFirst());
562
563 // No TPUReplicateMetadata for a `_tpu_replicate` attribute.
564 if (cluster_metadata == metadata_map.end()) {
565 cluster_ops.front()->emitWarning()
566 << "TPUReplicateMetadata for associated '" << kTPUReplicateAttr
567 << "' attribute '" << cluster_metadata_and_ops.getFirst()
568 << "' is missing";
569 continue;
570 }
571
572 llvm::SmallSetVector<Operation*, 8> preceding_users =
573 CollectClusterPrecedingUsers(block, cluster_ops,
574 resource_alias_analysis);
575
576 llvm::SmallVector<Value, 8> results =
577 CollectClusterResults(block, cluster_ops);
578
579 tf_device::ClusterOp cluster = CreateClusterOp(
580 block, cluster_ops, results, preceding_users.getArrayRef());
581
582 auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr);
583 if (!num_replicas || !num_replicas.isa<mlir::IntegerAttr>())
584 return cluster.emitError()
585 << "requires '" << kNumReplicasAttr << "' int attribute";
586
587 int num_cores_per_replica = 1;
588 auto num_cores_per_replica_attr =
589 cluster_metadata->getSecond()
590 .get(kNumCoresPerReplicaAttr)
591 .dyn_cast_or_null<mlir::IntegerAttr>();
592 if (num_cores_per_replica_attr)
593 num_cores_per_replica = num_cores_per_replica_attr.getInt();
594
595 if (failed(ReplicateCluster(cluster,
596 num_replicas.cast<mlir::IntegerAttr>().getInt(),
597 num_cores_per_replica)))
598 return failure();
599
600 // Copy TPUReplicateMetadata attributes to `tf_device.cluster`.
601 cluster->setAttrs(
602 cluster_metadata->second.getDictionary(cluster.getContext()));
603 // Exclude `num_replicas` as cluster should be replicated if necessary.
604 cluster->removeAttr(kNumReplicasAttr);
605 }
606
607 return success();
608 }
609
FormClustersInFunction(FuncOp func,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)610 LogicalResult FormClustersInFunction(
611 FuncOp func,
612 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
613 if (!llvm::hasSingleElement(func))
614 return func.emitOpError("Expecting a single block function");
615
616 if (failed(FormClustersInBlock(&func.front(), resource_alias_analysis)))
617 return failure();
618
619 // Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
620 auto remove_result = func.walk([&](Operation* op) {
621 if (!llvm::isa<TF::TPUReplicatedInputOp, TF::TPUReplicatedOutputOp>(op))
622 return WalkResult::advance();
623
624 // Forward operand to result. When `num_replicas` attribute is 1, no
625 // `tf_device.replicate` is created and replicated (1) operands/results are
626 // untouched.
627 if (op->getNumOperands() == 1 && op->getNumResults() == 1)
628 op->getResult(0).replaceAllUsesWith(op->getOperand(0));
629
630 // Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of
631 // `num_replicas` to 1.
632 if (!op->use_empty()) {
633 op->emitOpError() << "is expected to have no uses, but it is operand#"
634 << op->use_begin()->getOperandNumber() << " of "
635 << *op->use_begin()->getOwner();
636 return WalkResult::interrupt();
637 }
638
639 op->erase();
640
641 return WalkResult::advance();
642 });
643
644 return failure(remove_result.wasInterrupted());
645 }
646
runOnOperation()647 void TPUClusterFormationPass::runOnOperation() {
648 auto& resource_alias_analysis = getAnalysis<TF::ResourceAliasAnalysis>();
649 for (auto func : getOperation().getOps<FuncOp>())
650 if (!func.isExternal() &&
651 failed(FormClustersInFunction(
652 func, resource_alias_analysis.GetAnalysisForFunc(func))))
653 return signalPassFailure();
654 }
655 } // anonymous namespace
656
CreateTPUClusterFormationPass()657 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass() {
658 return std::make_unique<TPUClusterFormationPass>();
659 }
660
661 } // namespace TFTPU
662 } // namespace mlir
663