1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
17
18 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
19 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
20 #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
21 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
23 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
24 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
25
26 // Implements legalization and post-legalization optimization helper functions
27
28 namespace mlir {
29 namespace tosa {
30
31 // Create a TOSA rescale op from TFLite scaling, zero points and rounding mode
buildRescale(PatternRewriter & rewriter,Operation * op,ShapedType output_type,Value input_val,double scale,int64_t input_zp,int64_t output_zp,bool double_round,bool scale32)32 Value buildRescale(PatternRewriter& rewriter, Operation* op,
33 ShapedType output_type, Value input_val, double scale,
34 int64_t input_zp, int64_t output_zp, bool double_round,
35 bool scale32) {
36 int32_t multiplier;
37 int32_t shift;
38
39 int32_t scale_width = scale32 ? 32 : 16;
40
41 computeMultiplierAndShift(scale, multiplier, shift, scale_width);
42
43 auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
44 rewriter, op->getLoc(), output_type, input_val,
45 rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)),
46 rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)),
47 rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}),
48 rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round),
49 rewriter.getBoolAttr(false));
50
51 return rescale_op.getResult();
52 }
53
54 // Creates TOSA rescale op with int32 output
buildRescaleToInt32(PatternRewriter & rewriter,Operation * op,Value input_val,double input_scale,int64_t input_zp)55 Value buildRescaleToInt32(PatternRewriter& rewriter, Operation* op,
56 Value input_val, double input_scale,
57 int64_t input_zp) {
58 // Output is always int32 type
59 auto input_type = input_val.getType().dyn_cast<mlir::ShapedType>();
60 assert(input_type);
61 auto output_type = input_type.clone(rewriter.getI32Type());
62
63 return buildRescale(rewriter, op, output_type, input_val, input_scale,
64 input_zp, 0, false, true);
65 }
66
67 // Creates TOSA rescale op with int32 input
buildRescaleFromInt32(PatternRewriter & rewriter,Operation * op,ShapedType output_type,Value input_val,double output_scale,int64_t output_zp)68 Value buildRescaleFromInt32(PatternRewriter& rewriter, Operation* op,
69 ShapedType output_type, Value input_val,
70 double output_scale, int64_t output_zp) {
71 // Input should be int32 type
72 auto input_type = input_val.getType().dyn_cast<mlir::ShapedType>();
73 (void)input_type;
74 assert(input_type && input_type.getElementType().isInteger(32) &&
75 "expected rescale input element type to be i32");
76
77 // Potentially check input_shape == output_shape here
78 return buildRescale(rewriter, op, output_type, input_val, output_scale, 0,
79 output_zp, true, true);
80 }
81
82 // Creates a TOSA rescale op based on conv2d parameters.
buildRescaleOpConvOutput(PatternRewriter & rewriter,Operation * op,Value conv_val,ShapedType input_type,ShapedType weight_type,ShapedType output_type)83 Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op,
84 Value conv_val, ShapedType input_type,
85 ShapedType weight_type, ShapedType output_type) {
86 auto input_qtype =
87 input_type.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
88 auto output_qtype = output_type.getElementType()
89 .dyn_cast<mlir::quant::UniformQuantizedType>();
90
91 double input_scale = input_qtype.getScale();
92
93 int64_t output_zp = output_qtype.getZeroPoint();
94 double output_scale = output_qtype.getScale();
95
96 bool scale32 = isScale32(output_qtype);
97 int32_t scale_width = scale32 ? 32 : 16;
98 // Only use double round if we are doing 32 bit scaling
99 bool double_round = scale32;
100
101 if (auto weight_per_tensor_qtype =
102 weight_type.getElementType()
103 .dyn_cast<mlir::quant::UniformQuantizedType>()) {
104 // Per-tensor quantization
105 double weight_scale = weight_per_tensor_qtype.getScale();
106
107 int32_t multiplier;
108 int32_t shift;
109
110 double op_tensor_scale = (input_scale * weight_scale) / output_scale;
111
112 computeMultiplierAndShift(op_tensor_scale, multiplier, shift, scale_width);
113
114 auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
115 rewriter, op->getLoc(), output_type, conv_val,
116 rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
117 rewriter.getI32ArrayAttr({multiplier}),
118 rewriter.getI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32),
119 rewriter.getBoolAttr(double_round), rewriter.getBoolAttr(false));
120
121 return rescale_op.getResult();
122
123 } else if (auto weight_per_channel_qtype =
124 weight_type.getElementType()
125 .dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
126 // Per-channel quantization
127 SmallVector<int32_t> multiplier_arr;
128 SmallVector<int32_t> shift_arr;
129
130 SmallVector<double> weight_scale_arr(
131 weight_per_channel_qtype.getScales().begin(),
132 weight_per_channel_qtype.getScales().end());
133
134 int64_t output_zp = output_qtype.getZeroPoint();
135 double output_scale = output_qtype.getScale();
136
137 for (double weight_scale : weight_scale_arr) {
138 int32_t multiplier;
139 int32_t shift;
140
141 double op_channel_scale = (input_scale * weight_scale) / output_scale;
142
143 computeMultiplierAndShift(op_channel_scale, multiplier, shift,
144 scale_width);
145
146 multiplier_arr.push_back(multiplier);
147 shift_arr.push_back(shift);
148 }
149
150 auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
151 rewriter, op->getLoc(), output_type, conv_val,
152 rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
153 rewriter.getI32ArrayAttr(multiplier_arr),
154 rewriter.getI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32),
155 rewriter.getBoolAttr(double_round), rewriter.getBoolAttr(true));
156
157 return rescale_op.getResult();
158
159 } else {
160 op->emitOpError("buildConvRescaleOp: unknown weight quantized type");
161 return nullptr;
162 }
163 }
164
165 // Create a 8-bit TOSA TABLE constant tensor with int8[256] array.
166 // Follow PopulateLookupTable() tensorflow/lite/kernels/activations.cc
getTosaConst8bitTable(PatternRewriter & rewriter,Operation * op,double input_scale,int32_t input_zp,double output_scale,int32_t output_zp,std::function<double (double)> func)167 Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op,
168 double input_scale, int32_t input_zp,
169 double output_scale, int32_t output_zp,
170 std::function<double(double)> func) {
171 SmallVector<int8_t, 256> table;
172
173 for (int32_t i = -128; i < 128; i++) {
174 double dequantized = input_scale * (i - input_zp);
175 double transformed = func(dequantized);
176 int32_t rescaled = std::llround(transformed / output_scale);
177 int32_t quantized = static_cast<int32_t>(rescaled + output_zp);
178 table.push_back(
179 static_cast<int8_t>(std::min(std::max(quantized, -128), 127)));
180 }
181
182 auto element_qtype =
183 UniformQuantizedType::get(true, rewriter.getIntegerType(8),
184 rewriter.getF32Type(), 1.0f, 0, -128, 127);
185 auto const_type = RankedTensorType::get({256}, element_qtype);
186 auto storage_type =
187 RankedTensorType::get({256}, element_qtype.getStorageType());
188 auto const_attr =
189 DenseElementsAttr::get(storage_type, llvm::makeArrayRef(table));
190
191 auto const_op =
192 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
193 return const_op.getResult();
194 }
195
196 // Create a 16-bit TOSA TABLE constant tensor with int16[513] array.
197 // Output is restricted to [-1.0, 1.0].
198 // Follow gen_lut() tensorflow/lite/kernels/internal/common.h
getTosaConst16bitTable(PatternRewriter & rewriter,Operation * op,std::function<double (double)> func,double min,double max)199 Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op,
200 std::function<double(double)> func, double min,
201 double max) {
202 SmallVector<int16_t, 513> table;
203
204 double step = (max - min) / 512.0f;
205 double half_step = step / 2.0f;
206 for (int32_t i = 0; i < 512; i++) {
207 int32_t sample_val = std::llround(func(min + (i * step)) * 32768.0);
208 double midpoint_interp_val =
209 std::round(((func(min + (i + 1) * step) * 32768.0) +
210 std::round(func(min + (i * step)) * 32768.0)) /
211 2.0);
212 double midpoint_val =
213 std::round(func(min + (i * step) + half_step) * 32768.0);
214 double midpoint_err = midpoint_interp_val - midpoint_val;
215 int32_t bias = std::llround(midpoint_err / 2.0);
216
217 table.push_back(static_cast<int16_t>(
218 std::min(std::max(sample_val - bias, -32768), 32767)));
219 }
220
221 int32_t max_val = std::llround(func(max) * 32768.0);
222 table.push_back(
223 static_cast<int16_t>(std::min(std::max(max_val, -32768), 32767)));
224
225 auto element_qtype =
226 UniformQuantizedType::get(true, rewriter.getIntegerType(16),
227 rewriter.getF32Type(), 1.0f, 0, -32768, 32767);
228 auto const_type = RankedTensorType::get({513}, element_qtype);
229 auto storage_type =
230 RankedTensorType::get({513}, element_qtype.getStorageType());
231 auto const_attr =
232 DenseElementsAttr::get(storage_type, llvm::makeArrayRef(table));
233
234 auto const_op =
235 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
236 return const_op.getResult();
237 }
238
239 // Create a 32-bit TOSA TABLE constant tensor with int16[513] array.
240 // Output is restricted to [-1.0, 1.0] as s0.31 format.
getTosaConst32bitTable(PatternRewriter & rewriter,Operation * op,double input_scale,int32_t input_zp,std::function<double (double)> func,Value & first_const,Value & second_const,Value & third_const,Value & fourth_const)241 void getTosaConst32bitTable(PatternRewriter& rewriter, Operation* op,
242 double input_scale, int32_t input_zp,
243 std::function<double(double)> func,
244 Value& first_const, Value& second_const,
245 Value& third_const, Value& fourth_const) {
246 SmallVector<int16_t, 513> first_table, second_table, third_table,
247 fourth_table;
248
249 double output_inv_scale = static_cast<double>(1L << 31);
250
251 for (int32_t i = -256; i <= 256; i++) {
252 double dequantized = input_scale * (i - input_zp);
253 double transformed = func(dequantized);
254 double truncated = std::min(std::max(transformed, -1.0), 1.0);
255 int64_t rescaled =
256 static_cast<int64_t>(std::round(truncated * output_inv_scale));
257
258 // 2^31 is not representable in int32_t, so store as 2^31 - 1 instead
259 if (rescaled == static_cast<int64_t>(1L << 31)) {
260 rescaled = static_cast<int64_t>(1L << 31) - 1;
261 }
262
263 // Only copy the 8-bit groups
264 int32_t first = (rescaled >> 24) & 0xFF;
265 int32_t second = (rescaled >> 16) & 0xFF;
266 int32_t third = (rescaled >> 8) & 0xFF;
267 int32_t fourth = (rescaled)&0xFF;
268
269 first_table.push_back(first);
270 second_table.push_back(second);
271 third_table.push_back(third);
272 fourth_table.push_back(fourth);
273 }
274
275 auto element_qtype =
276 UniformQuantizedType::get(true, rewriter.getIntegerType(16),
277 rewriter.getF32Type(), 1.0f, 0, -32768, 32767);
278 auto const_type = RankedTensorType::get({513}, element_qtype);
279 auto storage_type =
280 RankedTensorType::get({513}, element_qtype.getStorageType());
281
282 auto first_const_attr =
283 DenseElementsAttr::get(storage_type, llvm::makeArrayRef(first_table));
284 auto second_const_attr =
285 DenseElementsAttr::get(storage_type, llvm::makeArrayRef(second_table));
286 auto third_const_attr =
287 DenseElementsAttr::get(storage_type, llvm::makeArrayRef(third_table));
288 auto fourth_const_attr =
289 DenseElementsAttr::get(storage_type, llvm::makeArrayRef(fourth_table));
290
291 first_const =
292 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, first_const_attr)
293 .getResult();
294 second_const =
295 rewriter
296 .create<tosa::ConstOp>(op->getLoc(), const_type, second_const_attr)
297 .getResult();
298 third_const =
299 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, third_const_attr)
300 .getResult();
301 fourth_const =
302 rewriter
303 .create<tosa::ConstOp>(op->getLoc(), const_type, fourth_const_attr)
304 .getResult();
305 }
306
307 // Create a 32-bit float constant operator from a float
getTosaConstTensorSingleF32(PatternRewriter & rewriter,Operation * op,float val)308 Value getTosaConstTensorSingleF32(PatternRewriter& rewriter, Operation* op,
309 float val) {
310 auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
311 auto const_attr = DenseElementsAttr::get(const_type, val);
312
313 auto const_op =
314 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
315 return const_op.getResult();
316 }
317
318 // Create a 32-bit integer constant operator from an int
getTosaConstTensorSingleI32(PatternRewriter & rewriter,Operation * op,int32_t val)319 Value getTosaConstTensorSingleI32(PatternRewriter& rewriter, Operation* op,
320 int32_t val) {
321 auto const_type = RankedTensorType::get({}, rewriter.getIntegerType(32));
322 auto const_attr = DenseElementsAttr::get(const_type, val);
323
324 auto const_op =
325 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
326 return const_op.getResult();
327 }
328
329 // Create a vector from a 32-bit value tensor. Returns the size of
330 // the new vector or -1 on error.
getVectorFromValue32(Value val,SmallVectorImpl<int32_t> & vec)331 LogicalResult getVectorFromValue32(Value val, SmallVectorImpl<int32_t>& vec) {
332 int i = 0;
333
334 ElementsAttr elems;
335
336 vec.clear();
337
338 if (!matchPattern(val, m_Constant(&elems))) return failure();
339
340 for (auto idx : elems.getValues<IntegerAttr>()) {
341 vec.push_back(idx.getInt());
342 i++;
343 }
344
345 return success();
346 }
347
348 // Calculates the TOSA padding values based on TF operators padded with
349 // SAME/VALID.
350 //
351 // This could pass tensorflow::FilterTensorFormat and do
352 // GetFilterTensorSpatialDimIndex but the current TF core libs do not support
353 // FORMAT_OHWI parsing by that function in core/util/tensor_format.h
getPaddingValuesFromPadType(tensorflow::Padding tf_pad,tensorflow::TensorFormat data_format_tf,uint32_t first_filter_spatial_dim,ShapedType input_type,ShapedType filter_type,ArrayAttr strides,ArrayAttr dilations,PatternRewriter & rewriter,ArrayAttr & explicit_padding)354 bool getPaddingValuesFromPadType(tensorflow::Padding tf_pad,
355 tensorflow::TensorFormat data_format_tf,
356 uint32_t first_filter_spatial_dim,
357 ShapedType input_type, ShapedType filter_type,
358 ArrayAttr strides, ArrayAttr dilations,
359 PatternRewriter& rewriter,
360 ArrayAttr& explicit_padding) {
361 assert(tf_pad != tensorflow::Padding::EXPLICIT);
362 if (!input_type.hasRank() || !filter_type.getRank()) return false;
363
364 // Storing the numeric padding values is useful for TOSA codegen, as opposed
365 // to holding the padding regime mnemonic, i.e. SAME, VALID, FULL, ...
366 SmallVector<int64_t> computed_paddings;
367
368 int64_t pad_before, pad_after;
369 for (int i = 0; i < 2; i++) { // Two spatial dimensions X&Y
370 int64_t ifm_dim = GetTensorSpatialDimIndex(
371 4, data_format_tf, i); // 4D tensor, NHWC/NCHW format
372 int64_t filter_dim = first_filter_spatial_dim + i;
373
374 int64_t dim_dilation = dilations[i].template cast<IntegerAttr>().getInt();
375 int64_t dim_stride = strides[i].template cast<IntegerAttr>().getInt();
376
377 int64_t ip_size = input_type.getDimSize(ifm_dim);
378 int64_t f_size = filter_type.getDimSize(filter_dim);
379 // If we have a dynamic shape we should assume it is wide enough.
380 ip_size = ip_size < 0 ? f_size * dim_dilation : ip_size;
381 int64_t op_size, pad_before_tf,
382 pad_after_tf; // Complains if using int64_T
383 tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
384 ip_size, f_size, dim_dilation, dim_stride, tf_pad, &op_size,
385 &pad_before_tf, &pad_after_tf);
386 if (!status.ok()) return false;
387
388 pad_before = pad_before_tf;
389 pad_after = pad_after_tf;
390 computed_paddings.push_back(pad_before);
391 computed_paddings.push_back(pad_after);
392 }
393
394 explicit_padding = rewriter.getI64ArrayAttr(computed_paddings);
395 return true;
396 }
397
398 // Calculates the TOSA padding values for explicit-padded TF operators.
399 //
400 // This function only handles the TF padding array explicit_padding, which is
401 // only present in certain TF ops. All others encode padding using the string
402 // SAME/VALID, which is interpreted using the getPaddingValuesFromPadString
403 // function below.
404
405 // The explicit padding array in TF holds 2 pad values for every
406 // dimension, even those that are not the 2 spatial ones. Just extract the
407 // 2x pad values for the XY dims.
getPaddingValuesFromExplicitPadAttr(ArrayAttr explicit_pad,tensorflow::TensorFormat data_format_tf,PatternRewriter & rewriter)408 ArrayAttr getPaddingValuesFromExplicitPadAttr(
409 ArrayAttr explicit_pad, tensorflow::TensorFormat data_format_tf,
410 PatternRewriter& rewriter) {
411 SmallVector<int64_t> computed_paddings;
412
413 int64_t pad_before, pad_after;
414 for (int i = 0; i < 2; i++) { // Two spatial dimensions X&Y
415 int64_t dim = GetTensorSpatialDimIndex(4, data_format_tf,
416 i); // 4D tensor, NHWC/NCHW format
417
418 pad_before = explicit_pad[dim * 2].template cast<IntegerAttr>().getInt();
419 pad_after = explicit_pad[dim * 2 + 1].template cast<IntegerAttr>().getInt();
420 computed_paddings.push_back(pad_before);
421 computed_paddings.push_back(pad_after);
422 }
423
424 return rewriter.getI64ArrayAttr(computed_paddings);
425 }
426
427 // Calculates the TOSA padding values for transposeConv2d
getTransposeConv2dPaddingValues(tensorflow::Padding tf_pad,tensorflow::TensorFormat data_format_tf,uint32_t first_filter_spatial_dim,ShapedType input_type,ShapedType filter_type,ShapedType output_type,ArrayAttr strides,PatternRewriter & rewriter,ArrayAttr & explicit_padding)428 bool getTransposeConv2dPaddingValues(
429 tensorflow::Padding tf_pad, tensorflow::TensorFormat data_format_tf,
430 uint32_t first_filter_spatial_dim, ShapedType input_type,
431 ShapedType filter_type, ShapedType output_type, ArrayAttr strides,
432 PatternRewriter& rewriter, ArrayAttr& explicit_padding) {
433 assert(tf_pad != tensorflow::Padding::EXPLICIT);
434 if (!input_type.hasRank() || !filter_type.hasRank() || !output_type.hasRank())
435 return false;
436
437 // Storing the numeric padding values is useful for TOSA codegen, as opposed
438 // to holding the padding regime mnemonic, i.e. SAME, VALID, FULL, ...
439
440 SmallVector<int64_t> computed_paddings;
441
442 int64_t pad_before, pad_after;
443 for (int i = 0; i < 2; i++) { // Two spatial dimensions X&Y
444 int64_t ifm_dim = GetTensorSpatialDimIndex(
445 4, data_format_tf, i); // 4D tensor, NHWC/NCHW format
446 int64_t ofm_dim = GetTensorSpatialDimIndex(
447 4, data_format_tf, i); // 4D tensor, NHWC/NCHW format
448 int64_t filter_dim = first_filter_spatial_dim + i;
449
450 int64_t ifm_size = input_type.getDimSize(ifm_dim);
451 int64_t filter_size = filter_type.getDimSize(filter_dim);
452 int64_t ofm_size = output_type.getDimSize(ofm_dim);
453 int64_t dim_stride = strides[i].template cast<IntegerAttr>().getInt();
454
455 // These dimensions need to be static to legalize.
456 if (ShapedType::isDynamic(filter_size) || ShapedType::isDynamic(ifm_size) ||
457 ShapedType::isDynamic(ofm_size)) {
458 return false;
459 }
460
461 int total_padding = ((ifm_size - 1) * dim_stride + filter_size - ofm_size);
462 total_padding = total_padding > 0 ? total_padding : 0;
463
464 pad_before = total_padding / 2;
465 pad_after = total_padding - pad_before;
466
467 computed_paddings.push_back(pad_before);
468 computed_paddings.push_back(pad_after);
469 }
470
471 explicit_padding = rewriter.getI64ArrayAttr(computed_paddings);
472 return true;
473 }
474
475 // Templated function to create a constant op for given type and shape.
476 // T: storage C type.
477 // Default template creates a constant tensor in T.
478 template <typename T>
getConstTensor(PatternRewriter & rewriter,Operation * op,ArrayRef<T> vec,ArrayRef<int64_t> shape)479 llvm::Optional<Value> getConstTensor(PatternRewriter& rewriter, Operation* op,
480 ArrayRef<T> vec, ArrayRef<int64_t> shape) {
481 int64_t num_total_elements = 1;
482 for (int64_t a : shape) {
483 num_total_elements *= a;
484 }
485
486 if (vec.size() != num_total_elements) {
487 op->emitOpError("getConstTensor(): number of elements mismatch.");
488 return llvm::None;
489 }
490
491 auto const_type =
492 RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
493 auto const_attr = DenseElementsAttr::get(const_type, vec);
494
495 auto const_op =
496 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
497 return const_op.getResult();
498 }
499
500 // Template specialization for APInt
501 template <>
getConstTensor(PatternRewriter & rewriter,Operation * op,ArrayRef<APInt> vec,ArrayRef<int64_t> shape)502 llvm::Optional<Value> getConstTensor<APInt>(PatternRewriter& rewriter,
503 Operation* op, ArrayRef<APInt> vec,
504 ArrayRef<int64_t> shape) {
505 int64_t num_total_elements = 1;
506 for (int64_t a : shape) {
507 num_total_elements *= a;
508 }
509
510 if (vec.size() != num_total_elements) {
511 op->emitOpError("getConstTensor(): number of elements mismatch.");
512 return llvm::None;
513 }
514
515 auto const_type = RankedTensorType::get(
516 shape, rewriter.getIntegerType(vec[0].getBitWidth()));
517 auto const_attr = DenseElementsAttr::get(const_type, vec);
518
519 auto const_op =
520 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
521 return const_op.getResult();
522 }
523
524 // Template specialization for float
525 template <>
getConstTensor(PatternRewriter & rewriter,Operation * op,ArrayRef<float> vec,ArrayRef<int64_t> shape)526 llvm::Optional<Value> getConstTensor<float>(PatternRewriter& rewriter,
527 Operation* op, ArrayRef<float> vec,
528 ArrayRef<int64_t> shape) {
529 int64_t num_total_elements = 1;
530 for (int64_t a : shape) {
531 num_total_elements *= a;
532 }
533
534 if (vec.size() != num_total_elements) {
535 op->emitOpError("getConstTensor(): number of elements mismatch.");
536 return llvm::None;
537 }
538
539 auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
540 auto const_attr = DenseElementsAttr::get(const_type, vec);
541
542 auto const_op =
543 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
544 return const_op.getResult();
545 }
546
547 // Template instantiation
548 template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter&,
549 Operation*,
550 ArrayRef<int32_t> vec,
551 ArrayRef<int64_t> shape);
552
553 // Check if scale32 mode is used for given output_element_type
isScale32(mlir::quant::UniformQuantizedType output_element_type)554 bool isScale32(mlir::quant::UniformQuantizedType output_element_type) {
555 return (output_element_type.getStorageTypeIntegralWidth() == 8);
556 }
557
ApplyPatternsWithShapeResolution(func::FuncOp func,const FrozenRewritePatternSet & patterns)558 LogicalResult ApplyPatternsWithShapeResolution(
559 func::FuncOp func, const FrozenRewritePatternSet& patterns) {
560 // We use top-down traversal so that shape inference can fully infer types
561 // during pattern rewrite.
562 GreedyRewriteConfig config;
563 config.useTopDownTraversal = true;
564 if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
565 return failure();
566 }
567
568 // Check that constant attributes types and op types match up. If the lowering
569 // needs to change a type (e.g. fp16 -> fp32) its possible the return type
570 // could be incorrect.
571 //
572 // This should be investigate for whether it is still necessary due to quant
573 // type stripping changing.
574 func.walk([&](tosa::ConstOp op) {
575 auto ety = op.getValue().getType().getElementType();
576 auto new_ty = op.getType().cast<ShapedType>().clone(ety);
577 op.getResult().setType(new_ty);
578 });
579
580 auto returnOp = cast<func::ReturnOp>(func.getBody().front().getTerminator());
581 llvm::SmallVector<Type> result_tys(returnOp.getOperandTypes());
582
583 func.setType(FunctionType::get(
584 func.getContext(), func.getFunctionType().getInputs(), result_tys));
585
586 return success();
587 }
588
589 } // namespace tosa
590 } // namespace mlir
591