1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass takes ops with the same `_tpu_replicate` attribute
17 // in a block and clusters them together under a `tf_device::LaunchOp`.
18 // Associated TPUReplicateMetadata ops are removed and its attributes are copied
19 // over to the associated `tf_device::LaunchOp`. If a cluster should be
20 // replicated, the associated `tf_device::LaunchOp` will be wrapped further with
21 // a `tf_device.replicate`. This pass also assumes ops of the same cluster do
22 // not have ops outside of the cluster that are both operands and results of the
23 // cluster. Note, this currently does not handle side effecting ops yet.
24
25 #include <iterator>
26 #include <memory>
27 #include <tuple>
28 #include <utility>
29
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/SetVector.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/iterator_range.h"
37 #include "llvm/Support/Casting.h"
38 #include "mlir/IR/Attributes.h" // TF:llvm-project
39 #include "mlir/IR/Builders.h" // TF:llvm-project
40 #include "mlir/IR/Identifier.h" // TF:llvm-project
41 #include "mlir/IR/MLIRContext.h" // TF:llvm-project
42 #include "mlir/IR/Operation.h" // TF:llvm-project
43 #include "mlir/IR/Types.h" // TF:llvm-project
44 #include "mlir/IR/Value.h" // TF:llvm-project
45 #include "mlir/Pass/Pass.h" // TF:llvm-project
46 #include "mlir/Pass/PassRegistry.h" // TF:llvm-project
47 #include "mlir/Support/LogicalResult.h" // TF:llvm-project
48 #include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
49 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
50 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
52
53 namespace mlir {
54 namespace TFTPU {
55
56 namespace {
57
58 constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
59 constexpr char kDeviceAttr[] = "device";
60 constexpr char kNameAttr[] = "name";
61 constexpr char kNumReplicasAttr[] = "num_replicas";
62 constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
63
64 constexpr char kBadTPUReplicateAttrMsg[] =
65 "requires '_tpu_replicate' string attribute";
66
67 // Mapping for `_tpu_replicate` attribute to TPUReplicateMetadata attributes.
68 using MetadataMap = llvm::SmallDenseMap<llvm::StringRef, NamedAttributeList, 8>;
69
70 // Mapping for `_tpu_replicate` attribute to ops of a cluster.
71 using ClusterMap = llvm::SmallDenseMap<llvm::StringRef,
72 llvm::SmallSetVector<Operation*, 8>, 8>;
73
74 struct TPUClusterFormation : public FunctionPass<TPUClusterFormation> {
75 void runOnFunction() override;
76 };
77
78 // Creates a mapping from the TPUReplicateMetadata ops `_tpu_replicate`
79 // attribute to its attributes and removes the ops. If multiple
80 // TPUReplicateMetadata ops have the same `_tpu_replicate` attribute, an error
81 // will be returned.
CollectMetadata(Operation * op,MetadataMap * metadata_map)82 LogicalResult CollectMetadata(Operation* op, MetadataMap* metadata_map) {
83 auto result =
84 op->walk([&](TF::TPUReplicateMetadataOp metadata_op) -> WalkResult {
85 NamedAttributeList attrs = metadata_op.getAttrs();
86
87 // Missing or bad `_tpu_replicate` attribute.
88 auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr);
89 if (!tpu_replicate_attr)
90 return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
91
92 auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast<StringAttr>();
93 if (!tpu_replicate_attr_str ||
94 tpu_replicate_attr_str.getValue().empty())
95 return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
96
97 // Remove `name` attribute.
98 attrs.remove(Identifier::get(kNameAttr, metadata_op.getContext()));
99
100 auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(),
101 std::move(attrs));
102
103 // There are multiple TPUReplicateMetadata ops with the same
104 // `_tpu_replicate` attribute.
105 if (!it.second) {
106 return metadata_op.emitError()
107 << "multiple TPUReplicateMetadata ops with the same '"
108 << kTPUReplicateAttr << "' attribute '"
109 << tpu_replicate_attr_str.getValue() << "' found";
110 }
111
112 metadata_op.erase();
113 return WalkResult::advance();
114 });
115
116 // Return failure if the walk was interrupted.
117 return failure(result.wasInterrupted());
118 }
119
120 // Collects and clusters ops with the same `_tpu_replicate` attribute. This will
121 // return an error if a `_tpu_replicate` attribute of an op is empty.
CollectAndGroupClusterOps(Block * block,ClusterMap * clusters)122 LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) {
123 for (Operation& op : *block) {
124 if (auto attr = op.getAttrOfType<StringAttr>(kTPUReplicateAttr)) {
125 if (attr.getValue().empty())
126 return op.emitError()
127 << "attribute '" << kTPUReplicateAttr << "' is empty";
128
129 auto it = clusters->try_emplace(attr.getValue());
130 it.first->getSecond().insert(&op);
131 }
132 }
133
134 return success();
135 }
136
137 // Checks if an op should be moved after a cluster. There may be users of a
138 // cluster interleaved among the cluster ops.
ShouldMoveOpAfterCluster(Block * block,Operation * op,const llvm::SmallSetVector<Operation *,8> & cluster_ops,const llvm::SmallSetVector<Operation *,8> & preceding_users)139 bool ShouldMoveOpAfterCluster(
140 Block* block, Operation* op,
141 const llvm::SmallSetVector<Operation*, 8>& cluster_ops,
142 const llvm::SmallSetVector<Operation*, 8>& preceding_users) {
143 auto result = op->walk([&](Operation* op) {
144 for (Value operand : op->getOperands()) {
145 Operation* def = operand.getDefiningOp();
146 // Operands may not have a defining op (BlockArgument) or is from a
147 // different block.
148 if (!def || def->getBlock() != block) continue;
149
150 if (cluster_ops.count(def) != 0 || preceding_users.count(def) != 0) {
151 // Op is a user of a cluster or another op that is a user of the
152 // cluster (transitively), but is before the cluster.
153 return WalkResult::interrupt();
154 }
155 }
156 return WalkResult::advance();
157 });
158
159 return result.wasInterrupted();
160 }
161
162 // Collects ops that are before ops in the cluster but are users of other ops
163 // in the cluster. This may happen because users of individual ops in the
164 // cluster may be interleaved with other ops in the cluster.
CollectClusterPrecedingUsers(Block * block,const llvm::SmallSetVector<Operation *,8> & cluster_ops)165 llvm::SmallSetVector<Operation*, 8> CollectClusterPrecedingUsers(
166 Block* block, const llvm::SmallSetVector<Operation*, 8>& cluster_ops) {
167 llvm::SmallSetVector<Operation*, 8> preceding_users;
168
169 for (Operation& op : llvm::make_range(Block::iterator(cluster_ops.front()),
170 Block::iterator(cluster_ops.back())))
171 if (cluster_ops.count(&op) == 0 &&
172 ShouldMoveOpAfterCluster(block, &op, cluster_ops, preceding_users))
173 preceding_users.insert(&op);
174
175 return preceding_users;
176 }
177
178 // Collects results and associated types of the cluster that are used outside of
179 // the cluster. These results and types are used to create the clusters
180 // `tf_device::LaunchOp` and associated terminator. Results that have no uses
181 // outside of the cluster (i.e. results of ops in the cluster are only consumed
182 // by other ops in the cluster) are pruned.
CollectClusterResults(Block * block,const llvm::SmallSetVector<Operation *,8> & cluster_ops)183 llvm::SmallVector<Value, 8> CollectClusterResults(
184 Block* block, const llvm::SmallSetVector<Operation*, 8>& cluster_ops) {
185 llvm::SmallVector<Value, 8> results;
186
187 for (Operation* op : cluster_ops) {
188 for (Value result : op->getResults()) {
189 for (Operation* user : result.getUsers()) {
190 // Check if user is not an op in the cluster.
191 if (cluster_ops.count(block->findAncestorOpInBlock(*user)) == 0) {
192 results.push_back(result);
193 break;
194 }
195 }
196 }
197 }
198
199 return results;
200 }
201
202 // Creates a `tf_device::LaunchOp` to wrap cluster ops.
CreateLaunchOpForCluster(Operation * last_cluster_op,llvm::ArrayRef<Value> results)203 tf_device::LaunchOp CreateLaunchOpForCluster(Operation* last_cluster_op,
204 llvm::ArrayRef<Value> results) {
205 // `tf_device::LaunchOp` will be placed at where the last op of the cluster
206 // is.
207 OpBuilder builder(last_cluster_op);
208
209 llvm::SmallVector<Type, 8> result_types;
210 for (Value result : results) result_types.push_back(result.getType());
211
212 // An empty string placeholder is used for the device as that will be later
213 // populated with the device of the associated TPUReplicateMetadata op.
214 auto launch_op = builder.create<tf_device::LaunchOp>(
215 last_cluster_op->getLoc(), builder.getStringAttr(""), result_types);
216
217 launch_op.body().push_back(new Block);
218
219 // Add terminator.
220 builder.setInsertionPointToEnd(&launch_op.GetBody());
221 builder.create<tf_device::ReturnOp>(last_cluster_op->getLoc(), results);
222
223 return launch_op;
224 }
225
226 // Moves cluster ops to associated `tf_device.LaunchOp` body.
MoveClusterOpsToLaunchOp(tf_device::LaunchOp launch_op,const llvm::SmallSetVector<Operation *,8> & cluster_ops)227 void MoveClusterOpsToLaunchOp(
228 tf_device::LaunchOp launch_op,
229 const llvm::SmallSetVector<Operation*, 8>& cluster_ops) {
230 MLIRContext* context = launch_op.getContext();
231 Operation* terminator = &launch_op.GetBody().back();
232
233 for (Operation* cluster_op : cluster_ops) {
234 // Remove `_tpu_replicate` and `device` attribute from ops in the cluster
235 // as that information will be present in the `tf_device.LaunchOp`.
236 cluster_op->removeAttr(Identifier::get(kTPUReplicateAttr, context));
237 cluster_op->removeAttr(Identifier::get(kDeviceAttr, context));
238 cluster_op->moveBefore(terminator);
239 }
240 }
241
242 // Replaces uses of cluster ops results outside of cluster with the associated
243 // `tf_device::LaunchOp` results.
UpdateLaunchOpResultExternalUses(tf_device::LaunchOp launch_op,llvm::ArrayRef<Value> results)244 void UpdateLaunchOpResultExternalUses(tf_device::LaunchOp launch_op,
245 llvm::ArrayRef<Value> results) {
246 Block& launch_op_block = launch_op.GetBody();
247 for (auto ret_vals : llvm::zip(results, launch_op.getResults())) {
248 Value old_ret = std::get<0>(ret_vals);
249 Value new_ret = std::get<1>(ret_vals);
250 for (auto& use : llvm::make_early_inc_range(old_ret.getUses()))
251 if (!launch_op_block.findAncestorOpInBlock(*use.getOwner()))
252 use.set(new_ret);
253 }
254 }
255
256 // Moves users of cluster that are before the cluster to after the cluster.
MovePrecedingClusterUsers(tf_device::LaunchOp launch_op,llvm::ArrayRef<Operation * > preceding_users)257 void MovePrecedingClusterUsers(tf_device::LaunchOp launch_op,
258 llvm::ArrayRef<Operation*> preceding_users) {
259 Operation* op_after_launch_op = launch_op.getOperation()->getNextNode();
260 for (Operation* user : preceding_users) user->moveBefore(op_after_launch_op);
261 }
262
263 // Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index`
264 // of -1 are always after ops with a non negative `index`, and an arbitrary
265 // ordering is used as there are no dependencies on their relative ordering.
SortTPUReplicatedInputsByIndex(llvm::ArrayRef<Operation * > inputs,llvm::SmallVectorImpl<Operation * > * sorted_inputs)266 LogicalResult SortTPUReplicatedInputsByIndex(
267 llvm::ArrayRef<Operation*> inputs,
268 llvm::SmallVectorImpl<Operation*>* sorted_inputs) {
269 const int input_size = inputs.size();
270 sorted_inputs->resize(input_size, nullptr);
271 int last_index = input_size - 1;
272
273 for (Operation* input : inputs) {
274 int64_t index =
275 llvm::cast<TF::TPUReplicatedInputOp>(input).index().getLimitedValue();
276
277 if (index >= input_size || index < -1)
278 return input->emitError() << "'" << input->getName().getStringRef()
279 << "' index is not in range [-1, " << input_size
280 << "), got " << index;
281
282 if (index == -1)
283 (*sorted_inputs)[last_index--] = input;
284 else
285 (*sorted_inputs)[index] = input;
286 }
287
288 if (llvm::any_of(*sorted_inputs, [](Operation* op) { return op == nullptr; }))
289 return inputs.front()->emitError()
290 << "failed to sort '" << inputs.front()->getName().getStringRef()
291 << "' ops, gap(s) found in indices";
292
293 return success();
294 }
295
296 // Creates a `tf_device.replicate` to represent replication for the cluster, if
297 // necessary.
ReplicateCluster(tf_device::LaunchOp launch_op,int num_replicas)298 LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op,
299 int num_replicas) {
300 // No need to replicate.
301 if (num_replicas == 1) return success();
302
303 if (num_replicas < 1)
304 return launch_op.emitError() << "requires '" << kNumReplicasAttr
305 << "' int attribute to be at least 1";
306
307 // Collect all used TPUReplicatedInput ops and sort by `index`.
308 llvm::SmallSetVector<Operation*, 8> unique_replicated_input_ops;
309 mlir::visitUsedValuesDefinedAbove(
310 launch_op.body(), launch_op.body(), [&](mlir::OpOperand* operand) {
311 Operation* def = operand->get().getDefiningOp();
312 if (def && llvm::isa<TF::TPUReplicatedInputOp>(def))
313 unique_replicated_input_ops.insert(def);
314 });
315 llvm::SmallVector<Operation*, 8> replicated_input_ops;
316 if (failed(SortTPUReplicatedInputsByIndex(
317 unique_replicated_input_ops.getArrayRef(), &replicated_input_ops)))
318 return failure();
319
320 // Indices of the replicate op's arguments that are mirrored variables.
321 llvm::SmallVector<int64_t, 8> mirrored_variable_indices;
322
323 // Check if number of operands of each used TPUReplicatedInput op matches
324 // `num_replicas`. Collect all their operands and associated type for creating
325 // the replicate op.
326 llvm::SmallVector<std::pair<Operation::operand_range, Type>, 8>
327 replicated_inputs;
328 for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) {
329 auto input = pos_and_input.value();
330 if (input->getNumOperands() != num_replicas)
331 return input->emitOpError() << "requires " << num_replicas << " operands";
332
333 replicated_inputs.push_back(
334 {input->getOperands(), input->getOperand(0).getType()});
335 if (llvm::cast<TF::TPUReplicatedInputOp>(input).is_mirrored_variable())
336 mirrored_variable_indices.push_back(pos_and_input.index());
337 }
338
339 // Create replicate op.
340 OpBuilder builder(launch_op);
341 auto replicate_op = builder.create<tf_device::ReplicateOp>(
342 launch_op.getLoc(), num_replicas, llvm::ArrayRef<llvm::StringRef>(),
343 replicated_inputs, launch_op.getResultTypes());
344 if (!mirrored_variable_indices.empty())
345 replicate_op.setAttr(kMirroredVariableIndicesAttr,
346 builder.getI64ArrayAttr(mirrored_variable_indices));
347
348 // Replace replicated cluster results with replicate op results.
349 for (auto result_and_idx : llvm::enumerate(launch_op.getResults())) {
350 Value result = result_and_idx.value();
351 int idx = result_and_idx.index();
352 for (auto& use : result.getUses()) {
353 Operation* def = use.getOwner();
354 if (!def || !llvm::isa<TF::TPUReplicatedOutputOp>(def))
355 return launch_op.emitError()
356 << "requires output of " << launch_op.getOperationName()
357 << " to lead to a 'tf.TPUReplicatedOutput' op";
358
359 if (def->getNumResults() != num_replicas)
360 return def->emitOpError() << "requires " << num_replicas << " results";
361
362 auto replicate_outputs = llvm::make_range(
363 std::next(replicate_op.result_begin(), idx * num_replicas),
364 std::next(replicate_op.result_begin(), (idx + 1) * num_replicas));
365 def->replaceAllUsesWith(replicate_outputs);
366 }
367 }
368
369 // Update replicated inputs with replicate op block arguments.
370 for (auto input_and_block_arg :
371 llvm::zip(replicated_input_ops, replicate_op.GetBody().getArguments())) {
372 Operation* input = std::get<0>(input_and_block_arg);
373 Value block_arg = std::get<1>(input_and_block_arg);
374 mlir::replaceAllUsesInRegionWith(input->getResult(0), block_arg,
375 launch_op.body());
376 }
377
378 // Create terminator for replicate op and move launch into replicate.
379 builder.setInsertionPointToEnd(&replicate_op.GetBody());
380 auto return_op = builder.create<tf_device::ReturnOp>(replicate_op.getLoc(),
381 launch_op.getResults());
382 launch_op.getOperation()->moveBefore(return_op);
383
384 return success();
385 }
386
387 // Forms clusters with ops of the same `_tpu_replicate` attribute under a block.
388 //
389 // For a given block, clusters are formed via grouping ops by `_tpu_replicate`
390 // attributes.
391 // For every cluster formed:
392 // 1. Find associated TPUReplicateMetadata attributes with the same
393 // `_tpu_replicate` attribute.
394 // 2. Find users not in cluster that are interleaved between cluster ops.
395 // 3. Find external uses of cluster ops.
396 // 4. Create `tf_device::LaunchOp` with results consisting of the external
397 // uses of cluster ops determined at 3.
398 // 5. Move cluster ops to `tf_device::LaunchOp` body.
399 // 6. Replace external uses of cluster ops uses with `tf_device::LaunchOp`
400 // results.
401 // 7. Move users from 2 to after the `tf_device::LaunchOp`.
402 // 8. Wrap cluster (`tf_device::LaunchOp`) in a `tf_device.replicate` if
403 // attribute `num_replicas` is greater than 1.
404 // 9. Copy over TPUReplicateMetadata attributes to `tf_device::LaunchOp`.
FormClustersInBlock(Block * block,const MetadataMap & metadata_map)405 LogicalResult FormClustersInBlock(Block* block,
406 const MetadataMap& metadata_map) {
407 ClusterMap clusters;
408 LogicalResult result = CollectAndGroupClusterOps(block, &clusters);
409 if (failed(result)) return result;
410
411 for (const auto& cluster : clusters) {
412 const auto& cluster_ops = cluster.getSecond();
413
414 auto cluster_metadata = metadata_map.find(cluster.getFirst());
415
416 // No TPUReplicateMetadata for a `_tpu_replicate` attribute.
417 if (cluster_metadata == metadata_map.end()) {
418 cluster_ops.front()->emitWarning()
419 << "TPUReplicateMetadata for associated '" << kTPUReplicateAttr
420 << "' attribute '" << cluster.getFirst() << "' is missing";
421 continue;
422 }
423
424 llvm::SmallSetVector<Operation*, 8> preceding_users =
425 CollectClusterPrecedingUsers(block, cluster_ops);
426
427 llvm::SmallVector<Value, 8> results =
428 CollectClusterResults(block, cluster_ops);
429
430 tf_device::LaunchOp launch_op =
431 CreateLaunchOpForCluster(cluster_ops.back(), results);
432
433 MoveClusterOpsToLaunchOp(launch_op, cluster_ops);
434
435 UpdateLaunchOpResultExternalUses(launch_op, results);
436
437 MovePrecedingClusterUsers(launch_op, preceding_users.getArrayRef());
438
439 auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr);
440 if (!num_replicas || !num_replicas.isa<mlir::IntegerAttr>())
441 return launch_op.emitError()
442 << "requires '" << kNumReplicasAttr << "' int attribute";
443
444 if (failed(ReplicateCluster(
445 launch_op, num_replicas.cast<mlir::IntegerAttr>().getInt())))
446 return failure();
447
448 // Copy TPUReplicateMetadata attributes to launch.
449 launch_op.setAttrs(cluster_metadata->second);
450 // Exclude `num_replicas` as cluster should be replicated if necessary.
451 launch_op.removeAttr(kNumReplicasAttr);
452 }
453
454 return success();
455 }
456
runOnFunction()457 void TPUClusterFormation::runOnFunction() {
458 MetadataMap metadata_map;
459 if (failed(CollectMetadata(getFunction(), &metadata_map)))
460 return signalPassFailure();
461
462 for (Block& block : getFunction())
463 if (failed(FormClustersInBlock(&block, metadata_map)))
464 return signalPassFailure();
465
466 auto island_result = getFunction().walk([&](tf_executor::IslandOp island) {
467 if (failed(FormClustersInBlock(&island.GetBody(), metadata_map)))
468 return WalkResult::interrupt();
469
470 return WalkResult::advance();
471 });
472
473 if (island_result.wasInterrupted()) return signalPassFailure();
474
475 // Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
476 auto remove_result = getFunction().walk([&](Operation* op) {
477 if (!llvm::isa<TF::TPUReplicatedInputOp>(op) &&
478 !llvm::isa<TF::TPUReplicatedOutputOp>(op))
479 return WalkResult::advance();
480
481 // Forward operand to result. When `num_replicas` attribute is 1, no
482 // `tf_device.replicate` is created and replicated (1) operands/results are
483 // untouched.
484 if (op->getNumOperands() == 1 && op->getNumResults() == 1)
485 op->getResult(0).replaceAllUsesWith(op->getOperand(0));
486
487 // Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of
488 // `num_replicas` to 1.
489 if (!op->use_empty()) {
490 op->emitOpError() << "expects " << op->getName().getStringRef()
491 << " to have no uses";
492 return WalkResult::interrupt();
493 }
494
495 op->erase();
496
497 return WalkResult::advance();
498 });
499
500 if (remove_result.wasInterrupted()) return signalPassFailure();
501 }
502 } // anonymous namespace
503
CreateTPUClusterFormationPass()504 std::unique_ptr<OpPassBase<FuncOp>> CreateTPUClusterFormationPass() {
505 return std::make_unique<TPUClusterFormation>();
506 }
507
508 static PassRegistration<TPUClusterFormation> pass(
509 "tf-tpu-cluster-formation",
510 "Form clusters from operations assigned to the same TPU cluster");
511
512 } // namespace TFTPU
513 } // namespace mlir
514