• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <iterator>
18 #include <memory>
19 #include <tuple>
20 #include <utility>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/ADT/iterator_range.h"
30 #include "llvm/Support/Casting.h"
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/Builders.h"  // from @llvm-project
33 #include "mlir/IR/Identifier.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/Operation.h"  // from @llvm-project
36 #include "mlir/IR/Types.h"  // from @llvm-project
37 #include "mlir/IR/Value.h"  // from @llvm-project
38 #include "mlir/Pass/Pass.h"  // from @llvm-project
39 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
47 
48 namespace mlir {
49 namespace TFTPU {
50 
51 namespace {
52 
53 constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
54 constexpr char kDeviceAttr[] = "device";
55 constexpr char kNameAttr[] = "name";
56 constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica";
57 constexpr char kNumReplicasAttr[] = "num_replicas";
58 constexpr char kReplicatedInputIndicesAttr[] = "_replicated_input_indices";
59 constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
60 
61 constexpr char kBadTPUReplicateAttrMsg[] =
62     "requires '_tpu_replicate' string attribute";
63 
64 // Mapping for `_tpu_replicate` attribute to TPUReplicateMetadata attributes.
65 using MetadataMap = llvm::SmallDenseMap<llvm::StringRef, NamedAttrList, 8>;
66 
67 // A set of operations in a cluster.
68 using ClusterOps = llvm::SmallSetVector<Operation*, 8>;
69 
70 // Mapping for `_tpu_replicate` attribute to ops of a cluster.
71 using ClusterMap = llvm::SmallDenseMap<llvm::StringRef, ClusterOps, 8>;
72 
73 struct TPUClusterFormationPass
74     : public TF::TPUClusterFormationPassBase<TPUClusterFormationPass> {
getDependentDialectsmlir::TFTPU::__anonb7006daa0111::TPUClusterFormationPass75   void getDependentDialects(DialectRegistry& registry) const override {
76     registry.insert<tf_device::TensorFlowDeviceDialect>();
77   }
78 
79   void runOnOperation() override;
80 };
81 
82 // Creates a mapping from the TPUReplicateMetadata ops `_tpu_replicate`
83 // attribute to its attributes and removes the ops. If multiple
84 // TPUReplicateMetadata ops have the same `_tpu_replicate` attribute, an error
85 // will be returned.
CollectMetadata(Block * block,MetadataMap * metadata_map)86 LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) {
87   // Just look at top-level operations in the block (not nested ones)
88   for (Operation& op : llvm::make_early_inc_range(*block)) {
89     auto metadata_op = dyn_cast<TF::TPUReplicateMetadataOp>(op);
90     if (!metadata_op) continue;
91 
92     NamedAttrList attrs(metadata_op->getAttrDictionary());
93 
94     // Missing or bad `_tpu_replicate` attribute.
95     auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr);
96     if (!tpu_replicate_attr)
97       return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
98 
99     auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast<StringAttr>();
100     if (!tpu_replicate_attr_str || tpu_replicate_attr_str.getValue().empty())
101       return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
102 
103     // Remove `name` attribute.
104     attrs.erase(Identifier::get(kNameAttr, metadata_op.getContext()));
105 
106     auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(),
107                                         std::move(attrs));
108 
109     // There are multiple TPUReplicateMetadata ops with the same
110     // `_tpu_replicate` attribute.
111     if (!it.second) {
112       return metadata_op.emitError()
113              << "multiple TPUReplicateMetadata ops with the same '"
114              << kTPUReplicateAttr << "' attribute '"
115              << tpu_replicate_attr_str.getValue() << "' found";
116     }
117     metadata_op.erase();
118   }
119   return success();
120 }
121 
122 // Collects and clusters ops with the same `_tpu_replicate` attribute. This will
123 // return an error if a `_tpu_replicate` attribute of an op is empty.
CollectAndGroupClusterOps(Block * block,ClusterMap * clusters)124 LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) {
125   for (Operation& op : *block) {
126     if (auto attr = op.getAttrOfType<StringAttr>(kTPUReplicateAttr)) {
127       if (attr.getValue().empty())
128         return op.emitError()
129                << "attribute '" << kTPUReplicateAttr << "' is empty";
130 
131       auto it = clusters->try_emplace(attr.getValue());
132       it.first->getSecond().insert(&op);
133     }
134   }
135 
136   return success();
137 }
138 
139 // Collects all resource ids from an op.
CollectResourceIdsFromOp(Operation & op,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis,llvm::SmallDenseSet<int64_t> & observed_resource_ids)140 void CollectResourceIdsFromOp(
141     Operation& op,
142     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis,
143     llvm::SmallDenseSet<int64_t>& observed_resource_ids) {
144   op.walk([&](Operation* inner_op) {
145     for (Value operand : TF::filter_resources(inner_op->getOperands())) {
146       if (resource_alias_analysis.IsUnknownResource(operand)) continue;
147       const auto& ids = resource_alias_analysis.GetResourceUniqueIds(operand);
148       observed_resource_ids.insert(ids.begin(), ids.end());
149     }
150     for (Value result : TF::filter_resources(inner_op->getResults())) {
151       if (resource_alias_analysis.IsUnknownResource(result)) continue;
152       const auto& ids = resource_alias_analysis.GetResourceUniqueIds(result);
153       observed_resource_ids.insert(ids.begin(), ids.end());
154     }
155   });
156 }
157 
158 // Checks if an op should be moved after a cluster. There may be users of a
159 // cluster interleaved among the cluster ops.
ShouldMoveOpAfterCluster(Block * block,Operation * op,const ClusterOps & cluster_ops,const llvm::SmallSetVector<Operation *,8> & preceding_users,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis,const llvm::SmallDenseSet<int64_t> & observed_resource_ids)160 bool ShouldMoveOpAfterCluster(
161     Block* block, Operation* op, const ClusterOps& cluster_ops,
162     const llvm::SmallSetVector<Operation*, 8>& preceding_users,
163     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis,
164     const llvm::SmallDenseSet<int64_t>& observed_resource_ids) {
165   const bool is_replicate = llvm::isa<tf_device::ReplicateOp>(op);
166   auto result = op->walk([&](Operation* inner_op) {
167     for (Value operand : inner_op->getOperands()) {
168       Operation* def = operand.getDefiningOp();
169       // Operands may not have a defining op (BlockArgument) or is from a
170       // different block.
171       if (!def || def->getBlock() != block) continue;
172 
173       if (cluster_ops.count(def) != 0 || preceding_users.count(def) != 0) {
174         // Op is a user of a cluster or another op that is a user of the
175         // cluster (transitively), but is before the cluster.
176         return WalkResult::interrupt();
177       }
178     }
179 
180     // Don't visit replicate op inner op operands as new resource
181     // values/arguments may have been created but are not known in
182     // `resource_alias_analysis`.
183     if (is_replicate && inner_op != op) return WalkResult::advance();
184 
185     // Check for uses of any resource in or after cluster.
186     for (Value operand : TF::filter_resources(inner_op->getOperands())) {
187       if (resource_alias_analysis.IsUnknownResource(operand)) continue;
188       auto ids = resource_alias_analysis.GetResourceUniqueIds(operand);
189       for (const auto& id : ids)
190         if (observed_resource_ids.contains(id)) return WalkResult::interrupt();
191     }
192     return WalkResult::advance();
193   });
194 
195   return result.wasInterrupted();
196 }
197 
198 // Collects ops that are before ops in the cluster but are users of other ops
199 // in the cluster. This may happen because users of individual ops in the
200 // cluster may be interleaved with other ops in the cluster. Resource id's are
201 // also captured, to keep track of resource usage before, in, or after the
202 // cluster.
203 // TODO(b/175701589): Extend this to handle all side effecting ops while
204 // handling transitive data dependencies.
CollectClusterPrecedingUsers(Block * block,const ClusterOps & cluster_ops,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)205 llvm::SmallSetVector<Operation*, 8> CollectClusterPrecedingUsers(
206     Block* block, const ClusterOps& cluster_ops,
207     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
208   llvm::SmallSetVector<Operation*, 8> preceding_users;
209   llvm::SmallDenseSet<int64_t> observed_resource_ids;
210 
211   auto front = Block::iterator(cluster_ops.front());
212   auto back = Block::iterator(cluster_ops.back());
213   for (Operation& op : llvm::make_range(front, back)) {
214     if (cluster_ops.contains(&op)) {
215       CollectResourceIdsFromOp(op, resource_alias_analysis,
216                                observed_resource_ids);
217     } else if (ShouldMoveOpAfterCluster(
218                    block, &op, cluster_ops, preceding_users,
219                    resource_alias_analysis, observed_resource_ids)) {
220       preceding_users.insert(&op);
221       CollectResourceIdsFromOp(op, resource_alias_analysis,
222                                observed_resource_ids);
223     }
224   }
225 
226   return preceding_users;
227 }
228 
229 // Collects results and associated types of the cluster that are used outside of
230 // the cluster. These results and types are used to create the clusters
231 // `tf_device.cluster` and associated terminator. Results that have no uses
232 // outside of the cluster (i.e. results of ops in the cluster are only consumed
233 // by other ops in the cluster) are pruned.
CollectClusterResults(Block * block,const ClusterOps & cluster_ops)234 llvm::SmallVector<Value, 8> CollectClusterResults(
235     Block* block, const ClusterOps& cluster_ops) {
236   llvm::SmallVector<Value, 8> results;
237 
238   for (Operation* op : cluster_ops) {
239     for (Value result : op->getResults()) {
240       for (Operation* user : result.getUsers()) {
241         // Check if user is not an op in the cluster.
242         if (cluster_ops.count(block->findAncestorOpInBlock(*user)) == 0) {
243           results.push_back(result);
244           break;
245         }
246       }
247     }
248   }
249 
250   return results;
251 }
252 
253 // Creates a `tf_device.cluster` to wrap cluster ops.
CreateClusterOp(Block * block,const ClusterOps & cluster_ops,llvm::ArrayRef<Value> results,llvm::ArrayRef<Operation * > preceding_users)254 tf_device::ClusterOp CreateClusterOp(
255     Block* block, const ClusterOps& cluster_ops, llvm::ArrayRef<Value> results,
256     llvm::ArrayRef<Operation*> preceding_users) {
257   // `tf_device.cluster` will be placed at where the last op of the cluster is.
258   Operation* last_cluster_op = cluster_ops.back();
259   OpBuilder builder(last_cluster_op);
260 
261   llvm::SmallVector<Type, 8> result_types;
262   for (Value result : results) result_types.push_back(result.getType());
263   auto cluster = builder.create<tf_device::ClusterOp>(last_cluster_op->getLoc(),
264                                                       result_types);
265 
266   Block* body = new Block;
267   cluster.body().push_back(body);
268 
269   // Move cluster ops to the cluster body. Also remove `_tpu_replicate` and
270   // `device` attribute from ops in the cluster when that information is
271   // redundant will the `tf_device.cluster`. Do this for all ops including
272   // nested ops.
273   for (Operation* cluster_op : cluster_ops) {
274     cluster_op->moveBefore(body, body->end());
275     cluster_op->walk([&](Operation* inner_op) {
276       inner_op->removeAttr(kTPUReplicateAttr);
277 
278       if (auto attr = inner_op->getAttrOfType<StringAttr>(kDeviceAttr)) {
279         // Preserve device attribute if the op is placed on a replicated core
280         // device. Device attribute is used to infer the appropriate sharding
281         // within TPUs for this op.
282         // TODO(b/183598857): Use explicit sharding ops from the front-end.
283         // For example, dequeue ops generated by
284         // tensorflow/python/tpu/tpu_feed.py
285         if (!tensorflow::IsTPUReplicatedCore(attr.getValue())) {
286           inner_op->removeAttr(kDeviceAttr);
287         }
288       }
289     });
290   }
291 
292   // Add terminator.
293   builder.setInsertionPointToEnd(body);
294   builder.create<tf_device::ReturnOp>(last_cluster_op->getLoc(), results);
295 
296   // Replaces uses of cluster ops results outside of cluster with the associated
297   // `tf_device.cluster` results.
298   for (auto ret_vals : llvm::zip(results, cluster.getResults())) {
299     Value old_ret = std::get<0>(ret_vals);
300     Value new_ret = std::get<1>(ret_vals);
301     for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) {
302       Operation* user = use.getOwner();
303       if (!body->findAncestorOpInBlock(*user)) use.set(new_ret);
304     }
305   }
306 
307   // Move users of cluster that are before the cluster to after the cluster.
308   Operation* op_after_cluster = cluster.getOperation()->getNextNode();
309   for (Operation* user : preceding_users) user->moveBefore(op_after_cluster);
310   return cluster;
311 }
312 
313 // Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index`
314 // of -1 are always after ops with a non negative `index`, and an arbitrary
315 // ordering is used as there are no dependencies on their relative ordering. If
316 // there are multiple `tf.TPUReplicatedInput` ops with the same non negative
317 // index or if indices are less than -1, an error will be returned.
SortTPUReplicatedInputsByIndex(llvm::ArrayRef<Operation * > inputs,llvm::SmallVectorImpl<Operation * > * sorted_inputs)318 LogicalResult SortTPUReplicatedInputsByIndex(
319     llvm::ArrayRef<Operation*> inputs,
320     llvm::SmallVectorImpl<Operation*>* sorted_inputs) {
321   llvm::SmallDenseSet<int64_t, 8> unique_indices;
322   for (Operation* input : inputs) {
323     int64_t index = llvm::cast<TF::TPUReplicatedInputOp>(input).index();
324     if (index < -1)
325       return input->emitOpError()
326              << "requires index to be at least -1, but got " << index;
327     if (index == -1) continue;
328     if (!unique_indices.insert(index).second)
329       return input->emitOpError()
330              << "requires indices to be unique, but found multiple '"
331              << input->getName() << "' ops with index " << index;
332   }
333 
334   // Sort all TPUReplicatedInputs by `index` attribute to have
335   // TPUReplicatedInputs with indices be added to the `tf_device.replicate` op
336   // deterministically. If `index` attribute is -1, instead move them to the
337   // end.
338   sorted_inputs->assign(inputs.begin(), inputs.end());
339   std::stable_sort(
340       sorted_inputs->begin(), sorted_inputs->end(),
341       [](Operation* l, Operation* r) {
342         int64_t l_index = llvm::cast<TF::TPUReplicatedInputOp>(l).index();
343         int64_t r_index = llvm::cast<TF::TPUReplicatedInputOp>(r).index();
344         if (l_index == -1 && r_index != -1) return false;
345         if (r_index == -1 && l_index != -1) return true;
346         return l_index < r_index;
347       });
348 
349   return success();
350 }
351 
352 // Creates a `tf_device.replicate` to represent replication for the cluster, if
353 // necessary.
ReplicateCluster(tf_device::ClusterOp cluster,int num_replicas,int num_cores_per_replica)354 LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas,
355                                int num_cores_per_replica) {
356   // No need to replicate.
357   if (num_replicas == 1) return success();
358 
359   if (num_replicas < 1)
360     return cluster.emitError() << "requires '" << kNumReplicasAttr
361                                << "' int attribute to be at least 1";
362 
363   LogicalResult status = success();
364   // Collect all used TPUReplicatedInput ops and sort by `index`.
365   llvm::SmallSetVector<Operation*, 8> unique_replicated_input_ops;
366   mlir::visitUsedValuesDefinedAbove(
367       cluster.body(), cluster.body(), [&](mlir::OpOperand* operand) {
368         Operation* def = operand->get().getDefiningOp();
369         if (llvm::isa_and_nonnull<TF::TPUReplicatedInputOp>(def))
370           unique_replicated_input_ops.insert(def);
371         // When model parallelism is used in conjunction with data parallelism
372         // for resource inputs, we need to collect the per replica resource
373         // inputs from input to `tf.TPUPartitionedInput` ops.
374         if (auto pi = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(def)) {
375           if (pi->getNumOperands() != num_cores_per_replica)
376             status = pi.emitOpError()
377                      << "requires " << num_cores_per_replica
378                      << " operands but found " << pi->getNumOperands();
379           for (auto operand : pi.inputs()) {
380             if (llvm::isa_and_nonnull<TF::TPUReplicatedInputOp>(
381                     operand.getDefiningOp()))
382               unique_replicated_input_ops.insert(operand.getDefiningOp());
383           }
384         }
385       });
386 
387   if (failed(status)) return failure();
388   llvm::SmallVector<Operation*, 8> replicated_input_ops;
389   if (failed(SortTPUReplicatedInputsByIndex(
390           unique_replicated_input_ops.getArrayRef(), &replicated_input_ops)))
391     return failure();
392 
393   // Index attribute value stored on TPUReplicatedInput op. These will be used
394   // later for dynamic padder.
395   llvm::SmallVector<int64_t, 8> replicated_input_indices;
396   llvm::SmallVector<int64_t, 8> packed_input_indices;
397   bool has_replicated_input_index = false;
398 
399   // Indices of the replicate op's arguments that are mirrored variables.
400   llvm::SmallVector<int64_t, 8> mirrored_variable_indices;
401 
402   // Check if number of operands of each used TPUReplicatedInput op matches
403   // `num_replicas` or 1. Collect all their operands and associated type for
404   // creating the replicate op.
405   llvm::SmallVector<std::pair<ValueRange, Type>, 8> replicated_inputs;
406   llvm::SmallVector<Value, 8> packed_inputs;
407   llvm::SmallVector<Operation*, 8> replicated_ops;
408   llvm::SmallVector<Operation*, 8> packed_ops;
409   for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) {
410     auto input = pos_and_input.value();
411     bool is_packed = llvm::cast<TF::TPUReplicatedInputOp>(input).is_packed();
412     const int num_operands = input->getNumOperands();
413     int num_inputs = is_packed ? 1 : num_replicas;
414     if (num_operands != num_inputs)
415       return input->emitOpError() << "requires " << num_inputs << " operands";
416 
417     auto tpu_replicated_input = llvm::cast<TF::TPUReplicatedInputOp>(input);
418     int64_t tpu_replicated_input_index = tpu_replicated_input.index();
419     if (is_packed) {
420       packed_inputs.push_back(input->getOperand(0));
421       packed_input_indices.push_back(tpu_replicated_input_index);
422       packed_ops.push_back(input);
423     } else {
424       replicated_inputs.push_back(
425           {input->getOperands(), input->getOperand(0).getType()});
426       replicated_input_indices.push_back(tpu_replicated_input_index);
427       replicated_ops.push_back(input);
428     }
429     if (tpu_replicated_input_index != -1) has_replicated_input_index = true;
430 
431     if (tpu_replicated_input.is_mirrored_variable())
432       mirrored_variable_indices.push_back(pos_and_input.index());
433   }
434 
435   replicated_input_indices.append(packed_input_indices.begin(),
436                                   packed_input_indices.end());
437 
438   // Create replicate op.
439   OpBuilder builder(cluster);
440   auto replicate_op = builder.create<tf_device::ReplicateOp>(
441       cluster.getLoc(), num_replicas,
442       llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
443       replicated_inputs, packed_inputs, cluster.getResultTypes());
444   if (has_replicated_input_index)
445     replicate_op->setAttr(kReplicatedInputIndicesAttr,
446                           builder.getI64ArrayAttr(replicated_input_indices));
447 
448   if (!mirrored_variable_indices.empty())
449     replicate_op->setAttr(kMirroredVariableIndicesAttr,
450                           builder.getI64ArrayAttr(mirrored_variable_indices));
451 
452   // Replace replicated cluster results with replicate op results.
453   for (auto result_and_idx : llvm::enumerate(cluster.getResults())) {
454     Value result = result_and_idx.value();
455     int idx = result_and_idx.index();
456     auto replicate_outputs = llvm::make_range(
457         std::next(replicate_op.result_begin(), idx * num_replicas),
458         std::next(replicate_op.result_begin(), (idx + 1) * num_replicas));
459 
460     for (auto& use : llvm::make_early_inc_range(result.getUses())) {
461       Operation* def = use.getOwner();
462       if (!llvm::isa<TF::TPUReplicatedOutputOp>(def)) {
463         // If user is not a `tf.TPUReplicatedOutput`, simply forward the first
464         // replica output. Certain Graphs under V1 create `tf.Identity` users of
465         // replicated ops to pin the TPU computation for execution.
466         use.set(*replicate_outputs.begin());
467         continue;
468       }
469 
470       const int def_num_results = def->getNumResults();
471       if (def_num_results != num_replicas)
472         return def->emitOpError() << "requires " << num_replicas << " results";
473 
474       def->replaceAllUsesWith(replicate_outputs);
475     }
476   }
477 
478   // Collect all `tf.TPUPartitionedInput` ops to be moved inside the
479   // `tf_device.replicate` later.
480   llvm::SmallSet<Operation*, 4> partitioned_inputs;
481   // Update replicated inputs with replicate op block arguments.
482   auto ordered_tpu_replicate_inputs =
483       llvm::concat<Operation*>(replicated_ops, packed_ops);
484   for (auto input_and_block_arg :
485        llvm::zip(ordered_tpu_replicate_inputs,
486                  replicate_op.GetBody().getArguments())) {
487     Operation* input = std::get<0>(input_and_block_arg);
488     Value block_arg = std::get<1>(input_and_block_arg);
489     mlir::replaceAllUsesInRegionWith(input->getResult(0), block_arg,
490                                      cluster.body());
491     // Update replicated input use in tf.TPUPartitionedInput op.
492     for (auto& use : input->getUses()) {
493       auto pi = llvm::dyn_cast<TF::TPUPartitionedInputOp>(use.getOwner());
494       if (pi) {
495         pi.setOperand(use.getOperandNumber(), block_arg);
496         partitioned_inputs.insert(pi.getOperation());
497       }
498     }
499   }
500 
501   // Create terminator for replicate op and move `tf_device.cluster` and
502   // `tf.TPUPartitionedInput`(s) into replicate body.
503   builder.setInsertionPointToEnd(&replicate_op.GetBody());
504   auto return_op = builder.create<tf_device::ReturnOp>(replicate_op.getLoc(),
505                                                        cluster.getResults());
506   for (auto pi : partitioned_inputs) pi->moveBefore(return_op);
507 
508   cluster.getOperation()->moveBefore(return_op);
509 
510   return success();
511 }
512 
513 // Forms clusters with ops of the same `_tpu_replicate` attribute under a block.
514 //
515 // For a given block, clusters are formed via grouping ops by `_tpu_replicate`
516 // attributes.
517 // For every cluster formed:
518 //   1. Find associated TPUReplicateMetadata attributes with the same
519 //      `_tpu_replicate` attribute.
520 //   2. Find users not in cluster that are interleaved between cluster ops.
521 //   3. Find external uses of cluster ops.
522 //   4. Create `tf_device.cluster` with results consisting of the external uses
523 //      of cluster ops determined at 3.
524 //   5. Move cluster ops to `tf_device.cluster` body.
525 //   6. Replace external uses of cluster ops uses with `tf_device.cluster`
526 //      results.
527 //   7. Move users from 2 to after the `tf_device.cluster`.
528 //   8. Wrap cluster (`tf_device.cluster`) in a `tf_device.replicate` if
529 //      attribute `num_replicas` is greater than 1.
530 //   9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`.
FormClustersInBlock(Block * block,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)531 LogicalResult FormClustersInBlock(
532     Block* block,
533     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
534   MetadataMap metadata_map;
535   LogicalResult result = CollectMetadata(block, &metadata_map);
536   if (failed(result)) return result;
537 
538   // If there is no TPUReplicateMetadata op in this block, process blocks in
539   // regions attached to the op's in the block.
540   if (metadata_map.empty()) {
541     for (Operation& op : *block) {
542       for (Region& region : op.getRegions()) {
543         if (!llvm::hasSingleElement(region))
544           return op.emitOpError("Expected single block region");
545         if (failed(
546                 FormClustersInBlock(&region.front(), resource_alias_analysis)))
547           return failure();
548       }
549     }
550     return success();
551   }
552 
553   ClusterMap clusters;
554   result = CollectAndGroupClusterOps(block, &clusters);
555   if (failed(result)) return result;
556 
557   for (const auto& cluster_metadata_and_ops : clusters) {
558     const auto& cluster_ops = cluster_metadata_and_ops.getSecond();
559 
560     auto cluster_metadata =
561         metadata_map.find(cluster_metadata_and_ops.getFirst());
562 
563     // No TPUReplicateMetadata for a `_tpu_replicate` attribute.
564     if (cluster_metadata == metadata_map.end()) {
565       cluster_ops.front()->emitWarning()
566           << "TPUReplicateMetadata for associated '" << kTPUReplicateAttr
567           << "' attribute '" << cluster_metadata_and_ops.getFirst()
568           << "' is missing";
569       continue;
570     }
571 
572     llvm::SmallSetVector<Operation*, 8> preceding_users =
573         CollectClusterPrecedingUsers(block, cluster_ops,
574                                      resource_alias_analysis);
575 
576     llvm::SmallVector<Value, 8> results =
577         CollectClusterResults(block, cluster_ops);
578 
579     tf_device::ClusterOp cluster = CreateClusterOp(
580         block, cluster_ops, results, preceding_users.getArrayRef());
581 
582     auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr);
583     if (!num_replicas || !num_replicas.isa<mlir::IntegerAttr>())
584       return cluster.emitError()
585              << "requires '" << kNumReplicasAttr << "' int attribute";
586 
587     int num_cores_per_replica = 1;
588     auto num_cores_per_replica_attr =
589         cluster_metadata->getSecond()
590             .get(kNumCoresPerReplicaAttr)
591             .dyn_cast_or_null<mlir::IntegerAttr>();
592     if (num_cores_per_replica_attr)
593       num_cores_per_replica = num_cores_per_replica_attr.getInt();
594 
595     if (failed(ReplicateCluster(cluster,
596                                 num_replicas.cast<mlir::IntegerAttr>().getInt(),
597                                 num_cores_per_replica)))
598       return failure();
599 
600     // Copy TPUReplicateMetadata attributes to `tf_device.cluster`.
601     cluster->setAttrs(
602         cluster_metadata->second.getDictionary(cluster.getContext()));
603     // Exclude `num_replicas` as cluster should be replicated if necessary.
604     cluster->removeAttr(kNumReplicasAttr);
605   }
606 
607   return success();
608 }
609 
FormClustersInFunction(FuncOp func,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)610 LogicalResult FormClustersInFunction(
611     FuncOp func,
612     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
613   if (!llvm::hasSingleElement(func))
614     return func.emitOpError("Expecting a single block function");
615 
616   if (failed(FormClustersInBlock(&func.front(), resource_alias_analysis)))
617     return failure();
618 
619   // Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
620   auto remove_result = func.walk([&](Operation* op) {
621     if (!llvm::isa<TF::TPUReplicatedInputOp, TF::TPUReplicatedOutputOp>(op))
622       return WalkResult::advance();
623 
624     // Forward operand to result. When `num_replicas` attribute is 1, no
625     // `tf_device.replicate` is created and replicated (1) operands/results are
626     // untouched.
627     if (op->getNumOperands() == 1 && op->getNumResults() == 1)
628       op->getResult(0).replaceAllUsesWith(op->getOperand(0));
629 
630     // Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of
631     // `num_replicas` to 1.
632     if (!op->use_empty()) {
633       op->emitOpError() << "is expected to have no uses, but it is operand#"
634                         << op->use_begin()->getOperandNumber() << " of "
635                         << *op->use_begin()->getOwner();
636       return WalkResult::interrupt();
637     }
638 
639     op->erase();
640 
641     return WalkResult::advance();
642   });
643 
644   return failure(remove_result.wasInterrupted());
645 }
646 
runOnOperation()647 void TPUClusterFormationPass::runOnOperation() {
648   auto& resource_alias_analysis = getAnalysis<TF::ResourceAliasAnalysis>();
649   for (auto func : getOperation().getOps<FuncOp>())
650     if (!func.isExternal() &&
651         failed(FormClustersInFunction(
652             func, resource_alias_analysis.GetAnalysisForFunc(func))))
653       return signalPassFailure();
654 }
655 }  // anonymous namespace
656 
CreateTPUClusterFormationPass()657 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass() {
658   return std::make_unique<TPUClusterFormationPass>();
659 }
660 
661 }  // namespace TFTPU
662 }  // namespace mlir
663