1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/None.h"
17 #include "llvm/Support/raw_ostream.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
19 #include "mlir/IR/Attributes.h" // from @llvm-project
20 #include "mlir/IR/Builders.h" // from @llvm-project
21 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
22 #include "mlir/IR/Operation.h" // from @llvm-project
23 #include "mlir/IR/PatternMatch.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
26 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
31
32 namespace mlir {
33 namespace TFL {
34 namespace {
35 constexpr char kTfSavedModelSessionInitNameAttr[] =
36 "__tf_saved_model_session_initializer";
37 constexpr char kTfSavedModelExportedNameAttr[] =
38 "tf_saved_model.exported_names";
39
40 // Returns Value representing the resource_id.
GetResourceIDAsI32(int resource_id,Location loc,mlir::OpBuilder & rewriter)41 Value GetResourceIDAsI32(int resource_id, Location loc,
42 mlir::OpBuilder& rewriter) {
43 return rewriter.create<ConstOp>(
44 loc, DenseElementsAttr::get(
45 RankedTensorType::get({1}, rewriter.getIntegerType(32)),
46 resource_id));
47 }
48
49 // Helper method that fetches the global tensor that 'op' points to it.
50 template <typename T>
GetGlobalTensor(const SymbolTable & symbol_table,T op,FuncOp func)51 tf_saved_model::GlobalTensorOp GetGlobalTensor(const SymbolTable& symbol_table,
52 T op, FuncOp func) {
53 auto block_arg = op.resource().template dyn_cast<BlockArgument>();
54 if (!block_arg) return nullptr;
55 int index = block_arg.getArgNumber();
56 auto sym = func.template getArgAttrOfType<FlatSymbolRefAttr>(
57 index, "tf_saved_model.bound_input");
58 if (!sym) {
59 return nullptr;
60 }
61 return symbol_table.lookup<tf_saved_model::GlobalTensorOp>(sym.getValue());
62 }
63
64 // Pass which Initializes TF variables which are already passed as bounded
65 // arguments to functions, to a TFLite variables.
66 class InitializeVariablesPass
67 : public PassWrapper<InitializeVariablesPass, OperationPass<ModuleOp>> {
68 public:
69 InitializeVariablesPass() = default;
InitializeVariablesPass(const InitializeVariablesPass &)70 InitializeVariablesPass(const InitializeVariablesPass&) {}
71
72 // Initializes a single variable identified by 'var_id' with value 'value'
73 // in 'session_init' function.
InitializeVariable(int var_id,ElementsAttr value,FuncOp session_init)74 void InitializeVariable(int var_id, ElementsAttr value, FuncOp session_init) {
75 // TODO(b/149099381): Initialize using TF::AssignVariableOp instead
76 // and let legalization be handled by Legalize variables pass.
77 mlir::OpBuilder builder(&getContext());
78 builder.setInsertionPoint(&session_init.getBlocks().front().front());
79 auto resource_op =
80 GetResourceIDAsI32(var_id, session_init.body().getLoc(), builder);
81 auto value_op =
82 builder.create<ConstOp>(session_init.body().getLoc(), value);
83 builder.create<TFL::AssignVariableOp>(session_init.body().getLoc(),
84 resource_op, value_op);
85 }
86
GetGlobalTensorOp(mlir::Operation * op,SymbolTable symbol_table,FuncOp func)87 tf_saved_model::GlobalTensorOp GetGlobalTensorOp(mlir::Operation* op,
88 SymbolTable symbol_table,
89 FuncOp func) {
90 if (auto read_var = llvm::dyn_cast_or_null<TF::ReadVariableOp>(op))
91 return GetGlobalTensor<TF::ReadVariableOp>(symbol_table, read_var, func);
92 else if (auto assign_var = llvm::dyn_cast_or_null<TF::AssignVariableOp>(op))
93 return GetGlobalTensor<TF::AssignVariableOp>(symbol_table, assign_var,
94 func);
95 return nullptr;
96 }
97
98 // Initializes all variables in the module.
InitializeVariables(const std::map<std::string,int> & global_tensor_id,SymbolTable symbol_table)99 void InitializeVariables(const std::map<std::string, int>& global_tensor_id,
100 SymbolTable symbol_table) {
101 auto module = getOperation();
102 // Check if there is Session init func already, if not create one.
103 FuncOp session_init_func = nullptr;
104 for (auto func : module.getOps<FuncOp>()) {
105 if (auto attr = func->getAttr(kTfSavedModelExportedNameAttr)) {
106 auto exported_names = attr.dyn_cast<ArrayAttr>();
107 if (!exported_names) continue;
108 for (auto exported_name : exported_names) {
109 if (auto name = exported_name.dyn_cast_or_null<StringAttr>())
110 if (name.getValue() == kTfSavedModelSessionInitNameAttr)
111 session_init_func = func;
112 }
113 if (session_init_func) break;
114 }
115 }
116 // TODO(b/149099381): Refactor to separate function in saved model util.
117 if (!session_init_func) session_init_func = CreateSessionInitFunc();
118
119 std::set<tf_saved_model::GlobalTensorOp> tensors_to_initialize;
120 for (auto func : module.getOps<FuncOp>()) {
121 func->walk([&](Operation* op) {
122 // TODO(b/149099381): Make sure to verify flex compatability
123 // with ops that accepts resource as input.
124 if (!llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp>(op))
125 return WalkResult::advance();
126 tensors_to_initialize.insert(GetGlobalTensorOp(op, symbol_table, func));
127 return WalkResult::advance();
128 });
129 }
130 for (auto global_tensor : tensors_to_initialize) {
131 InitializeVariable(global_tensor_id.at(global_tensor.sym_name().str()),
132 global_tensor.value(), session_init_func);
133 }
134 }
135 // Create a new function in the module which is SessionInitializerOp.
CreateSessionInitFunc()136 FuncOp CreateSessionInitFunc() {
137 constexpr char kSessionInitFuncName[] = "SessionInitializerFunction";
138 auto module = getOperation();
139
140 mlir::OpBuilder builder(module.body());
141 auto func_type = FunctionType::get(&getContext(), {}, {});
142 auto func = builder.create<FuncOp>(module->getLoc(), kSessionInitFuncName,
143 func_type);
144 func->setAttr(kTfSavedModelExportedNameAttr,
145 builder.getStrArrayAttr({kSessionInitFuncName}));
146 func.setVisibility(mlir::FuncOp::Visibility::Public);
147 auto funcBuilder = OpBuilder::atBlockBegin(func.addEntryBlock());
148 funcBuilder.create<mlir::ReturnOp>(func.getLoc());
149 builder.create<tf_saved_model::SessionInitializerOp>(
150 module->getLoc(),
151 builder.getArrayAttr(builder.getSymbolRefAttr(kSessionInitFuncName)));
152 return func;
153 }
154
runOnOperation()155 void runOnOperation() override {
156 auto module = getOperation();
157 // Use ordered container to make sure ids are deterministic if we got tensor
158 // ids from different part, since we have different passes that touches
159 // variables.
160 // TODO(b/149099381): Remove integer IDs after adding the new variable
161 // handle type.
162 std::map<std::string, int> global_tensor_id;
163 int id = 0;
164 for (auto global_tensor : module.getOps<tf_saved_model::GlobalTensorOp>()) {
165 global_tensor_id[global_tensor.sym_name().str()];
166 }
167 for (auto& tensor : global_tensor_id) tensor.second = id++;
168 SymbolTable symbol_table(module);
169
170 // Initialize all variables.
171 InitializeVariables(global_tensor_id, symbol_table);
172 }
173 };
174 } // namespace
175
CreateInitializeVariablesPass()176 std::unique_ptr<OperationPass<ModuleOp>> CreateInitializeVariablesPass() {
177 return std::make_unique<InitializeVariablesPass>();
178 }
179
180 static PassRegistration<InitializeVariablesPass> pass(
181 "tfl-initialize-variables-tf",
182 "Initialize TensorFlow variables to TensorFlow Lite dialect");
183
184 } // namespace TFL
185 } // namespace mlir
186