1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <numeric>
17
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/MemoryBuffer.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/OperationSupport.h" // from @llvm-project
24 #include "mlir/IR/PatternMatch.h" // from @llvm-project
25 #include "mlir/Pass/Pass.h" // from @llvm-project
26 #include "mlir/Support/FileUtilities.h" // from @llvm-project
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29 #include "tensorflow/core/lib/io/path.h"
30
31 namespace mlir {
32 namespace TF {
33 namespace {
34
35 static constexpr int kTextFileIndex_WholeLine = -2;
36 static constexpr int kTextFileIndex_LineNumber = -1;
37
38 // InitTextFileToImportPass converts InitializeTableFromTextFileV2Op to the
39 // corresponding LookupTableImportV2Op if possible.
40 class InitTextFileToImportPass
41 : public mlir::PassWrapper<InitTextFileToImportPass, FunctionPass> {
42 public:
InitTextFileToImportPass()43 InitTextFileToImportPass() {}
InitTextFileToImportPass(const InitTextFileToImportPass &)44 InitTextFileToImportPass(const InitTextFileToImportPass&) {}
InitTextFileToImportPass(std::string saved_model_dir)45 explicit InitTextFileToImportPass(std::string saved_model_dir) {
46 saved_model_dir_ = saved_model_dir;
47 }
48
getArgument() const49 StringRef getArgument() const final { return "tf-init-text-file-to-import"; }
50
getDescription() const51 StringRef getDescription() const final {
52 return "convert InitializeTableFromTextFileV2 ops to LookupTableImportV2Op "
53 "to remove the dependency on asset files";
54 }
55
56 private:
57 void runOnFunction() override;
58
59 Option<std::string> saved_model_dir_{
60 *this, "tf-saved-model-dir",
61 llvm::cl::desc("Directory containing the model exported as a TensorFlow "
62 "SavedModel. If your model is not based on the TensorFlow "
63 "SavedModel, use an empty value."),
64 llvm::cl::init("")};
65 };
66
67 class ConvertInitializeTableFromTextFileV2
68 : public OpRewritePattern<InitializeTableFromTextFileV2Op> {
69 public:
ConvertInitializeTableFromTextFileV2(mlir::MLIRContext * context,StringRef saved_model_dir)70 explicit ConvertInitializeTableFromTextFileV2(mlir::MLIRContext* context,
71 StringRef saved_model_dir)
72 : OpRewritePattern<InitializeTableFromTextFileV2Op>(context),
73 saved_model_dir_(saved_model_dir) {}
74
matchAndRewrite(InitializeTableFromTextFileV2Op op,PatternRewriter & rewriter) const75 LogicalResult matchAndRewrite(InitializeTableFromTextFileV2Op op,
76 PatternRewriter& rewriter) const override {
77 // Now, this pattern matching only supports the following case, which is
78 // commonly used among inference use cases:
79 //
80 // tf.lookup.TextFileInitializer(
81 // "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE,
82 // tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" ")
83 //
84 // In the above case, the delimiter will be not used since the key is just a
85 // whole line and value is a line number.
86 if (op.key_index() != kTextFileIndex_WholeLine ||
87 op.value_index() != kTextFileIndex_LineNumber) {
88 return failure();
89 }
90
91 // Try to find filename from constant op.
92 DenseStringElementsAttr filename_attr;
93 if (!matchPattern(op.filename().getDefiningOp(),
94 m_Constant(&filename_attr))) {
95 return failure();
96 }
97
98 if (filename_attr.getRawStringData().size() != 1) {
99 return failure();
100 }
101 std::string filename = filename_attr.getRawStringData()[0].str();
102
103 if (!saved_model_dir_.empty()) {
104 filename = tensorflow::io::JoinPath(
105 saved_model_dir_.str(),
106 tensorflow::io::JoinPath("assets",
107 tensorflow::io::Basename(filename)));
108 }
109
110 // Read the content of the file.
111 std::string error_message;
112 auto file = openInputFile(filename, &error_message);
113 if (!file) {
114 return op.emitOpError("failed to open vocabulary file")
115 << " (" << filename << "): " << error_message;
116 }
117
118 // Splits into lines.
119 SmallVector<StringRef, 8> lines;
120 file->getBuffer().split(lines, "\n", -1, false);
121 // The resize method is used since split operator puts tail value in the end
122 // without splitting the leftovers.
123 if (op.vocab_size() != -1) lines.resize(op.vocab_size());
124
125 // Map each line to line number, starting from zero.
126 SmallVector<int64_t, 8> line_nums;
127 line_nums.resize(lines.size());
128 std::iota(line_nums.begin(), line_nums.end(), 0);
129
130 // Create constant ops for keys an values.
131 Value key_constant_tensor = rewriter.create<ConstantOp>(
132 op.getLoc(),
133 DenseStringElementsAttr::get(
134 RankedTensorType::get(static_cast<int64_t>(lines.size()),
135 StringType::get(rewriter.getContext())),
136 lines));
137
138 Value value_constant_tensor = rewriter.create<ConstantOp>(
139 op.getLoc(), rewriter.getI64TensorAttr(line_nums));
140
141 // Replace the given op with LookupTableImportV2Op.
142 rewriter.create<LookupTableImportV2Op>(op.getLoc(), op.table_handle(),
143 key_constant_tensor,
144 value_constant_tensor);
145 rewriter.eraseOp(op);
146 return success();
147 }
148
149 private:
150 StringRef saved_model_dir_;
151 };
152
runOnFunction()153 void InitTextFileToImportPass::runOnFunction() {
154 OwningRewritePatternList patterns(&getContext());
155 MLIRContext* context = &getContext();
156 FuncOp func = getFunction();
157
158 patterns.insert<ConvertInitializeTableFromTextFileV2>(
159 context, StringRef(saved_model_dir_));
160 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
161 }
162
163 } // namespace
164
165 // Replace InitializeTableFromTextFileV2Ops with LookupTableImportV2Ops.
CreateInitTextFileToImportPass(std::string saved_model_dir)166 std::unique_ptr<OperationPass<FuncOp>> CreateInitTextFileToImportPass(
167 std::string saved_model_dir) {
168 return std::make_unique<InitTextFileToImportPass>(saved_model_dir);
169 }
170
171 static PassRegistration<InitTextFileToImportPass> pass;
172
173 } // namespace TF
174 } // namespace mlir
175