• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file contains legalizations common to mapping both TensorFlow and
17 // TensorFlow Lite to TOSA. It operates generically on ops and does not have
18 // a hard reference on either dialect.
19 //
20 // Conversion functions return llvm::None on a legalization failure or a
21 // legalized value on success.  Callers must check for presence of an
22 // llvm::Optional value after each call.
23 
24 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
25 
26 #include <climits>
27 #include <cstddef>
28 #include <cstdint>
29 #include <iterator>
30 #include <numeric>
31 
32 #include "llvm/ADT/ArrayRef.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/Support/FormatVariadic.h"
35 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
36 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
37 #include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
38 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
39 #include "mlir/IR/Matchers.h"  // from @llvm-project
40 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
42 
43 namespace mlir {
44 namespace tosa {
45 
multiply_dims(int64_t a,int64_t b)46 static int64_t multiply_dims(int64_t a, int64_t b) {
47   if (a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize) {
48     return ShapedType::kDynamicSize;
49   }
50   return a * b;
51 }
52 
multiply_dims(llvm::ArrayRef<int64_t> dims,int64_t res=1)53 static int64_t multiply_dims(llvm::ArrayRef<int64_t> dims, int64_t res = 1) {
54   for (auto dim : dims) {
55     if (ShapedType::isDynamic(dim)) {
56       return ShapedType::kDynamicSize;
57     }
58     res = res * dim;
59   }
60   return res;
61 }
62 
count_dynamic_dims(llvm::ArrayRef<int64_t> dims)63 static int64_t count_dynamic_dims(llvm::ArrayRef<int64_t> dims) {
64   int64_t count = 0;
65   for (auto dim : dims)
66     if (ShapedType::isDynamic(dim)) ++count;
67   return count;
68 }
69 
70 namespace {
71 // Given an axis that can be a positive or negative value and the tensor size,
72 // return the adjusted axis value wrapped around the tensor size.
adjust_axis(int32_t axis,int32_t tensor_size)73 int32_t adjust_axis(int32_t axis, int32_t tensor_size) {
74   return axis >= 0 ? axis % tensor_size : (axis + tensor_size) % tensor_size;
75 }
76 };  // namespace
77 
78 // Copied Nudge implementation from
79 // tensorflow/core/kernels/fake_quant_ops_functor.h.
80 // Suggested approach to avoid significant TensorFlow
81 // build dependency.
tensorflow_nudge(const float min,const float max,const int quant_min,const int quant_max,float * nudged_min,float * nudged_max,float * scale)82 void tensorflow_nudge(const float min, const float max, const int quant_min,
83                       const int quant_max, float* nudged_min, float* nudged_max,
84                       float* scale) {
85   const float quant_min_float = static_cast<float>(quant_min);
86   const float quant_max_float = static_cast<float>(quant_max);
87   *scale = (max - min) / (quant_max_float - quant_min_float);
88   const float zero_point_from_min = quant_min_float - min / *scale;
89   const uint16_t nudged_zero_point = [zero_point_from_min, quant_min,
90                                       quant_min_float, quant_max,
91                                       quant_max_float] {
92     if (zero_point_from_min < quant_min_float) {
93       return static_cast<uint16_t>(quant_min);
94     }
95     if (zero_point_from_min > quant_max_float) {
96       return static_cast<uint16_t>(quant_max);
97     }
98     return static_cast<uint16_t>(std::round(zero_point_from_min));
99   }();
100   *nudged_min = (quant_min_float - nudged_zero_point) * (*scale);
101   *nudged_max = (quant_max_float - nudged_zero_point) * (*scale);
102 }
103 
104 // Lowers the Pack operator to TOSA.
convertPackOp(PatternRewriter & rewriter,Operation * op,Value result_value,SmallVectorImpl<Value> & inputs,int32_t axis)105 llvm::Optional<Value> convertPackOp(PatternRewriter& rewriter, Operation* op,
106                                     Value result_value,
107                                     SmallVectorImpl<Value>& inputs,
108                                     int32_t axis) {
109   //////////////////////////////////////////////////
110   // Operator: output = Pack([values], axis) or output = Stack([values], axis)
111   // Lowering:
112   //
113   // This operator is lowered into a series of pairwise tosa.concat()
114   // operators and a reshape
115   // Depending on the inputs, a tranpose operator is also generated:
116   //
117   // Step 1: concatenate the tensors
118   // a1_concat = tosa.concat(input[0], input[1], axis)
119   // for (i = 2; i < len(input); i++)
120   //   a1_concat = tosa.concat(a1_concat, input[i], axis)
121   //
122   // Step 2: reshape to N+1 dimensions
123   // a2_reshape = tosa.reshape(a1_concat, new_rank)
124   //
125   // Step 3: Transpose if a new dimension is being added:
126   // if (axis == rank(values[0]):
127   //   // perm will be [1, 2, 3, 0]
128   //   a3_transpose = tosa.transpose(a2_reshape, perm)
129 
130   // Sanity check 1: make sure all input tensors have the same shape
131   // if input[0] has shape [A, B, C], input[1] to input[N-1] should also have
132   // shape[A, B, C]
133   RankedTensorType result_type =
134       result_value.getType().dyn_cast<RankedTensorType>();
135 
136   // Check for ranked tensor type.
137   if (!result_type) {
138     op->emitOpError("PackOp: result type not ranked tensor");
139     return llvm::None;
140   }
141 
142   // Valid axis in TF is [-rank(input), rank(input))
143   // Valid axis in TOSA is [0, rank(input))
144   // Plus rank(input) once if axis is negative.
145   RankedTensorType input_type =
146       op->getOperand(0).getType().dyn_cast<RankedTensorType>();
147   if (!input_type) {
148     op->emitOpError("PackOp: input type not ranked tensor");
149     return llvm::None;
150   }
151 
152   input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
153   if (!input_type) {
154     op->emitOpError("Input 0 type not ranked tensor.");
155     return llvm::None;
156   }
157   ArrayRef<int64_t> input0_tensor_shape = input_type.getShape();
158   int input_tensor_rank = input0_tensor_shape.size();
159 
160   for (int i = 1; i < inputs.size(); i++) {
161     input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
162     if (!input_type) {
163       op->emitOpError(llvm::formatv(
164           "reduce axis {} is not in valid range [-rank(input), rank(input))",
165           i));
166       return llvm::None;
167     }
168     ArrayRef<int64_t> next_tensor_shape = input_type.getShape();
169     if (next_tensor_shape.size() != input_tensor_rank) {
170       op->emitOpError("PackOp: input tensor rank mismatch.");
171       return llvm::None;
172     }
173     for (int d = 0; d < input0_tensor_shape.size(); d++) {
174       if (input0_tensor_shape[d] != next_tensor_shape[d]) {
175         op->emitOpError("PackOp: input tensor shape mismatch.");
176         return llvm::None;
177       }
178     }
179   }
180 
181   // If input tensors are rank 0, should reshape them to rank 1 size 1 before
182   // performing concat.
183   if (input_tensor_rank == 0) {
184     SmallVector<int64_t, 1> reshape_rank1_size1_shape({1});
185     RankedTensorType reshape_rank1_size1_type = RankedTensorType::get(
186         reshape_rank1_size1_shape, result_type.getElementType());
187     ArrayAttr shape_rank1_size1_attr =
188         rewriter.getI64ArrayAttr(reshape_rank1_size1_shape);
189     for (int i = 0; i < inputs.size(); i++) {
190       auto a0_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
191           rewriter, op->getLoc(), reshape_rank1_size1_type, inputs[i],
192           shape_rank1_size1_attr);
193       inputs[i] = a0_reshape_op.getResult();
194     }
195   }
196 
197   // Sanity check 2: axis can be from [0, rank(input)+1]
198   // Where rank(input)+1 means create a new dimension
199   // Negative values are also allowed up to -(rank(input)+1)
200   // where the axis "wraps around".
201   if (axis < 0) axis += input_tensor_rank;
202   if ((axis < 0) || (axis > (input_tensor_rank + 1))) {
203     op->emitOpError("PackOp: axis out of valid range.");
204     return llvm::None;
205   }
206 
207   // Sanity check 2: if input shape is [A, B, C], output shape should be [N,
208   // A, B, C]
209   // 2.a check output is rank(input) + 1
210   SmallVector<int64_t> output_shape_vals(result_type.getShape().begin(),
211                                          result_type.getShape().end());
212   if (output_shape_vals.size() != (input_tensor_rank + 1)) {
213     op->emitOpError("PackOp: output tensor rank mismatch.");
214     return llvm::None;
215   }
216   // 2.b check output rank 0 is N
217   if (output_shape_vals[axis] != inputs.size()) {
218     op->emitOpError("PackOp: output tensor shape mismatch.");
219     return llvm::None;
220   }
221   // Most of the cases when PackOp.axis() is within [0, rank(input) - 1].
222   // We can directly concatenate along that axis and perform the reshape.
223   // For example, stack N [A, B, C] input tensor ranks along axis = 1
224   // after concatenation, output will be [A, N * B, C]
225   // and then reshape it into [A, N, B, C]
226   // a special case would be PackOp.axis() equal to rank(input), in which case
227   // we can't directly concatenate along the PackOp.axis(), instead
228   // we concat along axis=0, and reshape into [N, A, B, C]
229   // and then we need an extra transpose to [A, B, C, N].
230   int64_t concat_axis;
231   SmallVector<int32_t> perm;
232   SmallVector<int64_t> reshape_output_shape;
233   if (axis == 0 && input_tensor_rank == 0) {
234     concat_axis = 0;
235   } else if (axis == input_tensor_rank) {
236     concat_axis = 0;
237 
238     // A special case when stack axis is equal to input tensor rank:
239     // Output shape is [A, B, C, N]
240     // so reshape output will be [N, A, B, C]
241     // and perm will be [1, 2, 3, 0].
242     reshape_output_shape.push_back(output_shape_vals[axis]);
243     for (int d = 0; d < input_tensor_rank; d++) {
244       perm.push_back(d + 1);
245       reshape_output_shape.push_back(output_shape_vals[d]);
246     }
247     perm.push_back(0);
248   } else {
249     // General case, doesn't need perm vector.
250     concat_axis = axis;
251     reshape_output_shape.assign(output_shape_vals.begin(),
252                                 output_shape_vals.end());
253   }
254   IntegerAttr concat_axis_attr = rewriter.getI64IntegerAttr(concat_axis);
255   ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_output_shape);
256 
257   // Concat output shape will depend on concat_axis. E.g. [N * A, B, C]
258   SmallVector<int64_t> concat_output_shape;
259   if (input_tensor_rank == 0) {
260     concat_output_shape.push_back(1);
261   } else {
262     for (int i = 0; i < input_tensor_rank; i++) {
263       concat_output_shape.push_back(input0_tensor_shape[i]);
264     }
265   }
266 
267   concat_output_shape[concat_axis] =
268       concat_output_shape[concat_axis] * inputs.size();
269   RankedTensorType concat_type = RankedTensorType::get(
270       ArrayRef<int64_t>(concat_output_shape), result_type.getElementType());
271 
272   SmallVector<Value> inputs_0;
273   for (int i = 0; i < inputs.size(); i++) {
274     inputs_0.push_back(inputs[i]);
275   }
276   auto a1_concat_op = CreateOpAndInfer<tosa::ConcatOp>(
277       rewriter, op->getLoc(), concat_type, inputs_0, concat_axis_attr);
278 
279   // Doesn't need reshape or transpose if input tensor is rank 0, since inputs
280   // are reshaped beforehand.
281   if (input_tensor_rank == 0) return a1_concat_op.getResult();
282 
283   // Reshape [N * A, B, C] to [N, A, B, C].
284   RankedTensorType reshape_output_type =
285       RankedTensorType::get(reshape_output_shape, result_type.getElementType());
286 
287   auto a2_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
288       rewriter, op->getLoc(), reshape_output_type, a1_concat_op.getResult(),
289       shape_attr);
290 
291   // If axis is equal to input tensor rank, then we need extra transpose
292   // [N, A, B, C] to [A, B, C, N]
293   if (axis == input_tensor_rank) {
294     llvm::Optional<Value> a3_transpose_perm = getConstTensor<int32_t>(
295         rewriter, op, perm, {static_cast<int64_t>(perm.size())});
296 
297     if (!a3_transpose_perm) return llvm::None;
298 
299     return CreateOpAndInfer<tosa::TransposeOp>(
300                rewriter, op->getLoc(), result_type, a2_reshape_op.getResult(),
301                a3_transpose_perm.getValue())
302         .getResult();
303   }
304 
305   return a2_reshape_op.getResult();
306 }
307 
308 // Lowers the Unpack operator to TOSA
convertUnpackOp(PatternRewriter & rewriter,Operation * op,Value input_value,int32_t axis)309 llvm::Optional<SmallVector<Value>> convertUnpackOp(PatternRewriter& rewriter,
310                                                    Operation* op,
311                                                    Value input_value,
312                                                    int32_t axis) {
313   RankedTensorType input_type =
314       input_value.getType().dyn_cast<RankedTensorType>();
315   if (!input_type) return llvm::None;
316 
317   auto input_shape = input_type.getShape();
318   int64_t input_rank = input_shape.size();
319 
320   SmallVector<Value> results_vec;
321 
322   // Negative axis allowed as long as it's within [-input_rank, input_rank).
323   if (axis < 0) axis += input_rank;
324   if ((axis < 0) || (axis > input_rank)) {
325     op->emitOpError("UnpackOp: axis out of valid range.");
326     return llvm::None;
327   }
328 
329   // Step 1: transpose 'axis' to leftmost dimension.
330   Value transposed_input_value;
331   if (axis != 0) {
332     SmallVector<int32_t> perm;
333     SmallVector<int64_t> a1_transpose_shape(input_rank);
334 
335     perm.push_back(axis);
336     for (int i = 0; i < input_rank; i++) {
337       if (i == axis) continue;
338       perm.push_back(i);
339     }
340 
341     llvm::Optional<Value> a1_transpose_perm = getConstTensor<int32_t>(
342         rewriter, op, perm, {static_cast<int64_t>(perm.size())});
343 
344     if (!a1_transpose_perm) return llvm::None;
345 
346     for (int i = 0; i < input_rank; i++) {
347       a1_transpose_shape[i] = input_shape[perm[i]];
348     }
349 
350     auto a1_transpose_op = CreateOpAndInfer<tosa::TransposeOp>(
351         rewriter, op->getLoc(),
352         RankedTensorType::get(a1_transpose_shape, input_type.getElementType()),
353         input_value, a1_transpose_perm.getValue());
354 
355     transposed_input_value = a1_transpose_op.getResult();
356   } else {
357     // Do nothing if axis is already at leftmost dimension.
358     transposed_input_value = input_value;
359   }
360 
361   // Step 2: slice [N, A, B, C] into N [A, B, C].
362   RankedTensorType transposed_input_type =
363       transposed_input_value.getType().dyn_cast<RankedTensorType>();
364   if (!transposed_input_type) return llvm::None;
365 
366   auto transposed_input_shape = transposed_input_type.getShape();
367   int64_t transposed_input_rank = transposed_input_shape.size();
368 
369   for (int i = 0; i < transposed_input_shape[0]; i++) {
370     SmallVector<int64_t> begin_vals, size_vals, shape_vals;
371 
372     for (int j = 0; j < transposed_input_rank; j++) {
373       if (j == 0) {
374         begin_vals.push_back(i);
375         size_vals.push_back(1);
376       } else {
377         begin_vals.push_back(0);
378         size_vals.push_back(transposed_input_shape[j]);
379         shape_vals.push_back(transposed_input_shape[j]);
380       }
381     }
382 
383     ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
384     ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
385 
386     auto a2_slice_op = CreateOpAndInfer<tosa::SliceOp>(
387         rewriter, op->getLoc(),
388         RankedTensorType::get(size_vals,
389                               transposed_input_type.getElementType()),
390         transposed_input_value, begin, size);
391 
392     auto a3_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
393         rewriter, op->getLoc(),
394         RankedTensorType::get(shape_vals,
395                               transposed_input_type.getElementType()),
396         a2_slice_op.getResult(), rewriter.getI64ArrayAttr(shape_vals));
397 
398     results_vec.push_back(a3_reshape_op.getResult());
399   }
400 
401   return results_vec;
402 }
403 
404 // Lowers the Select operator to TOSA.
convertSelectOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value condition_value,Value x_value,Value y_value)405 llvm::Optional<Value> convertSelectOp(PatternRewriter& rewriter, Operation* op,
406                                       Value result_value, Value condition_value,
407                                       Value x_value, Value y_value) {
408   RankedTensorType result_type =
409       result_value.getType().dyn_cast<RankedTensorType>();
410   RankedTensorType condition_type =
411       condition_value.getType().dyn_cast<RankedTensorType>();
412   RankedTensorType x_type = x_value.getType().dyn_cast<RankedTensorType>();
413   RankedTensorType y_type = y_value.getType().dyn_cast<RankedTensorType>();
414 
415   if (!result_type || !condition_type || !x_type || !y_type) {
416     op->emitOpError("Select: failed ranked tensor type check");
417     return llvm::None;
418   }
419 
420   // First check whether we need to reshape the condition to match
421   // the same rank as the then/else clauses.
422   if (result_type.getRank() == condition_type.getRank()) {
423     // Nothing to reshape.
424     return CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(), result_type,
425                                             condition_value, x_value, y_value)
426         .getResult();
427   }
428 
429   // Need to reshape the condition.
430   SmallVector<int64_t> new_cond_dims(
431       result_type.getRank() - condition_type.getRank(), 1);
432 
433   for (int i = 0; i < condition_type.getRank(); i++) {
434     new_cond_dims.push_back(condition_type.getShape()[i]);
435   }
436 
437   auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
438       rewriter, op->getLoc(),
439       RankedTensorType::get(new_cond_dims, condition_type.getElementType()),
440       condition_value, rewriter.getI64ArrayAttr(new_cond_dims));
441 
442   return CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(), result_type,
443                                           reshape_op, x_value, y_value)
444       .getResult();
445 }
446 
447 // Lowers the ZerosLike operator to TOSA by creating a constant
448 // of the desired type and shape.
convertZerosLikeOp(PatternRewriter & rewriter,Operation * op,Value result,Value input)449 llvm::Optional<Value> convertZerosLikeOp(PatternRewriter& rewriter,
450                                          Operation* op, Value result,
451                                          Value input) {
452   RankedTensorType result_type = result.getType().dyn_cast<RankedTensorType>();
453   if (!result_type) {
454     op->emitOpError("Zeroslike: result not ranked tensor type");
455     return llvm::None;
456   }
457 
458   RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
459   if (!input_type) {
460     op->emitOpError("Zeroslike: input not ranked tensor type");
461     return llvm::None;
462   }
463 
464   auto input_shape = input_type.getShape();
465 
466   ShapedType zero_type =
467       RankedTensorType::get(input_shape, input_type.getElementType());
468   Attribute zero_attr = rewriter.getZeroAttr(zero_type);
469 
470   return CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(), zero_type,
471                                          zero_attr.cast<ElementsAttr>())
472       .getResult();
473 }
474 
475 // Lowers the Mul operator to TOSA.  For quantized types, this requires
476 // inserting rescale operators before and after the operation.
convertMultiplyOp(PatternRewriter & rewriter,Operation * op,Value output_val,Value input_lhs_val,Value input_rhs_val)477 llvm::Optional<Value> convertMultiplyOp(PatternRewriter& rewriter,
478                                         Operation* op, Value output_val,
479                                         Value input_lhs_val,
480                                         Value input_rhs_val) {
481   ShapedType input_lhs_type = input_lhs_val.getType().dyn_cast<ShapedType>();
482   ShapedType input_rhs_type = input_rhs_val.getType().dyn_cast<ShapedType>();
483   ShapedType output_type = output_val.getType().dyn_cast<ShapedType>();
484   // Not a shaped tensor output
485   if (!input_lhs_type || !input_rhs_type || !output_type) return llvm::None;
486 
487   bool input_lhs_is_qtype =
488       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
489   bool input_rhs_is_qtype =
490       input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
491   bool output_is_qtype =
492       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
493 
494   if (input_lhs_is_qtype != output_is_qtype ||
495       input_rhs_is_qtype != output_is_qtype) {
496     op->emitOpError(
497         "ConvertMultiplyOp: input/output tensor should "
498         "be all quantized or all floating-point");
499     return llvm::None;
500   }
501 
502   if (output_is_qtype) {
503     ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
504     auto input_lhs_qtype = input_lhs_type.getElementType()
505                                .cast<mlir::quant::UniformQuantizedType>();
506     auto input_rhs_qtype = input_rhs_type.getElementType()
507                                .cast<mlir::quant::UniformQuantizedType>();
508     auto output_qtype =
509         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
510 
511     // MLIR store scale as double, but TFLite store scale as float
512     // Downcasting from double to float to match TFLite behavior
513     float in_lhs_scale = input_lhs_qtype.getScale();
514     float in_rhs_scale = input_rhs_qtype.getScale();
515     float output_scale = output_qtype.getScale();
516 
517     double output_rescale_scale = in_lhs_scale * in_rhs_scale / output_scale;
518 
519     // 16bits x 16bits -> 32bits
520     // 32bits can be rescaled with 32bits quantize multiplier back to 16bits
521     bool scale32 = true;
522 
523     Value op1_rescale_lhs = buildRescaleToInt32(
524         rewriter, op, input_lhs_val, 1.0f, input_lhs_qtype.getZeroPoint());
525     Value op2_rescale_rhs = buildRescaleToInt32(
526         rewriter, op, input_rhs_val, 1.0f, input_rhs_qtype.getZeroPoint());
527     auto op3_mul_op1_op2 =
528         CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), rescale_type,
529                                       op1_rescale_lhs, op2_rescale_rhs, 0);
530     return buildRescale(rewriter, op, output_type, op3_mul_op1_op2.getResult(),
531                         output_rescale_scale, 0, output_qtype.getZeroPoint(),
532                         true, scale32);
533   }
534 
535   return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
536                                        input_lhs_val, input_rhs_val, 0)
537       .getResult();
538 }
539 
540 // Lowers the SquaredDifference operator to TOSA.
convertSquaredDifferenceOp(PatternRewriter & rewriter,Operation * op,Value result,Value x,Value y)541 llvm::Optional<Value> convertSquaredDifferenceOp(PatternRewriter& rewriter,
542                                                  Operation* op, Value result,
543                                                  Value x, Value y) {
544   // Squared-difference is (x-y)*(x-y).
545   // This lowering calculates the difference and multiplies.
546   ShapedType result_type = result.getType().dyn_cast<ShapedType>();
547   if (!result_type) {
548     op->emitOpError("SquaredDifference: result not ranked tensor type");
549     return llvm::None;
550   }
551 
552   ShapedType x_type = x.getType().dyn_cast<ShapedType>();
553   ShapedType y_type = y.getType().dyn_cast<ShapedType>();
554   if (!x_type || !y_type) {
555     op->emitOpError("SquaredDifference: inputs not ranked tensor type");
556     return llvm::None;
557   }
558 
559   auto sub_op =
560       CreateOpAndInfer<tosa::SubOp>(rewriter, op->getLoc(), result_type, x, y);
561   return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), result_type,
562                                        sub_op.getResult(), sub_op.getResult(),
563                                        0)
564       .getResult();
565 }
566 
567 // Lowers the Round operator to TOSA.
convertRoundOp(PatternRewriter & rewriter,Operation * op,Value result,Value input)568 llvm::Optional<Value> convertRoundOp(PatternRewriter& rewriter, Operation* op,
569                                      Value result, Value input) {
570   // Implements banker's rounding by calculating floor(input + 0.5).
571   ShapedType result_type = result.getType().dyn_cast<ShapedType>();
572   if (!result_type) {
573     op->emitOpError("Round: result not shaped tensor type");
574     return llvm::None;
575   }
576 
577   ShapedType input_type = input.getType().dyn_cast<ShapedType>();
578   if (!input_type) {
579     op->emitOpError("Round: input not shaped tensor type");
580     return llvm::None;
581   }
582 
583   auto add_op = CreateOpAndInfer<tosa::AddOp>(
584       rewriter, op->getLoc(), result_type, input,
585       getTosaConstTensorSingleF32(rewriter, op, 0.5));
586 
587   return CreateOpAndInfer<tosa::FloorOp>(rewriter, op->getLoc(), result_type,
588                                          add_op.getResult())
589       .getResult();
590 }
591 
592 // Lowers ConcatV2 to TOSA Concat.
convertConcatV2Op(PatternRewriter & rewriter,Operation * op,Value result_value,SmallVectorImpl<Value> & values,int32_t axis)593 llvm::Optional<Value> convertConcatV2Op(PatternRewriter& rewriter,
594                                         Operation* op, Value result_value,
595                                         SmallVectorImpl<Value>& values,
596                                         int32_t axis) {
597   // Check all inputs are RankedTensorType
598   for (auto v : values) {
599     if (!v.getType().dyn_cast<RankedTensorType>()) {
600       op->emitOpError("ConcatV2Op: value type not ranked tensor.");
601       return llvm::None;
602     }
603   }
604 
605   // Check output is Ranked tensor type
606   if (!result_value.getType().dyn_cast<RankedTensorType>()) {
607     op->emitOpError("ConcatV2Op: output value type not ranked tensor.");
608     return llvm::None;
609   }
610 
611   RankedTensorType result_type =
612       result_value.getType().dyn_cast<RankedTensorType>();
613   mlir::quant::UniformQuantizedType result_quant_type =
614       result_type.getElementType()
615           .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
616 
617   SmallVector<Value> values_rescaled;
618 
619   for (auto v : values) {
620     RankedTensorType operand_type = v.getType().dyn_cast<RankedTensorType>();
621     mlir::quant::UniformQuantizedType operand_quant_type =
622         operand_type.getElementType()
623             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
624 
625     // tfl.concat currently allows different scales for each input tensor, which
626     // TFlite team will fix in:
627     // https://github.com/tensorflow/tensorflow/issues/39658
628     // For backward compatibility, we still need to support this artifact by
629     // scaling inputs to let them have the same scales.
630     if (result_quant_type && operand_quant_type) {
631       double operand_scale = static_cast<double>(operand_quant_type.getScale());
632       int32_t operand_zeropoint = operand_quant_type.getZeroPoint();
633 
634       double result_scale = static_cast<double>(result_quant_type.getScale());
635       int32_t result_zeropoint = result_quant_type.getZeroPoint();
636 
637       // Rescale input if scale is not equal to output tensor scale.
638       if (operand_scale != result_scale) {
639         RankedTensorType rescale_type =
640             RankedTensorType::get(operand_type.getShape(), result_quant_type);
641         Value rescale_op = buildRescale(
642             rewriter, op, rescale_type, v, operand_scale / result_scale,
643             operand_zeropoint, result_zeropoint, false, true);
644         values_rescaled.push_back(rescale_op);
645       } else {
646         values_rescaled.push_back(v);
647       }
648     } else {
649       values_rescaled.push_back(v);
650     }
651   }
652 
653   int32_t tensor_rank = result_type.getShape().size();
654 
655   if (axis < 0) axis += tensor_rank;
656   if ((axis < 0) || (axis > tensor_rank)) {
657     op->emitOpError("ConcatV2Op: axis out of valid range.");
658     return llvm::None;
659   }
660 
661   auto concat_op = CreateOpAndInfer<tosa::ConcatOp>(
662       rewriter, op->getLoc(), result_value.getType(), values_rescaled,
663       rewriter.getI64IntegerAttr(axis));
664 
665   return concat_op.getResult();
666 }
667 
668 // Lowers SpaceToBatchND to TOSA.
convertSpaceToBatchNDOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value block_shape_value,Value paddings_value)669 llvm::Optional<Value> convertSpaceToBatchNDOp(PatternRewriter& rewriter,
670                                               Operation* op, Value result_value,
671                                               Value input_value,
672                                               Value block_shape_value,
673                                               Value paddings_value) {
674   /////////////////////////////////////////////////
675   // Operator: output = SpaceToBatchND(input, block_shape, paddings)
676   // Lowering:
677   //
678   // SpaceToBatch input tensors are broken into three pieces:
679   //   (a) batch dimension (N in NHWC)
680   //   (b) input being transformed to batch dimension (typically H, W in NHWC)
681   //   (c) remainder of input (typically C in NHWC)
682   //
683   // Step 0. Generate padding constant for the first reshape.
684   //   No padding on the batch dimension
685   //   The input paddings array is addressed as [input_rank][2]
686   //   No padding on the remaining dimensions
687   //
688   //  a0_pad_const = tosa.const(input=Tensor<input_rank, 2>)
689   //
690   // Step 1. Pad the input tensor
691   //
692   //  a1_pad_input_op = tosa.pad(input=input, shape=a0_pad_const_op)
693   //
694   // Step 2. Reshape the padded structure of shape padded_shape to
695   // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
696   //    padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
697   //    remaining_shape
698   //
699   // block_rank = M (number of elements in block_shape)
700   // New rank: input_rank + block_rank
701   //
702   //  a2_reshape_a1_op = tosa.reshape(input=a1_pad_input_op, shape=a2_shape)
703   //
704   // Step 3. Transpose dimensions to:
705   //  block-shape +
706   //  [batch] +
707   //  [padded_shape[1] / block_shape[0],
708   // ...
709   //  [padded_shape[M] / block_shape[M-1]] +
710   //  remaining_shape
711   //
712   // a3_transpose_a2_op = tosa.tranpose(input=a2_reshape_a1_op,
713   // perms=a3_perm)
714   //
715   // Step 4. Reshape the transposed tensor to flatten block_shape stuff
716   // into the batch dimension with the following shape:
717   // [ batch * prod(block_shape)] +
718   // [ padded_shape[1] / block_shape[0],
719   //   ...,
720   // padded_shape[M] / block_shape[M-1]] +
721   // remaining_shape
722   //
723   //  a4_reshape_a3_op = tosa.reshape(input=a3_tranpose_a2_op,
724   //  shape=a3_shape)
725   //
726 
727   RankedTensorType result_type =
728       result_value.getType().dyn_cast<RankedTensorType>();
729   RankedTensorType input_type =
730       input_value.getType().dyn_cast<RankedTensorType>();
731   RankedTensorType block_shape_type =
732       block_shape_value.getType().dyn_cast<RankedTensorType>();
733   RankedTensorType paddings_type =
734       paddings_value.getType().dyn_cast<RankedTensorType>();
735 
736   // Not a ranked tensor output.
737   if (!result_type) {
738     op->emitOpError("SpaceToBatchND: result type not ranked tensor");
739     return llvm::None;
740   }
741   if (!input_type) {
742     op->emitOpError("SpaceToBatchND: input type not ranked tensor");
743     return llvm::None;
744   }
745   if (!block_shape_type) {
746     op->emitOpError("SpaceToBatchND: block shape type not ranked tensor");
747     return llvm::None;
748   }
749   if (!paddings_type) {
750     op->emitOpError("SpaceToBatchND: paddings type not ranked tensor");
751     return llvm::None;
752   }
753 
754   // Follow implementation in
755   // tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
756 
757   // So, to figure out the spatial_shape, remove the batch dimension and
758   // then use the next block_rank dimensions.  The remaining dimensions are
759   // remaining_shape.
760 
761   auto block_shape = block_shape_type.getShape();
762   auto input_shape = input_type.getShape();
763 
764   int block_rank = block_shape[0];
765   int batch_size = input_shape[0];
766   int input_rank = input_type.getRank();
767   int remaining_shape_rank = input_rank - block_rank - 1;
768   int block_num_elems = 1;
769   int padding_sum = 0;
770 
771   ElementsAttr block_shape_elems;
772   ElementsAttr paddings_elems;
773 
774   if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems)))
775     return llvm::None;
776 
777   if (!matchPattern(paddings_value, m_Constant(&paddings_elems)))
778     return llvm::None;
779 
780   SmallVector<int32_t> a0_pad_const(2 * (input_rank));
781   SmallVector<int64_t> padded_shape(input_rank);
782 
783   // 1. Pad based on paddings operand.  No padding on the batch dimension.
784   // The a0_pad_const array is addressed as [input_rank][2], but
785   // it is flattened to a 1D array because LLVM appears to only accept 1D.
786   //
787   // padded_shape[] is the shape of the padded output of step a1.
788   // The name is retained for consistency with the TF reference code.
789   padded_shape[0] = input_shape[0];
790 
791   // Batch dimension padding
792   a0_pad_const[0] = 0;
793   a0_pad_const[1] = 0;
794 
795   // This iterator seems to be the only reliable way to get
796   // int values out of a multi-dimensional ElementsAttr.
797   int idx = 0;
798 
799   for (auto i : paddings_elems.getValues<IntegerAttr>()) {
800     a0_pad_const[idx + 2] = i.getInt();
801     padding_sum += i.getInt();
802     idx++;
803   }
804 
805   // Insert padding on the spatial shape dimensions
806   for (int i = 0; i < block_rank; i++) {
807     padded_shape[i + 1] = input_shape[i + 1];
808     if (!ShapedType::isDynamic(padded_shape[i + 1])) {
809       int32_t lo_pad = a0_pad_const[2 * (i + 1) + 0];
810       int32_t hi_pad = a0_pad_const[2 * (i + 1) + 1];
811       padded_shape[i + 1] += lo_pad + hi_pad;
812     }
813   }
814 
815   // No padding on the remaining_shape dimensions
816   for (int i = 0; i < remaining_shape_rank; i++) {
817     a0_pad_const[2 * (i + block_rank + 1) + 0] = 0;
818     a0_pad_const[2 * (i + block_rank + 1) + 1] = 0;
819     padded_shape[i + block_rank + 1] = input_shape[i + block_rank + 1];
820   }
821 
822   RankedTensorType a0_pad_const_attr_type =
823       RankedTensorType::get({(input_rank), 2}, rewriter.getIntegerType(32));
824 
825   // Create a const op to generate the tensor type for the input padding array
826   auto a0_pad_const_op = rewriter.create<tosa::ConstOp>(
827       op->getLoc(), a0_pad_const_attr_type,
828       DenseElementsAttr::get(a0_pad_const_attr_type,
829                              llvm::makeArrayRef(a0_pad_const)));
830 
831   auto a1_pad_input_op = CreateOpAndInfer<tosa::PadOp>(
832       rewriter, op->getLoc(),
833       RankedTensorType::get(padded_shape, result_type.getElementType()),
834       input_value, a0_pad_const_op.getResult());
835 
836   // 2. Reshape the padded structure of shape padded_shape to
837   // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
838   //    padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
839   //    remaining_shape
840 
841   // block_rank = M (number of elements in block_shape)
842   // New rank: input_rank + block_rank
843   SmallVector<int64_t> a2_shape(1 + block_rank * 2 + remaining_shape_rank);
844 
845   // First dimension is batch.
846   a2_shape[0] = input_type.getShape()[0];
847   for (int i = 0; i < block_rank; i++) {
848     int32_t block_shape_val =
849         block_shape_elems.getValues<IntegerAttr>()[i].getInt();
850     a2_shape[1 + i * 2 + 0] = padded_shape[1 + i];
851     if (a2_shape[1 + i * 2 + 0] != ShapedType::kDynamicSize) {
852       a2_shape[1 + i * 2 + 0] /= block_shape_val;
853     }
854 
855     a2_shape[1 + i * 2 + 1] = block_shape_val;
856     block_num_elems *= block_shape_val;
857   }
858 
859   // Copy in the remaining block shape.
860   for (int i = 0; i < remaining_shape_rank; i++) {
861     a2_shape[1 + block_rank * 2 + i] = input_shape[1 + block_rank + i];
862   }
863 
864   auto a2_reshape_a1_op = CreateOpAndInfer<tosa::ReshapeOp>(
865       rewriter, op->getLoc(),
866       RankedTensorType::get(a2_shape, result_type.getElementType()),
867       a1_pad_input_op.getResult(), rewriter.getI64ArrayAttr(a2_shape));
868 
869   // 3. Transpose dimensions to:
870   //  block-shape +
871   //  [batch] +
872   //  [padded_shape[1] / block_shape[0],
873   // ...
874   //  [padded_shape[M] / block_shape[M-1]] +
875   //  remaining_shape
876   int32_t a2_reshape_a1_rank =
877       a2_reshape_a1_op.getResult().getType().cast<RankedTensorType>().getRank();
878   SmallVector<int32_t> a3_perm(a2_reshape_a1_rank);
879   SmallVector<int64_t> a3_transpose_shape(a2_reshape_a1_rank);
880 
881   for (int i = 0; i < block_rank; i++) {
882     a3_perm[i] = 1 + 2 * i + 1;
883     a3_perm[block_rank + 1 + i] = 1 + 2 * i;
884   }
885   a3_perm[block_rank] = 0;
886   for (int i = 1 + block_rank * 2; i < a2_reshape_a1_rank; i++) {
887     a3_perm[i] = i;
888   }
889 
890   for (int i = 0; i < a3_transpose_shape.size(); i++) {
891     a3_transpose_shape[i] = a2_shape[a3_perm[i]];
892   }
893 
894   llvm::Optional<Value> a3_transpose_const = getConstTensor<int32_t>(
895       rewriter, op, a3_perm, {static_cast<int64_t>(a3_perm.size())});
896 
897   if (!a3_transpose_const) return llvm::None;
898 
899   auto a3_transpose_a2_op = CreateOpAndInfer<tosa::TransposeOp>(
900       rewriter, op->getLoc(),
901       RankedTensorType::get(a3_transpose_shape, result_type.getElementType()),
902       a2_reshape_a1_op.getResult(), a3_transpose_const.getValue());
903 
904   // 4. Reshape the transposed tensor to flatten block_shape
905   // into the batch dimension with the following shape:
906   // [ batch * prod(block_shape)] +
907   // [ padded_shape[1] / block_shape[0],
908   //   ...,
909   // padded_shape[M] / block_shape[M-1]] +
910   // remaining_shape
911   SmallVector<int64_t> a4_reshape_shape(input_rank);
912 
913   // Batch
914   a4_reshape_shape[0] = multiply_dims(batch_size, block_num_elems);
915 
916   // padded shape / block_shape.
917   for (int i = 0; i < block_rank; i++) {
918     int32_t block_shape_val =
919         block_shape_elems.getValues<IntegerAttr>()[i].getInt();
920     a4_reshape_shape[i + 1] = padded_shape[i + 1];
921     if (a4_reshape_shape[i + 1] != ShapedType::kDynamicSize) {
922       a4_reshape_shape[i + 1] /= block_shape_val;
923     }
924   }
925 
926   // Copy in remainder shape.
927   for (int i = 0; i < remaining_shape_rank; i++) {
928     a4_reshape_shape[1 + block_rank + i] = input_shape[1 + block_rank + i];
929   }
930 
931   return CreateOpAndInfer<tosa::ReshapeOp>(
932              rewriter, op->getLoc(), result_type,
933              a3_transpose_a2_op.getResult(),
934              rewriter.getI64ArrayAttr(a4_reshape_shape))
935       .getResult();
936 }
937 
938 // Lowers BatchToSpaceND to TOSA.
convertBatchToSpaceNDOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value block_shape_value,Value crops_value)939 llvm::Optional<Value> convertBatchToSpaceNDOp(PatternRewriter& rewriter,
940                                               Operation* op, Value result_value,
941                                               Value input_value,
942                                               Value block_shape_value,
943                                               Value crops_value) {
944   /////////////////////////////////////////////////
945   // Operator: output = BatchToSpaceND(input, block_shape, clips)
946   // Lowering:
947   //
948   // BatchToSpace input tensors are broken into three pieces:
949   //   (a) batch dimension (N in NHWC)
950   //   (b) input being transformed from batch dimension (typically H, W in
951   //   NHWC)
952   //   (c) remainder of input (typically C in NHWC)
953   //
954   // Step 1. Reshape input to:
955   // [block_shape[0],
956   // ...
957   // [block_shape[M-1],
958   // [batch / prod(block_shape)]
959   // [input_shape[1],
960   // ...
961   // [input_shape[N-1]
962   //
963   // a1_reshape_input_op = tosa.reshape(input=input, shape=a1_shape)
964   //
965   // Step 2. Permute to shape
966   // [ batch / prod(block_shape) ],
967   // [ input_shape[1] ], [ block_shape[1] ]
968   //  ...
969   // [ input_shape[M] ], [ block_shape[M-1]
970   // + remaining_input_shapes input_shape[M .. N-1]
971   //
972   // a2_transpose_a1 = tosa.transpose(input=a1_reshape_input_op,
973   // shape=a2_shape)
974   //
975   // Step 3. Reshape to:
976   // [ batch / prod(block_shape) ],
977   // [input_shape[1] * block_shape[0] ],
978   //    ..
979   // [input_shape[M * block_shape[M-1],
980   // + remaining input shapes [input_shape[M+1.. N-1]]
981   //
982   // a3_reshape_a2 = tosa.reshape(input=a2_transpose_a1, shape=a3_shape)
983   //
984   // Step 4. Crop the start/end dimensions according to crops of the
985   // a3_reshape_a2 shape
986   //
987   // a4_slice_a3 = tosa.slice(input=a3_reshape_a2, start=a4_start,
988   // size=a4_size)
989 
990   RankedTensorType result_type =
991       result_value.getType().dyn_cast<RankedTensorType>();
992   RankedTensorType input_type =
993       input_value.getType().dyn_cast<RankedTensorType>();
994   RankedTensorType block_shape_type =
995       block_shape_value.getType().dyn_cast<RankedTensorType>();
996   RankedTensorType crops_type =
997       crops_value.getType().dyn_cast<RankedTensorType>();
998 
999   if (!result_type) {
1000     op->emitOpError("BatchToSpaceND: result type not ranked tensor");
1001     return llvm::None;
1002   }
1003   if (!input_type) {
1004     op->emitOpError("BatchToSpaceND: input type not ranked tensor");
1005     return llvm::None;
1006   }
1007   if (!block_shape_type) {
1008     op->emitOpError("BatchToSpaceND: block shape type not ranked tensor");
1009     return llvm::None;
1010   }
1011   if (!crops_type) {
1012     op->emitOpError("BatchToSpaceND: crops type not ranked tensor");
1013     return llvm::None;
1014   }
1015 
1016   // Another 4-step process
1017   int block_rank = block_shape_type.getShape()[0];
1018   int input_rank = input_type.getRank();
1019   int crops_dims = crops_type.getShape()[0];
1020   int remaining_shape_rank = input_rank - block_rank - 1;
1021   auto input_shape = input_type.getShape();
1022 
1023   ElementsAttr block_shape_elems;
1024   ElementsAttr crops_elems;
1025 
1026   if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems))) {
1027     op->emitOpError("BatchToSpaceND: block_shape not a constant");
1028     return llvm::None;
1029   }
1030 
1031   if (!matchPattern(crops_value, m_Constant(&crops_elems))) {
1032     op->emitOpError("BatchToSpaceND: crops not a constant");
1033     return llvm::None;
1034   }
1035 
1036   SmallVector<int64_t> block_shape(block_rank);
1037   SmallVector<std::pair<int64_t, int64_t>> crops(crops_dims);
1038 
1039   // Extract values for block_shape and crops now.
1040   int block_num_elems = 1;
1041   for (int i = 0; i < block_rank; i++) {
1042     int block_shape_val =
1043         rewriter
1044             .getI32IntegerAttr(
1045                 block_shape_elems.getValues<IntegerAttr>()[i].getInt())
1046             .getInt();
1047     block_num_elems *= block_shape_val;
1048     block_shape[i] = block_shape_val;
1049   }
1050 
1051   // This iterator seems to be the only reliable way to get
1052   // int values out of a multi-dimensional ElementsAttr
1053   SmallVector<int32_t> crops_const(2 * (crops_dims));
1054   int idx = 0;
1055   for (auto i : crops_elems.getValues<IntegerAttr>()) {
1056     crops_const[idx++] = i.getInt();
1057   }
1058 
1059   for (int i = 0; i < crops_dims; i++) {
1060     int crops_lo = crops_const[i * crops_dims + 0];
1061     int crops_hi = crops_const[i * crops_dims + 1];
1062     crops[i] = std::make_pair(crops_lo, crops_hi);
1063   }
1064 
1065   // Step 1. Reshape input to:
1066   // [block_shape[0],
1067   // ...
1068   // [block_shape[M-1],
1069   // [batch / prod(block_shape)]
1070   // [input_shape[1],
1071   // ...
1072   // [input_shape[N-1]
1073   SmallVector<int64_t> a1_shape(block_rank + input_rank);
1074 
1075   for (int i = 0; i < block_rank; i++) a1_shape[i] = block_shape[i];
1076 
1077   a1_shape[block_rank] = (input_shape[0] == ShapedType::kDynamicSize)
1078                              ? ShapedType::kDynamicSize
1079                              : input_shape[0] / block_num_elems;
1080 
1081   for (int i = 0; i < input_rank - 1; i++)
1082     a1_shape[i + block_rank + 1] = input_shape[i + 1];
1083 
1084   auto a1_reshape_input_op = CreateOpAndInfer<tosa::ReshapeOp>(
1085       rewriter, op->getLoc(),
1086       RankedTensorType::get(a1_shape, result_type.getElementType()),
1087       input_value, rewriter.getI64ArrayAttr(a1_shape));
1088 
1089   // 2. Permute to shape
1090   // [ batch / prod(block_shape) ],
1091   // [ input_shape[1] ], [ block_shape[0] ]
1092   //  ...
1093   // [ input_shape[M] ], [ block_shape[M-1]
1094   // + remaining_input_shapes input_shape[M+1 .. N-1]
1095 
1096   // 2a. calculate the permutation
1097   SmallVector<int32_t> a2_perm(block_rank + input_rank);
1098   SmallVector<int64_t> a2_transpose_shape(block_rank + input_rank);
1099 
1100   a2_perm[0] = block_rank;
1101   for (int i = 0; i < block_rank; i++) {
1102     a2_perm[1 + i * 2 + 0] = block_rank + 1 + i;
1103     a2_perm[1 + i * 2 + 1] = i;
1104   }
1105 
1106   for (int i = 0; i < remaining_shape_rank; i++) {
1107     a2_perm[1 + 2 * block_rank + i] = 1 + 2 * block_rank + i;
1108   }
1109 
1110   // 2b. calculate the a2_permuted shape
1111   for (int i = 0; i < (block_rank + input_rank); i++) {
1112     a2_transpose_shape[i] = a1_shape[a2_perm[i]];
1113   }
1114 
1115   llvm::Optional<Value> a2_transpose_perm = getConstTensor<int32_t>(
1116       rewriter, op, a2_perm, {static_cast<int64_t>(a2_perm.size())});
1117 
1118   if (!a2_transpose_perm) return llvm::None;
1119 
1120   auto a2_transpose_a1_op = CreateOpAndInfer<tosa::TransposeOp>(
1121       rewriter, op->getLoc(),
1122       RankedTensorType::get(a2_transpose_shape, result_type.getElementType()),
1123       a1_reshape_input_op.getResult(), a2_transpose_perm.getValue());
1124 
1125   // Step 3. Reshape to:
1126   // [ batch / prod(block_shape) ],
1127   // [input_shape[1] * block_shape[0] ],
1128   //    ..
1129   // [input_shape[M * block_shape[M-1],
1130   // + remaining input shapes [input_shape[M+1.. N-1]]
1131   SmallVector<int64_t> a4_shape(input_rank);
1132 
1133   a4_shape[0] = input_shape[0];
1134   if (a4_shape[0] != ShapedType::kDynamicSize) {
1135     a4_shape[0] /= block_num_elems;
1136   }
1137   for (int i = 0; i < block_rank; i++) {
1138     a4_shape[1 + i] = multiply_dims(input_shape[i + 1], block_shape[i]);
1139   }
1140   for (int i = 0; i < remaining_shape_rank; i++) {
1141     a4_shape[1 + block_rank + i] = input_shape[block_rank + 1 + i];
1142   }
1143 
1144   auto a3_reshape_a2 = CreateOpAndInfer<tosa::ReshapeOp>(
1145       rewriter, op->getLoc(),
1146       RankedTensorType::get(a4_shape, result_type.getElementType()),
1147       a2_transpose_a1_op.getResult(), rewriter.getI64ArrayAttr(a4_shape));
1148 
1149   // 4. Crop the start/end dimensions on 'spatial dimension' according to
1150   // crops
1151   // Use a slice operator to do the cropping.
1152   //
1153   // Calculate a beginning point and a size:
1154   // - Begin is the origin, offset by the lo crop amount in each dimension
1155   // - Size is the reshaped tensor size, minus the quantity (lo + hi) for each
1156   // dimension
1157   SmallVector<int64_t> a4_begin_vals(input_rank), a4_size_vals(input_rank);
1158 
1159   for (int i = 0; i < input_rank; i++) {
1160     // Batch dimension and remaining dimensions.
1161     if (i == 0 || i > crops_dims) {
1162       a4_begin_vals[i] = 0;
1163       a4_size_vals[i] = result_type.getShape()[i];
1164     } else {
1165       // Spatial dimension.
1166       assert(i - 1 >= 0 && i - 1 < crops_dims);
1167       a4_begin_vals[i] = crops[i - 1].first;
1168       a4_size_vals[i] = a4_shape[i] - crops[i - 1].first - crops[i - 1].second;
1169     }
1170   }
1171 
1172   return CreateOpAndInfer<tosa::SliceOp>(
1173              rewriter, op->getLoc(),
1174              RankedTensorType::get(a4_size_vals, result_type.getElementType()),
1175              a3_reshape_a2.getResult(), rewriter.getI64ArrayAttr(a4_begin_vals),
1176              rewriter.getI64ArrayAttr(a4_size_vals))
1177       .getResult();
1178 }
1179 
1180 // Lowers ExpandDims to TOSA.
convertExpandDimsOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value dim_value)1181 llvm::Optional<Value> convertExpandDimsOp(PatternRewriter& rewriter,
1182                                           Operation* op, Value result_value,
1183                                           Value input_value, Value dim_value) {
1184   // Lowers to a reshape op with 1's inserted in the appropriate dimensions.
1185   RankedTensorType output_type =
1186       result_value.getType().dyn_cast<RankedTensorType>();
1187   // Not a ranked tensor output
1188   if (!output_type) {
1189     op->emitOpError("ExpandDims: output type not ranked tensor");
1190     return llvm::None;
1191   }
1192 
1193   RankedTensorType input_type =
1194       input_value.getType().dyn_cast<RankedTensorType>();
1195   if (!input_type) {
1196     op->emitOpError("ExpandDims: input type not ranked tensor");
1197     return llvm::None;
1198   }
1199 
1200   auto input_shape = input_type.getShape();
1201 
1202   ElementsAttr dim_elem;
1203   if (!matchPattern(dim_value, m_Constant(&dim_elem))) return llvm::None;
1204 
1205   if (dim_elem.getNumElements() > 1) {
1206     op->emitOpError("ExpandDims: expected single dimension to expand");
1207     return llvm::None;
1208   }
1209   int32_t dim = dim_elem.getValues<IntegerAttr>()[0].getInt();
1210   int32_t input_size = input_shape.size();
1211   SmallVector<int64_t> reshape_dims;
1212   if (dim >= input_size) {  // add dim at end of tensor
1213     dim = input_size;
1214     for (int i = 0; i < input_shape.size(); i++) {
1215       reshape_dims.emplace_back(input_shape[i]);
1216     }
1217     reshape_dims.emplace_back(1);
1218   } else {
1219     if (dim < 0) {
1220       dim += input_size;
1221       if (dim < 0) {
1222         op->emitOpError(
1223             "ExpandDims: dimension to expand + size of input shape < 0");
1224         return llvm::None;
1225       }
1226     }
1227     for (int i = 0; i < input_size; i++) {
1228       if (i == dim) {
1229         reshape_dims.emplace_back(1);
1230       }
1231       reshape_dims.emplace_back(input_shape[i]);
1232     }
1233   }
1234 
1235   ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
1236 
1237   return CreateOpAndInfer<tosa::ReshapeOp>(rewriter, op->getLoc(), output_type,
1238                                            input_value, shape_attr)
1239       .getResult();
1240 }
1241 
1242 // Lowers Squeeze to TOSA.
convertSqueezeOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,SmallVectorImpl<int32_t> & squeeze_dims)1243 llvm::Optional<Value> convertSqueezeOp(PatternRewriter& rewriter, Operation* op,
1244                                        Value result_value, Value input_value,
1245                                        SmallVectorImpl<int32_t>& squeeze_dims) {
1246   // Lowers to a reshape op where dimensions in squeeze_dims with size=1
1247   // are removed.
1248   RankedTensorType output_type =
1249       result_value.getType().dyn_cast<RankedTensorType>();
1250   // Not a ranked tensor output
1251   if (!output_type) {
1252     op->emitOpError("Squeeze: output type not ranked tensor");
1253     return llvm::None;
1254   }
1255 
1256   RankedTensorType input_type =
1257       input_value.getType().dyn_cast<RankedTensorType>();
1258   if (!input_type) {
1259     op->emitOpError("Squeeze: input type not ranked tensor");
1260     return llvm::None;
1261   }
1262 
1263   auto input_shape = input_type.getShape();
1264 
1265   SmallVector<int64_t> reshape_dims;
1266 
1267   if (squeeze_dims.empty()) {  // remove all 1-dims
1268     for (int i = 0; i < input_shape.size(); i++) {
1269       if (input_shape[i] != 1) {
1270         reshape_dims.emplace_back(input_shape[i]);
1271       }
1272     }
1273   } else {
1274     for (auto& dim : squeeze_dims) {
1275       dim = dim < 0 ? dim + input_shape.size() : dim;
1276     }
1277 
1278     // Remove only specified dims.
1279     // First sort the array so they can be picked off in sequence.
1280     std::sort(squeeze_dims.begin(), squeeze_dims.end(),
1281               [](const int32_t a, const int32_t b) { return a < b; });
1282 
1283     int pos = 0;
1284     auto dim = squeeze_dims[pos];
1285     for (int i = 0; i < input_shape.size(); i++) {
1286       if (i == dim) {
1287         pos = pos + 1;
1288         if (pos < squeeze_dims.size())
1289           dim = squeeze_dims[pos];
1290         else
1291           dim = -1;  // Invalid
1292       } else {
1293         reshape_dims.emplace_back(input_shape[i]);
1294       }
1295     }
1296   }
1297 
1298   ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
1299 
1300   return CreateOpAndInfer<tosa::ReshapeOp>(rewriter, op->getLoc(), output_type,
1301                                            input_value, shape_attr)
1302       .getResult();
1303 }
1304 
1305 // Lowers ELU to a sequence of TOSA ops.
convertEluOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value features_value)1306 llvm::Optional<Value> convertEluOp(PatternRewriter& rewriter, Operation* op,
1307                                    Value result_value, Value features_value) {
1308   // Lowers Elu using the following formula:
1309   // elu(x) = x < 0 ? (exp(x) - 1) : x
1310   // one = const({1});
1311   // zero = const({0});
1312   // one_bcast = reshape(one, [1, ..., rank(x) - 1])
1313   // zero_bcast = reshape(zero, [1, ..., rank(x) - 1])
1314   // a1 = exp(x);
1315   // a2 = sub(a1, one_bcast)
1316   // a3 = ge(x, zero_bcast)
1317   // a4 = select(a3, x, a2)
1318   RankedTensorType output_type =
1319       result_value.getType().dyn_cast<RankedTensorType>();
1320   // Not a ranked tensor output
1321   if (!output_type) {
1322     op->emitOpError("Elu: output type not ranked tensor");
1323     return llvm::None;
1324   }
1325 
1326   int32_t input_rank = output_type.getShape().size();
1327   SmallVector<int64_t> bcast_shape(input_rank, 1);
1328 
1329   // Can't directly create size=1, rank=rank(input) tensor because
1330   // it will be optimized out.  Instead, create rank0 tensor and reshape later.
1331   Value one_const_op = getTosaConstTensorSingleF32(rewriter, op, 1.0);
1332 
1333   Value zero_const_op = getTosaConstTensorSingleF32(rewriter, op, 0.0);
1334 
1335   auto a1_exp_in_op = CreateOpAndInfer<tosa::ExpOp>(
1336       rewriter, op->getLoc(), output_type, features_value);
1337 
1338   auto a2_sub_a1_one_op =
1339       CreateOpAndInfer<tosa::SubOp>(rewriter, op->getLoc(), output_type,
1340                                     a1_exp_in_op.getResult(), one_const_op);
1341 
1342   auto a3_ge_in_zero_op = CreateOpAndInfer<tosa::GreaterEqualOp>(
1343       rewriter, op->getLoc(),
1344       RankedTensorType::get(output_type.getShape(), rewriter.getIntegerType(1)),
1345       features_value, zero_const_op);
1346 
1347   return CreateOpAndInfer<tosa::SelectOp>(
1348              rewriter, op->getLoc(), output_type, a3_ge_in_zero_op.getResult(),
1349              features_value, a2_sub_a1_one_op.getResult())
1350       .getResult();
1351 }
1352 
1353 // Lowers Softmax to a sequence of TOSA ops.
convertSoftmaxOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value logits_value,double beta)1354 llvm::Optional<Value> convertSoftmaxOp(PatternRewriter& rewriter, Operation* op,
1355                                        Value result_value, Value logits_value,
1356                                        double beta) {
1357   // softmax = exp(logits) / reduce_sum(exp(logits), -1)
1358   //
1359   // or equivalently multiply exp(-max(logits)) to both numerator and
1360   // denominator we get:
1361   //
1362   // softmax = exp(logits - max(logits)) / reduce_sum(exp(logits -
1363   // max(logits)), -1)
1364   //
1365   // We'll use first version for direct fp lowering, and second version for
1366   // quantized lowering since second one we can restrict input to exp() be
1367   // negative, and thus LUT can always be within [0.0, 1.0].
1368   RankedTensorType output_type =
1369       result_value.getType().dyn_cast<RankedTensorType>();
1370   RankedTensorType input_type =
1371       logits_value.getType().dyn_cast<RankedTensorType>();
1372 
1373   // Not a ranked tensor input/output
1374   if (!output_type || !input_type) {
1375     op->emitOpError("Softmax: input and result not ranked tensors");
1376     return llvm::None;
1377   }
1378 
1379   // reduce_sum on last dimension
1380   int32_t input_rank = input_type.getShape().size();
1381   ArrayRef<int64_t> logits_shape = output_type.getShape();
1382 
1383   if (input_type.getElementType().isa<mlir::quant::QuantizedType>() &&
1384       output_type.getElementType().isa<mlir::quant::QuantizedType>()) {
1385     SmallVector<int64_t> rsum_shape_v(input_type.getShape().begin(),
1386                                       input_type.getShape().end() - 1);
1387     rsum_shape_v.push_back(1);
1388     ArrayRef<int64_t> rsum_shape(rsum_shape_v);
1389     // The if condition already checks if these are UQTs
1390     mlir::quant::UniformQuantizedType in_quant_type =
1391         input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
1392     mlir::quant::UniformQuantizedType out_quant_type =
1393         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
1394 
1395     auto int16_element_qtype = mlir::quant::UniformQuantizedType::get(
1396         true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
1397         -32768, 32767);
1398     RankedTensorType int16_logits_type =
1399         RankedTensorType::get(logits_shape, int16_element_qtype);
1400     RankedTensorType int32_logits_type =
1401         RankedTensorType::get(logits_shape, rewriter.getIntegerType(32));
1402     RankedTensorType int16_rsum_type =
1403         RankedTensorType::get(rsum_shape, int16_element_qtype);
1404     RankedTensorType int32_rsum_type =
1405         RankedTensorType::get(rsum_shape, rewriter.getIntegerType(32));
1406 
1407     if (in_quant_type.getStorageTypeIntegralWidth() == 8) {
1408       // Step 1. get x - max(x)
1409       Value op1_rescale_in =
1410           buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f,
1411                        in_quant_type.getZeroPoint(), 0, false, true);
1412 
1413       auto op2_reducemax_op1 = CreateOpAndInfer<tosa::ReduceMaxOp>(
1414           rewriter, op->getLoc(), int32_rsum_type, op1_rescale_in,
1415           rewriter.getI64IntegerAttr(input_rank - 1));
1416 
1417       auto op3_sub_op1_op2 = CreateOpAndInfer<tosa::SubOp>(
1418           rewriter, op->getLoc(), int32_logits_type, op1_rescale_in,
1419           op2_reducemax_op1.getResult());
1420 
1421       // Step 2. get exp() result
1422       // Implemented with 4 16-bit table lookup
1423       // We use a 16-bit table lookup, as the result of x - max(x) for
1424       // 8-bit values is a 9-bit value. We only use the bottom 8 bits of each
1425       // table to avoid having the slope between two 16-bit table entries be
1426       // greater than 16 bits, causing potential interpolation errors
1427       auto exp_func = [](double x) -> double { return std::exp(x); };
1428 
1429       Value exp_table_const_01, exp_table_const_02, exp_table_const_03,
1430           exp_table_const_04;
1431       getTosaConst32bitTable(rewriter, op, beta * in_quant_type.getScale(), 0,
1432                              exp_func, exp_table_const_01, exp_table_const_02,
1433                              exp_table_const_03, exp_table_const_04);
1434 
1435       Value op4_rescale_op3 =
1436           buildRescale(rewriter, op, int16_logits_type,
1437                        op3_sub_op1_op2.getResult(), 128.0, 0, 0, false, true);
1438 
1439       // Input is 9.7, where lower 7 bits are all zeros.
1440       // Output is 23 bits, where lower 7 bits should be all zeros as well,
1441       // since there's no interpolation here.
1442       auto op5_table_op4_bits_31_24 = CreateOpAndInfer<tosa::TableOp>(
1443           rewriter, op->getLoc(), int32_logits_type, op4_rescale_op3,
1444           exp_table_const_01);
1445 
1446       auto op6_table_op4_bits_23_16 = CreateOpAndInfer<tosa::TableOp>(
1447           rewriter, op->getLoc(), int32_logits_type, op4_rescale_op3,
1448           exp_table_const_02);
1449 
1450       auto op7_table_op4_bits_15_8 = CreateOpAndInfer<tosa::TableOp>(
1451           rewriter, op->getLoc(), int32_logits_type, op4_rescale_op3,
1452           exp_table_const_03);
1453 
1454       auto op8_table_op4_bits_7_0 = CreateOpAndInfer<tosa::TableOp>(
1455           rewriter, op->getLoc(), int32_logits_type, op4_rescale_op3,
1456           exp_table_const_04);
1457 
1458       // To get 16 bits upper/lower value, we need to right shift 7 bits
1459       // And then we reconstruct 32-bit value we need (upper << 24) + lower
1460       // So effectively we left shift the 1st group of 8-bit with 17 bits
1461       auto op7_lshift_op5 = CreateOpAndInfer<tosa::LogicalLeftShiftOp>(
1462           rewriter, op->getLoc(), int32_logits_type,
1463           op5_table_op4_bits_31_24.getResult(),
1464           getTosaConstTensorSingleI32(rewriter, op, 17));
1465 
1466       // For the 2nd group of 8-bit, we need to >> 7 AND << 16 ==> << 9
1467       auto op8_lshift_op6 = CreateOpAndInfer<tosa::LogicalLeftShiftOp>(
1468           rewriter, op->getLoc(), int32_logits_type,
1469           op6_table_op4_bits_23_16.getResult(),
1470           getTosaConstTensorSingleI32(rewriter, op, 9));
1471 
1472       // For the 3rd 8-bit, we need to >> 7 AND << 8 ==> << 1
1473       auto op9_lshift_op7 = CreateOpAndInfer<tosa::LogicalLeftShiftOp>(
1474           rewriter, op->getLoc(), int32_logits_type,
1475           op7_table_op4_bits_15_8.getResult(),
1476           getTosaConstTensorSingleI32(rewriter, op, 1));
1477 
1478       // For the last 8-bit, we only need to >> 7
1479       auto op10_rshift_op8 = CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
1480           rewriter, op->getLoc(), int32_logits_type,
1481           op8_table_op4_bits_7_0.getResult(),
1482           getTosaConstTensorSingleI32(rewriter, op, 7), true);
1483 
1484       // Add together all the 8-bit groups
1485       // Add [1+2]
1486       auto op11_add_op7_op8 = CreateOpAndInfer<tosa::AddOp>(
1487           rewriter, op->getLoc(), int32_logits_type, op7_lshift_op5.getResult(),
1488           op8_lshift_op6.getResult());
1489 
1490       // Add [1+2+3]
1491       auto op12_add_op11_op9 = CreateOpAndInfer<tosa::AddOp>(
1492           rewriter, op->getLoc(), int32_logits_type,
1493           op11_add_op7_op8.getResult(), op9_lshift_op7.getResult());
1494 
1495       // Add [1+2+3+4]
1496       auto op13_add_op12_op10 = CreateOpAndInfer<tosa::AddOp>(
1497           rewriter, op->getLoc(), int32_logits_type,
1498           op12_add_op11_op9.getResult(), op10_rshift_op8.getResult());
1499 
1500       // Step 3. get sum(exp()). output 12.19
1501       auto op14_rshift_op13_12 = CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
1502           rewriter, op->getLoc(), int32_logits_type,
1503           op13_add_op12_op10.getResult(),
1504           getTosaConstTensorSingleI32(rewriter, op, 12), true);
1505 
1506       auto op15_reducesum_op14 = CreateOpAndInfer<tosa::ReduceSumOp>(
1507           rewriter, op->getLoc(), int32_rsum_type,
1508           op14_rshift_op13_12.getResult(),
1509           rewriter.getI64IntegerAttr(input_rank - 1));
1510 
1511       // Step 4. calculate reciprocal(sum(exp()))
1512       // CLZ returns the number of leading zeros which equals to headroom + 1
1513       auto op16_clz_op15 =
1514           CreateOpAndInfer<tosa::ClzOp>(rewriter, op->getLoc(), int32_rsum_type,
1515                                         op15_reducesum_op14.getResult());
1516 
1517       // minus one to get headroom
1518       auto op17_sub_op16 = CreateOpAndInfer<tosa::SubOp>(
1519           rewriter, op->getLoc(), int32_rsum_type, op16_clz_op15.getResult(),
1520           getTosaConstTensorSingleI32(rewriter, op, 1));
1521 
1522       // Left shift to get s1.30 format
1523       auto op18_lshift_op15_op17 = CreateOpAndInfer<tosa::LogicalLeftShiftOp>(
1524           rewriter, op->getLoc(), int32_rsum_type,
1525           op15_reducesum_op14.getResult(), op17_sub_op16.getResult());
1526 
1527       // Step 5. Calculate one_over_one_plus_x() with Newton-Raphson division
1528       // with 3 iterations.
1529       // Need two magic constants 48/17 and -32/17 from Newton-Raphson algorithm
1530       // We need to operate in s2.29 since 48/17 is > 2.0
1531       // Reference: gemmlowp/fixedpoint/fixedpoint.h
1532       Value half_denominator = op18_lshift_op15_op17.getResult();
1533       Value four = getTosaConstTensorSingleI32(rewriter, op, 4);
1534       Value F2_one = getTosaConstTensorSingleI32(rewriter, op, (1U << 29));
1535       Value constant_48_over_17 =
1536           getTosaConstTensorSingleI32(rewriter, op, 1515870810);
1537       Value constant_neg_32_over_17 =
1538           getTosaConstTensorSingleI32(rewriter, op, -1010580540);
1539 
1540       // F2 x = constant_48_over_17 + half_denominator *
1541       // constant_neg_32_over_17;
1542       auto op19_mul_half_denominator = CreateOpAndInfer<tosa::MulOp>(
1543           rewriter, op->getLoc(), int32_rsum_type, half_denominator,
1544           constant_neg_32_over_17, 31);
1545 
1546       auto op20_add_op19 = CreateOpAndInfer<tosa::AddOp>(
1547           rewriter, op->getLoc(), int32_rsum_type,
1548           op19_mul_half_denominator.getResult(), constant_48_over_17);
1549 
1550       // Newton-Raphson 3x iteration
1551       Value nr_x = op20_add_op19.getResult();
1552       for (int i = 0; i < 3; i++) {
1553         // half_denominator_times_x =
1554         // SaturatingRoundingDoublingHighMul(half_denominator, x)
1555         auto op21_mul_x_half_denominator = CreateOpAndInfer<tosa::MulOp>(
1556             rewriter, op->getLoc(), int32_rsum_type, nr_x, half_denominator,
1557             31);
1558 
1559         // F2 one_minus_half_denominator_times_x = F2::One() -
1560         // half_denominator_times_x
1561         auto op22_sub_one_op21 = CreateOpAndInfer<tosa::SubOp>(
1562             rewriter, op->getLoc(), int32_rsum_type, F2_one,
1563             op21_mul_x_half_denominator.getResult());
1564 
1565         // SaturatingRoundingDoublingHighMul(x,
1566         // one_minus_half_denominator_times_x)
1567         auto op23_mul_x_op22 = CreateOpAndInfer<tosa::MulOp>(
1568             rewriter, op->getLoc(), int32_rsum_type, nr_x,
1569             op22_sub_one_op21.getResult(), 31);
1570 
1571         // x + Rescale<2>(x * one_minus_half_denominator_times_x)
1572         auto op24_mul_op23_four = CreateOpAndInfer<tosa::MulOp>(
1573             rewriter, op->getLoc(), int32_rsum_type,
1574             op23_mul_x_op22.getResult(), four, 0);
1575 
1576         auto op25_add_x_op24 = CreateOpAndInfer<tosa::AddOp>(
1577             rewriter, op->getLoc(), int32_rsum_type, nr_x,
1578             op24_mul_op23_four.getResult());
1579 
1580         nr_x = op25_add_x_op24.getResult();
1581       }
1582 
1583       // Step 6. multiply exp(x) with 1 / sum(exp(x))
1584       // combined with Rescale<0>(ExactMulByPot<-1>(x))
1585       // so shift 30 instead of 31
1586       auto op26_mul_op13_x = CreateOpAndInfer<tosa::MulOp>(
1587           rewriter, op->getLoc(), int32_logits_type,
1588           op13_add_op12_op10.getResult(), nr_x, 31 - 1);
1589 
1590       // Right shift amount is
1591       // num_bits_over_unit + 31 - (sizeof(OutputT) * 8 =
1592       // (12 - headroom_plus_one) + 31 - 8 =
1593       // (12 + 31 - 8) - headroom_plus_one
1594       auto op27_sub_op16 = CreateOpAndInfer<tosa::SubOp>(
1595           rewriter, op->getLoc(), int32_rsum_type,
1596           getTosaConstTensorSingleI32(rewriter, op, 12 + 31 - 8),
1597           op16_clz_op15.getResult());
1598 
1599       auto op28_rshift_op26_op27 =
1600           CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
1601               rewriter, op->getLoc(), int32_logits_type,
1602               op26_mul_op13_x.getResult(), op27_sub_op16.getResult(), true);
1603 
1604       return buildRescale(rewriter, op, output_type,
1605                           op28_rshift_op26_op27.getResult(), 1.0, 0,
1606                           out_quant_type.getZeroPoint(), false, true);
1607 
1608     } else if (in_quant_type.getStorageTypeIntegralWidth() == 16) {
1609       // Step 1. get x - max(x)
1610       Value op1_rescale_in =
1611           buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f,
1612                        in_quant_type.getZeroPoint(), 0, false, true);
1613 
1614       auto op2_reducemax_op1 = CreateOpAndInfer<tosa::ReduceMaxOp>(
1615           rewriter, op->getLoc(), int32_rsum_type, op1_rescale_in,
1616           rewriter.getI64IntegerAttr(input_rank - 1));
1617 
1618       // output range is [-65535, 0]
1619       auto op3_sub_op1_op2 = CreateOpAndInfer<tosa::SubOp>(
1620           rewriter, op->getLoc(), int32_logits_type, op1_rescale_in,
1621           op2_reducemax_op1.getResult());
1622 
1623       auto exp_func = [](double x) -> double { return std::exp(x); };
1624 
1625       // Follow TFLite reference: tensorflow/lite/kernels/activations.cc
1626       Value exp_table_const =
1627           getTosaConst16bitTable(rewriter, op, exp_func, -10.0, 0);
1628 
1629       double input_diff_scale = in_quant_type.getScale() / (10.0 / 65535.0);
1630 
1631       // Step 2. rescale input from [-65535, 0] to [-32768, 32767] for LUT input
1632       Value op4_rescale_op3 = buildRescale(
1633           rewriter, op, int32_logits_type, op3_sub_op1_op2.getResult(),
1634           /*scale=*/input_diff_scale, /*input_zp=*/0, /*output_zp=*/0,
1635           /*double_round=*/true, /*scale32=*/true);
1636       auto op5_add_op4 = CreateOpAndInfer<tosa::AddOp>(
1637           rewriter, op->getLoc(), int32_logits_type, op4_rescale_op3,
1638           getTosaConstTensorSingleI32(rewriter, op, 32767));
1639 
1640       auto op6_cast_op5 = CreateOpAndInfer<tosa::CastOp>(
1641           rewriter, op->getLoc(), int16_logits_type, op5_add_op4.getResult());
1642 
1643       // Step 3. get exp() result
1644       // Output is 15.7.
1645       // In 8-bit case, no interpolation here, since input should be right on
1646       // table entry.
1647       auto op7_table_op6 = CreateOpAndInfer<tosa::TableOp>(
1648           rewriter, op->getLoc(), int32_logits_type, op6_cast_op5,
1649           exp_table_const);
1650 
1651       // Right shift 7 bits. output 15. Shouldn't lose any precision since last
1652       // 7 bits should be all 0.
1653       auto op8_rshift_op7 = CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
1654           rewriter, op->getLoc(), int32_logits_type, op7_table_op6.getResult(),
1655           getTosaConstTensorSingleI32(rewriter, op, 7), true);
1656 
1657       // Step 4. get sum(exp()). output 16.15
1658       auto op9_reducesum_op8 = CreateOpAndInfer<tosa::ReduceSumOp>(
1659           rewriter, op->getLoc(), int32_rsum_type, op8_rshift_op7.getResult(),
1660           rewriter.getI64IntegerAttr(input_rank - 1));
1661 
1662       // Step 5. calculate reciprocal(sum(exp()))
1663       // CLZ returns 32 - first non zero bit
1664       auto op10_clz_op9 =
1665           CreateOpAndInfer<tosa::ClzOp>(rewriter, op->getLoc(), int32_rsum_type,
1666                                         op9_reducesum_op8.getResult());
1667 
1668       auto op11_sub_op10 = CreateOpAndInfer<tosa::SubOp>(
1669           rewriter, op->getLoc(), int32_rsum_type, op10_clz_op9.getResult(),
1670           getTosaConstTensorSingleI32(rewriter, op, 1));
1671 
1672       // Left shift to get  1.30 format
1673       auto op12_lshift_op9_op11 = CreateOpAndInfer<tosa::LogicalLeftShiftOp>(
1674           rewriter, op->getLoc(), int32_rsum_type,
1675           op9_reducesum_op8.getResult(), op11_sub_op10.getResult());
1676 
1677       // Subtract (1 << 30) to make 0 <= x <= 1 under 0.30 format
1678       auto op13_sub_op12 = CreateOpAndInfer<tosa::SubOp>(
1679           rewriter, op->getLoc(), int32_rsum_type,
1680           op12_lshift_op9_op11.getResult(),
1681           getTosaConstTensorSingleI32(rewriter, op, (1u << 30)));
1682 
1683       // Right shift 14 bits to get output range [0, 65535]
1684       auto op14_rshift_op13 = CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
1685           rewriter, op->getLoc(), int32_rsum_type, op13_sub_op12.getResult(),
1686           getTosaConstTensorSingleI32(rewriter, op, 14), true);
1687 
1688       // Remap input to [-32768, 32767] for LUT input
1689       auto op15_add_op14 = CreateOpAndInfer<tosa::SubOp>(
1690           rewriter, op->getLoc(), int32_rsum_type, op14_rshift_op13.getResult(),
1691           getTosaConstTensorSingleI32(rewriter, op, 32768));
1692       auto op16_cast_op15 = CreateOpAndInfer<tosa::CastOp>(
1693           rewriter, op->getLoc(), int16_rsum_type, op15_add_op14.getResult());
1694 
1695       // Generate table for 1 / (1 + x), for 0 <= x <= 1
1696       auto one_over_one_plus_x_func = [](double x) -> double {
1697         return 1.0 / (1.0 + x);
1698       };
1699 
1700       Value one_over_one_plus_x_table_const = getTosaConst16bitTable(
1701           rewriter, op, one_over_one_plus_x_func, 0.0, 1.0);
1702 
1703       // Get (1 / sum(exp(x))) result as 23 bits (including sign bit)
1704       auto op17_table_op16 = CreateOpAndInfer<tosa::TableOp>(
1705           rewriter, op->getLoc(), int32_rsum_type, op16_cast_op15,
1706           one_over_one_plus_x_table_const);
1707 
1708       // Right shift 7 bits back to 0.15
1709       auto op18_rshift_op17 = CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
1710           rewriter, op->getLoc(), int32_rsum_type, op17_table_op16.getResult(),
1711           getTosaConstTensorSingleI32(rewriter, op, 7), true);
1712 
1713       // Step 6. multiply exp(max-x) with 1 / sum(exp(max-x))
1714       // lhs: 0.15, rhs: 0.15, output: 0.30
1715       auto op19_mul_op18_op8 = CreateOpAndInfer<tosa::MulOp>(
1716           rewriter, op->getLoc(), int32_logits_type, op18_rshift_op17,
1717           op8_rshift_op7, 0);
1718 
1719       auto op20_sub_op10 = CreateOpAndInfer<tosa::SubOp>(
1720           rewriter, op->getLoc(), int32_rsum_type,
1721           getTosaConstTensorSingleI32(rewriter, op, 31),
1722           op10_clz_op9.getResult());
1723 
1724       // Apply the clz back, we get 0.15 output
1725       // [0, 32767] corresponding to [0.0, 1.0]
1726       auto op21_rshift_op19_op20 =
1727           CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
1728               rewriter, op->getLoc(), int32_logits_type,
1729               op19_mul_op18_op8.getResult(), op20_sub_op10.getResult(), true);
1730 
1731       return buildRescale(rewriter, op, output_type,
1732                           op21_rshift_op19_op20.getResult(),
1733                           (1.0 / out_quant_type.getScale()) * (1.0 / 32768.0),
1734                           0, out_quant_type.getZeroPoint(), false, true);
1735     } else {
1736       op->emitOpError("Softmax: unknown quantization bitwidth");
1737       return llvm::None;
1738     }
1739   } else {
1740     SmallVector<int64_t> rsum_shape_v(input_type.getShape().begin(),
1741                                       input_type.getShape().end());
1742     rsum_shape_v[input_rank - 1] = 1;
1743     ArrayRef<int64_t> rsum_shape(rsum_shape_v);
1744 
1745     // Floating-point loewring is more direct:
1746     //
1747     // op1 = exp(logits)
1748     // op2 = reduce_sum(op1, -1)
1749     // op3 = reciprocal(op2)
1750     // op4 = mul(op1, op3)
1751     auto op1_exp_in = CreateOpAndInfer<tosa::ExpOp>(rewriter, op->getLoc(),
1752                                                     output_type, logits_value);
1753     RankedTensorType rsum_type =
1754         RankedTensorType::get(rsum_shape, output_type.getElementType());
1755 
1756     // Keep dims so we don't need to reshape later
1757     auto op2_reducesum_op1 = CreateOpAndInfer<tosa::ReduceSumOp>(
1758         rewriter, op->getLoc(), rsum_type, op1_exp_in.getResult(),
1759         rewriter.getI64IntegerAttr(input_rank - 1));
1760     auto op3_reciprocal_op2 = CreateOpAndInfer<tosa::ReciprocalOp>(
1761         rewriter, op->getLoc(), op2_reducesum_op1.getType(),
1762         op2_reducesum_op1.getResult());
1763 
1764     return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
1765                                          op1_exp_in.getResult(),
1766                                          op3_reciprocal_op2.getResult(), 0)
1767         .getResult();
1768   }
1769 }
1770 
1771 // Lowers LogSoftmax to a sequence of TOSA ops.
convertLogSoftmaxOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value logits_value)1772 llvm::Optional<Value> convertLogSoftmaxOp(PatternRewriter& rewriter,
1773                                           Operation* op, Value result_value,
1774                                           Value logits_value) {
1775   // log_softmax = log(exp(logits) / reduce_sum(exp(logits), -1))
1776   // op1 = exp(logits)
1777   // op2 = reduce_sum(op1, -1)
1778   // op3 = reciprocal(op2)
1779   // op4 = mul(op1, op3)
1780   // op5 = log(op4)
1781 
1782   TensorType output_type = result_value.getType().dyn_cast<TensorType>();
1783   // Not a tensor output
1784   if (!output_type) {
1785     op->emitOpError("LogSoftmax: output type not tensor.");
1786     return llvm::None;
1787   }
1788 
1789   RankedTensorType input_type =
1790       op->getOperand(0).getType().dyn_cast<RankedTensorType>();
1791   if (!input_type) {
1792     op->emitOpError("LogSoftmax: input type not ranked tensor.");
1793     return llvm::None;
1794   }
1795 
1796   mlir::quant::UniformQuantizedType in_quant_type =
1797       input_type.getElementType()
1798           .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
1799   mlir::quant::UniformQuantizedType out_quant_type =
1800       output_type.getElementType()
1801           .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
1802   if (in_quant_type || out_quant_type) {
1803     op->emitOpError("Quantized log_softmax lowering not implemented yet");
1804     return llvm::None;
1805   }
1806 
1807   auto op1_exp_in = CreateOpAndInfer<tosa::ExpOp>(rewriter, op->getLoc(),
1808                                                   output_type, logits_value);
1809 
1810   // reduce_sum on last dimension
1811   int32_t input_rank = input_type.getShape().size();
1812   // Keep dims so we don't need to reshape later
1813   auto op2_reducesum_op1 = CreateOpAndInfer<tosa::ReduceSumOp>(
1814       rewriter, op->getLoc(),
1815       UnrankedTensorType::get(output_type.getElementType()),
1816       op1_exp_in.getResult(), rewriter.getI64IntegerAttr(input_rank - 1));
1817   auto op3_reciprocal_op2 = CreateOpAndInfer<tosa::ReciprocalOp>(
1818       rewriter, op->getLoc(), op2_reducesum_op1.getType(),
1819       op2_reducesum_op1.getResult());
1820 
1821   auto op4_mul_op1_op3 = CreateOpAndInfer<tosa::MulOp>(
1822       rewriter, op->getLoc(), output_type, op1_exp_in.getResult(),
1823       op3_reciprocal_op2.getResult(), 0);
1824 
1825   return CreateOpAndInfer<tosa::LogOp>(rewriter, op->getLoc(), output_type,
1826                                        op4_mul_op1_op3.getResult())
1827       .getResult();
1828 }
1829 
1830 // Lowers SpaceToDepth to a sequence of TOSA ops.  Supports NHWC.
convertSpaceToDepthOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,IntegerAttr block_size_attr,StringAttr data_format)1831 llvm::Optional<Value> convertSpaceToDepthOp(PatternRewriter& rewriter,
1832                                             Operation* op, Value result_value,
1833                                             Value input_value,
1834                                             IntegerAttr block_size_attr,
1835                                             StringAttr data_format) {
1836   // NHWC lowering version:
1837   // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1]//b, b, orig_shape[2]//b,
1838   // b, orig_shape[3]])
1839   // a3 = tf.transpose(a2, [0, 1, 3, 2, 4, 5])
1840   // a4 = tf.reshape(a3, [orig_shape[0], orig_shape[1]//b, orig_shape[2]//b,
1841   // orig_shape[3]*b*b])
1842   // return a4
1843   RankedTensorType output_type =
1844       result_value.getType().dyn_cast<RankedTensorType>();
1845 
1846   // Not a ranked tensor output.
1847   if (!output_type) {
1848     op->emitOpError("SpaceToDepth: output type not ranked tensor.");
1849     return llvm::None;
1850   }
1851 
1852   RankedTensorType input_type =
1853       input_value.getType().dyn_cast<RankedTensorType>();
1854   if (!input_type) {
1855     op->emitOpError("SpaceToDepth: input type not ranked tensor.");
1856     return llvm::None;
1857   }
1858 
1859   if (input_type.getRank() != 4) {
1860     op->emitOpError("SpaceToDepth: input rank not 4.");
1861     return llvm::None;
1862   }
1863 
1864   auto input_shape = input_type.getShape();
1865 
1866   if (!block_size_attr) {  // This is a required parameter
1867     op->emitOpError("SpaceToDepth: block size attribute not set.");
1868     return llvm::None;
1869   }
1870 
1871   SmallVector<int64_t, 2> block_size;
1872   block_size.assign(2, block_size_attr.getInt());
1873 
1874   if (!data_format) data_format = rewriter.getStringAttr("NHWC");
1875 
1876   if (data_format.getValue().str() != "NHWC") {
1877     op->emitOpError("SpaceToDepth: data format not NHWC.");
1878     return llvm::None;
1879   }
1880 
1881   assert(block_size[0] * block_size[1] != 0);
1882 
1883   SmallVector<int64_t, 6> a_reshape_dims;
1884   a_reshape_dims.push_back(input_shape[0]);
1885   a_reshape_dims.push_back(input_shape[1] / block_size[0]);
1886   a_reshape_dims.push_back(block_size[0]);
1887   a_reshape_dims.push_back(input_shape[2] / block_size[1]);
1888   a_reshape_dims.push_back(block_size[1]);
1889   a_reshape_dims.push_back(input_shape[3]);
1890 
1891   RankedTensorType a_reshape_output_type =
1892       RankedTensorType::get(a_reshape_dims, output_type.getElementType());
1893   auto a2_reshape_a_op = CreateOpAndInfer<tosa::ReshapeOp>(
1894       rewriter, op->getLoc(), a_reshape_output_type, input_value,
1895       rewriter.getI64ArrayAttr(a_reshape_dims));
1896 
1897   llvm::Optional<Value> a3_transpose_perm = getConstTensor<int32_t>(
1898       rewriter, op, /*vec=*/{0, 1, 3, 2, 4, 5}, /*shape=*/{6});
1899 
1900   if (!a3_transpose_perm) return llvm::None;
1901 
1902   auto a3_transpose_a2_op = CreateOpAndInfer<tosa::TransposeOp>(
1903       rewriter, op->getLoc(), a_reshape_output_type,
1904       a2_reshape_a_op.getResult(), a3_transpose_perm.getValue());
1905 
1906   SmallVector<int64_t, 4> a3_reshape_dims;
1907   a3_reshape_dims.push_back(input_shape[0]);
1908   a3_reshape_dims.push_back(input_shape[1] / block_size[0]);
1909   a3_reshape_dims.push_back(input_shape[2] / block_size[1]);
1910   a3_reshape_dims.push_back(input_shape[3] * block_size[0] * block_size[1]);
1911 
1912   RankedTensorType a3_reshape_output_type =
1913       RankedTensorType::get(a3_reshape_dims, output_type.getElementType());
1914   return CreateOpAndInfer<tosa::ReshapeOp>(
1915              rewriter, op->getLoc(), a3_reshape_output_type,
1916              a3_transpose_a2_op.getResult(),
1917              rewriter.getI64ArrayAttr(a3_reshape_dims))
1918       .getResult();
1919 }
1920 
1921 // Lowers DepthToSpace to a sequence of TOSA ops.  Supports NHWC.
convertDepthToSpaceOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,IntegerAttr block_size_attr,StringAttr data_format)1922 llvm::Optional<Value> convertDepthToSpaceOp(PatternRewriter& rewriter,
1923                                             Operation* op, Value result_value,
1924                                             Value input_value,
1925                                             IntegerAttr block_size_attr,
1926                                             StringAttr data_format) {
1927   // NHWC version
1928   // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1], orig_shape[2], b, b,
1929   // orig_shape[3] // (b*b)])
1930   // a3 = tf.transpose(a2, [0, 1, 3, 2, 4, 5])
1931   // a4 = tf.reshape(a3, [orig_shape[0], orig_shape[1] * b, orig_shape[2] * b,
1932   // orig_shape[3] // (b*b)])
1933   // return a4
1934 
1935   RankedTensorType output_type =
1936       result_value.getType().dyn_cast<RankedTensorType>();
1937 
1938   // Not a ranked tensor output
1939   if (!output_type) {
1940     op->emitOpError("DepthToSpace: output type not ranked tensor.");
1941     return llvm::None;
1942   }
1943 
1944   RankedTensorType input_type =
1945       input_value.getType().dyn_cast<RankedTensorType>();
1946   if (!input_type) {
1947     op->emitOpError("DepthToSpace: input type not ranked tensor.");
1948     return llvm::None;
1949   }
1950 
1951   if (input_type.getRank() != 4) return llvm::None;
1952   auto input_shape = input_type.getShape();
1953 
1954   if (!block_size_attr) {  // This is a required parameter
1955     op->emitOpError("DepthToSpace: block size attribute not set.");
1956     return llvm::None;
1957   }
1958 
1959   SmallVector<int64_t, 2> block_size;
1960   block_size.assign(2, block_size_attr.getInt());
1961 
1962   if (!data_format) data_format = rewriter.getStringAttr("NHWC");
1963   if (data_format.getValue().str() != "NHWC") {
1964     op->emitOpError("DepthToSpace: data format not NHWC.");
1965     return llvm::None;
1966   }
1967 
1968   assert(block_size[0] * block_size[1] != 0);
1969 
1970   SmallVector<int64_t, 6> a_reshape_dims;
1971   a_reshape_dims.push_back(input_shape[0]);
1972   a_reshape_dims.push_back(input_shape[1]);
1973   a_reshape_dims.push_back(input_shape[2]);
1974   a_reshape_dims.push_back(block_size[0]);
1975   a_reshape_dims.push_back(block_size[1]);
1976   a_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
1977 
1978   RankedTensorType a_reshape_output_type =
1979       RankedTensorType::get(a_reshape_dims, output_type.getElementType());
1980   auto a2_reshape_a_op = CreateOpAndInfer<tosa::ReshapeOp>(
1981       rewriter, op->getLoc(), a_reshape_output_type, input_value,
1982       rewriter.getI64ArrayAttr(a_reshape_dims));
1983 
1984   llvm::Optional<Value> a3_transpose_perm = getConstTensor<int32_t>(
1985       rewriter, op, /*vec=*/{0, 1, 3, 2, 4, 5}, /*shape=*/{6});
1986 
1987   if (!a3_transpose_perm) return llvm::None;
1988 
1989   auto a3_transpose_a2_op = CreateOpAndInfer<tosa::TransposeOp>(
1990       rewriter, op->getLoc(), a_reshape_output_type,
1991       a2_reshape_a_op.getResult(), a3_transpose_perm.getValue());
1992 
1993   SmallVector<int64_t, 4> a3_reshape_dims;
1994   a3_reshape_dims.push_back(input_shape[0]);
1995   a3_reshape_dims.push_back(input_shape[1] * block_size[0]);
1996   a3_reshape_dims.push_back(input_shape[2] * block_size[1]);
1997   a3_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
1998 
1999   RankedTensorType a3_reshape_output_type =
2000       RankedTensorType::get(a3_reshape_dims, output_type.getElementType());
2001   return CreateOpAndInfer<tosa::ReshapeOp>(
2002              rewriter, op->getLoc(), a3_reshape_output_type,
2003              a3_transpose_a2_op.getResult(),
2004              rewriter.getI64ArrayAttr(a3_reshape_dims))
2005       .getResult();
2006 }
2007 
2008 // Lowers Split to a sequence of TOSA ops.
convertSplitOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,int32_t num_split,int32_t axis)2009 llvm::Optional<SmallVector<Value>> convertSplitOp(
2010     PatternRewriter& rewriter, Operation* op, Value result_value,
2011     Value input_value, int32_t num_split, int32_t axis) {
2012   // This lowering creates num_split slice ops and ties them together
2013   // with IdentityN to get from an array of Operations to a single Operation
2014   // with a list of result tensors.
2015   RankedTensorType result_type =
2016       result_value.getType().dyn_cast<RankedTensorType>();
2017   // Not a ranked tensor output
2018   if (!result_type) {
2019     op->emitOpError("Split: output type not ranked tensor.");
2020     return llvm::None;
2021   }
2022 
2023   RankedTensorType input_type =
2024       input_value.getType().dyn_cast<RankedTensorType>();
2025   if (!input_type) {
2026     op->emitOpError("Split: input type not ranked tensor.");
2027     return llvm::None;
2028   }
2029 
2030   Type etype = input_type.getElementType();
2031 
2032   if (axis < 0) axis += input_type.getRank();
2033   assert(num_split > 0);
2034   assert(axis >= 0 && axis < input_type.getRank());
2035 
2036   Value slice_value = input_value;
2037   bool is_dyn_split = input_type.getDimSize(axis) == -1;
2038   if (is_dyn_split) {
2039     SmallVector<int64_t> new_shape;
2040     for (int i = 0, s = input_type.getRank(); i < s; i++) {
2041       if (i != axis) {
2042         new_shape.push_back(input_type.getDimSize(i));
2043         continue;
2044       }
2045 
2046       new_shape.push_back(num_split);
2047       new_shape.push_back(-1);
2048     }
2049     slice_value = CreateOpAndInfer<tosa::ReshapeOp>(
2050         rewriter, op->getLoc(), RankedTensorType::get(new_shape, etype),
2051         input_value, rewriter.getI64ArrayAttr(new_shape));
2052   }
2053 
2054   RankedTensorType slice_type = slice_value.getType().cast<RankedTensorType>();
2055   assert((slice_type.getDimSize(axis) % num_split) == 0);
2056 
2057   // Each slice has a different beginning point.
2058   // The slice size is actually the same each op.
2059   SmallVector<int64_t> begin_vals, size_vals;
2060   for (int j = 0, s = slice_type.getRank(); j < s; j++) {
2061     begin_vals.push_back(0);
2062     size_vals.push_back(slice_type.getDimSize(j));
2063   }
2064   size_vals[axis] = size_vals[axis] / num_split;
2065 
2066   SmallVector<Value> results_vec;
2067   for (int i = 0; i < num_split; i++) {
2068     begin_vals[axis] = i * size_vals[axis];
2069     ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
2070     ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
2071 
2072     Value result = CreateOpAndInfer<tosa::SliceOp>(
2073         rewriter, op->getLoc(),
2074         RankedTensorType::get(size_vals, result_type.getElementType()),
2075         slice_value, begin, size);
2076 
2077     if (is_dyn_split) {
2078       SmallVector<int64_t> out_reshape_shape;
2079       for (int i = 0, s = size_vals.size(); i < s; i++)
2080         if (i != axis) out_reshape_shape.push_back(size_vals[i]);
2081 
2082       result = CreateOpAndInfer<tosa::ReshapeOp>(
2083                    rewriter, op->getLoc(),
2084                    RankedTensorType::get(out_reshape_shape, etype), result,
2085                    rewriter.getI64ArrayAttr(out_reshape_shape))
2086                    .getResult();
2087     }
2088 
2089     results_vec.push_back(result);
2090   }
2091 
2092   return results_vec;
2093 }
2094 
2095 // Lowers SplitV to a sequence of TOSA ops.
convertSplitVOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,SmallVectorImpl<int32_t> & size_split,int32_t axis)2096 llvm::Optional<SmallVector<Value>> convertSplitVOp(
2097     PatternRewriter& rewriter, Operation* op, Value result_value,
2098     Value input_value, SmallVectorImpl<int32_t>& size_split, int32_t axis) {
2099   // This lowering creates num_split slice ops and ties them together
2100   // with IdentityN to get from an array of Operations to a single Operation
2101   // with a list of result tensors.
2102   RankedTensorType result_type =
2103       result_value.getType().dyn_cast<RankedTensorType>();
2104   // Not a ranked tensor output
2105   if (!result_type) {
2106     op->emitOpError("SplitV: output type not ranked tensor.");
2107     return llvm::None;
2108   }
2109 
2110   RankedTensorType input_type =
2111       input_value.getType().dyn_cast<RankedTensorType>();
2112   if (!input_type) {
2113     op->emitOpError("SplitV: input type not ranked tensor.");
2114     return llvm::None;
2115   }
2116 
2117   auto input_shape = input_type.getShape();
2118 
2119   SmallVector<Value> results_vec;
2120 
2121   if ((axis >= 0 && axis >= input_shape.size()) ||
2122       (axis < 0 && axis < -input_shape.size())) {
2123     op->emitOpError("SplitV: invalid axis value.");
2124     return llvm::None;
2125   }
2126   axis = adjust_axis(axis, input_shape.size());
2127   int32_t size_split_sum = 0;
2128   for (int i = 0; i < size_split.size(); i++) {
2129     size_split_sum += size_split[i];
2130   }
2131 
2132   // The split sizes must sum up to the size of the axis being split
2133   assert(size_split_sum == input_shape[axis]);
2134 
2135   int32_t curr_split_start = 0;
2136   for (int i = 0; i < size_split.size(); i++) {
2137     // Each slice has a different beginning point.
2138     // The slice size is different for each op.
2139     SmallVector<int64_t> begin_vals, size_vals;
2140 
2141     for (int j = 0; j < input_shape.size(); j++) {
2142       if (j == axis) {
2143         begin_vals.push_back(curr_split_start);
2144         size_vals.push_back(size_split[i]);
2145       } else {
2146         begin_vals.push_back(0);
2147         size_vals.push_back(input_shape[j]);
2148       }
2149     }
2150 
2151     ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
2152     ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
2153 
2154     auto slice_op = CreateOpAndInfer<tosa::SliceOp>(
2155         rewriter, op->getLoc(),
2156         RankedTensorType::get(size_vals, result_type.getElementType()),
2157         input_value, begin, size);
2158 
2159     results_vec.push_back(slice_op.getResult());
2160 
2161     // Next start position
2162     curr_split_start += size_split[i];
2163   }
2164 
2165   return results_vec;
2166 }
2167 
2168 // Helper function to reverse negative striding. Only checks for -1 as that is
2169 // the only legal negative stride.
reverseNegativeStride(PatternRewriter & rewriter,Operation * op,Value input,ArrayRef<int32_t> strides)2170 static Value reverseNegativeStride(PatternRewriter& rewriter, Operation* op,
2171                                    Value input, ArrayRef<int32_t> strides) {
2172   Type reverse_ty = UnrankedTensorType::get(
2173       input.getType().cast<ShapedType>().getElementType());
2174   for (auto it : llvm::enumerate(strides)) {
2175     auto axis = it.index();
2176     auto stride = it.value();
2177     if (stride != -1) continue;
2178 
2179     input = CreateOpAndInfer<tosa::ReverseOp>(rewriter, op->getLoc(),
2180                                               reverse_ty, input,
2181                                               rewriter.getI64IntegerAttr(axis))
2182                 .getResult();
2183   }
2184 
2185   return input;
2186 }
2187 
2188 // Lowers StridedSlice to a sequence of TOSA ops.
convertStridedSliceOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value begin_value,Value end_value,Value strides_value,int32_t begin_mask,int32_t end_mask,int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask)2189 llvm::Optional<Value> convertStridedSliceOp(
2190     PatternRewriter& rewriter, Operation* op, Value result_value,
2191     Value input_value, Value begin_value, Value end_value, Value strides_value,
2192     int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
2193     int32_t new_axis_mask, int32_t shrink_axis_mask) {
2194   // The mask arguments are bitmasks where bit [i] applies to
2195   // dimension [i] of the input tensor.
2196   //
2197   // The rough algorithm for lowering strided slice is as follows:
2198   //
2199   // 0. Process begin/end masks, since they are basically syntactic sugar
2200   // on top of the begin_value/end_value arrays
2201   //
2202   // 1. Slice1: Ignoring stride, slice the interesting range from the input
2203   // tensor
2204   //
2205   // 2. Reshape2: Reshape the tensor from (1) such that each dimension with
2206   // stride is split into two dimensions of size_i/stride_i, stride_i. A naive
2207   // implementation doubles the input tensor rank, but only dimensions being
2208   // strided actually need to be doubled.
2209   //
2210   // 3. Slice3: Slice the tensor from (2) such that we select index [0] from
2211   // each of the stride_i dimensions in (2)
2212   //
2213   // 4. Reshape4: Reshape the tensor to eliminate the stride_i dimensions, add
2214   // any dimensions in new_axis_mask and remove any dimensions in the
2215   // shrink_axis_mask
2216 
2217   // Limitations:
2218   // * This implementation only supports ellipsis_mask=0 for now
2219   // * This implementation does not support reverse stride yet.  Will need
2220   // to insert tosa.Reverse operators for this.
2221   if (ellipsis_mask != 0) {
2222     (void)rewriter.notifyMatchFailure(op, "ellipses mask not supported yet");
2223   }
2224 
2225   ShapedType input_type = input_value.getType().cast<ShapedType>();
2226   ShapedType result_type = result_value.getType().cast<ShapedType>();
2227 
2228   // Extract the begin/end/stride tensors
2229   SmallVector<int32_t> begin, end, strides;
2230 
2231   DenseIntElementsAttr strides_attr;
2232 
2233   if (!matchPattern(strides_value, m_Constant(&strides_attr))) {
2234     (void)rewriter.notifyMatchFailure(op, "strides is not a constant");
2235     return llvm::None;
2236   }
2237 
2238   if (failed(getVectorFromValue32(strides_value, strides))) {
2239     (void)rewriter.notifyMatchFailure(op, "strides isn't a constant");
2240     return llvm::None;
2241   }
2242 
2243   // Current configuration does not support negative strides greater than 1.
2244   // Bail out for now (fix if this proves to be legal).
2245   for (auto stride : strides)
2246     if (stride < -1) return llvm::None;
2247 
2248   bool all_strides_one = true;
2249   for (auto stride : strides) all_strides_one &= abs(stride) == 1;
2250 
2251   int32_t strides_size = strides_attr.getNumElements();
2252 
2253   // If all of the masks are set we can just bypass the entire thing.
2254   const int32_t all_masks_one = (1 << strides_size) - 1;
2255 
2256   if (failed(getVectorFromValue32(begin_value, begin)) &&
2257       begin_mask != all_masks_one) {
2258     (void)rewriter.notifyMatchFailure(op, "begin isn't a constant");
2259     return llvm::None;
2260   }
2261 
2262   if (failed(getVectorFromValue32(end_value, end)) &&
2263       end_mask != all_masks_one) {
2264     (void)rewriter.notifyMatchFailure(op, "end isn't a constant");
2265     return llvm::None;
2266   }
2267 
2268   // If begin value is 0 we can set the begin_mask instead..
2269   for (auto val : llvm::enumerate(begin)) {
2270     if (val.value() == 0) begin_mask |= (0x1 << val.index());
2271   }
2272 
2273   // If end value is -1 we can set the end_mask instead.
2274   for (const auto& val : llvm::enumerate(end)) {
2275     if (val.value() == -1) end_mask |= (0x1 << val.index());
2276   }
2277 
2278   if (all_strides_one && begin_mask == all_masks_one &&
2279       end_mask == all_masks_one) {
2280     return reverseNegativeStride(rewriter, op, input_value, strides);
2281   }
2282 
2283   if (!input_type.hasRank()) {
2284     return (void)rewriter.notifyMatchFailure(op,
2285                                              "input type not ranked tensor."),
2286            llvm::None;
2287   }
2288 
2289   int32_t input_rank = input_type.getRank();
2290 
2291   // If strides is incomplete, pad out to the full size.
2292   while (strides.size() < input_rank) strides.push_back(1);
2293 
2294   // Unspecified begins should set the begin mask.
2295   while (begin.size() < input_rank) {
2296     begin_mask = begin_mask | (1 << begin.size());
2297     begin.push_back(0);
2298   }
2299 
2300   // Unspecified ends should set the end mask.
2301   while (end.size() < input_rank) {
2302     end_mask = end_mask | (1 << end.size());
2303     end.push_back(-1);
2304   }
2305 
2306   auto input_shape = input_type.getShape();
2307 
2308   // Step 0: Process the begin/end masks and build the begin/sizes for the
2309   // first slice
2310   SmallVector<int64_t> a1_begin(input_rank), a1_size(input_rank);
2311   for (int i = 0; i < input_rank; ++i) {
2312     // Mask-bit overrides begin/end values.
2313     if (begin_mask & (1 << i)) begin[i] = 0;
2314     if (end_mask & (1 << i)) end[i] = input_shape[i];
2315 
2316     if (a1_begin[i] < 0 && ShapedType::isDynamic(input_shape[i])) {
2317       (void)rewriter.notifyMatchFailure(
2318           op, "begin offset is negative on dynamic size.");
2319       return llvm::None;
2320     }
2321 
2322     // Wrap around index if begin and end is negative
2323     a1_begin[i] = begin[i];
2324     if (a1_begin[i] < 0) a1_begin[i] += input_shape[i];
2325 
2326     if (ShapedType::isDynamic(end[i]) &&
2327         ShapedType::isDynamic(input_shape[i])) {
2328       // Slice using -1 as TOSA's sentinal value.
2329       a1_size[i] = -1;
2330     } else if (end[i] < 0 && ShapedType::isDynamic(input_shape[i])) {
2331       // Other dynamic cases cannot be handled.
2332       (void)rewriter.notifyMatchFailure(
2333           op, "input dim is dynamic and slice end depends on the length.");
2334       return llvm::None;
2335     } else {
2336       a1_size[i] = end[i] - a1_begin[i];
2337       if (end[i] < 0) a1_size[i] += input_shape[i];
2338     }
2339 
2340     // Shrink axis mask means we know the size and stride are 1.
2341     if (shrink_axis_mask & (1 << i)) {
2342       a1_size[i] = 1;
2343       strides[i] = 1;
2344     }
2345   }
2346 
2347   // Step 1: Slice the input array
2348   auto a1_slice_op = CreateOpAndInfer<tosa::SliceOp>(
2349       rewriter, op->getLoc(),
2350       RankedTensorType::get(a1_size, input_type.getElementType()), input_value,
2351       rewriter.getI64ArrayAttr(a1_begin), rewriter.getI64ArrayAttr(a1_size));
2352 
2353   if (all_strides_one) {
2354     auto reversed =
2355         reverseNegativeStride(rewriter, op, a1_slice_op.getResult(), strides);
2356     auto shape = reversed.getType().cast<RankedTensorType>().getShape();
2357 
2358     SmallVector<int64_t> new_shape;
2359     for (int i = 0; i < input_rank; ++i) {
2360       if (!(shrink_axis_mask & (1 << i))) {
2361         if (new_axis_mask & (1 << i)) new_shape.push_back(1);
2362         new_shape.push_back((shape[i]));
2363       }
2364     }
2365 
2366     return CreateOpAndInfer<tosa::ReshapeOp>(
2367                rewriter, op->getLoc(), result_type, reversed,
2368                rewriter.getI64ArrayAttr(new_shape))
2369         .getResult();
2370   }
2371 
2372   // Step 2: reshape the sliced array
2373   SmallVector<int64_t> a2_shape(input_rank * 2);
2374   for (int i = 0; i < input_rank; ++i) {
2375     a2_shape[i * 2 + 0] = a1_size[i] == -1 ? -1 : a1_size[i] / abs(strides[i]);
2376     a2_shape[i * 2 + 1] = abs(strides[i]);
2377   }
2378 
2379   auto a2_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
2380       rewriter, op->getLoc(),
2381       RankedTensorType::get(a2_shape, input_type.getElementType()),
2382       a1_slice_op.getResult(), rewriter.getI64ArrayAttr(a2_shape));
2383 
2384   // Step 3: take a slice along the strides
2385   SmallVector<int64_t> a3_begin(input_rank * 2), a3_size(input_rank * 2);
2386   for (int i = 0; i < input_rank; ++i) {
2387     a3_begin[i * 2 + 0] = 0;
2388     a3_begin[i * 2 + 1] = 0;
2389 
2390     if (shrink_axis_mask & (1 << i)) {
2391       a3_size[i * 2 + 0] = 1;
2392     } else {
2393       a3_size[i * 2 + 0] =
2394           (a1_size[i] == -1) ? -1 : (a1_size[i] / abs(strides[i]));
2395     }
2396     a3_size[i * 2 + 1] = 1;
2397   }
2398 
2399   auto a3_slice_op = CreateOpAndInfer<tosa::SliceOp>(
2400       rewriter, op->getLoc(),
2401       RankedTensorType::get(a3_size, input_type.getElementType()),
2402       a2_reshape_op.getResult(), rewriter.getI64ArrayAttr(a3_begin),
2403       rewriter.getI64ArrayAttr(a3_size));
2404 
2405   // Step 4: reshape the now-strided tensor
2406   SmallVector<int64_t> a4_shape;
2407   for (int i = 0; i < input_rank; ++i) {
2408     if (!(shrink_axis_mask & (1 << i))) {
2409       if (new_axis_mask & (1 << i)) a4_shape.push_back(1);
2410       a4_shape.push_back(
2411           ((a1_size[i] == -1) ? -1 : (a1_size[i] / abs(strides[i]))));
2412     }
2413   }
2414 
2415   auto a4_reshape_op =
2416       CreateOpAndInfer<tosa::ReshapeOp>(rewriter, op->getLoc(), result_type,
2417                                         a3_slice_op.getResult(),
2418                                         rewriter.getI64ArrayAttr(a4_shape))
2419           .getResult();
2420 
2421   return reverseNegativeStride(rewriter, op, a4_reshape_op, strides);
2422 }
2423 
2424 // Lowers FloorDiv to a sequence of TOSA operators.
convertFloorDivOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value lhs_value,Value rhs_value)2425 llvm::Optional<Value> convertFloorDivOp(PatternRewriter& rewriter,
2426                                         Operation* op, Value result_value,
2427                                         Value lhs_value, Value rhs_value) {
2428   // FloorDiv lowering:
2429   // floor(1/rhs * lhs)
2430   //
2431   // a1 = reciprocal(rhs);
2432   // a2 = mul(lhs, a1);
2433   // a3 = floor(a2);
2434   // return a3;
2435   ShapedType output_type = result_value.getType().dyn_cast<ShapedType>();
2436   // Not a shaped tensor output
2437   if (!output_type) return llvm::None;
2438 
2439   Type element_type = output_type.getElementType();
2440 
2441   if (element_type.isa<IntegerType>()) {
2442     return CreateOpAndInfer<tosa::DivOp>(rewriter, op->getLoc(), output_type,
2443                                          lhs_value, rhs_value)
2444         .getResult();
2445   }
2446 
2447   auto a1_reciprocal_rhs_op = CreateOpAndInfer<tosa::ReciprocalOp>(
2448       rewriter, op->getLoc(), rhs_value.getType(), rhs_value);
2449   auto a2_mul_lhs_a1_op = CreateOpAndInfer<tosa::MulOp>(
2450       rewriter, op->getLoc(), output_type, lhs_value,
2451       a1_reciprocal_rhs_op.getResult(), 0);
2452   return CreateOpAndInfer<tosa::FloorOp>(rewriter, op->getLoc(), output_type,
2453                                          a2_mul_lhs_a1_op.getResult())
2454       .getResult();
2455 }
2456 
2457 // Lowers FloorMod to a sequence of TOSA operators.
convertFloorModOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value lhs_value,Value rhs_value)2458 llvm::Optional<Value> convertFloorModOp(PatternRewriter& rewriter,
2459                                         Operation* op, Value result_value,
2460                                         Value lhs_value, Value rhs_value) {
2461   // FloorMod lowering:
2462   // (1/rhs * lhs) - floor(1/rhs * lhs)
2463   // a1 = reciprocal(rhs);
2464   // a2 = mul(lhs, a1);
2465   // a3 = floor(a2);
2466   // a4 = sub(a2, a3);
2467   // return a4;
2468 
2469   RankedTensorType output_type =
2470       result_value.getType().dyn_cast<RankedTensorType>();
2471   // Not a ranked tensor output
2472   if (!output_type) return llvm::None;
2473 
2474   auto a1_reciprocal_rhs_op = CreateOpAndInfer<tosa::ReciprocalOp>(
2475       rewriter, op->getLoc(), rhs_value.getType(), rhs_value);
2476   auto a2_mul_lhs_a1_op = CreateOpAndInfer<tosa::MulOp>(
2477       rewriter, op->getLoc(), output_type, lhs_value,
2478       a1_reciprocal_rhs_op.getResult(), 0);
2479   auto a3_floor_a2_op = CreateOpAndInfer<tosa::FloorOp>(
2480       rewriter, op->getLoc(), output_type, a2_mul_lhs_a1_op.getResult());
2481   return CreateOpAndInfer<tosa::SubOp>(rewriter, op->getLoc(), output_type,
2482                                        a2_mul_lhs_a1_op.getResult(),
2483                                        a3_floor_a2_op.getResult())
2484       .getResult();
2485 }
2486 
2487 // Lowers FusedActivation to a sequence of TOSA ops.
convertFusedActivation(PatternRewriter & rewriter,Operation * op,Value input_value,StringAttr fused_activation_fn)2488 llvm::Optional<Value> convertFusedActivation(PatternRewriter& rewriter,
2489                                              Operation* op, Value input_value,
2490                                              StringAttr fused_activation_fn) {
2491   ShapedType input_type = input_value.getType().dyn_cast<ShapedType>();
2492   if (!input_type) return llvm::None;
2493 
2494   bool input_is_qtype =
2495       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2496 
2497   if (input_is_qtype) {
2498     // We can always make output/input tensor's scale/zp always be the same
2499     // when legalizing fused_activation_function, as it's generated during
2500     // legalization.
2501     auto input_qtype =
2502         input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2503 
2504     if (fused_activation_fn.getValue() == "NONE") {
2505       return input_value;
2506     } else if (fused_activation_fn.getValue() == "RELU") {
2507       int32_t quantized_0 = input_qtype.getZeroPoint();
2508       int32_t quantized_max = input_qtype.getStorageTypeMax();
2509 
2510       auto clamp_op = CreateOpAndInfer<tosa::ClampOp>(
2511           rewriter, op->getLoc(), input_type, input_value,
2512           rewriter.getI64IntegerAttr(quantized_0),
2513           rewriter.getI64IntegerAttr(quantized_max),
2514           rewriter.getF32FloatAttr(0), rewriter.getF32FloatAttr(0));
2515 
2516       return clamp_op.getResult();
2517     } else if (fused_activation_fn.getValue() == "RELU6") {
2518       int64_t quantized_0 = input_qtype.getZeroPoint();
2519       int64_t quantized_6 = std::llround((6.0f / input_qtype.getScale()) +
2520                                          input_qtype.getZeroPoint());
2521       int64_t quantized_min = input_qtype.getStorageTypeMin();
2522       int64_t quantized_max = input_qtype.getStorageTypeMax();
2523 
2524       int64_t clamp_min = std::max(quantized_0, quantized_min);
2525       int64_t clamp_max = std::min(quantized_6, quantized_max);
2526 
2527       auto clamp_op = CreateOpAndInfer<tosa::ClampOp>(
2528           rewriter, op->getLoc(), input_type, input_value,
2529           rewriter.getI64IntegerAttr(clamp_min),
2530           rewriter.getI64IntegerAttr(clamp_max), rewriter.getF32FloatAttr(0),
2531           rewriter.getF32FloatAttr(0));
2532 
2533       return clamp_op.getResult();
2534     } else if (fused_activation_fn.getValue() == "RELU_N1_TO_1") {
2535       int64_t quantized_max = input_qtype.getStorageTypeMax();
2536       int64_t quantized_min = input_qtype.getStorageTypeMin();
2537       int64_t quantized_n1 = std::llround((-1.0f / input_qtype.getScale()) +
2538                                           input_qtype.getZeroPoint());
2539       int64_t quantized_1 = std::llround((1.0f / input_qtype.getScale()) +
2540                                          input_qtype.getZeroPoint());
2541 
2542       int64_t clamp_min = std::max(quantized_n1, quantized_min);
2543       int64_t clamp_max = std::min(quantized_1, quantized_max);
2544 
2545       auto clamp_op = CreateOpAndInfer<tosa::ClampOp>(
2546           rewriter, op->getLoc(), input_type, input_value,
2547           rewriter.getI64IntegerAttr(clamp_min),
2548           rewriter.getI64IntegerAttr(clamp_max), rewriter.getF32FloatAttr(0),
2549           rewriter.getF32FloatAttr(0));
2550 
2551       return clamp_op.getResult();
2552     } else {
2553       op->emitWarning("convertFusedActivation: Not implemented yet");
2554       return llvm::None;
2555     }
2556   } else {
2557     if (fused_activation_fn.getValue() == "NONE") {
2558       return input_value;
2559     } else {
2560       // For non-quantized type, only support F32.
2561       if (!input_type.getElementType().isF32()) {
2562         op->emitOpError("ConvertTFLeakyReluOp: only support F32");
2563         return llvm::None;
2564       }
2565 
2566       if (fused_activation_fn.getValue() == "RELU") {
2567         return CreateOpAndInfer<tosa::ClampOp>(
2568                    rewriter, op->getLoc(), input_type, input_value,
2569                    rewriter.getI64IntegerAttr(0),
2570                    rewriter.getI64IntegerAttr(
2571                        std::numeric_limits<int32_t>::max()),
2572                    rewriter.getF32FloatAttr(0.0f),
2573                    rewriter.getF32FloatAttr(std::numeric_limits<float>::max()))
2574             .getResult();
2575       } else if (fused_activation_fn.getValue() == "RELU6") {
2576         return CreateOpAndInfer<tosa::ClampOp>(
2577                    rewriter, op->getLoc(), input_type, input_value,
2578                    rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(6),
2579                    rewriter.getF32FloatAttr(0.0f),
2580                    rewriter.getF32FloatAttr(6.0f))
2581             .getResult();
2582       } else if (fused_activation_fn.getValue() == "RELU_N1_TO_1") {
2583         return CreateOpAndInfer<tosa::ClampOp>(rewriter, op->getLoc(),
2584                                                input_type, input_value,
2585                                                rewriter.getI64IntegerAttr(-1),
2586                                                rewriter.getI64IntegerAttr(1),
2587                                                rewriter.getF32FloatAttr(-1.0),
2588                                                rewriter.getF32FloatAttr(1.0))
2589             .getResult();
2590       } else if (fused_activation_fn.getValue() == "TANH") {
2591         return CreateOpAndInfer<tosa::TanhOp>(rewriter, op->getLoc(),
2592                                               input_type, input_value)
2593             .getResult();
2594       } else {
2595         // Unsupported activation type. Bail out.
2596         return llvm::None;
2597       }
2598     }
2599   }
2600 
2601   return llvm::None;
2602 }
2603 
2604 template <typename T>
convertGenericReduceOp(PatternRewriter & rewriter,Operation * op,Value input,Type input_etype,Type reduce_etype,ArrayRef<int64_t> input_shape,ArrayRef<int64_t> axes)2605 static Value convertGenericReduceOp(PatternRewriter& rewriter, Operation* op,
2606                                     Value input, Type input_etype,
2607                                     Type reduce_etype,
2608                                     ArrayRef<int64_t> input_shape,
2609                                     ArrayRef<int64_t> axes) {
2610   Location loc = op->getLoc();
2611   int64_t input_rank = input_shape.size();
2612   llvm::SmallVector<int32_t> perms;
2613   llvm::SmallVector<int64_t> reshape_shape;
2614   perms.reserve(input_rank);
2615   reshape_shape.resize(2, 1);
2616 
2617   // First insert all non-reduction axes.
2618   for (int i = 0; i < input_rank; i++) {
2619     auto it = std::find(axes.begin(), axes.end(), i);
2620     if (it == axes.end()) {
2621       perms.push_back(i);
2622       reshape_shape[0] *= input_shape[i];
2623     }
2624   }
2625 
2626   // Then insert all reduction matrices.
2627   for (auto axis : axes) {
2628     perms.push_back(axis);
2629     reshape_shape[1] *= input_shape[axis];
2630   }
2631 
2632   Value perms_value =
2633       getConstTensor<int32_t>(rewriter, op, perms,
2634                               {static_cast<int64_t>(perms.size())})
2635           .getValue();
2636 
2637   auto transpose_op = CreateOpAndInfer<tosa::TransposeOp>(
2638       rewriter, loc, UnrankedTensorType::get(rewriter.getI32Type()), input,
2639       perms_value);
2640 
2641   auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
2642       rewriter, loc, RankedTensorType::get(reshape_shape, input_etype),
2643       transpose_op, rewriter.getI64ArrayAttr(reshape_shape));
2644 
2645   return CreateOpAndInfer<T>(rewriter, loc,
2646                              UnrankedTensorType::get(reduce_etype), reshape_op,
2647                              rewriter.getI64IntegerAttr(1));
2648 }
2649 
2650 // Common function for lowering reduce operations to TOSA ops.
2651 template <typename T>
convertReduceOpCommon(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,Type reduce_element_type,bool is_quantized,double input_scale,int64_t input_zp,double output_scale,int64_t output_zp)2652 llvm::Optional<Value> convertReduceOpCommon(
2653     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2654     Value input_value, ElementsAttr axes_elems, Type reduce_element_type,
2655     bool is_quantized, double input_scale, int64_t input_zp,
2656     double output_scale, int64_t output_zp) {
2657   RankedTensorType input_type =
2658       input_value.getType().dyn_cast<RankedTensorType>();
2659   if (!input_type) return llvm::None;
2660 
2661   ArrayRef<int64_t> input_shape = input_type.getShape();
2662   ArrayRef<int64_t> output_shape = output_type.getShape();
2663   auto input_rank = input_shape.size();
2664   Location loc = op->getLoc();
2665 
2666   if (axes_elems.getNumElements() == 0) {
2667     // No axes means return the original tensor.
2668     auto identity_op = CreateOpAndInfer<tosa::IdentityOp>(
2669         rewriter, loc, output_type, input_value);
2670     return identity_op.getResult();
2671   }
2672 
2673   // Handle negative axis and guarantee in increasing order.
2674   llvm::SmallVector<int64_t> axes;
2675   for (auto axis : axes_elems.getValues<IntegerAttr>()) {
2676     auto axis_val = axis.getInt();
2677     if (axis_val < 0) axis_val += input_rank;
2678     axes.push_back(axis_val);
2679   }
2680 
2681   // Reduction operations are limited to 4D tensors, restructure to obey this
2682   // restriction. We transpose all reduction axis to the RHS, then reshape
2683   // such that all reduction and non-reduction values are grouped. This forms a
2684   // 2D matrix in all cases.
2685   Value val = input_value;
2686   if (input_rank > 4) {
2687     val = convertGenericReduceOp<T>(rewriter, op, val,
2688                                     input_type.getElementType(),
2689                                     reduce_element_type, input_shape, axes);
2690   } else {
2691     // Reduce along each axis
2692     SmallVector<int64_t> shape_vec(input_shape.begin(), input_shape.end());
2693 
2694     if (is_quantized) {
2695       val = buildRescaleToInt32(rewriter, op, val, input_scale, input_zp);
2696     }
2697 
2698     for (auto axis_val : axes) {
2699       if (axis_val < 0) axis_val += input_rank;
2700       auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
2701 
2702       shape_vec[axis_val] = 1;
2703       RankedTensorType reduce_type =
2704           RankedTensorType::get(shape_vec, reduce_element_type);
2705 
2706       auto reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
2707                                            val, axis_attr);
2708 
2709       val = reduce_op.getResult();
2710     }
2711   }
2712 
2713   if (is_quantized) {
2714     UnrankedTensorType output_rescale_type =
2715         UnrankedTensorType::get(output_type.getElementType());
2716     val = buildRescale(rewriter, op, output_rescale_type, val, output_scale, 0,
2717                        output_zp, false, true);
2718   }
2719 
2720   // Squeeze out the reduced axes.
2721   auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
2722       rewriter, op->getLoc(), output_type, val,
2723       rewriter.getI64ArrayAttr(output_shape));
2724   return reshape_op.getResult();
2725 }
2726 
2727 // Lowers ReduceAll to a sequence of TOSA ops.
convertReduceAllOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems)2728 llvm::Optional<Value> convertReduceAllOp(PatternRewriter& rewriter,
2729                                          Operation* op,
2730                                          RankedTensorType output_type,
2731                                          Value input_value,
2732                                          ElementsAttr axes_elems) {
2733   RankedTensorType input_type =
2734       input_value.getType().dyn_cast<RankedTensorType>();
2735   if (!input_type) return llvm::None;
2736 
2737   return convertReduceOpCommon<tosa::ReduceAllOp>(
2738       rewriter, op, output_type, input_value, axes_elems,
2739       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2740 }
2741 
2742 // Lowers ReduceAny to a sequence of TOSA ops.
convertReduceAnyOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems)2743 llvm::Optional<Value> convertReduceAnyOp(PatternRewriter& rewriter,
2744                                          Operation* op,
2745                                          RankedTensorType output_type,
2746                                          Value input_value,
2747                                          ElementsAttr axes_elems) {
2748   RankedTensorType input_type =
2749       input_value.getType().dyn_cast<RankedTensorType>();
2750   if (!input_type) return llvm::None;
2751 
2752   return convertReduceOpCommon<tosa::ReduceAnyOp>(
2753       rewriter, op, output_type, input_value, axes_elems,
2754       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2755 }
2756 
2757 // Lowers ReduceMin to a sequence of TOSA ops.
convertReduceMinOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems)2758 llvm::Optional<Value> convertReduceMinOp(PatternRewriter& rewriter,
2759                                          Operation* op,
2760                                          RankedTensorType output_type,
2761                                          Value input_value,
2762                                          ElementsAttr axes_elems) {
2763   RankedTensorType input_type =
2764       input_value.getType().dyn_cast<RankedTensorType>();
2765   if (!input_type) return llvm::None;
2766 
2767   return convertReduceOpCommon<tosa::ReduceMinOp>(
2768       rewriter, op, output_type, input_value, axes_elems,
2769       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2770 }
2771 
2772 // Lowers ReduceMax to a sequence of TOSA ops.
convertReduceMaxOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems)2773 llvm::Optional<Value> convertReduceMaxOp(PatternRewriter& rewriter,
2774                                          Operation* op,
2775                                          RankedTensorType output_type,
2776                                          Value input_value,
2777                                          ElementsAttr axes_elems) {
2778   RankedTensorType input_type =
2779       input_value.getType().dyn_cast<RankedTensorType>();
2780   if (!input_type) return llvm::None;
2781 
2782   return convertReduceOpCommon<tosa::ReduceMaxOp>(
2783       rewriter, op, output_type, input_value, axes_elems,
2784       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2785 }
2786 
2787 // Lowers ReduceProd to a sequence of TOSA ops.
convertReduceProdOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems)2788 llvm::Optional<Value> convertReduceProdOp(PatternRewriter& rewriter,
2789                                           Operation* op,
2790                                           RankedTensorType output_type,
2791                                           Value input_value,
2792                                           ElementsAttr axes_elems) {
2793   RankedTensorType input_type =
2794       input_value.getType().dyn_cast<RankedTensorType>();
2795   if (!input_type) return llvm::None;
2796 
2797   bool input_is_qtype =
2798       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2799   bool output_is_qtype =
2800       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2801 
2802   if (input_is_qtype || output_is_qtype) {
2803     op->emitOpError(
2804         "ConvertReduceProdOp: input/output tensor should "
2805         "be all floating-point.");
2806     return llvm::None;
2807   }
2808 
2809   return convertReduceOpCommon<tosa::ReduceProdOp>(
2810       rewriter, op, output_type, input_value, axes_elems,
2811       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2812 }
2813 
2814 // Lowers ReduceSum to a sequence of TOSA ops.
convertReduceSumOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems)2815 llvm::Optional<Value> convertReduceSumOp(PatternRewriter& rewriter,
2816                                          Operation* op,
2817                                          RankedTensorType output_type,
2818                                          Value input_value,
2819                                          ElementsAttr axes_elems) {
2820   RankedTensorType input_type =
2821       input_value.getType().dyn_cast<RankedTensorType>();
2822   if (!input_type) return llvm::None;
2823 
2824   bool input_is_qtype =
2825       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2826   bool output_is_qtype =
2827       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2828 
2829   if (input_is_qtype != output_is_qtype) {
2830     op->emitOpError(
2831         "ConvertReduceSumOp: input/output tensor should "
2832         "be all quantized or all floating-point.");
2833     return llvm::None;
2834   }
2835 
2836   double input_scale = 1.0f;
2837   double output_scale = 1.0f;
2838   int64_t input_zp = 0;
2839   int64_t output_zp = 0;
2840   Type reduce_element_type = input_type.getElementType();
2841 
2842   if (input_is_qtype) {
2843     auto input_qtype =
2844         input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2845     auto output_qtype =
2846         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2847 
2848     int32_t input_shift = 20;
2849 
2850     input_scale =
2851         static_cast<double>(1 << input_shift) * input_qtype.getScale();
2852     output_scale =
2853         1.0 / (output_qtype.getScale() * static_cast<double>(1 << input_shift));
2854 
2855     input_zp = input_qtype.getZeroPoint();
2856     output_zp = output_qtype.getZeroPoint();
2857     reduce_element_type = rewriter.getI32Type();
2858   }
2859 
2860   return convertReduceOpCommon<tosa::ReduceSumOp>(
2861       rewriter, op, output_type, input_value, axes_elems, reduce_element_type,
2862       input_is_qtype, input_scale, input_zp, output_scale, output_zp);
2863 }
2864 
2865 // Lowers ReduceMean to a sequence of TOSA ops.
convertReduceMeanOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems)2866 llvm::Optional<Value> convertReduceMeanOp(PatternRewriter& rewriter,
2867                                           Operation* op,
2868                                           RankedTensorType output_type,
2869                                           Value input_value,
2870                                           ElementsAttr axes_elems) {
2871   // reduce_mean is lowered as followed:
2872   // op1 = reduce_sum(input)
2873   // op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
2874 
2875   RankedTensorType input_type =
2876       input_value.getType().dyn_cast<RankedTensorType>();
2877   if (!input_type) return llvm::None;
2878 
2879   bool input_is_qtype =
2880       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2881   bool output_is_qtype =
2882       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2883 
2884   if (input_is_qtype != output_is_qtype) {
2885     op->emitOpError(
2886         "ConvertReduceSumOp: input/output tensor should "
2887         "be all quantized or all floating-point.");
2888     return llvm::None;
2889   }
2890 
2891   // Only supports float type mean() if it's non-quantized
2892   if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) {
2893     op->emitWarning(
2894         "Failed convertReduceMean: input unquantized type but output element "
2895         "not FloatType!");
2896     return llvm::None;
2897   }
2898 
2899   int64_t input_rank = input_type.getRank();
2900   int64_t num_elems_on_reduced_axis = 1;
2901   for (int i = 0; i < axes_elems.getNumElements(); i++) {
2902     int64_t axis_val = axes_elems.getValues<IntegerAttr>()[i].getInt();
2903     if (axis_val < 0) axis_val += input_rank;
2904     num_elems_on_reduced_axis *= input_type.getShape()[axis_val];
2905   }
2906   double div_scale = 1.0 / static_cast<double>(num_elems_on_reduced_axis);
2907 
2908   double input_scale = 1.0f;
2909   double output_scale = 1.0f;
2910   int64_t input_zp = 0;
2911   int64_t output_zp = 0;
2912   Type reduce_element_type = input_type.getElementType();
2913 
2914   if (input_is_qtype) {
2915     auto input_qtype =
2916         input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2917     auto output_qtype =
2918         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2919 
2920     // Combine 'div_scale' as part of output rescale
2921     output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale();
2922 
2923     input_zp = input_qtype.getZeroPoint();
2924     output_zp = output_qtype.getZeroPoint();
2925     reduce_element_type = rewriter.getI32Type();
2926   }
2927 
2928   auto val = convertReduceOpCommon<tosa::ReduceSumOp>(
2929       rewriter, op, output_type, input_value, axes_elems, reduce_element_type,
2930       input_is_qtype, input_scale, input_zp, output_scale, output_zp);
2931 
2932   if (!val.has_value()) return llvm::None;
2933 
2934   if (!input_is_qtype) {
2935     Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
2936     return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
2937                                          val.getValue(), div_const, 0)
2938         .getResult();
2939   }
2940 
2941   return val;
2942 }
2943 
2944 // Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize.
convertResizeOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,StringRef mode,bool align_corners,bool half_pixel_centers)2945 llvm::Optional<Value> convertResizeOp(PatternRewriter& rewriter, Operation* op,
2946                                       RankedTensorType output_type,
2947                                       Value input_value, StringRef mode,
2948                                       bool align_corners,
2949                                       bool half_pixel_centers) {
2950   RankedTensorType input_type =
2951       input_value.getType().dyn_cast<RankedTensorType>();
2952   if (!input_type) return llvm::None;
2953 
2954   if (input_type.getRank() != 4 || output_type.getRank() != 4) {
2955     op->emitOpError("convertResizeOp: input/output must be rank 4");
2956     return llvm::None;
2957   }
2958 
2959   bool input_is_qtype =
2960       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2961   bool output_is_qtype =
2962       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2963 
2964   if (input_is_qtype != output_is_qtype) {
2965     op->emitOpError(
2966         "ConvertResizeOp: input/output tensor should "
2967         "be all quantized or all floating-point.");
2968     return llvm::None;
2969   }
2970 
2971   if (!input_is_qtype) {
2972     if (!input_type.getElementType().isa<mlir::FloatType>()) {
2973       op->emitOpError(
2974           "ConvertResizeOp: only quantized or float types supported.");
2975       return llvm::None;
2976     }
2977   }
2978 
2979   auto input_shape = input_type.getShape();
2980   auto output_shape = output_type.getShape();
2981 
2982   if (input_type.isDynamicDim(1) || input_type.isDynamicDim(2)) {
2983     op->emitOpError("ConvertResizeOp: resize dynamic input not supported.");
2984     return llvm::None;
2985   }
2986 
2987   if (output_type.isDynamicDim(1) || output_type.isDynamicDim(2)) {
2988     op->emitOpError("ConvertResizeOp: resize dynamic output not supported.");
2989     return llvm::None;
2990   }
2991 
2992   size_t input_height = input_shape[1];
2993   size_t input_width = input_shape[2];
2994   size_t output_height = output_shape[1];
2995   size_t output_width = output_shape[2];
2996 
2997   double fp_stride_y =
2998       static_cast<double>(input_height) / static_cast<double>(output_height);
2999   double fp_stride_x =
3000       static_cast<double>(input_width) / static_cast<double>(output_width);
3001   if (align_corners && output_height > 1) {
3002     fp_stride_y = static_cast<double>(input_height - 1) /
3003                   static_cast<double>(output_height - 1);
3004   }
3005   if (align_corners && output_width > 1) {
3006     fp_stride_x = static_cast<double>(input_width - 1) /
3007                   static_cast<double>(output_width - 1);
3008   }
3009 
3010   double fp_offset_y, fp_offset_x;
3011   if (half_pixel_centers) {
3012     fp_offset_y = fp_stride_y * 0.5f - 0.5f;
3013     fp_offset_x = fp_stride_x * 0.5f - 0.5f;
3014   } else {
3015     fp_offset_y = 0.0f;
3016     fp_offset_x = 0.0f;
3017   }
3018 
3019   // oh * fp_stride_y + fp_offset_y = ix
3020 
3021   ArrayAttr output_size =
3022       rewriter.getI64ArrayAttr({static_cast<int64_t>(output_height),
3023                                 static_cast<int64_t>(output_width)});
3024   StringAttr resize_mode = rewriter.getStringAttr(mode);
3025 
3026   if (input_is_qtype) {
3027     // Magic shift number TFLite resize bilinear use
3028     // reference: tensorflow/lite/kernels/internal/reference/reference_ops.h
3029     int32_t shift = 10;
3030 
3031     // 1.0 is equivalent to (1 << shift) in quantized space.
3032     // Here we noted as unit = (1 << shift).
3033     double unit = static_cast<double>(1 << shift);
3034 
3035     // Stride and Offset is int16.
3036     int32_t stride_y = std::lround(fp_stride_y * unit);
3037     int32_t stride_x = std::lround(fp_stride_x * unit);
3038     int32_t offset_y = std::lround(fp_offset_y * unit);
3039     int32_t offset_x = std::lround(fp_offset_x * unit);
3040 
3041     // Numerically we can decrement shift to let these number fits within 16
3042     // bits but that's not commonly seen and won't match TFLite reference
3043     if (stride_y > std::numeric_limits<int16_t>::max() ||
3044         stride_x > std::numeric_limits<int16_t>::max() ||
3045         stride_y < std::numeric_limits<int16_t>::min() ||
3046         stride_x < std::numeric_limits<int16_t>::min() ||
3047         offset_y > std::numeric_limits<int16_t>::max() ||
3048         offset_x > std::numeric_limits<int16_t>::max() ||
3049         offset_y < std::numeric_limits<int16_t>::min() ||
3050         offset_x < std::numeric_limits<int16_t>::min()) {
3051       op->emitOpError("OpResize: stride or offset out of 16 bits");
3052       return llvm::None;
3053     }
3054 
3055     ArrayAttr stride = rewriter.getI64ArrayAttr({stride_y, stride_x});
3056     ArrayAttr offset = rewriter.getI64ArrayAttr({offset_y, offset_x});
3057     IntegerAttr shift_attr = rewriter.getI32IntegerAttr(shift);
3058 
3059     // If quantized bilinear mode, need to lower to RESIZE + RESCALE pair.
3060     if (mode == "BILINEAR") {
3061       RankedTensorType output_acc_type;
3062       auto input_element_qtype =
3063           input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
3064 
3065       bool scale32;
3066 
3067       // TOSA RESIZE: 16 bit input -> 48 bit output, or 8 bit input -> 32 bit
3068       // output.
3069       if (input_element_qtype.getStorageTypeIntegralWidth() == 16) {
3070         scale32 = false;
3071         output_acc_type = RankedTensorType::get(output_type.getShape(),
3072                                                 rewriter.getIntegerType(48));
3073       } else if (input_element_qtype.getStorageTypeIntegralWidth() == 8) {
3074         scale32 = true;
3075         output_acc_type = RankedTensorType::get(output_type.getShape(),
3076                                                 rewriter.getI32Type());
3077       } else {
3078         op->emitOpError("OpResize: support 16-bit and 8-bit quantized input");
3079         return llvm::None;
3080       }
3081 
3082       auto resize_op = CreateOpAndInfer<tosa::ResizeOp>(
3083           rewriter, op->getLoc(), output_acc_type, input_value, output_size,
3084           stride, offset, shift_attr, rewriter.getF32ArrayAttr({0.0, 0.0}),
3085           rewriter.getF32ArrayAttr({0.0, 0.0}), resize_mode);
3086 
3087 #ifdef RESIZE_BILINEAR_LOWER_SYMMETRIC_ROUNDING
3088       // TFLite resize_bilinear always assume input and output tensors have
3089       // same scale That means we only need to arithmetic right shift with
3090       // (2 * shift)
3091       // TODO(suderman): Align TFLite rounding behavior
3092       // TFLite also uses symmetric rounding by doing 'x / (1 << 20)'
3093       // TOSA arithmetic right shift is doing standard rounding.
3094       // Right now it's legalized using GreaterEqualOp + SelectOp to conform
3095       // to TFLite reference. But this eventually should be fixed in TFLite
3096       // reference
3097       Value cst_zero = getTosaConstTensorSingleI32(rewriter, op, 0);
3098       Value cst_twenty = getTosaConstTensorSingleI32(rewriter, op, 20);
3099 
3100       auto ge_op = CreateOpAndInfer<tosa::GreaterEqualOp>(
3101           rewriter, op->getLoc(), output_bool_type, resize_op.getResult(),
3102           cst_zero);
3103 
3104       auto abs_op = CreateOpAndInfer<tosa::AbsOp>(
3105           rewriter, op->getLoc(), output_acc_type, resize_op.getResult());
3106 
3107       auto rshift_op = CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
3108           rewriter, op->getLoc(), output_acc_type, abs_op.getResult(),
3109           cst_twenty, true);
3110 
3111       auto negate_op = CreateOpAndInfer<tosa::NegateOp>(
3112           rewriter, op->getLoc(), output_acc_type, rshift_op.getResult());
3113 
3114       auto select_op = CreateOpAndInfer<tosa::SelectOp>(
3115           rewriter, op->getLoc(), output_acc_type, ge_op.getResult(),
3116           rshift_op.getResult(), negate_op.getResult());
3117 
3118       auto cast_op = CreateOpAndInfer<tosa::CastOp>(
3119           rewriter, op->getLoc(), output_type, select_op.getResult());
3120 
3121       return cast_op.getResult();
3122 #else
3123       // This should be the expected lowering, but is +-1 within compared to
3124       // TFLite reference.
3125       return buildRescale(rewriter, op, output_type, resize_op.getResult(),
3126                           1.0 / (1 << 20), 0, 0, false, scale32);
3127 #endif
3128 
3129     } else if (mode == "NEAREST_NEIGHBOR") {
3130       auto resize_op = CreateOpAndInfer<tosa::ResizeOp>(
3131           rewriter, op->getLoc(), output_type, input_value, output_size, stride,
3132           offset, shift_attr, rewriter.getF32ArrayAttr({0.0, 0.0}),
3133           rewriter.getF32ArrayAttr({0.0, 0.0}), resize_mode);
3134       return resize_op.getResult();
3135     } else {
3136       op->emitOpError(
3137           "OpResize: only support BILINEAR or NEAREST_NEIGHBOR mode");
3138       return llvm::None;
3139     }
3140   } else {
3141     auto resize_op = CreateOpAndInfer<tosa::ResizeOp>(
3142         rewriter, op->getLoc(), output_type, input_value, output_size,
3143         rewriter.getI64ArrayAttr({0, 0}), rewriter.getI64ArrayAttr({0, 0}),
3144         rewriter.getI32IntegerAttr(0),
3145         rewriter.getF32ArrayAttr(
3146             {static_cast<float>(fp_stride_y), static_cast<float>(fp_stride_x)}),
3147         rewriter.getF32ArrayAttr(
3148             {static_cast<float>(fp_offset_y), static_cast<float>(fp_offset_x)}),
3149         resize_mode);
3150     return resize_op.getResult();
3151   }
3152 }
3153 
3154 // Lowers Quantize to a sequence of TOSA quantization ops.
convertQuantizeOp(PatternRewriter & rewriter,Operation * op,ShapedType output_type,Value input_value,double scale,int64_t zeropoint)3155 llvm::Optional<Value> convertQuantizeOp(PatternRewriter& rewriter,
3156                                         Operation* op, ShapedType output_type,
3157                                         Value input_value, double scale,
3158                                         int64_t zeropoint) {
3159   RankedTensorType input_type =
3160       input_value.getType().dyn_cast<RankedTensorType>();
3161   if (!input_type) return llvm::None;
3162 
3163   auto output_element_type = output_type.getElementType();
3164 
3165   // output element type could only be quantized integer
3166   if (!output_element_type.isa<mlir::quant::QuantizedType>()) {
3167     op->emitWarning(
3168         "Lowering quantizeOp but output element type not quantized!");
3169     return llvm::None;
3170   }
3171 
3172   ShapedType output_fp_type = output_type.clone(rewriter.getF32Type());
3173 
3174   Value zp_val =
3175       getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(zeropoint));
3176 
3177   auto op1_mul_in = CreateOpAndInfer<tosa::MulOp>(
3178       rewriter, op->getLoc(), output_fp_type, input_value,
3179       getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale)), 0);
3180 
3181   auto op2_add_op1 = CreateOpAndInfer<tosa::AddOp>(
3182       rewriter, op->getLoc(), output_fp_type, op1_mul_in.getResult(), zp_val);
3183 
3184   auto op3_cast_op2 = CreateOpAndInfer<tosa::CastOp>(
3185       rewriter, op->getLoc(), output_type, op2_add_op1.getResult());
3186 
3187   return op3_cast_op2.getResult();
3188 }
3189 
3190 // Lowers Dequantize to a sequence of TOSA dequantization ops.
convertDequantizeOp(PatternRewriter & rewriter,Operation * op,ShapedType output_type,Value input_value,ArrayRef<float> scale,ArrayRef<float> zeropoint,int64_t dim)3191 llvm::Optional<Value> convertDequantizeOp(PatternRewriter& rewriter,
3192                                           Operation* op, ShapedType output_type,
3193                                           Value input_value,
3194                                           ArrayRef<float> scale,
3195                                           ArrayRef<float> zeropoint,
3196                                           int64_t dim) {
3197   RankedTensorType input_type =
3198       input_value.getType().dyn_cast<RankedTensorType>();
3199   if (!input_type) return llvm::None;
3200 
3201   // input element type could only be quantized integer
3202   if (!input_type.getElementType().isa<mlir::quant::QuantizedType>())
3203     return llvm::None;
3204 
3205   Optional<Value> zp_val;
3206   if (zeropoint.size() == 1) {
3207     zp_val = getTosaConstTensorSingleF32(rewriter, op,
3208                                          static_cast<float>(zeropoint[0]));
3209   } else {
3210     SmallVector<int64_t> shape;
3211     shape.resize(input_type.getRank(), 1);
3212     shape[dim] = zeropoint.size();
3213     zp_val = getConstTensor(rewriter, op, zeropoint, shape);
3214   }
3215 
3216   Optional<Value> scale_val;
3217   if (scale.size() == 1) {
3218     scale_val =
3219         getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale[0]));
3220   } else {
3221     SmallVector<int64_t> shape;
3222     shape.resize(input_type.getRank(), 1);
3223     shape[dim] = scale.size();
3224     scale_val = getConstTensor(rewriter, op, scale, shape);
3225   }
3226 
3227   if (!zp_val || !scale_val) return llvm::None;
3228 
3229   auto op1_cast_in = CreateOpAndInfer<tosa::CastOp>(rewriter, op->getLoc(),
3230                                                     output_type, input_value);
3231 
3232   auto op2_sub_op1 =
3233       CreateOpAndInfer<tosa::SubOp>(rewriter, op->getLoc(), output_type,
3234                                     op1_cast_in.getResult(), zp_val.getValue());
3235 
3236   return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
3237                                        op2_sub_op1.getResult(),
3238                                        scale_val.getValue(), 0)
3239       .getResult();
3240 }
3241 
3242 // Lowers FakeQuant to a sequence of TOSA quantization ops.
convertFakeQuantOp(PatternRewriter & rewriter,Operation * op,ShapedType output_type,Value input_value,double min,double max,int64_t num_bits,bool narrow_range)3243 llvm::Optional<Value> convertFakeQuantOp(PatternRewriter& rewriter,
3244                                          Operation* op, ShapedType output_type,
3245                                          Value input_value, double min,
3246                                          double max, int64_t num_bits,
3247                                          bool narrow_range) {
3248   // FakeQuant is lowered as follow:
3249   // op1 = quantize(input)
3250   // op2 = dequantize(op1)
3251 
3252   RankedTensorType input_type =
3253       input_value.getType().dyn_cast<RankedTensorType>();
3254   if (!input_type) return llvm::None;
3255 
3256   // quantized as INT<num_bits>, where num_bits can only be 8, 16
3257   if (num_bits != 8 && num_bits != 16) {
3258     op->emitWarning("FakeQuantOp lowering handles only 8 and 16 for num_bits!");
3259     return llvm::None;
3260   }
3261 
3262   // This code originates from
3263   // tensorflow/core/kernels/fake_quant_ops_functor.h.
3264   int32_t qmax = (1 << (num_bits)) - 1;
3265   int32_t qmin = narrow_range ? 1 : 0;
3266 
3267   float nudged_min, nudged_max, nudged_scale;
3268   tensorflow_nudge(min, max, qmin, qmax, &nudged_min, &nudged_max,
3269                    &nudged_scale);
3270 
3271   Value cst_min = getTosaConstTensorSingleF32(rewriter, op, nudged_min);
3272   Value cst_max = getTosaConstTensorSingleF32(rewriter, op, nudged_max);
3273   Value cst_scale = getTosaConstTensorSingleF32(rewriter, op, nudged_scale);
3274   Value cst_inv_scale =
3275       getTosaConstTensorSingleF32(rewriter, op, 1.0f / nudged_scale);
3276   Value cst_half = getTosaConstTensorSingleF32(rewriter, op, 0.5f);
3277 
3278   // This code originates from
3279   // tensorflow/core/kernels/fake_quant_ops_functor.h.
3280   auto op1_min_in = CreateOpAndInfer<tosa::MinimumOp>(
3281       rewriter, op->getLoc(), output_type, input_value, cst_max);
3282 
3283   auto op2_max_op1 = CreateOpAndInfer<tosa::MaximumOp>(
3284       rewriter, op->getLoc(), output_type, op1_min_in.getResult(), cst_min);
3285 
3286   auto op3_sub_op2 = CreateOpAndInfer<tosa::SubOp>(
3287       rewriter, op->getLoc(), output_type, op2_max_op1.getResult(), cst_min);
3288 
3289   auto op4_mul_op3 =
3290       CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
3291                                     op3_sub_op2.getResult(), cst_inv_scale, 0);
3292 
3293   auto op5_add_op4 = CreateOpAndInfer<tosa::AddOp>(
3294       rewriter, op->getLoc(), output_type, op4_mul_op3.getResult(), cst_half);
3295 
3296   auto op6_floor_op5 = CreateOpAndInfer<tosa::FloorOp>(
3297       rewriter, op->getLoc(), output_type, op5_add_op4.getResult());
3298 
3299   auto op7_mul_op6 =
3300       CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
3301                                     op6_floor_op5.getResult(), cst_scale, 0);
3302 
3303   return CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(), output_type,
3304                                        op7_mul_op6.getResult(), cst_min)
3305       .getResult();
3306 }
3307 
convertTFConv2DCommon(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input,Value filter,Value bias,ArrayAttr strides_attr,ArrayAttr dilations_attr,ArrayAttr explicit_padding_attr,StringRef padding_ref,StringRef data_format_ref)3308 llvm::Optional<Value> convertTFConv2DCommon(
3309     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
3310     Value input, Value filter, Value bias, ArrayAttr strides_attr,
3311     ArrayAttr dilations_attr, ArrayAttr explicit_padding_attr,
3312     StringRef padding_ref, StringRef data_format_ref) {
3313   RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
3314   RankedTensorType filter_type = filter.getType().dyn_cast<RankedTensorType>();
3315   // Not a ranked tensor output
3316   if (!input_type) return llvm::None;
3317   if (!filter_type) return llvm::None;
3318 
3319   // Transpose [H, W, I, O] to [O, H, W, I]
3320   auto filter_shape = filter_type.getShape();
3321   SmallVector<int64_t, 4> a1_transpose_dims;
3322   a1_transpose_dims.push_back(filter_shape[3]);
3323   a1_transpose_dims.push_back(filter_shape[0]);
3324   a1_transpose_dims.push_back(filter_shape[1]);
3325   a1_transpose_dims.push_back(filter_shape[2]);
3326   llvm::Optional<Value> a1_filter_transpose_perm = getConstTensor<int32_t>(
3327       rewriter, op, /*vec=*/{3, 0, 1, 2}, /*shape=*/{4});
3328 
3329   if (!a1_filter_transpose_perm) return llvm::None;
3330 
3331   auto a1_filter_transpose_op = CreateOpAndInfer<tosa::TransposeOp>(
3332       rewriter, op->getLoc(),
3333       RankedTensorType::get(a1_transpose_dims, filter_type.getElementType()),
3334       filter, a1_filter_transpose_perm.getValue());
3335 
3336   // Only support NHWC now.
3337   if (data_format_ref.str() != "NHWC") {
3338     op->emitWarning("convertTDConv2DCommon only supports NHWC!");
3339     return llvm::None;
3340   }
3341 
3342   ArrayAttr stride;
3343   ArrayAttr dilation;
3344   ArrayAttr pad;
3345   {
3346     if (!strides_attr) {
3347       stride = rewriter.getI64ArrayAttr({1, 1});
3348     } else {
3349       // Note: hardcoded to NHWC for now
3350       int64_t stride_h = strides_attr[1].cast<IntegerAttr>().getInt();
3351       int64_t stride_w = strides_attr[2].cast<IntegerAttr>().getInt();
3352       stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
3353     }
3354   }
3355   {
3356     if (!dilations_attr) {
3357       dilation = rewriter.getI64ArrayAttr({1, 1});
3358     } else {
3359       // Note: hardcoded to NHWC for now
3360       int64_t dilation_h = dilations_attr[1].cast<IntegerAttr>().getInt();
3361       int64_t dilation_w = dilations_attr[2].cast<IntegerAttr>().getInt();
3362       dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
3363     }
3364   }
3365   {
3366     tensorflow::Padding tf_pad;
3367     if (!GetPaddingFromString(padding_ref.str(), &tf_pad).ok()) {
3368       op->emitWarning("Could not get padding data from padding string term!");
3369       return llvm::None;
3370     }
3371 
3372     tensorflow::TensorFormat data_format_tf;
3373     if (!FormatFromString(data_format_ref.str(), &data_format_tf))
3374       return llvm::None;
3375 
3376     if (tf_pad == tensorflow::Padding::EXPLICIT) {
3377       pad = getPaddingValuesFromExplicitPadAttr(explicit_padding_attr,
3378                                                 data_format_tf, rewriter);
3379     } else {
3380       if (!getPaddingValuesFromPadType(tf_pad, data_format_tf,
3381                                        0,  // tensorflow::FORMAT_HWIO
3382                                        input_type, filter_type, stride,
3383                                        dilation, rewriter, pad))
3384         return llvm::None;
3385     }
3386   }
3387 
3388   return CreateOpAndInfer<tosa::Conv2DOp>(
3389              rewriter, op->getLoc(), output_type, input,
3390              a1_filter_transpose_op.getResult(), bias, pad, stride, dilation)
3391       .getResult();
3392 }
3393 
3394 // Lowers Gather operators to a sequence of TOSA ops.
convertGatherOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value params_value,Value indices_value,int32_t batch_dims,int32_t axis)3395 llvm::Optional<Value> convertGatherOp(PatternRewriter& rewriter, Operation* op,
3396                                       Value result_value, Value params_value,
3397                                       Value indices_value, int32_t batch_dims,
3398                                       int32_t axis) {
3399   auto result_type = result_value.getType().dyn_cast<ShapedType>();
3400   auto params_type = params_value.getType().dyn_cast<RankedTensorType>();
3401   auto indices_type = indices_value.getType().dyn_cast<RankedTensorType>();
3402 
3403   if (!result_type || !params_type || !indices_type) return llvm::None;
3404 
3405   // batch_dims indicates the number of batch dimensions in params and
3406   // indices axis indicates the axis at which the gather indexing is
3407   // applied.  axis must be >= batch_dims.  When axis is equal to
3408   // batch_dims, the right-most batch dimension disappears.
3409   //
3410   // N: number of batches
3411   // Computed as product of params.shape[0:batch_dims-1]
3412   //
3413   // W: number of indices in each batch
3414   // Computed as product of indices.shape[batch_dims:]
3415   //
3416   // K: range of each index
3417   // Computed as  params.shape[axis:axis+rank(indices)-1]
3418   //
3419   // C: number of channels for each index
3420   // Computed as:  LeftChannels * RightChannels:
3421   // product(params.shape[batch_dims:axis]) * product(params.shape[axis+1:])
3422   //
3423   // The params tensor needs to be transposed, then reshaped to move the
3424   // dimensions into [N, K, C] order.
3425   //
3426   // The dimensions of the input params[] tensor are grouped in the following
3427   // order to begin with:
3428   //
3429   //  [Batch, LeftChannels, Indices, RightChannels]
3430   //  |-----||------------||-------||-------------|
3431   //     N         C_l         K          C_r
3432   //
3433   // Where Batch (N), Indices (K) can be one or more dimensions in size,
3434   // while LeftChannels and RightChannels represent the group of data channels
3435   // (C) to the left and right (C_l, C_r) of the indices; the sum of these two
3436   // is one or more dimensions in size, but either one may be zero depending
3437   // on how axis was specified by the caller.
3438   //
3439   // The resulting tensor will look like:
3440   //
3441   //  [Batch, Indices, LeftChannels, RightChannels]
3442   //  |-----||-------||---------------------------|
3443   //     N       K                 C
3444   //
3445   // The indices tensor simply needs a reshape to flatten all of the
3446   // batch dimensions (N) together and flatten all of the indices (W)
3447   // together.
3448   //
3449   // Then do the tosa.GATHER
3450   //
3451   // output[N,W,C] = tosa.GATHER(values[N,K,C], indices[N,W])
3452   //
3453   // Finally, the resulting tensor will have shape [N, W, C], where C is a
3454   // flattened version of [LeftChannels, RightChannels].  We need to reshape
3455   // to unflatten to:
3456   //
3457   //  [N, W, LeftChannels, RightChannels]
3458   //
3459   // and finally transpose back to the output shape
3460   //
3461   //  [Batch, LeftChannels, Non-Batch-Indices, RightChannels]
3462 
3463   int params_rank = params_type.getShape().size();
3464   int indices_rank = indices_type.getShape().size();
3465 
3466   if (!(batch_dims <= indices_rank)) {
3467     op->emitOpError("Batch_dims must be <= indices_rank for a valid gather op");
3468     return llvm::None;
3469   }
3470 
3471   if (!(axis >= batch_dims)) {
3472     op->emitOpError("axis must be >= batch_dims for a valid gather op");
3473     return llvm::None;
3474   }
3475 
3476   // Sizes for each of these fields.
3477   SmallVector<int64_t> params_batch, params_indices, params_left_channels,
3478       params_right_channels;
3479 
3480   // Dimension indices for each of these fields.
3481   SmallVector<int64_t> params_idx_batch, params_idx_indices,
3482       params_idx_left_channels, params_idx_right_channels;
3483 
3484   // Read through the params tensor dimensions left-to-right and extract the
3485   // different fields.
3486   for (int i = 0; i < params_rank; i++) {
3487     // When batch_dims == axis, the batch dimension gets replaced.
3488     if (i < batch_dims && i < axis) {
3489       params_batch.push_back(params_type.getShape()[i]);
3490       params_idx_batch.push_back(i);
3491     } else if (i < axis) {
3492       params_left_channels.push_back(params_type.getShape()[i]);
3493       params_idx_left_channels.push_back(i);
3494     } else if (i < (axis + 1)) {
3495       params_indices.push_back(params_type.getShape()[i]);
3496       params_idx_indices.push_back(i);
3497     } else {
3498       params_right_channels.push_back(params_type.getShape()[i]);
3499       params_idx_right_channels.push_back(i);
3500     }
3501   }
3502 
3503   // Calculate N, K, W, C
3504   int64_t N = multiply_dims(params_type.getShape().take_front(batch_dims));
3505   int64_t W = multiply_dims(
3506       indices_type.getShape().slice(batch_dims, indices_rank - batch_dims));
3507   int64_t K = params_type.getShape()[axis];
3508 
3509   int64_t C = multiply_dims(
3510       params_type.getShape().slice(batch_dims, axis - batch_dims));
3511   C = multiply_dims(
3512       params_type.getShape().slice(axis + 1, params_rank - axis - 1), C);
3513 
3514   /////////////////////////////////////////////
3515   // Build up the params transpose operator
3516   SmallVector<int32_t> params_transpose_perm;
3517   SmallVector<int64_t> params_transpose_shape;
3518 
3519   // Batch
3520   for (int i = 0; i < params_batch.size(); i++) {
3521     params_transpose_perm.push_back(params_idx_batch[i]);
3522     params_transpose_shape.push_back(params_batch[i]);
3523   }
3524 
3525   // Indices
3526   for (int i = 0; i < params_indices.size(); i++) {
3527     params_transpose_perm.push_back(params_idx_indices[i]);
3528     params_transpose_shape.push_back(params_indices[i]);
3529   }
3530 
3531   // LeftChannels
3532   for (int i = 0; i < params_left_channels.size(); i++) {
3533     params_transpose_perm.push_back(params_idx_left_channels[i]);
3534     params_transpose_shape.push_back(params_left_channels[i]);
3535   }
3536 
3537   // RightChannels
3538   for (int i = 0; i < params_right_channels.size(); i++) {
3539     params_transpose_perm.push_back(params_idx_right_channels[i]);
3540     params_transpose_shape.push_back(params_right_channels[i]);
3541   }
3542 
3543   /////////////////////////////////////////////
3544   // Build up the result reshape, in prepration for transpose
3545   // [N, W, C] -> [ Batch, Indices, LeftChannels, RightChannels ]
3546   SmallVector<int64_t> result_reshape_shape;
3547 
3548   // Indices
3549   for (int i = 0; i < indices_type.getShape().size(); i++) {
3550     result_reshape_shape.push_back(indices_type.getShape()[i]);
3551   }
3552 
3553   // Left channels
3554   for (int i = 0; i < params_left_channels.size(); i++) {
3555     result_reshape_shape.push_back(params_left_channels[i]);
3556   }
3557 
3558   // Right channels.  But remove the axis dimension.
3559   for (int i = 0; i < params_right_channels.size(); i++) {
3560     result_reshape_shape.push_back(params_right_channels[i]);
3561   }
3562 
3563   /////////////////////////////////////////////
3564   // Build up the result transpose operator.
3565   SmallVector<int32_t> result_transpose_perm;
3566 
3567   // Batch dimensions
3568   for (int i = 0; i < batch_dims; i++) {
3569     result_transpose_perm.push_back(i);
3570   }
3571 
3572   // LeftChannels
3573   for (int i = 0; i < params_left_channels.size(); i++) {
3574     result_transpose_perm.push_back(i + indices_type.getShape().size());
3575   }
3576 
3577   // Indices (remainder of dimensions after batch).
3578   for (int i = batch_dims; i < (indices_type.getShape().size()); i++) {
3579     result_transpose_perm.push_back(i);
3580   }
3581 
3582   // RightChannels, coming from after both the Indices and LeftChannels.
3583   for (int i = 0; i < params_right_channels.size(); i++) {
3584     result_transpose_perm.push_back(i + indices_type.getShape().size() +
3585                                     params_left_channels.size());
3586   }
3587 
3588   SmallVector<int64_t> tosa_values_shape = {N, K, C};
3589   SmallVector<int64_t> tosa_indices_shape = {N, W};
3590   SmallVector<int64_t> tosa_gather_result_shape = {N, W, C};
3591 
3592   llvm::Optional<Value> params_transpose_perm_val = getConstTensor<int32_t>(
3593       rewriter, op, params_transpose_perm,
3594       {static_cast<int64_t>(params_transpose_perm.size())});
3595 
3596   llvm::Optional<Value> result_transpose_perm_val = getConstTensor<int32_t>(
3597       rewriter, op, result_transpose_perm,
3598       {static_cast<int64_t>(result_transpose_perm.size())});
3599 
3600   if (!params_transpose_perm_val || !result_transpose_perm_val)
3601     return llvm::None;
3602 
3603   auto params_transpose_op = CreateOpAndInfer<tosa::TransposeOp>(
3604       rewriter, op->getLoc(),
3605       RankedTensorType::get(params_transpose_shape,
3606                             params_type.getElementType()),
3607       params_value, params_transpose_perm_val.getValue());
3608 
3609   if (count_dynamic_dims(tosa_values_shape) > 1) {
3610     return (void)rewriter.notifyMatchFailure(
3611                op,
3612                "multiply dynamic shapes when reshaping values down to "
3613                "tosa.gather"),
3614            llvm::None;
3615   }
3616 
3617   auto tosa_values_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3618       rewriter, op->getLoc(),
3619       RankedTensorType::get(tosa_values_shape, params_type.getElementType()),
3620       params_transpose_op.getResult(),
3621       rewriter.getI64ArrayAttr(tosa_values_shape));
3622 
3623   if (count_dynamic_dims(tosa_indices_shape) > 1) {
3624     return (void)rewriter.notifyMatchFailure(
3625                op,
3626                "multiply dynamic shapes when reshaping indices down to "
3627                "tosa.gather"),
3628            llvm::None;
3629   }
3630 
3631   auto tosa_indices_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3632       rewriter, op->getLoc(),
3633       RankedTensorType::get(tosa_indices_shape, indices_type.getElementType()),
3634       indices_value, rewriter.getI64ArrayAttr(tosa_indices_shape));
3635 
3636   auto tosa_gather_op = CreateOpAndInfer<tosa::GatherOp>(
3637       rewriter, op->getLoc(),
3638       RankedTensorType::get(tosa_gather_result_shape,
3639                             result_type.getElementType()),
3640       tosa_values_reshape_op.getResult(), tosa_indices_reshape_op.getResult());
3641 
3642   if (count_dynamic_dims(result_reshape_shape) > 1) {
3643     return (void)rewriter.notifyMatchFailure(
3644                op,
3645                "multiply dynamic shapes when reshaping tosa.gather result up"),
3646            llvm::None;
3647   }
3648 
3649   auto tosa_result_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3650       rewriter, op->getLoc(),
3651       RankedTensorType::get(result_reshape_shape, params_type.getElementType()),
3652       tosa_gather_op.getResult(),
3653       rewriter.getI64ArrayAttr(result_reshape_shape));
3654 
3655   return CreateOpAndInfer<tosa::TransposeOp>(
3656              rewriter, op->getLoc(), result_type,
3657              tosa_result_reshape_op.getResult(),
3658              result_transpose_perm_val.getValue())
3659       .getResult();
3660 }
3661 
3662 // Lowers Gather operators to a sequence of TOSA ops.
convertGatherNdOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value params_value,Value indices_value)3663 llvm::Optional<Value> convertGatherNdOp(PatternRewriter& rewriter,
3664                                         Operation* op, Value result_value,
3665                                         Value params_value, Value indices_value)
3666 
3667 {
3668   auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
3669   auto params_type = params_value.getType().dyn_cast<RankedTensorType>();
3670   auto indices_type = indices_value.getType().dyn_cast<RankedTensorType>();
3671 
3672   if (!result_type || !params_type || !indices_type) return llvm::None;
3673 
3674   // N: number of batches
3675   // Always 1 for GatherND
3676   //
3677   // Because TOSA's GATHER operator already uses the symbol 'N' for
3678   // the number of batches, we will use the symbol 'ND' to specify the
3679   // number of dimensions that are sliced from params instead of'N' in
3680   // the TF MLIR documentation.
3681   //
3682   // ND: indices.shape[-1]
3683   //
3684   // W: number of indices in each batch
3685   // Computed as:
3686   // product(indices.shape[0:-1]) (all but the last dimension)
3687   //
3688   // K: range of each index
3689   // Computed as:
3690   // product(params.shape[0:ND-1])
3691   //
3692   // C: number of channels for each index
3693   // Computed as:
3694   // product(params.shape[ND:])
3695   //
3696   // The params tensor needs to be reshaped, but not transposed, to move the
3697   // dimensions into [N, K, C] order.
3698   //
3699   // The dimensions of the input params[] tensor are grouped in the following
3700   // order to begin with:
3701   //
3702   //  [ParamIndices, ParamChannels]
3703   //  |------------||-------------|
3704   //         K              C
3705   //
3706   // The reshape simply flattens the params tensor into a 2D [K, C] shape.
3707   //
3708   // Indices needs to be put in the form of [N, W], but a simple flattening
3709   // will not suffice, because the indices need to index into a [W]-shape
3710   // vector instead of the params.shape[0:ND-1] tensor that we had before.
3711   //
3712   // To flatten the coordinates, first reshape indices to a [W, ND] matrix,
3713   // where the matrix now represents W ND-dimensional coordinates into the
3714   // params tensor.
3715   //
3716   // From here, we take each of the ND dimensions and multiply it with
3717   // the size of the next params dimension (or 1 for the last
3718   // dimension), then sum all these together with a reduce_sum
3719   // operator.  This is exactly the same mathematics as one would use
3720   // flatten the indices of an N-dimensional row-major array into a
3721   // 1-D array in C.
3722   //
3723   // More precisely, do an element-wise multiply with [params.shape[1
3724   // .. ND], 1] in axis 1, then reduce_sum in axis 1 to flatten to a
3725   // [W]-shaped tensor, then trivially reshape to [N=1, W] to be
3726   // compatible with the GATHER operator's shape.
3727   //
3728   // Then perform the tosa.GATHER() operation.
3729   //
3730   // Now we have result = [N, K, C].
3731   //
3732   // Reshape with a single, simple reshape to the final output shape of:
3733   //  [Indices, ParamChannels]
3734   //
3735   // Where, Indices is indices.shape[0:ND-1]
3736 
3737   int N = 1, W = 1, K = 1, C = 1, ND = 1;
3738 
3739   int params_rank = params_type.getShape().size();
3740   int indices_rank = indices_type.getShape().size();
3741 
3742   ND = indices_type.getShape()[indices_rank - 1];
3743 
3744   if (ND > params_rank) {
3745     op->emitOpError("Size of last dimension on indices must be <= params rank");
3746     return llvm::None;
3747   }
3748 
3749   // Calculate N, K, W, C.  (N is always 1)
3750   for (int i = 0; i < (indices_rank - 1); i++) {
3751     W *= indices_type.getShape()[i];
3752   }
3753 
3754   for (int i = 0; i < ND; i++) {
3755     K *= params_type.getShape()[i];
3756   }
3757 
3758   for (int i = ND; i < params_rank; i++) {
3759     C *= params_type.getShape()[i];
3760   }
3761 
3762   SmallVector<int64_t, 3> tosa_values_shape({N, K, C});
3763   SmallVector<int64_t, 2> tosa_indices_shape({N, W});
3764   SmallVector<int64_t, 2> indices_matrix_shape({W, ND});
3765   SmallVector<int64_t, 3> tosa_gather_result_shape({N, W, C});
3766 
3767   auto tosa_values_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3768       rewriter, op->getLoc(),
3769       RankedTensorType::get(tosa_values_shape, params_type.getElementType()),
3770       params_value, rewriter.getI64ArrayAttr(tosa_values_shape));
3771 
3772   // Flatten the input indices tensor to an [W, ND] matrix.
3773   auto indices_matrix_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3774       rewriter, op->getLoc(),
3775       RankedTensorType::get(indices_matrix_shape,
3776                             indices_type.getElementType()),
3777       indices_value, rewriter.getI64ArrayAttr(indices_matrix_shape));
3778 
3779   SmallVector<int32_t> flattened_coeff_vec;
3780   for (int i = 1; i < ND; i++) {
3781     flattened_coeff_vec.push_back(params_type.getShape()[i]);
3782   }
3783   flattened_coeff_vec.push_back(1);
3784 
3785   for (int i = ND - 1; i > 0; i--) {
3786     flattened_coeff_vec[i - 1] *= flattened_coeff_vec[i];
3787   }
3788 
3789   llvm::Optional<Value> flattened_coeff_value = getConstTensor<int32_t>(
3790       rewriter, op, flattened_coeff_vec,
3791       {static_cast<int64_t>(flattened_coeff_vec.size())});
3792 
3793   if (!flattened_coeff_value) return llvm::None;
3794 
3795   // Multiply the coefficients by the coordinates
3796   auto flattened_indices_mul_op = CreateOpAndInfer<tosa::MulOp>(
3797       rewriter, op->getLoc(),
3798       RankedTensorType::get(indices_matrix_shape,
3799                             indices_type.getElementType()),
3800       indices_matrix_reshape_op.getResult(), flattened_coeff_value.getValue(),
3801       0);
3802 
3803   // Sum up the products of the coefficients and coordinates
3804   auto flattened_indices_reduce_op = CreateOpAndInfer<tosa::ReduceSumOp>(
3805       rewriter, op->getLoc(),
3806       RankedTensorType::get(tosa_indices_shape, indices_type.getElementType()),
3807       flattened_indices_mul_op.getResult(), rewriter.getI64IntegerAttr(1));
3808 
3809   // And reshape to [N, W]
3810   auto tosa_indices_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3811       rewriter, op->getLoc(),
3812       RankedTensorType::get(tosa_indices_shape, indices_type.getElementType()),
3813       flattened_indices_reduce_op.getResult(),
3814       rewriter.getI64ArrayAttr(tosa_indices_shape));
3815 
3816   // Now the gather op itself
3817   auto tosa_gather_op = CreateOpAndInfer<tosa::GatherOp>(
3818       rewriter, op->getLoc(),
3819       RankedTensorType::get(tosa_gather_result_shape,
3820                             result_type.getElementType()),
3821       tosa_values_reshape_op.getResult(), tosa_indices_reshape_op.getResult());
3822 
3823   // Finally, reshape back to the original output shape of [Indices,
3824   // ParamChannels].
3825   return CreateOpAndInfer<tosa::ReshapeOp>(
3826              rewriter, op->getLoc(), result_type, tosa_gather_op.getResult(),
3827              rewriter.getI64ArrayAttr(result_type.getShape()))
3828       .getResult();
3829 }
3830 
3831 // Lowers OneHot operator to a sequence of TOSA ops.
convertOneHotOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value indices_value,Value on_value,Value off_value,int32_t depth,int32_t axis)3832 llvm::Optional<Value> convertOneHotOp(PatternRewriter& rewriter, Operation* op,
3833                                       Value result_value, Value indices_value,
3834                                       Value on_value, Value off_value,
3835                                       int32_t depth, int32_t axis) {
3836   auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
3837   auto indices_type = indices_value.getType().dyn_cast<RankedTensorType>();
3838   auto on_value_type = on_value.getType().dyn_cast<RankedTensorType>();
3839   auto off_value_type = off_value.getType().dyn_cast<RankedTensorType>();
3840 
3841   if (!result_type || !indices_type || !on_value_type || !off_value_type)
3842     return llvm::None;
3843 
3844   // OneHot operator creates a new tensor with shape indices.shape[:axis] +
3845   // [depth] + indices.shape[axis:] For each index in 'indices', it needs to
3846   // be within range of [0, depth - 1] and the [..., k, ...] = on_value (if k
3847   // = index), or [..., k, ...] = off_value (if k != index)
3848   //
3849   // The lowering below assumes depth is always known at compile time.
3850   // TBD for depth resolved in run time.
3851   //
3852   // OneHot can be lowered as TOSA Scatter, where off_value being mapped to
3853   // 'values_in', on_value being mapped to 'input', and indices naturally
3854   // mapped to 'indices'. Also the dimensions of TOSA scatter (N, W, K, C)
3855   // need to be picked.
3856   //
3857   // N: number of elements of input indices
3858   // Computed as:
3859   // product(indices.shape[:])
3860   //
3861   // K: newly added dimension
3862   // K = depth
3863   //
3864   // W, C: dummy dimension now
3865   // W = C = 1
3866   //
3867   // High level description of lowering looks like:
3868   // 1. off_value is reshaped/tiled into [N, K, C]
3869   // 2. on_value is reshaped/tiled into [N, W, C]
3870   // 3. indices is reshaped into [N, W]
3871   // 4. scatter into [N, K, C]
3872   // 5. reshaped into [LeftDims, RightDims, K]
3873   // 6. transpose into [LeftDims, K, RightDims]
3874   // 7. reshaped to result.shape
3875 
3876   if (on_value_type.getRank() != 0 || off_value_type.getRank() != 0) {
3877     op->emitOpError("OneHotOp: on_value/off_value needs to be scalar");
3878     return llvm::None;
3879   }
3880 
3881   if (axis < -1 || axis > indices_type.getRank()) {
3882     op->emitOpError("OneHotOp: axis out of valie range [-1, indices.rank]");
3883     return llvm::None;
3884   }
3885 
3886   // axis = -1 is equivalent to axis = indices.rank
3887   if (axis == -1) {
3888     axis = indices_type.getRank();
3889   }
3890 
3891   int N = 1, W = 1, C = 1;
3892   int K = depth;
3893   int left_dim = 1, right_dim = 1;
3894 
3895   for (int32_t i = 0; i < indices_type.getRank(); i++) {
3896     int32_t dim = indices_type.getShape()[i];
3897     N *= dim;
3898     if (i >= axis) {
3899       right_dim *= dim;
3900     } else {
3901       left_dim *= dim;
3902     }
3903   }
3904 
3905   // Reshape on_value to [1, 1, 1]
3906   auto op1_reshape_on_value = CreateOpAndInfer<tosa::ReshapeOp>(
3907       rewriter, op->getLoc(),
3908       RankedTensorType::get({1, 1, 1}, on_value_type.getElementType()),
3909       on_value, rewriter.getI64ArrayAttr({1, 1, 1}));
3910 
3911   // And tile to [N, W, C]
3912   auto op2_tile_op1 = CreateOpAndInfer<tosa::TileOp>(
3913       rewriter, op->getLoc(),
3914       RankedTensorType::get({N, W, C}, on_value_type.getElementType()),
3915       op1_reshape_on_value.getResult(), rewriter.getI64ArrayAttr({N, W, C}));
3916 
3917   // Reshape off_value to [1, 1, 1]
3918   auto op3_reshape_off_value = CreateOpAndInfer<tosa::ReshapeOp>(
3919       rewriter, op->getLoc(),
3920       RankedTensorType::get({1, 1, 1}, off_value_type.getElementType()),
3921       off_value, rewriter.getI64ArrayAttr({1, 1, 1}));
3922 
3923   // And tile to [N, K, C]
3924   auto op4_tile_op3 = CreateOpAndInfer<tosa::TileOp>(
3925       rewriter, op->getLoc(),
3926       RankedTensorType::get({N, K, C}, on_value_type.getElementType()),
3927       op3_reshape_off_value.getResult(), rewriter.getI64ArrayAttr({N, K, C}));
3928 
3929   // Reshape indices to [N, W]
3930   auto op5_reshape_indices = CreateOpAndInfer<tosa::ReshapeOp>(
3931       rewriter, op->getLoc(),
3932       RankedTensorType::get({N, W}, indices_type.getElementType()),
3933       indices_value, rewriter.getI64ArrayAttr({N, W}));
3934 
3935   // Scatter to [N, K, C]
3936   auto op6_scatter_op4_op5_op2 = CreateOpAndInfer<tosa::ScatterOp>(
3937       rewriter, op->getLoc(),
3938       RankedTensorType::get({N, K, C}, result_type.getElementType()),
3939       op4_tile_op3.getResult(), op5_reshape_indices.getResult(),
3940       op2_tile_op1.getResult());
3941 
3942   // Reshaped to [LeftDims, RightDims, K]. C being squeezed out since it's 1.
3943   auto op7_reshape_op6 = CreateOpAndInfer<tosa::ReshapeOp>(
3944       rewriter, op->getLoc(),
3945       RankedTensorType::get({left_dim, right_dim, K},
3946                             result_type.getElementType()),
3947       op6_scatter_op4_op5_op2.getResult(),
3948       rewriter.getI64ArrayAttr({left_dim, right_dim, K}));
3949 
3950   // Transposed to [LeftDims, K, RightDims].
3951   llvm::Optional<Value> perm_const =
3952       getConstTensor<int32_t>(rewriter, op, /*vec=*/{0, 2, 1}, /*shape=*/{3});
3953 
3954   if (!perm_const) return llvm::None;
3955 
3956   auto op8_transpose_op7 = CreateOpAndInfer<tosa::TransposeOp>(
3957       rewriter, op->getLoc(),
3958       RankedTensorType::get({left_dim, K, right_dim},
3959                             result_type.getElementType()),
3960       op7_reshape_op6.getResult(), perm_const.getValue());
3961 
3962   // Reshaped to result.shape.
3963   return CreateOpAndInfer<tosa::ReshapeOp>(
3964              rewriter, op->getLoc(), result_type, op8_transpose_op7.getResult(),
3965              rewriter.getI64ArrayAttr(result_type.getShape()))
3966       .getResult();
3967 }
3968 
3969 };  // namespace tosa
3970 };  // namespace mlir
3971