• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <queue>
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
20 #include "mlir/Pass/Pass.h"  // from @llvm-project
21 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
24 #include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h"
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes_detail.h"
27 
28 namespace mlir {
29 namespace TFDevice {
30 namespace {
31 
32 constexpr char kBadDecompositionMessage[] =
33     "Resource ops decomposition did not converge";
34 // TODO(prakalps): This can probably be reduced to much smaller number.
35 constexpr int kMaxIterations = 100;
36 
37 // Populates `reachable_functions` with all functions that can be reached from
38 // device cluster ops.
PopulateClusterReachableFunctions(ModuleOp module,SmallPtrSetImpl<Operation * > & reachable_functions)39 void PopulateClusterReachableFunctions(
40     ModuleOp module, SmallPtrSetImpl<Operation*>& reachable_functions) {
41   SymbolTableCollection table;
42   SymbolUserMap symbol_map(table, module);
43 
44   // Create map from caller to set of all callee(s).
45   llvm::DenseMap<FuncOp, llvm::DenseSet<FuncOp>> caller_callee_map;
46 
47   // Use worklist to populate the set of reachable functions.
48   std::queue<FuncOp> function_worklist;
49 
50   // Iterates over all functions within the module to (1) create caller-callee
51   // map, and (2) initialize function worklist with functions referenced from
52   // device cluster ops.
53   for (auto func : module.getOps<FuncOp>()) {
54     for (auto user : symbol_map.getUsers(func)) {
55       // Populate caller-callee map.
56       if (FuncOp caller = user->getParentOfType<FuncOp>())
57         caller_callee_map[caller].insert(func);
58       // Initialize function worklist with functions refrerenced in device
59       // cluster.
60       if (auto cluster = user->getParentOfType<tf_device::ClusterOp>()) {
61         if (reachable_functions.insert(func).second)
62           function_worklist.push(func);
63       }
64     }
65   }
66 
67   // Uses worklist algorithm to insert all functions reachable from device
68   // cluster ops.
69   while (!function_worklist.empty()) {
70     FuncOp caller = function_worklist.front();
71     function_worklist.pop();
72     for (auto callee : caller_callee_map[caller]) {
73       if (reachable_functions.insert(callee).second)
74         function_worklist.push(callee);
75     }
76   }
77 }
78 
79 // Applies patterns locally on ops within `cluster` until convergence or
80 // `max_iterations` are reached. Returns failure if resource ops decomposition
81 // does not converge after `max_iterations`.
82 // TODO(prakalps): This can be useful to a lot of other passes in bridge.
83 // Extract out as a separate utility.
ApplyPatternsLocallyUntilConverged(Operation * op_with_regions,FrozenRewritePatternSet & patterns,int max_iterations)84 LogicalResult ApplyPatternsLocallyUntilConverged(
85     Operation* op_with_regions, FrozenRewritePatternSet& patterns,
86     int max_iterations) {
87   bool changed = true;
88   int iteration = 0;
89   while (changed && (iteration++ < max_iterations)) {
90     changed = false;
91     auto walk_result =
92         op_with_regions->walk([&patterns, &changed](Operation* operation) {
93           bool op_changed;
94           if (failed(applyOpPatternsAndFold(operation, patterns, &op_changed)))
95             return WalkResult::interrupt();
96           changed |= op_changed;
97           return WalkResult::advance();
98         });
99     if (walk_result.wasInterrupted()) return failure();
100   }
101   // Return failure is `op_with_region` was modified changed in last iteration.
102   return success(!changed);
103 }
104 
105 // Applies patterns in only device clusters and functions reachable from such
106 // clusters. Returns failure if it fails to converge in `max_iterations`.
107 // TODO(prakalps): This can be useful to a lot of other passes in bridge.
108 // Extract out as a separate utility.
ApplyPatternsInClusterAndReachableFunctions(ModuleOp module,FrozenRewritePatternSet & patterns,int max_iterations)109 LogicalResult ApplyPatternsInClusterAndReachableFunctions(
110     ModuleOp module, FrozenRewritePatternSet& patterns, int max_iterations) {
111   SmallPtrSet<Operation*, 16> reachable_functions;
112   PopulateClusterReachableFunctions(module, reachable_functions);
113 
114   // Apply patterns to reachable functions.
115   for (Operation* op : reachable_functions) {
116     assert(isa<FuncOp>(op));
117     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
118       return op->emitError() << kBadDecompositionMessage;
119     }
120   }
121 
122   // Apply patterns to device cluster ops.
123   // Note: This module search for cluster ops is a bit wasteful as we could have
124   // collected many cluster ops when we were populating reachable functions. But
125   // we would still need to do a walk to find all clusters that do not
126   // reference any function.
127   for (FuncOp func : module.getOps<FuncOp>()) {
128     // If we have already applied patterns to a function then we can skip
129     // applying patterns to any device clusters it contains.
130     if (reachable_functions.contains(func)) continue;
131 
132     auto walk_result = func.walk([&](tf_device::ClusterOp cluster) {
133       // Cluster ops are not isolated from above so we cannot use
134       // `applyPatternsAndFoldGreedily` utility. Instead we apply patterns
135       // locally on each op within the cluster until convergence.
136       if (failed(ApplyPatternsLocallyUntilConverged(cluster, patterns,
137                                                     max_iterations))) {
138         cluster.emitError() << kBadDecompositionMessage;
139         return WalkResult::interrupt();
140       }
141       return WalkResult::advance();
142     });
143     if (walk_result.wasInterrupted()) return failure();
144   }
145 
146   return success();
147 }
148 
149 struct DecomposeResourceOpsPass
150     : public DecomposeResourceOpsPassBase<DecomposeResourceOpsPass> {
runOnFunctionmlir::TFDevice::__anonf45b070f0111::DecomposeResourceOpsPass151   void runOnFunction() override {
152     // Add lowering patterns to the list.
153     OwningRewritePatternList patterns(&getContext());
154     TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns);
155 
156     if (failed(
157             applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
158       getFunction().emitError() << kBadDecompositionMessage;
159       signalPassFailure();
160     }
161   }
162 };
163 
164 struct DecomposeResourceOpsInClusterPass
165     : public DecomposeResourceOpsInClusterPassBase<
166           DecomposeResourceOpsInClusterPass> {
runOnOperationmlir::TFDevice::__anonf45b070f0111::DecomposeResourceOpsInClusterPass167   void runOnOperation() override {
168     // Add lowering patterns to the list.
169     OwningRewritePatternList patterns(&getContext());
170     TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns);
171     FrozenRewritePatternSet frozen_patterns(std::move(patterns));
172 
173     if (failed(ApplyPatternsInClusterAndReachableFunctions(
174             getOperation(), frozen_patterns, kMaxIterations)))
175       signalPassFailure();
176   }
177 };
178 
179 }  // namespace
180 
CreateDecomposeResourceOpsPass()181 std::unique_ptr<OperationPass<FuncOp>> CreateDecomposeResourceOpsPass() {
182   return std::make_unique<DecomposeResourceOpsPass>();
183 }
184 
185 std::unique_ptr<OperationPass<ModuleOp>>
CreateDecomposeResourceOpsInClusterPass()186 CreateDecomposeResourceOpsInClusterPass() {
187   return std::make_unique<DecomposeResourceOpsInClusterPass>();
188 }
189 
190 }  // namespace TFDevice
191 }  // namespace mlir
192 
193