• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <numeric>
17 
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/MemoryBuffer.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
24 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
25 #include "mlir/Pass/Pass.h"  // from @llvm-project
26 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29 #include "tensorflow/core/lib/io/path.h"
30 
31 namespace mlir {
32 namespace TF {
33 namespace {
34 
35 static constexpr int kTextFileIndex_WholeLine = -2;
36 static constexpr int kTextFileIndex_LineNumber = -1;
37 
38 // InitTextFileToImportPass converts InitializeTableFromTextFileV2Op to the
39 // corresponding LookupTableImportV2Op if possible.
40 class InitTextFileToImportPass
41     : public mlir::PassWrapper<InitTextFileToImportPass, FunctionPass> {
42  public:
InitTextFileToImportPass()43   InitTextFileToImportPass() {}
InitTextFileToImportPass(const InitTextFileToImportPass &)44   InitTextFileToImportPass(const InitTextFileToImportPass&) {}
InitTextFileToImportPass(std::string saved_model_dir)45   explicit InitTextFileToImportPass(std::string saved_model_dir) {
46     saved_model_dir_ = saved_model_dir;
47   }
48 
getArgument() const49   StringRef getArgument() const final { return "tf-init-text-file-to-import"; }
50 
getDescription() const51   StringRef getDescription() const final {
52     return "convert InitializeTableFromTextFileV2 ops to LookupTableImportV2Op "
53            "to remove the dependency on asset files";
54   }
55 
56  private:
57   void runOnFunction() override;
58 
59   Option<std::string> saved_model_dir_{
60       *this, "tf-saved-model-dir",
61       llvm::cl::desc("Directory containing the model exported as a TensorFlow "
62                      "SavedModel. If your model is not based on the TensorFlow "
63                      "SavedModel, use an empty value."),
64       llvm::cl::init("")};
65 };
66 
67 class ConvertInitializeTableFromTextFileV2
68     : public OpRewritePattern<InitializeTableFromTextFileV2Op> {
69  public:
ConvertInitializeTableFromTextFileV2(mlir::MLIRContext * context,StringRef saved_model_dir)70   explicit ConvertInitializeTableFromTextFileV2(mlir::MLIRContext* context,
71                                                 StringRef saved_model_dir)
72       : OpRewritePattern<InitializeTableFromTextFileV2Op>(context),
73         saved_model_dir_(saved_model_dir) {}
74 
matchAndRewrite(InitializeTableFromTextFileV2Op op,PatternRewriter & rewriter) const75   LogicalResult matchAndRewrite(InitializeTableFromTextFileV2Op op,
76                                 PatternRewriter& rewriter) const override {
77     // Now, this pattern matching only supports the following case, which is
78     // commonly used among inference use cases:
79     //
80     // tf.lookup.TextFileInitializer(
81     //   "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE,
82     //   tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" ")
83     //
84     // In the above case, the delimiter will be not used since the key is just a
85     // whole line and value is a line number.
86     if (op.key_index() != kTextFileIndex_WholeLine ||
87         op.value_index() != kTextFileIndex_LineNumber) {
88       return failure();
89     }
90 
91     // Try to find filename from constant op.
92     DenseStringElementsAttr filename_attr;
93     if (!matchPattern(op.filename().getDefiningOp(),
94                       m_Constant(&filename_attr))) {
95       return failure();
96     }
97 
98     if (filename_attr.getRawStringData().size() != 1) {
99       return failure();
100     }
101     std::string filename = filename_attr.getRawStringData()[0].str();
102 
103     if (!saved_model_dir_.empty()) {
104       filename = tensorflow::io::JoinPath(
105           saved_model_dir_.str(),
106           tensorflow::io::JoinPath("assets",
107                                    tensorflow::io::Basename(filename)));
108     }
109 
110     // Read the content of the file.
111     std::string error_message;
112     auto file = openInputFile(filename, &error_message);
113     if (!file) {
114       return op.emitOpError("failed to open vocabulary file")
115              << " (" << filename << "): " << error_message;
116     }
117 
118     // Splits into lines.
119     SmallVector<StringRef, 8> lines;
120     file->getBuffer().split(lines, "\n", -1, false);
121     // The resize method is used since split operator puts tail value in the end
122     // without splitting the leftovers.
123     if (op.vocab_size() != -1) lines.resize(op.vocab_size());
124 
125     // Map each line to line number, starting from zero.
126     SmallVector<int64_t, 8> line_nums;
127     line_nums.resize(lines.size());
128     std::iota(line_nums.begin(), line_nums.end(), 0);
129 
130     // Create constant ops for keys an values.
131     Value key_constant_tensor = rewriter.create<ConstantOp>(
132         op.getLoc(),
133         DenseStringElementsAttr::get(
134             RankedTensorType::get(static_cast<int64_t>(lines.size()),
135                                   StringType::get(rewriter.getContext())),
136             lines));
137 
138     Value value_constant_tensor = rewriter.create<ConstantOp>(
139         op.getLoc(), rewriter.getI64TensorAttr(line_nums));
140 
141     // Replace the given op with LookupTableImportV2Op.
142     rewriter.create<LookupTableImportV2Op>(op.getLoc(), op.table_handle(),
143                                            key_constant_tensor,
144                                            value_constant_tensor);
145     rewriter.eraseOp(op);
146     return success();
147   }
148 
149  private:
150   StringRef saved_model_dir_;
151 };
152 
runOnFunction()153 void InitTextFileToImportPass::runOnFunction() {
154   OwningRewritePatternList patterns(&getContext());
155   MLIRContext* context = &getContext();
156   FuncOp func = getFunction();
157 
158   patterns.insert<ConvertInitializeTableFromTextFileV2>(
159       context, StringRef(saved_model_dir_));
160   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
161 }
162 
163 }  // namespace
164 
165 // Replace InitializeTableFromTextFileV2Ops with LookupTableImportV2Ops.
CreateInitTextFileToImportPass(std::string saved_model_dir)166 std::unique_ptr<OperationPass<FuncOp>> CreateInitTextFileToImportPass(
167     std::string saved_model_dir) {
168   return std::make_unique<InitTextFileToImportPass>(saved_model_dir);
169 }
170 
171 static PassRegistration<InitTextFileToImportPass> pass;
172 
173 }  // namespace TF
174 }  // namespace mlir
175