1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h"
17
18 #include "llvm/ADT/DenseSet.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/IR/Value.h" // from @llvm-project
27 #include "mlir/Support/LogicalResult.h" // from @llvm-project
28
29 #define DEBUG_TYPE "cluster-ops-by-policy"
30
31 namespace mlir {
32 namespace TFDevice {
33
34 // -------------------------------------------------------------------------- //
35 // ValueConstraint.
36 // -------------------------------------------------------------------------- //
37
Merge(ValueConstraint a,ValueConstraint b)38 ValueConstraint Merge(ValueConstraint a, ValueConstraint b) {
39 return a > b ? a : b;
40 }
41
IsStaticallyResolved(Value value,ValueConstraint constraint)42 LogicalResult IsStaticallyResolved(Value value, ValueConstraint constraint) {
43 // Resolve constraints inferred from the tensor type.
44 if (auto tensor = value.getType().dyn_cast<TensorType>()) {
45 if (constraint == ValueConstraint::kRank && tensor.hasRank())
46 return success();
47 if (constraint == ValueConstraint::kShape && tensor.hasStaticShape())
48 return success();
49 }
50
51 return failure();
52 }
53
operator <<(raw_ostream & os,const ValueConstraint & constraint)54 raw_ostream &operator<<(raw_ostream &os, const ValueConstraint &constraint) {
55 auto str = [](ValueConstraint constraint) -> StringRef {
56 switch (constraint) {
57 case ValueConstraint::kRank:
58 return "rank";
59 case ValueConstraint::kShape:
60 return "shape";
61 case ValueConstraint::kValue:
62 return "value";
63 default:
64 llvm_unreachable("unknown value constraint");
65 }
66 };
67
68 os << str(constraint);
69 return os;
70 }
71
72 // -------------------------------------------------------------------------- //
73 // ValuesConstraintSet.
74 // -------------------------------------------------------------------------- //
75
Insert(ValueRange values,ValueConstraint constraint)76 void ValuesConstraintSet::Insert(ValueRange values,
77 ValueConstraint constraint) {
78 for (Value value : values) Insert(value, constraint);
79 }
80
Insert(Value value,ValueConstraint constraint)81 std::pair<ValueConstraint, bool> ValuesConstraintSet::Insert(
82 Value value, ValueConstraint constraint) {
83 auto emplaced = constraints_.try_emplace(value, constraint);
84 ValueConstraint persisted = emplaced.first->getSecond();
85
86 // We've just inserted a new constraint for the value.
87 if (emplaced.second) return {persisted, true};
88
89 // Update existing constraint with a new one.
90 auto merged = Merge(constraint, persisted);
91 return {constraints_[value] = merged, merged != persisted};
92 }
93
Walk(llvm::function_ref<void (Value,ValueConstraint)> walk) const94 void ValuesConstraintSet::Walk(
95 llvm::function_ref<void(Value, ValueConstraint)> walk) const {
96 for (auto &kv : constraints_) walk(kv.getFirst(), kv.getSecond());
97 }
98
GetConstraint(Value value) const99 Optional<ValueConstraint> ValuesConstraintSet::GetConstraint(
100 Value value) const {
101 auto it = constraints_.find(value);
102 if (it == constraints_.end()) return None;
103 return it->getSecond();
104 }
105
HasConstraint(Value value) const106 bool ValuesConstraintSet::HasConstraint(Value value) const {
107 return GetConstraint(value).hasValue();
108 }
109
MergeAll(const ValuesConstraintSet & other)110 void ValuesConstraintSet::MergeAll(const ValuesConstraintSet &other) {
111 other.Walk([this](Value value, ValueConstraint constraint) {
112 Insert(value, constraint);
113 });
114 }
115
Resolve()116 ValuesConstraintSet &ValuesConstraintSet::Resolve() {
117 llvm::SmallDenseSet<Value, 4> resolved;
118 Walk([&](Value value, ValueConstraint constraint) {
119 if (succeeded(IsStaticallyResolved(value, constraint)))
120 resolved.insert(value);
121 });
122 for (Value value : resolved) constraints_.erase(value);
123 return *this;
124 }
125
Reset()126 ValuesConstraintSet &ValuesConstraintSet::Reset() {
127 constraints_.clear();
128 return *this;
129 }
130
Size() const131 size_t ValuesConstraintSet::Size() const { return constraints_.size(); }
132
Empty() const133 bool ValuesConstraintSet::Empty() const { return constraints_.empty(); }
134
135 // -------------------------------------------------------------------------- //
136 // Discovering clusters of operations based on the policy.
137 // -------------------------------------------------------------------------- //
138
139 namespace {
140 constexpr char kDeviceAttr[] = "device";
141
142 // A type that abstracts over types that have uses accessible via `getUses`.
143 using Source = PointerUnion<Operation *, BlockArgument *>;
144
145 // We use union-find algorithm to build clusters of connected operations based
146 // on the user provided policy. If an operation can be clustered (one of the
147 // user provided policies accepts it under given constraints), it will become
148 // a "member" that will participate in the union-find cluster construction.
149 //
150 // A block argument can also become a member (or even a root member), however
151 // only operations will become a part of the outline `tf_device.cluster`, block
152 // arguments will stay as block arguments, and will later become cluster
153 // function inputs.
154 struct Member {
Membermlir::TFDevice::__anonaac0da7e0411::Member155 Member(unsigned root, Source source, Operation *insertion_point,
156 ValuesConstraintSet constraints = {})
157 : root(root),
158 source(source),
159 insertion_point(insertion_point),
160 constraints(constraints) {}
161
162 unsigned root;
163 Source source;
164
165 // After construction:
166 // For basic block argument source this will be a first operation in the
167 // basic block, and for operations it will be an operation iself.
168 //
169 // During the union-find cluster formation:
170 // The root member will have the location in the basic block, where the
171 // cluster operation will be inserted. We use the location of the last
172 // operation in the cluster, so that during cluster construction we can
173 // ensure that all operands are above the insertion point, and all users are
174 // below the insertion point.
175 //
176 // Example:
177 //
178 // %0 = "clustered_op"(...)
179 // %1 = "non_clustered_op"(...)
180 // %2 = "clustered_op"(%1) <<<--- insert cluster here
181 // %3 = "cluster_result_user"(%1, %2)
182 //
183 // By using `%2` location as an insertion point we ensure that all operands
184 // (%1 in this example) dominate the cluster operation, and that the cluster
185 // operation dominates all the users (%3 in this example).
186 Operation *insertion_point;
187
188 // After construction:
189 // A set of constraints on the clustered operation operands that must be
190 // satisfied in order to add operation to the cluster. For basic block
191 // source this will be always empty.
192 //
193 // During the union-find cluster formation:
194 // The root member will have constraints merged from all of the cluster
195 // members.
196 ValuesConstraintSet constraints;
197 };
198
199 using Members = llvm::SmallVector<Member>;
200
201 struct ClusteringState {
202 // Storage backing an array based union-find algorithm for clustering. Index
203 // in this vector is the member id.
204 llvm::SmallVector<Member> members;
205
206 // Mapping from the member operation (block argument) to the member id.
207 llvm::SmallDenseMap<Source, unsigned> member_ids;
208
209 // Puts `a` and `b` members into the same cluster if it is possible. Returns
210 // success if union operation was completed successfully, otherwise returns
211 // failure.
212 //
213 // Members can be clustered together:
214 // 1. This will not break dominance property of the IR.
215 // 2. New clustering policy constraints can be propagated through the
216 // already clustered operations.
217 LogicalResult Union(unsigned a, unsigned b,
218 const ClusteringPolicySet &policies);
219
220 bool IsMember(Operation *op) const;
221 unsigned FindRoot(unsigned id);
222
223 // Verifies that merging `src_root` cluster with a `dst_root` cluster, and
224 // inserting it at `insertion_point` location will not break the dominance
225 // property: all users of the `src_root` cluster results must be below the
226 // insertion point in the block.
227 LogicalResult VerifyDominanceProperty(unsigned src_root, unsigned dst_root,
228 Operation *insertion_point);
229
230 // Verifies that all constraints on the values defined by the `dst_root`
231 // cluster can be propagated through the nodes in the `src_root` cluster, and
232 // updates `src_root` constraints on success.
233 LogicalResult VerifyValueConstraints(unsigned src_root, unsigned dst_root,
234 const ClusteringPolicySet &policies);
235 };
236
237 } // namespace
238
IsMember(Operation * op) const239 bool ClusteringState::IsMember(Operation *op) const {
240 return member_ids.find(op) != member_ids.end();
241 }
242
FindRoot(unsigned id)243 unsigned ClusteringState::FindRoot(unsigned id) {
244 if (members[id].root == id) return id;
245 return members[id].root = FindRoot(members[id].root);
246 }
247
VerifyDominanceProperty(unsigned src_root,unsigned dst_root,Operation * insertion_point)248 LogicalResult ClusteringState::VerifyDominanceProperty(
249 unsigned src_root, unsigned dst_root, Operation *insertion_point) {
250 // TODO(ezhulenev): Optimize this linear scan with a map lookup.
251 for (auto &member : members) {
252 unsigned root = FindRoot(member.root);
253 if (root != src_root) continue;
254
255 // Block arguments do not really participate in clustering, they are only
256 // used to connect independent operation using the same argument.
257 if (member.source.is<BlockArgument *>()) continue;
258
259 Operation *op = member.source.dyn_cast<Operation *>();
260 assert(op && "member operation must be not null");
261
262 for (Operation *user : op->getUsers()) {
263 // Skip users in other blocks.
264 if (user->getBlock() != op->getBlock()) continue;
265
266 // Skip users is in the `dst_root` or `src_root` clusters, if we'll merge
267 // roots they'll become a single cluster and will not violate the
268 // dominance property after that.
269 auto it = member_ids.find(user);
270 if (it != member_ids.end() && (FindRoot(it->getSecond()) == dst_root ||
271 FindRoot(it->getSecond()) == src_root))
272 continue;
273
274 if (user->isBeforeInBlock(insertion_point)) {
275 LLVM_DEBUG(llvm::dbgs()
276 << " Failure: user is before the insertion point: "
277 << *user << "\n";);
278 return failure();
279 }
280 }
281 }
282
283 return success();
284 }
285
VerifyValueConstraints(unsigned src_root,unsigned dst_root,const ClusteringPolicySet & policies)286 LogicalResult ClusteringState::VerifyValueConstraints(
287 unsigned src_root, unsigned dst_root, const ClusteringPolicySet &policies) {
288 // Propagate constraints only through operations in the `src_root` cluster.
289 auto filter = [&](Operation *op) -> bool {
290 auto it = member_ids.find(op);
291 return it != member_ids.end() && FindRoot(it->getSecond()) == src_root;
292 };
293
294 // Start from all operations in the `src_root` cluster.
295 llvm::SmallVector<Operation *> worklist;
296 for (Member &member : members)
297 if (Operation *op = member.source.dyn_cast<Operation *>())
298 if (FindRoot(member.root) == src_root) worklist.emplace_back(op);
299
300 // Collect `dst_root` constraints that are applicable to the values defined in
301 // the `src_root` cluster.
302 ValuesConstraintSet constraints = members[src_root].constraints;
303 members[dst_root].constraints.Walk([&](Value v, ValueConstraint constraint) {
304 Operation *op = v.getDefiningOp();
305 if (op && filter(op)) constraints.Insert(v, constraint);
306 });
307
308 // Update `src_root` constraints only if we can propagate them.
309 if (succeeded(PropagateValuesConstraints(worklist, filter, policies,
310 constraints, /*resolve=*/true))) {
311 members[src_root].constraints = constraints;
312 return success();
313 }
314
315 return failure();
316 }
317
Union(unsigned a,unsigned b,const ClusteringPolicySet & policies)318 LogicalResult ClusteringState::Union(unsigned a, unsigned b,
319 const ClusteringPolicySet &policies) {
320 unsigned a_root = FindRoot(a);
321 unsigned b_root = FindRoot(b);
322
323 // Already members of the same cluster.
324 if (a_root == b_root) return failure();
325
326 // Verify that merging two clusters will not break dominance property.
327 Operation *a_insertion_point = members[a_root].insertion_point;
328 Operation *b_insertion_point = members[b_root].insertion_point;
329 bool a_is_before_b = a_insertion_point->isBeforeInBlock(b_insertion_point);
330
331 // Use clusters position in the block to select merging src and dst.
332 unsigned src_root = a_is_before_b ? a_root : b_root; // merge `src_root` ...
333 unsigned dst_root = a_is_before_b ? b_root : a_root; // ... into `dst_root`
334 Operation *insertion_point =
335 a_is_before_b ? b_insertion_point : a_insertion_point;
336
337 // Print operations in the `root` cluster to debug stream.
338 auto debug_clustered_ops = [&](unsigned root) {
339 for (Member &member : members)
340 if (FindRoot(member.root) == root) {
341 if (auto *op = member.source.dyn_cast<Operation *>()) {
342 llvm::dbgs() << " " << *op << "\n";
343 } else if (auto *arg = member.source.dyn_cast<BlockArgument *>()) {
344 llvm::dbgs() << " " << *arg;
345 }
346 }
347 };
348 (void)debug_clustered_ops;
349
350 LLVM_DEBUG({
351 llvm::dbgs() << "\n\n--- Try to merge cluster:\n";
352 debug_clustered_ops(src_root);
353 llvm::dbgs() << "\n--- With cluster:\n";
354 debug_clustered_ops(dst_root);
355 LLVM_DEBUG(llvm::dbgs() << "\n--- Diagnostics:\n");
356 });
357
358 // Check if merging `src_root` with `dst_root` will not violate SSA dominance
359 // property (all operands before the cluster, all results after the cluster).
360 if (failed(VerifyDominanceProperty(src_root, dst_root, insertion_point)))
361 return failure();
362
363 // Check if `dst_root` constraints can be propagated to the `src_root`
364 // constraints.
365 if (failed(VerifyValueConstraints(src_root, dst_root, policies)))
366 return failure();
367
368 // Set `dst_root` as a new root for `src_root`.
369 members[src_root].root = dst_root;
370 // Update insertion point of the new root.
371 members[dst_root].insertion_point = insertion_point;
372 // Merge all constraints from `src_root` into `dst_root`.
373 members[dst_root].constraints.MergeAll(members[src_root].constraints);
374
375 LLVM_DEBUG(llvm::dbgs() << " Clusters successfully merged\n");
376
377 return success();
378 }
379
380 // Returns constraints on the operands specified by the clustering policy if the
381 // operation can be clustered (constraints could be empty). Otherwise return
382 // empty optional.
CanBeClustered(Operation * op,const ClusteringPolicySet & policies,const std::function<bool (Operation * op)> & filter)383 static Optional<ValuesConstraintSet> CanBeClustered(
384 Operation *op, const ClusteringPolicySet &policies,
385 const std::function<bool(Operation *op)> &filter) {
386 // Check that op has no side effects. This guarantees that we will not
387 // reorder side-effecting ops during cluster formation.
388 if (!MemoryEffectOpInterface::hasNoEffect(op)) return llvm::None;
389
390 // Operation rejected by the custom filter.
391 if (filter && !filter(op)) return llvm::None;
392
393 // Initially we do not have any constraints on the operation results.
394 ValuesConstraintSet result_constraints;
395
396 for (auto &policy : policies.policies()) {
397 ValuesConstraintSet operands_constraints;
398 if (succeeded(policy->MatchAndUpdateConstraints(op, result_constraints,
399 operands_constraints)))
400 return operands_constraints.Resolve();
401 }
402
403 return llvm::None;
404 }
405
406 // Compute initial clustering state based on the clustering polocy.
InitializeClusteringState(Block * block,const ClusteringPolicySet & policies,const std::function<bool (Operation * op)> & filter)407 static ClusteringState InitializeClusteringState(
408 Block *block, const ClusteringPolicySet &policies,
409 const std::function<bool(Operation *op)> &filter) {
410 ClusteringState state;
411
412 // Create members for all block arguments.
413 for (BlockArgument &arg : block->getArguments()) {
414 if (!arg.getUsers().empty())
415 state.members.emplace_back(state.members.size(), &arg, &block->front());
416 }
417
418 int num_bbarg_members = state.members.size();
419 (void)num_bbarg_members;
420
421 // Create members for operations that can be clustered based on the policy.
422 for (Operation &op : block->getOperations()) {
423 if (auto constraints = CanBeClustered(&op, policies, filter))
424 state.members.emplace_back(state.members.size(), &op, &op, *constraints);
425 }
426
427 // Initialize mapping from the member operation (block argument) to the id.
428 for (auto &tuple : llvm::enumerate(state.members)) {
429 state.member_ids.try_emplace(tuple.value().source, tuple.index());
430 }
431
432 LLVM_DEBUG(llvm::dbgs() << "Found "
433 << (state.members.size() - num_bbarg_members)
434 << " clustering candidate operations in the block\n");
435
436 return state;
437 }
438
439 // Users of the `source` that are candidates for clustering.
GetClusteringCandidates(const ClusteringState & state,Source source)440 static llvm::SmallVector<Operation *> GetClusteringCandidates(
441 const ClusteringState &state, Source source) {
442 // Users of operation result must be in the same block and placed on the same
443 // device.
444 if (auto op = source.dyn_cast<Operation *>()) {
445 auto range = llvm::make_filter_range(op->getUsers(), [&](Operation *user) {
446 bool same_block = user->getBlock() == op->getBlock();
447 bool same_device = op->getAttr(kDeviceAttr) == user->getAttr(kDeviceAttr);
448 return same_block && same_device && state.IsMember(user);
449 });
450 return {range.begin(), range.end()};
451 }
452
453 // Users of block argument must be in the same block.
454 if (auto arg = source.dyn_cast<BlockArgument *>()) {
455 auto range = llvm::make_filter_range(arg->getUsers(), [&](Operation *user) {
456 bool same_block = user->getBlock() == arg->getOwner();
457 return same_block && state.IsMember(user);
458 });
459 return {range.begin(), range.end()};
460 }
461
462 llvm_unreachable("Unexpected type in the union.");
463 }
464
465 // Cluster members with their result users. Returns `true` if merged at least a
466 // pair of members into a new cluster.
RunClusteringPass(ClusteringState & state,const ClusteringPolicySet & policies)467 static bool RunClusteringPass(ClusteringState &state,
468 const ClusteringPolicySet &policies) {
469 bool clustered = false;
470
471 for (auto &tuple : llvm::enumerate(state.members)) {
472 size_t member_id = tuple.index();
473 Member &member = tuple.value();
474
475 llvm::SmallVector<Operation *> users =
476 GetClusteringCandidates(state, member.source);
477
478 // Process candidates according to their order in the block to minimize
479 // the number of dominance property violations.
480 llvm::sort(users, [](auto *a, auto *b) { return a->isBeforeInBlock(b); });
481
482 for (Operation *user : users) {
483 auto user_member_id = state.member_ids.lookup(user);
484 if (succeeded(state.Union(member_id, user_member_id, policies)))
485 clustered = true;
486 }
487 }
488
489 return clustered;
490 }
491
FindClustersInTheBlock(Block * block,const ClusteringPolicySet & policies,std::function<bool (Operation * op)> filter)492 llvm::SmallVector<Cluster> FindClustersInTheBlock(
493 Block *block, const ClusteringPolicySet &policies,
494 std::function<bool(Operation *op)> filter) {
495 // It is impossible to build a cluster in the empty block.
496 if (block->empty()) return {};
497
498 ClusteringState state = InitializeClusteringState(block, policies, filter);
499
500 // Run clustering passes until the convergence. Limit the number of iterations
501 // to guard from the infinite loop in presence of bugs.
502 constexpr int max_iterations = 100;
503 for (unsigned i = 0; i < max_iterations; ++i)
504 if (!RunClusteringPass(state, policies)) break;
505
506 // Form clusters found by the union-find algorithm.
507 llvm::DenseMap<unsigned, Cluster> root_clusters;
508
509 for (Member &member : state.members) {
510 unsigned root = state.FindRoot(member.root);
511 Cluster &cluster = root_clusters.FindAndConstruct(root).getSecond();
512
513 // If member is a root of the cluster, copy inferred constraints.
514 if (state.FindRoot(member.root) == member.root)
515 cluster.constraints = std::move(member.constraints);
516
517 // Add operation to the cluster.
518 if (auto op = member.source.dyn_cast<Operation *>())
519 cluster.operations.emplace_back(op);
520 }
521
522 llvm::SmallVector<Cluster> clusters;
523 for (auto &kv : root_clusters) {
524 Cluster &cluster = kv.getSecond();
525 // Skip degenerate clusters formed by a single basic block argument.
526 if (!cluster.operations.empty()) clusters.emplace_back(std::move(cluster));
527 }
528
529 LLVM_DEBUG(llvm::dbgs() << "Found " << clusters.size() << " clusters\n");
530
531 return clusters;
532 }
533
534 // -------------------------------------------------------------------------- //
535 // Create `tf_device.cluster` operation from the discovered ops cluster.
536 // -------------------------------------------------------------------------- //
537
CreateClusterOp(Cluster & cluster,StringAttr policy)538 tf_device::ClusterOp CreateClusterOp(Cluster &cluster, StringAttr policy) {
539 // Find all the values that are used outside of the cluster. These values
540 // will be returned from the created cluster operation.
541 llvm::DenseSet<Operation *> in_cluster;
542 for (Operation *op : cluster.operations) in_cluster.insert(op);
543
544 llvm::SetVector<Value> return_values;
545 llvm::SmallVector<Type> return_types;
546
547 for (Operation *op : cluster.operations)
548 for (OpOperand &use : op->getUses()) {
549 // User is inside the cluster.
550 if (in_cluster.contains(use.getOwner())) continue;
551 // Do not return the same value multiple times.
552 if (return_values.contains(use.get())) continue;
553
554 return_values.insert(use.get());
555 return_types.emplace_back(use.get().getType());
556 }
557
558 // Sort matched operations by their position in the block.
559 llvm::sort(cluster.operations, [](Operation *a, Operation *b) -> bool {
560 return a->isBeforeInBlock(b);
561 });
562
563 // Create tf_device::ClusterOp before the last operation in the block that
564 // is a part of a match set.
565 auto back = cluster.operations.back();
566 auto loc = back->getLoc();
567 OpBuilder builder(back);
568
569 auto cluster_op =
570 builder.create<tf_device::ClusterOp>(loc, return_types, policy);
571
572 // Create block in cluster_op's region and move 'cluster.operations' into
573 // it.
574 auto block = builder.createBlock(&cluster_op.body());
575 auto block_end = block->end();
576 for (auto op : cluster.operations) op->moveBefore(block, block_end);
577
578 // Add 'tf_device::ReturnOp' at the end of the block.
579 builder.setInsertionPointToEnd(block);
580 builder.create<tf_device::ReturnOp>(loc, return_values.getArrayRef());
581
582 // Set device attribute
583 if (auto device = back->getAttr(kDeviceAttr))
584 cluster_op->setAttr(kDeviceAttr, device);
585
586 // Update all users of the operations moved into the cluster region.
587 for (auto tuple : llvm::zip(return_values, cluster_op.getResults())) {
588 Value old_value = std::get<0>(tuple);
589 Value new_value = std::get<1>(tuple);
590 old_value.replaceUsesWithIf(new_value, [&](OpOperand &operand) -> bool {
591 // Do not update users in the same cluster.
592 return operand.getOwner()->getBlock() != block;
593 });
594 }
595
596 return cluster_op;
597 }
598
599 // -------------------------------------------------------------------------- //
600 // Helper functions for value constraints propagations and analysis.
601 // -------------------------------------------------------------------------- //
602
PropagateValuesConstraints(llvm::ArrayRef<Operation * > root,std::function<bool (Operation *)> filter,const ClusteringPolicySet & policies,ValuesConstraintSet & constraints,bool resolve)603 mlir::LogicalResult PropagateValuesConstraints(
604 llvm::ArrayRef<Operation *> root, std::function<bool(Operation *)> filter,
605 const ClusteringPolicySet &policies, ValuesConstraintSet &constraints,
606 bool resolve) {
607 // A set of constraints for operation results.
608 llvm::DenseMap<Operation *, ValuesConstraintSet> op_results_constraints;
609 assert(filter && "filter predicate must be defined");
610
611 // Use initial constraints to initialize op results constraints.
612 for (std::pair<Value, ValueConstraint> pair : constraints) {
613 Value value = pair.first;
614 ValueConstraint constraint = pair.second;
615
616 // Value must be defined by an operation and accepted by the filter.
617 Operation *op = value.getDefiningOp();
618 if (!op || !filter(op)) continue;
619
620 op_results_constraints[op].Insert(value, constraint);
621 }
622
623 // Keep a worklist of operations that need their constraints to be updated.
624 llvm::SetVector<Operation *> worklist;
625 for (Operation *op : root) worklist.insert(op);
626
627 while (!worklist.empty()) {
628 Operation *op = worklist.pop_back_val();
629
630 // Use results constraints to infer operands constraints.
631 const ValuesConstraintSet &results = op_results_constraints[op];
632 ValuesConstraintSet operands;
633
634 // Walk through all policies until we find one that matches the operation.
635 bool updated = false;
636 for (auto &policy : policies.policies()) {
637 auto matched =
638 policy->MatchAndUpdateConstraints(op, results, operands.Reset());
639 if (succeeded(matched)) {
640 updated = true;
641 break;
642 }
643 }
644
645 // Signal a failure if could not propagate non-empty constraints on the
646 // operation results to the operands.
647 if (!updated && !results.Empty()) {
648 op->emitError("failed to propagate results constraints");
649 return failure();
650 }
651
652 // Update results constraints based on inferred operands constraints.
653 operands.Walk([&](Value value, ValueConstraint constraint) {
654 // Resolve constraint based on the static type information.
655 if (resolve && succeeded(IsStaticallyResolved(value, constraint))) return;
656
657 // Update constraint for a value.
658 auto updated = constraints.Insert(value, constraint);
659 if (!updated.second) return;
660
661 // Maybe update constaint on the operation result, but do not follow
662 // operations that are not accepted by the filter predicate.
663 Operation *op = value.getDefiningOp();
664 if (!op || !filter(op)) return;
665
666 // Add updated operation to the worklist.
667 auto inserted = op_results_constraints[op].Insert(value, updated.first);
668 if (inserted.second) worklist.insert(op);
669 });
670 }
671
672 return success();
673 }
674
PropagateValuesConstraints(mlir::Region & region,const ClusteringPolicySet & policies,ValuesConstraintSet & constraints,bool resolve)675 mlir::LogicalResult PropagateValuesConstraints(
676 mlir::Region ®ion, const ClusteringPolicySet &policies,
677 ValuesConstraintSet &constraints, bool resolve) {
678 // Propagate constraints for all operations in the region.
679 llvm::SmallVector<Operation *> worklist;
680 region.walk([&](Operation *op) { worklist.emplace_back(op); });
681
682 // Propagate constraints only through operations inside the `region`.
683 auto filter = [&](Operation *op) -> bool {
684 return region.findAncestorBlockInRegion(*op->getBlock());
685 };
686
687 return PropagateValuesConstraints(worklist, filter, policies, constraints,
688 resolve);
689 }
690
EmitValueConstraintsRemarks(const ValuesConstraintSet & constraints)691 void EmitValueConstraintsRemarks(const ValuesConstraintSet &constraints) {
692 constraints.Walk([](Value value, ValueConstraint constraint) {
693 for (OpOperand &operand : value.getUses())
694 operand.getOwner()->emitRemark(
695 llvm::formatv("operand #{0} constrained to: {1}",
696 operand.getOperandNumber(), constraint));
697 });
698 }
699
EmitInputsConstraintsRemarks(FuncOp func,const ValuesConstraintSet & constraints)700 void EmitInputsConstraintsRemarks(FuncOp func,
701 const ValuesConstraintSet &constraints) {
702 constraints.Walk([&](Value value, ValueConstraint constraint) {
703 if (auto arg = value.dyn_cast<BlockArgument>())
704 if (arg.getOwner() == &func.body().front())
705 func.emitRemark(llvm::formatv("input #{0} constrained to: {1}",
706 arg.getArgNumber(), constraint));
707 });
708 }
709
InferFunctionBodyValuesConstraints(FuncOp func,ValuesConstraintSet & constraints)710 LogicalResult InferFunctionBodyValuesConstraints(
711 FuncOp func, ValuesConstraintSet &constraints) {
712 for (unsigned i = 0; i < func.getNumResults(); ++i) {
713 auto str = func.getResultAttrOfType<StringAttr>(i, "tf.constraint");
714 if (!str) continue;
715
716 ValueConstraint constraint = StringSwitch<ValueConstraint>(str.getValue())
717 .Case("rank", ValueConstraint::kRank)
718 .Case("shape", ValueConstraint::kShape)
719 .Case("value", ValueConstraint::kValue);
720
721 // Propagate constraints through function return operations.
722 for (Block &block : func.body()) {
723 ReturnOp ret = dyn_cast<ReturnOp>(block.back());
724 if (ret) constraints.Insert(ret.getOperand(i), constraint);
725 }
726 }
727
728 return success();
729 }
730
731 } // namespace TFDevice
732 } // namespace mlir
733