• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
17 
18 #include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
19 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"  // from @llvm-project
20 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
21 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
22 
23 // Implements legalization and post-legalization optimization helper functions
24 
25 namespace mlir {
26 namespace tosa {
27 
28 // Create a TOSA rescale op from TFLite scaling, zero points and rounding mode
buildRescale(PatternRewriter & rewriter,Operation * op,ShapedType output_type,Value input_val,double scale,int64_t input_zp,int64_t output_zp,bool double_round,bool scale32)29 Value buildRescale(PatternRewriter& rewriter, Operation* op,
30                    ShapedType output_type, Value input_val, double scale,
31                    int64_t input_zp, int64_t output_zp, bool double_round,
32                    bool scale32) {
33   int32_t multiplier;
34   int32_t shift;
35 
36   int32_t scale_width = scale32 ? 32 : 16;
37 
38   computeMultiplierAndShift(scale, multiplier, shift, scale_width);
39 
40   auto rescale_op = rewriter.create<tosa::RescaleOp>(
41       op->getLoc(), output_type, input_val,
42       rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)),
43       rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)),
44       rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}),
45       rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round),
46       rewriter.getBoolAttr(false));
47 
48   return rescale_op.getResult();
49 }
50 
51 // Creates TOSA rescale op with int32 output
buildRescaleToInt32(PatternRewriter & rewriter,Operation * op,Value input_val,double input_scale,int64_t input_zp)52 Value buildRescaleToInt32(PatternRewriter& rewriter, Operation* op,
53                           Value input_val, double input_scale,
54                           int64_t input_zp) {
55   // Output is always int32 type
56   auto input_type = input_val.getType().dyn_cast<mlir::ShapedType>();
57   assert(input_type);
58   auto output_type = input_type.clone(rewriter.getI32Type());
59 
60   return buildRescale(rewriter, op, output_type, input_val, input_scale,
61                       input_zp, 0, false, true);
62 }
63 
64 // Creates TOSA rescale op with int32 input
buildRescaleFromInt32(PatternRewriter & rewriter,Operation * op,ShapedType output_type,Value input_val,double output_scale,int64_t output_zp)65 Value buildRescaleFromInt32(PatternRewriter& rewriter, Operation* op,
66                             ShapedType output_type, Value input_val,
67                             double output_scale, int64_t output_zp) {
68   // Input should be int32 type
69   auto input_type = input_val.getType().dyn_cast<mlir::ShapedType>();
70   (void)input_type;
71   assert(input_type && input_type.getElementType().isInteger(32) &&
72          "expected rescale input element type to be i32");
73 
74   // Potentially check input_shape == output_shape here
75   return buildRescale(rewriter, op, output_type, input_val, output_scale, 0,
76                       output_zp, true, true);
77 }
78 
79 // Creates a TOSA rescale op based on conv2d parameters.
buildRescaleOpConvOutput(PatternRewriter & rewriter,Operation * op,Value conv_val,ShapedType input_type,ShapedType weight_type,ShapedType output_type)80 Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op,
81                                Value conv_val, ShapedType input_type,
82                                ShapedType weight_type, ShapedType output_type) {
83   auto input_qtype =
84       input_type.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
85   auto output_qtype = output_type.getElementType()
86                           .dyn_cast<mlir::quant::UniformQuantizedType>();
87 
88   double input_scale = input_qtype.getScale();
89 
90   int64_t output_zp = output_qtype.getZeroPoint();
91   double output_scale = output_qtype.getScale();
92 
93   bool scale32 = isScale32(output_qtype);
94   int32_t scale_width = scale32 ? 32 : 16;
95 
96   if (auto weight_per_tensor_qtype =
97           weight_type.getElementType()
98               .dyn_cast<mlir::quant::UniformQuantizedType>()) {
99     // Per-tensor quantization
100     double weight_scale = weight_per_tensor_qtype.getScale();
101 
102     int32_t multiplier;
103     int32_t shift;
104 
105     double op_tensor_scale = (input_scale * weight_scale) / output_scale;
106 
107     computeMultiplierAndShift(op_tensor_scale, multiplier, shift, scale_width);
108 
109     auto rescale_op = rewriter.create<tosa::RescaleOp>(
110         op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0),
111         rewriter.getI32IntegerAttr(output_zp),
112         rewriter.getI32ArrayAttr({multiplier}),
113         rewriter.getI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32),
114         rewriter.getBoolAttr(true), rewriter.getBoolAttr(false));
115 
116     return rescale_op.getResult();
117 
118   } else if (auto weight_per_channel_qtype =
119                  weight_type.getElementType()
120                      .dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
121     // Per-channel quantization
122     auto output_last_axis = output_type.getShape().size() - 1;
123     uint32_t output_channels = output_type.getShape()[output_last_axis];
124 
125     SmallVector<int32_t> multiplier_arr;
126     SmallVector<int32_t> shift_arr;
127 
128     SmallVector<double> weight_scale_arr(
129         weight_per_channel_qtype.getScales().begin(),
130         weight_per_channel_qtype.getScales().end());
131 
132     int64_t output_zp = output_qtype.getZeroPoint();
133     double output_scale = output_qtype.getScale();
134 
135     for (uint32_t oc = 0; oc < output_channels; oc++) {
136       double weight_scale = weight_scale_arr[oc];
137 
138       int32_t multiplier;
139       int32_t shift;
140 
141       double op_channel_scale = (input_scale * weight_scale) / output_scale;
142 
143       computeMultiplierAndShift(op_channel_scale, multiplier, shift,
144                                 scale_width);
145 
146       multiplier_arr.push_back(multiplier);
147       shift_arr.push_back(shift);
148     }
149 
150     auto rescale_op = rewriter.create<tosa::RescaleOp>(
151         op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0),
152         rewriter.getI32IntegerAttr(output_zp),
153         rewriter.getI32ArrayAttr(multiplier_arr),
154         rewriter.getI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32),
155         rewriter.getBoolAttr(true), rewriter.getBoolAttr(true));
156 
157     return rescale_op.getResult();
158 
159   } else {
160     op->emitOpError("buildConvRescaleOp: unknown weight quantized type");
161     return nullptr;
162   }
163 }
164 
165 // Create a 8-bit TOSA TABLE constant tensor with int8[256] array.
166 // Follow PopulateLookupTable() tensorflow/lite/kernels/activations.cc
getTosaConst8bitTable(PatternRewriter & rewriter,Operation * op,double input_scale,int32_t input_zp,double output_scale,int32_t output_zp,std::function<double (double)> func)167 Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op,
168                             double input_scale, int32_t input_zp,
169                             double output_scale, int32_t output_zp,
170                             std::function<double(double)> func) {
171   SmallVector<int8_t, 256> table;
172 
173   for (int32_t i = -128; i < 128; i++) {
174     double dequantized = input_scale * (i - input_zp);
175     double transformed = func(dequantized);
176     int32_t rescaled = std::llround(transformed / output_scale);
177     int32_t quantized = static_cast<int32_t>(rescaled + output_zp);
178     table.push_back(
179         static_cast<int8_t>(std::min(std::max(quantized, -128), 127)));
180   }
181 
182   auto element_qtype =
183       UniformQuantizedType::get(true, rewriter.getIntegerType(8),
184                                 rewriter.getF32Type(), 1.0f, 0, -128, 127);
185   auto const_type = RankedTensorType::get({256}, element_qtype);
186   auto storage_type =
187       RankedTensorType::get({256}, element_qtype.getStorageType());
188   auto const_attr =
189       DenseElementsAttr::get(storage_type, llvm::makeArrayRef(table));
190 
191   auto const_op =
192       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
193   return const_op.getResult();
194 }
195 
196 // Create a 16-bit TOSA TABLE constant tensor with int16[513] array.
197 // Output is restricted to [-1.0, 1.0].
198 // Follow gen_lut() tensorflow/lite/kernels/internal/common.h
getTosaConst16bitTable(PatternRewriter & rewriter,Operation * op,std::function<double (double)> func,double min,double max)199 Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op,
200                              std::function<double(double)> func, double min,
201                              double max) {
202   SmallVector<int16_t, 513> table;
203 
204   double step = (max - min) / 512.0f;
205   double half_step = step / 2.0f;
206   for (int32_t i = 0; i < 512; i++) {
207     int32_t sample_val = std::llround(func(min + (i * step)) * 32768.0);
208     double midpoint_interp_val =
209         std::round(((func(min + (i + 1) * step) * 32768.0) +
210                     std::round(func(min + (i * step)) * 32768.0)) /
211                    2.0);
212     double midpoint_val =
213         std::round(func(min + (i * step) + half_step) * 32768.0);
214     double midpoint_err = midpoint_interp_val - midpoint_val;
215     int32_t bias = std::llround(midpoint_err / 2.0);
216 
217     table.push_back(static_cast<int16_t>(
218         std::min(std::max(sample_val - bias, -32768), 32767)));
219   }
220 
221   int32_t max_val = std::llround(func(max) * 32768.0);
222   table.push_back(
223       static_cast<int16_t>(std::min(std::max(max_val, -32768), 32767)));
224 
225   auto element_qtype =
226       UniformQuantizedType::get(true, rewriter.getIntegerType(16),
227                                 rewriter.getF32Type(), 1.0f, 0, -32768, 32767);
228   auto const_type = RankedTensorType::get({513}, element_qtype);
229   auto storage_type =
230       RankedTensorType::get({513}, element_qtype.getStorageType());
231   auto const_attr =
232       DenseElementsAttr::get(storage_type, llvm::makeArrayRef(table));
233 
234   auto const_op =
235       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
236   return const_op.getResult();
237 }
238 
239 // Create a 32-bit TOSA TABLE constant tensor with int16[513] array.
240 // Output is restricted to [-1.0, 1.0] as s0.31 format.
getTosaConst32bitTable(PatternRewriter & rewriter,Operation * op,double input_scale,int32_t input_zp,std::function<double (double)> func,Value & upper_const,Value & lower_const)241 void getTosaConst32bitTable(PatternRewriter& rewriter, Operation* op,
242                             double input_scale, int32_t input_zp,
243                             std::function<double(double)> func,
244                             Value& upper_const, Value& lower_const) {
245   SmallVector<int16_t, 513> upper_table, lower_table;
246 
247   double output_inv_scale = static_cast<double>(1L << 31);
248 
249   for (int32_t i = -256; i <= 256; i++) {
250     double dequantized = input_scale * (i - input_zp);
251     double transformed = func(dequantized);
252     double truncated = std::min(std::max(transformed, -1.0), 1.0);
253     int64_t rescaled =
254         static_cast<int64_t>(std::round(truncated * output_inv_scale));
255 
256     // 2^31 is not representable in int32_t, so store as 2^31 - 1 instead
257     if (rescaled == static_cast<int64_t>(1L << 31)) {
258       rescaled = static_cast<int64_t>(1L << 31) - 1;
259     }
260 
261     int32_t upper = (rescaled >> 16) & 0xFFFF;
262     // TABLE output is signed 16 bits with range [-32768, 32767]
263     // Lower 16 bits are unsigned and ranges [0, 65536]
264     // Need to adjust value with offset 0x8000 in table generation
265     // Legalization should add this back before recovering 32-bit value
266     int32_t lower = (rescaled & 0xFFFF) - 0x8000;
267 
268     upper_table.push_back(upper);
269     lower_table.push_back(lower);
270   }
271 
272   auto element_qtype =
273       UniformQuantizedType::get(true, rewriter.getIntegerType(16),
274                                 rewriter.getF32Type(), 1.0f, 0, -32768, 32767);
275   auto const_type = RankedTensorType::get({513}, element_qtype);
276   auto storage_type =
277       RankedTensorType::get({513}, element_qtype.getStorageType());
278 
279   auto upper_const_attr =
280       DenseElementsAttr::get(storage_type, llvm::makeArrayRef(upper_table));
281   auto lower_const_attr =
282       DenseElementsAttr::get(storage_type, llvm::makeArrayRef(lower_table));
283 
284   upper_const =
285       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, upper_const_attr)
286           .getResult();
287   lower_const =
288       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, lower_const_attr)
289           .getResult();
290 }
291 
292 // Create a 32-bit float constant operator from a float
getTosaConstTensorSingleF32(PatternRewriter & rewriter,Operation * op,float val)293 Value getTosaConstTensorSingleF32(PatternRewriter& rewriter, Operation* op,
294                                   float val) {
295   auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
296   auto const_attr = DenseElementsAttr::get(const_type, val);
297 
298   auto const_op =
299       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
300   return const_op.getResult();
301 }
302 
303 // Create a 32-bit integer constant operator from an int
getTosaConstTensorSingleI32(PatternRewriter & rewriter,Operation * op,int32_t val)304 Value getTosaConstTensorSingleI32(PatternRewriter& rewriter, Operation* op,
305                                   int32_t val) {
306   auto const_type = RankedTensorType::get({}, rewriter.getIntegerType(32));
307   auto const_attr = DenseElementsAttr::get(const_type, val);
308 
309   auto const_op =
310       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
311   return const_op.getResult();
312 }
313 
314 // Create a vector from a 32-bit value tensor.  Returns the size of
315 // the new vector or -1 on error.
getVectorFromValue32(Value val,SmallVectorImpl<int32_t> & vec)316 LogicalResult getVectorFromValue32(Value val, SmallVectorImpl<int32_t>& vec) {
317   int i = 0;
318 
319   ElementsAttr elems;
320 
321   vec.clear();
322 
323   if (!matchPattern(val, m_Constant(&elems))) return failure();
324 
325   for (auto idx : elems.getValues<IntegerAttr>()) {
326     vec.push_back(idx.getInt());
327     i++;
328   }
329 
330   return success();
331 }
332 
333 // Calculates the TOSA padding values based on TF operators padded with
334 // SAME/VALID.
335 //
336 // This could pass tensorflow::FilterTensorFormat and do
337 // GetFilterTensorSpatialDimIndex but the current TF core libs do not support
338 // FORMAT_OHWI parsing by that function in core/util/tensor_format.h
getPaddingValuesFromPadType(tensorflow::Padding tf_pad,tensorflow::TensorFormat data_format_tf,uint32_t first_filter_spatial_dim,ShapedType input_type,ShapedType filter_type,ArrayAttr strides,ArrayAttr dilations,PatternRewriter & rewriter,ArrayAttr & explicit_padding)339 bool getPaddingValuesFromPadType(tensorflow::Padding tf_pad,
340                                  tensorflow::TensorFormat data_format_tf,
341                                  uint32_t first_filter_spatial_dim,
342                                  ShapedType input_type, ShapedType filter_type,
343                                  ArrayAttr strides, ArrayAttr dilations,
344                                  PatternRewriter& rewriter,
345                                  ArrayAttr& explicit_padding) {
346   assert(tf_pad != tensorflow::Padding::EXPLICIT);
347   if (!input_type.hasRank() || !filter_type.getRank()) return false;
348 
349   // Storing the numeric padding values is useful for TOSA codegen, as opposed
350   // to holding the padding regime mnemonic, i.e. SAME, VALID, FULL, ...
351   SmallVector<int64_t> computed_paddings;
352 
353   int64_t pad_before, pad_after;
354   for (int i = 0; i < 2; i++) {  // Two spatial dimensions X&Y
355     int64_t ifm_dim = GetTensorSpatialDimIndex(
356         4, data_format_tf, i);  // 4D tensor, NHWC/NCHW format
357     int64_t filter_dim = first_filter_spatial_dim + i;
358 
359     int64_t dim_dilation = dilations[i].template cast<IntegerAttr>().getInt();
360     int64_t dim_stride = strides[i].template cast<IntegerAttr>().getInt();
361 
362     int64_t ip_size = input_type.getDimSize(ifm_dim);
363     int64_t f_size = filter_type.getDimSize(filter_dim);
364     // If we have a dynamic shape we should assume it is wide enough.
365     ip_size = ip_size < 0 ? f_size * dim_dilation : ip_size;
366     int64_t op_size, pad_before_tf,
367         pad_after_tf;  // Complains if using int64_T
368     tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
369         ip_size, f_size, dim_dilation, dim_stride, tf_pad, &op_size,
370         &pad_before_tf, &pad_after_tf);
371     if (!status.ok()) return false;
372 
373     pad_before = pad_before_tf;
374     pad_after = pad_after_tf;
375     computed_paddings.push_back(pad_before);
376     computed_paddings.push_back(pad_after);
377   }
378 
379   explicit_padding = rewriter.getI64ArrayAttr(computed_paddings);
380   return true;
381 }
382 
383 // Calculates the TOSA padding values for explicit-padded TF operators.
384 //
385 // This function only handles the TF padding array explicit_padding, which is
386 // only present in certain TF ops. All others encode padding using the string
387 // SAME/VALID, which is interpreted using the getPaddingValuesFromPadString
388 // function below.
389 
390 // The explicit padding array in TF holds 2 pad values for every
391 // dimension, even those that are not the 2 spatial ones. Just extract the
392 // 2x pad values for the XY dims.
getPaddingValuesFromExplicitPadAttr(ArrayAttr explicit_pad,tensorflow::TensorFormat data_format_tf,PatternRewriter & rewriter)393 ArrayAttr getPaddingValuesFromExplicitPadAttr(
394     ArrayAttr explicit_pad, tensorflow::TensorFormat data_format_tf,
395     PatternRewriter& rewriter) {
396   SmallVector<int64_t> computed_paddings;
397 
398   int64_t pad_before, pad_after;
399   for (int i = 0; i < 2; i++) {  // Two spatial dimensions X&Y
400     int64_t dim = GetTensorSpatialDimIndex(4, data_format_tf,
401                                            i);  // 4D tensor, NHWC/NCHW format
402 
403     pad_before = explicit_pad[dim * 2].template cast<IntegerAttr>().getInt();
404     pad_after = explicit_pad[dim * 2 + 1].template cast<IntegerAttr>().getInt();
405     computed_paddings.push_back(pad_before);
406     computed_paddings.push_back(pad_after);
407   }
408 
409   return rewriter.getI64ArrayAttr(computed_paddings);
410 }
411 
412 // Calculates the TOSA padding values for transposeConv2d
getTransposeConv2dPaddingValues(tensorflow::Padding tf_pad,tensorflow::TensorFormat data_format_tf,uint32_t first_filter_spatial_dim,ShapedType input_type,ShapedType filter_type,ShapedType output_type,ArrayAttr strides,ArrayAttr dilations,PatternRewriter & rewriter,ArrayAttr & explicit_padding)413 bool getTransposeConv2dPaddingValues(
414     tensorflow::Padding tf_pad, tensorflow::TensorFormat data_format_tf,
415     uint32_t first_filter_spatial_dim, ShapedType input_type,
416     ShapedType filter_type, ShapedType output_type, ArrayAttr strides,
417     ArrayAttr dilations, PatternRewriter& rewriter,
418     ArrayAttr& explicit_padding) {
419   assert(tf_pad != tensorflow::Padding::EXPLICIT);
420   if (!input_type.hasRank() || !filter_type.hasRank() || !output_type.hasRank())
421     return false;
422 
423   // Storing the numeric padding values is useful for TOSA codegen, as opposed
424   // to holding the padding regime mnemonic, i.e. SAME, VALID, FULL, ...
425 
426   SmallVector<int64_t> computed_paddings;
427 
428   int64_t pad_before, pad_after;
429   for (int i = 0; i < 2; i++) {  // Two spatial dimensions X&Y
430     int64_t ifm_dim = GetTensorSpatialDimIndex(
431         4, data_format_tf, i);  // 4D tensor, NHWC/NCHW format
432     int64_t ofm_dim = GetTensorSpatialDimIndex(
433         4, data_format_tf, i);  // 4D tensor, NHWC/NCHW format
434     int64_t filter_dim = first_filter_spatial_dim + i;
435 
436     int64_t ifm_size = input_type.getDimSize(ifm_dim);
437     int64_t filter_size = filter_type.getDimSize(filter_dim);
438     int64_t ofm_size = output_type.getDimSize(ofm_dim);
439     int64_t dim_dilation = dilations[i].template cast<IntegerAttr>().getInt();
440     int64_t dim_stride = strides[i].template cast<IntegerAttr>().getInt();
441 
442     int effective_filter_size = (filter_size - 1) * dim_dilation + 1;
443     int total_padding =
444         ((ifm_size - 1) * dim_stride + effective_filter_size - ofm_size);
445     total_padding = total_padding > 0 ? total_padding : 0;
446 
447     pad_before = total_padding / 2;
448     pad_after = total_padding - pad_before;
449 
450     computed_paddings.push_back(pad_before);
451   }
452 
453   explicit_padding = rewriter.getI64ArrayAttr(computed_paddings);
454   return true;
455 }
456 
457 // Templated function to create a constant op for given type and shape.
458 // T: storage C type.
459 // Default template creates a constant tensor in T.
460 template <typename T>
getConstTensor(PatternRewriter & rewriter,Operation * op,ArrayRef<T> vec,ArrayRef<int64_t> shape)461 llvm::Optional<Value> getConstTensor(PatternRewriter& rewriter, Operation* op,
462                                      ArrayRef<T> vec, ArrayRef<int64_t> shape) {
463   int64_t num_total_elements = 1;
464   for (int64_t a : shape) {
465     num_total_elements *= a;
466   }
467 
468   if (vec.size() != num_total_elements) {
469     op->emitOpError("getConstTensor(): number of elements mismatch.");
470     return llvm::None;
471   }
472 
473   auto const_type =
474       RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
475   auto const_attr = DenseElementsAttr::get(const_type, vec);
476 
477   auto const_op =
478       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
479   return const_op.getResult();
480 }
481 
482 // Template specialization for APInt
483 template <>
getConstTensor(PatternRewriter & rewriter,Operation * op,ArrayRef<APInt> vec,ArrayRef<int64_t> shape)484 llvm::Optional<Value> getConstTensor<APInt>(PatternRewriter& rewriter,
485                                             Operation* op, ArrayRef<APInt> vec,
486                                             ArrayRef<int64_t> shape) {
487   int64_t num_total_elements = 1;
488   for (int64_t a : shape) {
489     num_total_elements *= a;
490   }
491 
492   if (vec.size() != num_total_elements) {
493     op->emitOpError("getConstTensor(): number of elements mismatch.");
494     return llvm::None;
495   }
496 
497   auto const_type = RankedTensorType::get(
498       shape, rewriter.getIntegerType(vec[0].getBitWidth()));
499   auto const_attr = DenseElementsAttr::get(const_type, vec);
500 
501   auto const_op =
502       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
503   return const_op.getResult();
504 }
505 
506 // Template specialization for float
507 template <>
getConstTensor(PatternRewriter & rewriter,Operation * op,ArrayRef<float> vec,ArrayRef<int64_t> shape)508 llvm::Optional<Value> getConstTensor<float>(PatternRewriter& rewriter,
509                                             Operation* op, ArrayRef<float> vec,
510                                             ArrayRef<int64_t> shape) {
511   int64_t num_total_elements = 1;
512   for (int64_t a : shape) {
513     num_total_elements *= a;
514   }
515 
516   if (vec.size() != num_total_elements) {
517     op->emitOpError("getConstTensor(): number of elements mismatch.");
518     return llvm::None;
519   }
520 
521   auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
522   auto const_attr = DenseElementsAttr::get(const_type, vec);
523 
524   auto const_op =
525       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
526   return const_op.getResult();
527 }
528 
529 // Template instantiation
530 template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter&,
531                                                        Operation*,
532                                                        ArrayRef<int32_t> vec,
533                                                        ArrayRef<int64_t> shape);
534 
535 // Check if scale32 mode is used for given output_element_type
isScale32(mlir::quant::UniformQuantizedType output_element_type)536 bool isScale32(mlir::quant::UniformQuantizedType output_element_type) {
537   return (output_element_type.getStorageTypeIntegralWidth() == 8);
538 }
539 
540 }  // namespace tosa
541 }  // namespace mlir
542