• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
21 #include "mlir/IR/Builders.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
24 #include "mlir/IR/Identifier.h"  // from @llvm-project
25 #include "mlir/IR/Location.h"  // from @llvm-project
26 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
27 #include "mlir/IR/Matchers.h"  // from @llvm-project
28 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
29 #include "mlir/Pass/Pass.h"  // from @llvm-project
30 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
32 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 
36 namespace mlir {
37 namespace TFL {
38 namespace {
39 
40 // This pass outlines the cond/body region of the TFL WhileOp into functions and
41 // replaces the regions with calls to these outlined functions.
42 class WhileOutlinePass
43     : public mlir::PassWrapper<WhileOutlinePass, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const44   void getDependentDialects(DialectRegistry& registry) const override {
45     registry.insert<TF::TensorFlowDialect>();
46   }
47 
48  public:
WhileOutlinePass()49   explicit WhileOutlinePass() {}
50 
getArgument() const51   StringRef getArgument() const final {
52     // This is the argument used to refer to the pass in
53     // the textual format (on the commandline for example).
54     return "tfl-while-loop-outline";
55   }
getDescription() const56   StringRef getDescription() const final {
57     // This is a brief description of the pass.
58     return "Hoist while op regions into functions";
59   }
60 
61  private:
62   void runOnOperation() override;
63 
64   // Outlines the regions of the WhileOp's cond and body and insert function
65   // calls instead,
66   void OutlineWhile(WhileOp while_op);
67 
68   // Get unique name by using the loc to name mapping.
69   std::string GetName(Operation* op, StringRef suffix);
70 
71   tensorflow::OpOrArgLocNameMapper mapper_;
72 };
73 
GetName(Operation * op,StringRef suffix)74 std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
75   return (mapper_.GetUniqueName(op) + suffix).str();
76 }
77 
78 // Returns whether the WhileOp is already outlined (e.g., only consists of calls
79 // to functions).
IsAlreadyOutlined(WhileOp while_op)80 bool IsAlreadyOutlined(WhileOp while_op) {
81   auto just_call = [](Region& region) {
82     auto it = region.front().begin();
83     if (!isa<CallOp>(*it)) return false;
84     ++it;
85     if (!isa<YieldOp>(*it)) return false;
86     return true;
87   };
88   return just_call(while_op.body()) && just_call(while_op.cond());
89 }
90 
IsCompatibleTypeWithTFLCastOp(Type type)91 bool IsCompatibleTypeWithTFLCastOp(Type type) {
92   auto elemType = getElementTypeOrSelf(type);
93   // F32 and BF16 types are allowed.
94   if (elemType.isBF16() || elemType.isF32()) return true;
95 
96   // I1, I16, I32, I64 types are allowed.
97   if (elemType.isInteger(1) || elemType.isInteger(16) ||
98       elemType.isInteger(32) || elemType.isInteger(64))
99     return true;
100 
101   // Complex<F<32>> is allowed.
102   if (elemType.isa<ComplexType>() &&
103       elemType.cast<ComplexType>().getElementType().isF32())
104     return true;
105 
106   // QUINT8 and UI8 are allowed.
107   if (elemType.isa<TF::Quint8Type>() ||
108       (elemType.isInteger(8) && elemType.cast<IntegerType>().isUnsigned()))
109     return true;
110 
111   return false;
112 }
113 
CreateOutlineFunc(StringRef name,Region & region,bool passthru_extra_args,int num_loop_carried,const llvm::SetVector<Value> & extern_values,const SmallVectorImpl<Type> & types,Location loc)114 FuncOp CreateOutlineFunc(StringRef name, Region& region,
115                          bool passthru_extra_args, int num_loop_carried,
116                          const llvm::SetVector<Value>& extern_values,
117                          const SmallVectorImpl<Type>& types, Location loc) {
118   MLIRContext* context = loc.getContext();
119   OpBuilder builder(context);
120   FunctionType type;
121   if (passthru_extra_args) {
122     type = FunctionType::get(context, types, types);
123   } else {
124     SmallVector<Type, 4> result_types;
125     auto operands = region.front().getTerminator()->getOperandTypes();
126     result_types.append(operands.begin(), operands.end());
127     type = FunctionType::get(context, types, result_types);
128   }
129 
130   auto outlined_func = builder.create<FuncOp>(loc, name, type);
131   outlined_func.getBody().takeBody(region);
132   Region& func_region = outlined_func.getBody();
133 
134   // Replace all external uses with block args and update uses.
135   llvm::SmallVector<Value, 4> new_args;
136   new_args.reserve(extern_values.size());
137   Block& block = func_region.front();
138   for (Value value : extern_values) {
139     auto arg = block.addArgument(value.getType());
140     replaceAllUsesInRegionWith(value, arg, func_region);
141     new_args.push_back(arg);
142   }
143 
144   // Replace yield op with return.
145   Operation* yield_op = outlined_func.getBody().front().getTerminator();
146   OpBuilder b(yield_op);
147   llvm::SmallVector<Value, 4> args;
148   auto loop_carried_yield_operands =
149       yield_op->getOperands().take_front(num_loop_carried);
150   args.reserve(loop_carried_yield_operands.size() + new_args.size());
151   if (passthru_extra_args) {
152     // Add operands of yield to the return, inserting casts if needed.
153     for (auto it : llvm::zip_first(loop_carried_yield_operands, types)) {
154       auto value = std::get<0>(it);
155       auto type = std::get<1>(it);
156       if (value.getType() == type) {
157         args.push_back(value);
158       } else {
159         if (IsCompatibleTypeWithTFLCastOp(value.getType()) &&
160             IsCompatibleTypeWithTFLCastOp(type)) {
161           auto cast = b.create<CastOp>(yield_op->getLoc(), type, value);
162           args.push_back(cast);
163         } else {
164           auto cast = b.create<TF::CastOp>(yield_op->getLoc(), type, value);
165           args.push_back(cast);
166         }
167       }
168     }
169     args.append(new_args.begin(), new_args.end());
170   } else {
171     args.append(yield_op->operand_begin(), yield_op->operand_end());
172   }
173   b.create<ReturnOp>(yield_op->getLoc(), args);
174   yield_op->erase();
175   SymbolTable(region.getParentOfType<ModuleOp>()).insert(outlined_func);
176   outlined_func.setPrivate();
177   return outlined_func;
178 }
179 
180 // Replace region with call to outline function.
ReplaceRegionWithCall(StringRef name,Region & region,bool passthru_extra_args,int num_loop_carried,const llvm::SetVector<Value> & extern_values,const SmallVectorImpl<Type> & types,Location loc)181 void ReplaceRegionWithCall(StringRef name, Region& region,
182                            bool passthru_extra_args, int num_loop_carried,
183                            const llvm::SetVector<Value>& extern_values,
184                            const SmallVectorImpl<Type>& types, Location loc) {
185   auto func = CreateOutlineFunc(name, region, passthru_extra_args,
186                                 num_loop_carried, extern_values, types, loc);
187   OpBuilder b(region);
188   // The body of the region is empty/has been outlined into the function.
189   auto block = b.createBlock(&region);
190   SmallVector<Value, 4> new_operands;
191   new_operands.reserve(types.size());
192   for (Type t : llvm::makeArrayRef(types).drop_back(extern_values.size()))
193     new_operands.push_back(block->addArgument(t));
194   for (Value v : extern_values) new_operands.push_back(v);
195   auto call = b.create<CallOp>(loc, func, new_operands);
196   b.create<YieldOp>(loc, call.getResults());
197 }
198 
OutlineWhile(WhileOp while_op)199 void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
200   OpBuilder builder(&getContext());
201   // Collect external values used.
202   llvm::SetVector<Value> extern_values;
203 
204   // The basic block arguments correspond to values that are loop carried, while
205   // all those post are loop independent. Initialize extern_values with while_op
206   // not loop carried operands.
207   auto num_loop_carried = while_op.cond().getNumArguments();
208   auto not_carried_operands =
209       while_op.getOperands().drop_front(num_loop_carried);
210   extern_values.insert(not_carried_operands.begin(),
211                        not_carried_operands.end());
212   auto old_extern_values_size = extern_values.size();
213 
214   llvm::SmallVector<Region*, 2> regions{&while_op.cond(), &while_op.body()};
215   for (auto it : llvm::enumerate(regions)) {
216     llvm::SetVector<Value> region_extern_values;
217     getUsedValuesDefinedAbove(*it.value(), region_extern_values);
218 
219     // Sink down constants into the functions.
220     for (auto extern_value : region_extern_values) {
221       if (!matchPattern(extern_value, m_Constant())) {
222         extern_values.insert(extern_value);
223         continue;
224       }
225       // Add constant at start of region.
226       auto const_builder =
227           OpBuilder(&it.value()->front(), it.value()->front().begin());
228       auto const_value = const_builder.clone(*extern_value.getDefiningOp());
229       replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
230                                  *it.value());
231     }
232   }
233 
234   bool has_extra_extern_values = old_extern_values_size != extern_values.size();
235   // If an extern value is already an operand post the loop carried operands,
236   // then it need not be passed in again.
237   // Compute all the extra operands that have to be added to the while.
238   llvm::SetVector<Value> extra_operands;
239   if (has_extra_extern_values) {
240     auto new_extern =
241         extern_values.getArrayRef().drop_front(old_extern_values_size);
242     extra_operands.insert(new_extern.begin(), new_extern.end());
243   }
244 
245   // Skip if already just calls.
246   if (extra_operands.empty() && IsAlreadyOutlined(while_op)) return;
247 
248   // Collect new types.
249   SmallVector<Type, 4> types;
250   types.reserve(extra_operands.size() + while_op.getNumOperands());
251   for (Type type : while_op.cond().getArgumentTypes()) types.push_back(type);
252   for (Value operand : extern_values) types.push_back(operand.getType());
253 
254   // Create outline function from region. Optional pass extra arguments through
255   // to yield.
256   ReplaceRegionWithCall(GetName(while_op.getOperation(), "_cond"),
257                         while_op.cond(), false, num_loop_carried, extern_values,
258                         types, while_op.getLoc());
259   ReplaceRegionWithCall(GetName(while_op.getOperation(), "_body"),
260                         while_op.body(), true, num_loop_carried, extern_values,
261                         types, while_op.getLoc());
262 
263   // If there are extern values used then the result type of the while has to
264   // change, so replace with new while op.
265   if (extra_operands.empty()) return;
266 
267   const int operands_size = while_op.getNumOperands() + extra_operands.size();
268   SmallVector<Value, 4> operands;
269   operands.reserve(operands_size);
270   operands.append(while_op.getOperands().begin(), while_op.getOperands().end());
271   operands.append(extra_operands.begin(), extra_operands.end());
272   SmallVector<Type, 4> new_types;
273   new_types.reserve(operands_size);
274   new_types.append(while_op.getResultTypes().begin(),
275                    while_op.getResultTypes().end());
276   for (auto extra_operand : extra_operands)
277     new_types.push_back(extra_operand.getType());
278 
279   auto new_while_op = OpBuilder(while_op).create<WhileOp>(
280       while_op.getLoc(), new_types, operands, while_op->getAttrs());
281   new_while_op.cond().takeBody(while_op.cond());
282   new_while_op.body().takeBody(while_op.body());
283   while_op.replaceAllUsesWith(
284       new_while_op.getResults().take_front(while_op.getNumResults()));
285   while_op.erase();
286 }
287 
runOnOperation()288 void WhileOutlinePass::runOnOperation() {
289   getOperation().walk(
290       [&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
291 }
292 }  // namespace
293 
294 // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
CreateWhileOutlinePass()295 std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {
296   return std::make_unique<WhileOutlinePass>();
297 }
298 
299 static PassRegistration<WhileOutlinePass> pass;
300 
301 }  // namespace TFL
302 }  // namespace mlir
303