1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/SmallSet.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/ADT/StringSet.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/IR/SymbolTable.h" // from @llvm-project
23 #include "mlir/Pass/Pass.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/session_utils.h"
29 #include "tensorflow/core/framework/resource_var.h"
30 #include "tensorflow/core/public/session.h"
31
32 namespace mlir {
33 namespace tf_saved_model {
34 namespace {
35
36 class InitializeVariablesInSessionInitializerPass
37 : public PassWrapper<InitializeVariablesInSessionInitializerPass,
38 OperationPass<ModuleOp>> {
39 public:
InitializeVariablesInSessionInitializerPass(tensorflow::Session * session)40 explicit InitializeVariablesInSessionInitializerPass(
41 tensorflow::Session* session)
42 : session_(session) {}
43
getArgument() const44 StringRef getArgument() const final {
45 return "tf-saved-model-initialize-variables-in-session-init";
46 }
47
getDescription() const48 StringRef getDescription() const final {
49 return "Initialize variables in session initializer function.";
50 }
51
52 void runOnOperation() override;
53
54 private:
55 void InitializeVariable(TF::VarHandleOp var_handle_op,
56 tensorflow::Tensor* tensor, FuncOp session_init_func,
57 OpBuilder builder);
58
59 tensorflow::Session* session_ = nullptr;
60 };
61
InitializeVariable(TF::VarHandleOp var_handle_op,tensorflow::Tensor * tensor,FuncOp session_init_func,OpBuilder builder)62 void InitializeVariablesInSessionInitializerPass::InitializeVariable(
63 TF::VarHandleOp var_handle_op, tensorflow::Tensor* tensor,
64 FuncOp session_init_func, OpBuilder builder) {
65 tensorflow::StatusOr<ElementsAttr> tensor_attr_or =
66 tensorflow::ConvertTensor(*tensor, &builder);
67 assert(tensor_attr_or.ok() && "Expect valid tensor");
68 ElementsAttr tensor_attr = tensor_attr_or.ValueOrDie();
69
70 builder.setInsertionPointToStart(&session_init_func.getBlocks().front());
71 auto var_handle_op_in_init = var_handle_op->clone();
72 builder.insert(var_handle_op_in_init);
73 auto const_op = builder.create<mlir::ConstantOp>(
74 session_init_func.getLoc(), tensor_attr.getType(), tensor_attr);
75
76 builder.create<TF::AssignVariableOp>(
77 session_init_func.getLoc(), llvm::ArrayRef<mlir::Type>{},
78 llvm::ArrayRef<mlir::Value>{var_handle_op_in_init->getResult(0),
79 const_op.getResult()});
80 }
81
82 constexpr char kTfSavedModelExportedNameAttr[] =
83 "tf_saved_model.exported_names";
84
CreateSessionInitFunc(ModuleOp module)85 FuncOp CreateSessionInitFunc(ModuleOp module) {
86 constexpr char kSessionInitFuncName[] = "SessionInitializerFunction";
87
88 mlir::OpBuilder builder(module.body());
89 auto func_type =
90 FunctionType::get(module.getContext(), /*inputs=*/{}, /*results=*/{});
91 auto func =
92 builder.create<FuncOp>(module->getLoc(), kSessionInitFuncName, func_type);
93 func->setAttr(kTfSavedModelExportedNameAttr,
94 builder.getStrArrayAttr({kSessionInitFuncName}));
95 func.setVisibility(mlir::FuncOp::Visibility::Public);
96 auto func_builder = OpBuilder::atBlockBegin(func.addEntryBlock());
97 func_builder.create<mlir::ReturnOp>(func.getLoc());
98 // In cases where there is a session initializer op with empty initializer,
99 // replace the session initializer with the new one that points to the session
100 // initializer func.
101 SessionInitializerOp session_init_op = GetSessionInitializerOp(module);
102 auto new_session_init_op =
103 builder.create<tf_saved_model::SessionInitializerOp>(
104 module->getLoc(),
105 builder.getArrayAttr(builder.getSymbolRefAttr(kSessionInitFuncName)));
106 if (session_init_op) {
107 session_init_op->replaceAllUsesWith(new_session_init_op);
108 session_init_op->erase();
109 }
110 return func;
111 }
112
GetOrCreateSessionInitFunc(ModuleOp module)113 FuncOp GetOrCreateSessionInitFunc(ModuleOp module) {
114 SessionInitializerOp session_init_op = GetSessionInitializerOp(module);
115 if (!session_init_op) return CreateSessionInitFunc(module);
116
117 SymbolTable symbol_table(module);
118 if (!session_init_op.initializers().empty()) {
119 FuncOp init_func_op = symbol_table.lookup<mlir::FuncOp>(
120 session_init_op.initializers()[0].cast<FlatSymbolRefAttr>().getValue());
121 return init_func_op;
122 }
123 return CreateSessionInitFunc(module);
124 }
125
runOnOperation()126 void InitializeVariablesInSessionInitializerPass::runOnOperation() {
127 ModuleOp module = getOperation();
128 if (!session_) return;
129
130 const tensorflow::DeviceMgr* mgr = nullptr;
131 auto status = session_->LocalDeviceManager(&mgr);
132 if (!status.ok()) {
133 module->emitError("failed to fetch device manager: " +
134 status.error_message());
135 return signalPassFailure();
136 }
137
138 // Fetch all VarHandleOp.
139 llvm::StringSet<> variable_names;
140 llvm::SmallVector<TF::VarHandleOp, 4> var_ops;
141 for (auto func_op : module.getOps<FuncOp>()) {
142 for (auto var_handle_op : func_op.getOps<TF::VarHandleOp>()) {
143 auto variable_name = GetVariableName(var_handle_op);
144 if (variable_names.count(variable_name)) continue;
145 var_ops.emplace_back(var_handle_op);
146 variable_names.insert(variable_name);
147 }
148 }
149
150 // Get resources from Session.
151 auto resource_tensors_or = GetResourcesFromSession(var_ops, session_);
152 if (!resource_tensors_or.ok()) {
153 module->emitError(resource_tensors_or.status().message().data());
154 return signalPassFailure();
155 }
156
157 auto session_init_func = GetOrCreateSessionInitFunc(module);
158 OpBuilder builder(session_init_func.getContext());
159
160 for (auto var_and_tensor : llvm::zip(var_ops, resource_tensors_or.value())) {
161 auto& var_op = std::get<0>(var_and_tensor);
162 auto& resource_tensor = std::get<1>(var_and_tensor);
163 if (resource_tensor.dtype() != tensorflow::DT_RESOURCE) {
164 InitializeVariable(var_op, &resource_tensor, session_init_func, builder);
165 continue;
166 }
167
168 auto handle = resource_tensor.scalar<tensorflow::ResourceHandle>()();
169 auto* var_ptr = GetVariableFromSession(var_op, handle.device(), mgr);
170 if (!var_ptr) {
171 // If no value in session, then just skip this variable.
172 // This can happen if the variable is not saved in checkpoint.
173 // For example, when the variable is created on every call.
174 continue;
175 }
176 tensorflow::core::RefCountPtr<tensorflow::Var> var(var_ptr);
177 auto* tensor = var_ptr->tensor();
178
179 InitializeVariable(var_op, tensor, session_init_func, builder);
180 }
181 }
182
183 } // namespace
184
185 std::unique_ptr<OperationPass<ModuleOp>>
CreateInitializeVariablesInSessionInitializerPass(tensorflow::Session * session)186 CreateInitializeVariablesInSessionInitializerPass(
187 tensorflow::Session* session) {
188 return std::make_unique<InitializeVariablesInSessionInitializerPass>(session);
189 }
190 } // namespace tf_saved_model
191 } // namespace mlir
192