• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file contains legalizations common to mapping both TensorFlow and
17 // TensorFlow Lite to TOSA. It operates generically on ops and does not have
18 // a hard reference on either dialect.
19 //
20 // Conversion functions return llvm::None on a legalization failure or a
21 // legalized value on success.  Callers must check for presence of an
22 // llvm::Optional value after each call.
23 
24 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
25 
26 #include <climits>
27 #include <cstddef>
28 #include <cstdint>
29 #include <iterator>
30 #include <numeric>
31 
32 #include "llvm/Support/FormatVariadic.h"
33 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
34 #include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
35 #include "mlir/IR/Matchers.h"  // from @llvm-project
36 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
38 
39 namespace mlir {
40 namespace tosa {
41 
42 // Lowers the Pack operator to TOSA.
convertPackOp(PatternRewriter & rewriter,Operation * op,Value result_value,SmallVector<Value,8> & inputs,int32_t axis)43 llvm::Optional<Value> convertPackOp(PatternRewriter& rewriter, Operation* op,
44                                     Value result_value,
45                                     SmallVector<Value, 8>& inputs,
46                                     int32_t axis) {
47   //////////////////////////////////////////////////
48   // Operator: output = Pack([values], axis) or output = Stack([values], axis)
49   // Lowering:
50   //
51   // This operator is lowered into a series of pairwise tosa.concat()
52   // operators and a reshape
53   // Depending on the inputs, a tranpose operator is also generated:
54   //
55   // Step 1: concatenate the tensors
56   // a1_concat = tosa.concat(input[0], input[1], axis)
57   // for (i = 2; i < len(input); i++)
58   //   a1_concat = tosa.concat(a1_concat, input[i], axis)
59   //
60   // Step 2: reshape to N+1 dimensions
61   // a2_reshape = tosa.reshape(a1_concat, new_rank)
62   //
63   // Step 3: Transpose if a new dimension is being added:
64   // if (axis == rank(values[0]):
65   //   // perm will be [1, 2, 3, 0]
66   //   a3_transpose = tosa.transpose(a2_reshape, perm)
67 
68   // Sanity check 1: make sure all input tensors have the same shape
69   // if input[0] has shape [A, B, C], input[1] to input[N-1] should also have
70   // shape[A, B, C]
71   RankedTensorType result_type =
72       result_value.getType().dyn_cast<RankedTensorType>();
73 
74   // Check for ranked tensor type.
75   if (!result_type) {
76     op->emitOpError("PackOp: result type not ranked tensor");
77     return llvm::None;
78   }
79 
80   // Valid axis in TF is [-rank(input), rank(input))
81   // Valid axis in TOSA is [0, rank(input))
82   // Plus rank(input) once if axis is negative.
83   RankedTensorType input_type =
84       op->getOperand(0).getType().dyn_cast<RankedTensorType>();
85   if (!input_type) {
86     op->emitOpError("PackOp: input type not ranked tensor");
87     return llvm::None;
88   }
89 
90   int32_t input_rank = input_type.getShape().size();
91   if (axis < 0) axis += input_rank;
92 
93   input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
94   if (!input_type) {
95     op->emitOpError("Input 0 type not ranked tensor.");
96     return llvm::None;
97   }
98   ArrayRef<int64_t> input0_tensor_shape = input_type.getShape();
99   int input_tensor_rank = input0_tensor_shape.size();
100 
101   for (int i = 1; i < inputs.size(); i++) {
102     input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
103     if (!input_type) {
104       op->emitOpError(llvm::formatv(
105           "reduce axis {} is not in valid range [-rank(input), rank(input))",
106           i));
107       return llvm::None;
108     }
109     ArrayRef<int64_t> next_tensor_shape = input_type.getShape();
110     if (next_tensor_shape.size() != input_tensor_rank) {
111       op->emitOpError("PackOp: input tensor rank mismatch.");
112       return llvm::None;
113     }
114     for (int d = 0; d < input0_tensor_shape.size(); d++) {
115       if (input0_tensor_shape[d] != next_tensor_shape[d]) {
116         op->emitOpError("PackOp: input tensor shape mismatch.");
117         return llvm::None;
118       }
119     }
120   }
121 
122   // If input tensors are rank 0, should reshape them to rank 1 size 1 before
123   // performing concat.
124   if (input_tensor_rank == 0) {
125     SmallVector<int64_t, 8> reshape_rank1_size1_shape{1};
126     RankedTensorType reshape_rank1_size1_type =
127         RankedTensorType::get(ArrayRef<int64_t>(reshape_rank1_size1_shape),
128                               result_type.getElementType());
129     ArrayAttr shape_rank1_size1_attr =
130         rewriter.getI64ArrayAttr(reshape_rank1_size1_shape);
131     for (int i = 0; i < inputs.size(); i++) {
132       auto a0_reshape_op = rewriter.create<tosa::ReshapeOp>(
133           op->getLoc(), reshape_rank1_size1_type, inputs[i],
134           shape_rank1_size1_attr);
135       inputs[i] = a0_reshape_op.getResult();
136     }
137   }
138 
139   // Sanity check 2: axis can be from [0, rank(input)+1]
140   // Where rank(input)+1 means create a new dimension
141   // Negative values are also allowed up to -(rank(input)+1)
142   // where the axis "wraps around".
143   if (axis < 0) axis += input_rank;
144 
145   if (axis > (input_tensor_rank + 1)) {
146     op->emitOpError("PackOp: axis out of valid range.");
147     return llvm::None;
148   }
149 
150   // Sanity check 2: if input shape is [A, B, C], output shape should be [N,
151   // A, B, C]
152   // 2.a check output is rank(input) + 1
153   SmallVector<int64_t, 8> output_shape_vals(result_type.getShape().begin(),
154                                             result_type.getShape().end());
155   if (output_shape_vals.size() != (input_tensor_rank + 1)) {
156     op->emitOpError("PackOp: output tensor rank mismatch.");
157     return llvm::None;
158   }
159   // 2.b check output rank 0 is N
160   if (output_shape_vals[axis] != inputs.size()) {
161     op->emitOpError("PackOp: output tensor shape mismatch.");
162     return llvm::None;
163   }
164   // Most of the cases when PackOp.axis() is within [0, rank(input) - 1].
165   // We can directly concatenate along that axis and perform the reshape.
166   // For example, stack N [A, B, C] input tensor ranks along axis = 1
167   // after concatenation, output will be [A, N * B, C]
168   // and then reshape it into [A, N, B, C]
169   // a special case would be PackOp.axis() equal to rank(input), in which case
170   // we can't directly concatenate along the PackOp.axis(), instead
171   // we concat along axis=0, and reshape into [N, A, B, C]
172   // and then we need an extra transpose to [A, B, C, N].
173   int64_t concat_axis;
174   SmallVector<int32_t, 8> perm;
175   SmallVector<int64_t, 8> reshape_output_shape;
176   if (axis == 0 && input_tensor_rank == 0) {
177     concat_axis = 0;
178     // Don't need reshape and perm, since we inputs are reshaped into rank 1
179     // size 1.  Output will be rank 1 size N.
180   } else if (axis == input_tensor_rank) {
181     concat_axis = 0;
182 
183     // A special case when stack axis is equal to input tensor rank:
184     // Output shape is [A, B, C, N]
185     // so reshape output will be [N, A, B, C]
186     // and perm will be [1, 2, 3, 0].
187     reshape_output_shape.push_back(output_shape_vals[axis]);
188     for (int d = 0; d < input_tensor_rank; d++) {
189       perm.push_back(d + 1);
190       reshape_output_shape.push_back(output_shape_vals[d]);
191     }
192     perm.push_back(0);
193   } else {
194     // General case, doesn't need perm vector.
195     concat_axis = axis;
196     reshape_output_shape.assign(output_shape_vals.begin(),
197                                 output_shape_vals.end());
198   }
199   IntegerAttr concat_axis_attr = rewriter.getI64IntegerAttr(concat_axis);
200   ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_output_shape);
201 
202   // For each concat output, shape will be different.
203   // If input shape is [A, B, C] and concat_axis = 0, 1st concat output will
204   // be [2 * A, B, C].
205   int orig_input_dim_on_axis;
206   SmallVector<int64_t, 4> concat_output_shape;
207   if (input_tensor_rank == 0) {
208     concat_output_shape.push_back(1);
209     orig_input_dim_on_axis = 1;
210   } else {
211     for (int i = 0; i < input_tensor_rank; i++) {
212       concat_output_shape.push_back(input0_tensor_shape[i]);
213     }
214     orig_input_dim_on_axis = input0_tensor_shape[concat_axis];
215   }
216 
217   concat_output_shape[concat_axis] = orig_input_dim_on_axis * 2;
218   RankedTensorType concat_type = RankedTensorType::get(
219       ArrayRef<int64_t>(concat_output_shape), result_type.getElementType());
220   auto a1_concat_op = rewriter.create<tosa::ConcatOp>(
221       op->getLoc(), concat_type, inputs[0], inputs[1], concat_axis_attr);
222 
223   // K-th concat output will be [(k+1) * A, B, C], last output will be [N * A,
224   // B, C].
225   for (int i = 2; i < inputs.size(); i++) {
226     concat_output_shape[concat_axis] = orig_input_dim_on_axis * (i + 1);
227     concat_type = RankedTensorType::get(ArrayRef<int64_t>(concat_output_shape),
228                                         result_type.getElementType());
229     a1_concat_op = rewriter.create<tosa::ConcatOp>(op->getLoc(), concat_type,
230                                                    a1_concat_op.getResult(),
231                                                    inputs[i], concat_axis_attr);
232   }
233 
234   // Doesn't need reshape or transpose if input tensor is rank 0, since inputs
235   // are reshaped beforehand.
236   if (input_tensor_rank == 0) return a1_concat_op.getResult();
237 
238   // Reshape [N * A, B, C] to [N, A, B, C].
239   RankedTensorType reshape_output_type = RankedTensorType::get(
240       ArrayRef<int64_t>(reshape_output_shape), result_type.getElementType());
241 
242   auto a2_reshape_op = rewriter.create<tosa::ReshapeOp>(
243       op->getLoc(), reshape_output_type, a1_concat_op.getResult(), shape_attr);
244 
245   // If axis is equal to input tensor rank, then we need extra transpose
246   // [N, A, B, C] to [A, B, C, N]
247   if (axis == input_tensor_rank) {
248     Value a3_transpose_perm =
249         get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, perm);
250 
251     return rewriter
252         .create<tosa::TransposeOp>(op->getLoc(), result_type,
253                                    a2_reshape_op.getResult(), a3_transpose_perm)
254         .getResult();
255   }
256 
257   return a2_reshape_op.getResult();
258 }
259 
260 // Lowers the Unpack operator to TOSA
convertUnpackOp(PatternRewriter & rewriter,Operation * op,Value input_value,int32_t axis)261 llvm::Optional<ValueRange> convertUnpackOp(PatternRewriter& rewriter,
262                                            Operation* op, Value input_value,
263                                            int32_t axis) {
264   RankedTensorType input_type =
265       input_value.getType().dyn_cast<RankedTensorType>();
266   if (!input_type) return llvm::None;
267 
268   auto input_shape = input_type.getShape();
269   int64_t input_rank = input_shape.size();
270 
271   SmallVector<Value, 4> results_vec;
272 
273   // Negative axis allowed as long as it's within [-input_rank, input_rank).
274   if (axis < 0) axis += input_rank;
275 
276   assert(axis >= 0 && axis < input_shape.size());
277 
278   // A list of the output types for each slice op
279   SmallVector<Type, 4> outs_type_vec;
280 
281   // Step 1: transpose 'axis' to leftmost dimension.
282   Value transposed_input_value;
283   if (axis != 0) {
284     SmallVector<int32_t, 8> perm_vec;
285     SmallVector<int64_t, 2> a1_transpose_shape(input_rank);
286 
287     perm_vec.push_back(axis);
288     for (int i = 0; i < input_rank; i++) {
289       if (i == axis) continue;
290       perm_vec.push_back(i);
291     }
292 
293     Value a1_transpose_perm =
294         get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, perm_vec);
295 
296     for (int i = 0; i < input_rank; i++) {
297       a1_transpose_shape[i] = input_shape[perm_vec[i]];
298     }
299 
300     auto a1_transpose_op = rewriter.create<tosa::TransposeOp>(
301         op->getLoc(),
302         RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_shape),
303                               input_type.getElementType()),
304         input_value, a1_transpose_perm);
305 
306     transposed_input_value = a1_transpose_op.getResult();
307   } else {
308     // Do nothing if axis is already at leftmost dimension.
309     transposed_input_value = input_value;
310   }
311 
312   // Step 2: slice [N, A, B, C] into N [A, B, C].
313   RankedTensorType transposed_input_type =
314       transposed_input_value.getType().dyn_cast<RankedTensorType>();
315   if (!transposed_input_type) return llvm::None;
316 
317   auto transposed_input_shape = transposed_input_type.getShape();
318   int64_t transposed_input_rank = transposed_input_shape.size();
319 
320   for (int i = 0; i < transposed_input_shape[0]; i++) {
321     SmallVector<int64_t, 4> begin_vals, size_vals, shape_vals;
322 
323     for (int j = 0; j < transposed_input_rank; j++) {
324       if (j == 0) {
325         begin_vals.push_back(i);
326         size_vals.push_back(1);
327       } else {
328         begin_vals.push_back(0);
329         size_vals.push_back(transposed_input_shape[j]);
330         shape_vals.push_back(transposed_input_shape[j]);
331       }
332     }
333 
334     ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
335     ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
336 
337     auto a2_slice_op = rewriter.create<tosa::SliceOp>(
338         op->getLoc(),
339         RankedTensorType::get(ArrayRef<int64_t>(size_vals),
340                               transposed_input_type.getElementType()),
341         transposed_input_value, begin, size);
342 
343     auto a3_reshape_op = rewriter.create<tosa::ReshapeOp>(
344         op->getLoc(),
345         RankedTensorType::get(ArrayRef<int64_t>(shape_vals),
346                               transposed_input_type.getElementType()),
347         a2_slice_op.getResult(), rewriter.getI64ArrayAttr(shape_vals));
348 
349     outs_type_vec.push_back(RankedTensorType::get(
350         ArrayRef<int64_t>(shape_vals), transposed_input_type.getElementType()));
351 
352     results_vec.push_back(a3_reshape_op.getResult());
353   }
354 
355   // Combine the sequence of tosa.slice() ops into a list
356   // using the IdentityN operator.
357   return rewriter
358       .create<tosa::IdentityNOp>(op->getLoc(), ArrayRef<Type>(outs_type_vec),
359                                  results_vec)
360       .getResults();
361 }
362 
363 // Lowers the Select operator to TOSA.
convertSelectOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value condition_value,Value x_value,Value y_value)364 llvm::Optional<Value> convertSelectOp(PatternRewriter& rewriter, Operation* op,
365                                       Value result_value, Value condition_value,
366                                       Value x_value, Value y_value) {
367   RankedTensorType result_type =
368       result_value.getType().dyn_cast<RankedTensorType>();
369   RankedTensorType condition_type =
370       condition_value.getType().dyn_cast<RankedTensorType>();
371   RankedTensorType x_type = x_value.getType().dyn_cast<RankedTensorType>();
372   RankedTensorType y_type = y_value.getType().dyn_cast<RankedTensorType>();
373 
374   if (!result_type || !condition_type || !x_type || !y_type) {
375     op->emitOpError("Select: failed ranked tensor type check");
376     return llvm::None;
377   }
378 
379   // First check whether we need to reshape the condition to match
380   // the same rank as the then/else clauses.
381   if (result_type.getRank() == condition_type.getRank()) {
382     // Nothing to reshape.
383     return rewriter
384         .create<tosa::SelectOp>(op->getLoc(), result_type, condition_value,
385                                 x_value, y_value)
386         .getResult();
387   }
388 
389   // Need to reshape the condition.
390   SmallVector<int64_t, 8> new_cond_dims(
391       result_type.getRank() - condition_type.getRank(), 1);
392 
393   for (int i = 0; i < condition_type.getRank(); i++) {
394     new_cond_dims.push_back(condition_type.getShape()[i]);
395   }
396 
397   auto reshape_op = rewriter.create<tosa::ReshapeOp>(
398       op->getLoc(),
399       RankedTensorType::get(ArrayRef<int64_t>(new_cond_dims),
400                             condition_type.getElementType()),
401       condition_value, rewriter.getI64ArrayAttr(new_cond_dims));
402 
403   return rewriter
404       .create<tosa::SelectOp>(op->getLoc(), result_type, reshape_op, x_value,
405                               y_value)
406       .getResult();
407 }
408 
409 // Lowers the ZerosLike operator to TOSA by creating a constant
410 // of the desired type and shape.
convertZerosLikeOp(PatternRewriter & rewriter,Operation * op,Value result,Value input)411 llvm::Optional<Value> convertZerosLikeOp(PatternRewriter& rewriter,
412                                          Operation* op, Value result,
413                                          Value input) {
414   RankedTensorType result_type = result.getType().dyn_cast<RankedTensorType>();
415   if (!result_type) {
416     op->emitOpError("Zeroslike: result not ranked tensor type");
417     return llvm::None;
418   }
419 
420   RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
421   if (!input_type) {
422     op->emitOpError("Zeroslike: input not ranked tensor type");
423     return llvm::None;
424   }
425 
426   auto input_shape = input_type.getShape();
427 
428   ShapedType zero_type =
429       RankedTensorType::get(input_shape, input_type.getElementType());
430   Attribute zero_attr = rewriter.getZeroAttr(zero_type);
431 
432   return rewriter
433       .create<tosa::ConstOp>(op->getLoc(), zero_type,
434                              zero_attr.cast<ElementsAttr>())
435       .getResult();
436 }
437 
438 // Lowers the Mul operator to TOSA.  For quantized types, this requires
439 // inserting rescale operators before and after the operation.
convertMultiplyOp(PatternRewriter & rewriter,Operation * op,Value output_val,Value input_lhs_val,Value input_rhs_val)440 llvm::Optional<Value> convertMultiplyOp(PatternRewriter& rewriter,
441                                         Operation* op, Value output_val,
442                                         Value input_lhs_val,
443                                         Value input_rhs_val) {
444   RankedTensorType input_lhs_type =
445       input_lhs_val.getType().dyn_cast<RankedTensorType>();
446   RankedTensorType input_rhs_type =
447       input_rhs_val.getType().dyn_cast<RankedTensorType>();
448   RankedTensorType output_type =
449       output_val.getType().dyn_cast<RankedTensorType>();
450   // Not a ranked tensor output
451   if (!input_lhs_type || !input_rhs_type || !output_type) return llvm::None;
452 
453   bool input_lhs_is_qtype =
454       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
455   bool input_rhs_is_qtype =
456       input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
457   bool output_is_qtype =
458       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
459 
460   if (input_lhs_is_qtype != output_is_qtype ||
461       input_rhs_is_qtype != output_is_qtype) {
462     op->emitOpError(
463         "ConvertMultiplyOp: input/output tensor should "
464         "be all quantized or all floating-point");
465     return llvm::None;
466   }
467 
468   Value output;
469   if (output_is_qtype) {
470     RankedTensorType rescale_type =
471         RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
472     auto input_lhs_qtype = input_lhs_type.getElementType()
473                                .cast<mlir::quant::UniformQuantizedType>();
474     auto input_rhs_qtype = input_rhs_type.getElementType()
475                                .cast<mlir::quant::UniformQuantizedType>();
476     auto output_qtype =
477         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
478     double in_lhs_scale = input_lhs_qtype.getScale();
479     double in_rhs_scale = input_rhs_qtype.getScale();
480     double output_scale = output_qtype.getScale();
481 
482     double output_rescale_scale = in_lhs_scale * in_rhs_scale / output_scale;
483 
484     Value op1_rescale_lhs = buildRescaleToInt32(
485         rewriter, op, input_lhs_val, 1.0f, input_lhs_qtype.getZeroPoint());
486     Value op2_rescale_rhs = buildRescaleToInt32(
487         rewriter, op, input_rhs_val, 1.0f, input_rhs_qtype.getZeroPoint());
488     auto op3_mul_op1_op2 = rewriter.create<tosa::MulOp>(
489         op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs, 0);
490     return buildRescaleFromInt32(
491         rewriter, op, output_type, op3_mul_op1_op2.getResult(),
492         output_rescale_scale, output_qtype.getZeroPoint());
493   }
494 
495   return rewriter
496       .create<tosa::MulOp>(op->getLoc(), output_type, input_lhs_val,
497                            input_rhs_val, 0)
498       .getResult();
499 }
500 
501 // Lowers the SquaredDifference operator to TOSA.
convertSquaredDifferenceOp(PatternRewriter & rewriter,Operation * op,Value result,Value x,Value y)502 llvm::Optional<Value> convertSquaredDifferenceOp(PatternRewriter& rewriter,
503                                                  Operation* op, Value result,
504                                                  Value x, Value y) {
505   // Squared-difference is (x-y)*(x-y).
506   // This lowering calculates the difference and multiplies.
507   RankedTensorType result_type = result.getType().dyn_cast<RankedTensorType>();
508   if (!result_type) {
509     op->emitOpError("SquaredDifference: result not ranked tensor type");
510     return llvm::None;
511   }
512 
513   RankedTensorType x_type = x.getType().dyn_cast<RankedTensorType>();
514   RankedTensorType y_type = y.getType().dyn_cast<RankedTensorType>();
515   if (!x_type || !y_type) {
516     op->emitOpError("SquaredDifference: inputs not ranked tensor type");
517     return llvm::None;
518   }
519 
520   auto sub_op = rewriter.create<tosa::SubOp>(op->getLoc(), result_type, x, y);
521   return rewriter
522       .create<tosa::MulOp>(op->getLoc(), result_type, sub_op.getResult(),
523                            sub_op.getResult(), 0)
524       .getResult();
525 }
526 
527 // Lowers the Round operator to TOSA.
convertRoundOp(PatternRewriter & rewriter,Operation * op,Value result,Value input)528 llvm::Optional<Value> convertRoundOp(PatternRewriter& rewriter, Operation* op,
529                                      Value result, Value input) {
530   // Implements banker's rounding by calculating floor(input + 0.5).
531   RankedTensorType result_type = result.getType().dyn_cast<RankedTensorType>();
532   if (!result_type) {
533     op->emitOpError("Round: result not ranked tensor type");
534     return llvm::None;
535   }
536 
537   RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
538   if (!input_type) {
539     op->emitOpError("Round: input not ranked tensor type");
540     return llvm::None;
541   }
542 
543   auto add_op = rewriter.create<tosa::AddOp>(
544       op->getLoc(), result_type, input,
545       getTosaConstTensorSingleF32(rewriter, op, 0.5));
546 
547   return rewriter
548       .create<tosa::FloorOp>(op->getLoc(), result_type, add_op.getResult())
549       .getResult();
550 }
551 
552 // Lowers ConcatV2 to TOSA.
convertConcatV2Op(PatternRewriter & rewriter,Operation * op,Value result_value,SmallVector<Value,8> & values,int32_t axis)553 llvm::Optional<Value> convertConcatV2Op(PatternRewriter& rewriter,
554                                         Operation* op, Value result_value,
555                                         SmallVector<Value, 8>& values,
556                                         int32_t axis) {
557   // ConcatV2 becomes a series of TOSA Concat operators that take pairs of
558   // tensors as arguments.   Rank-0 tensors are reshaped to Rank-1,
559   // shape (1,) tensors.
560   RankedTensorType result_type =
561       result_value.getType().dyn_cast<RankedTensorType>();
562   if (!result_type) {
563     op->emitOpError("ConcatV2Op: result type not ranked tensor.");
564     return llvm::None;
565   }
566 
567   // Valid axis in TF is [-rank(input), rank(input)).
568   // Valid axis in TOSA is [0, rank(input)).
569   // Plus rank(input) once if axis is negative.
570   RankedTensorType input_type =
571       op->getOperand(0).getType().dyn_cast<RankedTensorType>();
572   if (!input_type) {
573     op->emitOpError("ConcatV2Op: input type not ranked tensor.");
574     return llvm::None;
575   }
576 
577   auto input_rank = input_type.getShape().size();
578 
579   if (axis < 0) axis += input_rank;
580 
581   assert(values.size() >= 2);
582 
583   if (!values[0].getType().dyn_cast<RankedTensorType>() ||
584       !values[1].getType().dyn_cast<RankedTensorType>()) {
585     op->emitOpError("ConcatV2Op: value type not ranked tensor.");
586     return llvm::None;
587   }
588 
589   Value lhs_val = values[0];
590   Value rhs_val = values[1];
591   RankedTensorType lhs_type = lhs_val.getType().cast<RankedTensorType>();
592   RankedTensorType rhs_type = rhs_val.getType().cast<RankedTensorType>();
593   ArrayRef<int64_t> lhs_tensor_shape = lhs_type.getShape();
594   ArrayRef<int64_t> rhs_tensor_shape = rhs_type.getShape();
595   int input_tensor_rank = lhs_tensor_shape.size();
596 
597   // For each concat output, shape will be different.
598   // If input tensors are rank 0, should reshape them to rank 1 size 1 before
599   // performing concat. If not, most dimensions should have same size as input
600   // except the concat'd axis.
601   //
602   // If input is [A0, B, C] and [A1, B, C] and axis = 0
603   // this concat output will be [A0 + A1, B, C].
604   SmallVector<int64_t, 4> concat_result_shape;
605   if (input_tensor_rank == 0) {
606     if (axis != 0) {
607       op->emitOpError("ConcatV2Op: axis invalid.");
608       return llvm::None;
609     }
610     SmallVector<int64_t, 8> reshape_rank1_size1_shape{1};
611     RankedTensorType reshape_rank1_size1_type =
612         RankedTensorType::get(ArrayRef<int64_t>(reshape_rank1_size1_shape),
613                               result_type.getElementType());
614     ArrayAttr shape_rank1_size1_attr =
615         rewriter.getI64ArrayAttr(reshape_rank1_size1_shape);
616     for (int i = 0; i < values.size(); i++) {
617       auto a0_reshape_op = rewriter.create<tosa::ReshapeOp>(
618           op->getLoc(), reshape_rank1_size1_type, values[i],
619           shape_rank1_size1_attr);
620       values[i] = a0_reshape_op.getResult();
621     }
622     concat_result_shape.push_back(2);
623   } else {
624     if (axis < 0 || axis >= input_tensor_rank) {
625       op->emitOpError("ConcatV2Op: axis invalid.");
626       return llvm::None;
627     }
628     for (int i = 0; i < input_tensor_rank; i++) {
629       concat_result_shape.push_back(lhs_tensor_shape[i]);
630     }
631     concat_result_shape[axis] = lhs_tensor_shape[axis] + rhs_tensor_shape[axis];
632   }
633 
634   RankedTensorType concat_type = RankedTensorType::get(
635       ArrayRef<int64_t>(concat_result_shape), result_type.getElementType());
636 
637   mlir::quant::UniformQuantizedType lhs_quant_type =
638       lhs_type.getElementType()
639           .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
640   mlir::quant::UniformQuantizedType rhs_quant_type =
641       rhs_type.getElementType()
642           .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
643   mlir::quant::UniformQuantizedType result_quant_type =
644       result_type.getElementType()
645           .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
646 
647   double lhs_scale, rhs_scale, result_scale;
648   int32_t lhs_zeropoint, rhs_zeropoint, result_zeropoint;
649 
650   // tfl.concat currently allows different scales for each input tensor, which
651   // TFlite team will fix in:
652   // https://github.com/tensorflow/tensorflow/issues/39658
653   //
654   // For backward compatibility, we still need to support this artifact by
655   // scaling inputs to let them have the same scales.
656   if (result_quant_type && lhs_quant_type && rhs_quant_type) {
657     lhs_scale = static_cast<double>(lhs_quant_type.getScale());
658     lhs_zeropoint = lhs_quant_type.getZeroPoint();
659     rhs_scale = static_cast<double>(rhs_quant_type.getScale());
660     rhs_zeropoint = rhs_quant_type.getZeroPoint();
661     result_scale = static_cast<double>(result_quant_type.getScale());
662     result_zeropoint = result_quant_type.getZeroPoint();
663 
664     // Rescale input if scale is not equal to output tensor scale.
665     if (lhs_scale != result_scale) {
666       RankedTensorType rescale_type =
667           RankedTensorType::get(lhs_type.getShape(), result_quant_type);
668 
669       Value rescale_op = buildRescale(rewriter, op, rescale_type, lhs_val,
670                                       lhs_scale / result_scale, lhs_zeropoint,
671                                       result_zeropoint);
672 
673       lhs_val = rescale_op;
674     }
675     if (rhs_scale != result_scale) {
676       RankedTensorType rescale_type =
677           RankedTensorType::get(rhs_type.getShape(), result_quant_type);
678 
679       Value rescale_op = buildRescale(rewriter, op, rescale_type, rhs_val,
680                                       rhs_scale / result_scale, rhs_zeropoint,
681                                       result_zeropoint);
682 
683       rhs_val = rescale_op;
684     }
685   }
686 
687   auto concat_op = rewriter.create<tosa::ConcatOp>(
688       op->getLoc(), concat_type, lhs_val, rhs_val,
689       rewriter.getI64IntegerAttr(axis));
690   for (int i = 2; i < values.size(); i++) {
691     rhs_val = values[i];
692     rhs_type = rhs_val.getType().dyn_cast<RankedTensorType>();
693     rhs_tensor_shape = rhs_type.getShape();
694     rhs_quant_type = rhs_type.getElementType()
695                          .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
696 
697     if (input_tensor_rank == 0) {
698       concat_result_shape[axis] = concat_result_shape[axis] + 1;
699     } else {
700       concat_result_shape[axis] =
701           concat_result_shape[axis] + rhs_tensor_shape[axis];
702     }
703     concat_type = RankedTensorType::get(ArrayRef<int64_t>(concat_result_shape),
704                                         result_type.getElementType());
705 
706     if (rhs_quant_type && result_quant_type) {
707       rhs_scale = static_cast<float>(rhs_quant_type.getScale());
708       rhs_zeropoint = rhs_quant_type.getZeroPoint();
709 
710       if (rhs_scale != result_scale) {
711         RankedTensorType rescale_type =
712             RankedTensorType::get(rhs_type.getShape(), result_quant_type);
713 
714         Value rescale_op = buildRescale(rewriter, op, rescale_type, rhs_val,
715                                         rhs_scale / result_scale, rhs_zeropoint,
716                                         result_zeropoint);
717 
718         rhs_val = rescale_op;
719       }
720     }
721 
722     concat_op = rewriter.create<tosa::ConcatOp>(
723         op->getLoc(), concat_type, concat_op.getResult(), rhs_val,
724         rewriter.getI64IntegerAttr(axis));
725   }
726 
727   return concat_op.getResult();
728 }
729 
730 // Lowers SpaceToBatchND to TOSA.
convertSpaceToBatchNDOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value block_shape_value,Value paddings_value)731 llvm::Optional<Value> convertSpaceToBatchNDOp(PatternRewriter& rewriter,
732                                               Operation* op, Value result_value,
733                                               Value input_value,
734                                               Value block_shape_value,
735                                               Value paddings_value) {
736   /////////////////////////////////////////////////
737   // Operator: output = SpaceToBatchND(input, block_shape, paddings)
738   // Lowering:
739   //
740   // SpaceToBatch input tensors are broken into three pieces:
741   //   (a) batch dimension (N in NHWC)
742   //   (b) input being transformed to batch dimension (typically H, W in NHWC)
743   //   (c) remainder of input (typically C in NHWC)
744   //
745   // Step 0. Generate padding constant for the first reshape.
746   //   No padding on the batch dimension
747   //   The input paddings array is addressed as [input_rank][2]
748   //   No padding on the remaining dimensions
749   //
750   //  a0_pad_const = tosa.const(input=Tensor<input_rank, 2>)
751   //
752   // Step 1. Pad the input tensor
753   //
754   //  a1_pad_input_op = tosa.pad(input=input, shape=a0_pad_const_op)
755   //
756   // Step 2. Reshape the padded structure of shape padded_shape to
757   // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
758   //    padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
759   //    remaining_shape
760   //
761   // block_rank = M (number of elements in block_shape)
762   // New rank: input_rank + block_rank
763   //
764   //  a2_reshape_a1_op = tosa.reshape(input=a1_pad_input_op, shape=a2_shape)
765   //
766   // Step 3. Transpose dimensions to:
767   //  block-shape +
768   //  [batch] +
769   //  [padded_shape[1] / block_shape[0],
770   // ...
771   //  [padded_shape[M] / block_shape[M-1]] +
772   //  remaining_shape
773   //
774   // a3_transpose_a2_op = tosa.tranpose(input=a2_reshape_a1_op,
775   // perms=a3_perm)
776   //
777   // Step 4. Reshape the transposed tensor to flatten block_shape stuff
778   // into the batch dimension with the following shape:
779   // [ batch * prod(block_shape)] +
780   // [ padded_shape[1] / block_shape[0],
781   //   ...,
782   // padded_shape[M] / block_shape[M-1]] +
783   // remaining_shape
784   //
785   //  a4_reshape_a3_op = tosa.reshape(input=a3_tranpose_a2_op,
786   //  shape=a3_shape)
787   //
788 
789   RankedTensorType result_type =
790       result_value.getType().dyn_cast<RankedTensorType>();
791   RankedTensorType input_type =
792       input_value.getType().dyn_cast<RankedTensorType>();
793   RankedTensorType block_shape_type =
794       block_shape_value.getType().dyn_cast<RankedTensorType>();
795   RankedTensorType paddings_type =
796       paddings_value.getType().dyn_cast<RankedTensorType>();
797 
798   // Not a ranked tensor output.
799   if (!result_type) {
800     op->emitOpError("SpaceToBatchND: result type not ranked tensor");
801     return llvm::None;
802   }
803   if (!input_type) {
804     op->emitOpError("SpaceToBatchND: input type not ranked tensor");
805     return llvm::None;
806   }
807   if (!block_shape_type) {
808     op->emitOpError("SpaceToBatchND: block shape type not ranked tensor");
809     return llvm::None;
810   }
811   if (!paddings_type) {
812     op->emitOpError("SpaceToBatchND: paddings type not ranked tensor");
813     return llvm::None;
814   }
815 
816   // Follow implementation in
817   // tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
818 
819   // So, to figure out the spatial_shape, remove the batch dimension and
820   // then use the next block_rank dimensions.  The remaining dimensions are
821   // remaining_shape.
822 
823   auto block_shape = block_shape_type.getShape();
824   auto input_shape = input_type.getShape();
825 
826   int block_rank = block_shape[0];
827   int batch_size = input_shape[0];
828   int input_rank = input_type.getRank();
829   int remaining_shape_rank = input_rank - block_rank - 1;
830   int block_num_elems = 1;
831   int padding_sum = 0;
832 
833   ElementsAttr block_shape_elems;
834   ElementsAttr paddings_elems;
835 
836   if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems)))
837     return llvm::None;
838 
839   if (!matchPattern(paddings_value, m_Constant(&paddings_elems)))
840     return llvm::None;
841 
842   SmallVector<int32_t, 2> a0_pad_const(2 * (input_rank));
843   SmallVector<int64_t, 2> padded_shape(input_rank);
844 
845   // 1. Pad based on paddings operand.  No padding on the batch dimension.
846   // The a0_pad_const array is addressed as [input_rank][2], but
847   // it is flattened to a 1D array because LLVM appears to only accept 1D.
848   //
849   // padded_shape[] is the shape of the padded output of step a1.
850   // The name is retained for consistency with the TF reference code.
851   padded_shape[0] = input_shape[0];
852 
853   // Batch dimension padding
854   a0_pad_const[0] = 0;
855   a0_pad_const[1] = 0;
856 
857   // This iterator seems to be the only reliable way to get
858   // int values out of a multi-dimensional ElementsAttr.
859   int idx = 0;
860 
861   for (auto i : paddings_elems.getValues<IntegerAttr>()) {
862     a0_pad_const[idx + 2] = i.getInt();
863     padding_sum += i.getInt();
864     idx++;
865   }
866 
867   // Insert padding on the spatial shape dimensions
868   for (int i = 0; i < block_rank; i++) {
869     int32_t lo_pad = a0_pad_const[2 * (i + 1) + 0];
870     int32_t hi_pad = a0_pad_const[2 * (i + 1) + 1];
871 
872     padded_shape[i + 1] = input_shape[i + 1] + lo_pad + hi_pad;
873   }
874 
875   // No padding on the remaining_shape dimensions
876   for (int i = 0; i < remaining_shape_rank; i++) {
877     a0_pad_const[2 * (i + block_rank + 1) + 0] = 0;
878     a0_pad_const[2 * (i + block_rank + 1) + 1] = 0;
879     padded_shape[i + block_rank + 1] = input_shape[i + block_rank + 1];
880   }
881 
882   RankedTensorType a0_pad_const_attr_type =
883       RankedTensorType::get({(input_rank), 2}, rewriter.getIntegerType(32));
884 
885   // Create a const op to generate the tensor type for the input padding array
886   auto a0_pad_const_op = rewriter.create<tosa::ConstOp>(
887       op->getLoc(), a0_pad_const_attr_type,
888       DenseElementsAttr::get(a0_pad_const_attr_type,
889                              llvm::makeArrayRef<int32_t>(a0_pad_const)));
890 
891   auto a1_pad_input_op = rewriter.create<tosa::PadOp>(
892       op->getLoc(),
893       RankedTensorType::get(ArrayRef<int64_t>(padded_shape),
894                             result_type.getElementType()),
895       input_value, a0_pad_const_op.getResult());
896 
897   // 2. Reshape the padded structure of shape padded_shape to
898   // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
899   //    padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
900   //    remaining_shape
901 
902   // block_rank = M (number of elements in block_shape)
903   // New rank: input_rank + block_rank
904   SmallVector<int64_t, 2> a2_shape(1 + block_rank * 2 + remaining_shape_rank);
905 
906   // First dimension is batch.
907   a2_shape[0] = input_type.getShape()[0];
908   for (int i = 0; i < block_rank; i++) {
909     int32_t block_shape_val =
910         rewriter
911             .getI32IntegerAttr(
912                 block_shape_elems.getValue<IntegerAttr>(i).getInt())
913             .getInt();
914     a2_shape[1 + i * 2 + 0] = padded_shape[1 + i] / block_shape_val;
915     a2_shape[1 + i * 2 + 1] = block_shape_val;
916     block_num_elems *= block_shape_val;
917   }
918 
919   // Copy in the remaining block shape.
920   for (int i = 0; i < remaining_shape_rank; i++) {
921     a2_shape[1 + block_rank * 2 + i] = input_shape[1 + block_rank + i];
922   }
923 
924   auto a2_reshape_a1_op = rewriter.create<tosa::ReshapeOp>(
925       op->getLoc(),
926       RankedTensorType::get(ArrayRef<int64_t>(a2_shape),
927                             result_type.getElementType()),
928       a1_pad_input_op.getResult(), rewriter.getI64ArrayAttr(a2_shape));
929 
930   // 3. Transpose dimensions to:
931   //  block-shape +
932   //  [batch] +
933   //  [padded_shape[1] / block_shape[0],
934   // ...
935   //  [padded_shape[M] / block_shape[M-1]] +
936   //  remaining_shape
937   int32_t a2_reshape_a1_rank =
938       a2_reshape_a1_op.getResult().getType().cast<RankedTensorType>().getRank();
939   SmallVector<int32_t, 8> a3_perm(a2_reshape_a1_rank);
940   SmallVector<int64_t, 2> a3_transpose_shape(a2_reshape_a1_rank);
941 
942   for (int i = 0; i < block_rank; i++) {
943     a3_perm[i] = 1 + 2 * i + 1;
944     a3_perm[block_rank + 1 + i] = 1 + 2 * i;
945   }
946   a3_perm[block_rank] = 0;
947   for (int i = 1 + block_rank * 2; i < a2_reshape_a1_rank; i++) {
948     a3_perm[i] = i;
949   }
950 
951   for (int i = 0; i < a3_transpose_shape.size(); i++) {
952     a3_transpose_shape[i] = a2_shape[a3_perm[i]];
953   }
954 
955   Value a3_transpose_const =
956       get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, a3_perm);
957 
958   auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
959       op->getLoc(),
960       RankedTensorType::get(ArrayRef<int64_t>(a3_transpose_shape),
961                             result_type.getElementType()),
962       a2_reshape_a1_op.getResult(), a3_transpose_const);
963 
964   // 4. Reshape the transposed tensor to flatten block_shape
965   // into the batch dimension with the following shape:
966   // [ batch * prod(block_shape)] +
967   // [ padded_shape[1] / block_shape[0],
968   //   ...,
969   // padded_shape[M] / block_shape[M-1]] +
970   // remaining_shape
971   SmallVector<int64_t, 2> a4_reshape_shape(input_rank);
972 
973   // Batch
974   a4_reshape_shape[0] = batch_size * block_num_elems;
975 
976   // padded shape / block_shape.
977   for (int i = 0; i < block_rank; i++) {
978     int32_t block_shape_val =
979         rewriter
980             .getI32IntegerAttr(
981                 block_shape_elems.getValue<IntegerAttr>(i).getInt())
982             .getInt();
983     a4_reshape_shape[i + 1] = padded_shape[i + 1] / block_shape_val;
984   }
985 
986   // Copy in remainder shape.
987   for (int i = 0; i < remaining_shape_rank; i++) {
988     a4_reshape_shape[1 + block_rank + i] = input_shape[1 + block_rank + i];
989   }
990 
991   return rewriter
992       .create<tosa::ReshapeOp>(op->getLoc(), result_type,
993                                a3_transpose_a2_op.getResult(),
994                                rewriter.getI64ArrayAttr(a4_reshape_shape))
995       .getResult();
996 }
997 
998 // Lowers BatchToSpaceND to TOSA.
convertBatchToSpaceNDOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value block_shape_value,Value crops_value)999 llvm::Optional<Value> convertBatchToSpaceNDOp(PatternRewriter& rewriter,
1000                                               Operation* op, Value result_value,
1001                                               Value input_value,
1002                                               Value block_shape_value,
1003                                               Value crops_value) {
1004   /////////////////////////////////////////////////
1005   // Operator: output = BatchToSpaceND(input, block_shape, clips)
1006   // Lowering:
1007   //
1008   // BatchToSpace input tensors are broken into three pieces:
1009   //   (a) batch dimension (N in NHWC)
1010   //   (b) input being transformed from batch dimension (typically H, W in
1011   //   NHWC)
1012   //   (c) remainder of input (typically C in NHWC)
1013   //
1014   // Step 1. Reshape input to:
1015   // [block_shape[0],
1016   // ...
1017   // [block_shape[M-1],
1018   // [batch / prod(block_shape)]
1019   // [input_shape[1],
1020   // ...
1021   // [input_shape[N-1]
1022   //
1023   // a1_reshape_input_op = tosa.reshape(input=input, shape=a1_shape)
1024   //
1025   // Step 2. Permute to shape
1026   // [ batch / prod(block_shape) ],
1027   // [ input_shape[1] ], [ block_shape[1] ]
1028   //  ...
1029   // [ input_shape[M] ], [ block_shape[M-1]
1030   // + remaining_input_shapes input_shape[M .. N-1]
1031   //
1032   // a2_transpose_a1 = tosa.transpose(input=a1_reshape_input_op,
1033   // shape=a2_shape)
1034   //
1035   // Step 3. Reshape to:
1036   // [ batch / prod(block_shape) ],
1037   // [input_shape[1] * block_shape[0] ],
1038   //    ..
1039   // [input_shape[M * block_shape[M-1],
1040   // + remaining input shapes [input_shape[M+1.. N-1]]
1041   //
1042   // a3_reshape_a2 = tosa.reshape(input=a2_transpose_a1, shape=a3_shape)
1043   //
1044   // Step 4. Crop the start/end dimensions according to crops of the
1045   // a3_reshape_a2 shape
1046   //
1047   // a4_slice_a3 = tosa.slice(input=a3_reshape_a2, start=a4_start,
1048   // size=a4_size)
1049 
1050   RankedTensorType result_type =
1051       result_value.getType().dyn_cast<RankedTensorType>();
1052   RankedTensorType input_type =
1053       input_value.getType().dyn_cast<RankedTensorType>();
1054   RankedTensorType block_shape_type =
1055       block_shape_value.getType().dyn_cast<RankedTensorType>();
1056   RankedTensorType crops_type =
1057       crops_value.getType().dyn_cast<RankedTensorType>();
1058 
1059   if (!result_type) {
1060     op->emitOpError("BatchToSpaceND: result type not ranked tensor");
1061     return llvm::None;
1062   }
1063   if (!input_type) {
1064     op->emitOpError("BatchToSpaceND: input type not ranked tensor");
1065     return llvm::None;
1066   }
1067   if (!block_shape_type) {
1068     op->emitOpError("BatchToSpaceND: block shape type not ranked tensor");
1069     return llvm::None;
1070   }
1071   if (!crops_type) {
1072     op->emitOpError("BatchToSpaceND: crops type not ranked tensor");
1073     return llvm::None;
1074   }
1075 
1076   // Another 4-step process
1077   int block_rank = block_shape_type.getShape()[0];
1078   int input_rank = input_type.getRank();
1079   int crops_dims = crops_type.getShape()[0];
1080   int remaining_shape_rank = input_rank - block_rank - 1;
1081   auto input_shape = input_type.getShape();
1082 
1083   ElementsAttr block_shape_elems;
1084   ElementsAttr crops_elems;
1085 
1086   if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems))) {
1087     op->emitOpError("BatchToSpaceND: block_shape not a constant");
1088     return llvm::None;
1089   }
1090 
1091   if (!matchPattern(crops_value, m_Constant(&crops_elems))) {
1092     op->emitOpError("BatchToSpaceND: crops not a constant");
1093     return llvm::None;
1094   }
1095 
1096   SmallVector<int64_t, 4> block_shape(block_rank);
1097   SmallVector<std::pair<int64_t, int64_t>, 4> crops(crops_dims);
1098 
1099   // Extract values for block_shape and crops now.
1100   int block_num_elems = 1;
1101   for (int i = 0; i < block_rank; i++) {
1102     int block_shape_val =
1103         rewriter
1104             .getI32IntegerAttr(
1105                 block_shape_elems.getValue<IntegerAttr>(i).getInt())
1106             .getInt();
1107     block_num_elems *= block_shape_val;
1108     block_shape[i] = block_shape_val;
1109   }
1110 
1111   // This iterator seems to be the only reliable way to get
1112   // int values out of a multi-dimensional ElementsAttr
1113   SmallVector<int32_t, 2> crops_const(2 * (crops_dims));
1114   int idx = 0;
1115   for (auto i : crops_elems.getValues<IntegerAttr>()) {
1116     crops_const[idx++] = i.getInt();
1117   }
1118 
1119   for (int i = 0; i < crops_dims; i++) {
1120     int crops_lo = crops_const[i * crops_dims + 0];
1121     int crops_hi = crops_const[i * crops_dims + 1];
1122     crops[i] = std::make_pair(crops_lo, crops_hi);
1123   }
1124 
1125   // Step 1. Reshape input to:
1126   // [block_shape[0],
1127   // ...
1128   // [block_shape[M-1],
1129   // [batch / prod(block_shape)]
1130   // [input_shape[1],
1131   // ...
1132   // [input_shape[N-1]
1133   SmallVector<int64_t, 2> a1_shape(block_rank + input_rank);
1134 
1135   for (int i = 0; i < block_rank; i++) a1_shape[i] = block_shape[i];
1136 
1137   a1_shape[block_rank] = input_shape[0] / block_num_elems;
1138 
1139   for (int i = 0; i < input_rank - 1; i++)
1140     a1_shape[i + block_rank + 1] = input_shape[i + 1];
1141 
1142   auto a1_reshape_input_op = rewriter.create<tosa::ReshapeOp>(
1143       op->getLoc(),
1144       RankedTensorType::get(ArrayRef<int64_t>(a1_shape),
1145                             result_type.getElementType()),
1146       input_value, rewriter.getI64ArrayAttr(a1_shape));
1147 
1148   // 2. Permute to shape
1149   // [ batch / prod(block_shape) ],
1150   // [ input_shape[1] ], [ block_shape[0] ]
1151   //  ...
1152   // [ input_shape[M] ], [ block_shape[M-1]
1153   // + remaining_input_shapes input_shape[M+1 .. N-1]
1154 
1155   // 2a. calculate the permutation
1156   SmallVector<int32_t, 8> a2_perm(block_rank + input_rank);
1157   SmallVector<int64_t, 2> a2_transpose_shape(block_rank + input_rank);
1158 
1159   a2_perm[0] = block_rank;
1160   for (int i = 0; i < block_rank; i++) {
1161     a2_perm[1 + i * 2 + 0] = block_rank + 1 + i;
1162     a2_perm[1 + i * 2 + 1] = i;
1163   }
1164 
1165   for (int i = 0; i < remaining_shape_rank; i++) {
1166     a2_perm[1 + 2 * block_rank + i] = 1 + 2 * block_rank + i;
1167   }
1168 
1169   // 2b. calculate the a2_permuted shape
1170   for (int i = 0; i < (block_rank + input_rank); i++) {
1171     a2_transpose_shape[i] = a1_shape[a2_perm[i]];
1172   }
1173 
1174   Value a2_transpose_perm =
1175       get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, a2_perm);
1176   auto a2_transpose_a1_op = rewriter.create<tosa::TransposeOp>(
1177       op->getLoc(),
1178       RankedTensorType::get(ArrayRef<int64_t>(a2_transpose_shape),
1179                             result_type.getElementType()),
1180       a1_reshape_input_op.getResult(), a2_transpose_perm);
1181 
1182   // Step 3. Reshape to:
1183   // [ batch / prod(block_shape) ],
1184   // [input_shape[1] * block_shape[0] ],
1185   //    ..
1186   // [input_shape[M * block_shape[M-1],
1187   // + remaining input shapes [input_shape[M+1.. N-1]]
1188   SmallVector<int64_t, 2> a4_shape(input_rank);
1189 
1190   a4_shape[0] = input_shape[0] / block_num_elems;
1191   for (int i = 0; i < block_rank; i++) {
1192     a4_shape[1 + i] = input_shape[i + 1] * block_shape[i];
1193   }
1194   for (int i = 0; i < remaining_shape_rank; i++) {
1195     a4_shape[1 + block_rank + i] = input_shape[block_rank + 1 + i];
1196   }
1197 
1198   auto a3_reshape_a2 = rewriter.create<tosa::ReshapeOp>(
1199       op->getLoc(),
1200       RankedTensorType::get(ArrayRef<int64_t>(a4_shape),
1201                             result_type.getElementType()),
1202       a2_transpose_a1_op.getResult(), rewriter.getI64ArrayAttr(a4_shape));
1203 
1204   // 4. Crop the start/end dimensions on 'spatial dimension' according to
1205   // crops
1206   // Use a slice operator to do the cropping.
1207   //
1208   // Calculate a beginning point and a size:
1209   // - Begin is the origin, offset by the lo crop amount in each dimension
1210   // - Size is the reshaped tensor size, minus the quantity (lo + hi) for each
1211   // dimension
1212   SmallVector<int64_t, 4> a4_begin_vals(input_rank), a4_size_vals(input_rank);
1213 
1214   for (int i = 0; i < input_rank; i++) {
1215     // Batch dimension and remaining dimensions.
1216     if (i == 0 || i > crops_dims) {
1217       a4_begin_vals[i] = 0;
1218       a4_size_vals[i] = result_type.getShape()[i];
1219     } else {
1220       // Spatial dimension.
1221       assert(i - 1 >= 0 && i - 1 < crops_dims);
1222       a4_begin_vals[i] = crops[i - 1].first;
1223       a4_size_vals[i] = a4_shape[i] - crops[i - 1].first - crops[i - 1].second;
1224     }
1225   }
1226 
1227   return rewriter
1228       .create<tosa::SliceOp>(
1229           op->getLoc(),
1230           RankedTensorType::get(ArrayRef<int64_t>(a4_size_vals),
1231                                 result_type.getElementType()),
1232           a3_reshape_a2.getResult(), rewriter.getI64ArrayAttr(a4_begin_vals),
1233           rewriter.getI64ArrayAttr(a4_size_vals))
1234       .getResult();
1235 }
1236 
1237 // Lowers ExpandDims to TOSA.
convertExpandDimsOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value dim_value)1238 llvm::Optional<Value> convertExpandDimsOp(PatternRewriter& rewriter,
1239                                           Operation* op, Value result_value,
1240                                           Value input_value, Value dim_value) {
1241   // Lowers to a reshape op with 1's inserted in the appropriate dimensions.
1242   RankedTensorType output_type =
1243       result_value.getType().dyn_cast<RankedTensorType>();
1244   // Not a ranked tensor output
1245   if (!output_type) {
1246     op->emitOpError("ExpandDims: output type not ranked tensor");
1247     return llvm::None;
1248   }
1249 
1250   RankedTensorType input_type =
1251       input_value.getType().dyn_cast<RankedTensorType>();
1252   if (!input_type) {
1253     op->emitOpError("ExpandDims: input type not ranked tensor");
1254     return llvm::None;
1255   }
1256 
1257   auto input_shape = input_type.getShape();
1258 
1259   ElementsAttr dim_elem;
1260   if (!matchPattern(dim_value, m_Constant(&dim_elem))) return llvm::None;
1261 
1262   assert(dim_elem.getType().getRank() == 0 && "expected scalar tensor");
1263   int32_t dim = dim_elem.getValue<IntegerAttr>({}).getInt();
1264 
1265   SmallVector<int64_t, 4> reshape_dims;
1266   if (dim < 0 || dim >= input_shape.size()) {  // add dim at end of tensor
1267     dim = input_shape.size();
1268     for (int i = 0; i < input_shape.size(); i++) {
1269       reshape_dims.emplace_back(input_shape[i]);
1270     }
1271     reshape_dims.emplace_back(1);
1272   } else {
1273     for (int i = 0; i < input_shape.size(); i++) {
1274       if (i == dim) {
1275         reshape_dims.emplace_back(1);
1276       }
1277       reshape_dims.emplace_back(input_shape[i]);
1278     }
1279   }
1280 
1281   ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
1282 
1283   return rewriter
1284       .create<tosa::ReshapeOp>(op->getLoc(), output_type, input_value,
1285                                shape_attr)
1286       .getResult();
1287 }
1288 
1289 // Lowers Squeeze to TOSA.
convertSqueezeOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,SmallVector<int32_t,8> & squeeze_dims)1290 llvm::Optional<Value> convertSqueezeOp(PatternRewriter& rewriter, Operation* op,
1291                                        Value result_value, Value input_value,
1292                                        SmallVector<int32_t, 8>& squeeze_dims) {
1293   // Lowers to a reshape op where dimensions in squeeze_dims with size=1
1294   // are removed.
1295   RankedTensorType output_type =
1296       result_value.getType().dyn_cast<RankedTensorType>();
1297   // Not a ranked tensor output
1298   if (!output_type) {
1299     op->emitOpError("Squeeze: output type not ranked tensor");
1300     return llvm::None;
1301   }
1302 
1303   RankedTensorType input_type =
1304       input_value.getType().dyn_cast<RankedTensorType>();
1305   if (!input_type) {
1306     op->emitOpError("Squeeze: input type not ranked tensor");
1307     return llvm::None;
1308   }
1309 
1310   auto input_shape = input_type.getShape();
1311 
1312   SmallVector<int64_t, 8> reshape_dims;
1313 
1314   if (squeeze_dims.empty()) {  // remove all 1-dims
1315     for (int i = 0; i < input_shape.size(); i++) {
1316       if (input_shape[i] != 1) {
1317         reshape_dims.emplace_back(input_shape[i]);
1318       }
1319     }
1320   } else {
1321     // Remove only specified dims.
1322     // First sort the array so they can be picked off in sequence.
1323     std::sort(squeeze_dims.begin(), squeeze_dims.end(),
1324               [](const int32_t& a, const int32_t& b) { return a < b; });
1325 
1326     int pos = 0;
1327     auto dim = squeeze_dims[pos];
1328     for (int i = 0; i < input_shape.size(); i++) {
1329       if (i == dim) {
1330         pos = pos + 1;
1331         if (pos < squeeze_dims.size())
1332           dim = squeeze_dims[pos];
1333         else
1334           dim = -1;  // Invalid
1335       } else {
1336         reshape_dims.emplace_back(input_shape[i]);
1337       }
1338     }
1339   }
1340 
1341   ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
1342 
1343   return rewriter
1344       .create<tosa::ReshapeOp>(op->getLoc(), output_type, input_value,
1345                                shape_attr)
1346       .getResult();
1347 }
1348 
1349 // Lowers ELU to a sequence of TOSA ops.
convertEluOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value features_value)1350 llvm::Optional<Value> convertEluOp(PatternRewriter& rewriter, Operation* op,
1351                                    Value result_value, Value features_value) {
1352   // Lowers Elu using the following formula:
1353   // elu(x) = x < 0 ? (exp(x) - 1) : x
1354   // one = const({1});
1355   // zero = const({0});
1356   // one_bcast = reshape(one, [1, ..., rank(x) - 1])
1357   // zero_bcast = reshape(zero, [1, ..., rank(x) - 1])
1358   // a1 = exp(x);
1359   // a2 = sub(a1, one_bcast)
1360   // a3 = ge(x, zero_bcast)
1361   // a4 = select(a3, x, a2)
1362   RankedTensorType output_type =
1363       result_value.getType().dyn_cast<RankedTensorType>();
1364   // Not a ranked tensor output
1365   if (!output_type) {
1366     op->emitOpError("Elu: output type not ranked tensor");
1367     return llvm::None;
1368   }
1369 
1370   int32_t input_rank = output_type.getShape().size();
1371   SmallVector<int64_t, 4> bcast_shape(input_rank, 1);
1372 
1373   // Can't directly create size=1, rank=rank(input) tensor because
1374   // it will be optimized out.  Instead, create rank0 tensor and reshape later.
1375   Value one_const_op = getTosaConstTensorSingleF32(rewriter, op, 1.0);
1376 
1377   Value zero_const_op = getTosaConstTensorSingleF32(rewriter, op, 0.0);
1378 
1379   auto a1_exp_in_op =
1380       rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, features_value);
1381 
1382   auto a2_sub_a1_one_op = rewriter.create<tosa::SubOp>(
1383       op->getLoc(), output_type, a1_exp_in_op.getResult(), one_const_op);
1384 
1385   auto a3_ge_in_zero_op = rewriter.create<tosa::GreaterEqualOp>(
1386       op->getLoc(),
1387       RankedTensorType::get(output_type.getShape(), rewriter.getIntegerType(1)),
1388       features_value, zero_const_op);
1389 
1390   return rewriter
1391       .create<tosa::SelectOp>(op->getLoc(), output_type,
1392                               a3_ge_in_zero_op.getResult(), features_value,
1393                               a2_sub_a1_one_op.getResult())
1394       .getResult();
1395 }
1396 
1397 // Lowers Softmax to a sequence of TOSA ops.
convertSoftmaxOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value logits_value)1398 llvm::Optional<Value> convertSoftmaxOp(PatternRewriter& rewriter, Operation* op,
1399                                        Value result_value, Value logits_value) {
1400   // softmax = exp(logits) / reduce_sum(exp(logits), -1)
1401   //
1402   // or equivalently multiply exp(-max(logits)) to both numerator and
1403   // denominator we get:
1404   //
1405   // softmax = exp(logits - max(logits)) / reduce_sum(exp(logits -
1406   // max(logits)), -1)
1407   //
1408   // We'll use first version for direct fp lowering, and second version for
1409   // quantized lowering since second one we can restrict input to exp() be
1410   // negative, and thus LUT can always be within [0.0, 1.0].
1411   RankedTensorType output_type =
1412       result_value.getType().dyn_cast<RankedTensorType>();
1413   RankedTensorType input_type =
1414       logits_value.getType().dyn_cast<RankedTensorType>();
1415 
1416   // Not a ranked tensor input/output
1417   if (!output_type || !input_type) {
1418     op->emitOpError("Softmax: input and result not ranked tensors");
1419     return llvm::None;
1420   }
1421 
1422   // reduce_sum on last dimension
1423   int32_t input_rank = input_type.getShape().size();
1424   ArrayRef<int64_t> logits_shape = output_type.getShape();
1425 
1426   if (input_type.getElementType().isa<mlir::quant::QuantizedType>() &&
1427       output_type.getElementType().isa<mlir::quant::QuantizedType>()) {
1428     SmallVector<int64_t, 4> rsum_shape_v(input_type.getShape().begin(),
1429                                          input_type.getShape().end() - 1);
1430     rsum_shape_v.push_back(1);
1431     ArrayRef<int64_t> rsum_shape(rsum_shape_v);
1432     // The if condition already checks if these are UQTs
1433     mlir::quant::UniformQuantizedType in_quant_type =
1434         input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
1435     mlir::quant::UniformQuantizedType out_quant_type =
1436         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
1437 
1438     auto int16_element_qtype = mlir::quant::UniformQuantizedType::get(
1439         true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
1440         -32768, 32767);
1441     RankedTensorType int16_logits_type =
1442         RankedTensorType::get(logits_shape, int16_element_qtype);
1443     RankedTensorType int32_logits_type =
1444         RankedTensorType::get(logits_shape, rewriter.getIntegerType(32));
1445     RankedTensorType int16_rsum_type =
1446         RankedTensorType::get(rsum_shape, int16_element_qtype);
1447     RankedTensorType int32_rsum_type =
1448         RankedTensorType::get(rsum_shape, rewriter.getIntegerType(32));
1449 
1450     // Step 1. get x - max(x)
1451     Value op1_rescale_in =
1452         buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f,
1453                      in_quant_type.getZeroPoint(), 0);
1454 
1455     auto op2_reducemax_op1 = rewriter.create<tosa::ReduceMaxOp>(
1456         op->getLoc(), int32_rsum_type, op1_rescale_in,
1457         rewriter.getI64IntegerAttr(input_rank - 1));
1458 
1459     auto op3_sub_op1_op2 = rewriter.create<tosa::SubOp>(
1460         op->getLoc(), int32_logits_type, op1_rescale_in,
1461         op2_reducemax_op1.getResult());
1462 
1463     // Table input range from -16.0 to 16.0, input below -16.0 treated as
1464     // exp(-16.0), which is 0 in 0.16
1465     const double exp_sample_grain = 1.0 / 16.0;
1466     auto exp_func = [exp_sample_grain](int32_t x) -> int32_t {
1467       double v = static_cast<double>(x) * exp_sample_grain;
1468       v = v < 0.0 ? std::exp(v) : 1.0;
1469       return std::lround(32768.0 * v);
1470     };
1471 
1472     Value exp_table_const = getTosa1DConstTensorTable(rewriter, op, exp_func);
1473 
1474     // Step 2. rescale input
1475     Value op4_rescale_op3 = buildRescale(
1476         rewriter, op, int16_logits_type, op3_sub_op1_op2.getResult(),
1477         in_quant_type.getScale() * 128.0 / exp_sample_grain, 0, 0);
1478 
1479     // Step 3. get exp() result
1480     // Since we already make sure input x < 0 in step 1,
1481     // we can utilize full output 0.16 range.
1482 
1483     // Output is 0.23
1484     auto op5_table_op4 = rewriter.create<tosa::TableOp>(
1485         op->getLoc(), int32_logits_type, op4_rescale_op3, exp_table_const);
1486 
1487     // Right shift 3 bits. output 0.20
1488     auto op6_rshift_op5 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1489         op->getLoc(), int32_logits_type, op5_table_op4.getResult(),
1490         getTosaConstTensorSingleI32(rewriter, op, 3), true);
1491 
1492     // Step 4. get sum(exp()). output 12.20
1493     auto op7_reducesum_op6 = rewriter.create<tosa::ReduceSumOp>(
1494         op->getLoc(), int32_rsum_type, op6_rshift_op5.getResult(),
1495         rewriter.getI64IntegerAttr(input_rank - 1));
1496 
1497     // Step 5. calculate reciprocal(sum(exp()))
1498     auto op8_clz_op7 = rewriter.create<tosa::ClzOp>(
1499         op->getLoc(), int32_rsum_type, op7_reducesum_op6.getResult());
1500 
1501     // rshift amount of reciprocal(sum(exp()))
1502     // 12 from the integer bits of 12.20 accumulator
1503     // 30 from output of multiply 0.15 x 0.15
1504     // -8 to keep additional 8 bits before output rescaling
1505     auto op9_sub_op8 = rewriter.create<tosa::SubOp>(
1506         op->getLoc(), int32_rsum_type,
1507         getTosaConstTensorSingleI32(rewriter, op, 12 + 30 - 8),
1508         op8_clz_op7.getResult());
1509 
1510     // Left shift to get  1.31 format
1511     auto op10_lshift_op7_op8 = rewriter.create<tosa::LogicalLeftShiftOp>(
1512         op->getLoc(), int32_rsum_type, op7_reducesum_op6.getResult(),
1513         op8_clz_op7.getResult());
1514 
1515     // Subtract (1 << 31) to make 0 <= x <= 1
1516     auto op11_sub_op10 = rewriter.create<tosa::SubOp>(
1517         op->getLoc(), int32_rsum_type, op10_lshift_op7_op8.getResult(),
1518         getTosaConstTensorSingleI32(rewriter, op, (1u << 31)));
1519 
1520     // Right shift 16 bits to get 16 bits index
1521     auto op12_rshift_op11 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1522         op->getLoc(), int32_rsum_type, op11_sub_op10.getResult(),
1523         getTosaConstTensorSingleI32(rewriter, op, 16), true);
1524 
1525     // cast to 16 bits to index TABLE op
1526     auto op13_cast_op12 = rewriter.create<tosa::CastOp>(
1527         op->getLoc(), int16_rsum_type, op12_rshift_op11.getResult());
1528 
1529     // Generate table for 1 / (1 + x), for 0 <= x <= 1
1530     const double one_over_one_plus_x_sample_grain = 1.0 / 256.0;
1531     auto one_over_one_plus_x_func =
1532         [one_over_one_plus_x_sample_grain](int32_t x) -> int32_t {
1533       double v = static_cast<double>(x) * one_over_one_plus_x_sample_grain;
1534       v = v < 0 ? 1.0 : 1.0 / (1.0 + v);
1535       return std::lround(32768.0 * v);
1536     };
1537 
1538     Value one_over_one_plus_x_table_const =
1539         getTosa1DConstTensorTable(rewriter, op, one_over_one_plus_x_func);
1540 
1541     auto op14_table_op13 = rewriter.create<tosa::TableOp>(
1542         op->getLoc(), int32_rsum_type, op13_cast_op12.getResult(),
1543         one_over_one_plus_x_table_const);
1544 
1545     // Rescale sum(exp(x)) from 0.23 back to 0.16
1546     Value op15_rescale_op14 = buildRescale(rewriter, op, int32_rsum_type,
1547                                            op14_table_op13, 1.0 / 128.0, 0, 0);
1548 
1549     // Rescale exp(x) from 0.23 back to 0.16
1550     Value op16_rescale_op5 =
1551         buildRescale(rewriter, op, int32_logits_type, op5_table_op4.getResult(),
1552                      1.0 / 128.0, 0, 0);
1553 
1554     // Step 6. apply the scales we just get explicitly in i32 space
1555     // lhs: 0.16, rhs: 0.16, output: 0.32
1556     auto op17_mul_op15_op16 =
1557         rewriter.create<tosa::MulOp>(op->getLoc(), int32_logits_type,
1558                                      op15_rescale_op14, op16_rescale_op5, 0);
1559 
1560     // Apply right shift from clz
1561     auto op18_rshift_op17_op9 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1562         op->getLoc(), int32_logits_type, op17_mul_op15_op16.getResult(),
1563         op9_sub_op8.getResult(), true);
1564 
1565     // Step 7. output scaling, extra 1.0 / 256.0 since we keep extra 8 bits
1566     // in op9_sub_op8
1567     return buildRescale(rewriter, op, output_type,
1568                         op18_rshift_op17_op9.getResult(),
1569                         1.0 / (out_quant_type.getScale() * 256.0), 0,
1570                         out_quant_type.getZeroPoint());
1571 
1572   } else {
1573     SmallVector<int64_t, 4> rsum_shape_v(input_type.getShape().begin(),
1574                                          input_type.getShape().end());
1575     rsum_shape_v[input_rank - 1] = 1;
1576     ArrayRef<int64_t> rsum_shape(rsum_shape_v);
1577 
1578     // Floating-point loewring is more direct:
1579     //
1580     // op1 = exp(logits)
1581     // op2 = reduce_sum(op1, -1)
1582     // op3 = reciprocal(op2)
1583     // op4 = mul(op1, op3)
1584     auto op1_exp_in =
1585         rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, logits_value);
1586     RankedTensorType rsum_type =
1587         RankedTensorType::get(rsum_shape, output_type.getElementType());
1588 
1589     // Keep dims so we don't need to reshape later
1590     auto op2_reducesum_op1 = rewriter.create<tosa::ReduceSumOp>(
1591         op->getLoc(), rsum_type, op1_exp_in.getResult(),
1592         rewriter.getI64IntegerAttr(input_rank - 1));
1593     auto op3_reciprocal_op2 = rewriter.create<tosa::ReciprocalOp>(
1594         op->getLoc(), rsum_type, op2_reducesum_op1.getResult());
1595 
1596     return rewriter
1597         .create<tosa::MulOp>(op->getLoc(), output_type, op1_exp_in.getResult(),
1598                              op3_reciprocal_op2.getResult(), 0)
1599         .getResult();
1600   }
1601 }
1602 
1603 // Lowers LogSoftmax to a sequence of TOSA ops.
convertLogSoftmaxOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value logits_value)1604 llvm::Optional<Value> convertLogSoftmaxOp(PatternRewriter& rewriter,
1605                                           Operation* op, Value result_value,
1606                                           Value logits_value) {
1607   // log_softmax = log(exp(logits) / reduce_sum(exp(logits), -1))
1608   // op1 = exp(logits)
1609   // op2 = reduce_sum(op1, -1)
1610   // op3 = reciprocal(op2)
1611   // op4 = mul(op1, op3)
1612   // op5 = log(op4)
1613 
1614   RankedTensorType output_type =
1615       result_value.getType().dyn_cast<RankedTensorType>();
1616   // Not a ranked tensor output
1617   if (!output_type) {
1618     op->emitOpError("LogSoftmax: output type not ranked tensor.");
1619     return llvm::None;
1620   }
1621 
1622   RankedTensorType input_type =
1623       op->getOperand(0).getType().dyn_cast<RankedTensorType>();
1624   if (!input_type) {
1625     op->emitOpError("LogSoftmax: input type not ranked tensor.");
1626     return llvm::None;
1627   }
1628 
1629   mlir::quant::UniformQuantizedType in_quant_type =
1630       input_type.getElementType()
1631           .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
1632   mlir::quant::UniformQuantizedType out_quant_type =
1633       output_type.getElementType()
1634           .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
1635   if (in_quant_type || out_quant_type) {
1636     op->emitOpError("Quantized log_softmax lowering not implemented yet");
1637     return llvm::None;
1638   }
1639 
1640   auto op1_exp_in =
1641       rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, logits_value);
1642 
1643   // reduce_sum on last dimension
1644   int32_t input_rank = input_type.getShape().size();
1645   SmallVector<int64_t, 4> rsum_shape(output_type.getShape().begin(),
1646                                      output_type.getShape().end());
1647   rsum_shape[input_rank - 1] = 1;
1648   RankedTensorType rsum_type = RankedTensorType::get(
1649       ArrayRef<int64_t>(rsum_shape), output_type.getElementType());
1650   // Keep dims so we don't need to reshape later
1651   auto op2_reducesum_op1 = rewriter.create<tosa::ReduceSumOp>(
1652       op->getLoc(), rsum_type, op1_exp_in.getResult(),
1653       rewriter.getI64IntegerAttr(input_rank - 1));
1654   auto op3_reciprocal_op2 = rewriter.create<tosa::ReciprocalOp>(
1655       op->getLoc(), rsum_type, op2_reducesum_op1.getResult());
1656 
1657   auto op4_mul_op1_op3 = rewriter.create<tosa::MulOp>(
1658       op->getLoc(), output_type, op1_exp_in.getResult(),
1659       op3_reciprocal_op2.getResult(), 0);
1660 
1661   return rewriter
1662       .create<tosa::LogOp>(op->getLoc(), output_type,
1663                            op4_mul_op1_op3.getResult())
1664       .getResult();
1665 }
1666 
1667 // Lowers SpaceToDepth to a sequence of TOSA ops.  Supports NHWC.
convertSpaceToDepthOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,IntegerAttr block_size_attr,StringAttr data_format)1668 llvm::Optional<Value> convertSpaceToDepthOp(PatternRewriter& rewriter,
1669                                             Operation* op, Value result_value,
1670                                             Value input_value,
1671                                             IntegerAttr block_size_attr,
1672                                             StringAttr data_format) {
1673   // NHWC lowering version:
1674   // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1]//b, b, orig_shape[2]//b,
1675   // b, orig_shape[3]])
1676   // a3 = tf.transpose(a2, [0, 1, 3, 2, 4, 5])
1677   // a4 = tf.reshape(a3, [orig_shape[0], orig_shape[1]//b, orig_shape[2]//b,
1678   // orig_shape[3]*b*b])
1679   // return a4
1680   RankedTensorType output_type =
1681       result_value.getType().dyn_cast<RankedTensorType>();
1682 
1683   // Not a ranked tensor output.
1684   if (!output_type) {
1685     op->emitOpError("SpaceToDepth: output type not ranked tensor.");
1686     return llvm::None;
1687   }
1688 
1689   RankedTensorType input_type =
1690       input_value.getType().dyn_cast<RankedTensorType>();
1691   if (!input_type) {
1692     op->emitOpError("SpaceToDepth: input type not ranked tensor.");
1693     return llvm::None;
1694   }
1695 
1696   if (input_type.getRank() != 4) {
1697     op->emitOpError("SpaceToDepth: input rank not 4.");
1698     return llvm::None;
1699   }
1700 
1701   auto input_shape = input_type.getShape();
1702 
1703   if (!block_size_attr) {  // This is a required parameter
1704     op->emitOpError("SpaceToDepth: block size attribute not set.");
1705     return llvm::None;
1706   }
1707 
1708   SmallVector<int64_t, 2> block_size;
1709   block_size.assign(2, block_size_attr.getInt());
1710 
1711   if (!data_format) data_format = rewriter.getStringAttr("NHWC");
1712 
1713   if (data_format.getValue().str() != "NHWC") {
1714     op->emitOpError("SpaceToDepth: data format not NHWC.");
1715     return llvm::None;
1716   }
1717 
1718   assert(block_size[0] * block_size[1] != 0);
1719 
1720   SmallVector<int64_t, 4> a_reshape_dims;
1721   a_reshape_dims.push_back(input_shape[0]);
1722   a_reshape_dims.push_back(input_shape[1] / block_size[0]);
1723   a_reshape_dims.push_back(block_size[0]);
1724   a_reshape_dims.push_back(input_shape[2] / block_size[1]);
1725   a_reshape_dims.push_back(block_size[1]);
1726   a_reshape_dims.push_back(input_shape[3]);
1727 
1728   RankedTensorType a_reshape_output_type = RankedTensorType::get(
1729       ArrayRef<int64_t>(a_reshape_dims), output_type.getElementType());
1730   auto a2_reshape_a_op = rewriter.create<tosa::ReshapeOp>(
1731       op->getLoc(), a_reshape_output_type, input_value,
1732       rewriter.getI64ArrayAttr(a_reshape_dims));
1733 
1734   Value a3_transpose_perm = get1DConstTensor<tosa::ConstOp, int32_t>(
1735       rewriter, op, {0, 1, 3, 2, 4, 5});
1736 
1737   auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
1738       op->getLoc(), a_reshape_output_type, a2_reshape_a_op.getResult(),
1739       a3_transpose_perm);
1740 
1741   SmallVector<int64_t, 4> a3_reshape_dims;
1742   a3_reshape_dims.push_back(input_shape[0]);
1743   a3_reshape_dims.push_back(input_shape[1] / block_size[0]);
1744   a3_reshape_dims.push_back(input_shape[2] / block_size[1]);
1745   a3_reshape_dims.push_back(input_shape[3] * block_size[0] * block_size[1]);
1746 
1747   RankedTensorType a3_reshape_output_type = RankedTensorType::get(
1748       ArrayRef<int64_t>(a3_reshape_dims), output_type.getElementType());
1749   return rewriter
1750       .create<tosa::ReshapeOp>(op->getLoc(), a3_reshape_output_type,
1751                                a3_transpose_a2_op.getResult(),
1752                                rewriter.getI64ArrayAttr(a3_reshape_dims))
1753       .getResult();
1754 }
1755 
1756 // Lowers DepthToSpace to a sequence of TOSA ops.  Supports NHWC.
convertDepthToSpaceOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,IntegerAttr block_size_attr,StringAttr data_format)1757 llvm::Optional<Value> convertDepthToSpaceOp(PatternRewriter& rewriter,
1758                                             Operation* op, Value result_value,
1759                                             Value input_value,
1760                                             IntegerAttr block_size_attr,
1761                                             StringAttr data_format) {
1762   // NHWC version
1763   // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1], orig_shape[2], b, b,
1764   // orig_shape[3] // (b*b)])
1765   // a3 = tf.transpose(a2, [0, 1, 3, 2, 4, 5])
1766   // a4 = tf.reshape(a3, [orig_shape[0], orig_shape[1] * b, orig_shape[2] * b,
1767   // orig_shape[3] // (b*b)])
1768   // return a4
1769 
1770   RankedTensorType output_type =
1771       result_value.getType().dyn_cast<RankedTensorType>();
1772 
1773   // Not a ranked tensor output
1774   if (!output_type) {
1775     op->emitOpError("DepthToSpace: output type not ranked tensor.");
1776     return llvm::None;
1777   }
1778 
1779   RankedTensorType input_type =
1780       input_value.getType().dyn_cast<RankedTensorType>();
1781   if (!input_type) {
1782     op->emitOpError("DepthToSpace: input type not ranked tensor.");
1783     return llvm::None;
1784   }
1785 
1786   if (input_type.getRank() != 4) return llvm::None;
1787   auto input_shape = input_type.getShape();
1788 
1789   if (!block_size_attr) {  // This is a required parameter
1790     op->emitOpError("DepthToSpace: block size attribute not set.");
1791     return llvm::None;
1792   }
1793 
1794   SmallVector<int64_t, 2> block_size;
1795   block_size.assign(2, block_size_attr.getInt());
1796 
1797   if (!data_format) data_format = rewriter.getStringAttr("NHWC");
1798   if (data_format.getValue().str() != "NHWC") {
1799     op->emitOpError("DepthToSpace: data format not NHWC.");
1800     return llvm::None;
1801   }
1802 
1803   assert(block_size[0] * block_size[1] != 0);
1804 
1805   SmallVector<int64_t, 4> a_reshape_dims;
1806   a_reshape_dims.push_back(input_shape[0]);
1807   a_reshape_dims.push_back(input_shape[1]);
1808   a_reshape_dims.push_back(input_shape[2]);
1809   a_reshape_dims.push_back(block_size[0]);
1810   a_reshape_dims.push_back(block_size[1]);
1811   a_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
1812 
1813   RankedTensorType a_reshape_output_type = RankedTensorType::get(
1814       ArrayRef<int64_t>(a_reshape_dims), output_type.getElementType());
1815   auto a2_reshape_a_op = rewriter.create<tosa::ReshapeOp>(
1816       op->getLoc(), a_reshape_output_type, input_value,
1817       rewriter.getI64ArrayAttr(a_reshape_dims));
1818 
1819   Value a3_transpose_perm = get1DConstTensor<tosa::ConstOp, int32_t>(
1820       rewriter, op, {0, 1, 3, 2, 4, 5});
1821 
1822   auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
1823       op->getLoc(), a_reshape_output_type, a2_reshape_a_op.getResult(),
1824       a3_transpose_perm);
1825 
1826   SmallVector<int64_t, 4> a3_reshape_dims;
1827   a3_reshape_dims.push_back(input_shape[0]);
1828   a3_reshape_dims.push_back(input_shape[1] * block_size[0]);
1829   a3_reshape_dims.push_back(input_shape[2] * block_size[1]);
1830   a3_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
1831 
1832   RankedTensorType a3_reshape_output_type = RankedTensorType::get(
1833       ArrayRef<int64_t>(a3_reshape_dims), output_type.getElementType());
1834   return rewriter
1835       .create<tosa::ReshapeOp>(op->getLoc(), a3_reshape_output_type,
1836                                a3_transpose_a2_op.getResult(),
1837                                rewriter.getI64ArrayAttr(a3_reshape_dims))
1838       .getResult();
1839 }
1840 
1841 // Lowers Split to a sequence of TOSA ops.
convertSplitOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,int32_t num_split,int32_t axis)1842 llvm::Optional<ValueRange> convertSplitOp(PatternRewriter& rewriter,
1843                                           Operation* op, Value result_value,
1844                                           Value input_value, int32_t num_split,
1845                                           int32_t axis) {
1846   // This lowering creates num_split slice ops and ties them together
1847   // with IdentityN to get from an array of Operations to a single Operation
1848   // with a list of result tensors.
1849   RankedTensorType result_type =
1850       result_value.getType().dyn_cast<RankedTensorType>();
1851   // Not a ranked tensor output
1852   if (!result_type) {
1853     op->emitOpError("Split: output type not ranked tensor.");
1854     return llvm::None;
1855   }
1856 
1857   RankedTensorType input_type =
1858       input_value.getType().dyn_cast<RankedTensorType>();
1859   if (!input_type) {
1860     op->emitOpError("Split: input type not ranked tensor.");
1861     return llvm::None;
1862   }
1863 
1864   auto input_shape = input_type.getShape();
1865 
1866   SmallVector<Value, 4> results_vec;
1867 
1868   assert(axis > 0 && axis < input_shape.size());
1869   assert((input_shape[axis] % num_split) == 0);
1870   assert(num_split > 0);
1871 
1872   int64_t slice_size = input_shape[axis] / num_split;
1873 
1874   SmallVector<Type, 4>
1875       outs_type_vec;  // A list of the output types for each slice op
1876 
1877   for (int i = 0; i < num_split; i++) {
1878     // Each slice has a different begining point.
1879     // The slice size is actually the same each op.
1880     SmallVector<int64_t, 4> begin_vals, size_vals;
1881 
1882     for (int j = 0; j < input_shape.size(); j++) {
1883       if (j == axis) {
1884         begin_vals.push_back(slice_size * i);
1885         size_vals.push_back(slice_size);
1886       } else {
1887         begin_vals.push_back(0);
1888         size_vals.push_back(input_shape[j]);
1889       }
1890     }
1891 
1892     ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
1893     ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
1894 
1895     outs_type_vec.push_back(RankedTensorType::get(
1896         ArrayRef<int64_t>(size_vals), result_type.getElementType()));
1897 
1898     auto slice_op = rewriter.create<tosa::SliceOp>(
1899         op->getLoc(),
1900         RankedTensorType::get(ArrayRef<int64_t>(size_vals),
1901                               result_type.getElementType()),
1902         input_value, begin, size);
1903 
1904     results_vec.push_back(slice_op.getResult());
1905   }
1906 
1907   // Combine the sequence of tosa.slice() ops into a list
1908   // using the IdentityN operator
1909   return rewriter
1910       .create<tosa::IdentityNOp>(op->getLoc(), ArrayRef<Type>(outs_type_vec),
1911                                  results_vec)
1912       .getResults();
1913 }
1914 
1915 // Lowers SplitV to a sequence of TOSA ops.
convertSplitVOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,SmallVector<int32_t,4> & size_split,int32_t axis)1916 llvm::Optional<ValueRange> convertSplitVOp(PatternRewriter& rewriter,
1917                                            Operation* op, Value result_value,
1918                                            Value input_value,
1919                                            SmallVector<int32_t, 4>& size_split,
1920                                            int32_t axis) {
1921   // This lowering creates num_split slice ops and ties them together
1922   // with IdentityN to get from an array of Operations to a single Operation
1923   // with a list of result tensors.
1924   RankedTensorType result_type =
1925       result_value.getType().dyn_cast<RankedTensorType>();
1926   // Not a ranked tensor output
1927   if (!result_type) {
1928     op->emitOpError("SplitV: output type not ranked tensor.");
1929     return llvm::None;
1930   }
1931 
1932   RankedTensorType input_type =
1933       input_value.getType().dyn_cast<RankedTensorType>();
1934   if (!input_type) {
1935     op->emitOpError("SplitV: input type not ranked tensor.");
1936     return llvm::None;
1937   }
1938 
1939   auto input_shape = input_type.getShape();
1940 
1941   SmallVector<Value, 4> results_vec;
1942 
1943   assert(axis > 0 && axis < input_shape.size());
1944   int32_t size_split_sum = 0;
1945   for (int i = 0; i < size_split.size(); i++) {
1946     size_split_sum += size_split[i];
1947   }
1948 
1949   // The split sizes must sum up to the size of the axis being split
1950   assert(size_split_sum == input_shape[axis]);
1951 
1952   // Create num_split slice ops:
1953   SmallVector<Type, 4>
1954       outs_type_vec;  // A list of the output types for each slice op
1955 
1956   int32_t curr_split_start = 0;
1957   for (int i = 0; i < size_split.size(); i++) {
1958     // Each slice has a different begining point.
1959     // The slice size is different for each op.
1960     SmallVector<int64_t, 4> begin_vals, size_vals;
1961 
1962     for (int j = 0; j < input_shape.size(); j++) {
1963       if (j == axis) {
1964         begin_vals.push_back(curr_split_start);
1965         size_vals.push_back(size_split[i]);
1966       } else {
1967         begin_vals.push_back(0);
1968         size_vals.push_back(input_shape[j]);
1969       }
1970     }
1971 
1972     ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
1973     ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
1974 
1975     outs_type_vec.push_back(RankedTensorType::get(
1976         ArrayRef<int64_t>(size_vals), result_type.getElementType()));
1977 
1978     auto slice_op = rewriter.create<tosa::SliceOp>(
1979         op->getLoc(),
1980         RankedTensorType::get(ArrayRef<int64_t>(size_vals),
1981                               result_type.getElementType()),
1982         input_value, begin, size);
1983 
1984     results_vec.push_back(slice_op.getResult());
1985 
1986     // Next start position
1987     curr_split_start += size_split[i];
1988   }
1989 
1990   // Combine the sequence of tosa.slice() ops into a list
1991   // using the IdentityN operator
1992   return rewriter
1993       .create<tosa::IdentityNOp>(op->getLoc(), ArrayRef<Type>(outs_type_vec),
1994                                  results_vec)
1995       .getResults();
1996 }
1997 
1998 // Lowers StridedSlice to a sequence of TOSA ops.
convertStridedSliceOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value begin_value,Value end_value,Value strides_value,int32_t begin_mask,int32_t end_mask,int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask)1999 llvm::Optional<Value> convertStridedSliceOp(
2000     PatternRewriter& rewriter, Operation* op, Value result_value,
2001     Value input_value, Value begin_value, Value end_value, Value strides_value,
2002     int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
2003     int32_t new_axis_mask, int32_t shrink_axis_mask) {
2004   // The mask arguments are bitmasks where bit [i] applies to
2005   // dimension [i] of the input tensor.
2006   //
2007   // The rough algorithm for lowering strided slice is as follows:
2008   //
2009   // 0. Process begin/end masks, since they are basically syntactic sugar
2010   // on top of the begin_value/end_value arrays
2011   //
2012   // 1. Slice1: Ignoring stride, slice the interesting range from the input
2013   // tensor
2014   //
2015   // 2. Reshape2: Reshape the tensor from (1) such that each dimension with
2016   // stride is split into two dimensions of size_i/stride_i, stride_i.   A naive
2017   // implementation doubles the input tensor rank, but only dimensions being
2018   // strided actually need to be doubled.
2019   //
2020   // 3. Slice3: Slice the tensor from (2) such that we select index [0] from
2021   // each of the stride_i dimensions in (2)
2022   //
2023   // 4. Reshape4: Reshape the tensor to eliminate the stride_i dimensions, add
2024   // any dimensions in new_axis_mask and remove any dimensions in the
2025   // shrink_axis_mask
2026 
2027   // Limitations:
2028   // This implementation only supports ellipsis_mask=0 for now
2029   // This implementation does not support reverse stride yet.  Will need
2030   // to insert tosa.Reverse operators for this.
2031   assert(ellipsis_mask == 0);
2032 
2033   RankedTensorType input_type =
2034       input_value.getType().dyn_cast<RankedTensorType>();
2035   RankedTensorType result_type =
2036       result_value.getType().dyn_cast<RankedTensorType>();
2037 
2038   if (!result_type) {
2039     op->emitOpError("StridedSlice: output type not ranked tensor.");
2040     return llvm::None;
2041   }
2042 
2043   if (!input_type) {
2044     op->emitOpError("StridedSlice: input type not ranked tensor.");
2045     return llvm::None;
2046   }
2047 
2048   int32_t input_rank = input_type.getRank();
2049   auto input_shape = input_type.getShape();
2050 
2051   // Extract the begin/end/stride tensors
2052   SmallVector<int32_t, 4> begin, end, strides;
2053 
2054   if (getVectorFromValue32(begin_value, begin) != input_rank) {
2055     op->emitOpError("StridedSlice: begin doesn't match input_rank.");
2056     return llvm::None;
2057   }
2058   if (getVectorFromValue32(end_value, end) != input_rank) {
2059     op->emitOpError("StridedSlice: end doesn't match input_rank.");
2060     return llvm::None;
2061   }
2062   if (getVectorFromValue32(strides_value, strides) != input_rank) {
2063     op->emitOpError("StridedSlice: strides doesn't match input_rank.");
2064     return llvm::None;
2065   }
2066 
2067   SmallVector<int64_t, 2> a1_begin(input_rank), a1_size(input_rank);
2068   SmallVector<int64_t, 2> a2_shape(input_rank * 2);
2069   SmallVector<int64_t, 2> a3_begin(input_rank * 2), a3_size(input_rank * 2);
2070   SmallVector<int64_t, 2> a4_shape;
2071 
2072   // Step 0: Process the begin/end masks and build the begin/sizes for the
2073   // first slice
2074   int residual = 1;
2075   (void)residual;
2076   for (int i = 0; i < input_rank; i++) {
2077     if (begin_mask & (1 << i)) begin[i] = 0;
2078 
2079     if (end_mask & (1 << i)) end[i] = input_shape[i];
2080 
2081     // Wrap around index if begin and end is negative
2082     if (begin[i] < 0) begin[i] += input_shape[i];
2083 
2084     if (end[i] < 0) end[i] += input_shape[i];
2085 
2086     // TODO: support reverse stride
2087     a1_begin[i] = begin[i];
2088     a1_size[i] = end[i] - begin[i];
2089 
2090     a2_shape[i * 2 + 0] = a1_size[i] / strides[i];
2091     a2_shape[i * 2 + 1] = strides[i];
2092 
2093     a3_begin[i * 2 + 0] = 0;
2094     a3_begin[i * 2 + 1] = 0;
2095 
2096     if (shrink_axis_mask & (1 << i)) {
2097       a3_size[i * 2 + 0] = 1;
2098     } else {
2099       a3_size[i * 2 + 0] = a1_size[i] / strides[i];
2100     }
2101     a3_size[i * 2 + 1] = 1;
2102 
2103     if (!(shrink_axis_mask & (1 << i))) {
2104       if (new_axis_mask & (1 << i)) a4_shape.push_back(1);
2105       a4_shape.push_back((a1_size[i] / strides[i]));
2106     }
2107   }
2108 
2109   // Make sure we didn't lose any dimensions from the shrink_axis_mask
2110   assert(residual == 1);
2111 
2112   // Step 1: Slice the input array
2113   auto a1_slice_op = rewriter.create<tosa::SliceOp>(
2114       op->getLoc(),
2115       RankedTensorType::get(ArrayRef<int64_t>(a1_size),
2116                             input_type.getElementType()),
2117       input_value, rewriter.getI64ArrayAttr(a1_begin),
2118       rewriter.getI64ArrayAttr(a1_size));
2119 
2120   // Step 2: reshape the sliced array
2121   auto a2_reshape_op = rewriter.create<tosa::ReshapeOp>(
2122       op->getLoc(),
2123       RankedTensorType::get(ArrayRef<int64_t>(a2_shape),
2124                             input_type.getElementType()),
2125       a1_slice_op.getResult(), rewriter.getI64ArrayAttr(a2_shape));
2126 
2127   // Step 3: take a slice along the strides
2128   auto a3_slice_op = rewriter.create<tosa::SliceOp>(
2129       op->getLoc(),
2130       RankedTensorType::get(ArrayRef<int64_t>(a3_size),
2131                             input_type.getElementType()),
2132       a2_reshape_op.getResult(), rewriter.getI64ArrayAttr(a3_begin),
2133       rewriter.getI64ArrayAttr(a3_size));
2134 
2135   // Step 4: reshape the now-strided tensor
2136   return rewriter
2137       .create<tosa::ReshapeOp>(
2138           op->getLoc(),
2139           RankedTensorType::get(ArrayRef<int64_t>(a4_shape),
2140                                 input_type.getElementType()),
2141           a3_slice_op.getResult(), rewriter.getI64ArrayAttr(a4_shape))
2142       .getResult();
2143 }
2144 
2145 // Lowers FloorDiv to a sequence of TOSA operators.
convertFloorDivOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value lhs_value,Value rhs_value)2146 llvm::Optional<Value> convertFloorDivOp(PatternRewriter& rewriter,
2147                                         Operation* op, Value result_value,
2148                                         Value lhs_value, Value rhs_value) {
2149   // FloorDiv lowering:
2150   // floor(1/rhs * lhs)
2151   //
2152   // a1 = reciprocal(rhs);
2153   // a2 = mul(lhs, a1);
2154   // a3 = floor(a2);
2155   // return a3;
2156   RankedTensorType output_type =
2157       result_value.getType().dyn_cast<RankedTensorType>();
2158   // Not a ranked tensor output
2159   if (!output_type) return llvm::None;
2160 
2161   auto a1_reciprocal_rhs_op =
2162       rewriter.create<tosa::ReciprocalOp>(op->getLoc(), output_type, rhs_value);
2163   auto a2_mul_lhs_a1_op =
2164       rewriter.create<tosa::MulOp>(op->getLoc(), output_type, lhs_value,
2165                                    a1_reciprocal_rhs_op.getResult(), 0);
2166   return rewriter
2167       .create<tosa::FloorOp>(op->getLoc(), output_type,
2168                              a2_mul_lhs_a1_op.getResult())
2169       .getResult();
2170 }
2171 
2172 // Lowers FloorMod to a sequence of TOSA operators.
convertFloorModOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value lhs_value,Value rhs_value)2173 llvm::Optional<Value> convertFloorModOp(PatternRewriter& rewriter,
2174                                         Operation* op, Value result_value,
2175                                         Value lhs_value, Value rhs_value) {
2176   // FloorMod lowering:
2177   // (1/rhs * lhs) - floor(1/rhs * lhs)
2178   // a1 = reciprocal(rhs);
2179   // a2 = mul(lhs, a1);
2180   // a3 = floor(a2);
2181   // a4 = sub(a2, a3);
2182   // return a4;
2183 
2184   RankedTensorType output_type =
2185       result_value.getType().dyn_cast<RankedTensorType>();
2186   // Not a ranked tensor output
2187   if (!output_type) return llvm::None;
2188 
2189   auto a1_reciprocal_rhs_op =
2190       rewriter.create<tosa::ReciprocalOp>(op->getLoc(), output_type, rhs_value);
2191   auto a2_mul_lhs_a1_op =
2192       rewriter.create<tosa::MulOp>(op->getLoc(), output_type, lhs_value,
2193                                    a1_reciprocal_rhs_op.getResult(), 0);
2194   auto a3_floor_a2_op = rewriter.create<tosa::FloorOp>(
2195       op->getLoc(), output_type, a2_mul_lhs_a1_op.getResult());
2196   return rewriter
2197       .create<tosa::SubOp>(op->getLoc(), output_type,
2198                            a2_mul_lhs_a1_op.getResult(),
2199                            a3_floor_a2_op.getResult())
2200       .getResult();
2201 }
2202 
2203 // Lowers FusedActivation to a sequence of TOSA ops.
convertFusedActivation(PatternRewriter & rewriter,Operation * op,Value input_value,StringAttr fused_activation_fn)2204 llvm::Optional<Value> convertFusedActivation(PatternRewriter& rewriter,
2205                                              Operation* op, Value input_value,
2206                                              StringAttr fused_activation_fn) {
2207   RankedTensorType input_type =
2208       input_value.getType().dyn_cast<RankedTensorType>();
2209   if (!input_type) return llvm::None;
2210 
2211   bool input_is_qtype =
2212       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2213 
2214   if (input_is_qtype) {
2215     auto input_qtype =
2216         input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2217 
2218     if (fused_activation_fn.getValue() == "TANH") {
2219       // TODO: implement with TABLE
2220       op->emitWarning("Quantized TANH lowering TBD!");
2221       return llvm::None;
2222     } else {
2223       RankedTensorType rescale_type = RankedTensorType::get(
2224           input_type.getShape(), rewriter.getIntegerType(32));
2225 
2226       Value op1_rescale_in = buildRescaleToInt32(
2227           rewriter, op, input_value, 1.0f, input_qtype.getZeroPoint());
2228 
2229       Value op2_relu_op1;
2230       if (fused_activation_fn.getValue() == "NONE") {
2231         return input_value;
2232       } else if (fused_activation_fn.getValue() == "RELU") {
2233         auto relu_op = rewriter.create<tosa::ReluNOp>(
2234             op->getLoc(), rescale_type, op1_rescale_in,
2235             rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
2236             rewriter.getF32FloatAttr(0));
2237 
2238         op2_relu_op1 = relu_op.getResult();
2239 
2240       } else if (fused_activation_fn.getValue() == "RELU6") {
2241         int64_t rescaled_6 = std::llround(6.0f / input_qtype.getScale()) +
2242                              input_qtype.getZeroPoint();
2243 
2244         auto relu_op = rewriter.create<tosa::ReluNOp>(
2245             op->getLoc(), rescale_type, op1_rescale_in,
2246             rewriter.getI64IntegerAttr(rescaled_6),
2247             rewriter.getF32FloatAttr(0.0f));
2248 
2249         op2_relu_op1 = relu_op.getResult();
2250 
2251       } else if (fused_activation_fn.getValue() == "RELU_N1_TO_1") {
2252         int64_t rescaled_n1 = std::llround(-1.0f / input_qtype.getScale()) +
2253                               input_qtype.getZeroPoint();
2254         int64_t rescaled_1 = std::llround(1.0f / input_qtype.getScale()) +
2255                              input_qtype.getZeroPoint();
2256 
2257         auto relu_op = rewriter.create<tosa::ClampOp>(
2258             op->getLoc(), rescale_type, op1_rescale_in,
2259             rewriter.getI64IntegerAttr(rescaled_n1),
2260             rewriter.getI64IntegerAttr(rescaled_1),
2261             rewriter.getF32FloatAttr(0.0f), rewriter.getF32FloatAttr(0.0f));
2262 
2263         op2_relu_op1 = relu_op.getResult();
2264       } else {
2265         return llvm::None;
2266       }
2267 
2268       return buildRescaleFromInt32(rewriter, op, input_type, op2_relu_op1, 1.0f,
2269                                    input_qtype.getZeroPoint());
2270     }
2271   } else {
2272     if (fused_activation_fn.getValue() == "NONE") {
2273       return input_value;
2274     } else if (fused_activation_fn.getValue() == "RELU") {
2275       return rewriter
2276           .create<tosa::ReluNOp>(
2277               op->getLoc(), input_type, input_value,
2278               rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
2279               rewriter.getF32FloatAttr(std::numeric_limits<float>::max()))
2280           .getResult();
2281     } else if (fused_activation_fn.getValue() == "RELU6") {
2282       return rewriter
2283           .create<tosa::ReluNOp>(op->getLoc(), input_type, input_value,
2284                                  rewriter.getI64IntegerAttr(6),
2285                                  rewriter.getF32FloatAttr(6.0))
2286           .getResult();
2287     } else if (fused_activation_fn.getValue() == "RELU_N1_TO_1") {
2288       return rewriter
2289           .create<tosa::ClampOp>(
2290               op->getLoc(), input_type, input_value,
2291               rewriter.getI64IntegerAttr(-1), rewriter.getI64IntegerAttr(1),
2292               rewriter.getF32FloatAttr(-1.0), rewriter.getF32FloatAttr(1.0))
2293           .getResult();
2294     } else if (fused_activation_fn.getValue() == "TANH") {
2295       return rewriter
2296           .create<tosa::TanhOp>(op->getLoc(), input_type, input_value)
2297           .getResult();
2298     } else {
2299       // Unsupported activation type. Bail out.
2300       return llvm::None;
2301     }
2302   }
2303 
2304   return llvm::None;
2305 }
2306 
2307 // Common function for lowering reduce operations to TOSA ops.
2308 template <typename T>
convertReduceOpCommon(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims,Type reduce_element_type,bool is_quantized,double input_scale,int64_t input_zp,double output_scale,int64_t output_zp)2309 llvm::Optional<Value> convertReduceOpCommon(
2310     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2311     Value input_value, ElementsAttr axes_elems, bool keep_dims,
2312     Type reduce_element_type, bool is_quantized, double input_scale,
2313     int64_t input_zp, double output_scale, int64_t output_zp) {
2314   RankedTensorType input_type =
2315       input_value.getType().dyn_cast<RankedTensorType>();
2316   if (!input_type) return llvm::None;
2317 
2318   ArrayRef<int64_t> input_shape = input_type.getShape();
2319   ArrayRef<int64_t> output_shape = output_type.getShape();
2320   auto input_rank = input_shape.size();
2321   Value val = input_value;
2322 
2323   if (axes_elems.getNumElements() == 0) {
2324     // No axes means return the original tensor.
2325     auto identity_op =
2326         rewriter.create<tosa::IdentityOp>(op->getLoc(), output_type, val);
2327     val = identity_op.getResult();
2328   } else {
2329     // Reduce along each axis
2330     SmallVector<int64_t, 4> shape_vec(input_shape.begin(), input_shape.end());
2331 
2332     if (is_quantized) {
2333       val = buildRescaleToInt32(rewriter, op, val, input_scale, input_zp);
2334     }
2335 
2336     for (int i = 0; i < axes_elems.getNumElements(); i++) {
2337       int64_t axis_val = axes_elems.getValue<IntegerAttr>(i).getInt();
2338       if (axis_val < 0) axis_val += input_rank;
2339       auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
2340 
2341       shape_vec[axis_val] = 1;
2342       RankedTensorType reduce_type = RankedTensorType::get(
2343           llvm::makeArrayRef<int64_t>(shape_vec), reduce_element_type);
2344 
2345       auto reduce_op =
2346           rewriter.create<T>(op->getLoc(), reduce_type, val, axis_attr);
2347 
2348       val = reduce_op.getResult();
2349     }
2350 
2351     if (is_quantized) {
2352       RankedTensorType output_rescale_type = RankedTensorType::get(
2353           llvm::makeArrayRef<int64_t>(shape_vec), output_type.getElementType());
2354       val = buildRescaleFromInt32(rewriter, op, output_rescale_type, val,
2355                                   output_scale, output_zp);
2356     }
2357 
2358     // Optionally squeeze out the reduced axes.
2359     if (!keep_dims) {
2360       auto reshape_op = rewriter.create<tosa::ReshapeOp>(
2361           op->getLoc(), output_type, val,
2362           rewriter.getI64ArrayAttr(output_shape));
2363       val = reshape_op.getResult();
2364     }
2365   }
2366 
2367   return val;
2368 }
2369 
2370 // Lowers ReduceAll to a sequence of TOSA ops.
convertReduceAllOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2371 llvm::Optional<Value> convertReduceAllOp(
2372     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2373     Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2374   RankedTensorType input_type =
2375       input_value.getType().dyn_cast<RankedTensorType>();
2376   if (!input_type) return llvm::None;
2377 
2378   return convertReduceOpCommon<tosa::ReduceAllOp>(
2379       rewriter, op, output_type, input_value, axes_elems, keep_dims,
2380       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2381 }
2382 
2383 // Lowers ReduceAny to a sequence of TOSA ops.
convertReduceAnyOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2384 llvm::Optional<Value> convertReduceAnyOp(
2385     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2386     Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2387   RankedTensorType input_type =
2388       input_value.getType().dyn_cast<RankedTensorType>();
2389   if (!input_type) return llvm::None;
2390 
2391   return convertReduceOpCommon<tosa::ReduceAnyOp>(
2392       rewriter, op, output_type, input_value, axes_elems, keep_dims,
2393       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2394 }
2395 
2396 // Lowers ReduceMin to a sequence of TOSA ops.
convertReduceMinOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2397 llvm::Optional<Value> convertReduceMinOp(
2398     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2399     Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2400   RankedTensorType input_type =
2401       input_value.getType().dyn_cast<RankedTensorType>();
2402   if (!input_type) return llvm::None;
2403 
2404   return convertReduceOpCommon<tosa::ReduceMinOp>(
2405       rewriter, op, output_type, input_value, axes_elems, keep_dims,
2406       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2407 }
2408 
2409 // Lowers ReduceMax to a sequence of TOSA ops.
convertReduceMaxOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2410 llvm::Optional<Value> convertReduceMaxOp(
2411     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2412     Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2413   RankedTensorType input_type =
2414       input_value.getType().dyn_cast<RankedTensorType>();
2415   if (!input_type) return llvm::None;
2416 
2417   return convertReduceOpCommon<tosa::ReduceMaxOp>(
2418       rewriter, op, output_type, input_value, axes_elems, keep_dims,
2419       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2420 }
2421 
2422 // Lowers ReduceProd to a sequence of TOSA ops.
convertReduceProdOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2423 llvm::Optional<Value> convertReduceProdOp(
2424     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2425     Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2426   RankedTensorType input_type =
2427       input_value.getType().dyn_cast<RankedTensorType>();
2428   if (!input_type) return llvm::None;
2429 
2430   bool input_is_qtype =
2431       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2432   bool output_is_qtype =
2433       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2434 
2435   if (input_is_qtype || output_is_qtype) {
2436     op->emitOpError(
2437         "ConvertReduceProdOp: input/output tensor should "
2438         "be all floating-point.");
2439     return llvm::None;
2440   }
2441 
2442   return convertReduceOpCommon<tosa::ReduceProdOp>(
2443       rewriter, op, output_type, input_value, axes_elems, keep_dims,
2444       output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2445 }
2446 
2447 // Lowers ReduceSum to a sequence of TOSA ops.
convertReduceSumOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2448 llvm::Optional<Value> convertReduceSumOp(
2449     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2450     Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2451   RankedTensorType input_type =
2452       input_value.getType().dyn_cast<RankedTensorType>();
2453   if (!input_type) return llvm::None;
2454 
2455   bool input_is_qtype =
2456       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2457   bool output_is_qtype =
2458       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2459 
2460   if (input_is_qtype != output_is_qtype) {
2461     op->emitOpError(
2462         "ConvertReduceSumOp: input/output tensor should "
2463         "be all quantized or all floating-point.");
2464     return llvm::None;
2465   }
2466 
2467   double input_scale = 1.0f;
2468   double output_scale = 1.0f;
2469   int64_t input_zp = 0;
2470   int64_t output_zp = 0;
2471   Type reduce_element_type = input_type.getElementType();
2472 
2473   if (input_is_qtype) {
2474     auto input_qtype =
2475         input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2476     auto output_qtype =
2477         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2478 
2479     int32_t input_shift = 20;
2480 
2481     input_scale =
2482         static_cast<double>(1 << input_shift) * input_qtype.getScale();
2483     output_scale =
2484         1.0 / (output_qtype.getScale() * static_cast<double>(1 << input_shift));
2485 
2486     input_zp = input_qtype.getZeroPoint();
2487     output_zp = output_qtype.getZeroPoint();
2488     reduce_element_type = rewriter.getI32Type();
2489   }
2490 
2491   return convertReduceOpCommon<tosa::ReduceSumOp>(
2492       rewriter, op, output_type, input_value, axes_elems, keep_dims,
2493       reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
2494       output_zp);
2495 }
2496 
2497 // Lowers ReduceMean to a sequence of TOSA ops.
convertReduceMeanOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2498 llvm::Optional<Value> convertReduceMeanOp(
2499     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2500     Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2501   // reduce_mean is lowered as followed:
2502   // op1 = reduce_sum(input)
2503   // op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
2504 
2505   RankedTensorType input_type =
2506       input_value.getType().dyn_cast<RankedTensorType>();
2507   if (!input_type) return llvm::None;
2508 
2509   bool input_is_qtype =
2510       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2511   bool output_is_qtype =
2512       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2513 
2514   if (input_is_qtype != output_is_qtype) {
2515     op->emitOpError(
2516         "ConvertReduceSumOp: input/output tensor should "
2517         "be all quantized or all floating-point.");
2518     return llvm::None;
2519   }
2520 
2521   // Only supports float type mean() if it's non-quantized
2522   if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) {
2523     op->emitWarning(
2524         "Failed convertReduceMean: input unquantized type but output element "
2525         "not FloatType!");
2526     return llvm::None;
2527   }
2528 
2529   int64_t input_rank = input_type.getRank();
2530   int64_t num_elems_on_reduced_axis = 1;
2531   for (int i = 0; i < axes_elems.getNumElements(); i++) {
2532     int64_t axis_val = axes_elems.getValue<IntegerAttr>(i).getInt();
2533     if (axis_val < 0) axis_val += input_rank;
2534     num_elems_on_reduced_axis *= input_type.getShape()[axis_val];
2535   }
2536   double div_scale = 1.0 / static_cast<double>(num_elems_on_reduced_axis);
2537 
2538   double input_scale = 1.0f;
2539   double output_scale = 1.0f;
2540   int64_t input_zp = 0;
2541   int64_t output_zp = 0;
2542   Type reduce_element_type = input_type.getElementType();
2543 
2544   if (input_is_qtype) {
2545     auto input_qtype =
2546         input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2547     auto output_qtype =
2548         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2549 
2550     int32_t input_shift = 20;
2551 
2552     input_scale =
2553         static_cast<double>(1 << input_shift) * input_qtype.getScale();
2554     output_scale = div_scale / (output_qtype.getScale() *
2555                                 static_cast<double>(1 << input_shift));
2556 
2557     input_zp = input_qtype.getZeroPoint();
2558     output_zp = output_qtype.getZeroPoint();
2559     reduce_element_type = rewriter.getI32Type();
2560   }
2561 
2562   auto val = convertReduceOpCommon<tosa::ReduceSumOp>(
2563       rewriter, op, output_type, input_value, axes_elems, keep_dims,
2564       reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
2565       output_zp);
2566 
2567   if (!val.hasValue()) return llvm::None;
2568 
2569   if (!input_is_qtype) {
2570     Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
2571     return rewriter
2572         .create<tosa::MulOp>(op->getLoc(), output_type, val.getValue(),
2573                              div_const, 0)
2574         .getResult();
2575   }
2576 
2577   return val;
2578 }
2579 
2580 // Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize.
convertResizeOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,StringRef mode)2581 llvm::Optional<Value> convertResizeOp(PatternRewriter& rewriter, Operation* op,
2582                                       RankedTensorType output_type,
2583                                       Value input_value, StringRef mode) {
2584   RankedTensorType input_type =
2585       input_value.getType().dyn_cast<RankedTensorType>();
2586   if (!input_type) return llvm::None;
2587 
2588   auto input_shape = input_type.getShape();
2589   auto output_shape = output_type.getShape();
2590 
2591   size_t input_height = input_shape[1];
2592   size_t input_width = input_shape[2];
2593   size_t output_height = output_shape[1];
2594   size_t output_width = output_shape[2];
2595 
2596   bool input_is_qtype =
2597       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2598   bool output_is_qtype =
2599       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2600 
2601   if (input_is_qtype != output_is_qtype) {
2602     op->emitOpError(
2603         "ConvertResizeOp: input/output tensor should "
2604         "be all quantized or all floating-point.");
2605     return llvm::None;
2606   }
2607 
2608   if (!input_is_qtype) {
2609     // TODO: support float type
2610     op->emitOpError("ConvertResizeOp: floating-point type not supported yet ");
2611     return llvm::None;
2612   }
2613 
2614   int32_t shift = 11;  // Set default shift to maximum allowed
2615 
2616   double frac_y =
2617       static_cast<double>(output_height) / static_cast<double>(input_height);
2618   double frac_x =
2619       static_cast<double>(output_width) / static_cast<double>(input_width);
2620   int32_t stride_y = std::lround(frac_y * static_cast<double>(1 << shift));
2621   int32_t stride_x = std::lround(frac_x * static_cast<double>(1 << shift));
2622 
2623   // Stride is int16
2624   while (stride_y >= 32768 || stride_x >= 32768) {
2625     shift--;
2626     stride_y = std::lround(frac_y * static_cast<double>(1 << shift));
2627     stride_x = std::lround(frac_x * static_cast<double>(1 << shift));
2628   }
2629 
2630   ArrayAttr output_size =
2631       rewriter.getI64ArrayAttr({static_cast<int64_t>(output_height),
2632                                 static_cast<int64_t>(output_width)});
2633   ArrayAttr stride = rewriter.getI64ArrayAttr({stride_y, stride_x});
2634   ArrayAttr offset = rewriter.getI64ArrayAttr({0, 0});
2635   IntegerAttr shift_attr = rewriter.getI32IntegerAttr(shift);
2636   StringAttr resize_mode = rewriter.getStringAttr(mode.str());
2637 
2638   return rewriter
2639       .create<tosa::ResizeOp>(op->getLoc(), output_type, input_value,
2640                               output_size, stride, offset, shift_attr,
2641                               resize_mode)
2642       .getResult();
2643 }
2644 
2645 // Lowers Quantize to a sequence of TOSA quantization ops.
convertQuantizeOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,double scale,int64_t zeropoint)2646 llvm::Optional<Value> convertQuantizeOp(PatternRewriter& rewriter,
2647                                         Operation* op,
2648                                         RankedTensorType output_type,
2649                                         Value input_value, double scale,
2650                                         int64_t zeropoint) {
2651   RankedTensorType input_type =
2652       input_value.getType().dyn_cast<RankedTensorType>();
2653   if (!input_type) return llvm::None;
2654 
2655   auto output_shape = output_type.getShape();
2656   auto output_element_type = output_type.getElementType();
2657 
2658   // output element type could only be quantized integer
2659   if (!output_element_type.isa<mlir::quant::QuantizedType>()) {
2660     op->emitWarning(
2661         "Lowering quantizeOp but output element type not quantized!");
2662     return llvm::None;
2663   }
2664 
2665   RankedTensorType output_fp_type =
2666       RankedTensorType::get(output_shape, rewriter.getF32Type());
2667 
2668   Value zp_val =
2669       getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(zeropoint));
2670 
2671   auto op1_mul_in = rewriter.create<tosa::MulOp>(
2672       op->getLoc(), output_fp_type, input_value,
2673       getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale)), 0);
2674 
2675   auto op2_add_op1 = rewriter.create<tosa::AddOp>(
2676       op->getLoc(), output_fp_type, op1_mul_in.getResult(), zp_val);
2677 
2678   // TOSA doesn't support CAST FLOAT->AINT8, need to CAST to INT32
2679   // followed by a RESCALE
2680   RankedTensorType output_int32_type =
2681       RankedTensorType::get(output_shape, rewriter.getI32Type());
2682 
2683   auto op3_cast_op2 = rewriter.create<tosa::CastOp>(
2684       op->getLoc(), output_int32_type, op2_add_op1.getResult());
2685 
2686   return buildRescale(rewriter, op, output_type, op3_cast_op2.getResult(), 1.0,
2687                       0, 0);
2688 }
2689 
2690 // Lowers Dequantize to a sequence of TOSA dequantization ops.
convertDequantizeOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,double scale,int64_t zeropoint)2691 llvm::Optional<Value> convertDequantizeOp(PatternRewriter& rewriter,
2692                                           Operation* op,
2693                                           RankedTensorType output_type,
2694                                           Value input_value, double scale,
2695                                           int64_t zeropoint) {
2696   RankedTensorType input_type =
2697       input_value.getType().dyn_cast<RankedTensorType>();
2698   if (!input_type) return llvm::None;
2699 
2700   // input element type could only be quantized integer
2701   if (!input_type.getElementType().isa<mlir::quant::QuantizedType>())
2702     return llvm::None;
2703 
2704   auto output_shape = output_type.getShape();
2705 
2706   RankedTensorType output_int32_type =
2707       RankedTensorType::get(output_shape, rewriter.getI32Type());
2708 
2709   Value zp_val =
2710       getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(zeropoint));
2711 
2712   // TOSA doesn't support CAST AINT8 -> FLOAT, need to RESCALE to INT32
2713   // followed by a CAST
2714   Value op1_rescale_in =
2715       buildRescale(rewriter, op, output_int32_type, input_value, 1.0, 0, 0);
2716 
2717   auto op2_cast_op1 =
2718       rewriter.create<tosa::CastOp>(op->getLoc(), output_type, op1_rescale_in);
2719 
2720   auto op3_sub_op2 = rewriter.create<tosa::SubOp>(
2721       op->getLoc(), output_type, op2_cast_op1.getResult(), zp_val);
2722 
2723   return rewriter
2724       .create<tosa::MulOp>(
2725           op->getLoc(), output_type, op3_sub_op2.getResult(),
2726           getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale)),
2727           0)
2728       .getResult();
2729 }
2730 
2731 // Lowers FakeQuant to a sequence of TOSA quantization ops.
convertFakeQuantOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,double min,double max,int64_t num_bits,bool narrow_range)2732 llvm::Optional<Value> convertFakeQuantOp(PatternRewriter& rewriter,
2733                                          Operation* op,
2734                                          RankedTensorType output_type,
2735                                          Value input_value, double min,
2736                                          double max, int64_t num_bits,
2737                                          bool narrow_range) {
2738   // FakeQuant is lowered as follow:
2739   // op1 = quantize(input)
2740   // op2 = dequantize(op1)
2741 
2742   RankedTensorType input_type =
2743       input_value.getType().dyn_cast<RankedTensorType>();
2744   if (!input_type) return llvm::None;
2745 
2746   // quantized as INT<num_bits>, where num_bits can only be 8, 16
2747   if (num_bits != 8 && num_bits != 16) {
2748     op->emitWarning("FakeQuantOp lowering handles only 8 and 16 for num_bits!");
2749     return llvm::None;
2750   }
2751 
2752   auto output_shape = output_type.getShape();
2753 
2754   int64_t qmax = (1L << (num_bits - 1)) - 1;
2755   int64_t qmin = -(1L << (num_bits - 1));
2756   if (narrow_range) {
2757     qmin += 1;
2758   }
2759 
2760   auto int_element_qtype = mlir::quant::UniformQuantizedType::get(
2761       true, rewriter.getIntegerType(num_bits), rewriter.getF32Type(), 1.0f, 0,
2762       qmin, qmax);
2763   RankedTensorType output_int_type =
2764       RankedTensorType::get(output_shape, int_element_qtype);
2765 
2766   double scale = (max - min) / static_cast<double>(qmax - qmin);
2767   int64_t zeropoint = std::llround((-min) / scale + static_cast<double>(qmin));
2768 
2769   // Quantize: round(x / scale + zeropoint)
2770   auto quantized_val = convertQuantizeOp(rewriter, op, output_int_type,
2771                                          input_value, 1.0 / scale, zeropoint);
2772 
2773   if (!quantized_val.hasValue()) return llvm::None;
2774 
2775   // Dequantize: ((float)x - zeropoint) * scale
2776   return convertDequantizeOp(rewriter, op, output_type,
2777                              quantized_val.getValue(), scale, zeropoint);
2778 }
2779 
convertTFConv2DCommon(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input,Value filter,Value bias,ArrayAttr strides_attr,ArrayAttr dilations_attr,ArrayAttr explicit_padding_attr,StringRef padding_ref,StringRef data_format_ref)2780 llvm::Optional<Value> convertTFConv2DCommon(
2781     PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2782     Value input, Value filter, Value bias, ArrayAttr strides_attr,
2783     ArrayAttr dilations_attr, ArrayAttr explicit_padding_attr,
2784     StringRef padding_ref, StringRef data_format_ref) {
2785   RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
2786   RankedTensorType filter_type = filter.getType().dyn_cast<RankedTensorType>();
2787   // Not a ranked tensor output
2788   if (!input_type) return llvm::None;
2789   if (!filter_type) return llvm::None;
2790 
2791   // Transpose [H, W, I, O] to [O, H, W, I]
2792   auto filter_shape = filter_type.getShape();
2793   SmallVector<int64_t, 4> a1_transpose_dims;
2794   a1_transpose_dims.push_back(filter_shape[3]);
2795   a1_transpose_dims.push_back(filter_shape[0]);
2796   a1_transpose_dims.push_back(filter_shape[1]);
2797   a1_transpose_dims.push_back(filter_shape[2]);
2798   Value a1_filter_transpose_perm =
2799       get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, {3, 0, 1, 2});
2800   auto a1_filter_transpose_op = rewriter.create<tosa::TransposeOp>(
2801       op->getLoc(),
2802       RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_dims),
2803                             filter_type.getElementType()),
2804       filter, a1_filter_transpose_perm);
2805 
2806   // Only support NHWC now.
2807   if (data_format_ref.str() != "NHWC") {
2808     op->emitWarning("convertTDConv2DCommon only supports NHWC!");
2809     return llvm::None;
2810   }
2811 
2812   ArrayAttr stride;
2813   ArrayAttr dilation;
2814   ArrayAttr pad;
2815   {
2816     if (!strides_attr) {
2817       stride = rewriter.getI64ArrayAttr({1, 1});
2818     } else {
2819       // Note: hardcoded to NHWC for now
2820       int64_t stride_h = strides_attr[1].cast<IntegerAttr>().getInt();
2821       int64_t stride_w = strides_attr[2].cast<IntegerAttr>().getInt();
2822       stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
2823     }
2824   }
2825   {
2826     if (!dilations_attr) {
2827       dilation = rewriter.getI64ArrayAttr({1, 1});
2828     } else {
2829       // Note: hardcoded to NHWC for now
2830       int64_t dilation_h = dilations_attr[1].cast<IntegerAttr>().getInt();
2831       int64_t dilation_w = dilations_attr[2].cast<IntegerAttr>().getInt();
2832       dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
2833     }
2834   }
2835   {
2836     tensorflow::Padding tf_pad;
2837     if (!GetPaddingFromString(padding_ref.str(), &tf_pad).ok()) {
2838       op->emitWarning("Could not get padding data from padding string term!");
2839       return llvm::None;
2840     }
2841 
2842     tensorflow::TensorFormat data_format_tf;
2843     if (!FormatFromString(data_format_ref.str(), &data_format_tf))
2844       return llvm::None;
2845 
2846     if (tf_pad == tensorflow::Padding::EXPLICIT) {
2847       pad = getPaddingValuesFromExplicitPadAttr(explicit_padding_attr,
2848                                                 data_format_tf, rewriter);
2849     } else {
2850       if (!getPaddingValuesFromPadType(tf_pad, data_format_tf,
2851                                        0,  // tensorflow::FORMAT_HWIO
2852                                        input_type, filter_type, stride,
2853                                        dilation, rewriter, pad))
2854         return llvm::None;
2855     }
2856   }
2857 
2858   return rewriter
2859       .create<tosa::Conv2DOp>(op->getLoc(), output_type, input,
2860                               a1_filter_transpose_op.getResult(), bias, pad,
2861                               stride, dilation)
2862       .getResult();
2863 }
2864 
2865 };  // namespace tosa
2866 };  // namespace mlir
2867