1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <string>
18 #include <vector>
19
20 #include "llvm/ADT/StringRef.h"
21 #include "mlir/IR/Builders.h" // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
23 #include "mlir/IR/UseDefLists.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "mlir/Support/LLVM.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
28 #include "tensorflow/core/platform/path.h"
29
30 namespace mlir {
31 namespace tf_saved_model {
32 namespace {
33
34 // This pass will replace a func's saved model asset bound inputs which are
35 // bound to tf.InitializeTableFromTextFileV2Op ops with tf.Const ops inside the
36 // func's body.
37 struct FreezeAssetsPass
38 : public PassWrapper<FreezeAssetsPass, OperationPass<ModuleOp>> {
39 FreezeAssetsPass() = default;
40
FreezeAssetsPassmlir::tf_saved_model::__anon0e2084170111::FreezeAssetsPass41 FreezeAssetsPass(const FreezeAssetsPass& pass) {}
FreezeAssetsPassmlir::tf_saved_model::__anon0e2084170111::FreezeAssetsPass42 explicit FreezeAssetsPass(std::string saved_model_dir) {
43 this->saved_model_dir = saved_model_dir;
44 }
45
getArgumentmlir::tf_saved_model::__anon0e2084170111::FreezeAssetsPass46 StringRef getArgument() const final { return "tf-saved-model-freeze-assets"; }
47
getDescriptionmlir::tf_saved_model::__anon0e2084170111::FreezeAssetsPass48 StringRef getDescription() const final {
49 return "Freeze tf_saved_model.asset's in func bodies.";
50 }
51
52 void runOnOperation() override;
53
54 private:
55 std::string saved_model_dir;
56 };
57
runOnOperation()58 void FreezeAssetsPass::runOnOperation() {
59 auto module = getOperation();
60 if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
61 return;
62 }
63 SymbolTable symbol_table(module);
64
65 for (auto func : module.getOps<FuncOp>()) {
66 SmallVector<unsigned, 4> args_to_erase;
67 OpBuilder builder(func.getBody());
68
69 for (int i = 0, e = func.getNumArguments(); i < e; ++i) {
70 SmallVector<TF::InitializeTableFromTextFileV2Op, 4>
71 init_table_from_text_file_ops_to_erase;
72 auto asset = LookupBoundInputOfType<AssetOp>(func, i, symbol_table);
73
74 if (!asset) continue;
75
76 auto arg = func.getArgument(i);
77 bool arg_is_deletable = true;
78 for (auto user : arg.getUsers()) {
79 if (auto read_op =
80 llvm::dyn_cast<TF::InitializeTableFromTextFileV2Op>(user)) {
81 init_table_from_text_file_ops_to_erase.push_back(read_op);
82 } else {
83 arg_is_deletable = false;
84 continue;
85 }
86 }
87 if (arg_is_deletable) {
88 args_to_erase.push_back(i);
89 }
90
91 // Replace the arg with a tf.Const op in the function body.
92 builder.setInsertionPointToStart(&func.getBody().front());
93
94 std::string asset_filename = asset.filename().str();
95 std::string filename =
96 tensorflow::io::JoinPath(saved_model_dir, asset_filename);
97 ShapedType shaped_type =
98 RankedTensorType::get({1}, TF::StringType::get(builder.getContext()));
99 auto const_op = builder.create<TF::ConstOp>(
100 asset.getLoc(),
101 DenseStringElementsAttr::get(shaped_type, {filename}));
102 for (auto init_op : init_table_from_text_file_ops_to_erase) {
103 // Replace the InitializeTableFromTextFileV2Op to use the saved model's
104 // asset filepath.
105 builder.setInsertionPoint(init_op);
106 builder.create<TF::InitializeTableFromTextFileV2Op>(
107 init_op.getLoc(), init_op.table_handle(), const_op.getResult(),
108 init_op.key_index(), init_op.value_index(), init_op.vocab_size(),
109 init_op.delimiter());
110 init_op.erase();
111 }
112 }
113 func.eraseArguments(args_to_erase);
114 }
115 }
116
117 } // namespace
118
119 // For "opt" to pick up this pass.
120 static PassRegistration<FreezeAssetsPass> freeze_assets_pass;
121
CreateFreezeAssetsPass(std::string saved_model_dir)122 std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeAssetsPass(
123 std::string saved_model_dir) {
124 return std::make_unique<FreezeAssetsPass>(saved_model_dir);
125 }
126
127 } // namespace tf_saved_model
128 } // namespace mlir
129