• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <string>
18 #include <vector>
19 
20 #include "llvm/ADT/StringRef.h"
21 #include "mlir/IR/Builders.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Support/LLVM.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
28 #include "tensorflow/core/platform/path.h"
29 
30 namespace mlir {
31 namespace tf_saved_model {
32 namespace {
33 
34 // This pass will replace a func's saved model asset bound inputs which are
35 // bound to tf.InitializeTableFromTextFileV2Op ops with tf.Const ops inside the
36 // func's body.
37 struct FreezeAssetsPass
38     : public PassWrapper<FreezeAssetsPass, OperationPass<ModuleOp>> {
39   FreezeAssetsPass() = default;
40 
FreezeAssetsPassmlir::tf_saved_model::__anon0e2084170111::FreezeAssetsPass41   FreezeAssetsPass(const FreezeAssetsPass& pass) {}
FreezeAssetsPassmlir::tf_saved_model::__anon0e2084170111::FreezeAssetsPass42   explicit FreezeAssetsPass(std::string saved_model_dir) {
43     this->saved_model_dir = saved_model_dir;
44   }
45 
getArgumentmlir::tf_saved_model::__anon0e2084170111::FreezeAssetsPass46   StringRef getArgument() const final { return "tf-saved-model-freeze-assets"; }
47 
getDescriptionmlir::tf_saved_model::__anon0e2084170111::FreezeAssetsPass48   StringRef getDescription() const final {
49     return "Freeze tf_saved_model.asset's in func bodies.";
50   }
51 
52   void runOnOperation() override;
53 
54  private:
55   std::string saved_model_dir;
56 };
57 
runOnOperation()58 void FreezeAssetsPass::runOnOperation() {
59   auto module = getOperation();
60   if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
61     return;
62   }
63   SymbolTable symbol_table(module);
64 
65   for (auto func : module.getOps<FuncOp>()) {
66     SmallVector<unsigned, 4> args_to_erase;
67     OpBuilder builder(func.getBody());
68 
69     for (int i = 0, e = func.getNumArguments(); i < e; ++i) {
70       SmallVector<TF::InitializeTableFromTextFileV2Op, 4>
71           init_table_from_text_file_ops_to_erase;
72       auto asset = LookupBoundInputOfType<AssetOp>(func, i, symbol_table);
73 
74       if (!asset) continue;
75 
76       auto arg = func.getArgument(i);
77       bool arg_is_deletable = true;
78       for (auto user : arg.getUsers()) {
79         if (auto read_op =
80                 llvm::dyn_cast<TF::InitializeTableFromTextFileV2Op>(user)) {
81           init_table_from_text_file_ops_to_erase.push_back(read_op);
82         } else {
83           arg_is_deletable = false;
84           continue;
85         }
86       }
87       if (arg_is_deletable) {
88         args_to_erase.push_back(i);
89       }
90 
91       // Replace the arg with a tf.Const op in the function body.
92       builder.setInsertionPointToStart(&func.getBody().front());
93 
94       std::string asset_filename = asset.filename().str();
95       std::string filename =
96           tensorflow::io::JoinPath(saved_model_dir, asset_filename);
97       ShapedType shaped_type =
98           RankedTensorType::get({1}, TF::StringType::get(builder.getContext()));
99       auto const_op = builder.create<TF::ConstOp>(
100           asset.getLoc(),
101           DenseStringElementsAttr::get(shaped_type, {filename}));
102       for (auto init_op : init_table_from_text_file_ops_to_erase) {
103         // Replace the InitializeTableFromTextFileV2Op to use the saved model's
104         // asset filepath.
105         builder.setInsertionPoint(init_op);
106         builder.create<TF::InitializeTableFromTextFileV2Op>(
107             init_op.getLoc(), init_op.table_handle(), const_op.getResult(),
108             init_op.key_index(), init_op.value_index(), init_op.vocab_size(),
109             init_op.delimiter());
110         init_op.erase();
111       }
112     }
113     func.eraseArguments(args_to_erase);
114   }
115 }
116 
117 }  // namespace
118 
119 // For "opt" to pick up this pass.
120 static PassRegistration<FreezeAssetsPass> freeze_assets_pass;
121 
CreateFreezeAssetsPass(std::string saved_model_dir)122 std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeAssetsPass(
123     std::string saved_model_dir) {
124   return std::make_unique<FreezeAssetsPass>(saved_model_dir);
125 }
126 
127 }  // namespace tf_saved_model
128 }  // namespace mlir
129