1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/Identifier.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/MLIRContext.h" // from @llvm-project
34 #include "mlir/IR/Operation.h" // from @llvm-project
35 #include "mlir/IR/SymbolTable.h" // from @llvm-project
36 #include "mlir/IR/Visitors.h" // from @llvm-project
37 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
38 #include "mlir/Pass/Pass.h" // from @llvm-project
39 #include "mlir/Support/LLVM.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
42 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
44 #include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
45 #include "tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h"
46 #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
49
50 // The cmd line flag to turn on/off Tf.Text API fusion.
51 // NOLINTNEXTLINE
52 static llvm::cl::opt<bool> fuse_tftext_flag(
53 "tfl-fuse-tftext", llvm::cl::value_desc("bool"),
54 llvm::cl::desc("Fuse TF.Text API ops when it's true"),
55 llvm::cl::init(false));
56
57 namespace mlir {
58 namespace TFL {
59 namespace {
60
61 constexpr char kTFAPIImplements[] = "tf.api_implements";
62 constexpr char kTFTextAPIPrefix[] = "tftext:";
63 constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
64 constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
65 constexpr char kCustomMaxUnpooling[] = "addons:MaxUnpooling2D";
66 constexpr char kCustomDenseImageWarp[] = "addons:DenseImageWarp";
67
68 using mlir::TF::FuncAttr;
69
70 // Abstracts the conversion of the embedded lookup composite function.
71 class ConvertEmbeddedLookupFunc {
72 public:
ConvertEmbeddedLookupFunc(FuncOp func)73 explicit ConvertEmbeddedLookupFunc(FuncOp func) : func_(func) {}
74
RewriteFunc()75 void RewriteFunc() {
76 func_->setAttr(kTFImplements,
77 StringAttr::get(func_.getContext(), "embedding_lookup"));
78 Value lookup = func_.getArgument(1);
79 Value value = func_.getArgument(0);
80 auto output_type = func_.getType().getResult(0);
81
82 OpBuilder builder(func_.getBody());
83 auto op = builder.create<mlir::TFL::EmbeddingLookupOp>(
84 func_.getLoc(), output_type, lookup, value);
85
86 builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResult());
87 }
88
VerifySignature()89 LogicalResult VerifySignature() {
90 if (func_.getNumArguments() != 2) {
91 return func_.emitError()
92 << "Invalid number of arguments in the embedding "
93 "matmul composite function";
94 }
95 if (func_.getType().getNumResults() != 1) {
96 return func_.emitError() << "Invalid number of results in the embedding "
97 "matmul composite function";
98 }
99 return success();
100 }
101
102 private:
103 FuncOp func_;
104 };
105
106 // This pass uses mechanisms listed in RFC:
107 // https://github.com/tensorflow/community/pull/113
108 // It prepares composite functions that are attributed to indicate
109 // a specific interface (LSTM, SVDF, Embedding lookup etc.) by replacing the
110 // body with the corresponding fused TFLite op. The replacement need not always
111 // be a fused op, though that is the primary use case.
112 class PrepareCompositeFunctionsPass
113 : public PassWrapper<PrepareCompositeFunctionsPass,
114 OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const115 void getDependentDialects(DialectRegistry& registry) const override {
116 registry.insert<TFL::TensorFlowLiteDialect>();
117 }
118
119 public:
PrepareCompositeFunctionsPass()120 explicit PrepareCompositeFunctionsPass() {}
121
122 private:
123 // TODO(b/160915525): Consolidate FuncAttr and StringAttr into one.
124 void ConvertTFImplements(FuncOp func, StringAttr attr);
125 void ConvertTFImplementsWithAttributes(FuncOp func, FuncAttr attr);
126 void ConvertTFAPIImplements(FuncOp func, StringAttr attr, ModuleOp module);
127 void runOnOperation() override;
128 };
129
CheckFusableLayerNormalizedLstmCellSimple(FuncOp lstm_func)130 LogicalResult CheckFusableLayerNormalizedLstmCellSimple(FuncOp lstm_func) {
131 for (int i = 0; i < 5; ++i) {
132 auto input = lstm_func.getArgument(i);
133 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
134 if (!input_type) {
135 lstm_func.emitWarning(
136 "we cannot fuse this lstm func because all the inputs have not "
137 "ranked tensor type.");
138 return failure();
139 }
140 }
141
142 return success();
143 }
144
CheckFusableLstmCellSimple(FuncOp lstm_func)145 LogicalResult CheckFusableLstmCellSimple(FuncOp lstm_func) {
146 for (int i = 0; i < 4; ++i) {
147 auto input = lstm_func.getArgument(i);
148 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
149 if (!input_type) {
150 lstm_func.emitWarning(
151 "we cannot fuse this lstm func because all the inputs have not "
152 "ranked tensor type.");
153 return failure();
154 }
155 }
156
157 return success();
158 }
159
CheckOutputConsumer(Operation * call_op,int expected_num_outputs,llvm::DenseSet<int> expected_consumer_indices)160 LogicalResult CheckOutputConsumer(
161 Operation* call_op, int expected_num_outputs,
162 llvm::DenseSet<int> expected_consumer_indices) {
163 const int num_results = call_op->getNumResults();
164 if (num_results != expected_num_outputs) return failure();
165
166 for (int i = 0; i < expected_num_outputs; ++i) {
167 auto it = expected_consumer_indices.find(i);
168 if (it == expected_consumer_indices.end()) {
169 // Unexpected consumer.
170 if (!call_op->getResult(i).use_empty()) return failure();
171 }
172 }
173 return success();
174 }
175
CheckFusableKerasLstm(FuncOp lstm_func,ModuleOp module)176 LogicalResult CheckFusableKerasLstm(FuncOp lstm_func, ModuleOp module) {
177 for (auto func : module.getOps<FuncOp>()) {
178 if (func == lstm_func) continue;
179 auto result = func.walk([&](CallOpInterface op) {
180 if (dyn_cast<FuncOp>(op.resolveCallable()) == lstm_func) {
181 // Keras LSTM have 5 outputs.
182 // We should make sure only the first or the second output are
183 // consumed.
184 if (failed(CheckOutputConsumer(op.getOperation(), 5, {0, 1})))
185 return WalkResult::interrupt();
186 }
187 return WalkResult::advance();
188 });
189
190 if (result.wasInterrupted()) return failure();
191 }
192
193 // We should know the batch size in advance for the lstm fusion.
194 // A good indicator of batch size is both cell state and input state (indices
195 // 1 & 2) have fixed shape and other input tenors should have ranked tensor
196 // types.
197 for (int i = 0; i < 6; ++i) {
198 auto input = lstm_func.getArgument(i);
199 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
200 if (!input_type) {
201 lstm_func.emitWarning(
202 "we cannot fuse this lstm func because all the inputs have not "
203 "ranked tensor type.");
204 return failure();
205 }
206 switch (i) {
207 case 1: // output_init_state
208 case 2: // hidden_init_state
209 if (!input_type.hasStaticShape()) {
210 lstm_func.emitWarning(
211 "we cannot fuse this lstm func because the batch size is not "
212 "fixed, please consider setting fixed batch size like "
213 "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
214 "lite/examples/experimental_new_converter/"
215 "Keras_LSTM_fusion_Codelab.ipynb");
216 return failure();
217 }
218 break;
219 case 3: // wiehgt
220 case 4: // recurrent_kernel
221 case 5: // bias
222 if (!input_type.hasStaticShape()) {
223 lstm_func.emitWarning(
224 "we cannot fuse this lstm func because the weight & bias are not "
225 "fixed, please consider setting fixed batch size like "
226 "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
227 "lite/examples/experimental_new_converter/"
228 "Keras_LSTM_fusion_Codelab.ipynb");
229 return failure();
230 }
231 break;
232 default:
233 // No op.
234 break;
235 }
236 }
237
238 return success();
239 }
240
ConvertTFImplements(FuncOp func,StringAttr attr)241 void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
242 StringAttr attr) {
243 if (attr.getValue() == "embedding_matmul") {
244 func.eraseBody();
245 func.addEntryBlock();
246 // Convert the composite embedding_matmul function body to a
247 // TFLite fused embedding_lookup op.
248 ConvertEmbeddedLookupFunc convert_embedded_lookup(func);
249 if (failed(convert_embedded_lookup.VerifySignature())) {
250 return signalPassFailure();
251 }
252 convert_embedded_lookup.RewriteFunc();
253 } else if (attr.getValue() == mlir::TFL::kLstmCellSimple) {
254 // Check if the lstm cell simple can be fused, if not, we just don't do
255 // anything.
256 if (failed(CheckFusableLstmCellSimple(func))) return;
257 func.eraseBody();
258 func.addEntryBlock();
259 ConvertLSTMCellSimpleToFusedLSTM convert_lstm_cell_simple(func);
260 if (failed(convert_lstm_cell_simple.RewriteFunc())) {
261 return signalPassFailure();
262 }
263 } else if (attr.getValue() == mlir::TFL::kLayerNormalizedLstmCellSimple) {
264 // Check if the layer normalized lstm cell simple can be fused, if not, we
265 // just don't do anything.
266 if (failed(CheckFusableLayerNormalizedLstmCellSimple(func))) return;
267 func.eraseBody();
268 func.addEntryBlock();
269 ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM
270 convert_layer_norm_lstm_cell_simple(func);
271 if (failed(convert_layer_norm_lstm_cell_simple.RewriteFunc())) {
272 return signalPassFailure();
273 }
274 } else if (attr.getValue() == kTfNMSPadded) {
275 func.eraseBody();
276 func.addEntryBlock();
277 ConvertNMSPaddedFunc convert_nms_padded(func);
278 if (failed(convert_nms_padded.VerifySignature())) {
279 return signalPassFailure();
280 }
281 convert_nms_padded.RewriteFunc();
282 } else if (attr.getValue() == kCustomDenseImageWarp) {
283 ConvertDenseImageWarpFunc image_warping(func);
284 if (failed(image_warping.VerifySignature()) ||
285 failed(image_warping.RewriteFunc())) {
286 return signalPassFailure();
287 }
288 }
289 }
290
ConvertTFImplementsWithAttributes(FuncOp func,FuncAttr attr)291 void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes(
292 FuncOp func, FuncAttr attr) {
293 auto api_name = attr.GetName().getLeafReference();
294 bool enable_fuse_tftext =
295 fuse_tftext_flag || IsTFTextRegistered(tensorflow::OpRegistry::Global());
296 if (api_name.startswith(kTFTextAPIPrefix) && enable_fuse_tftext) {
297 if (failed(ConvertTFTextAPI(func, api_name, attr))) {
298 return signalPassFailure();
299 }
300 } else if (api_name == kCustomSSDPostprocessing) {
301 ConvertSSDPostProcessFunc convert_ssd_postprocess(func, attr);
302 if (failed(convert_ssd_postprocess.VerifySignature()) ||
303 failed(convert_ssd_postprocess.RewriteFunc())) {
304 return signalPassFailure();
305 }
306 } else if (api_name == kCustomMaxUnpooling) {
307 ConvertMaxUnpoolingFunc max_unpooling(func, attr);
308 if (failed(max_unpooling.VerifySignature()) ||
309 failed(max_unpooling.RewriteFunc())) {
310 return signalPassFailure();
311 }
312 }
313 }
314
ConvertTFAPIImplements(FuncOp func,StringAttr attr,ModuleOp module)315 void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
316 StringAttr attr,
317 ModuleOp module) {
318 // Keras lstm tf.api_implements usually has attribute like "lstm_abcde91...".
319 // TODO(b/147436982): we need to make sure that only the
320 // outputs(full sequence) is used, not the last_output, not the new_states.
321 // We will discard everything except the outputs.
322 // And the outputs is in the shape of [batch, time, units].
323 if (attr.getValue().startswith("lstm_")) {
324 // Check if the keras lstm can be fused, if not, we just don't do anything.
325 if (failed(CheckFusableKerasLstm(func, module))) return;
326
327 func.eraseBody();
328 func.addEntryBlock();
329
330 OpBuilder builder(func.getBody());
331 if (failed(ConvertKerasLSTMLayer(func, &builder)))
332 return signalPassFailure();
333 }
334 }
335
runOnOperation()336 void PrepareCompositeFunctionsPass::runOnOperation() {
337 auto module = getOperation();
338 for (auto func : module.getOps<FuncOp>()) {
339 // We have three kinds of implements:
340 // 1) tf._implements, with string attributes.
341 // 2) tf._implements, with proto attributes.
342 // 3) tf.api_implements.
343 // We need to handle them separately.
344 auto tf_implements_attr_str =
345 func->getAttrOfType<StringAttr>(kTFImplements);
346 if (tf_implements_attr_str) {
347 ConvertTFImplements(func, tf_implements_attr_str);
348 continue;
349 }
350
351 auto tf_implements_attr = func->getAttrOfType<FuncAttr>(kTFImplements);
352 if (tf_implements_attr) {
353 ConvertTFImplementsWithAttributes(func, tf_implements_attr);
354 continue;
355 }
356
357 auto tf_api_implements_attr =
358 func->getAttrOfType<StringAttr>(kTFAPIImplements);
359 if (tf_api_implements_attr) {
360 // TODO(b/147536816): Keras lstm should set up the correct attributes.
361 ConvertTFAPIImplements(func, tf_api_implements_attr, module);
362 }
363 }
364 }
365 } // namespace
366
CreatePrepareCompositeFunctionsPass()367 std::unique_ptr<OperationPass<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
368 return std::make_unique<PrepareCompositeFunctionsPass>();
369 }
370
371 static PassRegistration<PrepareCompositeFunctionsPass> pass(
372 "tfl-prepare-composite-funcs-tf",
373 "Prepares composite functions in Tensorflow dialect of MLIR ");
374
375 } // namespace TFL
376 } // namespace mlir
377