1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <queue>
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "mlir/IR/SymbolTable.h" // from @llvm-project
20 #include "mlir/Pass/Pass.h" // from @llvm-project
21 #include "mlir/Support/LogicalResult.h" // from @llvm-project
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
24 #include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h"
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes_detail.h"
27
28 namespace mlir {
29 namespace TFDevice {
30 namespace {
31
32 constexpr char kBadDecompositionMessage[] =
33 "Resource ops decomposition did not converge";
34 // TODO(prakalps): This can probably be reduced to much smaller number.
35 constexpr int kMaxIterations = 100;
36
37 // Populates `reachable_functions` with all functions that can be reached from
38 // device cluster ops.
PopulateClusterReachableFunctions(ModuleOp module,SmallPtrSetImpl<Operation * > & reachable_functions)39 void PopulateClusterReachableFunctions(
40 ModuleOp module, SmallPtrSetImpl<Operation*>& reachable_functions) {
41 SymbolTableCollection table;
42 SymbolUserMap symbol_map(table, module);
43
44 // Create map from caller to set of all callee(s).
45 llvm::DenseMap<FuncOp, llvm::DenseSet<FuncOp>> caller_callee_map;
46
47 // Use worklist to populate the set of reachable functions.
48 std::queue<FuncOp> function_worklist;
49
50 // Iterates over all functions within the module to (1) create caller-callee
51 // map, and (2) initialize function worklist with functions referenced from
52 // device cluster ops.
53 for (auto func : module.getOps<FuncOp>()) {
54 for (auto user : symbol_map.getUsers(func)) {
55 // Populate caller-callee map.
56 if (FuncOp caller = user->getParentOfType<FuncOp>())
57 caller_callee_map[caller].insert(func);
58 // Initialize function worklist with functions refrerenced in device
59 // cluster.
60 if (auto cluster = user->getParentOfType<tf_device::ClusterOp>()) {
61 if (reachable_functions.insert(func).second)
62 function_worklist.push(func);
63 }
64 }
65 }
66
67 // Uses worklist algorithm to insert all functions reachable from device
68 // cluster ops.
69 while (!function_worklist.empty()) {
70 FuncOp caller = function_worklist.front();
71 function_worklist.pop();
72 for (auto callee : caller_callee_map[caller]) {
73 if (reachable_functions.insert(callee).second)
74 function_worklist.push(callee);
75 }
76 }
77 }
78
79 // Applies patterns locally on ops within `cluster` until convergence or
80 // `max_iterations` are reached. Returns failure if resource ops decomposition
81 // does not converge after `max_iterations`.
82 // TODO(prakalps): This can be useful to a lot of other passes in bridge.
83 // Extract out as a separate utility.
ApplyPatternsLocallyUntilConverged(Operation * op_with_regions,FrozenRewritePatternSet & patterns,int max_iterations)84 LogicalResult ApplyPatternsLocallyUntilConverged(
85 Operation* op_with_regions, FrozenRewritePatternSet& patterns,
86 int max_iterations) {
87 bool changed = true;
88 int iteration = 0;
89 while (changed && (iteration++ < max_iterations)) {
90 changed = false;
91 auto walk_result =
92 op_with_regions->walk([&patterns, &changed](Operation* operation) {
93 bool op_changed;
94 if (failed(applyOpPatternsAndFold(operation, patterns, &op_changed)))
95 return WalkResult::interrupt();
96 changed |= op_changed;
97 return WalkResult::advance();
98 });
99 if (walk_result.wasInterrupted()) return failure();
100 }
101 // Return failure is `op_with_region` was modified changed in last iteration.
102 return success(!changed);
103 }
104
105 // Applies patterns in only device clusters and functions reachable from such
106 // clusters. Returns failure if it fails to converge in `max_iterations`.
107 // TODO(prakalps): This can be useful to a lot of other passes in bridge.
108 // Extract out as a separate utility.
ApplyPatternsInClusterAndReachableFunctions(ModuleOp module,FrozenRewritePatternSet & patterns,int max_iterations)109 LogicalResult ApplyPatternsInClusterAndReachableFunctions(
110 ModuleOp module, FrozenRewritePatternSet& patterns, int max_iterations) {
111 SmallPtrSet<Operation*, 16> reachable_functions;
112 PopulateClusterReachableFunctions(module, reachable_functions);
113
114 // Apply patterns to reachable functions.
115 for (Operation* op : reachable_functions) {
116 assert(isa<FuncOp>(op));
117 if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
118 return op->emitError() << kBadDecompositionMessage;
119 }
120 }
121
122 // Apply patterns to device cluster ops.
123 // Note: This module search for cluster ops is a bit wasteful as we could have
124 // collected many cluster ops when we were populating reachable functions. But
125 // we would still need to do a walk to find all clusters that do not
126 // reference any function.
127 for (FuncOp func : module.getOps<FuncOp>()) {
128 // If we have already applied patterns to a function then we can skip
129 // applying patterns to any device clusters it contains.
130 if (reachable_functions.contains(func)) continue;
131
132 auto walk_result = func.walk([&](tf_device::ClusterOp cluster) {
133 // Cluster ops are not isolated from above so we cannot use
134 // `applyPatternsAndFoldGreedily` utility. Instead we apply patterns
135 // locally on each op within the cluster until convergence.
136 if (failed(ApplyPatternsLocallyUntilConverged(cluster, patterns,
137 max_iterations))) {
138 cluster.emitError() << kBadDecompositionMessage;
139 return WalkResult::interrupt();
140 }
141 return WalkResult::advance();
142 });
143 if (walk_result.wasInterrupted()) return failure();
144 }
145
146 return success();
147 }
148
149 struct DecomposeResourceOpsPass
150 : public DecomposeResourceOpsPassBase<DecomposeResourceOpsPass> {
runOnFunctionmlir::TFDevice::__anonf45b070f0111::DecomposeResourceOpsPass151 void runOnFunction() override {
152 // Add lowering patterns to the list.
153 OwningRewritePatternList patterns(&getContext());
154 TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns);
155
156 if (failed(
157 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
158 getFunction().emitError() << kBadDecompositionMessage;
159 signalPassFailure();
160 }
161 }
162 };
163
164 struct DecomposeResourceOpsInClusterPass
165 : public DecomposeResourceOpsInClusterPassBase<
166 DecomposeResourceOpsInClusterPass> {
runOnOperationmlir::TFDevice::__anonf45b070f0111::DecomposeResourceOpsInClusterPass167 void runOnOperation() override {
168 // Add lowering patterns to the list.
169 OwningRewritePatternList patterns(&getContext());
170 TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns);
171 FrozenRewritePatternSet frozen_patterns(std::move(patterns));
172
173 if (failed(ApplyPatternsInClusterAndReachableFunctions(
174 getOperation(), frozen_patterns, kMaxIterations)))
175 signalPassFailure();
176 }
177 };
178
179 } // namespace
180
CreateDecomposeResourceOpsPass()181 std::unique_ptr<OperationPass<FuncOp>> CreateDecomposeResourceOpsPass() {
182 return std::make_unique<DecomposeResourceOpsPass>();
183 }
184
185 std::unique_ptr<OperationPass<ModuleOp>>
CreateDecomposeResourceOpsInClusterPass()186 CreateDecomposeResourceOpsInClusterPass() {
187 return std::make_unique<DecomposeResourceOpsInClusterPass>();
188 }
189
190 } // namespace TFDevice
191 } // namespace mlir
192
193