• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <vector>
18 
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
21 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Support/LLVM.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
26 
27 namespace mlir {
28 namespace tf_saved_model {
29 namespace {
30 
31 // This pass will replace a func's bound inputs which are bound to
32 // tf.ReadVariable ops global tensors with tf.Const ops inside the func's body.
33 // If this pass runs successfully, the resultant IR will be guaranteed to:
34 //
35 // 1. Not contain any tf_saved_model.global_tensor ops
36 // 2. Not contain any  tf_saved_model.bound_input arg attrs on tf_saved_model
37 // exported functions
38 // Else, the pass fails.
39 //
40 // The reason this pass has this contract is so that once this succeeds, we know
41 // the IR is in correct form for inference backends (like lite) that do not
42 // support resources/variables . Further, this contract also ensures that this
43 // pass lowers from saved model to pure TF. Hence it fails, if it cannot lower.
44 struct FreezeGlobalTensorsPass
45     : public PassWrapper<FreezeGlobalTensorsPass, OperationPass<ModuleOp>> {
46   FreezeGlobalTensorsPass() = default;
47 
FreezeGlobalTensorsPassmlir::tf_saved_model::__anon85f56c7b0111::FreezeGlobalTensorsPass48   explicit FreezeGlobalTensorsPass(bool allow_mutable_tensors) {
49     this->allow_mutable_tensors = allow_mutable_tensors;
50   }
FreezeGlobalTensorsPassmlir::tf_saved_model::__anon85f56c7b0111::FreezeGlobalTensorsPass51   FreezeGlobalTensorsPass(const FreezeGlobalTensorsPass& pass) {}
52 
53   void runOnOperation() override;
54 
55  private:
56   // Force a specified data format for all layout sensitive operations.
57   Option<bool> allow_mutable_tensors{
58       *this, "allow-mutable-tensors",
59       llvm::cl::desc("Allows mutable tensors to be in the graph. Default is "
60                      "false which means only immutable are allowed.")};
61 };
62 
runOnOperation()63 void FreezeGlobalTensorsPass::runOnOperation() {
64   auto module = getOperation();
65   if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
66     return;
67   }
68   SymbolTable symbol_table(module);
69   DenseSet<Operation*> frozen_global_tensors;
70 
71   for (auto func : module.getOps<FuncOp>()) {
72     SmallVector<unsigned, 4> args_to_erase;
73     OpBuilder builder(func.getBody());
74 
75     for (int i = 0, e = func.getNumArguments(); i < e; ++i) {
76       SmallVector<TF::ReadVariableOp, 4> read_variable_ops_to_erase;
77       auto global_tensor =
78           LookupBoundInputOfType<GlobalTensorOp>(func, i, symbol_table);
79 
80       if (!global_tensor) continue;
81 
82       // This pass assumes that all global tensors as immutable (e.g. by a
83       // previous optimize global tensors pass). If not, this pass has to fail
84       // since it cannot perform one of its goals.
85       if (global_tensor.is_mutable()) {
86         if (allow_mutable_tensors) continue;
87         global_tensor.emitError()
88             << "is not immutable, try removing mutable variables in your model "
89                "since mutable variables are currently not supported through "
90                "this converter";
91         return signalPassFailure();
92       }
93       frozen_global_tensors.insert(global_tensor);
94 
95       auto arg = func.getArgument(i);
96       for (auto user : arg.getUsers()) {
97         if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
98           // Collect all read variable ops so that all its uses can be replaced
99           // with the tf.constant corresponding to the global tensor op.
100           read_variable_ops_to_erase.push_back(read_op);
101         } else {
102           // Current assumption is all users are tf.ReadVariableOp. Need to
103           // expand this to handle control flow and call ops.
104           user->emitError() << "could not rewrite use of immutable bound input";
105           return signalPassFailure();
106         }
107       }
108 
109       // Replace the arg with a tf.Const op in the function body.
110       builder.setInsertionPointToStart(&func.getBody().front());
111       auto const_op = builder.create<TF::ConstOp>(global_tensor.getLoc(),
112                                                   global_tensor.value());
113       args_to_erase.push_back(i);
114       for (auto read_op : read_variable_ops_to_erase) {
115         read_op.getResult().replaceAllUsesWith(const_op.getResult());
116         read_op.erase();
117       }
118     }
119     func.eraseArguments(args_to_erase);
120   }
121   // Erase all global tensors that were frozen.
122   for (auto global_tensor : frozen_global_tensors) {
123     global_tensor->erase();
124   }
125 
126   if (!allow_mutable_tensors && !module.getOps<GlobalTensorOp>().empty()) {
127     module.emitError() << "could not freeze all global tensors in the module";
128     return signalPassFailure();
129   }
130 }
131 
132 }  // namespace
133 
134 // For "opt" to pick up this pass.
135 static PassRegistration<FreezeGlobalTensorsPass> pass(
136     "tf-saved-model-freeze-global-tensors",
137     "Freeze tf_saved_model.global_tensor's in func bodies.");
138 
CreateFreezeGlobalTensorsPass(bool allow_mutable_tensors)139 std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeGlobalTensorsPass(
140     bool allow_mutable_tensors) {
141   return std::make_unique<FreezeGlobalTensorsPass>(allow_mutable_tensors);
142 }
143 
144 }  // namespace tf_saved_model
145 }  // namespace mlir
146