1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file contains legalizations common to mapping both TensorFlow and
17 // TensorFlow Lite to TOSA. It operates generically on ops and does not have
18 // a hard reference on either dialect.
19 //
20 // Conversion functions return llvm::None on a legalization failure or a
21 // legalized value on success. Callers must check for presence of an
22 // llvm::Optional value after each call.
23
24 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
25
26 #include <climits>
27 #include <cstddef>
28 #include <cstdint>
29 #include <iterator>
30 #include <numeric>
31
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
34 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
35 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
36 #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
37 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
38 #include "mlir/IR/Matchers.h" // from @llvm-project
39 #include "mlir/IR/PatternMatch.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
41
42 namespace mlir {
43 namespace tosa {
44
45 // Copied Nudge implementation from
46 // tensorflow/core/kernels/fake_quant_ops_functor.h.
47 // Suggested approach to avoid significant TensorFlow
48 // build dependency.
tensorflow_nudge(const float min,const float max,const int quant_min,const int quant_max,float * nudged_min,float * nudged_max,float * scale)49 void tensorflow_nudge(const float min, const float max, const int quant_min,
50 const int quant_max, float* nudged_min, float* nudged_max,
51 float* scale) {
52 const float quant_min_float = static_cast<float>(quant_min);
53 const float quant_max_float = static_cast<float>(quant_max);
54 *scale = (max - min) / (quant_max_float - quant_min_float);
55 const float zero_point_from_min = quant_min_float - min / *scale;
56 const uint16_t nudged_zero_point = [zero_point_from_min, quant_min,
57 quant_min_float, quant_max,
58 quant_max_float] {
59 if (zero_point_from_min < quant_min_float) {
60 return static_cast<uint16_t>(quant_min);
61 }
62 if (zero_point_from_min > quant_max_float) {
63 return static_cast<uint16_t>(quant_max);
64 }
65 return static_cast<uint16_t>(std::round(zero_point_from_min));
66 }();
67 *nudged_min = (quant_min_float - nudged_zero_point) * (*scale);
68 *nudged_max = (quant_max_float - nudged_zero_point) * (*scale);
69 }
70
71 // Lowers the Pack operator to TOSA.
convertPackOp(PatternRewriter & rewriter,Operation * op,Value result_value,SmallVectorImpl<Value> & inputs,int32_t axis)72 llvm::Optional<Value> convertPackOp(PatternRewriter& rewriter, Operation* op,
73 Value result_value,
74 SmallVectorImpl<Value>& inputs,
75 int32_t axis) {
76 //////////////////////////////////////////////////
77 // Operator: output = Pack([values], axis) or output = Stack([values], axis)
78 // Lowering:
79 //
80 // This operator is lowered into a series of pairwise tosa.concat()
81 // operators and a reshape
82 // Depending on the inputs, a tranpose operator is also generated:
83 //
84 // Step 1: concatenate the tensors
85 // a1_concat = tosa.concat(input[0], input[1], axis)
86 // for (i = 2; i < len(input); i++)
87 // a1_concat = tosa.concat(a1_concat, input[i], axis)
88 //
89 // Step 2: reshape to N+1 dimensions
90 // a2_reshape = tosa.reshape(a1_concat, new_rank)
91 //
92 // Step 3: Transpose if a new dimension is being added:
93 // if (axis == rank(values[0]):
94 // // perm will be [1, 2, 3, 0]
95 // a3_transpose = tosa.transpose(a2_reshape, perm)
96
97 // Sanity check 1: make sure all input tensors have the same shape
98 // if input[0] has shape [A, B, C], input[1] to input[N-1] should also have
99 // shape[A, B, C]
100 RankedTensorType result_type =
101 result_value.getType().dyn_cast<RankedTensorType>();
102
103 // Check for ranked tensor type.
104 if (!result_type) {
105 op->emitOpError("PackOp: result type not ranked tensor");
106 return llvm::None;
107 }
108
109 // Valid axis in TF is [-rank(input), rank(input))
110 // Valid axis in TOSA is [0, rank(input))
111 // Plus rank(input) once if axis is negative.
112 RankedTensorType input_type =
113 op->getOperand(0).getType().dyn_cast<RankedTensorType>();
114 if (!input_type) {
115 op->emitOpError("PackOp: input type not ranked tensor");
116 return llvm::None;
117 }
118
119 input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
120 if (!input_type) {
121 op->emitOpError("Input 0 type not ranked tensor.");
122 return llvm::None;
123 }
124 ArrayRef<int64_t> input0_tensor_shape = input_type.getShape();
125 int input_tensor_rank = input0_tensor_shape.size();
126
127 for (int i = 1; i < inputs.size(); i++) {
128 input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
129 if (!input_type) {
130 op->emitOpError(llvm::formatv(
131 "reduce axis {} is not in valid range [-rank(input), rank(input))",
132 i));
133 return llvm::None;
134 }
135 ArrayRef<int64_t> next_tensor_shape = input_type.getShape();
136 if (next_tensor_shape.size() != input_tensor_rank) {
137 op->emitOpError("PackOp: input tensor rank mismatch.");
138 return llvm::None;
139 }
140 for (int d = 0; d < input0_tensor_shape.size(); d++) {
141 if (input0_tensor_shape[d] != next_tensor_shape[d]) {
142 op->emitOpError("PackOp: input tensor shape mismatch.");
143 return llvm::None;
144 }
145 }
146 }
147
148 // If input tensors are rank 0, should reshape them to rank 1 size 1 before
149 // performing concat.
150 if (input_tensor_rank == 0) {
151 SmallVector<int64_t, 1> reshape_rank1_size1_shape({1});
152 RankedTensorType reshape_rank1_size1_type = RankedTensorType::get(
153 reshape_rank1_size1_shape, result_type.getElementType());
154 ArrayAttr shape_rank1_size1_attr =
155 rewriter.getI64ArrayAttr(reshape_rank1_size1_shape);
156 for (int i = 0; i < inputs.size(); i++) {
157 auto a0_reshape_op = rewriter.create<tosa::ReshapeOp>(
158 op->getLoc(), reshape_rank1_size1_type, inputs[i],
159 shape_rank1_size1_attr);
160 inputs[i] = a0_reshape_op.getResult();
161 }
162 }
163
164 // Sanity check 2: axis can be from [0, rank(input)+1]
165 // Where rank(input)+1 means create a new dimension
166 // Negative values are also allowed up to -(rank(input)+1)
167 // where the axis "wraps around".
168 if (axis < 0) axis += input_tensor_rank;
169 if ((axis < 0) || (axis > (input_tensor_rank + 1))) {
170 op->emitOpError("PackOp: axis out of valid range.");
171 return llvm::None;
172 }
173
174 // Sanity check 2: if input shape is [A, B, C], output shape should be [N,
175 // A, B, C]
176 // 2.a check output is rank(input) + 1
177 SmallVector<int64_t> output_shape_vals(result_type.getShape().begin(),
178 result_type.getShape().end());
179 if (output_shape_vals.size() != (input_tensor_rank + 1)) {
180 op->emitOpError("PackOp: output tensor rank mismatch.");
181 return llvm::None;
182 }
183 // 2.b check output rank 0 is N
184 if (output_shape_vals[axis] != inputs.size()) {
185 op->emitOpError("PackOp: output tensor shape mismatch.");
186 return llvm::None;
187 }
188 // Most of the cases when PackOp.axis() is within [0, rank(input) - 1].
189 // We can directly concatenate along that axis and perform the reshape.
190 // For example, stack N [A, B, C] input tensor ranks along axis = 1
191 // after concatenation, output will be [A, N * B, C]
192 // and then reshape it into [A, N, B, C]
193 // a special case would be PackOp.axis() equal to rank(input), in which case
194 // we can't directly concatenate along the PackOp.axis(), instead
195 // we concat along axis=0, and reshape into [N, A, B, C]
196 // and then we need an extra transpose to [A, B, C, N].
197 int64_t concat_axis;
198 SmallVector<int32_t> perm;
199 SmallVector<int64_t> reshape_output_shape;
200 if (axis == 0 && input_tensor_rank == 0) {
201 concat_axis = 0;
202 } else if (axis == input_tensor_rank) {
203 concat_axis = 0;
204
205 // A special case when stack axis is equal to input tensor rank:
206 // Output shape is [A, B, C, N]
207 // so reshape output will be [N, A, B, C]
208 // and perm will be [1, 2, 3, 0].
209 reshape_output_shape.push_back(output_shape_vals[axis]);
210 for (int d = 0; d < input_tensor_rank; d++) {
211 perm.push_back(d + 1);
212 reshape_output_shape.push_back(output_shape_vals[d]);
213 }
214 perm.push_back(0);
215 } else {
216 // General case, doesn't need perm vector.
217 concat_axis = axis;
218 reshape_output_shape.assign(output_shape_vals.begin(),
219 output_shape_vals.end());
220 }
221 IntegerAttr concat_axis_attr = rewriter.getI64IntegerAttr(concat_axis);
222 ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_output_shape);
223
224 // Concat output shape will depend on concat_axis. E.g. [N * A, B, C]
225 SmallVector<int64_t> concat_output_shape;
226 if (input_tensor_rank == 0) {
227 concat_output_shape.push_back(1);
228 } else {
229 for (int i = 0; i < input_tensor_rank; i++) {
230 concat_output_shape.push_back(input0_tensor_shape[i]);
231 }
232 }
233
234 concat_output_shape[concat_axis] =
235 concat_output_shape[concat_axis] * inputs.size();
236 RankedTensorType concat_type = RankedTensorType::get(
237 ArrayRef<int64_t>(concat_output_shape), result_type.getElementType());
238
239 SmallVector<Value> inputs_0;
240 for (int i = 0; i < inputs.size(); i++) {
241 inputs_0.push_back(inputs[i]);
242 }
243 auto a1_concat_op = rewriter.create<tosa::ConcatOp>(
244 op->getLoc(), concat_type, inputs_0, concat_axis_attr);
245
246 // Doesn't need reshape or transpose if input tensor is rank 0, since inputs
247 // are reshaped beforehand.
248 if (input_tensor_rank == 0) return a1_concat_op.getResult();
249
250 // Reshape [N * A, B, C] to [N, A, B, C].
251 RankedTensorType reshape_output_type =
252 RankedTensorType::get(reshape_output_shape, result_type.getElementType());
253
254 auto a2_reshape_op = rewriter.create<tosa::ReshapeOp>(
255 op->getLoc(), reshape_output_type, a1_concat_op.getResult(), shape_attr);
256
257 // If axis is equal to input tensor rank, then we need extra transpose
258 // [N, A, B, C] to [A, B, C, N]
259 if (axis == input_tensor_rank) {
260 llvm::Optional<Value> a3_transpose_perm = getConstTensor<int32_t>(
261 rewriter, op, perm, {static_cast<int64_t>(perm.size())});
262
263 if (!a3_transpose_perm) return llvm::None;
264
265 return rewriter
266 .create<tosa::TransposeOp>(op->getLoc(), result_type,
267 a2_reshape_op.getResult(),
268 a3_transpose_perm.getValue())
269 .getResult();
270 }
271
272 return a2_reshape_op.getResult();
273 }
274
275 // Lowers the Unpack operator to TOSA
convertUnpackOp(PatternRewriter & rewriter,Operation * op,Value input_value,int32_t axis)276 llvm::Optional<SmallVector<Value>> convertUnpackOp(PatternRewriter& rewriter,
277 Operation* op,
278 Value input_value,
279 int32_t axis) {
280 RankedTensorType input_type =
281 input_value.getType().dyn_cast<RankedTensorType>();
282 if (!input_type) return llvm::None;
283
284 auto input_shape = input_type.getShape();
285 int64_t input_rank = input_shape.size();
286
287 SmallVector<Value> results_vec;
288
289 // Negative axis allowed as long as it's within [-input_rank, input_rank).
290 if (axis < 0) axis += input_rank;
291 if ((axis < 0) || (axis > input_rank)) {
292 op->emitOpError("UnpackOp: axis out of valid range.");
293 return llvm::None;
294 }
295
296 // Step 1: transpose 'axis' to leftmost dimension.
297 Value transposed_input_value;
298 if (axis != 0) {
299 SmallVector<int32_t> perm;
300 SmallVector<int64_t> a1_transpose_shape(input_rank);
301
302 perm.push_back(axis);
303 for (int i = 0; i < input_rank; i++) {
304 if (i == axis) continue;
305 perm.push_back(i);
306 }
307
308 llvm::Optional<Value> a1_transpose_perm = getConstTensor<int32_t>(
309 rewriter, op, perm, {static_cast<int64_t>(perm.size())});
310
311 if (!a1_transpose_perm) return llvm::None;
312
313 for (int i = 0; i < input_rank; i++) {
314 a1_transpose_shape[i] = input_shape[perm[i]];
315 }
316
317 auto a1_transpose_op = rewriter.create<tosa::TransposeOp>(
318 op->getLoc(),
319 RankedTensorType::get(a1_transpose_shape, input_type.getElementType()),
320 input_value, a1_transpose_perm.getValue());
321
322 transposed_input_value = a1_transpose_op.getResult();
323 } else {
324 // Do nothing if axis is already at leftmost dimension.
325 transposed_input_value = input_value;
326 }
327
328 // Step 2: slice [N, A, B, C] into N [A, B, C].
329 RankedTensorType transposed_input_type =
330 transposed_input_value.getType().dyn_cast<RankedTensorType>();
331 if (!transposed_input_type) return llvm::None;
332
333 auto transposed_input_shape = transposed_input_type.getShape();
334 int64_t transposed_input_rank = transposed_input_shape.size();
335
336 for (int i = 0; i < transposed_input_shape[0]; i++) {
337 SmallVector<int64_t> begin_vals, size_vals, shape_vals;
338
339 for (int j = 0; j < transposed_input_rank; j++) {
340 if (j == 0) {
341 begin_vals.push_back(i);
342 size_vals.push_back(1);
343 } else {
344 begin_vals.push_back(0);
345 size_vals.push_back(transposed_input_shape[j]);
346 shape_vals.push_back(transposed_input_shape[j]);
347 }
348 }
349
350 ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
351 ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
352
353 auto a2_slice_op = rewriter.create<tosa::SliceOp>(
354 op->getLoc(),
355 RankedTensorType::get(size_vals,
356 transposed_input_type.getElementType()),
357 transposed_input_value, begin, size);
358
359 auto a3_reshape_op = rewriter.create<tosa::ReshapeOp>(
360 op->getLoc(),
361 RankedTensorType::get(shape_vals,
362 transposed_input_type.getElementType()),
363 a2_slice_op.getResult(), rewriter.getI64ArrayAttr(shape_vals));
364
365 results_vec.push_back(a3_reshape_op.getResult());
366 }
367
368 return results_vec;
369 }
370
371 // Lowers the Select operator to TOSA.
convertSelectOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value condition_value,Value x_value,Value y_value)372 llvm::Optional<Value> convertSelectOp(PatternRewriter& rewriter, Operation* op,
373 Value result_value, Value condition_value,
374 Value x_value, Value y_value) {
375 RankedTensorType result_type =
376 result_value.getType().dyn_cast<RankedTensorType>();
377 RankedTensorType condition_type =
378 condition_value.getType().dyn_cast<RankedTensorType>();
379 RankedTensorType x_type = x_value.getType().dyn_cast<RankedTensorType>();
380 RankedTensorType y_type = y_value.getType().dyn_cast<RankedTensorType>();
381
382 if (!result_type || !condition_type || !x_type || !y_type) {
383 op->emitOpError("Select: failed ranked tensor type check");
384 return llvm::None;
385 }
386
387 // First check whether we need to reshape the condition to match
388 // the same rank as the then/else clauses.
389 if (result_type.getRank() == condition_type.getRank()) {
390 // Nothing to reshape.
391 return rewriter
392 .create<tosa::SelectOp>(op->getLoc(), result_type, condition_value,
393 x_value, y_value)
394 .getResult();
395 }
396
397 // Need to reshape the condition.
398 SmallVector<int64_t> new_cond_dims(
399 result_type.getRank() - condition_type.getRank(), 1);
400
401 for (int i = 0; i < condition_type.getRank(); i++) {
402 new_cond_dims.push_back(condition_type.getShape()[i]);
403 }
404
405 auto reshape_op = rewriter.create<tosa::ReshapeOp>(
406 op->getLoc(),
407 RankedTensorType::get(new_cond_dims, condition_type.getElementType()),
408 condition_value, rewriter.getI64ArrayAttr(new_cond_dims));
409
410 return rewriter
411 .create<tosa::SelectOp>(op->getLoc(), result_type, reshape_op, x_value,
412 y_value)
413 .getResult();
414 }
415
416 // Lowers the ZerosLike operator to TOSA by creating a constant
417 // of the desired type and shape.
convertZerosLikeOp(PatternRewriter & rewriter,Operation * op,Value result,Value input)418 llvm::Optional<Value> convertZerosLikeOp(PatternRewriter& rewriter,
419 Operation* op, Value result,
420 Value input) {
421 RankedTensorType result_type = result.getType().dyn_cast<RankedTensorType>();
422 if (!result_type) {
423 op->emitOpError("Zeroslike: result not ranked tensor type");
424 return llvm::None;
425 }
426
427 RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
428 if (!input_type) {
429 op->emitOpError("Zeroslike: input not ranked tensor type");
430 return llvm::None;
431 }
432
433 auto input_shape = input_type.getShape();
434
435 ShapedType zero_type =
436 RankedTensorType::get(input_shape, input_type.getElementType());
437 Attribute zero_attr = rewriter.getZeroAttr(zero_type);
438
439 return rewriter
440 .create<tosa::ConstOp>(op->getLoc(), zero_type,
441 zero_attr.cast<ElementsAttr>())
442 .getResult();
443 }
444
445 // Lowers the Mul operator to TOSA. For quantized types, this requires
446 // inserting rescale operators before and after the operation.
convertMultiplyOp(PatternRewriter & rewriter,Operation * op,Value output_val,Value input_lhs_val,Value input_rhs_val)447 llvm::Optional<Value> convertMultiplyOp(PatternRewriter& rewriter,
448 Operation* op, Value output_val,
449 Value input_lhs_val,
450 Value input_rhs_val) {
451 ShapedType input_lhs_type = input_lhs_val.getType().dyn_cast<ShapedType>();
452 ShapedType input_rhs_type = input_rhs_val.getType().dyn_cast<ShapedType>();
453 ShapedType output_type = output_val.getType().dyn_cast<ShapedType>();
454 // Not a shaped tensor output
455 if (!input_lhs_type || !input_rhs_type || !output_type) return llvm::None;
456
457 bool input_lhs_is_qtype =
458 input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
459 bool input_rhs_is_qtype =
460 input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
461 bool output_is_qtype =
462 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
463
464 if (input_lhs_is_qtype != output_is_qtype ||
465 input_rhs_is_qtype != output_is_qtype) {
466 op->emitOpError(
467 "ConvertMultiplyOp: input/output tensor should "
468 "be all quantized or all floating-point");
469 return llvm::None;
470 }
471
472 Value output;
473 if (output_is_qtype) {
474 ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
475 auto input_lhs_qtype = input_lhs_type.getElementType()
476 .cast<mlir::quant::UniformQuantizedType>();
477 auto input_rhs_qtype = input_rhs_type.getElementType()
478 .cast<mlir::quant::UniformQuantizedType>();
479 auto output_qtype =
480 output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
481
482 // MLIR store scale as double, but TFLite store scale as float
483 // Downcasting from double to float to match TFLite behavior
484 float in_lhs_scale = input_lhs_qtype.getScale();
485 float in_rhs_scale = input_rhs_qtype.getScale();
486 float output_scale = output_qtype.getScale();
487
488 double output_rescale_scale = in_lhs_scale * in_rhs_scale / output_scale;
489
490 // 16bits x 16bits -> 32bits
491 // 32bits can be rescaled with 32bits quantize multiplier back to 16bits
492 bool scale32 = true;
493
494 Value op1_rescale_lhs = buildRescaleToInt32(
495 rewriter, op, input_lhs_val, 1.0f, input_lhs_qtype.getZeroPoint());
496 Value op2_rescale_rhs = buildRescaleToInt32(
497 rewriter, op, input_rhs_val, 1.0f, input_rhs_qtype.getZeroPoint());
498 auto op3_mul_op1_op2 = rewriter.create<tosa::MulOp>(
499 op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs, 0);
500 return buildRescale(rewriter, op, output_type, op3_mul_op1_op2.getResult(),
501 output_rescale_scale, 0, output_qtype.getZeroPoint(),
502 true, scale32);
503 }
504
505 return rewriter
506 .create<tosa::MulOp>(op->getLoc(), output_type, input_lhs_val,
507 input_rhs_val, 0)
508 .getResult();
509 }
510
511 // Lowers the SquaredDifference operator to TOSA.
convertSquaredDifferenceOp(PatternRewriter & rewriter,Operation * op,Value result,Value x,Value y)512 llvm::Optional<Value> convertSquaredDifferenceOp(PatternRewriter& rewriter,
513 Operation* op, Value result,
514 Value x, Value y) {
515 // Squared-difference is (x-y)*(x-y).
516 // This lowering calculates the difference and multiplies.
517 ShapedType result_type = result.getType().dyn_cast<ShapedType>();
518 if (!result_type) {
519 op->emitOpError("SquaredDifference: result not ranked tensor type");
520 return llvm::None;
521 }
522
523 ShapedType x_type = x.getType().dyn_cast<ShapedType>();
524 ShapedType y_type = y.getType().dyn_cast<ShapedType>();
525 if (!x_type || !y_type) {
526 op->emitOpError("SquaredDifference: inputs not ranked tensor type");
527 return llvm::None;
528 }
529
530 auto sub_op = rewriter.create<tosa::SubOp>(op->getLoc(), result_type, x, y);
531 return rewriter
532 .create<tosa::MulOp>(op->getLoc(), result_type, sub_op.getResult(),
533 sub_op.getResult(), 0)
534 .getResult();
535 }
536
537 // Lowers the Round operator to TOSA.
convertRoundOp(PatternRewriter & rewriter,Operation * op,Value result,Value input)538 llvm::Optional<Value> convertRoundOp(PatternRewriter& rewriter, Operation* op,
539 Value result, Value input) {
540 // Implements banker's rounding by calculating floor(input + 0.5).
541 ShapedType result_type = result.getType().dyn_cast<ShapedType>();
542 if (!result_type) {
543 op->emitOpError("Round: result not shaped tensor type");
544 return llvm::None;
545 }
546
547 ShapedType input_type = input.getType().dyn_cast<ShapedType>();
548 if (!input_type) {
549 op->emitOpError("Round: input not shaped tensor type");
550 return llvm::None;
551 }
552
553 auto add_op = rewriter.create<tosa::AddOp>(
554 op->getLoc(), result_type, input,
555 getTosaConstTensorSingleF32(rewriter, op, 0.5));
556
557 return rewriter
558 .create<tosa::FloorOp>(op->getLoc(), result_type, add_op.getResult())
559 .getResult();
560 }
561
562 // Lowers ConcatV2 to TOSA Concat.
convertConcatV2Op(PatternRewriter & rewriter,Operation * op,Value result_value,SmallVectorImpl<Value> & values,int32_t axis)563 llvm::Optional<Value> convertConcatV2Op(PatternRewriter& rewriter,
564 Operation* op, Value result_value,
565 SmallVectorImpl<Value>& values,
566 int32_t axis) {
567 // Check all inputs are RankedTensorType
568 for (auto v : values) {
569 if (!v.getType().dyn_cast<RankedTensorType>()) {
570 op->emitOpError("ConcatV2Op: value type not ranked tensor.");
571 return llvm::None;
572 }
573 }
574
575 // Check output is Ranked tensor type
576 if (!result_value.getType().dyn_cast<RankedTensorType>()) {
577 op->emitOpError("ConcatV2Op: output value type not ranked tensor.");
578 return llvm::None;
579 }
580
581 RankedTensorType result_type =
582 result_value.getType().dyn_cast<RankedTensorType>();
583 mlir::quant::UniformQuantizedType result_quant_type =
584 result_type.getElementType()
585 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
586
587 SmallVector<Value> values_rescaled;
588
589 for (auto v : values) {
590 RankedTensorType operand_type = v.getType().dyn_cast<RankedTensorType>();
591 mlir::quant::UniformQuantizedType operand_quant_type =
592 operand_type.getElementType()
593 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
594
595 // tfl.concat currently allows different scales for each input tensor, which
596 // TFlite team will fix in:
597 // https://github.com/tensorflow/tensorflow/issues/39658
598 // For backward compatibility, we still need to support this artifact by
599 // scaling inputs to let them have the same scales.
600 if (result_quant_type && operand_quant_type) {
601 double operand_scale = static_cast<double>(operand_quant_type.getScale());
602 int32_t operand_zeropoint = operand_quant_type.getZeroPoint();
603
604 double result_scale = static_cast<double>(result_quant_type.getScale());
605 int32_t result_zeropoint = result_quant_type.getZeroPoint();
606
607 // Rescale input if scale is not equal to output tensor scale.
608 if (operand_scale != result_scale) {
609 RankedTensorType rescale_type =
610 RankedTensorType::get(operand_type.getShape(), result_quant_type);
611 Value rescale_op = buildRescale(
612 rewriter, op, rescale_type, v, operand_scale / result_scale,
613 operand_zeropoint, result_zeropoint, false, true);
614 values_rescaled.push_back(rescale_op);
615 } else {
616 values_rescaled.push_back(v);
617 }
618 } else {
619 values_rescaled.push_back(v);
620 }
621 }
622
623 int32_t tensor_rank = result_type.getShape().size();
624
625 if (axis < 0) axis += tensor_rank;
626 if ((axis < 0) || (axis > tensor_rank)) {
627 op->emitOpError("ConcatV2Op: axis out of valid range.");
628 return llvm::None;
629 }
630
631 auto concat_op = rewriter.create<tosa::ConcatOp>(
632 op->getLoc(), result_value.getType(), values_rescaled,
633 rewriter.getI64IntegerAttr(axis));
634
635 return concat_op.getResult();
636 }
637
638 // Lowers SpaceToBatchND to TOSA.
convertSpaceToBatchNDOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value block_shape_value,Value paddings_value)639 llvm::Optional<Value> convertSpaceToBatchNDOp(PatternRewriter& rewriter,
640 Operation* op, Value result_value,
641 Value input_value,
642 Value block_shape_value,
643 Value paddings_value) {
644 /////////////////////////////////////////////////
645 // Operator: output = SpaceToBatchND(input, block_shape, paddings)
646 // Lowering:
647 //
648 // SpaceToBatch input tensors are broken into three pieces:
649 // (a) batch dimension (N in NHWC)
650 // (b) input being transformed to batch dimension (typically H, W in NHWC)
651 // (c) remainder of input (typically C in NHWC)
652 //
653 // Step 0. Generate padding constant for the first reshape.
654 // No padding on the batch dimension
655 // The input paddings array is addressed as [input_rank][2]
656 // No padding on the remaining dimensions
657 //
658 // a0_pad_const = tosa.const(input=Tensor<input_rank, 2>)
659 //
660 // Step 1. Pad the input tensor
661 //
662 // a1_pad_input_op = tosa.pad(input=input, shape=a0_pad_const_op)
663 //
664 // Step 2. Reshape the padded structure of shape padded_shape to
665 // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
666 // padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
667 // remaining_shape
668 //
669 // block_rank = M (number of elements in block_shape)
670 // New rank: input_rank + block_rank
671 //
672 // a2_reshape_a1_op = tosa.reshape(input=a1_pad_input_op, shape=a2_shape)
673 //
674 // Step 3. Transpose dimensions to:
675 // block-shape +
676 // [batch] +
677 // [padded_shape[1] / block_shape[0],
678 // ...
679 // [padded_shape[M] / block_shape[M-1]] +
680 // remaining_shape
681 //
682 // a3_transpose_a2_op = tosa.tranpose(input=a2_reshape_a1_op,
683 // perms=a3_perm)
684 //
685 // Step 4. Reshape the transposed tensor to flatten block_shape stuff
686 // into the batch dimension with the following shape:
687 // [ batch * prod(block_shape)] +
688 // [ padded_shape[1] / block_shape[0],
689 // ...,
690 // padded_shape[M] / block_shape[M-1]] +
691 // remaining_shape
692 //
693 // a4_reshape_a3_op = tosa.reshape(input=a3_tranpose_a2_op,
694 // shape=a3_shape)
695 //
696
697 RankedTensorType result_type =
698 result_value.getType().dyn_cast<RankedTensorType>();
699 RankedTensorType input_type =
700 input_value.getType().dyn_cast<RankedTensorType>();
701 RankedTensorType block_shape_type =
702 block_shape_value.getType().dyn_cast<RankedTensorType>();
703 RankedTensorType paddings_type =
704 paddings_value.getType().dyn_cast<RankedTensorType>();
705
706 // Not a ranked tensor output.
707 if (!result_type) {
708 op->emitOpError("SpaceToBatchND: result type not ranked tensor");
709 return llvm::None;
710 }
711 if (!input_type) {
712 op->emitOpError("SpaceToBatchND: input type not ranked tensor");
713 return llvm::None;
714 }
715 if (!block_shape_type) {
716 op->emitOpError("SpaceToBatchND: block shape type not ranked tensor");
717 return llvm::None;
718 }
719 if (!paddings_type) {
720 op->emitOpError("SpaceToBatchND: paddings type not ranked tensor");
721 return llvm::None;
722 }
723
724 // Follow implementation in
725 // tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
726
727 // So, to figure out the spatial_shape, remove the batch dimension and
728 // then use the next block_rank dimensions. The remaining dimensions are
729 // remaining_shape.
730
731 auto block_shape = block_shape_type.getShape();
732 auto input_shape = input_type.getShape();
733
734 int block_rank = block_shape[0];
735 int batch_size = input_shape[0];
736 int input_rank = input_type.getRank();
737 int remaining_shape_rank = input_rank - block_rank - 1;
738 int block_num_elems = 1;
739 int padding_sum = 0;
740
741 ElementsAttr block_shape_elems;
742 ElementsAttr paddings_elems;
743
744 if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems)))
745 return llvm::None;
746
747 if (!matchPattern(paddings_value, m_Constant(&paddings_elems)))
748 return llvm::None;
749
750 SmallVector<int32_t> a0_pad_const(2 * (input_rank));
751 SmallVector<int64_t> padded_shape(input_rank);
752
753 // 1. Pad based on paddings operand. No padding on the batch dimension.
754 // The a0_pad_const array is addressed as [input_rank][2], but
755 // it is flattened to a 1D array because LLVM appears to only accept 1D.
756 //
757 // padded_shape[] is the shape of the padded output of step a1.
758 // The name is retained for consistency with the TF reference code.
759 padded_shape[0] = input_shape[0];
760
761 // Batch dimension padding
762 a0_pad_const[0] = 0;
763 a0_pad_const[1] = 0;
764
765 // This iterator seems to be the only reliable way to get
766 // int values out of a multi-dimensional ElementsAttr.
767 int idx = 0;
768
769 for (auto i : paddings_elems.getValues<IntegerAttr>()) {
770 a0_pad_const[idx + 2] = i.getInt();
771 padding_sum += i.getInt();
772 idx++;
773 }
774
775 // Insert padding on the spatial shape dimensions
776 for (int i = 0; i < block_rank; i++) {
777 int32_t lo_pad = a0_pad_const[2 * (i + 1) + 0];
778 int32_t hi_pad = a0_pad_const[2 * (i + 1) + 1];
779
780 padded_shape[i + 1] = input_shape[i + 1] + lo_pad + hi_pad;
781 }
782
783 // No padding on the remaining_shape dimensions
784 for (int i = 0; i < remaining_shape_rank; i++) {
785 a0_pad_const[2 * (i + block_rank + 1) + 0] = 0;
786 a0_pad_const[2 * (i + block_rank + 1) + 1] = 0;
787 padded_shape[i + block_rank + 1] = input_shape[i + block_rank + 1];
788 }
789
790 RankedTensorType a0_pad_const_attr_type =
791 RankedTensorType::get({(input_rank), 2}, rewriter.getIntegerType(32));
792
793 // Create a const op to generate the tensor type for the input padding array
794 auto a0_pad_const_op = rewriter.create<tosa::ConstOp>(
795 op->getLoc(), a0_pad_const_attr_type,
796 DenseElementsAttr::get(a0_pad_const_attr_type,
797 llvm::makeArrayRef(a0_pad_const)));
798
799 auto a1_pad_input_op = rewriter.create<tosa::PadOp>(
800 op->getLoc(),
801 RankedTensorType::get(padded_shape, result_type.getElementType()),
802 input_value, a0_pad_const_op.getResult());
803
804 // 2. Reshape the padded structure of shape padded_shape to
805 // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
806 // padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
807 // remaining_shape
808
809 // block_rank = M (number of elements in block_shape)
810 // New rank: input_rank + block_rank
811 SmallVector<int64_t> a2_shape(1 + block_rank * 2 + remaining_shape_rank);
812
813 // First dimension is batch.
814 a2_shape[0] = input_type.getShape()[0];
815 for (int i = 0; i < block_rank; i++) {
816 int32_t block_shape_val =
817 rewriter
818 .getI32IntegerAttr(
819 block_shape_elems.getValue<IntegerAttr>(i).getInt())
820 .getInt();
821 a2_shape[1 + i * 2 + 0] = padded_shape[1 + i] / block_shape_val;
822 a2_shape[1 + i * 2 + 1] = block_shape_val;
823 block_num_elems *= block_shape_val;
824 }
825
826 // Copy in the remaining block shape.
827 for (int i = 0; i < remaining_shape_rank; i++) {
828 a2_shape[1 + block_rank * 2 + i] = input_shape[1 + block_rank + i];
829 }
830
831 auto a2_reshape_a1_op = rewriter.create<tosa::ReshapeOp>(
832 op->getLoc(),
833 RankedTensorType::get(a2_shape, result_type.getElementType()),
834 a1_pad_input_op.getResult(), rewriter.getI64ArrayAttr(a2_shape));
835
836 // 3. Transpose dimensions to:
837 // block-shape +
838 // [batch] +
839 // [padded_shape[1] / block_shape[0],
840 // ...
841 // [padded_shape[M] / block_shape[M-1]] +
842 // remaining_shape
843 int32_t a2_reshape_a1_rank =
844 a2_reshape_a1_op.getResult().getType().cast<RankedTensorType>().getRank();
845 SmallVector<int32_t> a3_perm(a2_reshape_a1_rank);
846 SmallVector<int64_t> a3_transpose_shape(a2_reshape_a1_rank);
847
848 for (int i = 0; i < block_rank; i++) {
849 a3_perm[i] = 1 + 2 * i + 1;
850 a3_perm[block_rank + 1 + i] = 1 + 2 * i;
851 }
852 a3_perm[block_rank] = 0;
853 for (int i = 1 + block_rank * 2; i < a2_reshape_a1_rank; i++) {
854 a3_perm[i] = i;
855 }
856
857 for (int i = 0; i < a3_transpose_shape.size(); i++) {
858 a3_transpose_shape[i] = a2_shape[a3_perm[i]];
859 }
860
861 llvm::Optional<Value> a3_transpose_const = getConstTensor<int32_t>(
862 rewriter, op, a3_perm, {static_cast<int64_t>(a3_perm.size())});
863
864 if (!a3_transpose_const) return llvm::None;
865
866 auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
867 op->getLoc(),
868 RankedTensorType::get(a3_transpose_shape, result_type.getElementType()),
869 a2_reshape_a1_op.getResult(), a3_transpose_const.getValue());
870
871 // 4. Reshape the transposed tensor to flatten block_shape
872 // into the batch dimension with the following shape:
873 // [ batch * prod(block_shape)] +
874 // [ padded_shape[1] / block_shape[0],
875 // ...,
876 // padded_shape[M] / block_shape[M-1]] +
877 // remaining_shape
878 SmallVector<int64_t> a4_reshape_shape(input_rank);
879
880 // Batch
881 a4_reshape_shape[0] = batch_size * block_num_elems;
882
883 // padded shape / block_shape.
884 for (int i = 0; i < block_rank; i++) {
885 int32_t block_shape_val =
886 rewriter
887 .getI32IntegerAttr(
888 block_shape_elems.getValue<IntegerAttr>(i).getInt())
889 .getInt();
890 a4_reshape_shape[i + 1] = padded_shape[i + 1] / block_shape_val;
891 }
892
893 // Copy in remainder shape.
894 for (int i = 0; i < remaining_shape_rank; i++) {
895 a4_reshape_shape[1 + block_rank + i] = input_shape[1 + block_rank + i];
896 }
897
898 return rewriter
899 .create<tosa::ReshapeOp>(op->getLoc(), result_type,
900 a3_transpose_a2_op.getResult(),
901 rewriter.getI64ArrayAttr(a4_reshape_shape))
902 .getResult();
903 }
904
905 // Lowers BatchToSpaceND to TOSA.
convertBatchToSpaceNDOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value block_shape_value,Value crops_value)906 llvm::Optional<Value> convertBatchToSpaceNDOp(PatternRewriter& rewriter,
907 Operation* op, Value result_value,
908 Value input_value,
909 Value block_shape_value,
910 Value crops_value) {
911 /////////////////////////////////////////////////
912 // Operator: output = BatchToSpaceND(input, block_shape, clips)
913 // Lowering:
914 //
915 // BatchToSpace input tensors are broken into three pieces:
916 // (a) batch dimension (N in NHWC)
917 // (b) input being transformed from batch dimension (typically H, W in
918 // NHWC)
919 // (c) remainder of input (typically C in NHWC)
920 //
921 // Step 1. Reshape input to:
922 // [block_shape[0],
923 // ...
924 // [block_shape[M-1],
925 // [batch / prod(block_shape)]
926 // [input_shape[1],
927 // ...
928 // [input_shape[N-1]
929 //
930 // a1_reshape_input_op = tosa.reshape(input=input, shape=a1_shape)
931 //
932 // Step 2. Permute to shape
933 // [ batch / prod(block_shape) ],
934 // [ input_shape[1] ], [ block_shape[1] ]
935 // ...
936 // [ input_shape[M] ], [ block_shape[M-1]
937 // + remaining_input_shapes input_shape[M .. N-1]
938 //
939 // a2_transpose_a1 = tosa.transpose(input=a1_reshape_input_op,
940 // shape=a2_shape)
941 //
942 // Step 3. Reshape to:
943 // [ batch / prod(block_shape) ],
944 // [input_shape[1] * block_shape[0] ],
945 // ..
946 // [input_shape[M * block_shape[M-1],
947 // + remaining input shapes [input_shape[M+1.. N-1]]
948 //
949 // a3_reshape_a2 = tosa.reshape(input=a2_transpose_a1, shape=a3_shape)
950 //
951 // Step 4. Crop the start/end dimensions according to crops of the
952 // a3_reshape_a2 shape
953 //
954 // a4_slice_a3 = tosa.slice(input=a3_reshape_a2, start=a4_start,
955 // size=a4_size)
956
957 RankedTensorType result_type =
958 result_value.getType().dyn_cast<RankedTensorType>();
959 RankedTensorType input_type =
960 input_value.getType().dyn_cast<RankedTensorType>();
961 RankedTensorType block_shape_type =
962 block_shape_value.getType().dyn_cast<RankedTensorType>();
963 RankedTensorType crops_type =
964 crops_value.getType().dyn_cast<RankedTensorType>();
965
966 if (!result_type) {
967 op->emitOpError("BatchToSpaceND: result type not ranked tensor");
968 return llvm::None;
969 }
970 if (!input_type) {
971 op->emitOpError("BatchToSpaceND: input type not ranked tensor");
972 return llvm::None;
973 }
974 if (!block_shape_type) {
975 op->emitOpError("BatchToSpaceND: block shape type not ranked tensor");
976 return llvm::None;
977 }
978 if (!crops_type) {
979 op->emitOpError("BatchToSpaceND: crops type not ranked tensor");
980 return llvm::None;
981 }
982
983 // Another 4-step process
984 int block_rank = block_shape_type.getShape()[0];
985 int input_rank = input_type.getRank();
986 int crops_dims = crops_type.getShape()[0];
987 int remaining_shape_rank = input_rank - block_rank - 1;
988 auto input_shape = input_type.getShape();
989
990 ElementsAttr block_shape_elems;
991 ElementsAttr crops_elems;
992
993 if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems))) {
994 op->emitOpError("BatchToSpaceND: block_shape not a constant");
995 return llvm::None;
996 }
997
998 if (!matchPattern(crops_value, m_Constant(&crops_elems))) {
999 op->emitOpError("BatchToSpaceND: crops not a constant");
1000 return llvm::None;
1001 }
1002
1003 SmallVector<int64_t> block_shape(block_rank);
1004 SmallVector<std::pair<int64_t, int64_t>> crops(crops_dims);
1005
1006 // Extract values for block_shape and crops now.
1007 int block_num_elems = 1;
1008 for (int i = 0; i < block_rank; i++) {
1009 int block_shape_val =
1010 rewriter
1011 .getI32IntegerAttr(
1012 block_shape_elems.getValue<IntegerAttr>(i).getInt())
1013 .getInt();
1014 block_num_elems *= block_shape_val;
1015 block_shape[i] = block_shape_val;
1016 }
1017
1018 // This iterator seems to be the only reliable way to get
1019 // int values out of a multi-dimensional ElementsAttr
1020 SmallVector<int32_t> crops_const(2 * (crops_dims));
1021 int idx = 0;
1022 for (auto i : crops_elems.getValues<IntegerAttr>()) {
1023 crops_const[idx++] = i.getInt();
1024 }
1025
1026 for (int i = 0; i < crops_dims; i++) {
1027 int crops_lo = crops_const[i * crops_dims + 0];
1028 int crops_hi = crops_const[i * crops_dims + 1];
1029 crops[i] = std::make_pair(crops_lo, crops_hi);
1030 }
1031
1032 // Step 1. Reshape input to:
1033 // [block_shape[0],
1034 // ...
1035 // [block_shape[M-1],
1036 // [batch / prod(block_shape)]
1037 // [input_shape[1],
1038 // ...
1039 // [input_shape[N-1]
1040 SmallVector<int64_t> a1_shape(block_rank + input_rank);
1041
1042 for (int i = 0; i < block_rank; i++) a1_shape[i] = block_shape[i];
1043
1044 a1_shape[block_rank] = input_shape[0] / block_num_elems;
1045
1046 for (int i = 0; i < input_rank - 1; i++)
1047 a1_shape[i + block_rank + 1] = input_shape[i + 1];
1048
1049 auto a1_reshape_input_op = rewriter.create<tosa::ReshapeOp>(
1050 op->getLoc(),
1051 RankedTensorType::get(a1_shape, result_type.getElementType()),
1052 input_value, rewriter.getI64ArrayAttr(a1_shape));
1053
1054 // 2. Permute to shape
1055 // [ batch / prod(block_shape) ],
1056 // [ input_shape[1] ], [ block_shape[0] ]
1057 // ...
1058 // [ input_shape[M] ], [ block_shape[M-1]
1059 // + remaining_input_shapes input_shape[M+1 .. N-1]
1060
1061 // 2a. calculate the permutation
1062 SmallVector<int32_t> a2_perm(block_rank + input_rank);
1063 SmallVector<int64_t> a2_transpose_shape(block_rank + input_rank);
1064
1065 a2_perm[0] = block_rank;
1066 for (int i = 0; i < block_rank; i++) {
1067 a2_perm[1 + i * 2 + 0] = block_rank + 1 + i;
1068 a2_perm[1 + i * 2 + 1] = i;
1069 }
1070
1071 for (int i = 0; i < remaining_shape_rank; i++) {
1072 a2_perm[1 + 2 * block_rank + i] = 1 + 2 * block_rank + i;
1073 }
1074
1075 // 2b. calculate the a2_permuted shape
1076 for (int i = 0; i < (block_rank + input_rank); i++) {
1077 a2_transpose_shape[i] = a1_shape[a2_perm[i]];
1078 }
1079
1080 llvm::Optional<Value> a2_transpose_perm = getConstTensor<int32_t>(
1081 rewriter, op, a2_perm, {static_cast<int64_t>(a2_perm.size())});
1082
1083 if (!a2_transpose_perm) return llvm::None;
1084
1085 auto a2_transpose_a1_op = rewriter.create<tosa::TransposeOp>(
1086 op->getLoc(),
1087 RankedTensorType::get(a2_transpose_shape, result_type.getElementType()),
1088 a1_reshape_input_op.getResult(), a2_transpose_perm.getValue());
1089
1090 // Step 3. Reshape to:
1091 // [ batch / prod(block_shape) ],
1092 // [input_shape[1] * block_shape[0] ],
1093 // ..
1094 // [input_shape[M * block_shape[M-1],
1095 // + remaining input shapes [input_shape[M+1.. N-1]]
1096 SmallVector<int64_t> a4_shape(input_rank);
1097
1098 a4_shape[0] = input_shape[0] / block_num_elems;
1099 for (int i = 0; i < block_rank; i++) {
1100 a4_shape[1 + i] = input_shape[i + 1] * block_shape[i];
1101 }
1102 for (int i = 0; i < remaining_shape_rank; i++) {
1103 a4_shape[1 + block_rank + i] = input_shape[block_rank + 1 + i];
1104 }
1105
1106 auto a3_reshape_a2 = rewriter.create<tosa::ReshapeOp>(
1107 op->getLoc(),
1108 RankedTensorType::get(a4_shape, result_type.getElementType()),
1109 a2_transpose_a1_op.getResult(), rewriter.getI64ArrayAttr(a4_shape));
1110
1111 // 4. Crop the start/end dimensions on 'spatial dimension' according to
1112 // crops
1113 // Use a slice operator to do the cropping.
1114 //
1115 // Calculate a beginning point and a size:
1116 // - Begin is the origin, offset by the lo crop amount in each dimension
1117 // - Size is the reshaped tensor size, minus the quantity (lo + hi) for each
1118 // dimension
1119 SmallVector<int64_t> a4_begin_vals(input_rank), a4_size_vals(input_rank);
1120
1121 for (int i = 0; i < input_rank; i++) {
1122 // Batch dimension and remaining dimensions.
1123 if (i == 0 || i > crops_dims) {
1124 a4_begin_vals[i] = 0;
1125 a4_size_vals[i] = result_type.getShape()[i];
1126 } else {
1127 // Spatial dimension.
1128 assert(i - 1 >= 0 && i - 1 < crops_dims);
1129 a4_begin_vals[i] = crops[i - 1].first;
1130 a4_size_vals[i] = a4_shape[i] - crops[i - 1].first - crops[i - 1].second;
1131 }
1132 }
1133
1134 return rewriter
1135 .create<tosa::SliceOp>(
1136 op->getLoc(),
1137 RankedTensorType::get(a4_size_vals, result_type.getElementType()),
1138 a3_reshape_a2.getResult(), rewriter.getI64ArrayAttr(a4_begin_vals),
1139 rewriter.getI64ArrayAttr(a4_size_vals))
1140 .getResult();
1141 }
1142
1143 // Lowers ExpandDims to TOSA.
convertExpandDimsOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value dim_value)1144 llvm::Optional<Value> convertExpandDimsOp(PatternRewriter& rewriter,
1145 Operation* op, Value result_value,
1146 Value input_value, Value dim_value) {
1147 // Lowers to a reshape op with 1's inserted in the appropriate dimensions.
1148 RankedTensorType output_type =
1149 result_value.getType().dyn_cast<RankedTensorType>();
1150 // Not a ranked tensor output
1151 if (!output_type) {
1152 op->emitOpError("ExpandDims: output type not ranked tensor");
1153 return llvm::None;
1154 }
1155
1156 RankedTensorType input_type =
1157 input_value.getType().dyn_cast<RankedTensorType>();
1158 if (!input_type) {
1159 op->emitOpError("ExpandDims: input type not ranked tensor");
1160 return llvm::None;
1161 }
1162
1163 auto input_shape = input_type.getShape();
1164
1165 ElementsAttr dim_elem;
1166 if (!matchPattern(dim_value, m_Constant(&dim_elem))) return llvm::None;
1167
1168 assert(dim_elem.getType().getRank() == 0 && "expected scalar tensor");
1169 int32_t dim = dim_elem.getValue<IntegerAttr>({}).getInt();
1170
1171 SmallVector<int64_t> reshape_dims;
1172 if (dim < 0 || dim >= input_shape.size()) { // add dim at end of tensor
1173 dim = input_shape.size();
1174 for (int i = 0; i < input_shape.size(); i++) {
1175 reshape_dims.emplace_back(input_shape[i]);
1176 }
1177 reshape_dims.emplace_back(1);
1178 } else {
1179 for (int i = 0; i < input_shape.size(); i++) {
1180 if (i == dim) {
1181 reshape_dims.emplace_back(1);
1182 }
1183 reshape_dims.emplace_back(input_shape[i]);
1184 }
1185 }
1186
1187 ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
1188
1189 return rewriter
1190 .create<tosa::ReshapeOp>(op->getLoc(), output_type, input_value,
1191 shape_attr)
1192 .getResult();
1193 }
1194
1195 // Lowers Squeeze to TOSA.
convertSqueezeOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,SmallVectorImpl<int32_t> & squeeze_dims)1196 llvm::Optional<Value> convertSqueezeOp(PatternRewriter& rewriter, Operation* op,
1197 Value result_value, Value input_value,
1198 SmallVectorImpl<int32_t>& squeeze_dims) {
1199 // Lowers to a reshape op where dimensions in squeeze_dims with size=1
1200 // are removed.
1201 RankedTensorType output_type =
1202 result_value.getType().dyn_cast<RankedTensorType>();
1203 // Not a ranked tensor output
1204 if (!output_type) {
1205 op->emitOpError("Squeeze: output type not ranked tensor");
1206 return llvm::None;
1207 }
1208
1209 RankedTensorType input_type =
1210 input_value.getType().dyn_cast<RankedTensorType>();
1211 if (!input_type) {
1212 op->emitOpError("Squeeze: input type not ranked tensor");
1213 return llvm::None;
1214 }
1215
1216 auto input_shape = input_type.getShape();
1217
1218 SmallVector<int64_t> reshape_dims;
1219
1220 if (squeeze_dims.empty()) { // remove all 1-dims
1221 for (int i = 0; i < input_shape.size(); i++) {
1222 if (input_shape[i] != 1) {
1223 reshape_dims.emplace_back(input_shape[i]);
1224 }
1225 }
1226 } else {
1227 // Remove only specified dims.
1228 // First sort the array so they can be picked off in sequence.
1229 std::sort(squeeze_dims.begin(), squeeze_dims.end(),
1230 [](const int32_t a, const int32_t b) { return a < b; });
1231
1232 int pos = 0;
1233 auto dim = squeeze_dims[pos];
1234 for (int i = 0; i < input_shape.size(); i++) {
1235 if (i == dim) {
1236 pos = pos + 1;
1237 if (pos < squeeze_dims.size())
1238 dim = squeeze_dims[pos];
1239 else
1240 dim = -1; // Invalid
1241 } else {
1242 reshape_dims.emplace_back(input_shape[i]);
1243 }
1244 }
1245 }
1246
1247 ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
1248
1249 return rewriter
1250 .create<tosa::ReshapeOp>(op->getLoc(), output_type, input_value,
1251 shape_attr)
1252 .getResult();
1253 }
1254
1255 // Lowers ELU to a sequence of TOSA ops.
convertEluOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value features_value)1256 llvm::Optional<Value> convertEluOp(PatternRewriter& rewriter, Operation* op,
1257 Value result_value, Value features_value) {
1258 // Lowers Elu using the following formula:
1259 // elu(x) = x < 0 ? (exp(x) - 1) : x
1260 // one = const({1});
1261 // zero = const({0});
1262 // one_bcast = reshape(one, [1, ..., rank(x) - 1])
1263 // zero_bcast = reshape(zero, [1, ..., rank(x) - 1])
1264 // a1 = exp(x);
1265 // a2 = sub(a1, one_bcast)
1266 // a3 = ge(x, zero_bcast)
1267 // a4 = select(a3, x, a2)
1268 RankedTensorType output_type =
1269 result_value.getType().dyn_cast<RankedTensorType>();
1270 // Not a ranked tensor output
1271 if (!output_type) {
1272 op->emitOpError("Elu: output type not ranked tensor");
1273 return llvm::None;
1274 }
1275
1276 int32_t input_rank = output_type.getShape().size();
1277 SmallVector<int64_t> bcast_shape(input_rank, 1);
1278
1279 // Can't directly create size=1, rank=rank(input) tensor because
1280 // it will be optimized out. Instead, create rank0 tensor and reshape later.
1281 Value one_const_op = getTosaConstTensorSingleF32(rewriter, op, 1.0);
1282
1283 Value zero_const_op = getTosaConstTensorSingleF32(rewriter, op, 0.0);
1284
1285 auto a1_exp_in_op =
1286 rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, features_value);
1287
1288 auto a2_sub_a1_one_op = rewriter.create<tosa::SubOp>(
1289 op->getLoc(), output_type, a1_exp_in_op.getResult(), one_const_op);
1290
1291 auto a3_ge_in_zero_op = rewriter.create<tosa::GreaterEqualOp>(
1292 op->getLoc(),
1293 RankedTensorType::get(output_type.getShape(), rewriter.getIntegerType(1)),
1294 features_value, zero_const_op);
1295
1296 return rewriter
1297 .create<tosa::SelectOp>(op->getLoc(), output_type,
1298 a3_ge_in_zero_op.getResult(), features_value,
1299 a2_sub_a1_one_op.getResult())
1300 .getResult();
1301 }
1302
1303 // Lowers Softmax to a sequence of TOSA ops.
convertSoftmaxOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value logits_value,double beta)1304 llvm::Optional<Value> convertSoftmaxOp(PatternRewriter& rewriter, Operation* op,
1305 Value result_value, Value logits_value,
1306 double beta) {
1307 // softmax = exp(logits) / reduce_sum(exp(logits), -1)
1308 //
1309 // or equivalently multiply exp(-max(logits)) to both numerator and
1310 // denominator we get:
1311 //
1312 // softmax = exp(logits - max(logits)) / reduce_sum(exp(logits -
1313 // max(logits)), -1)
1314 //
1315 // We'll use first version for direct fp lowering, and second version for
1316 // quantized lowering since second one we can restrict input to exp() be
1317 // negative, and thus LUT can always be within [0.0, 1.0].
1318 RankedTensorType output_type =
1319 result_value.getType().dyn_cast<RankedTensorType>();
1320 RankedTensorType input_type =
1321 logits_value.getType().dyn_cast<RankedTensorType>();
1322
1323 // Not a ranked tensor input/output
1324 if (!output_type || !input_type) {
1325 op->emitOpError("Softmax: input and result not ranked tensors");
1326 return llvm::None;
1327 }
1328
1329 // reduce_sum on last dimension
1330 int32_t input_rank = input_type.getShape().size();
1331 ArrayRef<int64_t> logits_shape = output_type.getShape();
1332
1333 if (input_type.getElementType().isa<mlir::quant::QuantizedType>() &&
1334 output_type.getElementType().isa<mlir::quant::QuantizedType>()) {
1335 SmallVector<int64_t> rsum_shape_v(input_type.getShape().begin(),
1336 input_type.getShape().end() - 1);
1337 rsum_shape_v.push_back(1);
1338 ArrayRef<int64_t> rsum_shape(rsum_shape_v);
1339 // The if condition already checks if these are UQTs
1340 mlir::quant::UniformQuantizedType in_quant_type =
1341 input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
1342 mlir::quant::UniformQuantizedType out_quant_type =
1343 output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
1344
1345 auto int16_element_qtype = mlir::quant::UniformQuantizedType::get(
1346 true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
1347 -32768, 32767);
1348 RankedTensorType int16_logits_type =
1349 RankedTensorType::get(logits_shape, int16_element_qtype);
1350 RankedTensorType int32_logits_type =
1351 RankedTensorType::get(logits_shape, rewriter.getIntegerType(32));
1352 RankedTensorType int16_rsum_type =
1353 RankedTensorType::get(rsum_shape, int16_element_qtype);
1354 RankedTensorType int32_rsum_type =
1355 RankedTensorType::get(rsum_shape, rewriter.getIntegerType(32));
1356
1357 if (in_quant_type.getStorageTypeIntegralWidth() == 8) {
1358 // Step 1. get x - max(x)
1359 Value op1_rescale_in =
1360 buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f,
1361 in_quant_type.getZeroPoint(), 0, false, true);
1362
1363 auto op2_reducemax_op1 = rewriter.create<tosa::ReduceMaxOp>(
1364 op->getLoc(), int32_rsum_type, op1_rescale_in,
1365 rewriter.getI64IntegerAttr(input_rank - 1));
1366
1367 auto op3_sub_op1_op2 = rewriter.create<tosa::SubOp>(
1368 op->getLoc(), int32_logits_type, op1_rescale_in,
1369 op2_reducemax_op1.getResult());
1370
1371 // Step 2. get exp() result
1372 // Implemented with two 8-bit -> 16-bit table lookup
1373 // Since table output is allowed to be [-32768, 32767]
1374 // And lower 16 bits are unsigned and ranges [0, 65535]
1375 // Lower table is generated with offset -32768, and this need to be
1376 // recovered before adding with higher 16 bits.
1377 auto exp_func = [](double x) -> double { return std::exp(x); };
1378
1379 Value exp_table_const_upper, exp_table_const_lower;
1380 getTosaConst32bitTable(rewriter, op, beta * in_quant_type.getScale(), 0,
1381 exp_func, exp_table_const_upper,
1382 exp_table_const_lower);
1383
1384 Value op4_rescale_op3 =
1385 buildRescale(rewriter, op, int16_logits_type,
1386 op3_sub_op1_op2.getResult(), 128.0, 0, 0, false, true);
1387
1388 // Input is 9.7, where lower 7 bits are all zeros.
1389 // Output is 23 bits, where lower 7 bits should be all zeros as well,
1390 // since there's no interpolation here.
1391 auto op5_table_op4_upper = rewriter.create<tosa::TableOp>(
1392 op->getLoc(), int32_logits_type, op4_rescale_op3,
1393 exp_table_const_upper);
1394
1395 auto op6_table_op4_lower = rewriter.create<tosa::TableOp>(
1396 op->getLoc(), int32_logits_type, op4_rescale_op3,
1397 exp_table_const_lower);
1398
1399 // To get 16 bits upper/lower value, we need to right shift 7 bits
1400 // And then we reconstruct 32-bit value we need (upper << 16) + lower
1401 // So effectively we left shift upper with 9 bits
1402 auto op7_lshift_op5 = rewriter.create<tosa::LogicalLeftShiftOp>(
1403 op->getLoc(), int32_logits_type, op5_table_op4_upper.getResult(),
1404 getTosaConstTensorSingleI32(rewriter, op, 9));
1405
1406 // Right shift 7 bits to get lower 16 bits.
1407 auto op8_rshift_op6 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1408 op->getLoc(), int32_logits_type, op6_table_op4_lower.getResult(),
1409 getTosaConstTensorSingleI32(rewriter, op, 7), true);
1410
1411 // Recover lower bits from [-32768, 32767] back to [0, 65535]
1412 auto op9_add_op8_32768 = rewriter.create<tosa::AddOp>(
1413 op->getLoc(), int32_logits_type, op8_rshift_op6.getResult(),
1414 getTosaConstTensorSingleI32(rewriter, op, 32768));
1415
1416 auto op10_add_op7_op9 = rewriter.create<tosa::AddOp>(
1417 op->getLoc(), int32_logits_type, op7_lshift_op5.getResult(),
1418 op9_add_op8_32768.getResult());
1419
1420 // Step 3. get sum(exp()). output 12.19
1421 auto op11_rshift_op10_12 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1422 op->getLoc(), int32_logits_type, op10_add_op7_op9.getResult(),
1423 getTosaConstTensorSingleI32(rewriter, op, 12), true);
1424
1425 auto op12_reducesum_op11 = rewriter.create<tosa::ReduceSumOp>(
1426 op->getLoc(), int32_rsum_type, op11_rshift_op10_12.getResult(),
1427 rewriter.getI64IntegerAttr(input_rank - 1));
1428
1429 // Step 4. calculate reciprocal(sum(exp()))
1430 // CLZ returns headroom_plus_one
1431 auto op13_clz_op12 = rewriter.create<tosa::ClzOp>(
1432 op->getLoc(), int32_rsum_type, op12_reducesum_op11.getResult());
1433
1434 // minus one to get headroom
1435 auto op14_sub_op13 = rewriter.create<tosa::SubOp>(
1436 op->getLoc(), int32_rsum_type, op13_clz_op12.getResult(),
1437 getTosaConstTensorSingleI32(rewriter, op, 1));
1438
1439 // Left shift to get s1.30 format
1440 auto op15_lshift_op12_op14 = rewriter.create<tosa::LogicalLeftShiftOp>(
1441 op->getLoc(), int32_rsum_type, op12_reducesum_op11.getResult(),
1442 op14_sub_op13.getResult());
1443
1444 // Step 5. Calculate one_over_one_plus_x() with Newton-Raphson division
1445 // with 3 iterations.
1446 // Need two magic constants 48/17 and -32/17 from Newton-Raphson algorithm
1447 // We need to operator in s2.29 since 48/17 is > 2.0
1448 // Reference: gemmlowp/fixedpoint/fixedpoint.h
1449 Value half_denominator = op15_lshift_op12_op14.getResult();
1450 Value four = getTosaConstTensorSingleI32(rewriter, op, 4);
1451 Value F2_one = getTosaConstTensorSingleI32(rewriter, op, (1U << 29));
1452 Value constant_48_over_17 =
1453 getTosaConstTensorSingleI32(rewriter, op, 1515870810);
1454 Value constant_neg_32_over_17 =
1455 getTosaConstTensorSingleI32(rewriter, op, -1010580540);
1456
1457 // F2 x = constant_48_over_17 + half_denominator *
1458 // constant_neg_32_over_17;
1459 auto op16_mul_half_denominator = rewriter.create<tosa::MulOp>(
1460 op->getLoc(), int32_rsum_type, half_denominator,
1461 constant_neg_32_over_17, 31);
1462
1463 auto op17_add_op16 = rewriter.create<tosa::AddOp>(
1464 op->getLoc(), int32_rsum_type, op16_mul_half_denominator.getResult(),
1465 constant_48_over_17);
1466
1467 // Newton-Raphson 3x iteration
1468 Value nr_x = op17_add_op16.getResult();
1469 for (int i = 0; i < 3; i++) {
1470 // half_denominator_times_x =
1471 // SaturatingRoundingDoublingHighMul(half_denominator, x)
1472 auto op18_mul_x_half_denominator = rewriter.create<tosa::MulOp>(
1473 op->getLoc(), int32_rsum_type, nr_x, half_denominator, 31);
1474
1475 // F2 one_minus_half_denominator_times_x = F2::One() -
1476 // half_denominator_times_x
1477 auto op19_sub_one_op18 = rewriter.create<tosa::SubOp>(
1478 op->getLoc(), int32_rsum_type, F2_one,
1479 op18_mul_x_half_denominator.getResult());
1480
1481 // SaturatingRoundingDoublingHighMul(x,
1482 // one_minus_half_denominator_times_x)
1483 auto op20_mul_x_op19 =
1484 rewriter.create<tosa::MulOp>(op->getLoc(), int32_rsum_type, nr_x,
1485 op19_sub_one_op18.getResult(), 31);
1486
1487 // x + Rescale<2>(x * one_minus_half_denominator_times_x)
1488 auto op21_mul_op20_four =
1489 rewriter.create<tosa::MulOp>(op->getLoc(), int32_rsum_type,
1490 op20_mul_x_op19.getResult(), four, 0);
1491
1492 auto op22_add_x_op21 =
1493 rewriter.create<tosa::AddOp>(op->getLoc(), int32_rsum_type, nr_x,
1494 op21_mul_op20_four.getResult());
1495
1496 nr_x = op22_add_x_op21.getResult();
1497 }
1498
1499 // Step 6. multiply exp(x) with 1 / sum(exp(x))
1500 // combined with Rescale<0>(ExactMulByPot<-1>(x))
1501 // so shift 30 instead of 31
1502 auto op23_mul_op10_x = rewriter.create<tosa::MulOp>(
1503 op->getLoc(), int32_logits_type, op10_add_op7_op9.getResult(), nr_x,
1504 31 - 1);
1505
1506 // Right shift amount is
1507 // num_bits_over_unit + 31 - (sizeof(OutputT) * 8 =
1508 // (12 - headroom_plus_one) + 31 - 8 =
1509 // (12 + 31 - 8) - headroom_plus_one
1510 auto op24_sub_op13 = rewriter.create<tosa::SubOp>(
1511 op->getLoc(), int32_rsum_type,
1512 getTosaConstTensorSingleI32(rewriter, op, 12 + 31 - 8),
1513 op13_clz_op12.getResult());
1514
1515 auto op25_rshift_op23_op24 =
1516 rewriter.create<tosa::ArithmeticRightShiftOp>(
1517 op->getLoc(), int32_logits_type, op23_mul_op10_x.getResult(),
1518 op24_sub_op13.getResult(), true);
1519
1520 return buildRescale(rewriter, op, output_type,
1521 op25_rshift_op23_op24.getResult(), 1.0, 0,
1522 out_quant_type.getZeroPoint(), false, true);
1523
1524 } else if (in_quant_type.getStorageTypeIntegralWidth() == 16) {
1525 // Step 1. get x - max(x)
1526 Value op1_rescale_in =
1527 buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f,
1528 in_quant_type.getZeroPoint(), 0, false, true);
1529
1530 auto op2_reducemax_op1 = rewriter.create<tosa::ReduceMaxOp>(
1531 op->getLoc(), int32_rsum_type, op1_rescale_in,
1532 rewriter.getI64IntegerAttr(input_rank - 1));
1533
1534 // output range is [-65535, 0]
1535 auto op3_sub_op1_op2 = rewriter.create<tosa::SubOp>(
1536 op->getLoc(), int32_logits_type, op1_rescale_in,
1537 op2_reducemax_op1.getResult());
1538
1539 auto exp_func = [](double x) -> double { return std::exp(x); };
1540
1541 // Follow TFLite reference: tensorflow/lite/kernels/activations.cc
1542 Value exp_table_const =
1543 getTosaConst16bitTable(rewriter, op, exp_func, -10.0, 0);
1544
1545 double input_diff_scale = in_quant_type.getScale() / (10.0 / 65535.0);
1546
1547 // Step 2. rescale input from [-65535, 0] to [-32768, 32767] for LUT input
1548 Value op4_rescale_op3 = buildRescale(
1549 rewriter, op, int16_logits_type, op3_sub_op1_op2.getResult(),
1550 input_diff_scale, 0, 32767, true, true);
1551
1552 // Step 3. get exp() result
1553 // Output is 15.7.
1554 // In 8-bit case, no interpolation here, since input should be right on
1555 // table entry.
1556 auto op5_table_op4 = rewriter.create<tosa::TableOp>(
1557 op->getLoc(), int32_logits_type, op4_rescale_op3, exp_table_const);
1558
1559 // Right shift 7 bits. output 15. Shouldn't lose any precision since last
1560 // 7 bits should be all 0.
1561 auto op6_rshift_op5 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1562 op->getLoc(), int32_logits_type, op5_table_op4.getResult(),
1563 getTosaConstTensorSingleI32(rewriter, op, 7), true);
1564
1565 // Step 4. get sum(exp()). output 16.15
1566 auto op7_reducesum_op6 = rewriter.create<tosa::ReduceSumOp>(
1567 op->getLoc(), int32_rsum_type, op6_rshift_op5.getResult(),
1568 rewriter.getI64IntegerAttr(input_rank - 1));
1569
1570 // Step 5. calculate reciprocal(sum(exp()))
1571 // CLZ returns 32 - first non zero bit
1572 auto op8_clz_op7 = rewriter.create<tosa::ClzOp>(
1573 op->getLoc(), int32_rsum_type, op7_reducesum_op6.getResult());
1574
1575 auto op9_sub_op8 = rewriter.create<tosa::SubOp>(
1576 op->getLoc(), int32_rsum_type, op8_clz_op7.getResult(),
1577 getTosaConstTensorSingleI32(rewriter, op, 1));
1578
1579 // Left shift to get 1.30 format
1580 auto op10_lshift_op7_op9 = rewriter.create<tosa::LogicalLeftShiftOp>(
1581 op->getLoc(), int32_rsum_type, op7_reducesum_op6.getResult(),
1582 op9_sub_op8.getResult());
1583
1584 // Subtract (1 << 30) to make 0 <= x <= 1 under 0.30 format
1585 auto op11_sub_op10 = rewriter.create<tosa::SubOp>(
1586 op->getLoc(), int32_rsum_type, op10_lshift_op7_op9.getResult(),
1587 getTosaConstTensorSingleI32(rewriter, op, (1u << 30)));
1588
1589 // Right shift 14 bits to get output range [0, 65535]
1590 auto op12_rshift_op11 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1591 op->getLoc(), int32_rsum_type, op11_sub_op10.getResult(),
1592 getTosaConstTensorSingleI32(rewriter, op, 14), true);
1593
1594 // Remap input to [-32768, 32767] for LUT input
1595 auto op13_rescale_op12 = buildRescale(rewriter, op, int16_rsum_type,
1596 op12_rshift_op11.getResult(), 1.0,
1597 32768, 0, false, true);
1598
1599 // Generate table for 1 / (1 + x), for 0 <= x <= 1
1600 auto one_over_one_plus_x_func = [](double x) -> double {
1601 return 1.0 / (1.0 + x);
1602 };
1603
1604 Value one_over_one_plus_x_table_const = getTosaConst16bitTable(
1605 rewriter, op, one_over_one_plus_x_func, 0.0, 1.0);
1606
1607 // Get (1 / sum(exp(x))) result as 23 bits (including sign bit)
1608 auto op14_table_op13 = rewriter.create<tosa::TableOp>(
1609 op->getLoc(), int32_rsum_type, op13_rescale_op12,
1610 one_over_one_plus_x_table_const);
1611
1612 // Right shift 7 bits back to 0.15
1613 auto op15_rshift_op14 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1614 op->getLoc(), int32_rsum_type, op14_table_op13.getResult(),
1615 getTosaConstTensorSingleI32(rewriter, op, 7), true);
1616
1617 // Step 6. multiply exp(max-x) with 1 / sum(exp(max-x))
1618 // lhs: 0.15, rhs: 0.15, output: 0.30
1619 auto op16_mul_op15_op6 = rewriter.create<tosa::MulOp>(
1620 op->getLoc(), int32_logits_type, op15_rshift_op14, op6_rshift_op5, 0);
1621
1622 auto op17_sub_op8 = rewriter.create<tosa::SubOp>(
1623 op->getLoc(), int32_rsum_type,
1624 getTosaConstTensorSingleI32(rewriter, op, 31),
1625 op8_clz_op7.getResult());
1626
1627 // Apply the clz back, we get 0.15 output
1628 // [0, 32767] corresponding to [0.0, 1.0]
1629 auto op18_rshift_op16_op17 =
1630 rewriter.create<tosa::ArithmeticRightShiftOp>(
1631 op->getLoc(), int32_logits_type, op16_mul_op15_op6.getResult(),
1632 op17_sub_op8.getResult(), true);
1633
1634 return buildRescale(rewriter, op, output_type,
1635 op18_rshift_op16_op17.getResult(),
1636 (1.0 / out_quant_type.getScale()) * (1.0 / 32768.0),
1637 0, out_quant_type.getZeroPoint(), false, true);
1638 } else {
1639 op->emitOpError("Softmax: unknown quantization bitwidth");
1640 return llvm::None;
1641 }
1642 } else {
1643 SmallVector<int64_t> rsum_shape_v(input_type.getShape().begin(),
1644 input_type.getShape().end());
1645 rsum_shape_v[input_rank - 1] = 1;
1646 ArrayRef<int64_t> rsum_shape(rsum_shape_v);
1647
1648 // Floating-point loewring is more direct:
1649 //
1650 // op1 = exp(logits)
1651 // op2 = reduce_sum(op1, -1)
1652 // op3 = reciprocal(op2)
1653 // op4 = mul(op1, op3)
1654 auto op1_exp_in =
1655 rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, logits_value);
1656 RankedTensorType rsum_type =
1657 RankedTensorType::get(rsum_shape, output_type.getElementType());
1658
1659 // Keep dims so we don't need to reshape later
1660 auto op2_reducesum_op1 = rewriter.create<tosa::ReduceSumOp>(
1661 op->getLoc(), rsum_type, op1_exp_in.getResult(),
1662 rewriter.getI64IntegerAttr(input_rank - 1));
1663 auto op3_reciprocal_op2 = rewriter.create<tosa::ReciprocalOp>(
1664 op->getLoc(), op2_reducesum_op1.getType(),
1665 op2_reducesum_op1.getResult());
1666
1667 return rewriter
1668 .create<tosa::MulOp>(op->getLoc(), output_type, op1_exp_in.getResult(),
1669 op3_reciprocal_op2.getResult(), 0)
1670 .getResult();
1671 }
1672 }
1673
1674 // Lowers LogSoftmax to a sequence of TOSA ops.
convertLogSoftmaxOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value logits_value)1675 llvm::Optional<Value> convertLogSoftmaxOp(PatternRewriter& rewriter,
1676 Operation* op, Value result_value,
1677 Value logits_value) {
1678 // log_softmax = log(exp(logits) / reduce_sum(exp(logits), -1))
1679 // op1 = exp(logits)
1680 // op2 = reduce_sum(op1, -1)
1681 // op3 = reciprocal(op2)
1682 // op4 = mul(op1, op3)
1683 // op5 = log(op4)
1684
1685 RankedTensorType output_type =
1686 result_value.getType().dyn_cast<RankedTensorType>();
1687 // Not a ranked tensor output
1688 if (!output_type) {
1689 op->emitOpError("LogSoftmax: output type not ranked tensor.");
1690 return llvm::None;
1691 }
1692
1693 RankedTensorType input_type =
1694 op->getOperand(0).getType().dyn_cast<RankedTensorType>();
1695 if (!input_type) {
1696 op->emitOpError("LogSoftmax: input type not ranked tensor.");
1697 return llvm::None;
1698 }
1699
1700 mlir::quant::UniformQuantizedType in_quant_type =
1701 input_type.getElementType()
1702 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
1703 mlir::quant::UniformQuantizedType out_quant_type =
1704 output_type.getElementType()
1705 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
1706 if (in_quant_type || out_quant_type) {
1707 op->emitOpError("Quantized log_softmax lowering not implemented yet");
1708 return llvm::None;
1709 }
1710
1711 auto op1_exp_in =
1712 rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, logits_value);
1713
1714 // reduce_sum on last dimension
1715 int32_t input_rank = input_type.getShape().size();
1716 SmallVector<int64_t> rsum_shape(output_type.getShape().begin(),
1717 output_type.getShape().end());
1718 rsum_shape[input_rank - 1] = 1;
1719 RankedTensorType rsum_type =
1720 RankedTensorType::get(rsum_shape, output_type.getElementType());
1721 // Keep dims so we don't need to reshape later
1722 auto op2_reducesum_op1 = rewriter.create<tosa::ReduceSumOp>(
1723 op->getLoc(), rsum_type, op1_exp_in.getResult(),
1724 rewriter.getI64IntegerAttr(input_rank - 1));
1725 auto op3_reciprocal_op2 = rewriter.create<tosa::ReciprocalOp>(
1726 op->getLoc(), op2_reducesum_op1.getType(), op2_reducesum_op1.getResult());
1727
1728 auto op4_mul_op1_op3 = rewriter.create<tosa::MulOp>(
1729 op->getLoc(), output_type, op1_exp_in.getResult(),
1730 op3_reciprocal_op2.getResult(), 0);
1731
1732 return rewriter
1733 .create<tosa::LogOp>(op->getLoc(), output_type,
1734 op4_mul_op1_op3.getResult())
1735 .getResult();
1736 }
1737
1738 // Lowers SpaceToDepth to a sequence of TOSA ops. Supports NHWC.
convertSpaceToDepthOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,IntegerAttr block_size_attr,StringAttr data_format)1739 llvm::Optional<Value> convertSpaceToDepthOp(PatternRewriter& rewriter,
1740 Operation* op, Value result_value,
1741 Value input_value,
1742 IntegerAttr block_size_attr,
1743 StringAttr data_format) {
1744 // NHWC lowering version:
1745 // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1]//b, b, orig_shape[2]//b,
1746 // b, orig_shape[3]])
1747 // a3 = tf.transpose(a2, [0, 1, 3, 2, 4, 5])
1748 // a4 = tf.reshape(a3, [orig_shape[0], orig_shape[1]//b, orig_shape[2]//b,
1749 // orig_shape[3]*b*b])
1750 // return a4
1751 RankedTensorType output_type =
1752 result_value.getType().dyn_cast<RankedTensorType>();
1753
1754 // Not a ranked tensor output.
1755 if (!output_type) {
1756 op->emitOpError("SpaceToDepth: output type not ranked tensor.");
1757 return llvm::None;
1758 }
1759
1760 RankedTensorType input_type =
1761 input_value.getType().dyn_cast<RankedTensorType>();
1762 if (!input_type) {
1763 op->emitOpError("SpaceToDepth: input type not ranked tensor.");
1764 return llvm::None;
1765 }
1766
1767 if (input_type.getRank() != 4) {
1768 op->emitOpError("SpaceToDepth: input rank not 4.");
1769 return llvm::None;
1770 }
1771
1772 auto input_shape = input_type.getShape();
1773
1774 if (!block_size_attr) { // This is a required parameter
1775 op->emitOpError("SpaceToDepth: block size attribute not set.");
1776 return llvm::None;
1777 }
1778
1779 SmallVector<int64_t, 2> block_size;
1780 block_size.assign(2, block_size_attr.getInt());
1781
1782 if (!data_format) data_format = rewriter.getStringAttr("NHWC");
1783
1784 if (data_format.getValue().str() != "NHWC") {
1785 op->emitOpError("SpaceToDepth: data format not NHWC.");
1786 return llvm::None;
1787 }
1788
1789 assert(block_size[0] * block_size[1] != 0);
1790
1791 SmallVector<int64_t, 6> a_reshape_dims;
1792 a_reshape_dims.push_back(input_shape[0]);
1793 a_reshape_dims.push_back(input_shape[1] / block_size[0]);
1794 a_reshape_dims.push_back(block_size[0]);
1795 a_reshape_dims.push_back(input_shape[2] / block_size[1]);
1796 a_reshape_dims.push_back(block_size[1]);
1797 a_reshape_dims.push_back(input_shape[3]);
1798
1799 RankedTensorType a_reshape_output_type =
1800 RankedTensorType::get(a_reshape_dims, output_type.getElementType());
1801 auto a2_reshape_a_op = rewriter.create<tosa::ReshapeOp>(
1802 op->getLoc(), a_reshape_output_type, input_value,
1803 rewriter.getI64ArrayAttr(a_reshape_dims));
1804
1805 llvm::Optional<Value> a3_transpose_perm = getConstTensor<int32_t>(
1806 rewriter, op, /*vec=*/{0, 1, 3, 2, 4, 5}, /*shape=*/{6});
1807
1808 if (!a3_transpose_perm) return llvm::None;
1809
1810 auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
1811 op->getLoc(), a_reshape_output_type, a2_reshape_a_op.getResult(),
1812 a3_transpose_perm.getValue());
1813
1814 SmallVector<int64_t, 4> a3_reshape_dims;
1815 a3_reshape_dims.push_back(input_shape[0]);
1816 a3_reshape_dims.push_back(input_shape[1] / block_size[0]);
1817 a3_reshape_dims.push_back(input_shape[2] / block_size[1]);
1818 a3_reshape_dims.push_back(input_shape[3] * block_size[0] * block_size[1]);
1819
1820 RankedTensorType a3_reshape_output_type =
1821 RankedTensorType::get(a3_reshape_dims, output_type.getElementType());
1822 return rewriter
1823 .create<tosa::ReshapeOp>(op->getLoc(), a3_reshape_output_type,
1824 a3_transpose_a2_op.getResult(),
1825 rewriter.getI64ArrayAttr(a3_reshape_dims))
1826 .getResult();
1827 }
1828
1829 // Lowers DepthToSpace to a sequence of TOSA ops. Supports NHWC.
convertDepthToSpaceOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,IntegerAttr block_size_attr,StringAttr data_format)1830 llvm::Optional<Value> convertDepthToSpaceOp(PatternRewriter& rewriter,
1831 Operation* op, Value result_value,
1832 Value input_value,
1833 IntegerAttr block_size_attr,
1834 StringAttr data_format) {
1835 // NHWC version
1836 // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1], orig_shape[2], b, b,
1837 // orig_shape[3] // (b*b)])
1838 // a3 = tf.transpose(a2, [0, 1, 3, 2, 4, 5])
1839 // a4 = tf.reshape(a3, [orig_shape[0], orig_shape[1] * b, orig_shape[2] * b,
1840 // orig_shape[3] // (b*b)])
1841 // return a4
1842
1843 RankedTensorType output_type =
1844 result_value.getType().dyn_cast<RankedTensorType>();
1845
1846 // Not a ranked tensor output
1847 if (!output_type) {
1848 op->emitOpError("DepthToSpace: output type not ranked tensor.");
1849 return llvm::None;
1850 }
1851
1852 RankedTensorType input_type =
1853 input_value.getType().dyn_cast<RankedTensorType>();
1854 if (!input_type) {
1855 op->emitOpError("DepthToSpace: input type not ranked tensor.");
1856 return llvm::None;
1857 }
1858
1859 if (input_type.getRank() != 4) return llvm::None;
1860 auto input_shape = input_type.getShape();
1861
1862 if (!block_size_attr) { // This is a required parameter
1863 op->emitOpError("DepthToSpace: block size attribute not set.");
1864 return llvm::None;
1865 }
1866
1867 SmallVector<int64_t, 2> block_size;
1868 block_size.assign(2, block_size_attr.getInt());
1869
1870 if (!data_format) data_format = rewriter.getStringAttr("NHWC");
1871 if (data_format.getValue().str() != "NHWC") {
1872 op->emitOpError("DepthToSpace: data format not NHWC.");
1873 return llvm::None;
1874 }
1875
1876 assert(block_size[0] * block_size[1] != 0);
1877
1878 SmallVector<int64_t, 6> a_reshape_dims;
1879 a_reshape_dims.push_back(input_shape[0]);
1880 a_reshape_dims.push_back(input_shape[1]);
1881 a_reshape_dims.push_back(input_shape[2]);
1882 a_reshape_dims.push_back(block_size[0]);
1883 a_reshape_dims.push_back(block_size[1]);
1884 a_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
1885
1886 RankedTensorType a_reshape_output_type =
1887 RankedTensorType::get(a_reshape_dims, output_type.getElementType());
1888 auto a2_reshape_a_op = rewriter.create<tosa::ReshapeOp>(
1889 op->getLoc(), a_reshape_output_type, input_value,
1890 rewriter.getI64ArrayAttr(a_reshape_dims));
1891
1892 llvm::Optional<Value> a3_transpose_perm = getConstTensor<int32_t>(
1893 rewriter, op, /*vec=*/{0, 1, 3, 2, 4, 5}, /*shape=*/{6});
1894
1895 if (!a3_transpose_perm) return llvm::None;
1896
1897 auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
1898 op->getLoc(), a_reshape_output_type, a2_reshape_a_op.getResult(),
1899 a3_transpose_perm.getValue());
1900
1901 SmallVector<int64_t, 4> a3_reshape_dims;
1902 a3_reshape_dims.push_back(input_shape[0]);
1903 a3_reshape_dims.push_back(input_shape[1] * block_size[0]);
1904 a3_reshape_dims.push_back(input_shape[2] * block_size[1]);
1905 a3_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
1906
1907 RankedTensorType a3_reshape_output_type =
1908 RankedTensorType::get(a3_reshape_dims, output_type.getElementType());
1909 return rewriter
1910 .create<tosa::ReshapeOp>(op->getLoc(), a3_reshape_output_type,
1911 a3_transpose_a2_op.getResult(),
1912 rewriter.getI64ArrayAttr(a3_reshape_dims))
1913 .getResult();
1914 }
1915
1916 // Lowers Split to a sequence of TOSA ops.
convertSplitOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,int32_t num_split,int32_t axis)1917 llvm::Optional<SmallVector<Value>> convertSplitOp(
1918 PatternRewriter& rewriter, Operation* op, Value result_value,
1919 Value input_value, int32_t num_split, int32_t axis) {
1920 // This lowering creates num_split slice ops and ties them together
1921 // with IdentityN to get from an array of Operations to a single Operation
1922 // with a list of result tensors.
1923 RankedTensorType result_type =
1924 result_value.getType().dyn_cast<RankedTensorType>();
1925 // Not a ranked tensor output
1926 if (!result_type) {
1927 op->emitOpError("Split: output type not ranked tensor.");
1928 return llvm::None;
1929 }
1930
1931 RankedTensorType input_type =
1932 input_value.getType().dyn_cast<RankedTensorType>();
1933 if (!input_type) {
1934 op->emitOpError("Split: input type not ranked tensor.");
1935 return llvm::None;
1936 }
1937
1938 auto input_shape = input_type.getShape();
1939
1940 SmallVector<Value> results_vec;
1941
1942 assert(axis > 0 && axis < input_shape.size());
1943 assert((input_shape[axis] % num_split) == 0);
1944 assert(num_split > 0);
1945
1946 int64_t slice_size = input_shape[axis] / num_split;
1947
1948 for (int i = 0; i < num_split; i++) {
1949 // Each slice has a different begining point.
1950 // The slice size is actually the same each op.
1951 SmallVector<int64_t> begin_vals, size_vals;
1952
1953 for (int j = 0; j < input_shape.size(); j++) {
1954 if (j == axis) {
1955 begin_vals.push_back(slice_size * i);
1956 size_vals.push_back(slice_size);
1957 } else {
1958 begin_vals.push_back(0);
1959 size_vals.push_back(input_shape[j]);
1960 }
1961 }
1962
1963 ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
1964 ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
1965
1966 auto slice_op = rewriter.create<tosa::SliceOp>(
1967 op->getLoc(),
1968 RankedTensorType::get(size_vals, result_type.getElementType()),
1969 input_value, begin, size);
1970
1971 results_vec.push_back(slice_op.getResult());
1972 }
1973
1974 return results_vec;
1975 }
1976
1977 // Lowers SplitV to a sequence of TOSA ops.
convertSplitVOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,SmallVectorImpl<int32_t> & size_split,int32_t axis)1978 llvm::Optional<SmallVector<Value>> convertSplitVOp(
1979 PatternRewriter& rewriter, Operation* op, Value result_value,
1980 Value input_value, SmallVectorImpl<int32_t>& size_split, int32_t axis) {
1981 // This lowering creates num_split slice ops and ties them together
1982 // with IdentityN to get from an array of Operations to a single Operation
1983 // with a list of result tensors.
1984 RankedTensorType result_type =
1985 result_value.getType().dyn_cast<RankedTensorType>();
1986 // Not a ranked tensor output
1987 if (!result_type) {
1988 op->emitOpError("SplitV: output type not ranked tensor.");
1989 return llvm::None;
1990 }
1991
1992 RankedTensorType input_type =
1993 input_value.getType().dyn_cast<RankedTensorType>();
1994 if (!input_type) {
1995 op->emitOpError("SplitV: input type not ranked tensor.");
1996 return llvm::None;
1997 }
1998
1999 auto input_shape = input_type.getShape();
2000
2001 SmallVector<Value> results_vec;
2002
2003 assert(axis > 0 && axis < input_shape.size());
2004 int32_t size_split_sum = 0;
2005 for (int i = 0; i < size_split.size(); i++) {
2006 size_split_sum += size_split[i];
2007 }
2008
2009 // The split sizes must sum up to the size of the axis being split
2010 assert(size_split_sum == input_shape[axis]);
2011
2012 int32_t curr_split_start = 0;
2013 for (int i = 0; i < size_split.size(); i++) {
2014 // Each slice has a different begining point.
2015 // The slice size is different for each op.
2016 SmallVector<int64_t> begin_vals, size_vals;
2017
2018 for (int j = 0; j < input_shape.size(); j++) {
2019 if (j == axis) {
2020 begin_vals.push_back(curr_split_start);
2021 size_vals.push_back(size_split[i]);
2022 } else {
2023 begin_vals.push_back(0);
2024 size_vals.push_back(input_shape[j]);
2025 }
2026 }
2027
2028 ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
2029 ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
2030
2031 auto slice_op = rewriter.create<tosa::SliceOp>(
2032 op->getLoc(),
2033 RankedTensorType::get(size_vals, result_type.getElementType()),
2034 input_value, begin, size);
2035
2036 results_vec.push_back(slice_op.getResult());
2037
2038 // Next start position
2039 curr_split_start += size_split[i];
2040 }
2041
2042 return results_vec;
2043 }
2044
2045 // Lowers StridedSlice to a sequence of TOSA ops.
convertStridedSliceOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value begin_value,Value end_value,Value strides_value,int32_t begin_mask,int32_t end_mask,int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask)2046 llvm::Optional<Value> convertStridedSliceOp(
2047 PatternRewriter& rewriter, Operation* op, Value result_value,
2048 Value input_value, Value begin_value, Value end_value, Value strides_value,
2049 int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
2050 int32_t new_axis_mask, int32_t shrink_axis_mask) {
2051 // The mask arguments are bitmasks where bit [i] applies to
2052 // dimension [i] of the input tensor.
2053 //
2054 // The rough algorithm for lowering strided slice is as follows:
2055 //
2056 // 0. Process begin/end masks, since they are basically syntactic sugar
2057 // on top of the begin_value/end_value arrays
2058 //
2059 // 1. Slice1: Ignoring stride, slice the interesting range from the input
2060 // tensor
2061 //
2062 // 2. Reshape2: Reshape the tensor from (1) such that each dimension with
2063 // stride is split into two dimensions of size_i/stride_i, stride_i. A naive
2064 // implementation doubles the input tensor rank, but only dimensions being
2065 // strided actually need to be doubled.
2066 //
2067 // 3. Slice3: Slice the tensor from (2) such that we select index [0] from
2068 // each of the stride_i dimensions in (2)
2069 //
2070 // 4. Reshape4: Reshape the tensor to eliminate the stride_i dimensions, add
2071 // any dimensions in new_axis_mask and remove any dimensions in the
2072 // shrink_axis_mask
2073
2074 // Limitations:
2075 // * This implementation only supports ellipsis_mask=0 for now
2076 // * This implementation does not support reverse stride yet. Will need
2077 // to insert tosa.Reverse operators for this.
2078 if (ellipsis_mask != 0) {
2079 (void)rewriter.notifyMatchFailure(op, "ellipses mask not supported yet");
2080 }
2081
2082 ShapedType input_type = input_value.getType().cast<ShapedType>();
2083 ShapedType result_type = result_value.getType().cast<ShapedType>();
2084
2085 // Extract the begin/end/stride tensors
2086 SmallVector<int32_t> begin, end, strides;
2087
2088 DenseIntElementsAttr strides_attr;
2089
2090 if (!matchPattern(strides_value, m_Constant(&strides_attr))) {
2091 (void)rewriter.notifyMatchFailure(op, "strides is not a constant");
2092 return llvm::None;
2093 }
2094
2095 bool all_strides_one =
2096 strides_attr.isSplat() && strides_attr.getSplatValue<int32_t>() == 1;
2097 int32_t strides_size = strides_attr.getNumElements();
2098
2099 // If all of the masks are set we can just bypass the entire thing.
2100 const int32_t all_masks_one = (1 << strides_size) - 1;
2101 if (all_strides_one && begin_mask == all_masks_one &&
2102 end_mask == all_masks_one) {
2103 return rewriter
2104 .create<tensor::CastOp>(op->getLoc(), result_type, input_value)
2105 .getResult();
2106 }
2107
2108 if (failed(getVectorFromValue32(begin_value, begin))) {
2109 (void)rewriter.notifyMatchFailure(op, "begin isn't a constant");
2110 return llvm::None;
2111 }
2112
2113 // If begin value is a constant we might be able to still bypass.
2114 for (auto val : llvm::enumerate(begin)) {
2115 if (val.value() == 0) begin_mask |= (0x1 << val.index());
2116 }
2117
2118 if (all_strides_one && begin_mask == all_masks_one &&
2119 end_mask == all_masks_one) {
2120 return rewriter
2121 .create<tensor::CastOp>(op->getLoc(), result_type, input_value)
2122 .getResult();
2123 }
2124
2125 if (failed(getVectorFromValue32(end_value, end))) {
2126 return (void)rewriter.notifyMatchFailure(op, "end isn't a constant"),
2127 llvm::None;
2128 }
2129
2130 if (!input_type.hasRank()) {
2131 return (void)rewriter.notifyMatchFailure(op,
2132 "input type not ranked tensor."),
2133 llvm::None;
2134 }
2135
2136 int32_t input_rank = input_type.getRank();
2137
2138 if (failed(getVectorFromValue32(strides_value, strides))) {
2139 return (void)rewriter.notifyMatchFailure(op, "strides isn't a constant"),
2140 llvm::None;
2141 }
2142
2143 // If strides is incomplete, pad out to the full size.
2144 while (strides.size() < input_rank) strides.push_back(1);
2145
2146 // Unspecified begins should set the begin mask.
2147 while (begin.size() < input_rank) {
2148 begin_mask = begin_mask | (1 << begin.size());
2149 begin.push_back(0);
2150 }
2151
2152 // Unspecified ends should set the end mask.
2153 while (end.size() < input_rank) {
2154 end_mask = end_mask | (1 << end.size());
2155 end.push_back(-1);
2156 }
2157
2158 auto input_shape = input_type.getShape();
2159
2160 SmallVector<int64_t> a1_begin(input_rank), a1_size(input_rank);
2161 SmallVector<int64_t> a2_shape(input_rank * 2);
2162 SmallVector<int64_t> a3_begin(input_rank * 2), a3_size(input_rank * 2);
2163 SmallVector<int64_t> a4_shape;
2164
2165 // Step 0: Process the begin/end masks and build the begin/sizes for the
2166 // first slice
2167 int residual = 1;
2168 (void)residual;
2169 for (int i = 0; i < input_rank; i++) {
2170 if (begin_mask & (1 << i)) begin[i] = 0;
2171
2172 if (end_mask & (1 << i)) end[i] = input_shape[i];
2173
2174 // Wrap around index if begin and end is negative
2175 if (begin[i] < 0) begin[i] += input_shape[i];
2176
2177 if (end[i] < 0) end[i] += input_shape[i];
2178
2179 // TODO(suderman): support reverse stride
2180 a1_begin[i] = begin[i];
2181 a1_size[i] = end[i] - begin[i];
2182
2183 a2_shape[i * 2 + 0] = a1_size[i] / strides[i];
2184 a2_shape[i * 2 + 1] = strides[i];
2185
2186 a3_begin[i * 2 + 0] = 0;
2187 a3_begin[i * 2 + 1] = 0;
2188
2189 if (shrink_axis_mask & (1 << i)) {
2190 a3_size[i * 2 + 0] = 1;
2191 } else {
2192 a3_size[i * 2 + 0] = a1_size[i] / strides[i];
2193 }
2194 a3_size[i * 2 + 1] = 1;
2195
2196 if (!(shrink_axis_mask & (1 << i))) {
2197 if (new_axis_mask & (1 << i)) a4_shape.push_back(1);
2198 a4_shape.push_back((a1_size[i] / strides[i]));
2199 }
2200 }
2201
2202 // Make sure we didn't lose any dimensions from the shrink_axis_mask
2203 assert(residual == 1);
2204
2205 // Step 1: Slice the input array
2206 auto a1_slice_op = rewriter.create<tosa::SliceOp>(
2207 op->getLoc(), RankedTensorType::get(a1_size, input_type.getElementType()),
2208 input_value, rewriter.getI64ArrayAttr(a1_begin),
2209 rewriter.getI64ArrayAttr(a1_size));
2210
2211 if (all_strides_one) {
2212 return rewriter
2213 .create<tensor::CastOp>(op->getLoc(), result_type, a1_slice_op)
2214 .getResult();
2215 }
2216
2217 // Step 2: reshape the sliced array
2218 auto a2_reshape_op = rewriter.create<tosa::ReshapeOp>(
2219 op->getLoc(),
2220 RankedTensorType::get(a2_shape, input_type.getElementType()),
2221 a1_slice_op.getResult(), rewriter.getI64ArrayAttr(a2_shape));
2222
2223 // Step 3: take a slice along the strides
2224 auto a3_slice_op = rewriter.create<tosa::SliceOp>(
2225 op->getLoc(), RankedTensorType::get(a3_size, input_type.getElementType()),
2226 a2_reshape_op.getResult(), rewriter.getI64ArrayAttr(a3_begin),
2227 rewriter.getI64ArrayAttr(a3_size));
2228
2229 // Step 4: reshape the now-strided tensor
2230 return rewriter
2231 .create<tosa::ReshapeOp>(op->getLoc(), result_type,
2232 a3_slice_op.getResult(),
2233 rewriter.getI64ArrayAttr(a4_shape))
2234 .getResult();
2235 }
2236
2237 // Lowers FloorDiv to a sequence of TOSA operators.
convertFloorDivOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value lhs_value,Value rhs_value)2238 llvm::Optional<Value> convertFloorDivOp(PatternRewriter& rewriter,
2239 Operation* op, Value result_value,
2240 Value lhs_value, Value rhs_value) {
2241 // FloorDiv lowering:
2242 // floor(1/rhs * lhs)
2243 //
2244 // a1 = reciprocal(rhs);
2245 // a2 = mul(lhs, a1);
2246 // a3 = floor(a2);
2247 // return a3;
2248 RankedTensorType output_type =
2249 result_value.getType().dyn_cast<RankedTensorType>();
2250 // Not a ranked tensor output
2251 if (!output_type) return llvm::None;
2252
2253 Type element_type = output_type.getElementType();
2254
2255 if (element_type.isa<IntegerType>()) {
2256 return rewriter
2257 .create<tosa::DivOp>(op->getLoc(), output_type, lhs_value, rhs_value)
2258 .getResult();
2259 }
2260
2261 auto a1_reciprocal_rhs_op = rewriter.create<tosa::ReciprocalOp>(
2262 op->getLoc(), rhs_value.getType(), rhs_value);
2263 auto a2_mul_lhs_a1_op =
2264 rewriter.create<tosa::MulOp>(op->getLoc(), output_type, lhs_value,
2265 a1_reciprocal_rhs_op.getResult(), 0);
2266 return rewriter
2267 .create<tosa::FloorOp>(op->getLoc(), output_type,
2268 a2_mul_lhs_a1_op.getResult())
2269 .getResult();
2270 }
2271
2272 // Lowers FloorMod to a sequence of TOSA operators.
convertFloorModOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value lhs_value,Value rhs_value)2273 llvm::Optional<Value> convertFloorModOp(PatternRewriter& rewriter,
2274 Operation* op, Value result_value,
2275 Value lhs_value, Value rhs_value) {
2276 // FloorMod lowering:
2277 // (1/rhs * lhs) - floor(1/rhs * lhs)
2278 // a1 = reciprocal(rhs);
2279 // a2 = mul(lhs, a1);
2280 // a3 = floor(a2);
2281 // a4 = sub(a2, a3);
2282 // return a4;
2283
2284 RankedTensorType output_type =
2285 result_value.getType().dyn_cast<RankedTensorType>();
2286 // Not a ranked tensor output
2287 if (!output_type) return llvm::None;
2288
2289 auto a1_reciprocal_rhs_op = rewriter.create<tosa::ReciprocalOp>(
2290 op->getLoc(), rhs_value.getType(), rhs_value);
2291 auto a2_mul_lhs_a1_op =
2292 rewriter.create<tosa::MulOp>(op->getLoc(), output_type, lhs_value,
2293 a1_reciprocal_rhs_op.getResult(), 0);
2294 auto a3_floor_a2_op = rewriter.create<tosa::FloorOp>(
2295 op->getLoc(), output_type, a2_mul_lhs_a1_op.getResult());
2296 return rewriter
2297 .create<tosa::SubOp>(op->getLoc(), output_type,
2298 a2_mul_lhs_a1_op.getResult(),
2299 a3_floor_a2_op.getResult())
2300 .getResult();
2301 }
2302
2303 // Lowers FusedActivation to a sequence of TOSA ops.
convertFusedActivation(PatternRewriter & rewriter,Operation * op,Value input_value,StringAttr fused_activation_fn)2304 llvm::Optional<Value> convertFusedActivation(PatternRewriter& rewriter,
2305 Operation* op, Value input_value,
2306 StringAttr fused_activation_fn) {
2307 ShapedType input_type = input_value.getType().dyn_cast<ShapedType>();
2308 if (!input_type) return llvm::None;
2309
2310 bool input_is_qtype =
2311 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2312
2313 if (input_is_qtype) {
2314 // We can always make output/input tensor's scale/zp always be the same
2315 // when legalizing fused_activation_function, as it's generated during
2316 // legalization.
2317 auto input_qtype =
2318 input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2319
2320 if (fused_activation_fn.getValue() == "NONE") {
2321 return input_value;
2322 } else if (fused_activation_fn.getValue() == "RELU") {
2323 int32_t quantized_0 = input_qtype.getZeroPoint();
2324 int32_t quantized_max = input_qtype.getStorageTypeMax();
2325
2326 auto clamp_op = rewriter.create<tosa::ClampOp>(
2327 op->getLoc(), input_type, input_value,
2328 rewriter.getI64IntegerAttr(quantized_0),
2329 rewriter.getI64IntegerAttr(quantized_max),
2330 rewriter.getF32FloatAttr(0), rewriter.getF32FloatAttr(0));
2331
2332 return clamp_op.getResult();
2333 } else if (fused_activation_fn.getValue() == "RELU6") {
2334 int32_t quantized_0 = input_qtype.getZeroPoint();
2335 int32_t quantized_6 = std::llround((6.0f / input_qtype.getScale()) +
2336 input_qtype.getZeroPoint());
2337
2338 auto clamp_op = rewriter.create<tosa::ClampOp>(
2339 op->getLoc(), input_type, input_value,
2340 rewriter.getI64IntegerAttr(quantized_0),
2341 rewriter.getI64IntegerAttr(quantized_6), rewriter.getF32FloatAttr(0),
2342 rewriter.getF32FloatAttr(0));
2343
2344 return clamp_op.getResult();
2345 } else if (fused_activation_fn.getValue() == "RELU_N1_TO_1") {
2346 int32_t quantized_n1 = std::llround((-1.0f / input_qtype.getScale()) +
2347 input_qtype.getZeroPoint());
2348 int32_t quantized_1 = std::llround((1.0f / input_qtype.getScale()) +
2349 input_qtype.getZeroPoint());
2350
2351 auto clamp_op = rewriter.create<tosa::ClampOp>(
2352 op->getLoc(), input_type, input_value,
2353 rewriter.getI64IntegerAttr(quantized_n1),
2354 rewriter.getI64IntegerAttr(quantized_1), rewriter.getF32FloatAttr(0),
2355 rewriter.getF32FloatAttr(0));
2356
2357 return clamp_op.getResult();
2358 } else {
2359 op->emitWarning("convertFusedActivation: Not implemented yet");
2360 return llvm::None;
2361 }
2362 } else {
2363 if (fused_activation_fn.getValue() == "NONE") {
2364 return input_value;
2365 } else {
2366 // For non-quantized type, only support F32.
2367 if (!input_type.getElementType().isF32()) {
2368 op->emitOpError("ConvertTFLeakyReluOp: only support F32");
2369 return llvm::None;
2370 }
2371
2372 if (fused_activation_fn.getValue() == "RELU") {
2373 return rewriter
2374 .create<tosa::ClampOp>(
2375 op->getLoc(), input_type, input_value,
2376 rewriter.getI64IntegerAttr(0),
2377 rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
2378 rewriter.getF32FloatAttr(0.0f),
2379 rewriter.getF32FloatAttr(std::numeric_limits<float>::max()))
2380 .getResult();
2381 } else if (fused_activation_fn.getValue() == "RELU6") {
2382 return rewriter
2383 .create<tosa::ClampOp>(
2384 op->getLoc(), input_type, input_value,
2385 rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(6),
2386 rewriter.getF32FloatAttr(0.0f), rewriter.getF32FloatAttr(6.0f))
2387 .getResult();
2388 } else if (fused_activation_fn.getValue() == "RELU_N1_TO_1") {
2389 return rewriter
2390 .create<tosa::ClampOp>(
2391 op->getLoc(), input_type, input_value,
2392 rewriter.getI64IntegerAttr(-1), rewriter.getI64IntegerAttr(1),
2393 rewriter.getF32FloatAttr(-1.0), rewriter.getF32FloatAttr(1.0))
2394 .getResult();
2395 } else if (fused_activation_fn.getValue() == "TANH") {
2396 return rewriter
2397 .create<tosa::TanhOp>(op->getLoc(), input_type, input_value)
2398 .getResult();
2399 } else {
2400 // Unsupported activation type. Bail out.
2401 return llvm::None;
2402 }
2403 }
2404 }
2405
2406 return llvm::None;
2407 }
2408
2409 // Common function for lowering reduce operations to TOSA ops.
2410 template <typename T>
convertReduceOpCommon(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims,Type reduce_element_type,bool is_quantized,double input_scale,int64_t input_zp,double output_scale,int64_t output_zp)2411 llvm::Optional<Value> convertReduceOpCommon(
2412 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2413 Value input_value, ElementsAttr axes_elems, bool keep_dims,
2414 Type reduce_element_type, bool is_quantized, double input_scale,
2415 int64_t input_zp, double output_scale, int64_t output_zp) {
2416 RankedTensorType input_type =
2417 input_value.getType().dyn_cast<RankedTensorType>();
2418 if (!input_type) return llvm::None;
2419
2420 ArrayRef<int64_t> input_shape = input_type.getShape();
2421 ArrayRef<int64_t> output_shape = output_type.getShape();
2422 auto input_rank = input_shape.size();
2423 Value val = input_value;
2424
2425 if (axes_elems.getNumElements() == 0) {
2426 // No axes means return the original tensor.
2427 auto identity_op =
2428 rewriter.create<tosa::IdentityOp>(op->getLoc(), output_type, val);
2429 val = identity_op.getResult();
2430 } else {
2431 // Reduce along each axis
2432 SmallVector<int64_t> shape_vec(input_shape.begin(), input_shape.end());
2433
2434 if (is_quantized) {
2435 val = buildRescaleToInt32(rewriter, op, val, input_scale, input_zp);
2436 }
2437
2438 for (int i = 0; i < axes_elems.getNumElements(); i++) {
2439 int64_t axis_val = axes_elems.getValue<IntegerAttr>(i).getInt();
2440 if (axis_val < 0) axis_val += input_rank;
2441 auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
2442
2443 shape_vec[axis_val] = 1;
2444 RankedTensorType reduce_type =
2445 RankedTensorType::get(shape_vec, reduce_element_type);
2446
2447 auto reduce_op =
2448 rewriter.create<T>(op->getLoc(), reduce_type, val, axis_attr);
2449
2450 val = reduce_op.getResult();
2451 }
2452
2453 if (is_quantized) {
2454 RankedTensorType output_rescale_type =
2455 RankedTensorType::get(shape_vec, output_type.getElementType());
2456 val = buildRescale(rewriter, op, output_rescale_type, val, output_scale,
2457 0, output_zp, false, true);
2458 }
2459
2460 // Optionally squeeze out the reduced axes.
2461 if (!keep_dims) {
2462 auto reshape_op = rewriter.create<tosa::ReshapeOp>(
2463 op->getLoc(), output_type, val,
2464 rewriter.getI64ArrayAttr(output_shape));
2465 val = reshape_op.getResult();
2466 }
2467 }
2468
2469 return val;
2470 }
2471
2472 // Lowers ReduceAll to a sequence of TOSA ops.
convertReduceAllOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2473 llvm::Optional<Value> convertReduceAllOp(
2474 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2475 Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2476 RankedTensorType input_type =
2477 input_value.getType().dyn_cast<RankedTensorType>();
2478 if (!input_type) return llvm::None;
2479
2480 return convertReduceOpCommon<tosa::ReduceAllOp>(
2481 rewriter, op, output_type, input_value, axes_elems, keep_dims,
2482 output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2483 }
2484
2485 // Lowers ReduceAny to a sequence of TOSA ops.
convertReduceAnyOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2486 llvm::Optional<Value> convertReduceAnyOp(
2487 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2488 Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2489 RankedTensorType input_type =
2490 input_value.getType().dyn_cast<RankedTensorType>();
2491 if (!input_type) return llvm::None;
2492
2493 return convertReduceOpCommon<tosa::ReduceAnyOp>(
2494 rewriter, op, output_type, input_value, axes_elems, keep_dims,
2495 output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2496 }
2497
2498 // Lowers ReduceMin to a sequence of TOSA ops.
convertReduceMinOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2499 llvm::Optional<Value> convertReduceMinOp(
2500 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2501 Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2502 RankedTensorType input_type =
2503 input_value.getType().dyn_cast<RankedTensorType>();
2504 if (!input_type) return llvm::None;
2505
2506 return convertReduceOpCommon<tosa::ReduceMinOp>(
2507 rewriter, op, output_type, input_value, axes_elems, keep_dims,
2508 output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2509 }
2510
2511 // Lowers ReduceMax to a sequence of TOSA ops.
convertReduceMaxOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2512 llvm::Optional<Value> convertReduceMaxOp(
2513 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2514 Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2515 RankedTensorType input_type =
2516 input_value.getType().dyn_cast<RankedTensorType>();
2517 if (!input_type) return llvm::None;
2518
2519 return convertReduceOpCommon<tosa::ReduceMaxOp>(
2520 rewriter, op, output_type, input_value, axes_elems, keep_dims,
2521 output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2522 }
2523
2524 // Lowers ReduceProd to a sequence of TOSA ops.
convertReduceProdOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2525 llvm::Optional<Value> convertReduceProdOp(
2526 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2527 Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2528 RankedTensorType input_type =
2529 input_value.getType().dyn_cast<RankedTensorType>();
2530 if (!input_type) return llvm::None;
2531
2532 bool input_is_qtype =
2533 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2534 bool output_is_qtype =
2535 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2536
2537 if (input_is_qtype || output_is_qtype) {
2538 op->emitOpError(
2539 "ConvertReduceProdOp: input/output tensor should "
2540 "be all floating-point.");
2541 return llvm::None;
2542 }
2543
2544 return convertReduceOpCommon<tosa::ReduceProdOp>(
2545 rewriter, op, output_type, input_value, axes_elems, keep_dims,
2546 output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2547 }
2548
2549 // Lowers ReduceSum to a sequence of TOSA ops.
convertReduceSumOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2550 llvm::Optional<Value> convertReduceSumOp(
2551 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2552 Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2553 RankedTensorType input_type =
2554 input_value.getType().dyn_cast<RankedTensorType>();
2555 if (!input_type) return llvm::None;
2556
2557 bool input_is_qtype =
2558 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2559 bool output_is_qtype =
2560 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2561
2562 if (input_is_qtype != output_is_qtype) {
2563 op->emitOpError(
2564 "ConvertReduceSumOp: input/output tensor should "
2565 "be all quantized or all floating-point.");
2566 return llvm::None;
2567 }
2568
2569 double input_scale = 1.0f;
2570 double output_scale = 1.0f;
2571 int64_t input_zp = 0;
2572 int64_t output_zp = 0;
2573 Type reduce_element_type = input_type.getElementType();
2574
2575 if (input_is_qtype) {
2576 auto input_qtype =
2577 input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2578 auto output_qtype =
2579 output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2580
2581 int32_t input_shift = 20;
2582
2583 input_scale =
2584 static_cast<double>(1 << input_shift) * input_qtype.getScale();
2585 output_scale =
2586 1.0 / (output_qtype.getScale() * static_cast<double>(1 << input_shift));
2587
2588 input_zp = input_qtype.getZeroPoint();
2589 output_zp = output_qtype.getZeroPoint();
2590 reduce_element_type = rewriter.getI32Type();
2591 }
2592
2593 return convertReduceOpCommon<tosa::ReduceSumOp>(
2594 rewriter, op, output_type, input_value, axes_elems, keep_dims,
2595 reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
2596 output_zp);
2597 }
2598
2599 // Lowers ReduceMean to a sequence of TOSA ops.
convertReduceMeanOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2600 llvm::Optional<Value> convertReduceMeanOp(
2601 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2602 Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2603 // reduce_mean is lowered as followed:
2604 // op1 = reduce_sum(input)
2605 // op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
2606
2607 RankedTensorType input_type =
2608 input_value.getType().dyn_cast<RankedTensorType>();
2609 if (!input_type) return llvm::None;
2610
2611 bool input_is_qtype =
2612 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2613 bool output_is_qtype =
2614 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2615
2616 if (input_is_qtype != output_is_qtype) {
2617 op->emitOpError(
2618 "ConvertReduceSumOp: input/output tensor should "
2619 "be all quantized or all floating-point.");
2620 return llvm::None;
2621 }
2622
2623 // Only supports float type mean() if it's non-quantized
2624 if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) {
2625 op->emitWarning(
2626 "Failed convertReduceMean: input unquantized type but output element "
2627 "not FloatType!");
2628 return llvm::None;
2629 }
2630
2631 int64_t input_rank = input_type.getRank();
2632 int64_t num_elems_on_reduced_axis = 1;
2633 for (int i = 0; i < axes_elems.getNumElements(); i++) {
2634 int64_t axis_val = axes_elems.getValue<IntegerAttr>(i).getInt();
2635 if (axis_val < 0) axis_val += input_rank;
2636 num_elems_on_reduced_axis *= input_type.getShape()[axis_val];
2637 }
2638 double div_scale = 1.0 / static_cast<double>(num_elems_on_reduced_axis);
2639
2640 double input_scale = 1.0f;
2641 double output_scale = 1.0f;
2642 int64_t input_zp = 0;
2643 int64_t output_zp = 0;
2644 Type reduce_element_type = input_type.getElementType();
2645
2646 if (input_is_qtype) {
2647 auto input_qtype =
2648 input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2649 auto output_qtype =
2650 output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2651
2652 // Combine 'div_scale' as part of output rescale
2653 output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale();
2654
2655 input_zp = input_qtype.getZeroPoint();
2656 output_zp = output_qtype.getZeroPoint();
2657 reduce_element_type = rewriter.getI32Type();
2658 }
2659
2660 auto val = convertReduceOpCommon<tosa::ReduceSumOp>(
2661 rewriter, op, output_type, input_value, axes_elems, keep_dims,
2662 reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
2663 output_zp);
2664
2665 if (!val.hasValue()) return llvm::None;
2666
2667 if (!input_is_qtype) {
2668 Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
2669 return rewriter
2670 .create<tosa::MulOp>(op->getLoc(), output_type, val.getValue(),
2671 div_const, 0)
2672 .getResult();
2673 }
2674
2675 return val;
2676 }
2677
2678 // Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize.
convertResizeOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,StringRef mode,bool align_corners,bool half_pixel_centers)2679 llvm::Optional<Value> convertResizeOp(PatternRewriter& rewriter, Operation* op,
2680 RankedTensorType output_type,
2681 Value input_value, StringRef mode,
2682 bool align_corners,
2683 bool half_pixel_centers) {
2684 RankedTensorType input_type =
2685 input_value.getType().dyn_cast<RankedTensorType>();
2686 if (!input_type) return llvm::None;
2687
2688 if (input_type.getRank() != 4 || output_type.getRank() != 4) {
2689 op->emitOpError("convertResizeOp: input/output must be rank 4");
2690 return llvm::None;
2691 }
2692
2693 bool input_is_qtype =
2694 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2695 bool output_is_qtype =
2696 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2697
2698 if (input_is_qtype != output_is_qtype) {
2699 op->emitOpError(
2700 "ConvertResizeOp: input/output tensor should "
2701 "be all quantized or all floating-point.");
2702 return llvm::None;
2703 }
2704
2705 if (!input_is_qtype) {
2706 if (!input_type.getElementType().isa<mlir::FloatType>()) {
2707 op->emitOpError(
2708 "ConvertResizeOp: only quantized or float types supported.");
2709 return llvm::None;
2710 }
2711 }
2712
2713 auto input_shape = input_type.getShape();
2714 auto output_shape = output_type.getShape();
2715
2716 size_t input_height = input_shape[1];
2717 size_t input_width = input_shape[2];
2718 size_t output_height = output_shape[1];
2719 size_t output_width = output_shape[2];
2720
2721 double fp_stride_y =
2722 static_cast<double>(input_height) / static_cast<double>(output_height);
2723 double fp_stride_x =
2724 static_cast<double>(input_width) / static_cast<double>(output_width);
2725 if (align_corners && output_height > 1) {
2726 fp_stride_y = static_cast<double>(input_height - 1) /
2727 static_cast<double>(output_height - 1);
2728 }
2729 if (align_corners && output_width > 1) {
2730 fp_stride_x = static_cast<double>(input_width - 1) /
2731 static_cast<double>(output_width - 1);
2732 }
2733
2734 double fp_offset_y, fp_offset_x;
2735 if (half_pixel_centers) {
2736 fp_offset_y = fp_stride_y * 0.5f - 0.5f;
2737 fp_offset_x = fp_stride_x * 0.5f - 0.5f;
2738 } else {
2739 fp_offset_y = 0.0f;
2740 fp_offset_x = 0.0f;
2741 }
2742
2743 // oh * fp_stride_y + fp_offset_y = ix
2744
2745 ArrayAttr output_size =
2746 rewriter.getI64ArrayAttr({static_cast<int64_t>(output_height),
2747 static_cast<int64_t>(output_width)});
2748 StringAttr resize_mode = rewriter.getStringAttr(mode);
2749
2750 if (input_is_qtype) {
2751 // Magic shift number TFLite resize bilinear use
2752 // reference: tensorflow/lite/kernels/internal/reference/reference_ops.h
2753 int32_t shift = 10;
2754
2755 // 1.0 is equivalent to (1 << shift) in quantized space.
2756 // Here we noted as unit = (1 << shift).
2757 double unit = static_cast<double>(1 << shift);
2758
2759 // Stride and Offset is int16.
2760 int32_t stride_y = std::lround(fp_stride_y * unit);
2761 int32_t stride_x = std::lround(fp_stride_x * unit);
2762 int32_t offset_y = std::lround(fp_offset_y * unit);
2763 int32_t offset_x = std::lround(fp_offset_x * unit);
2764
2765 // Numerically we can decrement shift to let these number fits within 16
2766 // bits but that's not commonly seen and won't match TFLite reference
2767 if (stride_y > std::numeric_limits<int16_t>::max() ||
2768 stride_x > std::numeric_limits<int16_t>::max() ||
2769 stride_y < std::numeric_limits<int16_t>::min() ||
2770 stride_x < std::numeric_limits<int16_t>::min() ||
2771 offset_y > std::numeric_limits<int16_t>::max() ||
2772 offset_x > std::numeric_limits<int16_t>::max() ||
2773 offset_y < std::numeric_limits<int16_t>::min() ||
2774 offset_x < std::numeric_limits<int16_t>::min()) {
2775 op->emitOpError("OpResize: stride or offset out of 16 bits");
2776 return llvm::None;
2777 }
2778
2779 ArrayAttr stride = rewriter.getI64ArrayAttr({stride_y, stride_x});
2780 ArrayAttr offset = rewriter.getI64ArrayAttr({offset_y, offset_x});
2781 IntegerAttr shift_attr = rewriter.getI32IntegerAttr(shift);
2782
2783 // If quantized bilinear mode, need to lower to RESIZE + RESCALE pair.
2784 if (mode == "BILINEAR") {
2785 RankedTensorType output_acc_type;
2786 auto input_element_qtype =
2787 input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2788
2789 bool scale32;
2790
2791 // TOSA RESIZE: 16 bit input -> 48 bit output, or 8 bit input -> 32 bit
2792 // output.
2793 if (input_element_qtype.getStorageTypeIntegralWidth() == 16) {
2794 scale32 = false;
2795 output_acc_type = RankedTensorType::get(output_type.getShape(),
2796 rewriter.getIntegerType(48));
2797 } else if (input_element_qtype.getStorageTypeIntegralWidth() == 8) {
2798 scale32 = true;
2799 output_acc_type = RankedTensorType::get(output_type.getShape(),
2800 rewriter.getI32Type());
2801 } else {
2802 op->emitOpError("OpResize: support 16-bit and 8-bit quantized input");
2803 return llvm::None;
2804 }
2805
2806 auto resize_op = rewriter.create<tosa::ResizeOp>(
2807 op->getLoc(), output_acc_type, input_value, output_size, stride,
2808 offset, shift_attr, rewriter.getF32ArrayAttr({0.0, 0.0}),
2809 rewriter.getF32ArrayAttr({0.0, 0.0}), resize_mode);
2810
2811 #ifdef RESIZE_BILINEAR_LOWER_SYMMETRIC_ROUNDING
2812 // TFLite resize_bilinear always assume input and output tensors have
2813 // same scale That means we only need to arithmetic right shift with
2814 // (2 * shift)
2815 // TODO(suderman): Align TFLite rounding behavior
2816 // TFLite also uses symmetric rounding by doing 'x / (1 << 20)'
2817 // TOSA arithmetic right shift is doing standard rounding.
2818 // Right now it's legalized using GreaterEqualOp + SelectOp to conform
2819 // to TFLite reference. But this eventually should be fixed in TFLite
2820 // reference
2821 Value cst_zero = getTosaConstTensorSingleI32(rewriter, op, 0);
2822 Value cst_twenty = getTosaConstTensorSingleI32(rewriter, op, 20);
2823
2824 auto ge_op = rewriter.create<tosa::GreaterEqualOp>(
2825 op->getLoc(), output_bool_type, resize_op.getResult(), cst_zero);
2826
2827 auto abs_op = rewriter.create<tosa::AbsOp>(op->getLoc(), output_acc_type,
2828 resize_op.getResult());
2829
2830 auto rshift_op = rewriter.create<tosa::ArithmeticRightShiftOp>(
2831 op->getLoc(), output_acc_type, abs_op.getResult(), cst_twenty, true);
2832
2833 auto negate_op = rewriter.create<tosa::NegateOp>(
2834 op->getLoc(), output_acc_type, rshift_op.getResult());
2835
2836 auto select_op = rewriter.create<tosa::SelectOp>(
2837 op->getLoc(), output_acc_type, ge_op.getResult(),
2838 rshift_op.getResult(), negate_op.getResult());
2839
2840 auto cast_op = rewriter.create<tosa::CastOp>(op->getLoc(), output_type,
2841 select_op.getResult());
2842
2843 return cast_op.getResult();
2844 #else
2845 // This should be the expected lowering, but is +-1 within compared to
2846 // TFLite reference.
2847 return buildRescale(rewriter, op, output_type, resize_op.getResult(),
2848 1.0 / (1 << 20), 0, 0, false, scale32);
2849 #endif
2850
2851 } else if (mode == "NEAREST_NEIGHBOR") {
2852 auto resize_op = rewriter.create<tosa::ResizeOp>(
2853 op->getLoc(), output_type, input_value, output_size, stride, offset,
2854 shift_attr, rewriter.getF32ArrayAttr({0.0, 0.0}),
2855 rewriter.getF32ArrayAttr({0.0, 0.0}), resize_mode);
2856 return resize_op.getResult();
2857 } else {
2858 op->emitOpError(
2859 "OpResize: only support BILINEAR or NEAREST_NEIGHBOR mode");
2860 return llvm::None;
2861 }
2862 } else {
2863 auto resize_op = rewriter.create<tosa::ResizeOp>(
2864 op->getLoc(), output_type, input_value, output_size,
2865 rewriter.getI64ArrayAttr({0, 0}), rewriter.getI64ArrayAttr({0, 0}),
2866 rewriter.getI32IntegerAttr(0),
2867 rewriter.getF32ArrayAttr(
2868 {static_cast<float>(fp_stride_y), static_cast<float>(fp_stride_x)}),
2869 rewriter.getF32ArrayAttr(
2870 {static_cast<float>(fp_offset_y), static_cast<float>(fp_offset_x)}),
2871 resize_mode);
2872 return resize_op.getResult();
2873 }
2874 }
2875
2876 // Lowers Quantize to a sequence of TOSA quantization ops.
convertQuantizeOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,double scale,int64_t zeropoint)2877 llvm::Optional<Value> convertQuantizeOp(PatternRewriter& rewriter,
2878 Operation* op,
2879 RankedTensorType output_type,
2880 Value input_value, double scale,
2881 int64_t zeropoint) {
2882 RankedTensorType input_type =
2883 input_value.getType().dyn_cast<RankedTensorType>();
2884 if (!input_type) return llvm::None;
2885
2886 auto output_shape = output_type.getShape();
2887 auto output_element_type = output_type.getElementType();
2888
2889 // output element type could only be quantized integer
2890 if (!output_element_type.isa<mlir::quant::QuantizedType>()) {
2891 op->emitWarning(
2892 "Lowering quantizeOp but output element type not quantized!");
2893 return llvm::None;
2894 }
2895
2896 RankedTensorType output_fp_type =
2897 RankedTensorType::get(output_shape, rewriter.getF32Type());
2898
2899 Value zp_val =
2900 getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(zeropoint));
2901
2902 auto op1_mul_in = rewriter.create<tosa::MulOp>(
2903 op->getLoc(), output_fp_type, input_value,
2904 getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale)), 0);
2905
2906 auto op2_add_op1 = rewriter.create<tosa::AddOp>(
2907 op->getLoc(), output_fp_type, op1_mul_in.getResult(), zp_val);
2908
2909 auto op3_cast_op2 = rewriter.create<tosa::CastOp>(op->getLoc(), output_type,
2910 op2_add_op1.getResult());
2911
2912 return op3_cast_op2.getResult();
2913 }
2914
2915 // Lowers Dequantize to a sequence of TOSA dequantization ops.
convertDequantizeOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ArrayRef<float> scale,ArrayRef<float> zeropoint,int64_t dim)2916 llvm::Optional<Value> convertDequantizeOp(
2917 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2918 Value input_value, ArrayRef<float> scale, ArrayRef<float> zeropoint,
2919 int64_t dim) {
2920 RankedTensorType input_type =
2921 input_value.getType().dyn_cast<RankedTensorType>();
2922 if (!input_type) return llvm::None;
2923
2924 // input element type could only be quantized integer
2925 if (!input_type.getElementType().isa<mlir::quant::QuantizedType>())
2926 return llvm::None;
2927
2928 Optional<Value> zp_val;
2929 if (zeropoint.size() == 1) {
2930 zp_val = getTosaConstTensorSingleF32(rewriter, op,
2931 static_cast<float>(zeropoint[0]));
2932 } else {
2933 SmallVector<int64_t> shape;
2934 shape.resize(input_type.getRank(), 1);
2935 shape[dim] = zeropoint.size();
2936 zp_val = getConstTensor(rewriter, op, zeropoint, shape);
2937 }
2938
2939 Optional<Value> scale_val;
2940 if (scale.size() == 1) {
2941 scale_val =
2942 getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale[0]));
2943 } else {
2944 SmallVector<int64_t> shape;
2945 shape.resize(input_type.getRank(), 1);
2946 shape[dim] = scale.size();
2947 scale_val = getConstTensor(rewriter, op, scale, shape);
2948 }
2949
2950 if (!zp_val || !scale_val) return llvm::None;
2951
2952 auto op1_cast_in =
2953 rewriter.create<tosa::CastOp>(op->getLoc(), output_type, input_value);
2954
2955 auto op2_sub_op1 = rewriter.create<tosa::SubOp>(
2956 op->getLoc(), output_type, op1_cast_in.getResult(), zp_val.getValue());
2957
2958 return rewriter
2959 .create<tosa::MulOp>(op->getLoc(), output_type, op2_sub_op1.getResult(),
2960 scale_val.getValue(), 0)
2961 .getResult();
2962 }
2963
2964 // Lowers FakeQuant to a sequence of TOSA quantization ops.
convertFakeQuantOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,double min,double max,int64_t num_bits,bool narrow_range)2965 llvm::Optional<Value> convertFakeQuantOp(PatternRewriter& rewriter,
2966 Operation* op,
2967 RankedTensorType output_type,
2968 Value input_value, double min,
2969 double max, int64_t num_bits,
2970 bool narrow_range) {
2971 // FakeQuant is lowered as follow:
2972 // op1 = quantize(input)
2973 // op2 = dequantize(op1)
2974
2975 RankedTensorType input_type =
2976 input_value.getType().dyn_cast<RankedTensorType>();
2977 if (!input_type) return llvm::None;
2978
2979 // quantized as INT<num_bits>, where num_bits can only be 8, 16
2980 if (num_bits != 8 && num_bits != 16) {
2981 op->emitWarning("FakeQuantOp lowering handles only 8 and 16 for num_bits!");
2982 return llvm::None;
2983 }
2984
2985 // This code originates from
2986 // tensorflow/core/kernels/fake_quant_ops_functor.h.
2987 int32_t qmax = (1 << (num_bits)) - 1;
2988 int32_t qmin = narrow_range ? 1 : 0;
2989
2990 float nudged_min, nudged_max, nudged_scale;
2991 tensorflow_nudge(min, max, qmin, qmax, &nudged_min, &nudged_max,
2992 &nudged_scale);
2993
2994 Value cst_min = getTosaConstTensorSingleF32(rewriter, op, nudged_min);
2995 Value cst_max = getTosaConstTensorSingleF32(rewriter, op, nudged_max);
2996 Value cst_scale = getTosaConstTensorSingleF32(rewriter, op, nudged_scale);
2997 Value cst_inv_scale =
2998 getTosaConstTensorSingleF32(rewriter, op, 1.0f / nudged_scale);
2999 Value cst_half = getTosaConstTensorSingleF32(rewriter, op, 0.5f);
3000
3001 // This code originates from
3002 // tensorflow/core/kernels/fake_quant_ops_functor.h.
3003 auto op1_min_in = rewriter.create<tosa::MinimumOp>(op->getLoc(), output_type,
3004 input_value, cst_max);
3005
3006 auto op2_max_op1 = rewriter.create<tosa::MaximumOp>(
3007 op->getLoc(), output_type, op1_min_in.getResult(), cst_min);
3008
3009 auto op3_sub_op2 = rewriter.create<tosa::SubOp>(
3010 op->getLoc(), output_type, op2_max_op1.getResult(), cst_min);
3011
3012 auto op4_mul_op3 = rewriter.create<tosa::MulOp>(
3013 op->getLoc(), output_type, op3_sub_op2.getResult(), cst_inv_scale, 0);
3014
3015 auto op5_add_op4 = rewriter.create<tosa::AddOp>(
3016 op->getLoc(), output_type, op4_mul_op3.getResult(), cst_half);
3017
3018 auto op6_floor_op5 = rewriter.create<tosa::FloorOp>(op->getLoc(), output_type,
3019 op5_add_op4.getResult());
3020
3021 auto op7_mul_op6 = rewriter.create<tosa::MulOp>(
3022 op->getLoc(), output_type, op6_floor_op5.getResult(), cst_scale, 0);
3023
3024 return rewriter
3025 .create<tosa::AddOp>(op->getLoc(), output_type, op7_mul_op6.getResult(),
3026 cst_min)
3027 .getResult();
3028 }
3029
convertTFConv2DCommon(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input,Value filter,Value bias,ArrayAttr strides_attr,ArrayAttr dilations_attr,ArrayAttr explicit_padding_attr,StringRef padding_ref,StringRef data_format_ref)3030 llvm::Optional<Value> convertTFConv2DCommon(
3031 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
3032 Value input, Value filter, Value bias, ArrayAttr strides_attr,
3033 ArrayAttr dilations_attr, ArrayAttr explicit_padding_attr,
3034 StringRef padding_ref, StringRef data_format_ref) {
3035 RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
3036 RankedTensorType filter_type = filter.getType().dyn_cast<RankedTensorType>();
3037 // Not a ranked tensor output
3038 if (!input_type) return llvm::None;
3039 if (!filter_type) return llvm::None;
3040
3041 // Transpose [H, W, I, O] to [O, H, W, I]
3042 auto filter_shape = filter_type.getShape();
3043 SmallVector<int64_t, 4> a1_transpose_dims;
3044 a1_transpose_dims.push_back(filter_shape[3]);
3045 a1_transpose_dims.push_back(filter_shape[0]);
3046 a1_transpose_dims.push_back(filter_shape[1]);
3047 a1_transpose_dims.push_back(filter_shape[2]);
3048 llvm::Optional<Value> a1_filter_transpose_perm = getConstTensor<int32_t>(
3049 rewriter, op, /*vec=*/{3, 0, 1, 2}, /*shape=*/{4});
3050
3051 if (!a1_filter_transpose_perm) return llvm::None;
3052
3053 auto a1_filter_transpose_op = rewriter.create<tosa::TransposeOp>(
3054 op->getLoc(),
3055 RankedTensorType::get(a1_transpose_dims, filter_type.getElementType()),
3056 filter, a1_filter_transpose_perm.getValue());
3057
3058 // Only support NHWC now.
3059 if (data_format_ref.str() != "NHWC") {
3060 op->emitWarning("convertTDConv2DCommon only supports NHWC!");
3061 return llvm::None;
3062 }
3063
3064 ArrayAttr stride;
3065 ArrayAttr dilation;
3066 ArrayAttr pad;
3067 {
3068 if (!strides_attr) {
3069 stride = rewriter.getI64ArrayAttr({1, 1});
3070 } else {
3071 // Note: hardcoded to NHWC for now
3072 int64_t stride_h = strides_attr[1].cast<IntegerAttr>().getInt();
3073 int64_t stride_w = strides_attr[2].cast<IntegerAttr>().getInt();
3074 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
3075 }
3076 }
3077 {
3078 if (!dilations_attr) {
3079 dilation = rewriter.getI64ArrayAttr({1, 1});
3080 } else {
3081 // Note: hardcoded to NHWC for now
3082 int64_t dilation_h = dilations_attr[1].cast<IntegerAttr>().getInt();
3083 int64_t dilation_w = dilations_attr[2].cast<IntegerAttr>().getInt();
3084 dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
3085 }
3086 }
3087 {
3088 tensorflow::Padding tf_pad;
3089 if (!GetPaddingFromString(padding_ref.str(), &tf_pad).ok()) {
3090 op->emitWarning("Could not get padding data from padding string term!");
3091 return llvm::None;
3092 }
3093
3094 tensorflow::TensorFormat data_format_tf;
3095 if (!FormatFromString(data_format_ref.str(), &data_format_tf))
3096 return llvm::None;
3097
3098 if (tf_pad == tensorflow::Padding::EXPLICIT) {
3099 pad = getPaddingValuesFromExplicitPadAttr(explicit_padding_attr,
3100 data_format_tf, rewriter);
3101 } else {
3102 if (!getPaddingValuesFromPadType(tf_pad, data_format_tf,
3103 0, // tensorflow::FORMAT_HWIO
3104 input_type, filter_type, stride,
3105 dilation, rewriter, pad))
3106 return llvm::None;
3107 }
3108 }
3109
3110 return rewriter
3111 .create<tosa::Conv2DOp>(op->getLoc(), output_type, input,
3112 a1_filter_transpose_op.getResult(), bias, pad,
3113 stride, dilation)
3114 .getResult();
3115 }
3116
3117 // Lowers Gather operators to a sequence of TOSA ops.
convertGatherOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value params_value,Value indices_value,int32_t batch_dims,int32_t axis)3118 llvm::Optional<Value> convertGatherOp(PatternRewriter& rewriter, Operation* op,
3119 Value result_value, Value params_value,
3120 Value indices_value, int32_t batch_dims,
3121 int32_t axis) {
3122 auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
3123 auto params_type = params_value.getType().dyn_cast<RankedTensorType>();
3124 auto indices_type = indices_value.getType().dyn_cast<RankedTensorType>();
3125
3126 if (!result_type || !params_type || !indices_type) return llvm::None;
3127
3128 // batch_dims indicates the number of batch dimensions in params and
3129 // indices axis indicates the axis at which the gather indexing is
3130 // applied. axis must be >= batch_dims. When axis is equal to
3131 // batch_dims, the right-most batch dimension disappears.
3132 //
3133 // N: number of batches
3134 // Computed as product of params.shape[0:batch_dims-1]
3135 //
3136 // W: number of indices in each batch
3137 // Computed as product of indices.shape[batch_dims:]
3138 //
3139 // K: range of each index
3140 // Computed as params.shape[axis:axis+rank(indices)-1]
3141 //
3142 // C: number of channels for each index
3143 // Computed as: LeftChannels * RightChannels:
3144 // product(params.shape[batch_dims:axis]) * product(params.shape[axis+1:])
3145 //
3146 // The params tensor needs to be transposed, then reshaped to move the
3147 // dimensions into [N, K, C] order.
3148 //
3149 // The dimensions of the input params[] tensor are grouped in the following
3150 // order to begin with:
3151 //
3152 // [Batch, LeftChannels, Indices, RightChannels]
3153 // |-----||------------||-------||-------------|
3154 // N C_l K C_r
3155 //
3156 // Where Batch (N), Indices (K) can be one or more dimensions in size,
3157 // while LeftChannels and RightChannels represent the group of data channels
3158 // (C) to the left and right (C_l, C_r) of the indices; the sum of these two
3159 // is one or more dimensions in size, but either one may be zero depending
3160 // on how axis was specified by the caller.
3161 //
3162 // The resulting tensor will look like:
3163 //
3164 // [Batch, Indices, LeftChannels, RightChannels]
3165 // |-----||-------||---------------------------|
3166 // N K C
3167 //
3168 // The indices tensor simply needs a reshape to flatten all of the
3169 // batch dimensions (N) together and flatten all of the indices (W)
3170 // together.
3171 //
3172 // Then do the tosa.GATHER
3173 //
3174 // output[N,W,C] = tosa.GATHER(values[N,K,C], indices[N,W])
3175 //
3176 // Finally, the resulting tensor will have shape [N, W, C], where C is a
3177 // flattened version of [LeftChannels, RightChannels]. We need to reshape
3178 // to unflatten to:
3179 //
3180 // [N, W, LeftChannels, RightChannels]
3181 //
3182 // and finally transpose back to the output shape
3183 //
3184 // [Batch, LeftChannels, Non-Batch-Indices, RightChannels]
3185
3186 int N = 1, W = 1, K = 1, C = 1;
3187
3188 int params_rank = params_type.getShape().size();
3189 int indices_rank = indices_type.getShape().size();
3190
3191 if (!(batch_dims <= indices_rank)) {
3192 op->emitOpError("Batch_dims must be <= indices_rank for a valid gather op");
3193 return llvm::None;
3194 }
3195
3196 if (!(axis >= batch_dims)) {
3197 op->emitOpError("axis must be >= batch_dims for a valid gather op");
3198 return llvm::None;
3199 }
3200
3201 // Sizes for each of these fields.
3202 SmallVector<int64_t> params_batch, params_indices, params_left_channels,
3203 params_right_channels;
3204
3205 // Dimension indices for each of these fields.
3206 SmallVector<int64_t> params_idx_batch, params_idx_indices,
3207 params_idx_left_channels, params_idx_right_channels;
3208
3209 // Read through the params tensor dimensions left-to-right and extract the
3210 // different fields.
3211 for (int i = 0; i < params_rank; i++) {
3212 // When batch_dims == axis, the batch dimension gets replaced.
3213 if (i < batch_dims && i < axis) {
3214 params_batch.push_back(params_type.getShape()[i]);
3215 params_idx_batch.push_back(i);
3216 } else if (i < axis) {
3217 params_left_channels.push_back(params_type.getShape()[i]);
3218 params_idx_left_channels.push_back(i);
3219 } else if (i < (axis + 1)) {
3220 params_indices.push_back(params_type.getShape()[i]);
3221 params_idx_indices.push_back(i);
3222 } else {
3223 params_right_channels.push_back(params_type.getShape()[i]);
3224 params_idx_right_channels.push_back(i);
3225 }
3226 }
3227
3228 // Calculate N, K, W, C
3229 for (int i = 0; i < batch_dims; i++) N *= params_type.getShape()[i];
3230
3231 for (int i = batch_dims; i < indices_rank; i++)
3232 W *= indices_type.getShape()[i];
3233
3234 K = params_type.getShape()[axis];
3235
3236 for (int i = batch_dims; i < axis; i++) C *= params_type.getShape()[i];
3237 for (int i = (axis + 1); i < params_rank; i++) C *= params_type.getShape()[i];
3238
3239 // Check for obviously invalid values before doing a divide.
3240 if (N <= 0 || K <= 0 || W <= 0 || C <= 0) {
3241 op->emitOpError(
3242 "N, K, W, or C was calculated as <= zero. Invalid dimensions for "
3243 "Gather");
3244 return llvm::None;
3245 }
3246
3247 /////////////////////////////////////////////
3248 // Build up the params transpose operator
3249 SmallVector<int32_t> params_transpose_perm;
3250 SmallVector<int64_t> params_transpose_shape;
3251
3252 // Batch
3253 for (int i = 0; i < params_batch.size(); i++) {
3254 params_transpose_perm.push_back(params_idx_batch[i]);
3255 params_transpose_shape.push_back(params_batch[i]);
3256 }
3257
3258 // Indices
3259 for (int i = 0; i < params_indices.size(); i++) {
3260 params_transpose_perm.push_back(params_idx_indices[i]);
3261 params_transpose_shape.push_back(params_indices[i]);
3262 }
3263
3264 // LeftChannels
3265 for (int i = 0; i < params_left_channels.size(); i++) {
3266 params_transpose_perm.push_back(params_idx_left_channels[i]);
3267 params_transpose_shape.push_back(params_left_channels[i]);
3268 }
3269
3270 // RightChannels
3271 for (int i = 0; i < params_right_channels.size(); i++) {
3272 params_transpose_perm.push_back(params_idx_right_channels[i]);
3273 params_transpose_shape.push_back(params_right_channels[i]);
3274 }
3275
3276 /////////////////////////////////////////////
3277 // Build up the result reshape, in prepration for transpose
3278 // [N, W, C] -> [ Batch, Indices, LeftChannels, RightChannels ]
3279 SmallVector<int64_t> result_reshape_shape;
3280
3281 // Indices
3282 for (int i = 0; i < indices_type.getShape().size(); i++) {
3283 result_reshape_shape.push_back(indices_type.getShape()[i]);
3284 }
3285
3286 // Left channels
3287 for (int i = 0; i < params_left_channels.size(); i++) {
3288 result_reshape_shape.push_back(params_left_channels[i]);
3289 }
3290
3291 // Right channels. But remove the axis dimension.
3292 for (int i = 0; i < params_right_channels.size(); i++) {
3293 result_reshape_shape.push_back(params_right_channels[i]);
3294 }
3295
3296 /////////////////////////////////////////////
3297 // Build up the result transpose operator.
3298 SmallVector<int32_t> result_transpose_perm;
3299
3300 // Batch dimensions
3301 for (int i = 0; i < batch_dims; i++) {
3302 result_transpose_perm.push_back(i);
3303 }
3304
3305 // LeftChannels
3306 for (int i = 0; i < params_left_channels.size(); i++) {
3307 result_transpose_perm.push_back(i + indices_type.getShape().size());
3308 }
3309
3310 // Indices (remainder of dimensions after batch).
3311 for (int i = batch_dims; i < (indices_type.getShape().size()); i++) {
3312 result_transpose_perm.push_back(i);
3313 }
3314
3315 // RightChannels, coming from after both the Indices and LeftChannels.
3316 for (int i = 0; i < params_right_channels.size(); i++) {
3317 result_transpose_perm.push_back(i + indices_type.getShape().size() +
3318 params_left_channels.size());
3319 }
3320
3321 SmallVector<int64_t> tosa_values_shape = {N, K, C};
3322 SmallVector<int64_t> tosa_indices_shape = {N, W};
3323 SmallVector<int64_t> tosa_gather_result_shape = {N, W, C};
3324
3325 llvm::Optional<Value> params_transpose_perm_val = getConstTensor<int32_t>(
3326 rewriter, op, params_transpose_perm,
3327 {static_cast<int64_t>(params_transpose_perm.size())});
3328
3329 llvm::Optional<Value> result_transpose_perm_val = getConstTensor<int32_t>(
3330 rewriter, op, result_transpose_perm,
3331 {static_cast<int64_t>(result_transpose_perm.size())});
3332
3333 if (!params_transpose_perm_val || !result_transpose_perm_val)
3334 return llvm::None;
3335
3336 auto params_transpose_op = rewriter.create<tosa::TransposeOp>(
3337 op->getLoc(),
3338 RankedTensorType::get(params_transpose_shape,
3339 params_type.getElementType()),
3340 params_value, params_transpose_perm_val.getValue());
3341
3342 auto tosa_values_reshape_op = rewriter.create<tosa::ReshapeOp>(
3343 op->getLoc(),
3344 RankedTensorType::get(tosa_values_shape, params_type.getElementType()),
3345 params_transpose_op.getResult(),
3346 rewriter.getI64ArrayAttr(tosa_values_shape));
3347
3348 auto tosa_indices_reshape_op = rewriter.create<tosa::ReshapeOp>(
3349 op->getLoc(),
3350 RankedTensorType::get(tosa_indices_shape, indices_type.getElementType()),
3351 indices_value, rewriter.getI64ArrayAttr(tosa_indices_shape));
3352
3353 auto tosa_gather_op = rewriter.create<tosa::GatherOp>(
3354 op->getLoc(),
3355 RankedTensorType::get(tosa_gather_result_shape,
3356 result_type.getElementType()),
3357 tosa_values_reshape_op.getResult(), tosa_indices_reshape_op.getResult());
3358
3359 auto tosa_result_reshape_op = rewriter.create<tosa::ReshapeOp>(
3360 op->getLoc(),
3361 RankedTensorType::get(result_reshape_shape, params_type.getElementType()),
3362 tosa_gather_op.getResult(),
3363 rewriter.getI64ArrayAttr(result_reshape_shape));
3364
3365 return rewriter
3366 .create<tosa::TransposeOp>(op->getLoc(), result_type,
3367 tosa_result_reshape_op.getResult(),
3368 result_transpose_perm_val.getValue())
3369 .getResult();
3370 }
3371
3372 // Lowers Gather operators to a sequence of TOSA ops.
convertGatherNdOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value params_value,Value indices_value)3373 llvm::Optional<Value> convertGatherNdOp(PatternRewriter& rewriter,
3374 Operation* op, Value result_value,
3375 Value params_value, Value indices_value)
3376
3377 {
3378 auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
3379 auto params_type = params_value.getType().dyn_cast<RankedTensorType>();
3380 auto indices_type = indices_value.getType().dyn_cast<RankedTensorType>();
3381
3382 if (!result_type || !params_type || !indices_type) return llvm::None;
3383
3384 // N: number of batches
3385 // Always 1 for GatherND
3386 //
3387 // Because TOSA's GATHER operator already uses the symbol 'N' for
3388 // the number of batches, we will use the symbol 'ND' to specify the
3389 // number of dimensions that are sliced from params instead of'N' in
3390 // the TF MLIR documentation.
3391 //
3392 // ND: indices.shape[-1]
3393 //
3394 // W: number of indices in each batch
3395 // Computed as:
3396 // product(indices.shape[0:-1]) (all but the last dimension)
3397 //
3398 // K: range of each index
3399 // Computed as:
3400 // product(params.shape[0:ND-1])
3401 //
3402 // C: number of channels for each index
3403 // Computed as:
3404 // product(params.shape[ND:])
3405 //
3406 // The params tensor needs to be reshaped, but not transposed, to move the
3407 // dimensions into [N, K, C] order.
3408 //
3409 // The dimensions of the input params[] tensor are grouped in the following
3410 // order to begin with:
3411 //
3412 // [ParamIndices, ParamChannels]
3413 // |------------||-------------|
3414 // K C
3415 //
3416 // The reshape simply flattens the params tensor into a 2D [K, C] shape.
3417 //
3418 // Indices needs to be put in the form of [N, W], but a simple flattening
3419 // will not suffice, because the indices need to index into a [W]-shape
3420 // vector instead of the params.shape[0:ND-1] tensor that we had before.
3421 //
3422 // To flatten the coordinates, first reshape indices to a [W, ND] matrix,
3423 // where the matrix now represents W ND-dimensional coordinates into the
3424 // params tensor.
3425 //
3426 // From here, we take each of the ND dimensions and multiply it with
3427 // the the size of the next params dimension (or 1 for the last
3428 // dimension), then sum all these together with a reduce_sum
3429 // operator. This is exactly the same mathematics as one would use
3430 // flatten the indicies of an N-dimensional row-major array into a
3431 // 1-D array in C.
3432 //
3433 // More precisely, do an element-wise multiply with [params.shape[1
3434 // .. ND], 1] in axis 1, then reduce_sum in axis 1 to flatten to a
3435 // [W]-shaped tensor, then trivially reshape to [N=1, W] to be
3436 // compatible with the GATHER operator's shape.
3437 //
3438 // Then perform the tosa.GATHER() operation.
3439 //
3440 // Now we have result = [N, K, C].
3441 //
3442 // Reshape with a single, simple reshape to the final output shape of:
3443 // [Indices, ParamChannels]
3444 //
3445 // Where, Indices is indices.shape[0:ND-1]
3446
3447 int N = 1, W = 1, K = 1, C = 1, ND = 1;
3448
3449 int params_rank = params_type.getShape().size();
3450 int indices_rank = indices_type.getShape().size();
3451
3452 ND = indices_type.getShape()[indices_rank - 1];
3453
3454 if (ND > params_rank) {
3455 op->emitOpError("Size of last dimension on indices must be <= params rank");
3456 return llvm::None;
3457 }
3458
3459 // Calculate N, K, W, C. (N is always 1)
3460 for (int i = 0; i < (indices_rank - 1); i++) {
3461 W *= indices_type.getShape()[i];
3462 }
3463
3464 for (int i = 0; i < ND; i++) {
3465 K *= params_type.getShape()[i];
3466 }
3467
3468 for (int i = ND; i < params_rank; i++) {
3469 C *= params_type.getShape()[i];
3470 }
3471
3472 SmallVector<int64_t, 3> tosa_values_shape({N, K, C});
3473 SmallVector<int64_t, 2> tosa_indices_shape({N, W});
3474 SmallVector<int64_t, 2> indices_matrix_shape({W, ND});
3475 SmallVector<int64_t, 3> tosa_gather_result_shape({N, W, C});
3476
3477 auto tosa_values_reshape_op = rewriter.create<tosa::ReshapeOp>(
3478 op->getLoc(),
3479 RankedTensorType::get(tosa_values_shape, params_type.getElementType()),
3480 params_value, rewriter.getI64ArrayAttr(tosa_values_shape));
3481
3482 // Flatten the input indices tensor to an [W, ND] matrix.
3483 auto indices_matrix_reshape_op = rewriter.create<tosa::ReshapeOp>(
3484 op->getLoc(),
3485 RankedTensorType::get(indices_matrix_shape,
3486 indices_type.getElementType()),
3487 indices_value, rewriter.getI64ArrayAttr(indices_matrix_shape));
3488
3489 SmallVector<int32_t> flattened_coeff_vec;
3490 for (int i = 1; i < ND; i++) {
3491 flattened_coeff_vec.push_back(params_type.getShape()[i]);
3492 }
3493 flattened_coeff_vec.push_back(1);
3494
3495 for (int i = ND - 1; i > 0; i--) {
3496 flattened_coeff_vec[i - 1] *= flattened_coeff_vec[i];
3497 }
3498
3499 llvm::Optional<Value> flattened_coeff_value = getConstTensor<int32_t>(
3500 rewriter, op, flattened_coeff_vec,
3501 {static_cast<int64_t>(flattened_coeff_vec.size())});
3502
3503 if (!flattened_coeff_value) return llvm::None;
3504
3505 // Multiply the coefficients by the coordinates
3506 auto flattened_indices_mul_op = rewriter.create<tosa::MulOp>(
3507 op->getLoc(),
3508 RankedTensorType::get(indices_matrix_shape,
3509 indices_type.getElementType()),
3510 indices_matrix_reshape_op.getResult(), flattened_coeff_value.getValue(),
3511 0);
3512
3513 // Sum up the products of the coefficients and coordinates
3514 auto flattened_indices_reduce_op = rewriter.create<tosa::ReduceSumOp>(
3515 op->getLoc(),
3516 RankedTensorType::get(tosa_indices_shape, indices_type.getElementType()),
3517 flattened_indices_mul_op.getResult(), rewriter.getI64IntegerAttr(1));
3518
3519 // And reshape to [N, W]
3520 auto tosa_indices_reshape_op = rewriter.create<tosa::ReshapeOp>(
3521 op->getLoc(),
3522 RankedTensorType::get(tosa_indices_shape, indices_type.getElementType()),
3523 flattened_indices_reduce_op.getResult(),
3524 rewriter.getI64ArrayAttr(tosa_indices_shape));
3525
3526 // Now the gather op itself
3527 auto tosa_gather_op = rewriter.create<tosa::GatherOp>(
3528 op->getLoc(),
3529 RankedTensorType::get(tosa_gather_result_shape,
3530 result_type.getElementType()),
3531 tosa_values_reshape_op.getResult(), tosa_indices_reshape_op.getResult());
3532
3533 // Finally, reshape back to the original output shape of [Indices,
3534 // ParamChannels].
3535 return rewriter
3536 .create<tosa::ReshapeOp>(op->getLoc(), result_type,
3537 tosa_gather_op.getResult(),
3538 rewriter.getI64ArrayAttr(result_type.getShape()))
3539 .getResult();
3540 }
3541
3542 // Lowers OneHot operator to a sequence of TOSA ops.
convertOneHotOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value indices_value,Value on_value,Value off_value,int32_t depth,int32_t axis)3543 llvm::Optional<Value> convertOneHotOp(PatternRewriter& rewriter, Operation* op,
3544 Value result_value, Value indices_value,
3545 Value on_value, Value off_value,
3546 int32_t depth, int32_t axis) {
3547 auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
3548 auto indices_type = indices_value.getType().dyn_cast<RankedTensorType>();
3549 auto on_value_type = on_value.getType().dyn_cast<RankedTensorType>();
3550 auto off_value_type = off_value.getType().dyn_cast<RankedTensorType>();
3551
3552 if (!result_type || !indices_type || !on_value_type || !off_value_type)
3553 return llvm::None;
3554
3555 // OneHot operator creates a new tensor with shape indices.shape[:axis] +
3556 // [depth] + indices.shape[axis:] For each index in 'indices', it needs to
3557 // be within range of [0, depth - 1] and the [..., k, ...] = on_value (if k
3558 // = index), or [..., k, ...] = off_value (if k != index)
3559 //
3560 // The lowering below assumes depth is always known at compile time.
3561 // TBD for depth resolved in run time.
3562 //
3563 // OneHot can be lowered as TOSA Scatter, where off_value being mapped to
3564 // 'values_in', on_value being mapped to 'input', and indices naturally
3565 // mapped to 'indices'. Also the dimensions of TOSA scatter (N, W, K, C)
3566 // need to be picked.
3567 //
3568 // N: number of elements of input indices
3569 // Computed as:
3570 // product(indices.shape[:])
3571 //
3572 // K: newly added dimension
3573 // K = depth
3574 //
3575 // W, C: dummy dimension now
3576 // W = C = 1
3577 //
3578 // High level description of lowering looks like:
3579 // 1. off_value is reshaped/tiled into [N, K, C]
3580 // 2. on_value is reshaped/tiled into [N, W, C]
3581 // 3. indices is reshaped into [N, W]
3582 // 4. scatter into [N, K, C]
3583 // 5. reshaped into [LeftDims, RightDims, K]
3584 // 6. transpose into [LeftDims, K, RightDims]
3585 // 7. reshaped to result.shape
3586
3587 if (on_value_type.getRank() != 0 || off_value_type.getRank() != 0) {
3588 op->emitOpError("OneHotOp: on_value/off_value needs to be scalar");
3589 return llvm::None;
3590 }
3591
3592 if (axis < -1 || axis > indices_type.getRank()) {
3593 op->emitOpError("OneHotOp: axis out of valie range [-1, indices.rank]");
3594 return llvm::None;
3595 }
3596
3597 // axis = -1 is equivalent to axis = indices.rank
3598 if (axis == -1) {
3599 axis = indices_type.getRank();
3600 }
3601
3602 int N = 1, W = 1, C = 1;
3603 int K = depth;
3604 int left_dim = 1, right_dim = 1;
3605
3606 for (int32_t i = 0; i < indices_type.getRank(); i++) {
3607 int32_t dim = indices_type.getShape()[i];
3608 N *= dim;
3609 if (i >= axis) {
3610 right_dim *= dim;
3611 } else {
3612 left_dim *= dim;
3613 }
3614 }
3615
3616 // Reshape on_value to [1, 1, 1]
3617 auto op1_reshape_on_value = rewriter.create<tosa::ReshapeOp>(
3618 op->getLoc(),
3619 RankedTensorType::get({1, 1, 1}, on_value_type.getElementType()),
3620 on_value, rewriter.getI64ArrayAttr({1, 1, 1}));
3621
3622 // And tile to [N, W, C]
3623 auto op2_tile_op1 = rewriter.create<tosa::TileOp>(
3624 op->getLoc(),
3625 RankedTensorType::get({N, W, C}, on_value_type.getElementType()),
3626 op1_reshape_on_value.getResult(), rewriter.getI64ArrayAttr({N, W, C}));
3627
3628 // Reshape off_value to [1, 1, 1]
3629 auto op3_reshape_off_value = rewriter.create<tosa::ReshapeOp>(
3630 op->getLoc(),
3631 RankedTensorType::get({1, 1, 1}, off_value_type.getElementType()),
3632 off_value, rewriter.getI64ArrayAttr({1, 1, 1}));
3633
3634 // And tile to [N, K, C]
3635 auto op4_tile_op3 = rewriter.create<tosa::TileOp>(
3636 op->getLoc(),
3637 RankedTensorType::get({N, K, C}, on_value_type.getElementType()),
3638 op3_reshape_off_value.getResult(), rewriter.getI64ArrayAttr({N, K, C}));
3639
3640 // Reshape indices to [N, W]
3641 auto op5_reshape_indices = rewriter.create<tosa::ReshapeOp>(
3642 op->getLoc(),
3643 RankedTensorType::get({N, W}, indices_type.getElementType()),
3644 indices_value, rewriter.getI64ArrayAttr({N, W}));
3645
3646 // Scatter to [N, K, C]
3647 auto op6_scatter_op4_op5_op2 = rewriter.create<tosa::ScatterOp>(
3648 op->getLoc(),
3649 RankedTensorType::get({N, K, C}, result_type.getElementType()),
3650 op4_tile_op3.getResult(), op5_reshape_indices.getResult(),
3651 op2_tile_op1.getResult());
3652
3653 // Reshaped to [LeftDims, RightDims, K]. C being squeezed out since it's 1.
3654 auto op7_reshape_op6 = rewriter.create<tosa::ReshapeOp>(
3655 op->getLoc(),
3656 RankedTensorType::get({left_dim, right_dim, K},
3657 result_type.getElementType()),
3658 op6_scatter_op4_op5_op2.getResult(),
3659 rewriter.getI64ArrayAttr({left_dim, right_dim, K}));
3660
3661 // Transposed to [LeftDims, K, RightDims].
3662 llvm::Optional<Value> perm_const =
3663 getConstTensor<int32_t>(rewriter, op, /*vec=*/{0, 2, 1}, /*shape=*/{3});
3664
3665 if (!perm_const) return llvm::None;
3666
3667 auto op8_transpose_op7 = rewriter.create<tosa::TransposeOp>(
3668 op->getLoc(),
3669 RankedTensorType::get({left_dim, K, right_dim},
3670 result_type.getElementType()),
3671 op7_reshape_op6.getResult(), perm_const.getValue());
3672
3673 // Reshaped to result.shape.
3674 return rewriter
3675 .create<tosa::ReshapeOp>(op->getLoc(), result_type,
3676 op8_transpose_op7.getResult(),
3677 rewriter.getI64ArrayAttr(result_type.getShape()))
3678 .getResult();
3679 }
3680
3681 }; // namespace tosa
3682 }; // namespace mlir
3683