• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <utility>
17 #include <vector>
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/DenseSet.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/Casting.h"
27 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
30 #include "mlir/IR/Location.h"  // from @llvm-project
31 #include "mlir/IR/Value.h"  // from @llvm-project
32 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
33 #include "mlir/Pass/Pass.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
38 #include "tensorflow/compiler/mlir/tensorflow/utils/session_utils.h"
39 #include "tensorflow/core/framework/resource_var.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/public/session.h"
42 
43 namespace mlir {
44 namespace tf_saved_model {
45 namespace {
46 
47 // Build and returns ElementsAttr which holds the data in 'tensor'.
GetTensorValueAsElementsAttr(const tensorflow::Tensor & tensor,OpBuilder builder)48 ElementsAttr GetTensorValueAsElementsAttr(const tensorflow::Tensor& tensor,
49                                           OpBuilder builder) {
50   tensorflow::StatusOr<ElementsAttr> tensor_attr_or =
51       tensorflow::ConvertTensor(tensor, &builder);
52   if (!tensor_attr_or.ok()) return nullptr;
53   return tensor_attr_or.ValueOrDie();
54 }
55 
56 // Creates ConstantOp which holds 'tensor_elements'.
GetConstOpFromElementsAttr(ElementsAttr tensor_elements,OpBuilder builder,Location loc)57 mlir::ConstantOp GetConstOpFromElementsAttr(ElementsAttr tensor_elements,
58                                             OpBuilder builder, Location loc) {
59   return builder.create<mlir::ConstantOp>(loc, tensor_elements.getType(),
60                                           tensor_elements);
61 }
62 
63 // Returns ElementsAttr which has the value held by 'resource_tensor'.
GetTensorValueAsElementsAttr(TF::VarHandleOp var_handle_op,const tensorflow::Tensor & resource_tensor,const tensorflow::DeviceMgr * mgr,OpBuilder builder)64 ElementsAttr GetTensorValueAsElementsAttr(
65     TF::VarHandleOp var_handle_op, const tensorflow::Tensor& resource_tensor,
66     const tensorflow::DeviceMgr* mgr, OpBuilder builder) {
67   if (resource_tensor.dtype() != tensorflow::DT_RESOURCE) {
68     return GetTensorValueAsElementsAttr(resource_tensor, builder);
69   }
70 
71   auto handle = resource_tensor.scalar<tensorflow::ResourceHandle>()();
72   auto* var_ptr = tf_saved_model::GetVariableFromSession(var_handle_op,
73                                                          handle.device(), mgr);
74   if (!var_ptr) {
75     return nullptr;
76   }
77   tensorflow::core::RefCountPtr<tensorflow::Var> var(var_ptr);
78   auto* tensor = var_ptr->tensor();
79 
80   return GetTensorValueAsElementsAttr(*tensor, builder);
81 }
82 
83 // Replace usage of 'read_variable_op' with 'value'.
PropagateUsage(TF::ReadVariableOp read_variable_op,ElementsAttr value)84 void PropagateUsage(TF::ReadVariableOp read_variable_op, ElementsAttr value) {
85   OpBuilder builder(read_variable_op);
86   read_variable_op->getResult(0).replaceAllUsesWith(
87       GetConstOpFromElementsAttr(value, builder, read_variable_op->getLoc()));
88 }
89 
90 // Propagates a resource usage across the graph where
91 // 'user_op' uses a resource and is passed to this op at 'argument_index'.
92 // This resource should be replaced by 'value'.
93 // Output params:
94 // - work_list: Is updated with new regions to process that is called
95 //   by 'user_op';
96 // - arguments_to_erase: Captures updates to the graph - which arguments
97 //   to remove from the op;
PropagateUsage(Operation * user_op,int argument_index,ElementsAttr value,llvm::SmallVector<std::pair<Region *,int>,4> * work_list,llvm::DenseMap<Operation *,llvm::SmallVector<unsigned int,4>> * arguments_to_erase)98 void PropagateUsage(
99     Operation* user_op, int argument_index, ElementsAttr value,
100     llvm::SmallVector<std::pair<Region*, int>, 4>* work_list,
101     llvm::DenseMap<Operation*, llvm::SmallVector<unsigned int, 4>>*
102         arguments_to_erase) {
103   if (auto read_variable_op = dyn_cast<TF::ReadVariableOp>(user_op)) {
104     PropagateUsage(read_variable_op, value);
105     (*arguments_to_erase)[read_variable_op];
106   } else if (auto call = dyn_cast<CallOpInterface>(user_op)) {
107     if (auto func = dyn_cast<FuncOp>(call.resolveCallable())) {
108       (*arguments_to_erase)[func].push_back(argument_index);
109       work_list->push_back(std::make_pair(&func.getRegion(), argument_index));
110     }
111     (*arguments_to_erase)[call].push_back(argument_index);
112   } else if (auto if_op = dyn_cast<TF::IfOp>(user_op)) {
113     for (auto callee : {if_op.then_function(), if_op.else_function()}) {
114       (*arguments_to_erase)[callee].push_back(argument_index - 1);
115       work_list->push_back(std::make_pair(&callee.body(), argument_index - 1));
116     }
117     (*arguments_to_erase)[if_op].push_back(argument_index);
118   } else if (auto if_op = dyn_cast<TF::IfRegionOp>(user_op)) {
119     for (auto callee : {&if_op.then_branch(), &if_op.else_branch()}) {
120       work_list->push_back(std::make_pair(callee, argument_index));
121     }
122     (*arguments_to_erase)[if_op].push_back(argument_index);
123   } else if (auto while_op = dyn_cast<TF::WhileOp>(user_op)) {
124     for (auto callee : {while_op.cond_function(), while_op.body_function()}) {
125       (*arguments_to_erase)[callee].push_back(argument_index);
126       work_list->push_back(std::make_pair(&callee.body(), argument_index));
127     }
128     (*arguments_to_erase)[while_op].push_back(argument_index);
129   } else if (auto while_op = dyn_cast<TF::WhileRegionOp>(user_op)) {
130     for (auto callee : {&while_op.cond(), &while_op.body()}) {
131       work_list->push_back(std::make_pair(callee, argument_index));
132     }
133     (*arguments_to_erase)[while_op].push_back(argument_index);
134   }
135 }
136 
137 // An override that takes region.
PropagateUsage(Region * region,ElementsAttr value,int argument_index,llvm::SmallVector<std::pair<Region *,int>,4> * work_list,llvm::DenseMap<Operation *,llvm::SmallVector<unsigned int,4>> * arguments_to_erase)138 void PropagateUsage(
139     Region* region, ElementsAttr value, int argument_index,
140     llvm::SmallVector<std::pair<Region*, int>, 4>* work_list,
141     llvm::DenseMap<Operation*, llvm::SmallVector<unsigned int, 4>>*
142         arguments_to_erase) {
143   auto arg = region->getArgument(argument_index);
144   for (auto& usage : arg.getUses()) {
145     auto* user_op = usage.getOwner();
146     int operand_index = usage.getOperandNumber();
147     PropagateUsage(user_op, operand_index, value, work_list,
148                    arguments_to_erase);
149   }
150 }
151 
152 // Traces usage of 'var_handle_op' and replaces it's usage with constant value
153 // 'value'.
154 // All op operands updates are captured in 'arguments_to_erase'.
ReplaceVarWithConstant(TF::VarHandleOp var_handle_op,ElementsAttr value,llvm::DenseMap<Operation *,llvm::SmallVector<unsigned int,4>> * arguments_to_erase)155 void ReplaceVarWithConstant(
156     TF::VarHandleOp var_handle_op, ElementsAttr value,
157     llvm::DenseMap<Operation*, llvm::SmallVector<unsigned int, 4>>*
158         arguments_to_erase) {
159   llvm::SmallVector<std::pair<Region*, int>, 4> work_list;
160   for (auto& usage : var_handle_op->getUses()) {
161     auto* user_op = usage.getOwner();
162     int operand_index = usage.getOperandNumber();
163     PropagateUsage(user_op, operand_index, value, &work_list,
164                    arguments_to_erase);
165   }
166   // Container to mark visited regions to avoid infinite loop.
167   llvm::DenseSet<std::pair<Region*, int>> visited;
168   while (!work_list.empty()) {
169     auto work_item = work_list.pop_back_val();
170     if (visited.contains(work_item)) continue;
171     PropagateUsage(work_item.first, value, work_item.second, &work_list,
172                    arguments_to_erase);
173     visited.insert(work_item);
174   }
175 }
176 
177 // A pass that tries to freeze / constant fold read only variables in the graph.
178 // The pass will analyze the variable usage and if the variable is immutable
179 // it will replace it with constant tensor which has its value.
180 // Note: This pass currently only works with variables that are initialized
181 // during Session init function only. Expanding to other uses is a todo.
182 class FreezeVariablesPass
183     : public PassWrapper<FreezeVariablesPass, OperationPass<ModuleOp>> {
184  public:
185   // If no session is provided or null the pass is no-op.
FreezeVariablesPass(tensorflow::Session * session=nullptr)186   explicit FreezeVariablesPass(tensorflow::Session* session = nullptr)
187       : session_(session) {}
188 
189   void runOnOperation() override;
190 
191  private:
192   tensorflow::Session* session_;
193 };
194 
195 // Helper that returns the FuncOp that is the SessionInit function which
196 // will be called to initialize all resources.
197 // Returns nullptr if no function is found.
GetSessionInitializerFunc(ModuleOp module)198 FuncOp GetSessionInitializerFunc(ModuleOp module) {
199   auto session_init_op = tf_saved_model::GetSessionInitializerOp(module);
200   SymbolTable symbol_table(module);
201   if (session_init_op && !session_init_op.initializers().empty()) {
202     FuncOp init_func_op = symbol_table.lookup<mlir::FuncOp>(
203         session_init_op.initializers()[0].cast<FlatSymbolRefAttr>().getValue());
204     return init_func_op;
205   }
206   return nullptr;
207 }
208 
209 // Returns ID for identifying a resource.
GetResourceKey(Operation * op)210 std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef> GetResourceKey(
211     Operation* op) {
212   llvm::StringRef device;
213   if (auto attr = op->getAttrOfType<mlir::StringAttr>("device")) {
214     device = attr.getValue();
215   }
216 
217   llvm::StringRef container;
218   if (auto attr = op->getAttrOfType<mlir::StringAttr>("container")) {
219     container = attr.getValue();
220   }
221 
222   llvm::StringRef shared_name;
223   if (auto attr = op->getAttrOfType<mlir::StringAttr>("shared_name")) {
224     shared_name = attr.getValue();
225   }
226 
227   return std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>{
228       device, container, shared_name};
229 }
230 
231 // Remove the initialization of the variables in 'var_handle_ops' from
232 // the session init function 'sesion_init_func'
RemoveVariablesInitializations(const llvm::SmallVector<TF::VarHandleOp,4> & var_handle_ops,FuncOp sesion_init_func)233 void RemoveVariablesInitializations(
234     const llvm::SmallVector<TF::VarHandleOp, 4>& var_handle_ops,
235     FuncOp sesion_init_func) {
236   // We identify the variables using (device, container, shared_name) of the
237   // resource. Capture them here and use them to identify the useless
238   // initializations.
239   llvm::SetVector<std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>>
240       variables;
241   for (auto var_handle_op : var_handle_ops)
242     variables.insert(GetResourceKey(var_handle_op));
243 
244   llvm::SmallVector<Operation*, 4> work_list;
245   for (auto var_handle_op : sesion_init_func.getOps<TF::VarHandleOp>()) {
246     if (variables.count(GetResourceKey(var_handle_op)))
247       work_list.push_back(var_handle_op);
248   }
249 
250   // Capture list of ops to be erased by traversing usage starting from
251   // the VarHandle ops.
252   llvm::SetVector<Operation*> erase_list;
253   while (!work_list.empty()) {
254     auto* operation = work_list.pop_back_val();
255     erase_list.insert(operation);
256     for (auto& use : operation->getUses()) {
257       if (erase_list.count(use.getOwner())) continue;
258       work_list.push_back(use.getOwner());
259     }
260   }
261 
262   for (auto* op : erase_list) {
263     op->dropAllUses();
264     op->erase();
265   }
266 }
267 
268 // Updates terminator op arguments of 'func' after removing arguments
269 // specified in 'arguments_to_erase'.
270 template <typename T>
UpdateTerminatorArguments(T & func,const llvm::SmallVector<unsigned,4> & arguments_to_erase)271 void UpdateTerminatorArguments(
272     T& func, const llvm::SmallVector<unsigned, 4>& arguments_to_erase) {
273   llvm::BitVector erase_indices(func.front().getTerminator()->getNumOperands());
274   for (auto arg_index : arguments_to_erase) {
275     auto argument = func.getArgument(arg_index);
276     for (auto& use : argument.getUses()) {
277       if (llvm::isa<ReturnOp, TF::YieldOp>(use.getOwner())) {
278         int operand_index = argument.getUses().begin()->getOperandNumber();
279         erase_indices.set(operand_index);
280       }
281     }
282     func.getArgument(arg_index).dropAllUses();
283   }
284   if (llvm::isa<ReturnOp, TF::YieldOp>(func.front().getTerminator())) {
285     func.front().getTerminator()->eraseOperands(erase_indices);
286   }
287 }
288 
289 // Updates 'while_op' signatures based on which arguments should be removed
290 // in 'arguments_to_erase'.
291 template <typename T>
GetUpdatedWhileOp(T while_op,const llvm::SmallVector<unsigned,4> & arguments_to_erase)292 T GetUpdatedWhileOp(T while_op,
293                     const llvm::SmallVector<unsigned, 4>& arguments_to_erase) {
294   OpBuilder builder(while_op);
295   llvm::SmallVector<Type, 4> new_operand_types;
296   llvm::SmallVector<Value> new_operands;
297   auto operands = while_op->getOperands();
298   const int num_operands = while_op->getNumOperands();
299   llvm::BitVector skip_indices(num_operands);
300   for (int i : arguments_to_erase) skip_indices.set(i);
301   for (int i = 0; i < num_operands; ++i) {
302     if (!skip_indices.test(i)) {
303       new_operand_types.emplace_back(operands[i].getType());
304       new_operands.emplace_back(operands[i]);
305     }
306   }
307   auto new_while_op = builder.create<T>(while_op->getLoc(), new_operand_types,
308                                         new_operands, while_op->getAttrs());
309   for (int i = 0; i < num_operands; ++i) {
310     if (!skip_indices.test(i)) {
311       while_op->getResult(i).replaceAllUsesWith(new_while_op->getResult(i));
312     }
313   }
314   return new_while_op;
315 }
316 
runOnOperation()317 void FreezeVariablesPass::runOnOperation() {
318   if (!session_) return;
319   ModuleOp module = getOperation();
320   const tensorflow::DeviceMgr* mgr = nullptr;
321   auto status = session_->LocalDeviceManager(&mgr);
322   if (!status.ok()) {
323     module->emitError("failed to fetch device manager: " +
324                       status.error_message());
325     return signalPassFailure();
326   }
327 
328   FuncOp session_init_func = GetSessionInitializerFunc(module);
329 
330   TF::ResourceAnalyzer analyzer(module, /*skip_session_init=*/true);
331   llvm::SmallVector<TF::VarHandleOp, 4> variables;
332   // Capture list of all read only variables.
333   for (auto func : module.getOps<FuncOp>()) {
334     if (func == session_init_func) continue;
335     for (auto var_handle_op : func.getOps<TF::VarHandleOp>()) {
336       if (!analyzer.IsPotentiallyWritten(var_handle_op.resource())) {
337         variables.push_back(var_handle_op);
338       }
339     }
340   }
341 
342   // Fetch the values to replace the VarHandleOps with.
343   auto resource_tensors_or =
344       tf_saved_model::GetResourcesFromSession(variables, session_);
345   if (!resource_tensors_or.ok()) {
346     module->emitError(resource_tensors_or.status().message().data());
347     signalPassFailure();
348   }
349 
350   auto* context = module.getContext();
351   OpBuilder builder(context);
352   // Note: We can't modify the graph while navigating through it, as erasing
353   // invalidate pointers.
354   // So instead we capture all the updates in the below map, and then
355   // process them after.
356 
357   // Container to hold all update actions on ops.
358   // Key: Operation to update.
359   // Value: optional list of arguments to delete from this op.
360   llvm::DenseMap<Operation*, llvm::SmallVector<unsigned int, 4>>
361       arguments_to_erase;
362   for (auto variable_value_pair :
363        llvm::zip(variables, resource_tensors_or.value())) {
364     auto var_handle_op = std::get<0>(variable_value_pair);
365     builder.setInsertionPointAfterValue(var_handle_op);
366     auto elements_attr = GetTensorValueAsElementsAttr(
367         var_handle_op, std::get<1>(variable_value_pair), mgr, builder);
368     ReplaceVarWithConstant(var_handle_op, elements_attr, &arguments_to_erase);
369   }
370 
371   // All updates to different ops are captured in 'arguments_to_erase'.
372   // Now loop on them and based on each item type update accordingly.
373   for (auto& items : arguments_to_erase) {
374     if (auto func = dyn_cast<FuncOp>(items.getFirst())) {
375       UpdateTerminatorArguments(func, items.getSecond());
376       func.eraseArguments(items.getSecond());
377     } else if (auto read_var = dyn_cast<TF::ReadVariableOp>(items.getFirst())) {
378       // Read variables was already replaced by constant op. Just remove the op.
379       read_var->erase();
380     } else if (auto while_op = dyn_cast<TF::WhileOp>(items.getFirst())) {
381       auto new_while_op =
382           GetUpdatedWhileOp<TF::WhileOp>(while_op, items.getSecond());
383       new_while_op.body_function().eraseResults(items.getSecond());
384       while_op->erase();
385     } else if (auto while_op = dyn_cast<TF::WhileRegionOp>(items.getFirst())) {
386       auto new_while_op = GetUpdatedWhileOp(while_op, items.getSecond());
387       new_while_op.cond().takeBody(while_op.cond());
388       new_while_op.body().takeBody(while_op.body());
389       UpdateTerminatorArguments(new_while_op.body(), items.getSecond());
390       new_while_op.body().front().eraseArguments(items.getSecond());
391       new_while_op.cond().front().eraseArguments(items.getSecond());
392       while_op->erase();
393     } else {
394       llvm::BitVector erase_indices(items.getFirst()->getNumOperands());
395       for (auto operand_index : items.getSecond()) {
396         erase_indices.set(operand_index);
397       }
398       items.getFirst()->eraseOperands(erase_indices);
399     }
400   }
401 
402   // Remove initialization of unused variables.
403   if (session_init_func)
404     RemoveVariablesInitializations(variables, session_init_func);
405 
406   // Remove the unused VarHandleOp.
407   for (auto var_handle_op : variables) {
408     if (var_handle_op) var_handle_op->erase();
409   }
410 }
411 }  // namespace
412 
CreateFreezeVariablesPass(tensorflow::Session * session)413 std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeVariablesPass(
414     tensorflow::Session* session) {
415   return std::make_unique<FreezeVariablesPass>(session);
416 }
417 
418 }  // namespace tf_saved_model
419 }  // namespace mlir
420