• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
17 
18 #include <stddef.h>
19 #include <stdlib.h>
20 
21 #include <cstdint>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/base/attributes.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/strings/match.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/strings/string_view.h"
34 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
35 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
36 #include "llvm/ADT/ArrayRef.h"
37 #include "llvm/ADT/DenseMap.h"
38 #include "llvm/ADT/None.h"
39 #include "llvm/ADT/Optional.h"
40 #include "llvm/ADT/StringRef.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Support/CommandLine.h"
43 #include "llvm/Support/FormatVariadic.h"
44 #include "llvm/Support/ToolOutputFile.h"
45 #include "llvm/Support/raw_ostream.h"
46 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
47 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
48 #include "mlir/IR/Attributes.h"  // from @llvm-project
49 #include "mlir/IR/Builders.h"  // from @llvm-project
50 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
51 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
52 #include "mlir/IR/Location.h"  // from @llvm-project
53 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
54 #include "mlir/IR/Operation.h"  // from @llvm-project
55 #include "mlir/IR/Types.h"  // from @llvm-project
56 #include "mlir/IR/Value.h"  // from @llvm-project
57 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
58 #include "mlir/Translation.h"  // from @llvm-project
59 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
60 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
61 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
62 #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
63 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
64 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
65 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
67 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
68 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
69 #include "tensorflow/compiler/xla/statusor.h"
70 #include "tensorflow/core/framework/attr_value.pb.h"
71 #include "tensorflow/core/framework/node_def.pb.h"
72 #include "tensorflow/core/platform/errors.h"
73 #include "tensorflow/core/platform/logging.h"
74 #include "tensorflow/core/platform/status.h"
75 #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
76 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
77 #include "tensorflow/lite/schema/schema_conversion_utils.h"
78 #include "tensorflow/lite/schema/schema_generated.h"
79 #include "tensorflow/lite/string_util.h"
80 #include "tensorflow/lite/tools/versioning/op_version.h"
81 #include "tensorflow/lite/tools/versioning/runtime_version.h"
82 #include "tensorflow/lite/version.h"
83 
84 using llvm::dyn_cast;
85 using llvm::formatv;
86 using llvm::isa;
87 using llvm::Optional;
88 using llvm::StringRef;
89 using llvm::Twine;
90 using mlir::Dialect;
91 using mlir::ElementsAttr;
92 using mlir::FuncOp;
93 using mlir::MLIRContext;
94 using mlir::ModuleOp;
95 using mlir::NoneType;
96 using mlir::Operation;
97 using mlir::Region;
98 using mlir::StringAttr;
99 using mlir::TensorType;
100 using mlir::Type;
101 using mlir::UnknownLoc;
102 using mlir::Value;
103 using tensorflow::OpOrArgLocNameMapper;
104 using tensorflow::OpOrArgNameMapper;
105 using tensorflow::Status;
106 using tflite::flex::IsAllowlistedFlexOp;
107 using xla::StatusOr;
108 
109 template <typename T>
110 using BufferOffset = flatbuffers::Offset<T>;
111 
112 template <typename T>
113 using VectorBufferOffset = flatbuffers::Offset<flatbuffers::Vector<T>>;
114 
115 using CustomOptionsOffset = VectorBufferOffset<uint8_t>;
116 
117 namespace error = tensorflow::error;
118 namespace tfl = mlir::TFL;
119 
120 ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex";
121 
122 // Use initial buffer size in flatbuffer builder to be same as the initial size
123 // used by the TOCO export. (It does not explain rationale for this choice.)
124 constexpr size_t kInitialBufferSize = 10240;
125 
126 // Set `isSigned` to false if the `type` is an 8-bit unsigned integer type.
127 // Since tflite doesn't support unsigned for other types, returns error if
128 // `isSigned` is set to false for other types.
GetTFLiteType(Type type,bool is_signed=true)129 static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
130                                                   bool is_signed = true) {
131   if (!is_signed && type.isSignlessInteger(8)) {
132     return tflite::TensorType_UINT8;
133   }
134   if (!is_signed) {
135     return Status(error::INVALID_ARGUMENT,
136                   "'isSigned' can only be set for 8-bits integer type");
137   }
138 
139   if (type.isF32()) {
140     return tflite::TensorType_FLOAT32;
141   } else if (type.isF16()) {
142     return tflite::TensorType_FLOAT16;
143   } else if (type.isF64()) {
144     return tflite::TensorType_FLOAT64;
145   } else if (type.isa<mlir::TF::StringType>()) {
146     return tflite::TensorType_STRING;
147   } else if (type.isa<mlir::TF::Quint8Type>()) {
148     return tflite::TensorType_UINT8;
149   } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
150     auto ftype = complex_type.getElementType();
151     if (ftype.isF32()) {
152       return tflite::TensorType_COMPLEX64;
153     }
154     if (ftype.isF64()) {
155       return tflite::TensorType_COMPLEX128;
156     }
157     return Status(error::INVALID_ARGUMENT, "Unsupported type");
158   } else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
159     switch (itype.getWidth()) {
160       case 1:
161         return tflite::TensorType_BOOL;
162       case 8:
163         return itype.isUnsigned() ? tflite::TensorType_UINT8
164                                   : tflite::TensorType_INT8;
165       case 16:
166         return tflite::TensorType_INT16;
167       case 32:
168         return itype.isUnsigned() ? tflite::TensorType_UINT32
169                                   : tflite::TensorType_INT32;
170       case 64:
171         return itype.isUnsigned() ? tflite::TensorType_UINT64
172                                   : tflite::TensorType_INT64;
173     }
174   } else if (auto q_uniform_type =
175                  type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
176     return GetTFLiteType(q_uniform_type.getStorageType(),
177                          q_uniform_type.isSigned());
178   } else if (auto q_peraxis_type =
179                  type.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
180     return GetTFLiteType(q_peraxis_type.getStorageType(),
181                          q_peraxis_type.isSigned());
182   } else if (auto q_calibrated_type =
183                  type.dyn_cast<mlir::quant::CalibratedQuantizedType>()) {
184     return GetTFLiteType(q_calibrated_type.getExpressedType());
185   } else if (type.isa<mlir::TF::ResourceType>()) {
186     return tflite::TensorType_RESOURCE;
187   } else if (type.isa<mlir::TF::VariantType>()) {
188     return tflite::TensorType_VARIANT;
189   }
190   // TFLite export fills FLOAT32 for unknown data types. Returning an error
191   // for now for safety and this could be revisited when required.
192   return Status(error::INVALID_ARGUMENT, "Unsupported type");
193 }
194 
IsConst(Operation * op)195 static bool IsConst(Operation* op) {
196   return isa<mlir::ConstantOp, mlir::TF::ConstOp, tfl::ConstOp, tfl::QConstOp,
197              tfl::SparseConstOp, tfl::SparseQConstOp>(op);
198 }
199 
IsTFResourceOp(Operation * op)200 static bool IsTFResourceOp(Operation* op) {
201   for (const auto& operand : op->getOperands()) {
202     auto elementType = getElementTypeOrSelf(operand.getType());
203     if (elementType.isa<mlir::TF::ResourceType>()) {
204       return true;
205     }
206   }
207   for (const auto& result : op->getResults()) {
208     auto elementType = getElementTypeOrSelf(result.getType());
209     if (elementType.isa<mlir::TF::ResourceType>()) {
210       return true;
211     }
212   }
213   return false;
214 }
215 
216 // Create description of operation that could not be converted.
GetOpDescriptionForDebug(Operation * inst)217 static std::string GetOpDescriptionForDebug(Operation* inst) {
218   const int kLargeElementsAttr = 16;
219   std::string op_str;
220   llvm::raw_string_ostream os(op_str);
221   inst->getName().print(os);
222   // Print out attributes except for large elementsattributes (which should
223   // rarely be the cause why the legalization didn't happen).
224   if (!inst->getAttrDictionary().empty()) {
225     os << " {";
226     bool first = true;
227     for (auto& named_attr : inst->getAttrDictionary()) {
228       os << (!first ? ", " : "");
229       first = false;
230       named_attr.first.print(os);
231       os << " = ";
232       if (auto element_attr = named_attr.second.dyn_cast<ElementsAttr>()) {
233         if (element_attr.getNumElements() <= kLargeElementsAttr) {
234           element_attr.print(os);
235         } else {
236           os << "<large>";
237         }
238       } else {
239         named_attr.second.print(os);
240       }
241     }
242     os << "}";
243   }
244   return os.str();
245 }
246 
247 // Create a summary with the given information regarding op names and
248 // descriptions.
GetOpsSummary(const std::map<std::string,std::set<std::string>> & ops,const std::string & summary_title)249 static std::string GetOpsSummary(
250     const std::map<std::string, std::set<std::string>>& ops,
251     const std::string& summary_title) {
252   std::string op_str;
253   llvm::raw_string_ostream os(op_str);
254 
255   std::vector<std::string> keys;
256   keys.reserve(ops.size());
257 
258   std::vector<std::string> values;
259   values.reserve(ops.size());
260 
261   for (auto const& op_name_and_details : ops) {
262     keys.push_back(op_name_and_details.first);
263     for (auto const& op_detail : op_name_and_details.second) {
264       values.push_back(op_detail);
265     }
266   }
267 
268   os << summary_title << " ops: " << absl::StrJoin(keys, ", ") << "\n";
269   os << "Details:\n\t" << absl::StrJoin(values, "\n\t");
270 
271   return os.str();
272 }
273 
274 template <typename T>
HasValidTFLiteType(Value value,T & error_handler)275 static bool HasValidTFLiteType(Value value, T& error_handler) {
276   // None type is allowed to represent unspecified operands.
277   if (value.getType().isa<NoneType>()) return true;
278 
279   auto type = value.getType().dyn_cast<TensorType>();
280   if (!type) {
281     if (auto op = value.getDefiningOp()) {
282       error_handler.emitError()
283           << '\'' << op << "' should produce value of tensor type instead of "
284           << value.getType();
285       return false;
286     }
287     error_handler.emitError("expected tensor type, got ") << value.getType();
288     return false;
289   }
290 
291   Type element_type = type.getElementType();
292   auto status = GetTFLiteType(element_type);
293   if (!status.ok()) {
294     return error_handler.emitError(
295                formatv("Failed to convert element type '{0}': {1}",
296                        element_type, status.status().error_message())),
297            false;
298   }
299   return true;
300 }
301 
302 // Returns true if the module holds all the invariants expected by the
303 // Translator class.
304 // TODO(hinsu): Now that translation is done by making a single pass over the
305 // MLIR module, consider inlining these validation checks at the place where
306 // these invariants are assumed instead of checking upfront.
IsValidTFLiteMlirModule(ModuleOp module)307 static bool IsValidTFLiteMlirModule(ModuleOp module) {
308   MLIRContext* context = module.getContext();
309 
310   // Verify that module has a function named main.
311   FuncOp main_fn = module.lookupSymbol<FuncOp>("main");
312   if (!main_fn) {
313     return emitError(UnknownLoc::get(context),
314                      "should have a function named 'main'"),
315            false;
316   }
317 
318   for (auto fn : module.getOps<FuncOp>()) {
319     if (!llvm::hasSingleElement(fn)) {
320       return fn.emitError("should have exactly one basic block"), false;
321     }
322     auto& bb = fn.front();
323 
324     for (auto arg : bb.getArguments()) {
325       if (!HasValidTFLiteType(arg, fn)) {
326         auto elementType = getElementTypeOrSelf(arg.getType());
327         if (elementType.isa<mlir::TF::VariantType>()) {
328           return fn.emitError(
329                      "function argument uses variant type. Currently, the "
330                      "variant type is not natively supported in TFLite. Please "
331                      "consider not using the variant type: ")
332                      << arg.getType(),
333                  false;
334         }
335         return fn.emitError("invalid TFLite type: ") << arg.getType(), false;
336       }
337     }
338 
339     // Verify that all operations except the terminator have exactly one
340     // result of type supported by TFLite.
341     for (auto& inst : bb) {
342       if (inst.hasTrait<mlir::OpTrait::IsTerminator>()) break;
343 
344       for (auto result : inst.getResults()) {
345         if (!HasValidTFLiteType(result, inst)) {
346           auto elementType = getElementTypeOrSelf(result.getType());
347           if (elementType.isa<mlir::TF::VariantType>()) {
348             return inst.emitError(
349                        "operand result uses variant type. Currently, the "
350                        "variant type is not natively supported in TFLite. "
351                        "Please "
352                        "consider not using the variant type: ")
353                        << result.getType(),
354                    false;
355           }
356           return fn.emitError("invalid TFLite type: ") << result.getType(),
357                  false;
358         }
359       }
360     }
361   }
362 
363   return true;
364 }
365 
GetTensorFlowNodeDef(::mlir::Operation * inst)366 static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef(
367     ::mlir::Operation* inst) {
368   // We pass empty string for the original node_def name since Flex runtime
369   // does not care about this being set correctly on node_def. There is no
370   // "easy" (see b/120948529) way yet to get this from MLIR inst.
371   auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef(
372       inst, /*name=*/"", /*ignore_unregistered_attrs=*/true);
373   if (!status_or_node_def.ok()) {
374     inst->emitOpError(
375         Twine("failed to obtain TensorFlow nodedef with status: " +
376               status_or_node_def.status().ToString()));
377     return {};
378   }
379   return std::move(status_or_node_def.ValueOrDie());
380 }
381 
382 // Converts a mlir padding StringRef to TfLitePadding.
383 // Returns llvm::None if conversion fails.
GetTflitePadding(Operation * inst,llvm::StringRef padding)384 static Optional<TfLitePadding> GetTflitePadding(Operation* inst,
385                                                 llvm::StringRef padding) {
386   const tflite::Padding padding_attr =
387       std::move(llvm::StringSwitch<tflite::Padding>(padding)
388                     .Case("SAME", tflite::Padding_SAME)
389                     .Case("VALID", tflite::Padding_VALID));
390   if (padding_attr == tflite::Padding_SAME) {
391     return kTfLitePaddingSame;
392   }
393   if (padding_attr == tflite::Padding_VALID) {
394     return kTfLitePaddingValid;
395   }
396 
397   return inst->emitOpError() << "Invalid padding attribute: " << padding,
398          llvm::None;
399 }
400 
401 // Extracts TfLitePoolParams from a TFL custom op.
402 // Template parameter, TFLOp, should be a TFL custom op containing attributes
403 // generated from TfLitePoolParams.
404 // Returns llvm::None if conversion fails.
405 template <typename TFLOp>
GetTflitePoolParams(Operation * inst,TFLOp op)406 static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
407                                                       TFLOp op) {
408   TfLitePoolParams pool_params;
409   pool_params.stride_height = op.stride_h().getSExtValue();
410   pool_params.stride_width = op.stride_w().getSExtValue();
411   pool_params.filter_height = op.filter_h().getSExtValue();
412   pool_params.filter_width = op.filter_w().getSExtValue();
413   const auto padding = GetTflitePadding(inst, op.padding());
414   if (padding) {
415     pool_params.padding = *padding;
416     pool_params.activation = kTfLiteActNone;
417     pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
418     return pool_params;
419   }
420 
421   return llvm::None;
422 }
423 
424 namespace {
425 
426 // Helper struct that wraps inputs/outputs of a single SignatureDef.
427 struct SignatureDefData {
428   // Note, we are using maps here to make order deterministic
429   // for easily testing only.
430 
431   // Inputs defined in the signature def mapped to tensor names.
432   std::map<std::string, std::string> inputs;
433   // Outputs defined in the signature def mapped to tensor names.
434   std::map<std::string, std::string> outputs;
435   // Method name exported by the signature def.
436   std::string method_name;
437   // SignatureDef key.
438   std::string signature_def_key;
439 };
440 
441 // Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
442 class Translator {
443  public:
444   // Translates the given MLIR module into TFLite FlatBuffer format and returns
445   // the serialized output. Returns llvm::None on unsupported, invalid inputs or
446   // internal error.
447   static Optional<std::string> Translate(
448       ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
449       bool emit_custom_ops,
450       const std::unordered_set<std::string>& select_user_tf_ops,
451       const std::unordered_set<std::string>& tags,
452       OpOrArgNameMapper* op_or_arg_name_mapper);
453 
454  private:
455   enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
Translator(ModuleOp module,bool emit_builtin_tflite_ops,bool emit_select_tf_ops,bool emit_custom_ops,const std::unordered_set<std::string> & select_user_tf_ops,const std::unordered_set<std::string> & saved_model_tags,OpOrArgNameMapper * op_or_arg_name_mapper)456   explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops,
457                       bool emit_select_tf_ops, bool emit_custom_ops,
458                       const std::unordered_set<std::string>& select_user_tf_ops,
459                       const std::unordered_set<std::string>& saved_model_tags,
460                       OpOrArgNameMapper* op_or_arg_name_mapper)
461       : module_(module),
462         name_mapper_(*op_or_arg_name_mapper),
463         builder_(kInitialBufferSize),
464         saved_model_tags_(saved_model_tags),
465         select_user_tf_ops_(select_user_tf_ops) {
466     // The first buffer must be empty according to the schema definition.
467     empty_buffer_ = tflite::CreateBuffer(builder_);
468     buffers_.push_back(empty_buffer_);
469     if (emit_builtin_tflite_ops) {
470       enabled_op_types_.emplace(OpType::kTfliteBuiltin);
471     }
472     if (emit_select_tf_ops) {
473       enabled_op_types_.emplace(OpType::kSelectTf);
474     }
475     if (emit_custom_ops) {
476       enabled_op_types_.emplace(OpType::kCustomOp);
477     }
478     tf_dialect_ =
479         module.getContext()->getOrLoadDialect<mlir::TF::TensorFlowDialect>();
480     tfl_dialect_ = module.getContext()
481                        ->getOrLoadDialect<mlir::TFL::TensorFlowLiteDialect>();
482     // Right now the TF executor dialect is still needed to build NodeDef.
483     module.getContext()
484         ->getOrLoadDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
485   }
486 
487   Optional<std::string> TranslateInternal();
488 
489   // Returns TFLite buffer populated with constant value if the operation is
490   // TFLite constant operation. Otherwise, returns an empty buffer. Emits error
491   // and returns llvm::None on failure.
492   Optional<BufferOffset<tflite::Buffer>> BuildBuffer(Operation* inst);
493 
494   // Build TFLite tensor from the given type. This function is for tfl.lstm
495   // intermediates, which should have UniformQuantizedType.
496   Optional<BufferOffset<tflite::Tensor>> BuildTensorFromType(
497       mlir::Type type, const std::string& name);
498 
499   // Builds TFLite tensor from the given value. `buffer_idx` is index of the
500   // corresponding buffer. Emits error and returns llvm::None on failure.
501   Optional<BufferOffset<tflite::Tensor>> BuildTensor(
502       Value value, const std::string& name, unsigned buffer_idx,
503       const Optional<BufferOffset<tflite::QuantizationParameters>>&
504           quant_parameters);
505 
506   // TODO(b/137395003): Legalize control flow ops to TFLite dialect, and remove
507   // these 2 functions here.
508   BufferOffset<tflite::Operator> BuildIfOperator(
509       mlir::TF::IfOp op, const std::vector<int32_t>& operands,
510       const std::vector<int32_t>& results);
511   BufferOffset<tflite::Operator> BuildWhileOperator(
512       mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
513       const std::vector<int32_t>& results);
514 
515   // Build while operator where cond & body are regions.
516   Optional<BufferOffset<tflite::Operator>> BuildWhileOperator(
517       mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
518       const std::vector<int32_t>& results);
519 
520   // Build call once operator.
521   BufferOffset<tflite::Operator> BuildCallOnceOperator(
522       mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
523       const std::vector<int32_t>& results);
524 
525   BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
526       mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
527       const std::vector<int32_t>& results);
528 
529   // Builds Assign/Read Variable ops.
530   template <typename T>
531   BufferOffset<tflite::Operator> BuildVariableOperator(
532       T op, const std::string& op_name, const std::vector<int32_t>& operands,
533       const std::vector<int32_t>& results);
534 
535   BufferOffset<tflite::Operator> BuildCustomOperator(
536       Operation* inst, mlir::TFL::CustomOp op,
537       const std::vector<int32_t>& operands,
538       const std::vector<int32_t>& results);
539 
540   Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
541       const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
542 
543   Optional<CustomOptionsOffset> CreateCustomOpCustomOptions(
544       const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
545 
546   std::unique_ptr<flexbuffers::Builder> CreateFlexBuilderWithNodeAttrs(
547       const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
548 
549   // Returns opcode index for op identified by the op_name, if already
550   // available. Otherwise, creates a new OperatorCode using the given `builtin`
551   // operator and associates it with `op_name`.
552   uint32_t GetOpcodeIndex(const std::string& op_name,
553                           tflite::BuiltinOperator builtin);
554 
555   // Builds operator for the given operation with specified operand and result
556   // tensor indices. Emits an error and returns llvm::None on failure.
557   Optional<BufferOffset<tflite::Operator>> BuildOperator(
558       Operation* inst, std::vector<int32_t> operands,
559       const std::vector<int32_t>& results,
560       const std::vector<int32_t>& intermediates);
561 
562   // Returns the quantization parameters for output value of "quant.stats" op.
563   BufferOffset<tflite::QuantizationParameters>
564   GetQuantizationForQuantStatsOpOutput(mlir::quant::StatisticsOp stats_op);
565 
566   // Build a subgraph with a given name out of the region either corresponding
567   // to a function's body or while op.
568   Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(
569       const std::string& name, Region* region);
570 
571   // Builds Metadata with the given `name` and buffer `content`.
572   BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
573                                                StringRef content);
574 
575   // Encodes the `tfl.metadata` dictionary attribute of the module to the
576   // metadata section in the final model.
577   Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
578   CreateMetadataVector();
579 
580   // Builds and returns list of tfl.SignatureDef sections in the model.
581   Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
582   CreateSignatureDefs(const std::vector<SignatureDefData>& signature_defs);
583 
584   // Returns list of offsets for the passed 'items' in TensorMap structure
585   // inside the flatbuffer.
586   // 'items' is a map from tensor name in signatureDef to tensor name in
587   // the model.
588   std::vector<BufferOffset<tflite::TensorMap>> GetList(
589       const std::map<std::string, std::string>& items);
590 
591   // Uses the tf.entry_function attribute (if set) to initialize the op to name
592   // mapping.
593   void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr);
594 
595   // Determines if the specified operation op's operand at operand_index
596   // is marked as a stateful operand.
597   bool IsStatefulOperand(mlir::Operation* op, int operand_index);
598 
599   // Returns a unique name for `val`.
600   std::string UniqueName(mlir::Value val);
601 
602   BufferOffset<tflite::SparsityParameters> BuildSparsityParameters(
603       const mlir::TFL::SparsityParameterAttr& s_attr);
604 
605   ModuleOp module_;
606 
607   tensorflow::OpOrArgNameMapper& name_mapper_;
608 
609   flatbuffers::FlatBufferBuilder builder_;
610   BufferOffset<tflite::Buffer> empty_buffer_;
611 
612   std::vector<BufferOffset<tflite::Buffer>> buffers_;
613   // Maps tensor name in the graph to the tensor index.
614   absl::flat_hash_map<std::string, int> tensor_index_map_;
615 
616   // Maps op name to index of the corresponding OperatorCode in opcodes_ vector.
617   absl::flat_hash_map<std::string, uint32_t> opcode_index_map_;
618   std::vector<BufferOffset<tflite::OperatorCode>> opcodes_;
619 
620   // Maps function name to index of the corresponding subgraph in the FlatBuffer
621   // model.
622   absl::flat_hash_map<std::string, int> subgraph_index_map_;
623   absl::flat_hash_set<OpType> enabled_op_types_;
624 
625   // Points to TensorFlow and TFLite dialects, respectively. nullptr if the
626   // dialect is not registered.
627   const Dialect* tf_dialect_;
628   const Dialect* tfl_dialect_;
629 
630   // The failed ops during legalization.
631   std::map<std::string, std::set<std::string>> failed_flex_ops_;
632   std::map<std::string, std::set<std::string>> failed_custom_ops_;
633 
634   // Ops to provide warning messages.
635   std::map<std::string, std::set<std::string>> custom_ops_;
636   std::map<std::string, std::set<std::string>> flex_ops_;
637 
638   // Resource ops to provide warning messages.
639   std::map<std::string, std::set<std::string>> resource_ops_;
640 
641   // Set of saved model tags, if any.
642   const std::unordered_set<std::string> saved_model_tags_;
643   // User's defined ops allowed with Flex.
644   const std::unordered_set<std::string> select_user_tf_ops_;
645 };
646 
UniqueName(mlir::Value val)647 std::string Translator::UniqueName(mlir::Value val) {
648   return std::string(name_mapper_.GetUniqueName(val));
649 }
650 
BuildBuffer(Operation * inst)651 Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
652     Operation* inst) {
653   ElementsAttr attr;
654   if (auto cst = dyn_cast<mlir::ConstantOp>(inst)) {
655     // ConstantOp have ElementAttr at this point due to validation of the TFLite
656     // module.
657     attr = cst.getValue().cast<ElementsAttr>();
658   } else if (auto cst = dyn_cast<mlir::TF::ConstOp>(inst)) {
659     attr = cst.value();
660   } else if (auto cst = dyn_cast<tfl::ConstOp>(inst)) {
661     attr = cst.value();
662   } else if (auto cst = dyn_cast<tfl::QConstOp>(inst)) {
663     attr = cst.value();
664   } else if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
665     attr = cst.compressed_data();
666   } else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
667     attr = cst.compressed_data();
668   } else {
669     return empty_buffer_;
670   }
671 
672   tensorflow::Tensor tensor;
673   auto status = tensorflow::ConvertToTensor(attr, &tensor);
674   if (!status.ok()) {
675     inst->emitError(
676         Twine("failed to convert value attribute to tensor with error: " +
677               status.ToString()));
678     return llvm::None;
679   }
680 
681   // TensorFlow and TensorFlow Lite use different string encoding formats.
682   // Convert to TensorFlow Lite format is it's a constant string tensor.
683   if (tensor.dtype() == tensorflow::DT_STRING) {
684     ::tflite::DynamicBuffer dynamic_buffer;
685     auto flat = tensor.flat<::tensorflow::tstring>();
686     for (int i = 0; i < flat.size(); ++i) {
687       const auto& str = flat(i);
688       dynamic_buffer.AddString(str.c_str(), str.length());
689     }
690     char* tensor_buffer;
691     int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer);
692     auto buffer_data =
693         builder_.CreateVector(reinterpret_cast<uint8_t*>(tensor_buffer), bytes);
694     free(tensor_buffer);
695     return tflite::CreateBuffer(builder_, buffer_data);
696   }
697 
698   absl::string_view tensor_data = tensor.tensor_data();
699   auto buffer_data = builder_.CreateVector(
700       reinterpret_cast<const uint8_t*>(tensor_data.data()), tensor_data.size());
701   return tflite::CreateBuffer(builder_, buffer_data);
702 }
703 
BuildTensorFromType(mlir::Type type,const std::string & name)704 Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensorFromType(
705     mlir::Type type, const std::string& name) {
706   auto tensor_type = type.cast<TensorType>();
707 
708   if (!tensor_type.hasStaticShape()) {
709     return llvm::None;
710   }
711   llvm::ArrayRef<int64_t> shape_ref = tensor_type.getShape();
712   std::vector<int32_t> shape(shape_ref.begin(), shape_ref.end());
713 
714   auto element_type = tensor_type.getElementType();
715   tflite::TensorType tflite_element_type =
716       GetTFLiteType(tensor_type.getElementType()).ValueOrDie();
717   BufferOffset<tflite::QuantizationParameters> q_params = 0;
718   if (auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
719     q_params = tflite::CreateQuantizationParameters(
720         builder_, /*min=*/0, /*max=*/0,
721         builder_.CreateVector<float>({static_cast<float>(qtype.getScale())}),
722         builder_.CreateVector<int64_t>({qtype.getZeroPoint()}));
723   } else if (auto qtype =
724                  element_type
725                      .dyn_cast<mlir::quant::CalibratedQuantizedType>()) {
726     q_params = tflite::CreateQuantizationParameters(
727         builder_,
728         builder_.CreateVector<float>({static_cast<float>(qtype.getMin())}),
729         builder_.CreateVector<float>({static_cast<float>(qtype.getMax())}));
730   }
731   return tflite::CreateTensor(
732       builder_, builder_.CreateVector(shape), tflite_element_type,
733       /*buffer=*/0, builder_.CreateString(name), q_params,
734       /*is_variable=*/false);
735 }
736 
BuildTensor(Value value,const std::string & name,unsigned buffer_idx,const Optional<BufferOffset<tflite::QuantizationParameters>> & quant_parameters)737 Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
738     Value value, const std::string& name, unsigned buffer_idx,
739     const Optional<BufferOffset<tflite::QuantizationParameters>>&
740         quant_parameters) {
741   auto type = value.getType().cast<TensorType>();
742 
743   // TFLite requires tensor shape only for the inputs and constants.
744   // However, we output all known shapes for better round-tripping
745   auto check_shape =
746       [&](llvm::ArrayRef<int64_t> shape_ref) -> mlir::LogicalResult {
747     auto is_out_of_range = [](int64_t dim) {
748       return dim > std::numeric_limits<int32_t>::max();
749     };
750 
751     if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range))
752       return mlir::emitError(
753           value.getLoc(),
754           "result shape dimensions out of 32 bit int type range");
755 
756     return mlir::success();
757   };
758 
759   std::vector<int32_t> shape;
760   std::vector<int32_t> shape_signature;
761   auto* inst = value.getDefiningOp();
762   if (type.hasStaticShape()) {
763     llvm::ArrayRef<int64_t> shape_ref = type.getShape();
764     if (mlir::failed(check_shape(shape_ref))) return llvm::None;
765 
766     shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
767   } else if (inst && IsConst(inst)) {
768     // Const op can have a result of dynamic shaped type (e.g. due to constant
769     // folding), but we can still derive the shape of a constant tensor for
770     // its attribute type.
771     mlir::Attribute tensor_attr = inst->getAttr("value");
772     llvm::ArrayRef<int64_t> shape_ref =
773         tensor_attr.getType().cast<TensorType>().getShape();
774     if (mlir::failed(check_shape(shape_ref))) return llvm::None;
775 
776     shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
777   } else if (type.hasRank()) {
778     llvm::ArrayRef<int64_t> shape_ref = type.getShape();
779     if (mlir::failed(check_shape(shape_ref))) return llvm::None;
780 
781     shape.reserve(shape_ref.size());
782     for (auto& dim : shape_ref) {
783       shape.push_back(dim == -1 ? 1 : dim);
784     }
785     shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
786   }
787 
788   BufferOffset<tflite::SparsityParameters> s_params = 0;
789   if (auto* inst = value.getDefiningOp()) {
790     if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
791       s_params = BuildSparsityParameters(cst.s_param());
792     } else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
793       s_params = BuildSparsityParameters(cst.s_param());
794     }
795   }
796 
797   Type element_type = type.getElementType();
798   tflite::TensorType tflite_element_type =
799       GetTFLiteType(type.getElementType()).ValueOrDie();
800 
801   BufferOffset<tflite::QuantizationParameters> q_params;
802   if (auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
803     q_params = tflite::CreateQuantizationParameters(
804         // TODO(fengliuai): min and max values are not stored in the
805         // quantized type, so both are set to 0. The model couldn't be imported
806         // to TensorFlow because of this.
807         builder_, /*min=*/0, /*max=*/0,
808         builder_.CreateVector<float>({static_cast<float>(qtype.getScale())}),
809         builder_.CreateVector<int64_t>({qtype.getZeroPoint()}));
810   } else if (auto qtype =
811                  element_type
812                      .dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
813     std::vector<float> scales(qtype.getScales().begin(),
814                               qtype.getScales().end());
815     q_params = tflite::CreateQuantizationParameters(
816         builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>(scales),
817         builder_.CreateVector<int64_t>(qtype.getZeroPoints()),
818         tflite::QuantizationDetails_NONE, /*details=*/0,
819         qtype.getQuantizedDimension());
820   } else if (quant_parameters.hasValue()) {
821     q_params = quant_parameters.getValue();
822   } else {
823     q_params = tflite::CreateQuantizationParameters(builder_);
824   }
825   // Check if the value's uses includes an op and usage at an operand index
826   // marked as a stateful. If so, set the tensor's is_variable as true
827   // This is v1 ref variable semantics in the TFLite runtime.
828   bool is_variable = false;
829   for (auto& use : value.getUses()) {
830     is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
831     if (is_variable) {
832       break;
833     }
834   }
835 
836   if (shape_signature.empty()) {
837     return tflite::CreateTensor(
838         builder_, builder_.CreateVector(shape), tflite_element_type,
839         (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
840         /*is_variable=*/is_variable, s_params);
841   } else {
842     return tflite::CreateTensor(
843         builder_, builder_.CreateVector(shape), tflite_element_type,
844         (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
845         /*is_variable=*/is_variable, s_params,
846         /*shape_signature=*/builder_.CreateVector(shape_signature));
847   }
848 }
849 
BuildIfOperator(mlir::TF::IfOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)850 BufferOffset<tflite::Operator> Translator::BuildIfOperator(
851     mlir::TF::IfOp op, const std::vector<int32_t>& operands,
852     const std::vector<int32_t>& results) {
853   auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF);
854   int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str());
855   int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str());
856   auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index,
857                                                  else_subgraph_index)
858                              .Union();
859   auto inputs = builder_.CreateVector(operands);
860   auto outputs = builder_.CreateVector(results);
861   return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
862                                 tflite::BuiltinOptions_IfOptions,
863                                 builtin_options);
864 }
865 
BuildCallOnceOperator(mlir::TFL::CallOnceOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)866 BufferOffset<tflite::Operator> Translator::BuildCallOnceOperator(
867     mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
868     const std::vector<int32_t>& results) {
869   auto opcode_index =
870       GetOpcodeIndex("call_once", tflite::BuiltinOperator_CALL_ONCE);
871   int init_subgraph_index =
872       subgraph_index_map_.at(op.session_init_function().str());
873   auto builtin_options =
874       tflite::CreateCallOnceOptions(builder_, init_subgraph_index).Union();
875   auto inputs = builder_.CreateVector(operands);
876   auto outputs = builder_.CreateVector(results);
877   return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
878                                 tflite::BuiltinOptions_CallOnceOptions,
879                                 builtin_options);
880 }
881 
BuildWhileOperator(mlir::TF::WhileOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)882 BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
883     mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
884     const std::vector<int32_t>& results) {
885   auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
886   int cond_subgraph_index = subgraph_index_map_.at(op.cond().str());
887   int body_subgraph_index = subgraph_index_map_.at(op.body().str());
888   auto builtin_options = tflite::CreateWhileOptions(
889                              builder_, cond_subgraph_index, body_subgraph_index)
890                              .Union();
891   auto inputs = builder_.CreateVector(operands);
892   auto outputs = builder_.CreateVector(results);
893   return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
894                                 tflite::BuiltinOptions_WhileOptions,
895                                 builtin_options);
896 }
897 
BuildWhileOperator(mlir::TFL::WhileOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)898 Optional<BufferOffset<tflite::Operator>> Translator::BuildWhileOperator(
899     mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
900     const std::vector<int32_t>& results) {
901   auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
902   auto get_call_index = [&](mlir::Block& b) -> Optional<int> {
903     if (b.getOperations().size() != 2) return llvm::None;
904     if (auto call_op = dyn_cast<mlir::CallOp>(b.front()))
905       return subgraph_index_map_.at(call_op.callee().str());
906     return llvm::None;
907   };
908   auto body_subgraph_index = get_call_index(op.body().front());
909   auto cond_subgraph_index = get_call_index(op.cond().front());
910   if (!body_subgraph_index || !cond_subgraph_index)
911     return op.emitOpError("only single call cond/body while export supported"),
912            llvm::None;
913   auto builtin_options =
914       tflite::CreateWhileOptions(builder_, *cond_subgraph_index,
915                                  *body_subgraph_index)
916           .Union();
917   auto inputs = builder_.CreateVector(operands);
918   auto outputs = builder_.CreateVector(results);
919   return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
920                                 tflite::BuiltinOptions_WhileOptions,
921                                 builtin_options);
922 }
923 
BuildNumericVerifyOperator(mlir::TFL::NumericVerifyOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)924 BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
925     mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
926     const std::vector<int32_t>& results) {
927   float tolerance = op.tolerance().convertToFloat();
928   bool log_if_failed = op.log_if_failed();
929   auto fbb = absl::make_unique<flexbuffers::Builder>();
930   fbb->Map([&]() {
931     fbb->Float("tolerance", tolerance);
932     fbb->Bool("log_if_failed", log_if_failed);
933   });
934   fbb->Finish();
935   auto f = std::unique_ptr<flexbuffers::Builder>(fbb.release());
936   auto custom_option = f->GetBuffer();
937   auto opcode_index =
938       GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM);
939   return tflite::CreateOperator(
940       builder_, opcode_index, builder_.CreateVector(operands),
941       builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
942       /*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_option),
943       tflite::CustomOptionsFormat_FLEXBUFFERS);
944 }
945 
946 // Builds Assign/Read Variable ops.
947 template <typename T>
BuildVariableOperator(T op,const std::string & op_name,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)948 BufferOffset<tflite::Operator> Translator::BuildVariableOperator(
949     T op, const std::string& op_name, const std::vector<int32_t>& operands,
950     const std::vector<int32_t>& results) {
951   auto opcode_index = GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM);
952   return tflite::CreateOperator(
953       builder_, opcode_index, builder_.CreateVector(operands),
954       builder_.CreateVector(results), tflite::BuiltinOptions_NONE);
955 }
956 
BuildCustomOperator(Operation * inst,mlir::TFL::CustomOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)957 BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
958     Operation* inst, mlir::TFL::CustomOp op,
959     const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
960   const std::string attrs =
961       op.custom_option().cast<mlir::OpaqueElementsAttr>().getValue().str();
962   std::vector<uint8_t> custom_option_vector(attrs.size());
963   memcpy(custom_option_vector.data(), attrs.data(), attrs.size());
964   auto opcode_index =
965       GetOpcodeIndex(op.custom_code().str(), tflite::BuiltinOperator_CUSTOM);
966   return tflite::CreateOperator(
967       builder_, opcode_index, builder_.CreateVector(operands),
968       builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
969       /*builtin_options=*/0,
970       builder_.CreateVector<uint8_t>(custom_option_vector),
971       tflite::CustomOptionsFormat_FLEXBUFFERS);
972 }
973 
CreateFlexOpCustomOptions(const::tensorflow::NodeDef & node_def,const mlir::Location & loc)974 Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
975     const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
976   std::string node_def_str;
977   if (!node_def.SerializeToString(&node_def_str)) {
978     return emitError(loc, "failed to serialize tensorflow node_def"),
979            llvm::None;
980   }
981 
982   auto flex_builder = absl::make_unique<flexbuffers::Builder>();
983   flex_builder->Vector([&]() {
984     flex_builder->String(node_def.op());
985     flex_builder->String(node_def_str);
986   });
987   flex_builder->Finish();
988   return builder_.CreateVector(flex_builder->GetBuffer());
989 }
990 
CreateCustomOpCustomOptions(const::tensorflow::NodeDef & node_def,const mlir::Location & loc)991 Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions(
992     const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
993   auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc);
994   return builder_.CreateVector(flex_builder->GetBuffer());
995 }
996 
997 std::unique_ptr<flexbuffers::Builder>
CreateFlexBuilderWithNodeAttrs(const::tensorflow::NodeDef & node_def,const mlir::Location & loc)998 Translator::CreateFlexBuilderWithNodeAttrs(
999     const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
1000   auto flex_builder = absl::make_unique<flexbuffers::Builder>();
1001   size_t map_start = flex_builder->StartMap();
1002   using Item = std::pair<std::string, ::tensorflow::AttrValue>;
1003   std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end());
1004   std::sort(attrs.begin(), attrs.end(),
1005             [](Item& p1, Item& p2) -> bool { return p1.first < p2.first; });
1006   for (const Item& pair : attrs) {
1007     const char* key = pair.first.c_str();
1008     const ::tensorflow::AttrValue& attr = pair.second;
1009     switch (attr.value_case()) {
1010       case ::tensorflow::AttrValue::kS:
1011         flex_builder->String(key, attr.s());
1012         break;
1013       case ::tensorflow::AttrValue::kType: {
1014         auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type());
1015         if (status_or_tfl_type.ok()) {
1016           flex_builder->Int(key, status_or_tfl_type.ValueOrDie());
1017         } else {
1018           emitWarning(loc, "ignoring unsupported tensorflow type: ")
1019               << std::to_string(attr.type());
1020         }
1021         break;
1022       }
1023       case ::tensorflow::AttrValue::kI:
1024         flex_builder->Int(key, attr.i());
1025         break;
1026       case ::tensorflow::AttrValue::kF:
1027         flex_builder->Float(key, attr.f());
1028         break;
1029       case ::tensorflow::AttrValue::kB:
1030         flex_builder->Bool(key, attr.b());
1031         break;
1032       case tensorflow::AttrValue::kList:
1033         if (attr.list().s_size() > 0) {
1034           auto start = flex_builder->StartVector(key);
1035           for (const std::string& v : attr.list().s()) {
1036             flex_builder->Add(v);
1037           }
1038           flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1039         } else if (attr.list().i_size() > 0) {
1040           auto start = flex_builder->StartVector(key);
1041           for (const int64_t v : attr.list().i()) {
1042             flex_builder->Add(v);
1043           }
1044           flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1045         } else if (attr.list().f_size() > 0) {
1046           auto start = flex_builder->StartVector(key);
1047           for (const float v : attr.list().f()) {
1048             flex_builder->Add(v);
1049           }
1050           flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1051         } else {
1052           emitWarning(loc,
1053                       "ignoring unsupported type in list attribute with key: ")
1054               << key;
1055         }
1056         break;
1057       default:
1058         emitWarning(loc, "ignoring unsupported attribute type with key: ")
1059             << key;
1060         break;
1061     }
1062   }
1063   flex_builder->EndMap(map_start);
1064   flex_builder->Finish();
1065   return flex_builder;
1066 }
1067 
GetOpcodeIndex(const std::string & op_name,tflite::BuiltinOperator builtin)1068 uint32_t Translator::GetOpcodeIndex(const std::string& op_name,
1069                                     tflite::BuiltinOperator builtin) {
1070   auto it = opcode_index_map_.insert({op_name, 0});
1071 
1072   // If the insert succeeded, the opcode has not been created already. Create a
1073   // new operator code and update its index value in the map.
1074   if (it.second) {
1075     it.first->second = opcodes_.size();
1076     auto custom_code = builtin == tflite::BuiltinOperator_CUSTOM
1077                            ? builder_.CreateString(op_name)
1078                            : BufferOffset<flatbuffers::String>();
1079     // Use version 0 for builtin op. This is a way to serialize version field to
1080     // flatbuffer (since 0 is non default) and it will be corrected later.
1081     int32_t op_version = builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1;
1082     opcodes_.push_back(CreateOperatorCode(builder_, /*builtin_code=*/builtin,
1083                                           custom_code, op_version));
1084   }
1085   return it.first->second;
1086 }
1087 
BuildOperator(Operation * inst,std::vector<int32_t> operands,const std::vector<int32_t> & results,const std::vector<int32_t> & intermediates)1088 Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
1089     Operation* inst, std::vector<int32_t> operands,
1090     const std::vector<int32_t>& results,
1091     const std::vector<int32_t>& intermediates) {
1092   const auto* dialect = inst->getDialect();
1093   if (!dialect) {
1094     inst->emitOpError("dialect is not registered");
1095     return llvm::None;
1096   }
1097 
1098   // TODO(b/149099381): Remove this once the kernels are promoted as
1099   // builtin TFLite kernels.
1100   // We export the Assign/Read variable ops as custom ops.
1101   if (auto read_op = llvm::dyn_cast<mlir::TFL::ReadVariableOp>(inst)) {
1102     return BuildVariableOperator<mlir::TFL::ReadVariableOp>(
1103         read_op, "ReadVariable", operands, results);
1104   } else if (auto assign_op =
1105                  llvm::dyn_cast<mlir::TFL::AssignVariableOp>(inst)) {
1106     return BuildVariableOperator<mlir::TFL::AssignVariableOp>(
1107         assign_op, "AssignVariable", operands, results);
1108   }
1109 
1110   // If TFLite built in op, create operator as a builtin op.
1111   if (dialect == tfl_dialect_) {
1112     // Only if built-in TFLite op emission is enabled, would legalization have
1113     // converted any TF->TFL.
1114     if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) {
1115       return inst->emitOpError(
1116                  "is a TFLite builtin op but builtin emission is not enabled"),
1117              llvm::None;
1118     }
1119 
1120     auto builtin_code = GetBuiltinOpCode(inst);
1121     if (!builtin_code) {
1122       if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
1123         return BuildNumericVerifyOperator(verify_op, operands, results);
1124       }
1125       if (auto custom_op = dyn_cast<mlir::TFL::CustomOp>(inst)) {
1126         return BuildCustomOperator(inst, custom_op, operands, results);
1127       }
1128       if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) {
1129         if (inst->getNumOperands() != inst->getNumResults()) {
1130           inst->emitOpError(
1131               "number of operands and results don't match, only canonical "
1132               "TFL While supported");
1133           return llvm::None;
1134         }
1135         return BuildWhileOperator(whileOp, operands, results);
1136       }
1137 
1138       inst->emitOpError("is not a supported TFLite op");
1139       return llvm::None;
1140     }
1141 
1142     if (*builtin_code == tflite::BuiltinOperator_CALL_ONCE) {
1143       if (auto initOp = dyn_cast<mlir::TFL::CallOnceOp>(inst)) {
1144         return BuildCallOnceOperator(initOp, operands, results);
1145       }
1146     }
1147 
1148     std::string op_name = inst->getName().getStringRef().str();
1149     uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code);
1150 
1151     // If this is TransposeConv we need to do a special case of ignoring the
1152     // optional tensor, to allow newly created models to run on old runtimes.
1153     if (*builtin_code == tflite::BuiltinOperator_TRANSPOSE_CONV) {
1154       if (operands.size() == 4 && operands.at(3) == -1) {
1155         operands.pop_back();
1156       }
1157     }
1158 
1159     auto offset = CreateFlatBufferOperator(inst, opcode_index, operands,
1160                                            results, intermediates, &builder_);
1161     if (!offset) {
1162       inst->emitOpError("is not a supported TFLite op");
1163     }
1164     return offset;
1165   }
1166 
1167   if (dialect == tf_dialect_) {
1168     if (auto ifOp = dyn_cast<mlir::TF::IfOp>(inst)) {
1169       return BuildIfOperator(ifOp, operands, results);
1170     } else if (auto whileOp = dyn_cast<mlir::TF::WhileOp>(inst)) {
1171       return BuildWhileOperator(whileOp, operands, results);
1172     }
1173 
1174     CustomOptionsOffset custom_options;
1175 
1176     // Ops in TF dialect can either be custom ops or flex ops.
1177     // The reason we go directly from TensorFlow dialect MLIR to tensorflow
1178     // node instead of going to TF table gen'd ops via generated code is that
1179     // we do not want to restrict custom and flex op conversion support to
1180     // only those TF ops that are currently registered in MLIR. The current
1181     // model is of an open op system.
1182     //
1183     //  The following algorithm is followed:
1184     //   if flex is enabled and the op is allowlisted as flex
1185     //     we emit op as flex.
1186     //   if custom is enabled
1187     //    we emit the op as custom.
1188     auto node_def = GetTensorFlowNodeDef(inst);
1189     if (!node_def) {
1190       return llvm::None;
1191     }
1192 
1193     std::string op_name = node_def->op();
1194     std::string op_desc = GetOpDescriptionForDebug(inst);
1195 
1196     if (IsTFResourceOp(inst)) {
1197       resource_ops_[op_name].insert(op_desc);
1198     }
1199 
1200     const bool is_allowed_flex_op =
1201         IsAllowlistedFlexOp(node_def->op()) ||
1202         ((select_user_tf_ops_.count(node_def->op()) != 0) &&
1203          (tensorflow::OpRegistry::Global()->LookUp(node_def->op()) != nullptr));
1204     // Flex op case
1205     // Eventually, the allowlist will go away and we will rely on some TF op
1206     // trait (e.g. No side effect) to determine if it is a supported "Flex"
1207     // op or not.
1208     if (is_allowed_flex_op && enabled_op_types_.contains(OpType::kSelectTf)) {
1209       // Construct ops as flex op encoding TensorFlow node definition
1210       // as custom options.
1211       // Flex ops are named with the kFlexOpNamePrefix prefix to the actual
1212       // TF op name.
1213       op_name = std::string(kFlexOpNamePrefix) + node_def->op();
1214       if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) {
1215         custom_options = *options;
1216       } else {
1217         return llvm::None;
1218       }
1219 
1220       // Gather flex ops.
1221       flex_ops_[op_name].insert(op_desc);
1222     } else if (enabled_op_types_.contains(OpType::kCustomOp)) {
1223       // Generic case of custom ops - write using flex buffers since that
1224       // is the only custom options supported by TFLite today.
1225       op_name = node_def->op();
1226       if (auto options =
1227               CreateCustomOpCustomOptions(*node_def, inst->getLoc())) {
1228         custom_options = *options;
1229       } else {
1230         return llvm::None;
1231       }
1232 
1233       // Gather custom ops.
1234       custom_ops_[op_name].insert(op_desc);
1235     } else {
1236       // Insert failed op to `flex_ops` or `custom_ops`.
1237       if (is_allowed_flex_op) {
1238         failed_flex_ops_[op_name].insert(op_desc);
1239       } else {
1240         failed_custom_ops_[op_name].insert(op_desc);
1241       }
1242       return inst->emitOpError("is neither a custom op nor a flex op"),
1243              llvm::None;
1244     }
1245 
1246     uint32_t opcode_index =
1247         GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM);
1248     auto inputs = builder_.CreateVector(operands);
1249     auto outputs = builder_.CreateVector(results);
1250 
1251     return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
1252                                   tflite::BuiltinOptions_NONE,
1253                                   /*builtin_options=*/0,
1254                                   /*custom_options=*/custom_options,
1255                                   tflite::CustomOptionsFormat_FLEXBUFFERS,
1256                                   /*mutating_variable_inputs=*/0);
1257   }
1258 
1259   return inst->emitOpError(
1260              "is not any of a builtin TFLite op, a flex TensorFlow op or a "
1261              "custom TensorFlow op"),
1262          llvm::None;
1263 }
1264 
InitializeNamesFromAttribute(FuncOp fn,bool * has_input_attr)1265 void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
1266   auto dict_attr = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1267   if (!dict_attr) return;
1268 
1269   llvm::SmallVector<llvm::StringRef, 2> input_names;
1270   llvm::SmallVector<llvm::StringRef, 2> output_names;
1271   if (auto str = dict_attr.get("inputs").dyn_cast_or_null<mlir::StringAttr>()) {
1272     str.getValue().split(input_names, ',', /*MaxSplit=*/-1,
1273                          /*KeepEmpty=*/false);
1274     if (input_names.size() != fn.getNumArguments()) {
1275       fn.emitWarning() << "invalid entry function specification";
1276       return;
1277     }
1278     for (auto it : llvm::enumerate(fn.getArguments())) {
1279       name_mapper_.InitOpName(it.value(), input_names[it.index()].trim());
1280     }
1281     *has_input_attr = true;
1282   }
1283 
1284   if (auto str =
1285           dict_attr.get("outputs").dyn_cast_or_null<mlir::StringAttr>()) {
1286     str.getValue().split(output_names, ',', /*MaxSplit=*/-1,
1287                          /*KeepEmpty=*/false);
1288     auto term = fn.back().getTerminator();
1289     if (output_names.size() != term->getNumOperands()) {
1290       fn.emitWarning() << "output names (" << output_names.size()
1291                        << ") != terminator operands (" << term->getNumOperands()
1292                        << ")";
1293       return;
1294     }
1295     for (const auto& it : llvm::enumerate(term->getOperands())) {
1296       name_mapper_.InitOpName(it.value(), output_names[it.index()].trim());
1297     }
1298   }
1299 }
1300 
IsStatefulOperand(mlir::Operation * op,int operand_index)1301 bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
1302   std::vector<int> operand_indices;
1303   if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false;
1304   return absl::c_find(operand_indices, operand_index) != operand_indices.end();
1305 }
1306 
1307 BufferOffset<tflite::QuantizationParameters>
GetQuantizationForQuantStatsOpOutput(mlir::quant::StatisticsOp stats_op)1308 Translator::GetQuantizationForQuantStatsOpOutput(
1309     mlir::quant::StatisticsOp stats_op) {
1310   auto layer_stats = stats_op.layerStats().cast<mlir::DenseFPElementsAttr>();
1311   Optional<mlir::ElementsAttr> axis_stats = stats_op.axisStats();
1312   Optional<uint64_t> axis = stats_op.axis();
1313   std::vector<float> mins, maxs;
1314   mlir::DenseFPElementsAttr min_max_attr =
1315       axis_stats.hasValue()
1316           ? axis_stats.getValue().cast<mlir::DenseFPElementsAttr>()
1317           : layer_stats;
1318 
1319   for (auto index_and_value : llvm::enumerate(min_max_attr.getFloatValues())) {
1320     const llvm::APFloat value = index_and_value.value();
1321     if (index_and_value.index() % 2 == 0) {
1322       mins.push_back(value.convertToFloat());
1323     } else {
1324       maxs.push_back(value.convertToFloat());
1325     }
1326   }
1327 
1328   return tflite::CreateQuantizationParameters(
1329       builder_, builder_.CreateVector<float>(mins),
1330       builder_.CreateVector<float>(maxs), /*scale=*/0, /*zero_point=*/0,
1331       tflite::QuantizationDetails_NONE, /*details=*/0,
1332       /*quantized_dimension=*/axis.hasValue() ? axis.getValue() : 0);
1333 }
1334 
BuildSubGraph(const std::string & name,Region * region)1335 Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
1336     const std::string& name, Region* region) {
1337   bool has_input_attr = false;
1338   if (auto fn = dyn_cast<FuncOp>(region->getParentOp())) {
1339     InitializeNamesFromAttribute(fn, &has_input_attr);
1340   }
1341   std::vector<BufferOffset<tflite::Tensor>> tensors;
1342   llvm::DenseMap<Value, int> tensor_index_map;
1343 
1344   // Builds tensor and buffer for argument or operation result. Returns false
1345   // on failure.
1346   auto build_tensor_and_buffer = [&](Value value, const std::string& name) {
1347     // NoneType represents optional and may be skipped here.
1348     if (value.getType().isa<NoneType>()) {
1349       return true;
1350     }
1351 
1352     tensor_index_map.insert({value, tensors.size()});
1353     tensor_index_map_[name] = tensors.size();
1354     Optional<BufferOffset<tflite::QuantizationParameters>> quant_parameters;
1355     if (value.hasOneUse()) {
1356       auto stats_op =
1357           llvm::dyn_cast<mlir::quant::StatisticsOp>(*value.user_begin());
1358       if (stats_op) {
1359         quant_parameters = GetQuantizationForQuantStatsOpOutput(stats_op);
1360       }
1361     }
1362     auto tensor_or =
1363         BuildTensor(value, name, buffers_.size(), quant_parameters);
1364     if (!tensor_or) return false;
1365     tensors.push_back(*tensor_or);
1366 
1367     // TODO(ashwinm): Check if for stateful tensors, if it is also needed to
1368     // make the Buffer empty apart from setting the buffer_idx=0 in the
1369     // Tensor. This does not seem to affect runtime behavior for RNN/LSTM,
1370     // but would be good for reducing memory footprint.
1371     if (auto* inst = value.getDefiningOp()) {
1372       auto buffer_or = BuildBuffer(inst);
1373       if (!buffer_or) return false;
1374       buffers_.push_back(*buffer_or);
1375     } else {
1376       buffers_.push_back(empty_buffer_);
1377     }
1378     return true;
1379   };
1380 
1381   std::vector<BufferOffset<tflite::Operator>> operators;
1382   auto& bb = region->front();
1383 
1384   // Main function's arguments are first passed to `input` op so they don't
1385   // have associated tensor and buffer. Build FlatBuffer tensor and buffer for
1386   // other functions.
1387   for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) {
1388     mlir::BlockArgument arg = bb.getArgument(i);
1389     std::string name;
1390     if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg));
1391     if (name.empty()) name = absl::StrCat("arg", i);
1392     if (!build_tensor_and_buffer(arg, name)) return llvm::None;
1393   }
1394 
1395   bool failed_once = false;
1396   for (auto& inst : bb) {
1397     if (inst.hasTrait<mlir::OpTrait::IsTerminator>()) break;
1398     // For "quant.stats" op, it's used to store the quantization parameters info
1399     // and its output should be then replaced by its input value.
1400     if (auto quant_stats_op = llvm::dyn_cast<mlir::quant::StatisticsOp>(inst)) {
1401       continue;
1402     }
1403     std::vector<int32_t> intermediates;
1404     // Build intermediate tensors for tfl.lstm and insert these tensors into
1405     // flatbuffer.
1406     if (llvm::isa<mlir::TFL::LSTMOp, mlir::TFL::UnidirectionalSequenceLSTMOp>(
1407             inst)) {
1408       std::vector<std::string> intermediate_names = {
1409           "input_to_input_intermediate", "input_to_forget_intermediate",
1410           "input_to_cell_intermediate", "input_to_output_intermediate",
1411           "effective_hidden_scale_intermediate"};
1412       for (const std::string& intermediate : intermediate_names) {
1413         auto intermediate_attr = inst.getAttr(intermediate);
1414         if (auto attr = intermediate_attr.dyn_cast_or_null<mlir::TypeAttr>()) {
1415           Type qtype = attr.getValue();
1416           auto tensor_or = BuildTensorFromType(
1417               qtype, name_mapper_.GetUniqueName(intermediate).str());
1418           if (!tensor_or.hasValue()) {
1419             continue;
1420           } else {
1421             intermediates.push_back(tensors.size());
1422             tensors.push_back(tensor_or.getValue());
1423           }
1424         }
1425       }
1426     }
1427 
1428     for (auto val : inst.getResults()) {
1429       std::string name = UniqueName(val);
1430       // For "tfl.numeric_verify" op, the name is used to find out the original
1431       // activation tensor rather than its own unique name in the visualization
1432       // or debugging tools.
1433       auto builtin_code = GetBuiltinOpCode(&inst);
1434       if (!builtin_code && dyn_cast<mlir::TFL::NumericVerifyOp>(&inst)) {
1435         // The first operand is the quantized activation, the target of this
1436         // NumericVerify op.
1437         auto quantized_op_val = inst.getOperands().front();
1438         name = "NumericVerify/" + UniqueName(quantized_op_val) + ":" +
1439                std::to_string(tensor_index_map[quantized_op_val]);
1440       }
1441       if (!build_tensor_and_buffer(val, name)) return llvm::None;
1442     }
1443 
1444     // Skip constant ops as they don't represent a TFLite operator.
1445     if (IsConst(&inst)) continue;
1446 
1447     // Fetch operand and result tensor indices.
1448     std::vector<int32_t> results;
1449     results.reserve(inst.getNumResults());
1450     for (auto result : inst.getResults()) {
1451       results.push_back(tensor_index_map.lookup(result));
1452     }
1453     Operation* real_inst = &inst;
1454     // CustomTfOp is just a wrapper around a TF op, we export the custom Op
1455     // not the wrapper, so we fetch the op from the region.
1456     if (auto custom_op = dyn_cast<mlir::TFL::CustomTfOp>(inst)) {
1457       // If we have custom op with a region, then use the first op in the
1458       // region, if it exists, otherwise just use params for custom op.
1459       if (!custom_op.body().empty()) {
1460         real_inst = &custom_op.body().front().front();
1461       } else {
1462         module_.emitError(
1463             "Invalid CustomTfOp: Custom TF Op have empty region.");
1464       }
1465     }
1466     std::vector<int32_t> operands;
1467     operands.reserve(real_inst->getNumOperands());
1468     for (auto operand : real_inst->getOperands()) {
1469       if (operand.getType().isa<NoneType>())
1470         operands.push_back(kTfLiteOptionalTensor);
1471       else if (auto stats_op =
1472                    llvm::dyn_cast_or_null<mlir::quant::StatisticsOp>(
1473                        operand.getDefiningOp()))
1474         operands.push_back(tensor_index_map.lookup(stats_op.arg()));
1475       else
1476         operands.push_back(tensor_index_map.lookup(operand));
1477     }
1478 
1479     if (auto tfl_operator =
1480             BuildOperator(real_inst, operands, results, intermediates))
1481       operators.push_back(*tfl_operator);
1482     else
1483       failed_once = true;
1484   }
1485 
1486   if (failed_once) return llvm::None;
1487 
1488   // Get input and output tensor indices for the subgraph.
1489   std::vector<int32_t> inputs, outputs;
1490   for (auto arg : bb.getArguments()) {
1491     inputs.push_back(tensor_index_map[arg]);
1492   }
1493   for (auto result : bb.getTerminator()->getOperands()) {
1494     outputs.push_back(tensor_index_map[result]);
1495   }
1496 
1497   return tflite::CreateSubGraph(
1498       builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs),
1499       builder_.CreateVector(outputs), builder_.CreateVector(operators),
1500       /*name=*/builder_.CreateString(name));
1501 }
1502 
BuildMetadata(StringRef name,StringRef content)1503 BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
1504                                                          StringRef content) {
1505   auto buffer_index = buffers_.size();
1506   auto buffer_data = builder_.CreateVector(
1507       reinterpret_cast<const uint8_t*>(content.data()), content.size());
1508   buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data));
1509   return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index);
1510 }
1511 
1512 Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
CreateMetadataVector()1513 Translator::CreateMetadataVector() {
1514   auto dict_attr = module_->getAttrOfType<mlir::DictionaryAttr>("tfl.metadata");
1515   std::vector<BufferOffset<tflite::Metadata>> metadata;
1516   if (dict_attr) {
1517     for (const auto& named_attr : dict_attr) {
1518       StringRef name = named_attr.first;
1519       mlir::Attribute attr = named_attr.second;
1520       if (auto content = attr.dyn_cast<StringAttr>()) {
1521         metadata.push_back(BuildMetadata(name, content.getValue()));
1522       } else {
1523         module_.emitError(
1524             "all values in tfl.metadata's dictionary key-value pairs should be "
1525             "string attributes");
1526         return llvm::None;
1527       }
1528     }
1529   }
1530   // Runtime version string is generated after we update the op
1531   // versions. Here we put a 16-byte dummy string as a placeholder. We choose
1532   // 16-byte because it's the alignment of buffers in flatbuffer, so it won't
1533   // cause any waste of space if the actual string is shorter than 16 bytes.
1534   metadata.push_back(
1535       BuildMetadata("min_runtime_version", std::string(16, '\0')));
1536   return builder_.CreateVector(metadata);
1537 }
1538 
1539 // Helper method that returns list of all strings in a StringAttr identified
1540 // by 'attr_key' and values are separated by a comma.
GetStringsFromAttrWithSeparator(mlir::DictionaryAttr attr,const std::string & attr_key)1541 llvm::SmallVector<llvm::StringRef, 2> GetStringsFromAttrWithSeparator(
1542     mlir::DictionaryAttr attr, const std::string& attr_key) {
1543   llvm::SmallVector<llvm::StringRef, 2> result;
1544   if (auto str = attr.get(attr_key).dyn_cast_or_null<mlir::StringAttr>()) {
1545     str.getValue().split(result, ',', /*MaxSplit=*/-1,
1546                          /*KeepEmpty=*/false);
1547   }
1548   return result;
1549 }
1550 
1551 // Helper method that return list of string for all the StringAttr in the
1552 // Attribute identified by 'attr_name'.
GetStringsFromDictionaryAttr(const llvm::SmallVector<mlir::DictionaryAttr,4> & dict_attrs,const std::string & attr_name)1553 std::vector<std::string> GetStringsFromDictionaryAttr(
1554     const llvm::SmallVector<mlir::DictionaryAttr, 4>& dict_attrs,
1555     const std::string& attr_name) {
1556   std::vector<std::string> result;
1557   for (const auto& arg_attr : dict_attrs) {
1558     if (!arg_attr) continue;
1559 
1560     auto attrs = arg_attr.getValue();
1561     for (const auto attr : attrs) {
1562       if (attr.first.str() == attr_name) {
1563         auto array_attr = attr.second.dyn_cast_or_null<mlir::ArrayAttr>();
1564         if (!array_attr || array_attr.empty()) continue;
1565         auto string_attr = array_attr[0].dyn_cast_or_null<mlir::StringAttr>();
1566         if (!string_attr) continue;
1567         result.push_back(string_attr.getValue().str());
1568       }
1569     }
1570   }
1571   return result;
1572 }
1573 
BuildSignaturedef(FuncOp main_op,const std::string & saved_model_tag)1574 std::vector<SignatureDefData> BuildSignaturedef(
1575     FuncOp main_op, const std::string& saved_model_tag) {
1576   static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path";
1577   static const char kEntryFunctionAttributes[] = "tf.entry_function";
1578 
1579   // Fetch inputs and outputs from the signature.
1580   llvm::SmallVector<mlir::DictionaryAttr, 4> arg_attrs, res_attrs;
1581   main_op.getAllArgAttrs(arg_attrs);
1582   main_op.getAllResultAttrs(res_attrs);
1583   std::vector<std::string> sig_def_inputs =
1584       GetStringsFromDictionaryAttr(arg_attrs, kSignatureDefIndexPath);
1585   std::vector<std::string> sig_def_outputs =
1586       GetStringsFromDictionaryAttr(res_attrs, kSignatureDefIndexPath);
1587 
1588   // If no defined saved model signature, then return empty list.
1589   // This can happen when we are converting model not from SavedModel.
1590   if (sig_def_inputs.empty() || sig_def_outputs.empty()) return {};
1591 
1592   // Fetch function inputs and outputs tensor names.
1593   auto dict_attr =
1594       main_op->getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes);
1595   if (!dict_attr) return {};
1596 
1597   // Get Input and output tensor names from attribute.
1598   llvm::SmallVector<llvm::StringRef, 2> input_names =
1599       GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs");
1600   llvm::SmallVector<llvm::StringRef, 2> output_names =
1601       GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs");
1602 
1603   // Verify input size match the number of arguments.
1604   if (input_names.size() != main_op.getNumArguments()) {
1605     main_op.emitWarning() << "invalid entry function specification";
1606     return {};
1607   }
1608   // Verify output size match the number of arguments.
1609   auto term = main_op.back().getTerminator();
1610   if (output_names.size() != term->getNumOperands()) {
1611     main_op.emitWarning() << "output names (" << output_names.size()
1612                           << ") != terminator operands ("
1613                           << term->getNumOperands() << ")";
1614     return {};
1615   }
1616   // Verify number of tensors for inputs and outputs matches size
1617   // of the list in the signature def.
1618   if (input_names.size() != sig_def_inputs.size() ||
1619       output_names.size() != sig_def_outputs.size()) {
1620     main_op.emitWarning(
1621         "Mismatch between signature def inputs/outputs and main function "
1622         "arguments.");
1623     return {};
1624   }
1625   // Exported method name.
1626   auto exported_name =
1627       main_op->getAttrOfType<mlir::ArrayAttr>("tf_saved_model.exported_names");
1628   if (exported_name.empty()) {
1629     main_op.emitError("Empty exported names for main Function");
1630     return {};
1631   }
1632   // Fill the SignatureDefData container.
1633   // We create vector of size 1 as TFLite now supports only 1 signatureDef.
1634   std::vector<SignatureDefData> result(1);
1635   for (int i = 0; i < input_names.size(); ++i) {
1636     result[0].inputs[sig_def_inputs[i]] = input_names[i].str();
1637   }
1638   for (int i = 0; i < output_names.size(); ++i) {
1639     result[0].outputs[sig_def_outputs[i]] = output_names[i].str();
1640   }
1641   if (auto name_attr = exported_name[0].dyn_cast_or_null<StringAttr>())
1642     result[0].method_name = name_attr.getValue().str();
1643   result[0].signature_def_key = saved_model_tag;
1644   return result;
1645 }
1646 
GetList(const std::map<std::string,std::string> & items)1647 std::vector<BufferOffset<tflite::TensorMap>> Translator::GetList(
1648     const std::map<std::string, std::string>& items) {
1649   std::vector<BufferOffset<tflite::TensorMap>> result;
1650   for (const auto& item : items) {
1651     auto name_buf = builder_.CreateString(item.first);
1652     tflite::TensorMapBuilder tensor_map_builder(builder_);
1653     tensor_map_builder.add_name(name_buf);
1654     tensor_map_builder.add_tensor_index(tensor_index_map_[item.second]);
1655     result.push_back(tensor_map_builder.Finish());
1656   }
1657   return result;
1658 }
1659 
1660 Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
CreateSignatureDefs(const std::vector<SignatureDefData> & signature_defs)1661 Translator::CreateSignatureDefs(
1662     const std::vector<SignatureDefData>& signature_defs) {
1663   std::vector<BufferOffset<tflite::SignatureDef>> signature_defs_buffer;
1664   for (const auto& signature_def_data : signature_defs) {
1665     auto inputs = GetList(signature_def_data.inputs);
1666     auto outputs = GetList(signature_def_data.outputs);
1667     auto inputs_buf = builder_.CreateVector(inputs);
1668     auto outputs_buf = builder_.CreateVector(outputs);
1669     auto method_name_buf =
1670         builder_.CreateString(signature_def_data.method_name);
1671     auto signature_def_key_buf =
1672         builder_.CreateString(signature_def_data.signature_def_key);
1673     tflite::SignatureDefBuilder sig_def_builder(builder_);
1674     sig_def_builder.add_inputs(inputs_buf);
1675     sig_def_builder.add_outputs(outputs_buf);
1676     sig_def_builder.add_method_name(method_name_buf);
1677     sig_def_builder.add_key(signature_def_key_buf);
1678     signature_defs_buffer.push_back(sig_def_builder.Finish());
1679   }
1680 
1681   return builder_.CreateVector(signature_defs_buffer);
1682 }
1683 
UpdateEntryFunction(ModuleOp module)1684 bool UpdateEntryFunction(ModuleOp module) {
1685   if (module.lookupSymbol<FuncOp>("main") != nullptr) {
1686     // We already have an entry function.
1687     return true;
1688   }
1689 
1690   int entry_func_count = 0;
1691   FuncOp entry_func = nullptr;
1692   for (auto fn : module.getOps<FuncOp>()) {
1693     auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1694     if (attrs && !attrs.empty()) {
1695       entry_func_count++;
1696       entry_func = fn;
1697     }
1698   }
1699 
1700   // We should have one & only have one entry function.
1701   if (entry_func_count != 1) return false;
1702 
1703   // Update the entry func to main.
1704   entry_func.setName("main");
1705   return true;
1706 }
1707 
Translate(ModuleOp module,bool emit_builtin_tflite_ops,bool emit_select_tf_ops,bool emit_custom_ops,const std::unordered_set<std::string> & select_user_tf_ops,const std::unordered_set<std::string> & tags,OpOrArgNameMapper * op_or_arg_name_mapper)1708 Optional<std::string> Translator::Translate(
1709     ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
1710     bool emit_custom_ops,
1711     const std::unordered_set<std::string>& select_user_tf_ops,
1712     const std::unordered_set<std::string>& tags,
1713     OpOrArgNameMapper* op_or_arg_name_mapper) {
1714   OpOrArgLocNameMapper default_op_or_arg_name_mapper;
1715   if (!op_or_arg_name_mapper)
1716     op_or_arg_name_mapper = &default_op_or_arg_name_mapper;
1717   if (!UpdateEntryFunction(module)) return llvm::None;
1718   if (!IsValidTFLiteMlirModule(module)) return llvm::None;
1719   Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
1720                         emit_custom_ops, select_user_tf_ops, tags,
1721                         op_or_arg_name_mapper);
1722   return translator.TranslateInternal();
1723 }
1724 
TranslateInternal()1725 Optional<std::string> Translator::TranslateInternal() {
1726   // A list of named regions in the module with main function being the first in
1727   // the list. The main function is required as the first subgraph in the model
1728   // is entry point for the model.
1729   std::vector<std::pair<std::string, Region*>> named_regions;
1730   named_regions.reserve(std::distance(module_.begin(), module_.end()));
1731 
1732   int subgraph_idx = 0;
1733   FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
1734   subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++;
1735   named_regions.emplace_back("main", &main_fn.getBody());
1736   // Walk over the module collection ops with functions and while ops.
1737   module_.walk([&](FuncOp fn) {
1738     if (fn != main_fn) {
1739       subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
1740       named_regions.emplace_back(fn.getName().str(), &fn.getBody());
1741     }
1742   });
1743 
1744   // Build subgraph for each of the named regions.
1745   std::vector<BufferOffset<tflite::SubGraph>> subgraphs;
1746   subgraphs.reserve(named_regions.size());
1747   int first_failed_func = -1;
1748   for (auto it : llvm::enumerate(named_regions)) {
1749     auto subgraph_or = BuildSubGraph(it.value().first, it.value().second);
1750     if (!subgraph_or) {
1751       if (first_failed_func == -1)
1752         // Record the index of the first region that cannot be converted.
1753         // Keep looping through all subgraphs in the module to make sure that
1754         // we collect the list of missing ops from the entire module.
1755         first_failed_func = it.index();
1756     } else {
1757       subgraphs.push_back(*subgraph_or);
1758     }
1759   }
1760 
1761   if (!resource_ops_.empty()) {
1762     std::string resource_ops_summary =
1763         GetOpsSummary(resource_ops_, /*summary_title=*/"Resource");
1764     LOG(WARNING) << "Graph contains the following resource op(s), that use(s) "
1765                     "resource type. Currently, the "
1766                     "resource type is not natively supported in TFLite. Please "
1767                     "consider not using the resource type if there are issues "
1768                     "with either TFLite converter or TFLite runtime:\n"
1769                  << resource_ops_summary;
1770   }
1771 
1772   if (!flex_ops_.empty()) {
1773     std::string flex_ops_summary =
1774         GetOpsSummary(flex_ops_, /*summary_title=*/"Flex");
1775     LOG(WARNING) << "TFLite interpreter needs to link Flex delegate in order "
1776                     "to run the model since it contains the following flex "
1777                     "op(s):\n"
1778                  << flex_ops_summary;
1779   }
1780 
1781   if (!custom_ops_.empty()) {
1782     std::string custom_ops_summary =
1783         GetOpsSummary(custom_ops_, /*summary_title=*/"Custom");
1784     LOG(WARNING) << "The following operation(s) need TFLite custom op "
1785                     "implementation(s):\n"
1786                  << custom_ops_summary;
1787   }
1788 
1789   if (first_failed_func != -1) {
1790     std::string failed_flex_ops_summary =
1791         GetOpsSummary(failed_flex_ops_, /*summary_title=*/"TF Select");
1792     std::string failed_custom_ops_summary =
1793         GetOpsSummary(failed_custom_ops_, /*summary_title=*/"Custom");
1794     std::string err;
1795     if (!failed_flex_ops_.empty())
1796       err +=
1797           "\nSome ops are not supported by the native TFLite runtime, you can "
1798           "enable TF kernels fallback using TF Select. See instructions: "
1799           "https://www.tensorflow.org/lite/guide/ops_select \n" +
1800           failed_flex_ops_summary + "\n";
1801     if (!failed_custom_ops_.empty())
1802       err +=
1803           "\nSome ops in the model are custom ops, "
1804           "See instructions to implement "
1805           "custom ops: https://www.tensorflow.org/lite/guide/ops_custom \n" +
1806           failed_custom_ops_summary + "\n";
1807 
1808     auto& failed_region = named_regions[first_failed_func];
1809     return failed_region.second->getParentOp()->emitError()
1810                << "failed while converting: '" << failed_region.first
1811                << "': " << err,
1812            llvm::None;
1813   }
1814 
1815   std::string model_description;
1816   if (auto attr = module_->getAttrOfType<StringAttr>("tfl.description")) {
1817     model_description = attr.getValue().str();
1818   } else {
1819     model_description = "MLIR Converted.";
1820   }
1821 
1822   // Build the model and finish the model building process.
1823   auto description = builder_.CreateString(model_description.data());
1824   VectorBufferOffset<int32_t> metadata_buffer = 0;  // Deprecated
1825   auto metadata = CreateMetadataVector();
1826   if (!metadata) return llvm::None;
1827 
1828   // Build SignatureDef
1829   // We only have 1 entry point 'main' function, so build only 1 signature def.
1830   auto main_fn_signature_def = BuildSignaturedef(
1831       main_fn, saved_model_tags_.empty() ? "" : *saved_model_tags_.begin());
1832   auto signature_defs = CreateSignatureDefs(main_fn_signature_def);
1833 
1834   auto model = tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
1835                                    builder_.CreateVector(opcodes_),
1836                                    builder_.CreateVector(subgraphs),
1837                                    description, builder_.CreateVector(buffers_),
1838                                    metadata_buffer, *metadata, *signature_defs);
1839   tflite::FinishModelBuffer(builder_, model);
1840   tflite::UpdateOpVersion(builder_.GetBufferPointer());
1841   tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer());
1842 
1843   // Return serialized string for the built FlatBuffer.
1844   return std::string(reinterpret_cast<const char*>(builder_.GetBufferPointer()),
1845                      builder_.GetSize());
1846 }
1847 
BuildSparsityParameters(const mlir::TFL::SparsityParameterAttr & s_attr)1848 BufferOffset<tflite::SparsityParameters> Translator::BuildSparsityParameters(
1849     const mlir::TFL::SparsityParameterAttr& s_attr) {
1850   const int dim_size = s_attr.dim_metadata().size();
1851   std::vector<flatbuffers::Offset<tflite::DimensionMetadata>> fb_dim_metadata(
1852       dim_size);
1853   for (int i = 0; i < dim_size; i++) {
1854     const auto dim_metadata =
1855         s_attr.dim_metadata()[i].dyn_cast<mlir::TFL::DimensionMetadataAttr>();
1856     if (dim_metadata.format().getValue() == "DENSE") {
1857       fb_dim_metadata[i] =
1858           tflite::CreateDimensionMetadata(builder_, tflite::DimensionType_DENSE,
1859                                           dim_metadata.dense_size().getInt());
1860 
1861     } else {
1862       auto segments = dim_metadata.segments();
1863       std::vector<int> vector_segments(segments.size(), 0);
1864       for (int j = 0, end = segments.size(); j < end; j++) {
1865         vector_segments[j] = segments[j].dyn_cast<mlir::IntegerAttr>().getInt();
1866       }
1867       tflite::SparseIndexVector segments_type;
1868       BufferOffset<void> array_segments;
1869       // The segment array is sorted.
1870       // TODO(b/147449640): Clean this up with util functions.
1871       int max_of_segments = vector_segments[segments.size() - 1];
1872       if (max_of_segments <= UINT8_MAX) {
1873         segments_type = tflite::SparseIndexVector_Uint8Vector;
1874         std::vector<uint8_t> uint8_vector(vector_segments.begin(),
1875                                           vector_segments.end());
1876         array_segments = tflite::CreateUint8Vector(
1877                              builder_, builder_.CreateVector(uint8_vector))
1878                              .Union();
1879       } else if (max_of_segments <= UINT16_MAX) {
1880         segments_type = tflite::SparseIndexVector_Uint16Vector;
1881         std::vector<uint16_t> uint16_vector(vector_segments.begin(),
1882                                             vector_segments.end());
1883         array_segments = tflite::CreateUint16Vector(
1884                              builder_, builder_.CreateVector(uint16_vector))
1885                              .Union();
1886       } else {
1887         segments_type = tflite::SparseIndexVector_Int32Vector;
1888         array_segments = tflite::CreateInt32Vector(
1889                              builder_, builder_.CreateVector(vector_segments))
1890                              .Union();
1891       }
1892 
1893       auto indices = dim_metadata.indices();
1894       std::vector<int> vector_indices(indices.size(), 0);
1895       int max_of_indices = 0;
1896       for (int j = 0, end = indices.size(); j < end; j++) {
1897         vector_indices[j] = indices[j].dyn_cast<mlir::IntegerAttr>().getInt();
1898         if (vector_indices[j] > max_of_indices) {
1899           max_of_indices = vector_indices[j];
1900         }
1901       }
1902       tflite::SparseIndexVector indices_type;
1903       BufferOffset<void> array_indices;
1904       if (max_of_indices <= UINT8_MAX) {
1905         indices_type = tflite::SparseIndexVector_Uint8Vector;
1906         std::vector<uint8_t> uint8_vector(vector_indices.begin(),
1907                                           vector_indices.end());
1908         array_indices = tflite::CreateUint8Vector(
1909                             builder_, builder_.CreateVector(uint8_vector))
1910                             .Union();
1911       } else if (max_of_indices <= UINT16_MAX) {
1912         indices_type = tflite::SparseIndexVector_Uint16Vector;
1913         std::vector<uint16_t> uint16_vector(vector_indices.begin(),
1914                                             vector_indices.end());
1915         array_indices = tflite::CreateUint16Vector(
1916                             builder_, builder_.CreateVector(uint16_vector))
1917                             .Union();
1918       } else {
1919         indices_type = tflite::SparseIndexVector_Int32Vector;
1920         array_indices = tflite::CreateInt32Vector(
1921                             builder_, builder_.CreateVector(vector_indices))
1922                             .Union();
1923       }
1924 
1925       fb_dim_metadata[i] = tflite::CreateDimensionMetadata(
1926           builder_, tflite::DimensionType_SPARSE_CSR, 0, segments_type,
1927           array_segments, indices_type, array_indices);
1928     }
1929   }
1930 
1931   std::vector<int> traversal_order(dim_size);
1932   for (int i = 0; i < dim_size; i++) {
1933     traversal_order[i] =
1934         s_attr.traversal_order()[i].dyn_cast<mlir::IntegerAttr>().getInt();
1935   }
1936   const int block_map_size = s_attr.block_map().size();
1937   std::vector<int> block_map(block_map_size);
1938   for (int i = 0; i < block_map_size; i++) {
1939     block_map[i] = s_attr.block_map()[i].dyn_cast<mlir::IntegerAttr>().getInt();
1940   }
1941 
1942   return tflite::CreateSparsityParameters(
1943       builder_, builder_.CreateVector(traversal_order),
1944       builder_.CreateVector(block_map), builder_.CreateVector(fb_dim_metadata));
1945 }
1946 
1947 }  // namespace
1948 
1949 namespace tflite {
1950 // TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting
1951 // the following:
1952 //
1953 // * Quantization
1954 // * Ops with variable tensors
1955 //
MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,const FlatbufferExportOptions & options,std::string * serialized_flatbuffer)1956 bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
1957                                        const FlatbufferExportOptions& options,
1958                                        std::string* serialized_flatbuffer) {
1959   auto maybe_translated = Translator::Translate(
1960       module, options.emit_builtin_tflite_ops, options.emit_select_tf_ops,
1961       options.emit_custom_ops, options.select_user_tf_ops,
1962       options.saved_model_tags, options.op_or_arg_name_mapper);
1963   if (!maybe_translated) return false;
1964   *serialized_flatbuffer = std::move(*maybe_translated);
1965   return true;
1966 }
1967 
1968 }  // namespace tflite
1969