1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
17
18 #include <algorithm>
19 #include <cctype>
20 #include <cstdint>
21 #include <iostream>
22 #include <sstream>
23 #include <string>
24 #include <vector>
25
26 #include "absl/base/casts.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/strings/string_view.h"
30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/ADT/APInt.h"
32 #include "llvm/ADT/ArrayRef.h"
33 #include "llvm/ADT/DenseMap.h"
34 #include "llvm/ADT/None.h"
35 #include "llvm/ADT/Optional.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/SmallVector.h"
38 #include "llvm/ADT/StringExtras.h"
39 #include "llvm/ADT/StringRef.h"
40 #include "llvm/ADT/Twine.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Support/CommandLine.h"
43 #include "llvm/Support/Endian.h"
44 #include "llvm/Support/FormatVariadic.h"
45 #include "llvm/Support/MemoryBuffer.h"
46 #include "llvm/Support/SourceMgr.h"
47 #include "llvm/Support/raw_ostream.h"
48 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
49 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
50 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
51 #include "mlir/IR/Attributes.h" // from @llvm-project
52 #include "mlir/IR/Builders.h" // from @llvm-project
53 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
54 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
55 #include "mlir/IR/Diagnostics.h" // from @llvm-project
56 #include "mlir/IR/Location.h" // from @llvm-project
57 #include "mlir/IR/MLIRContext.h" // from @llvm-project
58 #include "mlir/IR/Operation.h" // from @llvm-project
59 #include "mlir/IR/OperationSupport.h" // from @llvm-project
60 #include "mlir/IR/Types.h" // from @llvm-project
61 #include "mlir/IR/Value.h" // from @llvm-project
62 #include "mlir/Support/LLVM.h" // from @llvm-project
63 #include "mlir/Translation.h" // from @llvm-project
64 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
65 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
66 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
67 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
70 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
71 #include "tensorflow/compiler/xla/statusor.h"
72 #include "tensorflow/core/framework/tensor.pb.h"
73 #include "tensorflow/core/framework/tensor_shape.pb.h"
74 #include "tensorflow/core/platform/errors.h"
75 #include "tensorflow/core/platform/status.h"
76 #include "tensorflow/lite/model.h"
77 #include "tensorflow/lite/schema/schema_generated.h"
78 #include "tensorflow/lite/schema/schema_utils.h"
79
80 using llvm::ArrayRef;
81 using mlir::Builder;
82 using mlir::DenseElementsAttr;
83 using mlir::FuncOp;
84 using mlir::Location;
85 using mlir::MLIRContext;
86 using mlir::OpBuilder;
87 using mlir::Operation;
88 using mlir::OperationState;
89 using mlir::OwningModuleRef;
90 using mlir::RankedTensorType;
91 using mlir::UnrankedTensorType;
92 using mlir::Value;
93 using mlir::quant::QuantizedType;
94 using tflite::TensorT;
95 using xla::Status;
96 using xla::StatusOr;
97
98 namespace errors = tensorflow::errors;
99 namespace tfl = mlir::TFL;
100
101 namespace {
IsScalar(const TensorT & tensor)102 bool IsScalar(const TensorT& tensor) {
103 // TODO(b/138222071) We can't distinguish scalars and unranked tensors
104 // Work out a way to handle this and stub out the code until then
105 return tensor.shape.empty() && false;
106 }
107
IsQuantized(const TensorT & tensor)108 bool IsQuantized(const TensorT& tensor) {
109 return (tensor.quantization != nullptr) &&
110 !tensor.quantization->zero_point.empty();
111 }
112
113 // Create the MLIR NamedLoc location corresponding to a given tensor
TensorLoc(const TensorT & tensor,Builder builder,Location base)114 Location TensorLoc(const TensorT& tensor, Builder builder, Location base) {
115 if (tensor.name.empty()) {
116 return base;
117 }
118 return mlir::NameLoc::get(builder.getIdentifier(tensor.name), base);
119 }
120
121 // Returns the correct type for a quantized tensor
122 // We have a special case for constants since they have a higher minimum value.
GetQuantizedType(const TensorT & tensor,Builder builder,bool is_constant=false)123 StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
124 bool is_constant = false) {
125 tflite::QuantizationParametersT& quant_params = *tensor.quantization;
126 if (quant_params.details.AsCustomQuantization()) {
127 return errors::Unimplemented("Cannot handle experimental quantization");
128 }
129
130 bool is_signed = true;
131 mlir::IntegerType storage_type;
132 if (tensor.type == tflite::TensorType_UINT8) {
133 is_signed = false;
134 storage_type = builder.getIntegerType(8);
135 } else {
136 auto raw_elem_type = ConvertElementType(tensor.type, builder);
137 if (!raw_elem_type.isa<mlir::IntegerType>()) {
138 return errors::InvalidArgument(
139 "Quantized tensors must be stored as integers");
140 }
141 storage_type = raw_elem_type.cast<mlir::IntegerType>();
142 }
143
144 // TFlite uses narrow-range [u]int8 for constant buffers of quantized weights.
145 // Since we don't know which ones are weights, we represent this optimization
146 // as a change in the storage bounds for the type for all constants of this
147 // type.
148 bool is_weight_buffer = is_constant && (storage_type.getWidth() == 8);
149
150 int64_t storage_min = QuantizedType::getDefaultMinimumForInteger(
151 is_signed, storage_type.getWidth()) +
152 static_cast<int>(is_weight_buffer);
153 int64_t storage_max = QuantizedType::getDefaultMaximumForInteger(
154 is_signed, storage_type.getWidth());
155 uint32_t flags =
156 is_signed ? mlir::quant::QuantizationFlags::FlagValue::Signed : 0;
157
158 // Rejects if quantized tensors have zero scales.
159 for (float scale : quant_params.scale) {
160 if (scale == 0) {
161 return errors::InvalidArgument(
162 "Quantized tensors must have non-zero scales");
163 }
164 }
165
166 // Scale size can't be zero as it is checked before.
167 if (quant_params.scale.size() != 1) {
168 llvm::SmallVector<double, 4> scales(quant_params.scale.begin(),
169 quant_params.scale.end());
170 return mlir::quant::UniformQuantizedPerAxisType::get(
171 flags, storage_type, builder.getF32Type(), scales,
172 quant_params.zero_point, quant_params.quantized_dimension, storage_min,
173 storage_max);
174 }
175 return mlir::quant::UniformQuantizedType::get(
176 flags, storage_type, builder.getF32Type(), quant_params.scale.at(0),
177 quant_params.zero_point.at(0), storage_min, storage_max);
178 }
179
180 // import float tensor with calibration value into calibrated quantized type.
GetCalibratedQuantizedType(const TensorT & tensor,Builder builder)181 StatusOr<QuantizedType> GetCalibratedQuantizedType(const TensorT& tensor,
182 Builder builder) {
183 if (tensor.quantization == nullptr) {
184 return errors::InvalidArgument("The tensor is not quantized.");
185 }
186 auto raw_elem_type = ConvertElementType(tensor.type, builder);
187 float min = tensor.quantization->min[0];
188 float max = tensor.quantization->max[0];
189 return mlir::quant::CalibratedQuantizedType::get(raw_elem_type, min, max);
190 }
191
192 // TODO(b/138222071) Remove shapeless_are_scalars once we can reliably
193 // make that distinction and don't have to rely on context
194 // (input to main and constants must have static shape)
GetTensorType(const TensorT & tensor,Builder builder,bool shapeless_are_scalars=false,bool is_constant=false,bool is_intermediate=false)195 StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
196 bool shapeless_are_scalars = false,
197 bool is_constant = false,
198 bool is_intermediate = false) {
199 mlir::Type elem_type = ConvertElementType(tensor.type, builder);
200 // TODO(b/139554398) Store min/max (even for non-quantized tensors) somewhere
201 // if it's set
202 if (IsQuantized(tensor)) {
203 TF_ASSIGN_OR_RETURN(elem_type,
204 GetQuantizedType(tensor, builder, is_constant));
205 }
206
207 // Intermediate tensors with calibration value (but not scale and zero points)
208 // should return calibrated quantized type.
209 if (is_intermediate && tensor.quantization != nullptr &&
210 !IsQuantized(tensor)) {
211 TF_ASSIGN_OR_RETURN(elem_type, GetCalibratedQuantizedType(tensor, builder));
212 }
213
214 if (IsScalar(tensor) || (shapeless_are_scalars && tensor.shape.empty())) {
215 return RankedTensorType::get({}, elem_type);
216 }
217
218 if (!tensor.shape_signature.empty()) {
219 llvm::SmallVector<int64_t, 4> shape(tensor.shape_signature.begin(),
220 tensor.shape_signature.end());
221 return RankedTensorType::get(shape, elem_type);
222 }
223
224 if (!tensor.shape.empty()) {
225 llvm::SmallVector<int64_t, 4> shape(tensor.shape.begin(),
226 tensor.shape.end());
227 return RankedTensorType::get(shape, elem_type);
228 }
229
230 return UnrankedTensorType::get(elem_type);
231 }
232
233 // Extract the min max information in the tensor and create the quant stats op.
234 // If the input `tensor` has scale/zero_point, `res` should have quantized
235 // type, thus none stats op is required and nullptr is retruned.
236 // If the min max information is invalid, nullptr is returned.
ConvertMinMaxToStatsOp(const TensorT & tensor,OpBuilder b,Value res)237 mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
238 Value res) {
239 // If the `tensor` has scale/zero_point, it must have been quantized, then the
240 // min/max stats is just for comments, so ignore it.
241 if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
242 // If the result isn't float and unquantizable, the min/max is ignored.
243 if (!res.getType()
244 .cast<mlir::ShapedType>()
245 .getElementType()
246 .isa<mlir::FloatType>()) {
247 return nullptr;
248 }
249 auto mins = tensor.quantization->min;
250 auto maxs = tensor.quantization->max;
251 if (mins.size() != maxs.size() || mins.empty()) return nullptr;
252
253 llvm::SmallVector<llvm::APFloat, 4> min_maxs;
254 min_maxs.reserve(mins.size() * 2);
255 for (int i = 0, end = mins.size(); i < end; ++i) {
256 llvm::APFloat min(mins[i]);
257 llvm::APFloat max(maxs[i]);
258 min_maxs.push_back(min);
259 min_maxs.push_back(max);
260 }
261 // The layer stats contain only the first min/max pairs.
262 mlir::ElementsAttr layer_stats = mlir::DenseFPElementsAttr::get(
263 mlir::RankedTensorType::get({2}, b.getF32Type()),
264 {min_maxs[0], min_maxs[1]});
265 mlir::ElementsAttr axis_stats;
266 mlir::IntegerAttr axis;
267 if (mins.size() > 1) {
268 llvm::SmallVector<int64_t, 4> axis_stats_shape{
269 static_cast<int64_t>(mins.size()), 2};
270 axis_stats = mlir::DenseFPElementsAttr::get(
271 mlir::RankedTensorType::get(axis_stats_shape, b.getF32Type()),
272 min_maxs);
273 // TODO(fengliuai): this quantization dimension isn't correct.
274 axis = b.getI64IntegerAttr(tensor.quantization->quantized_dimension);
275 }
276 return b.create<mlir::quant::StatisticsOp>(b.getUnknownLoc(), res,
277 layer_stats, axis_stats, axis);
278 }
279
280 // Returns true if this is a basic LSTM op.
IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union)281 bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
282 if (const auto* op = op_union.AsLSTMOptions()) {
283 return op->kernel_type == tflite::LSTMKernelType_BASIC;
284 } else {
285 return false;
286 }
287 }
288
289 // Gets the MLIR op name with the dialect name for the flatbuffer operator.
GetMlirOpName(const tflite::OperatorT & op,const tflite::OperatorCodeT & op_code)290 StatusOr<std::string> GetMlirOpName(const tflite::OperatorT& op,
291 const tflite::OperatorCodeT& op_code) {
292 if (IsBasicLSTMOp(op.builtin_options)) {
293 return std::string("tfl.basic_lstm");
294 }
295
296 auto builtin_code = tflite::GetBuiltinCode(&op_code);
297 if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
298 return std::string("tfl.custom");
299 }
300 if (builtin_code == tflite::BuiltinOperator_IF) {
301 return std::string("tf.If");
302 }
303 if (builtin_code == tflite::BuiltinOperator_WHILE) {
304 return std::string("tfl.while");
305 }
306
307 llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code));
308 return llvm::Twine("tfl.", op_name.lower()).str();
309 }
310
311 // The buffers in TFLite flatbuffers have their contents stored as a vector of
312 // bytes that represent little-endian values.
313 // The read_size parameter is present to allow reading both float16 and float32s
314 // without a case split.
315 template <typename T>
ReadAsLittleEndian(ArrayRef<uint8_t> bytes)316 std::vector<T> ReadAsLittleEndian(ArrayRef<uint8_t> bytes) {
317 std::vector<T> ret;
318 size_t read_size = sizeof(T);
319 int bytes_len = bytes.size();
320 assert(bytes_len % read_size == 0);
321
322 int elem_count = bytes_len / read_size;
323 ret.reserve(elem_count);
324
325 const char* data_ptr = reinterpret_cast<const char*>(bytes.data());
326 for (int i = 0; i < elem_count; i++) {
327 ret.push_back(
328 llvm::support::endian::readNext<T, llvm::support::little,
329 llvm::support::unaligned>(data_ptr));
330 }
331 return ret;
332 }
333
ConvertTfliteConstTensor(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer)334 tensorflow::TensorProto ConvertTfliteConstTensor(
335 const tflite::TensorT& tensor, const std::vector<uint8_t>& buffer) {
336 tensorflow::TensorProto ret;
337 ret.set_dtype(TflTypeToTfType(tensor.type));
338
339 tensorflow::TensorShapeProto* shape = ret.mutable_tensor_shape();
340 shape->set_unknown_rank(false);
341 for (auto dim : tensor.shape) {
342 shape->add_dim()->set_size(int64_t{dim});
343 }
344 std::string content;
345 content.assign(reinterpret_cast<const char*>(buffer.data()), buffer.size());
346 ret.set_tensor_content(content);
347 return ret;
348 }
349
ConvertFloatBuffer(mlir::RankedTensorType shaped_type,mlir::FloatType elem_type,const std::vector<uint8_t> & buffer)350 StatusOr<mlir::ElementsAttr> ConvertFloatBuffer(
351 mlir::RankedTensorType shaped_type, mlir::FloatType elem_type,
352 const std::vector<uint8_t>& buffer) {
353 size_t bytes_len = buffer.size();
354
355 // The bytes of floats are stored little-endian.
356 switch (elem_type.getWidth()) {
357 case 16: {
358 assert(bytes_len % 2 == 0);
359 int elem_count = bytes_len / 2;
360 std::vector<llvm::APFloat> values;
361 values.reserve(elem_count);
362
363 const char* data = reinterpret_cast<const char*>(buffer.data());
364 auto& semantics = elem_type.getFloatSemantics();
365
366 for (int i = 0; i < elem_count; i++) {
367 uint16_t bit_repr =
368 llvm::support::endian::readNext<uint16_t, llvm::support::little,
369 llvm::support::unaligned>(data);
370 llvm::APInt int_repr(16, bit_repr);
371 values.emplace_back(semantics, int_repr);
372 }
373
374 return DenseElementsAttr::get(shaped_type, values);
375 }
376 case 32: {
377 assert(bytes_len % 4 == 0);
378 int elem_count = bytes_len / 4;
379 std::vector<float> values;
380 values.reserve(elem_count);
381
382 const char* data = reinterpret_cast<const char*>(buffer.data());
383
384 for (int i = 0; i < elem_count; i++) {
385 uint32_t bit_repr =
386 llvm::support::endian::readNext<uint32_t, llvm::support::little,
387 llvm::support::unaligned>(data);
388 values.push_back(absl::bit_cast<float>(bit_repr));
389 }
390 return DenseElementsAttr::get(shaped_type, ArrayRef<float>(values));
391 }
392 case 64: {
393 assert(bytes_len % 8 == 0);
394 int elem_count = bytes_len / 8;
395 std::vector<double> values;
396 values.reserve(elem_count);
397
398 const char* data = reinterpret_cast<const char*>(buffer.data());
399
400 for (int i = 0; i < elem_count; i++) {
401 uint64_t bit_repr =
402 llvm::support::endian::readNext<uint64_t, llvm::support::little,
403 llvm::support::unaligned>(data);
404 values.push_back(absl::bit_cast<double>(bit_repr));
405 }
406 return DenseElementsAttr::get(shaped_type, ArrayRef<double>(values));
407 }
408 }
409 return errors::InvalidArgument("unsupported bit width", elem_type.getWidth());
410 }
411
ConvertIntBuffer(mlir::RankedTensorType shaped_type,mlir::Type elem_type,const std::vector<uint8_t> & buffer)412 StatusOr<mlir::ElementsAttr> ConvertIntBuffer(
413 mlir::RankedTensorType shaped_type, mlir::Type elem_type,
414 const std::vector<uint8_t>& buffer) {
415 unsigned bit_width;
416 if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) {
417 bit_width = itype.getWidth();
418 } else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) {
419 bit_width = qtype.getStorageTypeIntegralWidth();
420 shaped_type = mlir::RankedTensorType::get(shaped_type.getShape(),
421 qtype.getStorageType());
422 } else {
423 return errors::InvalidArgument("unsupported integer constant type");
424 }
425
426 switch (bit_width) {
427 case 1: {
428 // vector<bool> doesn't convert to an ArrayRef
429 llvm::SmallVector<bool, 8> values;
430 values.reserve(buffer.size());
431 for (auto b : buffer) {
432 values.emplace_back(b != 0);
433 }
434 return DenseElementsAttr::get(shaped_type, ArrayRef<bool>(values));
435 }
436 case 8: {
437 return DenseElementsAttr::get(shaped_type, ArrayRef<uint8_t>(buffer));
438 }
439 case 16: {
440 auto values = ReadAsLittleEndian<uint16_t>(buffer);
441 return DenseElementsAttr::get(shaped_type, ArrayRef<uint16_t>(values));
442 }
443 case 32: {
444 auto values = ReadAsLittleEndian<uint32_t>(buffer);
445 return DenseElementsAttr::get(shaped_type, ArrayRef<uint32_t>(values));
446 }
447 case 64: {
448 auto values = ReadAsLittleEndian<uint64_t>(buffer);
449 return DenseElementsAttr::get(shaped_type, ArrayRef<uint64_t>(values));
450 }
451 default:
452 return errors::Unimplemented("Cannot handle bit width ", bit_width);
453 }
454 }
455
BuildExternalConstOp(const tflite::TensorT & tensor,int32_t buffer_index,OpBuilder builder,Location loc)456 StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
457 int32_t buffer_index,
458 OpBuilder builder, Location loc) {
459 TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
460 /*shapeless_are_scalars=*/true,
461 /*is_constant=*/true));
462 auto shaped_type = type.dyn_cast<mlir::RankedTensorType>();
463 if (!shaped_type) {
464 return errors::Internal("Constant doesn't have a shape");
465 }
466 auto op = builder.create<tfl::ExternalConstOp>(
467 loc, shaped_type, builder.getI32IntegerAttr(buffer_index));
468 return op.getOperation();
469 }
470
471 // Gets a constant splat for the given value of type. Requires value to be of
472 // type static shaped RankedTensorType. `unique_index` is used to get the unique
473 // value for the attribute.
GetSplat(RankedTensorType type,int unique_index,OpBuilder builder)474 static mlir::ElementsAttr GetSplat(RankedTensorType type, int unique_index,
475 OpBuilder builder) {
476 mlir::Type element_ty = getElementTypeOrSelf(type);
477
478 if (element_ty.isSignlessInteger())
479 return DenseElementsAttr::get(
480 type, builder.getIntegerAttr(element_ty, unique_index));
481
482 if (element_ty.isa<mlir::FloatType>())
483 return DenseElementsAttr::get(
484 type, builder.getFloatAttr(element_ty, unique_index));
485
486 if (auto qtype = element_ty.dyn_cast<QuantizedType>()) {
487 mlir::RankedTensorType new_type =
488 RankedTensorType::get(type.getShape(), qtype.getStorageType());
489 return DenseElementsAttr::get(
490 new_type, builder.getIntegerAttr(qtype.getStorageType(), unique_index));
491 }
492 llvm_unreachable("unhandled element type");
493 }
494
495 // TODO(b/172664358): Creates a new op instead of reusing constant op.
496 // Creates a constant op to represent stateful variable. The function static
497 // variable `stateful_variable_idx` is used as a unique value for each constant
498 // to avoid CSEed. `tensor` is the data structure of flatbuffer. `shaped_type`
499 // is the ShapedType for the const op.
BuildVariableOp(const tflite::TensorT & tensor,mlir::RankedTensorType shaped_type,OpBuilder builder,Location loc)500 Operation* BuildVariableOp(const tflite::TensorT& tensor,
501 mlir::RankedTensorType shaped_type,
502 OpBuilder builder, Location loc) {
503 static int stateful_variable_idx = 0;
504 mlir::ElementsAttr value =
505 GetSplat(shaped_type, stateful_variable_idx++, builder);
506 if (IsQuantized(tensor)) {
507 auto op = builder.create<tfl::QConstOp>(
508 loc, mlir::TypeAttr::get(shaped_type), value);
509 return op.getOperation();
510 }
511 auto op = builder.create<tfl::ConstOp>(loc, value);
512 if (tensor.quantization && !tensor.quantization->min.empty()) {
513 if (auto stats_op =
514 ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) {
515 return stats_op;
516 }
517 }
518 return op.getOperation();
519 }
520
BuildConstOp(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer,bool is_variable,OpBuilder builder,Location loc)521 StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
522 const std::vector<uint8_t>& buffer,
523 bool is_variable, OpBuilder builder,
524 Location loc) {
525 TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
526 /*shapeless_are_scalars=*/true,
527 /*is_constant=*/true));
528 auto shaped_type = type.dyn_cast<mlir::RankedTensorType>();
529 if (!shaped_type) {
530 return errors::Internal("Constant doesn't have a shape");
531 }
532
533 auto elem_type = shaped_type.getElementType();
534
535 mlir::ElementsAttr value;
536 if (is_variable) {
537 return BuildVariableOp(tensor, shaped_type, builder, loc);
538 } else if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
539 TF_ASSIGN_OR_RETURN(value,
540 ConvertFloatBuffer(shaped_type, float_type, buffer));
541 } else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
542 TF_ASSIGN_OR_RETURN(value,
543 ConvertIntBuffer(shaped_type, elem_type, buffer));
544 } else if (elem_type.isa<mlir::TF::StringType>()) {
545 tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
546 std::vector<llvm::StringRef> refs;
547 refs.reserve(repr.string_val_size());
548
549 for (const auto& ref : repr.string_val())
550 refs.push_back({ref.data(), ref.size()});
551
552 value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
553 } else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
554 auto dialect = elem_type.getContext()->getLoadedDialect("tf");
555 tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
556 std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
557
558 value = mlir::OpaqueElementsAttr::get(dialect, shaped_type, mangled);
559 } else {
560 return errors::Unimplemented("Constant of unsupported type");
561 }
562
563 if (IsQuantized(tensor)) {
564 auto op = builder.create<tfl::QConstOp>(
565 loc, mlir::TypeAttr::get(shaped_type), value);
566 return op.getOperation();
567 }
568 auto op = builder.create<tfl::ConstOp>(loc, value);
569 return op.getOperation();
570 }
571
ConvertSubgraphIdxsToFunctionAttrs(tflite::BuiltinOptionsUnion options,const std::vector<std::string> & func_names,Builder builder)572 llvm::SmallVector<mlir::NamedAttribute, 4> ConvertSubgraphIdxsToFunctionAttrs(
573 tflite::BuiltinOptionsUnion options,
574 const std::vector<std::string>& func_names, Builder builder) {
575 if (auto* opts = options.AsIfOptions()) {
576 uint32_t then_idx = opts->then_subgraph_index;
577 auto then_attr = builder.getSymbolRefAttr(func_names.at(then_idx));
578 uint32_t else_idx = opts->else_subgraph_index;
579 auto else_attr = builder.getSymbolRefAttr(func_names.at(else_idx));
580
581 return {builder.getNamedAttr("then_branch", then_attr),
582 builder.getNamedAttr("else_branch", else_attr),
583 // TODO(b/139667752): Analyze statelessness correctly
584 builder.getNamedAttr("is_stateless", builder.getBoolAttr(false))};
585 }
586 if (auto* opts = options.AsWhileOptions()) {
587 uint32_t cond_idx = opts->cond_subgraph_index;
588 auto cond_attr = builder.getSymbolRefAttr(func_names.at(cond_idx));
589 uint32_t body_idx = opts->body_subgraph_index;
590 auto body_attr = builder.getSymbolRefAttr(func_names.at(body_idx));
591
592 return {builder.getNamedAttr("cond", cond_attr),
593 builder.getNamedAttr("body", body_attr)};
594 }
595 return {};
596 }
597
AddOpIntermediatesForLstm(const tflite::OperatorT & op,const std::vector<mlir::TensorType> & intermediate_types,OperationState & op_state,Location loc,OpBuilder & builder)598 Status AddOpIntermediatesForLstm(
599 const tflite::OperatorT& op,
600 const std::vector<mlir::TensorType>& intermediate_types,
601 OperationState& op_state, Location loc, OpBuilder& builder) {
602 if (!op.intermediates.empty()) {
603 if (op.intermediates.size() != 5) {
604 auto err = errors::InvalidArgument(
605 "operator has intermediate tensors but the number of them is not "
606 "five.");
607 return emitError(loc, err.ToString()), err;
608 }
609 // Create intermediate value
610
611 const llvm::SmallVector<llvm::StringRef, 5> kIntermediateNames = {
612 "input_to_input_intermediate", "input_to_forget_intermediate",
613 "input_to_cell_intermediate", "input_to_output_intermediate",
614 "effective_hidden_scale_intermediate"};
615 for (auto type_and_name :
616 llvm::zip(intermediate_types, kIntermediateNames)) {
617 mlir::TypeAttr type_attr =
618 mlir::TypeAttr::get(std::get<0>(type_and_name));
619 auto named_attr =
620 builder.getNamedAttr(std::get<1>(type_and_name), type_attr);
621 op_state.addAttribute(named_attr.first, named_attr.second);
622 }
623 }
624 return Status::OK();
625 }
626
627 // TODO(krzysd) Handle function calls
ConvertOp(const tflite::OperatorT & op,const std::vector<Value> & vals_map,const std::vector<mlir::TensorType> & intermediate_types,Value optional_arg_marker,const std::vector<std::unique_ptr<tflite::OperatorCodeT>> & op_codes,const std::vector<std::string> & func_names,const std::vector<std::unique_ptr<tflite::TensorT>> & tensors,Location loc,OpBuilder builder)628 StatusOr<Operation*> ConvertOp(
629 const tflite::OperatorT& op, const std::vector<Value>& vals_map,
630 const std::vector<mlir::TensorType>& intermediate_types,
631 Value optional_arg_marker,
632 const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
633 const std::vector<std::string>& func_names,
634 const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
635 OpBuilder builder) {
636 llvm::SmallVector<Value, 4> operands;
637 llvm::SmallVector<mlir::Type, 2> outputTypes;
638
639 if (op.outputs.empty()) {
640 auto err = errors::InvalidArgument("operator with no outputs");
641 return emitError(loc, err.ToString()), err;
642 }
643
644 const tflite::OperatorCodeT& op_code = *op_codes.at(op.opcode_index);
645
646 TF_ASSIGN_OR_RETURN(const std::string op_name, GetMlirOpName(op, op_code));
647
648 OperationState op_state(loc, op_name);
649
650 for (auto input_num : op.inputs) {
651 if (input_num == -1) {
652 assert(optional_arg_marker != nullptr);
653 op_state.addOperands({optional_arg_marker});
654 } else {
655 op_state.addOperands({vals_map.at(input_num)});
656 }
657 }
658
659 for (auto output_num : op.outputs) {
660 auto& tensor = *tensors.at(output_num);
661 auto type_or_err = GetTensorType(tensor, builder);
662 if (!type_or_err.ok()) {
663 return emitError(loc, type_or_err.status().ToString()),
664 type_or_err.status();
665 }
666 auto type = type_or_err.ConsumeValueOrDie();
667
668 if (op_name == "tfl.quantize") {
669 // Special case for quantize: return type must also be in qtype attribute
670 op_state.addAttribute("qtype", mlir::TypeAttr::get(type));
671 } else if (op_name == "tfl.reshape" && op_state.operands.size() == 1) {
672 // Special case for reshape: the second op is optional in the old
673 // converter and kernel, so we create the second operand, which is
674 // required by the new converter, from the reshape op's option.
675 auto new_shape = op.builtin_options.AsReshapeOptions()->new_shape;
676 auto shape_type = RankedTensorType::get(
677 {static_cast<int64_t>(new_shape.size())}, builder.getIntegerType(32));
678
679 mlir::SmallVector<mlir::Attribute, 4> shape;
680 for (auto s : new_shape) {
681 shape.push_back(builder.getI32IntegerAttr(static_cast<int32_t>(s)));
682 }
683 auto output_shape = DenseElementsAttr::get(shape_type, shape);
684 auto shape_op = builder.create<tfl::ConstOp>(loc, output_shape);
685 op_state.addOperands({shape_op});
686 }
687
688 op_state.addTypes({type});
689 }
690
691 // While the last several tensors could be optional tensors for an tfl op, the
692 // number of input operands could vary. Gets the min/max number of
693 // operands from tflite op name.
694 // Also, since the above code special-handles the `tfl.reshape` op and add an
695 // additional input, we put these function block here.
696 llvm::MinMax input_min_max = mlir::OperandNumbersMinMax(op_name);
697 int input_max_num = input_min_max.Max;
698 int op_input_num = op_state.operands.size();
699 if (input_max_num != 0 && input_max_num > op_input_num) {
700 // If the number of current inputs is less than the op definition, fill in
701 // with `none` value,
702 llvm::SmallVector<Value, 4> none_operands(
703 input_max_num - op_input_num,
704 builder.create<mlir::ConstantOp>(loc, builder.getNoneType(),
705 builder.getUnitAttr()));
706 op_state.addOperands(ArrayRef<Value>(none_operands));
707 }
708
709 if (op_name == "tfl.lstm") {
710 // TODO(b/147587779): add the right region if region is empty.
711 op_state.addRegion();
712 TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
713 builder));
714 }
715 if (op_name == "tfl.while") {
716 // Adds two empty regions for "tfl.while". We will fill the regions after
717 // creating the callee functions because the "tfl.while" input/output types
718 // may be different with the callee functions, and the call ops need to sync
719 // with callee function types.
720 op_state.addRegion();
721 op_state.addRegion();
722 }
723 if (op_name == "tfl.unidirectional_sequence_lstm") {
724 TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
725 builder));
726 }
727 if (op_name == "tfl.reshape") {
728 // Flattern reshape ops when more than one dimension shape operand is given.
729 mlir::DenseIntElementsAttr shape_attr;
730 if (matchPattern(op_state.operands[1], m_Constant(&shape_attr))) {
731 auto shape_ty =
732 op_state.operands[1].getType().dyn_cast<RankedTensorType>();
733 if (shape_ty != nullptr && shape_ty.hasRank() && shape_ty.getRank() > 1) {
734 llvm::SmallVector<mlir::Attribute, 4> shape;
735 int32_t dim_size = 0;
736 for (const auto& dim : llvm::enumerate(shape_attr.getIntValues())) {
737 const int64_t size = dim.value().getSExtValue();
738 shape.push_back(
739 builder.getI32IntegerAttr(static_cast<int32_t>(size)));
740 ++dim_size;
741 }
742 auto shape_type = RankedTensorType::get(
743 {static_cast<int32_t>(dim_size)}, builder.getIntegerType(32));
744 auto output_shape = mlir::DenseElementsAttr::get(shape_type, shape);
745 auto shape_op = builder.create<tfl::ConstOp>(loc, output_shape);
746 op_state.operands[1] = shape_op;
747 }
748 }
749 }
750
751 llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
752 auto builtin_code = tflite::GetBuiltinCode(&op_code);
753 if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
754 auto status = mlir::CustomOptionsToAttributes(
755 op_code.custom_code, op.custom_options, builder, loc, &attrs);
756 if (!status.ok()) {
757 return emitError(loc, status.ToString()), status;
758 }
759 } else {
760 mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs);
761 }
762 op_state.addAttributes(attrs);
763
764 // Handle the conversion from subgraph index to functions for If and While. We
765 // will add CallOps in the region to call the functions later for While.
766 auto function_ref_attrs = ConvertSubgraphIdxsToFunctionAttrs(
767 op.builtin_options, func_names, builder);
768 op_state.addAttributes(function_ref_attrs);
769
770 return builder.createOperation(op_state);
771 }
772
773 // Returns indices of the given tensors in the subgraph. Returns error if a
774 // tensor name cannot be found in the subgraph.
GetTensorIndices(const tflite::SubGraphT & subgraph,const std::vector<std::string> & tensor_names)775 StatusOr<std::vector<int>> GetTensorIndices(
776 const tflite::SubGraphT& subgraph,
777 const std::vector<std::string>& tensor_names) {
778 absl::flat_hash_map<std::string, int> name_to_index;
779 for (auto index_and_tensor : llvm::enumerate(subgraph.tensors)) {
780 name_to_index[index_and_tensor.value()->name] = index_and_tensor.index();
781 }
782
783 std::vector<int> indices;
784 indices.reserve(tensor_names.size());
785
786 for (const auto& name : tensor_names) {
787 auto found = name_to_index.find(name);
788 if (found != name_to_index.end()) {
789 indices.push_back(found->second);
790 } else {
791 return errors::InvalidArgument("could not find tensor in subgraph: ",
792 name);
793 }
794 }
795
796 return indices;
797 }
798
799 // Given a list of tensor indices, returns a string of concatenated tensor names
800 // wrapped in a NamedAttribute.
801 template <typename ContainerType>
BuildTFEntryFunctionAttribute(const tflite::SubGraphT & subgraph,Builder * builder,const std::string name,const ContainerType indices)802 mlir::NamedAttribute BuildTFEntryFunctionAttribute(
803 const tflite::SubGraphT& subgraph, Builder* builder, const std::string name,
804 const ContainerType indices) {
805 auto tensor_names = llvm::map_range(
806 indices, [&](int i) { return subgraph.tensors.at(i)->name; });
807 return builder->getNamedAttr(
808 name, builder->getStringAttr(llvm::join(tensor_names, ",")));
809 }
810
811 // Traverses the subgraph from output_indices to input_indices and returns the
812 // set of ops that are visited.
PruneSubgraph(const tflite::SubGraphT & subgraph,ArrayRef<int32_t> input_indices,ArrayRef<int32_t> output_indices)813 StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
814 const tflite::SubGraphT& subgraph, ArrayRef<int32_t> input_indices,
815 ArrayRef<int32_t> output_indices) {
816 // Create a map from tensor index to defining op.
817 absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
818 for (const auto& op : subgraph.operators) {
819 for (int32_t output : op->outputs) {
820 if (!llvm::is_contained(input_indices, output)) {
821 defining_op[output] = op.get();
822 }
823 }
824 }
825
826 std::vector<const tflite::OperatorT*> queue;
827 for (int32_t output : output_indices) {
828 if (auto& op = defining_op[output]) {
829 queue.push_back(op);
830 }
831 }
832
833 // Traverse the graph towards inputs.
834 absl::flat_hash_set<const tflite::OperatorT*> visited;
835 while (!queue.empty()) {
836 const tflite::OperatorT* op = queue.back();
837 queue.pop_back();
838 if (!visited.insert(op).second) {
839 // The node has already been visited.
840 continue;
841 }
842
843 for (int32_t input : op->inputs) {
844 // Input tensor may not have a defining op in case it is a subgraph input
845 // or a constant tensor.
846 if (auto& op = defining_op[input]) {
847 queue.push_back(op);
848 }
849 }
850 }
851
852 return visited;
853 }
854
855 // We want to adjust the func op according to some cross ops information.
PostProcessFuncOp(FuncOp func)856 static StatusOr<FuncOp> PostProcessFuncOp(FuncOp func) {
857 OpBuilder builder(func);
858 // When a quantized constant is imported, its quantization parameter is set
859 // to be narrow range. Here revert to be the fully range if the user doesn't
860 // require narrow range.
861 func.walk([&](tfl::QConstOp cst) {
862 Value value = cst.getResult();
863 Value full_range_const = value;
864 for (auto& use : value.getUses()) {
865 Operation* user = use.getOwner();
866 if (user->hasTrait<mlir::OpTrait::IsTerminator>()) return;
867 auto qtype = mlir::quant::UniformQuantizedType::getQuantizedElementType(
868 value.getType());
869 // Only the 8-bit constants are imported with narrow range.
870 if (!qtype || qtype.getStorageTypeIntegralWidth() != 8) return;
871
872 auto affine_user = llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
873 if (affine_user &&
874 affine_user.GetAffineOperandIndex() == use.getOperandNumber() &&
875 affine_user.RequiredNarrowRangeAffineOperand())
876 return;
877
878 // Create a fully range quantized constant.
879 if (full_range_const == value) {
880 mlir::quant::QuantizedType new_qtype;
881 if (auto per_axis =
882 qtype.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
883 new_qtype = mlir::quant::UniformQuantizedPerAxisType::get(
884 per_axis.getFlags(), per_axis.getStorageType(),
885 per_axis.getExpressedType(), per_axis.getScales(),
886 per_axis.getZeroPoints(), per_axis.getQuantizedDimension(),
887 per_axis.getStorageTypeMin() - 1, per_axis.getStorageTypeMax());
888 } else if (auto per_tensor =
889 qtype.dyn_cast<mlir::quant::UniformQuantizedType>()) {
890 new_qtype = mlir::quant::UniformQuantizedType::get(
891 per_tensor.getFlags(), per_tensor.getStorageType(),
892 per_tensor.getExpressedType(), per_tensor.getScale(),
893 per_tensor.getZeroPoint(), per_tensor.getStorageTypeMin() - 1,
894 per_tensor.getStorageTypeMax());
895 } else {
896 return;
897 }
898 auto new_output_type = new_qtype.castFromExpressedType(
899 mlir::quant::UniformQuantizedType::castToExpressedType(
900 value.getType()));
901 builder.setInsertionPointAfter(cst.getOperation());
902 auto new_op = builder.create<tfl::QConstOp>(
903 cst.getLoc(), new_output_type, mlir::TypeAttr::get(new_output_type),
904 cst.valueAttr());
905 full_range_const = new_op.output();
906 }
907 use.set(full_range_const);
908 }
909 if (cst.use_empty()) cst.erase();
910 });
911 return func;
912 }
913
914 // Build a FuncOp from a tflite SubGraph
915 // The buffers are directly taken
916 // from the deserialized flatbuffer as we do not have the type information to
917 // interpret them until this point. The base_loc parameter is the location of
918 // the flatbuffer as a whole (usually a file). The is_entry_point flag
919 // controls whether shapeless types are treated as scalars. If
920 // ordered_output_arrays is not empty, then the imported mlir function will only
921 // return nodes in ordered_output_arrays in the same order.
ConvertSubgraph(const tflite::SubGraphT & subgraph,llvm::StringRef name,const std::vector<std::unique_ptr<tflite::OperatorCodeT>> & op_codes,const std::vector<std::string> & func_names,const std::vector<std::unique_ptr<tflite::BufferT>> & buffers,Location base_loc,Builder builder,bool is_entry_point,bool use_external_constant,const std::vector<std::string> & ordered_input_arrays,const std::vector<std::string> & ordered_output_arrays,bool experimental_prune_unreachable_nodes_unconditionally)922 StatusOr<FuncOp> ConvertSubgraph(
923 const tflite::SubGraphT& subgraph, llvm::StringRef name,
924 const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
925 const std::vector<std::string>& func_names,
926 const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
927 Location base_loc, Builder builder, bool is_entry_point,
928 bool use_external_constant,
929 const std::vector<std::string>& ordered_input_arrays,
930 const std::vector<std::string>& ordered_output_arrays,
931 bool experimental_prune_unreachable_nodes_unconditionally) {
932 llvm::SmallVector<mlir::Type, 2> ret_types;
933 llvm::SmallVector<mlir::Type, 4> input_types;
934
935 auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc);
936
937 std::vector<int> func_inputs = subgraph.inputs;
938 if (is_entry_point && !ordered_input_arrays.empty()) {
939 if (!experimental_prune_unreachable_nodes_unconditionally) {
940 // TODO(b/149922113): Resolve input-arrays/pruning flags interaction.
941 return errors::InvalidArgument(
942 "input-arrays should be used with experimental pruning flag");
943 }
944 TF_ASSIGN_OR_RETURN(func_inputs,
945 GetTensorIndices(subgraph, ordered_input_arrays));
946 }
947
948 for (int input : func_inputs) {
949 auto& tensor = *subgraph.tensors.at(input);
950 // TODO(b/138222071) Graph inputs must have static shape per the exporter,
951 // but we cannot differentiate scalars from unranked tensors.
952 // Here we reverse the default assumption that shape = [] means unranked.
953 // when processing main()
954 auto type_or_err = GetTensorType(tensor, builder,
955 /*shapeless_are_scalars=*/is_entry_point,
956 /*is_constant=*/false);
957 if (!type_or_err.ok()) {
958 emitError(func_loc, "error reading argument types")
959 << type_or_err.status().ToString();
960 return type_or_err.status();
961 }
962 auto type = type_or_err.ConsumeValueOrDie();
963 input_types.push_back(type);
964 }
965
966 llvm::SmallVector<bool, 16> is_op_output(subgraph.tensors.size(), false);
967 for (auto& op : subgraph.operators) {
968 for (auto output : op->outputs) {
969 is_op_output[output] = true;
970 }
971 }
972
973 std::vector<int> func_outputs = subgraph.outputs;
974 if (is_entry_point && !ordered_output_arrays.empty()) {
975 TF_ASSIGN_OR_RETURN(func_outputs,
976 GetTensorIndices(subgraph, ordered_output_arrays));
977 }
978
979 for (auto output : func_outputs) {
980 const bool is_func_input = std::find(func_inputs.begin(), func_inputs.end(),
981 output) != func_inputs.end();
982 bool is_constant = !is_op_output[output] && !is_func_input;
983 // There are 2 cases tensor is scalar when it doesn't have a shape in
984 // flatbuffer:
985 // 1. `is_constant` = true, means this tensor is created from a constant op.
986 // 2. `is_func_input` = true and `is_entry_point` = true, which means this
987 // tensor is function input and function input type is a scalar tensor.
988 const bool shapeless_is_scalar =
989 is_constant || (is_func_input && is_entry_point);
990 auto type_or_err = GetTensorType(*subgraph.tensors.at(output), builder,
991 shapeless_is_scalar,
992 /*is_constant=*/is_constant);
993 if (!type_or_err.ok()) {
994 emitError(func_loc, "error reading return types")
995 << type_or_err.status().ToString();
996 return type_or_err.status();
997 }
998 auto type = type_or_err.ConsumeValueOrDie();
999 ret_types.push_back(type);
1000 }
1001 auto func_type = builder.getFunctionType(input_types, ret_types);
1002
1003 // Construct function object
1004 auto func = FuncOp::create(func_loc, name, func_type, /* attrs= */ {});
1005 func.addEntryBlock();
1006 auto& body = func.getBody();
1007 OpBuilder op_builder{body};
1008
1009 std::vector<Value> vals_map(subgraph.tensors.size(), nullptr);
1010 Value maybe_optional_arg_marker = nullptr;
1011
1012 // Get or construct MLIR values for each input
1013 for (int i = 0, e = func_inputs.size(); i < e; i++) {
1014 auto input_tensor = func_inputs[i];
1015 const auto& tensor = *subgraph.tensors.at(input_tensor);
1016 auto loc = TensorLoc(tensor, builder, base_loc);
1017 if (vals_map[input_tensor]) {
1018 auto err = errors::FailedPrecondition("duplicate input arguments");
1019 return emitError(loc, err.ToString()), err;
1020 }
1021 Value input_value = func.getArgument(i);
1022
1023 // If the `tensor` has min/max and doesn't have scale/zero_point
1024 // information, a stats op is created to use the input_value, then the
1025 // `tensor` should be mapped to the result of this new stats op.
1026 if (auto stats_op =
1027 ConvertMinMaxToStatsOp(tensor, op_builder, input_value)) {
1028 vals_map[input_tensor] = stats_op->getResult(0);
1029 } else {
1030 vals_map[input_tensor] = input_value;
1031 }
1032 }
1033
1034 // Set tf.entry_function attribute
1035 if (is_entry_point) {
1036 llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
1037 if (!func_inputs.empty()) {
1038 attributes.push_back(BuildTFEntryFunctionAttribute(
1039 subgraph, &builder, "inputs", func_inputs));
1040 }
1041 if (!func_outputs.empty()) {
1042 attributes.push_back(BuildTFEntryFunctionAttribute(
1043 subgraph, &builder, "outputs", func_outputs));
1044 }
1045 func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
1046 } else {
1047 func.setPrivate();
1048 }
1049
1050 absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
1051 if (experimental_prune_unreachable_nodes_unconditionally) {
1052 TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
1053 PruneSubgraph(subgraph, func_inputs, func_outputs));
1054 }
1055
1056 // Construct MLIR operators from TFLite operators
1057 for (auto& op : subgraph.operators) {
1058 if (experimental_prune_unreachable_nodes_unconditionally &&
1059 !pruned_subgraph_ops.contains(op)) {
1060 continue;
1061 }
1062
1063 for (auto input_num : op->inputs) {
1064 // The operators in a graph are topologically sorted
1065 // and so if no previous operation has produced a tensor
1066 // it must be a constant.
1067 if (input_num == -1) {
1068 if (maybe_optional_arg_marker == nullptr) {
1069 maybe_optional_arg_marker =
1070 op_builder
1071 .create<mlir::ConstantOp>(base_loc, builder.getNoneType(),
1072 builder.getUnitAttr())
1073 .getResult();
1074 }
1075 } else if (!vals_map.at(input_num)) {
1076 auto& const_tensor = *subgraph.tensors[input_num];
1077 auto const_loc = TensorLoc(const_tensor, builder, base_loc);
1078 auto op_or_err =
1079 use_external_constant
1080 ? BuildExternalConstOp(const_tensor, const_tensor.buffer,
1081 op_builder, const_loc)
1082 : BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
1083 const_tensor.is_variable, op_builder, const_loc);
1084 if (!op_or_err.ok()) {
1085 return emitError(const_loc, op_or_err.status().ToString()),
1086 op_or_err.status();
1087 }
1088 vals_map[input_num] = op_or_err.ValueOrDie()->getResult(0);
1089 }
1090 }
1091
1092 // Intermediate tensors for LSTMs are used to carry quantization range
1093 // in their types, so we only need and extract their types.
1094 std::vector<mlir::TensorType> intermediate_types;
1095 intermediate_types.reserve(5);
1096 for (auto intermediate : op->intermediates) {
1097 TF_ASSIGN_OR_RETURN(
1098 auto type, GetTensorType(*subgraph.tensors[intermediate], builder,
1099 /*shapeless_are_scalars=*/true,
1100 /*is_constant=*/false,
1101 /*is_intermediate=*/true));
1102 intermediate_types.emplace_back(type);
1103 }
1104
1105 // The NameLoc corresponding to the name of the first output tensor
1106 auto op_loc =
1107 op->outputs.empty()
1108 ? base_loc
1109 : TensorLoc(*subgraph.tensors[op->outputs[0]], builder, base_loc);
1110 // If there's an optional argument, maybe_optional_arg_marker has been set
1111 // to a valid Value
1112 TF_ASSIGN_OR_RETURN(
1113 auto* mlir_op,
1114 ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
1115 op_codes, func_names, subgraph.tensors, op_loc, op_builder));
1116
1117 // Add the results to the value maps. There are two cases: 1. the result
1118 // tensor does not have min/max values, the original op result is used
1119 // directly; 2. the result tensor has some min/max values, a stats op is
1120 // created, then the result of the stats op is used.
1121 for (auto pair : llvm::enumerate(mlir_op->getResults())) {
1122 int output_tensor_index = op->outputs[pair.index()];
1123 auto& tensor = *subgraph.tensors[output_tensor_index];
1124 if (auto stats_op =
1125 ConvertMinMaxToStatsOp(tensor, op_builder, pair.value())) {
1126 vals_map[output_tensor_index] = stats_op->getResult(0);
1127 } else {
1128 vals_map[output_tensor_index] = pair.value();
1129 }
1130 }
1131 }
1132
1133 // Construct return values
1134 llvm::SmallVector<Value, 4> return_operands;
1135 for (auto index : func_outputs) {
1136 if (!vals_map.at(index)) {
1137 auto& const_tensor = *subgraph.tensors[index];
1138 auto const_loc = TensorLoc(const_tensor, builder, base_loc);
1139 auto op_or_err =
1140 use_external_constant
1141 ? BuildExternalConstOp(const_tensor, const_tensor.buffer,
1142 op_builder, const_loc)
1143 : BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
1144 const_tensor.is_variable, op_builder, const_loc);
1145 if (!op_or_err.ok()) {
1146 return emitError(const_loc, op_or_err.status().ToString()),
1147 op_or_err.status();
1148 }
1149 vals_map[index] = op_or_err.ValueOrDie()->getResult(0);
1150 }
1151 return_operands.push_back(vals_map[index]);
1152 }
1153
1154 op_builder.create<mlir::ReturnOp>(base_loc, return_operands);
1155
1156 return PostProcessFuncOp(func);
1157 }
1158
1159 // TFLite subgraphs do not necessarily have names, though MLIR functions must
1160 // have them, so we generate a name for subgraphs that are missing one here.
1161 // Note: in TFLite, the first subgraph is the entry point, and in MLIR that
1162 // represents TFLite, this entry point must be called "main"
1163 // TODO(b/131175224,b/132239787) Support multiple entry points
SubgraphName(unsigned index,const tflite::SubGraphT & subgraph)1164 std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
1165 if (index == 0) {
1166 return "main";
1167 }
1168 if (subgraph.name.empty()) {
1169 return llvm::formatv("fn_{0}", index).str();
1170 }
1171 return subgraph.name;
1172 }
1173
1174 // Adds a CallOp in `region` to call the `func` and returns the results of
1175 // CallOp.
AddCallOpInWhileOpRegion(mlir::Region & region,mlir::FuncOp func)1176 void AddCallOpInWhileOpRegion(mlir::Region& region, mlir::FuncOp func) {
1177 OpBuilder op_builder{region};
1178 region.push_back(new mlir::Block());
1179 region.addArguments(func.getType().getInputs());
1180 op_builder.setInsertionPointToStart(®ion.front());
1181 auto call_op = op_builder.create<mlir::CallOp>(
1182 region.getLoc(), func.getType().getResults(), func.sym_name(),
1183 region.getArguments());
1184 op_builder.create<mlir::TFL::YieldOp>(region.getLoc(), call_op.getResults());
1185 }
1186
1187 // TFL::WhileOp has regions, so we add CallOp to call the FuncOp in the regions
1188 // if we have while ops.
AddRegionsForTflWhileOp(mlir::ModuleOp module)1189 void AddRegionsForTflWhileOp(mlir::ModuleOp module) {
1190 mlir::SymbolTable symbol_table(module);
1191 module.walk([&](mlir::TFL::WhileOp while_op) {
1192 auto cond = symbol_table.lookup<mlir::FuncOp>(
1193 while_op->getAttr("cond").cast<mlir::FlatSymbolRefAttr>().getValue());
1194 AddCallOpInWhileOpRegion(while_op.cond(), cond);
1195 while_op.removeAttr("cond");
1196 auto body = symbol_table.lookup<mlir::FuncOp>(
1197 while_op->getAttr("body").cast<mlir::FlatSymbolRefAttr>().getValue());
1198 AddCallOpInWhileOpRegion(while_op.body(), body);
1199 while_op.removeAttr("body");
1200 });
1201 }
1202 } // namespace
1203
FlatBufferToMlir(absl::string_view buffer,MLIRContext * context,Location base_loc,bool use_external_constant,const std::vector<std::string> & ordered_input_arrays,const std::vector<std::string> & ordered_output_arrays,bool experimental_prune_unreachable_nodes_unconditionally)1204 OwningModuleRef tflite::FlatBufferToMlir(
1205 absl::string_view buffer, MLIRContext* context, Location base_loc,
1206 bool use_external_constant,
1207 const std::vector<std::string>& ordered_input_arrays,
1208 const std::vector<std::string>& ordered_output_arrays,
1209 bool experimental_prune_unreachable_nodes_unconditionally) {
1210 context->loadDialect<
1211 mlir::StandardOpsDialect, mlir::quant::QuantizationDialect,
1212 mlir::TFL::TensorFlowLiteDialect, mlir::TF::TensorFlowDialect>();
1213
1214 auto model_ptr =
1215 FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
1216 if (nullptr == model_ptr) {
1217 return emitError(base_loc, "couldn't parse flatbuffer"), nullptr;
1218 }
1219
1220 std::unique_ptr<ModelT> model(model_ptr->GetModel()->UnPack());
1221
1222 auto builder = Builder(context);
1223
1224 std::vector<std::string> func_names;
1225 for (auto& subgraph : model->subgraphs) {
1226 func_names.push_back(subgraph->name);
1227 }
1228
1229 auto module = mlir::ModuleOp::create(base_loc);
1230 // We currently don't use this to make decisions, but we could
1231 // use it in exports or if there are breaking changes
1232 module->setAttr("tfl.schema_version",
1233 builder.getI32IntegerAttr(model->version));
1234 if (!model->description.empty()) {
1235 module->setAttr("tfl.description",
1236 builder.getStringAttr(model->description));
1237 }
1238
1239 for (auto e : llvm::enumerate(model->subgraphs)) {
1240 auto& subgraph = e.value();
1241 std::string name = SubgraphName(e.index(), *subgraph);
1242 auto func_or_error = ConvertSubgraph(
1243 *subgraph, name, model->operator_codes, func_names, model->buffers,
1244 base_loc, builder,
1245 // TODO(b/131175224,b/132239787) Support multiple entry points
1246 /*is_entry_point=*/e.index() == 0,
1247 /*use_external_constant=*/use_external_constant, ordered_input_arrays,
1248 ordered_output_arrays,
1249 experimental_prune_unreachable_nodes_unconditionally);
1250 if (!func_or_error.ok()) {
1251 return emitError(base_loc, "could not translate function ")
1252 << subgraph->name << ": "
1253 << func_or_error.status().error_message(),
1254 nullptr;
1255 }
1256 module.push_back(func_or_error.ConsumeValueOrDie());
1257 }
1258 AddRegionsForTflWhileOp(module);
1259 return OwningModuleRef(module);
1260 }
1261