1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <vector>
18
19 #include "mlir/IR/Builders.h" // from @llvm-project
20 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
21 #include "mlir/IR/UseDefLists.h" // from @llvm-project
22 #include "mlir/Pass/Pass.h" // from @llvm-project
23 #include "mlir/Support/LLVM.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
26
27 namespace mlir {
28 namespace tf_saved_model {
29 namespace {
30
31 // This pass will replace a func's bound inputs which are bound to
32 // tf.ReadVariable ops global tensors with tf.Const ops inside the func's body.
33 // If this pass runs successfully, the resultant IR will be guaranteed to:
34 //
35 // 1. Not contain any tf_saved_model.global_tensor ops
36 // 2. Not contain any tf_saved_model.bound_input arg attrs on tf_saved_model
37 // exported functions
38 // Else, the pass fails.
39 //
40 // The reason this pass has this contract is so that once this succeeds, we know
41 // the IR is in correct form for inference backends (like lite) that do not
42 // support resources/variables . Further, this contract also ensures that this
43 // pass lowers from saved model to pure TF. Hence it fails, if it cannot lower.
44 struct FreezeGlobalTensorsPass
45 : public PassWrapper<FreezeGlobalTensorsPass, OperationPass<ModuleOp>> {
46 FreezeGlobalTensorsPass() = default;
47
FreezeGlobalTensorsPassmlir::tf_saved_model::__anon85f56c7b0111::FreezeGlobalTensorsPass48 explicit FreezeGlobalTensorsPass(bool allow_mutable_tensors) {
49 this->allow_mutable_tensors = allow_mutable_tensors;
50 }
FreezeGlobalTensorsPassmlir::tf_saved_model::__anon85f56c7b0111::FreezeGlobalTensorsPass51 FreezeGlobalTensorsPass(const FreezeGlobalTensorsPass& pass) {}
52
53 void runOnOperation() override;
54
55 private:
56 // Force a specified data format for all layout sensitive operations.
57 Option<bool> allow_mutable_tensors{
58 *this, "allow-mutable-tensors",
59 llvm::cl::desc("Allows mutable tensors to be in the graph. Default is "
60 "false which means only immutable are allowed.")};
61 };
62
runOnOperation()63 void FreezeGlobalTensorsPass::runOnOperation() {
64 auto module = getOperation();
65 if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
66 return;
67 }
68 SymbolTable symbol_table(module);
69 DenseSet<Operation*> frozen_global_tensors;
70
71 for (auto func : module.getOps<FuncOp>()) {
72 SmallVector<unsigned, 4> args_to_erase;
73 OpBuilder builder(func.getBody());
74
75 for (int i = 0, e = func.getNumArguments(); i < e; ++i) {
76 SmallVector<TF::ReadVariableOp, 4> read_variable_ops_to_erase;
77 auto global_tensor =
78 LookupBoundInputOfType<GlobalTensorOp>(func, i, symbol_table);
79
80 if (!global_tensor) continue;
81
82 // This pass assumes that all global tensors as immutable (e.g. by a
83 // previous optimize global tensors pass). If not, this pass has to fail
84 // since it cannot perform one of its goals.
85 if (global_tensor.is_mutable()) {
86 if (allow_mutable_tensors) continue;
87 global_tensor.emitError()
88 << "is not immutable, try removing mutable variables in your model "
89 "since mutable variables are currently not supported through "
90 "this converter";
91 return signalPassFailure();
92 }
93 frozen_global_tensors.insert(global_tensor);
94
95 auto arg = func.getArgument(i);
96 for (auto user : arg.getUsers()) {
97 if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
98 // Collect all read variable ops so that all its uses can be replaced
99 // with the tf.constant corresponding to the global tensor op.
100 read_variable_ops_to_erase.push_back(read_op);
101 } else {
102 // Current assumption is all users are tf.ReadVariableOp. Need to
103 // expand this to handle control flow and call ops.
104 user->emitError() << "could not rewrite use of immutable bound input";
105 return signalPassFailure();
106 }
107 }
108
109 // Replace the arg with a tf.Const op in the function body.
110 builder.setInsertionPointToStart(&func.getBody().front());
111 auto const_op = builder.create<TF::ConstOp>(global_tensor.getLoc(),
112 global_tensor.value());
113 args_to_erase.push_back(i);
114 for (auto read_op : read_variable_ops_to_erase) {
115 read_op.getResult().replaceAllUsesWith(const_op.getResult());
116 read_op.erase();
117 }
118 }
119 func.eraseArguments(args_to_erase);
120 }
121 // Erase all global tensors that were frozen.
122 for (auto global_tensor : frozen_global_tensors) {
123 global_tensor->erase();
124 }
125
126 if (!allow_mutable_tensors && !module.getOps<GlobalTensorOp>().empty()) {
127 module.emitError() << "could not freeze all global tensors in the module";
128 return signalPassFailure();
129 }
130 }
131
132 } // namespace
133
134 // For "opt" to pick up this pass.
135 static PassRegistration<FreezeGlobalTensorsPass> pass(
136 "tf-saved-model-freeze-global-tensors",
137 "Freeze tf_saved_model.global_tensor's in func bodies.");
138
CreateFreezeGlobalTensorsPass(bool allow_mutable_tensors)139 std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeGlobalTensorsPass(
140 bool allow_mutable_tensors) {
141 return std::make_unique<FreezeGlobalTensorsPass>(allow_mutable_tensors);
142 }
143
144 } // namespace tf_saved_model
145 } // namespace mlir
146