• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <vector>
18 
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "mlir/IR/Builders.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Support/LLVM.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
28 
29 namespace mlir {
30 namespace tf_saved_model {
31 namespace {
32 using mlir::Operation;
33 using mlir::TF::VarHandleOp;
34 
35 class RemoveVariablesInSessionInitializerPass
36     : public PassWrapper<RemoveVariablesInSessionInitializerPass,
37                          OperationPass<ModuleOp>> {
38  public:
getArgument() const39   StringRef getArgument() const final {
40     return "tf-saved-model-remove-vars-in-session-initializer";
41   }
42 
getDescription() const43   StringRef getDescription() const final {
44     return "Remove variables in tf saved model's session initializer.";
45   }
46 
47   void runOnOperation() override;
48 };
49 
RecursiveRemove(Operation * op,llvm::SmallVectorImpl<Operation * > & erase_list,llvm::SmallPtrSetImpl<Operation * > & dead_ops)50 void RecursiveRemove(Operation* op,
51                      llvm::SmallVectorImpl<Operation*>& erase_list,
52                      llvm::SmallPtrSetImpl<Operation*>& dead_ops) {
53   for (mlir::Value res : op->getResults()) {
54     for (Operation* user : res.getUsers()) {
55       if (!dead_ops.insert(user).second) continue;
56       RecursiveRemove(user, erase_list, dead_ops);
57     }
58   }
59 
60   erase_list.push_back(op);
61 
62   for (auto& use : op->getOpOperands()) {
63     if (auto op_result = use.get().dyn_cast<mlir::OpResult>()) {
64       Operation* def = op_result.getDefiningOp();
65       if (!dead_ops.insert(def).second) continue;
66       RecursiveRemove(def, erase_list, dead_ops);
67     }
68   }
69 }
70 
RemoveVariables(llvm::ArrayRef<VarHandleOp> vars)71 void RemoveVariables(llvm::ArrayRef<VarHandleOp> vars) {
72   // TODO(b/160906885): Repalce the following code with an non-recursive one.
73   llvm::SmallVector<Operation*, 4> erase_list;
74   llvm::SmallPtrSet<Operation*, 4> dead_ops;
75 
76   // Marks all the variables dead.
77   dead_ops.insert(vars.begin(), vars.end());
78 
79   // Removes relevant ops in topological order.
80   for (auto& op : vars) RecursiveRemove(op, erase_list, dead_ops);
81 
82   // Erases the ops.
83   for (auto op : erase_list) op->erase();
84 }
85 
runOnOperation()86 void RemoveVariablesInSessionInitializerPass::runOnOperation() {
87   ModuleOp module = getOperation();
88   SessionInitializerOp session_init_op = GetSessionInitializerOp(module);
89 
90   if (!session_init_op) return;
91 
92   SymbolTable symbol_table(module);
93 
94   for (auto sym_ref : session_init_op.initializers()) {
95     FuncOp init_func_op = symbol_table.lookup<mlir::FuncOp>(
96         sym_ref.cast<FlatSymbolRefAttr>().getValue());
97 
98     if (!init_func_op) {
99       module.emitError("no session initializer function found");
100       return signalPassFailure();
101     }
102 
103     if (init_func_op.getBlocks().size() != 1) {
104       init_func_op.emitError("expects exactly one block in the MLIR function");
105       return signalPassFailure();
106     }
107 
108     auto var_handle_ops =
109         init_func_op.getBlocks().front().getOps<VarHandleOp>();
110     llvm::SmallVector<VarHandleOp, 4> init_vars(var_handle_ops.begin(),
111                                                 var_handle_ops.end());
112     RemoveVariables(init_vars);
113   }
114 }
115 
116 }  // namespace
117 
118 static PassRegistration<RemoveVariablesInSessionInitializerPass> pass;
119 
120 std::unique_ptr<OperationPass<ModuleOp>>
CreateRemoveVariablesInSessionInitializerPass()121 CreateRemoveVariablesInSessionInitializerPass() {
122   return std::make_unique<RemoveVariablesInSessionInitializerPass>();
123 }
124 
125 }  // namespace tf_saved_model
126 }  // namespace mlir
127