• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <numeric>
17 
18 #include "llvm/Support/Casting.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
22 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 
28 namespace mlir {
29 namespace TF {
30 namespace {
31 
32 static constexpr int kTextFileIndex_WholeLine = -2;
33 static constexpr int kTextFileIndex_LineNumber = -1;
34 
35 // InitTextFileToImportPass converts InitializeTableFromTextFileV2Op to the
36 // corresponding LookupTableImportV2Op if possible.
37 class InitTextFileToImportPass
38     : public mlir::PassWrapper<InitTextFileToImportPass, FunctionPass> {
39  public:
InitTextFileToImportPass()40   explicit InitTextFileToImportPass() {}
41 
42  private:
43   void runOnFunction() override;
44 };
45 
46 class ConvertInitializeTableFromTextFileV2
47     : public OpRewritePattern<InitializeTableFromTextFileV2Op> {
48  public:
49   using OpRewritePattern::OpRewritePattern;
50 
matchAndRewrite(InitializeTableFromTextFileV2Op op,PatternRewriter & rewriter) const51   LogicalResult matchAndRewrite(InitializeTableFromTextFileV2Op op,
52                                 PatternRewriter& rewriter) const override {
53     // Now, this pattern matching only supports the following case, which is
54     // commonly used among inference use cases:
55     //
56     // tf.lookup.TextFileInitializer(
57     //   "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE,
58     //   tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" ")
59     //
60     // In the above case, the delimiter will be not used since the key is just a
61     // whole line and value is a line number.
62     if (op.key_index() != kTextFileIndex_WholeLine ||
63         op.value_index() != kTextFileIndex_LineNumber) {
64       return failure();
65     }
66 
67     // Try to find filename from constant op.
68     DenseStringElementsAttr filename_attr;
69     if (!matchPattern(op.filename().getDefiningOp(),
70                       m_Constant(&filename_attr))) {
71       return failure();
72     }
73     StringRef filename = filename_attr.getRawStringData()[0];
74 
75     // Read the content of the file.
76     std::string error_message;
77     auto file = openInputFile(filename, &error_message);
78     if (!file) {
79       return op.emitOpError("failed to open vocabulary file")
80              << " (" << filename.str() << "): " << error_message;
81     }
82 
83     // Splits into lines.
84     SmallVector<StringRef, 8> lines;
85     file->getBuffer().split(lines, "\n", -1, false);
86     // The resize method is used since split operator puts tail value in the end
87     // without splitting the leftovers.
88     if (op.vocab_size() != -1) lines.resize(op.vocab_size());
89 
90     // Map each line to line number, starting from zero.
91     SmallVector<int64_t, 8> line_nums;
92     line_nums.resize(lines.size());
93     std::iota(line_nums.begin(), line_nums.end(), 0);
94 
95     // Create constant ops for keys an values.
96     Value key_constant_tensor = rewriter.create<ConstantOp>(
97         op.getLoc(),
98         DenseStringElementsAttr::get(
99             RankedTensorType::get(static_cast<int64_t>(lines.size()),
100                                   StringType::get(rewriter.getContext())),
101             lines));
102 
103     Value value_constant_tensor = rewriter.create<ConstantOp>(
104         op.getLoc(), rewriter.getI64TensorAttr(line_nums));
105 
106     // Replace the given op with LookupTableImportV2Op.
107     rewriter.create<LookupTableImportV2Op>(op.getLoc(), op.table_handle(),
108                                            key_constant_tensor,
109                                            value_constant_tensor);
110     rewriter.eraseOp(op);
111     return success();
112   }
113 };
114 
runOnFunction()115 void InitTextFileToImportPass::runOnFunction() {
116   OwningRewritePatternList patterns;
117   MLIRContext* context = &getContext();
118   FuncOp func = getFunction();
119 
120   patterns.insert<ConvertInitializeTableFromTextFileV2>(context);
121   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
122 }
123 
124 }  // namespace
125 
126 // Replace InitializeTableFromTextFileV2Ops with LookupTableImportV2Ops.
CreateInitTextFileToImportPass()127 std::unique_ptr<OperationPass<FuncOp>> CreateInitTextFileToImportPass() {
128   return std::make_unique<InitTextFileToImportPass>();
129 }
130 
131 static PassRegistration<InitTextFileToImportPass> pass(
132     "tf-init-text-file-to-import",
133     "convert InitializeTableFromTextFileV2 ops to LookupTableImportV2Op to "
134     "remove the dependency on asset files");
135 
136 }  // namespace TF
137 }  // namespace mlir
138