1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h"
17
18 #include "llvm/ADT/DenseSet.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
26 #include "mlir/IR/Operation.h" // from @llvm-project
27 #include "mlir/IR/Value.h" // from @llvm-project
28 #include "mlir/Support/LogicalResult.h" // from @llvm-project
29
30 #define DEBUG_TYPE "cluster-ops-by-policy"
31
32 namespace mlir {
33 namespace TFDevice {
34
35 // -------------------------------------------------------------------------- //
36 // ValueConstraint.
37 // -------------------------------------------------------------------------- //
38
Merge(ValueConstraint a,ValueConstraint b)39 ValueConstraint Merge(ValueConstraint a, ValueConstraint b) {
40 return a > b ? a : b;
41 }
42
IsStaticallyResolved(Value value,ValueConstraint constraint)43 LogicalResult IsStaticallyResolved(Value value, ValueConstraint constraint) {
44 // Resolve constraints inferred from the tensor type.
45 if (auto tensor = value.getType().dyn_cast<TensorType>()) {
46 if (constraint == ValueConstraint::kRank && tensor.hasRank())
47 return success();
48 if (constraint == ValueConstraint::kShape && tensor.hasStaticShape())
49 return success();
50 }
51
52 return failure();
53 }
54
operator <<(raw_ostream & os,const ValueConstraint & constraint)55 raw_ostream &operator<<(raw_ostream &os, const ValueConstraint &constraint) {
56 auto str = [](ValueConstraint constraint) -> StringRef {
57 switch (constraint) {
58 case ValueConstraint::kRank:
59 return "rank";
60 case ValueConstraint::kShape:
61 return "shape";
62 case ValueConstraint::kValue:
63 return "value";
64 default:
65 llvm_unreachable("unknown value constraint");
66 }
67 };
68
69 os << str(constraint);
70 return os;
71 }
72
73 // -------------------------------------------------------------------------- //
74 // ValuesConstraintSet.
75 // -------------------------------------------------------------------------- //
76
Insert(ValueRange values,ValueConstraint constraint)77 void ValuesConstraintSet::Insert(ValueRange values,
78 ValueConstraint constraint) {
79 for (Value value : values) Insert(value, constraint);
80 }
81
Insert(Value value,ValueConstraint constraint)82 std::pair<ValueConstraint, bool> ValuesConstraintSet::Insert(
83 Value value, ValueConstraint constraint) {
84 auto emplaced = constraints_.try_emplace(value, constraint);
85 ValueConstraint persisted = emplaced.first->getSecond();
86
87 // We've just inserted a new constraint for the value.
88 if (emplaced.second) return {persisted, true};
89
90 // Update existing constraint with a new one.
91 auto merged = Merge(constraint, persisted);
92 return {constraints_[value] = merged, merged != persisted};
93 }
94
Walk(llvm::function_ref<void (Value,ValueConstraint)> walk) const95 void ValuesConstraintSet::Walk(
96 llvm::function_ref<void(Value, ValueConstraint)> walk) const {
97 for (auto &kv : constraints_) walk(kv.getFirst(), kv.getSecond());
98 }
99
GetConstraint(Value value) const100 Optional<ValueConstraint> ValuesConstraintSet::GetConstraint(
101 Value value) const {
102 auto it = constraints_.find(value);
103 if (it == constraints_.end()) return None;
104 return it->getSecond();
105 }
106
HasConstraint(Value value) const107 bool ValuesConstraintSet::HasConstraint(Value value) const {
108 return GetConstraint(value).has_value();
109 }
110
MergeAll(const ValuesConstraintSet & other)111 void ValuesConstraintSet::MergeAll(const ValuesConstraintSet &other) {
112 other.Walk([this](Value value, ValueConstraint constraint) {
113 Insert(value, constraint);
114 });
115 }
116
Resolve()117 ValuesConstraintSet &ValuesConstraintSet::Resolve() {
118 llvm::SmallDenseSet<Value, 4> resolved;
119 Walk([&](Value value, ValueConstraint constraint) {
120 if (succeeded(IsStaticallyResolved(value, constraint)))
121 resolved.insert(value);
122 });
123 for (Value value : resolved) constraints_.erase(value);
124 return *this;
125 }
126
Reset()127 ValuesConstraintSet &ValuesConstraintSet::Reset() {
128 constraints_.clear();
129 return *this;
130 }
131
Size() const132 size_t ValuesConstraintSet::Size() const { return constraints_.size(); }
133
Empty() const134 bool ValuesConstraintSet::Empty() const { return constraints_.empty(); }
135
136 // -------------------------------------------------------------------------- //
137 // Discovering clusters of operations based on the policy.
138 // -------------------------------------------------------------------------- //
139
140 namespace {
141 constexpr char kDeviceAttr[] = "device";
142
143 // A type that abstracts over types that have uses accessible via `getUses`.
144 using Source = PointerUnion<Operation *, BlockArgument *>;
145
146 // We use union-find algorithm to build clusters of connected operations based
147 // on the user provided policy. If an operation can be clustered (one of the
148 // user provided policies accepts it under given constraints), it will become
149 // a "member" that will participate in the union-find cluster construction.
150 //
151 // A block argument can also become a member (or even a root member), however
152 // only operations will become a part of the outline `tf_device.cluster`, block
153 // arguments will stay as block arguments, and will later become cluster
154 // function inputs.
155 struct Member {
Membermlir::TFDevice::__anon501e02370411::Member156 Member(unsigned root, Source source, Operation *insertion_point,
157 ValuesConstraintSet constraints = {})
158 : root(root),
159 source(source),
160 insertion_point(insertion_point),
161 constraints(constraints) {}
162
163 unsigned root;
164 Source source;
165
166 // After construction:
167 // For basic block argument source this will be a first operation in the
168 // basic block, and for operations it will be an operation iself.
169 //
170 // During the union-find cluster formation:
171 // The root member will have the location in the basic block, where the
172 // cluster operation will be inserted. We use the location of the last
173 // operation in the cluster, so that during cluster construction we can
174 // ensure that all operands are above the insertion point, and all users are
175 // below the insertion point.
176 //
177 // Example:
178 //
179 // %0 = "clustered_op"(...)
180 // %1 = "non_clustered_op"(...)
181 // %2 = "clustered_op"(%1) <<<--- insert cluster here
182 // %3 = "cluster_result_user"(%1, %2)
183 //
184 // By using `%2` location as an insertion point we ensure that all operands
185 // (%1 in this example) dominate the cluster operation, and that the cluster
186 // operation dominates all the users (%3 in this example).
187 Operation *insertion_point;
188
189 // After construction:
190 // A set of constraints on the clustered operation operands that must be
191 // satisfied in order to add operation to the cluster. For basic block
192 // source this will be always empty.
193 //
194 // During the union-find cluster formation:
195 // The root member will have constraints merged from all of the cluster
196 // members.
197 ValuesConstraintSet constraints;
198 };
199
200 using Members = llvm::SmallVector<Member>;
201
202 struct ClusteringState {
203 // Storage backing an array based union-find algorithm for clustering. Index
204 // in this vector is the member id.
205 llvm::SmallVector<Member> members;
206
207 // Mapping from the member operation (block argument) to the member id.
208 llvm::SmallDenseMap<Source, unsigned> member_ids;
209
210 // Puts `a` and `b` members into the same cluster if it is possible. Returns
211 // success if union operation was completed successfully, otherwise returns
212 // failure.
213 //
214 // Members can be clustered together:
215 // 1. This will not break dominance property of the IR.
216 // 2. New clustering policy constraints can be propagated through the
217 // already clustered operations.
218 LogicalResult Union(unsigned a, unsigned b,
219 const ClusteringPolicySet &policies);
220
221 bool IsMember(Operation *op) const;
222 unsigned FindRoot(unsigned id);
223
224 // Verifies that merging `src_root` cluster with a `dst_root` cluster, and
225 // inserting it at `insertion_point` location will not break the dominance
226 // property: all users of the `src_root` cluster results must be below the
227 // insertion point in the block.
228 LogicalResult VerifyDominanceProperty(unsigned src_root, unsigned dst_root,
229 Operation *insertion_point);
230
231 // Verifies that all constraints on the values defined by the `dst_root`
232 // cluster can be propagated through the nodes in the `src_root` cluster, and
233 // updates `src_root` constraints on success.
234 LogicalResult VerifyValueConstraints(unsigned src_root, unsigned dst_root,
235 const ClusteringPolicySet &policies);
236 };
237
238 } // namespace
239
IsMember(Operation * op) const240 bool ClusteringState::IsMember(Operation *op) const {
241 return member_ids.find(op) != member_ids.end();
242 }
243
FindRoot(unsigned id)244 unsigned ClusteringState::FindRoot(unsigned id) {
245 if (members[id].root == id) return id;
246 return members[id].root = FindRoot(members[id].root);
247 }
248
VerifyDominanceProperty(unsigned src_root,unsigned dst_root,Operation * insertion_point)249 LogicalResult ClusteringState::VerifyDominanceProperty(
250 unsigned src_root, unsigned dst_root, Operation *insertion_point) {
251 // TODO(ezhulenev): Optimize this linear scan with a map lookup.
252 for (auto &member : members) {
253 unsigned root = FindRoot(member.root);
254 if (root != src_root) continue;
255
256 // Block arguments do not really participate in clustering, they are only
257 // used to connect independent operation using the same argument.
258 if (member.source.is<BlockArgument *>()) continue;
259
260 Operation *op = member.source.dyn_cast<Operation *>();
261 assert(op && "member operation must be not null");
262
263 for (Operation *user : op->getUsers()) {
264 // Skip users in other blocks.
265 if (user->getBlock() != op->getBlock()) continue;
266
267 // Skip users is in the `dst_root` or `src_root` clusters, if we'll merge
268 // roots they'll become a single cluster and will not violate the
269 // dominance property after that.
270 auto it = member_ids.find(user);
271 if (it != member_ids.end() && (FindRoot(it->getSecond()) == dst_root ||
272 FindRoot(it->getSecond()) == src_root))
273 continue;
274
275 if (user->isBeforeInBlock(insertion_point)) {
276 LLVM_DEBUG(llvm::dbgs()
277 << " Failure: user is before the insertion point: "
278 << *user << "\n";);
279 return failure();
280 }
281 }
282 }
283
284 return success();
285 }
286
VerifyValueConstraints(unsigned src_root,unsigned dst_root,const ClusteringPolicySet & policies)287 LogicalResult ClusteringState::VerifyValueConstraints(
288 unsigned src_root, unsigned dst_root, const ClusteringPolicySet &policies) {
289 // Propagate constraints only through operations in the `src_root` cluster.
290 auto filter = [&](Operation *op) -> bool {
291 auto it = member_ids.find(op);
292 return it != member_ids.end() && FindRoot(it->getSecond()) == src_root;
293 };
294
295 // Start from all operations in the `src_root` cluster.
296 llvm::SmallVector<Operation *> worklist;
297 for (Member &member : members)
298 if (Operation *op = member.source.dyn_cast<Operation *>())
299 if (FindRoot(member.root) == src_root) worklist.emplace_back(op);
300
301 // Collect `dst_root` constraints that are applicable to the values defined in
302 // the `src_root` cluster.
303 ValuesConstraintSet constraints = members[src_root].constraints;
304 members[dst_root].constraints.Walk([&](Value v, ValueConstraint constraint) {
305 Operation *op = v.getDefiningOp();
306 if (op && filter(op)) constraints.Insert(v, constraint);
307 });
308
309 // Update `src_root` constraints only if we can propagate them.
310 if (succeeded(PropagateValuesConstraints(worklist, filter, policies,
311 constraints, /*resolve=*/true))) {
312 members[src_root].constraints = constraints;
313 return success();
314 }
315
316 return failure();
317 }
318
Union(unsigned a,unsigned b,const ClusteringPolicySet & policies)319 LogicalResult ClusteringState::Union(unsigned a, unsigned b,
320 const ClusteringPolicySet &policies) {
321 unsigned a_root = FindRoot(a);
322 unsigned b_root = FindRoot(b);
323
324 // Already members of the same cluster.
325 if (a_root == b_root) return failure();
326
327 // Verify that merging two clusters will not break dominance property.
328 Operation *a_insertion_point = members[a_root].insertion_point;
329 Operation *b_insertion_point = members[b_root].insertion_point;
330 bool a_is_before_b = a_insertion_point->isBeforeInBlock(b_insertion_point);
331
332 // Use clusters position in the block to select merging src and dst.
333 unsigned src_root = a_is_before_b ? a_root : b_root; // merge `src_root` ...
334 unsigned dst_root = a_is_before_b ? b_root : a_root; // ... into `dst_root`
335 Operation *insertion_point =
336 a_is_before_b ? b_insertion_point : a_insertion_point;
337
338 // Print operations in the `root` cluster to debug stream.
339 auto debug_clustered_ops = [&](unsigned root) {
340 for (Member &member : members)
341 if (FindRoot(member.root) == root) {
342 if (auto *op = member.source.dyn_cast<Operation *>()) {
343 llvm::dbgs() << " " << *op << "\n";
344 } else if (auto *arg = member.source.dyn_cast<BlockArgument *>()) {
345 llvm::dbgs() << " " << *arg;
346 }
347 }
348 };
349 (void)debug_clustered_ops;
350
351 LLVM_DEBUG({
352 llvm::dbgs() << "\n\n--- Try to merge cluster:\n";
353 debug_clustered_ops(src_root);
354 llvm::dbgs() << "\n--- With cluster:\n";
355 debug_clustered_ops(dst_root);
356 LLVM_DEBUG(llvm::dbgs() << "\n--- Diagnostics:\n");
357 });
358
359 // Check if merging `src_root` with `dst_root` will not violate SSA dominance
360 // property (all operands before the cluster, all results after the cluster).
361 if (failed(VerifyDominanceProperty(src_root, dst_root, insertion_point)))
362 return failure();
363
364 // Check if `dst_root` constraints can be propagated to the `src_root`
365 // constraints.
366 if (failed(VerifyValueConstraints(src_root, dst_root, policies)))
367 return failure();
368
369 // Set `dst_root` as a new root for `src_root`.
370 members[src_root].root = dst_root;
371 // Update insertion point of the new root.
372 members[dst_root].insertion_point = insertion_point;
373 // Merge all constraints from `src_root` into `dst_root`.
374 members[dst_root].constraints.MergeAll(members[src_root].constraints);
375
376 LLVM_DEBUG(llvm::dbgs() << " Clusters successfully merged\n");
377
378 return success();
379 }
380
381 // Returns constraints on the operands specified by the clustering policy if the
382 // operation can be clustered (constraints could be empty). Otherwise return
383 // empty optional.
CanBeClustered(Operation * op,const ClusteringPolicySet & policies,const std::function<bool (Operation * op)> & filter)384 static Optional<ValuesConstraintSet> CanBeClustered(
385 Operation *op, const ClusteringPolicySet &policies,
386 const std::function<bool(Operation *op)> &filter) {
387 // Check that op has no side effects. This guarantees that we will not
388 // reorder side-effecting ops during cluster formation.
389 if (!MemoryEffectOpInterface::hasNoEffect(op)) return llvm::None;
390
391 // Operation rejected by the custom filter.
392 if (filter && !filter(op)) return llvm::None;
393
394 // Initially we do not have any constraints on the operation results.
395 ValuesConstraintSet result_constraints;
396
397 for (auto &policy : policies.policies()) {
398 ValuesConstraintSet operands_constraints;
399 if (succeeded(policy->MatchAndUpdateConstraints(op, result_constraints,
400 operands_constraints)))
401 return operands_constraints.Resolve();
402 }
403
404 return llvm::None;
405 }
406
407 // Compute initial clustering state based on the clustering polocy.
InitializeClusteringState(Block * block,const ClusteringPolicySet & policies,const std::function<bool (Operation * op)> & filter)408 static ClusteringState InitializeClusteringState(
409 Block *block, const ClusteringPolicySet &policies,
410 const std::function<bool(Operation *op)> &filter) {
411 ClusteringState state;
412
413 // Create members for all block arguments.
414 for (BlockArgument &arg : block->getArguments()) {
415 if (!arg.getUsers().empty())
416 state.members.emplace_back(state.members.size(), &arg, &block->front());
417 }
418
419 int num_bbarg_members = state.members.size();
420 (void)num_bbarg_members;
421
422 // Create members for operations that can be clustered based on the policy.
423 for (Operation &op : block->getOperations()) {
424 if (auto constraints = CanBeClustered(&op, policies, filter))
425 state.members.emplace_back(state.members.size(), &op, &op, *constraints);
426 }
427
428 // Initialize mapping from the member operation (block argument) to the id.
429 for (auto &tuple : llvm::enumerate(state.members)) {
430 state.member_ids.try_emplace(tuple.value().source, tuple.index());
431 }
432
433 LLVM_DEBUG(llvm::dbgs() << "Found "
434 << (state.members.size() - num_bbarg_members)
435 << " clustering candidate operations in the block\n");
436
437 return state;
438 }
439
440 // Users of the `source` that are candidates for clustering.
GetClusteringCandidates(const ClusteringState & state,Source source)441 static llvm::SmallVector<Operation *> GetClusteringCandidates(
442 const ClusteringState &state, Source source) {
443 // Users of operation result must be in the same block and placed on the same
444 // device.
445 if (auto op = source.dyn_cast<Operation *>()) {
446 auto range = llvm::make_filter_range(op->getUsers(), [&](Operation *user) {
447 bool same_block = user->getBlock() == op->getBlock();
448 bool same_device = op->getAttr(kDeviceAttr) == user->getAttr(kDeviceAttr);
449 return same_block && same_device && state.IsMember(user);
450 });
451 return {range.begin(), range.end()};
452 }
453
454 // Users of block argument must be in the same block.
455 if (auto arg = source.dyn_cast<BlockArgument *>()) {
456 auto range = llvm::make_filter_range(arg->getUsers(), [&](Operation *user) {
457 bool same_block = user->getBlock() == arg->getOwner();
458 return same_block && state.IsMember(user);
459 });
460 return {range.begin(), range.end()};
461 }
462
463 llvm_unreachable("Unexpected type in the union.");
464 }
465
466 // Cluster members with their result users. Returns `true` if merged at least a
467 // pair of members into a new cluster.
RunClusteringPass(ClusteringState & state,const ClusteringPolicySet & policies)468 static bool RunClusteringPass(ClusteringState &state,
469 const ClusteringPolicySet &policies) {
470 bool clustered = false;
471
472 for (auto &tuple : llvm::enumerate(state.members)) {
473 size_t member_id = tuple.index();
474 Member &member = tuple.value();
475
476 llvm::SmallVector<Operation *> users =
477 GetClusteringCandidates(state, member.source);
478
479 // Process candidates according to their order in the block to minimize
480 // the number of dominance property violations.
481 llvm::sort(users, [](auto *a, auto *b) { return a->isBeforeInBlock(b); });
482
483 for (Operation *user : users) {
484 auto user_member_id = state.member_ids.lookup(user);
485 if (succeeded(state.Union(member_id, user_member_id, policies)))
486 clustered = true;
487 }
488 }
489
490 return clustered;
491 }
492
FindClustersInTheBlock(Block * block,const ClusteringPolicySet & policies,std::function<bool (Operation * op)> filter)493 llvm::SmallVector<Cluster> FindClustersInTheBlock(
494 Block *block, const ClusteringPolicySet &policies,
495 std::function<bool(Operation *op)> filter) {
496 // It is impossible to build a cluster in the empty block.
497 if (block->empty()) return {};
498
499 ClusteringState state = InitializeClusteringState(block, policies, filter);
500
501 // Run clustering passes until the convergence. Limit the number of iterations
502 // to guard from the infinite loop in presence of bugs.
503 constexpr int max_iterations = 100;
504 for (unsigned i = 0; i < max_iterations; ++i)
505 if (!RunClusteringPass(state, policies)) break;
506
507 // Form clusters found by the union-find algorithm.
508 llvm::DenseMap<unsigned, Cluster> root_clusters;
509
510 for (Member &member : state.members) {
511 unsigned root = state.FindRoot(member.root);
512 Cluster &cluster = root_clusters.FindAndConstruct(root).getSecond();
513
514 // If member is a root of the cluster, copy inferred constraints.
515 if (state.FindRoot(member.root) == member.root)
516 cluster.constraints = std::move(member.constraints);
517
518 // Add operation to the cluster.
519 if (auto op = member.source.dyn_cast<Operation *>())
520 cluster.operations.emplace_back(op);
521 }
522
523 llvm::SmallVector<Cluster> clusters;
524 for (auto &kv : root_clusters) {
525 Cluster &cluster = kv.getSecond();
526 // Skip degenerate clusters formed by a single basic block argument.
527 if (!cluster.operations.empty()) clusters.emplace_back(std::move(cluster));
528 }
529
530 LLVM_DEBUG(llvm::dbgs() << "Found " << clusters.size() << " clusters\n");
531
532 return clusters;
533 }
534
535 // -------------------------------------------------------------------------- //
536 // Create `tf_device.cluster` operation from the discovered ops cluster.
537 // -------------------------------------------------------------------------- //
538
CreateClusterOp(Cluster & cluster,StringAttr policy)539 tf_device::ClusterOp CreateClusterOp(Cluster &cluster, StringAttr policy) {
540 // Find all the values that are used outside of the cluster. These values
541 // will be returned from the created cluster operation.
542 llvm::DenseSet<Operation *> in_cluster;
543 for (Operation *op : cluster.operations) in_cluster.insert(op);
544
545 llvm::SetVector<Value> return_values;
546 llvm::SmallVector<Type> return_types;
547
548 for (Operation *op : cluster.operations)
549 for (OpOperand &use : op->getUses()) {
550 // User is inside the cluster.
551 if (in_cluster.contains(use.getOwner())) continue;
552 // Do not return the same value multiple times.
553 if (return_values.contains(use.get())) continue;
554
555 return_values.insert(use.get());
556 return_types.emplace_back(use.get().getType());
557 }
558
559 // Sort matched operations by their position in the block.
560 llvm::sort(cluster.operations, [](Operation *a, Operation *b) -> bool {
561 return a->isBeforeInBlock(b);
562 });
563
564 // Create tf_device::ClusterOp before the last operation in the block that
565 // is a part of a match set.
566 auto back = cluster.operations.back();
567 auto loc = back->getLoc();
568 OpBuilder builder(back);
569
570 auto cluster_op =
571 builder.create<tf_device::ClusterOp>(loc, return_types, policy);
572
573 // Create block in cluster_op's region and move 'cluster.operations' into
574 // it.
575 auto block = builder.createBlock(&cluster_op.body());
576 auto block_end = block->end();
577 for (auto op : cluster.operations) op->moveBefore(block, block_end);
578
579 // Add 'tf_device::ReturnOp' at the end of the block.
580 builder.setInsertionPointToEnd(block);
581 builder.create<tf_device::ReturnOp>(loc, return_values.getArrayRef());
582
583 // Set device attribute
584 if (auto device = back->getAttr(kDeviceAttr))
585 cluster_op->setAttr(kDeviceAttr, device);
586
587 // Update all users of the operations moved into the cluster region.
588 for (auto tuple : llvm::zip(return_values, cluster_op.getResults())) {
589 Value old_value = std::get<0>(tuple);
590 Value new_value = std::get<1>(tuple);
591 old_value.replaceUsesWithIf(new_value, [&](OpOperand &operand) -> bool {
592 // Do not update users in the same cluster.
593 return operand.getOwner()->getBlock() != block;
594 });
595 }
596
597 return cluster_op;
598 }
599
600 // -------------------------------------------------------------------------- //
601 // Helper functions for value constraints propagations and analysis.
602 // -------------------------------------------------------------------------- //
603
PropagateValuesConstraints(llvm::ArrayRef<Operation * > root,std::function<bool (Operation *)> filter,const ClusteringPolicySet & policies,ValuesConstraintSet & constraints,bool resolve,bool emit_remarks)604 mlir::LogicalResult PropagateValuesConstraints(
605 llvm::ArrayRef<Operation *> root, std::function<bool(Operation *)> filter,
606 const ClusteringPolicySet &policies, ValuesConstraintSet &constraints,
607 bool resolve, bool emit_remarks) {
608 // A set of constraints for operation results.
609 llvm::DenseMap<Operation *, ValuesConstraintSet> op_results_constraints;
610 assert(filter && "filter predicate must be defined");
611
612 // Use initial constraints to initialize op results constraints.
613 for (std::pair<Value, ValueConstraint> pair : constraints) {
614 Value value = pair.first;
615 ValueConstraint constraint = pair.second;
616
617 // Value must be defined by an operation and accepted by the filter.
618 Operation *op = value.getDefiningOp();
619 if (!op || !filter(op)) continue;
620
621 op_results_constraints[op].Insert(value, constraint);
622 }
623
624 // Keep a worklist of operations that need their constraints to be updated.
625 llvm::SetVector<Operation *> worklist;
626 for (Operation *op : root) worklist.insert(op);
627
628 while (!worklist.empty()) {
629 Operation *op = worklist.pop_back_val();
630
631 // Use results constraints to infer operands constraints.
632 const ValuesConstraintSet &results = op_results_constraints[op];
633 ValuesConstraintSet operands;
634
635 // Walk through all policies until we find one that matches the operation.
636 bool updated = false;
637 for (auto &policy : policies.policies()) {
638 auto matched =
639 policy->MatchAndUpdateConstraints(op, results, operands.Reset());
640 if (succeeded(matched)) {
641 updated = true;
642 break;
643 }
644 }
645
646 // Signal a failure if could not propagate non-empty constraints on the
647 // operation results to the operands.
648 if (!updated && !results.Empty()) {
649 if (emit_remarks) {
650 std::string err_msg;
651 llvm::raw_string_ostream os(err_msg);
652 for (unsigned i = 0; i < op->getNumResults(); ++i)
653 os << " " << i << ":" << results.GetConstraint(op->getResult(i));
654 op->emitError(llvm::formatv(
655 "failed to propagate results constraints:{0}", os.str()));
656 }
657 return failure();
658 }
659
660 // Update results constraints based on inferred operands constraints.
661 operands.Walk([&](Value value, ValueConstraint constraint) {
662 // Resolve constraint based on the static type information.
663 if (resolve && succeeded(IsStaticallyResolved(value, constraint))) return;
664
665 // Update constraint for a value.
666 auto updated = constraints.Insert(value, constraint);
667 if (!updated.second) return;
668
669 // Maybe update constaint on the operation result, but do not follow
670 // operations that are not accepted by the filter predicate.
671 Operation *op = value.getDefiningOp();
672 if (!op || !filter(op)) return;
673
674 // Add updated operation to the worklist.
675 auto inserted = op_results_constraints[op].Insert(value, updated.first);
676 if (inserted.second) worklist.insert(op);
677 });
678 }
679
680 return success();
681 }
682
PropagateValuesConstraints(mlir::Region & region,const ClusteringPolicySet & policies,ValuesConstraintSet & constraints,bool resolve,bool emit_remarks)683 mlir::LogicalResult PropagateValuesConstraints(
684 mlir::Region ®ion, const ClusteringPolicySet &policies,
685 ValuesConstraintSet &constraints, bool resolve, bool emit_remarks) {
686 // Propagate constraints for all operations in the region.
687 llvm::SmallVector<Operation *> worklist;
688 region.walk([&](Operation *op) { worklist.emplace_back(op); });
689
690 // Propagate constraints only through operations inside the `region`.
691 auto filter = [&](Operation *op) -> bool {
692 return region.findAncestorBlockInRegion(*op->getBlock());
693 };
694
695 return PropagateValuesConstraints(worklist, filter, policies, constraints,
696 resolve, emit_remarks);
697 }
698
EmitValueConstraintsRemarks(const ValuesConstraintSet & constraints)699 void EmitValueConstraintsRemarks(const ValuesConstraintSet &constraints) {
700 constraints.Walk([](Value value, ValueConstraint constraint) {
701 for (OpOperand &operand : value.getUses())
702 operand.getOwner()->emitRemark(
703 llvm::formatv("operand #{0} constrained to: {1}",
704 operand.getOperandNumber(), constraint));
705 });
706 }
707
EmitInputsConstraintsRemarks(func::FuncOp func,const ValuesConstraintSet & constraints)708 void EmitInputsConstraintsRemarks(func::FuncOp func,
709 const ValuesConstraintSet &constraints) {
710 constraints.Walk([&](Value value, ValueConstraint constraint) {
711 if (auto arg = value.dyn_cast<BlockArgument>())
712 if (arg.getOwner() == &func.getBody().front())
713 func.emitRemark(llvm::formatv("input #{0} constrained to: {1}",
714 arg.getArgNumber(), constraint));
715 });
716 }
717
InferFunctionBodyValuesConstraints(func::FuncOp func,ValuesConstraintSet & constraints)718 LogicalResult InferFunctionBodyValuesConstraints(
719 func::FuncOp func, ValuesConstraintSet &constraints) {
720 for (unsigned i = 0; i < func.getNumResults(); ++i) {
721 auto str = func.getResultAttrOfType<StringAttr>(i, "tf.constraint");
722 if (!str) continue;
723
724 ValueConstraint constraint = StringSwitch<ValueConstraint>(str.getValue())
725 .Case("rank", ValueConstraint::kRank)
726 .Case("shape", ValueConstraint::kShape)
727 .Case("value", ValueConstraint::kValue);
728
729 // Propagate constraints through function return operations.
730 for (Block &block : func.getBody()) {
731 func::ReturnOp ret = dyn_cast<func::ReturnOp>(block.back());
732 if (ret) constraints.Insert(ret.getOperand(i), constraint);
733 }
734 }
735
736 return success();
737 }
738
739 } // namespace TFDevice
740 } // namespace mlir
741