• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h"
17 
18 #include "llvm/ADT/DenseSet.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/IR/Value.h"  // from @llvm-project
27 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
28 
29 #define DEBUG_TYPE "cluster-ops-by-policy"
30 
31 namespace mlir {
32 namespace TFDevice {
33 
34 // -------------------------------------------------------------------------- //
35 // ValueConstraint.
36 // -------------------------------------------------------------------------- //
37 
Merge(ValueConstraint a,ValueConstraint b)38 ValueConstraint Merge(ValueConstraint a, ValueConstraint b) {
39   return a > b ? a : b;
40 }
41 
IsStaticallyResolved(Value value,ValueConstraint constraint)42 LogicalResult IsStaticallyResolved(Value value, ValueConstraint constraint) {
43   // Resolve constraints inferred from the tensor type.
44   if (auto tensor = value.getType().dyn_cast<TensorType>()) {
45     if (constraint == ValueConstraint::kRank && tensor.hasRank())
46       return success();
47     if (constraint == ValueConstraint::kShape && tensor.hasStaticShape())
48       return success();
49   }
50 
51   return failure();
52 }
53 
operator <<(raw_ostream & os,const ValueConstraint & constraint)54 raw_ostream &operator<<(raw_ostream &os, const ValueConstraint &constraint) {
55   auto str = [](ValueConstraint constraint) -> StringRef {
56     switch (constraint) {
57       case ValueConstraint::kRank:
58         return "rank";
59       case ValueConstraint::kShape:
60         return "shape";
61       case ValueConstraint::kValue:
62         return "value";
63       default:
64         llvm_unreachable("unknown value constraint");
65     }
66   };
67 
68   os << str(constraint);
69   return os;
70 }
71 
72 // -------------------------------------------------------------------------- //
73 // ValuesConstraintSet.
74 // -------------------------------------------------------------------------- //
75 
Insert(ValueRange values,ValueConstraint constraint)76 void ValuesConstraintSet::Insert(ValueRange values,
77                                  ValueConstraint constraint) {
78   for (Value value : values) Insert(value, constraint);
79 }
80 
Insert(Value value,ValueConstraint constraint)81 std::pair<ValueConstraint, bool> ValuesConstraintSet::Insert(
82     Value value, ValueConstraint constraint) {
83   auto emplaced = constraints_.try_emplace(value, constraint);
84   ValueConstraint persisted = emplaced.first->getSecond();
85 
86   // We've just inserted a new constraint for the value.
87   if (emplaced.second) return {persisted, true};
88 
89   // Update existing constraint with a new one.
90   auto merged = Merge(constraint, persisted);
91   return {constraints_[value] = merged, merged != persisted};
92 }
93 
Walk(llvm::function_ref<void (Value,ValueConstraint)> walk) const94 void ValuesConstraintSet::Walk(
95     llvm::function_ref<void(Value, ValueConstraint)> walk) const {
96   for (auto &kv : constraints_) walk(kv.getFirst(), kv.getSecond());
97 }
98 
GetConstraint(Value value) const99 Optional<ValueConstraint> ValuesConstraintSet::GetConstraint(
100     Value value) const {
101   auto it = constraints_.find(value);
102   if (it == constraints_.end()) return None;
103   return it->getSecond();
104 }
105 
HasConstraint(Value value) const106 bool ValuesConstraintSet::HasConstraint(Value value) const {
107   return GetConstraint(value).hasValue();
108 }
109 
MergeAll(const ValuesConstraintSet & other)110 void ValuesConstraintSet::MergeAll(const ValuesConstraintSet &other) {
111   other.Walk([this](Value value, ValueConstraint constraint) {
112     Insert(value, constraint);
113   });
114 }
115 
Resolve()116 ValuesConstraintSet &ValuesConstraintSet::Resolve() {
117   llvm::SmallDenseSet<Value, 4> resolved;
118   Walk([&](Value value, ValueConstraint constraint) {
119     if (succeeded(IsStaticallyResolved(value, constraint)))
120       resolved.insert(value);
121   });
122   for (Value value : resolved) constraints_.erase(value);
123   return *this;
124 }
125 
Reset()126 ValuesConstraintSet &ValuesConstraintSet::Reset() {
127   constraints_.clear();
128   return *this;
129 }
130 
Size() const131 size_t ValuesConstraintSet::Size() const { return constraints_.size(); }
132 
Empty() const133 bool ValuesConstraintSet::Empty() const { return constraints_.empty(); }
134 
135 // -------------------------------------------------------------------------- //
136 // Discovering clusters of operations based on the policy.
137 // -------------------------------------------------------------------------- //
138 
139 namespace {
140 constexpr char kDeviceAttr[] = "device";
141 
142 // A type that abstracts over types that have uses accessible via `getUses`.
143 using Source = PointerUnion<Operation *, BlockArgument *>;
144 
145 // We use union-find algorithm to build clusters of connected operations based
146 // on the user provided policy. If an operation can be clustered (one of the
147 // user provided policies accepts it under given constraints), it will become
148 // a "member" that will participate in the union-find cluster construction.
149 //
150 // A block argument can also become a member (or even a root member), however
151 // only operations will become a part of the outline `tf_device.cluster`, block
152 // arguments will stay as block arguments, and will later become cluster
153 // function inputs.
154 struct Member {
Membermlir::TFDevice::__anonaac0da7e0411::Member155   Member(unsigned root, Source source, Operation *insertion_point,
156          ValuesConstraintSet constraints = {})
157       : root(root),
158         source(source),
159         insertion_point(insertion_point),
160         constraints(constraints) {}
161 
162   unsigned root;
163   Source source;
164 
165   // After construction:
166   //  For basic block argument source this will be a first operation in the
167   //  basic block, and for operations it will be an operation iself.
168   //
169   // During the union-find cluster formation:
170   //  The root member will have the location in the basic block, where the
171   //  cluster operation will be inserted. We use the location of the last
172   //  operation in the cluster, so that during cluster construction we can
173   //  ensure that all operands are above the insertion point, and all users are
174   //  below the insertion point.
175   //
176   // Example:
177   //
178   //   %0 = "clustered_op"(...)
179   //   %1 = "non_clustered_op"(...)
180   //   %2 = "clustered_op"(%1)              <<<--- insert cluster here
181   //   %3 = "cluster_result_user"(%1, %2)
182   //
183   // By using `%2` location as an insertion point we ensure that all operands
184   // (%1 in this example) dominate the cluster operation, and that the cluster
185   // operation dominates all the users (%3 in this example).
186   Operation *insertion_point;
187 
188   // After construction:
189   //   A set of constraints on the clustered operation operands that must be
190   //   satisfied in order to add operation to the cluster. For basic block
191   //   source this will be always empty.
192   //
193   // During the union-find cluster formation:
194   //   The root member will have constraints merged from all of the cluster
195   //   members.
196   ValuesConstraintSet constraints;
197 };
198 
199 using Members = llvm::SmallVector<Member>;
200 
201 struct ClusteringState {
202   // Storage backing an array based union-find algorithm for clustering. Index
203   // in this vector is the member id.
204   llvm::SmallVector<Member> members;
205 
206   // Mapping from the member operation (block argument) to the member id.
207   llvm::SmallDenseMap<Source, unsigned> member_ids;
208 
209   // Puts `a` and `b` members into the same cluster if it is possible. Returns
210   // success if union operation was completed successfully, otherwise returns
211   // failure.
212   //
213   // Members can be clustered together:
214   //   1. This will not break dominance property of the IR.
215   //   2. New clustering policy constraints can be propagated through the
216   //      already clustered operations.
217   LogicalResult Union(unsigned a, unsigned b,
218                       const ClusteringPolicySet &policies);
219 
220   bool IsMember(Operation *op) const;
221   unsigned FindRoot(unsigned id);
222 
223   // Verifies that merging `src_root` cluster with a `dst_root` cluster, and
224   // inserting it at `insertion_point` location will not break the dominance
225   // property: all users of the `src_root` cluster results must be below the
226   // insertion point in the block.
227   LogicalResult VerifyDominanceProperty(unsigned src_root, unsigned dst_root,
228                                         Operation *insertion_point);
229 
230   // Verifies that all constraints on the values defined by the `dst_root`
231   // cluster can be propagated through the nodes in the `src_root` cluster, and
232   // updates `src_root` constraints on success.
233   LogicalResult VerifyValueConstraints(unsigned src_root, unsigned dst_root,
234                                        const ClusteringPolicySet &policies);
235 };
236 
237 }  // namespace
238 
IsMember(Operation * op) const239 bool ClusteringState::IsMember(Operation *op) const {
240   return member_ids.find(op) != member_ids.end();
241 }
242 
FindRoot(unsigned id)243 unsigned ClusteringState::FindRoot(unsigned id) {
244   if (members[id].root == id) return id;
245   return members[id].root = FindRoot(members[id].root);
246 }
247 
VerifyDominanceProperty(unsigned src_root,unsigned dst_root,Operation * insertion_point)248 LogicalResult ClusteringState::VerifyDominanceProperty(
249     unsigned src_root, unsigned dst_root, Operation *insertion_point) {
250   // TODO(ezhulenev): Optimize this linear scan with a map lookup.
251   for (auto &member : members) {
252     unsigned root = FindRoot(member.root);
253     if (root != src_root) continue;
254 
255     // Block arguments do not really participate in clustering, they are only
256     // used to connect independent operation using the same argument.
257     if (member.source.is<BlockArgument *>()) continue;
258 
259     Operation *op = member.source.dyn_cast<Operation *>();
260     assert(op && "member operation must be not null");
261 
262     for (Operation *user : op->getUsers()) {
263       // Skip users in other blocks.
264       if (user->getBlock() != op->getBlock()) continue;
265 
266       // Skip users is in the `dst_root` or `src_root` clusters, if we'll merge
267       // roots they'll become a single cluster and will not violate the
268       // dominance property after that.
269       auto it = member_ids.find(user);
270       if (it != member_ids.end() && (FindRoot(it->getSecond()) == dst_root ||
271                                      FindRoot(it->getSecond()) == src_root))
272         continue;
273 
274       if (user->isBeforeInBlock(insertion_point)) {
275         LLVM_DEBUG(llvm::dbgs()
276                        << "  Failure: user is before the insertion point: "
277                        << *user << "\n";);
278         return failure();
279       }
280     }
281   }
282 
283   return success();
284 }
285 
VerifyValueConstraints(unsigned src_root,unsigned dst_root,const ClusteringPolicySet & policies)286 LogicalResult ClusteringState::VerifyValueConstraints(
287     unsigned src_root, unsigned dst_root, const ClusteringPolicySet &policies) {
288   // Propagate constraints only through operations in the `src_root` cluster.
289   auto filter = [&](Operation *op) -> bool {
290     auto it = member_ids.find(op);
291     return it != member_ids.end() && FindRoot(it->getSecond()) == src_root;
292   };
293 
294   // Start from all operations in the `src_root` cluster.
295   llvm::SmallVector<Operation *> worklist;
296   for (Member &member : members)
297     if (Operation *op = member.source.dyn_cast<Operation *>())
298       if (FindRoot(member.root) == src_root) worklist.emplace_back(op);
299 
300   // Collect `dst_root` constraints that are applicable to the values defined in
301   // the `src_root` cluster.
302   ValuesConstraintSet constraints = members[src_root].constraints;
303   members[dst_root].constraints.Walk([&](Value v, ValueConstraint constraint) {
304     Operation *op = v.getDefiningOp();
305     if (op && filter(op)) constraints.Insert(v, constraint);
306   });
307 
308   // Update `src_root` constraints only if we can propagate them.
309   if (succeeded(PropagateValuesConstraints(worklist, filter, policies,
310                                            constraints, /*resolve=*/true))) {
311     members[src_root].constraints = constraints;
312     return success();
313   }
314 
315   return failure();
316 }
317 
Union(unsigned a,unsigned b,const ClusteringPolicySet & policies)318 LogicalResult ClusteringState::Union(unsigned a, unsigned b,
319                                      const ClusteringPolicySet &policies) {
320   unsigned a_root = FindRoot(a);
321   unsigned b_root = FindRoot(b);
322 
323   // Already members of the same cluster.
324   if (a_root == b_root) return failure();
325 
326   // Verify that merging two clusters will not break dominance property.
327   Operation *a_insertion_point = members[a_root].insertion_point;
328   Operation *b_insertion_point = members[b_root].insertion_point;
329   bool a_is_before_b = a_insertion_point->isBeforeInBlock(b_insertion_point);
330 
331   // Use clusters position in the block to select merging src and dst.
332   unsigned src_root = a_is_before_b ? a_root : b_root;  // merge `src_root` ...
333   unsigned dst_root = a_is_before_b ? b_root : a_root;  // ... into `dst_root`
334   Operation *insertion_point =
335       a_is_before_b ? b_insertion_point : a_insertion_point;
336 
337   // Print operations in the `root` cluster to debug stream.
338   auto debug_clustered_ops = [&](unsigned root) {
339     for (Member &member : members)
340       if (FindRoot(member.root) == root) {
341         if (auto *op = member.source.dyn_cast<Operation *>()) {
342           llvm::dbgs() << "  " << *op << "\n";
343         } else if (auto *arg = member.source.dyn_cast<BlockArgument *>()) {
344           llvm::dbgs() << "  " << *arg;
345         }
346       }
347   };
348   (void)debug_clustered_ops;
349 
350   LLVM_DEBUG({
351     llvm::dbgs() << "\n\n--- Try to merge cluster:\n";
352     debug_clustered_ops(src_root);
353     llvm::dbgs() << "\n--- With cluster:\n";
354     debug_clustered_ops(dst_root);
355     LLVM_DEBUG(llvm::dbgs() << "\n--- Diagnostics:\n");
356   });
357 
358   // Check if merging `src_root` with `dst_root` will not violate SSA dominance
359   // property (all operands before the cluster, all results after the cluster).
360   if (failed(VerifyDominanceProperty(src_root, dst_root, insertion_point)))
361     return failure();
362 
363   // Check if `dst_root` constraints can be propagated to the `src_root`
364   // constraints.
365   if (failed(VerifyValueConstraints(src_root, dst_root, policies)))
366     return failure();
367 
368   // Set `dst_root` as a new root for `src_root`.
369   members[src_root].root = dst_root;
370   // Update insertion point of the new root.
371   members[dst_root].insertion_point = insertion_point;
372   // Merge all constraints from `src_root` into `dst_root`.
373   members[dst_root].constraints.MergeAll(members[src_root].constraints);
374 
375   LLVM_DEBUG(llvm::dbgs() << "  Clusters successfully merged\n");
376 
377   return success();
378 }
379 
380 // Returns constraints on the operands specified by the clustering policy if the
381 // operation can be clustered (constraints could be empty). Otherwise return
382 // empty optional.
CanBeClustered(Operation * op,const ClusteringPolicySet & policies,const std::function<bool (Operation * op)> & filter)383 static Optional<ValuesConstraintSet> CanBeClustered(
384     Operation *op, const ClusteringPolicySet &policies,
385     const std::function<bool(Operation *op)> &filter) {
386   // Check that op has no side effects. This guarantees that we will not
387   // reorder side-effecting ops during cluster formation.
388   if (!MemoryEffectOpInterface::hasNoEffect(op)) return llvm::None;
389 
390   // Operation rejected by the custom filter.
391   if (filter && !filter(op)) return llvm::None;
392 
393   // Initially we do not have any constraints on the operation results.
394   ValuesConstraintSet result_constraints;
395 
396   for (auto &policy : policies.policies()) {
397     ValuesConstraintSet operands_constraints;
398     if (succeeded(policy->MatchAndUpdateConstraints(op, result_constraints,
399                                                     operands_constraints)))
400       return operands_constraints.Resolve();
401   }
402 
403   return llvm::None;
404 }
405 
406 // Compute initial clustering state based on the clustering polocy.
InitializeClusteringState(Block * block,const ClusteringPolicySet & policies,const std::function<bool (Operation * op)> & filter)407 static ClusteringState InitializeClusteringState(
408     Block *block, const ClusteringPolicySet &policies,
409     const std::function<bool(Operation *op)> &filter) {
410   ClusteringState state;
411 
412   // Create members for all block arguments.
413   for (BlockArgument &arg : block->getArguments()) {
414     if (!arg.getUsers().empty())
415       state.members.emplace_back(state.members.size(), &arg, &block->front());
416   }
417 
418   int num_bbarg_members = state.members.size();
419   (void)num_bbarg_members;
420 
421   // Create members for operations that can be clustered based on the policy.
422   for (Operation &op : block->getOperations()) {
423     if (auto constraints = CanBeClustered(&op, policies, filter))
424       state.members.emplace_back(state.members.size(), &op, &op, *constraints);
425   }
426 
427   // Initialize mapping from the member operation (block argument) to the id.
428   for (auto &tuple : llvm::enumerate(state.members)) {
429     state.member_ids.try_emplace(tuple.value().source, tuple.index());
430   }
431 
432   LLVM_DEBUG(llvm::dbgs() << "Found "
433                           << (state.members.size() - num_bbarg_members)
434                           << " clustering candidate operations in the block\n");
435 
436   return state;
437 }
438 
439 // Users of the `source` that are candidates for clustering.
GetClusteringCandidates(const ClusteringState & state,Source source)440 static llvm::SmallVector<Operation *> GetClusteringCandidates(
441     const ClusteringState &state, Source source) {
442   // Users of operation result must be in the same block and placed on the same
443   // device.
444   if (auto op = source.dyn_cast<Operation *>()) {
445     auto range = llvm::make_filter_range(op->getUsers(), [&](Operation *user) {
446       bool same_block = user->getBlock() == op->getBlock();
447       bool same_device = op->getAttr(kDeviceAttr) == user->getAttr(kDeviceAttr);
448       return same_block && same_device && state.IsMember(user);
449     });
450     return {range.begin(), range.end()};
451   }
452 
453   // Users of block argument must be in the same block.
454   if (auto arg = source.dyn_cast<BlockArgument *>()) {
455     auto range = llvm::make_filter_range(arg->getUsers(), [&](Operation *user) {
456       bool same_block = user->getBlock() == arg->getOwner();
457       return same_block && state.IsMember(user);
458     });
459     return {range.begin(), range.end()};
460   }
461 
462   llvm_unreachable("Unexpected type in the union.");
463 }
464 
465 // Cluster members with their result users. Returns `true` if merged at least a
466 // pair of members into a new cluster.
RunClusteringPass(ClusteringState & state,const ClusteringPolicySet & policies)467 static bool RunClusteringPass(ClusteringState &state,
468                               const ClusteringPolicySet &policies) {
469   bool clustered = false;
470 
471   for (auto &tuple : llvm::enumerate(state.members)) {
472     size_t member_id = tuple.index();
473     Member &member = tuple.value();
474 
475     llvm::SmallVector<Operation *> users =
476         GetClusteringCandidates(state, member.source);
477 
478     // Process candidates according to their order in the block to minimize
479     // the number of dominance property violations.
480     llvm::sort(users, [](auto *a, auto *b) { return a->isBeforeInBlock(b); });
481 
482     for (Operation *user : users) {
483       auto user_member_id = state.member_ids.lookup(user);
484       if (succeeded(state.Union(member_id, user_member_id, policies)))
485         clustered = true;
486     }
487   }
488 
489   return clustered;
490 }
491 
FindClustersInTheBlock(Block * block,const ClusteringPolicySet & policies,std::function<bool (Operation * op)> filter)492 llvm::SmallVector<Cluster> FindClustersInTheBlock(
493     Block *block, const ClusteringPolicySet &policies,
494     std::function<bool(Operation *op)> filter) {
495   // It is impossible to build a cluster in the empty block.
496   if (block->empty()) return {};
497 
498   ClusteringState state = InitializeClusteringState(block, policies, filter);
499 
500   // Run clustering passes until the convergence. Limit the number of iterations
501   // to guard from the infinite loop in presence of bugs.
502   constexpr int max_iterations = 100;
503   for (unsigned i = 0; i < max_iterations; ++i)
504     if (!RunClusteringPass(state, policies)) break;
505 
506   // Form clusters found by the union-find algorithm.
507   llvm::DenseMap<unsigned, Cluster> root_clusters;
508 
509   for (Member &member : state.members) {
510     unsigned root = state.FindRoot(member.root);
511     Cluster &cluster = root_clusters.FindAndConstruct(root).getSecond();
512 
513     // If member is a root of the cluster, copy inferred constraints.
514     if (state.FindRoot(member.root) == member.root)
515       cluster.constraints = std::move(member.constraints);
516 
517     // Add operation to the cluster.
518     if (auto op = member.source.dyn_cast<Operation *>())
519       cluster.operations.emplace_back(op);
520   }
521 
522   llvm::SmallVector<Cluster> clusters;
523   for (auto &kv : root_clusters) {
524     Cluster &cluster = kv.getSecond();
525     // Skip degenerate clusters formed by a single basic block argument.
526     if (!cluster.operations.empty()) clusters.emplace_back(std::move(cluster));
527   }
528 
529   LLVM_DEBUG(llvm::dbgs() << "Found " << clusters.size() << " clusters\n");
530 
531   return clusters;
532 }
533 
534 // -------------------------------------------------------------------------- //
535 // Create `tf_device.cluster` operation from the discovered ops cluster.
536 // -------------------------------------------------------------------------- //
537 
CreateClusterOp(Cluster & cluster,StringAttr policy)538 tf_device::ClusterOp CreateClusterOp(Cluster &cluster, StringAttr policy) {
539   // Find all the values that are used outside of the cluster. These values
540   // will be returned from the created cluster operation.
541   llvm::DenseSet<Operation *> in_cluster;
542   for (Operation *op : cluster.operations) in_cluster.insert(op);
543 
544   llvm::SetVector<Value> return_values;
545   llvm::SmallVector<Type> return_types;
546 
547   for (Operation *op : cluster.operations)
548     for (OpOperand &use : op->getUses()) {
549       // User is inside the cluster.
550       if (in_cluster.contains(use.getOwner())) continue;
551       // Do not return the same value multiple times.
552       if (return_values.contains(use.get())) continue;
553 
554       return_values.insert(use.get());
555       return_types.emplace_back(use.get().getType());
556     }
557 
558   // Sort matched operations by their position in the block.
559   llvm::sort(cluster.operations, [](Operation *a, Operation *b) -> bool {
560     return a->isBeforeInBlock(b);
561   });
562 
563   // Create tf_device::ClusterOp before the last operation in the block that
564   // is a part of a match set.
565   auto back = cluster.operations.back();
566   auto loc = back->getLoc();
567   OpBuilder builder(back);
568 
569   auto cluster_op =
570       builder.create<tf_device::ClusterOp>(loc, return_types, policy);
571 
572   // Create block in cluster_op's region and move 'cluster.operations' into
573   // it.
574   auto block = builder.createBlock(&cluster_op.body());
575   auto block_end = block->end();
576   for (auto op : cluster.operations) op->moveBefore(block, block_end);
577 
578   // Add 'tf_device::ReturnOp' at the end of the block.
579   builder.setInsertionPointToEnd(block);
580   builder.create<tf_device::ReturnOp>(loc, return_values.getArrayRef());
581 
582   // Set device attribute
583   if (auto device = back->getAttr(kDeviceAttr))
584     cluster_op->setAttr(kDeviceAttr, device);
585 
586   // Update all users of the operations moved into the cluster region.
587   for (auto tuple : llvm::zip(return_values, cluster_op.getResults())) {
588     Value old_value = std::get<0>(tuple);
589     Value new_value = std::get<1>(tuple);
590     old_value.replaceUsesWithIf(new_value, [&](OpOperand &operand) -> bool {
591       // Do not update users in the same cluster.
592       return operand.getOwner()->getBlock() != block;
593     });
594   }
595 
596   return cluster_op;
597 }
598 
599 // -------------------------------------------------------------------------- //
600 // Helper functions for value constraints propagations and analysis.
601 // -------------------------------------------------------------------------- //
602 
PropagateValuesConstraints(llvm::ArrayRef<Operation * > root,std::function<bool (Operation *)> filter,const ClusteringPolicySet & policies,ValuesConstraintSet & constraints,bool resolve)603 mlir::LogicalResult PropagateValuesConstraints(
604     llvm::ArrayRef<Operation *> root, std::function<bool(Operation *)> filter,
605     const ClusteringPolicySet &policies, ValuesConstraintSet &constraints,
606     bool resolve) {
607   // A set of constraints for operation results.
608   llvm::DenseMap<Operation *, ValuesConstraintSet> op_results_constraints;
609   assert(filter && "filter predicate must be defined");
610 
611   // Use initial constraints to initialize op results constraints.
612   for (std::pair<Value, ValueConstraint> pair : constraints) {
613     Value value = pair.first;
614     ValueConstraint constraint = pair.second;
615 
616     // Value must be defined by an operation and accepted by the filter.
617     Operation *op = value.getDefiningOp();
618     if (!op || !filter(op)) continue;
619 
620     op_results_constraints[op].Insert(value, constraint);
621   }
622 
623   // Keep a worklist of operations that need their constraints to be updated.
624   llvm::SetVector<Operation *> worklist;
625   for (Operation *op : root) worklist.insert(op);
626 
627   while (!worklist.empty()) {
628     Operation *op = worklist.pop_back_val();
629 
630     // Use results constraints to infer operands constraints.
631     const ValuesConstraintSet &results = op_results_constraints[op];
632     ValuesConstraintSet operands;
633 
634     // Walk through all policies until we find one that matches the operation.
635     bool updated = false;
636     for (auto &policy : policies.policies()) {
637       auto matched =
638           policy->MatchAndUpdateConstraints(op, results, operands.Reset());
639       if (succeeded(matched)) {
640         updated = true;
641         break;
642       }
643     }
644 
645     // Signal a failure if could not propagate non-empty constraints on the
646     // operation results to the operands.
647     if (!updated && !results.Empty()) {
648       op->emitError("failed to propagate results constraints");
649       return failure();
650     }
651 
652     // Update results constraints based on inferred operands constraints.
653     operands.Walk([&](Value value, ValueConstraint constraint) {
654       // Resolve constraint based on the static type information.
655       if (resolve && succeeded(IsStaticallyResolved(value, constraint))) return;
656 
657       // Update constraint for a value.
658       auto updated = constraints.Insert(value, constraint);
659       if (!updated.second) return;
660 
661       // Maybe update constaint on the operation result, but do not follow
662       // operations that are not accepted by the filter predicate.
663       Operation *op = value.getDefiningOp();
664       if (!op || !filter(op)) return;
665 
666       // Add updated operation to the worklist.
667       auto inserted = op_results_constraints[op].Insert(value, updated.first);
668       if (inserted.second) worklist.insert(op);
669     });
670   }
671 
672   return success();
673 }
674 
PropagateValuesConstraints(mlir::Region & region,const ClusteringPolicySet & policies,ValuesConstraintSet & constraints,bool resolve)675 mlir::LogicalResult PropagateValuesConstraints(
676     mlir::Region &region, const ClusteringPolicySet &policies,
677     ValuesConstraintSet &constraints, bool resolve) {
678   // Propagate constraints for all operations in the region.
679   llvm::SmallVector<Operation *> worklist;
680   region.walk([&](Operation *op) { worklist.emplace_back(op); });
681 
682   // Propagate constraints only through operations inside the `region`.
683   auto filter = [&](Operation *op) -> bool {
684     return region.findAncestorBlockInRegion(*op->getBlock());
685   };
686 
687   return PropagateValuesConstraints(worklist, filter, policies, constraints,
688                                     resolve);
689 }
690 
EmitValueConstraintsRemarks(const ValuesConstraintSet & constraints)691 void EmitValueConstraintsRemarks(const ValuesConstraintSet &constraints) {
692   constraints.Walk([](Value value, ValueConstraint constraint) {
693     for (OpOperand &operand : value.getUses())
694       operand.getOwner()->emitRemark(
695           llvm::formatv("operand #{0} constrained to: {1}",
696                         operand.getOperandNumber(), constraint));
697   });
698 }
699 
EmitInputsConstraintsRemarks(FuncOp func,const ValuesConstraintSet & constraints)700 void EmitInputsConstraintsRemarks(FuncOp func,
701                                   const ValuesConstraintSet &constraints) {
702   constraints.Walk([&](Value value, ValueConstraint constraint) {
703     if (auto arg = value.dyn_cast<BlockArgument>())
704       if (arg.getOwner() == &func.body().front())
705         func.emitRemark(llvm::formatv("input #{0} constrained to: {1}",
706                                       arg.getArgNumber(), constraint));
707   });
708 }
709 
InferFunctionBodyValuesConstraints(FuncOp func,ValuesConstraintSet & constraints)710 LogicalResult InferFunctionBodyValuesConstraints(
711     FuncOp func, ValuesConstraintSet &constraints) {
712   for (unsigned i = 0; i < func.getNumResults(); ++i) {
713     auto str = func.getResultAttrOfType<StringAttr>(i, "tf.constraint");
714     if (!str) continue;
715 
716     ValueConstraint constraint = StringSwitch<ValueConstraint>(str.getValue())
717                                      .Case("rank", ValueConstraint::kRank)
718                                      .Case("shape", ValueConstraint::kShape)
719                                      .Case("value", ValueConstraint::kValue);
720 
721     // Propagate constraints through function return operations.
722     for (Block &block : func.body()) {
723       ReturnOp ret = dyn_cast<ReturnOp>(block.back());
724       if (ret) constraints.Insert(ret.getOperand(i), constraint);
725     }
726   }
727 
728   return success();
729 }
730 
731 }  // namespace TFDevice
732 }  // namespace mlir
733