• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/Identifier.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
34 #include "mlir/IR/Operation.h"  // from @llvm-project
35 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
36 #include "mlir/IR/Visitors.h"  // from @llvm-project
37 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
38 #include "mlir/Pass/Pass.h"  // from @llvm-project
39 #include "mlir/Support/LLVM.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
42 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
44 #include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
45 #include "tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h"
46 #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
49 
50 // The cmd line flag to turn on/off Tf.Text API fusion.
51 // NOLINTNEXTLINE
52 static llvm::cl::opt<bool> fuse_tftext_flag(
53     "tfl-fuse-tftext", llvm::cl::value_desc("bool"),
54     llvm::cl::desc("Fuse TF.Text API ops when it's true"),
55     llvm::cl::init(false));
56 
57 namespace mlir {
58 namespace TFL {
59 namespace {
60 
61 constexpr char kTFAPIImplements[] = "tf.api_implements";
62 constexpr char kTFTextAPIPrefix[] = "tftext:";
63 constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
64 constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
65 constexpr char kCustomMaxUnpooling[] = "addons:MaxUnpooling2D";
66 constexpr char kCustomDenseImageWarp[] = "addons:DenseImageWarp";
67 
68 using mlir::TF::FuncAttr;
69 
70 // Abstracts the conversion of the embedded lookup composite function.
71 class ConvertEmbeddedLookupFunc {
72  public:
ConvertEmbeddedLookupFunc(FuncOp func)73   explicit ConvertEmbeddedLookupFunc(FuncOp func) : func_(func) {}
74 
RewriteFunc()75   void RewriteFunc() {
76     func_->setAttr(kTFImplements,
77                    StringAttr::get(func_.getContext(), "embedding_lookup"));
78     Value lookup = func_.getArgument(1);
79     Value value = func_.getArgument(0);
80     auto output_type = func_.getType().getResult(0);
81 
82     OpBuilder builder(func_.getBody());
83     auto op = builder.create<mlir::TFL::EmbeddingLookupOp>(
84         func_.getLoc(), output_type, lookup, value);
85 
86     builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResult());
87   }
88 
VerifySignature()89   LogicalResult VerifySignature() {
90     if (func_.getNumArguments() != 2) {
91       return func_.emitWarning()
92              << "Invalid number of arguments in the embedding "
93                 "matmul composite function";
94     }
95     if (func_.getType().getNumResults() != 1) {
96       return func_.emitWarning() << "Invalid number of results in the "
97                                     "embedding matmul composite function";
98     }
99     return success();
100   }
101 
102  private:
103   FuncOp func_;
104 };
105 
106 // This pass uses mechanisms listed in RFC:
107 // https://github.com/tensorflow/community/pull/113
108 // It prepares composite functions that are attributed to indicate
109 // a specific interface (LSTM, SVDF, Embedding lookup etc.) by replacing the
110 // body with the corresponding fused TFLite op. The replacement need not always
111 // be a fused op, though that is the primary use case.
112 class PrepareCompositeFunctionsPass
113     : public PassWrapper<PrepareCompositeFunctionsPass,
114                          OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const115   void getDependentDialects(DialectRegistry& registry) const override {
116     registry.insert<TFL::TensorFlowLiteDialect>();
117   }
118 
119  public:
PrepareCompositeFunctionsPass()120   explicit PrepareCompositeFunctionsPass() {}
121 
getArgument() const122   StringRef getArgument() const final {
123     // This is the argument used to refer to the pass in
124     // the textual format (on the commandline for example).
125     return "tfl-prepare-composite-funcs-tf";
126   }
getDescription() const127   StringRef getDescription() const final {
128     // This is a brief description of the pass.
129     return "Prepares composite functions in Tensorflow dialect of MLIR";
130   }
131 
132  private:
133   // TODO(b/160915525): Consolidate FuncAttr and StringAttr into one.
134   void ConvertTFImplements(FuncOp func, StringAttr attr);
135   void ConvertTFImplementsWithAttributes(FuncOp func, FuncAttr attr);
136   void ConvertTFAPIImplements(FuncOp func, StringAttr attr, ModuleOp module);
137   void runOnOperation() override;
138 };
139 
CheckFusableLayerNormalizedLstmCellSimple(FuncOp lstm_func)140 LogicalResult CheckFusableLayerNormalizedLstmCellSimple(FuncOp lstm_func) {
141   for (int i = 0; i < 5; ++i) {
142     auto input = lstm_func.getArgument(i);
143     auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
144     if (!input_type) {
145       lstm_func.emitWarning(
146           "we cannot fuse this lstm func because all the inputs have not "
147           "ranked tensor type.");
148       return failure();
149     }
150   }
151 
152   return success();
153 }
154 
CheckFusableLstmCellSimple(FuncOp lstm_func)155 LogicalResult CheckFusableLstmCellSimple(FuncOp lstm_func) {
156   for (int i = 0; i < 4; ++i) {
157     auto input = lstm_func.getArgument(i);
158     auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
159     if (!input_type) {
160       lstm_func.emitWarning(
161           "we cannot fuse this lstm func because all the inputs have not "
162           "ranked tensor type.");
163       return failure();
164     }
165   }
166 
167   return success();
168 }
169 
CheckOutputConsumer(Operation * call_op,int expected_num_outputs,llvm::DenseSet<int> expected_consumer_indices)170 LogicalResult CheckOutputConsumer(
171     Operation* call_op, int expected_num_outputs,
172     llvm::DenseSet<int> expected_consumer_indices) {
173   const int num_results = call_op->getNumResults();
174   if (num_results != expected_num_outputs) return failure();
175 
176   for (int i = 0; i < expected_num_outputs; ++i) {
177     auto it = expected_consumer_indices.find(i);
178     if (it == expected_consumer_indices.end()) {
179       // Unexpected consumer.
180       if (!call_op->getResult(i).use_empty()) return failure();
181     }
182   }
183   return success();
184 }
185 
CheckFusableKerasLstm(FuncOp lstm_func,ModuleOp module)186 LogicalResult CheckFusableKerasLstm(FuncOp lstm_func, ModuleOp module) {
187   for (auto func : module.getOps<FuncOp>()) {
188     if (func == lstm_func) continue;
189     auto result = func.walk([&](CallOpInterface op) {
190       if (dyn_cast<FuncOp>(op.resolveCallable()) == lstm_func) {
191         // Keras LSTM have 5 outputs.
192         // We should make sure only the first or the second output are
193         // consumed.
194         if (failed(CheckOutputConsumer(op.getOperation(), 5, {0, 1})))
195           return WalkResult::interrupt();
196       }
197       return WalkResult::advance();
198     });
199 
200     if (result.wasInterrupted()) return failure();
201   }
202   // Current UnidirectionalSequenceLSTMOp doesn't support mask input.
203   if (lstm_func.getNumArguments() == 7) return failure();
204 
205   // We should know the batch size in advance for the lstm fusion.
206   // A good indicator of batch size is both cell state and input state (indices
207   // 1 & 2) have fixed shape and other input tenors should have ranked tensor
208   // types.
209   for (int i = 0; i < 6; ++i) {
210     auto input = lstm_func.getArgument(i);
211     auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
212     if (!input_type) {
213       lstm_func.emitWarning(
214           "we cannot fuse this lstm func because all the inputs have not "
215           "ranked tensor type.");
216       return failure();
217     }
218     switch (i) {
219       case 1:  // output_init_state
220       case 2:  // hidden_init_state
221         if (!input_type.hasStaticShape()) {
222           lstm_func.emitWarning(
223               "we cannot fuse this lstm func because the batch size is not "
224               "fixed, please consider setting fixed batch size like "
225               "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
226               "lite/examples/experimental_new_converter/"
227               "Keras_LSTM_fusion_Codelab.ipynb");
228           return failure();
229         }
230         break;
231       case 3:  // wiehgt
232       case 4:  // recurrent_kernel
233       case 5:  // bias
234         if (!input_type.hasStaticShape()) {
235           lstm_func.emitWarning(
236               "we cannot fuse this lstm func because the weight & bias are not "
237               "fixed, please consider setting fixed batch size like "
238               "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
239               "lite/examples/experimental_new_converter/"
240               "Keras_LSTM_fusion_Codelab.ipynb");
241           return failure();
242         }
243         break;
244       default:
245         // No op.
246         break;
247     }
248   }
249 
250   return success();
251 }
252 
ConvertTFImplements(FuncOp func,StringAttr attr)253 void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
254                                                         StringAttr attr) {
255   if (attr.getValue() == "embedding_matmul") {
256     // Convert the composite embedding_matmul function body to a
257     // TFLite fused embedding_lookup op.
258     ConvertEmbeddedLookupFunc convert_embedded_lookup(func);
259     if (failed(convert_embedded_lookup.VerifySignature())) return;
260     func.eraseBody();
261     func.addEntryBlock();
262     convert_embedded_lookup.RewriteFunc();
263   } else if (attr.getValue() == mlir::TFL::kLstmCellSimple) {
264     // Check if the lstm cell simple can be fused, if not, we just don't do
265     // anything.
266     if (failed(CheckFusableLstmCellSimple(func))) return;
267     func.eraseBody();
268     func.addEntryBlock();
269     ConvertLSTMCellSimpleToFusedLSTM convert_lstm_cell_simple(func);
270     if (failed(convert_lstm_cell_simple.RewriteFunc())) {
271       return signalPassFailure();
272     }
273   } else if (attr.getValue() == mlir::TFL::kLayerNormalizedLstmCellSimple) {
274     // Check if the layer normalized lstm cell simple can be fused, if not, we
275     // just don't do anything.
276     if (failed(CheckFusableLayerNormalizedLstmCellSimple(func))) return;
277     func.eraseBody();
278     func.addEntryBlock();
279     ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM
280         convert_layer_norm_lstm_cell_simple(func);
281     if (failed(convert_layer_norm_lstm_cell_simple.RewriteFunc())) {
282       return signalPassFailure();
283     }
284   } else if (attr.getValue() == kTfNMSPadded) {
285     ConvertNMSPaddedFunc convert_nms_padded(func);
286     if (failed(convert_nms_padded.VerifySignature())) return;
287     func.eraseBody();
288     func.addEntryBlock();
289     convert_nms_padded.RewriteFunc();
290   } else if (attr.getValue() == kCustomDenseImageWarp) {
291     ConvertDenseImageWarpFunc image_warping(func);
292     if (failed(image_warping.VerifySignature())) return;
293     if (failed(image_warping.RewriteFunc())) {
294       return signalPassFailure();
295     }
296   }
297 }
298 
ConvertTFImplementsWithAttributes(FuncOp func,FuncAttr attr)299 void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes(
300     FuncOp func, FuncAttr attr) {
301   auto api_name = attr.getName().getLeafReference();
302   bool enable_fuse_tftext =
303       fuse_tftext_flag || IsTFTextRegistered(tensorflow::OpRegistry::Global());
304   if (api_name.startswith(kTFTextAPIPrefix) && enable_fuse_tftext) {
305     if (failed(ConvertTFTextAPI(func, api_name, attr))) {
306       return signalPassFailure();
307     }
308   } else if (api_name == kCustomSSDPostprocessing) {
309     ConvertSSDPostProcessFunc convert_ssd_postprocess(func, attr);
310     if (failed(convert_ssd_postprocess.VerifySignature())) return;
311     if (failed(convert_ssd_postprocess.RewriteFunc())) {
312       return signalPassFailure();
313     }
314   } else if (api_name == kCustomMaxUnpooling) {
315     ConvertMaxUnpoolingFunc max_unpooling(func, attr);
316     if (failed(max_unpooling.VerifySignature())) return;
317     if (failed(max_unpooling.RewriteFunc())) {
318       return signalPassFailure();
319     }
320   }
321 }
322 
ConvertTFAPIImplements(FuncOp func,StringAttr attr,ModuleOp module)323 void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
324                                                            StringAttr attr,
325                                                            ModuleOp module) {
326   // Keras lstm tf.api_implements usually has attribute like "lstm_abcde91...".
327   // TODO(b/147436982): we need to make sure that only the
328   // outputs(full sequence) is used, not the last_output, not the new_states.
329   // We will discard everything except the outputs.
330   // And the outputs is in the shape of [batch, time, units].
331   if (attr.getValue().startswith("lstm_")) {
332     // Check if the keras lstm can be fused, if not, we just don't do anything.
333     if (failed(CheckFusableKerasLstm(func, module))) return;
334     func.eraseBody();
335     func.addEntryBlock();
336     OpBuilder builder(func.getBody());
337     if (failed(ConvertKerasLSTMLayer(func, &builder)))
338       return signalPassFailure();
339   }
340 }
341 
runOnOperation()342 void PrepareCompositeFunctionsPass::runOnOperation() {
343   auto module = getOperation();
344   for (auto func : module.getOps<FuncOp>()) {
345     // We have three kinds of implements:
346     // 1) tf._implements, with string attributes.
347     // 2) tf._implements, with proto attributes.
348     // 3) tf.api_implements.
349     // We need to handle them separately.
350     auto tf_implements_attr_str =
351         func->getAttrOfType<StringAttr>(kTFImplements);
352     if (tf_implements_attr_str) {
353       ConvertTFImplements(func, tf_implements_attr_str);
354       continue;
355     }
356 
357     auto tf_implements_attr = func->getAttrOfType<FuncAttr>(kTFImplements);
358     if (tf_implements_attr) {
359       ConvertTFImplementsWithAttributes(func, tf_implements_attr);
360       continue;
361     }
362 
363     auto tf_api_implements_attr =
364         func->getAttrOfType<StringAttr>(kTFAPIImplements);
365     if (tf_api_implements_attr) {
366       // TODO(b/147536816): Keras lstm should set up the correct attributes.
367       ConvertTFAPIImplements(func, tf_api_implements_attr, module);
368     }
369   }
370 }
371 }  // namespace
372 
CreatePrepareCompositeFunctionsPass()373 std::unique_ptr<OperationPass<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
374   return std::make_unique<PrepareCompositeFunctionsPass>();
375 }
376 
377 static PassRegistration<PrepareCompositeFunctionsPass> pass;
378 
379 }  // namespace TFL
380 }  // namespace mlir
381