• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
17 
18 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/None.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/Builders.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/Identifier.h"  // from @llvm-project
31 #include "mlir/IR/Location.h"  // from @llvm-project
32 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
33 #include "mlir/IR/Matchers.h"  // from @llvm-project
34 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
35 #include "mlir/IR/Operation.h"  // from @llvm-project
36 #include "mlir/IR/Types.h"  // from @llvm-project
37 #include "mlir/IR/Value.h"  // from @llvm-project
38 #include "mlir/Support/LLVM.h"  // from @llvm-project
39 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 
43 namespace mlir {
44 namespace TFL {
45 
46 namespace {
47 
48 constexpr char kNgrams[] = "tftext:Ngrams";
49 constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer";
50 constexpr char kCustomSgnnProjection[] = "tftext:custom:SgnnProjection";
51 constexpr char kTFImplements[] = "tf._implements";
52 
53 using mlir::TF::FuncAttr;
54 using mlir::TF::StringType;
55 
CustomOption(OpBuilder * builder,const std::string & content)56 inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
57                                        const std::string& content) {
58   ShapedType type = RankedTensorType::get(
59       {static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
60   return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
61                                  type,
62                                  StringRef(content.data(), content.size()));
63 }
64 
GetInputType(FuncOp func,int idx)65 inline TensorType GetInputType(FuncOp func, int idx) {
66   return func.getType().getInput(idx).dyn_cast_or_null<TensorType>();
67 }
68 
GetResultType(FuncOp func,int idx)69 inline TensorType GetResultType(FuncOp func, int idx) {
70   return func.getType().getResult(idx).dyn_cast_or_null<TensorType>();
71 }
72 
RankEquals(const TensorType & type,int rank)73 inline bool RankEquals(const TensorType& type, int rank) {
74   return type && type.hasRank() && type.getRank() == rank;
75 }
76 
VerifyWhitespaceTokenizer(FuncOp func)77 LogicalResult VerifyWhitespaceTokenizer(FuncOp func) {
78   // In the case of input tensor with 0 rank.
79   // Whitespace tokenizer generates 1 output:
80   // * String tensor for tokens.
81   //
82   // In the case of 1-D input tensor,
83   // Whitespace tokenizer generates 2 outputs to make up a ragged tensor:
84   // * 1st output is the value of ragged tensor;
85   // * 2nd output is the offset.
86   //
87   // In the case of batched input tesnor,
88   // Whitespace tokenizer has 3 outputs to make up a nested ragged tensor:
89   // * 1st output is the value of ragged tensor;
90   // * 2nd output is the inner offset;
91   // * 3rd output is the outer offset.
92   auto input_type = GetInputType(func, 0);
93   if (!input_type || !input_type.getElementType().isa<StringType>() ||
94       !input_type.hasRank()) {
95     return func.emitError() << "Input should be a string tensor";
96   }
97 
98   const std::vector<int> kValidNumOfOutput = {1, 2, 3};
99   if (input_type.getRank() >= kValidNumOfOutput.size()) {
100     return func.emitError()
101            << "Unrecognized input rank: " << input_type.getRank();
102   }
103   if (func.getNumResults() != kValidNumOfOutput[input_type.getRank()]) {
104     return func.emitError()
105            << "Expect " << kValidNumOfOutput[input_type.getRank()]
106            << "output(s) when input has rank " << input_type.getRank();
107   }
108 
109   auto value_type = GetResultType(func, 0);
110   if (!RankEquals(value_type, 1) ||
111       !value_type.getElementType().isa<StringType>()) {
112     return func.emitError() << "1st output should be string tensor";
113   }
114   if (func.getNumResults() > 1) {
115     auto offset_type = GetResultType(func, 1);
116     if (!RankEquals(offset_type, 1) ||
117         !offset_type.getElementType().isInteger(64)) {
118       return func.emitError() << "2nd output should be int64 tensor";
119     }
120   }
121   if (func.getNumResults() > 2) {
122     auto offset_type = GetResultType(func, 2);
123     if (!RankEquals(offset_type, 1) ||
124         !offset_type.getElementType().isInteger(64)) {
125       return func.emitError() << "3rd output should be int64 tensor";
126     }
127   }
128 
129   return success();
130 }
131 
ConvertWhitespaceTokenizer(FuncOp func,llvm::StringRef api,FuncAttr attr)132 LogicalResult ConvertWhitespaceTokenizer(FuncOp func, llvm::StringRef api,
133                                          FuncAttr attr) {
134   func.eraseBody();
135   func.addEntryBlock();
136   func->setAttr(kTFImplements, attr);
137   OpBuilder builder(func.getBody());
138   std::string empty_option_buffer;
139   auto op = builder.create<CustomOp>(
140       func.getLoc(), func.getType().getResults(), func.getArguments(), api,
141       CustomOption(&builder, empty_option_buffer));
142   builder.create<ReturnOp>(func.getLoc(), op.getResults());
143   return success();
144 }
145 
VerifyNgrams(FuncOp func)146 LogicalResult VerifyNgrams(FuncOp func) {
147   // The inputs and outputs should be the same:
148   // * A string tensor for tokens/ragged tensor values.
149   // * Zero or more row_split tensors.
150   constexpr int kValues = 0;
151   constexpr int kRowSplits = 1;
152 
153   if (func.getType().getInputs().size() != func.getType().getResults().size()) {
154     return func.emitError() << "Mismatched number of inputs and outputs.";
155   }
156 
157   int row_splits = func.getType().getInputs().size() - kRowSplits;
158   if (row_splits == 0) {
159     auto input_values = GetInputType(func, kValues);
160     if (!input_values || !input_values.getElementType().isa<StringType>()) {
161       return func.emitError()
162              << "Input " << kValues << " should be a string tensor";
163     }
164     auto output_values = GetResultType(func, kValues);
165     if (!output_values || !output_values.getElementType().isa<StringType>()) {
166       return func.emitError()
167              << "Output " << kValues << " should be a string tensor";
168     }
169 
170     if (input_values.hasRank() && output_values.hasRank() &&
171         input_values.getRank() != output_values.getRank()) {
172       return func.emitError() << "Input " << kValues << " and output "
173                               << kValues << " should have the same rank";
174     }
175   } else {
176     auto input_values = GetInputType(func, kValues);
177     if (!RankEquals(input_values, 1) ||
178         !input_values.getElementType().isa<StringType>()) {
179       return func.emitError()
180              << "Input " << kValues << " should be a 1D string tensor";
181     }
182     auto output_values = GetResultType(func, kValues);
183     if (!RankEquals(output_values, 1) ||
184         !output_values.getElementType().isa<StringType>()) {
185       return func.emitError()
186              << "Output " << kValues << " should be a 1D string tensor";
187     }
188 
189     for (int i = 0; i < row_splits; ++i) {
190       const int row_index = i + kRowSplits;
191       auto input_row_splits = GetInputType(func, row_index);
192       if (!RankEquals(input_row_splits, 1) ||
193           !input_row_splits.getElementType().isInteger(64)) {
194         return func.emitError()
195                << "Input " << row_index << " should be a 1D int64 tensor";
196       }
197       auto output_row_splits = GetResultType(func, row_index);
198       if (!RankEquals(output_row_splits, 1) ||
199           !output_row_splits.getElementType().isInteger(64)) {
200         return func.emitError()
201                << "Output " << row_index << " should be a 1D int64 tensor";
202       }
203     }
204   }
205 
206   return success();
207 }
208 
CreateNgramsCustomOption(FuncOp func,DictionaryAttr attrs,std::string & custom_option_buffer)209 LogicalResult CreateNgramsCustomOption(FuncOp func, DictionaryAttr attrs,
210                                        std::string& custom_option_buffer) {
211   flexbuffers::Builder fbb;
212   size_t start_map = fbb.StartMap();
213 
214   auto width = attrs.get("width").dyn_cast_or_null<IntegerAttr>();
215   if (!width) {
216     return func.emitError() << "'width' attribute is not set or not an integer";
217   }
218   fbb.Int("width", width.getInt());
219 
220   auto string_separator =
221       attrs.get("string_separator").dyn_cast_or_null<StringAttr>();
222   if (!string_separator) {
223     return func.emitError()
224            << "'string_separator' attribute is not set or not a string";
225   }
226   // StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
227   // strings expect NUL terminated strings.
228   std::string string_separator_str(string_separator.getValue().data(),
229                                    string_separator.getValue().size());
230   fbb.String("string_separator", string_separator_str);
231 
232   auto axis = attrs.get("axis").dyn_cast_or_null<IntegerAttr>();
233   if (!axis) {
234     return func.emitError() << "'axis' attribute is not set or not an integer";
235   }
236   fbb.Int("axis", axis.getInt());
237 
238   auto reduction_type =
239       attrs.get("reduction_type").dyn_cast_or_null<StringAttr>();
240   if (!reduction_type) {
241     return func.emitError()
242            << "'reduction_type' attribute is not set or not a string";
243   }
244   // StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
245   // strings expect NUL terminated strings.
246   std::string reduction_type_str(reduction_type.getValue().data(),
247                                  reduction_type.getValue().size());
248   fbb.String("reduction_type", reduction_type_str);
249 
250   fbb.EndMap(start_map);
251   fbb.Finish();
252   custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
253   return success();
254 }
255 
ConvertNgrams(FuncOp func,llvm::StringRef api,FuncAttr attr)256 LogicalResult ConvertNgrams(FuncOp func, llvm::StringRef api, FuncAttr attr) {
257   func.eraseBody();
258   func.addEntryBlock();
259   func->setAttr(kTFImplements, attr);
260   OpBuilder builder(func.getBody());
261   std::string custom_option_buffer;
262   if (failed(CreateNgramsCustomOption(func, attr.getAttrs(),
263                                       custom_option_buffer))) {
264     return failure();
265   }
266   auto op = builder.create<CustomOp>(
267       func.getLoc(), func.getType().getResults(), func.getArguments(), api,
268       CustomOption(&builder, custom_option_buffer));
269   builder.create<ReturnOp>(func.getLoc(), op.getResults());
270   return success();
271 }
272 
VerifySgnnProjection(FuncOp func,FuncAttr attr)273 LogicalResult VerifySgnnProjection(FuncOp func, FuncAttr attr) {
274   if (func.getType().getNumInputs() != 2 ||
275       func.getType().getNumResults() != 1) {
276     return func.emitError() << "Mismatched number of inputs and outputs.";
277   }
278   auto values_type = GetInputType(func, 0);
279   if (!values_type || !values_type.getElementType().isa<StringType>()) {
280     return func.emitError() << "First input should be a string tensor";
281   }
282   auto row_splits_type = GetInputType(func, 1);
283   if (!row_splits_type ||
284       !row_splits_type.getElementType().isa<IntegerType>()) {
285     return func.emitError() << "Second input should be an integer tensor";
286   }
287 
288   auto hash_seed =
289       attr.getAttrs().get("hash_seed").dyn_cast_or_null<ArrayAttr>();
290   if (!hash_seed) {
291     return func.emitError()
292            << "'hash_seed' attribute is not set or not an array";
293   }
294   auto output_type = GetResultType(func, 0);
295   if (!output_type || !output_type.getElementType().isa<FloatType>() ||
296       !RankEquals(output_type, 2)) {
297     return func.emitError() << "Output should be a 2D float tensor.";
298   }
299   if (output_type.getDimSize(1) != hash_seed.size()) {
300     return func.emitError()
301            << "Output 2nd dimension should be the num of hash seeds.";
302   }
303 
304   auto buckets = attr.getAttrs().get("buckets").dyn_cast_or_null<IntegerAttr>();
305   if (!buckets) {
306     return func.emitError() << "'buckets' attribute is not set or not int";
307   }
308 
309   return success();
310 }
311 
CreateSgnnProjectionCustomOption(FuncOp func,DictionaryAttr attrs,std::string & custom_option_buffer)312 LogicalResult CreateSgnnProjectionCustomOption(
313     FuncOp func, DictionaryAttr attrs, std::string& custom_option_buffer) {
314   flexbuffers::Builder fbb;
315   size_t start_map = fbb.StartMap();
316 
317   auto hash_seed = attrs.get("hash_seed").dyn_cast_or_null<ArrayAttr>();
318   auto vector_start = fbb.StartVector("hash_seed");
319   for (int i = 0; i < hash_seed.size(); i++) {
320     fbb.Add(static_cast<int32_t>(
321         (hash_seed.getValue().data() + i)->dyn_cast<IntegerAttr>().getInt()));
322   }
323   fbb.EndVector(vector_start, /*typed=*/true, /*fixed=*/false);
324 
325   auto buckets = attrs.get("buckets").dyn_cast_or_null<IntegerAttr>();
326   fbb.Int("buckets", buckets.getInt());
327 
328   fbb.EndMap(start_map);
329   fbb.Finish();
330   custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
331   return success();
332 }
333 
ConvertSgnnProjection(FuncOp func,llvm::StringRef api,FuncAttr attr)334 LogicalResult ConvertSgnnProjection(FuncOp func, llvm::StringRef api,
335                                     FuncAttr attr) {
336   // See more details in tensorflow_models/sequence_projection/sgnn/sgnn.py
337   func.eraseBody();
338   func.addEntryBlock();
339   func->setAttr(kTFImplements, attr);
340   OpBuilder builder(func.getBody());
341   std::string custom_option_buffer;
342   if (failed(CreateSgnnProjectionCustomOption(func, attr.getAttrs(),
343                                               custom_option_buffer))) {
344     return failure();
345   }
346   auto op = builder.create<CustomOp>(
347       func.getLoc(), func.getType().getResults(), func.getArguments(), api,
348       CustomOption(&builder, custom_option_buffer));
349   builder.create<ReturnOp>(func.getLoc(), op.getResults());
350   return success();
351 }
352 }  // namespace
353 
ConvertTFTextAPI(FuncOp func,llvm::StringRef api,FuncAttr attr)354 LogicalResult ConvertTFTextAPI(FuncOp func, llvm::StringRef api,
355                                FuncAttr attr) {
356   if (api.str() == kWhitespaceTokenizer) {
357     if (succeeded(VerifyWhitespaceTokenizer(func))) {
358       return ConvertWhitespaceTokenizer(func, api, attr);
359     }
360   } else if (api.str() == kNgrams) {
361     if (succeeded(VerifyNgrams(func))) {
362       return ConvertNgrams(func, api, attr);
363     }
364   } else if (api.str() == kCustomSgnnProjection) {
365     if (succeeded(VerifySgnnProjection(func, attr))) {
366       return ConvertSgnnProjection(func, api, attr);
367     }
368   }
369   return failure();
370 }
371 
IsTFTextRegistered(const tensorflow::OpRegistry * op_registery)372 bool IsTFTextRegistered(const tensorflow::OpRegistry* op_registery) {
373   const std::vector<std::string> kTFTextOps = {
374       "WhitespaceTokenizeWithOffsets",
375   };
376   for (const auto& iter : kTFTextOps) {
377     if (op_registery->LookUp(iter)) {
378       return true;
379     }
380   }
381   return false;
382 }
383 
384 }  // namespace TFL
385 }  // namespace mlir
386