• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <unordered_map>
19 #include <utility>
20 
21 #include "absl/strings/str_cat.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/Block.h"  // from @llvm-project
33 #include "mlir/IR/Builders.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
35 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
36 #include "mlir/IR/Value.h"  // from @llvm-project
37 #include "mlir/Pass/Pass.h"  // from @llvm-project
38 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
39 #include "mlir/Support/LLVM.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h"
41 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
42 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
43 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
45 
46 namespace mlir {
47 namespace TFL {
48 namespace tac {
49 namespace {
50 
51 // Subgraph here is actually an intermediate data structure holder for the ops:
52 // The ops within share the same "target", they're topologically sorted.
53 // The subgraph here will be later populated to generate func ops.
54 // All the subgraphs should not create cyclic dependencies:
55 // So we should not have:
56 //     subgraph1
57 //             \
58 //            subgraph2
59 //            /
60 //       subgraph1
61 struct Subgraph {
62   // All ops must be inserted in it's topological order.
63   llvm::SetVector<Operation*> all_ops;
64   int subgraph_id;
65   InferenceDeviceType inference_device_type;
66 };
67 
68 // This will exclude arguments & consts & quantize/dequantize ops.
IsTFLNonConstQuatnizeOp(Operation * op)69 inline bool IsTFLNonConstQuatnizeOp(Operation* op) {
70   return IsTFLDialectNonConstOp(op) && IsTFLNonQuantDequantizeOp(op);
71 }
72 
IsTFLNonConstQuatnizeOp(const Value & value)73 inline bool IsTFLNonConstQuatnizeOp(const Value& value) {
74   auto* op = value.getDefiningOp();
75   if (op == nullptr) return false;
76   return IsTFLNonConstQuatnizeOp(op);
77 }
78 
79 // This pass will group those ops (non-const TFL dialect ops) have the same
80 // target together and raise them as FuncOps.
81 // See the following Example:
82 //
83 //     op1 (GPU)
84 //       \       op2 (GPU)
85 //       \        |
86 //        \      op3 (GPU)
87 //         \     /
88 //         op4 (CPU)
89 //
90 // will be raised as 3 subgraphs:
91 // Subgraph 1: {op1}, GPU -> Func_1_GPU
92 // Subgraph 2: {op2, op3}, GPU -> Func_2_GPU
93 // Subgraph 3: {op4} CPU -> Func_3_CPU
94 //
95 // MainFunc:
96 //   %0 = call @Func_1_GPU
97 //   %1 = call @Func_2_GPU
98 //   %2 = call @Func_3_CPU(%0, %1)
99 class RaiseTargetSubgraphsPass
100     : public mlir::PassWrapper<RaiseTargetSubgraphsPass,
101                                mlir::OperationPass<ModuleOp>> {
102  private:
103   void runOnOperation() override;
104 
105   void RaiseTargetSubgraphsForBlock(Block* block, OpBuilder* builder,
106                                     ModuleOp module);
107 
108   void ExtractSubgraphToFunc(Subgraph* subgraph, OpBuilder* builder,
109                              ModuleOp module);
110 
111   FuncOp BuildFuncOp(Subgraph* subgraph, OpBuilder* builder, ModuleOp module_op,
112                      SmallVector<Value, 4>* inputs,
113                      SmallVector<Value, 4>* outputs,
114                      InferenceDeviceType* inference_device_type);
115 
116   int subgraph_count_ = 0;
117 };
118 
119 // This is to collect input arguments for the given set of ops.
120 // See the example:
121 //
122 //   value1  value2
123 //    \     /
124 //      op1
125 //        \     value3
126 //        \   /
127 //         op2
128 //         |
129 //         op3
130 //
131 //  Then the arguments will be {value1, value2, value3}
CollectInputs(const llvm::SetVector<Operation * > & all_ops,SmallVector<Value,4> * inputs)132 void CollectInputs(const llvm::SetVector<Operation*>& all_ops,
133                    SmallVector<Value, 4>* inputs) {
134   for (Operation* op : all_ops) {
135     for (Value input : op->getOperands()) {
136       Operation* input_op = input.getDefiningOp();
137       const bool input_within_subgraph =
138           (input_op && all_ops.count(input_op) == 1);
139       if (!input_within_subgraph) {
140         inputs->push_back(input);
141       }
142     }
143   }
144 }
145 
146 // This is to collect outputs arguments for the given set of ops.
147 // See the example:
148 //
149 //      op1
150 //      /    \
151 //   value1   \
152 //           op2
153 //           |  \
154 //         op3  value2
155 //         |
156 //       value3
157 //
158 //  Then the arguments will be {value1, value2, value3}
CollectOutputs(const llvm::SetVector<Operation * > & all_ops,SmallVector<Value,4> * outputs)159 void CollectOutputs(const llvm::SetVector<Operation*>& all_ops,
160                     SmallVector<Value, 4>* outputs) {
161   for (Operation* op : all_ops) {
162     for (Value output : op->getResults()) {
163       bool output_consumed_outside_subgraph = false;
164       for (Operation* consumer : output.getUsers()) {
165         if (all_ops.count(consumer) == 0) {
166           output_consumed_outside_subgraph = true;
167         }
168       }
169       if (output_consumed_outside_subgraph) {
170         outputs->push_back(output);
171       }
172     }
173   }
174 }
175 
BuildTypes(const SmallVector<Value,4> & values,SmallVector<Type,4> * types)176 void BuildTypes(const SmallVector<Value, 4>& values,
177                 SmallVector<Type, 4>* types) {
178   for (auto value : values) {
179     types->push_back(value.getType());
180   }
181 }
182 
GetFunctionName(const Subgraph & subgrpah,std::string * function_name,std::string * interface_name)183 void GetFunctionName(const Subgraph& subgrpah, std::string* function_name,
184                      std::string* interface_name) {
185   *interface_name = absl::StrCat("func_", std::to_string(subgrpah.subgraph_id));
186   *function_name = absl::StrCat(
187       (*interface_name), "_", subgrpah.inference_device_type.hardware, "_",
188       GetInferenceString(subgrpah.inference_device_type.inference_type));
189 }
190 
BuildFuncOp(Subgraph * subgraph,OpBuilder * builder,ModuleOp module_op,SmallVector<Value,4> * inputs,SmallVector<Value,4> * outputs,InferenceDeviceType * inference_device_type)191 FuncOp RaiseTargetSubgraphsPass::BuildFuncOp(
192     Subgraph* subgraph, OpBuilder* builder, ModuleOp module_op,
193     SmallVector<Value, 4>* inputs, SmallVector<Value, 4>* outputs,
194     InferenceDeviceType* inference_device_type) {
195   CollectInputs(subgraph->all_ops, inputs);
196   CollectOutputs(subgraph->all_ops, outputs);
197 
198   SmallVector<Type, 4> input_types;
199   SmallVector<Type, 4> return_types;
200 
201   BuildTypes(*inputs, &input_types);
202   BuildTypes(*outputs, &return_types);
203 
204   FunctionType function_type =
205       builder->getFunctionType(input_types, return_types);
206 
207   SmallVector<NamedAttribute, 4> attrs;
208   // Function name.
209   std::string function_name;
210   std::string interface_name;
211   GetFunctionName(*subgraph, &function_name, &interface_name);
212   attrs.push_back(builder->getNamedAttr(
213       kInterfaceNameAttr, builder->getStringAttr(interface_name)));
214 
215   // Inference Device type.
216   attrs.push_back(builder->getNamedAttr(
217       kDevice,
218       builder->getStringAttr(subgraph->inference_device_type.hardware)));
219   attrs.push_back(builder->getNamedAttr(
220       kInferenceType, builder->getStringAttr(GetInferenceString(
221                           subgraph->inference_device_type.inference_type))));
222   *inference_device_type = subgraph->inference_device_type;
223 
224   FuncOp new_func = FuncOp::create(builder->getUnknownLoc(), function_name,
225                                    function_type, llvm::makeArrayRef(attrs));
226   new_func.setPrivate();
227 
228   new_func.addEntryBlock();
229 
230   // Function argument mapping.
231   llvm::DenseMap<Value, int> function_argument_mapping;
232   for (int i = 0; i < inputs->size(); ++i) {
233     function_argument_mapping.insert({(*inputs)[i], i});
234   }
235 
236   OpBuilder function_builder(new_func.getBody());
237 
238   llvm::DenseMap<Operation*, Operation*> op_cloned_op_mapping;
239   llvm::DenseMap<Value, Value> output_cloned_op_output_mapping;
240   for (Operation* op : subgraph->all_ops) {
241     Operation* cloned_op = function_builder.clone(*op);
242     op_cloned_op_mapping.insert({op, cloned_op});
243     for (int i = 0; i < op->getNumResults(); ++i) {
244       Value op_output = op->getResult(i);
245       Value cloned_op_output = cloned_op->getResult(i);
246       output_cloned_op_output_mapping.insert({op_output, cloned_op_output});
247     }
248   }
249 
250   for (Operation* op : subgraph->all_ops) {
251     Operation* cloned_op = op_cloned_op_mapping.find(op)->second;
252     for (int i = 0; i < op->getNumOperands(); ++i) {
253       Value input = op->getOperand(i);
254       Value cloned_op_input;
255       // If the input is actually a function argument.
256       if (function_argument_mapping.count(input) > 0) {
257         int function_argument = function_argument_mapping.find(input)->second;
258         cloned_op_input = new_func.getArgument(function_argument);
259       } else {
260         // The input is actually with in the subgraph.
261         cloned_op_input = output_cloned_op_output_mapping.find(input)->second;
262       }
263       cloned_op->setOperand(i, cloned_op_input);
264     }
265   }
266 
267   SmallVector<Value, 4> final_outputs;
268   for (auto output : *outputs) {
269     auto cloned_output = output_cloned_op_output_mapping.find(output)->second;
270     final_outputs.push_back(cloned_output);
271   }
272   function_builder.create<mlir::ReturnOp>(new_func.getLoc(), final_outputs);
273 
274   module_op.push_back(new_func);
275   return new_func;
276 }
277 
ExtractSubgraphToFunc(Subgraph * subgraph,OpBuilder * builder,ModuleOp module)278 void RaiseTargetSubgraphsPass::ExtractSubgraphToFunc(Subgraph* subgraph,
279                                                      OpBuilder* builder,
280                                                      ModuleOp module) {
281   SmallVector<Value, 4> func_inputs;
282   SmallVector<Value, 4> func_outputs;
283 
284   InferenceDeviceType inference_device_type;
285   FuncOp func = BuildFuncOp(subgraph, builder, module, &func_inputs,
286                             &func_outputs, &inference_device_type);
287 
288   // We just use the location of the last ops in the subgraph as the location
289   // for the call_op.
290   Operation* last_output = subgraph->all_ops.back();
291 
292   // TODO(renjieliu): we should add func attributes to the call op.
293   builder->setInsertionPoint(last_output);
294   auto call_op =
295       builder->create<CallOp>(last_output->getLoc(), func, func_inputs);
296 
297   auto interface_name = GetInterFaceName(func);
298 
299   // Set call op attribute: interface_name, hardware.
300   call_op->setAttr(kInterfaceNameAttr,
301                    builder->getStringAttr(interface_name.getValue()));
302   call_op->setAttr(kDevice,
303                    builder->getStringAttr(inference_device_type.hardware));
304   call_op->setAttr(kInferenceType, builder->getStringAttr(GetInferenceString(
305                                        inference_device_type.inference_type)));
306 
307   // Rewire the outputs.
308   if (call_op.getNumResults() != func_outputs.size()) {
309     module.emitError("the constructed func op has mismatched returns");
310     signalPassFailure();
311   }
312 
313   for (int i = 0; i < func_outputs.size(); ++i) {
314     Value output = func_outputs[i];
315     output.replaceAllUsesWith(call_op.getResult(i));
316   }
317 
318   // Clear the subgraph.
319   // Those ops should be removed.
320   for (auto* op : subgraph->all_ops) {
321     op->dropAllDefinedValueUses();
322     op->dropAllReferences();
323     op->erase();
324   }
325 }
326 
327 // TODO(renjieliu): We may need to consider about side effect ops: we may leave
328 // those ops alone when building the subgraph.
RaiseTargetSubgraphsForBlock(Block * block,OpBuilder * builder,ModuleOp module)329 void RaiseTargetSubgraphsPass::RaiseTargetSubgraphsForBlock(Block* block,
330                                                             OpBuilder* builder,
331                                                             ModuleOp module) {
332   // This is a very naive implementation:
333   // It will greedily group adjacent ops that have the same inference type to a
334   // subgraph.
335   llvm::DenseMap<int, Subgraph> all_subgraphs;
336   llvm::Optional<InferenceDeviceType> previous_device_type = llvm::None;
337   int current_subgraph_id = -1;
338   for (auto& op : *block) {
339     // We only care about TFL dialect.
340     if (IsTFLNonConstQuatnizeOp(&op)) {
341       auto current_device_type = GetInferenceDeviceTypeForOp(&op);
342       if (!(current_device_type.hasValue() &&
343             current_device_type == previous_device_type)) {
344         // We should start a new subgraph.
345         Subgraph new_subgraph;
346         new_subgraph.inference_device_type = current_device_type.getValue();
347         new_subgraph.subgraph_id = subgraph_count_++;
348         all_subgraphs.insert({new_subgraph.subgraph_id, new_subgraph});
349         current_subgraph_id = new_subgraph.subgraph_id;
350       }
351       previous_device_type = current_device_type;
352       all_subgraphs.find(current_subgraph_id)->second.all_ops.insert(&op);
353     }
354   }
355 
356   // Create FuncOp & replace with current uses based on those subgraphs.
357   for (auto& subgraph : all_subgraphs) {
358     ExtractSubgraphToFunc(&subgraph.second, builder, module);
359   }
360 }
361 
runOnOperation()362 void RaiseTargetSubgraphsPass::runOnOperation() {
363   auto module = getOperation();
364   SmallVector<FuncOp, 16> funcs(module.getOps<FuncOp>());
365   for (auto func : funcs) {
366     for (auto& block : func) {
367       auto builder = OpBuilder::atBlockBegin(&block);
368       RaiseTargetSubgraphsForBlock(&block, &builder, module);
369     }
370   }
371 }
372 
373 }  // namespace
374 
CreateRaiseTargetSubgraphsPass()375 std::unique_ptr<OperationPass<ModuleOp>> CreateRaiseTargetSubgraphsPass() {
376   return std::make_unique<RaiseTargetSubgraphsPass>();
377 }
378 
379 static PassRegistration<RaiseTargetSubgraphsPass> pass(
380     "tfl-raise-target-subgraphs",
381     "This pass will merge those have target-annotated TFL IRs together & raise "
382     "them as a function.");
383 
384 }  // namespace tac
385 }  // namespace TFL
386 }  // namespace mlir
387