1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
17
18 #include <stddef.h>
19 #include <stdlib.h>
20
21 #include <cstdint>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/base/attributes.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/strings/match.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/strings/string_view.h"
34 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
35 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
36 #include "llvm/ADT/ArrayRef.h"
37 #include "llvm/ADT/DenseMap.h"
38 #include "llvm/ADT/None.h"
39 #include "llvm/ADT/Optional.h"
40 #include "llvm/ADT/StringRef.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Support/CommandLine.h"
43 #include "llvm/Support/FormatVariadic.h"
44 #include "llvm/Support/ToolOutputFile.h"
45 #include "llvm/Support/raw_ostream.h"
46 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
47 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
48 #include "mlir/IR/Attributes.h" // from @llvm-project
49 #include "mlir/IR/Builders.h" // from @llvm-project
50 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
51 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
52 #include "mlir/IR/Location.h" // from @llvm-project
53 #include "mlir/IR/MLIRContext.h" // from @llvm-project
54 #include "mlir/IR/Operation.h" // from @llvm-project
55 #include "mlir/IR/Types.h" // from @llvm-project
56 #include "mlir/IR/Value.h" // from @llvm-project
57 #include "mlir/Support/LogicalResult.h" // from @llvm-project
58 #include "mlir/Translation.h" // from @llvm-project
59 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
60 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
61 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
62 #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
63 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
64 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
65 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
67 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
68 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
69 #include "tensorflow/compiler/xla/statusor.h"
70 #include "tensorflow/core/framework/attr_value.pb.h"
71 #include "tensorflow/core/framework/node_def.pb.h"
72 #include "tensorflow/core/platform/errors.h"
73 #include "tensorflow/core/platform/logging.h"
74 #include "tensorflow/core/platform/status.h"
75 #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
76 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
77 #include "tensorflow/lite/schema/schema_conversion_utils.h"
78 #include "tensorflow/lite/schema/schema_generated.h"
79 #include "tensorflow/lite/string_util.h"
80 #include "tensorflow/lite/tools/versioning/op_version.h"
81 #include "tensorflow/lite/tools/versioning/runtime_version.h"
82 #include "tensorflow/lite/version.h"
83
84 using llvm::dyn_cast;
85 using llvm::formatv;
86 using llvm::isa;
87 using llvm::Optional;
88 using llvm::StringRef;
89 using llvm::Twine;
90 using mlir::Dialect;
91 using mlir::ElementsAttr;
92 using mlir::FuncOp;
93 using mlir::MLIRContext;
94 using mlir::ModuleOp;
95 using mlir::NoneType;
96 using mlir::Operation;
97 using mlir::Region;
98 using mlir::StringAttr;
99 using mlir::TensorType;
100 using mlir::Type;
101 using mlir::UnknownLoc;
102 using mlir::Value;
103 using tensorflow::OpOrArgLocNameMapper;
104 using tensorflow::OpOrArgNameMapper;
105 using tensorflow::Status;
106 using tflite::flex::IsAllowlistedFlexOp;
107 using xla::StatusOr;
108
109 template <typename T>
110 using BufferOffset = flatbuffers::Offset<T>;
111
112 template <typename T>
113 using VectorBufferOffset = flatbuffers::Offset<flatbuffers::Vector<T>>;
114
115 using CustomOptionsOffset = VectorBufferOffset<uint8_t>;
116
117 namespace error = tensorflow::error;
118 namespace tfl = mlir::TFL;
119
120 ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex";
121
122 // Use initial buffer size in flatbuffer builder to be same as the initial size
123 // used by the TOCO export. (It does not explain rationale for this choice.)
124 constexpr size_t kInitialBufferSize = 10240;
125
126 // Set `isSigned` to false if the `type` is an 8-bit unsigned integer type.
127 // Since tflite doesn't support unsigned for other types, returns error if
128 // `isSigned` is set to false for other types.
GetTFLiteType(Type type,bool is_signed=true)129 static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
130 bool is_signed = true) {
131 if (!is_signed && type.isSignlessInteger(8)) {
132 return tflite::TensorType_UINT8;
133 }
134 if (!is_signed) {
135 return Status(error::INVALID_ARGUMENT,
136 "'isSigned' can only be set for 8-bits integer type");
137 }
138
139 if (type.isF32()) {
140 return tflite::TensorType_FLOAT32;
141 } else if (type.isF16()) {
142 return tflite::TensorType_FLOAT16;
143 } else if (type.isF64()) {
144 return tflite::TensorType_FLOAT64;
145 } else if (type.isa<mlir::TF::StringType>()) {
146 return tflite::TensorType_STRING;
147 } else if (type.isa<mlir::TF::Quint8Type>()) {
148 return tflite::TensorType_UINT8;
149 } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
150 auto ftype = complex_type.getElementType();
151 if (ftype.isF32()) {
152 return tflite::TensorType_COMPLEX64;
153 }
154 if (ftype.isF64()) {
155 return tflite::TensorType_COMPLEX128;
156 }
157 return Status(error::INVALID_ARGUMENT, "Unsupported type");
158 } else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
159 switch (itype.getWidth()) {
160 case 1:
161 return tflite::TensorType_BOOL;
162 case 8:
163 return itype.isUnsigned() ? tflite::TensorType_UINT8
164 : tflite::TensorType_INT8;
165 case 16:
166 return tflite::TensorType_INT16;
167 case 32:
168 return itype.isUnsigned() ? tflite::TensorType_UINT32
169 : tflite::TensorType_INT32;
170 case 64:
171 return itype.isUnsigned() ? tflite::TensorType_UINT64
172 : tflite::TensorType_INT64;
173 }
174 } else if (auto q_uniform_type =
175 type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
176 return GetTFLiteType(q_uniform_type.getStorageType(),
177 q_uniform_type.isSigned());
178 } else if (auto q_peraxis_type =
179 type.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
180 return GetTFLiteType(q_peraxis_type.getStorageType(),
181 q_peraxis_type.isSigned());
182 } else if (auto q_calibrated_type =
183 type.dyn_cast<mlir::quant::CalibratedQuantizedType>()) {
184 return GetTFLiteType(q_calibrated_type.getExpressedType());
185 } else if (type.isa<mlir::TF::ResourceType>()) {
186 return tflite::TensorType_RESOURCE;
187 } else if (type.isa<mlir::TF::VariantType>()) {
188 return tflite::TensorType_VARIANT;
189 }
190 // TFLite export fills FLOAT32 for unknown data types. Returning an error
191 // for now for safety and this could be revisited when required.
192 return Status(error::INVALID_ARGUMENT, "Unsupported type");
193 }
194
IsConst(Operation * op)195 static bool IsConst(Operation* op) {
196 return isa<mlir::ConstantOp, mlir::TF::ConstOp, tfl::ConstOp, tfl::QConstOp,
197 tfl::SparseConstOp, tfl::SparseQConstOp>(op);
198 }
199
IsTFResourceOp(Operation * op)200 static bool IsTFResourceOp(Operation* op) {
201 for (const auto& operand : op->getOperands()) {
202 auto elementType = getElementTypeOrSelf(operand.getType());
203 if (elementType.isa<mlir::TF::ResourceType>()) {
204 return true;
205 }
206 }
207 for (const auto& result : op->getResults()) {
208 auto elementType = getElementTypeOrSelf(result.getType());
209 if (elementType.isa<mlir::TF::ResourceType>()) {
210 return true;
211 }
212 }
213 return false;
214 }
215
216 // Create description of operation that could not be converted.
GetOpDescriptionForDebug(Operation * inst)217 static std::string GetOpDescriptionForDebug(Operation* inst) {
218 const int kLargeElementsAttr = 16;
219 std::string op_str;
220 llvm::raw_string_ostream os(op_str);
221 inst->getName().print(os);
222 // Print out attributes except for large elementsattributes (which should
223 // rarely be the cause why the legalization didn't happen).
224 if (!inst->getAttrDictionary().empty()) {
225 os << " {";
226 bool first = true;
227 for (auto& named_attr : inst->getAttrDictionary()) {
228 os << (!first ? ", " : "");
229 first = false;
230 named_attr.first.print(os);
231 os << " = ";
232 if (auto element_attr = named_attr.second.dyn_cast<ElementsAttr>()) {
233 if (element_attr.getNumElements() <= kLargeElementsAttr) {
234 element_attr.print(os);
235 } else {
236 os << "<large>";
237 }
238 } else {
239 named_attr.second.print(os);
240 }
241 }
242 os << "}";
243 }
244 return os.str();
245 }
246
247 // Create a summary with the given information regarding op names and
248 // descriptions.
GetOpsSummary(const std::map<std::string,std::set<std::string>> & ops,const std::string & summary_title)249 static std::string GetOpsSummary(
250 const std::map<std::string, std::set<std::string>>& ops,
251 const std::string& summary_title) {
252 std::string op_str;
253 llvm::raw_string_ostream os(op_str);
254
255 std::vector<std::string> keys;
256 keys.reserve(ops.size());
257
258 std::vector<std::string> values;
259 values.reserve(ops.size());
260
261 for (auto const& op_name_and_details : ops) {
262 keys.push_back(op_name_and_details.first);
263 for (auto const& op_detail : op_name_and_details.second) {
264 values.push_back(op_detail);
265 }
266 }
267
268 os << summary_title << " ops: " << absl::StrJoin(keys, ", ") << "\n";
269 os << "Details:\n\t" << absl::StrJoin(values, "\n\t");
270
271 return os.str();
272 }
273
274 template <typename T>
HasValidTFLiteType(Value value,T & error_handler)275 static bool HasValidTFLiteType(Value value, T& error_handler) {
276 // None type is allowed to represent unspecified operands.
277 if (value.getType().isa<NoneType>()) return true;
278
279 auto type = value.getType().dyn_cast<TensorType>();
280 if (!type) {
281 if (auto op = value.getDefiningOp()) {
282 error_handler.emitError()
283 << '\'' << op << "' should produce value of tensor type instead of "
284 << value.getType();
285 return false;
286 }
287 error_handler.emitError("expected tensor type, got ") << value.getType();
288 return false;
289 }
290
291 Type element_type = type.getElementType();
292 auto status = GetTFLiteType(element_type);
293 if (!status.ok()) {
294 return error_handler.emitError(
295 formatv("Failed to convert element type '{0}': {1}",
296 element_type, status.status().error_message())),
297 false;
298 }
299 return true;
300 }
301
302 // Returns true if the module holds all the invariants expected by the
303 // Translator class.
304 // TODO(hinsu): Now that translation is done by making a single pass over the
305 // MLIR module, consider inlining these validation checks at the place where
306 // these invariants are assumed instead of checking upfront.
IsValidTFLiteMlirModule(ModuleOp module)307 static bool IsValidTFLiteMlirModule(ModuleOp module) {
308 MLIRContext* context = module.getContext();
309
310 // Verify that module has a function named main.
311 FuncOp main_fn = module.lookupSymbol<FuncOp>("main");
312 if (!main_fn) {
313 return emitError(UnknownLoc::get(context),
314 "should have a function named 'main'"),
315 false;
316 }
317
318 for (auto fn : module.getOps<FuncOp>()) {
319 if (!llvm::hasSingleElement(fn)) {
320 return fn.emitError("should have exactly one basic block"), false;
321 }
322 auto& bb = fn.front();
323
324 for (auto arg : bb.getArguments()) {
325 if (!HasValidTFLiteType(arg, fn)) {
326 auto elementType = getElementTypeOrSelf(arg.getType());
327 if (elementType.isa<mlir::TF::VariantType>()) {
328 return fn.emitError(
329 "function argument uses variant type. Currently, the "
330 "variant type is not natively supported in TFLite. Please "
331 "consider not using the variant type: ")
332 << arg.getType(),
333 false;
334 }
335 return fn.emitError("invalid TFLite type: ") << arg.getType(), false;
336 }
337 }
338
339 // Verify that all operations except the terminator have exactly one
340 // result of type supported by TFLite.
341 for (auto& inst : bb) {
342 if (inst.hasTrait<mlir::OpTrait::IsTerminator>()) break;
343
344 for (auto result : inst.getResults()) {
345 if (!HasValidTFLiteType(result, inst)) {
346 auto elementType = getElementTypeOrSelf(result.getType());
347 if (elementType.isa<mlir::TF::VariantType>()) {
348 return inst.emitError(
349 "operand result uses variant type. Currently, the "
350 "variant type is not natively supported in TFLite. "
351 "Please "
352 "consider not using the variant type: ")
353 << result.getType(),
354 false;
355 }
356 return fn.emitError("invalid TFLite type: ") << result.getType(),
357 false;
358 }
359 }
360 }
361 }
362
363 return true;
364 }
365
GetTensorFlowNodeDef(::mlir::Operation * inst)366 static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef(
367 ::mlir::Operation* inst) {
368 // We pass empty string for the original node_def name since Flex runtime
369 // does not care about this being set correctly on node_def. There is no
370 // "easy" (see b/120948529) way yet to get this from MLIR inst.
371 auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef(
372 inst, /*name=*/"", /*ignore_unregistered_attrs=*/true);
373 if (!status_or_node_def.ok()) {
374 inst->emitOpError(
375 Twine("failed to obtain TensorFlow nodedef with status: " +
376 status_or_node_def.status().ToString()));
377 return {};
378 }
379 return std::move(status_or_node_def.ValueOrDie());
380 }
381
382 // Converts a mlir padding StringRef to TfLitePadding.
383 // Returns llvm::None if conversion fails.
GetTflitePadding(Operation * inst,llvm::StringRef padding)384 static Optional<TfLitePadding> GetTflitePadding(Operation* inst,
385 llvm::StringRef padding) {
386 const tflite::Padding padding_attr =
387 std::move(llvm::StringSwitch<tflite::Padding>(padding)
388 .Case("SAME", tflite::Padding_SAME)
389 .Case("VALID", tflite::Padding_VALID));
390 if (padding_attr == tflite::Padding_SAME) {
391 return kTfLitePaddingSame;
392 }
393 if (padding_attr == tflite::Padding_VALID) {
394 return kTfLitePaddingValid;
395 }
396
397 return inst->emitOpError() << "Invalid padding attribute: " << padding,
398 llvm::None;
399 }
400
401 // Extracts TfLitePoolParams from a TFL custom op.
402 // Template parameter, TFLOp, should be a TFL custom op containing attributes
403 // generated from TfLitePoolParams.
404 // Returns llvm::None if conversion fails.
405 template <typename TFLOp>
GetTflitePoolParams(Operation * inst,TFLOp op)406 static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
407 TFLOp op) {
408 TfLitePoolParams pool_params;
409 pool_params.stride_height = op.stride_h().getSExtValue();
410 pool_params.stride_width = op.stride_w().getSExtValue();
411 pool_params.filter_height = op.filter_h().getSExtValue();
412 pool_params.filter_width = op.filter_w().getSExtValue();
413 const auto padding = GetTflitePadding(inst, op.padding());
414 if (padding) {
415 pool_params.padding = *padding;
416 pool_params.activation = kTfLiteActNone;
417 pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
418 return pool_params;
419 }
420
421 return llvm::None;
422 }
423
424 namespace {
425
426 // Helper struct that wraps inputs/outputs of a single SignatureDef.
427 struct SignatureDefData {
428 // Note, we are using maps here to make order deterministic
429 // for easily testing only.
430
431 // Inputs defined in the signature def mapped to tensor names.
432 std::map<std::string, std::string> inputs;
433 // Outputs defined in the signature def mapped to tensor names.
434 std::map<std::string, std::string> outputs;
435 // Method name exported by the signature def.
436 std::string method_name;
437 // SignatureDef key.
438 std::string signature_def_key;
439 };
440
441 // Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
442 class Translator {
443 public:
444 // Translates the given MLIR module into TFLite FlatBuffer format and returns
445 // the serialized output. Returns llvm::None on unsupported, invalid inputs or
446 // internal error.
447 static Optional<std::string> Translate(
448 ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
449 bool emit_custom_ops,
450 const std::unordered_set<std::string>& select_user_tf_ops,
451 const std::unordered_set<std::string>& tags,
452 OpOrArgNameMapper* op_or_arg_name_mapper);
453
454 private:
455 enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
Translator(ModuleOp module,bool emit_builtin_tflite_ops,bool emit_select_tf_ops,bool emit_custom_ops,const std::unordered_set<std::string> & select_user_tf_ops,const std::unordered_set<std::string> & saved_model_tags,OpOrArgNameMapper * op_or_arg_name_mapper)456 explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops,
457 bool emit_select_tf_ops, bool emit_custom_ops,
458 const std::unordered_set<std::string>& select_user_tf_ops,
459 const std::unordered_set<std::string>& saved_model_tags,
460 OpOrArgNameMapper* op_or_arg_name_mapper)
461 : module_(module),
462 name_mapper_(*op_or_arg_name_mapper),
463 builder_(kInitialBufferSize),
464 saved_model_tags_(saved_model_tags),
465 select_user_tf_ops_(select_user_tf_ops) {
466 // The first buffer must be empty according to the schema definition.
467 empty_buffer_ = tflite::CreateBuffer(builder_);
468 buffers_.push_back(empty_buffer_);
469 if (emit_builtin_tflite_ops) {
470 enabled_op_types_.emplace(OpType::kTfliteBuiltin);
471 }
472 if (emit_select_tf_ops) {
473 enabled_op_types_.emplace(OpType::kSelectTf);
474 }
475 if (emit_custom_ops) {
476 enabled_op_types_.emplace(OpType::kCustomOp);
477 }
478 tf_dialect_ =
479 module.getContext()->getOrLoadDialect<mlir::TF::TensorFlowDialect>();
480 tfl_dialect_ = module.getContext()
481 ->getOrLoadDialect<mlir::TFL::TensorFlowLiteDialect>();
482 // Right now the TF executor dialect is still needed to build NodeDef.
483 module.getContext()
484 ->getOrLoadDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
485 }
486
487 Optional<std::string> TranslateInternal();
488
489 // Returns TFLite buffer populated with constant value if the operation is
490 // TFLite constant operation. Otherwise, returns an empty buffer. Emits error
491 // and returns llvm::None on failure.
492 Optional<BufferOffset<tflite::Buffer>> BuildBuffer(Operation* inst);
493
494 // Build TFLite tensor from the given type. This function is for tfl.lstm
495 // intermediates, which should have UniformQuantizedType.
496 Optional<BufferOffset<tflite::Tensor>> BuildTensorFromType(
497 mlir::Type type, const std::string& name);
498
499 // Builds TFLite tensor from the given value. `buffer_idx` is index of the
500 // corresponding buffer. Emits error and returns llvm::None on failure.
501 Optional<BufferOffset<tflite::Tensor>> BuildTensor(
502 Value value, const std::string& name, unsigned buffer_idx,
503 const Optional<BufferOffset<tflite::QuantizationParameters>>&
504 quant_parameters);
505
506 // TODO(b/137395003): Legalize control flow ops to TFLite dialect, and remove
507 // these 2 functions here.
508 BufferOffset<tflite::Operator> BuildIfOperator(
509 mlir::TF::IfOp op, const std::vector<int32_t>& operands,
510 const std::vector<int32_t>& results);
511 BufferOffset<tflite::Operator> BuildWhileOperator(
512 mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
513 const std::vector<int32_t>& results);
514
515 // Build while operator where cond & body are regions.
516 Optional<BufferOffset<tflite::Operator>> BuildWhileOperator(
517 mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
518 const std::vector<int32_t>& results);
519
520 // Build call once operator.
521 BufferOffset<tflite::Operator> BuildCallOnceOperator(
522 mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
523 const std::vector<int32_t>& results);
524
525 BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
526 mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
527 const std::vector<int32_t>& results);
528
529 // Builds Assign/Read Variable ops.
530 template <typename T>
531 BufferOffset<tflite::Operator> BuildVariableOperator(
532 T op, const std::string& op_name, const std::vector<int32_t>& operands,
533 const std::vector<int32_t>& results);
534
535 BufferOffset<tflite::Operator> BuildCustomOperator(
536 Operation* inst, mlir::TFL::CustomOp op,
537 const std::vector<int32_t>& operands,
538 const std::vector<int32_t>& results);
539
540 Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
541 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
542
543 Optional<CustomOptionsOffset> CreateCustomOpCustomOptions(
544 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
545
546 std::unique_ptr<flexbuffers::Builder> CreateFlexBuilderWithNodeAttrs(
547 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
548
549 // Returns opcode index for op identified by the op_name, if already
550 // available. Otherwise, creates a new OperatorCode using the given `builtin`
551 // operator and associates it with `op_name`.
552 uint32_t GetOpcodeIndex(const std::string& op_name,
553 tflite::BuiltinOperator builtin);
554
555 // Builds operator for the given operation with specified operand and result
556 // tensor indices. Emits an error and returns llvm::None on failure.
557 Optional<BufferOffset<tflite::Operator>> BuildOperator(
558 Operation* inst, std::vector<int32_t> operands,
559 const std::vector<int32_t>& results,
560 const std::vector<int32_t>& intermediates);
561
562 // Returns the quantization parameters for output value of "quant.stats" op.
563 BufferOffset<tflite::QuantizationParameters>
564 GetQuantizationForQuantStatsOpOutput(mlir::quant::StatisticsOp stats_op);
565
566 // Build a subgraph with a given name out of the region either corresponding
567 // to a function's body or while op.
568 Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(
569 const std::string& name, Region* region);
570
571 // Builds Metadata with the given `name` and buffer `content`.
572 BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
573 StringRef content);
574
575 // Encodes the `tfl.metadata` dictionary attribute of the module to the
576 // metadata section in the final model.
577 Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
578 CreateMetadataVector();
579
580 // Builds and returns list of tfl.SignatureDef sections in the model.
581 Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
582 CreateSignatureDefs(const std::vector<SignatureDefData>& signature_defs);
583
584 // Returns list of offsets for the passed 'items' in TensorMap structure
585 // inside the flatbuffer.
586 // 'items' is a map from tensor name in signatureDef to tensor name in
587 // the model.
588 std::vector<BufferOffset<tflite::TensorMap>> GetList(
589 const std::map<std::string, std::string>& items);
590
591 // Uses the tf.entry_function attribute (if set) to initialize the op to name
592 // mapping.
593 void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr);
594
595 // Determines if the specified operation op's operand at operand_index
596 // is marked as a stateful operand.
597 bool IsStatefulOperand(mlir::Operation* op, int operand_index);
598
599 // Returns a unique name for `val`.
600 std::string UniqueName(mlir::Value val);
601
602 BufferOffset<tflite::SparsityParameters> BuildSparsityParameters(
603 const mlir::TFL::SparsityParameterAttr& s_attr);
604
605 ModuleOp module_;
606
607 tensorflow::OpOrArgNameMapper& name_mapper_;
608
609 flatbuffers::FlatBufferBuilder builder_;
610 BufferOffset<tflite::Buffer> empty_buffer_;
611
612 std::vector<BufferOffset<tflite::Buffer>> buffers_;
613 // Maps tensor name in the graph to the tensor index.
614 absl::flat_hash_map<std::string, int> tensor_index_map_;
615
616 // Maps op name to index of the corresponding OperatorCode in opcodes_ vector.
617 absl::flat_hash_map<std::string, uint32_t> opcode_index_map_;
618 std::vector<BufferOffset<tflite::OperatorCode>> opcodes_;
619
620 // Maps function name to index of the corresponding subgraph in the FlatBuffer
621 // model.
622 absl::flat_hash_map<std::string, int> subgraph_index_map_;
623 absl::flat_hash_set<OpType> enabled_op_types_;
624
625 // Points to TensorFlow and TFLite dialects, respectively. nullptr if the
626 // dialect is not registered.
627 const Dialect* tf_dialect_;
628 const Dialect* tfl_dialect_;
629
630 // The failed ops during legalization.
631 std::map<std::string, std::set<std::string>> failed_flex_ops_;
632 std::map<std::string, std::set<std::string>> failed_custom_ops_;
633
634 // Ops to provide warning messages.
635 std::map<std::string, std::set<std::string>> custom_ops_;
636 std::map<std::string, std::set<std::string>> flex_ops_;
637
638 // Resource ops to provide warning messages.
639 std::map<std::string, std::set<std::string>> resource_ops_;
640
641 // Set of saved model tags, if any.
642 const std::unordered_set<std::string> saved_model_tags_;
643 // User's defined ops allowed with Flex.
644 const std::unordered_set<std::string> select_user_tf_ops_;
645 };
646
UniqueName(mlir::Value val)647 std::string Translator::UniqueName(mlir::Value val) {
648 return std::string(name_mapper_.GetUniqueName(val));
649 }
650
BuildBuffer(Operation * inst)651 Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
652 Operation* inst) {
653 ElementsAttr attr;
654 if (auto cst = dyn_cast<mlir::ConstantOp>(inst)) {
655 // ConstantOp have ElementAttr at this point due to validation of the TFLite
656 // module.
657 attr = cst.getValue().cast<ElementsAttr>();
658 } else if (auto cst = dyn_cast<mlir::TF::ConstOp>(inst)) {
659 attr = cst.value();
660 } else if (auto cst = dyn_cast<tfl::ConstOp>(inst)) {
661 attr = cst.value();
662 } else if (auto cst = dyn_cast<tfl::QConstOp>(inst)) {
663 attr = cst.value();
664 } else if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
665 attr = cst.compressed_data();
666 } else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
667 attr = cst.compressed_data();
668 } else {
669 return empty_buffer_;
670 }
671
672 tensorflow::Tensor tensor;
673 auto status = tensorflow::ConvertToTensor(attr, &tensor);
674 if (!status.ok()) {
675 inst->emitError(
676 Twine("failed to convert value attribute to tensor with error: " +
677 status.ToString()));
678 return llvm::None;
679 }
680
681 // TensorFlow and TensorFlow Lite use different string encoding formats.
682 // Convert to TensorFlow Lite format is it's a constant string tensor.
683 if (tensor.dtype() == tensorflow::DT_STRING) {
684 ::tflite::DynamicBuffer dynamic_buffer;
685 auto flat = tensor.flat<::tensorflow::tstring>();
686 for (int i = 0; i < flat.size(); ++i) {
687 const auto& str = flat(i);
688 dynamic_buffer.AddString(str.c_str(), str.length());
689 }
690 char* tensor_buffer;
691 int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer);
692 auto buffer_data =
693 builder_.CreateVector(reinterpret_cast<uint8_t*>(tensor_buffer), bytes);
694 free(tensor_buffer);
695 return tflite::CreateBuffer(builder_, buffer_data);
696 }
697
698 absl::string_view tensor_data = tensor.tensor_data();
699 auto buffer_data = builder_.CreateVector(
700 reinterpret_cast<const uint8_t*>(tensor_data.data()), tensor_data.size());
701 return tflite::CreateBuffer(builder_, buffer_data);
702 }
703
BuildTensorFromType(mlir::Type type,const std::string & name)704 Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensorFromType(
705 mlir::Type type, const std::string& name) {
706 auto tensor_type = type.cast<TensorType>();
707
708 if (!tensor_type.hasStaticShape()) {
709 return llvm::None;
710 }
711 llvm::ArrayRef<int64_t> shape_ref = tensor_type.getShape();
712 std::vector<int32_t> shape(shape_ref.begin(), shape_ref.end());
713
714 auto element_type = tensor_type.getElementType();
715 tflite::TensorType tflite_element_type =
716 GetTFLiteType(tensor_type.getElementType()).ValueOrDie();
717 BufferOffset<tflite::QuantizationParameters> q_params = 0;
718 if (auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
719 q_params = tflite::CreateQuantizationParameters(
720 builder_, /*min=*/0, /*max=*/0,
721 builder_.CreateVector<float>({static_cast<float>(qtype.getScale())}),
722 builder_.CreateVector<int64_t>({qtype.getZeroPoint()}));
723 } else if (auto qtype =
724 element_type
725 .dyn_cast<mlir::quant::CalibratedQuantizedType>()) {
726 q_params = tflite::CreateQuantizationParameters(
727 builder_,
728 builder_.CreateVector<float>({static_cast<float>(qtype.getMin())}),
729 builder_.CreateVector<float>({static_cast<float>(qtype.getMax())}));
730 }
731 return tflite::CreateTensor(
732 builder_, builder_.CreateVector(shape), tflite_element_type,
733 /*buffer=*/0, builder_.CreateString(name), q_params,
734 /*is_variable=*/false);
735 }
736
BuildTensor(Value value,const std::string & name,unsigned buffer_idx,const Optional<BufferOffset<tflite::QuantizationParameters>> & quant_parameters)737 Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
738 Value value, const std::string& name, unsigned buffer_idx,
739 const Optional<BufferOffset<tflite::QuantizationParameters>>&
740 quant_parameters) {
741 auto type = value.getType().cast<TensorType>();
742
743 // TFLite requires tensor shape only for the inputs and constants.
744 // However, we output all known shapes for better round-tripping
745 auto check_shape =
746 [&](llvm::ArrayRef<int64_t> shape_ref) -> mlir::LogicalResult {
747 auto is_out_of_range = [](int64_t dim) {
748 return dim > std::numeric_limits<int32_t>::max();
749 };
750
751 if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range))
752 return mlir::emitError(
753 value.getLoc(),
754 "result shape dimensions out of 32 bit int type range");
755
756 return mlir::success();
757 };
758
759 std::vector<int32_t> shape;
760 std::vector<int32_t> shape_signature;
761 auto* inst = value.getDefiningOp();
762 if (type.hasStaticShape()) {
763 llvm::ArrayRef<int64_t> shape_ref = type.getShape();
764 if (mlir::failed(check_shape(shape_ref))) return llvm::None;
765
766 shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
767 } else if (inst && IsConst(inst)) {
768 // Const op can have a result of dynamic shaped type (e.g. due to constant
769 // folding), but we can still derive the shape of a constant tensor for
770 // its attribute type.
771 mlir::Attribute tensor_attr = inst->getAttr("value");
772 llvm::ArrayRef<int64_t> shape_ref =
773 tensor_attr.getType().cast<TensorType>().getShape();
774 if (mlir::failed(check_shape(shape_ref))) return llvm::None;
775
776 shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
777 } else if (type.hasRank()) {
778 llvm::ArrayRef<int64_t> shape_ref = type.getShape();
779 if (mlir::failed(check_shape(shape_ref))) return llvm::None;
780
781 shape.reserve(shape_ref.size());
782 for (auto& dim : shape_ref) {
783 shape.push_back(dim == -1 ? 1 : dim);
784 }
785 shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
786 }
787
788 BufferOffset<tflite::SparsityParameters> s_params = 0;
789 if (auto* inst = value.getDefiningOp()) {
790 if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
791 s_params = BuildSparsityParameters(cst.s_param());
792 } else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
793 s_params = BuildSparsityParameters(cst.s_param());
794 }
795 }
796
797 Type element_type = type.getElementType();
798 tflite::TensorType tflite_element_type =
799 GetTFLiteType(type.getElementType()).ValueOrDie();
800
801 BufferOffset<tflite::QuantizationParameters> q_params;
802 if (auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
803 q_params = tflite::CreateQuantizationParameters(
804 // TODO(fengliuai): min and max values are not stored in the
805 // quantized type, so both are set to 0. The model couldn't be imported
806 // to TensorFlow because of this.
807 builder_, /*min=*/0, /*max=*/0,
808 builder_.CreateVector<float>({static_cast<float>(qtype.getScale())}),
809 builder_.CreateVector<int64_t>({qtype.getZeroPoint()}));
810 } else if (auto qtype =
811 element_type
812 .dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
813 std::vector<float> scales(qtype.getScales().begin(),
814 qtype.getScales().end());
815 q_params = tflite::CreateQuantizationParameters(
816 builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>(scales),
817 builder_.CreateVector<int64_t>(qtype.getZeroPoints()),
818 tflite::QuantizationDetails_NONE, /*details=*/0,
819 qtype.getQuantizedDimension());
820 } else if (quant_parameters.hasValue()) {
821 q_params = quant_parameters.getValue();
822 } else {
823 q_params = tflite::CreateQuantizationParameters(builder_);
824 }
825 // Check if the value's uses includes an op and usage at an operand index
826 // marked as a stateful. If so, set the tensor's is_variable as true
827 // This is v1 ref variable semantics in the TFLite runtime.
828 bool is_variable = false;
829 for (auto& use : value.getUses()) {
830 is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
831 if (is_variable) {
832 break;
833 }
834 }
835
836 if (shape_signature.empty()) {
837 return tflite::CreateTensor(
838 builder_, builder_.CreateVector(shape), tflite_element_type,
839 (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
840 /*is_variable=*/is_variable, s_params);
841 } else {
842 return tflite::CreateTensor(
843 builder_, builder_.CreateVector(shape), tflite_element_type,
844 (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
845 /*is_variable=*/is_variable, s_params,
846 /*shape_signature=*/builder_.CreateVector(shape_signature));
847 }
848 }
849
BuildIfOperator(mlir::TF::IfOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)850 BufferOffset<tflite::Operator> Translator::BuildIfOperator(
851 mlir::TF::IfOp op, const std::vector<int32_t>& operands,
852 const std::vector<int32_t>& results) {
853 auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF);
854 int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str());
855 int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str());
856 auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index,
857 else_subgraph_index)
858 .Union();
859 auto inputs = builder_.CreateVector(operands);
860 auto outputs = builder_.CreateVector(results);
861 return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
862 tflite::BuiltinOptions_IfOptions,
863 builtin_options);
864 }
865
BuildCallOnceOperator(mlir::TFL::CallOnceOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)866 BufferOffset<tflite::Operator> Translator::BuildCallOnceOperator(
867 mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
868 const std::vector<int32_t>& results) {
869 auto opcode_index =
870 GetOpcodeIndex("call_once", tflite::BuiltinOperator_CALL_ONCE);
871 int init_subgraph_index =
872 subgraph_index_map_.at(op.session_init_function().str());
873 auto builtin_options =
874 tflite::CreateCallOnceOptions(builder_, init_subgraph_index).Union();
875 auto inputs = builder_.CreateVector(operands);
876 auto outputs = builder_.CreateVector(results);
877 return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
878 tflite::BuiltinOptions_CallOnceOptions,
879 builtin_options);
880 }
881
BuildWhileOperator(mlir::TF::WhileOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)882 BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
883 mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
884 const std::vector<int32_t>& results) {
885 auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
886 int cond_subgraph_index = subgraph_index_map_.at(op.cond().str());
887 int body_subgraph_index = subgraph_index_map_.at(op.body().str());
888 auto builtin_options = tflite::CreateWhileOptions(
889 builder_, cond_subgraph_index, body_subgraph_index)
890 .Union();
891 auto inputs = builder_.CreateVector(operands);
892 auto outputs = builder_.CreateVector(results);
893 return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
894 tflite::BuiltinOptions_WhileOptions,
895 builtin_options);
896 }
897
BuildWhileOperator(mlir::TFL::WhileOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)898 Optional<BufferOffset<tflite::Operator>> Translator::BuildWhileOperator(
899 mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
900 const std::vector<int32_t>& results) {
901 auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
902 auto get_call_index = [&](mlir::Block& b) -> Optional<int> {
903 if (b.getOperations().size() != 2) return llvm::None;
904 if (auto call_op = dyn_cast<mlir::CallOp>(b.front()))
905 return subgraph_index_map_.at(call_op.callee().str());
906 return llvm::None;
907 };
908 auto body_subgraph_index = get_call_index(op.body().front());
909 auto cond_subgraph_index = get_call_index(op.cond().front());
910 if (!body_subgraph_index || !cond_subgraph_index)
911 return op.emitOpError("only single call cond/body while export supported"),
912 llvm::None;
913 auto builtin_options =
914 tflite::CreateWhileOptions(builder_, *cond_subgraph_index,
915 *body_subgraph_index)
916 .Union();
917 auto inputs = builder_.CreateVector(operands);
918 auto outputs = builder_.CreateVector(results);
919 return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
920 tflite::BuiltinOptions_WhileOptions,
921 builtin_options);
922 }
923
BuildNumericVerifyOperator(mlir::TFL::NumericVerifyOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)924 BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
925 mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
926 const std::vector<int32_t>& results) {
927 float tolerance = op.tolerance().convertToFloat();
928 bool log_if_failed = op.log_if_failed();
929 auto fbb = absl::make_unique<flexbuffers::Builder>();
930 fbb->Map([&]() {
931 fbb->Float("tolerance", tolerance);
932 fbb->Bool("log_if_failed", log_if_failed);
933 });
934 fbb->Finish();
935 auto f = std::unique_ptr<flexbuffers::Builder>(fbb.release());
936 auto custom_option = f->GetBuffer();
937 auto opcode_index =
938 GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM);
939 return tflite::CreateOperator(
940 builder_, opcode_index, builder_.CreateVector(operands),
941 builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
942 /*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_option),
943 tflite::CustomOptionsFormat_FLEXBUFFERS);
944 }
945
946 // Builds Assign/Read Variable ops.
947 template <typename T>
BuildVariableOperator(T op,const std::string & op_name,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)948 BufferOffset<tflite::Operator> Translator::BuildVariableOperator(
949 T op, const std::string& op_name, const std::vector<int32_t>& operands,
950 const std::vector<int32_t>& results) {
951 auto opcode_index = GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM);
952 return tflite::CreateOperator(
953 builder_, opcode_index, builder_.CreateVector(operands),
954 builder_.CreateVector(results), tflite::BuiltinOptions_NONE);
955 }
956
BuildCustomOperator(Operation * inst,mlir::TFL::CustomOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)957 BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
958 Operation* inst, mlir::TFL::CustomOp op,
959 const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
960 const std::string attrs =
961 op.custom_option().cast<mlir::OpaqueElementsAttr>().getValue().str();
962 std::vector<uint8_t> custom_option_vector(attrs.size());
963 memcpy(custom_option_vector.data(), attrs.data(), attrs.size());
964 auto opcode_index =
965 GetOpcodeIndex(op.custom_code().str(), tflite::BuiltinOperator_CUSTOM);
966 return tflite::CreateOperator(
967 builder_, opcode_index, builder_.CreateVector(operands),
968 builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
969 /*builtin_options=*/0,
970 builder_.CreateVector<uint8_t>(custom_option_vector),
971 tflite::CustomOptionsFormat_FLEXBUFFERS);
972 }
973
CreateFlexOpCustomOptions(const::tensorflow::NodeDef & node_def,const mlir::Location & loc)974 Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
975 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
976 std::string node_def_str;
977 if (!node_def.SerializeToString(&node_def_str)) {
978 return emitError(loc, "failed to serialize tensorflow node_def"),
979 llvm::None;
980 }
981
982 auto flex_builder = absl::make_unique<flexbuffers::Builder>();
983 flex_builder->Vector([&]() {
984 flex_builder->String(node_def.op());
985 flex_builder->String(node_def_str);
986 });
987 flex_builder->Finish();
988 return builder_.CreateVector(flex_builder->GetBuffer());
989 }
990
CreateCustomOpCustomOptions(const::tensorflow::NodeDef & node_def,const mlir::Location & loc)991 Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions(
992 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
993 auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc);
994 return builder_.CreateVector(flex_builder->GetBuffer());
995 }
996
997 std::unique_ptr<flexbuffers::Builder>
CreateFlexBuilderWithNodeAttrs(const::tensorflow::NodeDef & node_def,const mlir::Location & loc)998 Translator::CreateFlexBuilderWithNodeAttrs(
999 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
1000 auto flex_builder = absl::make_unique<flexbuffers::Builder>();
1001 size_t map_start = flex_builder->StartMap();
1002 using Item = std::pair<std::string, ::tensorflow::AttrValue>;
1003 std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end());
1004 std::sort(attrs.begin(), attrs.end(),
1005 [](Item& p1, Item& p2) -> bool { return p1.first < p2.first; });
1006 for (const Item& pair : attrs) {
1007 const char* key = pair.first.c_str();
1008 const ::tensorflow::AttrValue& attr = pair.second;
1009 switch (attr.value_case()) {
1010 case ::tensorflow::AttrValue::kS:
1011 flex_builder->String(key, attr.s());
1012 break;
1013 case ::tensorflow::AttrValue::kType: {
1014 auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type());
1015 if (status_or_tfl_type.ok()) {
1016 flex_builder->Int(key, status_or_tfl_type.ValueOrDie());
1017 } else {
1018 emitWarning(loc, "ignoring unsupported tensorflow type: ")
1019 << std::to_string(attr.type());
1020 }
1021 break;
1022 }
1023 case ::tensorflow::AttrValue::kI:
1024 flex_builder->Int(key, attr.i());
1025 break;
1026 case ::tensorflow::AttrValue::kF:
1027 flex_builder->Float(key, attr.f());
1028 break;
1029 case ::tensorflow::AttrValue::kB:
1030 flex_builder->Bool(key, attr.b());
1031 break;
1032 case tensorflow::AttrValue::kList:
1033 if (attr.list().s_size() > 0) {
1034 auto start = flex_builder->StartVector(key);
1035 for (const std::string& v : attr.list().s()) {
1036 flex_builder->Add(v);
1037 }
1038 flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1039 } else if (attr.list().i_size() > 0) {
1040 auto start = flex_builder->StartVector(key);
1041 for (const int64_t v : attr.list().i()) {
1042 flex_builder->Add(v);
1043 }
1044 flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1045 } else if (attr.list().f_size() > 0) {
1046 auto start = flex_builder->StartVector(key);
1047 for (const float v : attr.list().f()) {
1048 flex_builder->Add(v);
1049 }
1050 flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1051 } else {
1052 emitWarning(loc,
1053 "ignoring unsupported type in list attribute with key: ")
1054 << key;
1055 }
1056 break;
1057 default:
1058 emitWarning(loc, "ignoring unsupported attribute type with key: ")
1059 << key;
1060 break;
1061 }
1062 }
1063 flex_builder->EndMap(map_start);
1064 flex_builder->Finish();
1065 return flex_builder;
1066 }
1067
GetOpcodeIndex(const std::string & op_name,tflite::BuiltinOperator builtin)1068 uint32_t Translator::GetOpcodeIndex(const std::string& op_name,
1069 tflite::BuiltinOperator builtin) {
1070 auto it = opcode_index_map_.insert({op_name, 0});
1071
1072 // If the insert succeeded, the opcode has not been created already. Create a
1073 // new operator code and update its index value in the map.
1074 if (it.second) {
1075 it.first->second = opcodes_.size();
1076 auto custom_code = builtin == tflite::BuiltinOperator_CUSTOM
1077 ? builder_.CreateString(op_name)
1078 : BufferOffset<flatbuffers::String>();
1079 // Use version 0 for builtin op. This is a way to serialize version field to
1080 // flatbuffer (since 0 is non default) and it will be corrected later.
1081 int32_t op_version = builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1;
1082 opcodes_.push_back(CreateOperatorCode(builder_, /*builtin_code=*/builtin,
1083 custom_code, op_version));
1084 }
1085 return it.first->second;
1086 }
1087
BuildOperator(Operation * inst,std::vector<int32_t> operands,const std::vector<int32_t> & results,const std::vector<int32_t> & intermediates)1088 Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
1089 Operation* inst, std::vector<int32_t> operands,
1090 const std::vector<int32_t>& results,
1091 const std::vector<int32_t>& intermediates) {
1092 const auto* dialect = inst->getDialect();
1093 if (!dialect) {
1094 inst->emitOpError("dialect is not registered");
1095 return llvm::None;
1096 }
1097
1098 // TODO(b/149099381): Remove this once the kernels are promoted as
1099 // builtin TFLite kernels.
1100 // We export the Assign/Read variable ops as custom ops.
1101 if (auto read_op = llvm::dyn_cast<mlir::TFL::ReadVariableOp>(inst)) {
1102 return BuildVariableOperator<mlir::TFL::ReadVariableOp>(
1103 read_op, "ReadVariable", operands, results);
1104 } else if (auto assign_op =
1105 llvm::dyn_cast<mlir::TFL::AssignVariableOp>(inst)) {
1106 return BuildVariableOperator<mlir::TFL::AssignVariableOp>(
1107 assign_op, "AssignVariable", operands, results);
1108 }
1109
1110 // If TFLite built in op, create operator as a builtin op.
1111 if (dialect == tfl_dialect_) {
1112 // Only if built-in TFLite op emission is enabled, would legalization have
1113 // converted any TF->TFL.
1114 if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) {
1115 return inst->emitOpError(
1116 "is a TFLite builtin op but builtin emission is not enabled"),
1117 llvm::None;
1118 }
1119
1120 auto builtin_code = GetBuiltinOpCode(inst);
1121 if (!builtin_code) {
1122 if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
1123 return BuildNumericVerifyOperator(verify_op, operands, results);
1124 }
1125 if (auto custom_op = dyn_cast<mlir::TFL::CustomOp>(inst)) {
1126 return BuildCustomOperator(inst, custom_op, operands, results);
1127 }
1128 if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) {
1129 if (inst->getNumOperands() != inst->getNumResults()) {
1130 inst->emitOpError(
1131 "number of operands and results don't match, only canonical "
1132 "TFL While supported");
1133 return llvm::None;
1134 }
1135 return BuildWhileOperator(whileOp, operands, results);
1136 }
1137
1138 inst->emitOpError("is not a supported TFLite op");
1139 return llvm::None;
1140 }
1141
1142 if (*builtin_code == tflite::BuiltinOperator_CALL_ONCE) {
1143 if (auto initOp = dyn_cast<mlir::TFL::CallOnceOp>(inst)) {
1144 return BuildCallOnceOperator(initOp, operands, results);
1145 }
1146 }
1147
1148 std::string op_name = inst->getName().getStringRef().str();
1149 uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code);
1150
1151 // If this is TransposeConv we need to do a special case of ignoring the
1152 // optional tensor, to allow newly created models to run on old runtimes.
1153 if (*builtin_code == tflite::BuiltinOperator_TRANSPOSE_CONV) {
1154 if (operands.size() == 4 && operands.at(3) == -1) {
1155 operands.pop_back();
1156 }
1157 }
1158
1159 auto offset = CreateFlatBufferOperator(inst, opcode_index, operands,
1160 results, intermediates, &builder_);
1161 if (!offset) {
1162 inst->emitOpError("is not a supported TFLite op");
1163 }
1164 return offset;
1165 }
1166
1167 if (dialect == tf_dialect_) {
1168 if (auto ifOp = dyn_cast<mlir::TF::IfOp>(inst)) {
1169 return BuildIfOperator(ifOp, operands, results);
1170 } else if (auto whileOp = dyn_cast<mlir::TF::WhileOp>(inst)) {
1171 return BuildWhileOperator(whileOp, operands, results);
1172 }
1173
1174 CustomOptionsOffset custom_options;
1175
1176 // Ops in TF dialect can either be custom ops or flex ops.
1177 // The reason we go directly from TensorFlow dialect MLIR to tensorflow
1178 // node instead of going to TF table gen'd ops via generated code is that
1179 // we do not want to restrict custom and flex op conversion support to
1180 // only those TF ops that are currently registered in MLIR. The current
1181 // model is of an open op system.
1182 //
1183 // The following algorithm is followed:
1184 // if flex is enabled and the op is allowlisted as flex
1185 // we emit op as flex.
1186 // if custom is enabled
1187 // we emit the op as custom.
1188 auto node_def = GetTensorFlowNodeDef(inst);
1189 if (!node_def) {
1190 return llvm::None;
1191 }
1192
1193 std::string op_name = node_def->op();
1194 std::string op_desc = GetOpDescriptionForDebug(inst);
1195
1196 if (IsTFResourceOp(inst)) {
1197 resource_ops_[op_name].insert(op_desc);
1198 }
1199
1200 const bool is_allowed_flex_op =
1201 IsAllowlistedFlexOp(node_def->op()) ||
1202 ((select_user_tf_ops_.count(node_def->op()) != 0) &&
1203 (tensorflow::OpRegistry::Global()->LookUp(node_def->op()) != nullptr));
1204 // Flex op case
1205 // Eventually, the allowlist will go away and we will rely on some TF op
1206 // trait (e.g. No side effect) to determine if it is a supported "Flex"
1207 // op or not.
1208 if (is_allowed_flex_op && enabled_op_types_.contains(OpType::kSelectTf)) {
1209 // Construct ops as flex op encoding TensorFlow node definition
1210 // as custom options.
1211 // Flex ops are named with the kFlexOpNamePrefix prefix to the actual
1212 // TF op name.
1213 op_name = std::string(kFlexOpNamePrefix) + node_def->op();
1214 if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) {
1215 custom_options = *options;
1216 } else {
1217 return llvm::None;
1218 }
1219
1220 // Gather flex ops.
1221 flex_ops_[op_name].insert(op_desc);
1222 } else if (enabled_op_types_.contains(OpType::kCustomOp)) {
1223 // Generic case of custom ops - write using flex buffers since that
1224 // is the only custom options supported by TFLite today.
1225 op_name = node_def->op();
1226 if (auto options =
1227 CreateCustomOpCustomOptions(*node_def, inst->getLoc())) {
1228 custom_options = *options;
1229 } else {
1230 return llvm::None;
1231 }
1232
1233 // Gather custom ops.
1234 custom_ops_[op_name].insert(op_desc);
1235 } else {
1236 // Insert failed op to `flex_ops` or `custom_ops`.
1237 if (is_allowed_flex_op) {
1238 failed_flex_ops_[op_name].insert(op_desc);
1239 } else {
1240 failed_custom_ops_[op_name].insert(op_desc);
1241 }
1242 return inst->emitOpError("is neither a custom op nor a flex op"),
1243 llvm::None;
1244 }
1245
1246 uint32_t opcode_index =
1247 GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM);
1248 auto inputs = builder_.CreateVector(operands);
1249 auto outputs = builder_.CreateVector(results);
1250
1251 return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
1252 tflite::BuiltinOptions_NONE,
1253 /*builtin_options=*/0,
1254 /*custom_options=*/custom_options,
1255 tflite::CustomOptionsFormat_FLEXBUFFERS,
1256 /*mutating_variable_inputs=*/0);
1257 }
1258
1259 return inst->emitOpError(
1260 "is not any of a builtin TFLite op, a flex TensorFlow op or a "
1261 "custom TensorFlow op"),
1262 llvm::None;
1263 }
1264
InitializeNamesFromAttribute(FuncOp fn,bool * has_input_attr)1265 void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
1266 auto dict_attr = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1267 if (!dict_attr) return;
1268
1269 llvm::SmallVector<llvm::StringRef, 2> input_names;
1270 llvm::SmallVector<llvm::StringRef, 2> output_names;
1271 if (auto str = dict_attr.get("inputs").dyn_cast_or_null<mlir::StringAttr>()) {
1272 str.getValue().split(input_names, ',', /*MaxSplit=*/-1,
1273 /*KeepEmpty=*/false);
1274 if (input_names.size() != fn.getNumArguments()) {
1275 fn.emitWarning() << "invalid entry function specification";
1276 return;
1277 }
1278 for (auto it : llvm::enumerate(fn.getArguments())) {
1279 name_mapper_.InitOpName(it.value(), input_names[it.index()].trim());
1280 }
1281 *has_input_attr = true;
1282 }
1283
1284 if (auto str =
1285 dict_attr.get("outputs").dyn_cast_or_null<mlir::StringAttr>()) {
1286 str.getValue().split(output_names, ',', /*MaxSplit=*/-1,
1287 /*KeepEmpty=*/false);
1288 auto term = fn.back().getTerminator();
1289 if (output_names.size() != term->getNumOperands()) {
1290 fn.emitWarning() << "output names (" << output_names.size()
1291 << ") != terminator operands (" << term->getNumOperands()
1292 << ")";
1293 return;
1294 }
1295 for (const auto& it : llvm::enumerate(term->getOperands())) {
1296 name_mapper_.InitOpName(it.value(), output_names[it.index()].trim());
1297 }
1298 }
1299 }
1300
IsStatefulOperand(mlir::Operation * op,int operand_index)1301 bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
1302 std::vector<int> operand_indices;
1303 if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false;
1304 return absl::c_find(operand_indices, operand_index) != operand_indices.end();
1305 }
1306
1307 BufferOffset<tflite::QuantizationParameters>
GetQuantizationForQuantStatsOpOutput(mlir::quant::StatisticsOp stats_op)1308 Translator::GetQuantizationForQuantStatsOpOutput(
1309 mlir::quant::StatisticsOp stats_op) {
1310 auto layer_stats = stats_op.layerStats().cast<mlir::DenseFPElementsAttr>();
1311 Optional<mlir::ElementsAttr> axis_stats = stats_op.axisStats();
1312 Optional<uint64_t> axis = stats_op.axis();
1313 std::vector<float> mins, maxs;
1314 mlir::DenseFPElementsAttr min_max_attr =
1315 axis_stats.hasValue()
1316 ? axis_stats.getValue().cast<mlir::DenseFPElementsAttr>()
1317 : layer_stats;
1318
1319 for (auto index_and_value : llvm::enumerate(min_max_attr.getFloatValues())) {
1320 const llvm::APFloat value = index_and_value.value();
1321 if (index_and_value.index() % 2 == 0) {
1322 mins.push_back(value.convertToFloat());
1323 } else {
1324 maxs.push_back(value.convertToFloat());
1325 }
1326 }
1327
1328 return tflite::CreateQuantizationParameters(
1329 builder_, builder_.CreateVector<float>(mins),
1330 builder_.CreateVector<float>(maxs), /*scale=*/0, /*zero_point=*/0,
1331 tflite::QuantizationDetails_NONE, /*details=*/0,
1332 /*quantized_dimension=*/axis.hasValue() ? axis.getValue() : 0);
1333 }
1334
BuildSubGraph(const std::string & name,Region * region)1335 Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
1336 const std::string& name, Region* region) {
1337 bool has_input_attr = false;
1338 if (auto fn = dyn_cast<FuncOp>(region->getParentOp())) {
1339 InitializeNamesFromAttribute(fn, &has_input_attr);
1340 }
1341 std::vector<BufferOffset<tflite::Tensor>> tensors;
1342 llvm::DenseMap<Value, int> tensor_index_map;
1343
1344 // Builds tensor and buffer for argument or operation result. Returns false
1345 // on failure.
1346 auto build_tensor_and_buffer = [&](Value value, const std::string& name) {
1347 // NoneType represents optional and may be skipped here.
1348 if (value.getType().isa<NoneType>()) {
1349 return true;
1350 }
1351
1352 tensor_index_map.insert({value, tensors.size()});
1353 tensor_index_map_[name] = tensors.size();
1354 Optional<BufferOffset<tflite::QuantizationParameters>> quant_parameters;
1355 if (value.hasOneUse()) {
1356 auto stats_op =
1357 llvm::dyn_cast<mlir::quant::StatisticsOp>(*value.user_begin());
1358 if (stats_op) {
1359 quant_parameters = GetQuantizationForQuantStatsOpOutput(stats_op);
1360 }
1361 }
1362 auto tensor_or =
1363 BuildTensor(value, name, buffers_.size(), quant_parameters);
1364 if (!tensor_or) return false;
1365 tensors.push_back(*tensor_or);
1366
1367 // TODO(ashwinm): Check if for stateful tensors, if it is also needed to
1368 // make the Buffer empty apart from setting the buffer_idx=0 in the
1369 // Tensor. This does not seem to affect runtime behavior for RNN/LSTM,
1370 // but would be good for reducing memory footprint.
1371 if (auto* inst = value.getDefiningOp()) {
1372 auto buffer_or = BuildBuffer(inst);
1373 if (!buffer_or) return false;
1374 buffers_.push_back(*buffer_or);
1375 } else {
1376 buffers_.push_back(empty_buffer_);
1377 }
1378 return true;
1379 };
1380
1381 std::vector<BufferOffset<tflite::Operator>> operators;
1382 auto& bb = region->front();
1383
1384 // Main function's arguments are first passed to `input` op so they don't
1385 // have associated tensor and buffer. Build FlatBuffer tensor and buffer for
1386 // other functions.
1387 for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) {
1388 mlir::BlockArgument arg = bb.getArgument(i);
1389 std::string name;
1390 if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg));
1391 if (name.empty()) name = absl::StrCat("arg", i);
1392 if (!build_tensor_and_buffer(arg, name)) return llvm::None;
1393 }
1394
1395 bool failed_once = false;
1396 for (auto& inst : bb) {
1397 if (inst.hasTrait<mlir::OpTrait::IsTerminator>()) break;
1398 // For "quant.stats" op, it's used to store the quantization parameters info
1399 // and its output should be then replaced by its input value.
1400 if (auto quant_stats_op = llvm::dyn_cast<mlir::quant::StatisticsOp>(inst)) {
1401 continue;
1402 }
1403 std::vector<int32_t> intermediates;
1404 // Build intermediate tensors for tfl.lstm and insert these tensors into
1405 // flatbuffer.
1406 if (llvm::isa<mlir::TFL::LSTMOp, mlir::TFL::UnidirectionalSequenceLSTMOp>(
1407 inst)) {
1408 std::vector<std::string> intermediate_names = {
1409 "input_to_input_intermediate", "input_to_forget_intermediate",
1410 "input_to_cell_intermediate", "input_to_output_intermediate",
1411 "effective_hidden_scale_intermediate"};
1412 for (const std::string& intermediate : intermediate_names) {
1413 auto intermediate_attr = inst.getAttr(intermediate);
1414 if (auto attr = intermediate_attr.dyn_cast_or_null<mlir::TypeAttr>()) {
1415 Type qtype = attr.getValue();
1416 auto tensor_or = BuildTensorFromType(
1417 qtype, name_mapper_.GetUniqueName(intermediate).str());
1418 if (!tensor_or.hasValue()) {
1419 continue;
1420 } else {
1421 intermediates.push_back(tensors.size());
1422 tensors.push_back(tensor_or.getValue());
1423 }
1424 }
1425 }
1426 }
1427
1428 for (auto val : inst.getResults()) {
1429 std::string name = UniqueName(val);
1430 // For "tfl.numeric_verify" op, the name is used to find out the original
1431 // activation tensor rather than its own unique name in the visualization
1432 // or debugging tools.
1433 auto builtin_code = GetBuiltinOpCode(&inst);
1434 if (!builtin_code && dyn_cast<mlir::TFL::NumericVerifyOp>(&inst)) {
1435 // The first operand is the quantized activation, the target of this
1436 // NumericVerify op.
1437 auto quantized_op_val = inst.getOperands().front();
1438 name = "NumericVerify/" + UniqueName(quantized_op_val) + ":" +
1439 std::to_string(tensor_index_map[quantized_op_val]);
1440 }
1441 if (!build_tensor_and_buffer(val, name)) return llvm::None;
1442 }
1443
1444 // Skip constant ops as they don't represent a TFLite operator.
1445 if (IsConst(&inst)) continue;
1446
1447 // Fetch operand and result tensor indices.
1448 std::vector<int32_t> results;
1449 results.reserve(inst.getNumResults());
1450 for (auto result : inst.getResults()) {
1451 results.push_back(tensor_index_map.lookup(result));
1452 }
1453 Operation* real_inst = &inst;
1454 // CustomTfOp is just a wrapper around a TF op, we export the custom Op
1455 // not the wrapper, so we fetch the op from the region.
1456 if (auto custom_op = dyn_cast<mlir::TFL::CustomTfOp>(inst)) {
1457 // If we have custom op with a region, then use the first op in the
1458 // region, if it exists, otherwise just use params for custom op.
1459 if (!custom_op.body().empty()) {
1460 real_inst = &custom_op.body().front().front();
1461 } else {
1462 module_.emitError(
1463 "Invalid CustomTfOp: Custom TF Op have empty region.");
1464 }
1465 }
1466 std::vector<int32_t> operands;
1467 operands.reserve(real_inst->getNumOperands());
1468 for (auto operand : real_inst->getOperands()) {
1469 if (operand.getType().isa<NoneType>())
1470 operands.push_back(kTfLiteOptionalTensor);
1471 else if (auto stats_op =
1472 llvm::dyn_cast_or_null<mlir::quant::StatisticsOp>(
1473 operand.getDefiningOp()))
1474 operands.push_back(tensor_index_map.lookup(stats_op.arg()));
1475 else
1476 operands.push_back(tensor_index_map.lookup(operand));
1477 }
1478
1479 if (auto tfl_operator =
1480 BuildOperator(real_inst, operands, results, intermediates))
1481 operators.push_back(*tfl_operator);
1482 else
1483 failed_once = true;
1484 }
1485
1486 if (failed_once) return llvm::None;
1487
1488 // Get input and output tensor indices for the subgraph.
1489 std::vector<int32_t> inputs, outputs;
1490 for (auto arg : bb.getArguments()) {
1491 inputs.push_back(tensor_index_map[arg]);
1492 }
1493 for (auto result : bb.getTerminator()->getOperands()) {
1494 outputs.push_back(tensor_index_map[result]);
1495 }
1496
1497 return tflite::CreateSubGraph(
1498 builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs),
1499 builder_.CreateVector(outputs), builder_.CreateVector(operators),
1500 /*name=*/builder_.CreateString(name));
1501 }
1502
BuildMetadata(StringRef name,StringRef content)1503 BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
1504 StringRef content) {
1505 auto buffer_index = buffers_.size();
1506 auto buffer_data = builder_.CreateVector(
1507 reinterpret_cast<const uint8_t*>(content.data()), content.size());
1508 buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data));
1509 return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index);
1510 }
1511
1512 Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
CreateMetadataVector()1513 Translator::CreateMetadataVector() {
1514 auto dict_attr = module_->getAttrOfType<mlir::DictionaryAttr>("tfl.metadata");
1515 std::vector<BufferOffset<tflite::Metadata>> metadata;
1516 if (dict_attr) {
1517 for (const auto& named_attr : dict_attr) {
1518 StringRef name = named_attr.first;
1519 mlir::Attribute attr = named_attr.second;
1520 if (auto content = attr.dyn_cast<StringAttr>()) {
1521 metadata.push_back(BuildMetadata(name, content.getValue()));
1522 } else {
1523 module_.emitError(
1524 "all values in tfl.metadata's dictionary key-value pairs should be "
1525 "string attributes");
1526 return llvm::None;
1527 }
1528 }
1529 }
1530 // Runtime version string is generated after we update the op
1531 // versions. Here we put a 16-byte dummy string as a placeholder. We choose
1532 // 16-byte because it's the alignment of buffers in flatbuffer, so it won't
1533 // cause any waste of space if the actual string is shorter than 16 bytes.
1534 metadata.push_back(
1535 BuildMetadata("min_runtime_version", std::string(16, '\0')));
1536 return builder_.CreateVector(metadata);
1537 }
1538
1539 // Helper method that returns list of all strings in a StringAttr identified
1540 // by 'attr_key' and values are separated by a comma.
GetStringsFromAttrWithSeparator(mlir::DictionaryAttr attr,const std::string & attr_key)1541 llvm::SmallVector<llvm::StringRef, 2> GetStringsFromAttrWithSeparator(
1542 mlir::DictionaryAttr attr, const std::string& attr_key) {
1543 llvm::SmallVector<llvm::StringRef, 2> result;
1544 if (auto str = attr.get(attr_key).dyn_cast_or_null<mlir::StringAttr>()) {
1545 str.getValue().split(result, ',', /*MaxSplit=*/-1,
1546 /*KeepEmpty=*/false);
1547 }
1548 return result;
1549 }
1550
1551 // Helper method that return list of string for all the StringAttr in the
1552 // Attribute identified by 'attr_name'.
GetStringsFromDictionaryAttr(const llvm::SmallVector<mlir::DictionaryAttr,4> & dict_attrs,const std::string & attr_name)1553 std::vector<std::string> GetStringsFromDictionaryAttr(
1554 const llvm::SmallVector<mlir::DictionaryAttr, 4>& dict_attrs,
1555 const std::string& attr_name) {
1556 std::vector<std::string> result;
1557 for (const auto& arg_attr : dict_attrs) {
1558 if (!arg_attr) continue;
1559
1560 auto attrs = arg_attr.getValue();
1561 for (const auto attr : attrs) {
1562 if (attr.first.str() == attr_name) {
1563 auto array_attr = attr.second.dyn_cast_or_null<mlir::ArrayAttr>();
1564 if (!array_attr || array_attr.empty()) continue;
1565 auto string_attr = array_attr[0].dyn_cast_or_null<mlir::StringAttr>();
1566 if (!string_attr) continue;
1567 result.push_back(string_attr.getValue().str());
1568 }
1569 }
1570 }
1571 return result;
1572 }
1573
BuildSignaturedef(FuncOp main_op,const std::string & saved_model_tag)1574 std::vector<SignatureDefData> BuildSignaturedef(
1575 FuncOp main_op, const std::string& saved_model_tag) {
1576 static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path";
1577 static const char kEntryFunctionAttributes[] = "tf.entry_function";
1578
1579 // Fetch inputs and outputs from the signature.
1580 llvm::SmallVector<mlir::DictionaryAttr, 4> arg_attrs, res_attrs;
1581 main_op.getAllArgAttrs(arg_attrs);
1582 main_op.getAllResultAttrs(res_attrs);
1583 std::vector<std::string> sig_def_inputs =
1584 GetStringsFromDictionaryAttr(arg_attrs, kSignatureDefIndexPath);
1585 std::vector<std::string> sig_def_outputs =
1586 GetStringsFromDictionaryAttr(res_attrs, kSignatureDefIndexPath);
1587
1588 // If no defined saved model signature, then return empty list.
1589 // This can happen when we are converting model not from SavedModel.
1590 if (sig_def_inputs.empty() || sig_def_outputs.empty()) return {};
1591
1592 // Fetch function inputs and outputs tensor names.
1593 auto dict_attr =
1594 main_op->getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes);
1595 if (!dict_attr) return {};
1596
1597 // Get Input and output tensor names from attribute.
1598 llvm::SmallVector<llvm::StringRef, 2> input_names =
1599 GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs");
1600 llvm::SmallVector<llvm::StringRef, 2> output_names =
1601 GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs");
1602
1603 // Verify input size match the number of arguments.
1604 if (input_names.size() != main_op.getNumArguments()) {
1605 main_op.emitWarning() << "invalid entry function specification";
1606 return {};
1607 }
1608 // Verify output size match the number of arguments.
1609 auto term = main_op.back().getTerminator();
1610 if (output_names.size() != term->getNumOperands()) {
1611 main_op.emitWarning() << "output names (" << output_names.size()
1612 << ") != terminator operands ("
1613 << term->getNumOperands() << ")";
1614 return {};
1615 }
1616 // Verify number of tensors for inputs and outputs matches size
1617 // of the list in the signature def.
1618 if (input_names.size() != sig_def_inputs.size() ||
1619 output_names.size() != sig_def_outputs.size()) {
1620 main_op.emitWarning(
1621 "Mismatch between signature def inputs/outputs and main function "
1622 "arguments.");
1623 return {};
1624 }
1625 // Exported method name.
1626 auto exported_name =
1627 main_op->getAttrOfType<mlir::ArrayAttr>("tf_saved_model.exported_names");
1628 if (exported_name.empty()) {
1629 main_op.emitError("Empty exported names for main Function");
1630 return {};
1631 }
1632 // Fill the SignatureDefData container.
1633 // We create vector of size 1 as TFLite now supports only 1 signatureDef.
1634 std::vector<SignatureDefData> result(1);
1635 for (int i = 0; i < input_names.size(); ++i) {
1636 result[0].inputs[sig_def_inputs[i]] = input_names[i].str();
1637 }
1638 for (int i = 0; i < output_names.size(); ++i) {
1639 result[0].outputs[sig_def_outputs[i]] = output_names[i].str();
1640 }
1641 if (auto name_attr = exported_name[0].dyn_cast_or_null<StringAttr>())
1642 result[0].method_name = name_attr.getValue().str();
1643 result[0].signature_def_key = saved_model_tag;
1644 return result;
1645 }
1646
GetList(const std::map<std::string,std::string> & items)1647 std::vector<BufferOffset<tflite::TensorMap>> Translator::GetList(
1648 const std::map<std::string, std::string>& items) {
1649 std::vector<BufferOffset<tflite::TensorMap>> result;
1650 for (const auto& item : items) {
1651 auto name_buf = builder_.CreateString(item.first);
1652 tflite::TensorMapBuilder tensor_map_builder(builder_);
1653 tensor_map_builder.add_name(name_buf);
1654 tensor_map_builder.add_tensor_index(tensor_index_map_[item.second]);
1655 result.push_back(tensor_map_builder.Finish());
1656 }
1657 return result;
1658 }
1659
1660 Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
CreateSignatureDefs(const std::vector<SignatureDefData> & signature_defs)1661 Translator::CreateSignatureDefs(
1662 const std::vector<SignatureDefData>& signature_defs) {
1663 std::vector<BufferOffset<tflite::SignatureDef>> signature_defs_buffer;
1664 for (const auto& signature_def_data : signature_defs) {
1665 auto inputs = GetList(signature_def_data.inputs);
1666 auto outputs = GetList(signature_def_data.outputs);
1667 auto inputs_buf = builder_.CreateVector(inputs);
1668 auto outputs_buf = builder_.CreateVector(outputs);
1669 auto method_name_buf =
1670 builder_.CreateString(signature_def_data.method_name);
1671 auto signature_def_key_buf =
1672 builder_.CreateString(signature_def_data.signature_def_key);
1673 tflite::SignatureDefBuilder sig_def_builder(builder_);
1674 sig_def_builder.add_inputs(inputs_buf);
1675 sig_def_builder.add_outputs(outputs_buf);
1676 sig_def_builder.add_method_name(method_name_buf);
1677 sig_def_builder.add_key(signature_def_key_buf);
1678 signature_defs_buffer.push_back(sig_def_builder.Finish());
1679 }
1680
1681 return builder_.CreateVector(signature_defs_buffer);
1682 }
1683
UpdateEntryFunction(ModuleOp module)1684 bool UpdateEntryFunction(ModuleOp module) {
1685 if (module.lookupSymbol<FuncOp>("main") != nullptr) {
1686 // We already have an entry function.
1687 return true;
1688 }
1689
1690 int entry_func_count = 0;
1691 FuncOp entry_func = nullptr;
1692 for (auto fn : module.getOps<FuncOp>()) {
1693 auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1694 if (attrs && !attrs.empty()) {
1695 entry_func_count++;
1696 entry_func = fn;
1697 }
1698 }
1699
1700 // We should have one & only have one entry function.
1701 if (entry_func_count != 1) return false;
1702
1703 // Update the entry func to main.
1704 entry_func.setName("main");
1705 return true;
1706 }
1707
Translate(ModuleOp module,bool emit_builtin_tflite_ops,bool emit_select_tf_ops,bool emit_custom_ops,const std::unordered_set<std::string> & select_user_tf_ops,const std::unordered_set<std::string> & tags,OpOrArgNameMapper * op_or_arg_name_mapper)1708 Optional<std::string> Translator::Translate(
1709 ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
1710 bool emit_custom_ops,
1711 const std::unordered_set<std::string>& select_user_tf_ops,
1712 const std::unordered_set<std::string>& tags,
1713 OpOrArgNameMapper* op_or_arg_name_mapper) {
1714 OpOrArgLocNameMapper default_op_or_arg_name_mapper;
1715 if (!op_or_arg_name_mapper)
1716 op_or_arg_name_mapper = &default_op_or_arg_name_mapper;
1717 if (!UpdateEntryFunction(module)) return llvm::None;
1718 if (!IsValidTFLiteMlirModule(module)) return llvm::None;
1719 Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
1720 emit_custom_ops, select_user_tf_ops, tags,
1721 op_or_arg_name_mapper);
1722 return translator.TranslateInternal();
1723 }
1724
TranslateInternal()1725 Optional<std::string> Translator::TranslateInternal() {
1726 // A list of named regions in the module with main function being the first in
1727 // the list. The main function is required as the first subgraph in the model
1728 // is entry point for the model.
1729 std::vector<std::pair<std::string, Region*>> named_regions;
1730 named_regions.reserve(std::distance(module_.begin(), module_.end()));
1731
1732 int subgraph_idx = 0;
1733 FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
1734 subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++;
1735 named_regions.emplace_back("main", &main_fn.getBody());
1736 // Walk over the module collection ops with functions and while ops.
1737 module_.walk([&](FuncOp fn) {
1738 if (fn != main_fn) {
1739 subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
1740 named_regions.emplace_back(fn.getName().str(), &fn.getBody());
1741 }
1742 });
1743
1744 // Build subgraph for each of the named regions.
1745 std::vector<BufferOffset<tflite::SubGraph>> subgraphs;
1746 subgraphs.reserve(named_regions.size());
1747 int first_failed_func = -1;
1748 for (auto it : llvm::enumerate(named_regions)) {
1749 auto subgraph_or = BuildSubGraph(it.value().first, it.value().second);
1750 if (!subgraph_or) {
1751 if (first_failed_func == -1)
1752 // Record the index of the first region that cannot be converted.
1753 // Keep looping through all subgraphs in the module to make sure that
1754 // we collect the list of missing ops from the entire module.
1755 first_failed_func = it.index();
1756 } else {
1757 subgraphs.push_back(*subgraph_or);
1758 }
1759 }
1760
1761 if (!resource_ops_.empty()) {
1762 std::string resource_ops_summary =
1763 GetOpsSummary(resource_ops_, /*summary_title=*/"Resource");
1764 LOG(WARNING) << "Graph contains the following resource op(s), that use(s) "
1765 "resource type. Currently, the "
1766 "resource type is not natively supported in TFLite. Please "
1767 "consider not using the resource type if there are issues "
1768 "with either TFLite converter or TFLite runtime:\n"
1769 << resource_ops_summary;
1770 }
1771
1772 if (!flex_ops_.empty()) {
1773 std::string flex_ops_summary =
1774 GetOpsSummary(flex_ops_, /*summary_title=*/"Flex");
1775 LOG(WARNING) << "TFLite interpreter needs to link Flex delegate in order "
1776 "to run the model since it contains the following flex "
1777 "op(s):\n"
1778 << flex_ops_summary;
1779 }
1780
1781 if (!custom_ops_.empty()) {
1782 std::string custom_ops_summary =
1783 GetOpsSummary(custom_ops_, /*summary_title=*/"Custom");
1784 LOG(WARNING) << "The following operation(s) need TFLite custom op "
1785 "implementation(s):\n"
1786 << custom_ops_summary;
1787 }
1788
1789 if (first_failed_func != -1) {
1790 std::string failed_flex_ops_summary =
1791 GetOpsSummary(failed_flex_ops_, /*summary_title=*/"TF Select");
1792 std::string failed_custom_ops_summary =
1793 GetOpsSummary(failed_custom_ops_, /*summary_title=*/"Custom");
1794 std::string err;
1795 if (!failed_flex_ops_.empty())
1796 err +=
1797 "\nSome ops are not supported by the native TFLite runtime, you can "
1798 "enable TF kernels fallback using TF Select. See instructions: "
1799 "https://www.tensorflow.org/lite/guide/ops_select \n" +
1800 failed_flex_ops_summary + "\n";
1801 if (!failed_custom_ops_.empty())
1802 err +=
1803 "\nSome ops in the model are custom ops, "
1804 "See instructions to implement "
1805 "custom ops: https://www.tensorflow.org/lite/guide/ops_custom \n" +
1806 failed_custom_ops_summary + "\n";
1807
1808 auto& failed_region = named_regions[first_failed_func];
1809 return failed_region.second->getParentOp()->emitError()
1810 << "failed while converting: '" << failed_region.first
1811 << "': " << err,
1812 llvm::None;
1813 }
1814
1815 std::string model_description;
1816 if (auto attr = module_->getAttrOfType<StringAttr>("tfl.description")) {
1817 model_description = attr.getValue().str();
1818 } else {
1819 model_description = "MLIR Converted.";
1820 }
1821
1822 // Build the model and finish the model building process.
1823 auto description = builder_.CreateString(model_description.data());
1824 VectorBufferOffset<int32_t> metadata_buffer = 0; // Deprecated
1825 auto metadata = CreateMetadataVector();
1826 if (!metadata) return llvm::None;
1827
1828 // Build SignatureDef
1829 // We only have 1 entry point 'main' function, so build only 1 signature def.
1830 auto main_fn_signature_def = BuildSignaturedef(
1831 main_fn, saved_model_tags_.empty() ? "" : *saved_model_tags_.begin());
1832 auto signature_defs = CreateSignatureDefs(main_fn_signature_def);
1833
1834 auto model = tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
1835 builder_.CreateVector(opcodes_),
1836 builder_.CreateVector(subgraphs),
1837 description, builder_.CreateVector(buffers_),
1838 metadata_buffer, *metadata, *signature_defs);
1839 tflite::FinishModelBuffer(builder_, model);
1840 tflite::UpdateOpVersion(builder_.GetBufferPointer());
1841 tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer());
1842
1843 // Return serialized string for the built FlatBuffer.
1844 return std::string(reinterpret_cast<const char*>(builder_.GetBufferPointer()),
1845 builder_.GetSize());
1846 }
1847
BuildSparsityParameters(const mlir::TFL::SparsityParameterAttr & s_attr)1848 BufferOffset<tflite::SparsityParameters> Translator::BuildSparsityParameters(
1849 const mlir::TFL::SparsityParameterAttr& s_attr) {
1850 const int dim_size = s_attr.dim_metadata().size();
1851 std::vector<flatbuffers::Offset<tflite::DimensionMetadata>> fb_dim_metadata(
1852 dim_size);
1853 for (int i = 0; i < dim_size; i++) {
1854 const auto dim_metadata =
1855 s_attr.dim_metadata()[i].dyn_cast<mlir::TFL::DimensionMetadataAttr>();
1856 if (dim_metadata.format().getValue() == "DENSE") {
1857 fb_dim_metadata[i] =
1858 tflite::CreateDimensionMetadata(builder_, tflite::DimensionType_DENSE,
1859 dim_metadata.dense_size().getInt());
1860
1861 } else {
1862 auto segments = dim_metadata.segments();
1863 std::vector<int> vector_segments(segments.size(), 0);
1864 for (int j = 0, end = segments.size(); j < end; j++) {
1865 vector_segments[j] = segments[j].dyn_cast<mlir::IntegerAttr>().getInt();
1866 }
1867 tflite::SparseIndexVector segments_type;
1868 BufferOffset<void> array_segments;
1869 // The segment array is sorted.
1870 // TODO(b/147449640): Clean this up with util functions.
1871 int max_of_segments = vector_segments[segments.size() - 1];
1872 if (max_of_segments <= UINT8_MAX) {
1873 segments_type = tflite::SparseIndexVector_Uint8Vector;
1874 std::vector<uint8_t> uint8_vector(vector_segments.begin(),
1875 vector_segments.end());
1876 array_segments = tflite::CreateUint8Vector(
1877 builder_, builder_.CreateVector(uint8_vector))
1878 .Union();
1879 } else if (max_of_segments <= UINT16_MAX) {
1880 segments_type = tflite::SparseIndexVector_Uint16Vector;
1881 std::vector<uint16_t> uint16_vector(vector_segments.begin(),
1882 vector_segments.end());
1883 array_segments = tflite::CreateUint16Vector(
1884 builder_, builder_.CreateVector(uint16_vector))
1885 .Union();
1886 } else {
1887 segments_type = tflite::SparseIndexVector_Int32Vector;
1888 array_segments = tflite::CreateInt32Vector(
1889 builder_, builder_.CreateVector(vector_segments))
1890 .Union();
1891 }
1892
1893 auto indices = dim_metadata.indices();
1894 std::vector<int> vector_indices(indices.size(), 0);
1895 int max_of_indices = 0;
1896 for (int j = 0, end = indices.size(); j < end; j++) {
1897 vector_indices[j] = indices[j].dyn_cast<mlir::IntegerAttr>().getInt();
1898 if (vector_indices[j] > max_of_indices) {
1899 max_of_indices = vector_indices[j];
1900 }
1901 }
1902 tflite::SparseIndexVector indices_type;
1903 BufferOffset<void> array_indices;
1904 if (max_of_indices <= UINT8_MAX) {
1905 indices_type = tflite::SparseIndexVector_Uint8Vector;
1906 std::vector<uint8_t> uint8_vector(vector_indices.begin(),
1907 vector_indices.end());
1908 array_indices = tflite::CreateUint8Vector(
1909 builder_, builder_.CreateVector(uint8_vector))
1910 .Union();
1911 } else if (max_of_indices <= UINT16_MAX) {
1912 indices_type = tflite::SparseIndexVector_Uint16Vector;
1913 std::vector<uint16_t> uint16_vector(vector_indices.begin(),
1914 vector_indices.end());
1915 array_indices = tflite::CreateUint16Vector(
1916 builder_, builder_.CreateVector(uint16_vector))
1917 .Union();
1918 } else {
1919 indices_type = tflite::SparseIndexVector_Int32Vector;
1920 array_indices = tflite::CreateInt32Vector(
1921 builder_, builder_.CreateVector(vector_indices))
1922 .Union();
1923 }
1924
1925 fb_dim_metadata[i] = tflite::CreateDimensionMetadata(
1926 builder_, tflite::DimensionType_SPARSE_CSR, 0, segments_type,
1927 array_segments, indices_type, array_indices);
1928 }
1929 }
1930
1931 std::vector<int> traversal_order(dim_size);
1932 for (int i = 0; i < dim_size; i++) {
1933 traversal_order[i] =
1934 s_attr.traversal_order()[i].dyn_cast<mlir::IntegerAttr>().getInt();
1935 }
1936 const int block_map_size = s_attr.block_map().size();
1937 std::vector<int> block_map(block_map_size);
1938 for (int i = 0; i < block_map_size; i++) {
1939 block_map[i] = s_attr.block_map()[i].dyn_cast<mlir::IntegerAttr>().getInt();
1940 }
1941
1942 return tflite::CreateSparsityParameters(
1943 builder_, builder_.CreateVector(traversal_order),
1944 builder_.CreateVector(block_map), builder_.CreateVector(fb_dim_metadata));
1945 }
1946
1947 } // namespace
1948
1949 namespace tflite {
1950 // TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting
1951 // the following:
1952 //
1953 // * Quantization
1954 // * Ops with variable tensors
1955 //
MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,const FlatbufferExportOptions & options,std::string * serialized_flatbuffer)1956 bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
1957 const FlatbufferExportOptions& options,
1958 std::string* serialized_flatbuffer) {
1959 auto maybe_translated = Translator::Translate(
1960 module, options.emit_builtin_tflite_ops, options.emit_select_tf_ops,
1961 options.emit_custom_ops, options.select_user_tf_ops,
1962 options.saved_model_tags, options.op_or_arg_name_mapper);
1963 if (!maybe_translated) return false;
1964 *serialized_flatbuffer = std::move(*maybe_translated);
1965 return true;
1966 }
1967
1968 } // namespace tflite
1969