1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
17
18 #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
19 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project
20 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
21 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
22
23 // Implements legalization and post-legalization optimization helper functions
24
25 namespace mlir {
26 namespace tosa {
27
28 // Create a TOSA rescale op from TFLite scaling, zero points and rounding mode
buildRescale(PatternRewriter & rewriter,Operation * op,ShapedType output_type,Value input_val,double scale,int64_t input_zp,int64_t output_zp,bool double_round,bool scale32)29 Value buildRescale(PatternRewriter& rewriter, Operation* op,
30 ShapedType output_type, Value input_val, double scale,
31 int64_t input_zp, int64_t output_zp, bool double_round,
32 bool scale32) {
33 int32_t multiplier;
34 int32_t shift;
35
36 int32_t scale_width = scale32 ? 32 : 16;
37
38 computeMultiplierAndShift(scale, multiplier, shift, scale_width);
39
40 auto rescale_op = rewriter.create<tosa::RescaleOp>(
41 op->getLoc(), output_type, input_val,
42 rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)),
43 rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)),
44 rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}),
45 rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round),
46 rewriter.getBoolAttr(false));
47
48 return rescale_op.getResult();
49 }
50
51 // Creates TOSA rescale op with int32 output
buildRescaleToInt32(PatternRewriter & rewriter,Operation * op,Value input_val,double input_scale,int64_t input_zp)52 Value buildRescaleToInt32(PatternRewriter& rewriter, Operation* op,
53 Value input_val, double input_scale,
54 int64_t input_zp) {
55 // Output is always int32 type
56 auto input_type = input_val.getType().dyn_cast<mlir::ShapedType>();
57 assert(input_type);
58 auto output_type = input_type.clone(rewriter.getI32Type());
59
60 return buildRescale(rewriter, op, output_type, input_val, input_scale,
61 input_zp, 0, false, true);
62 }
63
64 // Creates TOSA rescale op with int32 input
buildRescaleFromInt32(PatternRewriter & rewriter,Operation * op,ShapedType output_type,Value input_val,double output_scale,int64_t output_zp)65 Value buildRescaleFromInt32(PatternRewriter& rewriter, Operation* op,
66 ShapedType output_type, Value input_val,
67 double output_scale, int64_t output_zp) {
68 // Input should be int32 type
69 auto input_type = input_val.getType().dyn_cast<mlir::ShapedType>();
70 (void)input_type;
71 assert(input_type && input_type.getElementType().isInteger(32) &&
72 "expected rescale input element type to be i32");
73
74 // Potentially check input_shape == output_shape here
75 return buildRescale(rewriter, op, output_type, input_val, output_scale, 0,
76 output_zp, true, true);
77 }
78
79 // Creates a TOSA rescale op based on conv2d parameters.
buildRescaleOpConvOutput(PatternRewriter & rewriter,Operation * op,Value conv_val,ShapedType input_type,ShapedType weight_type,ShapedType output_type)80 Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op,
81 Value conv_val, ShapedType input_type,
82 ShapedType weight_type, ShapedType output_type) {
83 auto input_qtype =
84 input_type.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
85 auto output_qtype = output_type.getElementType()
86 .dyn_cast<mlir::quant::UniformQuantizedType>();
87
88 double input_scale = input_qtype.getScale();
89
90 int64_t output_zp = output_qtype.getZeroPoint();
91 double output_scale = output_qtype.getScale();
92
93 bool scale32 = isScale32(output_qtype);
94 int32_t scale_width = scale32 ? 32 : 16;
95
96 if (auto weight_per_tensor_qtype =
97 weight_type.getElementType()
98 .dyn_cast<mlir::quant::UniformQuantizedType>()) {
99 // Per-tensor quantization
100 double weight_scale = weight_per_tensor_qtype.getScale();
101
102 int32_t multiplier;
103 int32_t shift;
104
105 double op_tensor_scale = (input_scale * weight_scale) / output_scale;
106
107 computeMultiplierAndShift(op_tensor_scale, multiplier, shift, scale_width);
108
109 auto rescale_op = rewriter.create<tosa::RescaleOp>(
110 op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0),
111 rewriter.getI32IntegerAttr(output_zp),
112 rewriter.getI32ArrayAttr({multiplier}),
113 rewriter.getI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32),
114 rewriter.getBoolAttr(true), rewriter.getBoolAttr(false));
115
116 return rescale_op.getResult();
117
118 } else if (auto weight_per_channel_qtype =
119 weight_type.getElementType()
120 .dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
121 // Per-channel quantization
122 auto output_last_axis = output_type.getShape().size() - 1;
123 uint32_t output_channels = output_type.getShape()[output_last_axis];
124
125 SmallVector<int32_t> multiplier_arr;
126 SmallVector<int32_t> shift_arr;
127
128 SmallVector<double> weight_scale_arr(
129 weight_per_channel_qtype.getScales().begin(),
130 weight_per_channel_qtype.getScales().end());
131
132 int64_t output_zp = output_qtype.getZeroPoint();
133 double output_scale = output_qtype.getScale();
134
135 for (uint32_t oc = 0; oc < output_channels; oc++) {
136 double weight_scale = weight_scale_arr[oc];
137
138 int32_t multiplier;
139 int32_t shift;
140
141 double op_channel_scale = (input_scale * weight_scale) / output_scale;
142
143 computeMultiplierAndShift(op_channel_scale, multiplier, shift,
144 scale_width);
145
146 multiplier_arr.push_back(multiplier);
147 shift_arr.push_back(shift);
148 }
149
150 auto rescale_op = rewriter.create<tosa::RescaleOp>(
151 op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0),
152 rewriter.getI32IntegerAttr(output_zp),
153 rewriter.getI32ArrayAttr(multiplier_arr),
154 rewriter.getI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32),
155 rewriter.getBoolAttr(true), rewriter.getBoolAttr(true));
156
157 return rescale_op.getResult();
158
159 } else {
160 op->emitOpError("buildConvRescaleOp: unknown weight quantized type");
161 return nullptr;
162 }
163 }
164
165 // Create a 8-bit TOSA TABLE constant tensor with int8[256] array.
166 // Follow PopulateLookupTable() tensorflow/lite/kernels/activations.cc
getTosaConst8bitTable(PatternRewriter & rewriter,Operation * op,double input_scale,int32_t input_zp,double output_scale,int32_t output_zp,std::function<double (double)> func)167 Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op,
168 double input_scale, int32_t input_zp,
169 double output_scale, int32_t output_zp,
170 std::function<double(double)> func) {
171 SmallVector<int8_t, 256> table;
172
173 for (int32_t i = -128; i < 128; i++) {
174 double dequantized = input_scale * (i - input_zp);
175 double transformed = func(dequantized);
176 int32_t rescaled = std::llround(transformed / output_scale);
177 int32_t quantized = static_cast<int32_t>(rescaled + output_zp);
178 table.push_back(
179 static_cast<int8_t>(std::min(std::max(quantized, -128), 127)));
180 }
181
182 auto element_qtype =
183 UniformQuantizedType::get(true, rewriter.getIntegerType(8),
184 rewriter.getF32Type(), 1.0f, 0, -128, 127);
185 auto const_type = RankedTensorType::get({256}, element_qtype);
186 auto storage_type =
187 RankedTensorType::get({256}, element_qtype.getStorageType());
188 auto const_attr =
189 DenseElementsAttr::get(storage_type, llvm::makeArrayRef(table));
190
191 auto const_op =
192 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
193 return const_op.getResult();
194 }
195
196 // Create a 16-bit TOSA TABLE constant tensor with int16[513] array.
197 // Output is restricted to [-1.0, 1.0].
198 // Follow gen_lut() tensorflow/lite/kernels/internal/common.h
getTosaConst16bitTable(PatternRewriter & rewriter,Operation * op,std::function<double (double)> func,double min,double max)199 Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op,
200 std::function<double(double)> func, double min,
201 double max) {
202 SmallVector<int16_t, 513> table;
203
204 double step = (max - min) / 512.0f;
205 double half_step = step / 2.0f;
206 for (int32_t i = 0; i < 512; i++) {
207 int32_t sample_val = std::llround(func(min + (i * step)) * 32768.0);
208 double midpoint_interp_val =
209 std::round(((func(min + (i + 1) * step) * 32768.0) +
210 std::round(func(min + (i * step)) * 32768.0)) /
211 2.0);
212 double midpoint_val =
213 std::round(func(min + (i * step) + half_step) * 32768.0);
214 double midpoint_err = midpoint_interp_val - midpoint_val;
215 int32_t bias = std::llround(midpoint_err / 2.0);
216
217 table.push_back(static_cast<int16_t>(
218 std::min(std::max(sample_val - bias, -32768), 32767)));
219 }
220
221 int32_t max_val = std::llround(func(max) * 32768.0);
222 table.push_back(
223 static_cast<int16_t>(std::min(std::max(max_val, -32768), 32767)));
224
225 auto element_qtype =
226 UniformQuantizedType::get(true, rewriter.getIntegerType(16),
227 rewriter.getF32Type(), 1.0f, 0, -32768, 32767);
228 auto const_type = RankedTensorType::get({513}, element_qtype);
229 auto storage_type =
230 RankedTensorType::get({513}, element_qtype.getStorageType());
231 auto const_attr =
232 DenseElementsAttr::get(storage_type, llvm::makeArrayRef(table));
233
234 auto const_op =
235 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
236 return const_op.getResult();
237 }
238
239 // Create a 32-bit TOSA TABLE constant tensor with int16[513] array.
240 // Output is restricted to [-1.0, 1.0] as s0.31 format.
getTosaConst32bitTable(PatternRewriter & rewriter,Operation * op,double input_scale,int32_t input_zp,std::function<double (double)> func,Value & upper_const,Value & lower_const)241 void getTosaConst32bitTable(PatternRewriter& rewriter, Operation* op,
242 double input_scale, int32_t input_zp,
243 std::function<double(double)> func,
244 Value& upper_const, Value& lower_const) {
245 SmallVector<int16_t, 513> upper_table, lower_table;
246
247 double output_inv_scale = static_cast<double>(1L << 31);
248
249 for (int32_t i = -256; i <= 256; i++) {
250 double dequantized = input_scale * (i - input_zp);
251 double transformed = func(dequantized);
252 double truncated = std::min(std::max(transformed, -1.0), 1.0);
253 int64_t rescaled =
254 static_cast<int64_t>(std::round(truncated * output_inv_scale));
255
256 // 2^31 is not representable in int32_t, so store as 2^31 - 1 instead
257 if (rescaled == static_cast<int64_t>(1L << 31)) {
258 rescaled = static_cast<int64_t>(1L << 31) - 1;
259 }
260
261 int32_t upper = (rescaled >> 16) & 0xFFFF;
262 // TABLE output is signed 16 bits with range [-32768, 32767]
263 // Lower 16 bits are unsigned and ranges [0, 65536]
264 // Need to adjust value with offset 0x8000 in table generation
265 // Legalization should add this back before recovering 32-bit value
266 int32_t lower = (rescaled & 0xFFFF) - 0x8000;
267
268 upper_table.push_back(upper);
269 lower_table.push_back(lower);
270 }
271
272 auto element_qtype =
273 UniformQuantizedType::get(true, rewriter.getIntegerType(16),
274 rewriter.getF32Type(), 1.0f, 0, -32768, 32767);
275 auto const_type = RankedTensorType::get({513}, element_qtype);
276 auto storage_type =
277 RankedTensorType::get({513}, element_qtype.getStorageType());
278
279 auto upper_const_attr =
280 DenseElementsAttr::get(storage_type, llvm::makeArrayRef(upper_table));
281 auto lower_const_attr =
282 DenseElementsAttr::get(storage_type, llvm::makeArrayRef(lower_table));
283
284 upper_const =
285 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, upper_const_attr)
286 .getResult();
287 lower_const =
288 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, lower_const_attr)
289 .getResult();
290 }
291
292 // Create a 32-bit float constant operator from a float
getTosaConstTensorSingleF32(PatternRewriter & rewriter,Operation * op,float val)293 Value getTosaConstTensorSingleF32(PatternRewriter& rewriter, Operation* op,
294 float val) {
295 auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
296 auto const_attr = DenseElementsAttr::get(const_type, val);
297
298 auto const_op =
299 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
300 return const_op.getResult();
301 }
302
303 // Create a 32-bit integer constant operator from an int
getTosaConstTensorSingleI32(PatternRewriter & rewriter,Operation * op,int32_t val)304 Value getTosaConstTensorSingleI32(PatternRewriter& rewriter, Operation* op,
305 int32_t val) {
306 auto const_type = RankedTensorType::get({}, rewriter.getIntegerType(32));
307 auto const_attr = DenseElementsAttr::get(const_type, val);
308
309 auto const_op =
310 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
311 return const_op.getResult();
312 }
313
314 // Create a vector from a 32-bit value tensor. Returns the size of
315 // the new vector or -1 on error.
getVectorFromValue32(Value val,SmallVectorImpl<int32_t> & vec)316 LogicalResult getVectorFromValue32(Value val, SmallVectorImpl<int32_t>& vec) {
317 int i = 0;
318
319 ElementsAttr elems;
320
321 vec.clear();
322
323 if (!matchPattern(val, m_Constant(&elems))) return failure();
324
325 for (auto idx : elems.getValues<IntegerAttr>()) {
326 vec.push_back(idx.getInt());
327 i++;
328 }
329
330 return success();
331 }
332
333 // Calculates the TOSA padding values based on TF operators padded with
334 // SAME/VALID.
335 //
336 // This could pass tensorflow::FilterTensorFormat and do
337 // GetFilterTensorSpatialDimIndex but the current TF core libs do not support
338 // FORMAT_OHWI parsing by that function in core/util/tensor_format.h
getPaddingValuesFromPadType(tensorflow::Padding tf_pad,tensorflow::TensorFormat data_format_tf,uint32_t first_filter_spatial_dim,ShapedType input_type,ShapedType filter_type,ArrayAttr strides,ArrayAttr dilations,PatternRewriter & rewriter,ArrayAttr & explicit_padding)339 bool getPaddingValuesFromPadType(tensorflow::Padding tf_pad,
340 tensorflow::TensorFormat data_format_tf,
341 uint32_t first_filter_spatial_dim,
342 ShapedType input_type, ShapedType filter_type,
343 ArrayAttr strides, ArrayAttr dilations,
344 PatternRewriter& rewriter,
345 ArrayAttr& explicit_padding) {
346 assert(tf_pad != tensorflow::Padding::EXPLICIT);
347 if (!input_type.hasRank() || !filter_type.getRank()) return false;
348
349 // Storing the numeric padding values is useful for TOSA codegen, as opposed
350 // to holding the padding regime mnemonic, i.e. SAME, VALID, FULL, ...
351 SmallVector<int64_t> computed_paddings;
352
353 int64_t pad_before, pad_after;
354 for (int i = 0; i < 2; i++) { // Two spatial dimensions X&Y
355 int64_t ifm_dim = GetTensorSpatialDimIndex(
356 4, data_format_tf, i); // 4D tensor, NHWC/NCHW format
357 int64_t filter_dim = first_filter_spatial_dim + i;
358
359 int64_t dim_dilation = dilations[i].template cast<IntegerAttr>().getInt();
360 int64_t dim_stride = strides[i].template cast<IntegerAttr>().getInt();
361
362 int64_t ip_size = input_type.getDimSize(ifm_dim);
363 int64_t f_size = filter_type.getDimSize(filter_dim);
364 // If we have a dynamic shape we should assume it is wide enough.
365 ip_size = ip_size < 0 ? f_size * dim_dilation : ip_size;
366 int64_t op_size, pad_before_tf,
367 pad_after_tf; // Complains if using int64_T
368 tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
369 ip_size, f_size, dim_dilation, dim_stride, tf_pad, &op_size,
370 &pad_before_tf, &pad_after_tf);
371 if (!status.ok()) return false;
372
373 pad_before = pad_before_tf;
374 pad_after = pad_after_tf;
375 computed_paddings.push_back(pad_before);
376 computed_paddings.push_back(pad_after);
377 }
378
379 explicit_padding = rewriter.getI64ArrayAttr(computed_paddings);
380 return true;
381 }
382
383 // Calculates the TOSA padding values for explicit-padded TF operators.
384 //
385 // This function only handles the TF padding array explicit_padding, which is
386 // only present in certain TF ops. All others encode padding using the string
387 // SAME/VALID, which is interpreted using the getPaddingValuesFromPadString
388 // function below.
389
390 // The explicit padding array in TF holds 2 pad values for every
391 // dimension, even those that are not the 2 spatial ones. Just extract the
392 // 2x pad values for the XY dims.
getPaddingValuesFromExplicitPadAttr(ArrayAttr explicit_pad,tensorflow::TensorFormat data_format_tf,PatternRewriter & rewriter)393 ArrayAttr getPaddingValuesFromExplicitPadAttr(
394 ArrayAttr explicit_pad, tensorflow::TensorFormat data_format_tf,
395 PatternRewriter& rewriter) {
396 SmallVector<int64_t> computed_paddings;
397
398 int64_t pad_before, pad_after;
399 for (int i = 0; i < 2; i++) { // Two spatial dimensions X&Y
400 int64_t dim = GetTensorSpatialDimIndex(4, data_format_tf,
401 i); // 4D tensor, NHWC/NCHW format
402
403 pad_before = explicit_pad[dim * 2].template cast<IntegerAttr>().getInt();
404 pad_after = explicit_pad[dim * 2 + 1].template cast<IntegerAttr>().getInt();
405 computed_paddings.push_back(pad_before);
406 computed_paddings.push_back(pad_after);
407 }
408
409 return rewriter.getI64ArrayAttr(computed_paddings);
410 }
411
412 // Calculates the TOSA padding values for transposeConv2d
getTransposeConv2dPaddingValues(tensorflow::Padding tf_pad,tensorflow::TensorFormat data_format_tf,uint32_t first_filter_spatial_dim,ShapedType input_type,ShapedType filter_type,ShapedType output_type,ArrayAttr strides,ArrayAttr dilations,PatternRewriter & rewriter,ArrayAttr & explicit_padding)413 bool getTransposeConv2dPaddingValues(
414 tensorflow::Padding tf_pad, tensorflow::TensorFormat data_format_tf,
415 uint32_t first_filter_spatial_dim, ShapedType input_type,
416 ShapedType filter_type, ShapedType output_type, ArrayAttr strides,
417 ArrayAttr dilations, PatternRewriter& rewriter,
418 ArrayAttr& explicit_padding) {
419 assert(tf_pad != tensorflow::Padding::EXPLICIT);
420 if (!input_type.hasRank() || !filter_type.hasRank() || !output_type.hasRank())
421 return false;
422
423 // Storing the numeric padding values is useful for TOSA codegen, as opposed
424 // to holding the padding regime mnemonic, i.e. SAME, VALID, FULL, ...
425
426 SmallVector<int64_t> computed_paddings;
427
428 int64_t pad_before, pad_after;
429 for (int i = 0; i < 2; i++) { // Two spatial dimensions X&Y
430 int64_t ifm_dim = GetTensorSpatialDimIndex(
431 4, data_format_tf, i); // 4D tensor, NHWC/NCHW format
432 int64_t ofm_dim = GetTensorSpatialDimIndex(
433 4, data_format_tf, i); // 4D tensor, NHWC/NCHW format
434 int64_t filter_dim = first_filter_spatial_dim + i;
435
436 int64_t ifm_size = input_type.getDimSize(ifm_dim);
437 int64_t filter_size = filter_type.getDimSize(filter_dim);
438 int64_t ofm_size = output_type.getDimSize(ofm_dim);
439 int64_t dim_dilation = dilations[i].template cast<IntegerAttr>().getInt();
440 int64_t dim_stride = strides[i].template cast<IntegerAttr>().getInt();
441
442 int effective_filter_size = (filter_size - 1) * dim_dilation + 1;
443 int total_padding =
444 ((ifm_size - 1) * dim_stride + effective_filter_size - ofm_size);
445 total_padding = total_padding > 0 ? total_padding : 0;
446
447 pad_before = total_padding / 2;
448 pad_after = total_padding - pad_before;
449
450 computed_paddings.push_back(pad_before);
451 }
452
453 explicit_padding = rewriter.getI64ArrayAttr(computed_paddings);
454 return true;
455 }
456
457 // Templated function to create a constant op for given type and shape.
458 // T: storage C type.
459 // Default template creates a constant tensor in T.
460 template <typename T>
getConstTensor(PatternRewriter & rewriter,Operation * op,ArrayRef<T> vec,ArrayRef<int64_t> shape)461 llvm::Optional<Value> getConstTensor(PatternRewriter& rewriter, Operation* op,
462 ArrayRef<T> vec, ArrayRef<int64_t> shape) {
463 int64_t num_total_elements = 1;
464 for (int64_t a : shape) {
465 num_total_elements *= a;
466 }
467
468 if (vec.size() != num_total_elements) {
469 op->emitOpError("getConstTensor(): number of elements mismatch.");
470 return llvm::None;
471 }
472
473 auto const_type =
474 RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
475 auto const_attr = DenseElementsAttr::get(const_type, vec);
476
477 auto const_op =
478 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
479 return const_op.getResult();
480 }
481
482 // Template specialization for APInt
483 template <>
getConstTensor(PatternRewriter & rewriter,Operation * op,ArrayRef<APInt> vec,ArrayRef<int64_t> shape)484 llvm::Optional<Value> getConstTensor<APInt>(PatternRewriter& rewriter,
485 Operation* op, ArrayRef<APInt> vec,
486 ArrayRef<int64_t> shape) {
487 int64_t num_total_elements = 1;
488 for (int64_t a : shape) {
489 num_total_elements *= a;
490 }
491
492 if (vec.size() != num_total_elements) {
493 op->emitOpError("getConstTensor(): number of elements mismatch.");
494 return llvm::None;
495 }
496
497 auto const_type = RankedTensorType::get(
498 shape, rewriter.getIntegerType(vec[0].getBitWidth()));
499 auto const_attr = DenseElementsAttr::get(const_type, vec);
500
501 auto const_op =
502 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
503 return const_op.getResult();
504 }
505
506 // Template specialization for float
507 template <>
getConstTensor(PatternRewriter & rewriter,Operation * op,ArrayRef<float> vec,ArrayRef<int64_t> shape)508 llvm::Optional<Value> getConstTensor<float>(PatternRewriter& rewriter,
509 Operation* op, ArrayRef<float> vec,
510 ArrayRef<int64_t> shape) {
511 int64_t num_total_elements = 1;
512 for (int64_t a : shape) {
513 num_total_elements *= a;
514 }
515
516 if (vec.size() != num_total_elements) {
517 op->emitOpError("getConstTensor(): number of elements mismatch.");
518 return llvm::None;
519 }
520
521 auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
522 auto const_attr = DenseElementsAttr::get(const_type, vec);
523
524 auto const_op =
525 rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
526 return const_op.getResult();
527 }
528
529 // Template instantiation
530 template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter&,
531 Operation*,
532 ArrayRef<int32_t> vec,
533 ArrayRef<int64_t> shape);
534
535 // Check if scale32 mode is used for given output_element_type
isScale32(mlir::quant::UniformQuantizedType output_element_type)536 bool isScale32(mlir::quant::UniformQuantizedType output_element_type) {
537 return (output_element_type.getStorageTypeIntegralWidth() == 8);
538 }
539
540 } // namespace tosa
541 } // namespace mlir
542