• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file contains legalizations common to mapping both TensorFlow and
17 // TensorFlow Lite to TOSA. It operates generically on ops and does not have
18 // a hard reference on either dialect.
19 //
20 // Conversion functions return llvm::None on a legalization failure or a
21 // legalized value on success.  Callers must check for presence of an
22 // llvm::Optional value after each call.
23 
24 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
25 
26 #include <climits>
27 #include <cstddef>
28 #include <cstdint>
29 #include <iterator>
30 #include <numeric>
31 
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
34 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
35 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
36 #include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
37 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
38 #include "mlir/IR/Matchers.h"  // from @llvm-project
39 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
41 
42 namespace mlir {
43 namespace tosa {
44 
45 // Copied Nudge implementation from
46 // tensorflow/core/kernels/fake_quant_ops_functor.h.
47 // Suggested approach to avoid significant TensorFlow
48 // build dependency.
tensorflow_nudge(const float min,const float max,const int quant_min,const int quant_max,float * nudged_min,float * nudged_max,float * scale)49 void tensorflow_nudge(const float min, const float max, const int quant_min,
50                       const int quant_max, float* nudged_min, float* nudged_max,
51                       float* scale) {
52   const float quant_min_float = static_cast<float>(quant_min);
53   const float quant_max_float = static_cast<float>(quant_max);
54   *scale = (max - min) / (quant_max_float - quant_min_float);
55   const float zero_point_from_min = quant_min_float - min / *scale;
56   const uint16_t nudged_zero_point = [zero_point_from_min, quant_min,
57                                       quant_min_float, quant_max,
58                                       quant_max_float] {
59     if (zero_point_from_min < quant_min_float) {
60       return static_cast<uint16_t>(quant_min);
61     }
62     if (zero_point_from_min > quant_max_float) {
63       return static_cast<uint16_t>(quant_max);
64     }
65     return static_cast<uint16_t>(std::round(zero_point_from_min));
66   }();
67   *nudged_min = (quant_min_float - nudged_zero_point) * (*scale);
68   *nudged_max = (quant_max_float - nudged_zero_point) * (*scale);
69 }
70 
71 // Lowers the Pack operator to TOSA.
convertPackOp(PatternRewriter & rewriter,Operation * op,Value result_value,SmallVectorImpl<Value> & inputs,int32_t axis)72 llvm::Optional<Value> convertPackOp(PatternRewriter& rewriter, Operation* op,
73                                     Value result_value,
74                                     SmallVectorImpl<Value>& inputs,
75                                     int32_t axis) {
76   //////////////////////////////////////////////////
77   // Operator: output = Pack([values], axis) or output = Stack([values], axis)
78   // Lowering:
79   //
80   // This operator is lowered into a series of pairwise tosa.concat()
81   // operators and a reshape
82   // Depending on the inputs, a tranpose operator is also generated:
83   //
84   // Step 1: concatenate the tensors
85   // a1_concat = tosa.concat(input[0], input[1], axis)
86   // for (i = 2; i < len(input); i++)
87   //   a1_concat = tosa.concat(a1_concat, input[i], axis)
88   //
89   // Step 2: reshape to N+1 dimensions
90   // a2_reshape = tosa.reshape(a1_concat, new_rank)
91   //
92   // Step 3: Transpose if a new dimension is being added:
93   // if (axis == rank(values[0]):
94   //   // perm will be [1, 2, 3, 0]
95   //   a3_transpose = tosa.transpose(a2_reshape, perm)
96 
97   // Sanity check 1: make sure all input tensors have the same shape
98   // if input[0] has shape [A, B, C], input[1] to input[N-1] should also have
99   // shape[A, B, C]
100   RankedTensorType result_type =
101       result_value.getType().dyn_cast<RankedTensorType>();
102 
103   // Check for ranked tensor type.
104   if (!result_type) {
105     op->emitOpError("PackOp: result type not ranked tensor");
106     return llvm::None;
107   }
108 
109   // Valid axis in TF is [-rank(input), rank(input))
110   // Valid axis in TOSA is [0, rank(input))
111   // Plus rank(input) once if axis is negative.
112   RankedTensorType input_type =
113       op->getOperand(0).getType().dyn_cast<RankedTensorType>();
114   if (!input_type) {
115     op->emitOpError("PackOp: input type not ranked tensor");
116     return llvm::None;
117   }
118 
119   input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
120   if (!input_type) {
121     op->emitOpError("Input 0 type not ranked tensor.");
122     return llvm::None;
123   }
124   ArrayRef<int64_t> input0_tensor_shape = input_type.getShape();
125   int input_tensor_rank = input0_tensor_shape.size();
126 
127   for (int i = 1; i < inputs.size(); i++) {
128     input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
129     if (!input_type) {
130       op->emitOpError(llvm::formatv(
131           "reduce axis {} is not in valid range [-rank(input), rank(input))",
132           i));
133       return llvm::None;
134     }
135     ArrayRef<int64_t> next_tensor_shape = input_type.getShape();
136     if (next_tensor_shape.size() != input_tensor_rank) {
137       op->emitOpError("PackOp: input tensor rank mismatch.");
138       return llvm::None;
139     }
140     for (int d = 0; d < input0_tensor_shape.size(); d++) {
141       if (input0_tensor_shape[d] != next_tensor_shape[d]) {
142         op->emitOpError("PackOp: input tensor shape mismatch.");
143         return llvm::None;
144       }
145     }
146   }
147 
148   // If input tensors are rank 0, should reshape them to rank 1 size 1 before
149   // performing concat.
150   if (input_tensor_rank == 0) {
151     SmallVector<int64_t, 1> reshape_rank1_size1_shape({1});
152     RankedTensorType reshape_rank1_size1_type = RankedTensorType::get(
153         reshape_rank1_size1_shape, result_type.getElementType());
154     ArrayAttr shape_rank1_size1_attr =
155         rewriter.getI64ArrayAttr(reshape_rank1_size1_shape);
156     for (int i = 0; i < inputs.size(); i++) {
157       auto a0_reshape_op = rewriter.create<tosa::ReshapeOp>(
158           op->getLoc(), reshape_rank1_size1_type, inputs[i],
159           shape_rank1_size1_attr);
160       inputs[i] = a0_reshape_op.getResult();
161     }
162   }
163 
164   // Sanity check 2: axis can be from [0, rank(input)+1]
165   // Where rank(input)+1 means create a new dimension
166   // Negative values are also allowed up to -(rank(input)+1)
167   // where the axis "wraps around".
168   if (axis < 0) axis += input_tensor_rank;
169   if ((axis < 0) || (axis > (input_tensor_rank + 1))) {
170     op->emitOpError("PackOp: axis out of valid range.");
171     return llvm::None;
172   }
173 
174   // Sanity check 2: if input shape is [A, B, C], output shape should be [N,
175   // A, B, C]
176   // 2.a check output is rank(input) + 1
177   SmallVector<int64_t> output_shape_vals(result_type.getShape().begin(),
178                                          result_type.getShape().end());
179   if (output_shape_vals.size() != (input_tensor_rank + 1)) {
180     op->emitOpError("PackOp: output tensor rank mismatch.");
181     return llvm::None;
182   }
183   // 2.b check output rank 0 is N
184   if (output_shape_vals[axis] != inputs.size()) {
185     op->emitOpError("PackOp: output tensor shape mismatch.");
186     return llvm::None;
187   }
188   // Most of the cases when PackOp.axis() is within [0, rank(input) - 1].
189   // We can directly concatenate along that axis and perform the reshape.
190   // For example, stack N [A, B, C] input tensor ranks along axis = 1
191   // after concatenation, output will be [A, N * B, C]
192   // and then reshape it into [A, N, B, C]
193   // a special case would be PackOp.axis() equal to rank(input), in which case
194   // we can't directly concatenate along the PackOp.axis(), instead
195   // we concat along axis=0, and reshape into [N, A, B, C]
196   // and then we need an extra transpose to [A, B, C, N].
197   int64_t concat_axis;
198   SmallVector<int32_t> perm;
199   SmallVector<int64_t> reshape_output_shape;
200   if (axis == 0 && input_tensor_rank == 0) {
201     concat_axis = 0;
202   } else if (axis == input_tensor_rank) {
203     concat_axis = 0;
204 
205     // A special case when stack axis is equal to input tensor rank:
206     // Output shape is [A, B, C, N]
207     // so reshape output will be [N, A, B, C]
208     // and perm will be [1, 2, 3, 0].
209     reshape_output_shape.push_back(output_shape_vals[axis]);
210     for (int d = 0; d < input_tensor_rank; d++) {
211       perm.push_back(d + 1);
212       reshape_output_shape.push_back(output_shape_vals[d]);
213     }
214     perm.push_back(0);
215   } else {
216     // General case, doesn't need perm vector.
217     concat_axis = axis;
218     reshape_output_shape.assign(output_shape_vals.begin(),
219                                 output_shape_vals.end());
220   }
221   IntegerAttr concat_axis_attr = rewriter.getI64IntegerAttr(concat_axis);
222   ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_output_shape);
223 
224   // Concat output shape will depend on concat_axis. E.g. [N * A, B, C]
225   SmallVector<int64_t> concat_output_shape;
226   if (input_tensor_rank == 0) {
227     concat_output_shape.push_back(1);
228   } else {
229     for (int i = 0; i < input_tensor_rank; i++) {
230       concat_output_shape.push_back(input0_tensor_shape[i]);
231     }
232   }
233 
234   concat_output_shape[concat_axis] =
235       concat_output_shape[concat_axis] * inputs.size();
236   RankedTensorType concat_type = RankedTensorType::get(
237       ArrayRef<int64_t>(concat_output_shape), result_type.getElementType());
238 
239   SmallVector<Value> inputs_0;
240   for (int i = 0; i < inputs.size(); i++) {
241     inputs_0.push_back(inputs[i]);
242   }
243   auto a1_concat_op = rewriter.create<tosa::ConcatOp>(
244       op->getLoc(), concat_type, inputs_0, concat_axis_attr);
245 
246   // Doesn't need reshape or transpose if input tensor is rank 0, since inputs
247   // are reshaped beforehand.
248   if (input_tensor_rank == 0) return a1_concat_op.getResult();
249 
250   // Reshape [N * A, B, C] to [N, A, B, C].
251   RankedTensorType reshape_output_type =
252       RankedTensorType::get(reshape_output_shape, result_type.getElementType());
253 
254   auto a2_reshape_op = rewriter.create<tosa::ReshapeOp>(
255       op->getLoc(), reshape_output_type, a1_concat_op.getResult(), shape_attr);
256 
257   // If axis is equal to input tensor rank, then we need extra transpose
258   // [N, A, B, C] to [A, B, C, N]
259   if (axis == input_tensor_rank) {
260     llvm::Optional<Value> a3_transpose_perm = getConstTensor<int32_t>(
261         rewriter, op, perm, {static_cast<int64_t>(perm.size())});
262 
263     if (!a3_transpose_perm) return llvm::None;
264 
265     return rewriter
266         .create<tosa::TransposeOp>(op->getLoc(), result_type,
267                                    a2_reshape_op.getResult(),
268                                    a3_transpose_perm.getValue())
269         .getResult();
270   }
271 
272   return a2_reshape_op.getResult();
273 }
274 
275 // Lowers the Unpack operator to TOSA
convertUnpackOp(PatternRewriter & rewriter,Operation * op,Value input_value,int32_t axis)276 llvm::Optional<SmallVector<Value>> convertUnpackOp(PatternRewriter& rewriter,
277                                                    Operation* op,
278                                                    Value input_value,
279                                                    int32_t axis) {
280   RankedTensorType input_type =
281       input_value.getType().dyn_cast<RankedTensorType>();
282   if (!input_type) return llvm::None;
283 
284   auto input_shape = input_type.getShape();
285   int64_t input_rank = input_shape.size();
286 
287   SmallVector<Value> results_vec;
288 
289   // Negative axis allowed as long as it's within [-input_rank, input_rank).
290   if (axis < 0) axis += input_rank;
291   if ((axis < 0) || (axis > input_rank)) {
292     op->emitOpError("UnpackOp: axis out of valid range.");
293     return llvm::None;
294   }
295 
296   // Step 1: transpose 'axis' to leftmost dimension.
297   Value transposed_input_value;
298   if (axis != 0) {
299     SmallVector<int32_t> perm;
300     SmallVector<int64_t> a1_transpose_shape(input_rank);
301 
302     perm.push_back(axis);
303     for (int i = 0; i < input_rank; i++) {
304       if (i == axis) continue;
305       perm.push_back(i);
306     }
307 
308     llvm::Optional<Value> a1_transpose_perm = getConstTensor<int32_t>(
309         rewriter, op, perm, {static_cast<int64_t>(perm.size())});
310 
311     if (!a1_transpose_perm) return llvm::None;
312 
313     for (int i = 0; i < input_rank; i++) {
314       a1_transpose_shape[i] = input_shape[perm[i]];
315     }
316 
317     auto a1_transpose_op = rewriter.create<tosa::TransposeOp>(
318         op->getLoc(),
319         RankedTensorType::get(a1_transpose_shape, input_type.getElementType()),
320         input_value, a1_transpose_perm.getValue());
321 
322     transposed_input_value = a1_transpose_op.getResult();
323   } else {
324     // Do nothing if axis is already at leftmost dimension.
325     transposed_input_value = input_value;
326   }
327 
328   // Step 2: slice [N, A, B, C] into N [A, B, C].
329   RankedTensorType transposed_input_type =
330       transposed_input_value.getType().dyn_cast<RankedTensorType>();
331   if (!transposed_input_type) return llvm::None;
332 
333   auto transposed_input_shape = transposed_input_type.getShape();
334   int64_t transposed_input_rank = transposed_input_shape.size();
335 
336   for (int i = 0; i < transposed_input_shape[0]; i++) {
337     SmallVector<int64_t> begin_vals, size_vals, shape_vals;
338 
339     for (int j = 0; j < transposed_input_rank; j++) {
340       if (j == 0) {
341         begin_vals.push_back(i);
342         size_vals.push_back(1);
343       } else {
344         begin_vals.push_back(0);
345         size_vals.push_back(transposed_input_shape[j]);
346         shape_vals.push_back(transposed_input_shape[j]);
347       }
348     }
349 
350     ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
351     ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
352 
353     auto a2_slice_op = rewriter.create<tosa::SliceOp>(
354         op->getLoc(),
355         RankedTensorType::get(size_vals,
356                               transposed_input_type.getElementType()),
357         transposed_input_value, begin, size);
358 
359     auto a3_reshape_op = rewriter.create<tosa::ReshapeOp>(
360         op->getLoc(),
361         RankedTensorType::get(shape_vals,
362                               transposed_input_type.getElementType()),
363         a2_slice_op.getResult(), rewriter.getI64ArrayAttr(shape_vals));
364 
365     results_vec.push_back(a3_reshape_op.getResult());
366   }
367 
368   return results_vec;
369 }
370 
371 // Lowers the Select operator to TOSA.
convertSelectOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value condition_value,Value x_value,Value y_value)372 llvm::Optional<Value> convertSelectOp(PatternRewriter& rewriter, Operation* op,
373                                       Value result_value, Value condition_value,
374                                       Value x_value, Value y_value) {
375   RankedTensorType result_type =
376       result_value.getType().dyn_cast<RankedTensorType>();
377   RankedTensorType condition_type =
378       condition_value.getType().dyn_cast<RankedTensorType>();
379   RankedTensorType x_type = x_value.getType().dyn_cast<RankedTensorType>();
380   RankedTensorType y_type = y_value.getType().dyn_cast<RankedTensorType>();
381 
382   if (!result_type || !condition_type || !x_type || !y_type) {
383     op->emitOpError("Select: failed ranked tensor type check");
384     return llvm::None;
385   }
386 
387   // First check whether we need to reshape the condition to match
388   // the same rank as the then/else clauses.
389   if (result_type.getRank() == condition_type.getRank()) {
390     // Nothing to reshape.
391     return rewriter
392         .create<tosa::SelectOp>(op->getLoc(), result_type, condition_value,
393                                 x_value, y_value)
394         .getResult();
395   }
396 
397   // Need to reshape the condition.
398   SmallVector<int64_t> new_cond_dims(
399       result_type.getRank() - condition_type.getRank(), 1);
400 
401   for (int i = 0; i < condition_type.getRank(); i++) {
402     new_cond_dims.push_back(condition_type.getShape()[i]);
403   }
404 
405   auto reshape_op = rewriter.create<tosa::ReshapeOp>(
406       op->getLoc(),
407       RankedTensorType::get(new_cond_dims, condition_type.getElementType()),
408       condition_value, rewriter.getI64ArrayAttr(new_cond_dims));
409 
410   return rewriter
411       .create<tosa::SelectOp>(op->getLoc(), result_type, reshape_op, x_value,
412                               y_value)
413       .getResult();
414 }
415 
416 // Lowers the ZerosLike operator to TOSA by creating a constant
417 // of the desired type and shape.
convertZerosLikeOp(PatternRewriter & rewriter,Operation * op,Value result,Value input)418 llvm::Optional<Value> convertZerosLikeOp(PatternRewriter& rewriter,
419                                          Operation* op, Value result,
420                                          Value input) {
421   RankedTensorType result_type = result.getType().dyn_cast<RankedTensorType>();
422   if (!result_type) {
423     op->emitOpError("Zeroslike: result not ranked tensor type");
424     return llvm::None;
425   }
426 
427   RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
428   if (!input_type) {
429     op->emitOpError("Zeroslike: input not ranked tensor type");
430     return llvm::None;
431   }
432 
433   auto input_shape = input_type.getShape();
434 
435   ShapedType zero_type =
436       RankedTensorType::get(input_shape, input_type.getElementType());
437   Attribute zero_attr = rewriter.getZeroAttr(zero_type);
438 
439   return rewriter
440       .create<tosa::ConstOp>(op->getLoc(), zero_type,
441                              zero_attr.cast<ElementsAttr>())
442       .getResult();
443 }
444 
445 // Lowers the Mul operator to TOSA.  For quantized types, this requires
446 // inserting rescale operators before and after the operation.
convertMultiplyOp(PatternRewriter & rewriter,Operation * op,Value output_val,Value input_lhs_val,Value input_rhs_val)447 llvm::Optional<Value> convertMultiplyOp(PatternRewriter& rewriter,
448                                         Operation* op, Value output_val,
449                                         Value input_lhs_val,
450                                         Value input_rhs_val) {
451   ShapedType input_lhs_type = input_lhs_val.getType().dyn_cast<ShapedType>();
452   ShapedType input_rhs_type = input_rhs_val.getType().dyn_cast<ShapedType>();
453   ShapedType output_type = output_val.getType().dyn_cast<ShapedType>();
454   // Not a shaped tensor output
455   if (!input_lhs_type || !input_rhs_type || !output_type) return llvm::None;
456 
457   bool input_lhs_is_qtype =
458       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
459   bool input_rhs_is_qtype =
460       input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
461   bool output_is_qtype =
462       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
463 
464   if (input_lhs_is_qtype != output_is_qtype ||
465       input_rhs_is_qtype != output_is_qtype) {
466     op->emitOpError(
467         "ConvertMultiplyOp: input/output tensor should "
468         "be all quantized or all floating-point");
469     return llvm::None;
470   }
471 
472   Value output;
473   if (output_is_qtype) {
474     ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
475     auto input_lhs_qtype = input_lhs_type.getElementType()
476                                .cast<mlir::quant::UniformQuantizedType>();
477     auto input_rhs_qtype = input_rhs_type.getElementType()
478                                .cast<mlir::quant::UniformQuantizedType>();
479     auto output_qtype =
480         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
481 
482     // MLIR store scale as double, but TFLite store scale as float
483     // Downcasting from double to float to match TFLite behavior
484     float in_lhs_scale = input_lhs_qtype.getScale();
485     float in_rhs_scale = input_rhs_qtype.getScale();
486     float output_scale = output_qtype.getScale();
487 
488     double output_rescale_scale = in_lhs_scale * in_rhs_scale / output_scale;
489 
490     // 16bits x 16bits -> 32bits
491     // 32bits can be rescaled with 32bits quantize multiplier back to 16bits
492     bool scale32 = true;
493 
494     Value op1_rescale_lhs = buildRescaleToInt32(
495         rewriter, op, input_lhs_val, 1.0f, input_lhs_qtype.getZeroPoint());
496     Value op2_rescale_rhs = buildRescaleToInt32(
497         rewriter, op, input_rhs_val, 1.0f, input_rhs_qtype.getZeroPoint());
498     auto op3_mul_op1_op2 = rewriter.create<tosa::MulOp>(
499         op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs, 0);
500     return buildRescale(rewriter, op, output_type, op3_mul_op1_op2.getResult(),
501                         output_rescale_scale, 0, output_qtype.getZeroPoint(),
502                         true, scale32);
503   }
504 
505   return rewriter
506       .create<tosa::MulOp>(op->getLoc(), output_type, input_lhs_val,
507                            input_rhs_val, 0)
508       .getResult();
509 }
510 
511 // Lowers the SquaredDifference operator to TOSA.
convertSquaredDifferenceOp(PatternRewriter & rewriter,Operation * op,Value result,Value x,Value y)512 llvm::Optional<Value> convertSquaredDifferenceOp(PatternRewriter& rewriter,
513                                                  Operation* op, Value result,
514                                                  Value x, Value y) {
515   // Squared-difference is (x-y)*(x-y).
516   // This lowering calculates the difference and multiplies.
517   ShapedType result_type = result.getType().dyn_cast<ShapedType>();
518   if (!result_type) {
519     op->emitOpError("SquaredDifference: result not ranked tensor type");
520     return llvm::None;
521   }
522 
523   ShapedType x_type = x.getType().dyn_cast<ShapedType>();
524   ShapedType y_type = y.getType().dyn_cast<ShapedType>();
525   if (!x_type || !y_type) {
526     op->emitOpError("SquaredDifference: inputs not ranked tensor type");
527     return llvm::None;
528   }
529 
530   auto sub_op = rewriter.create<tosa::SubOp>(op->getLoc(), result_type, x, y);
531   return rewriter
532       .create<tosa::MulOp>(op->getLoc(), result_type, sub_op.getResult(),
533                            sub_op.getResult(), 0)
534       .getResult();
535 }
536 
537 // Lowers the Round operator to TOSA.
convertRoundOp(PatternRewriter & rewriter,Operation * op,Value result,Value input)538 llvm::Optional<Value> convertRoundOp(PatternRewriter& rewriter, Operation* op,
539                                      Value result, Value input) {
540   // Implements banker's rounding by calculating floor(input + 0.5).
541   ShapedType result_type = result.getType().dyn_cast<ShapedType>();
542   if (!result_type) {
543     op->emitOpError("Round: result not shaped tensor type");
544     return llvm::None;
545   }
546 
547   ShapedType input_type = input.getType().dyn_cast<ShapedType>();
548   if (!input_type) {
549     op->emitOpError("Round: input not shaped tensor type");
550     return llvm::None;
551   }
552 
553   auto add_op = rewriter.create<tosa::AddOp>(
554       op->getLoc(), result_type, input,
555       getTosaConstTensorSingleF32(rewriter, op, 0.5));
556 
557   return rewriter
558       .create<tosa::FloorOp>(op->getLoc(), result_type, add_op.getResult())
559       .getResult();
560 }
561 
562 // Lowers ConcatV2 to TOSA Concat.
convertConcatV2Op(PatternRewriter & rewriter,Operation * op,Value result_value,SmallVectorImpl<Value> & values,int32_t axis)563 llvm::Optional<Value> convertConcatV2Op(PatternRewriter& rewriter,
564                                         Operation* op, Value result_value,
565                                         SmallVectorImpl<Value>& values,
566                                         int32_t axis) {
567   // Check all inputs are RankedTensorType
568   for (auto v : values) {
569     if (!v.getType().dyn_cast<RankedTensorType>()) {
570       op->emitOpError("ConcatV2Op: value type not ranked tensor.");
571       return llvm::None;
572     }
573   }
574 
575   // Check output is Ranked tensor type
576   if (!result_value.getType().dyn_cast<RankedTensorType>()) {
577     op->emitOpError("ConcatV2Op: output value type not ranked tensor.");
578     return llvm::None;
579   }
580 
581   RankedTensorType result_type =
582       result_value.getType().dyn_cast<RankedTensorType>();
583   mlir::quant::UniformQuantizedType result_quant_type =
584       result_type.getElementType()
585           .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
586 
587   SmallVector<Value> values_rescaled;
588 
589   for (auto v : values) {
590     RankedTensorType operand_type = v.getType().dyn_cast<RankedTensorType>();
591     mlir::quant::UniformQuantizedType operand_quant_type =
592         operand_type.getElementType()
593             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
594 
595     // tfl.concat currently allows different scales for each input tensor, which
596     // TFlite team will fix in:
597     // https://github.com/tensorflow/tensorflow/issues/39658
598     // For backward compatibility, we still need to support this artifact by
599     // scaling inputs to let them have the same scales.
600     if (result_quant_type && operand_quant_type) {
601       double operand_scale = static_cast<double>(operand_quant_type.getScale());
602       int32_t operand_zeropoint = operand_quant_type.getZeroPoint();
603 
604       double result_scale = static_cast<double>(result_quant_type.getScale());
605       int32_t result_zeropoint = result_quant_type.getZeroPoint();
606 
607       // Rescale input if scale is not equal to output tensor scale.
608       if (operand_scale != result_scale) {
609         RankedTensorType rescale_type =
610             RankedTensorType::get(operand_type.getShape(), result_quant_type);
611         Value rescale_op = buildRescale(
612             rewriter, op, rescale_type, v, operand_scale / result_scale,
613             operand_zeropoint, result_zeropoint, false, true);
614         values_rescaled.push_back(rescale_op);
615       } else {
616         values_rescaled.push_back(v);
617       }
618     } else {
619       values_rescaled.push_back(v);
620     }
621   }
622 
623   int32_t tensor_rank = result_type.getShape().size();
624 
625   if (axis < 0) axis += tensor_rank;
626   if ((axis < 0) || (axis > tensor_rank)) {
627     op->emitOpError("ConcatV2Op: axis out of valid range.");
628     return llvm::None;
629   }
630 
631   auto concat_op = rewriter.create<tosa::ConcatOp>(
632       op->getLoc(), result_value.getType(), values_rescaled,
633       rewriter.getI64IntegerAttr(axis));
634 
635   return concat_op.getResult();
636 }
637 
638 // Lowers SpaceToBatchND to TOSA.
convertSpaceToBatchNDOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value block_shape_value,Value paddings_value)639 llvm::Optional<Value> convertSpaceToBatchNDOp(PatternRewriter& rewriter,
640                                               Operation* op, Value result_value,
641                                               Value input_value,
642                                               Value block_shape_value,
643                                               Value paddings_value) {
644   /////////////////////////////////////////////////
645   // Operator: output = SpaceToBatchND(input, block_shape, paddings)
646   // Lowering:
647   //
648   // SpaceToBatch input tensors are broken into three pieces:
649   //   (a) batch dimension (N in NHWC)
650   //   (b) input being transformed to batch dimension (typically H, W in NHWC)
651   //   (c) remainder of input (typically C in NHWC)
652   //
653   // Step 0. Generate padding constant for the first reshape.
654   //   No padding on the batch dimension
655   //   The input paddings array is addressed as [input_rank][2]
656   //   No padding on the remaining dimensions
657   //
658   //  a0_pad_const = tosa.const(input=Tensor<input_rank, 2>)
659   //
660   // Step 1. Pad the input tensor
661   //
662   //  a1_pad_input_op = tosa.pad(input=input, shape=a0_pad_const_op)
663   //
664   // Step 2. Reshape the padded structure of shape padded_shape to
665   // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
666   //    padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
667   //    remaining_shape
668   //
669   // block_rank = M (number of elements in block_shape)
670   // New rank: input_rank + block_rank
671   //
672   //  a2_reshape_a1_op = tosa.reshape(input=a1_pad_input_op, shape=a2_shape)
673   //
674   // Step 3. Transpose dimensions to:
675   //  block-shape +
676   //  [batch] +
677   //  [padded_shape[1] / block_shape[0],
678   // ...
679   //  [padded_shape[M] / block_shape[M-1]] +
680   //  remaining_shape
681   //
682   // a3_transpose_a2_op = tosa.tranpose(input=a2_reshape_a1_op,
683   // perms=a3_perm)
684   //
685   // Step 4. Reshape the transposed tensor to flatten block_shape stuff
686   // into the batch dimension with the following shape:
687   // [ batch * prod(block_shape)] +
688   // [ padded_shape[1] / block_shape[0],
689   //   ...,
690   // padded_shape[M] / block_shape[M-1]] +
691   // remaining_shape
692   //
693   //  a4_reshape_a3_op = tosa.reshape(input=a3_tranpose_a2_op,
694   //  shape=a3_shape)
695   //
696 
697   RankedTensorType result_type =
698       result_value.getType().dyn_cast<RankedTensorType>();
699   RankedTensorType input_type =
700       input_value.getType().dyn_cast<RankedTensorType>();
701   RankedTensorType block_shape_type =
702       block_shape_value.getType().dyn_cast<RankedTensorType>();
703   RankedTensorType paddings_type =
704       paddings_value.getType().dyn_cast<RankedTensorType>();
705 
706   // Not a ranked tensor output.
707   if (!result_type) {
708     op->emitOpError("SpaceToBatchND: result type not ranked tensor");
709     return llvm::None;
710   }
711   if (!input_type) {
712     op->emitOpError("SpaceToBatchND: input type not ranked tensor");
713     return llvm::None;
714   }
715   if (!block_shape_type) {
716     op->emitOpError("SpaceToBatchND: block shape type not ranked tensor");
717     return llvm::None;
718   }
719   if (!paddings_type) {
720     op->emitOpError("SpaceToBatchND: paddings type not ranked tensor");
721     return llvm::None;
722   }
723 
724   // Follow implementation in
725   // tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
726 
727   // So, to figure out the spatial_shape, remove the batch dimension and
728   // then use the next block_rank dimensions.  The remaining dimensions are
729   // remaining_shape.
730 
731   auto block_shape = block_shape_type.getShape();
732   auto input_shape = input_type.getShape();
733 
734   int block_rank = block_shape[0];
735   int batch_size = input_shape[0];
736   int input_rank = input_type.getRank();
737   int remaining_shape_rank = input_rank - block_rank - 1;
738   int block_num_elems = 1;
739   int padding_sum = 0;
740 
741   ElementsAttr block_shape_elems;
742   ElementsAttr paddings_elems;
743 
744   if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems)))
745     return llvm::None;
746 
747   if (!matchPattern(paddings_value, m_Constant(&paddings_elems)))
748     return llvm::None;
749 
750   SmallVector<int32_t> a0_pad_const(2 * (input_rank));
751   SmallVector<int64_t> padded_shape(input_rank);
752 
753   // 1. Pad based on paddings operand.  No padding on the batch dimension.
754   // The a0_pad_const array is addressed as [input_rank][2], but
755   // it is flattened to a 1D array because LLVM appears to only accept 1D.
756   //
757   // padded_shape[] is the shape of the padded output of step a1.
758   // The name is retained for consistency with the TF reference code.
759   padded_shape[0] = input_shape[0];
760 
761   // Batch dimension padding
762   a0_pad_const[0] = 0;
763   a0_pad_const[1] = 0;
764 
765   // This iterator seems to be the only reliable way to get
766   // int values out of a multi-dimensional ElementsAttr.
767   int idx = 0;
768 
769   for (auto i : paddings_elems.getValues<IntegerAttr>()) {
770     a0_pad_const[idx + 2] = i.getInt();
771     padding_sum += i.getInt();
772     idx++;
773   }
774 
775   // Insert padding on the spatial shape dimensions
776   for (int i = 0; i < block_rank; i++) {
777     int32_t lo_pad = a0_pad_const[2 * (i + 1) + 0];
778     int32_t hi_pad = a0_pad_const[2 * (i + 1) + 1];
779 
780     padded_shape[i + 1] = input_shape[i + 1] + lo_pad + hi_pad;
781   }
782 
783   // No padding on the remaining_shape dimensions
784   for (int i = 0; i < remaining_shape_rank; i++) {
785     a0_pad_const[2 * (i + block_rank + 1) + 0] = 0;
786     a0_pad_const[2 * (i + block_rank + 1) + 1] = 0;
787     padded_shape[i + block_rank + 1] = input_shape[i + block_rank + 1];
788   }
789 
790   RankedTensorType a0_pad_const_attr_type =
791       RankedTensorType::get({(input_rank), 2}, rewriter.getIntegerType(32));
792 
793   // Create a const op to generate the tensor type for the input padding array
794   auto a0_pad_const_op = rewriter.create<tosa::ConstOp>(
795       op->getLoc(), a0_pad_const_attr_type,
796       DenseElementsAttr::get(a0_pad_const_attr_type,
797                              llvm::makeArrayRef(a0_pad_const)));
798 
799   auto a1_pad_input_op = rewriter.create<tosa::PadOp>(
800       op->getLoc(),
801       RankedTensorType::get(padded_shape, result_type.getElementType()),
802       input_value, a0_pad_const_op.getResult());
803 
804   // 2. Reshape the padded structure of shape padded_shape to
805   // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
806   //    padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
807   //    remaining_shape
808 
809   // block_rank = M (number of elements in block_shape)
810   // New rank: input_rank + block_rank
811   SmallVector<int64_t> a2_shape(1 + block_rank * 2 + remaining_shape_rank);
812 
813   // First dimension is batch.
814   a2_shape[0] = input_type.getShape()[0];
815   for (int i = 0; i < block_rank; i++) {
816     int32_t block_shape_val =
817         rewriter
818             .getI32IntegerAttr(
819                 block_shape_elems.getValue<IntegerAttr>(i).getInt())
820             .getInt();
821     a2_shape[1 + i * 2 + 0] = padded_shape[1 + i] / block_shape_val;
822     a2_shape[1 + i * 2 + 1] = block_shape_val;
823     block_num_elems *= block_shape_val;
824   }
825 
826   // Copy in the remaining block shape.
827   for (int i = 0; i < remaining_shape_rank; i++) {
828     a2_shape[1 + block_rank * 2 + i] = input_shape[1 + block_rank + i];
829   }
830 
831   auto a2_reshape_a1_op = rewriter.create<tosa::ReshapeOp>(
832       op->getLoc(),
833       RankedTensorType::get(a2_shape, result_type.getElementType()),
834       a1_pad_input_op.getResult(), rewriter.getI64ArrayAttr(a2_shape));
835 
836   // 3. Transpose dimensions to:
837   //  block-shape +
838   //  [batch] +
839   //  [padded_shape[1] / block_shape[0],
840   // ...
841   //  [padded_shape[M] / block_shape[M-1]] +
842   //  remaining_shape
843   int32_t a2_reshape_a1_rank =
844       a2_reshape_a1_op.getResult().getType().cast<RankedTensorType>().getRank();
845   SmallVector<int32_t> a3_perm(a2_reshape_a1_rank);
846   SmallVector<int64_t> a3_transpose_shape(a2_reshape_a1_rank);
847 
848   for (int i = 0; i < block_rank; i++) {
849     a3_perm[i] = 1 + 2 * i + 1;
850     a3_perm[block_rank + 1 + i] = 1 + 2 * i;
851   }
852   a3_perm[block_rank] = 0;
853   for (int i = 1 + block_rank * 2; i < a2_reshape_a1_rank; i++) {
854     a3_perm[i] = i;
855   }
856 
857   for (int i = 0; i < a3_transpose_shape.size(); i++) {
858     a3_transpose_shape[i] = a2_shape[a3_perm[i]];
859   }
860 
861   llvm::Optional<Value> a3_transpose_const = getConstTensor<int32_t>(
862       rewriter, op, a3_perm, {static_cast<int64_t>(a3_perm.size())});
863 
864   if (!a3_transpose_const) return llvm::None;
865 
866   auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
867       op->getLoc(),
868       RankedTensorType::get(a3_transpose_shape, result_type.getElementType()),
869       a2_reshape_a1_op.getResult(), a3_transpose_const.getValue());
870 
871   // 4. Reshape the transposed tensor to flatten block_shape
872   // into the batch dimension with the following shape:
873   // [ batch * prod(block_shape)] +
874   // [ padded_shape[1] / block_shape[0],
875   //   ...,
876   // padded_shape[M] / block_shape[M-1]] +
877   // remaining_shape
878   SmallVector<int64_t> a4_reshape_shape(input_rank);
879 
880   // Batch
881   a4_reshape_shape[0] = batch_size * block_num_elems;
882 
883   // padded shape / block_shape.
884   for (int i = 0; i < block_rank; i++) {
885     int32_t block_shape_val =
886         rewriter
887             .getI32IntegerAttr(
888                 block_shape_elems.getValue<IntegerAttr>(i).getInt())
889             .getInt();
890     a4_reshape_shape[i + 1] = padded_shape[i + 1] / block_shape_val;
891   }
892 
893   // Copy in remainder shape.
894   for (int i = 0; i < remaining_shape_rank; i++) {
895     a4_reshape_shape[1 + block_rank + i] = input_shape[1 + block_rank + i];
896   }
897 
898   return rewriter
899       .create<tosa::ReshapeOp>(op->getLoc(), result_type,
900                                a3_transpose_a2_op.getResult(),
901                                rewriter.getI64ArrayAttr(a4_reshape_shape))
902       .getResult();
903 }
904 
905 // Lowers BatchToSpaceND to TOSA.
convertBatchToSpaceNDOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value block_shape_value,Value crops_value)906 llvm::Optional<Value> convertBatchToSpaceNDOp(PatternRewriter& rewriter,
907                                               Operation* op, Value result_value,
908                                               Value input_value,
909                                               Value block_shape_value,
910                                               Value crops_value) {
911   /////////////////////////////////////////////////
912   // Operator: output = BatchToSpaceND(input, block_shape, clips)
913   // Lowering:
914   //
915   // BatchToSpace input tensors are broken into three pieces:
916   //   (a) batch dimension (N in NHWC)
917   //   (b) input being transformed from batch dimension (typically H, W in
918   //   NHWC)
919   //   (c) remainder of input (typically C in NHWC)
920   //
921   // Step 1. Reshape input to:
922   // [block_shape[0],
923   // ...
924   // [block_shape[M-1],
925   // [batch / prod(block_shape)]
926   // [input_shape[1],
927   // ...
928   // [input_shape[N-1]
929   //
930   // a1_reshape_input_op = tosa.reshape(input=input, shape=a1_shape)
931   //
932   // Step 2. Permute to shape
933   // [ batch / prod(block_shape) ],
934   // [ input_shape[1] ], [ block_shape[1] ]
935   //  ...
936   // [ input_shape[M] ], [ block_shape[M-1]
937   // + remaining_input_shapes input_shape[M .. N-1]
938   //
939   // a2_transpose_a1 = tosa.transpose(input=a1_reshape_input_op,
940   // shape=a2_shape)
941   //
942   // Step 3. Reshape to:
943   // [ batch / prod(block_shape) ],
944   // [input_shape[1] * block_shape[0] ],
945   //    ..
946   // [input_shape[M * block_shape[M-1],
947   // + remaining input shapes [input_shape[M+1.. N-1]]
948   //
949   // a3_reshape_a2 = tosa.reshape(input=a2_transpose_a1, shape=a3_shape)
950   //
951   // Step 4. Crop the start/end dimensions according to crops of the
952   // a3_reshape_a2 shape
953   //
954   // a4_slice_a3 = tosa.slice(input=a3_reshape_a2, start=a4_start,
955   // size=a4_size)
956 
957   RankedTensorType result_type =
958       result_value.getType().dyn_cast<RankedTensorType>();
959   RankedTensorType input_type =
960       input_value.getType().dyn_cast<RankedTensorType>();
961   RankedTensorType block_shape_type =
962       block_shape_value.getType().dyn_cast<RankedTensorType>();
963   RankedTensorType crops_type =
964       crops_value.getType().dyn_cast<RankedTensorType>();
965 
966   if (!result_type) {
967     op->emitOpError("BatchToSpaceND: result type not ranked tensor");
968     return llvm::None;
969   }
970   if (!input_type) {
971     op->emitOpError("BatchToSpaceND: input type not ranked tensor");
972     return llvm::None;
973   }
974   if (!block_shape_type) {
975     op->emitOpError("BatchToSpaceND: block shape type not ranked tensor");
976     return llvm::None;
977   }
978   if (!crops_type) {
979     op->emitOpError("BatchToSpaceND: crops type not ranked tensor");
980     return llvm::None;
981   }
982 
983   // Another 4-step process
984   int block_rank = block_shape_type.getShape()[0];
985   int input_rank = input_type.getRank();
986   int crops_dims = crops_type.getShape()[0];
987   int remaining_shape_rank = input_rank - block_rank - 1;
988   auto input_shape = input_type.getShape();
989 
990   ElementsAttr block_shape_elems;
991   ElementsAttr crops_elems;
992 
993   if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems))) {
994     op->emitOpError("BatchToSpaceND: block_shape not a constant");
995     return llvm::None;
996   }
997 
998   if (!matchPattern(crops_value, m_Constant(&crops_elems))) {
999     op->emitOpError("BatchToSpaceND: crops not a constant");
1000     return llvm::None;
1001   }
1002 
1003   SmallVector<int64_t> block_shape(block_rank);
1004   SmallVector<std::pair<int64_t, int64_t>> crops(crops_dims);
1005 
1006   // Extract values for block_shape and crops now.
1007   int block_num_elems = 1;
1008   for (int i = 0; i < block_rank; i++) {
1009     int block_shape_val =
1010         rewriter
1011             .getI32IntegerAttr(
1012                 block_shape_elems.getValue<IntegerAttr>(i).getInt())
1013             .getInt();
1014     block_num_elems *= block_shape_val;
1015     block_shape[i] = block_shape_val;
1016   }
1017 
1018   // This iterator seems to be the only reliable way to get
1019   // int values out of a multi-dimensional ElementsAttr
1020   SmallVector<int32_t> crops_const(2 * (crops_dims));
1021   int idx = 0;
1022   for (auto i : crops_elems.getValues<IntegerAttr>()) {
1023     crops_const[idx++] = i.getInt();
1024   }
1025 
1026   for (int i = 0; i < crops_dims; i++) {
1027     int crops_lo = crops_const[i * crops_dims + 0];
1028     int crops_hi = crops_const[i * crops_dims + 1];
1029     crops[i] = std::make_pair(crops_lo, crops_hi);
1030   }
1031 
1032   // Step 1. Reshape input to:
1033   // [block_shape[0],
1034   // ...
1035   // [block_shape[M-1],
1036   // [batch / prod(block_shape)]
1037   // [input_shape[1],
1038   // ...
1039   // [input_shape[N-1]
1040   SmallVector<int64_t> a1_shape(block_rank + input_rank);
1041 
1042   for (int i = 0; i < block_rank; i++) a1_shape[i] = block_shape[i];
1043 
1044   a1_shape[block_rank] = input_shape[0] / block_num_elems;
1045 
1046   for (int i = 0; i < input_rank - 1; i++)
1047     a1_shape[i + block_rank + 1] = input_shape[i + 1];
1048 
1049   auto a1_reshape_input_op = rewriter.create<tosa::ReshapeOp>(
1050       op->getLoc(),
1051       RankedTensorType::get(a1_shape, result_type.getElementType()),
1052       input_value, rewriter.getI64ArrayAttr(a1_shape));
1053 
1054   // 2. Permute to shape
1055   // [ batch / prod(block_shape) ],
1056   // [ input_shape[1] ], [ block_shape[0] ]
1057   //  ...
1058   // [ input_shape[M] ], [ block_shape[M-1]
1059   // + remaining_input_shapes input_shape[M+1 .. N-1]
1060 
1061   // 2a. calculate the permutation
1062   SmallVector<int32_t> a2_perm(block_rank + input_rank);
1063   SmallVector<int64_t> a2_transpose_shape(block_rank + input_rank);
1064 
1065   a2_perm[0] = block_rank;
1066   for (int i = 0; i < block_rank; i++) {
1067     a2_perm[1 + i * 2 + 0] = block_rank + 1 + i;
1068     a2_perm[1 + i * 2 + 1] = i;
1069   }
1070 
1071   for (int i = 0; i < remaining_shape_rank; i++) {
1072     a2_perm[1 + 2 * block_rank + i] = 1 + 2 * block_rank + i;
1073   }
1074 
1075   // 2b. calculate the a2_permuted shape
1076   for (int i = 0; i < (block_rank + input_rank); i++) {
1077     a2_transpose_shape[i] = a1_shape[a2_perm[i]];
1078   }
1079 
1080   llvm::Optional<Value> a2_transpose_perm = getConstTensor<int32_t>(
1081       rewriter, op, a2_perm, {static_cast<int64_t>(a2_perm.size())});
1082 
1083   if (!a2_transpose_perm) return llvm::None;
1084 
1085   auto a2_transpose_a1_op = rewriter.create<tosa::TransposeOp>(
1086       op->getLoc(),
1087       RankedTensorType::get(a2_transpose_shape, result_type.getElementType()),
1088       a1_reshape_input_op.getResult(), a2_transpose_perm.getValue());
1089 
1090   // Step 3. Reshape to:
1091   // [ batch / prod(block_shape) ],
1092   // [input_shape[1] * block_shape[0] ],
1093   //    ..
1094   // [input_shape[M * block_shape[M-1],
1095   // + remaining input shapes [input_shape[M+1.. N-1]]
1096   SmallVector<int64_t> a4_shape(input_rank);
1097 
1098   a4_shape[0] = input_shape[0] / block_num_elems;
1099   for (int i = 0; i < block_rank; i++) {
1100     a4_shape[1 + i] = input_shape[i + 1] * block_shape[i];
1101   }
1102   for (int i = 0; i < remaining_shape_rank; i++) {
1103     a4_shape[1 + block_rank + i] = input_shape[block_rank + 1 + i];
1104   }
1105 
1106   auto a3_reshape_a2 = rewriter.create<tosa::ReshapeOp>(
1107       op->getLoc(),
1108       RankedTensorType::get(a4_shape, result_type.getElementType()),
1109       a2_transpose_a1_op.getResult(), rewriter.getI64ArrayAttr(a4_shape));
1110 
1111   // 4. Crop the start/end dimensions on 'spatial dimension' according to
1112   // crops
1113   // Use a slice operator to do the cropping.
1114   //
1115   // Calculate a beginning point and a size:
1116   // - Begin is the origin, offset by the lo crop amount in each dimension
1117   // - Size is the reshaped tensor size, minus the quantity (lo + hi) for each
1118   // dimension
1119   SmallVector<int64_t> a4_begin_vals(input_rank), a4_size_vals(input_rank);
1120 
1121   for (int i = 0; i < input_rank; i++) {
1122     // Batch dimension and remaining dimensions.
1123     if (i == 0 || i > crops_dims) {
1124       a4_begin_vals[i] = 0;
1125       a4_size_vals[i] = result_type.getShape()[i];
1126     } else {
1127       // Spatial dimension.
1128       assert(i - 1 >= 0 && i - 1 < crops_dims);
1129       a4_begin_vals[i] = crops[i - 1].first;
1130       a4_size_vals[i] = a4_shape[i] - crops[i - 1].first - crops[i - 1].second;
1131     }
1132   }
1133 
1134   return rewriter
1135       .create<tosa::SliceOp>(
1136           op->getLoc(),
1137           RankedTensorType::get(a4_size_vals, result_type.getElementType()),
1138           a3_reshape_a2.getResult(), rewriter.getI64ArrayAttr(a4_begin_vals),
1139           rewriter.getI64ArrayAttr(a4_size_vals))
1140       .getResult();
1141 }
1142 
1143 // Lowers ExpandDims to TOSA.
convertExpandDimsOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value dim_value)1144 llvm::Optional<Value> convertExpandDimsOp(PatternRewriter& rewriter,
1145                                           Operation* op, Value result_value,
1146                                           Value input_value, Value dim_value) {
1147   // Lowers to a reshape op with 1's inserted in the appropriate dimensions.
1148   RankedTensorType output_type =
1149       result_value.getType().dyn_cast<RankedTensorType>();
1150   // Not a ranked tensor output
1151   if (!output_type) {
1152     op->emitOpError("ExpandDims: output type not ranked tensor");
1153     return llvm::None;
1154   }
1155 
1156   RankedTensorType input_type =
1157       input_value.getType().dyn_cast<RankedTensorType>();
1158   if (!input_type) {
1159     op->emitOpError("ExpandDims: input type not ranked tensor");
1160     return llvm::None;
1161   }
1162 
1163   auto input_shape = input_type.getShape();
1164 
1165   ElementsAttr dim_elem;
1166   if (!matchPattern(dim_value, m_Constant(&dim_elem))) return llvm::None;
1167 
1168   assert(dim_elem.getType().getRank() == 0 && "expected scalar tensor");
1169   int32_t dim = dim_elem.getValue<IntegerAttr>({}).getInt();
1170 
1171   SmallVector<int64_t> reshape_dims;
1172   if (dim < 0 || dim >= input_shape.size()) {  // add dim at end of tensor
1173     dim = input_shape.size();
1174     for (int i = 0; i < input_shape.size(); i++) {
1175       reshape_dims.emplace_back(input_shape[i]);
1176     }
1177     reshape_dims.emplace_back(1);
1178   } else {
1179     for (int i = 0; i < input_shape.size(); i++) {
1180       if (i == dim) {
1181         reshape_dims.emplace_back(1);
1182       }
1183       reshape_dims.emplace_back(input_shape[i]);
1184     }
1185   }
1186 
1187   ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
1188 
1189   return rewriter
1190       .create<tosa::ReshapeOp>(op->getLoc(), output_type, input_value,
1191                                shape_attr)
1192       .getResult();
1193 }
1194 
1195 // Lowers Squeeze to TOSA.
convertSqueezeOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,SmallVectorImpl<int32_t> & squeeze_dims)1196 llvm::Optional<Value> convertSqueezeOp(PatternRewriter& rewriter, Operation* op,
1197                                        Value result_value, Value input_value,
1198                                        SmallVectorImpl<int32_t>& squeeze_dims) {
1199   // Lowers to a reshape op where dimensions in squeeze_dims with size=1
1200   // are removed.
1201   RankedTensorType output_type =
1202       result_value.getType().dyn_cast<RankedTensorType>();
1203   // Not a ranked tensor output
1204   if (!output_type) {
1205     op->emitOpError("Squeeze: output type not ranked tensor");
1206     return llvm::None;
1207   }
1208 
1209   RankedTensorType input_type =
1210       input_value.getType().dyn_cast<RankedTensorType>();
1211   if (!input_type) {
1212     op->emitOpError("Squeeze: input type not ranked tensor");
1213     return llvm::None;
1214   }
1215 
1216   auto input_shape = input_type.getShape();
1217 
1218   SmallVector<int64_t> reshape_dims;
1219 
1220   if (squeeze_dims.empty()) {  // remove all 1-dims
1221     for (int i = 0; i < input_shape.size(); i++) {
1222       if (input_shape[i] != 1) {
1223         reshape_dims.emplace_back(input_shape[i]);
1224       }
1225     }
1226   } else {
1227     // Remove only specified dims.
1228     // First sort the array so they can be picked off in sequence.
1229     std::sort(squeeze_dims.begin(), squeeze_dims.end(),
1230               [](const int32_t a, const int32_t b) { return a < b; });
1231 
1232     int pos = 0;
1233     auto dim = squeeze_dims[pos];
1234     for (int i = 0; i < input_shape.size(); i++) {
1235       if (i == dim) {
1236         pos = pos + 1;
1237         if (pos < squeeze_dims.size())
1238           dim = squeeze_dims[pos];
1239         else
1240           dim = -1;  // Invalid
1241       } else {
1242         reshape_dims.emplace_back(input_shape[i]);
1243       }
1244     }
1245   }
1246 
1247   ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
1248 
1249   return rewriter
1250       .create<tosa::ReshapeOp>(op->getLoc(), output_type, input_value,
1251                                shape_attr)
1252       .getResult();
1253 }
1254 
1255 // Lowers ELU to a sequence of TOSA ops.
convertEluOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value features_value)1256 llvm::Optional<Value> convertEluOp(PatternRewriter& rewriter, Operation* op,
1257                                    Value result_value, Value features_value) {
1258   // Lowers Elu using the following formula:
1259   // elu(x) = x < 0 ? (exp(x) - 1) : x
1260   // one = const({1});
1261   // zero = const({0});
1262   // one_bcast = reshape(one, [1, ..., rank(x) - 1])
1263   // zero_bcast = reshape(zero, [1, ..., rank(x) - 1])
1264   // a1 = exp(x);
1265   // a2 = sub(a1, one_bcast)
1266   // a3 = ge(x, zero_bcast)
1267   // a4 = select(a3, x, a2)
1268   RankedTensorType output_type =
1269       result_value.getType().dyn_cast<RankedTensorType>();
1270   // Not a ranked tensor output
1271   if (!output_type) {
1272     op->emitOpError("Elu: output type not ranked tensor");
1273     return llvm::None;
1274   }
1275 
1276   int32_t input_rank = output_type.getShape().size();
1277   SmallVector<int64_t> bcast_shape(input_rank, 1);
1278 
1279   // Can't directly create size=1, rank=rank(input) tensor because
1280   // it will be optimized out.  Instead, create rank0 tensor and reshape later.
1281   Value one_const_op = getTosaConstTensorSingleF32(rewriter, op, 1.0);
1282 
1283   Value zero_const_op = getTosaConstTensorSingleF32(rewriter, op, 0.0);
1284 
1285   auto a1_exp_in_op =
1286       rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, features_value);
1287 
1288   auto a2_sub_a1_one_op = rewriter.create<tosa::SubOp>(
1289       op->getLoc(), output_type, a1_exp_in_op.getResult(), one_const_op);
1290 
1291   auto a3_ge_in_zero_op = rewriter.create<tosa::GreaterEqualOp>(
1292       op->getLoc(),
1293       RankedTensorType::get(output_type.getShape(), rewriter.getIntegerType(1)),
1294       features_value, zero_const_op);
1295 
1296   return rewriter
1297       .create<tosa::SelectOp>(op->getLoc(), output_type,
1298                               a3_ge_in_zero_op.getResult(), features_value,
1299                               a2_sub_a1_one_op.getResult())
1300       .getResult();
1301 }
1302 
1303 // Lowers Softmax to a sequence of TOSA ops.
convertSoftmaxOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value logits_value,double beta)1304 llvm::Optional<Value> convertSoftmaxOp(PatternRewriter& rewriter, Operation* op,
1305                                        Value result_value, Value logits_value,
1306                                        double beta) {
1307   // softmax = exp(logits) / reduce_sum(exp(logits), -1)
1308   //
1309   // or equivalently multiply exp(-max(logits)) to both numerator and
1310   // denominator we get:
1311   //
1312   // softmax = exp(logits - max(logits)) / reduce_sum(exp(logits -
1313   // max(logits)), -1)
1314   //
1315   // We'll use first version for direct fp lowering, and second version for
1316   // quantized lowering since second one we can restrict input to exp() be
1317   // negative, and thus LUT can always be within [0.0, 1.0].
1318   RankedTensorType output_type =
1319       result_value.getType().dyn_cast<RankedTensorType>();
1320   RankedTensorType input_type =
1321       logits_value.getType().dyn_cast<RankedTensorType>();
1322 
1323   // Not a ranked tensor input/output
1324   if (!output_type || !input_type) {
1325     op->emitOpError("Softmax: input and result not ranked tensors");
1326     return llvm::None;
1327   }
1328 
1329   // reduce_sum on last dimension
1330   int32_t input_rank = input_type.getShape().size();
1331   ArrayRef<int64_t> logits_shape = output_type.getShape();
1332 
1333   if (input_type.getElementType().isa<mlir::quant::QuantizedType>() &&
1334       output_type.getElementType().isa<mlir::quant::QuantizedType>()) {
1335     SmallVector<int64_t> rsum_shape_v(input_type.getShape().begin(),
1336                                       input_type.getShape().end() - 1);
1337     rsum_shape_v.push_back(1);
1338     ArrayRef<int64_t> rsum_shape(rsum_shape_v);
1339     // The if condition already checks if these are UQTs
1340     mlir::quant::UniformQuantizedType in_quant_type =
1341         input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
1342     mlir::quant::UniformQuantizedType out_quant_type =
1343         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
1344 
1345     auto int16_element_qtype = mlir::quant::UniformQuantizedType::get(
1346         true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
1347         -32768, 32767);
1348     RankedTensorType int16_logits_type =
1349         RankedTensorType::get(logits_shape, int16_element_qtype);
1350     RankedTensorType int32_logits_type =
1351         RankedTensorType::get(logits_shape, rewriter.getIntegerType(32));
1352     RankedTensorType int16_rsum_type =
1353         RankedTensorType::get(rsum_shape, int16_element_qtype);
1354     RankedTensorType int32_rsum_type =
1355         RankedTensorType::get(rsum_shape, rewriter.getIntegerType(32));
1356 
1357     if (in_quant_type.getStorageTypeIntegralWidth() == 8) {
1358       // Step 1. get x - max(x)
1359       Value op1_rescale_in =
1360           buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f,
1361                        in_quant_type.getZeroPoint(), 0, false, true);
1362 
1363       auto op2_reducemax_op1 = rewriter.create<tosa::ReduceMaxOp>(
1364           op->getLoc(), int32_rsum_type, op1_rescale_in,
1365           rewriter.getI64IntegerAttr(input_rank - 1));
1366 
1367       auto op3_sub_op1_op2 = rewriter.create<tosa::SubOp>(
1368           op->getLoc(), int32_logits_type, op1_rescale_in,
1369           op2_reducemax_op1.getResult());
1370 
1371       // Step 2. get exp() result
1372       // Implemented with two 8-bit -> 16-bit table lookup
1373       // Since table output is allowed to be [-32768, 32767]
1374       // And lower 16 bits are unsigned and ranges [0, 65535]
1375       // Lower table is generated with offset -32768, and this need to be
1376       // recovered before adding with higher 16 bits.
1377       auto exp_func = [](double x) -> double { return std::exp(x); };
1378 
1379       Value exp_table_const_upper, exp_table_const_lower;
1380       getTosaConst32bitTable(rewriter, op, beta * in_quant_type.getScale(), 0,
1381                              exp_func, exp_table_const_upper,
1382                              exp_table_const_lower);
1383 
1384       Value op4_rescale_op3 =
1385           buildRescale(rewriter, op, int16_logits_type,
1386                        op3_sub_op1_op2.getResult(), 128.0, 0, 0, false, true);
1387 
1388       // Input is 9.7, where lower 7 bits are all zeros.
1389       // Output is 23 bits, where lower 7 bits should be all zeros as well,
1390       // since there's no interpolation here.
1391       auto op5_table_op4_upper = rewriter.create<tosa::TableOp>(
1392           op->getLoc(), int32_logits_type, op4_rescale_op3,
1393           exp_table_const_upper);
1394 
1395       auto op6_table_op4_lower = rewriter.create<tosa::TableOp>(
1396           op->getLoc(), int32_logits_type, op4_rescale_op3,
1397           exp_table_const_lower);
1398 
1399       // To get 16 bits upper/lower value, we need to right shift 7 bits
1400       // And then we reconstruct 32-bit value we need (upper << 16) + lower
1401       // So effectively we left shift upper with 9 bits
1402       auto op7_lshift_op5 = rewriter.create<tosa::LogicalLeftShiftOp>(
1403           op->getLoc(), int32_logits_type, op5_table_op4_upper.getResult(),
1404           getTosaConstTensorSingleI32(rewriter, op, 9));
1405 
1406       // Right shift 7 bits to get lower 16 bits.
1407       auto op8_rshift_op6 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1408           op->getLoc(), int32_logits_type, op6_table_op4_lower.getResult(),
1409           getTosaConstTensorSingleI32(rewriter, op, 7), true);
1410 
1411       // Recover lower bits from [-32768, 32767] back to [0, 65535]
1412       auto op9_add_op8_32768 = rewriter.create<tosa::AddOp>(
1413           op->getLoc(), int32_logits_type, op8_rshift_op6.getResult(),
1414           getTosaConstTensorSingleI32(rewriter, op, 32768));
1415 
1416       auto op10_add_op7_op9 = rewriter.create<tosa::AddOp>(
1417           op->getLoc(), int32_logits_type, op7_lshift_op5.getResult(),
1418           op9_add_op8_32768.getResult());
1419 
1420       // Step 3. get sum(exp()). output 12.19
1421       auto op11_rshift_op10_12 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1422           op->getLoc(), int32_logits_type, op10_add_op7_op9.getResult(),
1423           getTosaConstTensorSingleI32(rewriter, op, 12), true);
1424 
1425       auto op12_reducesum_op11 = rewriter.create<tosa::ReduceSumOp>(
1426           op->getLoc(), int32_rsum_type, op11_rshift_op10_12.getResult(),
1427           rewriter.getI64IntegerAttr(input_rank - 1));
1428 
1429       // Step 4. calculate reciprocal(sum(exp()))
1430       // CLZ returns headroom_plus_one
1431       auto op13_clz_op12 = rewriter.create<tosa::ClzOp>(
1432           op->getLoc(), int32_rsum_type, op12_reducesum_op11.getResult());
1433 
1434       // minus one to get headroom
1435       auto op14_sub_op13 = rewriter.create<tosa::SubOp>(
1436           op->getLoc(), int32_rsum_type, op13_clz_op12.getResult(),
1437           getTosaConstTensorSingleI32(rewriter, op, 1));
1438 
1439       // Left shift to get s1.30 format
1440       auto op15_lshift_op12_op14 = rewriter.create<tosa::LogicalLeftShiftOp>(
1441           op->getLoc(), int32_rsum_type, op12_reducesum_op11.getResult(),
1442           op14_sub_op13.getResult());
1443 
1444       // Step 5. Calculate one_over_one_plus_x() with Newton-Raphson division
1445       // with 3 iterations.
1446       // Need two magic constants 48/17 and -32/17 from Newton-Raphson algorithm
1447       // We need to operator in s2.29 since 48/17 is > 2.0
1448       // Reference: gemmlowp/fixedpoint/fixedpoint.h
1449       Value half_denominator = op15_lshift_op12_op14.getResult();
1450       Value four = getTosaConstTensorSingleI32(rewriter, op, 4);
1451       Value F2_one = getTosaConstTensorSingleI32(rewriter, op, (1U << 29));
1452       Value constant_48_over_17 =
1453           getTosaConstTensorSingleI32(rewriter, op, 1515870810);
1454       Value constant_neg_32_over_17 =
1455           getTosaConstTensorSingleI32(rewriter, op, -1010580540);
1456 
1457       // F2 x = constant_48_over_17 + half_denominator *
1458       // constant_neg_32_over_17;
1459       auto op16_mul_half_denominator = rewriter.create<tosa::MulOp>(
1460           op->getLoc(), int32_rsum_type, half_denominator,
1461           constant_neg_32_over_17, 31);
1462 
1463       auto op17_add_op16 = rewriter.create<tosa::AddOp>(
1464           op->getLoc(), int32_rsum_type, op16_mul_half_denominator.getResult(),
1465           constant_48_over_17);
1466 
1467       // Newton-Raphson 3x iteration
1468       Value nr_x = op17_add_op16.getResult();
1469       for (int i = 0; i < 3; i++) {
1470         // half_denominator_times_x =
1471         // SaturatingRoundingDoublingHighMul(half_denominator, x)
1472         auto op18_mul_x_half_denominator = rewriter.create<tosa::MulOp>(
1473             op->getLoc(), int32_rsum_type, nr_x, half_denominator, 31);
1474 
1475         // F2 one_minus_half_denominator_times_x = F2::One() -
1476         // half_denominator_times_x
1477         auto op19_sub_one_op18 = rewriter.create<tosa::SubOp>(
1478             op->getLoc(), int32_rsum_type, F2_one,
1479             op18_mul_x_half_denominator.getResult());
1480 
1481         // SaturatingRoundingDoublingHighMul(x,
1482         // one_minus_half_denominator_times_x)
1483         auto op20_mul_x_op19 =
1484             rewriter.create<tosa::MulOp>(op->getLoc(), int32_rsum_type, nr_x,
1485                                          op19_sub_one_op18.getResult(), 31);
1486 
1487         // x + Rescale<2>(x * one_minus_half_denominator_times_x)
1488         auto op21_mul_op20_four =
1489             rewriter.create<tosa::MulOp>(op->getLoc(), int32_rsum_type,
1490                                          op20_mul_x_op19.getResult(), four, 0);
1491 
1492         auto op22_add_x_op21 =
1493             rewriter.create<tosa::AddOp>(op->getLoc(), int32_rsum_type, nr_x,
1494                                          op21_mul_op20_four.getResult());
1495 
1496         nr_x = op22_add_x_op21.getResult();
1497       }
1498 
1499       // Step 6. multiply exp(x) with 1 / sum(exp(x))
1500       // combined with Rescale<0>(ExactMulByPot<-1>(x))
1501       // so shift 30 instead of 31
1502       auto op23_mul_op10_x = rewriter.create<tosa::MulOp>(
1503           op->getLoc(), int32_logits_type, op10_add_op7_op9.getResult(), nr_x,
1504           31 - 1);
1505 
1506       // Right shift amount is
1507       // num_bits_over_unit + 31 - (sizeof(OutputT) * 8 =
1508       // (12 - headroom_plus_one) + 31 - 8 =
1509       // (12 + 31 - 8) - headroom_plus_one
1510       auto op24_sub_op13 = rewriter.create<tosa::SubOp>(
1511           op->getLoc(), int32_rsum_type,
1512           getTosaConstTensorSingleI32(rewriter, op, 12 + 31 - 8),
1513           op13_clz_op12.getResult());
1514 
1515       auto op25_rshift_op23_op24 =
1516           rewriter.create<tosa::ArithmeticRightShiftOp>(
1517               op->getLoc(), int32_logits_type, op23_mul_op10_x.getResult(),
1518               op24_sub_op13.getResult(), true);
1519 
1520       return buildRescale(rewriter, op, output_type,
1521                           op25_rshift_op23_op24.getResult(), 1.0, 0,
1522                           out_quant_type.getZeroPoint(), false, true);
1523 
1524     } else if (in_quant_type.getStorageTypeIntegralWidth() == 16) {
1525       // Step 1. get x - max(x)
1526       Value op1_rescale_in =
1527           buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f,
1528                        in_quant_type.getZeroPoint(), 0, false, true);
1529 
1530       auto op2_reducemax_op1 = rewriter.create<tosa::ReduceMaxOp>(
1531           op->getLoc(), int32_rsum_type, op1_rescale_in,
1532           rewriter.getI64IntegerAttr(input_rank - 1));
1533 
1534       // output range is [-65535, 0]
1535       auto op3_sub_op1_op2 = rewriter.create<tosa::SubOp>(
1536           op->getLoc(), int32_logits_type, op1_rescale_in,
1537           op2_reducemax_op1.getResult());
1538 
1539       auto exp_func = [](double x) -> double { return std::exp(x); };
1540 
1541       // Follow TFLite reference: tensorflow/lite/kernels/activations.cc
1542       Value exp_table_const =
1543           getTosaConst16bitTable(rewriter, op, exp_func, -10.0, 0);
1544 
1545       double input_diff_scale = in_quant_type.getScale() / (10.0 / 65535.0);
1546 
1547       // Step 2. rescale input from [-65535, 0] to [-32768, 32767] for LUT input
1548       Value op4_rescale_op3 = buildRescale(
1549           rewriter, op, int16_logits_type, op3_sub_op1_op2.getResult(),
1550           input_diff_scale, 0, 32767, true, true);
1551 
1552       // Step 3. get exp() result
1553       // Output is 15.7.
1554       // In 8-bit case, no interpolation here, since input should be right on
1555       // table entry.
1556       auto op5_table_op4 = rewriter.create<tosa::TableOp>(
1557           op->getLoc(), int32_logits_type, op4_rescale_op3, exp_table_const);
1558 
1559       // Right shift 7 bits. output 15. Shouldn't lose any precision since last
1560       // 7 bits should be all 0.
1561       auto op6_rshift_op5 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1562           op->getLoc(), int32_logits_type, op5_table_op4.getResult(),
1563           getTosaConstTensorSingleI32(rewriter, op, 7), true);
1564 
1565       // Step 4. get sum(exp()). output 16.15
1566       auto op7_reducesum_op6 = rewriter.create<tosa::ReduceSumOp>(
1567           op->getLoc(), int32_rsum_type, op6_rshift_op5.getResult(),
1568           rewriter.getI64IntegerAttr(input_rank - 1));
1569 
1570       // Step 5. calculate reciprocal(sum(exp()))
1571       // CLZ returns 32 - first non zero bit
1572       auto op8_clz_op7 = rewriter.create<tosa::ClzOp>(
1573           op->getLoc(), int32_rsum_type, op7_reducesum_op6.getResult());
1574 
1575       auto op9_sub_op8 = rewriter.create<tosa::SubOp>(
1576           op->getLoc(), int32_rsum_type, op8_clz_op7.getResult(),
1577           getTosaConstTensorSingleI32(rewriter, op, 1));
1578 
1579       // Left shift to get  1.30 format
1580       auto op10_lshift_op7_op9 = rewriter.create<tosa::LogicalLeftShiftOp>(
1581           op->getLoc(), int32_rsum_type, op7_reducesum_op6.getResult(),
1582           op9_sub_op8.getResult());
1583 
1584       // Subtract (1 << 30) to make 0 <= x <= 1 under 0.30 format
1585       auto op11_sub_op10 = rewriter.create<tosa::SubOp>(
1586           op->getLoc(), int32_rsum_type, op10_lshift_op7_op9.getResult(),
1587           getTosaConstTensorSingleI32(rewriter, op, (1u << 30)));
1588 
1589       // Right shift 14 bits to get output range [0, 65535]
1590       auto op12_rshift_op11 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1591           op->getLoc(), int32_rsum_type, op11_sub_op10.getResult(),
1592           getTosaConstTensorSingleI32(rewriter, op, 14), true);
1593 
1594       // Remap input to [-32768, 32767] for LUT input
1595       auto op13_rescale_op12 = buildRescale(rewriter, op, int16_rsum_type,
1596                                             op12_rshift_op11.getResult(), 1.0,
1597                                             32768, 0, false, true);
1598 
1599       // Generate table for 1 / (1 + x), for 0 <= x <= 1
1600       auto one_over_one_plus_x_func = [](double x) -> double {
1601         return 1.0 / (1.0 + x);
1602       };
1603 
1604       Value one_over_one_plus_x_table_const = getTosaConst16bitTable(
1605           rewriter, op, one_over_one_plus_x_func, 0.0, 1.0);
1606 
1607       // Get (1 / sum(exp(x))) result as 23 bits (including sign bit)
1608       auto op14_table_op13 = rewriter.create<tosa::TableOp>(
1609           op->getLoc(), int32_rsum_type, op13_rescale_op12,
1610           one_over_one_plus_x_table_const);
1611 
1612       // Right shift 7 bits back to 0.15
1613       auto op15_rshift_op14 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1614           op->getLoc(), int32_rsum_type, op14_table_op13.getResult(),
1615           getTosaConstTensorSingleI32(rewriter, op, 7), true);
1616 
1617       // Step 6. multiply exp(max-x) with 1 / sum(exp(max-x))
1618       // lhs: 0.15, rhs: 0.15, output: 0.30
1619       auto op16_mul_op15_op6 = rewriter.create<tosa::MulOp>(
1620           op->getLoc(), int32_logits_type, op15_rshift_op14, op6_rshift_op5, 0);
1621 
1622       auto op17_sub_op8 = rewriter.create<tosa::SubOp>(
1623           op->getLoc(), int32_rsum_type,
1624           getTosaConstTensorSingleI32(rewriter, op, 31),
1625           op8_clz_op7.getResult());
1626 
1627       // Apply the clz back, we get 0.15 output
1628       // [0, 32767] corresponding to [0.0, 1.0]
1629       auto op18_rshift_op16_op17 =
1630           rewriter.create<tosa::ArithmeticRightShiftOp>(
1631               op->getLoc(), int32_logits_type, op16_mul_op15_op6.getResult(),
1632               op17_sub_op8.getResult(), true);
1633 
1634       return buildRescale(rewriter, op, output_type,
1635                           op18_rshift_op16_op17.getResult(),
1636                           (1.0 / out_quant_type.getScale()) * (1.0 / 32768.0),
1637                           0, out_quant_type.getZeroPoint(), false, true);
1638     } else {
1639       op->emitOpError("Softmax: unknown quantization bitwidth");
1640       return llvm::None;
1641     }
1642   } else {
1643     SmallVector<int64_t> rsum_shape_v(input_type.getShape().begin(),
1644                                       input_type.getShape().end());
1645     rsum_shape_v[input_rank - 1] = 1;
1646     ArrayRef<int64_t> rsum_shape(rsum_shape_v);
1647 
1648     // Floating-point loewring is more direct:
1649     //
1650     // op1 = exp(logits)
1651     // op2 = reduce_sum(op1, -1)
1652     // op3 = reciprocal(op2)
1653     // op4 = mul(op1, op3)
1654     auto op1_exp_in =
1655         rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, logits_value);
1656     RankedTensorType rsum_type =
1657         RankedTensorType::get(rsum_shape, output_type.getElementType());
1658 
1659     // Keep dims so we don't need to reshape later
1660     auto op2_reducesum_op1 = rewriter.create<tosa::ReduceSumOp>(
1661         op->getLoc(), rsum_type, op1_exp_in.getResult(),
1662         rewriter.getI64IntegerAttr(input_rank - 1));
1663     auto op3_reciprocal_op2 = rewriter.create<tosa::ReciprocalOp>(
1664         op->getLoc(), op2_reducesum_op1.getType(),
1665         op2_reducesum_op1.getResult());
1666 
1667     return rewriter
1668         .create<tosa::MulOp>(op->getLoc(), output_type, op1_exp_in.getResult(),
1669                              op3_reciprocal_op2.getResult(), 0)
1670         .getResult();
1671   }
1672 }
1673 
1674 // Lowers LogSoftmax to a sequence of TOSA ops.
convertLogSoftmaxOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value logits_value)1675 llvm::Optional<Value> convertLogSoftmaxOp(PatternRewriter& rewriter,
1676                                           Operation* op, Value result_value,
1677                                           Value logits_value) {
1678   // log_softmax = log(exp(logits) / reduce_sum(exp(logits), -1))
1679   // op1 = exp(logits)
1680   // op2 = reduce_sum(op1, -1)
1681   // op3 = reciprocal(op2)
1682   // op4 = mul(op1, op3)
1683   // op5 = log(op4)
1684 
1685   RankedTensorType output_type =
1686       result_value.getType().dyn_cast<RankedTensorType>();
1687   // Not a ranked tensor output
1688   if (!output_type) {
1689     op->emitOpError("LogSoftmax: output type not ranked tensor.");
1690     return llvm::None;
1691   }
1692 
1693   RankedTensorType input_type =
1694       op->getOperand(0).getType().dyn_cast<RankedTensorType>();
1695   if (!input_type) {
1696     op->emitOpError("LogSoftmax: input type not ranked tensor.");
1697     return llvm::None;
1698   }
1699 
1700   mlir::quant::UniformQuantizedType in_quant_type =
1701       input_type.getElementType()
1702           .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
1703   mlir::quant::UniformQuantizedType out_quant_type =
1704       output_type.getElementType()
1705           .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
1706   if (in_quant_type || out_quant_type) {
1707     op->emitOpError("Quantized log_softmax lowering not implemented yet");
1708     return llvm::None;
1709   }
1710 
1711   auto op1_exp_in =
1712       rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, logits_value);
1713 
1714   // reduce_sum on last dimension
1715   int32_t input_rank = input_type.getShape().size();
1716   SmallVector<int64_t> rsum_shape(output_type.getShape().begin(),
1717                                   output_type.getShape().end());
1718   rsum_shape[input_rank - 1] = 1;
1719   RankedTensorType rsum_type =
1720       RankedTensorType::get(rsum_shape, output_type.getElementType());
1721   // Keep dims so we don't need to reshape later
1722   auto op2_reducesum_op1 = rewriter.create<tosa::ReduceSumOp>(
1723       op->getLoc(), rsum_type, op1_exp_in.getResult(),
1724       rewriter.getI64IntegerAttr(input_rank - 1));
1725   auto op3_reciprocal_op2 = rewriter.create<tosa::ReciprocalOp>(
1726       op->getLoc(), op2_reducesum_op1.getType(), op2_reducesum_op1.getResult());
1727 
1728   auto op4_mul_op1_op3 = rewriter.create<tosa::MulOp>(
1729       op->getLoc(), output_type, op1_exp_in.getResult(),
1730       op3_reciprocal_op2.getResult(), 0);
1731 
1732   return rewriter
1733       .create<tosa::LogOp>(op->getLoc(), output_type,
1734                            op4_mul_op1_op3.getResult())
1735       .getResult();
1736 }
1737 
1738 // Lowers SpaceToDepth to a sequence of TOSA ops.  Supports NHWC.
convertSpaceToDepthOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,IntegerAttr block_size_attr,StringAttr data_format)1739 llvm::Optional<Value> convertSpaceToDepthOp(PatternRewriter& rewriter,
1740                                             Operation* op, Value result_value,
1741                                             Value input_value,
1742                                             IntegerAttr block_size_attr,
1743                                             StringAttr data_format) {
1744   // NHWC lowering version:
1745   // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1]//b, b, orig_shape[2]//b,
1746   // b, orig_shape[3]])
1747   // a3 = tf.transpose(a2, [0, 1, 3, 2, 4, 5])
1748   // a4 = tf.reshape(a3, [orig_shape[0], orig_shape[1]//b, orig_shape[2]//b,
1749   // orig_shape[3]*b*b])
1750   // return a4
1751   RankedTensorType output_type =
1752       result_value.getType().dyn_cast<RankedTensorType>();
1753 
1754   // Not a ranked tensor output.
1755   if (!output_type) {
1756     op->emitOpError("SpaceToDepth: output type not ranked tensor.");
1757     return llvm::None;
1758   }
1759 
1760   RankedTensorType input_type =
1761       input_value.getType().dyn_cast<RankedTensorType>();
1762   if (!input_type) {
1763     op->emitOpError("SpaceToDepth: input type not ranked tensor.");
1764     return llvm::None;
1765   }
1766 
1767   if (input_type.getRank() != 4) {
1768     op->emitOpError("SpaceToDepth: input rank not 4.");
1769     return llvm::None;
1770   }
1771 
1772   auto input_shape = input_type.getShape();
1773 
1774   if (!block_size_attr) {  // This is a required parameter
1775     op->emitOpError("SpaceToDepth: block size attribute not set.");
1776     return llvm::None;
1777   }
1778 
1779   SmallVector<int64_t, 2> block_size;
1780   block_size.assign(2, block_size_attr.getInt());
1781 
1782   if (!data_format) data_format = rewriter.getStringAttr("NHWC");
1783 
1784   if (data_format.getValue().str() != "NHWC") {
1785     op->emitOpError("SpaceToDepth: data format not NHWC.");
1786     return llvm::None;
1787   }
1788 
1789   assert(block_size[0] * block_size[1] != 0);
1790 
1791   SmallVector<int64_t, 6> a_reshape_dims;
1792   a_reshape_dims.push_back(input_shape[0]);
1793   a_reshape_dims.push_back(input_shape[1] / block_size[0]);
1794   a_reshape_dims.push_back(block_size[0]);
1795   a_reshape_dims.push_back(input_shape[2] / block_size[1]);
1796   a_reshape_dims.push_back(block_size[1]);
1797   a_reshape_dims.push_back(input_shape[3]);
1798 
1799   RankedTensorType a_reshape_output_type =
1800       RankedTensorType::get(a_reshape_dims, output_type.getElementType());
1801   auto a2_reshape_a_op = rewriter.create<tosa::ReshapeOp>(
1802       op->getLoc(), a_reshape_output_type, input_value,
1803       rewriter.getI64ArrayAttr(a_reshape_dims));
1804 
1805   llvm::Optional<Value> a3_transpose_perm = getConstTensor<int32_t>(
1806       rewriter, op, /*vec=*/{0, 1, 3, 2, 4, 5}, /*shape=*/{6});
1807 
1808   if (!a3_transpose_perm) return llvm::None;
1809 
1810   auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
1811       op->getLoc(), a_reshape_output_type, a2_reshape_a_op.getResult(),
1812       a3_transpose_perm.getValue());
1813 
1814   SmallVector<int64_t, 4> a3_reshape_dims;
1815   a3_reshape_dims.push_back(input_shape[0]);
1816   a3_reshape_dims.push_back(input_shape[1] / block_size[0]);
1817   a3_reshape_dims.push_back(input_shape[2] / block_size[1]);
1818   a3_reshape_dims.push_back(input_shape[3] * block_size[0] * block_size[1]);
1819 
1820   RankedTensorType a3_reshape_output_type =
1821       RankedTensorType::get(a3_reshape_dims, output_type.getElementType());
1822   return rewriter
1823       .create<tosa::ReshapeOp>(op->getLoc(), a3_reshape_output_type,
1824                                a3_transpose_a2_op.getResult(),
1825                                rewriter.getI64ArrayAttr(a3_reshape_dims))
1826       .getResult();
1827 }
1828 
1829 // Lowers DepthToSpace to a sequence of TOSA ops.  Supports NHWC.
convertDepthToSpaceOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,IntegerAttr block_size_attr,StringAttr data_format)1830 llvm::Optional<Value> convertDepthToSpaceOp(PatternRewriter& rewriter,
1831                                             Operation* op, Value result_value,
1832                                             Value input_value,
1833                                             IntegerAttr block_size_attr,
1834                                             StringAttr data_format) {
1835   // NHWC version
1836   // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1], orig_shape[2], b, b,
1837   // orig_shape[3] // (b*b)])
1838   // a3 = tf.transpose(a2, [0, 1, 3, 2, 4, 5])
1839   // a4 = tf.reshape(a3, [orig_shape[0], orig_shape[1] * b, orig_shape[2] * b,
1840   // orig_shape[3] // (b*b)])
1841   // return a4
1842 
1843   RankedTensorType output_type =
1844       result_value.getType().dyn_cast<RankedTensorType>();
1845 
1846   // Not a ranked tensor output
1847   if (!output_type) {
1848     op->emitOpError("DepthToSpace: output type not ranked tensor.");
1849     return llvm::None;
1850   }
1851 
1852   RankedTensorType input_type =
1853       input_value.getType().dyn_cast<RankedTensorType>();
1854   if (!input_type) {
1855     op->emitOpError("DepthToSpace: input type not ranked tensor.");
1856     return llvm::None;
1857   }
1858 
1859   if (input_type.getRank() != 4) return llvm::None;
1860   auto input_shape = input_type.getShape();
1861 
1862   if (!block_size_attr) {  // This is a required parameter
1863     op->emitOpError("DepthToSpace: block size attribute not set.");
1864     return llvm::None;
1865   }
1866 
1867   SmallVector<int64_t, 2> block_size;
1868   block_size.assign(2, block_size_attr.getInt());
1869 
1870   if (!data_format) data_format = rewriter.getStringAttr("NHWC");
1871   if (data_format.getValue().str() != "NHWC") {
1872     op->emitOpError("DepthToSpace: data format not NHWC.");
1873     return llvm::None;
1874   }
1875 
1876   assert(block_size[0] * block_size[1] != 0);
1877 
1878   SmallVector<int64_t, 6> a_reshape_dims;
1879   a_reshape_dims.push_back(input_shape[0]);
1880   a_reshape_dims.push_back(input_shape[1]);
1881   a_reshape_dims.push_back(input_shape[2]);
1882   a_reshape_dims.push_back(block_size[0]);
1883   a_reshape_dims.push_back(block_size[1]);
1884   a_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
1885 
1886   RankedTensorType a_reshape_output_type =
1887       RankedTensorType::get(a_reshape_dims, output_type.getElementType());
1888   auto a2_reshape_a_op = rewriter.create<tosa::ReshapeOp>(
1889       op->getLoc(), a_reshape_output_type, input_value,
1890       rewriter.getI64ArrayAttr(a_reshape_dims));
1891 
1892   llvm::Optional<Value> a3_transpose_perm = getConstTensor<int32_t>(
1893       rewriter, op, /*vec=*/{0, 1, 3, 2, 4, 5}, /*shape=*/{6});
1894 
1895   if (!a3_transpose_perm) return llvm::None;
1896 
1897   auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
1898       op->getLoc(), a_reshape_output_type, a2_reshape_a_op.getResult(),
1899       a3_transpose_perm.getValue());
1900 
1901   SmallVector<int64_t, 4> a3_reshape_dims;
1902   a3_reshape_dims.push_back(input_shape[0]);
1903   a3_reshape_dims.push_back(input_shape[1] * block_size[0]);
1904   a3_reshape_dims.push_back(input_shape[2] * block_size[1]);
1905   a3_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
1906 
1907   RankedTensorType a3_reshape_output_type =
1908       RankedTensorType::get(a3_reshape_dims, output_type.getElementType());
1909   return rewriter
1910       .create<tosa::ReshapeOp>(op->getLoc(), a3_reshape_output_type,
1911                                a3_transpose_a2_op.getResult(),
1912                                rewriter.getI64ArrayAttr(a3_reshape_dims))
1913       .getResult();
1914 }
1915 
1916 // Lowers Split to a sequence of TOSA ops.
convertSplitOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,int32_t num_split,int32_t axis)1917 llvm::Optional<SmallVector<Value>> convertSplitOp(
1918     PatternRewriter& rewriter, Operation* op, Value result_value,
1919     Value input_value, int32_t num_split, int32_t axis) {
1920   // This lowering creates num_split slice ops and ties them together
1921   // with IdentityN to get from an array of Operations to a single Operation
1922   // with a list of result tensors.
1923   RankedTensorType result_type =
1924       result_value.getType().dyn_cast<RankedTensorType>();
1925   // Not a ranked tensor output
1926   if (!result_type) {
1927     op->emitOpError("Split: output type not ranked tensor.");
1928     return llvm::None;
1929   }
1930 
1931   RankedTensorType input_type =
1932       input_value.getType().dyn_cast<RankedTensorType>();
1933   if (!input_type) {
1934     op->emitOpError("Split: input type not ranked tensor.");
1935     return llvm::None;
1936   }
1937 
1938   auto input_shape = input_type.getShape();
1939 
1940   SmallVector<Value> results_vec;
1941 
1942   assert(axis > 0 && axis < input_shape.size());
1943   assert((input_shape[axis] % num_split) == 0);
1944   assert(num_split > 0);
1945 
1946   int64_t slice_size = input_shape[axis] / num_split;
1947 
1948   for (int i = 0; i < num_split; i++) {
1949     // Each slice has a different begining point.
1950     // The slice size is actually the same each op.
1951     SmallVector<int64_t> begin_vals, size_vals;
1952 
1953     for (int j = 0; j < input_shape.size(); j++) {
1954       if (j == axis) {
1955         begin_vals.push_back(slice_size * i);
1956         size_vals.push_back(slice_size);
1957       } else {
1958         begin_vals.push_back(0);
1959         size_vals.push_back(input_shape[j]);
1960       }
1961     }
1962 
1963     ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
1964     ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
1965 
1966     auto slice_op = rewriter.create<tosa::SliceOp>(
1967         op->getLoc(),
1968         RankedTensorType::get(size_vals, result_type.getElementType()),
1969         input_value, begin, size);
1970 
1971     results_vec.push_back(slice_op.getResult());
1972   }
1973 
1974   return results_vec;
1975 }
1976 
1977 // Lowers SplitV to a sequence of TOSA ops.
convertSplitVOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,SmallVectorImpl<int32_t> & size_split,int32_t axis)1978 llvm::Optional<SmallVector<Value>> convertSplitVOp(
1979     PatternRewriter& rewriter, Operation* op, Value result_value,
1980     Value input_value, SmallVectorImpl<int32_t>& size_split, int32_t axis) {
1981   // This lowering creates num_split slice ops and ties them together
1982   // with IdentityN to get from an array of Operations to a single Operation
1983   // with a list of result tensors.
1984   RankedTensorType result_type =
1985       result_value.getType().dyn_cast<RankedTensorType>();
1986   // Not a ranked tensor output
1987   if (!result_type) {
1988     op->emitOpError("SplitV: output type not ranked tensor.");
1989     return llvm::None;
1990   }
1991 
1992   RankedTensorType input_type =
1993       input_value.getType().dyn_cast<RankedTensorType>();
1994   if (!input_type) {
1995     op->emitOpError("SplitV: input type not ranked tensor.");
1996     return llvm::None;
1997   }
1998 
1999   auto input_shape = input_type.getShape();
2000 
2001   SmallVector<Value> results_vec;
2002 
2003   assert(axis > 0 && axis < input_shape.size());
2004   int32_t size_split_sum = 0;
2005   for (int i = 0; i < size_split.size(); i++) {
2006     size_split_sum += size_split[i];
2007   }
2008 
2009   // The split sizes must sum up to the size of the axis being split
2010   assert(size_split_sum == input_shape[axis]);
2011 
2012   int32_t curr_split_start = 0;
2013   for (int i = 0; i < size_split.size(); i++) {
2014     // Each slice has a different begining point.
2015     // The slice size is different for each op.
2016     SmallVector<int64_t> begin_vals, size_vals;
2017 
2018     for (int j = 0; j < input_shape.size(); j++) {
2019       if (j == axis) {
2020         begin_vals.push_back(curr_split_start);
2021         size_vals.push_back(size_split[i]);
2022       } else {
2023         begin_vals.push_back(0);
2024         size_vals.push_back(input_shape[j]);
2025       }
2026     }
2027 
2028     ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
2029     ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
2030 
2031     auto slice_op = rewriter.create<tosa::SliceOp>(
2032         op->getLoc(),
2033         RankedTensorType::get(size_vals, result_type.getElementType()),
2034         input_value, begin, size);
2035 
2036     results_vec.push_back(slice_op.getResult());
2037 
2038     // Next start position
2039     curr_split_start += size_split[i];
2040   }
2041 
2042   return results_vec;
2043 }
2044 
2045 // Lowers StridedSlice to a sequence of TOSA ops.
convertStridedSliceOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value begin_value,Value end_value,Value strides_value,int32_t begin_mask,int32_t end_mask,int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask)2046 llvm::Optional<Value> convertStridedSliceOp(
2047     PatternRewriter& rewriter, Operation* op, Value result_value,
2048     Value input_value, Value begin_value, Value end_value, Value strides_value,
2049     int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
2050     int32_t new_axis_mask, int32_t shrink_axis_mask) {
2051   // The mask arguments are bitmasks where bit [i] applies to
2052   // dimension [i] of the input tensor.
2053   //
2054   // The rough algorithm for lowering strided slice is as follows:
2055   //
2056   // 0. Process begin/end masks, since they are basically syntactic sugar
2057   // on top of the begin_value/end_value arrays
2058   //
2059   // 1. Slice1: Ignoring stride, slice the interesting range from the input
2060   // tensor
2061   //
2062   // 2. Reshape2: Reshape the tensor from (1) such that each dimension with
2063   // stride is split into two dimensions of size_i/stride_i, stride_i. A naive
2064   // implementation doubles the input tensor rank, but only dimensions being
2065   // strided actually need to be doubled.
2066   //
2067   // 3. Slice3: Slice the tensor from (2) such that we select index [0] from
2068   // each of the stride_i dimensions in (2)
2069   //
2070   // 4. Reshape4: Reshape the tensor to eliminate the stride_i dimensions, add
2071   // any dimensions in new_axis_mask and remove any dimensions in the
2072   // shrink_axis_mask
2073 
2074   // Limitations:
2075   // * This implementation only supports ellipsis_mask=0 for now
2076   // * This implementation does not support reverse stride yet.  Will need
2077   // to insert tosa.Reverse operators for this.
2078   if (ellipsis_mask != 0) {
2079     (void)rewriter.notifyMatchFailure(op, "ellipses mask not supported yet");
2080   }
2081 
2082   ShapedType input_type = input_value.getType().cast<ShapedType>();
2083   ShapedType result_type = result_value.getType().cast<ShapedType>();
2084 
2085   // Extract the begin/end/stride tensors
2086   SmallVector<int32_t> begin, end, strides;
2087 
2088   DenseIntElementsAttr strides_attr;
2089 
2090   if (!matchPattern(strides_value, m_Constant(&strides_attr))) {
2091     (void)rewriter.notifyMatchFailure(op, "strides is not a constant");
2092     return llvm::None;
2093   }
2094 
2095   bool all_strides_one =
2096       strides_attr.isSplat() && strides_attr.getSplatValue<int32_t>() == 1;
2097   int32_t strides_size = strides_attr.getNumElements();
2098 
2099   // If all of the masks are set we can just bypass the entire thing.
2100   const int32_t all_masks_one = (1 << strides_size) - 1;
2101   if (all_strides_one && begin_mask == all_masks_one &&
2102       end_mask == all_masks_one) {
2103     return rewriter
2104         .create<tensor::CastOp>(op->getLoc(), result_type, input_value)
2105         .getResult();
2106   }
2107 
2108   if (failed(getVectorFromValue32(begin_value, begin))) {
2109     (void)rewriter.notifyMatchFailure(op, "begin isn't a constant");
2110     return llvm::None;
2111   }
2112 
2113   // If begin value is a constant we might be able to still bypass.
2114   for (auto val : llvm::enumerate(begin)) {
2115     if (val.value() == 0) begin_mask |= (0x1 << val.index());
2116   }
2117 
2118   if (all_strides_one && begin_mask == all_masks_one &&
2119       end_mask == all_masks_one) {
2120     return rewriter
2121         .create<tensor::CastOp>(op->getLoc(), result_type, input_value)
2122         .getResult();
2123   }
2124 
2125   if (failed(getVectorFromValue32(end_value, end))) {
2126     return (void)rewriter.notifyMatchFailure(op, "end isn't a constant"),
2127            llvm::None;
2128   }
2129 
2130   if (!input_type.hasRank()) {
2131     return (void)rewriter.notifyMatchFailure(op,
2132                                              "input type not ranked tensor."),
2133            llvm::None;
2134   }
2135 
2136   int32_t input_rank = input_type.getRank();
2137 
2138   if (failed(getVectorFromValue32(strides_value, strides))) {
2139     return (void)rewriter.notifyMatchFailure(op, "strides isn't a constant"),
2140            llvm::None;
2141   }
2142 
2143   // If strides is incomplete, pad out to the full size.
2144   while (strides.size() < input_rank) strides.push_back(1);
2145 
2146   // Unspecified begins should set the begin mask.
2147   while (begin.size() < input_rank) {
2148     begin_mask = begin_mask | (1 << begin.size());
2149     begin.push_back(0);
2150   }
2151 
2152   // Unspecified ends should set the end mask.
2153   while (end.size() < input_rank) {
2154     end_mask = end_mask | (1 << end.size());
2155     end.push_back(-1);
2156   }
2157 
2158   auto input_shape = input_type.getShape();
2159 
2160   SmallVector<int64_t> a1_begin(input_rank), a1_size(input_rank);
2161   SmallVector<int64_t> a2_shape(input_rank * 2);
2162   SmallVector<int64_t> a3_begin(input_rank * 2), a3_size(input_rank * 2);
2163   SmallVector<int64_t> a4_shape;
2164 
2165   // Step 0: Process the begin/end masks and build the begin/sizes for the
2166   // first slice
2167   int residual = 1;
2168   (void)residual;
2169   for (int i = 0; i < input_rank; i++) {
2170     if (begin_mask & (1 << i)) begin[i] = 0;
2171 
2172     if (end_mask & (1 << i)) end[i] = input_shape[i];
2173 
2174     // Wrap around index if begin and end is negative
2175     if (begin[i] < 0) begin[i] += input_shape[i];
2176 
2177     if (end[i] < 0) end[i] += input_shape[i];
2178 
2179     // TODO(suderman): support reverse stride
2180     a1_begin[i] = begin[i];
2181     a1_size[i] = end[i] - begin[i];
2182 
2183     a2_shape[i * 2 + 0] = a1_size[i] / strides[i];
2184     a2_shape[i * 2 + 1] = strides[i];
2185 
2186     a3_begin[i * 2 + 0] = 0;
2187     a3_begin[i * 2 + 1] = 0;
2188 
2189     if (shrink_axis_mask & (1 << i)) {
2190       a3_size[i * 2 + 0] = 1;
2191     } else {
2192       a3_size[i * 2 + 0] = a1_size[i] / strides[i];
2193     }
2194     a3_size[i * 2 + 1] = 1;
2195 
2196     if (!(shrink_axis_mask & (1 << i))) {
2197       if (new_axis_mask & (1 << i)) a4_shape.push_back(1);
2198       a4_shape.push_back((a1_size[i] / strides[i]));
2199     }
2200   }
2201 
2202   // Make sure we didn't lose any dimensions from the shrink_axis_mask
2203   assert(residual == 1);
2204 
2205   // Step 1: Slice the input array
2206   auto a1_slice_op = rewriter.create<tosa::SliceOp>(
2207       op->getLoc(), RankedTensorType::get(a1_size, input_type.getElementType()),
2208       input_value, rewriter.getI64ArrayAttr(a1_begin),
2209       rewriter.getI64ArrayAttr(a1_size));
2210 
2211   if (all_strides_one) {
2212     return rewriter
2213         .create<tensor::CastOp>(op->getLoc(), result_type, a1_slice_op)
2214         .getResult();
2215   }
2216 
2217   // Step 2: reshape the sliced array
2218   auto a2_reshape_op = rewriter.create<tosa::ReshapeOp>(
2219       op->getLoc(),
2220       RankedTensorType::get(a2_shape, input_type.getElementType()),
2221       a1_slice_op.getResult(), rewriter.getI64ArrayAttr(a2_shape));
2222 
2223   // Step 3: take a slice along the strides
2224   auto a3_slice_op = rewriter.create<tosa::SliceOp>(
2225       op->getLoc(), RankedTensorType::get(a3_size, input_type.getElementType()),
2226       a2_reshape_op.getResult(), rewriter.getI64ArrayAttr(a3_begin),
2227       rewriter.getI64ArrayAttr(a3_size));
2228 
2229   // Step 4: reshape the now-strided tensor
2230   return rewriter
2231       .create<tosa::ReshapeOp>(op->getLoc(), result_type,
2232                                a3_slice_op.getResult(),
2233                                rewriter.getI64ArrayAttr(a4_shape))
2234       .getResult();
2235 }
2236 
2237 // Lowers FloorDiv to a sequence of TOSA operators.
convertFloorDivOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value lhs_value,Value rhs_value)2238 llvm::Optional<Value> convertFloorDivOp(PatternRewriter& rewriter,
2239                                         Operation* op, Value result_value,
2240                                         Value lhs_value, Value rhs_value) {
2241   // FloorDiv lowering:
2242   // floor(1/rhs * lhs)
2243   //
2244   // a1 = reciprocal(rhs);
2245   // a2 = mul(lhs, a1);
2246   // a3 = floor(a2);
2247   // return a3;
2248   RankedTensorType output_type =
2249       result_value.getType().dyn_cast<RankedTensorType>();
2250   // Not a ranked tensor output
2251   if (!output_type) return llvm::None;
2252 
2253   Type element_type = output_type.getElementType();
2254 
2255   if (element_type.isa<IntegerType>()) {
2256     return rewriter
2257         .create<tosa::DivOp>(op->getLoc(), output_type, lhs_value, rhs_value)
2258         .getResult();
2259   }
2260 
2261   auto a1_reciprocal_rhs_op = rewriter.create<tosa::ReciprocalOp>(
2262       op->getLoc(), rhs_value.getType(), rhs_value);
2263   auto a2_mul_lhs_a1_op =
2264       rewriter.create<tosa::MulOp>(op->getLoc(), output_type, lhs_value,
2265                                    a1_reciprocal_rhs_op.getResult(), 0);
2266   return rewriter
2267       .create<tosa::FloorOp>(op->getLoc(), output_type,
2268                              a2_mul_lhs_a1_op.getResult())
2269       .getResult();
2270 }
2271 
2272 // Lowers FloorMod to a sequence of TOSA operators.
convertFloorModOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value lhs_value,Value rhs_value)2273 llvm::Optional<Value> convertFloorModOp(PatternRewriter& rewriter,
2274                                         Operation* op, Value result_value,
2275                                         Value lhs_value, Value rhs_value) {
2276   // FloorMod lowering:
2277   // (1/rhs * lhs) - floor(1/rhs * lhs)
2278   // a1 = reciprocal(rhs);
2279   // a2 = mul(lhs, a1);
2280   // a3 = floor(a2);
2281   // a4 = sub(a2, a3);
2282   // return a4;
2283 
2284   RankedTensorType output_type =
2285       result_value.getType().dyn_cast<RankedTensorType>();
2286   // Not a ranked tensor output
2287   if (!output_type) return llvm::None;
2288 
2289   auto a1_reciprocal_rhs_op = rewriter.create<tosa::ReciprocalOp>(
2290       op->getLoc(), rhs_value.getType(), rhs_value);
2291   auto a2_mul_lhs_a1_op =
2292       rewriter.create<tosa::MulOp>(op->getLoc(), output_type, lhs_value,
2293                                    a1_reciprocal_rhs_op.getResult(), 0);
2294   auto a3_floor_a2_op = rewriter.create<tosa::FloorOp>(
2295       op->getLoc(), output_type, a2_mul_lhs_a1_op.getResult());
2296   return rewriter
2297       .create<tosa::SubOp>(op->getLoc(), output_type,
2298                            a2_mul_lhs_a1_op.getResult(),
2299                            a3_floor_a2_op.getResult())
2300       .getResult();
2301 }
2302 
2303 // Lowers FusedActivation to a sequence of TOSA ops.
convertFusedActivation(PatternRewriter & rewriter,Operation * op,Value input_value,StringAttr fused_activation_fn)2304 llvm::Optional<Value> convertFusedActivation(PatternRewriter& rewriter,
2305                                              Operation* op, Value input_value,
2306                                              StringAttr fused_activation_fn) {
2307   ShapedType input_type = input_value.getType().dyn_cast<ShapedType>();
2308   if (!input_type) return llvm::None;
2309 
2310   bool input_is_qtype =
2311       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2312 
2313   if (input_is_qtype) {
2314     // We can always make output/input tensor's scale/zp always be the same
2315     // when legalizing fused_activation_function, as it's generated during
2316     // legalization.
2317     auto input_qtype =
2318         input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2319 
2320     if (fused_activation_fn.getValue() == "NONE") {
2321       return input_value;
2322     } else if (fused_activation_fn.getValue() == "RELU") {
2323       int32_t quantized_0 = input_qtype.getZeroPoint();
2324       int32_t quantized_max = input_qtype.getStorageTypeMax();
2325 
2326       auto clamp_op = rewriter.create<tosa::ClampOp>(
2327           op->getLoc(), input_type, input_value,
2328           rewriter.getI64IntegerAttr(quantized_0),
2329           rewriter.getI64IntegerAttr(quantized_max),
2330           rewriter.getF32FloatAttr(0), rewriter.getF32FloatAttr(0));
2331 
2332       return clamp_op.getResult();
2333     } else if (fused_activation_fn.getValue() == "RELU6") {
2334       int32_t quantized_0 = input_qtype.getZeroPoint();
2335       int32_t quantized_6 = std::llround((6.0f / input_qtype.getScale()) +
2336                                          input_qtype.getZeroPoint());
2337 
2338       auto clamp_op = rewriter.create<tosa::ClampOp>(
2339           op->getLoc(), input_type, input_value,
2340           rewriter.getI64IntegerAttr(quantized_0),
2341           rewriter.getI64IntegerAttr(quantized_6), rewriter.getF32FloatAttr(0),
2342           rewriter.getF32FloatAttr(0));
2343 
2344       return clamp_op.getResult();
2345     } else if (fused_activation_fn.getValue() == "RELU_N1_TO_1") {
2346       int32_t quantized_n1 = std::llround((-1.0f / input_qtype.getScale()) +
2347                                           input_qtype.getZeroPoint());
2348       int32_t quantized_1 = std::llround((1.0f / input_qtype.getScale()) +
2349                                          input_qtype.getZeroPoint());
2350 
2351       auto clamp_op = rewriter.create<tosa::ClampOp>(
2352           op->getLoc(), input_type, input_value,
2353           rewriter.getI64IntegerAttr(quantized_n1),
2354           rewriter.getI64IntegerAttr(quantized_1), rewriter.getF32FloatAttr(0),
2355           rewriter.getF32FloatAttr(0));
2356 
2357       return clamp_op.getResult();
2358     } else {
2359       op->emitWarning("convertFusedActivation: Not implemented yet");
2360       return llvm::None;
2361     }
2362   } else {
2363     if (fused_activation_fn.getValue() == "NONE") {
2364       return input_value;
2365     } else {
2366       // For non-quantized type, only support F32.
2367       if (!input_type.getElementType().isF32()) {
2368         op->emitOpError("ConvertTFLeakyReluOp: only support F32");
2369         return llvm::None;
2370       }
2371 
2372       if (fused_activation_fn.getValue() == "RELU") {
2373         return rewriter
2374             .create<tosa::ClampOp>(
2375                 op->getLoc(), input_type, input_value,
2376                 rewriter.getI64IntegerAttr(0),
2377                 rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
2378                 rewriter.getF32FloatAttr(0.0f),
2379                 rewriter.getF32FloatAttr(std::numeric_limits<float>::max()))
2380             .getResult();
2381       } else if (fused_activation_fn.getValue() == "RELU6") {
2382         return rewriter
2383             .create<tosa::ClampOp>(
2384                 op->getLoc(), input_type, input_value,
2385                 rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(6),
2386                 rewriter.getF32FloatAttr(0.0f), rewriter.getF32FloatAttr(6.0f))
2387             .getResult();
2388       } else if (fused_activation_fn.getValue() == "RELU_N1_TO_1") {
2389         return rewriter
2390             .create<tosa::ClampOp>(
2391                 op->getLoc(), input_type, input_value,
2392                 rewriter.getI64IntegerAttr(-1), rewriter.getI64IntegerAttr(1),
2393                 rewriter.getF32FloatAttr(-1.0), rewriter.getF32FloatAttr(1.0))
2394             .getResult();
2395       } else if (fused_activation_fn.getValue() == "TANH") {
2396         return rewriter
2397             .create<tosa::TanhOp>(op->getLoc(), input_type, input_value)
2398             .getResult();
2399       } else {
2400         // Unsupported activation type. Bail out.
2401         return llvm::None;
2402       }
2403     }
2404   }
2405 
2406   return llvm::None;
2407 }
2408 
2409 // Common function for lowering reduce operations to TOSA ops.
2410 template <typename T>
convertReduceOpCommon(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims,Type reduce_element_type,bool is_quantized,double input_scale,int64_t input_zp,double output_scale,int64_t output_zp)2411 llvm::Optional<Value> convertReduceOpCommon(
2412     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2413     Value input_value, ElementsAttr axes_elems, bool keep_dims,
2414     Type reduce_element_type, bool is_quantized, double input_scale,
2415     int64_t input_zp, double output_scale, int64_t output_zp) {
2416   RankedTensorType input_type =
2417       input_value.getType().dyn_cast<RankedTensorType>();
2418   if (!input_type) return llvm::None;
2419 
2420   ArrayRef<int64_t> input_shape = input_type.getShape();
2421   ArrayRef<int64_t> output_shape = output_type.getShape();
2422   auto input_rank = input_shape.size();
2423   Value val = input_value;
2424 
2425   if (axes_elems.getNumElements() == 0) {
2426     // No axes means return the original tensor.
2427     auto identity_op =
2428         rewriter.create<tosa::IdentityOp>(op->getLoc(), output_type, val);
2429     val = identity_op.getResult();
2430   } else {
2431     // Reduce along each axis
2432     SmallVector<int64_t> shape_vec(input_shape.begin(), input_shape.end());
2433 
2434     if (is_quantized) {
2435       val = buildRescaleToInt32(rewriter, op, val, input_scale, input_zp);
2436     }
2437 
2438     for (int i = 0; i < axes_elems.getNumElements(); i++) {
2439       int64_t axis_val = axes_elems.getValue<IntegerAttr>(i).getInt();
2440       if (axis_val < 0) axis_val += input_rank;
2441       auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
2442 
2443       shape_vec[axis_val] = 1;
2444       RankedTensorType reduce_type =
2445           RankedTensorType::get(shape_vec, reduce_element_type);
2446 
2447       auto reduce_op =
2448           rewriter.create<T>(op->getLoc(), reduce_type, val, axis_attr);
2449 
2450       val = reduce_op.getResult();
2451     }
2452 
2453     if (is_quantized) {
2454       RankedTensorType output_rescale_type =
2455           RankedTensorType::get(shape_vec, output_type.getElementType());
2456       val = buildRescale(rewriter, op, output_rescale_type, val, output_scale,
2457                          0, output_zp, false, true);
2458     }
2459 
2460     // Optionally squeeze out the reduced axes.
2461     if (!keep_dims) {
2462       auto reshape_op = rewriter.create<tosa::ReshapeOp>(
2463           op->getLoc(), output_type, val,
2464           rewriter.getI64ArrayAttr(output_shape));
2465       val = reshape_op.getResult();
2466     }
2467   }
2468 
2469   return val;
2470 }
2471 
2472 // Lowers ReduceAll to a sequence of TOSA ops.
convertReduceAllOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2473 llvm::Optional<Value> convertReduceAllOp(
2474     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2475     Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2476   RankedTensorType input_type =
2477       input_value.getType().dyn_cast<RankedTensorType>();
2478   if (!input_type) return llvm::None;
2479 
2480   return convertReduceOpCommon<tosa::ReduceAllOp>(
2481       rewriter, op, output_type, input_value, axes_elems, keep_dims,
2482       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2483 }
2484 
2485 // Lowers ReduceAny to a sequence of TOSA ops.
convertReduceAnyOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2486 llvm::Optional<Value> convertReduceAnyOp(
2487     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2488     Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2489   RankedTensorType input_type =
2490       input_value.getType().dyn_cast<RankedTensorType>();
2491   if (!input_type) return llvm::None;
2492 
2493   return convertReduceOpCommon<tosa::ReduceAnyOp>(
2494       rewriter, op, output_type, input_value, axes_elems, keep_dims,
2495       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2496 }
2497 
2498 // Lowers ReduceMin to a sequence of TOSA ops.
convertReduceMinOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2499 llvm::Optional<Value> convertReduceMinOp(
2500     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2501     Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2502   RankedTensorType input_type =
2503       input_value.getType().dyn_cast<RankedTensorType>();
2504   if (!input_type) return llvm::None;
2505 
2506   return convertReduceOpCommon<tosa::ReduceMinOp>(
2507       rewriter, op, output_type, input_value, axes_elems, keep_dims,
2508       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2509 }
2510 
2511 // Lowers ReduceMax to a sequence of TOSA ops.
convertReduceMaxOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2512 llvm::Optional<Value> convertReduceMaxOp(
2513     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2514     Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2515   RankedTensorType input_type =
2516       input_value.getType().dyn_cast<RankedTensorType>();
2517   if (!input_type) return llvm::None;
2518 
2519   return convertReduceOpCommon<tosa::ReduceMaxOp>(
2520       rewriter, op, output_type, input_value, axes_elems, keep_dims,
2521       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2522 }
2523 
2524 // Lowers ReduceProd to a sequence of TOSA ops.
convertReduceProdOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2525 llvm::Optional<Value> convertReduceProdOp(
2526     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2527     Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2528   RankedTensorType input_type =
2529       input_value.getType().dyn_cast<RankedTensorType>();
2530   if (!input_type) return llvm::None;
2531 
2532   bool input_is_qtype =
2533       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2534   bool output_is_qtype =
2535       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2536 
2537   if (input_is_qtype || output_is_qtype) {
2538     op->emitOpError(
2539         "ConvertReduceProdOp: input/output tensor should "
2540         "be all floating-point.");
2541     return llvm::None;
2542   }
2543 
2544   return convertReduceOpCommon<tosa::ReduceProdOp>(
2545       rewriter, op, output_type, input_value, axes_elems, keep_dims,
2546       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2547 }
2548 
2549 // Lowers ReduceSum to a sequence of TOSA ops.
convertReduceSumOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2550 llvm::Optional<Value> convertReduceSumOp(
2551     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2552     Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2553   RankedTensorType input_type =
2554       input_value.getType().dyn_cast<RankedTensorType>();
2555   if (!input_type) return llvm::None;
2556 
2557   bool input_is_qtype =
2558       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2559   bool output_is_qtype =
2560       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2561 
2562   if (input_is_qtype != output_is_qtype) {
2563     op->emitOpError(
2564         "ConvertReduceSumOp: input/output tensor should "
2565         "be all quantized or all floating-point.");
2566     return llvm::None;
2567   }
2568 
2569   double input_scale = 1.0f;
2570   double output_scale = 1.0f;
2571   int64_t input_zp = 0;
2572   int64_t output_zp = 0;
2573   Type reduce_element_type = input_type.getElementType();
2574 
2575   if (input_is_qtype) {
2576     auto input_qtype =
2577         input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2578     auto output_qtype =
2579         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2580 
2581     int32_t input_shift = 20;
2582 
2583     input_scale =
2584         static_cast<double>(1 << input_shift) * input_qtype.getScale();
2585     output_scale =
2586         1.0 / (output_qtype.getScale() * static_cast<double>(1 << input_shift));
2587 
2588     input_zp = input_qtype.getZeroPoint();
2589     output_zp = output_qtype.getZeroPoint();
2590     reduce_element_type = rewriter.getI32Type();
2591   }
2592 
2593   return convertReduceOpCommon<tosa::ReduceSumOp>(
2594       rewriter, op, output_type, input_value, axes_elems, keep_dims,
2595       reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
2596       output_zp);
2597 }
2598 
2599 // Lowers ReduceMean to a sequence of TOSA ops.
convertReduceMeanOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2600 llvm::Optional<Value> convertReduceMeanOp(
2601     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2602     Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2603   // reduce_mean is lowered as followed:
2604   // op1 = reduce_sum(input)
2605   // op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
2606 
2607   RankedTensorType input_type =
2608       input_value.getType().dyn_cast<RankedTensorType>();
2609   if (!input_type) return llvm::None;
2610 
2611   bool input_is_qtype =
2612       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2613   bool output_is_qtype =
2614       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2615 
2616   if (input_is_qtype != output_is_qtype) {
2617     op->emitOpError(
2618         "ConvertReduceSumOp: input/output tensor should "
2619         "be all quantized or all floating-point.");
2620     return llvm::None;
2621   }
2622 
2623   // Only supports float type mean() if it's non-quantized
2624   if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) {
2625     op->emitWarning(
2626         "Failed convertReduceMean: input unquantized type but output element "
2627         "not FloatType!");
2628     return llvm::None;
2629   }
2630 
2631   int64_t input_rank = input_type.getRank();
2632   int64_t num_elems_on_reduced_axis = 1;
2633   for (int i = 0; i < axes_elems.getNumElements(); i++) {
2634     int64_t axis_val = axes_elems.getValue<IntegerAttr>(i).getInt();
2635     if (axis_val < 0) axis_val += input_rank;
2636     num_elems_on_reduced_axis *= input_type.getShape()[axis_val];
2637   }
2638   double div_scale = 1.0 / static_cast<double>(num_elems_on_reduced_axis);
2639 
2640   double input_scale = 1.0f;
2641   double output_scale = 1.0f;
2642   int64_t input_zp = 0;
2643   int64_t output_zp = 0;
2644   Type reduce_element_type = input_type.getElementType();
2645 
2646   if (input_is_qtype) {
2647     auto input_qtype =
2648         input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2649     auto output_qtype =
2650         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2651 
2652     // Combine 'div_scale' as part of output rescale
2653     output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale();
2654 
2655     input_zp = input_qtype.getZeroPoint();
2656     output_zp = output_qtype.getZeroPoint();
2657     reduce_element_type = rewriter.getI32Type();
2658   }
2659 
2660   auto val = convertReduceOpCommon<tosa::ReduceSumOp>(
2661       rewriter, op, output_type, input_value, axes_elems, keep_dims,
2662       reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
2663       output_zp);
2664 
2665   if (!val.hasValue()) return llvm::None;
2666 
2667   if (!input_is_qtype) {
2668     Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
2669     return rewriter
2670         .create<tosa::MulOp>(op->getLoc(), output_type, val.getValue(),
2671                              div_const, 0)
2672         .getResult();
2673   }
2674 
2675   return val;
2676 }
2677 
2678 // Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize.
convertResizeOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,StringRef mode,bool align_corners,bool half_pixel_centers)2679 llvm::Optional<Value> convertResizeOp(PatternRewriter& rewriter, Operation* op,
2680                                       RankedTensorType output_type,
2681                                       Value input_value, StringRef mode,
2682                                       bool align_corners,
2683                                       bool half_pixel_centers) {
2684   RankedTensorType input_type =
2685       input_value.getType().dyn_cast<RankedTensorType>();
2686   if (!input_type) return llvm::None;
2687 
2688   if (input_type.getRank() != 4 || output_type.getRank() != 4) {
2689     op->emitOpError("convertResizeOp: input/output must be rank 4");
2690     return llvm::None;
2691   }
2692 
2693   bool input_is_qtype =
2694       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2695   bool output_is_qtype =
2696       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2697 
2698   if (input_is_qtype != output_is_qtype) {
2699     op->emitOpError(
2700         "ConvertResizeOp: input/output tensor should "
2701         "be all quantized or all floating-point.");
2702     return llvm::None;
2703   }
2704 
2705   if (!input_is_qtype) {
2706     if (!input_type.getElementType().isa<mlir::FloatType>()) {
2707       op->emitOpError(
2708           "ConvertResizeOp: only quantized or float types supported.");
2709       return llvm::None;
2710     }
2711   }
2712 
2713   auto input_shape = input_type.getShape();
2714   auto output_shape = output_type.getShape();
2715 
2716   size_t input_height = input_shape[1];
2717   size_t input_width = input_shape[2];
2718   size_t output_height = output_shape[1];
2719   size_t output_width = output_shape[2];
2720 
2721   double fp_stride_y =
2722       static_cast<double>(input_height) / static_cast<double>(output_height);
2723   double fp_stride_x =
2724       static_cast<double>(input_width) / static_cast<double>(output_width);
2725   if (align_corners && output_height > 1) {
2726     fp_stride_y = static_cast<double>(input_height - 1) /
2727                   static_cast<double>(output_height - 1);
2728   }
2729   if (align_corners && output_width > 1) {
2730     fp_stride_x = static_cast<double>(input_width - 1) /
2731                   static_cast<double>(output_width - 1);
2732   }
2733 
2734   double fp_offset_y, fp_offset_x;
2735   if (half_pixel_centers) {
2736     fp_offset_y = fp_stride_y * 0.5f - 0.5f;
2737     fp_offset_x = fp_stride_x * 0.5f - 0.5f;
2738   } else {
2739     fp_offset_y = 0.0f;
2740     fp_offset_x = 0.0f;
2741   }
2742 
2743   // oh * fp_stride_y + fp_offset_y = ix
2744 
2745   ArrayAttr output_size =
2746       rewriter.getI64ArrayAttr({static_cast<int64_t>(output_height),
2747                                 static_cast<int64_t>(output_width)});
2748   StringAttr resize_mode = rewriter.getStringAttr(mode);
2749 
2750   if (input_is_qtype) {
2751     // Magic shift number TFLite resize bilinear use
2752     // reference: tensorflow/lite/kernels/internal/reference/reference_ops.h
2753     int32_t shift = 10;
2754 
2755     // 1.0 is equivalent to (1 << shift) in quantized space.
2756     // Here we noted as unit = (1 << shift).
2757     double unit = static_cast<double>(1 << shift);
2758 
2759     // Stride and Offset is int16.
2760     int32_t stride_y = std::lround(fp_stride_y * unit);
2761     int32_t stride_x = std::lround(fp_stride_x * unit);
2762     int32_t offset_y = std::lround(fp_offset_y * unit);
2763     int32_t offset_x = std::lround(fp_offset_x * unit);
2764 
2765     // Numerically we can decrement shift to let these number fits within 16
2766     // bits but that's not commonly seen and won't match TFLite reference
2767     if (stride_y > std::numeric_limits<int16_t>::max() ||
2768         stride_x > std::numeric_limits<int16_t>::max() ||
2769         stride_y < std::numeric_limits<int16_t>::min() ||
2770         stride_x < std::numeric_limits<int16_t>::min() ||
2771         offset_y > std::numeric_limits<int16_t>::max() ||
2772         offset_x > std::numeric_limits<int16_t>::max() ||
2773         offset_y < std::numeric_limits<int16_t>::min() ||
2774         offset_x < std::numeric_limits<int16_t>::min()) {
2775       op->emitOpError("OpResize: stride or offset out of 16 bits");
2776       return llvm::None;
2777     }
2778 
2779     ArrayAttr stride = rewriter.getI64ArrayAttr({stride_y, stride_x});
2780     ArrayAttr offset = rewriter.getI64ArrayAttr({offset_y, offset_x});
2781     IntegerAttr shift_attr = rewriter.getI32IntegerAttr(shift);
2782 
2783     // If quantized bilinear mode, need to lower to RESIZE + RESCALE pair.
2784     if (mode == "BILINEAR") {
2785       RankedTensorType output_acc_type;
2786       auto input_element_qtype =
2787           input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2788 
2789       bool scale32;
2790 
2791       // TOSA RESIZE: 16 bit input -> 48 bit output, or 8 bit input -> 32 bit
2792       // output.
2793       if (input_element_qtype.getStorageTypeIntegralWidth() == 16) {
2794         scale32 = false;
2795         output_acc_type = RankedTensorType::get(output_type.getShape(),
2796                                                 rewriter.getIntegerType(48));
2797       } else if (input_element_qtype.getStorageTypeIntegralWidth() == 8) {
2798         scale32 = true;
2799         output_acc_type = RankedTensorType::get(output_type.getShape(),
2800                                                 rewriter.getI32Type());
2801       } else {
2802         op->emitOpError("OpResize: support 16-bit and 8-bit quantized input");
2803         return llvm::None;
2804       }
2805 
2806       auto resize_op = rewriter.create<tosa::ResizeOp>(
2807           op->getLoc(), output_acc_type, input_value, output_size, stride,
2808           offset, shift_attr, rewriter.getF32ArrayAttr({0.0, 0.0}),
2809           rewriter.getF32ArrayAttr({0.0, 0.0}), resize_mode);
2810 
2811 #ifdef RESIZE_BILINEAR_LOWER_SYMMETRIC_ROUNDING
2812       // TFLite resize_bilinear always assume input and output tensors have
2813       // same scale That means we only need to arithmetic right shift with
2814       // (2 * shift)
2815       // TODO(suderman): Align TFLite rounding behavior
2816       // TFLite also uses symmetric rounding by doing 'x / (1 << 20)'
2817       // TOSA arithmetic right shift is doing standard rounding.
2818       // Right now it's legalized using GreaterEqualOp + SelectOp to conform
2819       // to TFLite reference. But this eventually should be fixed in TFLite
2820       // reference
2821       Value cst_zero = getTosaConstTensorSingleI32(rewriter, op, 0);
2822       Value cst_twenty = getTosaConstTensorSingleI32(rewriter, op, 20);
2823 
2824       auto ge_op = rewriter.create<tosa::GreaterEqualOp>(
2825           op->getLoc(), output_bool_type, resize_op.getResult(), cst_zero);
2826 
2827       auto abs_op = rewriter.create<tosa::AbsOp>(op->getLoc(), output_acc_type,
2828                                                  resize_op.getResult());
2829 
2830       auto rshift_op = rewriter.create<tosa::ArithmeticRightShiftOp>(
2831           op->getLoc(), output_acc_type, abs_op.getResult(), cst_twenty, true);
2832 
2833       auto negate_op = rewriter.create<tosa::NegateOp>(
2834           op->getLoc(), output_acc_type, rshift_op.getResult());
2835 
2836       auto select_op = rewriter.create<tosa::SelectOp>(
2837           op->getLoc(), output_acc_type, ge_op.getResult(),
2838           rshift_op.getResult(), negate_op.getResult());
2839 
2840       auto cast_op = rewriter.create<tosa::CastOp>(op->getLoc(), output_type,
2841                                                    select_op.getResult());
2842 
2843       return cast_op.getResult();
2844 #else
2845       // This should be the expected lowering, but is +-1 within compared to
2846       // TFLite reference.
2847       return buildRescale(rewriter, op, output_type, resize_op.getResult(),
2848                           1.0 / (1 << 20), 0, 0, false, scale32);
2849 #endif
2850 
2851     } else if (mode == "NEAREST_NEIGHBOR") {
2852       auto resize_op = rewriter.create<tosa::ResizeOp>(
2853           op->getLoc(), output_type, input_value, output_size, stride, offset,
2854           shift_attr, rewriter.getF32ArrayAttr({0.0, 0.0}),
2855           rewriter.getF32ArrayAttr({0.0, 0.0}), resize_mode);
2856       return resize_op.getResult();
2857     } else {
2858       op->emitOpError(
2859           "OpResize: only support BILINEAR or NEAREST_NEIGHBOR mode");
2860       return llvm::None;
2861     }
2862   } else {
2863     auto resize_op = rewriter.create<tosa::ResizeOp>(
2864         op->getLoc(), output_type, input_value, output_size,
2865         rewriter.getI64ArrayAttr({0, 0}), rewriter.getI64ArrayAttr({0, 0}),
2866         rewriter.getI32IntegerAttr(0),
2867         rewriter.getF32ArrayAttr(
2868             {static_cast<float>(fp_stride_y), static_cast<float>(fp_stride_x)}),
2869         rewriter.getF32ArrayAttr(
2870             {static_cast<float>(fp_offset_y), static_cast<float>(fp_offset_x)}),
2871         resize_mode);
2872     return resize_op.getResult();
2873   }
2874 }
2875 
2876 // Lowers Quantize to a sequence of TOSA quantization ops.
convertQuantizeOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,double scale,int64_t zeropoint)2877 llvm::Optional<Value> convertQuantizeOp(PatternRewriter& rewriter,
2878                                         Operation* op,
2879                                         RankedTensorType output_type,
2880                                         Value input_value, double scale,
2881                                         int64_t zeropoint) {
2882   RankedTensorType input_type =
2883       input_value.getType().dyn_cast<RankedTensorType>();
2884   if (!input_type) return llvm::None;
2885 
2886   auto output_shape = output_type.getShape();
2887   auto output_element_type = output_type.getElementType();
2888 
2889   // output element type could only be quantized integer
2890   if (!output_element_type.isa<mlir::quant::QuantizedType>()) {
2891     op->emitWarning(
2892         "Lowering quantizeOp but output element type not quantized!");
2893     return llvm::None;
2894   }
2895 
2896   RankedTensorType output_fp_type =
2897       RankedTensorType::get(output_shape, rewriter.getF32Type());
2898 
2899   Value zp_val =
2900       getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(zeropoint));
2901 
2902   auto op1_mul_in = rewriter.create<tosa::MulOp>(
2903       op->getLoc(), output_fp_type, input_value,
2904       getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale)), 0);
2905 
2906   auto op2_add_op1 = rewriter.create<tosa::AddOp>(
2907       op->getLoc(), output_fp_type, op1_mul_in.getResult(), zp_val);
2908 
2909   auto op3_cast_op2 = rewriter.create<tosa::CastOp>(op->getLoc(), output_type,
2910                                                     op2_add_op1.getResult());
2911 
2912   return op3_cast_op2.getResult();
2913 }
2914 
2915 // Lowers Dequantize to a sequence of TOSA dequantization ops.
convertDequantizeOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ArrayRef<float> scale,ArrayRef<float> zeropoint,int64_t dim)2916 llvm::Optional<Value> convertDequantizeOp(
2917     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2918     Value input_value, ArrayRef<float> scale, ArrayRef<float> zeropoint,
2919     int64_t dim) {
2920   RankedTensorType input_type =
2921       input_value.getType().dyn_cast<RankedTensorType>();
2922   if (!input_type) return llvm::None;
2923 
2924   // input element type could only be quantized integer
2925   if (!input_type.getElementType().isa<mlir::quant::QuantizedType>())
2926     return llvm::None;
2927 
2928   Optional<Value> zp_val;
2929   if (zeropoint.size() == 1) {
2930     zp_val = getTosaConstTensorSingleF32(rewriter, op,
2931                                          static_cast<float>(zeropoint[0]));
2932   } else {
2933     SmallVector<int64_t> shape;
2934     shape.resize(input_type.getRank(), 1);
2935     shape[dim] = zeropoint.size();
2936     zp_val = getConstTensor(rewriter, op, zeropoint, shape);
2937   }
2938 
2939   Optional<Value> scale_val;
2940   if (scale.size() == 1) {
2941     scale_val =
2942         getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale[0]));
2943   } else {
2944     SmallVector<int64_t> shape;
2945     shape.resize(input_type.getRank(), 1);
2946     shape[dim] = scale.size();
2947     scale_val = getConstTensor(rewriter, op, scale, shape);
2948   }
2949 
2950   if (!zp_val || !scale_val) return llvm::None;
2951 
2952   auto op1_cast_in =
2953       rewriter.create<tosa::CastOp>(op->getLoc(), output_type, input_value);
2954 
2955   auto op2_sub_op1 = rewriter.create<tosa::SubOp>(
2956       op->getLoc(), output_type, op1_cast_in.getResult(), zp_val.getValue());
2957 
2958   return rewriter
2959       .create<tosa::MulOp>(op->getLoc(), output_type, op2_sub_op1.getResult(),
2960                            scale_val.getValue(), 0)
2961       .getResult();
2962 }
2963 
2964 // Lowers FakeQuant to a sequence of TOSA quantization ops.
convertFakeQuantOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,double min,double max,int64_t num_bits,bool narrow_range)2965 llvm::Optional<Value> convertFakeQuantOp(PatternRewriter& rewriter,
2966                                          Operation* op,
2967                                          RankedTensorType output_type,
2968                                          Value input_value, double min,
2969                                          double max, int64_t num_bits,
2970                                          bool narrow_range) {
2971   // FakeQuant is lowered as follow:
2972   // op1 = quantize(input)
2973   // op2 = dequantize(op1)
2974 
2975   RankedTensorType input_type =
2976       input_value.getType().dyn_cast<RankedTensorType>();
2977   if (!input_type) return llvm::None;
2978 
2979   // quantized as INT<num_bits>, where num_bits can only be 8, 16
2980   if (num_bits != 8 && num_bits != 16) {
2981     op->emitWarning("FakeQuantOp lowering handles only 8 and 16 for num_bits!");
2982     return llvm::None;
2983   }
2984 
2985   // This code originates from
2986   // tensorflow/core/kernels/fake_quant_ops_functor.h.
2987   int32_t qmax = (1 << (num_bits)) - 1;
2988   int32_t qmin = narrow_range ? 1 : 0;
2989 
2990   float nudged_min, nudged_max, nudged_scale;
2991   tensorflow_nudge(min, max, qmin, qmax, &nudged_min, &nudged_max,
2992                    &nudged_scale);
2993 
2994   Value cst_min = getTosaConstTensorSingleF32(rewriter, op, nudged_min);
2995   Value cst_max = getTosaConstTensorSingleF32(rewriter, op, nudged_max);
2996   Value cst_scale = getTosaConstTensorSingleF32(rewriter, op, nudged_scale);
2997   Value cst_inv_scale =
2998       getTosaConstTensorSingleF32(rewriter, op, 1.0f / nudged_scale);
2999   Value cst_half = getTosaConstTensorSingleF32(rewriter, op, 0.5f);
3000 
3001   // This code originates from
3002   // tensorflow/core/kernels/fake_quant_ops_functor.h.
3003   auto op1_min_in = rewriter.create<tosa::MinimumOp>(op->getLoc(), output_type,
3004                                                      input_value, cst_max);
3005 
3006   auto op2_max_op1 = rewriter.create<tosa::MaximumOp>(
3007       op->getLoc(), output_type, op1_min_in.getResult(), cst_min);
3008 
3009   auto op3_sub_op2 = rewriter.create<tosa::SubOp>(
3010       op->getLoc(), output_type, op2_max_op1.getResult(), cst_min);
3011 
3012   auto op4_mul_op3 = rewriter.create<tosa::MulOp>(
3013       op->getLoc(), output_type, op3_sub_op2.getResult(), cst_inv_scale, 0);
3014 
3015   auto op5_add_op4 = rewriter.create<tosa::AddOp>(
3016       op->getLoc(), output_type, op4_mul_op3.getResult(), cst_half);
3017 
3018   auto op6_floor_op5 = rewriter.create<tosa::FloorOp>(op->getLoc(), output_type,
3019                                                       op5_add_op4.getResult());
3020 
3021   auto op7_mul_op6 = rewriter.create<tosa::MulOp>(
3022       op->getLoc(), output_type, op6_floor_op5.getResult(), cst_scale, 0);
3023 
3024   return rewriter
3025       .create<tosa::AddOp>(op->getLoc(), output_type, op7_mul_op6.getResult(),
3026                            cst_min)
3027       .getResult();
3028 }
3029 
convertTFConv2DCommon(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input,Value filter,Value bias,ArrayAttr strides_attr,ArrayAttr dilations_attr,ArrayAttr explicit_padding_attr,StringRef padding_ref,StringRef data_format_ref)3030 llvm::Optional<Value> convertTFConv2DCommon(
3031     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
3032     Value input, Value filter, Value bias, ArrayAttr strides_attr,
3033     ArrayAttr dilations_attr, ArrayAttr explicit_padding_attr,
3034     StringRef padding_ref, StringRef data_format_ref) {
3035   RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
3036   RankedTensorType filter_type = filter.getType().dyn_cast<RankedTensorType>();
3037   // Not a ranked tensor output
3038   if (!input_type) return llvm::None;
3039   if (!filter_type) return llvm::None;
3040 
3041   // Transpose [H, W, I, O] to [O, H, W, I]
3042   auto filter_shape = filter_type.getShape();
3043   SmallVector<int64_t, 4> a1_transpose_dims;
3044   a1_transpose_dims.push_back(filter_shape[3]);
3045   a1_transpose_dims.push_back(filter_shape[0]);
3046   a1_transpose_dims.push_back(filter_shape[1]);
3047   a1_transpose_dims.push_back(filter_shape[2]);
3048   llvm::Optional<Value> a1_filter_transpose_perm = getConstTensor<int32_t>(
3049       rewriter, op, /*vec=*/{3, 0, 1, 2}, /*shape=*/{4});
3050 
3051   if (!a1_filter_transpose_perm) return llvm::None;
3052 
3053   auto a1_filter_transpose_op = rewriter.create<tosa::TransposeOp>(
3054       op->getLoc(),
3055       RankedTensorType::get(a1_transpose_dims, filter_type.getElementType()),
3056       filter, a1_filter_transpose_perm.getValue());
3057 
3058   // Only support NHWC now.
3059   if (data_format_ref.str() != "NHWC") {
3060     op->emitWarning("convertTDConv2DCommon only supports NHWC!");
3061     return llvm::None;
3062   }
3063 
3064   ArrayAttr stride;
3065   ArrayAttr dilation;
3066   ArrayAttr pad;
3067   {
3068     if (!strides_attr) {
3069       stride = rewriter.getI64ArrayAttr({1, 1});
3070     } else {
3071       // Note: hardcoded to NHWC for now
3072       int64_t stride_h = strides_attr[1].cast<IntegerAttr>().getInt();
3073       int64_t stride_w = strides_attr[2].cast<IntegerAttr>().getInt();
3074       stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
3075     }
3076   }
3077   {
3078     if (!dilations_attr) {
3079       dilation = rewriter.getI64ArrayAttr({1, 1});
3080     } else {
3081       // Note: hardcoded to NHWC for now
3082       int64_t dilation_h = dilations_attr[1].cast<IntegerAttr>().getInt();
3083       int64_t dilation_w = dilations_attr[2].cast<IntegerAttr>().getInt();
3084       dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
3085     }
3086   }
3087   {
3088     tensorflow::Padding tf_pad;
3089     if (!GetPaddingFromString(padding_ref.str(), &tf_pad).ok()) {
3090       op->emitWarning("Could not get padding data from padding string term!");
3091       return llvm::None;
3092     }
3093 
3094     tensorflow::TensorFormat data_format_tf;
3095     if (!FormatFromString(data_format_ref.str(), &data_format_tf))
3096       return llvm::None;
3097 
3098     if (tf_pad == tensorflow::Padding::EXPLICIT) {
3099       pad = getPaddingValuesFromExplicitPadAttr(explicit_padding_attr,
3100                                                 data_format_tf, rewriter);
3101     } else {
3102       if (!getPaddingValuesFromPadType(tf_pad, data_format_tf,
3103                                        0,  // tensorflow::FORMAT_HWIO
3104                                        input_type, filter_type, stride,
3105                                        dilation, rewriter, pad))
3106         return llvm::None;
3107     }
3108   }
3109 
3110   return rewriter
3111       .create<tosa::Conv2DOp>(op->getLoc(), output_type, input,
3112                               a1_filter_transpose_op.getResult(), bias, pad,
3113                               stride, dilation)
3114       .getResult();
3115 }
3116 
3117 // Lowers Gather operators to a sequence of TOSA ops.
convertGatherOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value params_value,Value indices_value,int32_t batch_dims,int32_t axis)3118 llvm::Optional<Value> convertGatherOp(PatternRewriter& rewriter, Operation* op,
3119                                       Value result_value, Value params_value,
3120                                       Value indices_value, int32_t batch_dims,
3121                                       int32_t axis) {
3122   auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
3123   auto params_type = params_value.getType().dyn_cast<RankedTensorType>();
3124   auto indices_type = indices_value.getType().dyn_cast<RankedTensorType>();
3125 
3126   if (!result_type || !params_type || !indices_type) return llvm::None;
3127 
3128   // batch_dims indicates the number of batch dimensions in params and
3129   // indices axis indicates the axis at which the gather indexing is
3130   // applied.  axis must be >= batch_dims.  When axis is equal to
3131   // batch_dims, the right-most batch dimension disappears.
3132   //
3133   // N: number of batches
3134   // Computed as product of params.shape[0:batch_dims-1]
3135   //
3136   // W: number of indices in each batch
3137   // Computed as product of indices.shape[batch_dims:]
3138   //
3139   // K: range of each index
3140   // Computed as  params.shape[axis:axis+rank(indices)-1]
3141   //
3142   // C: number of channels for each index
3143   // Computed as:  LeftChannels * RightChannels:
3144   // product(params.shape[batch_dims:axis]) * product(params.shape[axis+1:])
3145   //
3146   // The params tensor needs to be transposed, then reshaped to move the
3147   // dimensions into [N, K, C] order.
3148   //
3149   // The dimensions of the input params[] tensor are grouped in the following
3150   // order to begin with:
3151   //
3152   //  [Batch, LeftChannels, Indices, RightChannels]
3153   //  |-----||------------||-------||-------------|
3154   //     N         C_l         K          C_r
3155   //
3156   // Where Batch (N), Indices (K) can be one or more dimensions in size,
3157   // while LeftChannels and RightChannels represent the group of data channels
3158   // (C) to the left and right (C_l, C_r) of the indices; the sum of these two
3159   // is one or more dimensions in size, but either one may be zero depending
3160   // on how axis was specified by the caller.
3161   //
3162   // The resulting tensor will look like:
3163   //
3164   //  [Batch, Indices, LeftChannels, RightChannels]
3165   //  |-----||-------||---------------------------|
3166   //     N       K                 C
3167   //
3168   // The indices tensor simply needs a reshape to flatten all of the
3169   // batch dimensions (N) together and flatten all of the indices (W)
3170   // together.
3171   //
3172   // Then do the tosa.GATHER
3173   //
3174   // output[N,W,C] = tosa.GATHER(values[N,K,C], indices[N,W])
3175   //
3176   // Finally, the resulting tensor will have shape [N, W, C], where C is a
3177   // flattened version of [LeftChannels, RightChannels].  We need to reshape
3178   // to unflatten to:
3179   //
3180   //  [N, W, LeftChannels, RightChannels]
3181   //
3182   // and finally transpose back to the output shape
3183   //
3184   //  [Batch, LeftChannels, Non-Batch-Indices, RightChannels]
3185 
3186   int N = 1, W = 1, K = 1, C = 1;
3187 
3188   int params_rank = params_type.getShape().size();
3189   int indices_rank = indices_type.getShape().size();
3190 
3191   if (!(batch_dims <= indices_rank)) {
3192     op->emitOpError("Batch_dims must be <= indices_rank for a valid gather op");
3193     return llvm::None;
3194   }
3195 
3196   if (!(axis >= batch_dims)) {
3197     op->emitOpError("axis must be >= batch_dims for a valid gather op");
3198     return llvm::None;
3199   }
3200 
3201   // Sizes for each of these fields.
3202   SmallVector<int64_t> params_batch, params_indices, params_left_channels,
3203       params_right_channels;
3204 
3205   // Dimension indices for each of these fields.
3206   SmallVector<int64_t> params_idx_batch, params_idx_indices,
3207       params_idx_left_channels, params_idx_right_channels;
3208 
3209   // Read through the params tensor dimensions left-to-right and extract the
3210   // different fields.
3211   for (int i = 0; i < params_rank; i++) {
3212     // When batch_dims == axis, the batch dimension gets replaced.
3213     if (i < batch_dims && i < axis) {
3214       params_batch.push_back(params_type.getShape()[i]);
3215       params_idx_batch.push_back(i);
3216     } else if (i < axis) {
3217       params_left_channels.push_back(params_type.getShape()[i]);
3218       params_idx_left_channels.push_back(i);
3219     } else if (i < (axis + 1)) {
3220       params_indices.push_back(params_type.getShape()[i]);
3221       params_idx_indices.push_back(i);
3222     } else {
3223       params_right_channels.push_back(params_type.getShape()[i]);
3224       params_idx_right_channels.push_back(i);
3225     }
3226   }
3227 
3228   // Calculate N, K, W, C
3229   for (int i = 0; i < batch_dims; i++) N *= params_type.getShape()[i];
3230 
3231   for (int i = batch_dims; i < indices_rank; i++)
3232     W *= indices_type.getShape()[i];
3233 
3234   K = params_type.getShape()[axis];
3235 
3236   for (int i = batch_dims; i < axis; i++) C *= params_type.getShape()[i];
3237   for (int i = (axis + 1); i < params_rank; i++) C *= params_type.getShape()[i];
3238 
3239   // Check for obviously invalid values before doing a divide.
3240   if (N <= 0 || K <= 0 || W <= 0 || C <= 0) {
3241     op->emitOpError(
3242         "N, K, W, or C was calculated as <= zero.  Invalid dimensions for "
3243         "Gather");
3244     return llvm::None;
3245   }
3246 
3247   /////////////////////////////////////////////
3248   // Build up the params transpose operator
3249   SmallVector<int32_t> params_transpose_perm;
3250   SmallVector<int64_t> params_transpose_shape;
3251 
3252   // Batch
3253   for (int i = 0; i < params_batch.size(); i++) {
3254     params_transpose_perm.push_back(params_idx_batch[i]);
3255     params_transpose_shape.push_back(params_batch[i]);
3256   }
3257 
3258   // Indices
3259   for (int i = 0; i < params_indices.size(); i++) {
3260     params_transpose_perm.push_back(params_idx_indices[i]);
3261     params_transpose_shape.push_back(params_indices[i]);
3262   }
3263 
3264   // LeftChannels
3265   for (int i = 0; i < params_left_channels.size(); i++) {
3266     params_transpose_perm.push_back(params_idx_left_channels[i]);
3267     params_transpose_shape.push_back(params_left_channels[i]);
3268   }
3269 
3270   // RightChannels
3271   for (int i = 0; i < params_right_channels.size(); i++) {
3272     params_transpose_perm.push_back(params_idx_right_channels[i]);
3273     params_transpose_shape.push_back(params_right_channels[i]);
3274   }
3275 
3276   /////////////////////////////////////////////
3277   // Build up the result reshape, in prepration for transpose
3278   // [N, W, C] -> [ Batch, Indices, LeftChannels, RightChannels ]
3279   SmallVector<int64_t> result_reshape_shape;
3280 
3281   // Indices
3282   for (int i = 0; i < indices_type.getShape().size(); i++) {
3283     result_reshape_shape.push_back(indices_type.getShape()[i]);
3284   }
3285 
3286   // Left channels
3287   for (int i = 0; i < params_left_channels.size(); i++) {
3288     result_reshape_shape.push_back(params_left_channels[i]);
3289   }
3290 
3291   // Right channels.  But remove the axis dimension.
3292   for (int i = 0; i < params_right_channels.size(); i++) {
3293     result_reshape_shape.push_back(params_right_channels[i]);
3294   }
3295 
3296   /////////////////////////////////////////////
3297   // Build up the result transpose operator.
3298   SmallVector<int32_t> result_transpose_perm;
3299 
3300   // Batch dimensions
3301   for (int i = 0; i < batch_dims; i++) {
3302     result_transpose_perm.push_back(i);
3303   }
3304 
3305   // LeftChannels
3306   for (int i = 0; i < params_left_channels.size(); i++) {
3307     result_transpose_perm.push_back(i + indices_type.getShape().size());
3308   }
3309 
3310   // Indices (remainder of dimensions after batch).
3311   for (int i = batch_dims; i < (indices_type.getShape().size()); i++) {
3312     result_transpose_perm.push_back(i);
3313   }
3314 
3315   // RightChannels, coming from after both the Indices and LeftChannels.
3316   for (int i = 0; i < params_right_channels.size(); i++) {
3317     result_transpose_perm.push_back(i + indices_type.getShape().size() +
3318                                     params_left_channels.size());
3319   }
3320 
3321   SmallVector<int64_t> tosa_values_shape = {N, K, C};
3322   SmallVector<int64_t> tosa_indices_shape = {N, W};
3323   SmallVector<int64_t> tosa_gather_result_shape = {N, W, C};
3324 
3325   llvm::Optional<Value> params_transpose_perm_val = getConstTensor<int32_t>(
3326       rewriter, op, params_transpose_perm,
3327       {static_cast<int64_t>(params_transpose_perm.size())});
3328 
3329   llvm::Optional<Value> result_transpose_perm_val = getConstTensor<int32_t>(
3330       rewriter, op, result_transpose_perm,
3331       {static_cast<int64_t>(result_transpose_perm.size())});
3332 
3333   if (!params_transpose_perm_val || !result_transpose_perm_val)
3334     return llvm::None;
3335 
3336   auto params_transpose_op = rewriter.create<tosa::TransposeOp>(
3337       op->getLoc(),
3338       RankedTensorType::get(params_transpose_shape,
3339                             params_type.getElementType()),
3340       params_value, params_transpose_perm_val.getValue());
3341 
3342   auto tosa_values_reshape_op = rewriter.create<tosa::ReshapeOp>(
3343       op->getLoc(),
3344       RankedTensorType::get(tosa_values_shape, params_type.getElementType()),
3345       params_transpose_op.getResult(),
3346       rewriter.getI64ArrayAttr(tosa_values_shape));
3347 
3348   auto tosa_indices_reshape_op = rewriter.create<tosa::ReshapeOp>(
3349       op->getLoc(),
3350       RankedTensorType::get(tosa_indices_shape, indices_type.getElementType()),
3351       indices_value, rewriter.getI64ArrayAttr(tosa_indices_shape));
3352 
3353   auto tosa_gather_op = rewriter.create<tosa::GatherOp>(
3354       op->getLoc(),
3355       RankedTensorType::get(tosa_gather_result_shape,
3356                             result_type.getElementType()),
3357       tosa_values_reshape_op.getResult(), tosa_indices_reshape_op.getResult());
3358 
3359   auto tosa_result_reshape_op = rewriter.create<tosa::ReshapeOp>(
3360       op->getLoc(),
3361       RankedTensorType::get(result_reshape_shape, params_type.getElementType()),
3362       tosa_gather_op.getResult(),
3363       rewriter.getI64ArrayAttr(result_reshape_shape));
3364 
3365   return rewriter
3366       .create<tosa::TransposeOp>(op->getLoc(), result_type,
3367                                  tosa_result_reshape_op.getResult(),
3368                                  result_transpose_perm_val.getValue())
3369       .getResult();
3370 }
3371 
3372 // Lowers Gather operators to a sequence of TOSA ops.
convertGatherNdOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value params_value,Value indices_value)3373 llvm::Optional<Value> convertGatherNdOp(PatternRewriter& rewriter,
3374                                         Operation* op, Value result_value,
3375                                         Value params_value, Value indices_value)
3376 
3377 {
3378   auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
3379   auto params_type = params_value.getType().dyn_cast<RankedTensorType>();
3380   auto indices_type = indices_value.getType().dyn_cast<RankedTensorType>();
3381 
3382   if (!result_type || !params_type || !indices_type) return llvm::None;
3383 
3384   // N: number of batches
3385   // Always 1 for GatherND
3386   //
3387   // Because TOSA's GATHER operator already uses the symbol 'N' for
3388   // the number of batches, we will use the symbol 'ND' to specify the
3389   // number of dimensions that are sliced from params instead of'N' in
3390   // the TF MLIR documentation.
3391   //
3392   // ND: indices.shape[-1]
3393   //
3394   // W: number of indices in each batch
3395   // Computed as:
3396   // product(indices.shape[0:-1]) (all but the last dimension)
3397   //
3398   // K: range of each index
3399   // Computed as:
3400   // product(params.shape[0:ND-1])
3401   //
3402   // C: number of channels for each index
3403   // Computed as:
3404   // product(params.shape[ND:])
3405   //
3406   // The params tensor needs to be reshaped, but not transposed, to move the
3407   // dimensions into [N, K, C] order.
3408   //
3409   // The dimensions of the input params[] tensor are grouped in the following
3410   // order to begin with:
3411   //
3412   //  [ParamIndices, ParamChannels]
3413   //  |------------||-------------|
3414   //         K              C
3415   //
3416   // The reshape simply flattens the params tensor into a 2D [K, C] shape.
3417   //
3418   // Indices needs to be put in the form of [N, W], but a simple flattening
3419   // will not suffice, because the indices need to index into a [W]-shape
3420   // vector instead of the params.shape[0:ND-1] tensor that we had before.
3421   //
3422   // To flatten the coordinates, first reshape indices to a [W, ND] matrix,
3423   // where the matrix now represents W ND-dimensional coordinates into the
3424   // params tensor.
3425   //
3426   // From here, we take each of the ND dimensions and multiply it with
3427   // the the size of the next params dimension (or 1 for the last
3428   // dimension), then sum all these together with a reduce_sum
3429   // operator.  This is exactly the same mathematics as one would use
3430   // flatten the indicies of an N-dimensional row-major array into a
3431   // 1-D array in C.
3432   //
3433   // More precisely, do an element-wise multiply with [params.shape[1
3434   // .. ND], 1] in axis 1, then reduce_sum in axis 1 to flatten to a
3435   // [W]-shaped tensor, then trivially reshape to [N=1, W] to be
3436   // compatible with the GATHER operator's shape.
3437   //
3438   // Then perform the tosa.GATHER() operation.
3439   //
3440   // Now we have result = [N, K, C].
3441   //
3442   // Reshape with a single, simple reshape to the final output shape of:
3443   //  [Indices, ParamChannels]
3444   //
3445   // Where, Indices is indices.shape[0:ND-1]
3446 
3447   int N = 1, W = 1, K = 1, C = 1, ND = 1;
3448 
3449   int params_rank = params_type.getShape().size();
3450   int indices_rank = indices_type.getShape().size();
3451 
3452   ND = indices_type.getShape()[indices_rank - 1];
3453 
3454   if (ND > params_rank) {
3455     op->emitOpError("Size of last dimension on indices must be <= params rank");
3456     return llvm::None;
3457   }
3458 
3459   // Calculate N, K, W, C.  (N is always 1)
3460   for (int i = 0; i < (indices_rank - 1); i++) {
3461     W *= indices_type.getShape()[i];
3462   }
3463 
3464   for (int i = 0; i < ND; i++) {
3465     K *= params_type.getShape()[i];
3466   }
3467 
3468   for (int i = ND; i < params_rank; i++) {
3469     C *= params_type.getShape()[i];
3470   }
3471 
3472   SmallVector<int64_t, 3> tosa_values_shape({N, K, C});
3473   SmallVector<int64_t, 2> tosa_indices_shape({N, W});
3474   SmallVector<int64_t, 2> indices_matrix_shape({W, ND});
3475   SmallVector<int64_t, 3> tosa_gather_result_shape({N, W, C});
3476 
3477   auto tosa_values_reshape_op = rewriter.create<tosa::ReshapeOp>(
3478       op->getLoc(),
3479       RankedTensorType::get(tosa_values_shape, params_type.getElementType()),
3480       params_value, rewriter.getI64ArrayAttr(tosa_values_shape));
3481 
3482   // Flatten the input indices tensor to an [W, ND] matrix.
3483   auto indices_matrix_reshape_op = rewriter.create<tosa::ReshapeOp>(
3484       op->getLoc(),
3485       RankedTensorType::get(indices_matrix_shape,
3486                             indices_type.getElementType()),
3487       indices_value, rewriter.getI64ArrayAttr(indices_matrix_shape));
3488 
3489   SmallVector<int32_t> flattened_coeff_vec;
3490   for (int i = 1; i < ND; i++) {
3491     flattened_coeff_vec.push_back(params_type.getShape()[i]);
3492   }
3493   flattened_coeff_vec.push_back(1);
3494 
3495   for (int i = ND - 1; i > 0; i--) {
3496     flattened_coeff_vec[i - 1] *= flattened_coeff_vec[i];
3497   }
3498 
3499   llvm::Optional<Value> flattened_coeff_value = getConstTensor<int32_t>(
3500       rewriter, op, flattened_coeff_vec,
3501       {static_cast<int64_t>(flattened_coeff_vec.size())});
3502 
3503   if (!flattened_coeff_value) return llvm::None;
3504 
3505   // Multiply the coefficients by the coordinates
3506   auto flattened_indices_mul_op = rewriter.create<tosa::MulOp>(
3507       op->getLoc(),
3508       RankedTensorType::get(indices_matrix_shape,
3509                             indices_type.getElementType()),
3510       indices_matrix_reshape_op.getResult(), flattened_coeff_value.getValue(),
3511       0);
3512 
3513   // Sum up the products of the coefficients and coordinates
3514   auto flattened_indices_reduce_op = rewriter.create<tosa::ReduceSumOp>(
3515       op->getLoc(),
3516       RankedTensorType::get(tosa_indices_shape, indices_type.getElementType()),
3517       flattened_indices_mul_op.getResult(), rewriter.getI64IntegerAttr(1));
3518 
3519   // And reshape to [N, W]
3520   auto tosa_indices_reshape_op = rewriter.create<tosa::ReshapeOp>(
3521       op->getLoc(),
3522       RankedTensorType::get(tosa_indices_shape, indices_type.getElementType()),
3523       flattened_indices_reduce_op.getResult(),
3524       rewriter.getI64ArrayAttr(tosa_indices_shape));
3525 
3526   // Now the gather op itself
3527   auto tosa_gather_op = rewriter.create<tosa::GatherOp>(
3528       op->getLoc(),
3529       RankedTensorType::get(tosa_gather_result_shape,
3530                             result_type.getElementType()),
3531       tosa_values_reshape_op.getResult(), tosa_indices_reshape_op.getResult());
3532 
3533   // Finally, reshape back to the original output shape of [Indices,
3534   // ParamChannels].
3535   return rewriter
3536       .create<tosa::ReshapeOp>(op->getLoc(), result_type,
3537                                tosa_gather_op.getResult(),
3538                                rewriter.getI64ArrayAttr(result_type.getShape()))
3539       .getResult();
3540 }
3541 
3542 // Lowers OneHot operator to a sequence of TOSA ops.
convertOneHotOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value indices_value,Value on_value,Value off_value,int32_t depth,int32_t axis)3543 llvm::Optional<Value> convertOneHotOp(PatternRewriter& rewriter, Operation* op,
3544                                       Value result_value, Value indices_value,
3545                                       Value on_value, Value off_value,
3546                                       int32_t depth, int32_t axis) {
3547   auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
3548   auto indices_type = indices_value.getType().dyn_cast<RankedTensorType>();
3549   auto on_value_type = on_value.getType().dyn_cast<RankedTensorType>();
3550   auto off_value_type = off_value.getType().dyn_cast<RankedTensorType>();
3551 
3552   if (!result_type || !indices_type || !on_value_type || !off_value_type)
3553     return llvm::None;
3554 
3555   // OneHot operator creates a new tensor with shape indices.shape[:axis] +
3556   // [depth] + indices.shape[axis:] For each index in 'indices', it needs to
3557   // be within range of [0, depth - 1] and the [..., k, ...] = on_value (if k
3558   // = index), or [..., k, ...] = off_value (if k != index)
3559   //
3560   // The lowering below assumes depth is always known at compile time.
3561   // TBD for depth resolved in run time.
3562   //
3563   // OneHot can be lowered as TOSA Scatter, where off_value being mapped to
3564   // 'values_in', on_value being mapped to 'input', and indices naturally
3565   // mapped to 'indices'. Also the dimensions of TOSA scatter (N, W, K, C)
3566   // need to be picked.
3567   //
3568   // N: number of elements of input indices
3569   // Computed as:
3570   // product(indices.shape[:])
3571   //
3572   // K: newly added dimension
3573   // K = depth
3574   //
3575   // W, C: dummy dimension now
3576   // W = C = 1
3577   //
3578   // High level description of lowering looks like:
3579   // 1. off_value is reshaped/tiled into [N, K, C]
3580   // 2. on_value is reshaped/tiled into [N, W, C]
3581   // 3. indices is reshaped into [N, W]
3582   // 4. scatter into [N, K, C]
3583   // 5. reshaped into [LeftDims, RightDims, K]
3584   // 6. transpose into [LeftDims, K, RightDims]
3585   // 7. reshaped to result.shape
3586 
3587   if (on_value_type.getRank() != 0 || off_value_type.getRank() != 0) {
3588     op->emitOpError("OneHotOp: on_value/off_value needs to be scalar");
3589     return llvm::None;
3590   }
3591 
3592   if (axis < -1 || axis > indices_type.getRank()) {
3593     op->emitOpError("OneHotOp: axis out of valie range [-1, indices.rank]");
3594     return llvm::None;
3595   }
3596 
3597   // axis = -1 is equivalent to axis = indices.rank
3598   if (axis == -1) {
3599     axis = indices_type.getRank();
3600   }
3601 
3602   int N = 1, W = 1, C = 1;
3603   int K = depth;
3604   int left_dim = 1, right_dim = 1;
3605 
3606   for (int32_t i = 0; i < indices_type.getRank(); i++) {
3607     int32_t dim = indices_type.getShape()[i];
3608     N *= dim;
3609     if (i >= axis) {
3610       right_dim *= dim;
3611     } else {
3612       left_dim *= dim;
3613     }
3614   }
3615 
3616   // Reshape on_value to [1, 1, 1]
3617   auto op1_reshape_on_value = rewriter.create<tosa::ReshapeOp>(
3618       op->getLoc(),
3619       RankedTensorType::get({1, 1, 1}, on_value_type.getElementType()),
3620       on_value, rewriter.getI64ArrayAttr({1, 1, 1}));
3621 
3622   // And tile to [N, W, C]
3623   auto op2_tile_op1 = rewriter.create<tosa::TileOp>(
3624       op->getLoc(),
3625       RankedTensorType::get({N, W, C}, on_value_type.getElementType()),
3626       op1_reshape_on_value.getResult(), rewriter.getI64ArrayAttr({N, W, C}));
3627 
3628   // Reshape off_value to [1, 1, 1]
3629   auto op3_reshape_off_value = rewriter.create<tosa::ReshapeOp>(
3630       op->getLoc(),
3631       RankedTensorType::get({1, 1, 1}, off_value_type.getElementType()),
3632       off_value, rewriter.getI64ArrayAttr({1, 1, 1}));
3633 
3634   // And tile to [N, K, C]
3635   auto op4_tile_op3 = rewriter.create<tosa::TileOp>(
3636       op->getLoc(),
3637       RankedTensorType::get({N, K, C}, on_value_type.getElementType()),
3638       op3_reshape_off_value.getResult(), rewriter.getI64ArrayAttr({N, K, C}));
3639 
3640   // Reshape indices to [N, W]
3641   auto op5_reshape_indices = rewriter.create<tosa::ReshapeOp>(
3642       op->getLoc(),
3643       RankedTensorType::get({N, W}, indices_type.getElementType()),
3644       indices_value, rewriter.getI64ArrayAttr({N, W}));
3645 
3646   // Scatter to [N, K, C]
3647   auto op6_scatter_op4_op5_op2 = rewriter.create<tosa::ScatterOp>(
3648       op->getLoc(),
3649       RankedTensorType::get({N, K, C}, result_type.getElementType()),
3650       op4_tile_op3.getResult(), op5_reshape_indices.getResult(),
3651       op2_tile_op1.getResult());
3652 
3653   // Reshaped to [LeftDims, RightDims, K]. C being squeezed out since it's 1.
3654   auto op7_reshape_op6 = rewriter.create<tosa::ReshapeOp>(
3655       op->getLoc(),
3656       RankedTensorType::get({left_dim, right_dim, K},
3657                             result_type.getElementType()),
3658       op6_scatter_op4_op5_op2.getResult(),
3659       rewriter.getI64ArrayAttr({left_dim, right_dim, K}));
3660 
3661   // Transposed to [LeftDims, K, RightDims].
3662   llvm::Optional<Value> perm_const =
3663       getConstTensor<int32_t>(rewriter, op, /*vec=*/{0, 2, 1}, /*shape=*/{3});
3664 
3665   if (!perm_const) return llvm::None;
3666 
3667   auto op8_transpose_op7 = rewriter.create<tosa::TransposeOp>(
3668       op->getLoc(),
3669       RankedTensorType::get({left_dim, K, right_dim},
3670                             result_type.getElementType()),
3671       op7_reshape_op6.getResult(), perm_const.getValue());
3672 
3673   // Reshaped to result.shape.
3674   return rewriter
3675       .create<tosa::ReshapeOp>(op->getLoc(), result_type,
3676                                op8_transpose_op7.getResult(),
3677                                rewriter.getI64ArrayAttr(result_type.getShape()))
3678       .getResult();
3679 }
3680 
3681 };  // namespace tosa
3682 };  // namespace mlir
3683