• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/Identifier.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
34 #include "mlir/IR/Operation.h"  // from @llvm-project
35 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
36 #include "mlir/IR/Visitors.h"  // from @llvm-project
37 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
38 #include "mlir/Pass/Pass.h"  // from @llvm-project
39 #include "mlir/Support/LLVM.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
42 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
44 #include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
45 #include "tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h"
46 #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
49 
50 // The cmd line flag to turn on/off Tf.Text API fusion.
51 // NOLINTNEXTLINE
52 static llvm::cl::opt<bool> fuse_tftext_flag(
53     "tfl-fuse-tftext", llvm::cl::value_desc("bool"),
54     llvm::cl::desc("Fuse TF.Text API ops when it's true"),
55     llvm::cl::init(false));
56 
57 namespace mlir {
58 namespace TFL {
59 namespace {
60 
61 constexpr char kTFAPIImplements[] = "tf.api_implements";
62 constexpr char kTFTextAPIPrefix[] = "tftext:";
63 constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
64 constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
65 constexpr char kCustomMaxUnpooling[] = "addons:MaxUnpooling2D";
66 constexpr char kCustomDenseImageWarp[] = "addons:DenseImageWarp";
67 
68 using mlir::TF::FuncAttr;
69 
70 // Abstracts the conversion of the embedded lookup composite function.
71 class ConvertEmbeddedLookupFunc {
72  public:
ConvertEmbeddedLookupFunc(FuncOp func)73   explicit ConvertEmbeddedLookupFunc(FuncOp func) : func_(func) {}
74 
RewriteFunc()75   void RewriteFunc() {
76     func_->setAttr(kTFImplements,
77                    StringAttr::get(func_.getContext(), "embedding_lookup"));
78     Value lookup = func_.getArgument(1);
79     Value value = func_.getArgument(0);
80     auto output_type = func_.getType().getResult(0);
81 
82     OpBuilder builder(func_.getBody());
83     auto op = builder.create<mlir::TFL::EmbeddingLookupOp>(
84         func_.getLoc(), output_type, lookup, value);
85 
86     builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResult());
87   }
88 
VerifySignature()89   LogicalResult VerifySignature() {
90     if (func_.getNumArguments() != 2) {
91       return func_.emitError()
92              << "Invalid number of arguments in the embedding "
93                 "matmul composite function";
94     }
95     if (func_.getType().getNumResults() != 1) {
96       return func_.emitError() << "Invalid number of results in the embedding "
97                                   "matmul composite function";
98     }
99     return success();
100   }
101 
102  private:
103   FuncOp func_;
104 };
105 
106 // This pass uses mechanisms listed in RFC:
107 // https://github.com/tensorflow/community/pull/113
108 // It prepares composite functions that are attributed to indicate
109 // a specific interface (LSTM, SVDF, Embedding lookup etc.) by replacing the
110 // body with the corresponding fused TFLite op. The replacement need not always
111 // be a fused op, though that is the primary use case.
112 class PrepareCompositeFunctionsPass
113     : public PassWrapper<PrepareCompositeFunctionsPass,
114                          OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const115   void getDependentDialects(DialectRegistry& registry) const override {
116     registry.insert<TFL::TensorFlowLiteDialect>();
117   }
118 
119  public:
PrepareCompositeFunctionsPass()120   explicit PrepareCompositeFunctionsPass() {}
121 
122  private:
123   // TODO(b/160915525): Consolidate FuncAttr and StringAttr into one.
124   void ConvertTFImplements(FuncOp func, StringAttr attr);
125   void ConvertTFImplementsWithAttributes(FuncOp func, FuncAttr attr);
126   void ConvertTFAPIImplements(FuncOp func, StringAttr attr, ModuleOp module);
127   void runOnOperation() override;
128 };
129 
CheckFusableLayerNormalizedLstmCellSimple(FuncOp lstm_func)130 LogicalResult CheckFusableLayerNormalizedLstmCellSimple(FuncOp lstm_func) {
131   for (int i = 0; i < 5; ++i) {
132     auto input = lstm_func.getArgument(i);
133     auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
134     if (!input_type) {
135       lstm_func.emitWarning(
136           "we cannot fuse this lstm func because all the inputs have not "
137           "ranked tensor type.");
138       return failure();
139     }
140   }
141 
142   return success();
143 }
144 
CheckFusableLstmCellSimple(FuncOp lstm_func)145 LogicalResult CheckFusableLstmCellSimple(FuncOp lstm_func) {
146   for (int i = 0; i < 4; ++i) {
147     auto input = lstm_func.getArgument(i);
148     auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
149     if (!input_type) {
150       lstm_func.emitWarning(
151           "we cannot fuse this lstm func because all the inputs have not "
152           "ranked tensor type.");
153       return failure();
154     }
155   }
156 
157   return success();
158 }
159 
CheckOutputConsumer(Operation * call_op,int expected_num_outputs,llvm::DenseSet<int> expected_consumer_indices)160 LogicalResult CheckOutputConsumer(
161     Operation* call_op, int expected_num_outputs,
162     llvm::DenseSet<int> expected_consumer_indices) {
163   const int num_results = call_op->getNumResults();
164   if (num_results != expected_num_outputs) return failure();
165 
166   for (int i = 0; i < expected_num_outputs; ++i) {
167     auto it = expected_consumer_indices.find(i);
168     if (it == expected_consumer_indices.end()) {
169       // Unexpected consumer.
170       if (!call_op->getResult(i).use_empty()) return failure();
171     }
172   }
173   return success();
174 }
175 
CheckFusableKerasLstm(FuncOp lstm_func,ModuleOp module)176 LogicalResult CheckFusableKerasLstm(FuncOp lstm_func, ModuleOp module) {
177   for (auto func : module.getOps<FuncOp>()) {
178     if (func == lstm_func) continue;
179     auto result = func.walk([&](CallOpInterface op) {
180       if (dyn_cast<FuncOp>(op.resolveCallable()) == lstm_func) {
181         // Keras LSTM have 5 outputs.
182         // We should make sure only the first or the second output are
183         // consumed.
184         if (failed(CheckOutputConsumer(op.getOperation(), 5, {0, 1})))
185           return WalkResult::interrupt();
186       }
187       return WalkResult::advance();
188     });
189 
190     if (result.wasInterrupted()) return failure();
191   }
192 
193   // We should know the batch size in advance for the lstm fusion.
194   // A good indicator of batch size is both cell state and input state (indices
195   // 1 & 2) have fixed shape and other input tenors should have ranked tensor
196   // types.
197   for (int i = 0; i < 6; ++i) {
198     auto input = lstm_func.getArgument(i);
199     auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
200     if (!input_type) {
201       lstm_func.emitWarning(
202           "we cannot fuse this lstm func because all the inputs have not "
203           "ranked tensor type.");
204       return failure();
205     }
206     switch (i) {
207       case 1:  // output_init_state
208       case 2:  // hidden_init_state
209         if (!input_type.hasStaticShape()) {
210           lstm_func.emitWarning(
211               "we cannot fuse this lstm func because the batch size is not "
212               "fixed, please consider setting fixed batch size like "
213               "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
214               "lite/examples/experimental_new_converter/"
215               "Keras_LSTM_fusion_Codelab.ipynb");
216           return failure();
217         }
218         break;
219       case 3:  // wiehgt
220       case 4:  // recurrent_kernel
221       case 5:  // bias
222         if (!input_type.hasStaticShape()) {
223           lstm_func.emitWarning(
224               "we cannot fuse this lstm func because the weight & bias are not "
225               "fixed, please consider setting fixed batch size like "
226               "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
227               "lite/examples/experimental_new_converter/"
228               "Keras_LSTM_fusion_Codelab.ipynb");
229           return failure();
230         }
231         break;
232       default:
233         // No op.
234         break;
235     }
236   }
237 
238   return success();
239 }
240 
ConvertTFImplements(FuncOp func,StringAttr attr)241 void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
242                                                         StringAttr attr) {
243   if (attr.getValue() == "embedding_matmul") {
244     func.eraseBody();
245     func.addEntryBlock();
246     // Convert the composite embedding_matmul function body to a
247     // TFLite fused embedding_lookup op.
248     ConvertEmbeddedLookupFunc convert_embedded_lookup(func);
249     if (failed(convert_embedded_lookup.VerifySignature())) {
250       return signalPassFailure();
251     }
252     convert_embedded_lookup.RewriteFunc();
253   } else if (attr.getValue() == mlir::TFL::kLstmCellSimple) {
254     // Check if the lstm cell simple can be fused, if not, we just don't do
255     // anything.
256     if (failed(CheckFusableLstmCellSimple(func))) return;
257     func.eraseBody();
258     func.addEntryBlock();
259     ConvertLSTMCellSimpleToFusedLSTM convert_lstm_cell_simple(func);
260     if (failed(convert_lstm_cell_simple.RewriteFunc())) {
261       return signalPassFailure();
262     }
263   } else if (attr.getValue() == mlir::TFL::kLayerNormalizedLstmCellSimple) {
264     // Check if the layer normalized lstm cell simple can be fused, if not, we
265     // just don't do anything.
266     if (failed(CheckFusableLayerNormalizedLstmCellSimple(func))) return;
267     func.eraseBody();
268     func.addEntryBlock();
269     ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM
270         convert_layer_norm_lstm_cell_simple(func);
271     if (failed(convert_layer_norm_lstm_cell_simple.RewriteFunc())) {
272       return signalPassFailure();
273     }
274   } else if (attr.getValue() == kTfNMSPadded) {
275     func.eraseBody();
276     func.addEntryBlock();
277     ConvertNMSPaddedFunc convert_nms_padded(func);
278     if (failed(convert_nms_padded.VerifySignature())) {
279       return signalPassFailure();
280     }
281     convert_nms_padded.RewriteFunc();
282   } else if (attr.getValue() == kCustomDenseImageWarp) {
283     ConvertDenseImageWarpFunc image_warping(func);
284     if (failed(image_warping.VerifySignature()) ||
285         failed(image_warping.RewriteFunc())) {
286       return signalPassFailure();
287     }
288   }
289 }
290 
ConvertTFImplementsWithAttributes(FuncOp func,FuncAttr attr)291 void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes(
292     FuncOp func, FuncAttr attr) {
293   auto api_name = attr.GetName().getLeafReference();
294   bool enable_fuse_tftext =
295       fuse_tftext_flag || IsTFTextRegistered(tensorflow::OpRegistry::Global());
296   if (api_name.startswith(kTFTextAPIPrefix) && enable_fuse_tftext) {
297     if (failed(ConvertTFTextAPI(func, api_name, attr))) {
298       return signalPassFailure();
299     }
300   } else if (api_name == kCustomSSDPostprocessing) {
301     ConvertSSDPostProcessFunc convert_ssd_postprocess(func, attr);
302     if (failed(convert_ssd_postprocess.VerifySignature()) ||
303         failed(convert_ssd_postprocess.RewriteFunc())) {
304       return signalPassFailure();
305     }
306   } else if (api_name == kCustomMaxUnpooling) {
307     ConvertMaxUnpoolingFunc max_unpooling(func, attr);
308     if (failed(max_unpooling.VerifySignature()) ||
309         failed(max_unpooling.RewriteFunc())) {
310       return signalPassFailure();
311     }
312   }
313 }
314 
ConvertTFAPIImplements(FuncOp func,StringAttr attr,ModuleOp module)315 void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
316                                                            StringAttr attr,
317                                                            ModuleOp module) {
318   // Keras lstm tf.api_implements usually has attribute like "lstm_abcde91...".
319   // TODO(b/147436982): we need to make sure that only the
320   // outputs(full sequence) is used, not the last_output, not the new_states.
321   // We will discard everything except the outputs.
322   // And the outputs is in the shape of [batch, time, units].
323   if (attr.getValue().startswith("lstm_")) {
324     // Check if the keras lstm can be fused, if not, we just don't do anything.
325     if (failed(CheckFusableKerasLstm(func, module))) return;
326 
327     func.eraseBody();
328     func.addEntryBlock();
329 
330     OpBuilder builder(func.getBody());
331     if (failed(ConvertKerasLSTMLayer(func, &builder)))
332       return signalPassFailure();
333   }
334 }
335 
runOnOperation()336 void PrepareCompositeFunctionsPass::runOnOperation() {
337   auto module = getOperation();
338   for (auto func : module.getOps<FuncOp>()) {
339     // We have three kinds of implements:
340     // 1) tf._implements, with string attributes.
341     // 2) tf._implements, with proto attributes.
342     // 3) tf.api_implements.
343     // We need to handle them separately.
344     auto tf_implements_attr_str =
345         func->getAttrOfType<StringAttr>(kTFImplements);
346     if (tf_implements_attr_str) {
347       ConvertTFImplements(func, tf_implements_attr_str);
348       continue;
349     }
350 
351     auto tf_implements_attr = func->getAttrOfType<FuncAttr>(kTFImplements);
352     if (tf_implements_attr) {
353       ConvertTFImplementsWithAttributes(func, tf_implements_attr);
354       continue;
355     }
356 
357     auto tf_api_implements_attr =
358         func->getAttrOfType<StringAttr>(kTFAPIImplements);
359     if (tf_api_implements_attr) {
360       // TODO(b/147536816): Keras lstm should set up the correct attributes.
361       ConvertTFAPIImplements(func, tf_api_implements_attr, module);
362     }
363   }
364 }
365 }  // namespace
366 
CreatePrepareCompositeFunctionsPass()367 std::unique_ptr<OperationPass<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
368   return std::make_unique<PrepareCompositeFunctionsPass>();
369 }
370 
371 static PassRegistration<PrepareCompositeFunctionsPass> pass(
372     "tfl-prepare-composite-funcs-tf",
373     "Prepares composite functions in Tensorflow dialect of MLIR ");
374 
375 }  // namespace TFL
376 }  // namespace mlir
377