• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
17 
18 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
20 #include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
21 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"  // from @llvm-project
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
24 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
25 
26 // Implements legalization and post-legalization optimization helper functions
27 
28 namespace mlir {
29 namespace tosa {
30 
31 // Create a TOSA rescale op from TFLite scaling, zero points and rounding mode
buildRescale(PatternRewriter & rewriter,Operation * op,ShapedType output_type,Value input_val,double scale,int64_t input_zp,int64_t output_zp,bool double_round,bool scale32)32 Value buildRescale(PatternRewriter& rewriter, Operation* op,
33                    ShapedType output_type, Value input_val, double scale,
34                    int64_t input_zp, int64_t output_zp, bool double_round,
35                    bool scale32) {
36   int32_t multiplier;
37   int32_t shift;
38 
39   int32_t scale_width = scale32 ? 32 : 16;
40 
41   computeMultiplierAndShift(scale, multiplier, shift, scale_width);
42 
43   auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
44       rewriter, op->getLoc(), output_type, input_val,
45       rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)),
46       rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)),
47       rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}),
48       rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round),
49       rewriter.getBoolAttr(false));
50 
51   return rescale_op.getResult();
52 }
53 
54 // Creates TOSA rescale op with int32 output
buildRescaleToInt32(PatternRewriter & rewriter,Operation * op,Value input_val,double input_scale,int64_t input_zp)55 Value buildRescaleToInt32(PatternRewriter& rewriter, Operation* op,
56                           Value input_val, double input_scale,
57                           int64_t input_zp) {
58   // Output is always int32 type
59   auto input_type = input_val.getType().dyn_cast<mlir::ShapedType>();
60   assert(input_type);
61   auto output_type = input_type.clone(rewriter.getI32Type());
62 
63   return buildRescale(rewriter, op, output_type, input_val, input_scale,
64                       input_zp, 0, false, true);
65 }
66 
67 // Creates TOSA rescale op with int32 input
buildRescaleFromInt32(PatternRewriter & rewriter,Operation * op,ShapedType output_type,Value input_val,double output_scale,int64_t output_zp)68 Value buildRescaleFromInt32(PatternRewriter& rewriter, Operation* op,
69                             ShapedType output_type, Value input_val,
70                             double output_scale, int64_t output_zp) {
71   // Input should be int32 type
72   auto input_type = input_val.getType().dyn_cast<mlir::ShapedType>();
73   (void)input_type;
74   assert(input_type && input_type.getElementType().isInteger(32) &&
75          "expected rescale input element type to be i32");
76 
77   // Potentially check input_shape == output_shape here
78   return buildRescale(rewriter, op, output_type, input_val, output_scale, 0,
79                       output_zp, true, true);
80 }
81 
82 // Creates a TOSA rescale op based on conv2d parameters.
buildRescaleOpConvOutput(PatternRewriter & rewriter,Operation * op,Value conv_val,ShapedType input_type,ShapedType weight_type,ShapedType output_type)83 Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op,
84                                Value conv_val, ShapedType input_type,
85                                ShapedType weight_type, ShapedType output_type) {
86   auto input_qtype =
87       input_type.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
88   auto output_qtype = output_type.getElementType()
89                           .dyn_cast<mlir::quant::UniformQuantizedType>();
90 
91   double input_scale = input_qtype.getScale();
92 
93   int64_t output_zp = output_qtype.getZeroPoint();
94   double output_scale = output_qtype.getScale();
95 
96   bool scale32 = isScale32(output_qtype);
97   int32_t scale_width = scale32 ? 32 : 16;
98   // Only use double round if we are doing 32 bit scaling
99   bool double_round = scale32;
100 
101   if (auto weight_per_tensor_qtype =
102           weight_type.getElementType()
103               .dyn_cast<mlir::quant::UniformQuantizedType>()) {
104     // Per-tensor quantization
105     double weight_scale = weight_per_tensor_qtype.getScale();
106 
107     int32_t multiplier;
108     int32_t shift;
109 
110     double op_tensor_scale = (input_scale * weight_scale) / output_scale;
111 
112     computeMultiplierAndShift(op_tensor_scale, multiplier, shift, scale_width);
113 
114     auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
115         rewriter, op->getLoc(), output_type, conv_val,
116         rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
117         rewriter.getI32ArrayAttr({multiplier}),
118         rewriter.getI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32),
119         rewriter.getBoolAttr(double_round), rewriter.getBoolAttr(false));
120 
121     return rescale_op.getResult();
122 
123   } else if (auto weight_per_channel_qtype =
124                  weight_type.getElementType()
125                      .dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
126     // Per-channel quantization
127     SmallVector<int32_t> multiplier_arr;
128     SmallVector<int32_t> shift_arr;
129 
130     SmallVector<double> weight_scale_arr(
131         weight_per_channel_qtype.getScales().begin(),
132         weight_per_channel_qtype.getScales().end());
133 
134     int64_t output_zp = output_qtype.getZeroPoint();
135     double output_scale = output_qtype.getScale();
136 
137     for (double weight_scale : weight_scale_arr) {
138       int32_t multiplier;
139       int32_t shift;
140 
141       double op_channel_scale = (input_scale * weight_scale) / output_scale;
142 
143       computeMultiplierAndShift(op_channel_scale, multiplier, shift,
144                                 scale_width);
145 
146       multiplier_arr.push_back(multiplier);
147       shift_arr.push_back(shift);
148     }
149 
150     auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
151         rewriter, op->getLoc(), output_type, conv_val,
152         rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
153         rewriter.getI32ArrayAttr(multiplier_arr),
154         rewriter.getI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32),
155         rewriter.getBoolAttr(double_round), rewriter.getBoolAttr(true));
156 
157     return rescale_op.getResult();
158 
159   } else {
160     op->emitOpError("buildConvRescaleOp: unknown weight quantized type");
161     return nullptr;
162   }
163 }
164 
165 // Create a 8-bit TOSA TABLE constant tensor with int8[256] array.
166 // Follow PopulateLookupTable() tensorflow/lite/kernels/activations.cc
getTosaConst8bitTable(PatternRewriter & rewriter,Operation * op,double input_scale,int32_t input_zp,double output_scale,int32_t output_zp,std::function<double (double)> func)167 Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op,
168                             double input_scale, int32_t input_zp,
169                             double output_scale, int32_t output_zp,
170                             std::function<double(double)> func) {
171   SmallVector<int8_t, 256> table;
172 
173   for (int32_t i = -128; i < 128; i++) {
174     double dequantized = input_scale * (i - input_zp);
175     double transformed = func(dequantized);
176     int32_t rescaled = std::llround(transformed / output_scale);
177     int32_t quantized = static_cast<int32_t>(rescaled + output_zp);
178     table.push_back(
179         static_cast<int8_t>(std::min(std::max(quantized, -128), 127)));
180   }
181 
182   auto element_qtype =
183       UniformQuantizedType::get(true, rewriter.getIntegerType(8),
184                                 rewriter.getF32Type(), 1.0f, 0, -128, 127);
185   auto const_type = RankedTensorType::get({256}, element_qtype);
186   auto storage_type =
187       RankedTensorType::get({256}, element_qtype.getStorageType());
188   auto const_attr =
189       DenseElementsAttr::get(storage_type, llvm::makeArrayRef(table));
190 
191   auto const_op =
192       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
193   return const_op.getResult();
194 }
195 
196 // Create a 16-bit TOSA TABLE constant tensor with int16[513] array.
197 // Output is restricted to [-1.0, 1.0].
198 // Follow gen_lut() tensorflow/lite/kernels/internal/common.h
getTosaConst16bitTable(PatternRewriter & rewriter,Operation * op,std::function<double (double)> func,double min,double max)199 Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op,
200                              std::function<double(double)> func, double min,
201                              double max) {
202   SmallVector<int16_t, 513> table;
203 
204   double step = (max - min) / 512.0f;
205   double half_step = step / 2.0f;
206   for (int32_t i = 0; i < 512; i++) {
207     int32_t sample_val = std::llround(func(min + (i * step)) * 32768.0);
208     double midpoint_interp_val =
209         std::round(((func(min + (i + 1) * step) * 32768.0) +
210                     std::round(func(min + (i * step)) * 32768.0)) /
211                    2.0);
212     double midpoint_val =
213         std::round(func(min + (i * step) + half_step) * 32768.0);
214     double midpoint_err = midpoint_interp_val - midpoint_val;
215     int32_t bias = std::llround(midpoint_err / 2.0);
216 
217     table.push_back(static_cast<int16_t>(
218         std::min(std::max(sample_val - bias, -32768), 32767)));
219   }
220 
221   int32_t max_val = std::llround(func(max) * 32768.0);
222   table.push_back(
223       static_cast<int16_t>(std::min(std::max(max_val, -32768), 32767)));
224 
225   auto element_qtype =
226       UniformQuantizedType::get(true, rewriter.getIntegerType(16),
227                                 rewriter.getF32Type(), 1.0f, 0, -32768, 32767);
228   auto const_type = RankedTensorType::get({513}, element_qtype);
229   auto storage_type =
230       RankedTensorType::get({513}, element_qtype.getStorageType());
231   auto const_attr =
232       DenseElementsAttr::get(storage_type, llvm::makeArrayRef(table));
233 
234   auto const_op =
235       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
236   return const_op.getResult();
237 }
238 
239 // Create a 32-bit TOSA TABLE constant tensor with int16[513] array.
240 // Output is restricted to [-1.0, 1.0] as s0.31 format.
getTosaConst32bitTable(PatternRewriter & rewriter,Operation * op,double input_scale,int32_t input_zp,std::function<double (double)> func,Value & first_const,Value & second_const,Value & third_const,Value & fourth_const)241 void getTosaConst32bitTable(PatternRewriter& rewriter, Operation* op,
242                             double input_scale, int32_t input_zp,
243                             std::function<double(double)> func,
244                             Value& first_const, Value& second_const,
245                             Value& third_const, Value& fourth_const) {
246   SmallVector<int16_t, 513> first_table, second_table, third_table,
247       fourth_table;
248 
249   double output_inv_scale = static_cast<double>(1L << 31);
250 
251   for (int32_t i = -256; i <= 256; i++) {
252     double dequantized = input_scale * (i - input_zp);
253     double transformed = func(dequantized);
254     double truncated = std::min(std::max(transformed, -1.0), 1.0);
255     int64_t rescaled =
256         static_cast<int64_t>(std::round(truncated * output_inv_scale));
257 
258     // 2^31 is not representable in int32_t, so store as 2^31 - 1 instead
259     if (rescaled == static_cast<int64_t>(1L << 31)) {
260       rescaled = static_cast<int64_t>(1L << 31) - 1;
261     }
262 
263     // Only copy the 8-bit groups
264     int32_t first = (rescaled >> 24) & 0xFF;
265     int32_t second = (rescaled >> 16) & 0xFF;
266     int32_t third = (rescaled >> 8) & 0xFF;
267     int32_t fourth = (rescaled)&0xFF;
268 
269     first_table.push_back(first);
270     second_table.push_back(second);
271     third_table.push_back(third);
272     fourth_table.push_back(fourth);
273   }
274 
275   auto element_qtype =
276       UniformQuantizedType::get(true, rewriter.getIntegerType(16),
277                                 rewriter.getF32Type(), 1.0f, 0, -32768, 32767);
278   auto const_type = RankedTensorType::get({513}, element_qtype);
279   auto storage_type =
280       RankedTensorType::get({513}, element_qtype.getStorageType());
281 
282   auto first_const_attr =
283       DenseElementsAttr::get(storage_type, llvm::makeArrayRef(first_table));
284   auto second_const_attr =
285       DenseElementsAttr::get(storage_type, llvm::makeArrayRef(second_table));
286   auto third_const_attr =
287       DenseElementsAttr::get(storage_type, llvm::makeArrayRef(third_table));
288   auto fourth_const_attr =
289       DenseElementsAttr::get(storage_type, llvm::makeArrayRef(fourth_table));
290 
291   first_const =
292       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, first_const_attr)
293           .getResult();
294   second_const =
295       rewriter
296           .create<tosa::ConstOp>(op->getLoc(), const_type, second_const_attr)
297           .getResult();
298   third_const =
299       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, third_const_attr)
300           .getResult();
301   fourth_const =
302       rewriter
303           .create<tosa::ConstOp>(op->getLoc(), const_type, fourth_const_attr)
304           .getResult();
305 }
306 
307 // Create a 32-bit float constant operator from a float
getTosaConstTensorSingleF32(PatternRewriter & rewriter,Operation * op,float val)308 Value getTosaConstTensorSingleF32(PatternRewriter& rewriter, Operation* op,
309                                   float val) {
310   auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
311   auto const_attr = DenseElementsAttr::get(const_type, val);
312 
313   auto const_op =
314       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
315   return const_op.getResult();
316 }
317 
318 // Create a 32-bit integer constant operator from an int
getTosaConstTensorSingleI32(PatternRewriter & rewriter,Operation * op,int32_t val)319 Value getTosaConstTensorSingleI32(PatternRewriter& rewriter, Operation* op,
320                                   int32_t val) {
321   auto const_type = RankedTensorType::get({}, rewriter.getIntegerType(32));
322   auto const_attr = DenseElementsAttr::get(const_type, val);
323 
324   auto const_op =
325       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
326   return const_op.getResult();
327 }
328 
329 // Create a vector from a 32-bit value tensor.  Returns the size of
330 // the new vector or -1 on error.
getVectorFromValue32(Value val,SmallVectorImpl<int32_t> & vec)331 LogicalResult getVectorFromValue32(Value val, SmallVectorImpl<int32_t>& vec) {
332   int i = 0;
333 
334   ElementsAttr elems;
335 
336   vec.clear();
337 
338   if (!matchPattern(val, m_Constant(&elems))) return failure();
339 
340   for (auto idx : elems.getValues<IntegerAttr>()) {
341     vec.push_back(idx.getInt());
342     i++;
343   }
344 
345   return success();
346 }
347 
348 // Calculates the TOSA padding values based on TF operators padded with
349 // SAME/VALID.
350 //
351 // This could pass tensorflow::FilterTensorFormat and do
352 // GetFilterTensorSpatialDimIndex but the current TF core libs do not support
353 // FORMAT_OHWI parsing by that function in core/util/tensor_format.h
getPaddingValuesFromPadType(tensorflow::Padding tf_pad,tensorflow::TensorFormat data_format_tf,uint32_t first_filter_spatial_dim,ShapedType input_type,ShapedType filter_type,ArrayAttr strides,ArrayAttr dilations,PatternRewriter & rewriter,ArrayAttr & explicit_padding)354 bool getPaddingValuesFromPadType(tensorflow::Padding tf_pad,
355                                  tensorflow::TensorFormat data_format_tf,
356                                  uint32_t first_filter_spatial_dim,
357                                  ShapedType input_type, ShapedType filter_type,
358                                  ArrayAttr strides, ArrayAttr dilations,
359                                  PatternRewriter& rewriter,
360                                  ArrayAttr& explicit_padding) {
361   assert(tf_pad != tensorflow::Padding::EXPLICIT);
362   if (!input_type.hasRank() || !filter_type.getRank()) return false;
363 
364   // Storing the numeric padding values is useful for TOSA codegen, as opposed
365   // to holding the padding regime mnemonic, i.e. SAME, VALID, FULL, ...
366   SmallVector<int64_t> computed_paddings;
367 
368   int64_t pad_before, pad_after;
369   for (int i = 0; i < 2; i++) {  // Two spatial dimensions X&Y
370     int64_t ifm_dim = GetTensorSpatialDimIndex(
371         4, data_format_tf, i);  // 4D tensor, NHWC/NCHW format
372     int64_t filter_dim = first_filter_spatial_dim + i;
373 
374     int64_t dim_dilation = dilations[i].template cast<IntegerAttr>().getInt();
375     int64_t dim_stride = strides[i].template cast<IntegerAttr>().getInt();
376 
377     int64_t ip_size = input_type.getDimSize(ifm_dim);
378     int64_t f_size = filter_type.getDimSize(filter_dim);
379     // If we have a dynamic shape we should assume it is wide enough.
380     ip_size = ip_size < 0 ? f_size * dim_dilation : ip_size;
381     int64_t op_size, pad_before_tf,
382         pad_after_tf;  // Complains if using int64_T
383     tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
384         ip_size, f_size, dim_dilation, dim_stride, tf_pad, &op_size,
385         &pad_before_tf, &pad_after_tf);
386     if (!status.ok()) return false;
387 
388     pad_before = pad_before_tf;
389     pad_after = pad_after_tf;
390     computed_paddings.push_back(pad_before);
391     computed_paddings.push_back(pad_after);
392   }
393 
394   explicit_padding = rewriter.getI64ArrayAttr(computed_paddings);
395   return true;
396 }
397 
398 // Calculates the TOSA padding values for explicit-padded TF operators.
399 //
400 // This function only handles the TF padding array explicit_padding, which is
401 // only present in certain TF ops. All others encode padding using the string
402 // SAME/VALID, which is interpreted using the getPaddingValuesFromPadString
403 // function below.
404 
405 // The explicit padding array in TF holds 2 pad values for every
406 // dimension, even those that are not the 2 spatial ones. Just extract the
407 // 2x pad values for the XY dims.
getPaddingValuesFromExplicitPadAttr(ArrayAttr explicit_pad,tensorflow::TensorFormat data_format_tf,PatternRewriter & rewriter)408 ArrayAttr getPaddingValuesFromExplicitPadAttr(
409     ArrayAttr explicit_pad, tensorflow::TensorFormat data_format_tf,
410     PatternRewriter& rewriter) {
411   SmallVector<int64_t> computed_paddings;
412 
413   int64_t pad_before, pad_after;
414   for (int i = 0; i < 2; i++) {  // Two spatial dimensions X&Y
415     int64_t dim = GetTensorSpatialDimIndex(4, data_format_tf,
416                                            i);  // 4D tensor, NHWC/NCHW format
417 
418     pad_before = explicit_pad[dim * 2].template cast<IntegerAttr>().getInt();
419     pad_after = explicit_pad[dim * 2 + 1].template cast<IntegerAttr>().getInt();
420     computed_paddings.push_back(pad_before);
421     computed_paddings.push_back(pad_after);
422   }
423 
424   return rewriter.getI64ArrayAttr(computed_paddings);
425 }
426 
427 // Calculates the TOSA padding values for transposeConv2d
getTransposeConv2dPaddingValues(tensorflow::Padding tf_pad,tensorflow::TensorFormat data_format_tf,uint32_t first_filter_spatial_dim,ShapedType input_type,ShapedType filter_type,ShapedType output_type,ArrayAttr strides,PatternRewriter & rewriter,ArrayAttr & explicit_padding)428 bool getTransposeConv2dPaddingValues(
429     tensorflow::Padding tf_pad, tensorflow::TensorFormat data_format_tf,
430     uint32_t first_filter_spatial_dim, ShapedType input_type,
431     ShapedType filter_type, ShapedType output_type, ArrayAttr strides,
432     PatternRewriter& rewriter, ArrayAttr& explicit_padding) {
433   assert(tf_pad != tensorflow::Padding::EXPLICIT);
434   if (!input_type.hasRank() || !filter_type.hasRank() || !output_type.hasRank())
435     return false;
436 
437   // Storing the numeric padding values is useful for TOSA codegen, as opposed
438   // to holding the padding regime mnemonic, i.e. SAME, VALID, FULL, ...
439 
440   SmallVector<int64_t> computed_paddings;
441 
442   int64_t pad_before, pad_after;
443   for (int i = 0; i < 2; i++) {  // Two spatial dimensions X&Y
444     int64_t ifm_dim = GetTensorSpatialDimIndex(
445         4, data_format_tf, i);  // 4D tensor, NHWC/NCHW format
446     int64_t ofm_dim = GetTensorSpatialDimIndex(
447         4, data_format_tf, i);  // 4D tensor, NHWC/NCHW format
448     int64_t filter_dim = first_filter_spatial_dim + i;
449 
450     int64_t ifm_size = input_type.getDimSize(ifm_dim);
451     int64_t filter_size = filter_type.getDimSize(filter_dim);
452     int64_t ofm_size = output_type.getDimSize(ofm_dim);
453     int64_t dim_stride = strides[i].template cast<IntegerAttr>().getInt();
454 
455     // These dimensions need to be static to legalize.
456     if (ShapedType::isDynamic(filter_size) || ShapedType::isDynamic(ifm_size) ||
457         ShapedType::isDynamic(ofm_size)) {
458       return false;
459     }
460 
461     int total_padding = ((ifm_size - 1) * dim_stride + filter_size - ofm_size);
462     total_padding = total_padding > 0 ? total_padding : 0;
463 
464     pad_before = total_padding / 2;
465     pad_after = total_padding - pad_before;
466 
467     computed_paddings.push_back(pad_before);
468     computed_paddings.push_back(pad_after);
469   }
470 
471   explicit_padding = rewriter.getI64ArrayAttr(computed_paddings);
472   return true;
473 }
474 
475 // Templated function to create a constant op for given type and shape.
476 // T: storage C type.
477 // Default template creates a constant tensor in T.
478 template <typename T>
getConstTensor(PatternRewriter & rewriter,Operation * op,ArrayRef<T> vec,ArrayRef<int64_t> shape)479 llvm::Optional<Value> getConstTensor(PatternRewriter& rewriter, Operation* op,
480                                      ArrayRef<T> vec, ArrayRef<int64_t> shape) {
481   int64_t num_total_elements = 1;
482   for (int64_t a : shape) {
483     num_total_elements *= a;
484   }
485 
486   if (vec.size() != num_total_elements) {
487     op->emitOpError("getConstTensor(): number of elements mismatch.");
488     return llvm::None;
489   }
490 
491   auto const_type =
492       RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
493   auto const_attr = DenseElementsAttr::get(const_type, vec);
494 
495   auto const_op =
496       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
497   return const_op.getResult();
498 }
499 
500 // Template specialization for APInt
501 template <>
getConstTensor(PatternRewriter & rewriter,Operation * op,ArrayRef<APInt> vec,ArrayRef<int64_t> shape)502 llvm::Optional<Value> getConstTensor<APInt>(PatternRewriter& rewriter,
503                                             Operation* op, ArrayRef<APInt> vec,
504                                             ArrayRef<int64_t> shape) {
505   int64_t num_total_elements = 1;
506   for (int64_t a : shape) {
507     num_total_elements *= a;
508   }
509 
510   if (vec.size() != num_total_elements) {
511     op->emitOpError("getConstTensor(): number of elements mismatch.");
512     return llvm::None;
513   }
514 
515   auto const_type = RankedTensorType::get(
516       shape, rewriter.getIntegerType(vec[0].getBitWidth()));
517   auto const_attr = DenseElementsAttr::get(const_type, vec);
518 
519   auto const_op =
520       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
521   return const_op.getResult();
522 }
523 
524 // Template specialization for float
525 template <>
getConstTensor(PatternRewriter & rewriter,Operation * op,ArrayRef<float> vec,ArrayRef<int64_t> shape)526 llvm::Optional<Value> getConstTensor<float>(PatternRewriter& rewriter,
527                                             Operation* op, ArrayRef<float> vec,
528                                             ArrayRef<int64_t> shape) {
529   int64_t num_total_elements = 1;
530   for (int64_t a : shape) {
531     num_total_elements *= a;
532   }
533 
534   if (vec.size() != num_total_elements) {
535     op->emitOpError("getConstTensor(): number of elements mismatch.");
536     return llvm::None;
537   }
538 
539   auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
540   auto const_attr = DenseElementsAttr::get(const_type, vec);
541 
542   auto const_op =
543       rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
544   return const_op.getResult();
545 }
546 
547 // Template instantiation
548 template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter&,
549                                                        Operation*,
550                                                        ArrayRef<int32_t> vec,
551                                                        ArrayRef<int64_t> shape);
552 
553 // Check if scale32 mode is used for given output_element_type
isScale32(mlir::quant::UniformQuantizedType output_element_type)554 bool isScale32(mlir::quant::UniformQuantizedType output_element_type) {
555   return (output_element_type.getStorageTypeIntegralWidth() == 8);
556 }
557 
ApplyPatternsWithShapeResolution(func::FuncOp func,const FrozenRewritePatternSet & patterns)558 LogicalResult ApplyPatternsWithShapeResolution(
559     func::FuncOp func, const FrozenRewritePatternSet& patterns) {
560   // We use top-down traversal so that shape inference can fully infer types
561   // during pattern rewrite.
562   GreedyRewriteConfig config;
563   config.useTopDownTraversal = true;
564   if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
565     return failure();
566   }
567 
568   // Check that constant attributes types and op types match up. If the lowering
569   // needs to change a type (e.g. fp16 -> fp32) its possible the return type
570   // could be incorrect.
571   //
572   // This should be investigate for whether it is still necessary due to quant
573   // type stripping changing.
574   func.walk([&](tosa::ConstOp op) {
575     auto ety = op.getValue().getType().getElementType();
576     auto new_ty = op.getType().cast<ShapedType>().clone(ety);
577     op.getResult().setType(new_ty);
578   });
579 
580   auto returnOp = cast<func::ReturnOp>(func.getBody().front().getTerminator());
581   llvm::SmallVector<Type> result_tys(returnOp.getOperandTypes());
582 
583   func.setType(FunctionType::get(
584       func.getContext(), func.getFunctionType().getInputs(), result_tys));
585 
586   return success();
587 }
588 
589 }  // namespace tosa
590 }  // namespace mlir
591