• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/SmallSet.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/ADT/StringSet.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/session_utils.h"
29 #include "tensorflow/core/framework/resource_var.h"
30 #include "tensorflow/core/public/session.h"
31 
32 namespace mlir {
33 namespace tf_saved_model {
34 namespace {
35 
36 class InitializeVariablesInSessionInitializerPass
37     : public PassWrapper<InitializeVariablesInSessionInitializerPass,
38                          OperationPass<ModuleOp>> {
39  public:
InitializeVariablesInSessionInitializerPass(tensorflow::Session * session)40   explicit InitializeVariablesInSessionInitializerPass(
41       tensorflow::Session* session)
42       : session_(session) {}
43 
getArgument() const44   StringRef getArgument() const final {
45     return "tf-saved-model-initialize-variables-in-session-init";
46   }
47 
getDescription() const48   StringRef getDescription() const final {
49     return "Initialize variables in session initializer function.";
50   }
51 
52   void runOnOperation() override;
53 
54  private:
55   void InitializeVariable(TF::VarHandleOp var_handle_op,
56                           tensorflow::Tensor* tensor, FuncOp session_init_func,
57                           OpBuilder builder);
58 
59   tensorflow::Session* session_ = nullptr;
60 };
61 
InitializeVariable(TF::VarHandleOp var_handle_op,tensorflow::Tensor * tensor,FuncOp session_init_func,OpBuilder builder)62 void InitializeVariablesInSessionInitializerPass::InitializeVariable(
63     TF::VarHandleOp var_handle_op, tensorflow::Tensor* tensor,
64     FuncOp session_init_func, OpBuilder builder) {
65   tensorflow::StatusOr<ElementsAttr> tensor_attr_or =
66       tensorflow::ConvertTensor(*tensor, &builder);
67   assert(tensor_attr_or.ok() && "Expect valid tensor");
68   ElementsAttr tensor_attr = tensor_attr_or.ValueOrDie();
69 
70   builder.setInsertionPointToStart(&session_init_func.getBlocks().front());
71   auto var_handle_op_in_init = var_handle_op->clone();
72   builder.insert(var_handle_op_in_init);
73   auto const_op = builder.create<mlir::ConstantOp>(
74       session_init_func.getLoc(), tensor_attr.getType(), tensor_attr);
75 
76   builder.create<TF::AssignVariableOp>(
77       session_init_func.getLoc(), llvm::ArrayRef<mlir::Type>{},
78       llvm::ArrayRef<mlir::Value>{var_handle_op_in_init->getResult(0),
79                                   const_op.getResult()});
80 }
81 
82 constexpr char kTfSavedModelExportedNameAttr[] =
83     "tf_saved_model.exported_names";
84 
CreateSessionInitFunc(ModuleOp module)85 FuncOp CreateSessionInitFunc(ModuleOp module) {
86   constexpr char kSessionInitFuncName[] = "SessionInitializerFunction";
87 
88   mlir::OpBuilder builder(module.body());
89   auto func_type =
90       FunctionType::get(module.getContext(), /*inputs=*/{}, /*results=*/{});
91   auto func =
92       builder.create<FuncOp>(module->getLoc(), kSessionInitFuncName, func_type);
93   func->setAttr(kTfSavedModelExportedNameAttr,
94                 builder.getStrArrayAttr({kSessionInitFuncName}));
95   func.setVisibility(mlir::FuncOp::Visibility::Public);
96   auto func_builder = OpBuilder::atBlockBegin(func.addEntryBlock());
97   func_builder.create<mlir::ReturnOp>(func.getLoc());
98   // In cases where there is a session initializer op with empty initializer,
99   // replace the session initializer with the new one that points to the session
100   // initializer func.
101   SessionInitializerOp session_init_op = GetSessionInitializerOp(module);
102   auto new_session_init_op =
103       builder.create<tf_saved_model::SessionInitializerOp>(
104           module->getLoc(),
105           builder.getArrayAttr(builder.getSymbolRefAttr(kSessionInitFuncName)));
106   if (session_init_op) {
107     session_init_op->replaceAllUsesWith(new_session_init_op);
108     session_init_op->erase();
109   }
110   return func;
111 }
112 
GetOrCreateSessionInitFunc(ModuleOp module)113 FuncOp GetOrCreateSessionInitFunc(ModuleOp module) {
114   SessionInitializerOp session_init_op = GetSessionInitializerOp(module);
115   if (!session_init_op) return CreateSessionInitFunc(module);
116 
117   SymbolTable symbol_table(module);
118   if (!session_init_op.initializers().empty()) {
119     FuncOp init_func_op = symbol_table.lookup<mlir::FuncOp>(
120         session_init_op.initializers()[0].cast<FlatSymbolRefAttr>().getValue());
121     return init_func_op;
122   }
123   return CreateSessionInitFunc(module);
124 }
125 
runOnOperation()126 void InitializeVariablesInSessionInitializerPass::runOnOperation() {
127   ModuleOp module = getOperation();
128   if (!session_) return;
129 
130   const tensorflow::DeviceMgr* mgr = nullptr;
131   auto status = session_->LocalDeviceManager(&mgr);
132   if (!status.ok()) {
133     module->emitError("failed to fetch device manager: " +
134                       status.error_message());
135     return signalPassFailure();
136   }
137 
138   // Fetch all VarHandleOp.
139   llvm::StringSet<> variable_names;
140   llvm::SmallVector<TF::VarHandleOp, 4> var_ops;
141   for (auto func_op : module.getOps<FuncOp>()) {
142     for (auto var_handle_op : func_op.getOps<TF::VarHandleOp>()) {
143       auto variable_name = GetVariableName(var_handle_op);
144       if (variable_names.count(variable_name)) continue;
145       var_ops.emplace_back(var_handle_op);
146       variable_names.insert(variable_name);
147     }
148   }
149 
150   // Get resources from Session.
151   auto resource_tensors_or = GetResourcesFromSession(var_ops, session_);
152   if (!resource_tensors_or.ok()) {
153     module->emitError(resource_tensors_or.status().message().data());
154     return signalPassFailure();
155   }
156 
157   auto session_init_func = GetOrCreateSessionInitFunc(module);
158   OpBuilder builder(session_init_func.getContext());
159 
160   for (auto var_and_tensor : llvm::zip(var_ops, resource_tensors_or.value())) {
161     auto& var_op = std::get<0>(var_and_tensor);
162     auto& resource_tensor = std::get<1>(var_and_tensor);
163     if (resource_tensor.dtype() != tensorflow::DT_RESOURCE) {
164       InitializeVariable(var_op, &resource_tensor, session_init_func, builder);
165       continue;
166     }
167 
168     auto handle = resource_tensor.scalar<tensorflow::ResourceHandle>()();
169     auto* var_ptr = GetVariableFromSession(var_op, handle.device(), mgr);
170     if (!var_ptr) {
171       // If no value in session, then just skip this variable.
172       // This can happen if the variable is not saved in checkpoint.
173       // For example, when the variable is created on every call.
174       continue;
175     }
176     tensorflow::core::RefCountPtr<tensorflow::Var> var(var_ptr);
177     auto* tensor = var_ptr->tensor();
178 
179     InitializeVariable(var_op, tensor, session_init_func, builder);
180   }
181 }
182 
183 }  // namespace
184 
185 std::unique_ptr<OperationPass<ModuleOp>>
CreateInitializeVariablesInSessionInitializerPass(tensorflow::Session * session)186 CreateInitializeVariablesInSessionInitializerPass(
187     tensorflow::Session* session) {
188   return std::make_unique<InitializeVariablesInSessionInitializerPass>(session);
189 }
190 }  // namespace tf_saved_model
191 }  // namespace mlir
192