1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
17
18 #include <stddef.h>
19 #include <stdlib.h>
20
21 #include <cstdint>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/base/attributes.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/strings/match.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/strings/string_view.h"
35 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
36 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
37 #include "llvm/ADT/ArrayRef.h"
38 #include "llvm/ADT/DenseMap.h"
39 #include "llvm/ADT/None.h"
40 #include "llvm/ADT/Optional.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/StringRef.h"
43 #include "llvm/Support/Casting.h"
44 #include "llvm/Support/CommandLine.h"
45 #include "llvm/Support/FormatVariadic.h"
46 #include "llvm/Support/ToolOutputFile.h"
47 #include "llvm/Support/raw_ostream.h"
48 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
49 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
50 #include "mlir/IR/Attributes.h" // from @llvm-project
51 #include "mlir/IR/Builders.h" // from @llvm-project
52 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
53 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
54 #include "mlir/IR/Location.h" // from @llvm-project
55 #include "mlir/IR/MLIRContext.h" // from @llvm-project
56 #include "mlir/IR/Operation.h" // from @llvm-project
57 #include "mlir/IR/Types.h" // from @llvm-project
58 #include "mlir/IR/Value.h" // from @llvm-project
59 #include "mlir/Support/LogicalResult.h" // from @llvm-project
60 #include "mlir/Translation.h" // from @llvm-project
61 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
62 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
63 #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
64 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
65 #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
66 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
70 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
71 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
72 #include "tensorflow/compiler/xla/statusor.h"
73 #include "tensorflow/core/framework/attr_value.pb.h"
74 #include "tensorflow/core/framework/node_def.pb.h"
75 #include "tensorflow/core/platform/errors.h"
76 #include "tensorflow/core/platform/logging.h"
77 #include "tensorflow/core/platform/status.h"
78 #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
79 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
80 #include "tensorflow/lite/schema/schema_conversion_utils.h"
81 #include "tensorflow/lite/schema/schema_generated.h"
82 #include "tensorflow/lite/string_util.h"
83 #include "tensorflow/lite/tools/versioning/op_version.h"
84 #include "tensorflow/lite/tools/versioning/runtime_version.h"
85 #include "tensorflow/lite/version.h"
86
87 using llvm::dyn_cast;
88 using llvm::formatv;
89 using llvm::isa;
90 using llvm::Optional;
91 using llvm::StringRef;
92 using llvm::Twine;
93 using mlir::Dialect;
94 using mlir::ElementsAttr;
95 using mlir::FuncOp;
96 using mlir::MLIRContext;
97 using mlir::ModuleOp;
98 using mlir::NoneType;
99 using mlir::Operation;
100 using mlir::Region;
101 using mlir::StringAttr;
102 using mlir::TensorType;
103 using mlir::Type;
104 using mlir::UnknownLoc;
105 using mlir::Value;
106 using mlir::WalkResult;
107 using tensorflow::OpOrArgLocNameMapper;
108 using tensorflow::OpOrArgNameMapper;
109 using tensorflow::Status;
110 using tflite::flex::IsAllowlistedFlexOp;
111 using xla::StatusOr;
112
113 template <typename T>
114 using BufferOffset = flatbuffers::Offset<T>;
115
116 template <typename T>
117 using VectorBufferOffset = flatbuffers::Offset<flatbuffers::Vector<T>>;
118
119 using CustomOptionsOffset = VectorBufferOffset<uint8_t>;
120
121 namespace error = tensorflow::error;
122 namespace tfl = mlir::TFL;
123
124 ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex";
125
126 // Use initial buffer size in flatbuffer builder to be same as the initial size
127 // used by the TOCO export. (It does not explain rationale for this choice.)
128 constexpr size_t kInitialBufferSize = 10240;
129
130 // Set `isSigned` to false if the `type` is an 8-bit unsigned integer type.
131 // Since tflite doesn't support unsigned for other types, returns error if
132 // `isSigned` is set to false for other types.
GetTFLiteType(Type type,bool is_signed=true)133 static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
134 bool is_signed = true) {
135 if (!is_signed && type.isSignlessInteger(8)) {
136 return tflite::TensorType_UINT8;
137 }
138 if (!is_signed) {
139 return Status(error::INVALID_ARGUMENT,
140 "'isSigned' can only be set for 8-bits integer type");
141 }
142
143 if (type.isF32()) {
144 return tflite::TensorType_FLOAT32;
145 } else if (type.isF16()) {
146 return tflite::TensorType_FLOAT16;
147 } else if (type.isF64()) {
148 return tflite::TensorType_FLOAT64;
149 } else if (type.isa<mlir::TF::StringType>()) {
150 return tflite::TensorType_STRING;
151 } else if (type.isa<mlir::TF::Quint8Type>()) {
152 return tflite::TensorType_UINT8;
153 } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
154 auto ftype = complex_type.getElementType();
155 if (ftype.isF32()) {
156 return tflite::TensorType_COMPLEX64;
157 }
158 if (ftype.isF64()) {
159 return tflite::TensorType_COMPLEX128;
160 }
161 return Status(error::INVALID_ARGUMENT, "Unsupported type");
162 } else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
163 switch (itype.getWidth()) {
164 case 1:
165 return tflite::TensorType_BOOL;
166 case 8:
167 return itype.isUnsigned() ? tflite::TensorType_UINT8
168 : tflite::TensorType_INT8;
169 case 16:
170 return tflite::TensorType_INT16;
171 case 32:
172 return itype.isUnsigned() ? tflite::TensorType_UINT32
173 : tflite::TensorType_INT32;
174 case 64:
175 return itype.isUnsigned() ? tflite::TensorType_UINT64
176 : tflite::TensorType_INT64;
177 }
178 } else if (auto q_uniform_type =
179 type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
180 return GetTFLiteType(q_uniform_type.getStorageType(),
181 q_uniform_type.isSigned());
182 } else if (auto q_peraxis_type =
183 type.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
184 return GetTFLiteType(q_peraxis_type.getStorageType(),
185 q_peraxis_type.isSigned());
186 } else if (auto q_calibrated_type =
187 type.dyn_cast<mlir::quant::CalibratedQuantizedType>()) {
188 return GetTFLiteType(q_calibrated_type.getExpressedType());
189 } else if (type.isa<mlir::TF::ResourceType>()) {
190 return tflite::TensorType_RESOURCE;
191 } else if (type.isa<mlir::TF::VariantType>()) {
192 return tflite::TensorType_VARIANT;
193 }
194 // TFLite export fills FLOAT32 for unknown data types. Returning an error
195 // for now for safety and this could be revisited when required.
196 return Status(error::INVALID_ARGUMENT, "Unsupported type");
197 }
198
IsConst(Operation * op)199 static bool IsConst(Operation* op) {
200 return isa<mlir::ConstantOp, mlir::TF::ConstOp, tfl::ConstOp, tfl::QConstOp,
201 tfl::SparseConstOp, tfl::SparseQConstOp>(op);
202 }
203
IsTFResourceOp(Operation * op)204 static bool IsTFResourceOp(Operation* op) {
205 for (const auto& operand : op->getOperands()) {
206 auto elementType = getElementTypeOrSelf(operand.getType());
207 if (elementType.isa<mlir::TF::ResourceType>()) {
208 return true;
209 }
210 }
211 for (const auto& result : op->getResults()) {
212 auto elementType = getElementTypeOrSelf(result.getType());
213 if (elementType.isa<mlir::TF::ResourceType>()) {
214 return true;
215 }
216 }
217 return false;
218 }
219
220 // Create description of operation that could not be converted.
GetOpDescriptionForDebug(Operation * inst)221 static std::string GetOpDescriptionForDebug(Operation* inst) {
222 const int kLargeElementsAttr = 16;
223 std::string op_str;
224 llvm::raw_string_ostream os(op_str);
225 inst->getName().print(os);
226 os << "(";
227 if (!inst->getOperandTypes().empty()) {
228 bool first = true;
229 for (Type operand_type : inst->getOperandTypes()) {
230 os << (!first ? ", " : "");
231 first = false;
232 os << operand_type;
233 }
234 }
235 os << ") -> (";
236 if (!inst->getResultTypes().empty()) {
237 bool first = true;
238 for (Type result_type : inst->getResultTypes()) {
239 os << (!first ? ", " : "");
240 first = false;
241 os << result_type;
242 }
243 }
244 os << ")";
245 // Print out attributes except for large elementsattributes (which should
246 // rarely be the cause why the legalization didn't happen).
247 if (!inst->getAttrDictionary().empty()) {
248 os << " : {";
249 bool first = true;
250 for (auto& named_attr : inst->getAttrDictionary()) {
251 os << (!first ? ", " : "");
252 first = false;
253 named_attr.first.print(os);
254 os << " = ";
255 if (auto element_attr = named_attr.second.dyn_cast<ElementsAttr>()) {
256 if (element_attr.getNumElements() <= kLargeElementsAttr) {
257 element_attr.print(os);
258 } else {
259 os << "<large>";
260 }
261 } else {
262 named_attr.second.print(os);
263 }
264 }
265 os << "}";
266 }
267 return os.str();
268 }
269
270 // Create a summary with the given information regarding op names and
271 // descriptions.
GetOpsSummary(const std::map<std::string,std::set<std::string>> & ops,const std::string & summary_title)272 static std::string GetOpsSummary(
273 const std::map<std::string, std::set<std::string>>& ops,
274 const std::string& summary_title) {
275 std::string op_str;
276 llvm::raw_string_ostream os(op_str);
277
278 std::vector<std::string> keys;
279 keys.reserve(ops.size());
280
281 std::vector<std::string> values;
282 values.reserve(ops.size());
283
284 for (auto const& op_name_and_details : ops) {
285 keys.push_back(op_name_and_details.first);
286 for (auto const& op_detail : op_name_and_details.second) {
287 values.push_back(op_detail);
288 }
289 }
290
291 os << summary_title << " ops: " << absl::StrJoin(keys, ", ") << "\n";
292 os << "Details:\n\t" << absl::StrJoin(values, "\n\t");
293
294 return os.str();
295 }
296
297 template <typename T>
HasValidTFLiteType(Value value,T & error_handler)298 static bool HasValidTFLiteType(Value value, T& error_handler) {
299 // None type is allowed to represent unspecified operands.
300 if (value.getType().isa<NoneType>()) return true;
301
302 auto type = value.getType().dyn_cast<TensorType>();
303 if (!type) {
304 if (auto op = value.getDefiningOp()) {
305 error_handler.emitError()
306 << '\'' << op << "' should produce value of tensor type instead of "
307 << value.getType();
308 return false;
309 }
310 error_handler.emitError("expected tensor type, got ") << value.getType();
311 return false;
312 }
313
314 Type element_type = type.getElementType();
315 auto status = GetTFLiteType(element_type);
316 if (!status.ok()) {
317 return error_handler.emitError(
318 formatv("Failed to convert element type '{0}': {1}",
319 element_type, status.status().error_message())),
320 false;
321 }
322 return true;
323 }
324
325 // Returns true if the module holds all the invariants expected by the
326 // Translator class.
327 // TODO(hinsu): Now that translation is done by making a single pass over the
328 // MLIR module, consider inlining these validation checks at the place where
329 // these invariants are assumed instead of checking upfront.
IsValidTFLiteMlirModule(ModuleOp module)330 static bool IsValidTFLiteMlirModule(ModuleOp module) {
331 MLIRContext* context = module.getContext();
332
333 // Verify that module has a function named main.
334 FuncOp main_fn = module.lookupSymbol<FuncOp>("main");
335 if (!main_fn) {
336 int entry_func_count = 0;
337 for (auto fn : module.getOps<FuncOp>()) {
338 auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
339 if (attrs && !attrs.empty()) {
340 ++entry_func_count;
341 }
342 }
343
344 // Verify that module has a least one enrty function.
345 if (entry_func_count == 0) {
346 return emitError(UnknownLoc::get(context),
347 "should have a least one entry function"),
348 false;
349 }
350 }
351
352 for (auto fn : module.getOps<FuncOp>()) {
353 if (!llvm::hasSingleElement(fn)) {
354 return fn.emitError("should have exactly one basic block"), false;
355 }
356 auto& bb = fn.front();
357
358 for (auto arg : bb.getArguments()) {
359 if (!HasValidTFLiteType(arg, fn)) {
360 auto elementType = getElementTypeOrSelf(arg.getType());
361 if (elementType.isa<mlir::TF::VariantType>()) {
362 return fn.emitError(
363 "function argument uses variant type. Currently, the "
364 "variant type is not natively supported in TFLite. Please "
365 "consider not using the variant type: ")
366 << arg.getType(),
367 false;
368 }
369 return fn.emitError("invalid TFLite type: ") << arg.getType(), false;
370 }
371 }
372
373 // Verify that all operations except the terminator have exactly one
374 // result of type supported by TFLite.
375 for (auto& inst : bb) {
376 if (inst.hasTrait<mlir::OpTrait::IsTerminator>()) break;
377
378 for (auto result : inst.getResults()) {
379 if (!HasValidTFLiteType(result, inst)) {
380 auto elementType = getElementTypeOrSelf(result.getType());
381 if (elementType.isa<mlir::TF::VariantType>()) {
382 return inst.emitError(
383 "operand result uses variant type. Currently, the "
384 "variant type is not natively supported in TFLite. "
385 "Please "
386 "consider not using the variant type: ")
387 << result.getType(),
388 false;
389 }
390 return fn.emitError("invalid TFLite type: ") << result.getType(),
391 false;
392 }
393 }
394 }
395 }
396
397 return true;
398 }
399
GetTensorFlowNodeDef(::mlir::Operation * inst)400 static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef(
401 ::mlir::Operation* inst) {
402 // We pass empty string for the original node_def name since Flex runtime
403 // does not care about this being set correctly on node_def. There is no
404 // "easy" (see b/120948529) way yet to get this from MLIR inst.
405 auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef(
406 inst, /*name=*/"", /*ignore_unregistered_attrs=*/true);
407 if (!status_or_node_def.ok()) {
408 inst->emitOpError(
409 Twine("failed to obtain TensorFlow nodedef with status: " +
410 status_or_node_def.status().ToString()));
411 return {};
412 }
413 return std::move(status_or_node_def.ValueOrDie());
414 }
415
416 // Converts a mlir padding StringRef to TfLitePadding.
417 // Returns llvm::None if conversion fails.
GetTflitePadding(Operation * inst,llvm::StringRef padding)418 static Optional<TfLitePadding> GetTflitePadding(Operation* inst,
419 llvm::StringRef padding) {
420 const tflite::Padding padding_attr =
421 std::move(llvm::StringSwitch<tflite::Padding>(padding)
422 .Case("SAME", tflite::Padding_SAME)
423 .Case("VALID", tflite::Padding_VALID));
424 if (padding_attr == tflite::Padding_SAME) {
425 return kTfLitePaddingSame;
426 }
427 if (padding_attr == tflite::Padding_VALID) {
428 return kTfLitePaddingValid;
429 }
430
431 return inst->emitOpError() << "Invalid padding attribute: " << padding,
432 llvm::None;
433 }
434
435 // Extracts TfLitePoolParams from a TFL custom op.
436 // Template parameter, TFLOp, should be a TFL custom op containing attributes
437 // generated from TfLitePoolParams.
438 // Returns llvm::None if conversion fails.
439 template <typename TFLOp>
GetTflitePoolParams(Operation * inst,TFLOp op)440 static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
441 TFLOp op) {
442 TfLitePoolParams pool_params;
443 pool_params.stride_height = op.stride_h().getSExtValue();
444 pool_params.stride_width = op.stride_w().getSExtValue();
445 pool_params.filter_height = op.filter_h().getSExtValue();
446 pool_params.filter_width = op.filter_w().getSExtValue();
447 const auto padding = GetTflitePadding(inst, op.padding());
448 if (padding) {
449 pool_params.padding = *padding;
450 pool_params.activation = kTfLiteActNone;
451 pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
452 return pool_params;
453 }
454
455 return llvm::None;
456 }
457
458 namespace {
459
460 // Helper struct that wraps inputs/outputs of a single SignatureDef.
461 struct SignatureDefData {
462 // Note, we are using maps here to make order deterministic
463 // for easily testing only.
464
465 // Inputs defined in the signature def mapped to tensor names.
466 std::map<std::string, std::string> inputs;
467 // Outputs defined in the signature def mapped to tensor names.
468 std::map<std::string, std::string> outputs;
469 // Signature key.
470 std::string signature_key;
471 // Subgraph index.
472 uint32_t subgraph_index;
473 };
474
475 // Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
476 class Translator {
477 public:
478 // Translates the given MLIR module into TFLite FlatBuffer format and returns
479 // the serialized output. Returns llvm::None on unsupported, invalid inputs or
480 // internal error.
481 static Optional<std::string> Translate(
482 ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
483 bool emit_custom_ops, bool allow_all_select_tf_ops,
484 const std::unordered_set<std::string>& select_user_tf_ops,
485 const std::unordered_set<std::string>& tags,
486 OpOrArgNameMapper* op_or_arg_name_mapper,
487 const std::map<std::string, std::string>& metadata);
488
489 private:
490 enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
Translator(ModuleOp module,bool emit_builtin_tflite_ops,bool emit_select_tf_ops,bool emit_custom_ops,bool allow_all_select_tf_ops,const std::unordered_set<std::string> & select_user_tf_ops,const std::unordered_set<std::string> & saved_model_tags,OpOrArgNameMapper * op_or_arg_name_mapper,const std::map<std::string,std::string> & metadata)491 explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops,
492 bool emit_select_tf_ops, bool emit_custom_ops,
493 bool allow_all_select_tf_ops,
494 const std::unordered_set<std::string>& select_user_tf_ops,
495 const std::unordered_set<std::string>& saved_model_tags,
496 OpOrArgNameMapper* op_or_arg_name_mapper,
497 const std::map<std::string, std::string>& metadata)
498 : module_(module),
499 name_mapper_(*op_or_arg_name_mapper),
500 builder_(kInitialBufferSize),
501 saved_model_tags_(saved_model_tags),
502 allow_all_select_tf_ops_(allow_all_select_tf_ops),
503 select_user_tf_ops_(select_user_tf_ops),
504 metadata_(metadata) {
505 // The first buffer must be empty according to the schema definition.
506 empty_buffer_ = tflite::CreateBuffer(builder_);
507 buffers_.push_back(empty_buffer_);
508 if (emit_builtin_tflite_ops) {
509 enabled_op_types_.emplace(OpType::kTfliteBuiltin);
510 }
511 if (emit_select_tf_ops) {
512 enabled_op_types_.emplace(OpType::kSelectTf);
513 }
514 if (emit_custom_ops) {
515 enabled_op_types_.emplace(OpType::kCustomOp);
516 }
517 tf_dialect_ =
518 module.getContext()->getOrLoadDialect<mlir::TF::TensorFlowDialect>();
519 tfl_dialect_ = module.getContext()
520 ->getOrLoadDialect<mlir::TFL::TensorFlowLiteDialect>();
521 // Right now the TF executor dialect is still needed to build NodeDef.
522 module.getContext()
523 ->getOrLoadDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
524 }
525
526 Optional<std::string> TranslateInternal();
527
528 // Returns TFLite buffer populated with constant value if the operation is
529 // TFLite constant operation. Otherwise, returns an empty buffer. Emits error
530 // and returns llvm::None on failure.
531 Optional<BufferOffset<tflite::Buffer>> BuildBuffer(Operation* inst);
532
533 // Build TFLite tensor from the given type. This function is for tfl.lstm
534 // intermediates, which should have UniformQuantizedType.
535 Optional<BufferOffset<tflite::Tensor>> BuildTensorFromType(
536 mlir::Type type, const std::string& name);
537
538 // Builds TFLite tensor from the given value. `buffer_idx` is index of the
539 // corresponding buffer. Emits error and returns llvm::None on failure.
540 Optional<BufferOffset<tflite::Tensor>> BuildTensor(
541 Value value, const std::string& name, unsigned buffer_idx,
542 const Optional<BufferOffset<tflite::QuantizationParameters>>&
543 quant_parameters);
544
545 // TODO(b/137395003): Legalize tf.IfOp to TFLite dialect, and change the
546 // following method to handle TFL::IfOp.
547 BufferOffset<tflite::Operator> BuildIfOperator(
548 mlir::TF::IfOp op, const std::vector<int32_t>& operands,
549 const std::vector<int32_t>& results);
550
551 // Build while operator where cond & body are regions.
552 Optional<BufferOffset<tflite::Operator>> BuildWhileOperator(
553 mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
554 const std::vector<int32_t>& results);
555
556 // Build call once operator.
557 BufferOffset<tflite::Operator> BuildCallOnceOperator(
558 mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
559 const std::vector<int32_t>& results);
560
561 BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
562 mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
563 const std::vector<int32_t>& results);
564
565 BufferOffset<tflite::Operator> BuildCustomOperator(
566 Operation* inst, mlir::TFL::CustomOp op,
567 const std::vector<int32_t>& operands,
568 const std::vector<int32_t>& results);
569
570 Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
571 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
572
573 Optional<CustomOptionsOffset> CreateCustomOpCustomOptions(
574 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
575
576 std::unique_ptr<flexbuffers::Builder> CreateFlexBuilderWithNodeAttrs(
577 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
578
579 // Returns opcode index for op identified by the op_name, if already
580 // available. Otherwise, creates a new OperatorCode using the given `builtin`
581 // operator and associates it with `op_name`.
582 uint32_t GetOpcodeIndex(const std::string& op_name,
583 tflite::BuiltinOperator builtin);
584
585 // Builds operator for the given operation with specified operand and result
586 // tensor indices. Emits an error and returns llvm::None on failure.
587 Optional<BufferOffset<tflite::Operator>> BuildOperator(
588 Operation* inst, std::vector<int32_t> operands,
589 const std::vector<int32_t>& results,
590 const std::vector<int32_t>& intermediates);
591
592 // Returns the quantization parameters for output value of "quant.stats" op.
593 BufferOffset<tflite::QuantizationParameters>
594 GetQuantizationForQuantStatsOpOutput(mlir::quant::StatisticsOp stats_op);
595
596 // Build a subgraph with a given name out of the region either corresponding
597 // to a function's body or while op.
598 Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(
599 const std::string& name, Region* region, const int index);
600
601 // Builds Metadata with the given `name` and buffer `content`.
602 BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
603 StringRef content);
604
605 // Encodes the `tfl.metadata` dictionary attribute of the module to the
606 // metadata section in the final model.
607 Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
608 CreateMetadataVector();
609
610 // Builds and returns list of tfl.SignatureDef sections in the model.
611 Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
612 CreateSignatureDefs(const std::vector<SignatureDefData>& signature_defs);
613
614 // Returns list of offsets for the passed 'items' in TensorMap structure
615 // inside the flatbuffer.
616 // 'items' is a map from tensor name in signatureDef to tensor name in
617 // the subgraph, specified by the 'subgraph_index' argument.
618 std::vector<BufferOffset<tflite::TensorMap>> GetList(
619 const int subgraph_index,
620 const std::map<std::string, std::string>& items);
621
622 // Uses the tf.entry_function attribute (if set) to initialize the op to name
623 // mapping.
624 void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr);
625
626 // Determines if the specified operation op's operand at operand_index
627 // is marked as a stateful operand.
628 bool IsStatefulOperand(mlir::Operation* op, int operand_index);
629
630 // Returns a unique name for `val`.
631 std::string UniqueName(mlir::Value val);
632
633 BufferOffset<tflite::SparsityParameters> BuildSparsityParameters(
634 const mlir::TFL::SparsityParameterAttr& s_attr);
635
636 bool EstimateArithmeticCount(int64_t* count);
637
638 ModuleOp module_;
639
640 tensorflow::OpOrArgNameMapper& name_mapper_;
641
642 flatbuffers::FlatBufferBuilder builder_;
643 BufferOffset<tflite::Buffer> empty_buffer_;
644
645 std::vector<BufferOffset<tflite::Buffer>> buffers_;
646 // Maps subgraph index and tensor name in the graph to the tensor index.
647 absl::flat_hash_map<int, absl::flat_hash_map<std::string, int>>
648 tensor_index_map_;
649
650 // Maps op name to index of the corresponding OperatorCode in opcodes_ vector.
651 absl::flat_hash_map<std::string, uint32_t> opcode_index_map_;
652 std::vector<BufferOffset<tflite::OperatorCode>> opcodes_;
653
654 // Maps function name to index of the corresponding subgraph in the FlatBuffer
655 // model.
656 absl::flat_hash_map<std::string, int> subgraph_index_map_;
657 absl::flat_hash_set<OpType> enabled_op_types_;
658
659 // Points to TensorFlow and TFLite dialects, respectively. nullptr if the
660 // dialect is not registered.
661 const Dialect* tf_dialect_;
662 const Dialect* tfl_dialect_;
663
664 // The failed ops during legalization.
665 std::map<std::string, std::set<std::string>> failed_flex_ops_;
666 std::map<std::string, std::set<std::string>> failed_custom_ops_;
667
668 // Ops to provide warning messages.
669 std::map<std::string, std::set<std::string>> custom_ops_;
670 std::map<std::string, std::set<std::string>> flex_ops_;
671
672 // Resource ops to provide warning messages.
673 std::map<std::string, std::set<std::string>> resource_ops_;
674
675 // Set of saved model tags, if any.
676 const std::unordered_set<std::string> saved_model_tags_;
677 // Allows automatic pass through of TF ops as select Tensorflow ops.
678 const bool allow_all_select_tf_ops_;
679 // User's defined ops allowed with Flex.
680 const std::unordered_set<std::string> select_user_tf_ops_;
681 // Map of key value pairs of metadata to export.
682 const std::map<std::string, std::string> metadata_;
683 };
684
EstimateArithmeticCount(int64_t * count)685 bool Translator::EstimateArithmeticCount(int64_t* count) {
686 int64_t result = 0;
687 bool encounter_undetermined_mac = false;
688 module_->walk([&](mlir::TFL::TflArithmeticCountOpInterface op) {
689 int64_t mac_count = op.GetArithmeticCount(op);
690 if (mac_count < 0) {
691 encounter_undetermined_mac = true;
692 return;
693 }
694 result += mac_count;
695 });
696
697 *count = result;
698 return !encounter_undetermined_mac;
699 }
700
UniqueName(mlir::Value val)701 std::string Translator::UniqueName(mlir::Value val) {
702 return std::string(name_mapper_.GetUniqueName(val));
703 }
704
BuildBuffer(Operation * inst)705 Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
706 Operation* inst) {
707 ElementsAttr attr;
708 if (auto cst = dyn_cast<mlir::ConstantOp>(inst)) {
709 // ConstantOp have ElementAttr at this point due to validation of the TFLite
710 // module.
711 attr = cst.getValue().cast<ElementsAttr>();
712 } else if (auto cst = dyn_cast<mlir::TF::ConstOp>(inst)) {
713 attr = cst.value();
714 } else if (auto cst = dyn_cast<tfl::ConstOp>(inst)) {
715 attr = cst.value();
716 } else if (auto cst = dyn_cast<tfl::QConstOp>(inst)) {
717 attr = cst.value();
718 } else if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
719 attr = cst.compressed_data();
720 } else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
721 attr = cst.compressed_data();
722 } else {
723 return empty_buffer_;
724 }
725
726 tensorflow::Tensor tensor;
727 auto status = tensorflow::ConvertToTensor(attr, &tensor);
728 if (!status.ok()) {
729 inst->emitError(
730 Twine("failed to convert value attribute to tensor with error: " +
731 status.ToString()));
732 return llvm::None;
733 }
734
735 // TensorFlow and TensorFlow Lite use different string encoding formats.
736 // Convert to TensorFlow Lite format is it's a constant string tensor.
737 if (tensor.dtype() == tensorflow::DT_STRING) {
738 ::tflite::DynamicBuffer dynamic_buffer;
739 auto flat = tensor.flat<::tensorflow::tstring>();
740 for (int i = 0; i < flat.size(); ++i) {
741 const auto& str = flat(i);
742 dynamic_buffer.AddString(str.c_str(), str.length());
743 }
744 char* tensor_buffer;
745 int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer);
746 auto buffer_data =
747 builder_.CreateVector(reinterpret_cast<uint8_t*>(tensor_buffer), bytes);
748 free(tensor_buffer);
749 return tflite::CreateBuffer(builder_, buffer_data);
750 }
751
752 absl::string_view tensor_data = tensor.tensor_data();
753 auto buffer_data = builder_.CreateVector(
754 reinterpret_cast<const uint8_t*>(tensor_data.data()), tensor_data.size());
755 return tflite::CreateBuffer(builder_, buffer_data);
756 }
757
BuildTensorFromType(mlir::Type type,const std::string & name)758 Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensorFromType(
759 mlir::Type type, const std::string& name) {
760 auto tensor_type = type.cast<TensorType>();
761
762 if (!tensor_type.hasStaticShape()) {
763 return llvm::None;
764 }
765 llvm::ArrayRef<int64_t> shape_ref = tensor_type.getShape();
766 std::vector<int32_t> shape(shape_ref.begin(), shape_ref.end());
767
768 auto element_type = tensor_type.getElementType();
769 tflite::TensorType tflite_element_type =
770 GetTFLiteType(tensor_type.getElementType()).ValueOrDie();
771 BufferOffset<tflite::QuantizationParameters> q_params = 0;
772 if (auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
773 q_params = tflite::CreateQuantizationParameters(
774 builder_, /*min=*/0, /*max=*/0,
775 builder_.CreateVector<float>({static_cast<float>(qtype.getScale())}),
776 builder_.CreateVector<int64_t>({qtype.getZeroPoint()}));
777 } else if (auto qtype =
778 element_type
779 .dyn_cast<mlir::quant::CalibratedQuantizedType>()) {
780 q_params = tflite::CreateQuantizationParameters(
781 builder_,
782 builder_.CreateVector<float>({static_cast<float>(qtype.getMin())}),
783 builder_.CreateVector<float>({static_cast<float>(qtype.getMax())}));
784 }
785 return tflite::CreateTensor(
786 builder_, builder_.CreateVector(shape), tflite_element_type,
787 /*buffer=*/0, builder_.CreateString(name), q_params,
788 /*is_variable=*/false);
789 }
790
BuildTensor(Value value,const std::string & name,unsigned buffer_idx,const Optional<BufferOffset<tflite::QuantizationParameters>> & quant_parameters)791 Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
792 Value value, const std::string& name, unsigned buffer_idx,
793 const Optional<BufferOffset<tflite::QuantizationParameters>>&
794 quant_parameters) {
795 auto type = value.getType().cast<TensorType>();
796
797 // TFLite requires tensor shape only for the inputs and constants.
798 // However, we output all known shapes for better round-tripping
799 auto check_shape =
800 [&](llvm::ArrayRef<int64_t> shape_ref) -> mlir::LogicalResult {
801 auto is_out_of_range = [](int64_t dim) {
802 return dim > std::numeric_limits<int32_t>::max();
803 };
804
805 if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range))
806 return mlir::emitError(
807 value.getLoc(),
808 "result shape dimensions out of 32 bit int type range");
809
810 return mlir::success();
811 };
812
813 std::vector<int32_t> shape;
814 std::vector<int32_t> shape_signature;
815 auto* inst = value.getDefiningOp();
816 if (type.hasStaticShape()) {
817 llvm::ArrayRef<int64_t> shape_ref = type.getShape();
818 if (mlir::failed(check_shape(shape_ref))) return llvm::None;
819
820 shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
821 } else if (inst && IsConst(inst)) {
822 // Const op can have a result of dynamic shaped type (e.g. due to constant
823 // folding), but we can still derive the shape of a constant tensor for
824 // its attribute type.
825 mlir::Attribute tensor_attr = inst->getAttr("value");
826 llvm::ArrayRef<int64_t> shape_ref =
827 tensor_attr.getType().cast<TensorType>().getShape();
828 if (mlir::failed(check_shape(shape_ref))) return llvm::None;
829
830 shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
831 } else if (type.hasRank()) {
832 llvm::ArrayRef<int64_t> shape_ref = type.getShape();
833 if (mlir::failed(check_shape(shape_ref))) return llvm::None;
834
835 shape.reserve(shape_ref.size());
836 for (auto& dim : shape_ref) {
837 shape.push_back(dim == -1 ? 1 : dim);
838 }
839 shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
840 }
841
842 BufferOffset<tflite::SparsityParameters> s_params = 0;
843 if (auto* inst = value.getDefiningOp()) {
844 if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
845 s_params = BuildSparsityParameters(cst.s_param());
846 } else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
847 s_params = BuildSparsityParameters(cst.s_param());
848 }
849 }
850
851 Type element_type = type.getElementType();
852 tflite::TensorType tflite_element_type =
853 GetTFLiteType(type.getElementType()).ValueOrDie();
854
855 BufferOffset<tflite::QuantizationParameters> q_params;
856 if (auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
857 q_params = tflite::CreateQuantizationParameters(
858 // min and max values are not stored in the quantized type from MLIR, so
859 // both are set to 0 in the flatbuffer when they are exported.
860 builder_, /*min=*/0, /*max=*/0,
861 builder_.CreateVector<float>({static_cast<float>(qtype.getScale())}),
862 builder_.CreateVector<int64_t>({qtype.getZeroPoint()}));
863 } else if (auto qtype =
864 element_type
865 .dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
866 std::vector<float> scales(qtype.getScales().begin(),
867 qtype.getScales().end());
868 q_params = tflite::CreateQuantizationParameters(
869 builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>(scales),
870 builder_.CreateVector<int64_t>(qtype.getZeroPoints()),
871 tflite::QuantizationDetails_NONE, /*details=*/0,
872 qtype.getQuantizedDimension());
873 } else if (quant_parameters.hasValue()) {
874 q_params = quant_parameters.getValue();
875 } else {
876 q_params = tflite::CreateQuantizationParameters(builder_);
877 }
878 // Check if the value's uses includes an op and usage at an operand index
879 // marked as a stateful. If so, set the tensor's is_variable as true
880 // This is v1 ref variable semantics in the TFLite runtime.
881 bool is_variable = false;
882 for (auto& use : value.getUses()) {
883 is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
884 if (is_variable) {
885 break;
886 }
887 }
888
889 if (shape_signature.empty()) {
890 return tflite::CreateTensor(
891 builder_, builder_.CreateVector(shape), tflite_element_type,
892 (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
893 /*is_variable=*/is_variable, s_params);
894 } else {
895 return tflite::CreateTensor(
896 builder_, builder_.CreateVector(shape), tflite_element_type,
897 (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
898 /*is_variable=*/is_variable, s_params,
899 /*shape_signature=*/builder_.CreateVector(shape_signature));
900 }
901 }
902
BuildIfOperator(mlir::TF::IfOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)903 BufferOffset<tflite::Operator> Translator::BuildIfOperator(
904 mlir::TF::IfOp op, const std::vector<int32_t>& operands,
905 const std::vector<int32_t>& results) {
906 auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF);
907 int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str());
908 int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str());
909 auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index,
910 else_subgraph_index)
911 .Union();
912 auto inputs = builder_.CreateVector(operands);
913 auto outputs = builder_.CreateVector(results);
914 return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
915 tflite::BuiltinOptions_IfOptions,
916 builtin_options);
917 }
918
BuildCallOnceOperator(mlir::TFL::CallOnceOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)919 BufferOffset<tflite::Operator> Translator::BuildCallOnceOperator(
920 mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
921 const std::vector<int32_t>& results) {
922 auto opcode_index =
923 GetOpcodeIndex("call_once", tflite::BuiltinOperator_CALL_ONCE);
924 int init_subgraph_index =
925 subgraph_index_map_.at(op.session_init_function().str());
926 auto builtin_options =
927 tflite::CreateCallOnceOptions(builder_, init_subgraph_index).Union();
928 auto inputs = builder_.CreateVector(operands);
929 auto outputs = builder_.CreateVector(results);
930 return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
931 tflite::BuiltinOptions_CallOnceOptions,
932 builtin_options);
933 }
934
BuildWhileOperator(mlir::TFL::WhileOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)935 Optional<BufferOffset<tflite::Operator>> Translator::BuildWhileOperator(
936 mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
937 const std::vector<int32_t>& results) {
938 auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
939 auto get_call_index = [&](mlir::Block& b) -> Optional<int> {
940 if (b.getOperations().size() != 2) return llvm::None;
941 if (auto call_op = dyn_cast<mlir::CallOp>(b.front()))
942 return subgraph_index_map_.at(call_op.callee().str());
943 return llvm::None;
944 };
945 auto body_subgraph_index = get_call_index(op.body().front());
946 auto cond_subgraph_index = get_call_index(op.cond().front());
947 if (!body_subgraph_index || !cond_subgraph_index)
948 return op.emitOpError("only single call cond/body while export supported"),
949 llvm::None;
950 auto builtin_options =
951 tflite::CreateWhileOptions(builder_, *cond_subgraph_index,
952 *body_subgraph_index)
953 .Union();
954 auto inputs = builder_.CreateVector(operands);
955 auto outputs = builder_.CreateVector(results);
956 return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
957 tflite::BuiltinOptions_WhileOptions,
958 builtin_options);
959 }
960
BuildNumericVerifyOperator(mlir::TFL::NumericVerifyOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)961 BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
962 mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
963 const std::vector<int32_t>& results) {
964 float tolerance = op.tolerance().convertToFloat();
965 bool log_if_failed = op.log_if_failed();
966 auto fbb = absl::make_unique<flexbuffers::Builder>();
967 fbb->Map([&]() {
968 fbb->Float("tolerance", tolerance);
969 fbb->Bool("log_if_failed", log_if_failed);
970 });
971 fbb->Finish();
972 auto f = std::unique_ptr<flexbuffers::Builder>(fbb.release());
973 auto custom_option = f->GetBuffer();
974 auto opcode_index =
975 GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM);
976 return tflite::CreateOperator(
977 builder_, opcode_index, builder_.CreateVector(operands),
978 builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
979 /*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_option),
980 tflite::CustomOptionsFormat_FLEXBUFFERS);
981 }
982
BuildCustomOperator(Operation * inst,mlir::TFL::CustomOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)983 BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
984 Operation* inst, mlir::TFL::CustomOp op,
985 const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
986 const std::string attrs =
987 op.custom_option().cast<mlir::OpaqueElementsAttr>().getValue().str();
988 std::vector<uint8_t> custom_option_vector(attrs.size());
989 memcpy(custom_option_vector.data(), attrs.data(), attrs.size());
990 auto opcode_index =
991 GetOpcodeIndex(op.custom_code().str(), tflite::BuiltinOperator_CUSTOM);
992 return tflite::CreateOperator(
993 builder_, opcode_index, builder_.CreateVector(operands),
994 builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
995 /*builtin_options=*/0,
996 builder_.CreateVector<uint8_t>(custom_option_vector),
997 tflite::CustomOptionsFormat_FLEXBUFFERS);
998 }
999
CreateFlexOpCustomOptions(const::tensorflow::NodeDef & node_def,const mlir::Location & loc)1000 Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
1001 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
1002 std::string node_def_str;
1003 if (!node_def.SerializeToString(&node_def_str)) {
1004 return emitError(loc, "failed to serialize tensorflow node_def"),
1005 llvm::None;
1006 }
1007
1008 auto flex_builder = absl::make_unique<flexbuffers::Builder>();
1009 flex_builder->Vector([&]() {
1010 flex_builder->String(node_def.op());
1011 flex_builder->String(node_def_str);
1012 });
1013 flex_builder->Finish();
1014 return builder_.CreateVector(flex_builder->GetBuffer());
1015 }
1016
CreateCustomOpCustomOptions(const::tensorflow::NodeDef & node_def,const mlir::Location & loc)1017 Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions(
1018 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
1019 auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc);
1020 return builder_.CreateVector(flex_builder->GetBuffer());
1021 }
1022
1023 std::unique_ptr<flexbuffers::Builder>
CreateFlexBuilderWithNodeAttrs(const::tensorflow::NodeDef & node_def,const mlir::Location & loc)1024 Translator::CreateFlexBuilderWithNodeAttrs(
1025 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
1026 auto flex_builder = absl::make_unique<flexbuffers::Builder>();
1027 size_t map_start = flex_builder->StartMap();
1028 using Item = std::pair<std::string, ::tensorflow::AttrValue>;
1029 std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end());
1030 std::sort(attrs.begin(), attrs.end(),
1031 [](Item& p1, Item& p2) -> bool { return p1.first < p2.first; });
1032 for (const Item& pair : attrs) {
1033 const char* key = pair.first.c_str();
1034 const ::tensorflow::AttrValue& attr = pair.second;
1035 switch (attr.value_case()) {
1036 case ::tensorflow::AttrValue::kS:
1037 flex_builder->String(key, attr.s());
1038 break;
1039 case ::tensorflow::AttrValue::kType: {
1040 auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type());
1041 if (status_or_tfl_type.ok()) {
1042 flex_builder->Int(key, status_or_tfl_type.ValueOrDie());
1043 } else {
1044 emitWarning(loc, "ignoring unsupported tensorflow type: ")
1045 << std::to_string(attr.type());
1046 }
1047 break;
1048 }
1049 case ::tensorflow::AttrValue::kI:
1050 flex_builder->Int(key, attr.i());
1051 break;
1052 case ::tensorflow::AttrValue::kF:
1053 flex_builder->Float(key, attr.f());
1054 break;
1055 case ::tensorflow::AttrValue::kB:
1056 flex_builder->Bool(key, attr.b());
1057 break;
1058 case tensorflow::AttrValue::kList:
1059 if (attr.list().s_size() > 0) {
1060 auto start = flex_builder->StartVector(key);
1061 for (const std::string& v : attr.list().s()) {
1062 flex_builder->Add(v);
1063 }
1064 flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1065 } else if (attr.list().i_size() > 0) {
1066 auto start = flex_builder->StartVector(key);
1067 for (const int64_t v : attr.list().i()) {
1068 flex_builder->Add(v);
1069 }
1070 flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1071 } else if (attr.list().f_size() > 0) {
1072 auto start = flex_builder->StartVector(key);
1073 for (const float v : attr.list().f()) {
1074 flex_builder->Add(v);
1075 }
1076 flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1077 } else {
1078 emitWarning(loc,
1079 "ignoring unsupported type in list attribute with key: ")
1080 << key;
1081 }
1082 break;
1083 default:
1084 emitWarning(loc, "ignoring unsupported attribute type with key: ")
1085 << key;
1086 break;
1087 }
1088 }
1089 flex_builder->EndMap(map_start);
1090 flex_builder->Finish();
1091 return flex_builder;
1092 }
1093
GetOpcodeIndex(const std::string & op_name,tflite::BuiltinOperator builtin)1094 uint32_t Translator::GetOpcodeIndex(const std::string& op_name,
1095 tflite::BuiltinOperator builtin) {
1096 auto it = opcode_index_map_.insert({op_name, 0});
1097
1098 // If the insert succeeded, the opcode has not been created already. Create a
1099 // new operator code and update its index value in the map.
1100 if (it.second) {
1101 it.first->second = opcodes_.size();
1102 auto custom_code = builtin == tflite::BuiltinOperator_CUSTOM
1103 ? builder_.CreateString(op_name)
1104 : BufferOffset<flatbuffers::String>();
1105 // Use version 0 for builtin op. This is a way to serialize version field to
1106 // flatbuffer (since 0 is non default) and it will be corrected later.
1107 int32_t op_version = builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1;
1108 opcodes_.push_back(CreateOperatorCode(builder_, /*builtin_code=*/builtin,
1109 custom_code, op_version));
1110 }
1111 return it.first->second;
1112 }
1113
BuildOperator(Operation * inst,std::vector<int32_t> operands,const std::vector<int32_t> & results,const std::vector<int32_t> & intermediates)1114 Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
1115 Operation* inst, std::vector<int32_t> operands,
1116 const std::vector<int32_t>& results,
1117 const std::vector<int32_t>& intermediates) {
1118 const auto* dialect = inst->getDialect();
1119 if (!dialect) {
1120 inst->emitOpError("dialect is not registered");
1121 return llvm::None;
1122 }
1123
1124 // If TFLite built in op, create operator as a builtin op.
1125 if (dialect == tfl_dialect_) {
1126 // Only if built-in TFLite op emission is enabled, would legalization have
1127 // converted any TF->TFL.
1128 if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) {
1129 return inst->emitOpError(
1130 "is a TFLite builtin op but builtin emission is not enabled"),
1131 llvm::None;
1132 }
1133
1134 auto builtin_code = GetBuiltinOpCode(inst);
1135 if (!builtin_code) {
1136 if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
1137 return BuildNumericVerifyOperator(verify_op, operands, results);
1138 }
1139 if (auto custom_op = dyn_cast<mlir::TFL::CustomOp>(inst)) {
1140 return BuildCustomOperator(inst, custom_op, operands, results);
1141 }
1142 if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) {
1143 if (inst->getNumOperands() != inst->getNumResults()) {
1144 inst->emitOpError(
1145 "number of operands and results don't match, only canonical "
1146 "TFL While supported");
1147 return llvm::None;
1148 }
1149 return BuildWhileOperator(whileOp, operands, results);
1150 }
1151
1152 inst->emitOpError("is not a supported TFLite op");
1153 return llvm::None;
1154 }
1155
1156 if (*builtin_code == tflite::BuiltinOperator_CALL_ONCE) {
1157 if (auto initOp = dyn_cast<mlir::TFL::CallOnceOp>(inst)) {
1158 return BuildCallOnceOperator(initOp, operands, results);
1159 }
1160 }
1161
1162 std::string op_name = inst->getName().getStringRef().str();
1163 uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code);
1164
1165 // If this is TransposeConv we need to do a special case of ignoring the
1166 // optional tensor, to allow newly created models to run on old runtimes.
1167 if (*builtin_code == tflite::BuiltinOperator_TRANSPOSE_CONV) {
1168 if (operands.size() == 4 && operands.at(3) == -1) {
1169 operands.pop_back();
1170 }
1171 }
1172
1173 auto offset = CreateFlatBufferOperator(inst, opcode_index, operands,
1174 results, intermediates, &builder_);
1175 if (!offset) {
1176 inst->emitOpError("is not a supported TFLite op");
1177 }
1178 return offset;
1179 }
1180
1181 if (dialect == tf_dialect_) {
1182 if (auto ifOp = dyn_cast<mlir::TF::IfOp>(inst)) {
1183 return BuildIfOperator(ifOp, operands, results);
1184 }
1185
1186 CustomOptionsOffset custom_options;
1187
1188 // Ops in TF dialect can either be custom ops or flex ops.
1189 // The reason we go directly from TensorFlow dialect MLIR to tensorflow
1190 // node instead of going to TF table gen'd ops via generated code is that
1191 // we do not want to restrict custom and flex op conversion support to
1192 // only those TF ops that are currently registered in MLIR. The current
1193 // model is of an open op system.
1194 //
1195 // The following algorithm is followed:
1196 // if flex is enabled and the op is allowlisted as flex
1197 // we emit op as flex.
1198 // if custom is enabled
1199 // we emit the op as custom.
1200 auto node_def = GetTensorFlowNodeDef(inst);
1201 if (!node_def) {
1202 return llvm::None;
1203 }
1204
1205 std::string op_name = node_def->op();
1206 std::string op_desc = GetOpDescriptionForDebug(inst);
1207
1208 if (IsTFResourceOp(inst)) {
1209 resource_ops_[op_name].insert(op_desc);
1210 }
1211
1212 const bool is_allowed_flex_op =
1213 IsAllowlistedFlexOp(node_def->op()) ||
1214 (((select_user_tf_ops_.count(node_def->op()) != 0) ||
1215 allow_all_select_tf_ops_) &&
1216 (tensorflow::OpRegistry::Global()->LookUp(node_def->op()) != nullptr));
1217 // Flex op case
1218 // Eventually, the allowlist will go away and we will rely on some TF op
1219 // trait (e.g. No side effect) to determine if it is a supported "Flex"
1220 // op or not.
1221 if (is_allowed_flex_op && enabled_op_types_.contains(OpType::kSelectTf)) {
1222 // Construct ops as flex op encoding TensorFlow node definition
1223 // as custom options.
1224 // Flex ops are named with the kFlexOpNamePrefix prefix to the actual
1225 // TF op name.
1226 op_name = std::string(kFlexOpNamePrefix) + node_def->op();
1227 if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) {
1228 custom_options = *options;
1229 } else {
1230 return llvm::None;
1231 }
1232
1233 // Gather flex ops.
1234 flex_ops_[op_name].insert(op_desc);
1235 } else if (enabled_op_types_.contains(OpType::kCustomOp)) {
1236 // Generic case of custom ops - write using flex buffers since that
1237 // is the only custom options supported by TFLite today.
1238 op_name = node_def->op();
1239 if (auto options =
1240 CreateCustomOpCustomOptions(*node_def, inst->getLoc())) {
1241 custom_options = *options;
1242 } else {
1243 return llvm::None;
1244 }
1245
1246 // Gather custom ops.
1247 custom_ops_[op_name].insert(op_desc);
1248 } else {
1249 // Insert failed op to `flex_ops` or `custom_ops`.
1250 if (is_allowed_flex_op) {
1251 failed_flex_ops_[op_name].insert(op_desc);
1252 tfl::AttachErrorCode(
1253 inst->emitOpError("is neither a custom op nor a flex op"),
1254 tflite::metrics::ConverterErrorData::ERROR_NEEDS_FLEX_OPS);
1255 } else {
1256 failed_custom_ops_[op_name].insert(op_desc);
1257 tfl::AttachErrorCode(
1258 inst->emitOpError("is neither a custom op nor a flex op"),
1259 tflite::metrics::ConverterErrorData::ERROR_NEEDS_CUSTOM_OPS);
1260 }
1261 return llvm::None;
1262 }
1263
1264 uint32_t opcode_index =
1265 GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM);
1266 auto inputs = builder_.CreateVector(operands);
1267 auto outputs = builder_.CreateVector(results);
1268
1269 return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
1270 tflite::BuiltinOptions_NONE,
1271 /*builtin_options=*/0,
1272 /*custom_options=*/custom_options,
1273 tflite::CustomOptionsFormat_FLEXBUFFERS,
1274 /*mutating_variable_inputs=*/0);
1275 }
1276
1277 return inst->emitOpError(
1278 "is not any of a builtin TFLite op, a flex TensorFlow op or a "
1279 "custom TensorFlow op"),
1280 llvm::None;
1281 }
1282
InitializeNamesFromAttribute(FuncOp fn,bool * has_input_attr)1283 void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
1284 auto dict_attr = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1285 if (!dict_attr) return;
1286
1287 llvm::SmallVector<llvm::StringRef, 2> input_names;
1288 llvm::SmallVector<llvm::StringRef, 2> output_names;
1289 if (auto str = dict_attr.get("inputs").dyn_cast_or_null<mlir::StringAttr>()) {
1290 str.getValue().split(input_names, ',', /*MaxSplit=*/-1,
1291 /*KeepEmpty=*/false);
1292 if (input_names.size() != fn.getNumArguments()) {
1293 fn.emitWarning() << "invalid entry function specification";
1294 return;
1295 }
1296 for (auto it : llvm::enumerate(fn.getArguments())) {
1297 name_mapper_.InitOpName(it.value(), input_names[it.index()].trim());
1298 }
1299 *has_input_attr = true;
1300 }
1301
1302 if (auto str =
1303 dict_attr.get("outputs").dyn_cast_or_null<mlir::StringAttr>()) {
1304 str.getValue().split(output_names, ',', /*MaxSplit=*/-1,
1305 /*KeepEmpty=*/false);
1306 auto term = fn.back().getTerminator();
1307 if (output_names.size() != term->getNumOperands()) {
1308 fn.emitWarning() << "output names (" << output_names.size()
1309 << ") != terminator operands (" << term->getNumOperands()
1310 << ")";
1311 return;
1312 }
1313 for (const auto& it : llvm::enumerate(term->getOperands())) {
1314 name_mapper_.InitOpName(it.value(), output_names[it.index()].trim());
1315 }
1316 }
1317 }
1318
IsStatefulOperand(mlir::Operation * op,int operand_index)1319 bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
1320 std::vector<int> operand_indices;
1321 if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false;
1322 return absl::c_find(operand_indices, operand_index) != operand_indices.end();
1323 }
1324
1325 BufferOffset<tflite::QuantizationParameters>
GetQuantizationForQuantStatsOpOutput(mlir::quant::StatisticsOp stats_op)1326 Translator::GetQuantizationForQuantStatsOpOutput(
1327 mlir::quant::StatisticsOp stats_op) {
1328 auto layer_stats = stats_op.layerStats().cast<mlir::DenseFPElementsAttr>();
1329 Optional<mlir::ElementsAttr> axis_stats = stats_op.axisStats();
1330 Optional<uint64_t> axis = stats_op.axis();
1331 std::vector<float> mins, maxs;
1332 mlir::DenseFPElementsAttr min_max_attr =
1333 axis_stats.hasValue()
1334 ? axis_stats.getValue().cast<mlir::DenseFPElementsAttr>()
1335 : layer_stats;
1336
1337 for (auto index_and_value : llvm::enumerate(min_max_attr.getFloatValues())) {
1338 const llvm::APFloat value = index_and_value.value();
1339 if (index_and_value.index() % 2 == 0) {
1340 mins.push_back(value.convertToFloat());
1341 } else {
1342 maxs.push_back(value.convertToFloat());
1343 }
1344 }
1345
1346 return tflite::CreateQuantizationParameters(
1347 builder_, builder_.CreateVector<float>(mins),
1348 builder_.CreateVector<float>(maxs), /*scale=*/0, /*zero_point=*/0,
1349 tflite::QuantizationDetails_NONE, /*details=*/0,
1350 /*quantized_dimension=*/axis.hasValue() ? axis.getValue() : 0);
1351 }
1352
BuildSubGraph(const std::string & name,Region * region,const int index)1353 Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
1354 const std::string& name, Region* region, const int index) {
1355 bool has_input_attr = false;
1356 if (auto fn = dyn_cast<FuncOp>(region->getParentOp())) {
1357 InitializeNamesFromAttribute(fn, &has_input_attr);
1358 }
1359 std::vector<BufferOffset<tflite::Tensor>> tensors;
1360 llvm::DenseMap<Value, int> tensor_index_map;
1361
1362 // Builds tensor and buffer for argument or operation result. Returns false
1363 // on failure.
1364 auto build_tensor_and_buffer = [&](Value value, const int subgraph_index,
1365 const std::string& tensor_name) {
1366 // NoneType represents optional and may be skipped here.
1367 if (value.getType().isa<NoneType>()) {
1368 return true;
1369 }
1370
1371 tensor_index_map.insert({value, tensors.size()});
1372 tensor_index_map_[subgraph_index][tensor_name] = tensors.size();
1373 Optional<BufferOffset<tflite::QuantizationParameters>> quant_parameters;
1374 if (value.hasOneUse()) {
1375 auto stats_op =
1376 llvm::dyn_cast<mlir::quant::StatisticsOp>(*value.user_begin());
1377 if (stats_op) {
1378 quant_parameters = GetQuantizationForQuantStatsOpOutput(stats_op);
1379 }
1380 }
1381 auto tensor_or =
1382 BuildTensor(value, tensor_name, buffers_.size(), quant_parameters);
1383 if (!tensor_or) return false;
1384 tensors.push_back(*tensor_or);
1385
1386 // TODO(ashwinm): Check if for stateful tensors, if it is also needed to
1387 // make the Buffer empty apart from setting the buffer_idx=0 in the
1388 // Tensor. This does not seem to affect runtime behavior for RNN/LSTM,
1389 // but would be good for reducing memory footprint.
1390 if (auto* inst = value.getDefiningOp()) {
1391 auto buffer_or = BuildBuffer(inst);
1392 if (!buffer_or) return false;
1393 buffers_.push_back(*buffer_or);
1394 } else {
1395 buffers_.push_back(empty_buffer_);
1396 }
1397 return true;
1398 };
1399
1400 std::vector<BufferOffset<tflite::Operator>> operators;
1401 auto& bb = region->front();
1402
1403 // Main function's arguments are first passed to `input` op so they don't
1404 // have associated tensor and buffer. Build FlatBuffer tensor and buffer for
1405 // other functions.
1406 for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) {
1407 mlir::BlockArgument arg = bb.getArgument(i);
1408 std::string tensor_name;
1409 if (has_input_attr)
1410 tensor_name = std::string(name_mapper_.GetUniqueName(arg));
1411 if (tensor_name.empty()) tensor_name = absl::StrCat("arg", i);
1412 if (!build_tensor_and_buffer(arg, index, tensor_name)) return llvm::None;
1413 }
1414
1415 bool failed_once = false;
1416 for (auto& inst : bb) {
1417 if (inst.hasTrait<mlir::OpTrait::IsTerminator>()) break;
1418 // For "quant.stats" op, it's used to store the quantization parameters info
1419 // and its output should be then replaced by its input value.
1420 if (auto quant_stats_op = llvm::dyn_cast<mlir::quant::StatisticsOp>(inst)) {
1421 continue;
1422 }
1423 std::vector<int32_t> intermediates;
1424 // Build intermediate tensors for tfl.lstm and insert these tensors into
1425 // flatbuffer.
1426 if (llvm::isa<mlir::TFL::LSTMOp, mlir::TFL::UnidirectionalSequenceLSTMOp>(
1427 inst)) {
1428 std::vector<std::string> intermediate_names = {
1429 "input_to_input_intermediate", "input_to_forget_intermediate",
1430 "input_to_cell_intermediate", "input_to_output_intermediate",
1431 "effective_hidden_scale_intermediate"};
1432 for (const std::string& intermediate : intermediate_names) {
1433 auto intermediate_attr = inst.getAttr(intermediate);
1434 if (auto attr = intermediate_attr.dyn_cast_or_null<mlir::TypeAttr>()) {
1435 Type qtype = attr.getValue();
1436 auto tensor_or = BuildTensorFromType(
1437 qtype, name_mapper_.GetUniqueName(intermediate).str());
1438 if (!tensor_or.hasValue()) {
1439 continue;
1440 } else {
1441 intermediates.push_back(tensors.size());
1442 tensors.push_back(tensor_or.getValue());
1443 }
1444 }
1445 }
1446 }
1447
1448 for (auto val : inst.getResults()) {
1449 std::string tensor_name = UniqueName(val);
1450 // For "tfl.numeric_verify" op, the name is used to find out the original
1451 // activation tensor rather than its own unique name in the visualization
1452 // or debugging tools.
1453 auto builtin_code = GetBuiltinOpCode(&inst);
1454 if (!builtin_code && dyn_cast<mlir::TFL::NumericVerifyOp>(&inst)) {
1455 // The first operand is the quantized activation, the target of this
1456 // NumericVerify op.
1457 auto quantized_op_val = inst.getOperands().front();
1458 tensor_name = "NumericVerify/" + UniqueName(quantized_op_val) + ":" +
1459 std::to_string(tensor_index_map[quantized_op_val]);
1460 }
1461 if (!build_tensor_and_buffer(val, index, tensor_name)) return llvm::None;
1462 }
1463
1464 // Skip constant ops as they don't represent a TFLite operator.
1465 if (IsConst(&inst)) continue;
1466
1467 // Fetch operand and result tensor indices.
1468 std::vector<int32_t> results;
1469 results.reserve(inst.getNumResults());
1470 for (auto result : inst.getResults()) {
1471 results.push_back(tensor_index_map.lookup(result));
1472 }
1473 Operation* real_inst = &inst;
1474 std::vector<int32_t> operands;
1475 operands.reserve(real_inst->getNumOperands());
1476 for (auto operand : real_inst->getOperands()) {
1477 if (operand.getType().isa<NoneType>())
1478 operands.push_back(kTfLiteOptionalTensor);
1479 else if (auto stats_op =
1480 llvm::dyn_cast_or_null<mlir::quant::StatisticsOp>(
1481 operand.getDefiningOp()))
1482 operands.push_back(tensor_index_map.lookup(stats_op.arg()));
1483 else
1484 operands.push_back(tensor_index_map.lookup(operand));
1485 }
1486
1487 // CustomTfOp is just a wrapper around a TF op, we export the custom Op
1488 // not the wrapper, so we fetch the op from the region.
1489 if (auto custom_op = dyn_cast<mlir::TFL::CustomTfOp>(inst)) {
1490 // If we have custom op with a region, then use the first op in the
1491 // region, if it exists, otherwise just use params for custom op.
1492 if (!custom_op.body().empty()) {
1493 real_inst = &custom_op.body().front().front();
1494 } else {
1495 module_.emitError(
1496 "Invalid CustomTfOp: Custom TF Op have empty region.");
1497 }
1498 }
1499
1500 if (auto tfl_operator =
1501 BuildOperator(real_inst, operands, results, intermediates))
1502 operators.push_back(*tfl_operator);
1503 else
1504 failed_once = true;
1505 }
1506
1507 if (failed_once) return llvm::None;
1508
1509 // Get input and output tensor indices for the subgraph.
1510 std::vector<int32_t> inputs, outputs;
1511 for (auto arg : bb.getArguments()) {
1512 inputs.push_back(tensor_index_map[arg]);
1513 }
1514 for (auto result : bb.getTerminator()->getOperands()) {
1515 outputs.push_back(tensor_index_map[result]);
1516 }
1517
1518 return tflite::CreateSubGraph(
1519 builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs),
1520 builder_.CreateVector(outputs), builder_.CreateVector(operators),
1521 /*name=*/builder_.CreateString(name));
1522 }
1523
BuildMetadata(StringRef name,StringRef content)1524 BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
1525 StringRef content) {
1526 auto buffer_index = buffers_.size();
1527 auto buffer_data = builder_.CreateVector(
1528 reinterpret_cast<const uint8_t*>(content.data()), content.size());
1529 buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data));
1530 return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index);
1531 }
1532
1533 Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
CreateMetadataVector()1534 Translator::CreateMetadataVector() {
1535 auto dict_attr = module_->getAttrOfType<mlir::DictionaryAttr>("tfl.metadata");
1536 std::vector<BufferOffset<tflite::Metadata>> metadata;
1537 if (dict_attr) {
1538 for (const auto& named_attr : dict_attr) {
1539 StringRef name = named_attr.first;
1540 mlir::Attribute attr = named_attr.second;
1541 if (auto content = attr.dyn_cast<StringAttr>()) {
1542 metadata.push_back(BuildMetadata(name, content.getValue()));
1543 } else {
1544 module_.emitError(
1545 "all values in tfl.metadata's dictionary key-value pairs should be "
1546 "string attributes");
1547 return llvm::None;
1548 }
1549 }
1550 }
1551 // Runtime version string is generated after we update the op
1552 // versions. Here we put a 16-byte dummy string as a placeholder. We choose
1553 // 16-byte because it's the alignment of buffers in flatbuffer, so it won't
1554 // cause any waste of space if the actual string is shorter than 16 bytes.
1555 constexpr std::size_t kByteStringSize = 16;
1556 metadata.push_back(
1557 BuildMetadata("min_runtime_version", std::string(kByteStringSize, '\0')));
1558 for (const auto& kv : metadata_) {
1559 const std::string& val = kv.second;
1560 // Only take the first kByteStringSize values.
1561 const int count = std::min(kByteStringSize, val.length());
1562 std::string value = std::string(kByteStringSize, '\0')
1563 .assign(val.begin(), val.begin() + count);
1564 metadata.push_back(BuildMetadata(kv.first, value));
1565 }
1566 return builder_.CreateVector(metadata);
1567 }
1568
1569 // Helper method that returns list of all strings in a StringAttr identified
1570 // by 'attr_key' and values are separated by a comma.
GetStringsFromAttrWithSeparator(mlir::DictionaryAttr attr,const std::string & attr_key)1571 llvm::SmallVector<llvm::StringRef, 2> GetStringsFromAttrWithSeparator(
1572 mlir::DictionaryAttr attr, const std::string& attr_key) {
1573 llvm::SmallVector<llvm::StringRef, 2> result;
1574 if (auto str = attr.get(attr_key).dyn_cast_or_null<mlir::StringAttr>()) {
1575 str.getValue().split(result, ',', /*MaxSplit=*/-1,
1576 /*KeepEmpty=*/false);
1577 }
1578 return result;
1579 }
1580
1581 // Helper method that return list of string for all the StringAttr in the
1582 // Attribute identified by 'attr_name'.
GetStringsFromDictionaryAttr(const llvm::SmallVector<mlir::DictionaryAttr,4> & dict_attrs,const std::string & attr_name)1583 std::vector<std::string> GetStringsFromDictionaryAttr(
1584 const llvm::SmallVector<mlir::DictionaryAttr, 4>& dict_attrs,
1585 const std::string& attr_name) {
1586 std::vector<std::string> result;
1587 for (const auto& arg_attr : dict_attrs) {
1588 if (!arg_attr) continue;
1589
1590 auto attrs = arg_attr.getValue();
1591 for (const auto attr : attrs) {
1592 if (attr.first.str() == attr_name) {
1593 auto array_attr = attr.second.dyn_cast_or_null<mlir::ArrayAttr>();
1594 if (!array_attr || array_attr.empty()) continue;
1595 auto string_attr = array_attr[0].dyn_cast_or_null<mlir::StringAttr>();
1596 if (!string_attr) continue;
1597 result.push_back(string_attr.getValue().str());
1598 }
1599 }
1600 }
1601 return result;
1602 }
1603
BuildSignaturedef(FuncOp main_op,const std::string & saved_model_tag,const uint32_t subgraph_index,tensorflow::OpOrArgNameMapper & name_mapper)1604 std::vector<SignatureDefData> BuildSignaturedef(
1605 FuncOp main_op, const std::string& saved_model_tag,
1606 const uint32_t subgraph_index, tensorflow::OpOrArgNameMapper& name_mapper) {
1607 static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path";
1608 static const char kEntryFunctionAttributes[] = "tf.entry_function";
1609
1610 // Fetch inputs and outputs from the signature.
1611 llvm::SmallVector<mlir::DictionaryAttr, 4> arg_attrs, res_attrs;
1612 main_op.getAllArgAttrs(arg_attrs);
1613 main_op.getAllResultAttrs(res_attrs);
1614 std::vector<std::string> sig_def_inputs =
1615 GetStringsFromDictionaryAttr(arg_attrs, kSignatureDefIndexPath);
1616 std::vector<std::string> sig_def_outputs =
1617 GetStringsFromDictionaryAttr(res_attrs, kSignatureDefIndexPath);
1618
1619 // If no defined saved model signature, then return empty list.
1620 // This can happen when we are converting model not from SavedModel.
1621 if (sig_def_inputs.empty() && sig_def_outputs.empty()) return {};
1622
1623 // Fetch function inputs and outputs tensor names.
1624 auto dict_attr =
1625 main_op->getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes);
1626 if (!dict_attr) return {};
1627
1628 // Get Input and output tensor names from attribute.
1629 llvm::SmallVector<llvm::StringRef, 2> input_names =
1630 GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs");
1631 llvm::SmallVector<llvm::StringRef, 2> output_names =
1632 GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs");
1633
1634 // Verify input size match the number of arguments.
1635 if (input_names.size() != main_op.getNumArguments()) {
1636 main_op.emitWarning() << "invalid entry function specification";
1637 return {};
1638 }
1639 // Verify output size match the number of arguments.
1640 auto term = main_op.back().getTerminator();
1641 if (output_names.size() != term->getNumOperands()) {
1642 main_op.emitWarning() << "output names (" << output_names.size()
1643 << ") != terminator operands ("
1644 << term->getNumOperands() << ")";
1645 return {};
1646 }
1647 // Verify number of tensors for inputs and outputs matches size
1648 // of the list in the signature def.
1649 if (input_names.size() != sig_def_inputs.size() ||
1650 output_names.size() != sig_def_outputs.size()) {
1651 main_op.emitWarning(
1652 "Mismatch between signature def inputs/outputs and main function "
1653 "arguments.");
1654 return {};
1655 }
1656 // Exported method name.
1657 auto exported_name =
1658 main_op->getAttrOfType<mlir::ArrayAttr>("tf_saved_model.exported_names");
1659 if (exported_name.empty()) {
1660 main_op.emitError("Empty exported names for main Function");
1661 return {};
1662 }
1663 // Fill the SignatureDefData container.
1664 // We create vector of size 1 as TFLite now supports only 1 signatureDef.
1665 std::vector<SignatureDefData> result(1);
1666 for (int i = 0; i < input_names.size(); ++i) {
1667 result[0].inputs[sig_def_inputs[i]] = input_names[i].str();
1668 }
1669 for (int i = 0; i < output_names.size(); ++i) {
1670 // Fetch the name from the actual operand and not rely on names from
1671 // outputs as deduping can make them invalid after conversion.
1672 auto& operand = term->getOpOperand(i);
1673 auto unique_name = std::string(name_mapper.GetUniqueName(operand.get()));
1674 result[0].outputs[sig_def_outputs[i]] = unique_name;
1675 }
1676 if (auto name_attr = exported_name[0].dyn_cast_or_null<StringAttr>())
1677 result[0].signature_key = name_attr.getValue().str();
1678 result[0].subgraph_index = subgraph_index;
1679 return result;
1680 }
1681
GetList(const int subgraph_index,const std::map<std::string,std::string> & items)1682 std::vector<BufferOffset<tflite::TensorMap>> Translator::GetList(
1683 const int subgraph_index, const std::map<std::string, std::string>& items) {
1684 std::vector<BufferOffset<tflite::TensorMap>> result;
1685 for (const auto& item : items) {
1686 auto name_buf = builder_.CreateString(item.first);
1687 tflite::TensorMapBuilder tensor_map_builder(builder_);
1688 tensor_map_builder.add_name(name_buf);
1689 tensor_map_builder.add_tensor_index(
1690 tensor_index_map_[subgraph_index][item.second]);
1691 result.push_back(tensor_map_builder.Finish());
1692 }
1693 return result;
1694 }
1695
1696 Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
CreateSignatureDefs(const std::vector<SignatureDefData> & signature_defs)1697 Translator::CreateSignatureDefs(
1698 const std::vector<SignatureDefData>& signature_defs) {
1699 std::vector<BufferOffset<tflite::SignatureDef>> signature_defs_buffer;
1700 // When we export each function in the module op, intentionally, we export the
1701 // entry functions at the beginning of the subgraph list and the
1702 // subgraph_index is the index in entry functions and at the same, is the
1703 // index in the subgraph list.
1704 int subgraph_index = 0;
1705 for (const auto& signature_def_data : signature_defs) {
1706 auto inputs = GetList(subgraph_index, signature_def_data.inputs);
1707 auto outputs = GetList(subgraph_index, signature_def_data.outputs);
1708 auto inputs_buf = builder_.CreateVector(inputs);
1709 auto outputs_buf = builder_.CreateVector(outputs);
1710 auto signature_key_buf =
1711 builder_.CreateString(signature_def_data.signature_key);
1712 tflite::SignatureDefBuilder sig_def_builder(builder_);
1713 sig_def_builder.add_inputs(inputs_buf);
1714 sig_def_builder.add_outputs(outputs_buf);
1715 sig_def_builder.add_signature_key(signature_key_buf);
1716 sig_def_builder.add_subgraph_index(signature_def_data.subgraph_index);
1717 signature_defs_buffer.push_back(sig_def_builder.Finish());
1718 ++subgraph_index;
1719 }
1720
1721 return builder_.CreateVector(signature_defs_buffer);
1722 }
1723
UpdateEntryFunction(ModuleOp module)1724 bool UpdateEntryFunction(ModuleOp module) {
1725 if (module.lookupSymbol<FuncOp>("main") != nullptr) {
1726 // We already have an entry function.
1727 return true;
1728 }
1729
1730 int entry_func_count = 0;
1731 FuncOp entry_func = nullptr;
1732 for (auto fn : module.getOps<FuncOp>()) {
1733 auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1734 if (!attrs || attrs.empty()) continue;
1735 ++entry_func_count;
1736 entry_func = fn;
1737 }
1738
1739 // We should have at least one entry function.
1740 if (entry_func_count == 0) return false;
1741
1742 if (entry_func_count == 1) {
1743 // Update the entry func to main when the entry func is only & one.
1744 entry_func.setName("main");
1745 }
1746 return true;
1747 }
1748
Translate(ModuleOp module,bool emit_builtin_tflite_ops,bool emit_select_tf_ops,bool emit_custom_ops,bool allow_all_select_tf_ops,const std::unordered_set<std::string> & select_user_tf_ops,const std::unordered_set<std::string> & tags,OpOrArgNameMapper * op_or_arg_name_mapper,const std::map<std::string,std::string> & metadata)1749 Optional<std::string> Translator::Translate(
1750 ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
1751 bool emit_custom_ops, bool allow_all_select_tf_ops,
1752 const std::unordered_set<std::string>& select_user_tf_ops,
1753 const std::unordered_set<std::string>& tags,
1754 OpOrArgNameMapper* op_or_arg_name_mapper,
1755 const std::map<std::string, std::string>& metadata) {
1756 OpOrArgLocNameMapper default_op_or_arg_name_mapper;
1757 if (!op_or_arg_name_mapper)
1758 op_or_arg_name_mapper = &default_op_or_arg_name_mapper;
1759 if (!UpdateEntryFunction(module)) return llvm::None;
1760 if (!IsValidTFLiteMlirModule(module)) return llvm::None;
1761 Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
1762 emit_custom_ops, allow_all_select_tf_ops,
1763 select_user_tf_ops, tags, op_or_arg_name_mapper,
1764 metadata);
1765 return translator.TranslateInternal();
1766 }
1767
TranslateInternal()1768 Optional<std::string> Translator::TranslateInternal() {
1769 // A list of named regions in the module with main function being the first in
1770 // the list. The main function is required as the first subgraph in the model
1771 // is entry point for the model.
1772 std::vector<std::pair<std::string, Region*>> named_regions;
1773 named_regions.reserve(std::distance(module_.begin(), module_.end()));
1774
1775 int subgraph_idx = 0;
1776
1777 // Entry functions for signature defs.
1778 std::vector<FuncOp> entry_functions;
1779 std::vector<FuncOp> non_entry_functions;
1780 FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
1781 if (main_fn != nullptr) {
1782 // Treat the main function as a signature def when the given main function
1783 // contains on the tf.entry_function attribute.
1784 auto attrs =
1785 main_fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1786 if (attrs && !attrs.empty()) {
1787 entry_functions.push_back(main_fn);
1788 } else {
1789 non_entry_functions.push_back(main_fn);
1790 }
1791 }
1792
1793 // Walk over the module collection ops with functions and while ops.
1794 module_.walk([&](FuncOp fn) {
1795 if (main_fn == fn) return WalkResult::advance();
1796 auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1797 if (attrs && !attrs.empty()) {
1798 entry_functions.push_back(fn);
1799 } else {
1800 non_entry_functions.push_back(fn);
1801 }
1802 return WalkResult::advance();
1803 });
1804
1805 // Assign the subgraph index. Among the given functions, it will put entry
1806 // functions at the beginning of the list of the subgrahs.
1807 for (auto fn : entry_functions) {
1808 subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
1809 named_regions.emplace_back(fn.getName().str(), &fn.getBody());
1810 }
1811 for (auto fn : non_entry_functions) {
1812 subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
1813 named_regions.emplace_back(fn.getName().str(), &fn.getBody());
1814 }
1815
1816 // Build subgraph for each of the named regions.
1817 std::vector<BufferOffset<tflite::SubGraph>> subgraphs;
1818 subgraphs.reserve(named_regions.size());
1819 int first_failed_func = -1;
1820
1821 // When we export each function in the module op, intentionally, we export the
1822 // entry functions at the beginning of the subgraph list and the
1823 // subgraph_index is the index in entry functions and at the same, is the
1824 // index in the subgraph list.
1825 int subgraph_index = 0;
1826 for (auto it : llvm::enumerate(named_regions)) {
1827 auto subgraph_or =
1828 BuildSubGraph(it.value().first, it.value().second, subgraph_index);
1829 if (!subgraph_or) {
1830 if (first_failed_func == -1)
1831 // Record the index of the first region that cannot be converted.
1832 // Keep looping through all subgraphs in the module to make sure that
1833 // we collect the list of missing ops from the entire module.
1834 first_failed_func = it.index();
1835 } else {
1836 subgraphs.push_back(*subgraph_or);
1837 ++subgraph_index;
1838 }
1839 }
1840
1841 if (!resource_ops_.empty()) {
1842 std::string resource_ops_summary =
1843 GetOpsSummary(resource_ops_, /*summary_title=*/"Resource");
1844 LOG(WARNING) << "Graph contains the following resource op(s), that use(s) "
1845 "resource type. Currently, the "
1846 "resource type is not natively supported in TFLite. Please "
1847 "consider not using the resource type if there are issues "
1848 "with either TFLite converter or TFLite runtime:\n"
1849 << resource_ops_summary;
1850 }
1851
1852 if (!flex_ops_.empty()) {
1853 std::string flex_ops_summary =
1854 GetOpsSummary(flex_ops_, /*summary_title=*/"Flex");
1855 LOG(WARNING) << "TFLite interpreter needs to link Flex delegate in order "
1856 "to run the model since it contains the following Select TF"
1857 "op(s):\n"
1858 << flex_ops_summary
1859 << "\nSee instructions: "
1860 "https://www.tensorflow.org/lite/guide/ops_select";
1861 }
1862
1863 if (!custom_ops_.empty()) {
1864 std::string custom_ops_summary =
1865 GetOpsSummary(custom_ops_, /*summary_title=*/"Custom");
1866 LOG(WARNING) << "The following operation(s) need TFLite custom op "
1867 "implementation(s):\n"
1868 << custom_ops_summary
1869 << "\nSee instructions: "
1870 "https://www.tensorflow.org/lite/guide/ops_custom";
1871 }
1872
1873 if (first_failed_func != -1) {
1874 std::string failed_flex_ops_summary =
1875 GetOpsSummary(failed_flex_ops_, /*summary_title=*/"TF Select");
1876 std::string failed_custom_ops_summary =
1877 GetOpsSummary(failed_custom_ops_, /*summary_title=*/"Custom");
1878 std::string err;
1879 if (!failed_flex_ops_.empty())
1880 err +=
1881 "\nSome ops are not supported by the native TFLite runtime, you can "
1882 "enable TF kernels fallback using TF Select. See instructions: "
1883 "https://www.tensorflow.org/lite/guide/ops_select \n" +
1884 failed_flex_ops_summary + "\n";
1885 if (!failed_custom_ops_.empty())
1886 err +=
1887 "\nSome ops in the model are custom ops, "
1888 "See instructions to implement "
1889 "custom ops: https://www.tensorflow.org/lite/guide/ops_custom \n" +
1890 failed_custom_ops_summary + "\n";
1891
1892 auto& failed_region = named_regions[first_failed_func];
1893 return failed_region.second->getParentOp()->emitError()
1894 << "failed while converting: '" << failed_region.first
1895 << "': " << err,
1896 llvm::None;
1897 }
1898
1899 // Log MAC count.
1900 int64_t ops_count;
1901 if (EstimateArithmeticCount(&ops_count)) {
1902 const int64_t million = 1e6;
1903 const int64_t billion = 1e9;
1904 std::string flops_str;
1905 std::string mac_str;
1906 if (ops_count < 10000) {
1907 flops_str = absl::StrFormat("%ld ", ops_count);
1908 mac_str = absl::StrFormat("%ld ", ops_count / 2);
1909 } else if (ops_count < billion) {
1910 flops_str =
1911 absl::StrFormat("%.3f M ", static_cast<double>(ops_count) / million);
1912 mac_str = absl::StrFormat("%.3f M ",
1913 static_cast<double>(ops_count / 2) / million);
1914 } else {
1915 flops_str =
1916 absl::StrFormat("%.3f G ", static_cast<double>(ops_count) / billion);
1917 mac_str = absl::StrFormat("%.3f G ",
1918 static_cast<double>(ops_count / 2) / billion);
1919 }
1920 std::string mac_out_str;
1921 llvm::raw_string_ostream os(mac_out_str);
1922 os << "Estimated count of arithmetic ops: " << flops_str
1923 << " ops, equivalently " << mac_str << " MACs"
1924 << "\n";
1925 os.flush();
1926 LOG(INFO) << mac_out_str;
1927 std::cout << mac_out_str;
1928 }
1929
1930 std::string model_description;
1931 if (auto attr = module_->getAttrOfType<StringAttr>("tfl.description")) {
1932 model_description = attr.getValue().str();
1933 } else {
1934 model_description = "MLIR Converted.";
1935 }
1936
1937 // Build the model and finish the model building process.
1938 auto description = builder_.CreateString(model_description.data());
1939 VectorBufferOffset<int32_t> metadata_buffer = 0; // Deprecated
1940 auto metadata = CreateMetadataVector();
1941 if (!metadata) return llvm::None;
1942
1943 std::vector<SignatureDefData> signature_defs_vec;
1944 subgraph_index = 0;
1945 // Build SignatureDefs for the tf.entry_function based func ops.
1946 for (auto fn : entry_functions) {
1947 auto signature_defs = BuildSignaturedef(
1948 fn, saved_model_tags_.empty() ? "" : *saved_model_tags_.begin(),
1949 subgraph_index, name_mapper_);
1950 for (const auto& signature_def : signature_defs) {
1951 signature_defs_vec.push_back(signature_def);
1952 }
1953 // When we export each function in the module op, intentionally, we export
1954 // the entry functions at the beginning of the subgraph list and the
1955 // subgraph_index is the index in entry functions and at the same, is the
1956 // index in the subgraph list.
1957 ++subgraph_index;
1958 }
1959 auto signature_defs = CreateSignatureDefs(signature_defs_vec);
1960
1961 auto model = tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
1962 builder_.CreateVector(opcodes_),
1963 builder_.CreateVector(subgraphs),
1964 description, builder_.CreateVector(buffers_),
1965 metadata_buffer, *metadata, *signature_defs);
1966 tflite::FinishModelBuffer(builder_, model);
1967 tflite::UpdateOpVersion(builder_.GetBufferPointer());
1968 tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer());
1969
1970 // Return serialized string for the built FlatBuffer.
1971 return std::string(reinterpret_cast<const char*>(builder_.GetBufferPointer()),
1972 builder_.GetSize());
1973 }
1974
BuildSparsityParameters(const mlir::TFL::SparsityParameterAttr & s_attr)1975 BufferOffset<tflite::SparsityParameters> Translator::BuildSparsityParameters(
1976 const mlir::TFL::SparsityParameterAttr& s_attr) {
1977 const int dim_size = s_attr.dim_metadata().size();
1978 std::vector<flatbuffers::Offset<tflite::DimensionMetadata>> fb_dim_metadata(
1979 dim_size);
1980 for (int i = 0; i < dim_size; i++) {
1981 const auto dim_metadata =
1982 s_attr.dim_metadata()[i].dyn_cast<mlir::TFL::DimensionMetadataAttr>();
1983 if (dim_metadata.format().getValue() == "DENSE") {
1984 fb_dim_metadata[i] =
1985 tflite::CreateDimensionMetadata(builder_, tflite::DimensionType_DENSE,
1986 dim_metadata.dense_size().getInt());
1987
1988 } else {
1989 auto segments = dim_metadata.segments();
1990 std::vector<int> vector_segments(segments.size(), 0);
1991 for (int j = 0, end = segments.size(); j < end; j++) {
1992 vector_segments[j] = segments[j].dyn_cast<mlir::IntegerAttr>().getInt();
1993 }
1994 tflite::SparseIndexVector segments_type;
1995 BufferOffset<void> array_segments;
1996 // The segment array is sorted.
1997 // TODO(b/147449640): Clean this up with util functions.
1998 int max_of_segments = vector_segments[segments.size() - 1];
1999 if (max_of_segments <= UINT8_MAX) {
2000 segments_type = tflite::SparseIndexVector_Uint8Vector;
2001 std::vector<uint8_t> uint8_vector(vector_segments.begin(),
2002 vector_segments.end());
2003 array_segments = tflite::CreateUint8Vector(
2004 builder_, builder_.CreateVector(uint8_vector))
2005 .Union();
2006 } else if (max_of_segments <= UINT16_MAX) {
2007 segments_type = tflite::SparseIndexVector_Uint16Vector;
2008 std::vector<uint16_t> uint16_vector(vector_segments.begin(),
2009 vector_segments.end());
2010 array_segments = tflite::CreateUint16Vector(
2011 builder_, builder_.CreateVector(uint16_vector))
2012 .Union();
2013 } else {
2014 segments_type = tflite::SparseIndexVector_Int32Vector;
2015 array_segments = tflite::CreateInt32Vector(
2016 builder_, builder_.CreateVector(vector_segments))
2017 .Union();
2018 }
2019
2020 auto indices = dim_metadata.indices();
2021 std::vector<int> vector_indices(indices.size(), 0);
2022 int max_of_indices = 0;
2023 for (int j = 0, end = indices.size(); j < end; j++) {
2024 vector_indices[j] = indices[j].dyn_cast<mlir::IntegerAttr>().getInt();
2025 if (vector_indices[j] > max_of_indices) {
2026 max_of_indices = vector_indices[j];
2027 }
2028 }
2029 tflite::SparseIndexVector indices_type;
2030 BufferOffset<void> array_indices;
2031 if (max_of_indices <= UINT8_MAX) {
2032 indices_type = tflite::SparseIndexVector_Uint8Vector;
2033 std::vector<uint8_t> uint8_vector(vector_indices.begin(),
2034 vector_indices.end());
2035 array_indices = tflite::CreateUint8Vector(
2036 builder_, builder_.CreateVector(uint8_vector))
2037 .Union();
2038 } else if (max_of_indices <= UINT16_MAX) {
2039 indices_type = tflite::SparseIndexVector_Uint16Vector;
2040 std::vector<uint16_t> uint16_vector(vector_indices.begin(),
2041 vector_indices.end());
2042 array_indices = tflite::CreateUint16Vector(
2043 builder_, builder_.CreateVector(uint16_vector))
2044 .Union();
2045 } else {
2046 indices_type = tflite::SparseIndexVector_Int32Vector;
2047 array_indices = tflite::CreateInt32Vector(
2048 builder_, builder_.CreateVector(vector_indices))
2049 .Union();
2050 }
2051
2052 fb_dim_metadata[i] = tflite::CreateDimensionMetadata(
2053 builder_, tflite::DimensionType_SPARSE_CSR, 0, segments_type,
2054 array_segments, indices_type, array_indices);
2055 }
2056 }
2057
2058 std::vector<int> traversal_order(dim_size);
2059 for (int i = 0; i < dim_size; i++) {
2060 traversal_order[i] =
2061 s_attr.traversal_order()[i].dyn_cast<mlir::IntegerAttr>().getInt();
2062 }
2063 const int block_map_size = s_attr.block_map().size();
2064 std::vector<int> block_map(block_map_size);
2065 for (int i = 0; i < block_map_size; i++) {
2066 block_map[i] = s_attr.block_map()[i].dyn_cast<mlir::IntegerAttr>().getInt();
2067 }
2068
2069 return tflite::CreateSparsityParameters(
2070 builder_, builder_.CreateVector(traversal_order),
2071 builder_.CreateVector(block_map), builder_.CreateVector(fb_dim_metadata));
2072 }
2073
2074 } // namespace
2075
2076 namespace tflite {
2077 // TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting
2078 // the following:
2079 //
2080 // * Quantization
2081 // * Ops with variable tensors
2082 //
MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,const FlatbufferExportOptions & options,std::string * serialized_flatbuffer)2083 bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
2084 const FlatbufferExportOptions& options,
2085 std::string* serialized_flatbuffer) {
2086 auto maybe_translated = Translator::Translate(
2087 module, options.emit_builtin_tflite_ops, options.emit_select_tf_ops,
2088 options.emit_custom_ops, options.allow_all_select_tf_ops,
2089 options.select_user_tf_ops, options.saved_model_tags,
2090 options.op_or_arg_name_mapper, options.metadata);
2091 if (!maybe_translated) return false;
2092 *serialized_flatbuffer = std::move(*maybe_translated);
2093 return true;
2094 }
2095
2096 } // namespace tflite
2097