1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
17
18 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/None.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
26 #include "mlir/IR/Attributes.h" // from @llvm-project
27 #include "mlir/IR/Builders.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/Identifier.h" // from @llvm-project
31 #include "mlir/IR/Location.h" // from @llvm-project
32 #include "mlir/IR/MLIRContext.h" // from @llvm-project
33 #include "mlir/IR/Matchers.h" // from @llvm-project
34 #include "mlir/IR/OpDefinition.h" // from @llvm-project
35 #include "mlir/IR/Operation.h" // from @llvm-project
36 #include "mlir/IR/Types.h" // from @llvm-project
37 #include "mlir/IR/Value.h" // from @llvm-project
38 #include "mlir/Support/LLVM.h" // from @llvm-project
39 #include "mlir/Support/LogicalResult.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42
43 namespace mlir {
44 namespace TFL {
45
46 namespace {
47
48 constexpr char kNgrams[] = "tftext:Ngrams";
49 constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer";
50 constexpr char kCustomSgnnProjection[] = "tftext:custom:SgnnProjection";
51 constexpr char kTFImplements[] = "tf._implements";
52
53 using mlir::TF::FuncAttr;
54 using mlir::TF::StringType;
55
CustomOption(OpBuilder * builder,const std::string & content)56 inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
57 const std::string& content) {
58 ShapedType type = RankedTensorType::get(
59 {static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
60 return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
61 type,
62 StringRef(content.data(), content.size()));
63 }
64
GetInputType(FuncOp func,int idx)65 inline TensorType GetInputType(FuncOp func, int idx) {
66 return func.getType().getInput(idx).dyn_cast_or_null<TensorType>();
67 }
68
GetResultType(FuncOp func,int idx)69 inline TensorType GetResultType(FuncOp func, int idx) {
70 return func.getType().getResult(idx).dyn_cast_or_null<TensorType>();
71 }
72
RankEquals(const TensorType & type,int rank)73 inline bool RankEquals(const TensorType& type, int rank) {
74 return type && type.hasRank() && type.getRank() == rank;
75 }
76
VerifyWhitespaceTokenizer(FuncOp func)77 LogicalResult VerifyWhitespaceTokenizer(FuncOp func) {
78 // In the case of input tensor with 0 rank.
79 // Whitespace tokenizer generates 1 output:
80 // * String tensor for tokens.
81 //
82 // In the case of 1-D input tensor,
83 // Whitespace tokenizer generates 2 outputs to make up a ragged tensor:
84 // * 1st output is the value of ragged tensor;
85 // * 2nd output is the offset.
86 //
87 // In the case of batched input tesnor,
88 // Whitespace tokenizer has 3 outputs to make up a nested ragged tensor:
89 // * 1st output is the value of ragged tensor;
90 // * 2nd output is the inner offset;
91 // * 3rd output is the outer offset.
92 auto input_type = GetInputType(func, 0);
93 if (!input_type || !input_type.getElementType().isa<StringType>() ||
94 !input_type.hasRank()) {
95 return func.emitError() << "Input should be a string tensor";
96 }
97
98 const std::vector<int> kValidNumOfOutput = {1, 2, 3};
99 if (input_type.getRank() >= kValidNumOfOutput.size()) {
100 return func.emitError()
101 << "Unrecognized input rank: " << input_type.getRank();
102 }
103 if (func.getNumResults() != kValidNumOfOutput[input_type.getRank()]) {
104 return func.emitError()
105 << "Expect " << kValidNumOfOutput[input_type.getRank()]
106 << "output(s) when input has rank " << input_type.getRank();
107 }
108
109 auto value_type = GetResultType(func, 0);
110 if (!RankEquals(value_type, 1) ||
111 !value_type.getElementType().isa<StringType>()) {
112 return func.emitError() << "1st output should be string tensor";
113 }
114 if (func.getNumResults() > 1) {
115 auto offset_type = GetResultType(func, 1);
116 if (!RankEquals(offset_type, 1) ||
117 !offset_type.getElementType().isInteger(64)) {
118 return func.emitError() << "2nd output should be int64 tensor";
119 }
120 }
121 if (func.getNumResults() > 2) {
122 auto offset_type = GetResultType(func, 2);
123 if (!RankEquals(offset_type, 1) ||
124 !offset_type.getElementType().isInteger(64)) {
125 return func.emitError() << "3rd output should be int64 tensor";
126 }
127 }
128
129 return success();
130 }
131
ConvertWhitespaceTokenizer(FuncOp func,llvm::StringRef api,FuncAttr attr)132 LogicalResult ConvertWhitespaceTokenizer(FuncOp func, llvm::StringRef api,
133 FuncAttr attr) {
134 func.eraseBody();
135 func.addEntryBlock();
136 func->setAttr(kTFImplements, attr);
137 OpBuilder builder(func.getBody());
138 std::string empty_option_buffer;
139 auto op = builder.create<CustomOp>(
140 func.getLoc(), func.getType().getResults(), func.getArguments(), api,
141 CustomOption(&builder, empty_option_buffer));
142 builder.create<ReturnOp>(func.getLoc(), op.getResults());
143 return success();
144 }
145
VerifyNgrams(FuncOp func)146 LogicalResult VerifyNgrams(FuncOp func) {
147 // The inputs and outputs should be the same:
148 // * A string tensor for tokens/ragged tensor values.
149 // * Zero or more row_split tensors.
150 constexpr int kValues = 0;
151 constexpr int kRowSplits = 1;
152
153 if (func.getType().getInputs().size() != func.getType().getResults().size()) {
154 return func.emitError() << "Mismatched number of inputs and outputs.";
155 }
156
157 int row_splits = func.getType().getInputs().size() - kRowSplits;
158 if (row_splits == 0) {
159 auto input_values = GetInputType(func, kValues);
160 if (!input_values || !input_values.getElementType().isa<StringType>()) {
161 return func.emitError()
162 << "Input " << kValues << " should be a string tensor";
163 }
164 auto output_values = GetResultType(func, kValues);
165 if (!output_values || !output_values.getElementType().isa<StringType>()) {
166 return func.emitError()
167 << "Output " << kValues << " should be a string tensor";
168 }
169
170 if (input_values.hasRank() && output_values.hasRank() &&
171 input_values.getRank() != output_values.getRank()) {
172 return func.emitError() << "Input " << kValues << " and output "
173 << kValues << " should have the same rank";
174 }
175 } else {
176 auto input_values = GetInputType(func, kValues);
177 if (!RankEquals(input_values, 1) ||
178 !input_values.getElementType().isa<StringType>()) {
179 return func.emitError()
180 << "Input " << kValues << " should be a 1D string tensor";
181 }
182 auto output_values = GetResultType(func, kValues);
183 if (!RankEquals(output_values, 1) ||
184 !output_values.getElementType().isa<StringType>()) {
185 return func.emitError()
186 << "Output " << kValues << " should be a 1D string tensor";
187 }
188
189 for (int i = 0; i < row_splits; ++i) {
190 const int row_index = i + kRowSplits;
191 auto input_row_splits = GetInputType(func, row_index);
192 if (!RankEquals(input_row_splits, 1) ||
193 !input_row_splits.getElementType().isInteger(64)) {
194 return func.emitError()
195 << "Input " << row_index << " should be a 1D int64 tensor";
196 }
197 auto output_row_splits = GetResultType(func, row_index);
198 if (!RankEquals(output_row_splits, 1) ||
199 !output_row_splits.getElementType().isInteger(64)) {
200 return func.emitError()
201 << "Output " << row_index << " should be a 1D int64 tensor";
202 }
203 }
204 }
205
206 return success();
207 }
208
CreateNgramsCustomOption(FuncOp func,DictionaryAttr attrs,std::string & custom_option_buffer)209 LogicalResult CreateNgramsCustomOption(FuncOp func, DictionaryAttr attrs,
210 std::string& custom_option_buffer) {
211 flexbuffers::Builder fbb;
212 size_t start_map = fbb.StartMap();
213
214 auto width = attrs.get("width").dyn_cast_or_null<IntegerAttr>();
215 if (!width) {
216 return func.emitError() << "'width' attribute is not set or not an integer";
217 }
218 fbb.Int("width", width.getInt());
219
220 auto string_separator =
221 attrs.get("string_separator").dyn_cast_or_null<StringAttr>();
222 if (!string_separator) {
223 return func.emitError()
224 << "'string_separator' attribute is not set or not a string";
225 }
226 // StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
227 // strings expect NUL terminated strings.
228 std::string string_separator_str(string_separator.getValue().data(),
229 string_separator.getValue().size());
230 fbb.String("string_separator", string_separator_str);
231
232 auto axis = attrs.get("axis").dyn_cast_or_null<IntegerAttr>();
233 if (!axis) {
234 return func.emitError() << "'axis' attribute is not set or not an integer";
235 }
236 fbb.Int("axis", axis.getInt());
237
238 auto reduction_type =
239 attrs.get("reduction_type").dyn_cast_or_null<StringAttr>();
240 if (!reduction_type) {
241 return func.emitError()
242 << "'reduction_type' attribute is not set or not a string";
243 }
244 // StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
245 // strings expect NUL terminated strings.
246 std::string reduction_type_str(reduction_type.getValue().data(),
247 reduction_type.getValue().size());
248 fbb.String("reduction_type", reduction_type_str);
249
250 fbb.EndMap(start_map);
251 fbb.Finish();
252 custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
253 return success();
254 }
255
ConvertNgrams(FuncOp func,llvm::StringRef api,FuncAttr attr)256 LogicalResult ConvertNgrams(FuncOp func, llvm::StringRef api, FuncAttr attr) {
257 func.eraseBody();
258 func.addEntryBlock();
259 func->setAttr(kTFImplements, attr);
260 OpBuilder builder(func.getBody());
261 std::string custom_option_buffer;
262 if (failed(CreateNgramsCustomOption(func, attr.getAttrs(),
263 custom_option_buffer))) {
264 return failure();
265 }
266 auto op = builder.create<CustomOp>(
267 func.getLoc(), func.getType().getResults(), func.getArguments(), api,
268 CustomOption(&builder, custom_option_buffer));
269 builder.create<ReturnOp>(func.getLoc(), op.getResults());
270 return success();
271 }
272
VerifySgnnProjection(FuncOp func,FuncAttr attr)273 LogicalResult VerifySgnnProjection(FuncOp func, FuncAttr attr) {
274 if (func.getType().getNumInputs() != 2 ||
275 func.getType().getNumResults() != 1) {
276 return func.emitError() << "Mismatched number of inputs and outputs.";
277 }
278 auto values_type = GetInputType(func, 0);
279 if (!values_type || !values_type.getElementType().isa<StringType>()) {
280 return func.emitError() << "First input should be a string tensor";
281 }
282 auto row_splits_type = GetInputType(func, 1);
283 if (!row_splits_type ||
284 !row_splits_type.getElementType().isa<IntegerType>()) {
285 return func.emitError() << "Second input should be an integer tensor";
286 }
287
288 auto hash_seed =
289 attr.getAttrs().get("hash_seed").dyn_cast_or_null<ArrayAttr>();
290 if (!hash_seed) {
291 return func.emitError()
292 << "'hash_seed' attribute is not set or not an array";
293 }
294 auto output_type = GetResultType(func, 0);
295 if (!output_type || !output_type.getElementType().isa<FloatType>() ||
296 !RankEquals(output_type, 2)) {
297 return func.emitError() << "Output should be a 2D float tensor.";
298 }
299 if (output_type.getDimSize(1) != hash_seed.size()) {
300 return func.emitError()
301 << "Output 2nd dimension should be the num of hash seeds.";
302 }
303
304 auto buckets = attr.getAttrs().get("buckets").dyn_cast_or_null<IntegerAttr>();
305 if (!buckets) {
306 return func.emitError() << "'buckets' attribute is not set or not int";
307 }
308
309 return success();
310 }
311
CreateSgnnProjectionCustomOption(FuncOp func,DictionaryAttr attrs,std::string & custom_option_buffer)312 LogicalResult CreateSgnnProjectionCustomOption(
313 FuncOp func, DictionaryAttr attrs, std::string& custom_option_buffer) {
314 flexbuffers::Builder fbb;
315 size_t start_map = fbb.StartMap();
316
317 auto hash_seed = attrs.get("hash_seed").dyn_cast_or_null<ArrayAttr>();
318 auto vector_start = fbb.StartVector("hash_seed");
319 for (int i = 0; i < hash_seed.size(); i++) {
320 fbb.Add(static_cast<int32_t>(
321 (hash_seed.getValue().data() + i)->dyn_cast<IntegerAttr>().getInt()));
322 }
323 fbb.EndVector(vector_start, /*typed=*/true, /*fixed=*/false);
324
325 auto buckets = attrs.get("buckets").dyn_cast_or_null<IntegerAttr>();
326 fbb.Int("buckets", buckets.getInt());
327
328 fbb.EndMap(start_map);
329 fbb.Finish();
330 custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
331 return success();
332 }
333
ConvertSgnnProjection(FuncOp func,llvm::StringRef api,FuncAttr attr)334 LogicalResult ConvertSgnnProjection(FuncOp func, llvm::StringRef api,
335 FuncAttr attr) {
336 // See more details in tensorflow_models/sequence_projection/sgnn/sgnn.py
337 func.eraseBody();
338 func.addEntryBlock();
339 func->setAttr(kTFImplements, attr);
340 OpBuilder builder(func.getBody());
341 std::string custom_option_buffer;
342 if (failed(CreateSgnnProjectionCustomOption(func, attr.getAttrs(),
343 custom_option_buffer))) {
344 return failure();
345 }
346 auto op = builder.create<CustomOp>(
347 func.getLoc(), func.getType().getResults(), func.getArguments(), api,
348 CustomOption(&builder, custom_option_buffer));
349 builder.create<ReturnOp>(func.getLoc(), op.getResults());
350 return success();
351 }
352 } // namespace
353
ConvertTFTextAPI(FuncOp func,llvm::StringRef api,FuncAttr attr)354 LogicalResult ConvertTFTextAPI(FuncOp func, llvm::StringRef api,
355 FuncAttr attr) {
356 if (api.str() == kWhitespaceTokenizer) {
357 if (succeeded(VerifyWhitespaceTokenizer(func))) {
358 return ConvertWhitespaceTokenizer(func, api, attr);
359 }
360 } else if (api.str() == kNgrams) {
361 if (succeeded(VerifyNgrams(func))) {
362 return ConvertNgrams(func, api, attr);
363 }
364 } else if (api.str() == kCustomSgnnProjection) {
365 if (succeeded(VerifySgnnProjection(func, attr))) {
366 return ConvertSgnnProjection(func, api, attr);
367 }
368 }
369 return failure();
370 }
371
IsTFTextRegistered(const tensorflow::OpRegistry * op_registery)372 bool IsTFTextRegistered(const tensorflow::OpRegistry* op_registery) {
373 const std::vector<std::string> kTFTextOps = {
374 "WhitespaceTokenizeWithOffsets",
375 };
376 for (const auto& iter : kTFTextOps) {
377 if (op_registery->LookUp(iter)) {
378 return true;
379 }
380 }
381 return false;
382 }
383
384 } // namespace TFL
385 } // namespace mlir
386