• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
17 
18 #include <algorithm>
19 #include <cctype>
20 #include <cstdint>
21 #include <iostream>
22 #include <sstream>
23 #include <string>
24 #include <vector>
25 
26 #include "absl/base/casts.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/strings/string_view.h"
30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/ADT/APInt.h"
32 #include "llvm/ADT/ArrayRef.h"
33 #include "llvm/ADT/DenseMap.h"
34 #include "llvm/ADT/None.h"
35 #include "llvm/ADT/Optional.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/SmallVector.h"
38 #include "llvm/ADT/StringExtras.h"
39 #include "llvm/ADT/StringRef.h"
40 #include "llvm/ADT/Twine.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Support/CommandLine.h"
43 #include "llvm/Support/Endian.h"
44 #include "llvm/Support/FormatVariadic.h"
45 #include "llvm/Support/MemoryBuffer.h"
46 #include "llvm/Support/SourceMgr.h"
47 #include "llvm/Support/raw_ostream.h"
48 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
49 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
50 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
51 #include "mlir/IR/Attributes.h"  // from @llvm-project
52 #include "mlir/IR/Builders.h"  // from @llvm-project
53 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
54 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
55 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
56 #include "mlir/IR/Location.h"  // from @llvm-project
57 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
58 #include "mlir/IR/Operation.h"  // from @llvm-project
59 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
60 #include "mlir/IR/Types.h"  // from @llvm-project
61 #include "mlir/IR/Value.h"  // from @llvm-project
62 #include "mlir/Support/LLVM.h"  // from @llvm-project
63 #include "mlir/Translation.h"  // from @llvm-project
64 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
65 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
66 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
67 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
70 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
71 #include "tensorflow/compiler/xla/statusor.h"
72 #include "tensorflow/core/framework/tensor.pb.h"
73 #include "tensorflow/core/framework/tensor_shape.pb.h"
74 #include "tensorflow/core/platform/errors.h"
75 #include "tensorflow/core/platform/status.h"
76 #include "tensorflow/lite/model.h"
77 #include "tensorflow/lite/schema/schema_generated.h"
78 #include "tensorflow/lite/schema/schema_utils.h"
79 
80 using llvm::ArrayRef;
81 using mlir::Builder;
82 using mlir::DenseElementsAttr;
83 using mlir::FuncOp;
84 using mlir::Location;
85 using mlir::MLIRContext;
86 using mlir::OpBuilder;
87 using mlir::Operation;
88 using mlir::OperationState;
89 using mlir::OwningModuleRef;
90 using mlir::RankedTensorType;
91 using mlir::UnrankedTensorType;
92 using mlir::Value;
93 using mlir::quant::QuantizedType;
94 using tflite::TensorT;
95 using xla::Status;
96 using xla::StatusOr;
97 
98 namespace errors = tensorflow::errors;
99 namespace tfl = mlir::TFL;
100 
101 namespace {
IsScalar(const TensorT & tensor)102 bool IsScalar(const TensorT& tensor) {
103   // TODO(b/138222071) We can't distinguish scalars and unranked tensors
104   // Work out a way to handle this and stub out the code until then
105   return tensor.shape.empty() && false;
106 }
107 
IsQuantized(const TensorT & tensor)108 bool IsQuantized(const TensorT& tensor) {
109   return (tensor.quantization != nullptr) &&
110          !tensor.quantization->zero_point.empty();
111 }
112 
113 // Create the MLIR NamedLoc location corresponding to a given tensor
TensorLoc(const TensorT & tensor,Builder builder,Location base)114 Location TensorLoc(const TensorT& tensor, Builder builder, Location base) {
115   if (tensor.name.empty()) {
116     return base;
117   }
118   return mlir::NameLoc::get(builder.getIdentifier(tensor.name), base);
119 }
120 
121 // Returns the correct type for a quantized tensor
122 // We have a special case for constants since they have a higher minimum value.
GetQuantizedType(const TensorT & tensor,Builder builder,bool is_constant=false)123 StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
124                                          bool is_constant = false) {
125   tflite::QuantizationParametersT& quant_params = *tensor.quantization;
126   if (quant_params.details.AsCustomQuantization()) {
127     return errors::Unimplemented("Cannot handle experimental quantization");
128   }
129 
130   bool is_signed = true;
131   mlir::IntegerType storage_type;
132   if (tensor.type == tflite::TensorType_UINT8) {
133     is_signed = false;
134     storage_type = builder.getIntegerType(8);
135   } else {
136     auto raw_elem_type = ConvertElementType(tensor.type, builder);
137     if (!raw_elem_type.isa<mlir::IntegerType>()) {
138       return errors::InvalidArgument(
139           "Quantized tensors must be stored as integers");
140     }
141     storage_type = raw_elem_type.cast<mlir::IntegerType>();
142   }
143 
144   // TFlite uses narrow-range [u]int8 for constant buffers of quantized weights.
145   // Since we don't know which ones are weights, we represent this optimization
146   // as a change in the storage bounds for the type for all constants of this
147   // type.
148   bool is_weight_buffer = is_constant && (storage_type.getWidth() == 8);
149 
150   int64_t storage_min = QuantizedType::getDefaultMinimumForInteger(
151                             is_signed, storage_type.getWidth()) +
152                         static_cast<int>(is_weight_buffer);
153   int64_t storage_max = QuantizedType::getDefaultMaximumForInteger(
154       is_signed, storage_type.getWidth());
155   uint32_t flags =
156       is_signed ? mlir::quant::QuantizationFlags::FlagValue::Signed : 0;
157 
158   // Rejects if quantized tensors have zero scales.
159   for (float scale : quant_params.scale) {
160     if (scale == 0) {
161       return errors::InvalidArgument(
162           "Quantized tensors must have non-zero scales");
163     }
164   }
165 
166   // Scale size can't be zero as it is checked before.
167   if (quant_params.scale.size() != 1) {
168     llvm::SmallVector<double, 4> scales(quant_params.scale.begin(),
169                                         quant_params.scale.end());
170     return mlir::quant::UniformQuantizedPerAxisType::get(
171         flags, storage_type, builder.getF32Type(), scales,
172         quant_params.zero_point, quant_params.quantized_dimension, storage_min,
173         storage_max);
174   }
175   return mlir::quant::UniformQuantizedType::get(
176       flags, storage_type, builder.getF32Type(), quant_params.scale.at(0),
177       quant_params.zero_point.at(0), storage_min, storage_max);
178 }
179 
180 // import float tensor with calibration value into calibrated quantized type.
GetCalibratedQuantizedType(const TensorT & tensor,Builder builder)181 StatusOr<QuantizedType> GetCalibratedQuantizedType(const TensorT& tensor,
182                                                    Builder builder) {
183   if (tensor.quantization == nullptr) {
184     return errors::InvalidArgument("The tensor is not quantized.");
185   }
186   auto raw_elem_type = ConvertElementType(tensor.type, builder);
187   float min = tensor.quantization->min[0];
188   float max = tensor.quantization->max[0];
189   return mlir::quant::CalibratedQuantizedType::get(raw_elem_type, min, max);
190 }
191 
192 // TODO(b/138222071) Remove shapeless_are_scalars once we can reliably
193 // make that distinction and don't have to rely on context
194 // (input to main and constants must have static shape)
GetTensorType(const TensorT & tensor,Builder builder,bool shapeless_are_scalars=false,bool is_constant=false,bool is_intermediate=false)195 StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
196                                          bool shapeless_are_scalars = false,
197                                          bool is_constant = false,
198                                          bool is_intermediate = false) {
199   mlir::Type elem_type = ConvertElementType(tensor.type, builder);
200   // TODO(b/139554398) Store min/max (even for non-quantized tensors) somewhere
201   // if it's set
202   if (IsQuantized(tensor)) {
203     TF_ASSIGN_OR_RETURN(elem_type,
204                         GetQuantizedType(tensor, builder, is_constant));
205   }
206 
207   // Intermediate tensors with calibration value (but not scale and zero points)
208   // should return calibrated quantized type.
209   if (is_intermediate && tensor.quantization != nullptr &&
210       !IsQuantized(tensor)) {
211     TF_ASSIGN_OR_RETURN(elem_type, GetCalibratedQuantizedType(tensor, builder));
212   }
213 
214   if (IsScalar(tensor) || (shapeless_are_scalars && tensor.shape.empty())) {
215     return RankedTensorType::get({}, elem_type);
216   }
217 
218   if (!tensor.shape_signature.empty()) {
219     llvm::SmallVector<int64_t, 4> shape(tensor.shape_signature.begin(),
220                                         tensor.shape_signature.end());
221     return RankedTensorType::get(shape, elem_type);
222   }
223 
224   if (!tensor.shape.empty()) {
225     llvm::SmallVector<int64_t, 4> shape(tensor.shape.begin(),
226                                         tensor.shape.end());
227     return RankedTensorType::get(shape, elem_type);
228   }
229 
230   return UnrankedTensorType::get(elem_type);
231 }
232 
233 // Extract the min max information in the tensor and create the quant stats op.
234 // If the input `tensor` has scale/zero_point, `res` should have quantized
235 // type, thus none stats op is required and nullptr is retruned.
236 // If the min max information is invalid, nullptr is returned.
ConvertMinMaxToStatsOp(const TensorT & tensor,OpBuilder b,Value res)237 mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
238                                         Value res) {
239   // If the `tensor` has scale/zero_point, it must have been quantized, then the
240   // min/max stats is just for comments, so ignore it.
241   if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
242   // If the result isn't float and unquantizable, the min/max is ignored.
243   if (!res.getType()
244            .cast<mlir::ShapedType>()
245            .getElementType()
246            .isa<mlir::FloatType>()) {
247     return nullptr;
248   }
249   auto mins = tensor.quantization->min;
250   auto maxs = tensor.quantization->max;
251   if (mins.size() != maxs.size() || mins.empty()) return nullptr;
252 
253   llvm::SmallVector<llvm::APFloat, 4> min_maxs;
254   min_maxs.reserve(mins.size() * 2);
255   for (int i = 0, end = mins.size(); i < end; ++i) {
256     llvm::APFloat min(mins[i]);
257     llvm::APFloat max(maxs[i]);
258     min_maxs.push_back(min);
259     min_maxs.push_back(max);
260   }
261   // The layer stats contain only the first min/max pairs.
262   mlir::ElementsAttr layer_stats = mlir::DenseFPElementsAttr::get(
263       mlir::RankedTensorType::get({2}, b.getF32Type()),
264       {min_maxs[0], min_maxs[1]});
265   mlir::ElementsAttr axis_stats;
266   mlir::IntegerAttr axis;
267   if (mins.size() > 1) {
268     llvm::SmallVector<int64_t, 4> axis_stats_shape{
269         static_cast<int64_t>(mins.size()), 2};
270     axis_stats = mlir::DenseFPElementsAttr::get(
271         mlir::RankedTensorType::get(axis_stats_shape, b.getF32Type()),
272         min_maxs);
273     // TODO(fengliuai): this quantization dimension isn't correct.
274     axis = b.getI64IntegerAttr(tensor.quantization->quantized_dimension);
275   }
276   return b.create<mlir::quant::StatisticsOp>(b.getUnknownLoc(), res,
277                                              layer_stats, axis_stats, axis);
278 }
279 
280 // Returns true if this is a basic LSTM op.
IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union)281 bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
282   if (const auto* op = op_union.AsLSTMOptions()) {
283     return op->kernel_type == tflite::LSTMKernelType_BASIC;
284   } else {
285     return false;
286   }
287 }
288 
289 // Gets the MLIR op name with the dialect name for the flatbuffer operator.
GetMlirOpName(const tflite::OperatorT & op,const tflite::OperatorCodeT & op_code)290 StatusOr<std::string> GetMlirOpName(const tflite::OperatorT& op,
291                                     const tflite::OperatorCodeT& op_code) {
292   if (IsBasicLSTMOp(op.builtin_options)) {
293     return std::string("tfl.basic_lstm");
294   }
295 
296   auto builtin_code = tflite::GetBuiltinCode(&op_code);
297   if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
298     return std::string("tfl.custom");
299   }
300   if (builtin_code == tflite::BuiltinOperator_IF) {
301     return std::string("tf.If");
302   }
303   if (builtin_code == tflite::BuiltinOperator_WHILE) {
304     return std::string("tfl.while");
305   }
306 
307   llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code));
308   return llvm::Twine("tfl.", op_name.lower()).str();
309 }
310 
311 // The buffers in TFLite flatbuffers have their contents stored as a vector of
312 // bytes that represent little-endian values.
313 // The read_size parameter is present to allow reading both float16 and float32s
314 // without a case split.
315 template <typename T>
ReadAsLittleEndian(ArrayRef<uint8_t> bytes)316 std::vector<T> ReadAsLittleEndian(ArrayRef<uint8_t> bytes) {
317   std::vector<T> ret;
318   size_t read_size = sizeof(T);
319   int bytes_len = bytes.size();
320   assert(bytes_len % read_size == 0);
321 
322   int elem_count = bytes_len / read_size;
323   ret.reserve(elem_count);
324 
325   const char* data_ptr = reinterpret_cast<const char*>(bytes.data());
326   for (int i = 0; i < elem_count; i++) {
327     ret.push_back(
328         llvm::support::endian::readNext<T, llvm::support::little,
329                                         llvm::support::unaligned>(data_ptr));
330   }
331   return ret;
332 }
333 
ConvertTfliteConstTensor(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer)334 tensorflow::TensorProto ConvertTfliteConstTensor(
335     const tflite::TensorT& tensor, const std::vector<uint8_t>& buffer) {
336   tensorflow::TensorProto ret;
337   ret.set_dtype(TflTypeToTfType(tensor.type));
338 
339   tensorflow::TensorShapeProto* shape = ret.mutable_tensor_shape();
340   shape->set_unknown_rank(false);
341   for (auto dim : tensor.shape) {
342     shape->add_dim()->set_size(int64_t{dim});
343   }
344   std::string content;
345   content.assign(reinterpret_cast<const char*>(buffer.data()), buffer.size());
346   ret.set_tensor_content(content);
347   return ret;
348 }
349 
ConvertFloatBuffer(mlir::RankedTensorType shaped_type,mlir::FloatType elem_type,const std::vector<uint8_t> & buffer)350 StatusOr<mlir::ElementsAttr> ConvertFloatBuffer(
351     mlir::RankedTensorType shaped_type, mlir::FloatType elem_type,
352     const std::vector<uint8_t>& buffer) {
353   size_t bytes_len = buffer.size();
354 
355   // The bytes of floats are stored little-endian.
356   switch (elem_type.getWidth()) {
357     case 16: {
358       assert(bytes_len % 2 == 0);
359       int elem_count = bytes_len / 2;
360       std::vector<llvm::APFloat> values;
361       values.reserve(elem_count);
362 
363       const char* data = reinterpret_cast<const char*>(buffer.data());
364       auto& semantics = elem_type.getFloatSemantics();
365 
366       for (int i = 0; i < elem_count; i++) {
367         uint16_t bit_repr =
368             llvm::support::endian::readNext<uint16_t, llvm::support::little,
369                                             llvm::support::unaligned>(data);
370         llvm::APInt int_repr(16, bit_repr);
371         values.emplace_back(semantics, int_repr);
372       }
373 
374       return DenseElementsAttr::get(shaped_type, values);
375     }
376     case 32: {
377       assert(bytes_len % 4 == 0);
378       int elem_count = bytes_len / 4;
379       std::vector<float> values;
380       values.reserve(elem_count);
381 
382       const char* data = reinterpret_cast<const char*>(buffer.data());
383 
384       for (int i = 0; i < elem_count; i++) {
385         uint32_t bit_repr =
386             llvm::support::endian::readNext<uint32_t, llvm::support::little,
387                                             llvm::support::unaligned>(data);
388         values.push_back(absl::bit_cast<float>(bit_repr));
389       }
390       return DenseElementsAttr::get(shaped_type, ArrayRef<float>(values));
391     }
392     case 64: {
393       assert(bytes_len % 8 == 0);
394       int elem_count = bytes_len / 8;
395       std::vector<double> values;
396       values.reserve(elem_count);
397 
398       const char* data = reinterpret_cast<const char*>(buffer.data());
399 
400       for (int i = 0; i < elem_count; i++) {
401         uint64_t bit_repr =
402             llvm::support::endian::readNext<uint64_t, llvm::support::little,
403                                             llvm::support::unaligned>(data);
404         values.push_back(absl::bit_cast<double>(bit_repr));
405       }
406       return DenseElementsAttr::get(shaped_type, ArrayRef<double>(values));
407     }
408   }
409   return errors::InvalidArgument("unsupported bit width", elem_type.getWidth());
410 }
411 
ConvertIntBuffer(mlir::RankedTensorType shaped_type,mlir::Type elem_type,const std::vector<uint8_t> & buffer)412 StatusOr<mlir::ElementsAttr> ConvertIntBuffer(
413     mlir::RankedTensorType shaped_type, mlir::Type elem_type,
414     const std::vector<uint8_t>& buffer) {
415   unsigned bit_width;
416   if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) {
417     bit_width = itype.getWidth();
418   } else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) {
419     bit_width = qtype.getStorageTypeIntegralWidth();
420     shaped_type = mlir::RankedTensorType::get(shaped_type.getShape(),
421                                               qtype.getStorageType());
422   } else {
423     return errors::InvalidArgument("unsupported integer constant type");
424   }
425 
426   switch (bit_width) {
427     case 1: {
428       // vector<bool> doesn't convert to an ArrayRef
429       llvm::SmallVector<bool, 8> values;
430       values.reserve(buffer.size());
431       for (auto b : buffer) {
432         values.emplace_back(b != 0);
433       }
434       return DenseElementsAttr::get(shaped_type, ArrayRef<bool>(values));
435     }
436     case 8: {
437       return DenseElementsAttr::get(shaped_type, ArrayRef<uint8_t>(buffer));
438     }
439     case 16: {
440       auto values = ReadAsLittleEndian<uint16_t>(buffer);
441       return DenseElementsAttr::get(shaped_type, ArrayRef<uint16_t>(values));
442     }
443     case 32: {
444       auto values = ReadAsLittleEndian<uint32_t>(buffer);
445       return DenseElementsAttr::get(shaped_type, ArrayRef<uint32_t>(values));
446     }
447     case 64: {
448       auto values = ReadAsLittleEndian<uint64_t>(buffer);
449       return DenseElementsAttr::get(shaped_type, ArrayRef<uint64_t>(values));
450     }
451     default:
452       return errors::Unimplemented("Cannot handle bit width ", bit_width);
453   }
454 }
455 
BuildExternalConstOp(const tflite::TensorT & tensor,int32_t buffer_index,OpBuilder builder,Location loc)456 StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
457                                           int32_t buffer_index,
458                                           OpBuilder builder, Location loc) {
459   TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
460                                                /*shapeless_are_scalars=*/true,
461                                                /*is_constant=*/true));
462   auto shaped_type = type.dyn_cast<mlir::RankedTensorType>();
463   if (!shaped_type) {
464     return errors::Internal("Constant doesn't have a shape");
465   }
466   auto op = builder.create<tfl::ExternalConstOp>(
467       loc, shaped_type, builder.getI32IntegerAttr(buffer_index));
468   return op.getOperation();
469 }
470 
471 // Gets a constant splat for the given value of type. Requires value to be of
472 // type static shaped RankedTensorType. `unique_index` is used to get the unique
473 // value for the attribute.
GetSplat(RankedTensorType type,int unique_index,OpBuilder builder)474 static mlir::ElementsAttr GetSplat(RankedTensorType type, int unique_index,
475                                    OpBuilder builder) {
476   mlir::Type element_ty = getElementTypeOrSelf(type);
477 
478   if (element_ty.isSignlessInteger())
479     return DenseElementsAttr::get(
480         type, builder.getIntegerAttr(element_ty, unique_index));
481 
482   if (element_ty.isa<mlir::FloatType>())
483     return DenseElementsAttr::get(
484         type, builder.getFloatAttr(element_ty, unique_index));
485 
486   if (auto qtype = element_ty.dyn_cast<QuantizedType>()) {
487     mlir::RankedTensorType new_type =
488         RankedTensorType::get(type.getShape(), qtype.getStorageType());
489     return DenseElementsAttr::get(
490         new_type, builder.getIntegerAttr(qtype.getStorageType(), unique_index));
491   }
492   llvm_unreachable("unhandled element type");
493 }
494 
495 // TODO(b/172664358): Creates a new op instead of reusing constant op.
496 // Creates a constant op to represent stateful variable. The function static
497 // variable `stateful_variable_idx` is used as a unique value for each constant
498 // to avoid CSEed. `tensor` is the data structure of flatbuffer. `shaped_type`
499 // is the ShapedType for the const op.
BuildVariableOp(const tflite::TensorT & tensor,mlir::RankedTensorType shaped_type,OpBuilder builder,Location loc)500 Operation* BuildVariableOp(const tflite::TensorT& tensor,
501                            mlir::RankedTensorType shaped_type,
502                            OpBuilder builder, Location loc) {
503   static int stateful_variable_idx = 0;
504   mlir::ElementsAttr value =
505       GetSplat(shaped_type, stateful_variable_idx++, builder);
506   if (IsQuantized(tensor)) {
507     auto op = builder.create<tfl::QConstOp>(
508         loc, mlir::TypeAttr::get(shaped_type), value);
509     return op.getOperation();
510   }
511   auto op = builder.create<tfl::ConstOp>(loc, value);
512   if (tensor.quantization && !tensor.quantization->min.empty()) {
513     if (auto stats_op =
514             ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) {
515       return stats_op;
516     }
517   }
518   return op.getOperation();
519 }
520 
BuildConstOp(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer,bool is_variable,OpBuilder builder,Location loc)521 StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
522                                   const std::vector<uint8_t>& buffer,
523                                   bool is_variable, OpBuilder builder,
524                                   Location loc) {
525   TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
526                                                /*shapeless_are_scalars=*/true,
527                                                /*is_constant=*/true));
528   auto shaped_type = type.dyn_cast<mlir::RankedTensorType>();
529   if (!shaped_type) {
530     return errors::Internal("Constant doesn't have a shape");
531   }
532 
533   auto elem_type = shaped_type.getElementType();
534 
535   mlir::ElementsAttr value;
536   if (is_variable) {
537     return BuildVariableOp(tensor, shaped_type, builder, loc);
538   } else if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
539     TF_ASSIGN_OR_RETURN(value,
540                         ConvertFloatBuffer(shaped_type, float_type, buffer));
541   } else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
542     TF_ASSIGN_OR_RETURN(value,
543                         ConvertIntBuffer(shaped_type, elem_type, buffer));
544   } else if (elem_type.isa<mlir::TF::StringType>()) {
545     tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
546     std::vector<llvm::StringRef> refs;
547     refs.reserve(repr.string_val_size());
548 
549     for (const auto& ref : repr.string_val())
550       refs.push_back({ref.data(), ref.size()});
551 
552     value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
553   } else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
554     auto dialect = elem_type.getContext()->getLoadedDialect("tf");
555     tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
556     std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
557 
558     value = mlir::OpaqueElementsAttr::get(dialect, shaped_type, mangled);
559   } else {
560     return errors::Unimplemented("Constant of unsupported type");
561   }
562 
563   if (IsQuantized(tensor)) {
564     auto op = builder.create<tfl::QConstOp>(
565         loc, mlir::TypeAttr::get(shaped_type), value);
566     return op.getOperation();
567   }
568   auto op = builder.create<tfl::ConstOp>(loc, value);
569   return op.getOperation();
570 }
571 
ConvertSubgraphIdxsToFunctionAttrs(tflite::BuiltinOptionsUnion options,const std::vector<std::string> & func_names,Builder builder)572 llvm::SmallVector<mlir::NamedAttribute, 4> ConvertSubgraphIdxsToFunctionAttrs(
573     tflite::BuiltinOptionsUnion options,
574     const std::vector<std::string>& func_names, Builder builder) {
575   if (auto* opts = options.AsIfOptions()) {
576     uint32_t then_idx = opts->then_subgraph_index;
577     auto then_attr = builder.getSymbolRefAttr(func_names.at(then_idx));
578     uint32_t else_idx = opts->else_subgraph_index;
579     auto else_attr = builder.getSymbolRefAttr(func_names.at(else_idx));
580 
581     return {builder.getNamedAttr("then_branch", then_attr),
582             builder.getNamedAttr("else_branch", else_attr),
583             // TODO(b/139667752): Analyze statelessness correctly
584             builder.getNamedAttr("is_stateless", builder.getBoolAttr(false))};
585   }
586   if (auto* opts = options.AsWhileOptions()) {
587     uint32_t cond_idx = opts->cond_subgraph_index;
588     auto cond_attr = builder.getSymbolRefAttr(func_names.at(cond_idx));
589     uint32_t body_idx = opts->body_subgraph_index;
590     auto body_attr = builder.getSymbolRefAttr(func_names.at(body_idx));
591 
592     return {builder.getNamedAttr("cond", cond_attr),
593             builder.getNamedAttr("body", body_attr)};
594   }
595   return {};
596 }
597 
AddOpIntermediatesForLstm(const tflite::OperatorT & op,const std::vector<mlir::TensorType> & intermediate_types,OperationState & op_state,Location loc,OpBuilder & builder)598 Status AddOpIntermediatesForLstm(
599     const tflite::OperatorT& op,
600     const std::vector<mlir::TensorType>& intermediate_types,
601     OperationState& op_state, Location loc, OpBuilder& builder) {
602   if (!op.intermediates.empty()) {
603     if (op.intermediates.size() != 5) {
604       auto err = errors::InvalidArgument(
605           "operator has intermediate tensors but the number of them is not "
606           "five.");
607       return emitError(loc, err.ToString()), err;
608     }
609     // Create intermediate value
610 
611     const llvm::SmallVector<llvm::StringRef, 5> kIntermediateNames = {
612         "input_to_input_intermediate", "input_to_forget_intermediate",
613         "input_to_cell_intermediate", "input_to_output_intermediate",
614         "effective_hidden_scale_intermediate"};
615     for (auto type_and_name :
616          llvm::zip(intermediate_types, kIntermediateNames)) {
617       mlir::TypeAttr type_attr =
618           mlir::TypeAttr::get(std::get<0>(type_and_name));
619       auto named_attr =
620           builder.getNamedAttr(std::get<1>(type_and_name), type_attr);
621       op_state.addAttribute(named_attr.first, named_attr.second);
622     }
623   }
624   return Status::OK();
625 }
626 
627 // TODO(krzysd) Handle function calls
ConvertOp(const tflite::OperatorT & op,const std::vector<Value> & vals_map,const std::vector<mlir::TensorType> & intermediate_types,Value optional_arg_marker,const std::vector<std::unique_ptr<tflite::OperatorCodeT>> & op_codes,const std::vector<std::string> & func_names,const std::vector<std::unique_ptr<tflite::TensorT>> & tensors,Location loc,OpBuilder builder)628 StatusOr<Operation*> ConvertOp(
629     const tflite::OperatorT& op, const std::vector<Value>& vals_map,
630     const std::vector<mlir::TensorType>& intermediate_types,
631     Value optional_arg_marker,
632     const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
633     const std::vector<std::string>& func_names,
634     const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
635     OpBuilder builder) {
636   llvm::SmallVector<Value, 4> operands;
637   llvm::SmallVector<mlir::Type, 2> outputTypes;
638 
639   if (op.outputs.empty()) {
640     auto err = errors::InvalidArgument("operator with no outputs");
641     return emitError(loc, err.ToString()), err;
642   }
643 
644   const tflite::OperatorCodeT& op_code = *op_codes.at(op.opcode_index);
645 
646   TF_ASSIGN_OR_RETURN(const std::string op_name, GetMlirOpName(op, op_code));
647 
648   OperationState op_state(loc, op_name);
649 
650   for (auto input_num : op.inputs) {
651     if (input_num == -1) {
652       assert(optional_arg_marker != nullptr);
653       op_state.addOperands({optional_arg_marker});
654     } else {
655       op_state.addOperands({vals_map.at(input_num)});
656     }
657   }
658 
659   for (auto output_num : op.outputs) {
660     auto& tensor = *tensors.at(output_num);
661     auto type_or_err = GetTensorType(tensor, builder);
662     if (!type_or_err.ok()) {
663       return emitError(loc, type_or_err.status().ToString()),
664              type_or_err.status();
665     }
666     auto type = type_or_err.ConsumeValueOrDie();
667 
668     if (op_name == "tfl.quantize") {
669       // Special case for quantize: return type must also be in qtype attribute
670       op_state.addAttribute("qtype", mlir::TypeAttr::get(type));
671     } else if (op_name == "tfl.reshape" && op_state.operands.size() == 1) {
672       // Special case for reshape: the second op is optional in the old
673       // converter and kernel, so we create the second operand, which is
674       // required by the new converter, from the reshape op's option.
675       auto new_shape = op.builtin_options.AsReshapeOptions()->new_shape;
676       auto shape_type = RankedTensorType::get(
677           {static_cast<int64_t>(new_shape.size())}, builder.getIntegerType(32));
678 
679       mlir::SmallVector<mlir::Attribute, 4> shape;
680       for (auto s : new_shape) {
681         shape.push_back(builder.getI32IntegerAttr(static_cast<int32_t>(s)));
682       }
683       auto output_shape = DenseElementsAttr::get(shape_type, shape);
684       auto shape_op = builder.create<tfl::ConstOp>(loc, output_shape);
685       op_state.addOperands({shape_op});
686     }
687 
688     op_state.addTypes({type});
689   }
690 
691   // While the last several tensors could be optional tensors for an tfl op, the
692   // number of input operands could vary. Gets the min/max number of
693   // operands from tflite op name.
694   // Also, since the above code special-handles the `tfl.reshape` op and add an
695   // additional input, we put these function block here.
696   llvm::MinMax input_min_max = mlir::OperandNumbersMinMax(op_name);
697   int input_max_num = input_min_max.Max;
698   int op_input_num = op_state.operands.size();
699   if (input_max_num != 0 && input_max_num > op_input_num) {
700     // If the number of current inputs is less than the op definition, fill in
701     // with `none` value,
702     llvm::SmallVector<Value, 4> none_operands(
703         input_max_num - op_input_num,
704         builder.create<mlir::ConstantOp>(loc, builder.getNoneType(),
705                                          builder.getUnitAttr()));
706     op_state.addOperands(ArrayRef<Value>(none_operands));
707   }
708 
709   if (op_name == "tfl.lstm") {
710     // TODO(b/147587779): add the right region if region is empty.
711     op_state.addRegion();
712     TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
713                                           builder));
714   }
715   if (op_name == "tfl.while") {
716     // Adds two empty regions for "tfl.while". We will fill the regions after
717     // creating the callee functions because the "tfl.while" input/output types
718     // may be different with the callee functions, and the call ops need to sync
719     // with callee function types.
720     op_state.addRegion();
721     op_state.addRegion();
722   }
723   if (op_name == "tfl.unidirectional_sequence_lstm") {
724     TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
725                                           builder));
726   }
727   if (op_name == "tfl.reshape") {
728     // Flattern reshape ops when more than one dimension shape operand is given.
729     mlir::DenseIntElementsAttr shape_attr;
730     if (matchPattern(op_state.operands[1], m_Constant(&shape_attr))) {
731       auto shape_ty =
732           op_state.operands[1].getType().dyn_cast<RankedTensorType>();
733       if (shape_ty != nullptr && shape_ty.hasRank() && shape_ty.getRank() > 1) {
734         llvm::SmallVector<mlir::Attribute, 4> shape;
735         int32_t dim_size = 0;
736         for (const auto& dim : llvm::enumerate(shape_attr.getIntValues())) {
737           const int64_t size = dim.value().getSExtValue();
738           shape.push_back(
739               builder.getI32IntegerAttr(static_cast<int32_t>(size)));
740           ++dim_size;
741         }
742         auto shape_type = RankedTensorType::get(
743             {static_cast<int32_t>(dim_size)}, builder.getIntegerType(32));
744         auto output_shape = mlir::DenseElementsAttr::get(shape_type, shape);
745         auto shape_op = builder.create<tfl::ConstOp>(loc, output_shape);
746         op_state.operands[1] = shape_op;
747       }
748     }
749   }
750 
751   llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
752   auto builtin_code = tflite::GetBuiltinCode(&op_code);
753   if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
754     auto status = mlir::CustomOptionsToAttributes(
755         op_code.custom_code, op.custom_options, builder, loc, &attrs);
756     if (!status.ok()) {
757       return emitError(loc, status.ToString()), status;
758     }
759   } else {
760     mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs);
761   }
762   op_state.addAttributes(attrs);
763 
764   // Handle the conversion from subgraph index to functions for If and While. We
765   // will add CallOps in the region to call the functions later for While.
766   auto function_ref_attrs = ConvertSubgraphIdxsToFunctionAttrs(
767       op.builtin_options, func_names, builder);
768   op_state.addAttributes(function_ref_attrs);
769 
770   return builder.createOperation(op_state);
771 }
772 
773 // Returns indices of the given tensors in the subgraph. Returns error if a
774 // tensor name cannot be found in the subgraph.
GetTensorIndices(const tflite::SubGraphT & subgraph,const std::vector<std::string> & tensor_names)775 StatusOr<std::vector<int>> GetTensorIndices(
776     const tflite::SubGraphT& subgraph,
777     const std::vector<std::string>& tensor_names) {
778   absl::flat_hash_map<std::string, int> name_to_index;
779   for (auto index_and_tensor : llvm::enumerate(subgraph.tensors)) {
780     name_to_index[index_and_tensor.value()->name] = index_and_tensor.index();
781   }
782 
783   std::vector<int> indices;
784   indices.reserve(tensor_names.size());
785 
786   for (const auto& name : tensor_names) {
787     auto found = name_to_index.find(name);
788     if (found != name_to_index.end()) {
789       indices.push_back(found->second);
790     } else {
791       return errors::InvalidArgument("could not find tensor in subgraph: ",
792                                      name);
793     }
794   }
795 
796   return indices;
797 }
798 
799 // Given a list of tensor indices, returns a string of concatenated tensor names
800 // wrapped in a NamedAttribute.
801 template <typename ContainerType>
BuildTFEntryFunctionAttribute(const tflite::SubGraphT & subgraph,Builder * builder,const std::string name,const ContainerType indices)802 mlir::NamedAttribute BuildTFEntryFunctionAttribute(
803     const tflite::SubGraphT& subgraph, Builder* builder, const std::string name,
804     const ContainerType indices) {
805   auto tensor_names = llvm::map_range(
806       indices, [&](int i) { return subgraph.tensors.at(i)->name; });
807   return builder->getNamedAttr(
808       name, builder->getStringAttr(llvm::join(tensor_names, ",")));
809 }
810 
811 // Traverses the subgraph from output_indices to input_indices and returns the
812 // set of ops that are visited.
PruneSubgraph(const tflite::SubGraphT & subgraph,ArrayRef<int32_t> input_indices,ArrayRef<int32_t> output_indices)813 StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
814     const tflite::SubGraphT& subgraph, ArrayRef<int32_t> input_indices,
815     ArrayRef<int32_t> output_indices) {
816   // Create a map from tensor index to defining op.
817   absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
818   for (const auto& op : subgraph.operators) {
819     for (int32_t output : op->outputs) {
820       if (!llvm::is_contained(input_indices, output)) {
821         defining_op[output] = op.get();
822       }
823     }
824   }
825 
826   std::vector<const tflite::OperatorT*> queue;
827   for (int32_t output : output_indices) {
828     if (auto& op = defining_op[output]) {
829       queue.push_back(op);
830     }
831   }
832 
833   // Traverse the graph towards inputs.
834   absl::flat_hash_set<const tflite::OperatorT*> visited;
835   while (!queue.empty()) {
836     const tflite::OperatorT* op = queue.back();
837     queue.pop_back();
838     if (!visited.insert(op).second) {
839       // The node has already been visited.
840       continue;
841     }
842 
843     for (int32_t input : op->inputs) {
844       // Input tensor may not have a defining op in case it is a subgraph input
845       // or a constant tensor.
846       if (auto& op = defining_op[input]) {
847         queue.push_back(op);
848       }
849     }
850   }
851 
852   return visited;
853 }
854 
855 // We want to adjust the func op according to some cross ops information.
PostProcessFuncOp(FuncOp func)856 static StatusOr<FuncOp> PostProcessFuncOp(FuncOp func) {
857   OpBuilder builder(func);
858   // When a quantized constant is imported, its quantization parameter is set
859   // to be narrow range. Here revert to be the fully range if the user doesn't
860   // require narrow range.
861   func.walk([&](tfl::QConstOp cst) {
862     Value value = cst.getResult();
863     Value full_range_const = value;
864     for (auto& use : value.getUses()) {
865       Operation* user = use.getOwner();
866       if (user->hasTrait<mlir::OpTrait::IsTerminator>()) return;
867       auto qtype = mlir::quant::UniformQuantizedType::getQuantizedElementType(
868           value.getType());
869       // Only the 8-bit constants are imported with narrow range.
870       if (!qtype || qtype.getStorageTypeIntegralWidth() != 8) return;
871 
872       auto affine_user = llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
873       if (affine_user &&
874           affine_user.GetAffineOperandIndex() == use.getOperandNumber() &&
875           affine_user.RequiredNarrowRangeAffineOperand())
876         return;
877 
878       // Create a fully range quantized constant.
879       if (full_range_const == value) {
880         mlir::quant::QuantizedType new_qtype;
881         if (auto per_axis =
882                 qtype.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
883           new_qtype = mlir::quant::UniformQuantizedPerAxisType::get(
884               per_axis.getFlags(), per_axis.getStorageType(),
885               per_axis.getExpressedType(), per_axis.getScales(),
886               per_axis.getZeroPoints(), per_axis.getQuantizedDimension(),
887               per_axis.getStorageTypeMin() - 1, per_axis.getStorageTypeMax());
888         } else if (auto per_tensor =
889                        qtype.dyn_cast<mlir::quant::UniformQuantizedType>()) {
890           new_qtype = mlir::quant::UniformQuantizedType::get(
891               per_tensor.getFlags(), per_tensor.getStorageType(),
892               per_tensor.getExpressedType(), per_tensor.getScale(),
893               per_tensor.getZeroPoint(), per_tensor.getStorageTypeMin() - 1,
894               per_tensor.getStorageTypeMax());
895         } else {
896           return;
897         }
898         auto new_output_type = new_qtype.castFromExpressedType(
899             mlir::quant::UniformQuantizedType::castToExpressedType(
900                 value.getType()));
901         builder.setInsertionPointAfter(cst.getOperation());
902         auto new_op = builder.create<tfl::QConstOp>(
903             cst.getLoc(), new_output_type, mlir::TypeAttr::get(new_output_type),
904             cst.valueAttr());
905         full_range_const = new_op.output();
906       }
907       use.set(full_range_const);
908     }
909     if (cst.use_empty()) cst.erase();
910   });
911   return func;
912 }
913 
914 // Build a FuncOp from a tflite SubGraph
915 // The buffers are directly taken
916 // from the deserialized flatbuffer as we do not have the type information to
917 // interpret them until this point. The base_loc parameter is the location of
918 // the flatbuffer as a whole (usually a file). The is_entry_point flag
919 // controls whether shapeless types are treated as scalars. If
920 // ordered_output_arrays is not empty, then the imported mlir function will only
921 // return nodes in ordered_output_arrays in the same order.
ConvertSubgraph(const tflite::SubGraphT & subgraph,llvm::StringRef name,const std::vector<std::unique_ptr<tflite::OperatorCodeT>> & op_codes,const std::vector<std::string> & func_names,const std::vector<std::unique_ptr<tflite::BufferT>> & buffers,Location base_loc,Builder builder,bool is_entry_point,bool use_external_constant,const std::vector<std::string> & ordered_input_arrays,const std::vector<std::string> & ordered_output_arrays,bool experimental_prune_unreachable_nodes_unconditionally)922 StatusOr<FuncOp> ConvertSubgraph(
923     const tflite::SubGraphT& subgraph, llvm::StringRef name,
924     const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
925     const std::vector<std::string>& func_names,
926     const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
927     Location base_loc, Builder builder, bool is_entry_point,
928     bool use_external_constant,
929     const std::vector<std::string>& ordered_input_arrays,
930     const std::vector<std::string>& ordered_output_arrays,
931     bool experimental_prune_unreachable_nodes_unconditionally) {
932   llvm::SmallVector<mlir::Type, 2> ret_types;
933   llvm::SmallVector<mlir::Type, 4> input_types;
934 
935   auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc);
936 
937   std::vector<int> func_inputs = subgraph.inputs;
938   if (is_entry_point && !ordered_input_arrays.empty()) {
939     if (!experimental_prune_unreachable_nodes_unconditionally) {
940       // TODO(b/149922113): Resolve input-arrays/pruning flags interaction.
941       return errors::InvalidArgument(
942           "input-arrays should be used with experimental pruning flag");
943     }
944     TF_ASSIGN_OR_RETURN(func_inputs,
945                         GetTensorIndices(subgraph, ordered_input_arrays));
946   }
947 
948   for (int input : func_inputs) {
949     auto& tensor = *subgraph.tensors.at(input);
950     // TODO(b/138222071) Graph inputs must have static shape per the exporter,
951     // but we cannot differentiate scalars from unranked tensors.
952     // Here we reverse the default assumption that shape = [] means unranked.
953     // when processing main()
954     auto type_or_err = GetTensorType(tensor, builder,
955                                      /*shapeless_are_scalars=*/is_entry_point,
956                                      /*is_constant=*/false);
957     if (!type_or_err.ok()) {
958       emitError(func_loc, "error reading argument types")
959           << type_or_err.status().ToString();
960       return type_or_err.status();
961     }
962     auto type = type_or_err.ConsumeValueOrDie();
963     input_types.push_back(type);
964   }
965 
966   llvm::SmallVector<bool, 16> is_op_output(subgraph.tensors.size(), false);
967   for (auto& op : subgraph.operators) {
968     for (auto output : op->outputs) {
969       is_op_output[output] = true;
970     }
971   }
972 
973   std::vector<int> func_outputs = subgraph.outputs;
974   if (is_entry_point && !ordered_output_arrays.empty()) {
975     TF_ASSIGN_OR_RETURN(func_outputs,
976                         GetTensorIndices(subgraph, ordered_output_arrays));
977   }
978 
979   for (auto output : func_outputs) {
980     const bool is_func_input = std::find(func_inputs.begin(), func_inputs.end(),
981                                          output) != func_inputs.end();
982     bool is_constant = !is_op_output[output] && !is_func_input;
983     // There are 2 cases tensor is scalar when it doesn't have a shape in
984     // flatbuffer:
985     // 1. `is_constant` = true, means this tensor is created from a constant op.
986     // 2. `is_func_input` = true and `is_entry_point` = true, which means this
987     // tensor is function input and function input type is a scalar tensor.
988     const bool shapeless_is_scalar =
989         is_constant || (is_func_input && is_entry_point);
990     auto type_or_err = GetTensorType(*subgraph.tensors.at(output), builder,
991                                      shapeless_is_scalar,
992                                      /*is_constant=*/is_constant);
993     if (!type_or_err.ok()) {
994       emitError(func_loc, "error reading return types")
995           << type_or_err.status().ToString();
996       return type_or_err.status();
997     }
998     auto type = type_or_err.ConsumeValueOrDie();
999     ret_types.push_back(type);
1000   }
1001   auto func_type = builder.getFunctionType(input_types, ret_types);
1002 
1003   // Construct function object
1004   auto func = FuncOp::create(func_loc, name, func_type, /* attrs= */ {});
1005   func.addEntryBlock();
1006   auto& body = func.getBody();
1007   OpBuilder op_builder{body};
1008 
1009   std::vector<Value> vals_map(subgraph.tensors.size(), nullptr);
1010   Value maybe_optional_arg_marker = nullptr;
1011 
1012   // Get or construct MLIR values for each input
1013   for (int i = 0, e = func_inputs.size(); i < e; i++) {
1014     auto input_tensor = func_inputs[i];
1015     const auto& tensor = *subgraph.tensors.at(input_tensor);
1016     auto loc = TensorLoc(tensor, builder, base_loc);
1017     if (vals_map[input_tensor]) {
1018       auto err = errors::FailedPrecondition("duplicate input arguments");
1019       return emitError(loc, err.ToString()), err;
1020     }
1021     Value input_value = func.getArgument(i);
1022 
1023     // If the `tensor` has min/max and doesn't have scale/zero_point
1024     // information, a stats op is created to use the input_value, then the
1025     // `tensor` should be mapped to the result of this new stats op.
1026     if (auto stats_op =
1027             ConvertMinMaxToStatsOp(tensor, op_builder, input_value)) {
1028       vals_map[input_tensor] = stats_op->getResult(0);
1029     } else {
1030       vals_map[input_tensor] = input_value;
1031     }
1032   }
1033 
1034   // Set tf.entry_function attribute
1035   if (is_entry_point) {
1036     llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
1037     if (!func_inputs.empty()) {
1038       attributes.push_back(BuildTFEntryFunctionAttribute(
1039           subgraph, &builder, "inputs", func_inputs));
1040     }
1041     if (!func_outputs.empty()) {
1042       attributes.push_back(BuildTFEntryFunctionAttribute(
1043           subgraph, &builder, "outputs", func_outputs));
1044     }
1045     func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
1046   } else {
1047     func.setPrivate();
1048   }
1049 
1050   absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
1051   if (experimental_prune_unreachable_nodes_unconditionally) {
1052     TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
1053                         PruneSubgraph(subgraph, func_inputs, func_outputs));
1054   }
1055 
1056   // Construct MLIR operators from TFLite operators
1057   for (auto& op : subgraph.operators) {
1058     if (experimental_prune_unreachable_nodes_unconditionally &&
1059         !pruned_subgraph_ops.contains(op)) {
1060       continue;
1061     }
1062 
1063     for (auto input_num : op->inputs) {
1064       // The operators in a graph are topologically sorted
1065       // and so if no previous operation has produced a tensor
1066       // it must be a constant.
1067       if (input_num == -1) {
1068         if (maybe_optional_arg_marker == nullptr) {
1069           maybe_optional_arg_marker =
1070               op_builder
1071                   .create<mlir::ConstantOp>(base_loc, builder.getNoneType(),
1072                                             builder.getUnitAttr())
1073                   .getResult();
1074         }
1075       } else if (!vals_map.at(input_num)) {
1076         auto& const_tensor = *subgraph.tensors[input_num];
1077         auto const_loc = TensorLoc(const_tensor, builder, base_loc);
1078         auto op_or_err =
1079             use_external_constant
1080                 ? BuildExternalConstOp(const_tensor, const_tensor.buffer,
1081                                        op_builder, const_loc)
1082                 : BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
1083                                const_tensor.is_variable, op_builder, const_loc);
1084         if (!op_or_err.ok()) {
1085           return emitError(const_loc, op_or_err.status().ToString()),
1086                  op_or_err.status();
1087         }
1088         vals_map[input_num] = op_or_err.ValueOrDie()->getResult(0);
1089       }
1090     }
1091 
1092     // Intermediate tensors for LSTMs are used to carry quantization range
1093     // in their types, so we only need and extract their types.
1094     std::vector<mlir::TensorType> intermediate_types;
1095     intermediate_types.reserve(5);
1096     for (auto intermediate : op->intermediates) {
1097       TF_ASSIGN_OR_RETURN(
1098           auto type, GetTensorType(*subgraph.tensors[intermediate], builder,
1099                                    /*shapeless_are_scalars=*/true,
1100                                    /*is_constant=*/false,
1101                                    /*is_intermediate=*/true));
1102       intermediate_types.emplace_back(type);
1103     }
1104 
1105     // The NameLoc corresponding to the name of the first output tensor
1106     auto op_loc =
1107         op->outputs.empty()
1108             ? base_loc
1109             : TensorLoc(*subgraph.tensors[op->outputs[0]], builder, base_loc);
1110     // If there's an optional argument, maybe_optional_arg_marker has been set
1111     // to a valid Value
1112     TF_ASSIGN_OR_RETURN(
1113         auto* mlir_op,
1114         ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
1115                   op_codes, func_names, subgraph.tensors, op_loc, op_builder));
1116 
1117     // Add the results to the value maps. There are two cases: 1. the result
1118     // tensor does not have min/max values, the original op result is used
1119     // directly; 2. the result tensor has some min/max values, a stats op is
1120     // created, then the result of the stats op is used.
1121     for (auto pair : llvm::enumerate(mlir_op->getResults())) {
1122       int output_tensor_index = op->outputs[pair.index()];
1123       auto& tensor = *subgraph.tensors[output_tensor_index];
1124       if (auto stats_op =
1125               ConvertMinMaxToStatsOp(tensor, op_builder, pair.value())) {
1126         vals_map[output_tensor_index] = stats_op->getResult(0);
1127       } else {
1128         vals_map[output_tensor_index] = pair.value();
1129       }
1130     }
1131   }
1132 
1133   // Construct return values
1134   llvm::SmallVector<Value, 4> return_operands;
1135   for (auto index : func_outputs) {
1136     if (!vals_map.at(index)) {
1137       auto& const_tensor = *subgraph.tensors[index];
1138       auto const_loc = TensorLoc(const_tensor, builder, base_loc);
1139       auto op_or_err =
1140           use_external_constant
1141               ? BuildExternalConstOp(const_tensor, const_tensor.buffer,
1142                                      op_builder, const_loc)
1143               : BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
1144                              const_tensor.is_variable, op_builder, const_loc);
1145       if (!op_or_err.ok()) {
1146         return emitError(const_loc, op_or_err.status().ToString()),
1147                op_or_err.status();
1148       }
1149       vals_map[index] = op_or_err.ValueOrDie()->getResult(0);
1150     }
1151     return_operands.push_back(vals_map[index]);
1152   }
1153 
1154   op_builder.create<mlir::ReturnOp>(base_loc, return_operands);
1155 
1156   return PostProcessFuncOp(func);
1157 }
1158 
1159 // TFLite subgraphs do not necessarily have names, though MLIR functions must
1160 // have them, so we generate a name for subgraphs that are missing one here.
1161 // Note: in TFLite, the first subgraph is the entry point, and in MLIR that
1162 // represents TFLite, this entry point must be called "main"
1163 // TODO(b/131175224,b/132239787) Support multiple entry points
SubgraphName(unsigned index,const tflite::SubGraphT & subgraph)1164 std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
1165   if (index == 0) {
1166     return "main";
1167   }
1168   if (subgraph.name.empty()) {
1169     return llvm::formatv("fn_{0}", index).str();
1170   }
1171   return subgraph.name;
1172 }
1173 
1174 // Adds a CallOp in `region` to call the `func` and returns the results of
1175 // CallOp.
AddCallOpInWhileOpRegion(mlir::Region & region,mlir::FuncOp func)1176 void AddCallOpInWhileOpRegion(mlir::Region& region, mlir::FuncOp func) {
1177   OpBuilder op_builder{region};
1178   region.push_back(new mlir::Block());
1179   region.addArguments(func.getType().getInputs());
1180   op_builder.setInsertionPointToStart(&region.front());
1181   auto call_op = op_builder.create<mlir::CallOp>(
1182       region.getLoc(), func.getType().getResults(), func.sym_name(),
1183       region.getArguments());
1184   op_builder.create<mlir::TFL::YieldOp>(region.getLoc(), call_op.getResults());
1185 }
1186 
1187 // TFL::WhileOp has regions, so we add CallOp to call the FuncOp in the regions
1188 // if we have while ops.
AddRegionsForTflWhileOp(mlir::ModuleOp module)1189 void AddRegionsForTflWhileOp(mlir::ModuleOp module) {
1190   mlir::SymbolTable symbol_table(module);
1191   module.walk([&](mlir::TFL::WhileOp while_op) {
1192     auto cond = symbol_table.lookup<mlir::FuncOp>(
1193         while_op->getAttr("cond").cast<mlir::FlatSymbolRefAttr>().getValue());
1194     AddCallOpInWhileOpRegion(while_op.cond(), cond);
1195     while_op.removeAttr("cond");
1196     auto body = symbol_table.lookup<mlir::FuncOp>(
1197         while_op->getAttr("body").cast<mlir::FlatSymbolRefAttr>().getValue());
1198     AddCallOpInWhileOpRegion(while_op.body(), body);
1199     while_op.removeAttr("body");
1200   });
1201 }
1202 }  // namespace
1203 
FlatBufferToMlir(absl::string_view buffer,MLIRContext * context,Location base_loc,bool use_external_constant,const std::vector<std::string> & ordered_input_arrays,const std::vector<std::string> & ordered_output_arrays,bool experimental_prune_unreachable_nodes_unconditionally)1204 OwningModuleRef tflite::FlatBufferToMlir(
1205     absl::string_view buffer, MLIRContext* context, Location base_loc,
1206     bool use_external_constant,
1207     const std::vector<std::string>& ordered_input_arrays,
1208     const std::vector<std::string>& ordered_output_arrays,
1209     bool experimental_prune_unreachable_nodes_unconditionally) {
1210   context->loadDialect<
1211       mlir::StandardOpsDialect, mlir::quant::QuantizationDialect,
1212       mlir::TFL::TensorFlowLiteDialect, mlir::TF::TensorFlowDialect>();
1213 
1214   auto model_ptr =
1215       FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
1216   if (nullptr == model_ptr) {
1217     return emitError(base_loc, "couldn't parse flatbuffer"), nullptr;
1218   }
1219 
1220   std::unique_ptr<ModelT> model(model_ptr->GetModel()->UnPack());
1221 
1222   auto builder = Builder(context);
1223 
1224   std::vector<std::string> func_names;
1225   for (auto& subgraph : model->subgraphs) {
1226     func_names.push_back(subgraph->name);
1227   }
1228 
1229   auto module = mlir::ModuleOp::create(base_loc);
1230   // We currently don't use this to make decisions, but we could
1231   // use it in exports or if there are breaking changes
1232   module->setAttr("tfl.schema_version",
1233                   builder.getI32IntegerAttr(model->version));
1234   if (!model->description.empty()) {
1235     module->setAttr("tfl.description",
1236                     builder.getStringAttr(model->description));
1237   }
1238 
1239   for (auto e : llvm::enumerate(model->subgraphs)) {
1240     auto& subgraph = e.value();
1241     std::string name = SubgraphName(e.index(), *subgraph);
1242     auto func_or_error = ConvertSubgraph(
1243         *subgraph, name, model->operator_codes, func_names, model->buffers,
1244         base_loc, builder,
1245         // TODO(b/131175224,b/132239787) Support multiple entry points
1246         /*is_entry_point=*/e.index() == 0,
1247         /*use_external_constant=*/use_external_constant, ordered_input_arrays,
1248         ordered_output_arrays,
1249         experimental_prune_unreachable_nodes_unconditionally);
1250     if (!func_or_error.ok()) {
1251       return emitError(base_loc, "could not translate function ")
1252                  << subgraph->name << ": "
1253                  << func_or_error.status().error_message(),
1254              nullptr;
1255     }
1256     module.push_back(func_or_error.ConsumeValueOrDie());
1257   }
1258   AddRegionsForTflWhileOp(module);
1259   return OwningModuleRef(module);
1260 }
1261