1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
21 #include "mlir/IR/Builders.h" // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
24 #include "mlir/IR/Identifier.h" // from @llvm-project
25 #include "mlir/IR/Location.h" // from @llvm-project
26 #include "mlir/IR/MLIRContext.h" // from @llvm-project
27 #include "mlir/IR/Matchers.h" // from @llvm-project
28 #include "mlir/IR/SymbolTable.h" // from @llvm-project
29 #include "mlir/Pass/Pass.h" // from @llvm-project
30 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
32 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35
36 namespace mlir {
37 namespace TFL {
38 namespace {
39
40 // This pass outlines the cond/body region of the TFL WhileOp into functions and
41 // replaces the regions with calls to these outlined functions.
42 class WhileOutlinePass
43 : public mlir::PassWrapper<WhileOutlinePass, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const44 void getDependentDialects(DialectRegistry& registry) const override {
45 registry.insert<TF::TensorFlowDialect>();
46 }
47
48 public:
WhileOutlinePass()49 explicit WhileOutlinePass() {}
50
getArgument() const51 StringRef getArgument() const final {
52 // This is the argument used to refer to the pass in
53 // the textual format (on the commandline for example).
54 return "tfl-while-loop-outline";
55 }
getDescription() const56 StringRef getDescription() const final {
57 // This is a brief description of the pass.
58 return "Hoist while op regions into functions";
59 }
60
61 private:
62 void runOnOperation() override;
63
64 // Outlines the regions of the WhileOp's cond and body and insert function
65 // calls instead,
66 void OutlineWhile(WhileOp while_op);
67
68 // Get unique name by using the loc to name mapping.
69 std::string GetName(Operation* op, StringRef suffix);
70
71 tensorflow::OpOrArgLocNameMapper mapper_;
72 };
73
GetName(Operation * op,StringRef suffix)74 std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
75 return (mapper_.GetUniqueName(op) + suffix).str();
76 }
77
78 // Returns whether the WhileOp is already outlined (e.g., only consists of calls
79 // to functions).
IsAlreadyOutlined(WhileOp while_op)80 bool IsAlreadyOutlined(WhileOp while_op) {
81 auto just_call = [](Region& region) {
82 auto it = region.front().begin();
83 if (!isa<CallOp>(*it)) return false;
84 ++it;
85 if (!isa<YieldOp>(*it)) return false;
86 return true;
87 };
88 return just_call(while_op.body()) && just_call(while_op.cond());
89 }
90
IsCompatibleTypeWithTFLCastOp(Type type)91 bool IsCompatibleTypeWithTFLCastOp(Type type) {
92 auto elemType = getElementTypeOrSelf(type);
93 // F32 and BF16 types are allowed.
94 if (elemType.isBF16() || elemType.isF32()) return true;
95
96 // I1, I16, I32, I64 types are allowed.
97 if (elemType.isInteger(1) || elemType.isInteger(16) ||
98 elemType.isInteger(32) || elemType.isInteger(64))
99 return true;
100
101 // Complex<F<32>> is allowed.
102 if (elemType.isa<ComplexType>() &&
103 elemType.cast<ComplexType>().getElementType().isF32())
104 return true;
105
106 // QUINT8 and UI8 are allowed.
107 if (elemType.isa<TF::Quint8Type>() ||
108 (elemType.isInteger(8) && elemType.cast<IntegerType>().isUnsigned()))
109 return true;
110
111 return false;
112 }
113
CreateOutlineFunc(StringRef name,Region & region,bool passthru_extra_args,int num_loop_carried,const llvm::SetVector<Value> & extern_values,const SmallVectorImpl<Type> & types,Location loc)114 FuncOp CreateOutlineFunc(StringRef name, Region& region,
115 bool passthru_extra_args, int num_loop_carried,
116 const llvm::SetVector<Value>& extern_values,
117 const SmallVectorImpl<Type>& types, Location loc) {
118 MLIRContext* context = loc.getContext();
119 OpBuilder builder(context);
120 FunctionType type;
121 if (passthru_extra_args) {
122 type = FunctionType::get(context, types, types);
123 } else {
124 SmallVector<Type, 4> result_types;
125 auto operands = region.front().getTerminator()->getOperandTypes();
126 result_types.append(operands.begin(), operands.end());
127 type = FunctionType::get(context, types, result_types);
128 }
129
130 auto outlined_func = builder.create<FuncOp>(loc, name, type);
131 outlined_func.getBody().takeBody(region);
132 Region& func_region = outlined_func.getBody();
133
134 // Replace all external uses with block args and update uses.
135 llvm::SmallVector<Value, 4> new_args;
136 new_args.reserve(extern_values.size());
137 Block& block = func_region.front();
138 for (Value value : extern_values) {
139 auto arg = block.addArgument(value.getType());
140 replaceAllUsesInRegionWith(value, arg, func_region);
141 new_args.push_back(arg);
142 }
143
144 // Replace yield op with return.
145 Operation* yield_op = outlined_func.getBody().front().getTerminator();
146 OpBuilder b(yield_op);
147 llvm::SmallVector<Value, 4> args;
148 auto loop_carried_yield_operands =
149 yield_op->getOperands().take_front(num_loop_carried);
150 args.reserve(loop_carried_yield_operands.size() + new_args.size());
151 if (passthru_extra_args) {
152 // Add operands of yield to the return, inserting casts if needed.
153 for (auto it : llvm::zip_first(loop_carried_yield_operands, types)) {
154 auto value = std::get<0>(it);
155 auto type = std::get<1>(it);
156 if (value.getType() == type) {
157 args.push_back(value);
158 } else {
159 if (IsCompatibleTypeWithTFLCastOp(value.getType()) &&
160 IsCompatibleTypeWithTFLCastOp(type)) {
161 auto cast = b.create<CastOp>(yield_op->getLoc(), type, value);
162 args.push_back(cast);
163 } else {
164 auto cast = b.create<TF::CastOp>(yield_op->getLoc(), type, value);
165 args.push_back(cast);
166 }
167 }
168 }
169 args.append(new_args.begin(), new_args.end());
170 } else {
171 args.append(yield_op->operand_begin(), yield_op->operand_end());
172 }
173 b.create<ReturnOp>(yield_op->getLoc(), args);
174 yield_op->erase();
175 SymbolTable(region.getParentOfType<ModuleOp>()).insert(outlined_func);
176 outlined_func.setPrivate();
177 return outlined_func;
178 }
179
180 // Replace region with call to outline function.
ReplaceRegionWithCall(StringRef name,Region & region,bool passthru_extra_args,int num_loop_carried,const llvm::SetVector<Value> & extern_values,const SmallVectorImpl<Type> & types,Location loc)181 void ReplaceRegionWithCall(StringRef name, Region& region,
182 bool passthru_extra_args, int num_loop_carried,
183 const llvm::SetVector<Value>& extern_values,
184 const SmallVectorImpl<Type>& types, Location loc) {
185 auto func = CreateOutlineFunc(name, region, passthru_extra_args,
186 num_loop_carried, extern_values, types, loc);
187 OpBuilder b(region);
188 // The body of the region is empty/has been outlined into the function.
189 auto block = b.createBlock(®ion);
190 SmallVector<Value, 4> new_operands;
191 new_operands.reserve(types.size());
192 for (Type t : llvm::makeArrayRef(types).drop_back(extern_values.size()))
193 new_operands.push_back(block->addArgument(t));
194 for (Value v : extern_values) new_operands.push_back(v);
195 auto call = b.create<CallOp>(loc, func, new_operands);
196 b.create<YieldOp>(loc, call.getResults());
197 }
198
OutlineWhile(WhileOp while_op)199 void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
200 OpBuilder builder(&getContext());
201 // Collect external values used.
202 llvm::SetVector<Value> extern_values;
203
204 // The basic block arguments correspond to values that are loop carried, while
205 // all those post are loop independent. Initialize extern_values with while_op
206 // not loop carried operands.
207 auto num_loop_carried = while_op.cond().getNumArguments();
208 auto not_carried_operands =
209 while_op.getOperands().drop_front(num_loop_carried);
210 extern_values.insert(not_carried_operands.begin(),
211 not_carried_operands.end());
212 auto old_extern_values_size = extern_values.size();
213
214 llvm::SmallVector<Region*, 2> regions{&while_op.cond(), &while_op.body()};
215 for (auto it : llvm::enumerate(regions)) {
216 llvm::SetVector<Value> region_extern_values;
217 getUsedValuesDefinedAbove(*it.value(), region_extern_values);
218
219 // Sink down constants into the functions.
220 for (auto extern_value : region_extern_values) {
221 if (!matchPattern(extern_value, m_Constant())) {
222 extern_values.insert(extern_value);
223 continue;
224 }
225 // Add constant at start of region.
226 auto const_builder =
227 OpBuilder(&it.value()->front(), it.value()->front().begin());
228 auto const_value = const_builder.clone(*extern_value.getDefiningOp());
229 replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
230 *it.value());
231 }
232 }
233
234 bool has_extra_extern_values = old_extern_values_size != extern_values.size();
235 // If an extern value is already an operand post the loop carried operands,
236 // then it need not be passed in again.
237 // Compute all the extra operands that have to be added to the while.
238 llvm::SetVector<Value> extra_operands;
239 if (has_extra_extern_values) {
240 auto new_extern =
241 extern_values.getArrayRef().drop_front(old_extern_values_size);
242 extra_operands.insert(new_extern.begin(), new_extern.end());
243 }
244
245 // Skip if already just calls.
246 if (extra_operands.empty() && IsAlreadyOutlined(while_op)) return;
247
248 // Collect new types.
249 SmallVector<Type, 4> types;
250 types.reserve(extra_operands.size() + while_op.getNumOperands());
251 for (Type type : while_op.cond().getArgumentTypes()) types.push_back(type);
252 for (Value operand : extern_values) types.push_back(operand.getType());
253
254 // Create outline function from region. Optional pass extra arguments through
255 // to yield.
256 ReplaceRegionWithCall(GetName(while_op.getOperation(), "_cond"),
257 while_op.cond(), false, num_loop_carried, extern_values,
258 types, while_op.getLoc());
259 ReplaceRegionWithCall(GetName(while_op.getOperation(), "_body"),
260 while_op.body(), true, num_loop_carried, extern_values,
261 types, while_op.getLoc());
262
263 // If there are extern values used then the result type of the while has to
264 // change, so replace with new while op.
265 if (extra_operands.empty()) return;
266
267 const int operands_size = while_op.getNumOperands() + extra_operands.size();
268 SmallVector<Value, 4> operands;
269 operands.reserve(operands_size);
270 operands.append(while_op.getOperands().begin(), while_op.getOperands().end());
271 operands.append(extra_operands.begin(), extra_operands.end());
272 SmallVector<Type, 4> new_types;
273 new_types.reserve(operands_size);
274 new_types.append(while_op.getResultTypes().begin(),
275 while_op.getResultTypes().end());
276 for (auto extra_operand : extra_operands)
277 new_types.push_back(extra_operand.getType());
278
279 auto new_while_op = OpBuilder(while_op).create<WhileOp>(
280 while_op.getLoc(), new_types, operands, while_op->getAttrs());
281 new_while_op.cond().takeBody(while_op.cond());
282 new_while_op.body().takeBody(while_op.body());
283 while_op.replaceAllUsesWith(
284 new_while_op.getResults().take_front(while_op.getNumResults()));
285 while_op.erase();
286 }
287
runOnOperation()288 void WhileOutlinePass::runOnOperation() {
289 getOperation().walk(
290 [&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
291 }
292 } // namespace
293
294 // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
CreateWhileOutlinePass()295 std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {
296 return std::make_unique<WhileOutlinePass>();
297 }
298
299 static PassRegistration<WhileOutlinePass> pass;
300
301 } // namespace TFL
302 } // namespace mlir
303