1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
21 #include "mlir/IR/Builders.h" // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
23 #include "mlir/IR/Identifier.h" // from @llvm-project
24 #include "mlir/IR/Location.h" // from @llvm-project
25 #include "mlir/IR/MLIRContext.h" // from @llvm-project
26 #include "mlir/IR/Matchers.h" // from @llvm-project
27 #include "mlir/IR/SymbolTable.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
31 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
34
35 namespace mlir {
36 namespace TFL {
37 namespace {
38
39 // This pass outlines the cond/body region of the TFL WhileOp into functions and
40 // replaces the regions with calls to these outlined functions.
41 class WhileOutlinePass
42 : public mlir::PassWrapper<WhileOutlinePass, OperationPass<ModuleOp>> {
43 public:
WhileOutlinePass()44 explicit WhileOutlinePass() {}
45
46 private:
47 void runOnOperation() override;
48
49 // Outlines the regions of the WhileOp's cond and body and insert function
50 // calls instead,
51 void OutlineWhile(WhileOp while_op);
52
53 // Get unique name by using the loc to name mapping.
54 std::string GetName(Operation* op, StringRef suffix);
55
56 tensorflow::OpOrArgLocNameMapper mapper_;
57 };
58
GetName(Operation * op,StringRef suffix)59 std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
60 return (mapper_.GetUniqueName(op) + suffix).str();
61 }
62
63 // Returns whether the WhileOp is already outlined (e.g., only consists of calls
64 // to functions).
IsAlreadyOutlined(WhileOp while_op)65 bool IsAlreadyOutlined(WhileOp while_op) {
66 auto just_call = [](Region& region) {
67 auto it = region.front().begin();
68 if (!isa<CallOp>(*it)) return false;
69 ++it;
70 if (!isa<YieldOp>(*it)) return false;
71 return true;
72 };
73 return just_call(while_op.body()) && just_call(while_op.cond());
74 }
75
IsCompatibleTypeWithTFLCastOp(Type type)76 bool IsCompatibleTypeWithTFLCastOp(Type type) {
77 auto elemType = getElementTypeOrSelf(type);
78 // F32 and BF16 types are allowed.
79 if (elemType.isBF16() || elemType.isF32()) return true;
80
81 // I1, I16, I32, I64 types are allowed.
82 if (elemType.isInteger(1) || elemType.isInteger(16) ||
83 elemType.isInteger(32) || elemType.isInteger(64))
84 return true;
85
86 // Complex<F<32>> is allowed.
87 if (elemType.isa<ComplexType>() &&
88 elemType.cast<ComplexType>().getElementType().isF32())
89 return true;
90
91 // QUINT8 and UI8 are allowed.
92 if (elemType.isa<TF::Quint8Type>() ||
93 (elemType.isInteger(8) && elemType.cast<IntegerType>().isUnsigned()))
94 return true;
95
96 return false;
97 }
98
OutlineWhile(WhileOp while_op)99 void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
100 OpBuilder builder(&getContext());
101 // Collect external values used.
102 llvm::SetVector<Value> extern_values;
103
104 // The basic block arguments correspond to values that are loop carried, while
105 // all those post are loop independent. Initialize extern_values with while_op
106 // not loop carried operands.
107 auto num_loop_carried = while_op.cond().getNumArguments();
108 auto not_carried_operands =
109 while_op.getOperands().drop_front(num_loop_carried);
110 extern_values.insert(not_carried_operands.begin(),
111 not_carried_operands.end());
112 auto old_extern_values_size = extern_values.size();
113
114 llvm::SmallVector<Region*, 2> regions{&while_op.cond(), &while_op.body()};
115 for (auto it : llvm::enumerate(regions)) {
116 llvm::SetVector<Value> region_extern_values;
117 getUsedValuesDefinedAbove(*it.value(), region_extern_values);
118
119 // Sink down constants into the functions.
120 for (auto extern_value : region_extern_values) {
121 if (!matchPattern(extern_value, m_Constant())) {
122 extern_values.insert(extern_value);
123 continue;
124 }
125 // Add constant at start of region.
126 auto const_builder =
127 OpBuilder(&it.value()->front(), it.value()->front().begin());
128 auto const_value = const_builder.clone(*extern_value.getDefiningOp());
129 replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
130 *it.value());
131 }
132 }
133
134 bool has_extra_extern_values = old_extern_values_size != extern_values.size();
135 // If an extern value is already an operand post the loop carried operands,
136 // then it need not be passed in again.
137 // Compute all the extra operands that have to be added to the while.
138 llvm::SetVector<Value> extra_operands;
139 if (has_extra_extern_values) {
140 auto new_extern =
141 extern_values.getArrayRef().drop_front(old_extern_values_size);
142 extra_operands.insert(new_extern.begin(), new_extern.end());
143 }
144
145 // Skip if already just calls.
146 if (extra_operands.empty() && IsAlreadyOutlined(while_op)) return;
147
148 // Collect new types.
149 SmallVector<Type, 4> types;
150 types.reserve(extra_operands.size() + while_op.getNumOperands());
151 for (Type type : while_op.cond().getArgumentTypes()) types.push_back(type);
152 for (Value operand : extern_values) types.push_back(operand.getType());
153
154 // Create outline function from region. Optional pass extra arguments through
155 // to yield.
156 SymbolTable symbol_table(getOperation());
157 auto create_outline_func = [&](StringRef name, Region& region,
158 bool passthru_extra_args) {
159 FunctionType type;
160 if (passthru_extra_args) {
161 type = FunctionType::get(&getContext(), types, types);
162 } else {
163 SmallVector<Type, 4> result_types;
164 auto operands = region.front().getTerminator()->getOperandTypes();
165 result_types.append(operands.begin(), operands.end());
166 type = FunctionType::get(&getContext(), types, result_types);
167 }
168
169 auto outlined_func = builder.create<FuncOp>(while_op.getLoc(), name, type);
170 outlined_func.getBody().takeBody(region);
171 Region& func_region = outlined_func.getBody();
172
173 // Replace all external uses with block args and update uses.
174 llvm::SmallVector<Value, 4> new_args;
175 new_args.reserve(extern_values.size());
176 Block& block = func_region.front();
177 for (Value value : extern_values) {
178 auto arg = block.addArgument(value.getType());
179 replaceAllUsesInRegionWith(value, arg, func_region);
180 new_args.push_back(arg);
181 }
182
183 // Replace yield op with return.
184 Operation* yield_op = outlined_func.getBody().front().getTerminator();
185 OpBuilder b(yield_op);
186 llvm::SmallVector<Value, 4> args;
187 auto loop_carried_yield_operands =
188 yield_op->getOperands().take_front(num_loop_carried);
189 args.reserve(loop_carried_yield_operands.size() + new_args.size());
190 if (passthru_extra_args) {
191 // Add operands of yield to the return, inserting casts if needed.
192 for (auto it : llvm::zip_first(loop_carried_yield_operands, types)) {
193 auto value = std::get<0>(it);
194 auto type = std::get<1>(it);
195 if (value.getType() == type) {
196 args.push_back(value);
197 } else {
198 if (IsCompatibleTypeWithTFLCastOp(value.getType()) &&
199 IsCompatibleTypeWithTFLCastOp(type)) {
200 auto cast = b.create<CastOp>(yield_op->getLoc(), type, value);
201 args.push_back(cast);
202 } else {
203 auto cast = b.create<TF::CastOp>(yield_op->getLoc(), type, value);
204 args.push_back(cast);
205 }
206 }
207 }
208 args.append(new_args.begin(), new_args.end());
209 } else {
210 args.append(yield_op->operand_begin(), yield_op->operand_end());
211 }
212 b.create<ReturnOp>(yield_op->getLoc(), args);
213 yield_op->erase();
214 symbol_table.insert(outlined_func);
215 outlined_func.setPrivate();
216 return outlined_func;
217 };
218
219 // Replace region with call to outline function.
220 auto replace_with_call = [&](StringRef name, Region& region,
221 bool passthru_extra_args) {
222 auto func = create_outline_func(name, region, passthru_extra_args);
223 OpBuilder b(region);
224 // The body of the region is empty/has been outlined into the function.
225 auto block = b.createBlock(®ion);
226 SmallVector<Value, 4> new_operands;
227 new_operands.reserve(types.size());
228 for (Type t : llvm::makeArrayRef(types).drop_back(extern_values.size()))
229 new_operands.push_back(block->addArgument(t));
230 for (Value v : extern_values) new_operands.push_back(v);
231 auto call = b.create<CallOp>(while_op.getLoc(), func, new_operands);
232 b.create<YieldOp>(while_op.getLoc(), call.getResults());
233 };
234
235 replace_with_call(GetName(while_op.getOperation(), "_cond"), while_op.cond(),
236 false);
237 replace_with_call(GetName(while_op.getOperation(), "_body"), while_op.body(),
238 true);
239
240 // If there are extern values used then the result type of the while has to
241 // change, so replace with new while op.
242 if (extra_operands.empty()) return;
243
244 const int operands_size = while_op.getNumOperands() + extra_operands.size();
245 SmallVector<Value, 4> operands;
246 operands.reserve(operands_size);
247 operands.append(while_op.getOperands().begin(), while_op.getOperands().end());
248 operands.append(extra_operands.begin(), extra_operands.end());
249 SmallVector<Type, 4> new_types;
250 new_types.reserve(operands_size);
251 new_types.append(while_op.getResultTypes().begin(),
252 while_op.getResultTypes().end());
253 for (auto extra_operand : extra_operands)
254 new_types.push_back(extra_operand.getType());
255
256 auto new_while_op = OpBuilder(while_op).create<WhileOp>(
257 while_op.getLoc(), new_types, operands, while_op.getAttrs());
258 new_while_op.cond().takeBody(while_op.cond());
259 new_while_op.body().takeBody(while_op.body());
260 while_op.replaceAllUsesWith(
261 new_while_op.getResults().take_front(while_op.getNumResults()));
262 while_op.erase();
263 }
264
runOnOperation()265 void WhileOutlinePass::runOnOperation() {
266 getOperation().walk(
267 [&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
268 }
269 } // namespace
270
271 // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
CreateWhileOutlinePass()272 std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {
273 return std::make_unique<WhileOutlinePass>();
274 }
275
276 static PassRegistration<WhileOutlinePass> pass(
277 "tfl-while-loop-outline", "Hoist while op regions into functions");
278
279 } // namespace TFL
280 } // namespace mlir
281