• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
17 #include "mlir/Pass/Pass.h"  // from @llvm-project
18 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
20 
21 namespace mlir {
22 namespace TFL {
23 namespace {
24 
25 // This pass inserts a TFL::CallOnce op when tf_saved_model's session
26 // initializer is given.
27 class InsertCallOnceOpFromSessionInitializerPass
28     : public mlir::PassWrapper<InsertCallOnceOpFromSessionInitializerPass,
29                                OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const30   void getDependentDialects(DialectRegistry &registry) const override {
31     registry.insert<TensorFlowLiteDialect>();
32   }
33 
getArgument() const34   StringRef getArgument() const final {
35     // This is the argument used to refer to the pass in
36     // the textual format (on the commandline for example).
37     return "tfl-insert-call-once-op";
38   }
getDescription() const39   StringRef getDescription() const final {
40     // This is a brief description of the pass.
41     return "Insert CallOnce op when tf_saved_model's session initializer is "
42            "given";
43   }
44 
45  private:
46   void runOnOperation() override;
47 };
48 
runOnOperation()49 void InsertCallOnceOpFromSessionInitializerPass::runOnOperation() {
50   ModuleOp module = getOperation();
51   tf_saved_model::SessionInitializerOp session_init_op =
52       tf_saved_model::GetSessionInitializerOp(module);
53 
54   if (!session_init_op) return;
55 
56   SymbolTable symbol_table(module);
57 
58   for (auto sym_ref : session_init_op.initializers()) {
59     FuncOp init_func_op = symbol_table.lookup<mlir::FuncOp>(
60         sym_ref.cast<FlatSymbolRefAttr>().getValue());
61 
62     if (!init_func_op) {
63       module.emitError("no session initializer function found");
64       return signalPassFailure();
65     }
66 
67     for (auto func : module.getOps<FuncOp>()) {
68       auto dict_attr =
69           func->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
70       if (!dict_attr) continue;
71 
72       OpBuilder builder(func.getContext());
73       builder.setInsertionPointToStart(&func.getBlocks().front());
74       builder.create<TFL::CallOnceOp>(func.getLoc(), init_func_op.getName());
75     }
76   }
77 }
78 
79 }  // namespace
80 
81 // Inserts a TFL::CallOnce op when tf_saved_model's session initializer is
82 // given.
83 std::unique_ptr<OperationPass<ModuleOp>>
CreateInsertCallOnceOpFromSessionInitializerPass()84 CreateInsertCallOnceOpFromSessionInitializerPass() {
85   return std::make_unique<InsertCallOnceOpFromSessionInitializerPass>();
86 }
87 
88 static PassRegistration<InsertCallOnceOpFromSessionInitializerPass> pass;
89 
90 }  // namespace TFL
91 }  // namespace mlir
92