• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass takes ops with the same `_tpu_replicate` attribute
17 // in a block and clusters them together under a `tf_device::LaunchOp`.
18 // Associated TPUReplicateMetadata ops are removed and its attributes are copied
19 // over to the associated `tf_device::LaunchOp`. If a cluster should be
20 // replicated, the associated `tf_device::LaunchOp` will be wrapped further with
21 // a `tf_device.replicate`. This pass also assumes ops of the same cluster do
22 // not have ops outside of the cluster that are both operands and results of the
23 // cluster. Note, this currently does not handle side effecting ops yet.
24 
25 #include <iterator>
26 #include <memory>
27 #include <tuple>
28 #include <utility>
29 
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/SetVector.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/iterator_range.h"
37 #include "llvm/Support/Casting.h"
38 #include "mlir/IR/Attributes.h"  // TF:llvm-project
39 #include "mlir/IR/Builders.h"  // TF:llvm-project
40 #include "mlir/IR/Identifier.h"  // TF:llvm-project
41 #include "mlir/IR/MLIRContext.h"  // TF:llvm-project
42 #include "mlir/IR/Operation.h"  // TF:llvm-project
43 #include "mlir/IR/Types.h"  // TF:llvm-project
44 #include "mlir/IR/Value.h"  // TF:llvm-project
45 #include "mlir/Pass/Pass.h"  // TF:llvm-project
46 #include "mlir/Pass/PassRegistry.h"  // TF:llvm-project
47 #include "mlir/Support/LogicalResult.h"  // TF:llvm-project
48 #include "mlir/Transforms/RegionUtils.h"  // TF:llvm-project
49 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
50 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
52 
53 namespace mlir {
54 namespace TFTPU {
55 
56 namespace {
57 
58 constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
59 constexpr char kDeviceAttr[] = "device";
60 constexpr char kNameAttr[] = "name";
61 constexpr char kNumReplicasAttr[] = "num_replicas";
62 constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
63 
64 constexpr char kBadTPUReplicateAttrMsg[] =
65     "requires '_tpu_replicate' string attribute";
66 
67 // Mapping for `_tpu_replicate` attribute to TPUReplicateMetadata attributes.
68 using MetadataMap = llvm::SmallDenseMap<llvm::StringRef, NamedAttributeList, 8>;
69 
70 // Mapping for `_tpu_replicate` attribute to ops of a cluster.
71 using ClusterMap = llvm::SmallDenseMap<llvm::StringRef,
72                                        llvm::SmallSetVector<Operation*, 8>, 8>;
73 
74 struct TPUClusterFormation : public FunctionPass<TPUClusterFormation> {
75   void runOnFunction() override;
76 };
77 
78 // Creates a mapping from the TPUReplicateMetadata ops `_tpu_replicate`
79 // attribute to its attributes and removes the ops. If multiple
80 // TPUReplicateMetadata ops have the same `_tpu_replicate` attribute, an error
81 // will be returned.
CollectMetadata(Operation * op,MetadataMap * metadata_map)82 LogicalResult CollectMetadata(Operation* op, MetadataMap* metadata_map) {
83   auto result =
84       op->walk([&](TF::TPUReplicateMetadataOp metadata_op) -> WalkResult {
85         NamedAttributeList attrs = metadata_op.getAttrs();
86 
87         // Missing or bad `_tpu_replicate` attribute.
88         auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr);
89         if (!tpu_replicate_attr)
90           return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
91 
92         auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast<StringAttr>();
93         if (!tpu_replicate_attr_str ||
94             tpu_replicate_attr_str.getValue().empty())
95           return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
96 
97         // Remove `name` attribute.
98         attrs.remove(Identifier::get(kNameAttr, metadata_op.getContext()));
99 
100         auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(),
101                                             std::move(attrs));
102 
103         // There are multiple TPUReplicateMetadata ops with the same
104         // `_tpu_replicate` attribute.
105         if (!it.second) {
106           return metadata_op.emitError()
107                  << "multiple TPUReplicateMetadata ops with the same '"
108                  << kTPUReplicateAttr << "' attribute '"
109                  << tpu_replicate_attr_str.getValue() << "' found";
110         }
111 
112         metadata_op.erase();
113         return WalkResult::advance();
114       });
115 
116   // Return failure if the walk was interrupted.
117   return failure(result.wasInterrupted());
118 }
119 
120 // Collects and clusters ops with the same `_tpu_replicate` attribute. This will
121 // return an error if a `_tpu_replicate` attribute of an op is empty.
CollectAndGroupClusterOps(Block * block,ClusterMap * clusters)122 LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) {
123   for (Operation& op : *block) {
124     if (auto attr = op.getAttrOfType<StringAttr>(kTPUReplicateAttr)) {
125       if (attr.getValue().empty())
126         return op.emitError()
127                << "attribute '" << kTPUReplicateAttr << "' is empty";
128 
129       auto it = clusters->try_emplace(attr.getValue());
130       it.first->getSecond().insert(&op);
131     }
132   }
133 
134   return success();
135 }
136 
137 // Checks if an op should be moved after a cluster. There may be users of a
138 // cluster interleaved among the cluster ops.
ShouldMoveOpAfterCluster(Block * block,Operation * op,const llvm::SmallSetVector<Operation *,8> & cluster_ops,const llvm::SmallSetVector<Operation *,8> & preceding_users)139 bool ShouldMoveOpAfterCluster(
140     Block* block, Operation* op,
141     const llvm::SmallSetVector<Operation*, 8>& cluster_ops,
142     const llvm::SmallSetVector<Operation*, 8>& preceding_users) {
143   auto result = op->walk([&](Operation* op) {
144     for (Value operand : op->getOperands()) {
145       Operation* def = operand.getDefiningOp();
146       // Operands may not have a defining op (BlockArgument) or is from a
147       // different block.
148       if (!def || def->getBlock() != block) continue;
149 
150       if (cluster_ops.count(def) != 0 || preceding_users.count(def) != 0) {
151         // Op is a user of a cluster or another op that is a user of the
152         // cluster (transitively), but is before the cluster.
153         return WalkResult::interrupt();
154       }
155     }
156     return WalkResult::advance();
157   });
158 
159   return result.wasInterrupted();
160 }
161 
162 // Collects ops that are before ops in the cluster but are users of other ops
163 // in the cluster. This may happen because users of individual ops in the
164 // cluster may be interleaved with other ops in the cluster.
CollectClusterPrecedingUsers(Block * block,const llvm::SmallSetVector<Operation *,8> & cluster_ops)165 llvm::SmallSetVector<Operation*, 8> CollectClusterPrecedingUsers(
166     Block* block, const llvm::SmallSetVector<Operation*, 8>& cluster_ops) {
167   llvm::SmallSetVector<Operation*, 8> preceding_users;
168 
169   for (Operation& op : llvm::make_range(Block::iterator(cluster_ops.front()),
170                                         Block::iterator(cluster_ops.back())))
171     if (cluster_ops.count(&op) == 0 &&
172         ShouldMoveOpAfterCluster(block, &op, cluster_ops, preceding_users))
173       preceding_users.insert(&op);
174 
175   return preceding_users;
176 }
177 
178 // Collects results and associated types of the cluster that are used outside of
179 // the cluster. These results and types are used to create the clusters
180 // `tf_device::LaunchOp` and associated terminator. Results that have no uses
181 // outside of the cluster (i.e. results of ops in the cluster are only consumed
182 // by other ops in the cluster) are pruned.
CollectClusterResults(Block * block,const llvm::SmallSetVector<Operation *,8> & cluster_ops)183 llvm::SmallVector<Value, 8> CollectClusterResults(
184     Block* block, const llvm::SmallSetVector<Operation*, 8>& cluster_ops) {
185   llvm::SmallVector<Value, 8> results;
186 
187   for (Operation* op : cluster_ops) {
188     for (Value result : op->getResults()) {
189       for (Operation* user : result.getUsers()) {
190         // Check if user is not an op in the cluster.
191         if (cluster_ops.count(block->findAncestorOpInBlock(*user)) == 0) {
192           results.push_back(result);
193           break;
194         }
195       }
196     }
197   }
198 
199   return results;
200 }
201 
202 // Creates a `tf_device::LaunchOp` to wrap cluster ops.
CreateLaunchOpForCluster(Operation * last_cluster_op,llvm::ArrayRef<Value> results)203 tf_device::LaunchOp CreateLaunchOpForCluster(Operation* last_cluster_op,
204                                              llvm::ArrayRef<Value> results) {
205   // `tf_device::LaunchOp` will be placed at where the last op of the cluster
206   // is.
207   OpBuilder builder(last_cluster_op);
208 
209   llvm::SmallVector<Type, 8> result_types;
210   for (Value result : results) result_types.push_back(result.getType());
211 
212   // An empty string placeholder is used for the device as that will be later
213   // populated with the device of the associated TPUReplicateMetadata op.
214   auto launch_op = builder.create<tf_device::LaunchOp>(
215       last_cluster_op->getLoc(), builder.getStringAttr(""), result_types);
216 
217   launch_op.body().push_back(new Block);
218 
219   // Add terminator.
220   builder.setInsertionPointToEnd(&launch_op.GetBody());
221   builder.create<tf_device::ReturnOp>(last_cluster_op->getLoc(), results);
222 
223   return launch_op;
224 }
225 
226 // Moves cluster ops to associated `tf_device.LaunchOp` body.
MoveClusterOpsToLaunchOp(tf_device::LaunchOp launch_op,const llvm::SmallSetVector<Operation *,8> & cluster_ops)227 void MoveClusterOpsToLaunchOp(
228     tf_device::LaunchOp launch_op,
229     const llvm::SmallSetVector<Operation*, 8>& cluster_ops) {
230   MLIRContext* context = launch_op.getContext();
231   Operation* terminator = &launch_op.GetBody().back();
232 
233   for (Operation* cluster_op : cluster_ops) {
234     // Remove `_tpu_replicate` and `device` attribute from ops in the cluster
235     // as that information will be present in the `tf_device.LaunchOp`.
236     cluster_op->removeAttr(Identifier::get(kTPUReplicateAttr, context));
237     cluster_op->removeAttr(Identifier::get(kDeviceAttr, context));
238     cluster_op->moveBefore(terminator);
239   }
240 }
241 
242 // Replaces uses of cluster ops results outside of cluster with the associated
243 // `tf_device::LaunchOp` results.
UpdateLaunchOpResultExternalUses(tf_device::LaunchOp launch_op,llvm::ArrayRef<Value> results)244 void UpdateLaunchOpResultExternalUses(tf_device::LaunchOp launch_op,
245                                       llvm::ArrayRef<Value> results) {
246   Block& launch_op_block = launch_op.GetBody();
247   for (auto ret_vals : llvm::zip(results, launch_op.getResults())) {
248     Value old_ret = std::get<0>(ret_vals);
249     Value new_ret = std::get<1>(ret_vals);
250     for (auto& use : llvm::make_early_inc_range(old_ret.getUses()))
251       if (!launch_op_block.findAncestorOpInBlock(*use.getOwner()))
252         use.set(new_ret);
253   }
254 }
255 
256 // Moves users of cluster that are before the cluster to after the cluster.
MovePrecedingClusterUsers(tf_device::LaunchOp launch_op,llvm::ArrayRef<Operation * > preceding_users)257 void MovePrecedingClusterUsers(tf_device::LaunchOp launch_op,
258                                llvm::ArrayRef<Operation*> preceding_users) {
259   Operation* op_after_launch_op = launch_op.getOperation()->getNextNode();
260   for (Operation* user : preceding_users) user->moveBefore(op_after_launch_op);
261 }
262 
263 // Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index`
264 // of -1 are always after ops with a non negative `index`, and an arbitrary
265 // ordering is used as there are no dependencies on their relative ordering.
SortTPUReplicatedInputsByIndex(llvm::ArrayRef<Operation * > inputs,llvm::SmallVectorImpl<Operation * > * sorted_inputs)266 LogicalResult SortTPUReplicatedInputsByIndex(
267     llvm::ArrayRef<Operation*> inputs,
268     llvm::SmallVectorImpl<Operation*>* sorted_inputs) {
269   const int input_size = inputs.size();
270   sorted_inputs->resize(input_size, nullptr);
271   int last_index = input_size - 1;
272 
273   for (Operation* input : inputs) {
274     int64_t index =
275         llvm::cast<TF::TPUReplicatedInputOp>(input).index().getLimitedValue();
276 
277     if (index >= input_size || index < -1)
278       return input->emitError() << "'" << input->getName().getStringRef()
279                                 << "' index is not in range [-1, " << input_size
280                                 << "), got " << index;
281 
282     if (index == -1)
283       (*sorted_inputs)[last_index--] = input;
284     else
285       (*sorted_inputs)[index] = input;
286   }
287 
288   if (llvm::any_of(*sorted_inputs, [](Operation* op) { return op == nullptr; }))
289     return inputs.front()->emitError()
290            << "failed to sort '" << inputs.front()->getName().getStringRef()
291            << "' ops, gap(s) found in indices";
292 
293   return success();
294 }
295 
296 // Creates a `tf_device.replicate` to represent replication for the cluster, if
297 // necessary.
ReplicateCluster(tf_device::LaunchOp launch_op,int num_replicas)298 LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op,
299                                int num_replicas) {
300   // No need to replicate.
301   if (num_replicas == 1) return success();
302 
303   if (num_replicas < 1)
304     return launch_op.emitError() << "requires '" << kNumReplicasAttr
305                                  << "' int attribute to be at least 1";
306 
307   // Collect all used TPUReplicatedInput ops and sort by `index`.
308   llvm::SmallSetVector<Operation*, 8> unique_replicated_input_ops;
309   mlir::visitUsedValuesDefinedAbove(
310       launch_op.body(), launch_op.body(), [&](mlir::OpOperand* operand) {
311         Operation* def = operand->get().getDefiningOp();
312         if (def && llvm::isa<TF::TPUReplicatedInputOp>(def))
313           unique_replicated_input_ops.insert(def);
314       });
315   llvm::SmallVector<Operation*, 8> replicated_input_ops;
316   if (failed(SortTPUReplicatedInputsByIndex(
317           unique_replicated_input_ops.getArrayRef(), &replicated_input_ops)))
318     return failure();
319 
320   // Indices of the replicate op's arguments that are mirrored variables.
321   llvm::SmallVector<int64_t, 8> mirrored_variable_indices;
322 
323   // Check if number of operands of each used TPUReplicatedInput op matches
324   // `num_replicas`. Collect all their operands and associated type for creating
325   // the replicate op.
326   llvm::SmallVector<std::pair<Operation::operand_range, Type>, 8>
327       replicated_inputs;
328   for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) {
329     auto input = pos_and_input.value();
330     if (input->getNumOperands() != num_replicas)
331       return input->emitOpError() << "requires " << num_replicas << " operands";
332 
333     replicated_inputs.push_back(
334         {input->getOperands(), input->getOperand(0).getType()});
335     if (llvm::cast<TF::TPUReplicatedInputOp>(input).is_mirrored_variable())
336       mirrored_variable_indices.push_back(pos_and_input.index());
337   }
338 
339   // Create replicate op.
340   OpBuilder builder(launch_op);
341   auto replicate_op = builder.create<tf_device::ReplicateOp>(
342       launch_op.getLoc(), num_replicas, llvm::ArrayRef<llvm::StringRef>(),
343       replicated_inputs, launch_op.getResultTypes());
344   if (!mirrored_variable_indices.empty())
345     replicate_op.setAttr(kMirroredVariableIndicesAttr,
346                          builder.getI64ArrayAttr(mirrored_variable_indices));
347 
348   // Replace replicated cluster results with replicate op results.
349   for (auto result_and_idx : llvm::enumerate(launch_op.getResults())) {
350     Value result = result_and_idx.value();
351     int idx = result_and_idx.index();
352     for (auto& use : result.getUses()) {
353       Operation* def = use.getOwner();
354       if (!def || !llvm::isa<TF::TPUReplicatedOutputOp>(def))
355         return launch_op.emitError()
356                << "requires output of " << launch_op.getOperationName()
357                << " to lead to a 'tf.TPUReplicatedOutput' op";
358 
359       if (def->getNumResults() != num_replicas)
360         return def->emitOpError() << "requires " << num_replicas << " results";
361 
362       auto replicate_outputs = llvm::make_range(
363           std::next(replicate_op.result_begin(), idx * num_replicas),
364           std::next(replicate_op.result_begin(), (idx + 1) * num_replicas));
365       def->replaceAllUsesWith(replicate_outputs);
366     }
367   }
368 
369   // Update replicated inputs with replicate op block arguments.
370   for (auto input_and_block_arg :
371        llvm::zip(replicated_input_ops, replicate_op.GetBody().getArguments())) {
372     Operation* input = std::get<0>(input_and_block_arg);
373     Value block_arg = std::get<1>(input_and_block_arg);
374     mlir::replaceAllUsesInRegionWith(input->getResult(0), block_arg,
375                                      launch_op.body());
376   }
377 
378   // Create terminator for replicate op and move launch into replicate.
379   builder.setInsertionPointToEnd(&replicate_op.GetBody());
380   auto return_op = builder.create<tf_device::ReturnOp>(replicate_op.getLoc(),
381                                                        launch_op.getResults());
382   launch_op.getOperation()->moveBefore(return_op);
383 
384   return success();
385 }
386 
387 // Forms clusters with ops of the same `_tpu_replicate` attribute under a block.
388 //
389 // For a given block, clusters are formed via grouping ops by `_tpu_replicate`
390 // attributes.
391 // For every cluster formed:
392 //   1. Find associated TPUReplicateMetadata attributes with the same
393 //      `_tpu_replicate` attribute.
394 //   2. Find users not in cluster that are interleaved between cluster ops.
395 //   3. Find external uses of cluster ops.
396 //   4. Create `tf_device::LaunchOp` with results consisting of the external
397 //      uses of cluster ops determined at 3.
398 //   5. Move cluster ops to `tf_device::LaunchOp` body.
399 //   6. Replace external uses of cluster ops uses with `tf_device::LaunchOp`
400 //      results.
401 //   7. Move users from 2 to after the `tf_device::LaunchOp`.
402 //   8. Wrap cluster (`tf_device::LaunchOp`) in a `tf_device.replicate` if
403 //      attribute `num_replicas` is greater than 1.
404 //   9. Copy over TPUReplicateMetadata attributes to `tf_device::LaunchOp`.
FormClustersInBlock(Block * block,const MetadataMap & metadata_map)405 LogicalResult FormClustersInBlock(Block* block,
406                                   const MetadataMap& metadata_map) {
407   ClusterMap clusters;
408   LogicalResult result = CollectAndGroupClusterOps(block, &clusters);
409   if (failed(result)) return result;
410 
411   for (const auto& cluster : clusters) {
412     const auto& cluster_ops = cluster.getSecond();
413 
414     auto cluster_metadata = metadata_map.find(cluster.getFirst());
415 
416     // No TPUReplicateMetadata for a `_tpu_replicate` attribute.
417     if (cluster_metadata == metadata_map.end()) {
418       cluster_ops.front()->emitWarning()
419           << "TPUReplicateMetadata for associated '" << kTPUReplicateAttr
420           << "' attribute '" << cluster.getFirst() << "' is missing";
421       continue;
422     }
423 
424     llvm::SmallSetVector<Operation*, 8> preceding_users =
425         CollectClusterPrecedingUsers(block, cluster_ops);
426 
427     llvm::SmallVector<Value, 8> results =
428         CollectClusterResults(block, cluster_ops);
429 
430     tf_device::LaunchOp launch_op =
431         CreateLaunchOpForCluster(cluster_ops.back(), results);
432 
433     MoveClusterOpsToLaunchOp(launch_op, cluster_ops);
434 
435     UpdateLaunchOpResultExternalUses(launch_op, results);
436 
437     MovePrecedingClusterUsers(launch_op, preceding_users.getArrayRef());
438 
439     auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr);
440     if (!num_replicas || !num_replicas.isa<mlir::IntegerAttr>())
441       return launch_op.emitError()
442              << "requires '" << kNumReplicasAttr << "' int attribute";
443 
444     if (failed(ReplicateCluster(
445             launch_op, num_replicas.cast<mlir::IntegerAttr>().getInt())))
446       return failure();
447 
448     // Copy TPUReplicateMetadata attributes to launch.
449     launch_op.setAttrs(cluster_metadata->second);
450     // Exclude `num_replicas` as cluster should be replicated if necessary.
451     launch_op.removeAttr(kNumReplicasAttr);
452   }
453 
454   return success();
455 }
456 
runOnFunction()457 void TPUClusterFormation::runOnFunction() {
458   MetadataMap metadata_map;
459   if (failed(CollectMetadata(getFunction(), &metadata_map)))
460     return signalPassFailure();
461 
462   for (Block& block : getFunction())
463     if (failed(FormClustersInBlock(&block, metadata_map)))
464       return signalPassFailure();
465 
466   auto island_result = getFunction().walk([&](tf_executor::IslandOp island) {
467     if (failed(FormClustersInBlock(&island.GetBody(), metadata_map)))
468       return WalkResult::interrupt();
469 
470     return WalkResult::advance();
471   });
472 
473   if (island_result.wasInterrupted()) return signalPassFailure();
474 
475   // Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
476   auto remove_result = getFunction().walk([&](Operation* op) {
477     if (!llvm::isa<TF::TPUReplicatedInputOp>(op) &&
478         !llvm::isa<TF::TPUReplicatedOutputOp>(op))
479       return WalkResult::advance();
480 
481     // Forward operand to result. When `num_replicas` attribute is 1, no
482     // `tf_device.replicate` is created and replicated (1) operands/results are
483     // untouched.
484     if (op->getNumOperands() == 1 && op->getNumResults() == 1)
485       op->getResult(0).replaceAllUsesWith(op->getOperand(0));
486 
487     // Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of
488     // `num_replicas` to 1.
489     if (!op->use_empty()) {
490       op->emitOpError() << "expects " << op->getName().getStringRef()
491                         << " to have no uses";
492       return WalkResult::interrupt();
493     }
494 
495     op->erase();
496 
497     return WalkResult::advance();
498   });
499 
500   if (remove_result.wasInterrupted()) return signalPassFailure();
501 }
502 }  // anonymous namespace
503 
CreateTPUClusterFormationPass()504 std::unique_ptr<OpPassBase<FuncOp>> CreateTPUClusterFormationPass() {
505   return std::make_unique<TPUClusterFormation>();
506 }
507 
508 static PassRegistration<TPUClusterFormation> pass(
509     "tf-tpu-cluster-formation",
510     "Form clusters from operations assigned to the same TPU cluster");
511 
512 }  // namespace TFTPU
513 }  // namespace mlir
514