• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <vector>
18 
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
21 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Support/LLVM.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
26 #include "tensorflow/compiler/mlir/tensorflow/transforms/savedmodel_passes_detail.h"
27 
28 namespace mlir {
29 namespace tf_saved_model {
30 namespace {
31 
32 struct FreezeGlobalTensorsPass
33     : public FreezeGlobalTensorsPassBase<FreezeGlobalTensorsPass> {
FreezeGlobalTensorsPassmlir::tf_saved_model::__anonba2acfa00111::FreezeGlobalTensorsPass34   explicit FreezeGlobalTensorsPass(bool allow_mutable_tensors) {
35     this->allow_mutable_tensors = allow_mutable_tensors;
36   }
37   void runOnOperation() override;
38 };
39 
runOnOperation()40 void FreezeGlobalTensorsPass::runOnOperation() {
41   auto module = getOperation();
42   if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
43     return;
44   }
45   SymbolTable symbol_table(module);
46   DenseSet<Operation*> frozen_global_tensors;
47 
48   for (auto func : module.getOps<FuncOp>()) {
49     SmallVector<unsigned, 4> args_to_erase;
50     OpBuilder builder(func.getBody());
51 
52     for (int i = 0, e = func.getNumArguments(); i < e; ++i) {
53       SmallVector<TF::ReadVariableOp, 4> read_variable_ops_to_erase;
54       auto global_tensor =
55           LookupBoundInputOfType<GlobalTensorOp>(func, i, symbol_table);
56 
57       if (!global_tensor) continue;
58 
59       // This pass assumes that all global tensors as immutable (e.g. by a
60       // previous optimize global tensors pass). If not, this pass has to fail
61       // since it cannot perform one of its goals.
62       if (global_tensor.is_mutable()) {
63         if (allow_mutable_tensors) continue;
64         global_tensor.emitError()
65             << "is not immutable, try removing mutable variables in your model "
66                "since mutable variables are currently not supported through "
67                "this converter";
68         return signalPassFailure();
69       }
70       frozen_global_tensors.insert(global_tensor);
71 
72       auto arg = func.getArgument(i);
73       for (auto user : arg.getUsers()) {
74         if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
75           // Collect all read variable ops so that all its uses can be replaced
76           // with the tf.constant corresponding to the global tensor op.
77           read_variable_ops_to_erase.push_back(read_op);
78         } else {
79           // Current assumption is all users are tf.ReadVariableOp. Need to
80           // expand this to handle control flow and call ops.
81           user->emitError() << "could not rewrite use of immutable bound input";
82           return signalPassFailure();
83         }
84       }
85 
86       // Replace the arg with a tf.Const op in the function body.
87       builder.setInsertionPointToStart(&func.getBody().front());
88       auto const_op = builder.create<TF::ConstOp>(global_tensor.getLoc(),
89                                                   global_tensor.value());
90       args_to_erase.push_back(i);
91       for (auto read_op : read_variable_ops_to_erase) {
92         read_op.getResult().replaceAllUsesWith(const_op.getResult());
93         read_op.erase();
94       }
95     }
96     func.eraseArguments(args_to_erase);
97   }
98   // Erase all global tensors that were frozen.
99   for (auto global_tensor : frozen_global_tensors) {
100     global_tensor->erase();
101   }
102 
103   if (!allow_mutable_tensors && !module.getOps<GlobalTensorOp>().empty()) {
104     module.emitError() << "could not freeze all global tensors in the module";
105     return signalPassFailure();
106   }
107 }
108 
109 }  // namespace
110 
CreateFreezeGlobalTensorsPass(bool allow_mutable_tensors)111 std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeGlobalTensorsPass(
112     bool allow_mutable_tensors) {
113   return std::make_unique<FreezeGlobalTensorsPass>(allow_mutable_tensors);
114 }
115 
116 }  // namespace tf_saved_model
117 }  // namespace mlir
118