1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <unordered_map>
19 #include <utility>
20
21 #include "absl/strings/str_cat.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/Block.h" // from @llvm-project
33 #include "mlir/IR/Builders.h" // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
35 #include "mlir/IR/MLIRContext.h" // from @llvm-project
36 #include "mlir/IR/Value.h" // from @llvm-project
37 #include "mlir/Pass/Pass.h" // from @llvm-project
38 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
39 #include "mlir/Support/LLVM.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h"
41 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
42 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
43 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
45
46 namespace mlir {
47 namespace TFL {
48 namespace tac {
49 namespace {
50
51 // Subgraph here is actually an intermediate data structure holder for the ops:
52 // The ops within share the same "target", they're topologically sorted.
53 // The subgraph here will be later populated to generate func ops.
54 // All the subgraphs should not create cyclic dependencies:
55 // So we should not have:
56 // subgraph1
57 // \
58 // subgraph2
59 // /
60 // subgraph1
61 struct Subgraph {
62 // All ops must be inserted in it's topological order.
63 llvm::SetVector<Operation*> all_ops;
64 int subgraph_id;
65 InferenceDeviceType inference_device_type;
66 };
67
68 // This will exclude arguments & consts & quantize/dequantize ops.
IsTFLNonConstQuatnizeOp(Operation * op)69 inline bool IsTFLNonConstQuatnizeOp(Operation* op) {
70 return IsTFLDialectNonConstOp(op) && IsTFLNonQuantDequantizeOp(op);
71 }
72
IsTFLNonConstQuatnizeOp(const Value & value)73 inline bool IsTFLNonConstQuatnizeOp(const Value& value) {
74 auto* op = value.getDefiningOp();
75 if (op == nullptr) return false;
76 return IsTFLNonConstQuatnizeOp(op);
77 }
78
79 // This pass will group those ops (non-const TFL dialect ops) have the same
80 // target together and raise them as FuncOps.
81 // See the following Example:
82 //
83 // op1 (GPU)
84 // \ op2 (GPU)
85 // \ |
86 // \ op3 (GPU)
87 // \ /
88 // op4 (CPU)
89 //
90 // will be raised as 3 subgraphs:
91 // Subgraph 1: {op1}, GPU -> Func_1_GPU
92 // Subgraph 2: {op2, op3}, GPU -> Func_2_GPU
93 // Subgraph 3: {op4} CPU -> Func_3_CPU
94 //
95 // MainFunc:
96 // %0 = call @Func_1_GPU
97 // %1 = call @Func_2_GPU
98 // %2 = call @Func_3_CPU(%0, %1)
99 class RaiseTargetSubgraphsPass
100 : public mlir::PassWrapper<RaiseTargetSubgraphsPass,
101 mlir::OperationPass<ModuleOp>> {
102 private:
103 void runOnOperation() override;
104
105 void RaiseTargetSubgraphsForBlock(Block* block, OpBuilder* builder,
106 ModuleOp module);
107
108 void ExtractSubgraphToFunc(Subgraph* subgraph, OpBuilder* builder,
109 ModuleOp module);
110
111 FuncOp BuildFuncOp(Subgraph* subgraph, OpBuilder* builder, ModuleOp module_op,
112 SmallVector<Value, 4>* inputs,
113 SmallVector<Value, 4>* outputs,
114 InferenceDeviceType* inference_device_type);
115
116 int subgraph_count_ = 0;
117 };
118
119 // This is to collect input arguments for the given set of ops.
120 // See the example:
121 //
122 // value1 value2
123 // \ /
124 // op1
125 // \ value3
126 // \ /
127 // op2
128 // |
129 // op3
130 //
131 // Then the arguments will be {value1, value2, value3}
CollectInputs(const llvm::SetVector<Operation * > & all_ops,SmallVector<Value,4> * inputs)132 void CollectInputs(const llvm::SetVector<Operation*>& all_ops,
133 SmallVector<Value, 4>* inputs) {
134 for (Operation* op : all_ops) {
135 for (Value input : op->getOperands()) {
136 Operation* input_op = input.getDefiningOp();
137 const bool input_within_subgraph =
138 (input_op && all_ops.count(input_op) == 1);
139 if (!input_within_subgraph) {
140 inputs->push_back(input);
141 }
142 }
143 }
144 }
145
146 // This is to collect outputs arguments for the given set of ops.
147 // See the example:
148 //
149 // op1
150 // / \
151 // value1 \
152 // op2
153 // | \
154 // op3 value2
155 // |
156 // value3
157 //
158 // Then the arguments will be {value1, value2, value3}
CollectOutputs(const llvm::SetVector<Operation * > & all_ops,SmallVector<Value,4> * outputs)159 void CollectOutputs(const llvm::SetVector<Operation*>& all_ops,
160 SmallVector<Value, 4>* outputs) {
161 for (Operation* op : all_ops) {
162 for (Value output : op->getResults()) {
163 bool output_consumed_outside_subgraph = false;
164 for (Operation* consumer : output.getUsers()) {
165 if (all_ops.count(consumer) == 0) {
166 output_consumed_outside_subgraph = true;
167 }
168 }
169 if (output_consumed_outside_subgraph) {
170 outputs->push_back(output);
171 }
172 }
173 }
174 }
175
BuildTypes(const SmallVector<Value,4> & values,SmallVector<Type,4> * types)176 void BuildTypes(const SmallVector<Value, 4>& values,
177 SmallVector<Type, 4>* types) {
178 for (auto value : values) {
179 types->push_back(value.getType());
180 }
181 }
182
GetFunctionName(const Subgraph & subgrpah,std::string * function_name,std::string * interface_name)183 void GetFunctionName(const Subgraph& subgrpah, std::string* function_name,
184 std::string* interface_name) {
185 *interface_name = absl::StrCat("func_", std::to_string(subgrpah.subgraph_id));
186 *function_name = absl::StrCat(
187 (*interface_name), "_", subgrpah.inference_device_type.hardware, "_",
188 GetInferenceString(subgrpah.inference_device_type.inference_type));
189 }
190
BuildFuncOp(Subgraph * subgraph,OpBuilder * builder,ModuleOp module_op,SmallVector<Value,4> * inputs,SmallVector<Value,4> * outputs,InferenceDeviceType * inference_device_type)191 FuncOp RaiseTargetSubgraphsPass::BuildFuncOp(
192 Subgraph* subgraph, OpBuilder* builder, ModuleOp module_op,
193 SmallVector<Value, 4>* inputs, SmallVector<Value, 4>* outputs,
194 InferenceDeviceType* inference_device_type) {
195 CollectInputs(subgraph->all_ops, inputs);
196 CollectOutputs(subgraph->all_ops, outputs);
197
198 SmallVector<Type, 4> input_types;
199 SmallVector<Type, 4> return_types;
200
201 BuildTypes(*inputs, &input_types);
202 BuildTypes(*outputs, &return_types);
203
204 FunctionType function_type =
205 builder->getFunctionType(input_types, return_types);
206
207 SmallVector<NamedAttribute, 4> attrs;
208 // Function name.
209 std::string function_name;
210 std::string interface_name;
211 GetFunctionName(*subgraph, &function_name, &interface_name);
212 attrs.push_back(builder->getNamedAttr(
213 kInterfaceNameAttr, builder->getStringAttr(interface_name)));
214
215 // Inference Device type.
216 attrs.push_back(builder->getNamedAttr(
217 kDevice,
218 builder->getStringAttr(subgraph->inference_device_type.hardware)));
219 attrs.push_back(builder->getNamedAttr(
220 kInferenceType, builder->getStringAttr(GetInferenceString(
221 subgraph->inference_device_type.inference_type))));
222 *inference_device_type = subgraph->inference_device_type;
223
224 FuncOp new_func = FuncOp::create(builder->getUnknownLoc(), function_name,
225 function_type, llvm::makeArrayRef(attrs));
226 new_func.setPrivate();
227
228 new_func.addEntryBlock();
229
230 // Function argument mapping.
231 llvm::DenseMap<Value, int> function_argument_mapping;
232 for (int i = 0; i < inputs->size(); ++i) {
233 function_argument_mapping.insert({(*inputs)[i], i});
234 }
235
236 OpBuilder function_builder(new_func.getBody());
237
238 llvm::DenseMap<Operation*, Operation*> op_cloned_op_mapping;
239 llvm::DenseMap<Value, Value> output_cloned_op_output_mapping;
240 for (Operation* op : subgraph->all_ops) {
241 Operation* cloned_op = function_builder.clone(*op);
242 op_cloned_op_mapping.insert({op, cloned_op});
243 for (int i = 0; i < op->getNumResults(); ++i) {
244 Value op_output = op->getResult(i);
245 Value cloned_op_output = cloned_op->getResult(i);
246 output_cloned_op_output_mapping.insert({op_output, cloned_op_output});
247 }
248 }
249
250 for (Operation* op : subgraph->all_ops) {
251 Operation* cloned_op = op_cloned_op_mapping.find(op)->second;
252 for (int i = 0; i < op->getNumOperands(); ++i) {
253 Value input = op->getOperand(i);
254 Value cloned_op_input;
255 // If the input is actually a function argument.
256 if (function_argument_mapping.count(input) > 0) {
257 int function_argument = function_argument_mapping.find(input)->second;
258 cloned_op_input = new_func.getArgument(function_argument);
259 } else {
260 // The input is actually with in the subgraph.
261 cloned_op_input = output_cloned_op_output_mapping.find(input)->second;
262 }
263 cloned_op->setOperand(i, cloned_op_input);
264 }
265 }
266
267 SmallVector<Value, 4> final_outputs;
268 for (auto output : *outputs) {
269 auto cloned_output = output_cloned_op_output_mapping.find(output)->second;
270 final_outputs.push_back(cloned_output);
271 }
272 function_builder.create<mlir::ReturnOp>(new_func.getLoc(), final_outputs);
273
274 module_op.push_back(new_func);
275 return new_func;
276 }
277
ExtractSubgraphToFunc(Subgraph * subgraph,OpBuilder * builder,ModuleOp module)278 void RaiseTargetSubgraphsPass::ExtractSubgraphToFunc(Subgraph* subgraph,
279 OpBuilder* builder,
280 ModuleOp module) {
281 SmallVector<Value, 4> func_inputs;
282 SmallVector<Value, 4> func_outputs;
283
284 InferenceDeviceType inference_device_type;
285 FuncOp func = BuildFuncOp(subgraph, builder, module, &func_inputs,
286 &func_outputs, &inference_device_type);
287
288 // We just use the location of the last ops in the subgraph as the location
289 // for the call_op.
290 Operation* last_output = subgraph->all_ops.back();
291
292 // TODO(renjieliu): we should add func attributes to the call op.
293 builder->setInsertionPoint(last_output);
294 auto call_op =
295 builder->create<CallOp>(last_output->getLoc(), func, func_inputs);
296
297 auto interface_name = GetInterFaceName(func);
298
299 // Set call op attribute: interface_name, hardware.
300 call_op->setAttr(kInterfaceNameAttr,
301 builder->getStringAttr(interface_name.getValue()));
302 call_op->setAttr(kDevice,
303 builder->getStringAttr(inference_device_type.hardware));
304 call_op->setAttr(kInferenceType, builder->getStringAttr(GetInferenceString(
305 inference_device_type.inference_type)));
306
307 // Rewire the outputs.
308 if (call_op.getNumResults() != func_outputs.size()) {
309 module.emitError("the constructed func op has mismatched returns");
310 signalPassFailure();
311 }
312
313 for (int i = 0; i < func_outputs.size(); ++i) {
314 Value output = func_outputs[i];
315 output.replaceAllUsesWith(call_op.getResult(i));
316 }
317
318 // Clear the subgraph.
319 // Those ops should be removed.
320 for (auto* op : subgraph->all_ops) {
321 op->dropAllDefinedValueUses();
322 op->dropAllReferences();
323 op->erase();
324 }
325 }
326
327 // TODO(renjieliu): We may need to consider about side effect ops: we may leave
328 // those ops alone when building the subgraph.
RaiseTargetSubgraphsForBlock(Block * block,OpBuilder * builder,ModuleOp module)329 void RaiseTargetSubgraphsPass::RaiseTargetSubgraphsForBlock(Block* block,
330 OpBuilder* builder,
331 ModuleOp module) {
332 // This is a very naive implementation:
333 // It will greedily group adjacent ops that have the same inference type to a
334 // subgraph.
335 llvm::DenseMap<int, Subgraph> all_subgraphs;
336 llvm::Optional<InferenceDeviceType> previous_device_type = llvm::None;
337 int current_subgraph_id = -1;
338 for (auto& op : *block) {
339 // We only care about TFL dialect.
340 if (IsTFLNonConstQuatnizeOp(&op)) {
341 auto current_device_type = GetInferenceDeviceTypeForOp(&op);
342 if (!(current_device_type.hasValue() &&
343 current_device_type == previous_device_type)) {
344 // We should start a new subgraph.
345 Subgraph new_subgraph;
346 new_subgraph.inference_device_type = current_device_type.getValue();
347 new_subgraph.subgraph_id = subgraph_count_++;
348 all_subgraphs.insert({new_subgraph.subgraph_id, new_subgraph});
349 current_subgraph_id = new_subgraph.subgraph_id;
350 }
351 previous_device_type = current_device_type;
352 all_subgraphs.find(current_subgraph_id)->second.all_ops.insert(&op);
353 }
354 }
355
356 // Create FuncOp & replace with current uses based on those subgraphs.
357 for (auto& subgraph : all_subgraphs) {
358 ExtractSubgraphToFunc(&subgraph.second, builder, module);
359 }
360 }
361
runOnOperation()362 void RaiseTargetSubgraphsPass::runOnOperation() {
363 auto module = getOperation();
364 SmallVector<FuncOp, 16> funcs(module.getOps<FuncOp>());
365 for (auto func : funcs) {
366 for (auto& block : func) {
367 auto builder = OpBuilder::atBlockBegin(&block);
368 RaiseTargetSubgraphsForBlock(&block, &builder, module);
369 }
370 }
371 }
372
373 } // namespace
374
CreateRaiseTargetSubgraphsPass()375 std::unique_ptr<OperationPass<ModuleOp>> CreateRaiseTargetSubgraphsPass() {
376 return std::make_unique<RaiseTargetSubgraphsPass>();
377 }
378
379 static PassRegistration<RaiseTargetSubgraphsPass> pass(
380 "tfl-raise-target-subgraphs",
381 "This pass will merge those have target-annotated TFL IRs together & raise "
382 "them as a function.");
383
384 } // namespace tac
385 } // namespace TFL
386 } // namespace mlir
387