1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
17
18 #include <algorithm>
19 #include <cctype>
20 #include <climits>
21 #include <cstdint>
22 #include <iostream>
23 #include <sstream>
24 #include <string>
25 #include <vector>
26
27 #include "absl/base/casts.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/strings/string_view.h"
31 #include "llvm/ADT/APFloat.h"
32 #include "llvm/ADT/APInt.h"
33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/DenseMap.h"
35 #include "llvm/ADT/None.h"
36 #include "llvm/ADT/Optional.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/StringExtras.h"
40 #include "llvm/ADT/StringRef.h"
41 #include "llvm/ADT/Twine.h"
42 #include "llvm/Support/Casting.h"
43 #include "llvm/Support/CommandLine.h"
44 #include "llvm/Support/Endian.h"
45 #include "llvm/Support/FormatVariadic.h"
46 #include "llvm/Support/MemoryBuffer.h"
47 #include "llvm/Support/SourceMgr.h"
48 #include "llvm/Support/raw_ostream.h"
49 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
50 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
51 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
52 #include "mlir/IR/Attributes.h" // from @llvm-project
53 #include "mlir/IR/Builders.h" // from @llvm-project
54 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
55 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
56 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
57 #include "mlir/IR/Diagnostics.h" // from @llvm-project
58 #include "mlir/IR/Location.h" // from @llvm-project
59 #include "mlir/IR/MLIRContext.h" // from @llvm-project
60 #include "mlir/IR/Operation.h" // from @llvm-project
61 #include "mlir/IR/OperationSupport.h" // from @llvm-project
62 #include "mlir/IR/Types.h" // from @llvm-project
63 #include "mlir/IR/Value.h" // from @llvm-project
64 #include "mlir/Support/LLVM.h" // from @llvm-project
65 #include "mlir/Translation.h" // from @llvm-project
66 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
67 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
68 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
69 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
70 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
71 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
72 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
73 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
74 #include "tensorflow/compiler/xla/statusor.h"
75 #include "tensorflow/core/framework/tensor.pb.h"
76 #include "tensorflow/core/framework/tensor_shape.pb.h"
77 #include "tensorflow/core/platform/errors.h"
78 #include "tensorflow/core/platform/status.h"
79 #include "tensorflow/lite/model.h"
80 #include "tensorflow/lite/schema/schema_generated.h"
81 #include "tensorflow/lite/schema/schema_utils.h"
82 #include "tensorflow/lite/string_util.h"
83
84 using llvm::ArrayRef;
85 using mlir::Builder;
86 using mlir::DenseElementsAttr;
87 using mlir::FuncOp;
88 using mlir::Location;
89 using mlir::MLIRContext;
90 using mlir::OpBuilder;
91 using mlir::Operation;
92 using mlir::OperationState;
93 using mlir::OwningModuleRef;
94 using mlir::RankedTensorType;
95 using mlir::UnrankedTensorType;
96 using mlir::Value;
97 using mlir::quant::QuantizedType;
98 using tflite::OperatorT;
99 using tflite::TensorT;
100 using xla::Status;
101 using xla::StatusOr;
102
103 namespace errors = tensorflow::errors;
104 namespace tfl = mlir::TFL;
105
106 namespace {
IsScalar(const TensorT & tensor)107 bool IsScalar(const TensorT& tensor) {
108 // TODO(b/138222071) We can't distinguish scalars and unranked tensors
109 // Work out a way to handle this and stub out the code until then
110 return tensor.shape.empty() && false;
111 }
112
IsQuantized(const TensorT & tensor)113 bool IsQuantized(const TensorT& tensor) {
114 return (tensor.quantization != nullptr) &&
115 !tensor.quantization->zero_point.empty();
116 }
117
118 // Create the MLIR NamedLoc location corresponding to a given tensor
TensorLoc(const TensorT & tensor,Builder builder,Location base)119 Location TensorLoc(const TensorT& tensor, Builder builder, Location base) {
120 if (tensor.name.empty()) {
121 return base;
122 }
123 return mlir::NameLoc::get(builder.getIdentifier(tensor.name), base);
124 }
125
126 // Create the MLIR Location corresponding to a given op. This is an
127 // experimental/debugging feature and production code should not rely on names
128 // of intermediate tensors since importer doesn't guarantee to preserve tensor
129 // names except output tensors.
OpLoc(const OperatorT & op,const std::vector<std::unique_ptr<tflite::TensorT>> & tensors,Builder builder,Location base)130 Location OpLoc(const OperatorT& op,
131 const std::vector<std::unique_ptr<tflite::TensorT>>& tensors,
132 Builder builder, Location base) {
133 if (op.outputs.empty()) return base;
134
135 llvm::SmallVector<Location, 4> locations;
136 locations.reserve(op.outputs.size());
137 for (auto tensor_index : op.outputs) {
138 locations.push_back(TensorLoc(*tensors[tensor_index], builder, base));
139 }
140 return mlir::FusedLoc::get(builder.getContext(), locations);
141 }
142
143 // Returns the correct type for a quantized tensor
144 // We have a special case for constants since they have a higher minimum value.
GetQuantizedType(const TensorT & tensor,Builder builder,bool is_constant=false)145 StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
146 bool is_constant = false) {
147 tflite::QuantizationParametersT& quant_params = *tensor.quantization;
148 if (quant_params.details.AsCustomQuantization()) {
149 return errors::Unimplemented("Cannot handle experimental quantization");
150 }
151
152 bool is_signed = true;
153 mlir::IntegerType storage_type;
154 if (tensor.type == tflite::TensorType_UINT8) {
155 is_signed = false;
156 storage_type = builder.getIntegerType(8);
157 } else {
158 auto raw_elem_type = ConvertElementType(tensor.type, builder);
159 if (!raw_elem_type.isa<mlir::IntegerType>()) {
160 return errors::InvalidArgument(
161 "Quantized tensors must be stored as integers");
162 }
163 storage_type = raw_elem_type.cast<mlir::IntegerType>();
164 }
165
166 // TFlite uses narrow-range [u]int8 for constant buffers of quantized weights.
167 // Since we don't know which ones are weights, we represent this optimization
168 // as a change in the storage bounds for the type for all constants of this
169 // type.
170 bool is_weight_buffer = is_constant && (storage_type.getWidth() == 8);
171
172 int64_t storage_min = QuantizedType::getDefaultMinimumForInteger(
173 is_signed, storage_type.getWidth()) +
174 static_cast<int>(is_weight_buffer);
175 int64_t storage_max = QuantizedType::getDefaultMaximumForInteger(
176 is_signed, storage_type.getWidth());
177 uint32_t flags =
178 is_signed ? mlir::quant::QuantizationFlags::FlagValue::Signed : 0;
179
180 // Rejects if quantized tensors have zero scales.
181 for (float scale : quant_params.scale) {
182 if (scale == 0) {
183 return errors::InvalidArgument(
184 "Quantized tensors must have non-zero scales");
185 }
186 }
187
188 // Scale size can't be zero as it is checked before.
189 if (quant_params.scale.size() != 1) {
190 llvm::SmallVector<double, 4> scales(quant_params.scale.begin(),
191 quant_params.scale.end());
192 return mlir::quant::UniformQuantizedPerAxisType::get(
193 flags, storage_type, builder.getF32Type(), scales,
194 quant_params.zero_point, quant_params.quantized_dimension, storage_min,
195 storage_max);
196 }
197 return mlir::quant::UniformQuantizedType::get(
198 flags, storage_type, builder.getF32Type(), quant_params.scale.at(0),
199 quant_params.zero_point.at(0), storage_min, storage_max);
200 }
201
202 // import float tensor with calibration value into calibrated quantized type.
GetCalibratedQuantizedType(const TensorT & tensor,Builder builder)203 StatusOr<QuantizedType> GetCalibratedQuantizedType(const TensorT& tensor,
204 Builder builder) {
205 if (tensor.quantization == nullptr) {
206 return errors::InvalidArgument("The tensor is not quantized.");
207 }
208 auto raw_elem_type = ConvertElementType(tensor.type, builder);
209 float min = tensor.quantization->min[0];
210 float max = tensor.quantization->max[0];
211 return mlir::quant::CalibratedQuantizedType::get(raw_elem_type, min, max);
212 }
213
214 // TODO(b/138222071) Remove shapeless_are_scalars once we can reliably
215 // make that distinction and don't have to rely on context
216 // (input to main and constants must have static shape)
GetTensorType(const TensorT & tensor,Builder builder,bool shapeless_are_scalars=false,bool is_constant=false,bool is_intermediate=false)217 StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
218 bool shapeless_are_scalars = false,
219 bool is_constant = false,
220 bool is_intermediate = false) {
221 mlir::Type elem_type = ConvertElementType(tensor.type, builder);
222 if (IsQuantized(tensor)) {
223 TF_ASSIGN_OR_RETURN(elem_type,
224 GetQuantizedType(tensor, builder, is_constant));
225 }
226
227 // Intermediate tensors with calibration value (but not scale and zero points)
228 // should return calibrated quantized type.
229 if (is_intermediate && tensor.quantization != nullptr &&
230 !IsQuantized(tensor)) {
231 TF_ASSIGN_OR_RETURN(elem_type, GetCalibratedQuantizedType(tensor, builder));
232 }
233
234 if (IsScalar(tensor) || (shapeless_are_scalars && tensor.shape.empty())) {
235 return RankedTensorType::get({}, elem_type);
236 }
237
238 if (!tensor.shape_signature.empty()) {
239 llvm::SmallVector<int64_t, 4> shape(tensor.shape_signature.begin(),
240 tensor.shape_signature.end());
241 return RankedTensorType::get(shape, elem_type);
242 }
243
244 if (!tensor.shape.empty()) {
245 llvm::SmallVector<int64_t, 4> shape(tensor.shape.begin(),
246 tensor.shape.end());
247 return RankedTensorType::get(shape, elem_type);
248 }
249
250 return UnrankedTensorType::get(elem_type);
251 }
252
253 // Extract the min max information in the tensor and create the quant stats op.
254 // If the input `tensor` has scale/zero_point, `res` should have quantized
255 // type, thus none stats op is required and nullptr is retruned.
256 // If the min max information is invalid, nullptr is returned.
ConvertMinMaxToStatsOp(const TensorT & tensor,OpBuilder b,Value res)257 mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
258 Value res) {
259 // If the `tensor` has scale/zero_point, it must have been quantized, then the
260 // min/max stats is just for comments, so ignore it.
261 if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
262 // If the result isn't float and unquantizable, the min/max is ignored.
263 if (!res.getType()
264 .cast<mlir::ShapedType>()
265 .getElementType()
266 .isa<mlir::FloatType>()) {
267 return nullptr;
268 }
269 auto mins = tensor.quantization->min;
270 auto maxs = tensor.quantization->max;
271 if (mins.size() != maxs.size() || mins.empty()) return nullptr;
272
273 llvm::SmallVector<llvm::APFloat, 4> min_maxs;
274 min_maxs.reserve(mins.size() * 2);
275 for (int i = 0, end = mins.size(); i < end; ++i) {
276 llvm::APFloat min(mins[i]);
277 llvm::APFloat max(maxs[i]);
278 min_maxs.push_back(min);
279 min_maxs.push_back(max);
280 }
281 // The layer stats contain only the first min/max pairs.
282 mlir::ElementsAttr layer_stats = mlir::DenseFPElementsAttr::get(
283 mlir::RankedTensorType::get({2}, b.getF32Type()),
284 {min_maxs[0], min_maxs[1]});
285 mlir::ElementsAttr axis_stats;
286 mlir::IntegerAttr axis;
287 if (mins.size() > 1) {
288 llvm::SmallVector<int64_t, 4> axis_stats_shape{
289 static_cast<int64_t>(mins.size()), 2};
290 axis_stats = mlir::DenseFPElementsAttr::get(
291 mlir::RankedTensorType::get(axis_stats_shape, b.getF32Type()),
292 min_maxs);
293 // TODO(fengliuai): this quantization dimension isn't correct.
294 axis = b.getI64IntegerAttr(tensor.quantization->quantized_dimension);
295 }
296 return b.create<mlir::quant::StatisticsOp>(b.getUnknownLoc(), res,
297 layer_stats, axis_stats, axis);
298 }
299
300 // Returns true if this is a basic LSTM op.
IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union)301 bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
302 if (const auto* op = op_union.AsLSTMOptions()) {
303 return op->kernel_type == tflite::LSTMKernelType_BASIC;
304 } else {
305 return false;
306 }
307 }
308
309 // Gets the MLIR op name with the dialect name for the flatbuffer operator.
GetMlirOpName(const tflite::OperatorT & op,const tflite::OperatorCodeT & op_code)310 StatusOr<std::string> GetMlirOpName(const tflite::OperatorT& op,
311 const tflite::OperatorCodeT& op_code) {
312 if (IsBasicLSTMOp(op.builtin_options)) {
313 return std::string("tfl.basic_lstm");
314 }
315
316 auto builtin_code = tflite::GetBuiltinCode(&op_code);
317 if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
318 return std::string("tfl.custom");
319 }
320 if (builtin_code == tflite::BuiltinOperator_IF) {
321 return std::string("tf.If");
322 }
323 if (builtin_code == tflite::BuiltinOperator_WHILE) {
324 return std::string("tfl.while");
325 }
326
327 llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code));
328 return llvm::Twine("tfl.", op_name.lower()).str();
329 }
330
331 // The buffers in TFLite flatbuffers have their contents stored as a vector of
332 // bytes that represent little-endian values.
333 // The read_size parameter is present to allow reading both float16 and float32s
334 // without a case split.
335 template <typename T>
ReadAsLittleEndian(ArrayRef<uint8_t> bytes)336 std::vector<T> ReadAsLittleEndian(ArrayRef<uint8_t> bytes) {
337 std::vector<T> ret;
338 size_t read_size = sizeof(T);
339 int bytes_len = bytes.size();
340 assert(bytes_len % read_size == 0);
341
342 int elem_count = bytes_len / read_size;
343 ret.reserve(elem_count);
344
345 const char* data_ptr = reinterpret_cast<const char*>(bytes.data());
346 for (int i = 0; i < elem_count; i++) {
347 ret.push_back(
348 llvm::support::endian::readNext<T, llvm::support::little,
349 llvm::support::unaligned>(data_ptr));
350 }
351 return ret;
352 }
353
ConvertTfliteConstTensor(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer)354 tensorflow::TensorProto ConvertTfliteConstTensor(
355 const tflite::TensorT& tensor, const std::vector<uint8_t>& buffer) {
356 tensorflow::TensorProto ret;
357 ret.set_dtype(TflTypeToTfType(tensor.type));
358
359 tensorflow::TensorShapeProto* shape = ret.mutable_tensor_shape();
360 shape->set_unknown_rank(false);
361 for (auto dim : tensor.shape) {
362 shape->add_dim()->set_size(int64_t{dim});
363 }
364 // TensorFlow Lite uses tflite::DynamicBufer to encode vector of strings.
365 if (tensor.type == tflite::TensorType_STRING) {
366 for (int i = 0; i < tflite::GetStringCount(buffer.data()); ++i) {
367 tflite::StringRef str = tflite::GetString(buffer.data(), i);
368 ret.add_string_val(str.str, str.len);
369 }
370 return ret;
371 }
372 std::string content;
373 content.assign(reinterpret_cast<const char*>(buffer.data()), buffer.size());
374 ret.set_tensor_content(content);
375 return ret;
376 }
377
ConvertFloatBuffer(mlir::RankedTensorType shaped_type,mlir::FloatType elem_type,const std::vector<uint8_t> & buffer)378 StatusOr<mlir::ElementsAttr> ConvertFloatBuffer(
379 mlir::RankedTensorType shaped_type, mlir::FloatType elem_type,
380 const std::vector<uint8_t>& buffer) {
381 size_t bytes_len = buffer.size();
382
383 // The bytes of floats are stored little-endian.
384 switch (elem_type.getWidth()) {
385 case 16: {
386 assert(bytes_len % 2 == 0);
387 int elem_count = bytes_len / 2;
388 std::vector<llvm::APFloat> values;
389 values.reserve(elem_count);
390
391 const char* data = reinterpret_cast<const char*>(buffer.data());
392 auto& semantics = elem_type.getFloatSemantics();
393
394 for (int i = 0; i < elem_count; i++) {
395 uint16_t bit_repr =
396 llvm::support::endian::readNext<uint16_t, llvm::support::little,
397 llvm::support::unaligned>(data);
398 llvm::APInt int_repr(16, bit_repr);
399 values.emplace_back(semantics, int_repr);
400 }
401
402 return DenseElementsAttr::get(shaped_type, values);
403 }
404 case 32: {
405 assert(bytes_len % 4 == 0);
406 int elem_count = bytes_len / 4;
407 std::vector<float> values;
408 values.reserve(elem_count);
409
410 const char* data = reinterpret_cast<const char*>(buffer.data());
411
412 for (int i = 0; i < elem_count; i++) {
413 uint32_t bit_repr =
414 llvm::support::endian::readNext<uint32_t, llvm::support::little,
415 llvm::support::unaligned>(data);
416 values.push_back(absl::bit_cast<float>(bit_repr));
417 }
418 return DenseElementsAttr::get(shaped_type, ArrayRef<float>(values));
419 }
420 case 64: {
421 assert(bytes_len % 8 == 0);
422 int elem_count = bytes_len / 8;
423 std::vector<double> values;
424 values.reserve(elem_count);
425
426 const char* data = reinterpret_cast<const char*>(buffer.data());
427
428 for (int i = 0; i < elem_count; i++) {
429 uint64_t bit_repr =
430 llvm::support::endian::readNext<uint64_t, llvm::support::little,
431 llvm::support::unaligned>(data);
432 values.push_back(absl::bit_cast<double>(bit_repr));
433 }
434 return DenseElementsAttr::get(shaped_type, ArrayRef<double>(values));
435 }
436 }
437 return errors::InvalidArgument("unsupported bit width", elem_type.getWidth());
438 }
439
ConvertIntBuffer(mlir::RankedTensorType shaped_type,mlir::Type elem_type,const std::vector<uint8_t> & buffer)440 StatusOr<mlir::ElementsAttr> ConvertIntBuffer(
441 mlir::RankedTensorType shaped_type, mlir::Type elem_type,
442 const std::vector<uint8_t>& buffer) {
443 unsigned bit_width;
444 if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) {
445 bit_width = itype.getWidth();
446 } else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) {
447 bit_width = qtype.getStorageTypeIntegralWidth();
448 shaped_type = mlir::RankedTensorType::get(shaped_type.getShape(),
449 qtype.getStorageType());
450 } else {
451 return errors::InvalidArgument("unsupported integer constant type");
452 }
453
454 switch (bit_width) {
455 case 1: {
456 // vector<bool> doesn't convert to an ArrayRef
457 llvm::SmallVector<bool, 8> values;
458 values.reserve(buffer.size());
459 for (auto b : buffer) {
460 values.emplace_back(b != 0);
461 }
462 return DenseElementsAttr::get(shaped_type, ArrayRef<bool>(values));
463 }
464 case 8: {
465 return DenseElementsAttr::get(shaped_type, ArrayRef<uint8_t>(buffer));
466 }
467 case 16: {
468 auto values = ReadAsLittleEndian<uint16_t>(buffer);
469 return DenseElementsAttr::get(shaped_type, ArrayRef<uint16_t>(values));
470 }
471 case 32: {
472 auto values = ReadAsLittleEndian<uint32_t>(buffer);
473 return DenseElementsAttr::get(shaped_type, ArrayRef<uint32_t>(values));
474 }
475 case 64: {
476 auto values = ReadAsLittleEndian<uint64_t>(buffer);
477 return DenseElementsAttr::get(shaped_type, ArrayRef<uint64_t>(values));
478 }
479 default:
480 return errors::Unimplemented("Cannot handle bit width ", bit_width);
481 }
482 }
483
BuildExternalConstOp(const tflite::TensorT & tensor,int32_t buffer_index,OpBuilder builder,Location loc)484 StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
485 int32_t buffer_index,
486 OpBuilder builder, Location loc) {
487 TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
488 /*shapeless_are_scalars=*/true,
489 /*is_constant=*/true));
490 auto shaped_type = type.dyn_cast<mlir::RankedTensorType>();
491 if (!shaped_type) {
492 return errors::Internal("Constant doesn't have a shape");
493 }
494 auto op = builder.create<tfl::ExternalConstOp>(
495 loc, shaped_type, builder.getI32IntegerAttr(buffer_index));
496 return op.getOperation();
497 }
498
499 // Gets a constant splat for the given value of type. Requires value to be of
500 // type static shaped RankedTensorType. `unique_index` is used to get the unique
501 // value for the attribute.
GetSplat(RankedTensorType type,int unique_index,OpBuilder builder)502 static mlir::ElementsAttr GetSplat(RankedTensorType type, int unique_index,
503 OpBuilder builder) {
504 mlir::Type element_ty = getElementTypeOrSelf(type);
505
506 if (element_ty.isSignlessInteger())
507 return DenseElementsAttr::get(
508 type, builder.getIntegerAttr(element_ty, unique_index));
509
510 if (element_ty.isa<mlir::FloatType>())
511 return DenseElementsAttr::get(
512 type, builder.getFloatAttr(element_ty, unique_index));
513
514 if (auto qtype = element_ty.dyn_cast<QuantizedType>()) {
515 mlir::RankedTensorType new_type =
516 RankedTensorType::get(type.getShape(), qtype.getStorageType());
517 return DenseElementsAttr::get(
518 new_type, builder.getIntegerAttr(qtype.getStorageType(), unique_index));
519 }
520 llvm_unreachable("unhandled element type");
521 }
522
523 // TODO(b/172664358): Creates a new op instead of reusing constant op.
524 // Creates a constant op to represent stateful variable. The function static
525 // variable `stateful_variable_idx` is used as a unique value for each constant
526 // to avoid CSEed. `tensor` is the data structure of flatbuffer. `shaped_type`
527 // is the ShapedType for the const op.
BuildVariableOp(const tflite::TensorT & tensor,mlir::RankedTensorType shaped_type,OpBuilder builder,Location loc)528 Operation* BuildVariableOp(const tflite::TensorT& tensor,
529 mlir::RankedTensorType shaped_type,
530 OpBuilder builder, Location loc) {
531 static int stateful_variable_idx = 0;
532 mlir::ElementsAttr value =
533 GetSplat(shaped_type, stateful_variable_idx++, builder);
534 if (IsQuantized(tensor)) {
535 auto op = builder.create<tfl::QConstOp>(
536 loc, mlir::TypeAttr::get(shaped_type), value);
537 return op.getOperation();
538 }
539 auto op = builder.create<tfl::ConstOp>(loc, value);
540 if (tensor.quantization && !tensor.quantization->min.empty()) {
541 if (auto stats_op =
542 ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) {
543 return stats_op;
544 }
545 }
546 return op.getOperation();
547 }
548
ConvertSparseIndexVector(const tflite::SparseIndexVectorUnion & sparse_index_vector,OpBuilder & builder)549 static StatusOr<mlir::ArrayAttr> ConvertSparseIndexVector(
550 const tflite::SparseIndexVectorUnion& sparse_index_vector,
551 OpBuilder& builder) {
552 if (sparse_index_vector.type == tflite::SparseIndexVector_Int32Vector) {
553 return builder.getI32ArrayAttr(sparse_index_vector.AsInt32Vector()->values);
554 } else if (sparse_index_vector.type ==
555 tflite::SparseIndexVector_Uint16Vector) {
556 const auto& inputs = sparse_index_vector.AsUint16Vector()->values;
557 std::vector<int32_t> outputs(inputs.size());
558 std::transform(inputs.begin(), inputs.end(), outputs.begin(),
559 [](auto x) { return static_cast<int32_t>(x); });
560 return builder.getI32ArrayAttr(outputs);
561 } else if (sparse_index_vector.type ==
562 tflite::SparseIndexVector_Uint8Vector) {
563 const auto& inputs = sparse_index_vector.AsUint8Vector()->values;
564 std::vector<int32_t> outputs(inputs.size());
565 std::transform(inputs.begin(), inputs.end(), outputs.begin(),
566 [](auto x) { return static_cast<int32_t>(x); });
567 return builder.getI32ArrayAttr(outputs);
568 } else {
569 return errors::Unimplemented("Unsupported SparseIndexVector type");
570 }
571 }
572
BuildSparseConstOp(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer,const mlir::RankedTensorType shaped_type,OpBuilder & builder,Location loc)573 static StatusOr<Operation*> BuildSparseConstOp(
574 const tflite::TensorT& tensor, const std::vector<uint8_t>& buffer,
575 const mlir::RankedTensorType shaped_type, OpBuilder& builder,
576 Location loc) {
577 tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
578 repr.clear_tensor_shape();
579 if (IsQuantized(tensor)) {
580 repr.mutable_tensor_shape()->add_dim()->set_size(buffer.size());
581 repr.set_dtype(tensorflow::DT_INT8);
582 } else {
583 repr.mutable_tensor_shape()->add_dim()->set_size(
584 buffer.size() / (shaped_type.getElementTypeBitWidth() / CHAR_BIT));
585 }
586 TF_ASSIGN_OR_RETURN(mlir::ElementsAttr compressed_data,
587 tensorflow::ConvertTensorProto(repr, &builder));
588
589 const int dim_metadata_size = tensor.sparsity->dim_metadata.size();
590 std::vector<mlir::Attribute> dim_metadata(dim_metadata_size);
591 for (int i = 0; i < dim_metadata_size; i++) {
592 if (tensor.sparsity->dim_metadata[i]->format ==
593 tflite::DimensionType_DENSE) {
594 dim_metadata[i] = tfl::DimensionMetadataAttr::get(
595 builder.getStringAttr("DENSE"),
596 builder.getI32IntegerAttr(
597 tensor.sparsity->dim_metadata[i]->dense_size),
598 builder.getI32ArrayAttr({}), builder.getI32ArrayAttr({}),
599 builder.getContext());
600 } else if (tensor.sparsity->dim_metadata[i]->format ==
601 tflite::DimensionType_SPARSE_CSR) {
602 TF_ASSIGN_OR_RETURN(
603 mlir::ArrayAttr segments,
604 ConvertSparseIndexVector(
605 tensor.sparsity->dim_metadata[i]->array_segments, builder));
606 TF_ASSIGN_OR_RETURN(
607 mlir::ArrayAttr indices,
608 ConvertSparseIndexVector(
609 tensor.sparsity->dim_metadata[i]->array_indices, builder));
610 dim_metadata[i] = tfl::DimensionMetadataAttr::get(
611 builder.getStringAttr("SPARSE_CSR"), builder.getI32IntegerAttr(0),
612 segments, indices, builder.getContext());
613 } else {
614 return errors::Unimplemented("Unsupported dimension metadata type");
615 }
616 }
617 auto s_param = tfl::SparsityParameterAttr::get(
618 builder.getI32ArrayAttr(tensor.sparsity->traversal_order),
619 builder.getI32ArrayAttr(tensor.sparsity->block_map),
620 builder.getArrayAttr(dim_metadata), builder.getContext());
621
622 auto value_type = shaped_type;
623 if (IsQuantized(tensor)) {
624 value_type = RankedTensorType::get(
625 shaped_type.getShape(), shaped_type.getElementType()
626 .dyn_cast<mlir::quant::QuantizedType>()
627 .getStorageType());
628 }
629 std::vector<char> dense_buffer(
630 value_type.getElementType().getIntOrFloatBitWidth() / CHAR_BIT);
631 mlir::Attribute dummy_value =
632 mlir::DenseIntOrFPElementsAttr::getFromRawBuffer(value_type, dense_buffer,
633 /*isSplatBuffer=*/true);
634
635 if (IsQuantized(tensor)) {
636 return builder
637 .create<tfl::SparseQConstOp>(loc, mlir::TypeAttr::get(shaped_type),
638 dummy_value, s_param, compressed_data)
639 .getOperation();
640 }
641 return builder
642 .create<tfl::SparseConstOp>(loc, dummy_value, s_param, compressed_data)
643 .getOperation();
644 }
645
BuildConstOp(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer,bool is_variable,OpBuilder builder,Location loc)646 StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
647 const std::vector<uint8_t>& buffer,
648 bool is_variable, OpBuilder builder,
649 Location loc) {
650 TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
651 /*shapeless_are_scalars=*/true,
652 /*is_constant=*/true));
653 auto shaped_type = type.dyn_cast<mlir::RankedTensorType>();
654 if (!shaped_type) {
655 return errors::Internal("Constant doesn't have a shape");
656 }
657
658 if (tensor.sparsity != nullptr) {
659 return BuildSparseConstOp(tensor, buffer, shaped_type, builder, loc);
660 }
661
662 auto elem_type = shaped_type.getElementType();
663
664 mlir::ElementsAttr value;
665 if (is_variable) {
666 return BuildVariableOp(tensor, shaped_type, builder, loc);
667 } else if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
668 TF_ASSIGN_OR_RETURN(value,
669 ConvertFloatBuffer(shaped_type, float_type, buffer));
670 } else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
671 TF_ASSIGN_OR_RETURN(value,
672 ConvertIntBuffer(shaped_type, elem_type, buffer));
673 } else if (elem_type.isa<mlir::TF::StringType>()) {
674 tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
675 std::vector<llvm::StringRef> refs;
676 refs.reserve(repr.string_val_size());
677
678 for (const auto& ref : repr.string_val())
679 refs.push_back({ref.data(), ref.size()});
680
681 value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
682 } else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
683 auto dialect = elem_type.getContext()->getLoadedDialect("tf");
684 tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
685 std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
686
687 value = mlir::OpaqueElementsAttr::get(dialect, shaped_type, mangled);
688 } else {
689 return errors::Unimplemented("Constant of unsupported type");
690 }
691
692 if (IsQuantized(tensor)) {
693 auto op = builder.create<tfl::QConstOp>(
694 loc, mlir::TypeAttr::get(shaped_type), value);
695 return op.getOperation();
696 }
697 auto op = builder.create<tfl::ConstOp>(loc, value);
698 return op.getOperation();
699 }
700
ConvertSubgraphIdxsToFunctionAttrs(tflite::BuiltinOptionsUnion options,const std::vector<std::string> & func_names,Builder builder)701 llvm::SmallVector<mlir::NamedAttribute, 4> ConvertSubgraphIdxsToFunctionAttrs(
702 tflite::BuiltinOptionsUnion options,
703 const std::vector<std::string>& func_names, Builder builder) {
704 if (auto* opts = options.AsCallOnceOptions()) {
705 uint32_t init_idx = opts->init_subgraph_index;
706 auto init_attr = builder.getStringAttr(func_names.at(init_idx));
707
708 return {builder.getNamedAttr("session_init_function", init_attr)};
709 }
710 if (auto* opts = options.AsIfOptions()) {
711 uint32_t then_idx = opts->then_subgraph_index;
712 auto then_attr = builder.getSymbolRefAttr(func_names.at(then_idx));
713 uint32_t else_idx = opts->else_subgraph_index;
714 auto else_attr = builder.getSymbolRefAttr(func_names.at(else_idx));
715
716 return {builder.getNamedAttr("then_branch", then_attr),
717 builder.getNamedAttr("else_branch", else_attr),
718 // TODO(b/139667752): Analyze statelessness correctly
719 builder.getNamedAttr("is_stateless", builder.getBoolAttr(false))};
720 }
721 if (auto* opts = options.AsWhileOptions()) {
722 uint32_t cond_idx = opts->cond_subgraph_index;
723 auto cond_attr = builder.getSymbolRefAttr(func_names.at(cond_idx));
724 uint32_t body_idx = opts->body_subgraph_index;
725 auto body_attr = builder.getSymbolRefAttr(func_names.at(body_idx));
726
727 return {builder.getNamedAttr("cond", cond_attr),
728 builder.getNamedAttr("body", body_attr)};
729 }
730 return {};
731 }
732
AddOpIntermediatesForLstm(const tflite::OperatorT & op,const std::vector<mlir::TensorType> & intermediate_types,OperationState & op_state,Location loc,OpBuilder & builder)733 Status AddOpIntermediatesForLstm(
734 const tflite::OperatorT& op,
735 const std::vector<mlir::TensorType>& intermediate_types,
736 OperationState& op_state, Location loc, OpBuilder& builder) {
737 if (!op.intermediates.empty()) {
738 if (op.intermediates.size() != 5) {
739 auto err = errors::InvalidArgument(
740 "operator has intermediate tensors but the number of them is not "
741 "five.");
742 return emitError(loc, err.ToString()), err;
743 }
744 // Create intermediate value
745
746 const llvm::SmallVector<llvm::StringRef, 5> kIntermediateNames = {
747 "input_to_input_intermediate", "input_to_forget_intermediate",
748 "input_to_cell_intermediate", "input_to_output_intermediate",
749 "effective_hidden_scale_intermediate"};
750 for (auto type_and_name :
751 llvm::zip(intermediate_types, kIntermediateNames)) {
752 mlir::TypeAttr type_attr =
753 mlir::TypeAttr::get(std::get<0>(type_and_name));
754 auto named_attr =
755 builder.getNamedAttr(std::get<1>(type_and_name), type_attr);
756 op_state.addAttribute(named_attr.first, named_attr.second);
757 }
758 }
759 return Status::OK();
760 }
761
762 // TODO(krzysd) Handle function calls
ConvertOp(const tflite::OperatorT & op,const std::vector<Value> & vals_map,const std::vector<mlir::TensorType> & intermediate_types,Value optional_arg_marker,const std::vector<std::unique_ptr<tflite::OperatorCodeT>> & op_codes,const std::vector<std::string> & func_names,const std::vector<std::unique_ptr<tflite::TensorT>> & tensors,Location loc,OpBuilder builder)763 StatusOr<Operation*> ConvertOp(
764 const tflite::OperatorT& op, const std::vector<Value>& vals_map,
765 const std::vector<mlir::TensorType>& intermediate_types,
766 Value optional_arg_marker,
767 const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
768 const std::vector<std::string>& func_names,
769 const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
770 OpBuilder builder) {
771 llvm::SmallVector<Value, 4> operands;
772 llvm::SmallVector<mlir::Type, 2> outputTypes;
773
774 const tflite::OperatorCodeT& op_code = *op_codes.at(op.opcode_index);
775
776 TF_ASSIGN_OR_RETURN(const std::string op_name, GetMlirOpName(op, op_code));
777
778 OperationState op_state(loc, op_name);
779
780 for (auto input_num : op.inputs) {
781 if (input_num == -1) {
782 assert(optional_arg_marker != nullptr);
783 op_state.addOperands({optional_arg_marker});
784 } else {
785 op_state.addOperands({vals_map.at(input_num)});
786 }
787 }
788
789 for (auto output_num : op.outputs) {
790 auto& tensor = *tensors.at(output_num);
791 auto type_or_err = GetTensorType(tensor, builder);
792 if (!type_or_err.ok()) {
793 return emitError(loc, type_or_err.status().ToString()),
794 type_or_err.status();
795 }
796 auto type = type_or_err.ConsumeValueOrDie();
797
798 if (op_name == "tfl.quantize") {
799 // Special case for quantize: return type must also be in qtype attribute
800 op_state.addAttribute("qtype", mlir::TypeAttr::get(type));
801 } else if (op_name == "tfl.reshape" && op_state.operands.size() == 1) {
802 // Special case for reshape: the second op is optional in the old
803 // converter and kernel, so we create the second operand, which is
804 // required by the new converter, from the reshape op's option.
805 auto new_shape = op.builtin_options.AsReshapeOptions()->new_shape;
806 auto shape_type = RankedTensorType::get(
807 {static_cast<int64_t>(new_shape.size())}, builder.getIntegerType(32));
808
809 mlir::SmallVector<mlir::Attribute, 4> shape;
810 for (auto s : new_shape) {
811 shape.push_back(builder.getI32IntegerAttr(static_cast<int32_t>(s)));
812 }
813 auto output_shape = DenseElementsAttr::get(shape_type, shape);
814 auto shape_op = builder.create<tfl::ConstOp>(loc, output_shape);
815 op_state.addOperands({shape_op});
816 }
817
818 op_state.addTypes({type});
819 }
820
821 // While the last several tensors could be optional tensors for an tfl op, the
822 // number of input operands could vary. Gets the min/max number of
823 // operands from tflite op name.
824 // Also, since the above code special-handles the `tfl.reshape` op and add an
825 // additional input, we put these function block here.
826 llvm::MinMax input_min_max = mlir::OperandNumbersMinMax(op_name);
827 int input_max_num = input_min_max.Max;
828 int op_input_num = op_state.operands.size();
829 if (input_max_num != 0 && input_max_num > op_input_num) {
830 // If the number of current inputs is less than the op definition, fill in
831 // with `none` value,
832 llvm::SmallVector<Value, 4> none_operands(
833 input_max_num - op_input_num,
834 builder.create<mlir::ConstantOp>(loc, builder.getNoneType(),
835 builder.getUnitAttr()));
836 op_state.addOperands(ArrayRef<Value>(none_operands));
837 }
838
839 if (op_name == "tfl.lstm") {
840 // TODO(b/147587779): add the right region if region is empty.
841 op_state.addRegion();
842 TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
843 builder));
844 }
845 if (op_name == "tfl.while") {
846 // Adds two empty regions for "tfl.while". We will fill the regions after
847 // creating the callee functions because the "tfl.while" input/output types
848 // may be different with the callee functions, and the call ops need to sync
849 // with callee function types.
850 op_state.addRegion();
851 op_state.addRegion();
852 }
853 if (op_name == "tfl.unidirectional_sequence_lstm") {
854 TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
855 builder));
856 }
857 if (op_name == "tfl.reshape") {
858 // Flattern reshape ops when more than one dimension shape operand is given.
859 mlir::DenseIntElementsAttr shape_attr;
860 if (matchPattern(op_state.operands[1], m_Constant(&shape_attr))) {
861 auto shape_ty =
862 op_state.operands[1].getType().dyn_cast<RankedTensorType>();
863 if (shape_ty != nullptr && shape_ty.hasRank() && shape_ty.getRank() > 1) {
864 llvm::SmallVector<mlir::Attribute, 4> shape;
865 int32_t dim_size = 0;
866 for (const auto& dim : llvm::enumerate(shape_attr.getIntValues())) {
867 const int64_t size = dim.value().getSExtValue();
868 shape.push_back(
869 builder.getI32IntegerAttr(static_cast<int32_t>(size)));
870 ++dim_size;
871 }
872 auto shape_type = RankedTensorType::get(
873 {static_cast<int32_t>(dim_size)}, builder.getIntegerType(32));
874 auto output_shape = mlir::DenseElementsAttr::get(shape_type, shape);
875 auto shape_op = builder.create<tfl::ConstOp>(loc, output_shape);
876 op_state.operands[1] = shape_op;
877 }
878 }
879 }
880
881 llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
882 auto builtin_code = tflite::GetBuiltinCode(&op_code);
883 if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
884 auto status = mlir::CustomOptionsToAttributes(
885 op_code.custom_code, op.custom_options, builder, loc, &attrs);
886 if (!status.ok()) {
887 return emitError(loc, status.ToString()), status;
888 }
889 } else {
890 mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs);
891 }
892 op_state.addAttributes(attrs);
893
894 // Handle the conversion from subgraph index to functions for If and While. We
895 // will add CallOps in the region to call the functions later for While.
896 auto function_ref_attrs = ConvertSubgraphIdxsToFunctionAttrs(
897 op.builtin_options, func_names, builder);
898 op_state.addAttributes(function_ref_attrs);
899
900 return builder.createOperation(op_state);
901 }
902
903 // Returns indices of the given tensors in the subgraph. Returns error if a
904 // tensor name cannot be found in the subgraph.
GetTensorIndices(const tflite::SubGraphT & subgraph,const std::vector<std::string> & tensor_names)905 StatusOr<std::vector<int>> GetTensorIndices(
906 const tflite::SubGraphT& subgraph,
907 const std::vector<std::string>& tensor_names) {
908 absl::flat_hash_map<std::string, int> name_to_index;
909 for (auto index_and_tensor : llvm::enumerate(subgraph.tensors)) {
910 name_to_index[index_and_tensor.value()->name] = index_and_tensor.index();
911 }
912
913 std::vector<int> indices;
914 indices.reserve(tensor_names.size());
915
916 for (const auto& name : tensor_names) {
917 auto found = name_to_index.find(name);
918 if (found != name_to_index.end()) {
919 indices.push_back(found->second);
920 } else {
921 return errors::InvalidArgument("could not find tensor in subgraph: ",
922 name);
923 }
924 }
925
926 return indices;
927 }
928
929 // Given a list of tensor indices, returns a string of concatenated tensor names
930 // wrapped in a NamedAttribute.
931 template <typename ContainerType>
BuildTFEntryFunctionAttribute(const tflite::SubGraphT & subgraph,Builder * builder,const std::string name,const ContainerType indices)932 mlir::NamedAttribute BuildTFEntryFunctionAttribute(
933 const tflite::SubGraphT& subgraph, Builder* builder, const std::string name,
934 const ContainerType indices) {
935 auto tensor_names = llvm::map_range(
936 indices, [&](int i) { return subgraph.tensors.at(i)->name; });
937 return builder->getNamedAttr(
938 name, builder->getStringAttr(llvm::join(tensor_names, ",")));
939 }
940
941 // Traverses the subgraph from output_indices to input_indices and returns the
942 // set of ops that are visited.
PruneSubgraph(const tflite::SubGraphT & subgraph,ArrayRef<int32_t> input_indices,ArrayRef<int32_t> output_indices)943 StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
944 const tflite::SubGraphT& subgraph, ArrayRef<int32_t> input_indices,
945 ArrayRef<int32_t> output_indices) {
946 // Create a map from tensor index to defining op.
947 absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
948 for (const auto& op : subgraph.operators) {
949 for (int32_t output : op->outputs) {
950 if (!llvm::is_contained(input_indices, output)) {
951 defining_op[output] = op.get();
952 }
953 }
954 }
955
956 std::vector<const tflite::OperatorT*> queue;
957 for (int32_t output : output_indices) {
958 if (auto& op = defining_op[output]) {
959 queue.push_back(op);
960 }
961 }
962
963 // Traverse the graph towards inputs.
964 absl::flat_hash_set<const tflite::OperatorT*> visited;
965 while (!queue.empty()) {
966 const tflite::OperatorT* op = queue.back();
967 queue.pop_back();
968 if (!visited.insert(op).second) {
969 // The node has already been visited.
970 continue;
971 }
972
973 for (int32_t input : op->inputs) {
974 // Input tensor may not have a defining op in case it is a subgraph input
975 // or a constant tensor.
976 if (auto& op = defining_op[input]) {
977 queue.push_back(op);
978 }
979 }
980 }
981
982 return visited;
983 }
984
985 // We want to adjust the func op according to some cross ops information.
PostProcessFuncOp(FuncOp func)986 static StatusOr<FuncOp> PostProcessFuncOp(FuncOp func) {
987 OpBuilder builder(func);
988 // When a quantized constant is imported, its quantization parameter is set
989 // to be narrow range. Here revert to be the fully range if the user doesn't
990 // require narrow range.
991 func.walk([&](tfl::QConstOp cst) {
992 Value value = cst.getResult();
993 Value full_range_const = value;
994 auto qtype = mlir::quant::UniformQuantizedType::getQuantizedElementType(
995 value.getType());
996 // Only the 8-bit constants are imported with narrow range.
997 if (!qtype || qtype.getStorageTypeIntegralWidth() != 8 ||
998 !(qtype.isa<mlir::quant::UniformQuantizedType>() ||
999 qtype.isa<mlir::quant::UniformQuantizedPerAxisType>())) {
1000 return;
1001 }
1002 for (auto& use : value.getUses()) {
1003 Operation* user = use.getOwner();
1004 if (user->hasTrait<mlir::OpTrait::IsTerminator>()) continue;
1005
1006 auto affine_user = llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
1007 if (affine_user &&
1008 affine_user.GetAffineOperandIndex() == use.getOperandNumber() &&
1009 affine_user.RequiredNarrowRangeAffineOperand())
1010 continue;
1011 // Create a fully range quantized constant.
1012 if (full_range_const == value) {
1013 mlir::quant::QuantizedType new_qtype;
1014 if (auto per_axis =
1015 qtype.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
1016 new_qtype = mlir::quant::UniformQuantizedPerAxisType::get(
1017 per_axis.getFlags(), per_axis.getStorageType(),
1018 per_axis.getExpressedType(), per_axis.getScales(),
1019 per_axis.getZeroPoints(), per_axis.getQuantizedDimension(),
1020 per_axis.getStorageTypeMin() - 1, per_axis.getStorageTypeMax());
1021 } else if (auto per_tensor =
1022 qtype.dyn_cast<mlir::quant::UniformQuantizedType>()) {
1023 new_qtype = mlir::quant::UniformQuantizedType::get(
1024 per_tensor.getFlags(), per_tensor.getStorageType(),
1025 per_tensor.getExpressedType(), per_tensor.getScale(),
1026 per_tensor.getZeroPoint(), per_tensor.getStorageTypeMin() - 1,
1027 per_tensor.getStorageTypeMax());
1028 } else {
1029 return; // Should not reach here, as it's already checked.
1030 }
1031 auto new_output_type = new_qtype.castFromExpressedType(
1032 mlir::quant::UniformQuantizedType::castToExpressedType(
1033 value.getType()));
1034 builder.setInsertionPointAfter(cst.getOperation());
1035 auto new_op = builder.create<tfl::QConstOp>(
1036 cst.getLoc(), new_output_type, mlir::TypeAttr::get(new_output_type),
1037 cst.valueAttr());
1038 full_range_const = new_op.output();
1039 }
1040 use.set(full_range_const);
1041 }
1042 if (cst.use_empty()) cst.erase();
1043 });
1044 return func;
1045 }
1046
1047 // Helper method that returns the index of the tensor with name 'tensor_name'
1048 // in the list of tensor names 'tensors'.
GetTensorIndex(const std::string & tensor_name,llvm::SmallVector<llvm::StringRef,2> tensors)1049 int GetTensorIndex(const std::string& tensor_name,
1050 llvm::SmallVector<llvm::StringRef, 2> tensors) {
1051 for (const auto& tensor_index_pair : llvm::enumerate(tensors)) {
1052 if (tensor_index_pair.value() == tensor_name)
1053 return tensor_index_pair.index();
1054 }
1055 return -1;
1056 }
1057
1058 // Helper method that returns list of all strings in a StringAttr identified
1059 // by 'attr_key' and values are separated by a comma.
GetStringsFromAttrWithSeparator(mlir::DictionaryAttr attr,const std::string & attr_key)1060 llvm::SmallVector<llvm::StringRef, 2> GetStringsFromAttrWithSeparator(
1061 mlir::DictionaryAttr attr, const std::string& attr_key) {
1062 llvm::SmallVector<llvm::StringRef, 2> result;
1063 if (auto str = attr.get(attr_key).dyn_cast_or_null<mlir::StringAttr>()) {
1064 str.getValue().split(result, ',', /*MaxSplit=*/-1,
1065 /*KeepEmpty=*/false);
1066 }
1067 return result;
1068 }
1069
1070 // Sets signature attributes on the function.
SetSignature(FuncOp func,const tflite::SignatureDefT * signature,const std::vector<std::unique_ptr<tflite::TensorT>> & tensors)1071 void SetSignature(
1072 FuncOp func, const tflite::SignatureDefT* signature,
1073 const std::vector<std::unique_ptr<tflite::TensorT>>& tensors) {
1074 auto* context = func->getContext();
1075 static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path";
1076 static const char kExportedNameAttr[] = "tf_saved_model.exported_names";
1077 static const char kEntryFunctionAttributes[] = "tf.entry_function";
1078
1079 auto dict_attr =
1080 func->getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes);
1081 if (!dict_attr) return;
1082
1083 // Get Input and output tensor names from attribute.
1084 llvm::SmallVector<llvm::StringRef, 2> input_names =
1085 GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs");
1086 llvm::SmallVector<llvm::StringRef, 2> output_names =
1087 GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs");
1088
1089 for (auto input_pair : llvm::enumerate(signature->inputs)) {
1090 const int arg_index = GetTensorIndex(
1091 tensors[input_pair.value()->tensor_index]->name, input_names);
1092 if (arg_index == -1) {
1093 func->emitWarning("Invalid signature tensors specified.");
1094 return;
1095 }
1096 func.setArgAttr(
1097 arg_index, kSignatureDefIndexPath,
1098 mlir::ArrayAttr::get(context, {mlir::StringAttr::get(
1099 context, input_pair.value()->name)}));
1100 }
1101 for (auto output_pair : llvm::enumerate(signature->outputs)) {
1102 const int arg_index = GetTensorIndex(
1103 tensors[output_pair.value()->tensor_index]->name, output_names);
1104 if (arg_index == -1) {
1105 func->emitWarning("Invalid signature tensors specified.");
1106 return;
1107 }
1108 func.setResultAttr(arg_index, kSignatureDefIndexPath,
1109 mlir::ArrayAttr::get(
1110 context, {mlir::StringAttr::get(
1111 context, output_pair.value()->name)}));
1112 }
1113 func->setAttr(
1114 kExportedNameAttr,
1115 mlir::ArrayAttr::get(
1116 context, {mlir::StringAttr::get(context, signature->signature_key)}));
1117 }
1118
1119 // Build a FuncOp from a tflite SubGraph
1120 // The buffers are directly taken
1121 // from the deserialized flatbuffer as we do not have the type information to
1122 // interpret them until this point. The base_loc parameter is the location of
1123 // the flatbuffer as a whole (usually a file). The is_entry_point flag
1124 // controls whether shapeless types are treated as scalars. If
1125 // ordered_output_arrays is not empty, then the imported mlir function will only
1126 // return nodes in ordered_output_arrays in the same order.
1127 // If signature is not null, then the inputs/outputs in signature will be
1128 // attached to the FuncOp.
ConvertSubgraph(const tflite::SubGraphT & subgraph,llvm::StringRef name,const std::vector<std::unique_ptr<tflite::OperatorCodeT>> & op_codes,const std::vector<std::string> & func_names,const std::vector<std::unique_ptr<tflite::BufferT>> & buffers,Location base_loc,Builder builder,bool is_entry_point,bool use_external_constant,const std::vector<std::string> & ordered_input_arrays,const std::vector<std::string> & ordered_output_arrays,bool experimental_prune_unreachable_nodes_unconditionally,const tflite::SignatureDefT * signature)1129 StatusOr<FuncOp> ConvertSubgraph(
1130 const tflite::SubGraphT& subgraph, llvm::StringRef name,
1131 const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
1132 const std::vector<std::string>& func_names,
1133 const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
1134 Location base_loc, Builder builder, bool is_entry_point,
1135 bool use_external_constant,
1136 const std::vector<std::string>& ordered_input_arrays,
1137 const std::vector<std::string>& ordered_output_arrays,
1138 bool experimental_prune_unreachable_nodes_unconditionally,
1139 const tflite::SignatureDefT* signature) {
1140 llvm::SmallVector<mlir::Type, 2> ret_types;
1141 llvm::SmallVector<mlir::Type, 4> input_types;
1142
1143 auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc);
1144
1145 std::vector<int> func_inputs = subgraph.inputs;
1146 if (is_entry_point && !ordered_input_arrays.empty()) {
1147 if (!experimental_prune_unreachable_nodes_unconditionally) {
1148 // TODO(b/149922113): Resolve input-arrays/pruning flags interaction.
1149 return errors::InvalidArgument(
1150 "input-arrays should be used with experimental pruning flag");
1151 }
1152 TF_ASSIGN_OR_RETURN(func_inputs,
1153 GetTensorIndices(subgraph, ordered_input_arrays));
1154 }
1155
1156 for (int input : func_inputs) {
1157 auto& tensor = *subgraph.tensors.at(input);
1158 // TODO(b/138222071) Graph inputs must have static shape per the exporter,
1159 // but we cannot differentiate scalars from unranked tensors.
1160 // Here we reverse the default assumption that shape = [] means unranked.
1161 // when processing main()
1162 auto type_or_err = GetTensorType(tensor, builder,
1163 /*shapeless_are_scalars=*/is_entry_point,
1164 /*is_constant=*/false);
1165 if (!type_or_err.ok()) {
1166 emitError(func_loc, "error reading argument types")
1167 << type_or_err.status().ToString();
1168 return type_or_err.status();
1169 }
1170 auto type = type_or_err.ConsumeValueOrDie();
1171 input_types.push_back(type);
1172 }
1173
1174 llvm::SmallVector<bool, 16> is_op_output(subgraph.tensors.size(), false);
1175 for (auto& op : subgraph.operators) {
1176 for (auto output : op->outputs) {
1177 is_op_output[output] = true;
1178 }
1179 }
1180
1181 std::vector<int> func_outputs = subgraph.outputs;
1182 if (is_entry_point && !ordered_output_arrays.empty()) {
1183 TF_ASSIGN_OR_RETURN(func_outputs,
1184 GetTensorIndices(subgraph, ordered_output_arrays));
1185 }
1186
1187 for (auto output : func_outputs) {
1188 const bool is_func_input = std::find(func_inputs.begin(), func_inputs.end(),
1189 output) != func_inputs.end();
1190 bool is_constant = !is_op_output[output] && !is_func_input;
1191 // There are 2 cases tensor is scalar when it doesn't have a shape in
1192 // flatbuffer:
1193 // 1. `is_constant` = true, means this tensor is created from a constant op.
1194 // 2. `is_func_input` = true and `is_entry_point` = true, which means this
1195 // tensor is function input and function input type is a scalar tensor.
1196 const bool shapeless_is_scalar =
1197 is_constant || (is_func_input && is_entry_point);
1198 auto type_or_err = GetTensorType(*subgraph.tensors.at(output), builder,
1199 shapeless_is_scalar,
1200 /*is_constant=*/is_constant);
1201 if (!type_or_err.ok()) {
1202 emitError(func_loc, "error reading return types")
1203 << type_or_err.status().ToString();
1204 return type_or_err.status();
1205 }
1206 auto type = type_or_err.ConsumeValueOrDie();
1207 ret_types.push_back(type);
1208 }
1209 auto func_type = builder.getFunctionType(input_types, ret_types);
1210
1211 // Construct function object
1212 auto func = FuncOp::create(func_loc, name, func_type, /* attrs= */ {});
1213 func.addEntryBlock();
1214 auto& body = func.getBody();
1215 OpBuilder op_builder{body};
1216
1217 std::vector<Value> vals_map(subgraph.tensors.size(), nullptr);
1218 Value maybe_optional_arg_marker = nullptr;
1219
1220 // Get or construct MLIR values for each input
1221 for (int i = 0, e = func_inputs.size(); i < e; i++) {
1222 auto input_tensor = func_inputs[i];
1223 const auto& tensor = *subgraph.tensors.at(input_tensor);
1224 auto loc = TensorLoc(tensor, builder, base_loc);
1225 if (vals_map[input_tensor]) {
1226 auto err = errors::FailedPrecondition("duplicate input arguments");
1227 return emitError(loc, err.ToString()), err;
1228 }
1229 Value input_value = func.getArgument(i);
1230
1231 // If the `tensor` has min/max and doesn't have scale/zero_point
1232 // information, a stats op is created to use the input_value, then the
1233 // `tensor` should be mapped to the result of this new stats op.
1234 if (auto stats_op =
1235 ConvertMinMaxToStatsOp(tensor, op_builder, input_value)) {
1236 vals_map[input_tensor] = stats_op->getResult(0);
1237 } else {
1238 vals_map[input_tensor] = input_value;
1239 }
1240 }
1241
1242 // Set tf.entry_function attribute
1243 if (is_entry_point) {
1244 llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
1245 if (!func_inputs.empty()) {
1246 attributes.push_back(BuildTFEntryFunctionAttribute(
1247 subgraph, &builder, "inputs", func_inputs));
1248 }
1249 if (!func_outputs.empty()) {
1250 attributes.push_back(BuildTFEntryFunctionAttribute(
1251 subgraph, &builder, "outputs", func_outputs));
1252 }
1253 func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
1254 } else {
1255 func.setPrivate();
1256 }
1257
1258 // Set signature on function.
1259 if (signature) {
1260 SetSignature(func, signature, subgraph.tensors);
1261 }
1262
1263 absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
1264 if (experimental_prune_unreachable_nodes_unconditionally) {
1265 TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
1266 PruneSubgraph(subgraph, func_inputs, func_outputs));
1267 }
1268
1269 // Construct MLIR operators from TFLite operators
1270 for (auto& op : subgraph.operators) {
1271 if (experimental_prune_unreachable_nodes_unconditionally &&
1272 !pruned_subgraph_ops.contains(op)) {
1273 continue;
1274 }
1275
1276 for (auto input_num : op->inputs) {
1277 // The operators in a graph are topologically sorted
1278 // and so if no previous operation has produced a tensor
1279 // it must be a constant.
1280 if (input_num == -1) {
1281 if (maybe_optional_arg_marker == nullptr) {
1282 maybe_optional_arg_marker =
1283 op_builder
1284 .create<mlir::ConstantOp>(base_loc, builder.getNoneType(),
1285 builder.getUnitAttr())
1286 .getResult();
1287 }
1288 } else if (!vals_map.at(input_num)) {
1289 auto& const_tensor = *subgraph.tensors[input_num];
1290 auto const_loc = TensorLoc(const_tensor, builder, base_loc);
1291 auto op_or_err =
1292 use_external_constant
1293 ? BuildExternalConstOp(const_tensor, const_tensor.buffer,
1294 op_builder, const_loc)
1295 : BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
1296 const_tensor.is_variable, op_builder, const_loc);
1297 if (!op_or_err.ok()) {
1298 return emitError(const_loc, op_or_err.status().ToString()),
1299 op_or_err.status();
1300 }
1301 vals_map[input_num] = op_or_err.ValueOrDie()->getResult(0);
1302 }
1303 }
1304
1305 // Intermediate tensors for LSTMs are used to carry quantization range
1306 // in their types, so we only need and extract their types.
1307 std::vector<mlir::TensorType> intermediate_types;
1308 intermediate_types.reserve(5);
1309 for (auto intermediate : op->intermediates) {
1310 TF_ASSIGN_OR_RETURN(
1311 auto type, GetTensorType(*subgraph.tensors[intermediate], builder,
1312 /*shapeless_are_scalars=*/true,
1313 /*is_constant=*/false,
1314 /*is_intermediate=*/true));
1315 intermediate_types.emplace_back(type);
1316 }
1317
1318 auto op_loc = OpLoc(*op, subgraph.tensors, builder, base_loc);
1319
1320 // If there's an optional argument, maybe_optional_arg_marker has been set
1321 // to a valid Value
1322 TF_ASSIGN_OR_RETURN(
1323 auto* mlir_op,
1324 ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
1325 op_codes, func_names, subgraph.tensors, op_loc, op_builder));
1326
1327 // Add the results to the value maps. There are two cases: 1. the result
1328 // tensor does not have min/max values, the original op result is used
1329 // directly; 2. the result tensor has some min/max values, a stats op is
1330 // created, then the result of the stats op is used.
1331 for (auto pair : llvm::enumerate(mlir_op->getResults())) {
1332 int output_tensor_index = op->outputs[pair.index()];
1333 auto& tensor = *subgraph.tensors[output_tensor_index];
1334 if (auto stats_op =
1335 ConvertMinMaxToStatsOp(tensor, op_builder, pair.value())) {
1336 vals_map[output_tensor_index] = stats_op->getResult(0);
1337 } else {
1338 vals_map[output_tensor_index] = pair.value();
1339 }
1340 }
1341 }
1342
1343 // Construct return values
1344 llvm::SmallVector<Value, 4> return_operands;
1345 for (auto index : func_outputs) {
1346 if (!vals_map.at(index)) {
1347 auto& const_tensor = *subgraph.tensors[index];
1348 auto const_loc = TensorLoc(const_tensor, builder, base_loc);
1349 auto op_or_err =
1350 use_external_constant
1351 ? BuildExternalConstOp(const_tensor, const_tensor.buffer,
1352 op_builder, const_loc)
1353 : BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
1354 const_tensor.is_variable, op_builder, const_loc);
1355 if (!op_or_err.ok()) {
1356 return emitError(const_loc, op_or_err.status().ToString()),
1357 op_or_err.status();
1358 }
1359 vals_map[index] = op_or_err.ValueOrDie()->getResult(0);
1360 }
1361 return_operands.push_back(vals_map[index]);
1362 }
1363
1364 op_builder.create<mlir::ReturnOp>(base_loc, return_operands);
1365
1366 return PostProcessFuncOp(func);
1367 }
1368
1369 // TFLite subgraphs do not necessarily have names, though MLIR functions must
1370 // have them, so we generate a name for subgraphs that are missing one here.
1371 // Note: in TFLite, the first subgraph is the entry point, and in MLIR that
1372 // represents TFLite, this entry point must be called "main"
1373 // TODO(b/131175224,b/132239787) Support multiple entry points
SubgraphName(unsigned index,const tflite::SubGraphT & subgraph)1374 std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
1375 if (index == 0) {
1376 return "main";
1377 }
1378 if (subgraph.name.empty()) {
1379 return llvm::formatv("fn_{0}", index).str();
1380 }
1381 return subgraph.name;
1382 }
1383
1384 // Adds a CallOp in `region` to call the `func` and returns the results of
1385 // CallOp.
AddCallOpInWhileOpRegion(mlir::Region & region,mlir::FuncOp func)1386 void AddCallOpInWhileOpRegion(mlir::Region& region, mlir::FuncOp func) {
1387 OpBuilder op_builder{region};
1388 region.push_back(new mlir::Block());
1389 region.addArguments(func.getType().getInputs());
1390 op_builder.setInsertionPointToStart(®ion.front());
1391 auto call_op = op_builder.create<mlir::CallOp>(
1392 region.getLoc(), func.getType().getResults(), func.sym_name(),
1393 region.getArguments());
1394 op_builder.create<mlir::TFL::YieldOp>(region.getLoc(), call_op.getResults());
1395 }
1396
1397 // TFL::WhileOp has regions, so we add CallOp to call the FuncOp in the regions
1398 // if we have while ops.
AddRegionsForTflWhileOp(mlir::ModuleOp module)1399 void AddRegionsForTflWhileOp(mlir::ModuleOp module) {
1400 mlir::SymbolTable symbol_table(module);
1401 module.walk([&](mlir::TFL::WhileOp while_op) {
1402 auto cond = symbol_table.lookup<mlir::FuncOp>(
1403 while_op->getAttr("cond").cast<mlir::FlatSymbolRefAttr>().getValue());
1404 AddCallOpInWhileOpRegion(while_op.cond(), cond);
1405 while_op->removeAttr("cond");
1406 auto body = symbol_table.lookup<mlir::FuncOp>(
1407 while_op->getAttr("body").cast<mlir::FlatSymbolRefAttr>().getValue());
1408 AddCallOpInWhileOpRegion(while_op.body(), body);
1409 while_op->removeAttr("body");
1410 });
1411 }
1412 } // namespace
1413
FlatBufferToMlir(absl::string_view buffer,MLIRContext * context,Location base_loc,bool use_external_constant,const std::vector<std::string> & ordered_input_arrays,const std::vector<std::string> & ordered_output_arrays,bool experimental_prune_unreachable_nodes_unconditionally)1414 OwningModuleRef tflite::FlatBufferToMlir(
1415 absl::string_view buffer, MLIRContext* context, Location base_loc,
1416 bool use_external_constant,
1417 const std::vector<std::string>& ordered_input_arrays,
1418 const std::vector<std::string>& ordered_output_arrays,
1419 bool experimental_prune_unreachable_nodes_unconditionally) {
1420 context->loadDialect<
1421 mlir::StandardOpsDialect, mlir::quant::QuantizationDialect,
1422 mlir::TFL::TensorFlowLiteDialect, mlir::TF::TensorFlowDialect>();
1423
1424 auto model_ptr =
1425 FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
1426 if (nullptr == model_ptr) {
1427 return emitError(base_loc, "couldn't parse flatbuffer"), nullptr;
1428 }
1429
1430 std::unique_ptr<ModelT> model(model_ptr->GetModel()->UnPack());
1431
1432 auto builder = Builder(context);
1433
1434 std::vector<std::string> func_names;
1435 for (auto& subgraph : model->subgraphs) {
1436 func_names.push_back(subgraph->name);
1437 }
1438
1439 auto module = mlir::ModuleOp::create(base_loc);
1440 // We currently don't use this to make decisions, but we could
1441 // use it in exports or if there are breaking changes
1442 module->setAttr("tfl.schema_version",
1443 builder.getI32IntegerAttr(model->version));
1444 if (!model->description.empty()) {
1445 module->setAttr("tfl.description",
1446 builder.getStringAttr(model->description));
1447 }
1448
1449 // TODO(b/184697652): Update to handle multiple entry points.
1450 tflite::SignatureDefT* signature_def = nullptr;
1451 if (!model->signature_defs.empty()) {
1452 signature_def = model->signature_defs[0].get();
1453 }
1454
1455 for (auto e : llvm::enumerate(model->subgraphs)) {
1456 auto& subgraph = e.value();
1457 std::string name = SubgraphName(e.index(), *subgraph);
1458 auto func_or_error = ConvertSubgraph(
1459 *subgraph, name, model->operator_codes, func_names, model->buffers,
1460 base_loc, builder,
1461 // TODO(b/131175224,b/132239787) Support multiple entry points
1462 /*is_entry_point=*/e.index() == 0,
1463 /*use_external_constant=*/use_external_constant, ordered_input_arrays,
1464 ordered_output_arrays,
1465 experimental_prune_unreachable_nodes_unconditionally,
1466 e.index() == 0 ? signature_def : nullptr);
1467 if (!func_or_error.ok()) {
1468 return emitError(base_loc, "could not translate function ")
1469 << subgraph->name << ": "
1470 << func_or_error.status().error_message(),
1471 nullptr;
1472 }
1473 module.push_back(func_or_error.ConsumeValueOrDie());
1474 }
1475 AddRegionsForTflWhileOp(module);
1476 return OwningModuleRef(module);
1477 }
1478