• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h"
17 
18 #include "llvm/ADT/DenseSet.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/Operation.h"  // from @llvm-project
27 #include "mlir/IR/Value.h"  // from @llvm-project
28 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
29 
30 #define DEBUG_TYPE "cluster-ops-by-policy"
31 
32 namespace mlir {
33 namespace TFDevice {
34 
35 // -------------------------------------------------------------------------- //
36 // ValueConstraint.
37 // -------------------------------------------------------------------------- //
38 
Merge(ValueConstraint a,ValueConstraint b)39 ValueConstraint Merge(ValueConstraint a, ValueConstraint b) {
40   return a > b ? a : b;
41 }
42 
IsStaticallyResolved(Value value,ValueConstraint constraint)43 LogicalResult IsStaticallyResolved(Value value, ValueConstraint constraint) {
44   // Resolve constraints inferred from the tensor type.
45   if (auto tensor = value.getType().dyn_cast<TensorType>()) {
46     if (constraint == ValueConstraint::kRank && tensor.hasRank())
47       return success();
48     if (constraint == ValueConstraint::kShape && tensor.hasStaticShape())
49       return success();
50   }
51 
52   return failure();
53 }
54 
operator <<(raw_ostream & os,const ValueConstraint & constraint)55 raw_ostream &operator<<(raw_ostream &os, const ValueConstraint &constraint) {
56   auto str = [](ValueConstraint constraint) -> StringRef {
57     switch (constraint) {
58       case ValueConstraint::kRank:
59         return "rank";
60       case ValueConstraint::kShape:
61         return "shape";
62       case ValueConstraint::kValue:
63         return "value";
64       default:
65         llvm_unreachable("unknown value constraint");
66     }
67   };
68 
69   os << str(constraint);
70   return os;
71 }
72 
73 // -------------------------------------------------------------------------- //
74 // ValuesConstraintSet.
75 // -------------------------------------------------------------------------- //
76 
Insert(ValueRange values,ValueConstraint constraint)77 void ValuesConstraintSet::Insert(ValueRange values,
78                                  ValueConstraint constraint) {
79   for (Value value : values) Insert(value, constraint);
80 }
81 
Insert(Value value,ValueConstraint constraint)82 std::pair<ValueConstraint, bool> ValuesConstraintSet::Insert(
83     Value value, ValueConstraint constraint) {
84   auto emplaced = constraints_.try_emplace(value, constraint);
85   ValueConstraint persisted = emplaced.first->getSecond();
86 
87   // We've just inserted a new constraint for the value.
88   if (emplaced.second) return {persisted, true};
89 
90   // Update existing constraint with a new one.
91   auto merged = Merge(constraint, persisted);
92   return {constraints_[value] = merged, merged != persisted};
93 }
94 
Walk(llvm::function_ref<void (Value,ValueConstraint)> walk) const95 void ValuesConstraintSet::Walk(
96     llvm::function_ref<void(Value, ValueConstraint)> walk) const {
97   for (auto &kv : constraints_) walk(kv.getFirst(), kv.getSecond());
98 }
99 
GetConstraint(Value value) const100 Optional<ValueConstraint> ValuesConstraintSet::GetConstraint(
101     Value value) const {
102   auto it = constraints_.find(value);
103   if (it == constraints_.end()) return None;
104   return it->getSecond();
105 }
106 
HasConstraint(Value value) const107 bool ValuesConstraintSet::HasConstraint(Value value) const {
108   return GetConstraint(value).has_value();
109 }
110 
MergeAll(const ValuesConstraintSet & other)111 void ValuesConstraintSet::MergeAll(const ValuesConstraintSet &other) {
112   other.Walk([this](Value value, ValueConstraint constraint) {
113     Insert(value, constraint);
114   });
115 }
116 
Resolve()117 ValuesConstraintSet &ValuesConstraintSet::Resolve() {
118   llvm::SmallDenseSet<Value, 4> resolved;
119   Walk([&](Value value, ValueConstraint constraint) {
120     if (succeeded(IsStaticallyResolved(value, constraint)))
121       resolved.insert(value);
122   });
123   for (Value value : resolved) constraints_.erase(value);
124   return *this;
125 }
126 
Reset()127 ValuesConstraintSet &ValuesConstraintSet::Reset() {
128   constraints_.clear();
129   return *this;
130 }
131 
Size() const132 size_t ValuesConstraintSet::Size() const { return constraints_.size(); }
133 
Empty() const134 bool ValuesConstraintSet::Empty() const { return constraints_.empty(); }
135 
136 // -------------------------------------------------------------------------- //
137 // Discovering clusters of operations based on the policy.
138 // -------------------------------------------------------------------------- //
139 
140 namespace {
141 constexpr char kDeviceAttr[] = "device";
142 
143 // A type that abstracts over types that have uses accessible via `getUses`.
144 using Source = PointerUnion<Operation *, BlockArgument *>;
145 
146 // We use union-find algorithm to build clusters of connected operations based
147 // on the user provided policy. If an operation can be clustered (one of the
148 // user provided policies accepts it under given constraints), it will become
149 // a "member" that will participate in the union-find cluster construction.
150 //
151 // A block argument can also become a member (or even a root member), however
152 // only operations will become a part of the outline `tf_device.cluster`, block
153 // arguments will stay as block arguments, and will later become cluster
154 // function inputs.
155 struct Member {
Membermlir::TFDevice::__anon501e02370411::Member156   Member(unsigned root, Source source, Operation *insertion_point,
157          ValuesConstraintSet constraints = {})
158       : root(root),
159         source(source),
160         insertion_point(insertion_point),
161         constraints(constraints) {}
162 
163   unsigned root;
164   Source source;
165 
166   // After construction:
167   //  For basic block argument source this will be a first operation in the
168   //  basic block, and for operations it will be an operation iself.
169   //
170   // During the union-find cluster formation:
171   //  The root member will have the location in the basic block, where the
172   //  cluster operation will be inserted. We use the location of the last
173   //  operation in the cluster, so that during cluster construction we can
174   //  ensure that all operands are above the insertion point, and all users are
175   //  below the insertion point.
176   //
177   // Example:
178   //
179   //   %0 = "clustered_op"(...)
180   //   %1 = "non_clustered_op"(...)
181   //   %2 = "clustered_op"(%1)              <<<--- insert cluster here
182   //   %3 = "cluster_result_user"(%1, %2)
183   //
184   // By using `%2` location as an insertion point we ensure that all operands
185   // (%1 in this example) dominate the cluster operation, and that the cluster
186   // operation dominates all the users (%3 in this example).
187   Operation *insertion_point;
188 
189   // After construction:
190   //   A set of constraints on the clustered operation operands that must be
191   //   satisfied in order to add operation to the cluster. For basic block
192   //   source this will be always empty.
193   //
194   // During the union-find cluster formation:
195   //   The root member will have constraints merged from all of the cluster
196   //   members.
197   ValuesConstraintSet constraints;
198 };
199 
200 using Members = llvm::SmallVector<Member>;
201 
202 struct ClusteringState {
203   // Storage backing an array based union-find algorithm for clustering. Index
204   // in this vector is the member id.
205   llvm::SmallVector<Member> members;
206 
207   // Mapping from the member operation (block argument) to the member id.
208   llvm::SmallDenseMap<Source, unsigned> member_ids;
209 
210   // Puts `a` and `b` members into the same cluster if it is possible. Returns
211   // success if union operation was completed successfully, otherwise returns
212   // failure.
213   //
214   // Members can be clustered together:
215   //   1. This will not break dominance property of the IR.
216   //   2. New clustering policy constraints can be propagated through the
217   //      already clustered operations.
218   LogicalResult Union(unsigned a, unsigned b,
219                       const ClusteringPolicySet &policies);
220 
221   bool IsMember(Operation *op) const;
222   unsigned FindRoot(unsigned id);
223 
224   // Verifies that merging `src_root` cluster with a `dst_root` cluster, and
225   // inserting it at `insertion_point` location will not break the dominance
226   // property: all users of the `src_root` cluster results must be below the
227   // insertion point in the block.
228   LogicalResult VerifyDominanceProperty(unsigned src_root, unsigned dst_root,
229                                         Operation *insertion_point);
230 
231   // Verifies that all constraints on the values defined by the `dst_root`
232   // cluster can be propagated through the nodes in the `src_root` cluster, and
233   // updates `src_root` constraints on success.
234   LogicalResult VerifyValueConstraints(unsigned src_root, unsigned dst_root,
235                                        const ClusteringPolicySet &policies);
236 };
237 
238 }  // namespace
239 
IsMember(Operation * op) const240 bool ClusteringState::IsMember(Operation *op) const {
241   return member_ids.find(op) != member_ids.end();
242 }
243 
FindRoot(unsigned id)244 unsigned ClusteringState::FindRoot(unsigned id) {
245   if (members[id].root == id) return id;
246   return members[id].root = FindRoot(members[id].root);
247 }
248 
VerifyDominanceProperty(unsigned src_root,unsigned dst_root,Operation * insertion_point)249 LogicalResult ClusteringState::VerifyDominanceProperty(
250     unsigned src_root, unsigned dst_root, Operation *insertion_point) {
251   // TODO(ezhulenev): Optimize this linear scan with a map lookup.
252   for (auto &member : members) {
253     unsigned root = FindRoot(member.root);
254     if (root != src_root) continue;
255 
256     // Block arguments do not really participate in clustering, they are only
257     // used to connect independent operation using the same argument.
258     if (member.source.is<BlockArgument *>()) continue;
259 
260     Operation *op = member.source.dyn_cast<Operation *>();
261     assert(op && "member operation must be not null");
262 
263     for (Operation *user : op->getUsers()) {
264       // Skip users in other blocks.
265       if (user->getBlock() != op->getBlock()) continue;
266 
267       // Skip users is in the `dst_root` or `src_root` clusters, if we'll merge
268       // roots they'll become a single cluster and will not violate the
269       // dominance property after that.
270       auto it = member_ids.find(user);
271       if (it != member_ids.end() && (FindRoot(it->getSecond()) == dst_root ||
272                                      FindRoot(it->getSecond()) == src_root))
273         continue;
274 
275       if (user->isBeforeInBlock(insertion_point)) {
276         LLVM_DEBUG(llvm::dbgs()
277                        << "  Failure: user is before the insertion point: "
278                        << *user << "\n";);
279         return failure();
280       }
281     }
282   }
283 
284   return success();
285 }
286 
VerifyValueConstraints(unsigned src_root,unsigned dst_root,const ClusteringPolicySet & policies)287 LogicalResult ClusteringState::VerifyValueConstraints(
288     unsigned src_root, unsigned dst_root, const ClusteringPolicySet &policies) {
289   // Propagate constraints only through operations in the `src_root` cluster.
290   auto filter = [&](Operation *op) -> bool {
291     auto it = member_ids.find(op);
292     return it != member_ids.end() && FindRoot(it->getSecond()) == src_root;
293   };
294 
295   // Start from all operations in the `src_root` cluster.
296   llvm::SmallVector<Operation *> worklist;
297   for (Member &member : members)
298     if (Operation *op = member.source.dyn_cast<Operation *>())
299       if (FindRoot(member.root) == src_root) worklist.emplace_back(op);
300 
301   // Collect `dst_root` constraints that are applicable to the values defined in
302   // the `src_root` cluster.
303   ValuesConstraintSet constraints = members[src_root].constraints;
304   members[dst_root].constraints.Walk([&](Value v, ValueConstraint constraint) {
305     Operation *op = v.getDefiningOp();
306     if (op && filter(op)) constraints.Insert(v, constraint);
307   });
308 
309   // Update `src_root` constraints only if we can propagate them.
310   if (succeeded(PropagateValuesConstraints(worklist, filter, policies,
311                                            constraints, /*resolve=*/true))) {
312     members[src_root].constraints = constraints;
313     return success();
314   }
315 
316   return failure();
317 }
318 
Union(unsigned a,unsigned b,const ClusteringPolicySet & policies)319 LogicalResult ClusteringState::Union(unsigned a, unsigned b,
320                                      const ClusteringPolicySet &policies) {
321   unsigned a_root = FindRoot(a);
322   unsigned b_root = FindRoot(b);
323 
324   // Already members of the same cluster.
325   if (a_root == b_root) return failure();
326 
327   // Verify that merging two clusters will not break dominance property.
328   Operation *a_insertion_point = members[a_root].insertion_point;
329   Operation *b_insertion_point = members[b_root].insertion_point;
330   bool a_is_before_b = a_insertion_point->isBeforeInBlock(b_insertion_point);
331 
332   // Use clusters position in the block to select merging src and dst.
333   unsigned src_root = a_is_before_b ? a_root : b_root;  // merge `src_root` ...
334   unsigned dst_root = a_is_before_b ? b_root : a_root;  // ... into `dst_root`
335   Operation *insertion_point =
336       a_is_before_b ? b_insertion_point : a_insertion_point;
337 
338   // Print operations in the `root` cluster to debug stream.
339   auto debug_clustered_ops = [&](unsigned root) {
340     for (Member &member : members)
341       if (FindRoot(member.root) == root) {
342         if (auto *op = member.source.dyn_cast<Operation *>()) {
343           llvm::dbgs() << "  " << *op << "\n";
344         } else if (auto *arg = member.source.dyn_cast<BlockArgument *>()) {
345           llvm::dbgs() << "  " << *arg;
346         }
347       }
348   };
349   (void)debug_clustered_ops;
350 
351   LLVM_DEBUG({
352     llvm::dbgs() << "\n\n--- Try to merge cluster:\n";
353     debug_clustered_ops(src_root);
354     llvm::dbgs() << "\n--- With cluster:\n";
355     debug_clustered_ops(dst_root);
356     LLVM_DEBUG(llvm::dbgs() << "\n--- Diagnostics:\n");
357   });
358 
359   // Check if merging `src_root` with `dst_root` will not violate SSA dominance
360   // property (all operands before the cluster, all results after the cluster).
361   if (failed(VerifyDominanceProperty(src_root, dst_root, insertion_point)))
362     return failure();
363 
364   // Check if `dst_root` constraints can be propagated to the `src_root`
365   // constraints.
366   if (failed(VerifyValueConstraints(src_root, dst_root, policies)))
367     return failure();
368 
369   // Set `dst_root` as a new root for `src_root`.
370   members[src_root].root = dst_root;
371   // Update insertion point of the new root.
372   members[dst_root].insertion_point = insertion_point;
373   // Merge all constraints from `src_root` into `dst_root`.
374   members[dst_root].constraints.MergeAll(members[src_root].constraints);
375 
376   LLVM_DEBUG(llvm::dbgs() << "  Clusters successfully merged\n");
377 
378   return success();
379 }
380 
381 // Returns constraints on the operands specified by the clustering policy if the
382 // operation can be clustered (constraints could be empty). Otherwise return
383 // empty optional.
CanBeClustered(Operation * op,const ClusteringPolicySet & policies,const std::function<bool (Operation * op)> & filter)384 static Optional<ValuesConstraintSet> CanBeClustered(
385     Operation *op, const ClusteringPolicySet &policies,
386     const std::function<bool(Operation *op)> &filter) {
387   // Check that op has no side effects. This guarantees that we will not
388   // reorder side-effecting ops during cluster formation.
389   if (!MemoryEffectOpInterface::hasNoEffect(op)) return llvm::None;
390 
391   // Operation rejected by the custom filter.
392   if (filter && !filter(op)) return llvm::None;
393 
394   // Initially we do not have any constraints on the operation results.
395   ValuesConstraintSet result_constraints;
396 
397   for (auto &policy : policies.policies()) {
398     ValuesConstraintSet operands_constraints;
399     if (succeeded(policy->MatchAndUpdateConstraints(op, result_constraints,
400                                                     operands_constraints)))
401       return operands_constraints.Resolve();
402   }
403 
404   return llvm::None;
405 }
406 
407 // Compute initial clustering state based on the clustering polocy.
InitializeClusteringState(Block * block,const ClusteringPolicySet & policies,const std::function<bool (Operation * op)> & filter)408 static ClusteringState InitializeClusteringState(
409     Block *block, const ClusteringPolicySet &policies,
410     const std::function<bool(Operation *op)> &filter) {
411   ClusteringState state;
412 
413   // Create members for all block arguments.
414   for (BlockArgument &arg : block->getArguments()) {
415     if (!arg.getUsers().empty())
416       state.members.emplace_back(state.members.size(), &arg, &block->front());
417   }
418 
419   int num_bbarg_members = state.members.size();
420   (void)num_bbarg_members;
421 
422   // Create members for operations that can be clustered based on the policy.
423   for (Operation &op : block->getOperations()) {
424     if (auto constraints = CanBeClustered(&op, policies, filter))
425       state.members.emplace_back(state.members.size(), &op, &op, *constraints);
426   }
427 
428   // Initialize mapping from the member operation (block argument) to the id.
429   for (auto &tuple : llvm::enumerate(state.members)) {
430     state.member_ids.try_emplace(tuple.value().source, tuple.index());
431   }
432 
433   LLVM_DEBUG(llvm::dbgs() << "Found "
434                           << (state.members.size() - num_bbarg_members)
435                           << " clustering candidate operations in the block\n");
436 
437   return state;
438 }
439 
440 // Users of the `source` that are candidates for clustering.
GetClusteringCandidates(const ClusteringState & state,Source source)441 static llvm::SmallVector<Operation *> GetClusteringCandidates(
442     const ClusteringState &state, Source source) {
443   // Users of operation result must be in the same block and placed on the same
444   // device.
445   if (auto op = source.dyn_cast<Operation *>()) {
446     auto range = llvm::make_filter_range(op->getUsers(), [&](Operation *user) {
447       bool same_block = user->getBlock() == op->getBlock();
448       bool same_device = op->getAttr(kDeviceAttr) == user->getAttr(kDeviceAttr);
449       return same_block && same_device && state.IsMember(user);
450     });
451     return {range.begin(), range.end()};
452   }
453 
454   // Users of block argument must be in the same block.
455   if (auto arg = source.dyn_cast<BlockArgument *>()) {
456     auto range = llvm::make_filter_range(arg->getUsers(), [&](Operation *user) {
457       bool same_block = user->getBlock() == arg->getOwner();
458       return same_block && state.IsMember(user);
459     });
460     return {range.begin(), range.end()};
461   }
462 
463   llvm_unreachable("Unexpected type in the union.");
464 }
465 
466 // Cluster members with their result users. Returns `true` if merged at least a
467 // pair of members into a new cluster.
RunClusteringPass(ClusteringState & state,const ClusteringPolicySet & policies)468 static bool RunClusteringPass(ClusteringState &state,
469                               const ClusteringPolicySet &policies) {
470   bool clustered = false;
471 
472   for (auto &tuple : llvm::enumerate(state.members)) {
473     size_t member_id = tuple.index();
474     Member &member = tuple.value();
475 
476     llvm::SmallVector<Operation *> users =
477         GetClusteringCandidates(state, member.source);
478 
479     // Process candidates according to their order in the block to minimize
480     // the number of dominance property violations.
481     llvm::sort(users, [](auto *a, auto *b) { return a->isBeforeInBlock(b); });
482 
483     for (Operation *user : users) {
484       auto user_member_id = state.member_ids.lookup(user);
485       if (succeeded(state.Union(member_id, user_member_id, policies)))
486         clustered = true;
487     }
488   }
489 
490   return clustered;
491 }
492 
FindClustersInTheBlock(Block * block,const ClusteringPolicySet & policies,std::function<bool (Operation * op)> filter)493 llvm::SmallVector<Cluster> FindClustersInTheBlock(
494     Block *block, const ClusteringPolicySet &policies,
495     std::function<bool(Operation *op)> filter) {
496   // It is impossible to build a cluster in the empty block.
497   if (block->empty()) return {};
498 
499   ClusteringState state = InitializeClusteringState(block, policies, filter);
500 
501   // Run clustering passes until the convergence. Limit the number of iterations
502   // to guard from the infinite loop in presence of bugs.
503   constexpr int max_iterations = 100;
504   for (unsigned i = 0; i < max_iterations; ++i)
505     if (!RunClusteringPass(state, policies)) break;
506 
507   // Form clusters found by the union-find algorithm.
508   llvm::DenseMap<unsigned, Cluster> root_clusters;
509 
510   for (Member &member : state.members) {
511     unsigned root = state.FindRoot(member.root);
512     Cluster &cluster = root_clusters.FindAndConstruct(root).getSecond();
513 
514     // If member is a root of the cluster, copy inferred constraints.
515     if (state.FindRoot(member.root) == member.root)
516       cluster.constraints = std::move(member.constraints);
517 
518     // Add operation to the cluster.
519     if (auto op = member.source.dyn_cast<Operation *>())
520       cluster.operations.emplace_back(op);
521   }
522 
523   llvm::SmallVector<Cluster> clusters;
524   for (auto &kv : root_clusters) {
525     Cluster &cluster = kv.getSecond();
526     // Skip degenerate clusters formed by a single basic block argument.
527     if (!cluster.operations.empty()) clusters.emplace_back(std::move(cluster));
528   }
529 
530   LLVM_DEBUG(llvm::dbgs() << "Found " << clusters.size() << " clusters\n");
531 
532   return clusters;
533 }
534 
535 // -------------------------------------------------------------------------- //
536 // Create `tf_device.cluster` operation from the discovered ops cluster.
537 // -------------------------------------------------------------------------- //
538 
CreateClusterOp(Cluster & cluster,StringAttr policy)539 tf_device::ClusterOp CreateClusterOp(Cluster &cluster, StringAttr policy) {
540   // Find all the values that are used outside of the cluster. These values
541   // will be returned from the created cluster operation.
542   llvm::DenseSet<Operation *> in_cluster;
543   for (Operation *op : cluster.operations) in_cluster.insert(op);
544 
545   llvm::SetVector<Value> return_values;
546   llvm::SmallVector<Type> return_types;
547 
548   for (Operation *op : cluster.operations)
549     for (OpOperand &use : op->getUses()) {
550       // User is inside the cluster.
551       if (in_cluster.contains(use.getOwner())) continue;
552       // Do not return the same value multiple times.
553       if (return_values.contains(use.get())) continue;
554 
555       return_values.insert(use.get());
556       return_types.emplace_back(use.get().getType());
557     }
558 
559   // Sort matched operations by their position in the block.
560   llvm::sort(cluster.operations, [](Operation *a, Operation *b) -> bool {
561     return a->isBeforeInBlock(b);
562   });
563 
564   // Create tf_device::ClusterOp before the last operation in the block that
565   // is a part of a match set.
566   auto back = cluster.operations.back();
567   auto loc = back->getLoc();
568   OpBuilder builder(back);
569 
570   auto cluster_op =
571       builder.create<tf_device::ClusterOp>(loc, return_types, policy);
572 
573   // Create block in cluster_op's region and move 'cluster.operations' into
574   // it.
575   auto block = builder.createBlock(&cluster_op.body());
576   auto block_end = block->end();
577   for (auto op : cluster.operations) op->moveBefore(block, block_end);
578 
579   // Add 'tf_device::ReturnOp' at the end of the block.
580   builder.setInsertionPointToEnd(block);
581   builder.create<tf_device::ReturnOp>(loc, return_values.getArrayRef());
582 
583   // Set device attribute
584   if (auto device = back->getAttr(kDeviceAttr))
585     cluster_op->setAttr(kDeviceAttr, device);
586 
587   // Update all users of the operations moved into the cluster region.
588   for (auto tuple : llvm::zip(return_values, cluster_op.getResults())) {
589     Value old_value = std::get<0>(tuple);
590     Value new_value = std::get<1>(tuple);
591     old_value.replaceUsesWithIf(new_value, [&](OpOperand &operand) -> bool {
592       // Do not update users in the same cluster.
593       return operand.getOwner()->getBlock() != block;
594     });
595   }
596 
597   return cluster_op;
598 }
599 
600 // -------------------------------------------------------------------------- //
601 // Helper functions for value constraints propagations and analysis.
602 // -------------------------------------------------------------------------- //
603 
PropagateValuesConstraints(llvm::ArrayRef<Operation * > root,std::function<bool (Operation *)> filter,const ClusteringPolicySet & policies,ValuesConstraintSet & constraints,bool resolve,bool emit_remarks)604 mlir::LogicalResult PropagateValuesConstraints(
605     llvm::ArrayRef<Operation *> root, std::function<bool(Operation *)> filter,
606     const ClusteringPolicySet &policies, ValuesConstraintSet &constraints,
607     bool resolve, bool emit_remarks) {
608   // A set of constraints for operation results.
609   llvm::DenseMap<Operation *, ValuesConstraintSet> op_results_constraints;
610   assert(filter && "filter predicate must be defined");
611 
612   // Use initial constraints to initialize op results constraints.
613   for (std::pair<Value, ValueConstraint> pair : constraints) {
614     Value value = pair.first;
615     ValueConstraint constraint = pair.second;
616 
617     // Value must be defined by an operation and accepted by the filter.
618     Operation *op = value.getDefiningOp();
619     if (!op || !filter(op)) continue;
620 
621     op_results_constraints[op].Insert(value, constraint);
622   }
623 
624   // Keep a worklist of operations that need their constraints to be updated.
625   llvm::SetVector<Operation *> worklist;
626   for (Operation *op : root) worklist.insert(op);
627 
628   while (!worklist.empty()) {
629     Operation *op = worklist.pop_back_val();
630 
631     // Use results constraints to infer operands constraints.
632     const ValuesConstraintSet &results = op_results_constraints[op];
633     ValuesConstraintSet operands;
634 
635     // Walk through all policies until we find one that matches the operation.
636     bool updated = false;
637     for (auto &policy : policies.policies()) {
638       auto matched =
639           policy->MatchAndUpdateConstraints(op, results, operands.Reset());
640       if (succeeded(matched)) {
641         updated = true;
642         break;
643       }
644     }
645 
646     // Signal a failure if could not propagate non-empty constraints on the
647     // operation results to the operands.
648     if (!updated && !results.Empty()) {
649       if (emit_remarks) {
650         std::string err_msg;
651         llvm::raw_string_ostream os(err_msg);
652         for (unsigned i = 0; i < op->getNumResults(); ++i)
653           os << " " << i << ":" << results.GetConstraint(op->getResult(i));
654         op->emitError(llvm::formatv(
655             "failed to propagate results constraints:{0}", os.str()));
656       }
657       return failure();
658     }
659 
660     // Update results constraints based on inferred operands constraints.
661     operands.Walk([&](Value value, ValueConstraint constraint) {
662       // Resolve constraint based on the static type information.
663       if (resolve && succeeded(IsStaticallyResolved(value, constraint))) return;
664 
665       // Update constraint for a value.
666       auto updated = constraints.Insert(value, constraint);
667       if (!updated.second) return;
668 
669       // Maybe update constaint on the operation result, but do not follow
670       // operations that are not accepted by the filter predicate.
671       Operation *op = value.getDefiningOp();
672       if (!op || !filter(op)) return;
673 
674       // Add updated operation to the worklist.
675       auto inserted = op_results_constraints[op].Insert(value, updated.first);
676       if (inserted.second) worklist.insert(op);
677     });
678   }
679 
680   return success();
681 }
682 
PropagateValuesConstraints(mlir::Region & region,const ClusteringPolicySet & policies,ValuesConstraintSet & constraints,bool resolve,bool emit_remarks)683 mlir::LogicalResult PropagateValuesConstraints(
684     mlir::Region &region, const ClusteringPolicySet &policies,
685     ValuesConstraintSet &constraints, bool resolve, bool emit_remarks) {
686   // Propagate constraints for all operations in the region.
687   llvm::SmallVector<Operation *> worklist;
688   region.walk([&](Operation *op) { worklist.emplace_back(op); });
689 
690   // Propagate constraints only through operations inside the `region`.
691   auto filter = [&](Operation *op) -> bool {
692     return region.findAncestorBlockInRegion(*op->getBlock());
693   };
694 
695   return PropagateValuesConstraints(worklist, filter, policies, constraints,
696                                     resolve, emit_remarks);
697 }
698 
EmitValueConstraintsRemarks(const ValuesConstraintSet & constraints)699 void EmitValueConstraintsRemarks(const ValuesConstraintSet &constraints) {
700   constraints.Walk([](Value value, ValueConstraint constraint) {
701     for (OpOperand &operand : value.getUses())
702       operand.getOwner()->emitRemark(
703           llvm::formatv("operand #{0} constrained to: {1}",
704                         operand.getOperandNumber(), constraint));
705   });
706 }
707 
EmitInputsConstraintsRemarks(func::FuncOp func,const ValuesConstraintSet & constraints)708 void EmitInputsConstraintsRemarks(func::FuncOp func,
709                                   const ValuesConstraintSet &constraints) {
710   constraints.Walk([&](Value value, ValueConstraint constraint) {
711     if (auto arg = value.dyn_cast<BlockArgument>())
712       if (arg.getOwner() == &func.getBody().front())
713         func.emitRemark(llvm::formatv("input #{0} constrained to: {1}",
714                                       arg.getArgNumber(), constraint));
715   });
716 }
717 
InferFunctionBodyValuesConstraints(func::FuncOp func,ValuesConstraintSet & constraints)718 LogicalResult InferFunctionBodyValuesConstraints(
719     func::FuncOp func, ValuesConstraintSet &constraints) {
720   for (unsigned i = 0; i < func.getNumResults(); ++i) {
721     auto str = func.getResultAttrOfType<StringAttr>(i, "tf.constraint");
722     if (!str) continue;
723 
724     ValueConstraint constraint = StringSwitch<ValueConstraint>(str.getValue())
725                                      .Case("rank", ValueConstraint::kRank)
726                                      .Case("shape", ValueConstraint::kShape)
727                                      .Case("value", ValueConstraint::kValue);
728 
729     // Propagate constraints through function return operations.
730     for (Block &block : func.getBody()) {
731       func::ReturnOp ret = dyn_cast<func::ReturnOp>(block.back());
732       if (ret) constraints.Insert(ret.getOperand(i), constraint);
733     }
734   }
735 
736   return success();
737 }
738 
739 }  // namespace TFDevice
740 }  // namespace mlir
741