1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <utility>
17 #include <vector>
18
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/DenseSet.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/Casting.h"
27 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
30 #include "mlir/IR/Location.h" // from @llvm-project
31 #include "mlir/IR/Value.h" // from @llvm-project
32 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
33 #include "mlir/Pass/Pass.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
38 #include "tensorflow/compiler/mlir/tensorflow/utils/session_utils.h"
39 #include "tensorflow/core/framework/resource_var.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/public/session.h"
42
43 namespace mlir {
44 namespace tf_saved_model {
45 namespace {
46
47 // Build and returns ElementsAttr which holds the data in 'tensor'.
GetTensorValueAsElementsAttr(const tensorflow::Tensor & tensor,OpBuilder builder)48 ElementsAttr GetTensorValueAsElementsAttr(const tensorflow::Tensor& tensor,
49 OpBuilder builder) {
50 tensorflow::StatusOr<ElementsAttr> tensor_attr_or =
51 tensorflow::ConvertTensor(tensor, &builder);
52 if (!tensor_attr_or.ok()) return nullptr;
53 return tensor_attr_or.ValueOrDie();
54 }
55
56 // Creates ConstantOp which holds 'tensor_elements'.
GetConstOpFromElementsAttr(ElementsAttr tensor_elements,OpBuilder builder,Location loc)57 mlir::ConstantOp GetConstOpFromElementsAttr(ElementsAttr tensor_elements,
58 OpBuilder builder, Location loc) {
59 return builder.create<mlir::ConstantOp>(loc, tensor_elements.getType(),
60 tensor_elements);
61 }
62
63 // Returns ElementsAttr which has the value held by 'resource_tensor'.
GetTensorValueAsElementsAttr(TF::VarHandleOp var_handle_op,const tensorflow::Tensor & resource_tensor,const tensorflow::DeviceMgr * mgr,OpBuilder builder)64 ElementsAttr GetTensorValueAsElementsAttr(
65 TF::VarHandleOp var_handle_op, const tensorflow::Tensor& resource_tensor,
66 const tensorflow::DeviceMgr* mgr, OpBuilder builder) {
67 if (resource_tensor.dtype() != tensorflow::DT_RESOURCE) {
68 return GetTensorValueAsElementsAttr(resource_tensor, builder);
69 }
70
71 auto handle = resource_tensor.scalar<tensorflow::ResourceHandle>()();
72 auto* var_ptr = tf_saved_model::GetVariableFromSession(var_handle_op,
73 handle.device(), mgr);
74 if (!var_ptr) {
75 return nullptr;
76 }
77 tensorflow::core::RefCountPtr<tensorflow::Var> var(var_ptr);
78 auto* tensor = var_ptr->tensor();
79
80 return GetTensorValueAsElementsAttr(*tensor, builder);
81 }
82
83 // Replace usage of 'read_variable_op' with 'value'.
PropagateUsage(TF::ReadVariableOp read_variable_op,ElementsAttr value)84 void PropagateUsage(TF::ReadVariableOp read_variable_op, ElementsAttr value) {
85 OpBuilder builder(read_variable_op);
86 read_variable_op->getResult(0).replaceAllUsesWith(
87 GetConstOpFromElementsAttr(value, builder, read_variable_op->getLoc()));
88 }
89
90 // Propagates a resource usage across the graph where
91 // 'user_op' uses a resource and is passed to this op at 'argument_index'.
92 // This resource should be replaced by 'value'.
93 // Output params:
94 // - work_list: Is updated with new regions to process that is called
95 // by 'user_op';
96 // - arguments_to_erase: Captures updates to the graph - which arguments
97 // to remove from the op;
PropagateUsage(Operation * user_op,int argument_index,ElementsAttr value,llvm::SmallVector<std::pair<Region *,int>,4> * work_list,llvm::DenseMap<Operation *,llvm::SmallVector<unsigned int,4>> * arguments_to_erase)98 void PropagateUsage(
99 Operation* user_op, int argument_index, ElementsAttr value,
100 llvm::SmallVector<std::pair<Region*, int>, 4>* work_list,
101 llvm::DenseMap<Operation*, llvm::SmallVector<unsigned int, 4>>*
102 arguments_to_erase) {
103 if (auto read_variable_op = dyn_cast<TF::ReadVariableOp>(user_op)) {
104 PropagateUsage(read_variable_op, value);
105 (*arguments_to_erase)[read_variable_op];
106 } else if (auto call = dyn_cast<CallOpInterface>(user_op)) {
107 if (auto func = dyn_cast<FuncOp>(call.resolveCallable())) {
108 (*arguments_to_erase)[func].push_back(argument_index);
109 work_list->push_back(std::make_pair(&func.getRegion(), argument_index));
110 }
111 (*arguments_to_erase)[call].push_back(argument_index);
112 } else if (auto if_op = dyn_cast<TF::IfOp>(user_op)) {
113 for (auto callee : {if_op.then_function(), if_op.else_function()}) {
114 (*arguments_to_erase)[callee].push_back(argument_index - 1);
115 work_list->push_back(std::make_pair(&callee.body(), argument_index - 1));
116 }
117 (*arguments_to_erase)[if_op].push_back(argument_index);
118 } else if (auto if_op = dyn_cast<TF::IfRegionOp>(user_op)) {
119 for (auto callee : {&if_op.then_branch(), &if_op.else_branch()}) {
120 work_list->push_back(std::make_pair(callee, argument_index));
121 }
122 (*arguments_to_erase)[if_op].push_back(argument_index);
123 } else if (auto while_op = dyn_cast<TF::WhileOp>(user_op)) {
124 for (auto callee : {while_op.cond_function(), while_op.body_function()}) {
125 (*arguments_to_erase)[callee].push_back(argument_index);
126 work_list->push_back(std::make_pair(&callee.body(), argument_index));
127 }
128 (*arguments_to_erase)[while_op].push_back(argument_index);
129 } else if (auto while_op = dyn_cast<TF::WhileRegionOp>(user_op)) {
130 for (auto callee : {&while_op.cond(), &while_op.body()}) {
131 work_list->push_back(std::make_pair(callee, argument_index));
132 }
133 (*arguments_to_erase)[while_op].push_back(argument_index);
134 }
135 }
136
137 // An override that takes region.
PropagateUsage(Region * region,ElementsAttr value,int argument_index,llvm::SmallVector<std::pair<Region *,int>,4> * work_list,llvm::DenseMap<Operation *,llvm::SmallVector<unsigned int,4>> * arguments_to_erase)138 void PropagateUsage(
139 Region* region, ElementsAttr value, int argument_index,
140 llvm::SmallVector<std::pair<Region*, int>, 4>* work_list,
141 llvm::DenseMap<Operation*, llvm::SmallVector<unsigned int, 4>>*
142 arguments_to_erase) {
143 auto arg = region->getArgument(argument_index);
144 for (auto& usage : arg.getUses()) {
145 auto* user_op = usage.getOwner();
146 int operand_index = usage.getOperandNumber();
147 PropagateUsage(user_op, operand_index, value, work_list,
148 arguments_to_erase);
149 }
150 }
151
152 // Traces usage of 'var_handle_op' and replaces it's usage with constant value
153 // 'value'.
154 // All op operands updates are captured in 'arguments_to_erase'.
ReplaceVarWithConstant(TF::VarHandleOp var_handle_op,ElementsAttr value,llvm::DenseMap<Operation *,llvm::SmallVector<unsigned int,4>> * arguments_to_erase)155 void ReplaceVarWithConstant(
156 TF::VarHandleOp var_handle_op, ElementsAttr value,
157 llvm::DenseMap<Operation*, llvm::SmallVector<unsigned int, 4>>*
158 arguments_to_erase) {
159 llvm::SmallVector<std::pair<Region*, int>, 4> work_list;
160 for (auto& usage : var_handle_op->getUses()) {
161 auto* user_op = usage.getOwner();
162 int operand_index = usage.getOperandNumber();
163 PropagateUsage(user_op, operand_index, value, &work_list,
164 arguments_to_erase);
165 }
166 // Container to mark visited regions to avoid infinite loop.
167 llvm::DenseSet<std::pair<Region*, int>> visited;
168 while (!work_list.empty()) {
169 auto work_item = work_list.pop_back_val();
170 if (visited.contains(work_item)) continue;
171 PropagateUsage(work_item.first, value, work_item.second, &work_list,
172 arguments_to_erase);
173 visited.insert(work_item);
174 }
175 }
176
177 // A pass that tries to freeze / constant fold read only variables in the graph.
178 // The pass will analyze the variable usage and if the variable is immutable
179 // it will replace it with constant tensor which has its value.
180 // Note: This pass currently only works with variables that are initialized
181 // during Session init function only. Expanding to other uses is a todo.
182 class FreezeVariablesPass
183 : public PassWrapper<FreezeVariablesPass, OperationPass<ModuleOp>> {
184 public:
185 // If no session is provided or null the pass is no-op.
FreezeVariablesPass(tensorflow::Session * session=nullptr)186 explicit FreezeVariablesPass(tensorflow::Session* session = nullptr)
187 : session_(session) {}
188
189 void runOnOperation() override;
190
191 private:
192 tensorflow::Session* session_;
193 };
194
195 // Helper that returns the FuncOp that is the SessionInit function which
196 // will be called to initialize all resources.
197 // Returns nullptr if no function is found.
GetSessionInitializerFunc(ModuleOp module)198 FuncOp GetSessionInitializerFunc(ModuleOp module) {
199 auto session_init_op = tf_saved_model::GetSessionInitializerOp(module);
200 SymbolTable symbol_table(module);
201 if (session_init_op && !session_init_op.initializers().empty()) {
202 FuncOp init_func_op = symbol_table.lookup<mlir::FuncOp>(
203 session_init_op.initializers()[0].cast<FlatSymbolRefAttr>().getValue());
204 return init_func_op;
205 }
206 return nullptr;
207 }
208
209 // Returns ID for identifying a resource.
GetResourceKey(Operation * op)210 std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef> GetResourceKey(
211 Operation* op) {
212 llvm::StringRef device;
213 if (auto attr = op->getAttrOfType<mlir::StringAttr>("device")) {
214 device = attr.getValue();
215 }
216
217 llvm::StringRef container;
218 if (auto attr = op->getAttrOfType<mlir::StringAttr>("container")) {
219 container = attr.getValue();
220 }
221
222 llvm::StringRef shared_name;
223 if (auto attr = op->getAttrOfType<mlir::StringAttr>("shared_name")) {
224 shared_name = attr.getValue();
225 }
226
227 return std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>{
228 device, container, shared_name};
229 }
230
231 // Remove the initialization of the variables in 'var_handle_ops' from
232 // the session init function 'sesion_init_func'
RemoveVariablesInitializations(const llvm::SmallVector<TF::VarHandleOp,4> & var_handle_ops,FuncOp sesion_init_func)233 void RemoveVariablesInitializations(
234 const llvm::SmallVector<TF::VarHandleOp, 4>& var_handle_ops,
235 FuncOp sesion_init_func) {
236 // We identify the variables using (device, container, shared_name) of the
237 // resource. Capture them here and use them to identify the useless
238 // initializations.
239 llvm::SetVector<std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>>
240 variables;
241 for (auto var_handle_op : var_handle_ops)
242 variables.insert(GetResourceKey(var_handle_op));
243
244 llvm::SmallVector<Operation*, 4> work_list;
245 for (auto var_handle_op : sesion_init_func.getOps<TF::VarHandleOp>()) {
246 if (variables.count(GetResourceKey(var_handle_op)))
247 work_list.push_back(var_handle_op);
248 }
249
250 // Capture list of ops to be erased by traversing usage starting from
251 // the VarHandle ops.
252 llvm::SetVector<Operation*> erase_list;
253 while (!work_list.empty()) {
254 auto* operation = work_list.pop_back_val();
255 erase_list.insert(operation);
256 for (auto& use : operation->getUses()) {
257 if (erase_list.count(use.getOwner())) continue;
258 work_list.push_back(use.getOwner());
259 }
260 }
261
262 for (auto* op : erase_list) {
263 op->dropAllUses();
264 op->erase();
265 }
266 }
267
268 // Updates terminator op arguments of 'func' after removing arguments
269 // specified in 'arguments_to_erase'.
270 template <typename T>
UpdateTerminatorArguments(T & func,const llvm::SmallVector<unsigned,4> & arguments_to_erase)271 void UpdateTerminatorArguments(
272 T& func, const llvm::SmallVector<unsigned, 4>& arguments_to_erase) {
273 llvm::BitVector erase_indices(func.front().getTerminator()->getNumOperands());
274 for (auto arg_index : arguments_to_erase) {
275 auto argument = func.getArgument(arg_index);
276 for (auto& use : argument.getUses()) {
277 if (llvm::isa<ReturnOp, TF::YieldOp>(use.getOwner())) {
278 int operand_index = argument.getUses().begin()->getOperandNumber();
279 erase_indices.set(operand_index);
280 }
281 }
282 func.getArgument(arg_index).dropAllUses();
283 }
284 if (llvm::isa<ReturnOp, TF::YieldOp>(func.front().getTerminator())) {
285 func.front().getTerminator()->eraseOperands(erase_indices);
286 }
287 }
288
289 // Updates 'while_op' signatures based on which arguments should be removed
290 // in 'arguments_to_erase'.
291 template <typename T>
GetUpdatedWhileOp(T while_op,const llvm::SmallVector<unsigned,4> & arguments_to_erase)292 T GetUpdatedWhileOp(T while_op,
293 const llvm::SmallVector<unsigned, 4>& arguments_to_erase) {
294 OpBuilder builder(while_op);
295 llvm::SmallVector<Type, 4> new_operand_types;
296 llvm::SmallVector<Value> new_operands;
297 auto operands = while_op->getOperands();
298 const int num_operands = while_op->getNumOperands();
299 llvm::BitVector skip_indices(num_operands);
300 for (int i : arguments_to_erase) skip_indices.set(i);
301 for (int i = 0; i < num_operands; ++i) {
302 if (!skip_indices.test(i)) {
303 new_operand_types.emplace_back(operands[i].getType());
304 new_operands.emplace_back(operands[i]);
305 }
306 }
307 auto new_while_op = builder.create<T>(while_op->getLoc(), new_operand_types,
308 new_operands, while_op->getAttrs());
309 for (int i = 0; i < num_operands; ++i) {
310 if (!skip_indices.test(i)) {
311 while_op->getResult(i).replaceAllUsesWith(new_while_op->getResult(i));
312 }
313 }
314 return new_while_op;
315 }
316
runOnOperation()317 void FreezeVariablesPass::runOnOperation() {
318 if (!session_) return;
319 ModuleOp module = getOperation();
320 const tensorflow::DeviceMgr* mgr = nullptr;
321 auto status = session_->LocalDeviceManager(&mgr);
322 if (!status.ok()) {
323 module->emitError("failed to fetch device manager: " +
324 status.error_message());
325 return signalPassFailure();
326 }
327
328 FuncOp session_init_func = GetSessionInitializerFunc(module);
329
330 TF::ResourceAnalyzer analyzer(module, /*skip_session_init=*/true);
331 llvm::SmallVector<TF::VarHandleOp, 4> variables;
332 // Capture list of all read only variables.
333 for (auto func : module.getOps<FuncOp>()) {
334 if (func == session_init_func) continue;
335 for (auto var_handle_op : func.getOps<TF::VarHandleOp>()) {
336 if (!analyzer.IsPotentiallyWritten(var_handle_op.resource())) {
337 variables.push_back(var_handle_op);
338 }
339 }
340 }
341
342 // Fetch the values to replace the VarHandleOps with.
343 auto resource_tensors_or =
344 tf_saved_model::GetResourcesFromSession(variables, session_);
345 if (!resource_tensors_or.ok()) {
346 module->emitError(resource_tensors_or.status().message().data());
347 signalPassFailure();
348 }
349
350 auto* context = module.getContext();
351 OpBuilder builder(context);
352 // Note: We can't modify the graph while navigating through it, as erasing
353 // invalidate pointers.
354 // So instead we capture all the updates in the below map, and then
355 // process them after.
356
357 // Container to hold all update actions on ops.
358 // Key: Operation to update.
359 // Value: optional list of arguments to delete from this op.
360 llvm::DenseMap<Operation*, llvm::SmallVector<unsigned int, 4>>
361 arguments_to_erase;
362 for (auto variable_value_pair :
363 llvm::zip(variables, resource_tensors_or.value())) {
364 auto var_handle_op = std::get<0>(variable_value_pair);
365 builder.setInsertionPointAfterValue(var_handle_op);
366 auto elements_attr = GetTensorValueAsElementsAttr(
367 var_handle_op, std::get<1>(variable_value_pair), mgr, builder);
368 ReplaceVarWithConstant(var_handle_op, elements_attr, &arguments_to_erase);
369 }
370
371 // All updates to different ops are captured in 'arguments_to_erase'.
372 // Now loop on them and based on each item type update accordingly.
373 for (auto& items : arguments_to_erase) {
374 if (auto func = dyn_cast<FuncOp>(items.getFirst())) {
375 UpdateTerminatorArguments(func, items.getSecond());
376 func.eraseArguments(items.getSecond());
377 } else if (auto read_var = dyn_cast<TF::ReadVariableOp>(items.getFirst())) {
378 // Read variables was already replaced by constant op. Just remove the op.
379 read_var->erase();
380 } else if (auto while_op = dyn_cast<TF::WhileOp>(items.getFirst())) {
381 auto new_while_op =
382 GetUpdatedWhileOp<TF::WhileOp>(while_op, items.getSecond());
383 new_while_op.body_function().eraseResults(items.getSecond());
384 while_op->erase();
385 } else if (auto while_op = dyn_cast<TF::WhileRegionOp>(items.getFirst())) {
386 auto new_while_op = GetUpdatedWhileOp(while_op, items.getSecond());
387 new_while_op.cond().takeBody(while_op.cond());
388 new_while_op.body().takeBody(while_op.body());
389 UpdateTerminatorArguments(new_while_op.body(), items.getSecond());
390 new_while_op.body().front().eraseArguments(items.getSecond());
391 new_while_op.cond().front().eraseArguments(items.getSecond());
392 while_op->erase();
393 } else {
394 llvm::BitVector erase_indices(items.getFirst()->getNumOperands());
395 for (auto operand_index : items.getSecond()) {
396 erase_indices.set(operand_index);
397 }
398 items.getFirst()->eraseOperands(erase_indices);
399 }
400 }
401
402 // Remove initialization of unused variables.
403 if (session_init_func)
404 RemoveVariablesInitializations(variables, session_init_func);
405
406 // Remove the unused VarHandleOp.
407 for (auto var_handle_op : variables) {
408 if (var_handle_op) var_handle_op->erase();
409 }
410 }
411 } // namespace
412
CreateFreezeVariablesPass(tensorflow::Session * session)413 std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeVariablesPass(
414 tensorflow::Session* session) {
415 return std::make_unique<FreezeVariablesPass>(session);
416 }
417
418 } // namespace tf_saved_model
419 } // namespace mlir
420