• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
17 
18 #include <algorithm>
19 #include <cctype>
20 #include <climits>
21 #include <cstdint>
22 #include <iostream>
23 #include <sstream>
24 #include <string>
25 #include <vector>
26 
27 #include "absl/base/casts.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/strings/string_view.h"
31 #include "llvm/ADT/APFloat.h"
32 #include "llvm/ADT/APInt.h"
33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/DenseMap.h"
35 #include "llvm/ADT/None.h"
36 #include "llvm/ADT/Optional.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/StringExtras.h"
40 #include "llvm/ADT/StringRef.h"
41 #include "llvm/ADT/Twine.h"
42 #include "llvm/Support/Casting.h"
43 #include "llvm/Support/CommandLine.h"
44 #include "llvm/Support/Endian.h"
45 #include "llvm/Support/FormatVariadic.h"
46 #include "llvm/Support/MemoryBuffer.h"
47 #include "llvm/Support/SourceMgr.h"
48 #include "llvm/Support/raw_ostream.h"
49 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
50 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
51 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
52 #include "mlir/IR/Attributes.h"  // from @llvm-project
53 #include "mlir/IR/Builders.h"  // from @llvm-project
54 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
55 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
56 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
57 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
58 #include "mlir/IR/Location.h"  // from @llvm-project
59 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
60 #include "mlir/IR/Operation.h"  // from @llvm-project
61 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
62 #include "mlir/IR/Types.h"  // from @llvm-project
63 #include "mlir/IR/Value.h"  // from @llvm-project
64 #include "mlir/Support/LLVM.h"  // from @llvm-project
65 #include "mlir/Translation.h"  // from @llvm-project
66 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
67 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
68 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
69 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
70 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
71 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
72 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
73 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
74 #include "tensorflow/compiler/xla/statusor.h"
75 #include "tensorflow/core/framework/tensor.pb.h"
76 #include "tensorflow/core/framework/tensor_shape.pb.h"
77 #include "tensorflow/core/platform/errors.h"
78 #include "tensorflow/core/platform/status.h"
79 #include "tensorflow/lite/model.h"
80 #include "tensorflow/lite/schema/schema_generated.h"
81 #include "tensorflow/lite/schema/schema_utils.h"
82 #include "tensorflow/lite/string_util.h"
83 
84 using llvm::ArrayRef;
85 using mlir::Builder;
86 using mlir::DenseElementsAttr;
87 using mlir::FuncOp;
88 using mlir::Location;
89 using mlir::MLIRContext;
90 using mlir::OpBuilder;
91 using mlir::Operation;
92 using mlir::OperationState;
93 using mlir::OwningModuleRef;
94 using mlir::RankedTensorType;
95 using mlir::UnrankedTensorType;
96 using mlir::Value;
97 using mlir::quant::QuantizedType;
98 using tflite::OperatorT;
99 using tflite::TensorT;
100 using xla::Status;
101 using xla::StatusOr;
102 
103 namespace errors = tensorflow::errors;
104 namespace tfl = mlir::TFL;
105 
106 namespace {
IsScalar(const TensorT & tensor)107 bool IsScalar(const TensorT& tensor) {
108   // TODO(b/138222071) We can't distinguish scalars and unranked tensors
109   // Work out a way to handle this and stub out the code until then
110   return tensor.shape.empty() && false;
111 }
112 
IsQuantized(const TensorT & tensor)113 bool IsQuantized(const TensorT& tensor) {
114   return (tensor.quantization != nullptr) &&
115          !tensor.quantization->zero_point.empty();
116 }
117 
118 // Create the MLIR NamedLoc location corresponding to a given tensor
TensorLoc(const TensorT & tensor,Builder builder,Location base)119 Location TensorLoc(const TensorT& tensor, Builder builder, Location base) {
120   if (tensor.name.empty()) {
121     return base;
122   }
123   return mlir::NameLoc::get(builder.getIdentifier(tensor.name), base);
124 }
125 
126 // Create the MLIR Location corresponding to a given op. This is an
127 // experimental/debugging feature and production code should not rely on names
128 // of intermediate tensors since importer doesn't guarantee to preserve tensor
129 // names except output tensors.
OpLoc(const OperatorT & op,const std::vector<std::unique_ptr<tflite::TensorT>> & tensors,Builder builder,Location base)130 Location OpLoc(const OperatorT& op,
131                const std::vector<std::unique_ptr<tflite::TensorT>>& tensors,
132                Builder builder, Location base) {
133   if (op.outputs.empty()) return base;
134 
135   llvm::SmallVector<Location, 4> locations;
136   locations.reserve(op.outputs.size());
137   for (auto tensor_index : op.outputs) {
138     locations.push_back(TensorLoc(*tensors[tensor_index], builder, base));
139   }
140   return mlir::FusedLoc::get(builder.getContext(), locations);
141 }
142 
143 // Returns the correct type for a quantized tensor
144 // We have a special case for constants since they have a higher minimum value.
GetQuantizedType(const TensorT & tensor,Builder builder,bool is_constant=false)145 StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
146                                          bool is_constant = false) {
147   tflite::QuantizationParametersT& quant_params = *tensor.quantization;
148   if (quant_params.details.AsCustomQuantization()) {
149     return errors::Unimplemented("Cannot handle experimental quantization");
150   }
151 
152   bool is_signed = true;
153   mlir::IntegerType storage_type;
154   if (tensor.type == tflite::TensorType_UINT8) {
155     is_signed = false;
156     storage_type = builder.getIntegerType(8);
157   } else {
158     auto raw_elem_type = ConvertElementType(tensor.type, builder);
159     if (!raw_elem_type.isa<mlir::IntegerType>()) {
160       return errors::InvalidArgument(
161           "Quantized tensors must be stored as integers");
162     }
163     storage_type = raw_elem_type.cast<mlir::IntegerType>();
164   }
165 
166   // TFlite uses narrow-range [u]int8 for constant buffers of quantized weights.
167   // Since we don't know which ones are weights, we represent this optimization
168   // as a change in the storage bounds for the type for all constants of this
169   // type.
170   bool is_weight_buffer = is_constant && (storage_type.getWidth() == 8);
171 
172   int64_t storage_min = QuantizedType::getDefaultMinimumForInteger(
173                             is_signed, storage_type.getWidth()) +
174                         static_cast<int>(is_weight_buffer);
175   int64_t storage_max = QuantizedType::getDefaultMaximumForInteger(
176       is_signed, storage_type.getWidth());
177   uint32_t flags =
178       is_signed ? mlir::quant::QuantizationFlags::FlagValue::Signed : 0;
179 
180   // Rejects if quantized tensors have zero scales.
181   for (float scale : quant_params.scale) {
182     if (scale == 0) {
183       return errors::InvalidArgument(
184           "Quantized tensors must have non-zero scales");
185     }
186   }
187 
188   // Scale size can't be zero as it is checked before.
189   if (quant_params.scale.size() != 1) {
190     llvm::SmallVector<double, 4> scales(quant_params.scale.begin(),
191                                         quant_params.scale.end());
192     return mlir::quant::UniformQuantizedPerAxisType::get(
193         flags, storage_type, builder.getF32Type(), scales,
194         quant_params.zero_point, quant_params.quantized_dimension, storage_min,
195         storage_max);
196   }
197   return mlir::quant::UniformQuantizedType::get(
198       flags, storage_type, builder.getF32Type(), quant_params.scale.at(0),
199       quant_params.zero_point.at(0), storage_min, storage_max);
200 }
201 
202 // import float tensor with calibration value into calibrated quantized type.
GetCalibratedQuantizedType(const TensorT & tensor,Builder builder)203 StatusOr<QuantizedType> GetCalibratedQuantizedType(const TensorT& tensor,
204                                                    Builder builder) {
205   if (tensor.quantization == nullptr) {
206     return errors::InvalidArgument("The tensor is not quantized.");
207   }
208   auto raw_elem_type = ConvertElementType(tensor.type, builder);
209   float min = tensor.quantization->min[0];
210   float max = tensor.quantization->max[0];
211   return mlir::quant::CalibratedQuantizedType::get(raw_elem_type, min, max);
212 }
213 
214 // TODO(b/138222071) Remove shapeless_are_scalars once we can reliably
215 // make that distinction and don't have to rely on context
216 // (input to main and constants must have static shape)
GetTensorType(const TensorT & tensor,Builder builder,bool shapeless_are_scalars=false,bool is_constant=false,bool is_intermediate=false)217 StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
218                                          bool shapeless_are_scalars = false,
219                                          bool is_constant = false,
220                                          bool is_intermediate = false) {
221   mlir::Type elem_type = ConvertElementType(tensor.type, builder);
222   if (IsQuantized(tensor)) {
223     TF_ASSIGN_OR_RETURN(elem_type,
224                         GetQuantizedType(tensor, builder, is_constant));
225   }
226 
227   // Intermediate tensors with calibration value (but not scale and zero points)
228   // should return calibrated quantized type.
229   if (is_intermediate && tensor.quantization != nullptr &&
230       !IsQuantized(tensor)) {
231     TF_ASSIGN_OR_RETURN(elem_type, GetCalibratedQuantizedType(tensor, builder));
232   }
233 
234   if (IsScalar(tensor) || (shapeless_are_scalars && tensor.shape.empty())) {
235     return RankedTensorType::get({}, elem_type);
236   }
237 
238   if (!tensor.shape_signature.empty()) {
239     llvm::SmallVector<int64_t, 4> shape(tensor.shape_signature.begin(),
240                                         tensor.shape_signature.end());
241     return RankedTensorType::get(shape, elem_type);
242   }
243 
244   if (!tensor.shape.empty()) {
245     llvm::SmallVector<int64_t, 4> shape(tensor.shape.begin(),
246                                         tensor.shape.end());
247     return RankedTensorType::get(shape, elem_type);
248   }
249 
250   return UnrankedTensorType::get(elem_type);
251 }
252 
253 // Extract the min max information in the tensor and create the quant stats op.
254 // If the input `tensor` has scale/zero_point, `res` should have quantized
255 // type, thus none stats op is required and nullptr is retruned.
256 // If the min max information is invalid, nullptr is returned.
ConvertMinMaxToStatsOp(const TensorT & tensor,OpBuilder b,Value res)257 mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
258                                         Value res) {
259   // If the `tensor` has scale/zero_point, it must have been quantized, then the
260   // min/max stats is just for comments, so ignore it.
261   if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
262   // If the result isn't float and unquantizable, the min/max is ignored.
263   if (!res.getType()
264            .cast<mlir::ShapedType>()
265            .getElementType()
266            .isa<mlir::FloatType>()) {
267     return nullptr;
268   }
269   auto mins = tensor.quantization->min;
270   auto maxs = tensor.quantization->max;
271   if (mins.size() != maxs.size() || mins.empty()) return nullptr;
272 
273   llvm::SmallVector<llvm::APFloat, 4> min_maxs;
274   min_maxs.reserve(mins.size() * 2);
275   for (int i = 0, end = mins.size(); i < end; ++i) {
276     llvm::APFloat min(mins[i]);
277     llvm::APFloat max(maxs[i]);
278     min_maxs.push_back(min);
279     min_maxs.push_back(max);
280   }
281   // The layer stats contain only the first min/max pairs.
282   mlir::ElementsAttr layer_stats = mlir::DenseFPElementsAttr::get(
283       mlir::RankedTensorType::get({2}, b.getF32Type()),
284       {min_maxs[0], min_maxs[1]});
285   mlir::ElementsAttr axis_stats;
286   mlir::IntegerAttr axis;
287   if (mins.size() > 1) {
288     llvm::SmallVector<int64_t, 4> axis_stats_shape{
289         static_cast<int64_t>(mins.size()), 2};
290     axis_stats = mlir::DenseFPElementsAttr::get(
291         mlir::RankedTensorType::get(axis_stats_shape, b.getF32Type()),
292         min_maxs);
293     // TODO(fengliuai): this quantization dimension isn't correct.
294     axis = b.getI64IntegerAttr(tensor.quantization->quantized_dimension);
295   }
296   return b.create<mlir::quant::StatisticsOp>(b.getUnknownLoc(), res,
297                                              layer_stats, axis_stats, axis);
298 }
299 
300 // Returns true if this is a basic LSTM op.
IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union)301 bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
302   if (const auto* op = op_union.AsLSTMOptions()) {
303     return op->kernel_type == tflite::LSTMKernelType_BASIC;
304   } else {
305     return false;
306   }
307 }
308 
309 // Gets the MLIR op name with the dialect name for the flatbuffer operator.
GetMlirOpName(const tflite::OperatorT & op,const tflite::OperatorCodeT & op_code)310 StatusOr<std::string> GetMlirOpName(const tflite::OperatorT& op,
311                                     const tflite::OperatorCodeT& op_code) {
312   if (IsBasicLSTMOp(op.builtin_options)) {
313     return std::string("tfl.basic_lstm");
314   }
315 
316   auto builtin_code = tflite::GetBuiltinCode(&op_code);
317   if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
318     return std::string("tfl.custom");
319   }
320   if (builtin_code == tflite::BuiltinOperator_IF) {
321     return std::string("tf.If");
322   }
323   if (builtin_code == tflite::BuiltinOperator_WHILE) {
324     return std::string("tfl.while");
325   }
326 
327   llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code));
328   return llvm::Twine("tfl.", op_name.lower()).str();
329 }
330 
331 // The buffers in TFLite flatbuffers have their contents stored as a vector of
332 // bytes that represent little-endian values.
333 // The read_size parameter is present to allow reading both float16 and float32s
334 // without a case split.
335 template <typename T>
ReadAsLittleEndian(ArrayRef<uint8_t> bytes)336 std::vector<T> ReadAsLittleEndian(ArrayRef<uint8_t> bytes) {
337   std::vector<T> ret;
338   size_t read_size = sizeof(T);
339   int bytes_len = bytes.size();
340   assert(bytes_len % read_size == 0);
341 
342   int elem_count = bytes_len / read_size;
343   ret.reserve(elem_count);
344 
345   const char* data_ptr = reinterpret_cast<const char*>(bytes.data());
346   for (int i = 0; i < elem_count; i++) {
347     ret.push_back(
348         llvm::support::endian::readNext<T, llvm::support::little,
349                                         llvm::support::unaligned>(data_ptr));
350   }
351   return ret;
352 }
353 
ConvertTfliteConstTensor(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer)354 tensorflow::TensorProto ConvertTfliteConstTensor(
355     const tflite::TensorT& tensor, const std::vector<uint8_t>& buffer) {
356   tensorflow::TensorProto ret;
357   ret.set_dtype(TflTypeToTfType(tensor.type));
358 
359   tensorflow::TensorShapeProto* shape = ret.mutable_tensor_shape();
360   shape->set_unknown_rank(false);
361   for (auto dim : tensor.shape) {
362     shape->add_dim()->set_size(int64_t{dim});
363   }
364   // TensorFlow Lite uses tflite::DynamicBufer to encode vector of strings.
365   if (tensor.type == tflite::TensorType_STRING) {
366     for (int i = 0; i < tflite::GetStringCount(buffer.data()); ++i) {
367       tflite::StringRef str = tflite::GetString(buffer.data(), i);
368       ret.add_string_val(str.str, str.len);
369     }
370     return ret;
371   }
372   std::string content;
373   content.assign(reinterpret_cast<const char*>(buffer.data()), buffer.size());
374   ret.set_tensor_content(content);
375   return ret;
376 }
377 
ConvertFloatBuffer(mlir::RankedTensorType shaped_type,mlir::FloatType elem_type,const std::vector<uint8_t> & buffer)378 StatusOr<mlir::ElementsAttr> ConvertFloatBuffer(
379     mlir::RankedTensorType shaped_type, mlir::FloatType elem_type,
380     const std::vector<uint8_t>& buffer) {
381   size_t bytes_len = buffer.size();
382 
383   // The bytes of floats are stored little-endian.
384   switch (elem_type.getWidth()) {
385     case 16: {
386       assert(bytes_len % 2 == 0);
387       int elem_count = bytes_len / 2;
388       std::vector<llvm::APFloat> values;
389       values.reserve(elem_count);
390 
391       const char* data = reinterpret_cast<const char*>(buffer.data());
392       auto& semantics = elem_type.getFloatSemantics();
393 
394       for (int i = 0; i < elem_count; i++) {
395         uint16_t bit_repr =
396             llvm::support::endian::readNext<uint16_t, llvm::support::little,
397                                             llvm::support::unaligned>(data);
398         llvm::APInt int_repr(16, bit_repr);
399         values.emplace_back(semantics, int_repr);
400       }
401 
402       return DenseElementsAttr::get(shaped_type, values);
403     }
404     case 32: {
405       assert(bytes_len % 4 == 0);
406       int elem_count = bytes_len / 4;
407       std::vector<float> values;
408       values.reserve(elem_count);
409 
410       const char* data = reinterpret_cast<const char*>(buffer.data());
411 
412       for (int i = 0; i < elem_count; i++) {
413         uint32_t bit_repr =
414             llvm::support::endian::readNext<uint32_t, llvm::support::little,
415                                             llvm::support::unaligned>(data);
416         values.push_back(absl::bit_cast<float>(bit_repr));
417       }
418       return DenseElementsAttr::get(shaped_type, ArrayRef<float>(values));
419     }
420     case 64: {
421       assert(bytes_len % 8 == 0);
422       int elem_count = bytes_len / 8;
423       std::vector<double> values;
424       values.reserve(elem_count);
425 
426       const char* data = reinterpret_cast<const char*>(buffer.data());
427 
428       for (int i = 0; i < elem_count; i++) {
429         uint64_t bit_repr =
430             llvm::support::endian::readNext<uint64_t, llvm::support::little,
431                                             llvm::support::unaligned>(data);
432         values.push_back(absl::bit_cast<double>(bit_repr));
433       }
434       return DenseElementsAttr::get(shaped_type, ArrayRef<double>(values));
435     }
436   }
437   return errors::InvalidArgument("unsupported bit width", elem_type.getWidth());
438 }
439 
ConvertIntBuffer(mlir::RankedTensorType shaped_type,mlir::Type elem_type,const std::vector<uint8_t> & buffer)440 StatusOr<mlir::ElementsAttr> ConvertIntBuffer(
441     mlir::RankedTensorType shaped_type, mlir::Type elem_type,
442     const std::vector<uint8_t>& buffer) {
443   unsigned bit_width;
444   if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) {
445     bit_width = itype.getWidth();
446   } else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) {
447     bit_width = qtype.getStorageTypeIntegralWidth();
448     shaped_type = mlir::RankedTensorType::get(shaped_type.getShape(),
449                                               qtype.getStorageType());
450   } else {
451     return errors::InvalidArgument("unsupported integer constant type");
452   }
453 
454   switch (bit_width) {
455     case 1: {
456       // vector<bool> doesn't convert to an ArrayRef
457       llvm::SmallVector<bool, 8> values;
458       values.reserve(buffer.size());
459       for (auto b : buffer) {
460         values.emplace_back(b != 0);
461       }
462       return DenseElementsAttr::get(shaped_type, ArrayRef<bool>(values));
463     }
464     case 8: {
465       return DenseElementsAttr::get(shaped_type, ArrayRef<uint8_t>(buffer));
466     }
467     case 16: {
468       auto values = ReadAsLittleEndian<uint16_t>(buffer);
469       return DenseElementsAttr::get(shaped_type, ArrayRef<uint16_t>(values));
470     }
471     case 32: {
472       auto values = ReadAsLittleEndian<uint32_t>(buffer);
473       return DenseElementsAttr::get(shaped_type, ArrayRef<uint32_t>(values));
474     }
475     case 64: {
476       auto values = ReadAsLittleEndian<uint64_t>(buffer);
477       return DenseElementsAttr::get(shaped_type, ArrayRef<uint64_t>(values));
478     }
479     default:
480       return errors::Unimplemented("Cannot handle bit width ", bit_width);
481   }
482 }
483 
BuildExternalConstOp(const tflite::TensorT & tensor,int32_t buffer_index,OpBuilder builder,Location loc)484 StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
485                                           int32_t buffer_index,
486                                           OpBuilder builder, Location loc) {
487   TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
488                                                /*shapeless_are_scalars=*/true,
489                                                /*is_constant=*/true));
490   auto shaped_type = type.dyn_cast<mlir::RankedTensorType>();
491   if (!shaped_type) {
492     return errors::Internal("Constant doesn't have a shape");
493   }
494   auto op = builder.create<tfl::ExternalConstOp>(
495       loc, shaped_type, builder.getI32IntegerAttr(buffer_index));
496   return op.getOperation();
497 }
498 
499 // Gets a constant splat for the given value of type. Requires value to be of
500 // type static shaped RankedTensorType. `unique_index` is used to get the unique
501 // value for the attribute.
GetSplat(RankedTensorType type,int unique_index,OpBuilder builder)502 static mlir::ElementsAttr GetSplat(RankedTensorType type, int unique_index,
503                                    OpBuilder builder) {
504   mlir::Type element_ty = getElementTypeOrSelf(type);
505 
506   if (element_ty.isSignlessInteger())
507     return DenseElementsAttr::get(
508         type, builder.getIntegerAttr(element_ty, unique_index));
509 
510   if (element_ty.isa<mlir::FloatType>())
511     return DenseElementsAttr::get(
512         type, builder.getFloatAttr(element_ty, unique_index));
513 
514   if (auto qtype = element_ty.dyn_cast<QuantizedType>()) {
515     mlir::RankedTensorType new_type =
516         RankedTensorType::get(type.getShape(), qtype.getStorageType());
517     return DenseElementsAttr::get(
518         new_type, builder.getIntegerAttr(qtype.getStorageType(), unique_index));
519   }
520   llvm_unreachable("unhandled element type");
521 }
522 
523 // TODO(b/172664358): Creates a new op instead of reusing constant op.
524 // Creates a constant op to represent stateful variable. The function static
525 // variable `stateful_variable_idx` is used as a unique value for each constant
526 // to avoid CSEed. `tensor` is the data structure of flatbuffer. `shaped_type`
527 // is the ShapedType for the const op.
BuildVariableOp(const tflite::TensorT & tensor,mlir::RankedTensorType shaped_type,OpBuilder builder,Location loc)528 Operation* BuildVariableOp(const tflite::TensorT& tensor,
529                            mlir::RankedTensorType shaped_type,
530                            OpBuilder builder, Location loc) {
531   static int stateful_variable_idx = 0;
532   mlir::ElementsAttr value =
533       GetSplat(shaped_type, stateful_variable_idx++, builder);
534   if (IsQuantized(tensor)) {
535     auto op = builder.create<tfl::QConstOp>(
536         loc, mlir::TypeAttr::get(shaped_type), value);
537     return op.getOperation();
538   }
539   auto op = builder.create<tfl::ConstOp>(loc, value);
540   if (tensor.quantization && !tensor.quantization->min.empty()) {
541     if (auto stats_op =
542             ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) {
543       return stats_op;
544     }
545   }
546   return op.getOperation();
547 }
548 
ConvertSparseIndexVector(const tflite::SparseIndexVectorUnion & sparse_index_vector,OpBuilder & builder)549 static StatusOr<mlir::ArrayAttr> ConvertSparseIndexVector(
550     const tflite::SparseIndexVectorUnion& sparse_index_vector,
551     OpBuilder& builder) {
552   if (sparse_index_vector.type == tflite::SparseIndexVector_Int32Vector) {
553     return builder.getI32ArrayAttr(sparse_index_vector.AsInt32Vector()->values);
554   } else if (sparse_index_vector.type ==
555              tflite::SparseIndexVector_Uint16Vector) {
556     const auto& inputs = sparse_index_vector.AsUint16Vector()->values;
557     std::vector<int32_t> outputs(inputs.size());
558     std::transform(inputs.begin(), inputs.end(), outputs.begin(),
559                    [](auto x) { return static_cast<int32_t>(x); });
560     return builder.getI32ArrayAttr(outputs);
561   } else if (sparse_index_vector.type ==
562              tflite::SparseIndexVector_Uint8Vector) {
563     const auto& inputs = sparse_index_vector.AsUint8Vector()->values;
564     std::vector<int32_t> outputs(inputs.size());
565     std::transform(inputs.begin(), inputs.end(), outputs.begin(),
566                    [](auto x) { return static_cast<int32_t>(x); });
567     return builder.getI32ArrayAttr(outputs);
568   } else {
569     return errors::Unimplemented("Unsupported SparseIndexVector type");
570   }
571 }
572 
BuildSparseConstOp(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer,const mlir::RankedTensorType shaped_type,OpBuilder & builder,Location loc)573 static StatusOr<Operation*> BuildSparseConstOp(
574     const tflite::TensorT& tensor, const std::vector<uint8_t>& buffer,
575     const mlir::RankedTensorType shaped_type, OpBuilder& builder,
576     Location loc) {
577   tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
578   repr.clear_tensor_shape();
579   if (IsQuantized(tensor)) {
580     repr.mutable_tensor_shape()->add_dim()->set_size(buffer.size());
581     repr.set_dtype(tensorflow::DT_INT8);
582   } else {
583     repr.mutable_tensor_shape()->add_dim()->set_size(
584         buffer.size() / (shaped_type.getElementTypeBitWidth() / CHAR_BIT));
585   }
586   TF_ASSIGN_OR_RETURN(mlir::ElementsAttr compressed_data,
587                       tensorflow::ConvertTensorProto(repr, &builder));
588 
589   const int dim_metadata_size = tensor.sparsity->dim_metadata.size();
590   std::vector<mlir::Attribute> dim_metadata(dim_metadata_size);
591   for (int i = 0; i < dim_metadata_size; i++) {
592     if (tensor.sparsity->dim_metadata[i]->format ==
593         tflite::DimensionType_DENSE) {
594       dim_metadata[i] = tfl::DimensionMetadataAttr::get(
595           builder.getStringAttr("DENSE"),
596           builder.getI32IntegerAttr(
597               tensor.sparsity->dim_metadata[i]->dense_size),
598           builder.getI32ArrayAttr({}), builder.getI32ArrayAttr({}),
599           builder.getContext());
600     } else if (tensor.sparsity->dim_metadata[i]->format ==
601                tflite::DimensionType_SPARSE_CSR) {
602       TF_ASSIGN_OR_RETURN(
603           mlir::ArrayAttr segments,
604           ConvertSparseIndexVector(
605               tensor.sparsity->dim_metadata[i]->array_segments, builder));
606       TF_ASSIGN_OR_RETURN(
607           mlir::ArrayAttr indices,
608           ConvertSparseIndexVector(
609               tensor.sparsity->dim_metadata[i]->array_indices, builder));
610       dim_metadata[i] = tfl::DimensionMetadataAttr::get(
611           builder.getStringAttr("SPARSE_CSR"), builder.getI32IntegerAttr(0),
612           segments, indices, builder.getContext());
613     } else {
614       return errors::Unimplemented("Unsupported dimension metadata type");
615     }
616   }
617   auto s_param = tfl::SparsityParameterAttr::get(
618       builder.getI32ArrayAttr(tensor.sparsity->traversal_order),
619       builder.getI32ArrayAttr(tensor.sparsity->block_map),
620       builder.getArrayAttr(dim_metadata), builder.getContext());
621 
622   auto value_type = shaped_type;
623   if (IsQuantized(tensor)) {
624     value_type = RankedTensorType::get(
625         shaped_type.getShape(), shaped_type.getElementType()
626                                     .dyn_cast<mlir::quant::QuantizedType>()
627                                     .getStorageType());
628   }
629   std::vector<char> dense_buffer(
630       value_type.getElementType().getIntOrFloatBitWidth() / CHAR_BIT);
631   mlir::Attribute dummy_value =
632       mlir::DenseIntOrFPElementsAttr::getFromRawBuffer(value_type, dense_buffer,
633                                                        /*isSplatBuffer=*/true);
634 
635   if (IsQuantized(tensor)) {
636     return builder
637         .create<tfl::SparseQConstOp>(loc, mlir::TypeAttr::get(shaped_type),
638                                      dummy_value, s_param, compressed_data)
639         .getOperation();
640   }
641   return builder
642       .create<tfl::SparseConstOp>(loc, dummy_value, s_param, compressed_data)
643       .getOperation();
644 }
645 
BuildConstOp(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer,bool is_variable,OpBuilder builder,Location loc)646 StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
647                                   const std::vector<uint8_t>& buffer,
648                                   bool is_variable, OpBuilder builder,
649                                   Location loc) {
650   TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
651                                                /*shapeless_are_scalars=*/true,
652                                                /*is_constant=*/true));
653   auto shaped_type = type.dyn_cast<mlir::RankedTensorType>();
654   if (!shaped_type) {
655     return errors::Internal("Constant doesn't have a shape");
656   }
657 
658   if (tensor.sparsity != nullptr) {
659     return BuildSparseConstOp(tensor, buffer, shaped_type, builder, loc);
660   }
661 
662   auto elem_type = shaped_type.getElementType();
663 
664   mlir::ElementsAttr value;
665   if (is_variable) {
666     return BuildVariableOp(tensor, shaped_type, builder, loc);
667   } else if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
668     TF_ASSIGN_OR_RETURN(value,
669                         ConvertFloatBuffer(shaped_type, float_type, buffer));
670   } else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
671     TF_ASSIGN_OR_RETURN(value,
672                         ConvertIntBuffer(shaped_type, elem_type, buffer));
673   } else if (elem_type.isa<mlir::TF::StringType>()) {
674     tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
675     std::vector<llvm::StringRef> refs;
676     refs.reserve(repr.string_val_size());
677 
678     for (const auto& ref : repr.string_val())
679       refs.push_back({ref.data(), ref.size()});
680 
681     value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
682   } else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
683     auto dialect = elem_type.getContext()->getLoadedDialect("tf");
684     tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
685     std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
686 
687     value = mlir::OpaqueElementsAttr::get(dialect, shaped_type, mangled);
688   } else {
689     return errors::Unimplemented("Constant of unsupported type");
690   }
691 
692   if (IsQuantized(tensor)) {
693     auto op = builder.create<tfl::QConstOp>(
694         loc, mlir::TypeAttr::get(shaped_type), value);
695     return op.getOperation();
696   }
697   auto op = builder.create<tfl::ConstOp>(loc, value);
698   return op.getOperation();
699 }
700 
ConvertSubgraphIdxsToFunctionAttrs(tflite::BuiltinOptionsUnion options,const std::vector<std::string> & func_names,Builder builder)701 llvm::SmallVector<mlir::NamedAttribute, 4> ConvertSubgraphIdxsToFunctionAttrs(
702     tflite::BuiltinOptionsUnion options,
703     const std::vector<std::string>& func_names, Builder builder) {
704   if (auto* opts = options.AsCallOnceOptions()) {
705     uint32_t init_idx = opts->init_subgraph_index;
706     auto init_attr = builder.getStringAttr(func_names.at(init_idx));
707 
708     return {builder.getNamedAttr("session_init_function", init_attr)};
709   }
710   if (auto* opts = options.AsIfOptions()) {
711     uint32_t then_idx = opts->then_subgraph_index;
712     auto then_attr = builder.getSymbolRefAttr(func_names.at(then_idx));
713     uint32_t else_idx = opts->else_subgraph_index;
714     auto else_attr = builder.getSymbolRefAttr(func_names.at(else_idx));
715 
716     return {builder.getNamedAttr("then_branch", then_attr),
717             builder.getNamedAttr("else_branch", else_attr),
718             // TODO(b/139667752): Analyze statelessness correctly
719             builder.getNamedAttr("is_stateless", builder.getBoolAttr(false))};
720   }
721   if (auto* opts = options.AsWhileOptions()) {
722     uint32_t cond_idx = opts->cond_subgraph_index;
723     auto cond_attr = builder.getSymbolRefAttr(func_names.at(cond_idx));
724     uint32_t body_idx = opts->body_subgraph_index;
725     auto body_attr = builder.getSymbolRefAttr(func_names.at(body_idx));
726 
727     return {builder.getNamedAttr("cond", cond_attr),
728             builder.getNamedAttr("body", body_attr)};
729   }
730   return {};
731 }
732 
AddOpIntermediatesForLstm(const tflite::OperatorT & op,const std::vector<mlir::TensorType> & intermediate_types,OperationState & op_state,Location loc,OpBuilder & builder)733 Status AddOpIntermediatesForLstm(
734     const tflite::OperatorT& op,
735     const std::vector<mlir::TensorType>& intermediate_types,
736     OperationState& op_state, Location loc, OpBuilder& builder) {
737   if (!op.intermediates.empty()) {
738     if (op.intermediates.size() != 5) {
739       auto err = errors::InvalidArgument(
740           "operator has intermediate tensors but the number of them is not "
741           "five.");
742       return emitError(loc, err.ToString()), err;
743     }
744     // Create intermediate value
745 
746     const llvm::SmallVector<llvm::StringRef, 5> kIntermediateNames = {
747         "input_to_input_intermediate", "input_to_forget_intermediate",
748         "input_to_cell_intermediate", "input_to_output_intermediate",
749         "effective_hidden_scale_intermediate"};
750     for (auto type_and_name :
751          llvm::zip(intermediate_types, kIntermediateNames)) {
752       mlir::TypeAttr type_attr =
753           mlir::TypeAttr::get(std::get<0>(type_and_name));
754       auto named_attr =
755           builder.getNamedAttr(std::get<1>(type_and_name), type_attr);
756       op_state.addAttribute(named_attr.first, named_attr.second);
757     }
758   }
759   return Status::OK();
760 }
761 
762 // TODO(krzysd) Handle function calls
ConvertOp(const tflite::OperatorT & op,const std::vector<Value> & vals_map,const std::vector<mlir::TensorType> & intermediate_types,Value optional_arg_marker,const std::vector<std::unique_ptr<tflite::OperatorCodeT>> & op_codes,const std::vector<std::string> & func_names,const std::vector<std::unique_ptr<tflite::TensorT>> & tensors,Location loc,OpBuilder builder)763 StatusOr<Operation*> ConvertOp(
764     const tflite::OperatorT& op, const std::vector<Value>& vals_map,
765     const std::vector<mlir::TensorType>& intermediate_types,
766     Value optional_arg_marker,
767     const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
768     const std::vector<std::string>& func_names,
769     const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
770     OpBuilder builder) {
771   llvm::SmallVector<Value, 4> operands;
772   llvm::SmallVector<mlir::Type, 2> outputTypes;
773 
774   const tflite::OperatorCodeT& op_code = *op_codes.at(op.opcode_index);
775 
776   TF_ASSIGN_OR_RETURN(const std::string op_name, GetMlirOpName(op, op_code));
777 
778   OperationState op_state(loc, op_name);
779 
780   for (auto input_num : op.inputs) {
781     if (input_num == -1) {
782       assert(optional_arg_marker != nullptr);
783       op_state.addOperands({optional_arg_marker});
784     } else {
785       op_state.addOperands({vals_map.at(input_num)});
786     }
787   }
788 
789   for (auto output_num : op.outputs) {
790     auto& tensor = *tensors.at(output_num);
791     auto type_or_err = GetTensorType(tensor, builder);
792     if (!type_or_err.ok()) {
793       return emitError(loc, type_or_err.status().ToString()),
794              type_or_err.status();
795     }
796     auto type = type_or_err.ConsumeValueOrDie();
797 
798     if (op_name == "tfl.quantize") {
799       // Special case for quantize: return type must also be in qtype attribute
800       op_state.addAttribute("qtype", mlir::TypeAttr::get(type));
801     } else if (op_name == "tfl.reshape" && op_state.operands.size() == 1) {
802       // Special case for reshape: the second op is optional in the old
803       // converter and kernel, so we create the second operand, which is
804       // required by the new converter, from the reshape op's option.
805       auto new_shape = op.builtin_options.AsReshapeOptions()->new_shape;
806       auto shape_type = RankedTensorType::get(
807           {static_cast<int64_t>(new_shape.size())}, builder.getIntegerType(32));
808 
809       mlir::SmallVector<mlir::Attribute, 4> shape;
810       for (auto s : new_shape) {
811         shape.push_back(builder.getI32IntegerAttr(static_cast<int32_t>(s)));
812       }
813       auto output_shape = DenseElementsAttr::get(shape_type, shape);
814       auto shape_op = builder.create<tfl::ConstOp>(loc, output_shape);
815       op_state.addOperands({shape_op});
816     }
817 
818     op_state.addTypes({type});
819   }
820 
821   // While the last several tensors could be optional tensors for an tfl op, the
822   // number of input operands could vary. Gets the min/max number of
823   // operands from tflite op name.
824   // Also, since the above code special-handles the `tfl.reshape` op and add an
825   // additional input, we put these function block here.
826   llvm::MinMax input_min_max = mlir::OperandNumbersMinMax(op_name);
827   int input_max_num = input_min_max.Max;
828   int op_input_num = op_state.operands.size();
829   if (input_max_num != 0 && input_max_num > op_input_num) {
830     // If the number of current inputs is less than the op definition, fill in
831     // with `none` value,
832     llvm::SmallVector<Value, 4> none_operands(
833         input_max_num - op_input_num,
834         builder.create<mlir::ConstantOp>(loc, builder.getNoneType(),
835                                          builder.getUnitAttr()));
836     op_state.addOperands(ArrayRef<Value>(none_operands));
837   }
838 
839   if (op_name == "tfl.lstm") {
840     // TODO(b/147587779): add the right region if region is empty.
841     op_state.addRegion();
842     TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
843                                           builder));
844   }
845   if (op_name == "tfl.while") {
846     // Adds two empty regions for "tfl.while". We will fill the regions after
847     // creating the callee functions because the "tfl.while" input/output types
848     // may be different with the callee functions, and the call ops need to sync
849     // with callee function types.
850     op_state.addRegion();
851     op_state.addRegion();
852   }
853   if (op_name == "tfl.unidirectional_sequence_lstm") {
854     TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
855                                           builder));
856   }
857   if (op_name == "tfl.reshape") {
858     // Flattern reshape ops when more than one dimension shape operand is given.
859     mlir::DenseIntElementsAttr shape_attr;
860     if (matchPattern(op_state.operands[1], m_Constant(&shape_attr))) {
861       auto shape_ty =
862           op_state.operands[1].getType().dyn_cast<RankedTensorType>();
863       if (shape_ty != nullptr && shape_ty.hasRank() && shape_ty.getRank() > 1) {
864         llvm::SmallVector<mlir::Attribute, 4> shape;
865         int32_t dim_size = 0;
866         for (const auto& dim : llvm::enumerate(shape_attr.getIntValues())) {
867           const int64_t size = dim.value().getSExtValue();
868           shape.push_back(
869               builder.getI32IntegerAttr(static_cast<int32_t>(size)));
870           ++dim_size;
871         }
872         auto shape_type = RankedTensorType::get(
873             {static_cast<int32_t>(dim_size)}, builder.getIntegerType(32));
874         auto output_shape = mlir::DenseElementsAttr::get(shape_type, shape);
875         auto shape_op = builder.create<tfl::ConstOp>(loc, output_shape);
876         op_state.operands[1] = shape_op;
877       }
878     }
879   }
880 
881   llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
882   auto builtin_code = tflite::GetBuiltinCode(&op_code);
883   if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
884     auto status = mlir::CustomOptionsToAttributes(
885         op_code.custom_code, op.custom_options, builder, loc, &attrs);
886     if (!status.ok()) {
887       return emitError(loc, status.ToString()), status;
888     }
889   } else {
890     mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs);
891   }
892   op_state.addAttributes(attrs);
893 
894   // Handle the conversion from subgraph index to functions for If and While. We
895   // will add CallOps in the region to call the functions later for While.
896   auto function_ref_attrs = ConvertSubgraphIdxsToFunctionAttrs(
897       op.builtin_options, func_names, builder);
898   op_state.addAttributes(function_ref_attrs);
899 
900   return builder.createOperation(op_state);
901 }
902 
903 // Returns indices of the given tensors in the subgraph. Returns error if a
904 // tensor name cannot be found in the subgraph.
GetTensorIndices(const tflite::SubGraphT & subgraph,const std::vector<std::string> & tensor_names)905 StatusOr<std::vector<int>> GetTensorIndices(
906     const tflite::SubGraphT& subgraph,
907     const std::vector<std::string>& tensor_names) {
908   absl::flat_hash_map<std::string, int> name_to_index;
909   for (auto index_and_tensor : llvm::enumerate(subgraph.tensors)) {
910     name_to_index[index_and_tensor.value()->name] = index_and_tensor.index();
911   }
912 
913   std::vector<int> indices;
914   indices.reserve(tensor_names.size());
915 
916   for (const auto& name : tensor_names) {
917     auto found = name_to_index.find(name);
918     if (found != name_to_index.end()) {
919       indices.push_back(found->second);
920     } else {
921       return errors::InvalidArgument("could not find tensor in subgraph: ",
922                                      name);
923     }
924   }
925 
926   return indices;
927 }
928 
929 // Given a list of tensor indices, returns a string of concatenated tensor names
930 // wrapped in a NamedAttribute.
931 template <typename ContainerType>
BuildTFEntryFunctionAttribute(const tflite::SubGraphT & subgraph,Builder * builder,const std::string name,const ContainerType indices)932 mlir::NamedAttribute BuildTFEntryFunctionAttribute(
933     const tflite::SubGraphT& subgraph, Builder* builder, const std::string name,
934     const ContainerType indices) {
935   auto tensor_names = llvm::map_range(
936       indices, [&](int i) { return subgraph.tensors.at(i)->name; });
937   return builder->getNamedAttr(
938       name, builder->getStringAttr(llvm::join(tensor_names, ",")));
939 }
940 
941 // Traverses the subgraph from output_indices to input_indices and returns the
942 // set of ops that are visited.
PruneSubgraph(const tflite::SubGraphT & subgraph,ArrayRef<int32_t> input_indices,ArrayRef<int32_t> output_indices)943 StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
944     const tflite::SubGraphT& subgraph, ArrayRef<int32_t> input_indices,
945     ArrayRef<int32_t> output_indices) {
946   // Create a map from tensor index to defining op.
947   absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
948   for (const auto& op : subgraph.operators) {
949     for (int32_t output : op->outputs) {
950       if (!llvm::is_contained(input_indices, output)) {
951         defining_op[output] = op.get();
952       }
953     }
954   }
955 
956   std::vector<const tflite::OperatorT*> queue;
957   for (int32_t output : output_indices) {
958     if (auto& op = defining_op[output]) {
959       queue.push_back(op);
960     }
961   }
962 
963   // Traverse the graph towards inputs.
964   absl::flat_hash_set<const tflite::OperatorT*> visited;
965   while (!queue.empty()) {
966     const tflite::OperatorT* op = queue.back();
967     queue.pop_back();
968     if (!visited.insert(op).second) {
969       // The node has already been visited.
970       continue;
971     }
972 
973     for (int32_t input : op->inputs) {
974       // Input tensor may not have a defining op in case it is a subgraph input
975       // or a constant tensor.
976       if (auto& op = defining_op[input]) {
977         queue.push_back(op);
978       }
979     }
980   }
981 
982   return visited;
983 }
984 
985 // We want to adjust the func op according to some cross ops information.
PostProcessFuncOp(FuncOp func)986 static StatusOr<FuncOp> PostProcessFuncOp(FuncOp func) {
987   OpBuilder builder(func);
988   // When a quantized constant is imported, its quantization parameter is set
989   // to be narrow range. Here revert to be the fully range if the user doesn't
990   // require narrow range.
991   func.walk([&](tfl::QConstOp cst) {
992     Value value = cst.getResult();
993     Value full_range_const = value;
994     auto qtype = mlir::quant::UniformQuantizedType::getQuantizedElementType(
995         value.getType());
996     // Only the 8-bit constants are imported with narrow range.
997     if (!qtype || qtype.getStorageTypeIntegralWidth() != 8 ||
998         !(qtype.isa<mlir::quant::UniformQuantizedType>() ||
999           qtype.isa<mlir::quant::UniformQuantizedPerAxisType>())) {
1000       return;
1001     }
1002     for (auto& use : value.getUses()) {
1003       Operation* user = use.getOwner();
1004       if (user->hasTrait<mlir::OpTrait::IsTerminator>()) continue;
1005 
1006       auto affine_user = llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
1007       if (affine_user &&
1008           affine_user.GetAffineOperandIndex() == use.getOperandNumber() &&
1009           affine_user.RequiredNarrowRangeAffineOperand())
1010         continue;
1011       // Create a fully range quantized constant.
1012       if (full_range_const == value) {
1013         mlir::quant::QuantizedType new_qtype;
1014         if (auto per_axis =
1015                 qtype.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
1016           new_qtype = mlir::quant::UniformQuantizedPerAxisType::get(
1017               per_axis.getFlags(), per_axis.getStorageType(),
1018               per_axis.getExpressedType(), per_axis.getScales(),
1019               per_axis.getZeroPoints(), per_axis.getQuantizedDimension(),
1020               per_axis.getStorageTypeMin() - 1, per_axis.getStorageTypeMax());
1021         } else if (auto per_tensor =
1022                        qtype.dyn_cast<mlir::quant::UniformQuantizedType>()) {
1023           new_qtype = mlir::quant::UniformQuantizedType::get(
1024               per_tensor.getFlags(), per_tensor.getStorageType(),
1025               per_tensor.getExpressedType(), per_tensor.getScale(),
1026               per_tensor.getZeroPoint(), per_tensor.getStorageTypeMin() - 1,
1027               per_tensor.getStorageTypeMax());
1028         } else {
1029           return;  // Should not reach here, as it's already checked.
1030         }
1031         auto new_output_type = new_qtype.castFromExpressedType(
1032             mlir::quant::UniformQuantizedType::castToExpressedType(
1033                 value.getType()));
1034         builder.setInsertionPointAfter(cst.getOperation());
1035         auto new_op = builder.create<tfl::QConstOp>(
1036             cst.getLoc(), new_output_type, mlir::TypeAttr::get(new_output_type),
1037             cst.valueAttr());
1038         full_range_const = new_op.output();
1039       }
1040       use.set(full_range_const);
1041     }
1042     if (cst.use_empty()) cst.erase();
1043   });
1044   return func;
1045 }
1046 
1047 // Helper method that returns the index of the tensor with name 'tensor_name'
1048 // in the list of tensor names 'tensors'.
GetTensorIndex(const std::string & tensor_name,llvm::SmallVector<llvm::StringRef,2> tensors)1049 int GetTensorIndex(const std::string& tensor_name,
1050                    llvm::SmallVector<llvm::StringRef, 2> tensors) {
1051   for (const auto& tensor_index_pair : llvm::enumerate(tensors)) {
1052     if (tensor_index_pair.value() == tensor_name)
1053       return tensor_index_pair.index();
1054   }
1055   return -1;
1056 }
1057 
1058 // Helper method that returns list of all strings in a StringAttr identified
1059 // by 'attr_key' and values are separated by a comma.
GetStringsFromAttrWithSeparator(mlir::DictionaryAttr attr,const std::string & attr_key)1060 llvm::SmallVector<llvm::StringRef, 2> GetStringsFromAttrWithSeparator(
1061     mlir::DictionaryAttr attr, const std::string& attr_key) {
1062   llvm::SmallVector<llvm::StringRef, 2> result;
1063   if (auto str = attr.get(attr_key).dyn_cast_or_null<mlir::StringAttr>()) {
1064     str.getValue().split(result, ',', /*MaxSplit=*/-1,
1065                          /*KeepEmpty=*/false);
1066   }
1067   return result;
1068 }
1069 
1070 // Sets signature attributes on the function.
SetSignature(FuncOp func,const tflite::SignatureDefT * signature,const std::vector<std::unique_ptr<tflite::TensorT>> & tensors)1071 void SetSignature(
1072     FuncOp func, const tflite::SignatureDefT* signature,
1073     const std::vector<std::unique_ptr<tflite::TensorT>>& tensors) {
1074   auto* context = func->getContext();
1075   static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path";
1076   static const char kExportedNameAttr[] = "tf_saved_model.exported_names";
1077   static const char kEntryFunctionAttributes[] = "tf.entry_function";
1078 
1079   auto dict_attr =
1080       func->getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes);
1081   if (!dict_attr) return;
1082 
1083   // Get Input and output tensor names from attribute.
1084   llvm::SmallVector<llvm::StringRef, 2> input_names =
1085       GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs");
1086   llvm::SmallVector<llvm::StringRef, 2> output_names =
1087       GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs");
1088 
1089   for (auto input_pair : llvm::enumerate(signature->inputs)) {
1090     const int arg_index = GetTensorIndex(
1091         tensors[input_pair.value()->tensor_index]->name, input_names);
1092     if (arg_index == -1) {
1093       func->emitWarning("Invalid signature tensors specified.");
1094       return;
1095     }
1096     func.setArgAttr(
1097         arg_index, kSignatureDefIndexPath,
1098         mlir::ArrayAttr::get(context, {mlir::StringAttr::get(
1099                                           context, input_pair.value()->name)}));
1100   }
1101   for (auto output_pair : llvm::enumerate(signature->outputs)) {
1102     const int arg_index = GetTensorIndex(
1103         tensors[output_pair.value()->tensor_index]->name, output_names);
1104     if (arg_index == -1) {
1105       func->emitWarning("Invalid signature tensors specified.");
1106       return;
1107     }
1108     func.setResultAttr(arg_index, kSignatureDefIndexPath,
1109                        mlir::ArrayAttr::get(
1110                            context, {mlir::StringAttr::get(
1111                                         context, output_pair.value()->name)}));
1112   }
1113   func->setAttr(
1114       kExportedNameAttr,
1115       mlir::ArrayAttr::get(
1116           context, {mlir::StringAttr::get(context, signature->signature_key)}));
1117 }
1118 
1119 // Build a FuncOp from a tflite SubGraph
1120 // The buffers are directly taken
1121 // from the deserialized flatbuffer as we do not have the type information to
1122 // interpret them until this point. The base_loc parameter is the location of
1123 // the flatbuffer as a whole (usually a file). The is_entry_point flag
1124 // controls whether shapeless types are treated as scalars. If
1125 // ordered_output_arrays is not empty, then the imported mlir function will only
1126 // return nodes in ordered_output_arrays in the same order.
1127 // If signature is not null, then the inputs/outputs in signature will be
1128 // attached to the FuncOp.
ConvertSubgraph(const tflite::SubGraphT & subgraph,llvm::StringRef name,const std::vector<std::unique_ptr<tflite::OperatorCodeT>> & op_codes,const std::vector<std::string> & func_names,const std::vector<std::unique_ptr<tflite::BufferT>> & buffers,Location base_loc,Builder builder,bool is_entry_point,bool use_external_constant,const std::vector<std::string> & ordered_input_arrays,const std::vector<std::string> & ordered_output_arrays,bool experimental_prune_unreachable_nodes_unconditionally,const tflite::SignatureDefT * signature)1129 StatusOr<FuncOp> ConvertSubgraph(
1130     const tflite::SubGraphT& subgraph, llvm::StringRef name,
1131     const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
1132     const std::vector<std::string>& func_names,
1133     const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
1134     Location base_loc, Builder builder, bool is_entry_point,
1135     bool use_external_constant,
1136     const std::vector<std::string>& ordered_input_arrays,
1137     const std::vector<std::string>& ordered_output_arrays,
1138     bool experimental_prune_unreachable_nodes_unconditionally,
1139     const tflite::SignatureDefT* signature) {
1140   llvm::SmallVector<mlir::Type, 2> ret_types;
1141   llvm::SmallVector<mlir::Type, 4> input_types;
1142 
1143   auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc);
1144 
1145   std::vector<int> func_inputs = subgraph.inputs;
1146   if (is_entry_point && !ordered_input_arrays.empty()) {
1147     if (!experimental_prune_unreachable_nodes_unconditionally) {
1148       // TODO(b/149922113): Resolve input-arrays/pruning flags interaction.
1149       return errors::InvalidArgument(
1150           "input-arrays should be used with experimental pruning flag");
1151     }
1152     TF_ASSIGN_OR_RETURN(func_inputs,
1153                         GetTensorIndices(subgraph, ordered_input_arrays));
1154   }
1155 
1156   for (int input : func_inputs) {
1157     auto& tensor = *subgraph.tensors.at(input);
1158     // TODO(b/138222071) Graph inputs must have static shape per the exporter,
1159     // but we cannot differentiate scalars from unranked tensors.
1160     // Here we reverse the default assumption that shape = [] means unranked.
1161     // when processing main()
1162     auto type_or_err = GetTensorType(tensor, builder,
1163                                      /*shapeless_are_scalars=*/is_entry_point,
1164                                      /*is_constant=*/false);
1165     if (!type_or_err.ok()) {
1166       emitError(func_loc, "error reading argument types")
1167           << type_or_err.status().ToString();
1168       return type_or_err.status();
1169     }
1170     auto type = type_or_err.ConsumeValueOrDie();
1171     input_types.push_back(type);
1172   }
1173 
1174   llvm::SmallVector<bool, 16> is_op_output(subgraph.tensors.size(), false);
1175   for (auto& op : subgraph.operators) {
1176     for (auto output : op->outputs) {
1177       is_op_output[output] = true;
1178     }
1179   }
1180 
1181   std::vector<int> func_outputs = subgraph.outputs;
1182   if (is_entry_point && !ordered_output_arrays.empty()) {
1183     TF_ASSIGN_OR_RETURN(func_outputs,
1184                         GetTensorIndices(subgraph, ordered_output_arrays));
1185   }
1186 
1187   for (auto output : func_outputs) {
1188     const bool is_func_input = std::find(func_inputs.begin(), func_inputs.end(),
1189                                          output) != func_inputs.end();
1190     bool is_constant = !is_op_output[output] && !is_func_input;
1191     // There are 2 cases tensor is scalar when it doesn't have a shape in
1192     // flatbuffer:
1193     // 1. `is_constant` = true, means this tensor is created from a constant op.
1194     // 2. `is_func_input` = true and `is_entry_point` = true, which means this
1195     // tensor is function input and function input type is a scalar tensor.
1196     const bool shapeless_is_scalar =
1197         is_constant || (is_func_input && is_entry_point);
1198     auto type_or_err = GetTensorType(*subgraph.tensors.at(output), builder,
1199                                      shapeless_is_scalar,
1200                                      /*is_constant=*/is_constant);
1201     if (!type_or_err.ok()) {
1202       emitError(func_loc, "error reading return types")
1203           << type_or_err.status().ToString();
1204       return type_or_err.status();
1205     }
1206     auto type = type_or_err.ConsumeValueOrDie();
1207     ret_types.push_back(type);
1208   }
1209   auto func_type = builder.getFunctionType(input_types, ret_types);
1210 
1211   // Construct function object
1212   auto func = FuncOp::create(func_loc, name, func_type, /* attrs= */ {});
1213   func.addEntryBlock();
1214   auto& body = func.getBody();
1215   OpBuilder op_builder{body};
1216 
1217   std::vector<Value> vals_map(subgraph.tensors.size(), nullptr);
1218   Value maybe_optional_arg_marker = nullptr;
1219 
1220   // Get or construct MLIR values for each input
1221   for (int i = 0, e = func_inputs.size(); i < e; i++) {
1222     auto input_tensor = func_inputs[i];
1223     const auto& tensor = *subgraph.tensors.at(input_tensor);
1224     auto loc = TensorLoc(tensor, builder, base_loc);
1225     if (vals_map[input_tensor]) {
1226       auto err = errors::FailedPrecondition("duplicate input arguments");
1227       return emitError(loc, err.ToString()), err;
1228     }
1229     Value input_value = func.getArgument(i);
1230 
1231     // If the `tensor` has min/max and doesn't have scale/zero_point
1232     // information, a stats op is created to use the input_value, then the
1233     // `tensor` should be mapped to the result of this new stats op.
1234     if (auto stats_op =
1235             ConvertMinMaxToStatsOp(tensor, op_builder, input_value)) {
1236       vals_map[input_tensor] = stats_op->getResult(0);
1237     } else {
1238       vals_map[input_tensor] = input_value;
1239     }
1240   }
1241 
1242   // Set tf.entry_function attribute
1243   if (is_entry_point) {
1244     llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
1245     if (!func_inputs.empty()) {
1246       attributes.push_back(BuildTFEntryFunctionAttribute(
1247           subgraph, &builder, "inputs", func_inputs));
1248     }
1249     if (!func_outputs.empty()) {
1250       attributes.push_back(BuildTFEntryFunctionAttribute(
1251           subgraph, &builder, "outputs", func_outputs));
1252     }
1253     func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
1254   } else {
1255     func.setPrivate();
1256   }
1257 
1258   // Set signature on function.
1259   if (signature) {
1260     SetSignature(func, signature, subgraph.tensors);
1261   }
1262 
1263   absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
1264   if (experimental_prune_unreachable_nodes_unconditionally) {
1265     TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
1266                         PruneSubgraph(subgraph, func_inputs, func_outputs));
1267   }
1268 
1269   // Construct MLIR operators from TFLite operators
1270   for (auto& op : subgraph.operators) {
1271     if (experimental_prune_unreachable_nodes_unconditionally &&
1272         !pruned_subgraph_ops.contains(op)) {
1273       continue;
1274     }
1275 
1276     for (auto input_num : op->inputs) {
1277       // The operators in a graph are topologically sorted
1278       // and so if no previous operation has produced a tensor
1279       // it must be a constant.
1280       if (input_num == -1) {
1281         if (maybe_optional_arg_marker == nullptr) {
1282           maybe_optional_arg_marker =
1283               op_builder
1284                   .create<mlir::ConstantOp>(base_loc, builder.getNoneType(),
1285                                             builder.getUnitAttr())
1286                   .getResult();
1287         }
1288       } else if (!vals_map.at(input_num)) {
1289         auto& const_tensor = *subgraph.tensors[input_num];
1290         auto const_loc = TensorLoc(const_tensor, builder, base_loc);
1291         auto op_or_err =
1292             use_external_constant
1293                 ? BuildExternalConstOp(const_tensor, const_tensor.buffer,
1294                                        op_builder, const_loc)
1295                 : BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
1296                                const_tensor.is_variable, op_builder, const_loc);
1297         if (!op_or_err.ok()) {
1298           return emitError(const_loc, op_or_err.status().ToString()),
1299                  op_or_err.status();
1300         }
1301         vals_map[input_num] = op_or_err.ValueOrDie()->getResult(0);
1302       }
1303     }
1304 
1305     // Intermediate tensors for LSTMs are used to carry quantization range
1306     // in their types, so we only need and extract their types.
1307     std::vector<mlir::TensorType> intermediate_types;
1308     intermediate_types.reserve(5);
1309     for (auto intermediate : op->intermediates) {
1310       TF_ASSIGN_OR_RETURN(
1311           auto type, GetTensorType(*subgraph.tensors[intermediate], builder,
1312                                    /*shapeless_are_scalars=*/true,
1313                                    /*is_constant=*/false,
1314                                    /*is_intermediate=*/true));
1315       intermediate_types.emplace_back(type);
1316     }
1317 
1318     auto op_loc = OpLoc(*op, subgraph.tensors, builder, base_loc);
1319 
1320     // If there's an optional argument, maybe_optional_arg_marker has been set
1321     // to a valid Value
1322     TF_ASSIGN_OR_RETURN(
1323         auto* mlir_op,
1324         ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
1325                   op_codes, func_names, subgraph.tensors, op_loc, op_builder));
1326 
1327     // Add the results to the value maps. There are two cases: 1. the result
1328     // tensor does not have min/max values, the original op result is used
1329     // directly; 2. the result tensor has some min/max values, a stats op is
1330     // created, then the result of the stats op is used.
1331     for (auto pair : llvm::enumerate(mlir_op->getResults())) {
1332       int output_tensor_index = op->outputs[pair.index()];
1333       auto& tensor = *subgraph.tensors[output_tensor_index];
1334       if (auto stats_op =
1335               ConvertMinMaxToStatsOp(tensor, op_builder, pair.value())) {
1336         vals_map[output_tensor_index] = stats_op->getResult(0);
1337       } else {
1338         vals_map[output_tensor_index] = pair.value();
1339       }
1340     }
1341   }
1342 
1343   // Construct return values
1344   llvm::SmallVector<Value, 4> return_operands;
1345   for (auto index : func_outputs) {
1346     if (!vals_map.at(index)) {
1347       auto& const_tensor = *subgraph.tensors[index];
1348       auto const_loc = TensorLoc(const_tensor, builder, base_loc);
1349       auto op_or_err =
1350           use_external_constant
1351               ? BuildExternalConstOp(const_tensor, const_tensor.buffer,
1352                                      op_builder, const_loc)
1353               : BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
1354                              const_tensor.is_variable, op_builder, const_loc);
1355       if (!op_or_err.ok()) {
1356         return emitError(const_loc, op_or_err.status().ToString()),
1357                op_or_err.status();
1358       }
1359       vals_map[index] = op_or_err.ValueOrDie()->getResult(0);
1360     }
1361     return_operands.push_back(vals_map[index]);
1362   }
1363 
1364   op_builder.create<mlir::ReturnOp>(base_loc, return_operands);
1365 
1366   return PostProcessFuncOp(func);
1367 }
1368 
1369 // TFLite subgraphs do not necessarily have names, though MLIR functions must
1370 // have them, so we generate a name for subgraphs that are missing one here.
1371 // Note: in TFLite, the first subgraph is the entry point, and in MLIR that
1372 // represents TFLite, this entry point must be called "main"
1373 // TODO(b/131175224,b/132239787) Support multiple entry points
SubgraphName(unsigned index,const tflite::SubGraphT & subgraph)1374 std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
1375   if (index == 0) {
1376     return "main";
1377   }
1378   if (subgraph.name.empty()) {
1379     return llvm::formatv("fn_{0}", index).str();
1380   }
1381   return subgraph.name;
1382 }
1383 
1384 // Adds a CallOp in `region` to call the `func` and returns the results of
1385 // CallOp.
AddCallOpInWhileOpRegion(mlir::Region & region,mlir::FuncOp func)1386 void AddCallOpInWhileOpRegion(mlir::Region& region, mlir::FuncOp func) {
1387   OpBuilder op_builder{region};
1388   region.push_back(new mlir::Block());
1389   region.addArguments(func.getType().getInputs());
1390   op_builder.setInsertionPointToStart(&region.front());
1391   auto call_op = op_builder.create<mlir::CallOp>(
1392       region.getLoc(), func.getType().getResults(), func.sym_name(),
1393       region.getArguments());
1394   op_builder.create<mlir::TFL::YieldOp>(region.getLoc(), call_op.getResults());
1395 }
1396 
1397 // TFL::WhileOp has regions, so we add CallOp to call the FuncOp in the regions
1398 // if we have while ops.
AddRegionsForTflWhileOp(mlir::ModuleOp module)1399 void AddRegionsForTflWhileOp(mlir::ModuleOp module) {
1400   mlir::SymbolTable symbol_table(module);
1401   module.walk([&](mlir::TFL::WhileOp while_op) {
1402     auto cond = symbol_table.lookup<mlir::FuncOp>(
1403         while_op->getAttr("cond").cast<mlir::FlatSymbolRefAttr>().getValue());
1404     AddCallOpInWhileOpRegion(while_op.cond(), cond);
1405     while_op->removeAttr("cond");
1406     auto body = symbol_table.lookup<mlir::FuncOp>(
1407         while_op->getAttr("body").cast<mlir::FlatSymbolRefAttr>().getValue());
1408     AddCallOpInWhileOpRegion(while_op.body(), body);
1409     while_op->removeAttr("body");
1410   });
1411 }
1412 }  // namespace
1413 
FlatBufferToMlir(absl::string_view buffer,MLIRContext * context,Location base_loc,bool use_external_constant,const std::vector<std::string> & ordered_input_arrays,const std::vector<std::string> & ordered_output_arrays,bool experimental_prune_unreachable_nodes_unconditionally)1414 OwningModuleRef tflite::FlatBufferToMlir(
1415     absl::string_view buffer, MLIRContext* context, Location base_loc,
1416     bool use_external_constant,
1417     const std::vector<std::string>& ordered_input_arrays,
1418     const std::vector<std::string>& ordered_output_arrays,
1419     bool experimental_prune_unreachable_nodes_unconditionally) {
1420   context->loadDialect<
1421       mlir::StandardOpsDialect, mlir::quant::QuantizationDialect,
1422       mlir::TFL::TensorFlowLiteDialect, mlir::TF::TensorFlowDialect>();
1423 
1424   auto model_ptr =
1425       FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
1426   if (nullptr == model_ptr) {
1427     return emitError(base_loc, "couldn't parse flatbuffer"), nullptr;
1428   }
1429 
1430   std::unique_ptr<ModelT> model(model_ptr->GetModel()->UnPack());
1431 
1432   auto builder = Builder(context);
1433 
1434   std::vector<std::string> func_names;
1435   for (auto& subgraph : model->subgraphs) {
1436     func_names.push_back(subgraph->name);
1437   }
1438 
1439   auto module = mlir::ModuleOp::create(base_loc);
1440   // We currently don't use this to make decisions, but we could
1441   // use it in exports or if there are breaking changes
1442   module->setAttr("tfl.schema_version",
1443                   builder.getI32IntegerAttr(model->version));
1444   if (!model->description.empty()) {
1445     module->setAttr("tfl.description",
1446                     builder.getStringAttr(model->description));
1447   }
1448 
1449   // TODO(b/184697652): Update to handle multiple entry points.
1450   tflite::SignatureDefT* signature_def = nullptr;
1451   if (!model->signature_defs.empty()) {
1452     signature_def = model->signature_defs[0].get();
1453   }
1454 
1455   for (auto e : llvm::enumerate(model->subgraphs)) {
1456     auto& subgraph = e.value();
1457     std::string name = SubgraphName(e.index(), *subgraph);
1458     auto func_or_error = ConvertSubgraph(
1459         *subgraph, name, model->operator_codes, func_names, model->buffers,
1460         base_loc, builder,
1461         // TODO(b/131175224,b/132239787) Support multiple entry points
1462         /*is_entry_point=*/e.index() == 0,
1463         /*use_external_constant=*/use_external_constant, ordered_input_arrays,
1464         ordered_output_arrays,
1465         experimental_prune_unreachable_nodes_unconditionally,
1466         e.index() == 0 ? signature_def : nullptr);
1467     if (!func_or_error.ok()) {
1468       return emitError(base_loc, "could not translate function ")
1469                  << subgraph->name << ": "
1470                  << func_or_error.status().error_message(),
1471              nullptr;
1472     }
1473     module.push_back(func_or_error.ConsumeValueOrDie());
1474   }
1475   AddRegionsForTflWhileOp(module);
1476   return OwningModuleRef(module);
1477 }
1478