1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/Identifier.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/MLIRContext.h" // from @llvm-project
34 #include "mlir/IR/Operation.h" // from @llvm-project
35 #include "mlir/IR/SymbolTable.h" // from @llvm-project
36 #include "mlir/IR/Visitors.h" // from @llvm-project
37 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
38 #include "mlir/Pass/Pass.h" // from @llvm-project
39 #include "mlir/Support/LLVM.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
42 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
44 #include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
45 #include "tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h"
46 #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
49
50 // The cmd line flag to turn on/off Tf.Text API fusion.
51 // NOLINTNEXTLINE
52 static llvm::cl::opt<bool> fuse_tftext_flag(
53 "tfl-fuse-tftext", llvm::cl::value_desc("bool"),
54 llvm::cl::desc("Fuse TF.Text API ops when it's true"),
55 llvm::cl::init(false));
56
57 namespace mlir {
58 namespace TFL {
59 namespace {
60
61 constexpr char kTFAPIImplements[] = "tf.api_implements";
62 constexpr char kTFTextAPIPrefix[] = "tftext:";
63 constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
64 constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
65 constexpr char kCustomMaxUnpooling[] = "addons:MaxUnpooling2D";
66 constexpr char kCustomDenseImageWarp[] = "addons:DenseImageWarp";
67
68 using mlir::TF::FuncAttr;
69
70 // Abstracts the conversion of the embedded lookup composite function.
71 class ConvertEmbeddedLookupFunc {
72 public:
ConvertEmbeddedLookupFunc(FuncOp func)73 explicit ConvertEmbeddedLookupFunc(FuncOp func) : func_(func) {}
74
RewriteFunc()75 void RewriteFunc() {
76 func_->setAttr(kTFImplements,
77 StringAttr::get(func_.getContext(), "embedding_lookup"));
78 Value lookup = func_.getArgument(1);
79 Value value = func_.getArgument(0);
80 auto output_type = func_.getType().getResult(0);
81
82 OpBuilder builder(func_.getBody());
83 auto op = builder.create<mlir::TFL::EmbeddingLookupOp>(
84 func_.getLoc(), output_type, lookup, value);
85
86 builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResult());
87 }
88
VerifySignature()89 LogicalResult VerifySignature() {
90 if (func_.getNumArguments() != 2) {
91 return func_.emitWarning()
92 << "Invalid number of arguments in the embedding "
93 "matmul composite function";
94 }
95 if (func_.getType().getNumResults() != 1) {
96 return func_.emitWarning() << "Invalid number of results in the "
97 "embedding matmul composite function";
98 }
99 return success();
100 }
101
102 private:
103 FuncOp func_;
104 };
105
106 // This pass uses mechanisms listed in RFC:
107 // https://github.com/tensorflow/community/pull/113
108 // It prepares composite functions that are attributed to indicate
109 // a specific interface (LSTM, SVDF, Embedding lookup etc.) by replacing the
110 // body with the corresponding fused TFLite op. The replacement need not always
111 // be a fused op, though that is the primary use case.
112 class PrepareCompositeFunctionsPass
113 : public PassWrapper<PrepareCompositeFunctionsPass,
114 OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const115 void getDependentDialects(DialectRegistry& registry) const override {
116 registry.insert<TFL::TensorFlowLiteDialect>();
117 }
118
119 public:
PrepareCompositeFunctionsPass()120 explicit PrepareCompositeFunctionsPass() {}
121
getArgument() const122 StringRef getArgument() const final {
123 // This is the argument used to refer to the pass in
124 // the textual format (on the commandline for example).
125 return "tfl-prepare-composite-funcs-tf";
126 }
getDescription() const127 StringRef getDescription() const final {
128 // This is a brief description of the pass.
129 return "Prepares composite functions in Tensorflow dialect of MLIR";
130 }
131
132 private:
133 // TODO(b/160915525): Consolidate FuncAttr and StringAttr into one.
134 void ConvertTFImplements(FuncOp func, StringAttr attr);
135 void ConvertTFImplementsWithAttributes(FuncOp func, FuncAttr attr);
136 void ConvertTFAPIImplements(FuncOp func, StringAttr attr, ModuleOp module);
137 void runOnOperation() override;
138 };
139
CheckFusableLayerNormalizedLstmCellSimple(FuncOp lstm_func)140 LogicalResult CheckFusableLayerNormalizedLstmCellSimple(FuncOp lstm_func) {
141 for (int i = 0; i < 5; ++i) {
142 auto input = lstm_func.getArgument(i);
143 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
144 if (!input_type) {
145 lstm_func.emitWarning(
146 "we cannot fuse this lstm func because all the inputs have not "
147 "ranked tensor type.");
148 return failure();
149 }
150 }
151
152 return success();
153 }
154
CheckFusableLstmCellSimple(FuncOp lstm_func)155 LogicalResult CheckFusableLstmCellSimple(FuncOp lstm_func) {
156 for (int i = 0; i < 4; ++i) {
157 auto input = lstm_func.getArgument(i);
158 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
159 if (!input_type) {
160 lstm_func.emitWarning(
161 "we cannot fuse this lstm func because all the inputs have not "
162 "ranked tensor type.");
163 return failure();
164 }
165 }
166
167 return success();
168 }
169
CheckOutputConsumer(Operation * call_op,int expected_num_outputs,llvm::DenseSet<int> expected_consumer_indices)170 LogicalResult CheckOutputConsumer(
171 Operation* call_op, int expected_num_outputs,
172 llvm::DenseSet<int> expected_consumer_indices) {
173 const int num_results = call_op->getNumResults();
174 if (num_results != expected_num_outputs) return failure();
175
176 for (int i = 0; i < expected_num_outputs; ++i) {
177 auto it = expected_consumer_indices.find(i);
178 if (it == expected_consumer_indices.end()) {
179 // Unexpected consumer.
180 if (!call_op->getResult(i).use_empty()) return failure();
181 }
182 }
183 return success();
184 }
185
CheckFusableKerasLstm(FuncOp lstm_func,ModuleOp module)186 LogicalResult CheckFusableKerasLstm(FuncOp lstm_func, ModuleOp module) {
187 for (auto func : module.getOps<FuncOp>()) {
188 if (func == lstm_func) continue;
189 auto result = func.walk([&](CallOpInterface op) {
190 if (dyn_cast<FuncOp>(op.resolveCallable()) == lstm_func) {
191 // Keras LSTM have 5 outputs.
192 // We should make sure only the first or the second output are
193 // consumed.
194 if (failed(CheckOutputConsumer(op.getOperation(), 5, {0, 1})))
195 return WalkResult::interrupt();
196 }
197 return WalkResult::advance();
198 });
199
200 if (result.wasInterrupted()) return failure();
201 }
202 // Current UnidirectionalSequenceLSTMOp doesn't support mask input.
203 if (lstm_func.getNumArguments() == 7) return failure();
204
205 // We should know the batch size in advance for the lstm fusion.
206 // A good indicator of batch size is both cell state and input state (indices
207 // 1 & 2) have fixed shape and other input tenors should have ranked tensor
208 // types.
209 for (int i = 0; i < 6; ++i) {
210 auto input = lstm_func.getArgument(i);
211 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
212 if (!input_type) {
213 lstm_func.emitWarning(
214 "we cannot fuse this lstm func because all the inputs have not "
215 "ranked tensor type.");
216 return failure();
217 }
218 switch (i) {
219 case 1: // output_init_state
220 case 2: // hidden_init_state
221 if (!input_type.hasStaticShape()) {
222 lstm_func.emitWarning(
223 "we cannot fuse this lstm func because the batch size is not "
224 "fixed, please consider setting fixed batch size like "
225 "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
226 "lite/examples/experimental_new_converter/"
227 "Keras_LSTM_fusion_Codelab.ipynb");
228 return failure();
229 }
230 break;
231 case 3: // wiehgt
232 case 4: // recurrent_kernel
233 case 5: // bias
234 if (!input_type.hasStaticShape()) {
235 lstm_func.emitWarning(
236 "we cannot fuse this lstm func because the weight & bias are not "
237 "fixed, please consider setting fixed batch size like "
238 "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
239 "lite/examples/experimental_new_converter/"
240 "Keras_LSTM_fusion_Codelab.ipynb");
241 return failure();
242 }
243 break;
244 default:
245 // No op.
246 break;
247 }
248 }
249
250 return success();
251 }
252
ConvertTFImplements(FuncOp func,StringAttr attr)253 void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
254 StringAttr attr) {
255 if (attr.getValue() == "embedding_matmul") {
256 // Convert the composite embedding_matmul function body to a
257 // TFLite fused embedding_lookup op.
258 ConvertEmbeddedLookupFunc convert_embedded_lookup(func);
259 if (failed(convert_embedded_lookup.VerifySignature())) return;
260 func.eraseBody();
261 func.addEntryBlock();
262 convert_embedded_lookup.RewriteFunc();
263 } else if (attr.getValue() == mlir::TFL::kLstmCellSimple) {
264 // Check if the lstm cell simple can be fused, if not, we just don't do
265 // anything.
266 if (failed(CheckFusableLstmCellSimple(func))) return;
267 func.eraseBody();
268 func.addEntryBlock();
269 ConvertLSTMCellSimpleToFusedLSTM convert_lstm_cell_simple(func);
270 if (failed(convert_lstm_cell_simple.RewriteFunc())) {
271 return signalPassFailure();
272 }
273 } else if (attr.getValue() == mlir::TFL::kLayerNormalizedLstmCellSimple) {
274 // Check if the layer normalized lstm cell simple can be fused, if not, we
275 // just don't do anything.
276 if (failed(CheckFusableLayerNormalizedLstmCellSimple(func))) return;
277 func.eraseBody();
278 func.addEntryBlock();
279 ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM
280 convert_layer_norm_lstm_cell_simple(func);
281 if (failed(convert_layer_norm_lstm_cell_simple.RewriteFunc())) {
282 return signalPassFailure();
283 }
284 } else if (attr.getValue() == kTfNMSPadded) {
285 ConvertNMSPaddedFunc convert_nms_padded(func);
286 if (failed(convert_nms_padded.VerifySignature())) return;
287 func.eraseBody();
288 func.addEntryBlock();
289 convert_nms_padded.RewriteFunc();
290 } else if (attr.getValue() == kCustomDenseImageWarp) {
291 ConvertDenseImageWarpFunc image_warping(func);
292 if (failed(image_warping.VerifySignature())) return;
293 if (failed(image_warping.RewriteFunc())) {
294 return signalPassFailure();
295 }
296 }
297 }
298
ConvertTFImplementsWithAttributes(FuncOp func,FuncAttr attr)299 void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes(
300 FuncOp func, FuncAttr attr) {
301 auto api_name = attr.getName().getLeafReference();
302 bool enable_fuse_tftext =
303 fuse_tftext_flag || IsTFTextRegistered(tensorflow::OpRegistry::Global());
304 if (api_name.startswith(kTFTextAPIPrefix) && enable_fuse_tftext) {
305 if (failed(ConvertTFTextAPI(func, api_name, attr))) {
306 return signalPassFailure();
307 }
308 } else if (api_name == kCustomSSDPostprocessing) {
309 ConvertSSDPostProcessFunc convert_ssd_postprocess(func, attr);
310 if (failed(convert_ssd_postprocess.VerifySignature())) return;
311 if (failed(convert_ssd_postprocess.RewriteFunc())) {
312 return signalPassFailure();
313 }
314 } else if (api_name == kCustomMaxUnpooling) {
315 ConvertMaxUnpoolingFunc max_unpooling(func, attr);
316 if (failed(max_unpooling.VerifySignature())) return;
317 if (failed(max_unpooling.RewriteFunc())) {
318 return signalPassFailure();
319 }
320 }
321 }
322
ConvertTFAPIImplements(FuncOp func,StringAttr attr,ModuleOp module)323 void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
324 StringAttr attr,
325 ModuleOp module) {
326 // Keras lstm tf.api_implements usually has attribute like "lstm_abcde91...".
327 // TODO(b/147436982): we need to make sure that only the
328 // outputs(full sequence) is used, not the last_output, not the new_states.
329 // We will discard everything except the outputs.
330 // And the outputs is in the shape of [batch, time, units].
331 if (attr.getValue().startswith("lstm_")) {
332 // Check if the keras lstm can be fused, if not, we just don't do anything.
333 if (failed(CheckFusableKerasLstm(func, module))) return;
334 func.eraseBody();
335 func.addEntryBlock();
336 OpBuilder builder(func.getBody());
337 if (failed(ConvertKerasLSTMLayer(func, &builder)))
338 return signalPassFailure();
339 }
340 }
341
runOnOperation()342 void PrepareCompositeFunctionsPass::runOnOperation() {
343 auto module = getOperation();
344 for (auto func : module.getOps<FuncOp>()) {
345 // We have three kinds of implements:
346 // 1) tf._implements, with string attributes.
347 // 2) tf._implements, with proto attributes.
348 // 3) tf.api_implements.
349 // We need to handle them separately.
350 auto tf_implements_attr_str =
351 func->getAttrOfType<StringAttr>(kTFImplements);
352 if (tf_implements_attr_str) {
353 ConvertTFImplements(func, tf_implements_attr_str);
354 continue;
355 }
356
357 auto tf_implements_attr = func->getAttrOfType<FuncAttr>(kTFImplements);
358 if (tf_implements_attr) {
359 ConvertTFImplementsWithAttributes(func, tf_implements_attr);
360 continue;
361 }
362
363 auto tf_api_implements_attr =
364 func->getAttrOfType<StringAttr>(kTFAPIImplements);
365 if (tf_api_implements_attr) {
366 // TODO(b/147536816): Keras lstm should set up the correct attributes.
367 ConvertTFAPIImplements(func, tf_api_implements_attr, module);
368 }
369 }
370 }
371 } // namespace
372
CreatePrepareCompositeFunctionsPass()373 std::unique_ptr<OperationPass<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
374 return std::make_unique<PrepareCompositeFunctionsPass>();
375 }
376
377 static PassRegistration<PrepareCompositeFunctionsPass> pass;
378
379 } // namespace TFL
380 } // namespace mlir
381