• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
17 
18 #include <stddef.h>
19 #include <stdlib.h>
20 
21 #include <cstdint>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/base/attributes.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/strings/match.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/strings/string_view.h"
35 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
36 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
37 #include "llvm/ADT/ArrayRef.h"
38 #include "llvm/ADT/DenseMap.h"
39 #include "llvm/ADT/None.h"
40 #include "llvm/ADT/Optional.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/StringRef.h"
43 #include "llvm/Support/Casting.h"
44 #include "llvm/Support/CommandLine.h"
45 #include "llvm/Support/FormatVariadic.h"
46 #include "llvm/Support/ToolOutputFile.h"
47 #include "llvm/Support/raw_ostream.h"
48 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
49 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
50 #include "mlir/IR/Attributes.h"  // from @llvm-project
51 #include "mlir/IR/Builders.h"  // from @llvm-project
52 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
53 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
54 #include "mlir/IR/Location.h"  // from @llvm-project
55 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
56 #include "mlir/IR/Operation.h"  // from @llvm-project
57 #include "mlir/IR/Types.h"  // from @llvm-project
58 #include "mlir/IR/Value.h"  // from @llvm-project
59 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
60 #include "mlir/Translation.h"  // from @llvm-project
61 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
62 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
63 #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
64 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
65 #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
66 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
70 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
71 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
72 #include "tensorflow/compiler/xla/statusor.h"
73 #include "tensorflow/core/framework/attr_value.pb.h"
74 #include "tensorflow/core/framework/node_def.pb.h"
75 #include "tensorflow/core/platform/errors.h"
76 #include "tensorflow/core/platform/logging.h"
77 #include "tensorflow/core/platform/status.h"
78 #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
79 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
80 #include "tensorflow/lite/schema/schema_conversion_utils.h"
81 #include "tensorflow/lite/schema/schema_generated.h"
82 #include "tensorflow/lite/string_util.h"
83 #include "tensorflow/lite/tools/versioning/op_version.h"
84 #include "tensorflow/lite/tools/versioning/runtime_version.h"
85 #include "tensorflow/lite/version.h"
86 
87 using llvm::dyn_cast;
88 using llvm::formatv;
89 using llvm::isa;
90 using llvm::Optional;
91 using llvm::StringRef;
92 using llvm::Twine;
93 using mlir::Dialect;
94 using mlir::ElementsAttr;
95 using mlir::FuncOp;
96 using mlir::MLIRContext;
97 using mlir::ModuleOp;
98 using mlir::NoneType;
99 using mlir::Operation;
100 using mlir::Region;
101 using mlir::StringAttr;
102 using mlir::TensorType;
103 using mlir::Type;
104 using mlir::UnknownLoc;
105 using mlir::Value;
106 using mlir::WalkResult;
107 using tensorflow::OpOrArgLocNameMapper;
108 using tensorflow::OpOrArgNameMapper;
109 using tensorflow::Status;
110 using tflite::flex::IsAllowlistedFlexOp;
111 using xla::StatusOr;
112 
113 template <typename T>
114 using BufferOffset = flatbuffers::Offset<T>;
115 
116 template <typename T>
117 using VectorBufferOffset = flatbuffers::Offset<flatbuffers::Vector<T>>;
118 
119 using CustomOptionsOffset = VectorBufferOffset<uint8_t>;
120 
121 namespace error = tensorflow::error;
122 namespace tfl = mlir::TFL;
123 
124 ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex";
125 
126 // Use initial buffer size in flatbuffer builder to be same as the initial size
127 // used by the TOCO export. (It does not explain rationale for this choice.)
128 constexpr size_t kInitialBufferSize = 10240;
129 
130 // Set `isSigned` to false if the `type` is an 8-bit unsigned integer type.
131 // Since tflite doesn't support unsigned for other types, returns error if
132 // `isSigned` is set to false for other types.
GetTFLiteType(Type type,bool is_signed=true)133 static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
134                                                   bool is_signed = true) {
135   if (!is_signed && type.isSignlessInteger(8)) {
136     return tflite::TensorType_UINT8;
137   }
138   if (!is_signed) {
139     return Status(error::INVALID_ARGUMENT,
140                   "'isSigned' can only be set for 8-bits integer type");
141   }
142 
143   if (type.isF32()) {
144     return tflite::TensorType_FLOAT32;
145   } else if (type.isF16()) {
146     return tflite::TensorType_FLOAT16;
147   } else if (type.isF64()) {
148     return tflite::TensorType_FLOAT64;
149   } else if (type.isa<mlir::TF::StringType>()) {
150     return tflite::TensorType_STRING;
151   } else if (type.isa<mlir::TF::Quint8Type>()) {
152     return tflite::TensorType_UINT8;
153   } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
154     auto ftype = complex_type.getElementType();
155     if (ftype.isF32()) {
156       return tflite::TensorType_COMPLEX64;
157     }
158     if (ftype.isF64()) {
159       return tflite::TensorType_COMPLEX128;
160     }
161     return Status(error::INVALID_ARGUMENT, "Unsupported type");
162   } else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
163     switch (itype.getWidth()) {
164       case 1:
165         return tflite::TensorType_BOOL;
166       case 8:
167         return itype.isUnsigned() ? tflite::TensorType_UINT8
168                                   : tflite::TensorType_INT8;
169       case 16:
170         return tflite::TensorType_INT16;
171       case 32:
172         return itype.isUnsigned() ? tflite::TensorType_UINT32
173                                   : tflite::TensorType_INT32;
174       case 64:
175         return itype.isUnsigned() ? tflite::TensorType_UINT64
176                                   : tflite::TensorType_INT64;
177     }
178   } else if (auto q_uniform_type =
179                  type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
180     return GetTFLiteType(q_uniform_type.getStorageType(),
181                          q_uniform_type.isSigned());
182   } else if (auto q_peraxis_type =
183                  type.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
184     return GetTFLiteType(q_peraxis_type.getStorageType(),
185                          q_peraxis_type.isSigned());
186   } else if (auto q_calibrated_type =
187                  type.dyn_cast<mlir::quant::CalibratedQuantizedType>()) {
188     return GetTFLiteType(q_calibrated_type.getExpressedType());
189   } else if (type.isa<mlir::TF::ResourceType>()) {
190     return tflite::TensorType_RESOURCE;
191   } else if (type.isa<mlir::TF::VariantType>()) {
192     return tflite::TensorType_VARIANT;
193   }
194   // TFLite export fills FLOAT32 for unknown data types. Returning an error
195   // for now for safety and this could be revisited when required.
196   return Status(error::INVALID_ARGUMENT, "Unsupported type");
197 }
198 
IsConst(Operation * op)199 static bool IsConst(Operation* op) {
200   return isa<mlir::ConstantOp, mlir::TF::ConstOp, tfl::ConstOp, tfl::QConstOp,
201              tfl::SparseConstOp, tfl::SparseQConstOp>(op);
202 }
203 
IsTFResourceOp(Operation * op)204 static bool IsTFResourceOp(Operation* op) {
205   for (const auto& operand : op->getOperands()) {
206     auto elementType = getElementTypeOrSelf(operand.getType());
207     if (elementType.isa<mlir::TF::ResourceType>()) {
208       return true;
209     }
210   }
211   for (const auto& result : op->getResults()) {
212     auto elementType = getElementTypeOrSelf(result.getType());
213     if (elementType.isa<mlir::TF::ResourceType>()) {
214       return true;
215     }
216   }
217   return false;
218 }
219 
220 // Create description of operation that could not be converted.
GetOpDescriptionForDebug(Operation * inst)221 static std::string GetOpDescriptionForDebug(Operation* inst) {
222   const int kLargeElementsAttr = 16;
223   std::string op_str;
224   llvm::raw_string_ostream os(op_str);
225   inst->getName().print(os);
226   os << "(";
227   if (!inst->getOperandTypes().empty()) {
228     bool first = true;
229     for (Type operand_type : inst->getOperandTypes()) {
230       os << (!first ? ", " : "");
231       first = false;
232       os << operand_type;
233     }
234   }
235   os << ") -> (";
236   if (!inst->getResultTypes().empty()) {
237     bool first = true;
238     for (Type result_type : inst->getResultTypes()) {
239       os << (!first ? ", " : "");
240       first = false;
241       os << result_type;
242     }
243   }
244   os << ")";
245   // Print out attributes except for large elementsattributes (which should
246   // rarely be the cause why the legalization didn't happen).
247   if (!inst->getAttrDictionary().empty()) {
248     os << " : {";
249     bool first = true;
250     for (auto& named_attr : inst->getAttrDictionary()) {
251       os << (!first ? ", " : "");
252       first = false;
253       named_attr.first.print(os);
254       os << " = ";
255       if (auto element_attr = named_attr.second.dyn_cast<ElementsAttr>()) {
256         if (element_attr.getNumElements() <= kLargeElementsAttr) {
257           element_attr.print(os);
258         } else {
259           os << "<large>";
260         }
261       } else {
262         named_attr.second.print(os);
263       }
264     }
265     os << "}";
266   }
267   return os.str();
268 }
269 
270 // Create a summary with the given information regarding op names and
271 // descriptions.
GetOpsSummary(const std::map<std::string,std::set<std::string>> & ops,const std::string & summary_title)272 static std::string GetOpsSummary(
273     const std::map<std::string, std::set<std::string>>& ops,
274     const std::string& summary_title) {
275   std::string op_str;
276   llvm::raw_string_ostream os(op_str);
277 
278   std::vector<std::string> keys;
279   keys.reserve(ops.size());
280 
281   std::vector<std::string> values;
282   values.reserve(ops.size());
283 
284   for (auto const& op_name_and_details : ops) {
285     keys.push_back(op_name_and_details.first);
286     for (auto const& op_detail : op_name_and_details.second) {
287       values.push_back(op_detail);
288     }
289   }
290 
291   os << summary_title << " ops: " << absl::StrJoin(keys, ", ") << "\n";
292   os << "Details:\n\t" << absl::StrJoin(values, "\n\t");
293 
294   return os.str();
295 }
296 
297 template <typename T>
HasValidTFLiteType(Value value,T & error_handler)298 static bool HasValidTFLiteType(Value value, T& error_handler) {
299   // None type is allowed to represent unspecified operands.
300   if (value.getType().isa<NoneType>()) return true;
301 
302   auto type = value.getType().dyn_cast<TensorType>();
303   if (!type) {
304     if (auto op = value.getDefiningOp()) {
305       error_handler.emitError()
306           << '\'' << op << "' should produce value of tensor type instead of "
307           << value.getType();
308       return false;
309     }
310     error_handler.emitError("expected tensor type, got ") << value.getType();
311     return false;
312   }
313 
314   Type element_type = type.getElementType();
315   auto status = GetTFLiteType(element_type);
316   if (!status.ok()) {
317     return error_handler.emitError(
318                formatv("Failed to convert element type '{0}': {1}",
319                        element_type, status.status().error_message())),
320            false;
321   }
322   return true;
323 }
324 
325 // Returns true if the module holds all the invariants expected by the
326 // Translator class.
327 // TODO(hinsu): Now that translation is done by making a single pass over the
328 // MLIR module, consider inlining these validation checks at the place where
329 // these invariants are assumed instead of checking upfront.
IsValidTFLiteMlirModule(ModuleOp module)330 static bool IsValidTFLiteMlirModule(ModuleOp module) {
331   MLIRContext* context = module.getContext();
332 
333   // Verify that module has a function named main.
334   FuncOp main_fn = module.lookupSymbol<FuncOp>("main");
335   if (!main_fn) {
336     int entry_func_count = 0;
337     for (auto fn : module.getOps<FuncOp>()) {
338       auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
339       if (attrs && !attrs.empty()) {
340         ++entry_func_count;
341       }
342     }
343 
344     // Verify that module has a least one enrty function.
345     if (entry_func_count == 0) {
346       return emitError(UnknownLoc::get(context),
347                        "should have a least one entry function"),
348              false;
349     }
350   }
351 
352   for (auto fn : module.getOps<FuncOp>()) {
353     if (!llvm::hasSingleElement(fn)) {
354       return fn.emitError("should have exactly one basic block"), false;
355     }
356     auto& bb = fn.front();
357 
358     for (auto arg : bb.getArguments()) {
359       if (!HasValidTFLiteType(arg, fn)) {
360         auto elementType = getElementTypeOrSelf(arg.getType());
361         if (elementType.isa<mlir::TF::VariantType>()) {
362           return fn.emitError(
363                      "function argument uses variant type. Currently, the "
364                      "variant type is not natively supported in TFLite. Please "
365                      "consider not using the variant type: ")
366                      << arg.getType(),
367                  false;
368         }
369         return fn.emitError("invalid TFLite type: ") << arg.getType(), false;
370       }
371     }
372 
373     // Verify that all operations except the terminator have exactly one
374     // result of type supported by TFLite.
375     for (auto& inst : bb) {
376       if (inst.hasTrait<mlir::OpTrait::IsTerminator>()) break;
377 
378       for (auto result : inst.getResults()) {
379         if (!HasValidTFLiteType(result, inst)) {
380           auto elementType = getElementTypeOrSelf(result.getType());
381           if (elementType.isa<mlir::TF::VariantType>()) {
382             return inst.emitError(
383                        "operand result uses variant type. Currently, the "
384                        "variant type is not natively supported in TFLite. "
385                        "Please "
386                        "consider not using the variant type: ")
387                        << result.getType(),
388                    false;
389           }
390           return fn.emitError("invalid TFLite type: ") << result.getType(),
391                  false;
392         }
393       }
394     }
395   }
396 
397   return true;
398 }
399 
GetTensorFlowNodeDef(::mlir::Operation * inst)400 static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef(
401     ::mlir::Operation* inst) {
402   // We pass empty string for the original node_def name since Flex runtime
403   // does not care about this being set correctly on node_def. There is no
404   // "easy" (see b/120948529) way yet to get this from MLIR inst.
405   auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef(
406       inst, /*name=*/"", /*ignore_unregistered_attrs=*/true);
407   if (!status_or_node_def.ok()) {
408     inst->emitOpError(
409         Twine("failed to obtain TensorFlow nodedef with status: " +
410               status_or_node_def.status().ToString()));
411     return {};
412   }
413   return std::move(status_or_node_def.ValueOrDie());
414 }
415 
416 // Converts a mlir padding StringRef to TfLitePadding.
417 // Returns llvm::None if conversion fails.
GetTflitePadding(Operation * inst,llvm::StringRef padding)418 static Optional<TfLitePadding> GetTflitePadding(Operation* inst,
419                                                 llvm::StringRef padding) {
420   const tflite::Padding padding_attr =
421       std::move(llvm::StringSwitch<tflite::Padding>(padding)
422                     .Case("SAME", tflite::Padding_SAME)
423                     .Case("VALID", tflite::Padding_VALID));
424   if (padding_attr == tflite::Padding_SAME) {
425     return kTfLitePaddingSame;
426   }
427   if (padding_attr == tflite::Padding_VALID) {
428     return kTfLitePaddingValid;
429   }
430 
431   return inst->emitOpError() << "Invalid padding attribute: " << padding,
432          llvm::None;
433 }
434 
435 // Extracts TfLitePoolParams from a TFL custom op.
436 // Template parameter, TFLOp, should be a TFL custom op containing attributes
437 // generated from TfLitePoolParams.
438 // Returns llvm::None if conversion fails.
439 template <typename TFLOp>
GetTflitePoolParams(Operation * inst,TFLOp op)440 static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
441                                                       TFLOp op) {
442   TfLitePoolParams pool_params;
443   pool_params.stride_height = op.stride_h().getSExtValue();
444   pool_params.stride_width = op.stride_w().getSExtValue();
445   pool_params.filter_height = op.filter_h().getSExtValue();
446   pool_params.filter_width = op.filter_w().getSExtValue();
447   const auto padding = GetTflitePadding(inst, op.padding());
448   if (padding) {
449     pool_params.padding = *padding;
450     pool_params.activation = kTfLiteActNone;
451     pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
452     return pool_params;
453   }
454 
455   return llvm::None;
456 }
457 
458 namespace {
459 
460 // Helper struct that wraps inputs/outputs of a single SignatureDef.
461 struct SignatureDefData {
462   // Note, we are using maps here to make order deterministic
463   // for easily testing only.
464 
465   // Inputs defined in the signature def mapped to tensor names.
466   std::map<std::string, std::string> inputs;
467   // Outputs defined in the signature def mapped to tensor names.
468   std::map<std::string, std::string> outputs;
469   // Signature key.
470   std::string signature_key;
471   // Subgraph index.
472   uint32_t subgraph_index;
473 };
474 
475 // Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
476 class Translator {
477  public:
478   // Translates the given MLIR module into TFLite FlatBuffer format and returns
479   // the serialized output. Returns llvm::None on unsupported, invalid inputs or
480   // internal error.
481   static Optional<std::string> Translate(
482       ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
483       bool emit_custom_ops, bool allow_all_select_tf_ops,
484       const std::unordered_set<std::string>& select_user_tf_ops,
485       const std::unordered_set<std::string>& tags,
486       OpOrArgNameMapper* op_or_arg_name_mapper,
487       const std::map<std::string, std::string>& metadata);
488 
489  private:
490   enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
Translator(ModuleOp module,bool emit_builtin_tflite_ops,bool emit_select_tf_ops,bool emit_custom_ops,bool allow_all_select_tf_ops,const std::unordered_set<std::string> & select_user_tf_ops,const std::unordered_set<std::string> & saved_model_tags,OpOrArgNameMapper * op_or_arg_name_mapper,const std::map<std::string,std::string> & metadata)491   explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops,
492                       bool emit_select_tf_ops, bool emit_custom_ops,
493                       bool allow_all_select_tf_ops,
494                       const std::unordered_set<std::string>& select_user_tf_ops,
495                       const std::unordered_set<std::string>& saved_model_tags,
496                       OpOrArgNameMapper* op_or_arg_name_mapper,
497                       const std::map<std::string, std::string>& metadata)
498       : module_(module),
499         name_mapper_(*op_or_arg_name_mapper),
500         builder_(kInitialBufferSize),
501         saved_model_tags_(saved_model_tags),
502         allow_all_select_tf_ops_(allow_all_select_tf_ops),
503         select_user_tf_ops_(select_user_tf_ops),
504         metadata_(metadata) {
505     // The first buffer must be empty according to the schema definition.
506     empty_buffer_ = tflite::CreateBuffer(builder_);
507     buffers_.push_back(empty_buffer_);
508     if (emit_builtin_tflite_ops) {
509       enabled_op_types_.emplace(OpType::kTfliteBuiltin);
510     }
511     if (emit_select_tf_ops) {
512       enabled_op_types_.emplace(OpType::kSelectTf);
513     }
514     if (emit_custom_ops) {
515       enabled_op_types_.emplace(OpType::kCustomOp);
516     }
517     tf_dialect_ =
518         module.getContext()->getOrLoadDialect<mlir::TF::TensorFlowDialect>();
519     tfl_dialect_ = module.getContext()
520                        ->getOrLoadDialect<mlir::TFL::TensorFlowLiteDialect>();
521     // Right now the TF executor dialect is still needed to build NodeDef.
522     module.getContext()
523         ->getOrLoadDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
524   }
525 
526   Optional<std::string> TranslateInternal();
527 
528   // Returns TFLite buffer populated with constant value if the operation is
529   // TFLite constant operation. Otherwise, returns an empty buffer. Emits error
530   // and returns llvm::None on failure.
531   Optional<BufferOffset<tflite::Buffer>> BuildBuffer(Operation* inst);
532 
533   // Build TFLite tensor from the given type. This function is for tfl.lstm
534   // intermediates, which should have UniformQuantizedType.
535   Optional<BufferOffset<tflite::Tensor>> BuildTensorFromType(
536       mlir::Type type, const std::string& name);
537 
538   // Builds TFLite tensor from the given value. `buffer_idx` is index of the
539   // corresponding buffer. Emits error and returns llvm::None on failure.
540   Optional<BufferOffset<tflite::Tensor>> BuildTensor(
541       Value value, const std::string& name, unsigned buffer_idx,
542       const Optional<BufferOffset<tflite::QuantizationParameters>>&
543           quant_parameters);
544 
545   // TODO(b/137395003): Legalize tf.IfOp to TFLite dialect, and change the
546   // following method to handle TFL::IfOp.
547   BufferOffset<tflite::Operator> BuildIfOperator(
548       mlir::TF::IfOp op, const std::vector<int32_t>& operands,
549       const std::vector<int32_t>& results);
550 
551   // Build while operator where cond & body are regions.
552   Optional<BufferOffset<tflite::Operator>> BuildWhileOperator(
553       mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
554       const std::vector<int32_t>& results);
555 
556   // Build call once operator.
557   BufferOffset<tflite::Operator> BuildCallOnceOperator(
558       mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
559       const std::vector<int32_t>& results);
560 
561   BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
562       mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
563       const std::vector<int32_t>& results);
564 
565   BufferOffset<tflite::Operator> BuildCustomOperator(
566       Operation* inst, mlir::TFL::CustomOp op,
567       const std::vector<int32_t>& operands,
568       const std::vector<int32_t>& results);
569 
570   Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
571       const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
572 
573   Optional<CustomOptionsOffset> CreateCustomOpCustomOptions(
574       const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
575 
576   std::unique_ptr<flexbuffers::Builder> CreateFlexBuilderWithNodeAttrs(
577       const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
578 
579   // Returns opcode index for op identified by the op_name, if already
580   // available. Otherwise, creates a new OperatorCode using the given `builtin`
581   // operator and associates it with `op_name`.
582   uint32_t GetOpcodeIndex(const std::string& op_name,
583                           tflite::BuiltinOperator builtin);
584 
585   // Builds operator for the given operation with specified operand and result
586   // tensor indices. Emits an error and returns llvm::None on failure.
587   Optional<BufferOffset<tflite::Operator>> BuildOperator(
588       Operation* inst, std::vector<int32_t> operands,
589       const std::vector<int32_t>& results,
590       const std::vector<int32_t>& intermediates);
591 
592   // Returns the quantization parameters for output value of "quant.stats" op.
593   BufferOffset<tflite::QuantizationParameters>
594   GetQuantizationForQuantStatsOpOutput(mlir::quant::StatisticsOp stats_op);
595 
596   // Build a subgraph with a given name out of the region either corresponding
597   // to a function's body or while op.
598   Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(
599       const std::string& name, Region* region, const int index);
600 
601   // Builds Metadata with the given `name` and buffer `content`.
602   BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
603                                                StringRef content);
604 
605   // Encodes the `tfl.metadata` dictionary attribute of the module to the
606   // metadata section in the final model.
607   Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
608   CreateMetadataVector();
609 
610   // Builds and returns list of tfl.SignatureDef sections in the model.
611   Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
612   CreateSignatureDefs(const std::vector<SignatureDefData>& signature_defs);
613 
614   // Returns list of offsets for the passed 'items' in TensorMap structure
615   // inside the flatbuffer.
616   // 'items' is a map from tensor name in signatureDef to tensor name in
617   // the subgraph, specified by the 'subgraph_index' argument.
618   std::vector<BufferOffset<tflite::TensorMap>> GetList(
619       const int subgraph_index,
620       const std::map<std::string, std::string>& items);
621 
622   // Uses the tf.entry_function attribute (if set) to initialize the op to name
623   // mapping.
624   void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr);
625 
626   // Determines if the specified operation op's operand at operand_index
627   // is marked as a stateful operand.
628   bool IsStatefulOperand(mlir::Operation* op, int operand_index);
629 
630   // Returns a unique name for `val`.
631   std::string UniqueName(mlir::Value val);
632 
633   BufferOffset<tflite::SparsityParameters> BuildSparsityParameters(
634       const mlir::TFL::SparsityParameterAttr& s_attr);
635 
636   bool EstimateArithmeticCount(int64_t* count);
637 
638   ModuleOp module_;
639 
640   tensorflow::OpOrArgNameMapper& name_mapper_;
641 
642   flatbuffers::FlatBufferBuilder builder_;
643   BufferOffset<tflite::Buffer> empty_buffer_;
644 
645   std::vector<BufferOffset<tflite::Buffer>> buffers_;
646   // Maps subgraph index and tensor name in the graph to the tensor index.
647   absl::flat_hash_map<int, absl::flat_hash_map<std::string, int>>
648       tensor_index_map_;
649 
650   // Maps op name to index of the corresponding OperatorCode in opcodes_ vector.
651   absl::flat_hash_map<std::string, uint32_t> opcode_index_map_;
652   std::vector<BufferOffset<tflite::OperatorCode>> opcodes_;
653 
654   // Maps function name to index of the corresponding subgraph in the FlatBuffer
655   // model.
656   absl::flat_hash_map<std::string, int> subgraph_index_map_;
657   absl::flat_hash_set<OpType> enabled_op_types_;
658 
659   // Points to TensorFlow and TFLite dialects, respectively. nullptr if the
660   // dialect is not registered.
661   const Dialect* tf_dialect_;
662   const Dialect* tfl_dialect_;
663 
664   // The failed ops during legalization.
665   std::map<std::string, std::set<std::string>> failed_flex_ops_;
666   std::map<std::string, std::set<std::string>> failed_custom_ops_;
667 
668   // Ops to provide warning messages.
669   std::map<std::string, std::set<std::string>> custom_ops_;
670   std::map<std::string, std::set<std::string>> flex_ops_;
671 
672   // Resource ops to provide warning messages.
673   std::map<std::string, std::set<std::string>> resource_ops_;
674 
675   // Set of saved model tags, if any.
676   const std::unordered_set<std::string> saved_model_tags_;
677   // Allows automatic pass through of TF ops as select Tensorflow ops.
678   const bool allow_all_select_tf_ops_;
679   // User's defined ops allowed with Flex.
680   const std::unordered_set<std::string> select_user_tf_ops_;
681   // Map of key value pairs of metadata to export.
682   const std::map<std::string, std::string> metadata_;
683 };
684 
EstimateArithmeticCount(int64_t * count)685 bool Translator::EstimateArithmeticCount(int64_t* count) {
686   int64_t result = 0;
687   bool encounter_undetermined_mac = false;
688   module_->walk([&](mlir::TFL::TflArithmeticCountOpInterface op) {
689     int64_t mac_count = op.GetArithmeticCount(op);
690     if (mac_count < 0) {
691       encounter_undetermined_mac = true;
692       return;
693     }
694     result += mac_count;
695   });
696 
697   *count = result;
698   return !encounter_undetermined_mac;
699 }
700 
UniqueName(mlir::Value val)701 std::string Translator::UniqueName(mlir::Value val) {
702   return std::string(name_mapper_.GetUniqueName(val));
703 }
704 
BuildBuffer(Operation * inst)705 Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
706     Operation* inst) {
707   ElementsAttr attr;
708   if (auto cst = dyn_cast<mlir::ConstantOp>(inst)) {
709     // ConstantOp have ElementAttr at this point due to validation of the TFLite
710     // module.
711     attr = cst.getValue().cast<ElementsAttr>();
712   } else if (auto cst = dyn_cast<mlir::TF::ConstOp>(inst)) {
713     attr = cst.value();
714   } else if (auto cst = dyn_cast<tfl::ConstOp>(inst)) {
715     attr = cst.value();
716   } else if (auto cst = dyn_cast<tfl::QConstOp>(inst)) {
717     attr = cst.value();
718   } else if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
719     attr = cst.compressed_data();
720   } else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
721     attr = cst.compressed_data();
722   } else {
723     return empty_buffer_;
724   }
725 
726   tensorflow::Tensor tensor;
727   auto status = tensorflow::ConvertToTensor(attr, &tensor);
728   if (!status.ok()) {
729     inst->emitError(
730         Twine("failed to convert value attribute to tensor with error: " +
731               status.ToString()));
732     return llvm::None;
733   }
734 
735   // TensorFlow and TensorFlow Lite use different string encoding formats.
736   // Convert to TensorFlow Lite format is it's a constant string tensor.
737   if (tensor.dtype() == tensorflow::DT_STRING) {
738     ::tflite::DynamicBuffer dynamic_buffer;
739     auto flat = tensor.flat<::tensorflow::tstring>();
740     for (int i = 0; i < flat.size(); ++i) {
741       const auto& str = flat(i);
742       dynamic_buffer.AddString(str.c_str(), str.length());
743     }
744     char* tensor_buffer;
745     int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer);
746     auto buffer_data =
747         builder_.CreateVector(reinterpret_cast<uint8_t*>(tensor_buffer), bytes);
748     free(tensor_buffer);
749     return tflite::CreateBuffer(builder_, buffer_data);
750   }
751 
752   absl::string_view tensor_data = tensor.tensor_data();
753   auto buffer_data = builder_.CreateVector(
754       reinterpret_cast<const uint8_t*>(tensor_data.data()), tensor_data.size());
755   return tflite::CreateBuffer(builder_, buffer_data);
756 }
757 
BuildTensorFromType(mlir::Type type,const std::string & name)758 Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensorFromType(
759     mlir::Type type, const std::string& name) {
760   auto tensor_type = type.cast<TensorType>();
761 
762   if (!tensor_type.hasStaticShape()) {
763     return llvm::None;
764   }
765   llvm::ArrayRef<int64_t> shape_ref = tensor_type.getShape();
766   std::vector<int32_t> shape(shape_ref.begin(), shape_ref.end());
767 
768   auto element_type = tensor_type.getElementType();
769   tflite::TensorType tflite_element_type =
770       GetTFLiteType(tensor_type.getElementType()).ValueOrDie();
771   BufferOffset<tflite::QuantizationParameters> q_params = 0;
772   if (auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
773     q_params = tflite::CreateQuantizationParameters(
774         builder_, /*min=*/0, /*max=*/0,
775         builder_.CreateVector<float>({static_cast<float>(qtype.getScale())}),
776         builder_.CreateVector<int64_t>({qtype.getZeroPoint()}));
777   } else if (auto qtype =
778                  element_type
779                      .dyn_cast<mlir::quant::CalibratedQuantizedType>()) {
780     q_params = tflite::CreateQuantizationParameters(
781         builder_,
782         builder_.CreateVector<float>({static_cast<float>(qtype.getMin())}),
783         builder_.CreateVector<float>({static_cast<float>(qtype.getMax())}));
784   }
785   return tflite::CreateTensor(
786       builder_, builder_.CreateVector(shape), tflite_element_type,
787       /*buffer=*/0, builder_.CreateString(name), q_params,
788       /*is_variable=*/false);
789 }
790 
BuildTensor(Value value,const std::string & name,unsigned buffer_idx,const Optional<BufferOffset<tflite::QuantizationParameters>> & quant_parameters)791 Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
792     Value value, const std::string& name, unsigned buffer_idx,
793     const Optional<BufferOffset<tflite::QuantizationParameters>>&
794         quant_parameters) {
795   auto type = value.getType().cast<TensorType>();
796 
797   // TFLite requires tensor shape only for the inputs and constants.
798   // However, we output all known shapes for better round-tripping
799   auto check_shape =
800       [&](llvm::ArrayRef<int64_t> shape_ref) -> mlir::LogicalResult {
801     auto is_out_of_range = [](int64_t dim) {
802       return dim > std::numeric_limits<int32_t>::max();
803     };
804 
805     if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range))
806       return mlir::emitError(
807           value.getLoc(),
808           "result shape dimensions out of 32 bit int type range");
809 
810     return mlir::success();
811   };
812 
813   std::vector<int32_t> shape;
814   std::vector<int32_t> shape_signature;
815   auto* inst = value.getDefiningOp();
816   if (type.hasStaticShape()) {
817     llvm::ArrayRef<int64_t> shape_ref = type.getShape();
818     if (mlir::failed(check_shape(shape_ref))) return llvm::None;
819 
820     shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
821   } else if (inst && IsConst(inst)) {
822     // Const op can have a result of dynamic shaped type (e.g. due to constant
823     // folding), but we can still derive the shape of a constant tensor for
824     // its attribute type.
825     mlir::Attribute tensor_attr = inst->getAttr("value");
826     llvm::ArrayRef<int64_t> shape_ref =
827         tensor_attr.getType().cast<TensorType>().getShape();
828     if (mlir::failed(check_shape(shape_ref))) return llvm::None;
829 
830     shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
831   } else if (type.hasRank()) {
832     llvm::ArrayRef<int64_t> shape_ref = type.getShape();
833     if (mlir::failed(check_shape(shape_ref))) return llvm::None;
834 
835     shape.reserve(shape_ref.size());
836     for (auto& dim : shape_ref) {
837       shape.push_back(dim == -1 ? 1 : dim);
838     }
839     shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
840   }
841 
842   BufferOffset<tflite::SparsityParameters> s_params = 0;
843   if (auto* inst = value.getDefiningOp()) {
844     if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
845       s_params = BuildSparsityParameters(cst.s_param());
846     } else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
847       s_params = BuildSparsityParameters(cst.s_param());
848     }
849   }
850 
851   Type element_type = type.getElementType();
852   tflite::TensorType tflite_element_type =
853       GetTFLiteType(type.getElementType()).ValueOrDie();
854 
855   BufferOffset<tflite::QuantizationParameters> q_params;
856   if (auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
857     q_params = tflite::CreateQuantizationParameters(
858         // min and max values are not stored in the quantized type from MLIR, so
859         // both are set to 0 in the flatbuffer when they are exported.
860         builder_, /*min=*/0, /*max=*/0,
861         builder_.CreateVector<float>({static_cast<float>(qtype.getScale())}),
862         builder_.CreateVector<int64_t>({qtype.getZeroPoint()}));
863   } else if (auto qtype =
864                  element_type
865                      .dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
866     std::vector<float> scales(qtype.getScales().begin(),
867                               qtype.getScales().end());
868     q_params = tflite::CreateQuantizationParameters(
869         builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>(scales),
870         builder_.CreateVector<int64_t>(qtype.getZeroPoints()),
871         tflite::QuantizationDetails_NONE, /*details=*/0,
872         qtype.getQuantizedDimension());
873   } else if (quant_parameters.hasValue()) {
874     q_params = quant_parameters.getValue();
875   } else {
876     q_params = tflite::CreateQuantizationParameters(builder_);
877   }
878   // Check if the value's uses includes an op and usage at an operand index
879   // marked as a stateful. If so, set the tensor's is_variable as true
880   // This is v1 ref variable semantics in the TFLite runtime.
881   bool is_variable = false;
882   for (auto& use : value.getUses()) {
883     is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
884     if (is_variable) {
885       break;
886     }
887   }
888 
889   if (shape_signature.empty()) {
890     return tflite::CreateTensor(
891         builder_, builder_.CreateVector(shape), tflite_element_type,
892         (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
893         /*is_variable=*/is_variable, s_params);
894   } else {
895     return tflite::CreateTensor(
896         builder_, builder_.CreateVector(shape), tflite_element_type,
897         (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
898         /*is_variable=*/is_variable, s_params,
899         /*shape_signature=*/builder_.CreateVector(shape_signature));
900   }
901 }
902 
BuildIfOperator(mlir::TF::IfOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)903 BufferOffset<tflite::Operator> Translator::BuildIfOperator(
904     mlir::TF::IfOp op, const std::vector<int32_t>& operands,
905     const std::vector<int32_t>& results) {
906   auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF);
907   int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str());
908   int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str());
909   auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index,
910                                                  else_subgraph_index)
911                              .Union();
912   auto inputs = builder_.CreateVector(operands);
913   auto outputs = builder_.CreateVector(results);
914   return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
915                                 tflite::BuiltinOptions_IfOptions,
916                                 builtin_options);
917 }
918 
BuildCallOnceOperator(mlir::TFL::CallOnceOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)919 BufferOffset<tflite::Operator> Translator::BuildCallOnceOperator(
920     mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
921     const std::vector<int32_t>& results) {
922   auto opcode_index =
923       GetOpcodeIndex("call_once", tflite::BuiltinOperator_CALL_ONCE);
924   int init_subgraph_index =
925       subgraph_index_map_.at(op.session_init_function().str());
926   auto builtin_options =
927       tflite::CreateCallOnceOptions(builder_, init_subgraph_index).Union();
928   auto inputs = builder_.CreateVector(operands);
929   auto outputs = builder_.CreateVector(results);
930   return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
931                                 tflite::BuiltinOptions_CallOnceOptions,
932                                 builtin_options);
933 }
934 
BuildWhileOperator(mlir::TFL::WhileOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)935 Optional<BufferOffset<tflite::Operator>> Translator::BuildWhileOperator(
936     mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
937     const std::vector<int32_t>& results) {
938   auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
939   auto get_call_index = [&](mlir::Block& b) -> Optional<int> {
940     if (b.getOperations().size() != 2) return llvm::None;
941     if (auto call_op = dyn_cast<mlir::CallOp>(b.front()))
942       return subgraph_index_map_.at(call_op.callee().str());
943     return llvm::None;
944   };
945   auto body_subgraph_index = get_call_index(op.body().front());
946   auto cond_subgraph_index = get_call_index(op.cond().front());
947   if (!body_subgraph_index || !cond_subgraph_index)
948     return op.emitOpError("only single call cond/body while export supported"),
949            llvm::None;
950   auto builtin_options =
951       tflite::CreateWhileOptions(builder_, *cond_subgraph_index,
952                                  *body_subgraph_index)
953           .Union();
954   auto inputs = builder_.CreateVector(operands);
955   auto outputs = builder_.CreateVector(results);
956   return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
957                                 tflite::BuiltinOptions_WhileOptions,
958                                 builtin_options);
959 }
960 
BuildNumericVerifyOperator(mlir::TFL::NumericVerifyOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)961 BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
962     mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
963     const std::vector<int32_t>& results) {
964   float tolerance = op.tolerance().convertToFloat();
965   bool log_if_failed = op.log_if_failed();
966   auto fbb = absl::make_unique<flexbuffers::Builder>();
967   fbb->Map([&]() {
968     fbb->Float("tolerance", tolerance);
969     fbb->Bool("log_if_failed", log_if_failed);
970   });
971   fbb->Finish();
972   auto f = std::unique_ptr<flexbuffers::Builder>(fbb.release());
973   auto custom_option = f->GetBuffer();
974   auto opcode_index =
975       GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM);
976   return tflite::CreateOperator(
977       builder_, opcode_index, builder_.CreateVector(operands),
978       builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
979       /*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_option),
980       tflite::CustomOptionsFormat_FLEXBUFFERS);
981 }
982 
BuildCustomOperator(Operation * inst,mlir::TFL::CustomOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)983 BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
984     Operation* inst, mlir::TFL::CustomOp op,
985     const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
986   const std::string attrs =
987       op.custom_option().cast<mlir::OpaqueElementsAttr>().getValue().str();
988   std::vector<uint8_t> custom_option_vector(attrs.size());
989   memcpy(custom_option_vector.data(), attrs.data(), attrs.size());
990   auto opcode_index =
991       GetOpcodeIndex(op.custom_code().str(), tflite::BuiltinOperator_CUSTOM);
992   return tflite::CreateOperator(
993       builder_, opcode_index, builder_.CreateVector(operands),
994       builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
995       /*builtin_options=*/0,
996       builder_.CreateVector<uint8_t>(custom_option_vector),
997       tflite::CustomOptionsFormat_FLEXBUFFERS);
998 }
999 
CreateFlexOpCustomOptions(const::tensorflow::NodeDef & node_def,const mlir::Location & loc)1000 Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
1001     const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
1002   std::string node_def_str;
1003   if (!node_def.SerializeToString(&node_def_str)) {
1004     return emitError(loc, "failed to serialize tensorflow node_def"),
1005            llvm::None;
1006   }
1007 
1008   auto flex_builder = absl::make_unique<flexbuffers::Builder>();
1009   flex_builder->Vector([&]() {
1010     flex_builder->String(node_def.op());
1011     flex_builder->String(node_def_str);
1012   });
1013   flex_builder->Finish();
1014   return builder_.CreateVector(flex_builder->GetBuffer());
1015 }
1016 
CreateCustomOpCustomOptions(const::tensorflow::NodeDef & node_def,const mlir::Location & loc)1017 Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions(
1018     const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
1019   auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc);
1020   return builder_.CreateVector(flex_builder->GetBuffer());
1021 }
1022 
1023 std::unique_ptr<flexbuffers::Builder>
CreateFlexBuilderWithNodeAttrs(const::tensorflow::NodeDef & node_def,const mlir::Location & loc)1024 Translator::CreateFlexBuilderWithNodeAttrs(
1025     const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
1026   auto flex_builder = absl::make_unique<flexbuffers::Builder>();
1027   size_t map_start = flex_builder->StartMap();
1028   using Item = std::pair<std::string, ::tensorflow::AttrValue>;
1029   std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end());
1030   std::sort(attrs.begin(), attrs.end(),
1031             [](Item& p1, Item& p2) -> bool { return p1.first < p2.first; });
1032   for (const Item& pair : attrs) {
1033     const char* key = pair.first.c_str();
1034     const ::tensorflow::AttrValue& attr = pair.second;
1035     switch (attr.value_case()) {
1036       case ::tensorflow::AttrValue::kS:
1037         flex_builder->String(key, attr.s());
1038         break;
1039       case ::tensorflow::AttrValue::kType: {
1040         auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type());
1041         if (status_or_tfl_type.ok()) {
1042           flex_builder->Int(key, status_or_tfl_type.ValueOrDie());
1043         } else {
1044           emitWarning(loc, "ignoring unsupported tensorflow type: ")
1045               << std::to_string(attr.type());
1046         }
1047         break;
1048       }
1049       case ::tensorflow::AttrValue::kI:
1050         flex_builder->Int(key, attr.i());
1051         break;
1052       case ::tensorflow::AttrValue::kF:
1053         flex_builder->Float(key, attr.f());
1054         break;
1055       case ::tensorflow::AttrValue::kB:
1056         flex_builder->Bool(key, attr.b());
1057         break;
1058       case tensorflow::AttrValue::kList:
1059         if (attr.list().s_size() > 0) {
1060           auto start = flex_builder->StartVector(key);
1061           for (const std::string& v : attr.list().s()) {
1062             flex_builder->Add(v);
1063           }
1064           flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1065         } else if (attr.list().i_size() > 0) {
1066           auto start = flex_builder->StartVector(key);
1067           for (const int64_t v : attr.list().i()) {
1068             flex_builder->Add(v);
1069           }
1070           flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1071         } else if (attr.list().f_size() > 0) {
1072           auto start = flex_builder->StartVector(key);
1073           for (const float v : attr.list().f()) {
1074             flex_builder->Add(v);
1075           }
1076           flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1077         } else {
1078           emitWarning(loc,
1079                       "ignoring unsupported type in list attribute with key: ")
1080               << key;
1081         }
1082         break;
1083       default:
1084         emitWarning(loc, "ignoring unsupported attribute type with key: ")
1085             << key;
1086         break;
1087     }
1088   }
1089   flex_builder->EndMap(map_start);
1090   flex_builder->Finish();
1091   return flex_builder;
1092 }
1093 
GetOpcodeIndex(const std::string & op_name,tflite::BuiltinOperator builtin)1094 uint32_t Translator::GetOpcodeIndex(const std::string& op_name,
1095                                     tflite::BuiltinOperator builtin) {
1096   auto it = opcode_index_map_.insert({op_name, 0});
1097 
1098   // If the insert succeeded, the opcode has not been created already. Create a
1099   // new operator code and update its index value in the map.
1100   if (it.second) {
1101     it.first->second = opcodes_.size();
1102     auto custom_code = builtin == tflite::BuiltinOperator_CUSTOM
1103                            ? builder_.CreateString(op_name)
1104                            : BufferOffset<flatbuffers::String>();
1105     // Use version 0 for builtin op. This is a way to serialize version field to
1106     // flatbuffer (since 0 is non default) and it will be corrected later.
1107     int32_t op_version = builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1;
1108     opcodes_.push_back(CreateOperatorCode(builder_, /*builtin_code=*/builtin,
1109                                           custom_code, op_version));
1110   }
1111   return it.first->second;
1112 }
1113 
BuildOperator(Operation * inst,std::vector<int32_t> operands,const std::vector<int32_t> & results,const std::vector<int32_t> & intermediates)1114 Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
1115     Operation* inst, std::vector<int32_t> operands,
1116     const std::vector<int32_t>& results,
1117     const std::vector<int32_t>& intermediates) {
1118   const auto* dialect = inst->getDialect();
1119   if (!dialect) {
1120     inst->emitOpError("dialect is not registered");
1121     return llvm::None;
1122   }
1123 
1124   // If TFLite built in op, create operator as a builtin op.
1125   if (dialect == tfl_dialect_) {
1126     // Only if built-in TFLite op emission is enabled, would legalization have
1127     // converted any TF->TFL.
1128     if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) {
1129       return inst->emitOpError(
1130                  "is a TFLite builtin op but builtin emission is not enabled"),
1131              llvm::None;
1132     }
1133 
1134     auto builtin_code = GetBuiltinOpCode(inst);
1135     if (!builtin_code) {
1136       if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
1137         return BuildNumericVerifyOperator(verify_op, operands, results);
1138       }
1139       if (auto custom_op = dyn_cast<mlir::TFL::CustomOp>(inst)) {
1140         return BuildCustomOperator(inst, custom_op, operands, results);
1141       }
1142       if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) {
1143         if (inst->getNumOperands() != inst->getNumResults()) {
1144           inst->emitOpError(
1145               "number of operands and results don't match, only canonical "
1146               "TFL While supported");
1147           return llvm::None;
1148         }
1149         return BuildWhileOperator(whileOp, operands, results);
1150       }
1151 
1152       inst->emitOpError("is not a supported TFLite op");
1153       return llvm::None;
1154     }
1155 
1156     if (*builtin_code == tflite::BuiltinOperator_CALL_ONCE) {
1157       if (auto initOp = dyn_cast<mlir::TFL::CallOnceOp>(inst)) {
1158         return BuildCallOnceOperator(initOp, operands, results);
1159       }
1160     }
1161 
1162     std::string op_name = inst->getName().getStringRef().str();
1163     uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code);
1164 
1165     // If this is TransposeConv we need to do a special case of ignoring the
1166     // optional tensor, to allow newly created models to run on old runtimes.
1167     if (*builtin_code == tflite::BuiltinOperator_TRANSPOSE_CONV) {
1168       if (operands.size() == 4 && operands.at(3) == -1) {
1169         operands.pop_back();
1170       }
1171     }
1172 
1173     auto offset = CreateFlatBufferOperator(inst, opcode_index, operands,
1174                                            results, intermediates, &builder_);
1175     if (!offset) {
1176       inst->emitOpError("is not a supported TFLite op");
1177     }
1178     return offset;
1179   }
1180 
1181   if (dialect == tf_dialect_) {
1182     if (auto ifOp = dyn_cast<mlir::TF::IfOp>(inst)) {
1183       return BuildIfOperator(ifOp, operands, results);
1184     }
1185 
1186     CustomOptionsOffset custom_options;
1187 
1188     // Ops in TF dialect can either be custom ops or flex ops.
1189     // The reason we go directly from TensorFlow dialect MLIR to tensorflow
1190     // node instead of going to TF table gen'd ops via generated code is that
1191     // we do not want to restrict custom and flex op conversion support to
1192     // only those TF ops that are currently registered in MLIR. The current
1193     // model is of an open op system.
1194     //
1195     //  The following algorithm is followed:
1196     //   if flex is enabled and the op is allowlisted as flex
1197     //     we emit op as flex.
1198     //   if custom is enabled
1199     //    we emit the op as custom.
1200     auto node_def = GetTensorFlowNodeDef(inst);
1201     if (!node_def) {
1202       return llvm::None;
1203     }
1204 
1205     std::string op_name = node_def->op();
1206     std::string op_desc = GetOpDescriptionForDebug(inst);
1207 
1208     if (IsTFResourceOp(inst)) {
1209       resource_ops_[op_name].insert(op_desc);
1210     }
1211 
1212     const bool is_allowed_flex_op =
1213         IsAllowlistedFlexOp(node_def->op()) ||
1214         (((select_user_tf_ops_.count(node_def->op()) != 0) ||
1215           allow_all_select_tf_ops_) &&
1216          (tensorflow::OpRegistry::Global()->LookUp(node_def->op()) != nullptr));
1217     // Flex op case
1218     // Eventually, the allowlist will go away and we will rely on some TF op
1219     // trait (e.g. No side effect) to determine if it is a supported "Flex"
1220     // op or not.
1221     if (is_allowed_flex_op && enabled_op_types_.contains(OpType::kSelectTf)) {
1222       // Construct ops as flex op encoding TensorFlow node definition
1223       // as custom options.
1224       // Flex ops are named with the kFlexOpNamePrefix prefix to the actual
1225       // TF op name.
1226       op_name = std::string(kFlexOpNamePrefix) + node_def->op();
1227       if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) {
1228         custom_options = *options;
1229       } else {
1230         return llvm::None;
1231       }
1232 
1233       // Gather flex ops.
1234       flex_ops_[op_name].insert(op_desc);
1235     } else if (enabled_op_types_.contains(OpType::kCustomOp)) {
1236       // Generic case of custom ops - write using flex buffers since that
1237       // is the only custom options supported by TFLite today.
1238       op_name = node_def->op();
1239       if (auto options =
1240               CreateCustomOpCustomOptions(*node_def, inst->getLoc())) {
1241         custom_options = *options;
1242       } else {
1243         return llvm::None;
1244       }
1245 
1246       // Gather custom ops.
1247       custom_ops_[op_name].insert(op_desc);
1248     } else {
1249       // Insert failed op to `flex_ops` or `custom_ops`.
1250       if (is_allowed_flex_op) {
1251         failed_flex_ops_[op_name].insert(op_desc);
1252         tfl::AttachErrorCode(
1253             inst->emitOpError("is neither a custom op nor a flex op"),
1254             tflite::metrics::ConverterErrorData::ERROR_NEEDS_FLEX_OPS);
1255       } else {
1256         failed_custom_ops_[op_name].insert(op_desc);
1257         tfl::AttachErrorCode(
1258             inst->emitOpError("is neither a custom op nor a flex op"),
1259             tflite::metrics::ConverterErrorData::ERROR_NEEDS_CUSTOM_OPS);
1260       }
1261       return llvm::None;
1262     }
1263 
1264     uint32_t opcode_index =
1265         GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM);
1266     auto inputs = builder_.CreateVector(operands);
1267     auto outputs = builder_.CreateVector(results);
1268 
1269     return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
1270                                   tflite::BuiltinOptions_NONE,
1271                                   /*builtin_options=*/0,
1272                                   /*custom_options=*/custom_options,
1273                                   tflite::CustomOptionsFormat_FLEXBUFFERS,
1274                                   /*mutating_variable_inputs=*/0);
1275   }
1276 
1277   return inst->emitOpError(
1278              "is not any of a builtin TFLite op, a flex TensorFlow op or a "
1279              "custom TensorFlow op"),
1280          llvm::None;
1281 }
1282 
InitializeNamesFromAttribute(FuncOp fn,bool * has_input_attr)1283 void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
1284   auto dict_attr = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1285   if (!dict_attr) return;
1286 
1287   llvm::SmallVector<llvm::StringRef, 2> input_names;
1288   llvm::SmallVector<llvm::StringRef, 2> output_names;
1289   if (auto str = dict_attr.get("inputs").dyn_cast_or_null<mlir::StringAttr>()) {
1290     str.getValue().split(input_names, ',', /*MaxSplit=*/-1,
1291                          /*KeepEmpty=*/false);
1292     if (input_names.size() != fn.getNumArguments()) {
1293       fn.emitWarning() << "invalid entry function specification";
1294       return;
1295     }
1296     for (auto it : llvm::enumerate(fn.getArguments())) {
1297       name_mapper_.InitOpName(it.value(), input_names[it.index()].trim());
1298     }
1299     *has_input_attr = true;
1300   }
1301 
1302   if (auto str =
1303           dict_attr.get("outputs").dyn_cast_or_null<mlir::StringAttr>()) {
1304     str.getValue().split(output_names, ',', /*MaxSplit=*/-1,
1305                          /*KeepEmpty=*/false);
1306     auto term = fn.back().getTerminator();
1307     if (output_names.size() != term->getNumOperands()) {
1308       fn.emitWarning() << "output names (" << output_names.size()
1309                        << ") != terminator operands (" << term->getNumOperands()
1310                        << ")";
1311       return;
1312     }
1313     for (const auto& it : llvm::enumerate(term->getOperands())) {
1314       name_mapper_.InitOpName(it.value(), output_names[it.index()].trim());
1315     }
1316   }
1317 }
1318 
IsStatefulOperand(mlir::Operation * op,int operand_index)1319 bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
1320   std::vector<int> operand_indices;
1321   if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false;
1322   return absl::c_find(operand_indices, operand_index) != operand_indices.end();
1323 }
1324 
1325 BufferOffset<tflite::QuantizationParameters>
GetQuantizationForQuantStatsOpOutput(mlir::quant::StatisticsOp stats_op)1326 Translator::GetQuantizationForQuantStatsOpOutput(
1327     mlir::quant::StatisticsOp stats_op) {
1328   auto layer_stats = stats_op.layerStats().cast<mlir::DenseFPElementsAttr>();
1329   Optional<mlir::ElementsAttr> axis_stats = stats_op.axisStats();
1330   Optional<uint64_t> axis = stats_op.axis();
1331   std::vector<float> mins, maxs;
1332   mlir::DenseFPElementsAttr min_max_attr =
1333       axis_stats.hasValue()
1334           ? axis_stats.getValue().cast<mlir::DenseFPElementsAttr>()
1335           : layer_stats;
1336 
1337   for (auto index_and_value : llvm::enumerate(min_max_attr.getFloatValues())) {
1338     const llvm::APFloat value = index_and_value.value();
1339     if (index_and_value.index() % 2 == 0) {
1340       mins.push_back(value.convertToFloat());
1341     } else {
1342       maxs.push_back(value.convertToFloat());
1343     }
1344   }
1345 
1346   return tflite::CreateQuantizationParameters(
1347       builder_, builder_.CreateVector<float>(mins),
1348       builder_.CreateVector<float>(maxs), /*scale=*/0, /*zero_point=*/0,
1349       tflite::QuantizationDetails_NONE, /*details=*/0,
1350       /*quantized_dimension=*/axis.hasValue() ? axis.getValue() : 0);
1351 }
1352 
BuildSubGraph(const std::string & name,Region * region,const int index)1353 Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
1354     const std::string& name, Region* region, const int index) {
1355   bool has_input_attr = false;
1356   if (auto fn = dyn_cast<FuncOp>(region->getParentOp())) {
1357     InitializeNamesFromAttribute(fn, &has_input_attr);
1358   }
1359   std::vector<BufferOffset<tflite::Tensor>> tensors;
1360   llvm::DenseMap<Value, int> tensor_index_map;
1361 
1362   // Builds tensor and buffer for argument or operation result. Returns false
1363   // on failure.
1364   auto build_tensor_and_buffer = [&](Value value, const int subgraph_index,
1365                                      const std::string& tensor_name) {
1366     // NoneType represents optional and may be skipped here.
1367     if (value.getType().isa<NoneType>()) {
1368       return true;
1369     }
1370 
1371     tensor_index_map.insert({value, tensors.size()});
1372     tensor_index_map_[subgraph_index][tensor_name] = tensors.size();
1373     Optional<BufferOffset<tflite::QuantizationParameters>> quant_parameters;
1374     if (value.hasOneUse()) {
1375       auto stats_op =
1376           llvm::dyn_cast<mlir::quant::StatisticsOp>(*value.user_begin());
1377       if (stats_op) {
1378         quant_parameters = GetQuantizationForQuantStatsOpOutput(stats_op);
1379       }
1380     }
1381     auto tensor_or =
1382         BuildTensor(value, tensor_name, buffers_.size(), quant_parameters);
1383     if (!tensor_or) return false;
1384     tensors.push_back(*tensor_or);
1385 
1386     // TODO(ashwinm): Check if for stateful tensors, if it is also needed to
1387     // make the Buffer empty apart from setting the buffer_idx=0 in the
1388     // Tensor. This does not seem to affect runtime behavior for RNN/LSTM,
1389     // but would be good for reducing memory footprint.
1390     if (auto* inst = value.getDefiningOp()) {
1391       auto buffer_or = BuildBuffer(inst);
1392       if (!buffer_or) return false;
1393       buffers_.push_back(*buffer_or);
1394     } else {
1395       buffers_.push_back(empty_buffer_);
1396     }
1397     return true;
1398   };
1399 
1400   std::vector<BufferOffset<tflite::Operator>> operators;
1401   auto& bb = region->front();
1402 
1403   // Main function's arguments are first passed to `input` op so they don't
1404   // have associated tensor and buffer. Build FlatBuffer tensor and buffer for
1405   // other functions.
1406   for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) {
1407     mlir::BlockArgument arg = bb.getArgument(i);
1408     std::string tensor_name;
1409     if (has_input_attr)
1410       tensor_name = std::string(name_mapper_.GetUniqueName(arg));
1411     if (tensor_name.empty()) tensor_name = absl::StrCat("arg", i);
1412     if (!build_tensor_and_buffer(arg, index, tensor_name)) return llvm::None;
1413   }
1414 
1415   bool failed_once = false;
1416   for (auto& inst : bb) {
1417     if (inst.hasTrait<mlir::OpTrait::IsTerminator>()) break;
1418     // For "quant.stats" op, it's used to store the quantization parameters info
1419     // and its output should be then replaced by its input value.
1420     if (auto quant_stats_op = llvm::dyn_cast<mlir::quant::StatisticsOp>(inst)) {
1421       continue;
1422     }
1423     std::vector<int32_t> intermediates;
1424     // Build intermediate tensors for tfl.lstm and insert these tensors into
1425     // flatbuffer.
1426     if (llvm::isa<mlir::TFL::LSTMOp, mlir::TFL::UnidirectionalSequenceLSTMOp>(
1427             inst)) {
1428       std::vector<std::string> intermediate_names = {
1429           "input_to_input_intermediate", "input_to_forget_intermediate",
1430           "input_to_cell_intermediate", "input_to_output_intermediate",
1431           "effective_hidden_scale_intermediate"};
1432       for (const std::string& intermediate : intermediate_names) {
1433         auto intermediate_attr = inst.getAttr(intermediate);
1434         if (auto attr = intermediate_attr.dyn_cast_or_null<mlir::TypeAttr>()) {
1435           Type qtype = attr.getValue();
1436           auto tensor_or = BuildTensorFromType(
1437               qtype, name_mapper_.GetUniqueName(intermediate).str());
1438           if (!tensor_or.hasValue()) {
1439             continue;
1440           } else {
1441             intermediates.push_back(tensors.size());
1442             tensors.push_back(tensor_or.getValue());
1443           }
1444         }
1445       }
1446     }
1447 
1448     for (auto val : inst.getResults()) {
1449       std::string tensor_name = UniqueName(val);
1450       // For "tfl.numeric_verify" op, the name is used to find out the original
1451       // activation tensor rather than its own unique name in the visualization
1452       // or debugging tools.
1453       auto builtin_code = GetBuiltinOpCode(&inst);
1454       if (!builtin_code && dyn_cast<mlir::TFL::NumericVerifyOp>(&inst)) {
1455         // The first operand is the quantized activation, the target of this
1456         // NumericVerify op.
1457         auto quantized_op_val = inst.getOperands().front();
1458         tensor_name = "NumericVerify/" + UniqueName(quantized_op_val) + ":" +
1459                       std::to_string(tensor_index_map[quantized_op_val]);
1460       }
1461       if (!build_tensor_and_buffer(val, index, tensor_name)) return llvm::None;
1462     }
1463 
1464     // Skip constant ops as they don't represent a TFLite operator.
1465     if (IsConst(&inst)) continue;
1466 
1467     // Fetch operand and result tensor indices.
1468     std::vector<int32_t> results;
1469     results.reserve(inst.getNumResults());
1470     for (auto result : inst.getResults()) {
1471       results.push_back(tensor_index_map.lookup(result));
1472     }
1473     Operation* real_inst = &inst;
1474     std::vector<int32_t> operands;
1475     operands.reserve(real_inst->getNumOperands());
1476     for (auto operand : real_inst->getOperands()) {
1477       if (operand.getType().isa<NoneType>())
1478         operands.push_back(kTfLiteOptionalTensor);
1479       else if (auto stats_op =
1480                    llvm::dyn_cast_or_null<mlir::quant::StatisticsOp>(
1481                        operand.getDefiningOp()))
1482         operands.push_back(tensor_index_map.lookup(stats_op.arg()));
1483       else
1484         operands.push_back(tensor_index_map.lookup(operand));
1485     }
1486 
1487     // CustomTfOp is just a wrapper around a TF op, we export the custom Op
1488     // not the wrapper, so we fetch the op from the region.
1489     if (auto custom_op = dyn_cast<mlir::TFL::CustomTfOp>(inst)) {
1490       // If we have custom op with a region, then use the first op in the
1491       // region, if it exists, otherwise just use params for custom op.
1492       if (!custom_op.body().empty()) {
1493         real_inst = &custom_op.body().front().front();
1494       } else {
1495         module_.emitError(
1496             "Invalid CustomTfOp: Custom TF Op have empty region.");
1497       }
1498     }
1499 
1500     if (auto tfl_operator =
1501             BuildOperator(real_inst, operands, results, intermediates))
1502       operators.push_back(*tfl_operator);
1503     else
1504       failed_once = true;
1505   }
1506 
1507   if (failed_once) return llvm::None;
1508 
1509   // Get input and output tensor indices for the subgraph.
1510   std::vector<int32_t> inputs, outputs;
1511   for (auto arg : bb.getArguments()) {
1512     inputs.push_back(tensor_index_map[arg]);
1513   }
1514   for (auto result : bb.getTerminator()->getOperands()) {
1515     outputs.push_back(tensor_index_map[result]);
1516   }
1517 
1518   return tflite::CreateSubGraph(
1519       builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs),
1520       builder_.CreateVector(outputs), builder_.CreateVector(operators),
1521       /*name=*/builder_.CreateString(name));
1522 }
1523 
BuildMetadata(StringRef name,StringRef content)1524 BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
1525                                                          StringRef content) {
1526   auto buffer_index = buffers_.size();
1527   auto buffer_data = builder_.CreateVector(
1528       reinterpret_cast<const uint8_t*>(content.data()), content.size());
1529   buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data));
1530   return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index);
1531 }
1532 
1533 Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
CreateMetadataVector()1534 Translator::CreateMetadataVector() {
1535   auto dict_attr = module_->getAttrOfType<mlir::DictionaryAttr>("tfl.metadata");
1536   std::vector<BufferOffset<tflite::Metadata>> metadata;
1537   if (dict_attr) {
1538     for (const auto& named_attr : dict_attr) {
1539       StringRef name = named_attr.first;
1540       mlir::Attribute attr = named_attr.second;
1541       if (auto content = attr.dyn_cast<StringAttr>()) {
1542         metadata.push_back(BuildMetadata(name, content.getValue()));
1543       } else {
1544         module_.emitError(
1545             "all values in tfl.metadata's dictionary key-value pairs should be "
1546             "string attributes");
1547         return llvm::None;
1548       }
1549     }
1550   }
1551   // Runtime version string is generated after we update the op
1552   // versions. Here we put a 16-byte dummy string as a placeholder. We choose
1553   // 16-byte because it's the alignment of buffers in flatbuffer, so it won't
1554   // cause any waste of space if the actual string is shorter than 16 bytes.
1555   constexpr std::size_t kByteStringSize = 16;
1556   metadata.push_back(
1557       BuildMetadata("min_runtime_version", std::string(kByteStringSize, '\0')));
1558   for (const auto& kv : metadata_) {
1559     const std::string& val = kv.second;
1560     // Only take the first kByteStringSize values.
1561     const int count = std::min(kByteStringSize, val.length());
1562     std::string value = std::string(kByteStringSize, '\0')
1563                             .assign(val.begin(), val.begin() + count);
1564     metadata.push_back(BuildMetadata(kv.first, value));
1565   }
1566   return builder_.CreateVector(metadata);
1567 }
1568 
1569 // Helper method that returns list of all strings in a StringAttr identified
1570 // by 'attr_key' and values are separated by a comma.
GetStringsFromAttrWithSeparator(mlir::DictionaryAttr attr,const std::string & attr_key)1571 llvm::SmallVector<llvm::StringRef, 2> GetStringsFromAttrWithSeparator(
1572     mlir::DictionaryAttr attr, const std::string& attr_key) {
1573   llvm::SmallVector<llvm::StringRef, 2> result;
1574   if (auto str = attr.get(attr_key).dyn_cast_or_null<mlir::StringAttr>()) {
1575     str.getValue().split(result, ',', /*MaxSplit=*/-1,
1576                          /*KeepEmpty=*/false);
1577   }
1578   return result;
1579 }
1580 
1581 // Helper method that return list of string for all the StringAttr in the
1582 // Attribute identified by 'attr_name'.
GetStringsFromDictionaryAttr(const llvm::SmallVector<mlir::DictionaryAttr,4> & dict_attrs,const std::string & attr_name)1583 std::vector<std::string> GetStringsFromDictionaryAttr(
1584     const llvm::SmallVector<mlir::DictionaryAttr, 4>& dict_attrs,
1585     const std::string& attr_name) {
1586   std::vector<std::string> result;
1587   for (const auto& arg_attr : dict_attrs) {
1588     if (!arg_attr) continue;
1589 
1590     auto attrs = arg_attr.getValue();
1591     for (const auto attr : attrs) {
1592       if (attr.first.str() == attr_name) {
1593         auto array_attr = attr.second.dyn_cast_or_null<mlir::ArrayAttr>();
1594         if (!array_attr || array_attr.empty()) continue;
1595         auto string_attr = array_attr[0].dyn_cast_or_null<mlir::StringAttr>();
1596         if (!string_attr) continue;
1597         result.push_back(string_attr.getValue().str());
1598       }
1599     }
1600   }
1601   return result;
1602 }
1603 
BuildSignaturedef(FuncOp main_op,const std::string & saved_model_tag,const uint32_t subgraph_index,tensorflow::OpOrArgNameMapper & name_mapper)1604 std::vector<SignatureDefData> BuildSignaturedef(
1605     FuncOp main_op, const std::string& saved_model_tag,
1606     const uint32_t subgraph_index, tensorflow::OpOrArgNameMapper& name_mapper) {
1607   static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path";
1608   static const char kEntryFunctionAttributes[] = "tf.entry_function";
1609 
1610   // Fetch inputs and outputs from the signature.
1611   llvm::SmallVector<mlir::DictionaryAttr, 4> arg_attrs, res_attrs;
1612   main_op.getAllArgAttrs(arg_attrs);
1613   main_op.getAllResultAttrs(res_attrs);
1614   std::vector<std::string> sig_def_inputs =
1615       GetStringsFromDictionaryAttr(arg_attrs, kSignatureDefIndexPath);
1616   std::vector<std::string> sig_def_outputs =
1617       GetStringsFromDictionaryAttr(res_attrs, kSignatureDefIndexPath);
1618 
1619   // If no defined saved model signature, then return empty list.
1620   // This can happen when we are converting model not from SavedModel.
1621   if (sig_def_inputs.empty() && sig_def_outputs.empty()) return {};
1622 
1623   // Fetch function inputs and outputs tensor names.
1624   auto dict_attr =
1625       main_op->getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes);
1626   if (!dict_attr) return {};
1627 
1628   // Get Input and output tensor names from attribute.
1629   llvm::SmallVector<llvm::StringRef, 2> input_names =
1630       GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs");
1631   llvm::SmallVector<llvm::StringRef, 2> output_names =
1632       GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs");
1633 
1634   // Verify input size match the number of arguments.
1635   if (input_names.size() != main_op.getNumArguments()) {
1636     main_op.emitWarning() << "invalid entry function specification";
1637     return {};
1638   }
1639   // Verify output size match the number of arguments.
1640   auto term = main_op.back().getTerminator();
1641   if (output_names.size() != term->getNumOperands()) {
1642     main_op.emitWarning() << "output names (" << output_names.size()
1643                           << ") != terminator operands ("
1644                           << term->getNumOperands() << ")";
1645     return {};
1646   }
1647   // Verify number of tensors for inputs and outputs matches size
1648   // of the list in the signature def.
1649   if (input_names.size() != sig_def_inputs.size() ||
1650       output_names.size() != sig_def_outputs.size()) {
1651     main_op.emitWarning(
1652         "Mismatch between signature def inputs/outputs and main function "
1653         "arguments.");
1654     return {};
1655   }
1656   // Exported method name.
1657   auto exported_name =
1658       main_op->getAttrOfType<mlir::ArrayAttr>("tf_saved_model.exported_names");
1659   if (exported_name.empty()) {
1660     main_op.emitError("Empty exported names for main Function");
1661     return {};
1662   }
1663   // Fill the SignatureDefData container.
1664   // We create vector of size 1 as TFLite now supports only 1 signatureDef.
1665   std::vector<SignatureDefData> result(1);
1666   for (int i = 0; i < input_names.size(); ++i) {
1667     result[0].inputs[sig_def_inputs[i]] = input_names[i].str();
1668   }
1669   for (int i = 0; i < output_names.size(); ++i) {
1670     // Fetch the name from the actual operand and not rely on names from
1671     // outputs as deduping can make them invalid after conversion.
1672     auto& operand = term->getOpOperand(i);
1673     auto unique_name = std::string(name_mapper.GetUniqueName(operand.get()));
1674     result[0].outputs[sig_def_outputs[i]] = unique_name;
1675   }
1676   if (auto name_attr = exported_name[0].dyn_cast_or_null<StringAttr>())
1677     result[0].signature_key = name_attr.getValue().str();
1678   result[0].subgraph_index = subgraph_index;
1679   return result;
1680 }
1681 
GetList(const int subgraph_index,const std::map<std::string,std::string> & items)1682 std::vector<BufferOffset<tflite::TensorMap>> Translator::GetList(
1683     const int subgraph_index, const std::map<std::string, std::string>& items) {
1684   std::vector<BufferOffset<tflite::TensorMap>> result;
1685   for (const auto& item : items) {
1686     auto name_buf = builder_.CreateString(item.first);
1687     tflite::TensorMapBuilder tensor_map_builder(builder_);
1688     tensor_map_builder.add_name(name_buf);
1689     tensor_map_builder.add_tensor_index(
1690         tensor_index_map_[subgraph_index][item.second]);
1691     result.push_back(tensor_map_builder.Finish());
1692   }
1693   return result;
1694 }
1695 
1696 Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
CreateSignatureDefs(const std::vector<SignatureDefData> & signature_defs)1697 Translator::CreateSignatureDefs(
1698     const std::vector<SignatureDefData>& signature_defs) {
1699   std::vector<BufferOffset<tflite::SignatureDef>> signature_defs_buffer;
1700   // When we export each function in the module op, intentionally, we export the
1701   // entry functions at the beginning of the subgraph list and the
1702   // subgraph_index is the index in entry functions and at the same, is the
1703   // index in the subgraph list.
1704   int subgraph_index = 0;
1705   for (const auto& signature_def_data : signature_defs) {
1706     auto inputs = GetList(subgraph_index, signature_def_data.inputs);
1707     auto outputs = GetList(subgraph_index, signature_def_data.outputs);
1708     auto inputs_buf = builder_.CreateVector(inputs);
1709     auto outputs_buf = builder_.CreateVector(outputs);
1710     auto signature_key_buf =
1711         builder_.CreateString(signature_def_data.signature_key);
1712     tflite::SignatureDefBuilder sig_def_builder(builder_);
1713     sig_def_builder.add_inputs(inputs_buf);
1714     sig_def_builder.add_outputs(outputs_buf);
1715     sig_def_builder.add_signature_key(signature_key_buf);
1716     sig_def_builder.add_subgraph_index(signature_def_data.subgraph_index);
1717     signature_defs_buffer.push_back(sig_def_builder.Finish());
1718     ++subgraph_index;
1719   }
1720 
1721   return builder_.CreateVector(signature_defs_buffer);
1722 }
1723 
UpdateEntryFunction(ModuleOp module)1724 bool UpdateEntryFunction(ModuleOp module) {
1725   if (module.lookupSymbol<FuncOp>("main") != nullptr) {
1726     // We already have an entry function.
1727     return true;
1728   }
1729 
1730   int entry_func_count = 0;
1731   FuncOp entry_func = nullptr;
1732   for (auto fn : module.getOps<FuncOp>()) {
1733     auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1734     if (!attrs || attrs.empty()) continue;
1735     ++entry_func_count;
1736     entry_func = fn;
1737   }
1738 
1739   // We should have at least one entry function.
1740   if (entry_func_count == 0) return false;
1741 
1742   if (entry_func_count == 1) {
1743     // Update the entry func to main when the entry func is only & one.
1744     entry_func.setName("main");
1745   }
1746   return true;
1747 }
1748 
Translate(ModuleOp module,bool emit_builtin_tflite_ops,bool emit_select_tf_ops,bool emit_custom_ops,bool allow_all_select_tf_ops,const std::unordered_set<std::string> & select_user_tf_ops,const std::unordered_set<std::string> & tags,OpOrArgNameMapper * op_or_arg_name_mapper,const std::map<std::string,std::string> & metadata)1749 Optional<std::string> Translator::Translate(
1750     ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
1751     bool emit_custom_ops, bool allow_all_select_tf_ops,
1752     const std::unordered_set<std::string>& select_user_tf_ops,
1753     const std::unordered_set<std::string>& tags,
1754     OpOrArgNameMapper* op_or_arg_name_mapper,
1755     const std::map<std::string, std::string>& metadata) {
1756   OpOrArgLocNameMapper default_op_or_arg_name_mapper;
1757   if (!op_or_arg_name_mapper)
1758     op_or_arg_name_mapper = &default_op_or_arg_name_mapper;
1759   if (!UpdateEntryFunction(module)) return llvm::None;
1760   if (!IsValidTFLiteMlirModule(module)) return llvm::None;
1761   Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
1762                         emit_custom_ops, allow_all_select_tf_ops,
1763                         select_user_tf_ops, tags, op_or_arg_name_mapper,
1764                         metadata);
1765   return translator.TranslateInternal();
1766 }
1767 
TranslateInternal()1768 Optional<std::string> Translator::TranslateInternal() {
1769   // A list of named regions in the module with main function being the first in
1770   // the list. The main function is required as the first subgraph in the model
1771   // is entry point for the model.
1772   std::vector<std::pair<std::string, Region*>> named_regions;
1773   named_regions.reserve(std::distance(module_.begin(), module_.end()));
1774 
1775   int subgraph_idx = 0;
1776 
1777   // Entry functions for signature defs.
1778   std::vector<FuncOp> entry_functions;
1779   std::vector<FuncOp> non_entry_functions;
1780   FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
1781   if (main_fn != nullptr) {
1782     // Treat the main function as a signature def when the given main function
1783     // contains on the tf.entry_function attribute.
1784     auto attrs =
1785         main_fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1786     if (attrs && !attrs.empty()) {
1787       entry_functions.push_back(main_fn);
1788     } else {
1789       non_entry_functions.push_back(main_fn);
1790     }
1791   }
1792 
1793   // Walk over the module collection ops with functions and while ops.
1794   module_.walk([&](FuncOp fn) {
1795     if (main_fn == fn) return WalkResult::advance();
1796     auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1797     if (attrs && !attrs.empty()) {
1798       entry_functions.push_back(fn);
1799     } else {
1800       non_entry_functions.push_back(fn);
1801     }
1802     return WalkResult::advance();
1803   });
1804 
1805   // Assign the subgraph index. Among the given functions, it will put entry
1806   // functions at the beginning of the list of the subgrahs.
1807   for (auto fn : entry_functions) {
1808     subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
1809     named_regions.emplace_back(fn.getName().str(), &fn.getBody());
1810   }
1811   for (auto fn : non_entry_functions) {
1812     subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
1813     named_regions.emplace_back(fn.getName().str(), &fn.getBody());
1814   }
1815 
1816   // Build subgraph for each of the named regions.
1817   std::vector<BufferOffset<tflite::SubGraph>> subgraphs;
1818   subgraphs.reserve(named_regions.size());
1819   int first_failed_func = -1;
1820 
1821   // When we export each function in the module op, intentionally, we export the
1822   // entry functions at the beginning of the subgraph list and the
1823   // subgraph_index is the index in entry functions and at the same, is the
1824   // index in the subgraph list.
1825   int subgraph_index = 0;
1826   for (auto it : llvm::enumerate(named_regions)) {
1827     auto subgraph_or =
1828         BuildSubGraph(it.value().first, it.value().second, subgraph_index);
1829     if (!subgraph_or) {
1830       if (first_failed_func == -1)
1831         // Record the index of the first region that cannot be converted.
1832         // Keep looping through all subgraphs in the module to make sure that
1833         // we collect the list of missing ops from the entire module.
1834         first_failed_func = it.index();
1835     } else {
1836       subgraphs.push_back(*subgraph_or);
1837       ++subgraph_index;
1838     }
1839   }
1840 
1841   if (!resource_ops_.empty()) {
1842     std::string resource_ops_summary =
1843         GetOpsSummary(resource_ops_, /*summary_title=*/"Resource");
1844     LOG(WARNING) << "Graph contains the following resource op(s), that use(s) "
1845                     "resource type. Currently, the "
1846                     "resource type is not natively supported in TFLite. Please "
1847                     "consider not using the resource type if there are issues "
1848                     "with either TFLite converter or TFLite runtime:\n"
1849                  << resource_ops_summary;
1850   }
1851 
1852   if (!flex_ops_.empty()) {
1853     std::string flex_ops_summary =
1854         GetOpsSummary(flex_ops_, /*summary_title=*/"Flex");
1855     LOG(WARNING) << "TFLite interpreter needs to link Flex delegate in order "
1856                     "to run the model since it contains the following Select TF"
1857                     "op(s):\n"
1858                  << flex_ops_summary
1859                  << "\nSee instructions: "
1860                     "https://www.tensorflow.org/lite/guide/ops_select";
1861   }
1862 
1863   if (!custom_ops_.empty()) {
1864     std::string custom_ops_summary =
1865         GetOpsSummary(custom_ops_, /*summary_title=*/"Custom");
1866     LOG(WARNING) << "The following operation(s) need TFLite custom op "
1867                     "implementation(s):\n"
1868                  << custom_ops_summary
1869                  << "\nSee instructions: "
1870                     "https://www.tensorflow.org/lite/guide/ops_custom";
1871   }
1872 
1873   if (first_failed_func != -1) {
1874     std::string failed_flex_ops_summary =
1875         GetOpsSummary(failed_flex_ops_, /*summary_title=*/"TF Select");
1876     std::string failed_custom_ops_summary =
1877         GetOpsSummary(failed_custom_ops_, /*summary_title=*/"Custom");
1878     std::string err;
1879     if (!failed_flex_ops_.empty())
1880       err +=
1881           "\nSome ops are not supported by the native TFLite runtime, you can "
1882           "enable TF kernels fallback using TF Select. See instructions: "
1883           "https://www.tensorflow.org/lite/guide/ops_select \n" +
1884           failed_flex_ops_summary + "\n";
1885     if (!failed_custom_ops_.empty())
1886       err +=
1887           "\nSome ops in the model are custom ops, "
1888           "See instructions to implement "
1889           "custom ops: https://www.tensorflow.org/lite/guide/ops_custom \n" +
1890           failed_custom_ops_summary + "\n";
1891 
1892     auto& failed_region = named_regions[first_failed_func];
1893     return failed_region.second->getParentOp()->emitError()
1894                << "failed while converting: '" << failed_region.first
1895                << "': " << err,
1896            llvm::None;
1897   }
1898 
1899   // Log MAC count.
1900   int64_t ops_count;
1901   if (EstimateArithmeticCount(&ops_count)) {
1902     const int64_t million = 1e6;
1903     const int64_t billion = 1e9;
1904     std::string flops_str;
1905     std::string mac_str;
1906     if (ops_count < 10000) {
1907       flops_str = absl::StrFormat("%ld ", ops_count);
1908       mac_str = absl::StrFormat("%ld ", ops_count / 2);
1909     } else if (ops_count < billion) {
1910       flops_str =
1911           absl::StrFormat("%.3f M ", static_cast<double>(ops_count) / million);
1912       mac_str = absl::StrFormat("%.3f M ",
1913                                 static_cast<double>(ops_count / 2) / million);
1914     } else {
1915       flops_str =
1916           absl::StrFormat("%.3f G ", static_cast<double>(ops_count) / billion);
1917       mac_str = absl::StrFormat("%.3f G ",
1918                                 static_cast<double>(ops_count / 2) / billion);
1919     }
1920     std::string mac_out_str;
1921     llvm::raw_string_ostream os(mac_out_str);
1922     os << "Estimated count of arithmetic ops: " << flops_str
1923        << " ops, equivalently " << mac_str << " MACs"
1924        << "\n";
1925     os.flush();
1926     LOG(INFO) << mac_out_str;
1927     std::cout << mac_out_str;
1928   }
1929 
1930   std::string model_description;
1931   if (auto attr = module_->getAttrOfType<StringAttr>("tfl.description")) {
1932     model_description = attr.getValue().str();
1933   } else {
1934     model_description = "MLIR Converted.";
1935   }
1936 
1937   // Build the model and finish the model building process.
1938   auto description = builder_.CreateString(model_description.data());
1939   VectorBufferOffset<int32_t> metadata_buffer = 0;  // Deprecated
1940   auto metadata = CreateMetadataVector();
1941   if (!metadata) return llvm::None;
1942 
1943   std::vector<SignatureDefData> signature_defs_vec;
1944   subgraph_index = 0;
1945   // Build SignatureDefs for the tf.entry_function based func ops.
1946   for (auto fn : entry_functions) {
1947     auto signature_defs = BuildSignaturedef(
1948         fn, saved_model_tags_.empty() ? "" : *saved_model_tags_.begin(),
1949         subgraph_index, name_mapper_);
1950     for (const auto& signature_def : signature_defs) {
1951       signature_defs_vec.push_back(signature_def);
1952     }
1953     // When we export each function in the module op, intentionally, we export
1954     // the entry functions at the beginning of the subgraph list and the
1955     // subgraph_index is the index in entry functions and at the same, is the
1956     // index in the subgraph list.
1957     ++subgraph_index;
1958   }
1959   auto signature_defs = CreateSignatureDefs(signature_defs_vec);
1960 
1961   auto model = tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
1962                                    builder_.CreateVector(opcodes_),
1963                                    builder_.CreateVector(subgraphs),
1964                                    description, builder_.CreateVector(buffers_),
1965                                    metadata_buffer, *metadata, *signature_defs);
1966   tflite::FinishModelBuffer(builder_, model);
1967   tflite::UpdateOpVersion(builder_.GetBufferPointer());
1968   tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer());
1969 
1970   // Return serialized string for the built FlatBuffer.
1971   return std::string(reinterpret_cast<const char*>(builder_.GetBufferPointer()),
1972                      builder_.GetSize());
1973 }
1974 
BuildSparsityParameters(const mlir::TFL::SparsityParameterAttr & s_attr)1975 BufferOffset<tflite::SparsityParameters> Translator::BuildSparsityParameters(
1976     const mlir::TFL::SparsityParameterAttr& s_attr) {
1977   const int dim_size = s_attr.dim_metadata().size();
1978   std::vector<flatbuffers::Offset<tflite::DimensionMetadata>> fb_dim_metadata(
1979       dim_size);
1980   for (int i = 0; i < dim_size; i++) {
1981     const auto dim_metadata =
1982         s_attr.dim_metadata()[i].dyn_cast<mlir::TFL::DimensionMetadataAttr>();
1983     if (dim_metadata.format().getValue() == "DENSE") {
1984       fb_dim_metadata[i] =
1985           tflite::CreateDimensionMetadata(builder_, tflite::DimensionType_DENSE,
1986                                           dim_metadata.dense_size().getInt());
1987 
1988     } else {
1989       auto segments = dim_metadata.segments();
1990       std::vector<int> vector_segments(segments.size(), 0);
1991       for (int j = 0, end = segments.size(); j < end; j++) {
1992         vector_segments[j] = segments[j].dyn_cast<mlir::IntegerAttr>().getInt();
1993       }
1994       tflite::SparseIndexVector segments_type;
1995       BufferOffset<void> array_segments;
1996       // The segment array is sorted.
1997       // TODO(b/147449640): Clean this up with util functions.
1998       int max_of_segments = vector_segments[segments.size() - 1];
1999       if (max_of_segments <= UINT8_MAX) {
2000         segments_type = tflite::SparseIndexVector_Uint8Vector;
2001         std::vector<uint8_t> uint8_vector(vector_segments.begin(),
2002                                           vector_segments.end());
2003         array_segments = tflite::CreateUint8Vector(
2004                              builder_, builder_.CreateVector(uint8_vector))
2005                              .Union();
2006       } else if (max_of_segments <= UINT16_MAX) {
2007         segments_type = tflite::SparseIndexVector_Uint16Vector;
2008         std::vector<uint16_t> uint16_vector(vector_segments.begin(),
2009                                             vector_segments.end());
2010         array_segments = tflite::CreateUint16Vector(
2011                              builder_, builder_.CreateVector(uint16_vector))
2012                              .Union();
2013       } else {
2014         segments_type = tflite::SparseIndexVector_Int32Vector;
2015         array_segments = tflite::CreateInt32Vector(
2016                              builder_, builder_.CreateVector(vector_segments))
2017                              .Union();
2018       }
2019 
2020       auto indices = dim_metadata.indices();
2021       std::vector<int> vector_indices(indices.size(), 0);
2022       int max_of_indices = 0;
2023       for (int j = 0, end = indices.size(); j < end; j++) {
2024         vector_indices[j] = indices[j].dyn_cast<mlir::IntegerAttr>().getInt();
2025         if (vector_indices[j] > max_of_indices) {
2026           max_of_indices = vector_indices[j];
2027         }
2028       }
2029       tflite::SparseIndexVector indices_type;
2030       BufferOffset<void> array_indices;
2031       if (max_of_indices <= UINT8_MAX) {
2032         indices_type = tflite::SparseIndexVector_Uint8Vector;
2033         std::vector<uint8_t> uint8_vector(vector_indices.begin(),
2034                                           vector_indices.end());
2035         array_indices = tflite::CreateUint8Vector(
2036                             builder_, builder_.CreateVector(uint8_vector))
2037                             .Union();
2038       } else if (max_of_indices <= UINT16_MAX) {
2039         indices_type = tflite::SparseIndexVector_Uint16Vector;
2040         std::vector<uint16_t> uint16_vector(vector_indices.begin(),
2041                                             vector_indices.end());
2042         array_indices = tflite::CreateUint16Vector(
2043                             builder_, builder_.CreateVector(uint16_vector))
2044                             .Union();
2045       } else {
2046         indices_type = tflite::SparseIndexVector_Int32Vector;
2047         array_indices = tflite::CreateInt32Vector(
2048                             builder_, builder_.CreateVector(vector_indices))
2049                             .Union();
2050       }
2051 
2052       fb_dim_metadata[i] = tflite::CreateDimensionMetadata(
2053           builder_, tflite::DimensionType_SPARSE_CSR, 0, segments_type,
2054           array_segments, indices_type, array_indices);
2055     }
2056   }
2057 
2058   std::vector<int> traversal_order(dim_size);
2059   for (int i = 0; i < dim_size; i++) {
2060     traversal_order[i] =
2061         s_attr.traversal_order()[i].dyn_cast<mlir::IntegerAttr>().getInt();
2062   }
2063   const int block_map_size = s_attr.block_map().size();
2064   std::vector<int> block_map(block_map_size);
2065   for (int i = 0; i < block_map_size; i++) {
2066     block_map[i] = s_attr.block_map()[i].dyn_cast<mlir::IntegerAttr>().getInt();
2067   }
2068 
2069   return tflite::CreateSparsityParameters(
2070       builder_, builder_.CreateVector(traversal_order),
2071       builder_.CreateVector(block_map), builder_.CreateVector(fb_dim_metadata));
2072 }
2073 
2074 }  // namespace
2075 
2076 namespace tflite {
2077 // TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting
2078 // the following:
2079 //
2080 // * Quantization
2081 // * Ops with variable tensors
2082 //
MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,const FlatbufferExportOptions & options,std::string * serialized_flatbuffer)2083 bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
2084                                        const FlatbufferExportOptions& options,
2085                                        std::string* serialized_flatbuffer) {
2086   auto maybe_translated = Translator::Translate(
2087       module, options.emit_builtin_tflite_ops, options.emit_select_tf_ops,
2088       options.emit_custom_ops, options.allow_all_select_tf_ops,
2089       options.select_user_tf_ops, options.saved_model_tags,
2090       options.op_or_arg_name_mapper, options.metadata);
2091   if (!maybe_translated) return false;
2092   *serialized_flatbuffer = std::move(*maybe_translated);
2093   return true;
2094 }
2095 
2096 }  // namespace tflite
2097