1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file contains legalizations common to mapping both TensorFlow and
17 // TensorFlow Lite to TOSA. It operates generically on ops and does not have
18 // a hard reference on either dialect.
19 //
20 // Conversion functions return llvm::None on a legalization failure or a
21 // legalized value on success. Callers must check for presence of an
22 // llvm::Optional value after each call.
23
24 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
25
26 #include <climits>
27 #include <cstddef>
28 #include <cstdint>
29 #include <iterator>
30 #include <numeric>
31
32 #include "llvm/ADT/ArrayRef.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/Support/FormatVariadic.h"
35 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
36 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
37 #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
38 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
39 #include "mlir/IR/Matchers.h" // from @llvm-project
40 #include "mlir/IR/PatternMatch.h" // from @llvm-project
41 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
42
43 namespace mlir {
44 namespace tosa {
45
multiply_dims(int64_t a,int64_t b)46 static int64_t multiply_dims(int64_t a, int64_t b) {
47 if (a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize) {
48 return ShapedType::kDynamicSize;
49 }
50 return a * b;
51 }
52
multiply_dims(llvm::ArrayRef<int64_t> dims,int64_t res=1)53 static int64_t multiply_dims(llvm::ArrayRef<int64_t> dims, int64_t res = 1) {
54 for (auto dim : dims) {
55 if (ShapedType::isDynamic(dim)) {
56 return ShapedType::kDynamicSize;
57 }
58 res = res * dim;
59 }
60 return res;
61 }
62
count_dynamic_dims(llvm::ArrayRef<int64_t> dims)63 static int64_t count_dynamic_dims(llvm::ArrayRef<int64_t> dims) {
64 int64_t count = 0;
65 for (auto dim : dims)
66 if (ShapedType::isDynamic(dim)) ++count;
67 return count;
68 }
69
70 namespace {
71 // Given an axis that can be a positive or negative value and the tensor size,
72 // return the adjusted axis value wrapped around the tensor size.
adjust_axis(int32_t axis,int32_t tensor_size)73 int32_t adjust_axis(int32_t axis, int32_t tensor_size) {
74 return axis >= 0 ? axis % tensor_size : (axis + tensor_size) % tensor_size;
75 }
76 }; // namespace
77
78 // Copied Nudge implementation from
79 // tensorflow/core/kernels/fake_quant_ops_functor.h.
80 // Suggested approach to avoid significant TensorFlow
81 // build dependency.
tensorflow_nudge(const float min,const float max,const int quant_min,const int quant_max,float * nudged_min,float * nudged_max,float * scale)82 void tensorflow_nudge(const float min, const float max, const int quant_min,
83 const int quant_max, float* nudged_min, float* nudged_max,
84 float* scale) {
85 const float quant_min_float = static_cast<float>(quant_min);
86 const float quant_max_float = static_cast<float>(quant_max);
87 *scale = (max - min) / (quant_max_float - quant_min_float);
88 const float zero_point_from_min = quant_min_float - min / *scale;
89 const uint16_t nudged_zero_point = [zero_point_from_min, quant_min,
90 quant_min_float, quant_max,
91 quant_max_float] {
92 if (zero_point_from_min < quant_min_float) {
93 return static_cast<uint16_t>(quant_min);
94 }
95 if (zero_point_from_min > quant_max_float) {
96 return static_cast<uint16_t>(quant_max);
97 }
98 return static_cast<uint16_t>(std::round(zero_point_from_min));
99 }();
100 *nudged_min = (quant_min_float - nudged_zero_point) * (*scale);
101 *nudged_max = (quant_max_float - nudged_zero_point) * (*scale);
102 }
103
104 // Lowers the Pack operator to TOSA.
convertPackOp(PatternRewriter & rewriter,Operation * op,Value result_value,SmallVectorImpl<Value> & inputs,int32_t axis)105 llvm::Optional<Value> convertPackOp(PatternRewriter& rewriter, Operation* op,
106 Value result_value,
107 SmallVectorImpl<Value>& inputs,
108 int32_t axis) {
109 //////////////////////////////////////////////////
110 // Operator: output = Pack([values], axis) or output = Stack([values], axis)
111 // Lowering:
112 //
113 // This operator is lowered into a series of pairwise tosa.concat()
114 // operators and a reshape
115 // Depending on the inputs, a tranpose operator is also generated:
116 //
117 // Step 1: concatenate the tensors
118 // a1_concat = tosa.concat(input[0], input[1], axis)
119 // for (i = 2; i < len(input); i++)
120 // a1_concat = tosa.concat(a1_concat, input[i], axis)
121 //
122 // Step 2: reshape to N+1 dimensions
123 // a2_reshape = tosa.reshape(a1_concat, new_rank)
124 //
125 // Step 3: Transpose if a new dimension is being added:
126 // if (axis == rank(values[0]):
127 // // perm will be [1, 2, 3, 0]
128 // a3_transpose = tosa.transpose(a2_reshape, perm)
129
130 // Sanity check 1: make sure all input tensors have the same shape
131 // if input[0] has shape [A, B, C], input[1] to input[N-1] should also have
132 // shape[A, B, C]
133 RankedTensorType result_type =
134 result_value.getType().dyn_cast<RankedTensorType>();
135
136 // Check for ranked tensor type.
137 if (!result_type) {
138 op->emitOpError("PackOp: result type not ranked tensor");
139 return llvm::None;
140 }
141
142 // Valid axis in TF is [-rank(input), rank(input))
143 // Valid axis in TOSA is [0, rank(input))
144 // Plus rank(input) once if axis is negative.
145 RankedTensorType input_type =
146 op->getOperand(0).getType().dyn_cast<RankedTensorType>();
147 if (!input_type) {
148 op->emitOpError("PackOp: input type not ranked tensor");
149 return llvm::None;
150 }
151
152 input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
153 if (!input_type) {
154 op->emitOpError("Input 0 type not ranked tensor.");
155 return llvm::None;
156 }
157 ArrayRef<int64_t> input0_tensor_shape = input_type.getShape();
158 int input_tensor_rank = input0_tensor_shape.size();
159
160 for (int i = 1; i < inputs.size(); i++) {
161 input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
162 if (!input_type) {
163 op->emitOpError(llvm::formatv(
164 "reduce axis {} is not in valid range [-rank(input), rank(input))",
165 i));
166 return llvm::None;
167 }
168 ArrayRef<int64_t> next_tensor_shape = input_type.getShape();
169 if (next_tensor_shape.size() != input_tensor_rank) {
170 op->emitOpError("PackOp: input tensor rank mismatch.");
171 return llvm::None;
172 }
173 for (int d = 0; d < input0_tensor_shape.size(); d++) {
174 if (input0_tensor_shape[d] != next_tensor_shape[d]) {
175 op->emitOpError("PackOp: input tensor shape mismatch.");
176 return llvm::None;
177 }
178 }
179 }
180
181 // If input tensors are rank 0, should reshape them to rank 1 size 1 before
182 // performing concat.
183 if (input_tensor_rank == 0) {
184 SmallVector<int64_t, 1> reshape_rank1_size1_shape({1});
185 RankedTensorType reshape_rank1_size1_type = RankedTensorType::get(
186 reshape_rank1_size1_shape, result_type.getElementType());
187 ArrayAttr shape_rank1_size1_attr =
188 rewriter.getI64ArrayAttr(reshape_rank1_size1_shape);
189 for (int i = 0; i < inputs.size(); i++) {
190 auto a0_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
191 rewriter, op->getLoc(), reshape_rank1_size1_type, inputs[i],
192 shape_rank1_size1_attr);
193 inputs[i] = a0_reshape_op.getResult();
194 }
195 }
196
197 // Sanity check 2: axis can be from [0, rank(input)+1]
198 // Where rank(input)+1 means create a new dimension
199 // Negative values are also allowed up to -(rank(input)+1)
200 // where the axis "wraps around".
201 if (axis < 0) axis += input_tensor_rank;
202 if ((axis < 0) || (axis > (input_tensor_rank + 1))) {
203 op->emitOpError("PackOp: axis out of valid range.");
204 return llvm::None;
205 }
206
207 // Sanity check 2: if input shape is [A, B, C], output shape should be [N,
208 // A, B, C]
209 // 2.a check output is rank(input) + 1
210 SmallVector<int64_t> output_shape_vals(result_type.getShape().begin(),
211 result_type.getShape().end());
212 if (output_shape_vals.size() != (input_tensor_rank + 1)) {
213 op->emitOpError("PackOp: output tensor rank mismatch.");
214 return llvm::None;
215 }
216 // 2.b check output rank 0 is N
217 if (output_shape_vals[axis] != inputs.size()) {
218 op->emitOpError("PackOp: output tensor shape mismatch.");
219 return llvm::None;
220 }
221 // Most of the cases when PackOp.axis() is within [0, rank(input) - 1].
222 // We can directly concatenate along that axis and perform the reshape.
223 // For example, stack N [A, B, C] input tensor ranks along axis = 1
224 // after concatenation, output will be [A, N * B, C]
225 // and then reshape it into [A, N, B, C]
226 // a special case would be PackOp.axis() equal to rank(input), in which case
227 // we can't directly concatenate along the PackOp.axis(), instead
228 // we concat along axis=0, and reshape into [N, A, B, C]
229 // and then we need an extra transpose to [A, B, C, N].
230 int64_t concat_axis;
231 SmallVector<int32_t> perm;
232 SmallVector<int64_t> reshape_output_shape;
233 if (axis == 0 && input_tensor_rank == 0) {
234 concat_axis = 0;
235 } else if (axis == input_tensor_rank) {
236 concat_axis = 0;
237
238 // A special case when stack axis is equal to input tensor rank:
239 // Output shape is [A, B, C, N]
240 // so reshape output will be [N, A, B, C]
241 // and perm will be [1, 2, 3, 0].
242 reshape_output_shape.push_back(output_shape_vals[axis]);
243 for (int d = 0; d < input_tensor_rank; d++) {
244 perm.push_back(d + 1);
245 reshape_output_shape.push_back(output_shape_vals[d]);
246 }
247 perm.push_back(0);
248 } else {
249 // General case, doesn't need perm vector.
250 concat_axis = axis;
251 reshape_output_shape.assign(output_shape_vals.begin(),
252 output_shape_vals.end());
253 }
254 IntegerAttr concat_axis_attr = rewriter.getI64IntegerAttr(concat_axis);
255 ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_output_shape);
256
257 // Concat output shape will depend on concat_axis. E.g. [N * A, B, C]
258 SmallVector<int64_t> concat_output_shape;
259 if (input_tensor_rank == 0) {
260 concat_output_shape.push_back(1);
261 } else {
262 for (int i = 0; i < input_tensor_rank; i++) {
263 concat_output_shape.push_back(input0_tensor_shape[i]);
264 }
265 }
266
267 concat_output_shape[concat_axis] =
268 concat_output_shape[concat_axis] * inputs.size();
269 RankedTensorType concat_type = RankedTensorType::get(
270 ArrayRef<int64_t>(concat_output_shape), result_type.getElementType());
271
272 SmallVector<Value> inputs_0;
273 for (int i = 0; i < inputs.size(); i++) {
274 inputs_0.push_back(inputs[i]);
275 }
276 auto a1_concat_op = CreateOpAndInfer<tosa::ConcatOp>(
277 rewriter, op->getLoc(), concat_type, inputs_0, concat_axis_attr);
278
279 // Doesn't need reshape or transpose if input tensor is rank 0, since inputs
280 // are reshaped beforehand.
281 if (input_tensor_rank == 0) return a1_concat_op.getResult();
282
283 // Reshape [N * A, B, C] to [N, A, B, C].
284 RankedTensorType reshape_output_type =
285 RankedTensorType::get(reshape_output_shape, result_type.getElementType());
286
287 auto a2_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
288 rewriter, op->getLoc(), reshape_output_type, a1_concat_op.getResult(),
289 shape_attr);
290
291 // If axis is equal to input tensor rank, then we need extra transpose
292 // [N, A, B, C] to [A, B, C, N]
293 if (axis == input_tensor_rank) {
294 llvm::Optional<Value> a3_transpose_perm = getConstTensor<int32_t>(
295 rewriter, op, perm, {static_cast<int64_t>(perm.size())});
296
297 if (!a3_transpose_perm) return llvm::None;
298
299 return CreateOpAndInfer<tosa::TransposeOp>(
300 rewriter, op->getLoc(), result_type, a2_reshape_op.getResult(),
301 a3_transpose_perm.getValue())
302 .getResult();
303 }
304
305 return a2_reshape_op.getResult();
306 }
307
308 // Lowers the Unpack operator to TOSA
convertUnpackOp(PatternRewriter & rewriter,Operation * op,Value input_value,int32_t axis)309 llvm::Optional<SmallVector<Value>> convertUnpackOp(PatternRewriter& rewriter,
310 Operation* op,
311 Value input_value,
312 int32_t axis) {
313 RankedTensorType input_type =
314 input_value.getType().dyn_cast<RankedTensorType>();
315 if (!input_type) return llvm::None;
316
317 auto input_shape = input_type.getShape();
318 int64_t input_rank = input_shape.size();
319
320 SmallVector<Value> results_vec;
321
322 // Negative axis allowed as long as it's within [-input_rank, input_rank).
323 if (axis < 0) axis += input_rank;
324 if ((axis < 0) || (axis > input_rank)) {
325 op->emitOpError("UnpackOp: axis out of valid range.");
326 return llvm::None;
327 }
328
329 // Step 1: transpose 'axis' to leftmost dimension.
330 Value transposed_input_value;
331 if (axis != 0) {
332 SmallVector<int32_t> perm;
333 SmallVector<int64_t> a1_transpose_shape(input_rank);
334
335 perm.push_back(axis);
336 for (int i = 0; i < input_rank; i++) {
337 if (i == axis) continue;
338 perm.push_back(i);
339 }
340
341 llvm::Optional<Value> a1_transpose_perm = getConstTensor<int32_t>(
342 rewriter, op, perm, {static_cast<int64_t>(perm.size())});
343
344 if (!a1_transpose_perm) return llvm::None;
345
346 for (int i = 0; i < input_rank; i++) {
347 a1_transpose_shape[i] = input_shape[perm[i]];
348 }
349
350 auto a1_transpose_op = CreateOpAndInfer<tosa::TransposeOp>(
351 rewriter, op->getLoc(),
352 RankedTensorType::get(a1_transpose_shape, input_type.getElementType()),
353 input_value, a1_transpose_perm.getValue());
354
355 transposed_input_value = a1_transpose_op.getResult();
356 } else {
357 // Do nothing if axis is already at leftmost dimension.
358 transposed_input_value = input_value;
359 }
360
361 // Step 2: slice [N, A, B, C] into N [A, B, C].
362 RankedTensorType transposed_input_type =
363 transposed_input_value.getType().dyn_cast<RankedTensorType>();
364 if (!transposed_input_type) return llvm::None;
365
366 auto transposed_input_shape = transposed_input_type.getShape();
367 int64_t transposed_input_rank = transposed_input_shape.size();
368
369 for (int i = 0; i < transposed_input_shape[0]; i++) {
370 SmallVector<int64_t> begin_vals, size_vals, shape_vals;
371
372 for (int j = 0; j < transposed_input_rank; j++) {
373 if (j == 0) {
374 begin_vals.push_back(i);
375 size_vals.push_back(1);
376 } else {
377 begin_vals.push_back(0);
378 size_vals.push_back(transposed_input_shape[j]);
379 shape_vals.push_back(transposed_input_shape[j]);
380 }
381 }
382
383 ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
384 ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
385
386 auto a2_slice_op = CreateOpAndInfer<tosa::SliceOp>(
387 rewriter, op->getLoc(),
388 RankedTensorType::get(size_vals,
389 transposed_input_type.getElementType()),
390 transposed_input_value, begin, size);
391
392 auto a3_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
393 rewriter, op->getLoc(),
394 RankedTensorType::get(shape_vals,
395 transposed_input_type.getElementType()),
396 a2_slice_op.getResult(), rewriter.getI64ArrayAttr(shape_vals));
397
398 results_vec.push_back(a3_reshape_op.getResult());
399 }
400
401 return results_vec;
402 }
403
404 // Lowers the Select operator to TOSA.
convertSelectOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value condition_value,Value x_value,Value y_value)405 llvm::Optional<Value> convertSelectOp(PatternRewriter& rewriter, Operation* op,
406 Value result_value, Value condition_value,
407 Value x_value, Value y_value) {
408 RankedTensorType result_type =
409 result_value.getType().dyn_cast<RankedTensorType>();
410 RankedTensorType condition_type =
411 condition_value.getType().dyn_cast<RankedTensorType>();
412 RankedTensorType x_type = x_value.getType().dyn_cast<RankedTensorType>();
413 RankedTensorType y_type = y_value.getType().dyn_cast<RankedTensorType>();
414
415 if (!result_type || !condition_type || !x_type || !y_type) {
416 op->emitOpError("Select: failed ranked tensor type check");
417 return llvm::None;
418 }
419
420 // First check whether we need to reshape the condition to match
421 // the same rank as the then/else clauses.
422 if (result_type.getRank() == condition_type.getRank()) {
423 // Nothing to reshape.
424 return CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(), result_type,
425 condition_value, x_value, y_value)
426 .getResult();
427 }
428
429 // Need to reshape the condition.
430 SmallVector<int64_t> new_cond_dims(
431 result_type.getRank() - condition_type.getRank(), 1);
432
433 for (int i = 0; i < condition_type.getRank(); i++) {
434 new_cond_dims.push_back(condition_type.getShape()[i]);
435 }
436
437 auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
438 rewriter, op->getLoc(),
439 RankedTensorType::get(new_cond_dims, condition_type.getElementType()),
440 condition_value, rewriter.getI64ArrayAttr(new_cond_dims));
441
442 return CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(), result_type,
443 reshape_op, x_value, y_value)
444 .getResult();
445 }
446
447 // Lowers the ZerosLike operator to TOSA by creating a constant
448 // of the desired type and shape.
convertZerosLikeOp(PatternRewriter & rewriter,Operation * op,Value result,Value input)449 llvm::Optional<Value> convertZerosLikeOp(PatternRewriter& rewriter,
450 Operation* op, Value result,
451 Value input) {
452 RankedTensorType result_type = result.getType().dyn_cast<RankedTensorType>();
453 if (!result_type) {
454 op->emitOpError("Zeroslike: result not ranked tensor type");
455 return llvm::None;
456 }
457
458 RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
459 if (!input_type) {
460 op->emitOpError("Zeroslike: input not ranked tensor type");
461 return llvm::None;
462 }
463
464 auto input_shape = input_type.getShape();
465
466 ShapedType zero_type =
467 RankedTensorType::get(input_shape, input_type.getElementType());
468 Attribute zero_attr = rewriter.getZeroAttr(zero_type);
469
470 return CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(), zero_type,
471 zero_attr.cast<ElementsAttr>())
472 .getResult();
473 }
474
475 // Lowers the Mul operator to TOSA. For quantized types, this requires
476 // inserting rescale operators before and after the operation.
convertMultiplyOp(PatternRewriter & rewriter,Operation * op,Value output_val,Value input_lhs_val,Value input_rhs_val)477 llvm::Optional<Value> convertMultiplyOp(PatternRewriter& rewriter,
478 Operation* op, Value output_val,
479 Value input_lhs_val,
480 Value input_rhs_val) {
481 ShapedType input_lhs_type = input_lhs_val.getType().dyn_cast<ShapedType>();
482 ShapedType input_rhs_type = input_rhs_val.getType().dyn_cast<ShapedType>();
483 ShapedType output_type = output_val.getType().dyn_cast<ShapedType>();
484 // Not a shaped tensor output
485 if (!input_lhs_type || !input_rhs_type || !output_type) return llvm::None;
486
487 bool input_lhs_is_qtype =
488 input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
489 bool input_rhs_is_qtype =
490 input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
491 bool output_is_qtype =
492 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
493
494 if (input_lhs_is_qtype != output_is_qtype ||
495 input_rhs_is_qtype != output_is_qtype) {
496 op->emitOpError(
497 "ConvertMultiplyOp: input/output tensor should "
498 "be all quantized or all floating-point");
499 return llvm::None;
500 }
501
502 if (output_is_qtype) {
503 ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
504 auto input_lhs_qtype = input_lhs_type.getElementType()
505 .cast<mlir::quant::UniformQuantizedType>();
506 auto input_rhs_qtype = input_rhs_type.getElementType()
507 .cast<mlir::quant::UniformQuantizedType>();
508 auto output_qtype =
509 output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
510
511 // MLIR store scale as double, but TFLite store scale as float
512 // Downcasting from double to float to match TFLite behavior
513 float in_lhs_scale = input_lhs_qtype.getScale();
514 float in_rhs_scale = input_rhs_qtype.getScale();
515 float output_scale = output_qtype.getScale();
516
517 double output_rescale_scale = in_lhs_scale * in_rhs_scale / output_scale;
518
519 // 16bits x 16bits -> 32bits
520 // 32bits can be rescaled with 32bits quantize multiplier back to 16bits
521 bool scale32 = true;
522
523 Value op1_rescale_lhs = buildRescaleToInt32(
524 rewriter, op, input_lhs_val, 1.0f, input_lhs_qtype.getZeroPoint());
525 Value op2_rescale_rhs = buildRescaleToInt32(
526 rewriter, op, input_rhs_val, 1.0f, input_rhs_qtype.getZeroPoint());
527 auto op3_mul_op1_op2 =
528 CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), rescale_type,
529 op1_rescale_lhs, op2_rescale_rhs, 0);
530 return buildRescale(rewriter, op, output_type, op3_mul_op1_op2.getResult(),
531 output_rescale_scale, 0, output_qtype.getZeroPoint(),
532 true, scale32);
533 }
534
535 return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
536 input_lhs_val, input_rhs_val, 0)
537 .getResult();
538 }
539
540 // Lowers the SquaredDifference operator to TOSA.
convertSquaredDifferenceOp(PatternRewriter & rewriter,Operation * op,Value result,Value x,Value y)541 llvm::Optional<Value> convertSquaredDifferenceOp(PatternRewriter& rewriter,
542 Operation* op, Value result,
543 Value x, Value y) {
544 // Squared-difference is (x-y)*(x-y).
545 // This lowering calculates the difference and multiplies.
546 ShapedType result_type = result.getType().dyn_cast<ShapedType>();
547 if (!result_type) {
548 op->emitOpError("SquaredDifference: result not ranked tensor type");
549 return llvm::None;
550 }
551
552 ShapedType x_type = x.getType().dyn_cast<ShapedType>();
553 ShapedType y_type = y.getType().dyn_cast<ShapedType>();
554 if (!x_type || !y_type) {
555 op->emitOpError("SquaredDifference: inputs not ranked tensor type");
556 return llvm::None;
557 }
558
559 auto sub_op =
560 CreateOpAndInfer<tosa::SubOp>(rewriter, op->getLoc(), result_type, x, y);
561 return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), result_type,
562 sub_op.getResult(), sub_op.getResult(),
563 0)
564 .getResult();
565 }
566
567 // Lowers the Round operator to TOSA.
convertRoundOp(PatternRewriter & rewriter,Operation * op,Value result,Value input)568 llvm::Optional<Value> convertRoundOp(PatternRewriter& rewriter, Operation* op,
569 Value result, Value input) {
570 // Implements banker's rounding by calculating floor(input + 0.5).
571 ShapedType result_type = result.getType().dyn_cast<ShapedType>();
572 if (!result_type) {
573 op->emitOpError("Round: result not shaped tensor type");
574 return llvm::None;
575 }
576
577 ShapedType input_type = input.getType().dyn_cast<ShapedType>();
578 if (!input_type) {
579 op->emitOpError("Round: input not shaped tensor type");
580 return llvm::None;
581 }
582
583 auto add_op = CreateOpAndInfer<tosa::AddOp>(
584 rewriter, op->getLoc(), result_type, input,
585 getTosaConstTensorSingleF32(rewriter, op, 0.5));
586
587 return CreateOpAndInfer<tosa::FloorOp>(rewriter, op->getLoc(), result_type,
588 add_op.getResult())
589 .getResult();
590 }
591
592 // Lowers ConcatV2 to TOSA Concat.
convertConcatV2Op(PatternRewriter & rewriter,Operation * op,Value result_value,SmallVectorImpl<Value> & values,int32_t axis)593 llvm::Optional<Value> convertConcatV2Op(PatternRewriter& rewriter,
594 Operation* op, Value result_value,
595 SmallVectorImpl<Value>& values,
596 int32_t axis) {
597 // Check all inputs are RankedTensorType
598 for (auto v : values) {
599 if (!v.getType().dyn_cast<RankedTensorType>()) {
600 op->emitOpError("ConcatV2Op: value type not ranked tensor.");
601 return llvm::None;
602 }
603 }
604
605 // Check output is Ranked tensor type
606 if (!result_value.getType().dyn_cast<RankedTensorType>()) {
607 op->emitOpError("ConcatV2Op: output value type not ranked tensor.");
608 return llvm::None;
609 }
610
611 RankedTensorType result_type =
612 result_value.getType().dyn_cast<RankedTensorType>();
613 mlir::quant::UniformQuantizedType result_quant_type =
614 result_type.getElementType()
615 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
616
617 SmallVector<Value> values_rescaled;
618
619 for (auto v : values) {
620 RankedTensorType operand_type = v.getType().dyn_cast<RankedTensorType>();
621 mlir::quant::UniformQuantizedType operand_quant_type =
622 operand_type.getElementType()
623 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
624
625 // tfl.concat currently allows different scales for each input tensor, which
626 // TFlite team will fix in:
627 // https://github.com/tensorflow/tensorflow/issues/39658
628 // For backward compatibility, we still need to support this artifact by
629 // scaling inputs to let them have the same scales.
630 if (result_quant_type && operand_quant_type) {
631 double operand_scale = static_cast<double>(operand_quant_type.getScale());
632 int32_t operand_zeropoint = operand_quant_type.getZeroPoint();
633
634 double result_scale = static_cast<double>(result_quant_type.getScale());
635 int32_t result_zeropoint = result_quant_type.getZeroPoint();
636
637 // Rescale input if scale is not equal to output tensor scale.
638 if (operand_scale != result_scale) {
639 RankedTensorType rescale_type =
640 RankedTensorType::get(operand_type.getShape(), result_quant_type);
641 Value rescale_op = buildRescale(
642 rewriter, op, rescale_type, v, operand_scale / result_scale,
643 operand_zeropoint, result_zeropoint, false, true);
644 values_rescaled.push_back(rescale_op);
645 } else {
646 values_rescaled.push_back(v);
647 }
648 } else {
649 values_rescaled.push_back(v);
650 }
651 }
652
653 int32_t tensor_rank = result_type.getShape().size();
654
655 if (axis < 0) axis += tensor_rank;
656 if ((axis < 0) || (axis > tensor_rank)) {
657 op->emitOpError("ConcatV2Op: axis out of valid range.");
658 return llvm::None;
659 }
660
661 auto concat_op = CreateOpAndInfer<tosa::ConcatOp>(
662 rewriter, op->getLoc(), result_value.getType(), values_rescaled,
663 rewriter.getI64IntegerAttr(axis));
664
665 return concat_op.getResult();
666 }
667
668 // Lowers SpaceToBatchND to TOSA.
convertSpaceToBatchNDOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value block_shape_value,Value paddings_value)669 llvm::Optional<Value> convertSpaceToBatchNDOp(PatternRewriter& rewriter,
670 Operation* op, Value result_value,
671 Value input_value,
672 Value block_shape_value,
673 Value paddings_value) {
674 /////////////////////////////////////////////////
675 // Operator: output = SpaceToBatchND(input, block_shape, paddings)
676 // Lowering:
677 //
678 // SpaceToBatch input tensors are broken into three pieces:
679 // (a) batch dimension (N in NHWC)
680 // (b) input being transformed to batch dimension (typically H, W in NHWC)
681 // (c) remainder of input (typically C in NHWC)
682 //
683 // Step 0. Generate padding constant for the first reshape.
684 // No padding on the batch dimension
685 // The input paddings array is addressed as [input_rank][2]
686 // No padding on the remaining dimensions
687 //
688 // a0_pad_const = tosa.const(input=Tensor<input_rank, 2>)
689 //
690 // Step 1. Pad the input tensor
691 //
692 // a1_pad_input_op = tosa.pad(input=input, shape=a0_pad_const_op)
693 //
694 // Step 2. Reshape the padded structure of shape padded_shape to
695 // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
696 // padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
697 // remaining_shape
698 //
699 // block_rank = M (number of elements in block_shape)
700 // New rank: input_rank + block_rank
701 //
702 // a2_reshape_a1_op = tosa.reshape(input=a1_pad_input_op, shape=a2_shape)
703 //
704 // Step 3. Transpose dimensions to:
705 // block-shape +
706 // [batch] +
707 // [padded_shape[1] / block_shape[0],
708 // ...
709 // [padded_shape[M] / block_shape[M-1]] +
710 // remaining_shape
711 //
712 // a3_transpose_a2_op = tosa.tranpose(input=a2_reshape_a1_op,
713 // perms=a3_perm)
714 //
715 // Step 4. Reshape the transposed tensor to flatten block_shape stuff
716 // into the batch dimension with the following shape:
717 // [ batch * prod(block_shape)] +
718 // [ padded_shape[1] / block_shape[0],
719 // ...,
720 // padded_shape[M] / block_shape[M-1]] +
721 // remaining_shape
722 //
723 // a4_reshape_a3_op = tosa.reshape(input=a3_tranpose_a2_op,
724 // shape=a3_shape)
725 //
726
727 RankedTensorType result_type =
728 result_value.getType().dyn_cast<RankedTensorType>();
729 RankedTensorType input_type =
730 input_value.getType().dyn_cast<RankedTensorType>();
731 RankedTensorType block_shape_type =
732 block_shape_value.getType().dyn_cast<RankedTensorType>();
733 RankedTensorType paddings_type =
734 paddings_value.getType().dyn_cast<RankedTensorType>();
735
736 // Not a ranked tensor output.
737 if (!result_type) {
738 op->emitOpError("SpaceToBatchND: result type not ranked tensor");
739 return llvm::None;
740 }
741 if (!input_type) {
742 op->emitOpError("SpaceToBatchND: input type not ranked tensor");
743 return llvm::None;
744 }
745 if (!block_shape_type) {
746 op->emitOpError("SpaceToBatchND: block shape type not ranked tensor");
747 return llvm::None;
748 }
749 if (!paddings_type) {
750 op->emitOpError("SpaceToBatchND: paddings type not ranked tensor");
751 return llvm::None;
752 }
753
754 // Follow implementation in
755 // tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
756
757 // So, to figure out the spatial_shape, remove the batch dimension and
758 // then use the next block_rank dimensions. The remaining dimensions are
759 // remaining_shape.
760
761 auto block_shape = block_shape_type.getShape();
762 auto input_shape = input_type.getShape();
763
764 int block_rank = block_shape[0];
765 int batch_size = input_shape[0];
766 int input_rank = input_type.getRank();
767 int remaining_shape_rank = input_rank - block_rank - 1;
768 int block_num_elems = 1;
769 int padding_sum = 0;
770
771 ElementsAttr block_shape_elems;
772 ElementsAttr paddings_elems;
773
774 if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems)))
775 return llvm::None;
776
777 if (!matchPattern(paddings_value, m_Constant(&paddings_elems)))
778 return llvm::None;
779
780 SmallVector<int32_t> a0_pad_const(2 * (input_rank));
781 SmallVector<int64_t> padded_shape(input_rank);
782
783 // 1. Pad based on paddings operand. No padding on the batch dimension.
784 // The a0_pad_const array is addressed as [input_rank][2], but
785 // it is flattened to a 1D array because LLVM appears to only accept 1D.
786 //
787 // padded_shape[] is the shape of the padded output of step a1.
788 // The name is retained for consistency with the TF reference code.
789 padded_shape[0] = input_shape[0];
790
791 // Batch dimension padding
792 a0_pad_const[0] = 0;
793 a0_pad_const[1] = 0;
794
795 // This iterator seems to be the only reliable way to get
796 // int values out of a multi-dimensional ElementsAttr.
797 int idx = 0;
798
799 for (auto i : paddings_elems.getValues<IntegerAttr>()) {
800 a0_pad_const[idx + 2] = i.getInt();
801 padding_sum += i.getInt();
802 idx++;
803 }
804
805 // Insert padding on the spatial shape dimensions
806 for (int i = 0; i < block_rank; i++) {
807 padded_shape[i + 1] = input_shape[i + 1];
808 if (!ShapedType::isDynamic(padded_shape[i + 1])) {
809 int32_t lo_pad = a0_pad_const[2 * (i + 1) + 0];
810 int32_t hi_pad = a0_pad_const[2 * (i + 1) + 1];
811 padded_shape[i + 1] += lo_pad + hi_pad;
812 }
813 }
814
815 // No padding on the remaining_shape dimensions
816 for (int i = 0; i < remaining_shape_rank; i++) {
817 a0_pad_const[2 * (i + block_rank + 1) + 0] = 0;
818 a0_pad_const[2 * (i + block_rank + 1) + 1] = 0;
819 padded_shape[i + block_rank + 1] = input_shape[i + block_rank + 1];
820 }
821
822 RankedTensorType a0_pad_const_attr_type =
823 RankedTensorType::get({(input_rank), 2}, rewriter.getIntegerType(32));
824
825 // Create a const op to generate the tensor type for the input padding array
826 auto a0_pad_const_op = rewriter.create<tosa::ConstOp>(
827 op->getLoc(), a0_pad_const_attr_type,
828 DenseElementsAttr::get(a0_pad_const_attr_type,
829 llvm::makeArrayRef(a0_pad_const)));
830
831 auto a1_pad_input_op = CreateOpAndInfer<tosa::PadOp>(
832 rewriter, op->getLoc(),
833 RankedTensorType::get(padded_shape, result_type.getElementType()),
834 input_value, a0_pad_const_op.getResult());
835
836 // 2. Reshape the padded structure of shape padded_shape to
837 // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
838 // padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
839 // remaining_shape
840
841 // block_rank = M (number of elements in block_shape)
842 // New rank: input_rank + block_rank
843 SmallVector<int64_t> a2_shape(1 + block_rank * 2 + remaining_shape_rank);
844
845 // First dimension is batch.
846 a2_shape[0] = input_type.getShape()[0];
847 for (int i = 0; i < block_rank; i++) {
848 int32_t block_shape_val =
849 block_shape_elems.getValues<IntegerAttr>()[i].getInt();
850 a2_shape[1 + i * 2 + 0] = padded_shape[1 + i];
851 if (a2_shape[1 + i * 2 + 0] != ShapedType::kDynamicSize) {
852 a2_shape[1 + i * 2 + 0] /= block_shape_val;
853 }
854
855 a2_shape[1 + i * 2 + 1] = block_shape_val;
856 block_num_elems *= block_shape_val;
857 }
858
859 // Copy in the remaining block shape.
860 for (int i = 0; i < remaining_shape_rank; i++) {
861 a2_shape[1 + block_rank * 2 + i] = input_shape[1 + block_rank + i];
862 }
863
864 auto a2_reshape_a1_op = CreateOpAndInfer<tosa::ReshapeOp>(
865 rewriter, op->getLoc(),
866 RankedTensorType::get(a2_shape, result_type.getElementType()),
867 a1_pad_input_op.getResult(), rewriter.getI64ArrayAttr(a2_shape));
868
869 // 3. Transpose dimensions to:
870 // block-shape +
871 // [batch] +
872 // [padded_shape[1] / block_shape[0],
873 // ...
874 // [padded_shape[M] / block_shape[M-1]] +
875 // remaining_shape
876 int32_t a2_reshape_a1_rank =
877 a2_reshape_a1_op.getResult().getType().cast<RankedTensorType>().getRank();
878 SmallVector<int32_t> a3_perm(a2_reshape_a1_rank);
879 SmallVector<int64_t> a3_transpose_shape(a2_reshape_a1_rank);
880
881 for (int i = 0; i < block_rank; i++) {
882 a3_perm[i] = 1 + 2 * i + 1;
883 a3_perm[block_rank + 1 + i] = 1 + 2 * i;
884 }
885 a3_perm[block_rank] = 0;
886 for (int i = 1 + block_rank * 2; i < a2_reshape_a1_rank; i++) {
887 a3_perm[i] = i;
888 }
889
890 for (int i = 0; i < a3_transpose_shape.size(); i++) {
891 a3_transpose_shape[i] = a2_shape[a3_perm[i]];
892 }
893
894 llvm::Optional<Value> a3_transpose_const = getConstTensor<int32_t>(
895 rewriter, op, a3_perm, {static_cast<int64_t>(a3_perm.size())});
896
897 if (!a3_transpose_const) return llvm::None;
898
899 auto a3_transpose_a2_op = CreateOpAndInfer<tosa::TransposeOp>(
900 rewriter, op->getLoc(),
901 RankedTensorType::get(a3_transpose_shape, result_type.getElementType()),
902 a2_reshape_a1_op.getResult(), a3_transpose_const.getValue());
903
904 // 4. Reshape the transposed tensor to flatten block_shape
905 // into the batch dimension with the following shape:
906 // [ batch * prod(block_shape)] +
907 // [ padded_shape[1] / block_shape[0],
908 // ...,
909 // padded_shape[M] / block_shape[M-1]] +
910 // remaining_shape
911 SmallVector<int64_t> a4_reshape_shape(input_rank);
912
913 // Batch
914 a4_reshape_shape[0] = multiply_dims(batch_size, block_num_elems);
915
916 // padded shape / block_shape.
917 for (int i = 0; i < block_rank; i++) {
918 int32_t block_shape_val =
919 block_shape_elems.getValues<IntegerAttr>()[i].getInt();
920 a4_reshape_shape[i + 1] = padded_shape[i + 1];
921 if (a4_reshape_shape[i + 1] != ShapedType::kDynamicSize) {
922 a4_reshape_shape[i + 1] /= block_shape_val;
923 }
924 }
925
926 // Copy in remainder shape.
927 for (int i = 0; i < remaining_shape_rank; i++) {
928 a4_reshape_shape[1 + block_rank + i] = input_shape[1 + block_rank + i];
929 }
930
931 return CreateOpAndInfer<tosa::ReshapeOp>(
932 rewriter, op->getLoc(), result_type,
933 a3_transpose_a2_op.getResult(),
934 rewriter.getI64ArrayAttr(a4_reshape_shape))
935 .getResult();
936 }
937
938 // Lowers BatchToSpaceND to TOSA.
convertBatchToSpaceNDOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value block_shape_value,Value crops_value)939 llvm::Optional<Value> convertBatchToSpaceNDOp(PatternRewriter& rewriter,
940 Operation* op, Value result_value,
941 Value input_value,
942 Value block_shape_value,
943 Value crops_value) {
944 /////////////////////////////////////////////////
945 // Operator: output = BatchToSpaceND(input, block_shape, clips)
946 // Lowering:
947 //
948 // BatchToSpace input tensors are broken into three pieces:
949 // (a) batch dimension (N in NHWC)
950 // (b) input being transformed from batch dimension (typically H, W in
951 // NHWC)
952 // (c) remainder of input (typically C in NHWC)
953 //
954 // Step 1. Reshape input to:
955 // [block_shape[0],
956 // ...
957 // [block_shape[M-1],
958 // [batch / prod(block_shape)]
959 // [input_shape[1],
960 // ...
961 // [input_shape[N-1]
962 //
963 // a1_reshape_input_op = tosa.reshape(input=input, shape=a1_shape)
964 //
965 // Step 2. Permute to shape
966 // [ batch / prod(block_shape) ],
967 // [ input_shape[1] ], [ block_shape[1] ]
968 // ...
969 // [ input_shape[M] ], [ block_shape[M-1]
970 // + remaining_input_shapes input_shape[M .. N-1]
971 //
972 // a2_transpose_a1 = tosa.transpose(input=a1_reshape_input_op,
973 // shape=a2_shape)
974 //
975 // Step 3. Reshape to:
976 // [ batch / prod(block_shape) ],
977 // [input_shape[1] * block_shape[0] ],
978 // ..
979 // [input_shape[M * block_shape[M-1],
980 // + remaining input shapes [input_shape[M+1.. N-1]]
981 //
982 // a3_reshape_a2 = tosa.reshape(input=a2_transpose_a1, shape=a3_shape)
983 //
984 // Step 4. Crop the start/end dimensions according to crops of the
985 // a3_reshape_a2 shape
986 //
987 // a4_slice_a3 = tosa.slice(input=a3_reshape_a2, start=a4_start,
988 // size=a4_size)
989
990 RankedTensorType result_type =
991 result_value.getType().dyn_cast<RankedTensorType>();
992 RankedTensorType input_type =
993 input_value.getType().dyn_cast<RankedTensorType>();
994 RankedTensorType block_shape_type =
995 block_shape_value.getType().dyn_cast<RankedTensorType>();
996 RankedTensorType crops_type =
997 crops_value.getType().dyn_cast<RankedTensorType>();
998
999 if (!result_type) {
1000 op->emitOpError("BatchToSpaceND: result type not ranked tensor");
1001 return llvm::None;
1002 }
1003 if (!input_type) {
1004 op->emitOpError("BatchToSpaceND: input type not ranked tensor");
1005 return llvm::None;
1006 }
1007 if (!block_shape_type) {
1008 op->emitOpError("BatchToSpaceND: block shape type not ranked tensor");
1009 return llvm::None;
1010 }
1011 if (!crops_type) {
1012 op->emitOpError("BatchToSpaceND: crops type not ranked tensor");
1013 return llvm::None;
1014 }
1015
1016 // Another 4-step process
1017 int block_rank = block_shape_type.getShape()[0];
1018 int input_rank = input_type.getRank();
1019 int crops_dims = crops_type.getShape()[0];
1020 int remaining_shape_rank = input_rank - block_rank - 1;
1021 auto input_shape = input_type.getShape();
1022
1023 ElementsAttr block_shape_elems;
1024 ElementsAttr crops_elems;
1025
1026 if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems))) {
1027 op->emitOpError("BatchToSpaceND: block_shape not a constant");
1028 return llvm::None;
1029 }
1030
1031 if (!matchPattern(crops_value, m_Constant(&crops_elems))) {
1032 op->emitOpError("BatchToSpaceND: crops not a constant");
1033 return llvm::None;
1034 }
1035
1036 SmallVector<int64_t> block_shape(block_rank);
1037 SmallVector<std::pair<int64_t, int64_t>> crops(crops_dims);
1038
1039 // Extract values for block_shape and crops now.
1040 int block_num_elems = 1;
1041 for (int i = 0; i < block_rank; i++) {
1042 int block_shape_val =
1043 rewriter
1044 .getI32IntegerAttr(
1045 block_shape_elems.getValues<IntegerAttr>()[i].getInt())
1046 .getInt();
1047 block_num_elems *= block_shape_val;
1048 block_shape[i] = block_shape_val;
1049 }
1050
1051 // This iterator seems to be the only reliable way to get
1052 // int values out of a multi-dimensional ElementsAttr
1053 SmallVector<int32_t> crops_const(2 * (crops_dims));
1054 int idx = 0;
1055 for (auto i : crops_elems.getValues<IntegerAttr>()) {
1056 crops_const[idx++] = i.getInt();
1057 }
1058
1059 for (int i = 0; i < crops_dims; i++) {
1060 int crops_lo = crops_const[i * crops_dims + 0];
1061 int crops_hi = crops_const[i * crops_dims + 1];
1062 crops[i] = std::make_pair(crops_lo, crops_hi);
1063 }
1064
1065 // Step 1. Reshape input to:
1066 // [block_shape[0],
1067 // ...
1068 // [block_shape[M-1],
1069 // [batch / prod(block_shape)]
1070 // [input_shape[1],
1071 // ...
1072 // [input_shape[N-1]
1073 SmallVector<int64_t> a1_shape(block_rank + input_rank);
1074
1075 for (int i = 0; i < block_rank; i++) a1_shape[i] = block_shape[i];
1076
1077 a1_shape[block_rank] = (input_shape[0] == ShapedType::kDynamicSize)
1078 ? ShapedType::kDynamicSize
1079 : input_shape[0] / block_num_elems;
1080
1081 for (int i = 0; i < input_rank - 1; i++)
1082 a1_shape[i + block_rank + 1] = input_shape[i + 1];
1083
1084 auto a1_reshape_input_op = CreateOpAndInfer<tosa::ReshapeOp>(
1085 rewriter, op->getLoc(),
1086 RankedTensorType::get(a1_shape, result_type.getElementType()),
1087 input_value, rewriter.getI64ArrayAttr(a1_shape));
1088
1089 // 2. Permute to shape
1090 // [ batch / prod(block_shape) ],
1091 // [ input_shape[1] ], [ block_shape[0] ]
1092 // ...
1093 // [ input_shape[M] ], [ block_shape[M-1]
1094 // + remaining_input_shapes input_shape[M+1 .. N-1]
1095
1096 // 2a. calculate the permutation
1097 SmallVector<int32_t> a2_perm(block_rank + input_rank);
1098 SmallVector<int64_t> a2_transpose_shape(block_rank + input_rank);
1099
1100 a2_perm[0] = block_rank;
1101 for (int i = 0; i < block_rank; i++) {
1102 a2_perm[1 + i * 2 + 0] = block_rank + 1 + i;
1103 a2_perm[1 + i * 2 + 1] = i;
1104 }
1105
1106 for (int i = 0; i < remaining_shape_rank; i++) {
1107 a2_perm[1 + 2 * block_rank + i] = 1 + 2 * block_rank + i;
1108 }
1109
1110 // 2b. calculate the a2_permuted shape
1111 for (int i = 0; i < (block_rank + input_rank); i++) {
1112 a2_transpose_shape[i] = a1_shape[a2_perm[i]];
1113 }
1114
1115 llvm::Optional<Value> a2_transpose_perm = getConstTensor<int32_t>(
1116 rewriter, op, a2_perm, {static_cast<int64_t>(a2_perm.size())});
1117
1118 if (!a2_transpose_perm) return llvm::None;
1119
1120 auto a2_transpose_a1_op = CreateOpAndInfer<tosa::TransposeOp>(
1121 rewriter, op->getLoc(),
1122 RankedTensorType::get(a2_transpose_shape, result_type.getElementType()),
1123 a1_reshape_input_op.getResult(), a2_transpose_perm.getValue());
1124
1125 // Step 3. Reshape to:
1126 // [ batch / prod(block_shape) ],
1127 // [input_shape[1] * block_shape[0] ],
1128 // ..
1129 // [input_shape[M * block_shape[M-1],
1130 // + remaining input shapes [input_shape[M+1.. N-1]]
1131 SmallVector<int64_t> a4_shape(input_rank);
1132
1133 a4_shape[0] = input_shape[0];
1134 if (a4_shape[0] != ShapedType::kDynamicSize) {
1135 a4_shape[0] /= block_num_elems;
1136 }
1137 for (int i = 0; i < block_rank; i++) {
1138 a4_shape[1 + i] = multiply_dims(input_shape[i + 1], block_shape[i]);
1139 }
1140 for (int i = 0; i < remaining_shape_rank; i++) {
1141 a4_shape[1 + block_rank + i] = input_shape[block_rank + 1 + i];
1142 }
1143
1144 auto a3_reshape_a2 = CreateOpAndInfer<tosa::ReshapeOp>(
1145 rewriter, op->getLoc(),
1146 RankedTensorType::get(a4_shape, result_type.getElementType()),
1147 a2_transpose_a1_op.getResult(), rewriter.getI64ArrayAttr(a4_shape));
1148
1149 // 4. Crop the start/end dimensions on 'spatial dimension' according to
1150 // crops
1151 // Use a slice operator to do the cropping.
1152 //
1153 // Calculate a beginning point and a size:
1154 // - Begin is the origin, offset by the lo crop amount in each dimension
1155 // - Size is the reshaped tensor size, minus the quantity (lo + hi) for each
1156 // dimension
1157 SmallVector<int64_t> a4_begin_vals(input_rank), a4_size_vals(input_rank);
1158
1159 for (int i = 0; i < input_rank; i++) {
1160 // Batch dimension and remaining dimensions.
1161 if (i == 0 || i > crops_dims) {
1162 a4_begin_vals[i] = 0;
1163 a4_size_vals[i] = result_type.getShape()[i];
1164 } else {
1165 // Spatial dimension.
1166 assert(i - 1 >= 0 && i - 1 < crops_dims);
1167 a4_begin_vals[i] = crops[i - 1].first;
1168 a4_size_vals[i] = a4_shape[i] - crops[i - 1].first - crops[i - 1].second;
1169 }
1170 }
1171
1172 return CreateOpAndInfer<tosa::SliceOp>(
1173 rewriter, op->getLoc(),
1174 RankedTensorType::get(a4_size_vals, result_type.getElementType()),
1175 a3_reshape_a2.getResult(), rewriter.getI64ArrayAttr(a4_begin_vals),
1176 rewriter.getI64ArrayAttr(a4_size_vals))
1177 .getResult();
1178 }
1179
1180 // Lowers ExpandDims to TOSA.
convertExpandDimsOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value dim_value)1181 llvm::Optional<Value> convertExpandDimsOp(PatternRewriter& rewriter,
1182 Operation* op, Value result_value,
1183 Value input_value, Value dim_value) {
1184 // Lowers to a reshape op with 1's inserted in the appropriate dimensions.
1185 RankedTensorType output_type =
1186 result_value.getType().dyn_cast<RankedTensorType>();
1187 // Not a ranked tensor output
1188 if (!output_type) {
1189 op->emitOpError("ExpandDims: output type not ranked tensor");
1190 return llvm::None;
1191 }
1192
1193 RankedTensorType input_type =
1194 input_value.getType().dyn_cast<RankedTensorType>();
1195 if (!input_type) {
1196 op->emitOpError("ExpandDims: input type not ranked tensor");
1197 return llvm::None;
1198 }
1199
1200 auto input_shape = input_type.getShape();
1201
1202 ElementsAttr dim_elem;
1203 if (!matchPattern(dim_value, m_Constant(&dim_elem))) return llvm::None;
1204
1205 if (dim_elem.getNumElements() > 1) {
1206 op->emitOpError("ExpandDims: expected single dimension to expand");
1207 return llvm::None;
1208 }
1209 int32_t dim = dim_elem.getValues<IntegerAttr>()[0].getInt();
1210 int32_t input_size = input_shape.size();
1211 SmallVector<int64_t> reshape_dims;
1212 if (dim >= input_size) { // add dim at end of tensor
1213 dim = input_size;
1214 for (int i = 0; i < input_shape.size(); i++) {
1215 reshape_dims.emplace_back(input_shape[i]);
1216 }
1217 reshape_dims.emplace_back(1);
1218 } else {
1219 if (dim < 0) {
1220 dim += input_size;
1221 if (dim < 0) {
1222 op->emitOpError(
1223 "ExpandDims: dimension to expand + size of input shape < 0");
1224 return llvm::None;
1225 }
1226 }
1227 for (int i = 0; i < input_size; i++) {
1228 if (i == dim) {
1229 reshape_dims.emplace_back(1);
1230 }
1231 reshape_dims.emplace_back(input_shape[i]);
1232 }
1233 }
1234
1235 ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
1236
1237 return CreateOpAndInfer<tosa::ReshapeOp>(rewriter, op->getLoc(), output_type,
1238 input_value, shape_attr)
1239 .getResult();
1240 }
1241
1242 // Lowers Squeeze to TOSA.
convertSqueezeOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,SmallVectorImpl<int32_t> & squeeze_dims)1243 llvm::Optional<Value> convertSqueezeOp(PatternRewriter& rewriter, Operation* op,
1244 Value result_value, Value input_value,
1245 SmallVectorImpl<int32_t>& squeeze_dims) {
1246 // Lowers to a reshape op where dimensions in squeeze_dims with size=1
1247 // are removed.
1248 RankedTensorType output_type =
1249 result_value.getType().dyn_cast<RankedTensorType>();
1250 // Not a ranked tensor output
1251 if (!output_type) {
1252 op->emitOpError("Squeeze: output type not ranked tensor");
1253 return llvm::None;
1254 }
1255
1256 RankedTensorType input_type =
1257 input_value.getType().dyn_cast<RankedTensorType>();
1258 if (!input_type) {
1259 op->emitOpError("Squeeze: input type not ranked tensor");
1260 return llvm::None;
1261 }
1262
1263 auto input_shape = input_type.getShape();
1264
1265 SmallVector<int64_t> reshape_dims;
1266
1267 if (squeeze_dims.empty()) { // remove all 1-dims
1268 for (int i = 0; i < input_shape.size(); i++) {
1269 if (input_shape[i] != 1) {
1270 reshape_dims.emplace_back(input_shape[i]);
1271 }
1272 }
1273 } else {
1274 for (auto& dim : squeeze_dims) {
1275 dim = dim < 0 ? dim + input_shape.size() : dim;
1276 }
1277
1278 // Remove only specified dims.
1279 // First sort the array so they can be picked off in sequence.
1280 std::sort(squeeze_dims.begin(), squeeze_dims.end(),
1281 [](const int32_t a, const int32_t b) { return a < b; });
1282
1283 int pos = 0;
1284 auto dim = squeeze_dims[pos];
1285 for (int i = 0; i < input_shape.size(); i++) {
1286 if (i == dim) {
1287 pos = pos + 1;
1288 if (pos < squeeze_dims.size())
1289 dim = squeeze_dims[pos];
1290 else
1291 dim = -1; // Invalid
1292 } else {
1293 reshape_dims.emplace_back(input_shape[i]);
1294 }
1295 }
1296 }
1297
1298 ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
1299
1300 return CreateOpAndInfer<tosa::ReshapeOp>(rewriter, op->getLoc(), output_type,
1301 input_value, shape_attr)
1302 .getResult();
1303 }
1304
1305 // Lowers ELU to a sequence of TOSA ops.
convertEluOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value features_value)1306 llvm::Optional<Value> convertEluOp(PatternRewriter& rewriter, Operation* op,
1307 Value result_value, Value features_value) {
1308 // Lowers Elu using the following formula:
1309 // elu(x) = x < 0 ? (exp(x) - 1) : x
1310 // one = const({1});
1311 // zero = const({0});
1312 // one_bcast = reshape(one, [1, ..., rank(x) - 1])
1313 // zero_bcast = reshape(zero, [1, ..., rank(x) - 1])
1314 // a1 = exp(x);
1315 // a2 = sub(a1, one_bcast)
1316 // a3 = ge(x, zero_bcast)
1317 // a4 = select(a3, x, a2)
1318 RankedTensorType output_type =
1319 result_value.getType().dyn_cast<RankedTensorType>();
1320 // Not a ranked tensor output
1321 if (!output_type) {
1322 op->emitOpError("Elu: output type not ranked tensor");
1323 return llvm::None;
1324 }
1325
1326 int32_t input_rank = output_type.getShape().size();
1327 SmallVector<int64_t> bcast_shape(input_rank, 1);
1328
1329 // Can't directly create size=1, rank=rank(input) tensor because
1330 // it will be optimized out. Instead, create rank0 tensor and reshape later.
1331 Value one_const_op = getTosaConstTensorSingleF32(rewriter, op, 1.0);
1332
1333 Value zero_const_op = getTosaConstTensorSingleF32(rewriter, op, 0.0);
1334
1335 auto a1_exp_in_op = CreateOpAndInfer<tosa::ExpOp>(
1336 rewriter, op->getLoc(), output_type, features_value);
1337
1338 auto a2_sub_a1_one_op =
1339 CreateOpAndInfer<tosa::SubOp>(rewriter, op->getLoc(), output_type,
1340 a1_exp_in_op.getResult(), one_const_op);
1341
1342 auto a3_ge_in_zero_op = CreateOpAndInfer<tosa::GreaterEqualOp>(
1343 rewriter, op->getLoc(),
1344 RankedTensorType::get(output_type.getShape(), rewriter.getIntegerType(1)),
1345 features_value, zero_const_op);
1346
1347 return CreateOpAndInfer<tosa::SelectOp>(
1348 rewriter, op->getLoc(), output_type, a3_ge_in_zero_op.getResult(),
1349 features_value, a2_sub_a1_one_op.getResult())
1350 .getResult();
1351 }
1352
1353 // Lowers Softmax to a sequence of TOSA ops.
convertSoftmaxOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value logits_value,double beta)1354 llvm::Optional<Value> convertSoftmaxOp(PatternRewriter& rewriter, Operation* op,
1355 Value result_value, Value logits_value,
1356 double beta) {
1357 // softmax = exp(logits) / reduce_sum(exp(logits), -1)
1358 //
1359 // or equivalently multiply exp(-max(logits)) to both numerator and
1360 // denominator we get:
1361 //
1362 // softmax = exp(logits - max(logits)) / reduce_sum(exp(logits -
1363 // max(logits)), -1)
1364 //
1365 // We'll use first version for direct fp lowering, and second version for
1366 // quantized lowering since second one we can restrict input to exp() be
1367 // negative, and thus LUT can always be within [0.0, 1.0].
1368 RankedTensorType output_type =
1369 result_value.getType().dyn_cast<RankedTensorType>();
1370 RankedTensorType input_type =
1371 logits_value.getType().dyn_cast<RankedTensorType>();
1372
1373 // Not a ranked tensor input/output
1374 if (!output_type || !input_type) {
1375 op->emitOpError("Softmax: input and result not ranked tensors");
1376 return llvm::None;
1377 }
1378
1379 // reduce_sum on last dimension
1380 int32_t input_rank = input_type.getShape().size();
1381 ArrayRef<int64_t> logits_shape = output_type.getShape();
1382
1383 if (input_type.getElementType().isa<mlir::quant::QuantizedType>() &&
1384 output_type.getElementType().isa<mlir::quant::QuantizedType>()) {
1385 SmallVector<int64_t> rsum_shape_v(input_type.getShape().begin(),
1386 input_type.getShape().end() - 1);
1387 rsum_shape_v.push_back(1);
1388 ArrayRef<int64_t> rsum_shape(rsum_shape_v);
1389 // The if condition already checks if these are UQTs
1390 mlir::quant::UniformQuantizedType in_quant_type =
1391 input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
1392 mlir::quant::UniformQuantizedType out_quant_type =
1393 output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
1394
1395 auto int16_element_qtype = mlir::quant::UniformQuantizedType::get(
1396 true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
1397 -32768, 32767);
1398 RankedTensorType int16_logits_type =
1399 RankedTensorType::get(logits_shape, int16_element_qtype);
1400 RankedTensorType int32_logits_type =
1401 RankedTensorType::get(logits_shape, rewriter.getIntegerType(32));
1402 RankedTensorType int16_rsum_type =
1403 RankedTensorType::get(rsum_shape, int16_element_qtype);
1404 RankedTensorType int32_rsum_type =
1405 RankedTensorType::get(rsum_shape, rewriter.getIntegerType(32));
1406
1407 if (in_quant_type.getStorageTypeIntegralWidth() == 8) {
1408 // Step 1. get x - max(x)
1409 Value op1_rescale_in =
1410 buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f,
1411 in_quant_type.getZeroPoint(), 0, false, true);
1412
1413 auto op2_reducemax_op1 = CreateOpAndInfer<tosa::ReduceMaxOp>(
1414 rewriter, op->getLoc(), int32_rsum_type, op1_rescale_in,
1415 rewriter.getI64IntegerAttr(input_rank - 1));
1416
1417 auto op3_sub_op1_op2 = CreateOpAndInfer<tosa::SubOp>(
1418 rewriter, op->getLoc(), int32_logits_type, op1_rescale_in,
1419 op2_reducemax_op1.getResult());
1420
1421 // Step 2. get exp() result
1422 // Implemented with 4 16-bit table lookup
1423 // We use a 16-bit table lookup, as the result of x - max(x) for
1424 // 8-bit values is a 9-bit value. We only use the bottom 8 bits of each
1425 // table to avoid having the slope between two 16-bit table entries be
1426 // greater than 16 bits, causing potential interpolation errors
1427 auto exp_func = [](double x) -> double { return std::exp(x); };
1428
1429 Value exp_table_const_01, exp_table_const_02, exp_table_const_03,
1430 exp_table_const_04;
1431 getTosaConst32bitTable(rewriter, op, beta * in_quant_type.getScale(), 0,
1432 exp_func, exp_table_const_01, exp_table_const_02,
1433 exp_table_const_03, exp_table_const_04);
1434
1435 Value op4_rescale_op3 =
1436 buildRescale(rewriter, op, int16_logits_type,
1437 op3_sub_op1_op2.getResult(), 128.0, 0, 0, false, true);
1438
1439 // Input is 9.7, where lower 7 bits are all zeros.
1440 // Output is 23 bits, where lower 7 bits should be all zeros as well,
1441 // since there's no interpolation here.
1442 auto op5_table_op4_bits_31_24 = CreateOpAndInfer<tosa::TableOp>(
1443 rewriter, op->getLoc(), int32_logits_type, op4_rescale_op3,
1444 exp_table_const_01);
1445
1446 auto op6_table_op4_bits_23_16 = CreateOpAndInfer<tosa::TableOp>(
1447 rewriter, op->getLoc(), int32_logits_type, op4_rescale_op3,
1448 exp_table_const_02);
1449
1450 auto op7_table_op4_bits_15_8 = CreateOpAndInfer<tosa::TableOp>(
1451 rewriter, op->getLoc(), int32_logits_type, op4_rescale_op3,
1452 exp_table_const_03);
1453
1454 auto op8_table_op4_bits_7_0 = CreateOpAndInfer<tosa::TableOp>(
1455 rewriter, op->getLoc(), int32_logits_type, op4_rescale_op3,
1456 exp_table_const_04);
1457
1458 // To get 16 bits upper/lower value, we need to right shift 7 bits
1459 // And then we reconstruct 32-bit value we need (upper << 24) + lower
1460 // So effectively we left shift the 1st group of 8-bit with 17 bits
1461 auto op7_lshift_op5 = CreateOpAndInfer<tosa::LogicalLeftShiftOp>(
1462 rewriter, op->getLoc(), int32_logits_type,
1463 op5_table_op4_bits_31_24.getResult(),
1464 getTosaConstTensorSingleI32(rewriter, op, 17));
1465
1466 // For the 2nd group of 8-bit, we need to >> 7 AND << 16 ==> << 9
1467 auto op8_lshift_op6 = CreateOpAndInfer<tosa::LogicalLeftShiftOp>(
1468 rewriter, op->getLoc(), int32_logits_type,
1469 op6_table_op4_bits_23_16.getResult(),
1470 getTosaConstTensorSingleI32(rewriter, op, 9));
1471
1472 // For the 3rd 8-bit, we need to >> 7 AND << 8 ==> << 1
1473 auto op9_lshift_op7 = CreateOpAndInfer<tosa::LogicalLeftShiftOp>(
1474 rewriter, op->getLoc(), int32_logits_type,
1475 op7_table_op4_bits_15_8.getResult(),
1476 getTosaConstTensorSingleI32(rewriter, op, 1));
1477
1478 // For the last 8-bit, we only need to >> 7
1479 auto op10_rshift_op8 = CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
1480 rewriter, op->getLoc(), int32_logits_type,
1481 op8_table_op4_bits_7_0.getResult(),
1482 getTosaConstTensorSingleI32(rewriter, op, 7), true);
1483
1484 // Add together all the 8-bit groups
1485 // Add [1+2]
1486 auto op11_add_op7_op8 = CreateOpAndInfer<tosa::AddOp>(
1487 rewriter, op->getLoc(), int32_logits_type, op7_lshift_op5.getResult(),
1488 op8_lshift_op6.getResult());
1489
1490 // Add [1+2+3]
1491 auto op12_add_op11_op9 = CreateOpAndInfer<tosa::AddOp>(
1492 rewriter, op->getLoc(), int32_logits_type,
1493 op11_add_op7_op8.getResult(), op9_lshift_op7.getResult());
1494
1495 // Add [1+2+3+4]
1496 auto op13_add_op12_op10 = CreateOpAndInfer<tosa::AddOp>(
1497 rewriter, op->getLoc(), int32_logits_type,
1498 op12_add_op11_op9.getResult(), op10_rshift_op8.getResult());
1499
1500 // Step 3. get sum(exp()). output 12.19
1501 auto op14_rshift_op13_12 = CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
1502 rewriter, op->getLoc(), int32_logits_type,
1503 op13_add_op12_op10.getResult(),
1504 getTosaConstTensorSingleI32(rewriter, op, 12), true);
1505
1506 auto op15_reducesum_op14 = CreateOpAndInfer<tosa::ReduceSumOp>(
1507 rewriter, op->getLoc(), int32_rsum_type,
1508 op14_rshift_op13_12.getResult(),
1509 rewriter.getI64IntegerAttr(input_rank - 1));
1510
1511 // Step 4. calculate reciprocal(sum(exp()))
1512 // CLZ returns the number of leading zeros which equals to headroom + 1
1513 auto op16_clz_op15 =
1514 CreateOpAndInfer<tosa::ClzOp>(rewriter, op->getLoc(), int32_rsum_type,
1515 op15_reducesum_op14.getResult());
1516
1517 // minus one to get headroom
1518 auto op17_sub_op16 = CreateOpAndInfer<tosa::SubOp>(
1519 rewriter, op->getLoc(), int32_rsum_type, op16_clz_op15.getResult(),
1520 getTosaConstTensorSingleI32(rewriter, op, 1));
1521
1522 // Left shift to get s1.30 format
1523 auto op18_lshift_op15_op17 = CreateOpAndInfer<tosa::LogicalLeftShiftOp>(
1524 rewriter, op->getLoc(), int32_rsum_type,
1525 op15_reducesum_op14.getResult(), op17_sub_op16.getResult());
1526
1527 // Step 5. Calculate one_over_one_plus_x() with Newton-Raphson division
1528 // with 3 iterations.
1529 // Need two magic constants 48/17 and -32/17 from Newton-Raphson algorithm
1530 // We need to operate in s2.29 since 48/17 is > 2.0
1531 // Reference: gemmlowp/fixedpoint/fixedpoint.h
1532 Value half_denominator = op18_lshift_op15_op17.getResult();
1533 Value four = getTosaConstTensorSingleI32(rewriter, op, 4);
1534 Value F2_one = getTosaConstTensorSingleI32(rewriter, op, (1U << 29));
1535 Value constant_48_over_17 =
1536 getTosaConstTensorSingleI32(rewriter, op, 1515870810);
1537 Value constant_neg_32_over_17 =
1538 getTosaConstTensorSingleI32(rewriter, op, -1010580540);
1539
1540 // F2 x = constant_48_over_17 + half_denominator *
1541 // constant_neg_32_over_17;
1542 auto op19_mul_half_denominator = CreateOpAndInfer<tosa::MulOp>(
1543 rewriter, op->getLoc(), int32_rsum_type, half_denominator,
1544 constant_neg_32_over_17, 31);
1545
1546 auto op20_add_op19 = CreateOpAndInfer<tosa::AddOp>(
1547 rewriter, op->getLoc(), int32_rsum_type,
1548 op19_mul_half_denominator.getResult(), constant_48_over_17);
1549
1550 // Newton-Raphson 3x iteration
1551 Value nr_x = op20_add_op19.getResult();
1552 for (int i = 0; i < 3; i++) {
1553 // half_denominator_times_x =
1554 // SaturatingRoundingDoublingHighMul(half_denominator, x)
1555 auto op21_mul_x_half_denominator = CreateOpAndInfer<tosa::MulOp>(
1556 rewriter, op->getLoc(), int32_rsum_type, nr_x, half_denominator,
1557 31);
1558
1559 // F2 one_minus_half_denominator_times_x = F2::One() -
1560 // half_denominator_times_x
1561 auto op22_sub_one_op21 = CreateOpAndInfer<tosa::SubOp>(
1562 rewriter, op->getLoc(), int32_rsum_type, F2_one,
1563 op21_mul_x_half_denominator.getResult());
1564
1565 // SaturatingRoundingDoublingHighMul(x,
1566 // one_minus_half_denominator_times_x)
1567 auto op23_mul_x_op22 = CreateOpAndInfer<tosa::MulOp>(
1568 rewriter, op->getLoc(), int32_rsum_type, nr_x,
1569 op22_sub_one_op21.getResult(), 31);
1570
1571 // x + Rescale<2>(x * one_minus_half_denominator_times_x)
1572 auto op24_mul_op23_four = CreateOpAndInfer<tosa::MulOp>(
1573 rewriter, op->getLoc(), int32_rsum_type,
1574 op23_mul_x_op22.getResult(), four, 0);
1575
1576 auto op25_add_x_op24 = CreateOpAndInfer<tosa::AddOp>(
1577 rewriter, op->getLoc(), int32_rsum_type, nr_x,
1578 op24_mul_op23_four.getResult());
1579
1580 nr_x = op25_add_x_op24.getResult();
1581 }
1582
1583 // Step 6. multiply exp(x) with 1 / sum(exp(x))
1584 // combined with Rescale<0>(ExactMulByPot<-1>(x))
1585 // so shift 30 instead of 31
1586 auto op26_mul_op13_x = CreateOpAndInfer<tosa::MulOp>(
1587 rewriter, op->getLoc(), int32_logits_type,
1588 op13_add_op12_op10.getResult(), nr_x, 31 - 1);
1589
1590 // Right shift amount is
1591 // num_bits_over_unit + 31 - (sizeof(OutputT) * 8 =
1592 // (12 - headroom_plus_one) + 31 - 8 =
1593 // (12 + 31 - 8) - headroom_plus_one
1594 auto op27_sub_op16 = CreateOpAndInfer<tosa::SubOp>(
1595 rewriter, op->getLoc(), int32_rsum_type,
1596 getTosaConstTensorSingleI32(rewriter, op, 12 + 31 - 8),
1597 op16_clz_op15.getResult());
1598
1599 auto op28_rshift_op26_op27 =
1600 CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
1601 rewriter, op->getLoc(), int32_logits_type,
1602 op26_mul_op13_x.getResult(), op27_sub_op16.getResult(), true);
1603
1604 return buildRescale(rewriter, op, output_type,
1605 op28_rshift_op26_op27.getResult(), 1.0, 0,
1606 out_quant_type.getZeroPoint(), false, true);
1607
1608 } else if (in_quant_type.getStorageTypeIntegralWidth() == 16) {
1609 // Step 1. get x - max(x)
1610 Value op1_rescale_in =
1611 buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f,
1612 in_quant_type.getZeroPoint(), 0, false, true);
1613
1614 auto op2_reducemax_op1 = CreateOpAndInfer<tosa::ReduceMaxOp>(
1615 rewriter, op->getLoc(), int32_rsum_type, op1_rescale_in,
1616 rewriter.getI64IntegerAttr(input_rank - 1));
1617
1618 // output range is [-65535, 0]
1619 auto op3_sub_op1_op2 = CreateOpAndInfer<tosa::SubOp>(
1620 rewriter, op->getLoc(), int32_logits_type, op1_rescale_in,
1621 op2_reducemax_op1.getResult());
1622
1623 auto exp_func = [](double x) -> double { return std::exp(x); };
1624
1625 // Follow TFLite reference: tensorflow/lite/kernels/activations.cc
1626 Value exp_table_const =
1627 getTosaConst16bitTable(rewriter, op, exp_func, -10.0, 0);
1628
1629 double input_diff_scale = in_quant_type.getScale() / (10.0 / 65535.0);
1630
1631 // Step 2. rescale input from [-65535, 0] to [-32768, 32767] for LUT input
1632 Value op4_rescale_op3 = buildRescale(
1633 rewriter, op, int32_logits_type, op3_sub_op1_op2.getResult(),
1634 /*scale=*/input_diff_scale, /*input_zp=*/0, /*output_zp=*/0,
1635 /*double_round=*/true, /*scale32=*/true);
1636 auto op5_add_op4 = CreateOpAndInfer<tosa::AddOp>(
1637 rewriter, op->getLoc(), int32_logits_type, op4_rescale_op3,
1638 getTosaConstTensorSingleI32(rewriter, op, 32767));
1639
1640 auto op6_cast_op5 = CreateOpAndInfer<tosa::CastOp>(
1641 rewriter, op->getLoc(), int16_logits_type, op5_add_op4.getResult());
1642
1643 // Step 3. get exp() result
1644 // Output is 15.7.
1645 // In 8-bit case, no interpolation here, since input should be right on
1646 // table entry.
1647 auto op7_table_op6 = CreateOpAndInfer<tosa::TableOp>(
1648 rewriter, op->getLoc(), int32_logits_type, op6_cast_op5,
1649 exp_table_const);
1650
1651 // Right shift 7 bits. output 15. Shouldn't lose any precision since last
1652 // 7 bits should be all 0.
1653 auto op8_rshift_op7 = CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
1654 rewriter, op->getLoc(), int32_logits_type, op7_table_op6.getResult(),
1655 getTosaConstTensorSingleI32(rewriter, op, 7), true);
1656
1657 // Step 4. get sum(exp()). output 16.15
1658 auto op9_reducesum_op8 = CreateOpAndInfer<tosa::ReduceSumOp>(
1659 rewriter, op->getLoc(), int32_rsum_type, op8_rshift_op7.getResult(),
1660 rewriter.getI64IntegerAttr(input_rank - 1));
1661
1662 // Step 5. calculate reciprocal(sum(exp()))
1663 // CLZ returns 32 - first non zero bit
1664 auto op10_clz_op9 =
1665 CreateOpAndInfer<tosa::ClzOp>(rewriter, op->getLoc(), int32_rsum_type,
1666 op9_reducesum_op8.getResult());
1667
1668 auto op11_sub_op10 = CreateOpAndInfer<tosa::SubOp>(
1669 rewriter, op->getLoc(), int32_rsum_type, op10_clz_op9.getResult(),
1670 getTosaConstTensorSingleI32(rewriter, op, 1));
1671
1672 // Left shift to get 1.30 format
1673 auto op12_lshift_op9_op11 = CreateOpAndInfer<tosa::LogicalLeftShiftOp>(
1674 rewriter, op->getLoc(), int32_rsum_type,
1675 op9_reducesum_op8.getResult(), op11_sub_op10.getResult());
1676
1677 // Subtract (1 << 30) to make 0 <= x <= 1 under 0.30 format
1678 auto op13_sub_op12 = CreateOpAndInfer<tosa::SubOp>(
1679 rewriter, op->getLoc(), int32_rsum_type,
1680 op12_lshift_op9_op11.getResult(),
1681 getTosaConstTensorSingleI32(rewriter, op, (1u << 30)));
1682
1683 // Right shift 14 bits to get output range [0, 65535]
1684 auto op14_rshift_op13 = CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
1685 rewriter, op->getLoc(), int32_rsum_type, op13_sub_op12.getResult(),
1686 getTosaConstTensorSingleI32(rewriter, op, 14), true);
1687
1688 // Remap input to [-32768, 32767] for LUT input
1689 auto op15_add_op14 = CreateOpAndInfer<tosa::SubOp>(
1690 rewriter, op->getLoc(), int32_rsum_type, op14_rshift_op13.getResult(),
1691 getTosaConstTensorSingleI32(rewriter, op, 32768));
1692 auto op16_cast_op15 = CreateOpAndInfer<tosa::CastOp>(
1693 rewriter, op->getLoc(), int16_rsum_type, op15_add_op14.getResult());
1694
1695 // Generate table for 1 / (1 + x), for 0 <= x <= 1
1696 auto one_over_one_plus_x_func = [](double x) -> double {
1697 return 1.0 / (1.0 + x);
1698 };
1699
1700 Value one_over_one_plus_x_table_const = getTosaConst16bitTable(
1701 rewriter, op, one_over_one_plus_x_func, 0.0, 1.0);
1702
1703 // Get (1 / sum(exp(x))) result as 23 bits (including sign bit)
1704 auto op17_table_op16 = CreateOpAndInfer<tosa::TableOp>(
1705 rewriter, op->getLoc(), int32_rsum_type, op16_cast_op15,
1706 one_over_one_plus_x_table_const);
1707
1708 // Right shift 7 bits back to 0.15
1709 auto op18_rshift_op17 = CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
1710 rewriter, op->getLoc(), int32_rsum_type, op17_table_op16.getResult(),
1711 getTosaConstTensorSingleI32(rewriter, op, 7), true);
1712
1713 // Step 6. multiply exp(max-x) with 1 / sum(exp(max-x))
1714 // lhs: 0.15, rhs: 0.15, output: 0.30
1715 auto op19_mul_op18_op8 = CreateOpAndInfer<tosa::MulOp>(
1716 rewriter, op->getLoc(), int32_logits_type, op18_rshift_op17,
1717 op8_rshift_op7, 0);
1718
1719 auto op20_sub_op10 = CreateOpAndInfer<tosa::SubOp>(
1720 rewriter, op->getLoc(), int32_rsum_type,
1721 getTosaConstTensorSingleI32(rewriter, op, 31),
1722 op10_clz_op9.getResult());
1723
1724 // Apply the clz back, we get 0.15 output
1725 // [0, 32767] corresponding to [0.0, 1.0]
1726 auto op21_rshift_op19_op20 =
1727 CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
1728 rewriter, op->getLoc(), int32_logits_type,
1729 op19_mul_op18_op8.getResult(), op20_sub_op10.getResult(), true);
1730
1731 return buildRescale(rewriter, op, output_type,
1732 op21_rshift_op19_op20.getResult(),
1733 (1.0 / out_quant_type.getScale()) * (1.0 / 32768.0),
1734 0, out_quant_type.getZeroPoint(), false, true);
1735 } else {
1736 op->emitOpError("Softmax: unknown quantization bitwidth");
1737 return llvm::None;
1738 }
1739 } else {
1740 SmallVector<int64_t> rsum_shape_v(input_type.getShape().begin(),
1741 input_type.getShape().end());
1742 rsum_shape_v[input_rank - 1] = 1;
1743 ArrayRef<int64_t> rsum_shape(rsum_shape_v);
1744
1745 // Floating-point loewring is more direct:
1746 //
1747 // op1 = exp(logits)
1748 // op2 = reduce_sum(op1, -1)
1749 // op3 = reciprocal(op2)
1750 // op4 = mul(op1, op3)
1751 auto op1_exp_in = CreateOpAndInfer<tosa::ExpOp>(rewriter, op->getLoc(),
1752 output_type, logits_value);
1753 RankedTensorType rsum_type =
1754 RankedTensorType::get(rsum_shape, output_type.getElementType());
1755
1756 // Keep dims so we don't need to reshape later
1757 auto op2_reducesum_op1 = CreateOpAndInfer<tosa::ReduceSumOp>(
1758 rewriter, op->getLoc(), rsum_type, op1_exp_in.getResult(),
1759 rewriter.getI64IntegerAttr(input_rank - 1));
1760 auto op3_reciprocal_op2 = CreateOpAndInfer<tosa::ReciprocalOp>(
1761 rewriter, op->getLoc(), op2_reducesum_op1.getType(),
1762 op2_reducesum_op1.getResult());
1763
1764 return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
1765 op1_exp_in.getResult(),
1766 op3_reciprocal_op2.getResult(), 0)
1767 .getResult();
1768 }
1769 }
1770
1771 // Lowers LogSoftmax to a sequence of TOSA ops.
convertLogSoftmaxOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value logits_value)1772 llvm::Optional<Value> convertLogSoftmaxOp(PatternRewriter& rewriter,
1773 Operation* op, Value result_value,
1774 Value logits_value) {
1775 // log_softmax = log(exp(logits) / reduce_sum(exp(logits), -1))
1776 // op1 = exp(logits)
1777 // op2 = reduce_sum(op1, -1)
1778 // op3 = reciprocal(op2)
1779 // op4 = mul(op1, op3)
1780 // op5 = log(op4)
1781
1782 TensorType output_type = result_value.getType().dyn_cast<TensorType>();
1783 // Not a tensor output
1784 if (!output_type) {
1785 op->emitOpError("LogSoftmax: output type not tensor.");
1786 return llvm::None;
1787 }
1788
1789 RankedTensorType input_type =
1790 op->getOperand(0).getType().dyn_cast<RankedTensorType>();
1791 if (!input_type) {
1792 op->emitOpError("LogSoftmax: input type not ranked tensor.");
1793 return llvm::None;
1794 }
1795
1796 mlir::quant::UniformQuantizedType in_quant_type =
1797 input_type.getElementType()
1798 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
1799 mlir::quant::UniformQuantizedType out_quant_type =
1800 output_type.getElementType()
1801 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
1802 if (in_quant_type || out_quant_type) {
1803 op->emitOpError("Quantized log_softmax lowering not implemented yet");
1804 return llvm::None;
1805 }
1806
1807 auto op1_exp_in = CreateOpAndInfer<tosa::ExpOp>(rewriter, op->getLoc(),
1808 output_type, logits_value);
1809
1810 // reduce_sum on last dimension
1811 int32_t input_rank = input_type.getShape().size();
1812 // Keep dims so we don't need to reshape later
1813 auto op2_reducesum_op1 = CreateOpAndInfer<tosa::ReduceSumOp>(
1814 rewriter, op->getLoc(),
1815 UnrankedTensorType::get(output_type.getElementType()),
1816 op1_exp_in.getResult(), rewriter.getI64IntegerAttr(input_rank - 1));
1817 auto op3_reciprocal_op2 = CreateOpAndInfer<tosa::ReciprocalOp>(
1818 rewriter, op->getLoc(), op2_reducesum_op1.getType(),
1819 op2_reducesum_op1.getResult());
1820
1821 auto op4_mul_op1_op3 = CreateOpAndInfer<tosa::MulOp>(
1822 rewriter, op->getLoc(), output_type, op1_exp_in.getResult(),
1823 op3_reciprocal_op2.getResult(), 0);
1824
1825 return CreateOpAndInfer<tosa::LogOp>(rewriter, op->getLoc(), output_type,
1826 op4_mul_op1_op3.getResult())
1827 .getResult();
1828 }
1829
1830 // Lowers SpaceToDepth to a sequence of TOSA ops. Supports NHWC.
convertSpaceToDepthOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,IntegerAttr block_size_attr,StringAttr data_format)1831 llvm::Optional<Value> convertSpaceToDepthOp(PatternRewriter& rewriter,
1832 Operation* op, Value result_value,
1833 Value input_value,
1834 IntegerAttr block_size_attr,
1835 StringAttr data_format) {
1836 // NHWC lowering version:
1837 // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1]//b, b, orig_shape[2]//b,
1838 // b, orig_shape[3]])
1839 // a3 = tf.transpose(a2, [0, 1, 3, 2, 4, 5])
1840 // a4 = tf.reshape(a3, [orig_shape[0], orig_shape[1]//b, orig_shape[2]//b,
1841 // orig_shape[3]*b*b])
1842 // return a4
1843 RankedTensorType output_type =
1844 result_value.getType().dyn_cast<RankedTensorType>();
1845
1846 // Not a ranked tensor output.
1847 if (!output_type) {
1848 op->emitOpError("SpaceToDepth: output type not ranked tensor.");
1849 return llvm::None;
1850 }
1851
1852 RankedTensorType input_type =
1853 input_value.getType().dyn_cast<RankedTensorType>();
1854 if (!input_type) {
1855 op->emitOpError("SpaceToDepth: input type not ranked tensor.");
1856 return llvm::None;
1857 }
1858
1859 if (input_type.getRank() != 4) {
1860 op->emitOpError("SpaceToDepth: input rank not 4.");
1861 return llvm::None;
1862 }
1863
1864 auto input_shape = input_type.getShape();
1865
1866 if (!block_size_attr) { // This is a required parameter
1867 op->emitOpError("SpaceToDepth: block size attribute not set.");
1868 return llvm::None;
1869 }
1870
1871 SmallVector<int64_t, 2> block_size;
1872 block_size.assign(2, block_size_attr.getInt());
1873
1874 if (!data_format) data_format = rewriter.getStringAttr("NHWC");
1875
1876 if (data_format.getValue().str() != "NHWC") {
1877 op->emitOpError("SpaceToDepth: data format not NHWC.");
1878 return llvm::None;
1879 }
1880
1881 assert(block_size[0] * block_size[1] != 0);
1882
1883 SmallVector<int64_t, 6> a_reshape_dims;
1884 a_reshape_dims.push_back(input_shape[0]);
1885 a_reshape_dims.push_back(input_shape[1] / block_size[0]);
1886 a_reshape_dims.push_back(block_size[0]);
1887 a_reshape_dims.push_back(input_shape[2] / block_size[1]);
1888 a_reshape_dims.push_back(block_size[1]);
1889 a_reshape_dims.push_back(input_shape[3]);
1890
1891 RankedTensorType a_reshape_output_type =
1892 RankedTensorType::get(a_reshape_dims, output_type.getElementType());
1893 auto a2_reshape_a_op = CreateOpAndInfer<tosa::ReshapeOp>(
1894 rewriter, op->getLoc(), a_reshape_output_type, input_value,
1895 rewriter.getI64ArrayAttr(a_reshape_dims));
1896
1897 llvm::Optional<Value> a3_transpose_perm = getConstTensor<int32_t>(
1898 rewriter, op, /*vec=*/{0, 1, 3, 2, 4, 5}, /*shape=*/{6});
1899
1900 if (!a3_transpose_perm) return llvm::None;
1901
1902 auto a3_transpose_a2_op = CreateOpAndInfer<tosa::TransposeOp>(
1903 rewriter, op->getLoc(), a_reshape_output_type,
1904 a2_reshape_a_op.getResult(), a3_transpose_perm.getValue());
1905
1906 SmallVector<int64_t, 4> a3_reshape_dims;
1907 a3_reshape_dims.push_back(input_shape[0]);
1908 a3_reshape_dims.push_back(input_shape[1] / block_size[0]);
1909 a3_reshape_dims.push_back(input_shape[2] / block_size[1]);
1910 a3_reshape_dims.push_back(input_shape[3] * block_size[0] * block_size[1]);
1911
1912 RankedTensorType a3_reshape_output_type =
1913 RankedTensorType::get(a3_reshape_dims, output_type.getElementType());
1914 return CreateOpAndInfer<tosa::ReshapeOp>(
1915 rewriter, op->getLoc(), a3_reshape_output_type,
1916 a3_transpose_a2_op.getResult(),
1917 rewriter.getI64ArrayAttr(a3_reshape_dims))
1918 .getResult();
1919 }
1920
1921 // Lowers DepthToSpace to a sequence of TOSA ops. Supports NHWC.
convertDepthToSpaceOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,IntegerAttr block_size_attr,StringAttr data_format)1922 llvm::Optional<Value> convertDepthToSpaceOp(PatternRewriter& rewriter,
1923 Operation* op, Value result_value,
1924 Value input_value,
1925 IntegerAttr block_size_attr,
1926 StringAttr data_format) {
1927 // NHWC version
1928 // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1], orig_shape[2], b, b,
1929 // orig_shape[3] // (b*b)])
1930 // a3 = tf.transpose(a2, [0, 1, 3, 2, 4, 5])
1931 // a4 = tf.reshape(a3, [orig_shape[0], orig_shape[1] * b, orig_shape[2] * b,
1932 // orig_shape[3] // (b*b)])
1933 // return a4
1934
1935 RankedTensorType output_type =
1936 result_value.getType().dyn_cast<RankedTensorType>();
1937
1938 // Not a ranked tensor output
1939 if (!output_type) {
1940 op->emitOpError("DepthToSpace: output type not ranked tensor.");
1941 return llvm::None;
1942 }
1943
1944 RankedTensorType input_type =
1945 input_value.getType().dyn_cast<RankedTensorType>();
1946 if (!input_type) {
1947 op->emitOpError("DepthToSpace: input type not ranked tensor.");
1948 return llvm::None;
1949 }
1950
1951 if (input_type.getRank() != 4) return llvm::None;
1952 auto input_shape = input_type.getShape();
1953
1954 if (!block_size_attr) { // This is a required parameter
1955 op->emitOpError("DepthToSpace: block size attribute not set.");
1956 return llvm::None;
1957 }
1958
1959 SmallVector<int64_t, 2> block_size;
1960 block_size.assign(2, block_size_attr.getInt());
1961
1962 if (!data_format) data_format = rewriter.getStringAttr("NHWC");
1963 if (data_format.getValue().str() != "NHWC") {
1964 op->emitOpError("DepthToSpace: data format not NHWC.");
1965 return llvm::None;
1966 }
1967
1968 assert(block_size[0] * block_size[1] != 0);
1969
1970 SmallVector<int64_t, 6> a_reshape_dims;
1971 a_reshape_dims.push_back(input_shape[0]);
1972 a_reshape_dims.push_back(input_shape[1]);
1973 a_reshape_dims.push_back(input_shape[2]);
1974 a_reshape_dims.push_back(block_size[0]);
1975 a_reshape_dims.push_back(block_size[1]);
1976 a_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
1977
1978 RankedTensorType a_reshape_output_type =
1979 RankedTensorType::get(a_reshape_dims, output_type.getElementType());
1980 auto a2_reshape_a_op = CreateOpAndInfer<tosa::ReshapeOp>(
1981 rewriter, op->getLoc(), a_reshape_output_type, input_value,
1982 rewriter.getI64ArrayAttr(a_reshape_dims));
1983
1984 llvm::Optional<Value> a3_transpose_perm = getConstTensor<int32_t>(
1985 rewriter, op, /*vec=*/{0, 1, 3, 2, 4, 5}, /*shape=*/{6});
1986
1987 if (!a3_transpose_perm) return llvm::None;
1988
1989 auto a3_transpose_a2_op = CreateOpAndInfer<tosa::TransposeOp>(
1990 rewriter, op->getLoc(), a_reshape_output_type,
1991 a2_reshape_a_op.getResult(), a3_transpose_perm.getValue());
1992
1993 SmallVector<int64_t, 4> a3_reshape_dims;
1994 a3_reshape_dims.push_back(input_shape[0]);
1995 a3_reshape_dims.push_back(input_shape[1] * block_size[0]);
1996 a3_reshape_dims.push_back(input_shape[2] * block_size[1]);
1997 a3_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
1998
1999 RankedTensorType a3_reshape_output_type =
2000 RankedTensorType::get(a3_reshape_dims, output_type.getElementType());
2001 return CreateOpAndInfer<tosa::ReshapeOp>(
2002 rewriter, op->getLoc(), a3_reshape_output_type,
2003 a3_transpose_a2_op.getResult(),
2004 rewriter.getI64ArrayAttr(a3_reshape_dims))
2005 .getResult();
2006 }
2007
2008 // Lowers Split to a sequence of TOSA ops.
convertSplitOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,int32_t num_split,int32_t axis)2009 llvm::Optional<SmallVector<Value>> convertSplitOp(
2010 PatternRewriter& rewriter, Operation* op, Value result_value,
2011 Value input_value, int32_t num_split, int32_t axis) {
2012 // This lowering creates num_split slice ops and ties them together
2013 // with IdentityN to get from an array of Operations to a single Operation
2014 // with a list of result tensors.
2015 RankedTensorType result_type =
2016 result_value.getType().dyn_cast<RankedTensorType>();
2017 // Not a ranked tensor output
2018 if (!result_type) {
2019 op->emitOpError("Split: output type not ranked tensor.");
2020 return llvm::None;
2021 }
2022
2023 RankedTensorType input_type =
2024 input_value.getType().dyn_cast<RankedTensorType>();
2025 if (!input_type) {
2026 op->emitOpError("Split: input type not ranked tensor.");
2027 return llvm::None;
2028 }
2029
2030 Type etype = input_type.getElementType();
2031
2032 if (axis < 0) axis += input_type.getRank();
2033 assert(num_split > 0);
2034 assert(axis >= 0 && axis < input_type.getRank());
2035
2036 Value slice_value = input_value;
2037 bool is_dyn_split = input_type.getDimSize(axis) == -1;
2038 if (is_dyn_split) {
2039 SmallVector<int64_t> new_shape;
2040 for (int i = 0, s = input_type.getRank(); i < s; i++) {
2041 if (i != axis) {
2042 new_shape.push_back(input_type.getDimSize(i));
2043 continue;
2044 }
2045
2046 new_shape.push_back(num_split);
2047 new_shape.push_back(-1);
2048 }
2049 slice_value = CreateOpAndInfer<tosa::ReshapeOp>(
2050 rewriter, op->getLoc(), RankedTensorType::get(new_shape, etype),
2051 input_value, rewriter.getI64ArrayAttr(new_shape));
2052 }
2053
2054 RankedTensorType slice_type = slice_value.getType().cast<RankedTensorType>();
2055 assert((slice_type.getDimSize(axis) % num_split) == 0);
2056
2057 // Each slice has a different beginning point.
2058 // The slice size is actually the same each op.
2059 SmallVector<int64_t> begin_vals, size_vals;
2060 for (int j = 0, s = slice_type.getRank(); j < s; j++) {
2061 begin_vals.push_back(0);
2062 size_vals.push_back(slice_type.getDimSize(j));
2063 }
2064 size_vals[axis] = size_vals[axis] / num_split;
2065
2066 SmallVector<Value> results_vec;
2067 for (int i = 0; i < num_split; i++) {
2068 begin_vals[axis] = i * size_vals[axis];
2069 ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
2070 ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
2071
2072 Value result = CreateOpAndInfer<tosa::SliceOp>(
2073 rewriter, op->getLoc(),
2074 RankedTensorType::get(size_vals, result_type.getElementType()),
2075 slice_value, begin, size);
2076
2077 if (is_dyn_split) {
2078 SmallVector<int64_t> out_reshape_shape;
2079 for (int i = 0, s = size_vals.size(); i < s; i++)
2080 if (i != axis) out_reshape_shape.push_back(size_vals[i]);
2081
2082 result = CreateOpAndInfer<tosa::ReshapeOp>(
2083 rewriter, op->getLoc(),
2084 RankedTensorType::get(out_reshape_shape, etype), result,
2085 rewriter.getI64ArrayAttr(out_reshape_shape))
2086 .getResult();
2087 }
2088
2089 results_vec.push_back(result);
2090 }
2091
2092 return results_vec;
2093 }
2094
2095 // Lowers SplitV to a sequence of TOSA ops.
convertSplitVOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,SmallVectorImpl<int32_t> & size_split,int32_t axis)2096 llvm::Optional<SmallVector<Value>> convertSplitVOp(
2097 PatternRewriter& rewriter, Operation* op, Value result_value,
2098 Value input_value, SmallVectorImpl<int32_t>& size_split, int32_t axis) {
2099 // This lowering creates num_split slice ops and ties them together
2100 // with IdentityN to get from an array of Operations to a single Operation
2101 // with a list of result tensors.
2102 RankedTensorType result_type =
2103 result_value.getType().dyn_cast<RankedTensorType>();
2104 // Not a ranked tensor output
2105 if (!result_type) {
2106 op->emitOpError("SplitV: output type not ranked tensor.");
2107 return llvm::None;
2108 }
2109
2110 RankedTensorType input_type =
2111 input_value.getType().dyn_cast<RankedTensorType>();
2112 if (!input_type) {
2113 op->emitOpError("SplitV: input type not ranked tensor.");
2114 return llvm::None;
2115 }
2116
2117 auto input_shape = input_type.getShape();
2118
2119 SmallVector<Value> results_vec;
2120
2121 if ((axis >= 0 && axis >= input_shape.size()) ||
2122 (axis < 0 && axis < -input_shape.size())) {
2123 op->emitOpError("SplitV: invalid axis value.");
2124 return llvm::None;
2125 }
2126 axis = adjust_axis(axis, input_shape.size());
2127 int32_t size_split_sum = 0;
2128 for (int i = 0; i < size_split.size(); i++) {
2129 size_split_sum += size_split[i];
2130 }
2131
2132 // The split sizes must sum up to the size of the axis being split
2133 assert(size_split_sum == input_shape[axis]);
2134
2135 int32_t curr_split_start = 0;
2136 for (int i = 0; i < size_split.size(); i++) {
2137 // Each slice has a different beginning point.
2138 // The slice size is different for each op.
2139 SmallVector<int64_t> begin_vals, size_vals;
2140
2141 for (int j = 0; j < input_shape.size(); j++) {
2142 if (j == axis) {
2143 begin_vals.push_back(curr_split_start);
2144 size_vals.push_back(size_split[i]);
2145 } else {
2146 begin_vals.push_back(0);
2147 size_vals.push_back(input_shape[j]);
2148 }
2149 }
2150
2151 ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
2152 ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
2153
2154 auto slice_op = CreateOpAndInfer<tosa::SliceOp>(
2155 rewriter, op->getLoc(),
2156 RankedTensorType::get(size_vals, result_type.getElementType()),
2157 input_value, begin, size);
2158
2159 results_vec.push_back(slice_op.getResult());
2160
2161 // Next start position
2162 curr_split_start += size_split[i];
2163 }
2164
2165 return results_vec;
2166 }
2167
2168 // Helper function to reverse negative striding. Only checks for -1 as that is
2169 // the only legal negative stride.
reverseNegativeStride(PatternRewriter & rewriter,Operation * op,Value input,ArrayRef<int32_t> strides)2170 static Value reverseNegativeStride(PatternRewriter& rewriter, Operation* op,
2171 Value input, ArrayRef<int32_t> strides) {
2172 Type reverse_ty = UnrankedTensorType::get(
2173 input.getType().cast<ShapedType>().getElementType());
2174 for (auto it : llvm::enumerate(strides)) {
2175 auto axis = it.index();
2176 auto stride = it.value();
2177 if (stride != -1) continue;
2178
2179 input = CreateOpAndInfer<tosa::ReverseOp>(rewriter, op->getLoc(),
2180 reverse_ty, input,
2181 rewriter.getI64IntegerAttr(axis))
2182 .getResult();
2183 }
2184
2185 return input;
2186 }
2187
2188 // Lowers StridedSlice to a sequence of TOSA ops.
convertStridedSliceOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value begin_value,Value end_value,Value strides_value,int32_t begin_mask,int32_t end_mask,int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask)2189 llvm::Optional<Value> convertStridedSliceOp(
2190 PatternRewriter& rewriter, Operation* op, Value result_value,
2191 Value input_value, Value begin_value, Value end_value, Value strides_value,
2192 int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
2193 int32_t new_axis_mask, int32_t shrink_axis_mask) {
2194 // The mask arguments are bitmasks where bit [i] applies to
2195 // dimension [i] of the input tensor.
2196 //
2197 // The rough algorithm for lowering strided slice is as follows:
2198 //
2199 // 0. Process begin/end masks, since they are basically syntactic sugar
2200 // on top of the begin_value/end_value arrays
2201 //
2202 // 1. Slice1: Ignoring stride, slice the interesting range from the input
2203 // tensor
2204 //
2205 // 2. Reshape2: Reshape the tensor from (1) such that each dimension with
2206 // stride is split into two dimensions of size_i/stride_i, stride_i. A naive
2207 // implementation doubles the input tensor rank, but only dimensions being
2208 // strided actually need to be doubled.
2209 //
2210 // 3. Slice3: Slice the tensor from (2) such that we select index [0] from
2211 // each of the stride_i dimensions in (2)
2212 //
2213 // 4. Reshape4: Reshape the tensor to eliminate the stride_i dimensions, add
2214 // any dimensions in new_axis_mask and remove any dimensions in the
2215 // shrink_axis_mask
2216
2217 // Limitations:
2218 // * This implementation only supports ellipsis_mask=0 for now
2219 // * This implementation does not support reverse stride yet. Will need
2220 // to insert tosa.Reverse operators for this.
2221 if (ellipsis_mask != 0) {
2222 (void)rewriter.notifyMatchFailure(op, "ellipses mask not supported yet");
2223 }
2224
2225 ShapedType input_type = input_value.getType().cast<ShapedType>();
2226 ShapedType result_type = result_value.getType().cast<ShapedType>();
2227
2228 // Extract the begin/end/stride tensors
2229 SmallVector<int32_t> begin, end, strides;
2230
2231 DenseIntElementsAttr strides_attr;
2232
2233 if (!matchPattern(strides_value, m_Constant(&strides_attr))) {
2234 (void)rewriter.notifyMatchFailure(op, "strides is not a constant");
2235 return llvm::None;
2236 }
2237
2238 if (failed(getVectorFromValue32(strides_value, strides))) {
2239 (void)rewriter.notifyMatchFailure(op, "strides isn't a constant");
2240 return llvm::None;
2241 }
2242
2243 // Current configuration does not support negative strides greater than 1.
2244 // Bail out for now (fix if this proves to be legal).
2245 for (auto stride : strides)
2246 if (stride < -1) return llvm::None;
2247
2248 bool all_strides_one = true;
2249 for (auto stride : strides) all_strides_one &= abs(stride) == 1;
2250
2251 int32_t strides_size = strides_attr.getNumElements();
2252
2253 // If all of the masks are set we can just bypass the entire thing.
2254 const int32_t all_masks_one = (1 << strides_size) - 1;
2255
2256 if (failed(getVectorFromValue32(begin_value, begin)) &&
2257 begin_mask != all_masks_one) {
2258 (void)rewriter.notifyMatchFailure(op, "begin isn't a constant");
2259 return llvm::None;
2260 }
2261
2262 if (failed(getVectorFromValue32(end_value, end)) &&
2263 end_mask != all_masks_one) {
2264 (void)rewriter.notifyMatchFailure(op, "end isn't a constant");
2265 return llvm::None;
2266 }
2267
2268 // If begin value is 0 we can set the begin_mask instead..
2269 for (auto val : llvm::enumerate(begin)) {
2270 if (val.value() == 0) begin_mask |= (0x1 << val.index());
2271 }
2272
2273 // If end value is -1 we can set the end_mask instead.
2274 for (const auto& val : llvm::enumerate(end)) {
2275 if (val.value() == -1) end_mask |= (0x1 << val.index());
2276 }
2277
2278 if (all_strides_one && begin_mask == all_masks_one &&
2279 end_mask == all_masks_one) {
2280 return reverseNegativeStride(rewriter, op, input_value, strides);
2281 }
2282
2283 if (!input_type.hasRank()) {
2284 return (void)rewriter.notifyMatchFailure(op,
2285 "input type not ranked tensor."),
2286 llvm::None;
2287 }
2288
2289 int32_t input_rank = input_type.getRank();
2290
2291 // If strides is incomplete, pad out to the full size.
2292 while (strides.size() < input_rank) strides.push_back(1);
2293
2294 // Unspecified begins should set the begin mask.
2295 while (begin.size() < input_rank) {
2296 begin_mask = begin_mask | (1 << begin.size());
2297 begin.push_back(0);
2298 }
2299
2300 // Unspecified ends should set the end mask.
2301 while (end.size() < input_rank) {
2302 end_mask = end_mask | (1 << end.size());
2303 end.push_back(-1);
2304 }
2305
2306 auto input_shape = input_type.getShape();
2307
2308 // Step 0: Process the begin/end masks and build the begin/sizes for the
2309 // first slice
2310 SmallVector<int64_t> a1_begin(input_rank), a1_size(input_rank);
2311 for (int i = 0; i < input_rank; ++i) {
2312 // Mask-bit overrides begin/end values.
2313 if (begin_mask & (1 << i)) begin[i] = 0;
2314 if (end_mask & (1 << i)) end[i] = input_shape[i];
2315
2316 if (a1_begin[i] < 0 && ShapedType::isDynamic(input_shape[i])) {
2317 (void)rewriter.notifyMatchFailure(
2318 op, "begin offset is negative on dynamic size.");
2319 return llvm::None;
2320 }
2321
2322 // Wrap around index if begin and end is negative
2323 a1_begin[i] = begin[i];
2324 if (a1_begin[i] < 0) a1_begin[i] += input_shape[i];
2325
2326 if (ShapedType::isDynamic(end[i]) &&
2327 ShapedType::isDynamic(input_shape[i])) {
2328 // Slice using -1 as TOSA's sentinal value.
2329 a1_size[i] = -1;
2330 } else if (end[i] < 0 && ShapedType::isDynamic(input_shape[i])) {
2331 // Other dynamic cases cannot be handled.
2332 (void)rewriter.notifyMatchFailure(
2333 op, "input dim is dynamic and slice end depends on the length.");
2334 return llvm::None;
2335 } else {
2336 a1_size[i] = end[i] - a1_begin[i];
2337 if (end[i] < 0) a1_size[i] += input_shape[i];
2338 }
2339
2340 // Shrink axis mask means we know the size and stride are 1.
2341 if (shrink_axis_mask & (1 << i)) {
2342 a1_size[i] = 1;
2343 strides[i] = 1;
2344 }
2345 }
2346
2347 // Step 1: Slice the input array
2348 auto a1_slice_op = CreateOpAndInfer<tosa::SliceOp>(
2349 rewriter, op->getLoc(),
2350 RankedTensorType::get(a1_size, input_type.getElementType()), input_value,
2351 rewriter.getI64ArrayAttr(a1_begin), rewriter.getI64ArrayAttr(a1_size));
2352
2353 if (all_strides_one) {
2354 auto reversed =
2355 reverseNegativeStride(rewriter, op, a1_slice_op.getResult(), strides);
2356 auto shape = reversed.getType().cast<RankedTensorType>().getShape();
2357
2358 SmallVector<int64_t> new_shape;
2359 for (int i = 0; i < input_rank; ++i) {
2360 if (!(shrink_axis_mask & (1 << i))) {
2361 if (new_axis_mask & (1 << i)) new_shape.push_back(1);
2362 new_shape.push_back((shape[i]));
2363 }
2364 }
2365
2366 return CreateOpAndInfer<tosa::ReshapeOp>(
2367 rewriter, op->getLoc(), result_type, reversed,
2368 rewriter.getI64ArrayAttr(new_shape))
2369 .getResult();
2370 }
2371
2372 // Step 2: reshape the sliced array
2373 SmallVector<int64_t> a2_shape(input_rank * 2);
2374 for (int i = 0; i < input_rank; ++i) {
2375 a2_shape[i * 2 + 0] = a1_size[i] == -1 ? -1 : a1_size[i] / abs(strides[i]);
2376 a2_shape[i * 2 + 1] = abs(strides[i]);
2377 }
2378
2379 auto a2_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
2380 rewriter, op->getLoc(),
2381 RankedTensorType::get(a2_shape, input_type.getElementType()),
2382 a1_slice_op.getResult(), rewriter.getI64ArrayAttr(a2_shape));
2383
2384 // Step 3: take a slice along the strides
2385 SmallVector<int64_t> a3_begin(input_rank * 2), a3_size(input_rank * 2);
2386 for (int i = 0; i < input_rank; ++i) {
2387 a3_begin[i * 2 + 0] = 0;
2388 a3_begin[i * 2 + 1] = 0;
2389
2390 if (shrink_axis_mask & (1 << i)) {
2391 a3_size[i * 2 + 0] = 1;
2392 } else {
2393 a3_size[i * 2 + 0] =
2394 (a1_size[i] == -1) ? -1 : (a1_size[i] / abs(strides[i]));
2395 }
2396 a3_size[i * 2 + 1] = 1;
2397 }
2398
2399 auto a3_slice_op = CreateOpAndInfer<tosa::SliceOp>(
2400 rewriter, op->getLoc(),
2401 RankedTensorType::get(a3_size, input_type.getElementType()),
2402 a2_reshape_op.getResult(), rewriter.getI64ArrayAttr(a3_begin),
2403 rewriter.getI64ArrayAttr(a3_size));
2404
2405 // Step 4: reshape the now-strided tensor
2406 SmallVector<int64_t> a4_shape;
2407 for (int i = 0; i < input_rank; ++i) {
2408 if (!(shrink_axis_mask & (1 << i))) {
2409 if (new_axis_mask & (1 << i)) a4_shape.push_back(1);
2410 a4_shape.push_back(
2411 ((a1_size[i] == -1) ? -1 : (a1_size[i] / abs(strides[i]))));
2412 }
2413 }
2414
2415 auto a4_reshape_op =
2416 CreateOpAndInfer<tosa::ReshapeOp>(rewriter, op->getLoc(), result_type,
2417 a3_slice_op.getResult(),
2418 rewriter.getI64ArrayAttr(a4_shape))
2419 .getResult();
2420
2421 return reverseNegativeStride(rewriter, op, a4_reshape_op, strides);
2422 }
2423
2424 // Lowers FloorDiv to a sequence of TOSA operators.
convertFloorDivOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value lhs_value,Value rhs_value)2425 llvm::Optional<Value> convertFloorDivOp(PatternRewriter& rewriter,
2426 Operation* op, Value result_value,
2427 Value lhs_value, Value rhs_value) {
2428 // FloorDiv lowering:
2429 // floor(1/rhs * lhs)
2430 //
2431 // a1 = reciprocal(rhs);
2432 // a2 = mul(lhs, a1);
2433 // a3 = floor(a2);
2434 // return a3;
2435 ShapedType output_type = result_value.getType().dyn_cast<ShapedType>();
2436 // Not a shaped tensor output
2437 if (!output_type) return llvm::None;
2438
2439 Type element_type = output_type.getElementType();
2440
2441 if (element_type.isa<IntegerType>()) {
2442 return CreateOpAndInfer<tosa::DivOp>(rewriter, op->getLoc(), output_type,
2443 lhs_value, rhs_value)
2444 .getResult();
2445 }
2446
2447 auto a1_reciprocal_rhs_op = CreateOpAndInfer<tosa::ReciprocalOp>(
2448 rewriter, op->getLoc(), rhs_value.getType(), rhs_value);
2449 auto a2_mul_lhs_a1_op = CreateOpAndInfer<tosa::MulOp>(
2450 rewriter, op->getLoc(), output_type, lhs_value,
2451 a1_reciprocal_rhs_op.getResult(), 0);
2452 return CreateOpAndInfer<tosa::FloorOp>(rewriter, op->getLoc(), output_type,
2453 a2_mul_lhs_a1_op.getResult())
2454 .getResult();
2455 }
2456
2457 // Lowers FloorMod to a sequence of TOSA operators.
convertFloorModOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value lhs_value,Value rhs_value)2458 llvm::Optional<Value> convertFloorModOp(PatternRewriter& rewriter,
2459 Operation* op, Value result_value,
2460 Value lhs_value, Value rhs_value) {
2461 // FloorMod lowering:
2462 // (1/rhs * lhs) - floor(1/rhs * lhs)
2463 // a1 = reciprocal(rhs);
2464 // a2 = mul(lhs, a1);
2465 // a3 = floor(a2);
2466 // a4 = sub(a2, a3);
2467 // return a4;
2468
2469 RankedTensorType output_type =
2470 result_value.getType().dyn_cast<RankedTensorType>();
2471 // Not a ranked tensor output
2472 if (!output_type) return llvm::None;
2473
2474 auto a1_reciprocal_rhs_op = CreateOpAndInfer<tosa::ReciprocalOp>(
2475 rewriter, op->getLoc(), rhs_value.getType(), rhs_value);
2476 auto a2_mul_lhs_a1_op = CreateOpAndInfer<tosa::MulOp>(
2477 rewriter, op->getLoc(), output_type, lhs_value,
2478 a1_reciprocal_rhs_op.getResult(), 0);
2479 auto a3_floor_a2_op = CreateOpAndInfer<tosa::FloorOp>(
2480 rewriter, op->getLoc(), output_type, a2_mul_lhs_a1_op.getResult());
2481 return CreateOpAndInfer<tosa::SubOp>(rewriter, op->getLoc(), output_type,
2482 a2_mul_lhs_a1_op.getResult(),
2483 a3_floor_a2_op.getResult())
2484 .getResult();
2485 }
2486
2487 // Lowers FusedActivation to a sequence of TOSA ops.
convertFusedActivation(PatternRewriter & rewriter,Operation * op,Value input_value,StringAttr fused_activation_fn)2488 llvm::Optional<Value> convertFusedActivation(PatternRewriter& rewriter,
2489 Operation* op, Value input_value,
2490 StringAttr fused_activation_fn) {
2491 ShapedType input_type = input_value.getType().dyn_cast<ShapedType>();
2492 if (!input_type) return llvm::None;
2493
2494 bool input_is_qtype =
2495 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2496
2497 if (input_is_qtype) {
2498 // We can always make output/input tensor's scale/zp always be the same
2499 // when legalizing fused_activation_function, as it's generated during
2500 // legalization.
2501 auto input_qtype =
2502 input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2503
2504 if (fused_activation_fn.getValue() == "NONE") {
2505 return input_value;
2506 } else if (fused_activation_fn.getValue() == "RELU") {
2507 int32_t quantized_0 = input_qtype.getZeroPoint();
2508 int32_t quantized_max = input_qtype.getStorageTypeMax();
2509
2510 auto clamp_op = CreateOpAndInfer<tosa::ClampOp>(
2511 rewriter, op->getLoc(), input_type, input_value,
2512 rewriter.getI64IntegerAttr(quantized_0),
2513 rewriter.getI64IntegerAttr(quantized_max),
2514 rewriter.getF32FloatAttr(0), rewriter.getF32FloatAttr(0));
2515
2516 return clamp_op.getResult();
2517 } else if (fused_activation_fn.getValue() == "RELU6") {
2518 int64_t quantized_0 = input_qtype.getZeroPoint();
2519 int64_t quantized_6 = std::llround((6.0f / input_qtype.getScale()) +
2520 input_qtype.getZeroPoint());
2521 int64_t quantized_min = input_qtype.getStorageTypeMin();
2522 int64_t quantized_max = input_qtype.getStorageTypeMax();
2523
2524 int64_t clamp_min = std::max(quantized_0, quantized_min);
2525 int64_t clamp_max = std::min(quantized_6, quantized_max);
2526
2527 auto clamp_op = CreateOpAndInfer<tosa::ClampOp>(
2528 rewriter, op->getLoc(), input_type, input_value,
2529 rewriter.getI64IntegerAttr(clamp_min),
2530 rewriter.getI64IntegerAttr(clamp_max), rewriter.getF32FloatAttr(0),
2531 rewriter.getF32FloatAttr(0));
2532
2533 return clamp_op.getResult();
2534 } else if (fused_activation_fn.getValue() == "RELU_N1_TO_1") {
2535 int64_t quantized_max = input_qtype.getStorageTypeMax();
2536 int64_t quantized_min = input_qtype.getStorageTypeMin();
2537 int64_t quantized_n1 = std::llround((-1.0f / input_qtype.getScale()) +
2538 input_qtype.getZeroPoint());
2539 int64_t quantized_1 = std::llround((1.0f / input_qtype.getScale()) +
2540 input_qtype.getZeroPoint());
2541
2542 int64_t clamp_min = std::max(quantized_n1, quantized_min);
2543 int64_t clamp_max = std::min(quantized_1, quantized_max);
2544
2545 auto clamp_op = CreateOpAndInfer<tosa::ClampOp>(
2546 rewriter, op->getLoc(), input_type, input_value,
2547 rewriter.getI64IntegerAttr(clamp_min),
2548 rewriter.getI64IntegerAttr(clamp_max), rewriter.getF32FloatAttr(0),
2549 rewriter.getF32FloatAttr(0));
2550
2551 return clamp_op.getResult();
2552 } else {
2553 op->emitWarning("convertFusedActivation: Not implemented yet");
2554 return llvm::None;
2555 }
2556 } else {
2557 if (fused_activation_fn.getValue() == "NONE") {
2558 return input_value;
2559 } else {
2560 // For non-quantized type, only support F32.
2561 if (!input_type.getElementType().isF32()) {
2562 op->emitOpError("ConvertTFLeakyReluOp: only support F32");
2563 return llvm::None;
2564 }
2565
2566 if (fused_activation_fn.getValue() == "RELU") {
2567 return CreateOpAndInfer<tosa::ClampOp>(
2568 rewriter, op->getLoc(), input_type, input_value,
2569 rewriter.getI64IntegerAttr(0),
2570 rewriter.getI64IntegerAttr(
2571 std::numeric_limits<int32_t>::max()),
2572 rewriter.getF32FloatAttr(0.0f),
2573 rewriter.getF32FloatAttr(std::numeric_limits<float>::max()))
2574 .getResult();
2575 } else if (fused_activation_fn.getValue() == "RELU6") {
2576 return CreateOpAndInfer<tosa::ClampOp>(
2577 rewriter, op->getLoc(), input_type, input_value,
2578 rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(6),
2579 rewriter.getF32FloatAttr(0.0f),
2580 rewriter.getF32FloatAttr(6.0f))
2581 .getResult();
2582 } else if (fused_activation_fn.getValue() == "RELU_N1_TO_1") {
2583 return CreateOpAndInfer<tosa::ClampOp>(rewriter, op->getLoc(),
2584 input_type, input_value,
2585 rewriter.getI64IntegerAttr(-1),
2586 rewriter.getI64IntegerAttr(1),
2587 rewriter.getF32FloatAttr(-1.0),
2588 rewriter.getF32FloatAttr(1.0))
2589 .getResult();
2590 } else if (fused_activation_fn.getValue() == "TANH") {
2591 return CreateOpAndInfer<tosa::TanhOp>(rewriter, op->getLoc(),
2592 input_type, input_value)
2593 .getResult();
2594 } else {
2595 // Unsupported activation type. Bail out.
2596 return llvm::None;
2597 }
2598 }
2599 }
2600
2601 return llvm::None;
2602 }
2603
2604 template <typename T>
convertGenericReduceOp(PatternRewriter & rewriter,Operation * op,Value input,Type input_etype,Type reduce_etype,ArrayRef<int64_t> input_shape,ArrayRef<int64_t> axes)2605 static Value convertGenericReduceOp(PatternRewriter& rewriter, Operation* op,
2606 Value input, Type input_etype,
2607 Type reduce_etype,
2608 ArrayRef<int64_t> input_shape,
2609 ArrayRef<int64_t> axes) {
2610 Location loc = op->getLoc();
2611 int64_t input_rank = input_shape.size();
2612 llvm::SmallVector<int32_t> perms;
2613 llvm::SmallVector<int64_t> reshape_shape;
2614 perms.reserve(input_rank);
2615 reshape_shape.resize(2, 1);
2616
2617 // First insert all non-reduction axes.
2618 for (int i = 0; i < input_rank; i++) {
2619 auto it = std::find(axes.begin(), axes.end(), i);
2620 if (it == axes.end()) {
2621 perms.push_back(i);
2622 reshape_shape[0] *= input_shape[i];
2623 }
2624 }
2625
2626 // Then insert all reduction matrices.
2627 for (auto axis : axes) {
2628 perms.push_back(axis);
2629 reshape_shape[1] *= input_shape[axis];
2630 }
2631
2632 Value perms_value =
2633 getConstTensor<int32_t>(rewriter, op, perms,
2634 {static_cast<int64_t>(perms.size())})
2635 .getValue();
2636
2637 auto transpose_op = CreateOpAndInfer<tosa::TransposeOp>(
2638 rewriter, loc, UnrankedTensorType::get(rewriter.getI32Type()), input,
2639 perms_value);
2640
2641 auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
2642 rewriter, loc, RankedTensorType::get(reshape_shape, input_etype),
2643 transpose_op, rewriter.getI64ArrayAttr(reshape_shape));
2644
2645 return CreateOpAndInfer<T>(rewriter, loc,
2646 UnrankedTensorType::get(reduce_etype), reshape_op,
2647 rewriter.getI64IntegerAttr(1));
2648 }
2649
2650 // Common function for lowering reduce operations to TOSA ops.
2651 template <typename T>
convertReduceOpCommon(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,Type reduce_element_type,bool is_quantized,double input_scale,int64_t input_zp,double output_scale,int64_t output_zp)2652 llvm::Optional<Value> convertReduceOpCommon(
2653 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2654 Value input_value, ElementsAttr axes_elems, Type reduce_element_type,
2655 bool is_quantized, double input_scale, int64_t input_zp,
2656 double output_scale, int64_t output_zp) {
2657 RankedTensorType input_type =
2658 input_value.getType().dyn_cast<RankedTensorType>();
2659 if (!input_type) return llvm::None;
2660
2661 ArrayRef<int64_t> input_shape = input_type.getShape();
2662 ArrayRef<int64_t> output_shape = output_type.getShape();
2663 auto input_rank = input_shape.size();
2664 Location loc = op->getLoc();
2665
2666 if (axes_elems.getNumElements() == 0) {
2667 // No axes means return the original tensor.
2668 auto identity_op = CreateOpAndInfer<tosa::IdentityOp>(
2669 rewriter, loc, output_type, input_value);
2670 return identity_op.getResult();
2671 }
2672
2673 // Handle negative axis and guarantee in increasing order.
2674 llvm::SmallVector<int64_t> axes;
2675 for (auto axis : axes_elems.getValues<IntegerAttr>()) {
2676 auto axis_val = axis.getInt();
2677 if (axis_val < 0) axis_val += input_rank;
2678 axes.push_back(axis_val);
2679 }
2680
2681 // Reduction operations are limited to 4D tensors, restructure to obey this
2682 // restriction. We transpose all reduction axis to the RHS, then reshape
2683 // such that all reduction and non-reduction values are grouped. This forms a
2684 // 2D matrix in all cases.
2685 Value val = input_value;
2686 if (input_rank > 4) {
2687 val = convertGenericReduceOp<T>(rewriter, op, val,
2688 input_type.getElementType(),
2689 reduce_element_type, input_shape, axes);
2690 } else {
2691 // Reduce along each axis
2692 SmallVector<int64_t> shape_vec(input_shape.begin(), input_shape.end());
2693
2694 if (is_quantized) {
2695 val = buildRescaleToInt32(rewriter, op, val, input_scale, input_zp);
2696 }
2697
2698 for (auto axis_val : axes) {
2699 if (axis_val < 0) axis_val += input_rank;
2700 auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
2701
2702 shape_vec[axis_val] = 1;
2703 RankedTensorType reduce_type =
2704 RankedTensorType::get(shape_vec, reduce_element_type);
2705
2706 auto reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
2707 val, axis_attr);
2708
2709 val = reduce_op.getResult();
2710 }
2711 }
2712
2713 if (is_quantized) {
2714 UnrankedTensorType output_rescale_type =
2715 UnrankedTensorType::get(output_type.getElementType());
2716 val = buildRescale(rewriter, op, output_rescale_type, val, output_scale, 0,
2717 output_zp, false, true);
2718 }
2719
2720 // Squeeze out the reduced axes.
2721 auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
2722 rewriter, op->getLoc(), output_type, val,
2723 rewriter.getI64ArrayAttr(output_shape));
2724 return reshape_op.getResult();
2725 }
2726
2727 // Lowers ReduceAll to a sequence of TOSA ops.
convertReduceAllOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems)2728 llvm::Optional<Value> convertReduceAllOp(PatternRewriter& rewriter,
2729 Operation* op,
2730 RankedTensorType output_type,
2731 Value input_value,
2732 ElementsAttr axes_elems) {
2733 RankedTensorType input_type =
2734 input_value.getType().dyn_cast<RankedTensorType>();
2735 if (!input_type) return llvm::None;
2736
2737 return convertReduceOpCommon<tosa::ReduceAllOp>(
2738 rewriter, op, output_type, input_value, axes_elems,
2739 output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2740 }
2741
2742 // Lowers ReduceAny to a sequence of TOSA ops.
convertReduceAnyOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems)2743 llvm::Optional<Value> convertReduceAnyOp(PatternRewriter& rewriter,
2744 Operation* op,
2745 RankedTensorType output_type,
2746 Value input_value,
2747 ElementsAttr axes_elems) {
2748 RankedTensorType input_type =
2749 input_value.getType().dyn_cast<RankedTensorType>();
2750 if (!input_type) return llvm::None;
2751
2752 return convertReduceOpCommon<tosa::ReduceAnyOp>(
2753 rewriter, op, output_type, input_value, axes_elems,
2754 output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2755 }
2756
2757 // Lowers ReduceMin to a sequence of TOSA ops.
convertReduceMinOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems)2758 llvm::Optional<Value> convertReduceMinOp(PatternRewriter& rewriter,
2759 Operation* op,
2760 RankedTensorType output_type,
2761 Value input_value,
2762 ElementsAttr axes_elems) {
2763 RankedTensorType input_type =
2764 input_value.getType().dyn_cast<RankedTensorType>();
2765 if (!input_type) return llvm::None;
2766
2767 return convertReduceOpCommon<tosa::ReduceMinOp>(
2768 rewriter, op, output_type, input_value, axes_elems,
2769 output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2770 }
2771
2772 // Lowers ReduceMax to a sequence of TOSA ops.
convertReduceMaxOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems)2773 llvm::Optional<Value> convertReduceMaxOp(PatternRewriter& rewriter,
2774 Operation* op,
2775 RankedTensorType output_type,
2776 Value input_value,
2777 ElementsAttr axes_elems) {
2778 RankedTensorType input_type =
2779 input_value.getType().dyn_cast<RankedTensorType>();
2780 if (!input_type) return llvm::None;
2781
2782 return convertReduceOpCommon<tosa::ReduceMaxOp>(
2783 rewriter, op, output_type, input_value, axes_elems,
2784 output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2785 }
2786
2787 // Lowers ReduceProd to a sequence of TOSA ops.
convertReduceProdOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems)2788 llvm::Optional<Value> convertReduceProdOp(PatternRewriter& rewriter,
2789 Operation* op,
2790 RankedTensorType output_type,
2791 Value input_value,
2792 ElementsAttr axes_elems) {
2793 RankedTensorType input_type =
2794 input_value.getType().dyn_cast<RankedTensorType>();
2795 if (!input_type) return llvm::None;
2796
2797 bool input_is_qtype =
2798 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2799 bool output_is_qtype =
2800 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2801
2802 if (input_is_qtype || output_is_qtype) {
2803 op->emitOpError(
2804 "ConvertReduceProdOp: input/output tensor should "
2805 "be all floating-point.");
2806 return llvm::None;
2807 }
2808
2809 return convertReduceOpCommon<tosa::ReduceProdOp>(
2810 rewriter, op, output_type, input_value, axes_elems,
2811 output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2812 }
2813
2814 // Lowers ReduceSum to a sequence of TOSA ops.
convertReduceSumOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems)2815 llvm::Optional<Value> convertReduceSumOp(PatternRewriter& rewriter,
2816 Operation* op,
2817 RankedTensorType output_type,
2818 Value input_value,
2819 ElementsAttr axes_elems) {
2820 RankedTensorType input_type =
2821 input_value.getType().dyn_cast<RankedTensorType>();
2822 if (!input_type) return llvm::None;
2823
2824 bool input_is_qtype =
2825 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2826 bool output_is_qtype =
2827 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2828
2829 if (input_is_qtype != output_is_qtype) {
2830 op->emitOpError(
2831 "ConvertReduceSumOp: input/output tensor should "
2832 "be all quantized or all floating-point.");
2833 return llvm::None;
2834 }
2835
2836 double input_scale = 1.0f;
2837 double output_scale = 1.0f;
2838 int64_t input_zp = 0;
2839 int64_t output_zp = 0;
2840 Type reduce_element_type = input_type.getElementType();
2841
2842 if (input_is_qtype) {
2843 auto input_qtype =
2844 input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2845 auto output_qtype =
2846 output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2847
2848 int32_t input_shift = 20;
2849
2850 input_scale =
2851 static_cast<double>(1 << input_shift) * input_qtype.getScale();
2852 output_scale =
2853 1.0 / (output_qtype.getScale() * static_cast<double>(1 << input_shift));
2854
2855 input_zp = input_qtype.getZeroPoint();
2856 output_zp = output_qtype.getZeroPoint();
2857 reduce_element_type = rewriter.getI32Type();
2858 }
2859
2860 return convertReduceOpCommon<tosa::ReduceSumOp>(
2861 rewriter, op, output_type, input_value, axes_elems, reduce_element_type,
2862 input_is_qtype, input_scale, input_zp, output_scale, output_zp);
2863 }
2864
2865 // Lowers ReduceMean to a sequence of TOSA ops.
convertReduceMeanOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems)2866 llvm::Optional<Value> convertReduceMeanOp(PatternRewriter& rewriter,
2867 Operation* op,
2868 RankedTensorType output_type,
2869 Value input_value,
2870 ElementsAttr axes_elems) {
2871 // reduce_mean is lowered as followed:
2872 // op1 = reduce_sum(input)
2873 // op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
2874
2875 RankedTensorType input_type =
2876 input_value.getType().dyn_cast<RankedTensorType>();
2877 if (!input_type) return llvm::None;
2878
2879 bool input_is_qtype =
2880 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2881 bool output_is_qtype =
2882 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2883
2884 if (input_is_qtype != output_is_qtype) {
2885 op->emitOpError(
2886 "ConvertReduceSumOp: input/output tensor should "
2887 "be all quantized or all floating-point.");
2888 return llvm::None;
2889 }
2890
2891 // Only supports float type mean() if it's non-quantized
2892 if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) {
2893 op->emitWarning(
2894 "Failed convertReduceMean: input unquantized type but output element "
2895 "not FloatType!");
2896 return llvm::None;
2897 }
2898
2899 int64_t input_rank = input_type.getRank();
2900 int64_t num_elems_on_reduced_axis = 1;
2901 for (int i = 0; i < axes_elems.getNumElements(); i++) {
2902 int64_t axis_val = axes_elems.getValues<IntegerAttr>()[i].getInt();
2903 if (axis_val < 0) axis_val += input_rank;
2904 num_elems_on_reduced_axis *= input_type.getShape()[axis_val];
2905 }
2906 double div_scale = 1.0 / static_cast<double>(num_elems_on_reduced_axis);
2907
2908 double input_scale = 1.0f;
2909 double output_scale = 1.0f;
2910 int64_t input_zp = 0;
2911 int64_t output_zp = 0;
2912 Type reduce_element_type = input_type.getElementType();
2913
2914 if (input_is_qtype) {
2915 auto input_qtype =
2916 input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2917 auto output_qtype =
2918 output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2919
2920 // Combine 'div_scale' as part of output rescale
2921 output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale();
2922
2923 input_zp = input_qtype.getZeroPoint();
2924 output_zp = output_qtype.getZeroPoint();
2925 reduce_element_type = rewriter.getI32Type();
2926 }
2927
2928 auto val = convertReduceOpCommon<tosa::ReduceSumOp>(
2929 rewriter, op, output_type, input_value, axes_elems, reduce_element_type,
2930 input_is_qtype, input_scale, input_zp, output_scale, output_zp);
2931
2932 if (!val.has_value()) return llvm::None;
2933
2934 if (!input_is_qtype) {
2935 Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
2936 return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
2937 val.getValue(), div_const, 0)
2938 .getResult();
2939 }
2940
2941 return val;
2942 }
2943
2944 // Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize.
convertResizeOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,StringRef mode,bool align_corners,bool half_pixel_centers)2945 llvm::Optional<Value> convertResizeOp(PatternRewriter& rewriter, Operation* op,
2946 RankedTensorType output_type,
2947 Value input_value, StringRef mode,
2948 bool align_corners,
2949 bool half_pixel_centers) {
2950 RankedTensorType input_type =
2951 input_value.getType().dyn_cast<RankedTensorType>();
2952 if (!input_type) return llvm::None;
2953
2954 if (input_type.getRank() != 4 || output_type.getRank() != 4) {
2955 op->emitOpError("convertResizeOp: input/output must be rank 4");
2956 return llvm::None;
2957 }
2958
2959 bool input_is_qtype =
2960 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2961 bool output_is_qtype =
2962 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2963
2964 if (input_is_qtype != output_is_qtype) {
2965 op->emitOpError(
2966 "ConvertResizeOp: input/output tensor should "
2967 "be all quantized or all floating-point.");
2968 return llvm::None;
2969 }
2970
2971 if (!input_is_qtype) {
2972 if (!input_type.getElementType().isa<mlir::FloatType>()) {
2973 op->emitOpError(
2974 "ConvertResizeOp: only quantized or float types supported.");
2975 return llvm::None;
2976 }
2977 }
2978
2979 auto input_shape = input_type.getShape();
2980 auto output_shape = output_type.getShape();
2981
2982 if (input_type.isDynamicDim(1) || input_type.isDynamicDim(2)) {
2983 op->emitOpError("ConvertResizeOp: resize dynamic input not supported.");
2984 return llvm::None;
2985 }
2986
2987 if (output_type.isDynamicDim(1) || output_type.isDynamicDim(2)) {
2988 op->emitOpError("ConvertResizeOp: resize dynamic output not supported.");
2989 return llvm::None;
2990 }
2991
2992 size_t input_height = input_shape[1];
2993 size_t input_width = input_shape[2];
2994 size_t output_height = output_shape[1];
2995 size_t output_width = output_shape[2];
2996
2997 double fp_stride_y =
2998 static_cast<double>(input_height) / static_cast<double>(output_height);
2999 double fp_stride_x =
3000 static_cast<double>(input_width) / static_cast<double>(output_width);
3001 if (align_corners && output_height > 1) {
3002 fp_stride_y = static_cast<double>(input_height - 1) /
3003 static_cast<double>(output_height - 1);
3004 }
3005 if (align_corners && output_width > 1) {
3006 fp_stride_x = static_cast<double>(input_width - 1) /
3007 static_cast<double>(output_width - 1);
3008 }
3009
3010 double fp_offset_y, fp_offset_x;
3011 if (half_pixel_centers) {
3012 fp_offset_y = fp_stride_y * 0.5f - 0.5f;
3013 fp_offset_x = fp_stride_x * 0.5f - 0.5f;
3014 } else {
3015 fp_offset_y = 0.0f;
3016 fp_offset_x = 0.0f;
3017 }
3018
3019 // oh * fp_stride_y + fp_offset_y = ix
3020
3021 ArrayAttr output_size =
3022 rewriter.getI64ArrayAttr({static_cast<int64_t>(output_height),
3023 static_cast<int64_t>(output_width)});
3024 StringAttr resize_mode = rewriter.getStringAttr(mode);
3025
3026 if (input_is_qtype) {
3027 // Magic shift number TFLite resize bilinear use
3028 // reference: tensorflow/lite/kernels/internal/reference/reference_ops.h
3029 int32_t shift = 10;
3030
3031 // 1.0 is equivalent to (1 << shift) in quantized space.
3032 // Here we noted as unit = (1 << shift).
3033 double unit = static_cast<double>(1 << shift);
3034
3035 // Stride and Offset is int16.
3036 int32_t stride_y = std::lround(fp_stride_y * unit);
3037 int32_t stride_x = std::lround(fp_stride_x * unit);
3038 int32_t offset_y = std::lround(fp_offset_y * unit);
3039 int32_t offset_x = std::lround(fp_offset_x * unit);
3040
3041 // Numerically we can decrement shift to let these number fits within 16
3042 // bits but that's not commonly seen and won't match TFLite reference
3043 if (stride_y > std::numeric_limits<int16_t>::max() ||
3044 stride_x > std::numeric_limits<int16_t>::max() ||
3045 stride_y < std::numeric_limits<int16_t>::min() ||
3046 stride_x < std::numeric_limits<int16_t>::min() ||
3047 offset_y > std::numeric_limits<int16_t>::max() ||
3048 offset_x > std::numeric_limits<int16_t>::max() ||
3049 offset_y < std::numeric_limits<int16_t>::min() ||
3050 offset_x < std::numeric_limits<int16_t>::min()) {
3051 op->emitOpError("OpResize: stride or offset out of 16 bits");
3052 return llvm::None;
3053 }
3054
3055 ArrayAttr stride = rewriter.getI64ArrayAttr({stride_y, stride_x});
3056 ArrayAttr offset = rewriter.getI64ArrayAttr({offset_y, offset_x});
3057 IntegerAttr shift_attr = rewriter.getI32IntegerAttr(shift);
3058
3059 // If quantized bilinear mode, need to lower to RESIZE + RESCALE pair.
3060 if (mode == "BILINEAR") {
3061 RankedTensorType output_acc_type;
3062 auto input_element_qtype =
3063 input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
3064
3065 bool scale32;
3066
3067 // TOSA RESIZE: 16 bit input -> 48 bit output, or 8 bit input -> 32 bit
3068 // output.
3069 if (input_element_qtype.getStorageTypeIntegralWidth() == 16) {
3070 scale32 = false;
3071 output_acc_type = RankedTensorType::get(output_type.getShape(),
3072 rewriter.getIntegerType(48));
3073 } else if (input_element_qtype.getStorageTypeIntegralWidth() == 8) {
3074 scale32 = true;
3075 output_acc_type = RankedTensorType::get(output_type.getShape(),
3076 rewriter.getI32Type());
3077 } else {
3078 op->emitOpError("OpResize: support 16-bit and 8-bit quantized input");
3079 return llvm::None;
3080 }
3081
3082 auto resize_op = CreateOpAndInfer<tosa::ResizeOp>(
3083 rewriter, op->getLoc(), output_acc_type, input_value, output_size,
3084 stride, offset, shift_attr, rewriter.getF32ArrayAttr({0.0, 0.0}),
3085 rewriter.getF32ArrayAttr({0.0, 0.0}), resize_mode);
3086
3087 #ifdef RESIZE_BILINEAR_LOWER_SYMMETRIC_ROUNDING
3088 // TFLite resize_bilinear always assume input and output tensors have
3089 // same scale That means we only need to arithmetic right shift with
3090 // (2 * shift)
3091 // TODO(suderman): Align TFLite rounding behavior
3092 // TFLite also uses symmetric rounding by doing 'x / (1 << 20)'
3093 // TOSA arithmetic right shift is doing standard rounding.
3094 // Right now it's legalized using GreaterEqualOp + SelectOp to conform
3095 // to TFLite reference. But this eventually should be fixed in TFLite
3096 // reference
3097 Value cst_zero = getTosaConstTensorSingleI32(rewriter, op, 0);
3098 Value cst_twenty = getTosaConstTensorSingleI32(rewriter, op, 20);
3099
3100 auto ge_op = CreateOpAndInfer<tosa::GreaterEqualOp>(
3101 rewriter, op->getLoc(), output_bool_type, resize_op.getResult(),
3102 cst_zero);
3103
3104 auto abs_op = CreateOpAndInfer<tosa::AbsOp>(
3105 rewriter, op->getLoc(), output_acc_type, resize_op.getResult());
3106
3107 auto rshift_op = CreateOpAndInfer<tosa::ArithmeticRightShiftOp>(
3108 rewriter, op->getLoc(), output_acc_type, abs_op.getResult(),
3109 cst_twenty, true);
3110
3111 auto negate_op = CreateOpAndInfer<tosa::NegateOp>(
3112 rewriter, op->getLoc(), output_acc_type, rshift_op.getResult());
3113
3114 auto select_op = CreateOpAndInfer<tosa::SelectOp>(
3115 rewriter, op->getLoc(), output_acc_type, ge_op.getResult(),
3116 rshift_op.getResult(), negate_op.getResult());
3117
3118 auto cast_op = CreateOpAndInfer<tosa::CastOp>(
3119 rewriter, op->getLoc(), output_type, select_op.getResult());
3120
3121 return cast_op.getResult();
3122 #else
3123 // This should be the expected lowering, but is +-1 within compared to
3124 // TFLite reference.
3125 return buildRescale(rewriter, op, output_type, resize_op.getResult(),
3126 1.0 / (1 << 20), 0, 0, false, scale32);
3127 #endif
3128
3129 } else if (mode == "NEAREST_NEIGHBOR") {
3130 auto resize_op = CreateOpAndInfer<tosa::ResizeOp>(
3131 rewriter, op->getLoc(), output_type, input_value, output_size, stride,
3132 offset, shift_attr, rewriter.getF32ArrayAttr({0.0, 0.0}),
3133 rewriter.getF32ArrayAttr({0.0, 0.0}), resize_mode);
3134 return resize_op.getResult();
3135 } else {
3136 op->emitOpError(
3137 "OpResize: only support BILINEAR or NEAREST_NEIGHBOR mode");
3138 return llvm::None;
3139 }
3140 } else {
3141 auto resize_op = CreateOpAndInfer<tosa::ResizeOp>(
3142 rewriter, op->getLoc(), output_type, input_value, output_size,
3143 rewriter.getI64ArrayAttr({0, 0}), rewriter.getI64ArrayAttr({0, 0}),
3144 rewriter.getI32IntegerAttr(0),
3145 rewriter.getF32ArrayAttr(
3146 {static_cast<float>(fp_stride_y), static_cast<float>(fp_stride_x)}),
3147 rewriter.getF32ArrayAttr(
3148 {static_cast<float>(fp_offset_y), static_cast<float>(fp_offset_x)}),
3149 resize_mode);
3150 return resize_op.getResult();
3151 }
3152 }
3153
3154 // Lowers Quantize to a sequence of TOSA quantization ops.
convertQuantizeOp(PatternRewriter & rewriter,Operation * op,ShapedType output_type,Value input_value,double scale,int64_t zeropoint)3155 llvm::Optional<Value> convertQuantizeOp(PatternRewriter& rewriter,
3156 Operation* op, ShapedType output_type,
3157 Value input_value, double scale,
3158 int64_t zeropoint) {
3159 RankedTensorType input_type =
3160 input_value.getType().dyn_cast<RankedTensorType>();
3161 if (!input_type) return llvm::None;
3162
3163 auto output_element_type = output_type.getElementType();
3164
3165 // output element type could only be quantized integer
3166 if (!output_element_type.isa<mlir::quant::QuantizedType>()) {
3167 op->emitWarning(
3168 "Lowering quantizeOp but output element type not quantized!");
3169 return llvm::None;
3170 }
3171
3172 ShapedType output_fp_type = output_type.clone(rewriter.getF32Type());
3173
3174 Value zp_val =
3175 getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(zeropoint));
3176
3177 auto op1_mul_in = CreateOpAndInfer<tosa::MulOp>(
3178 rewriter, op->getLoc(), output_fp_type, input_value,
3179 getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale)), 0);
3180
3181 auto op2_add_op1 = CreateOpAndInfer<tosa::AddOp>(
3182 rewriter, op->getLoc(), output_fp_type, op1_mul_in.getResult(), zp_val);
3183
3184 auto op3_cast_op2 = CreateOpAndInfer<tosa::CastOp>(
3185 rewriter, op->getLoc(), output_type, op2_add_op1.getResult());
3186
3187 return op3_cast_op2.getResult();
3188 }
3189
3190 // Lowers Dequantize to a sequence of TOSA dequantization ops.
convertDequantizeOp(PatternRewriter & rewriter,Operation * op,ShapedType output_type,Value input_value,ArrayRef<float> scale,ArrayRef<float> zeropoint,int64_t dim)3191 llvm::Optional<Value> convertDequantizeOp(PatternRewriter& rewriter,
3192 Operation* op, ShapedType output_type,
3193 Value input_value,
3194 ArrayRef<float> scale,
3195 ArrayRef<float> zeropoint,
3196 int64_t dim) {
3197 RankedTensorType input_type =
3198 input_value.getType().dyn_cast<RankedTensorType>();
3199 if (!input_type) return llvm::None;
3200
3201 // input element type could only be quantized integer
3202 if (!input_type.getElementType().isa<mlir::quant::QuantizedType>())
3203 return llvm::None;
3204
3205 Optional<Value> zp_val;
3206 if (zeropoint.size() == 1) {
3207 zp_val = getTosaConstTensorSingleF32(rewriter, op,
3208 static_cast<float>(zeropoint[0]));
3209 } else {
3210 SmallVector<int64_t> shape;
3211 shape.resize(input_type.getRank(), 1);
3212 shape[dim] = zeropoint.size();
3213 zp_val = getConstTensor(rewriter, op, zeropoint, shape);
3214 }
3215
3216 Optional<Value> scale_val;
3217 if (scale.size() == 1) {
3218 scale_val =
3219 getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale[0]));
3220 } else {
3221 SmallVector<int64_t> shape;
3222 shape.resize(input_type.getRank(), 1);
3223 shape[dim] = scale.size();
3224 scale_val = getConstTensor(rewriter, op, scale, shape);
3225 }
3226
3227 if (!zp_val || !scale_val) return llvm::None;
3228
3229 auto op1_cast_in = CreateOpAndInfer<tosa::CastOp>(rewriter, op->getLoc(),
3230 output_type, input_value);
3231
3232 auto op2_sub_op1 =
3233 CreateOpAndInfer<tosa::SubOp>(rewriter, op->getLoc(), output_type,
3234 op1_cast_in.getResult(), zp_val.getValue());
3235
3236 return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
3237 op2_sub_op1.getResult(),
3238 scale_val.getValue(), 0)
3239 .getResult();
3240 }
3241
3242 // Lowers FakeQuant to a sequence of TOSA quantization ops.
convertFakeQuantOp(PatternRewriter & rewriter,Operation * op,ShapedType output_type,Value input_value,double min,double max,int64_t num_bits,bool narrow_range)3243 llvm::Optional<Value> convertFakeQuantOp(PatternRewriter& rewriter,
3244 Operation* op, ShapedType output_type,
3245 Value input_value, double min,
3246 double max, int64_t num_bits,
3247 bool narrow_range) {
3248 // FakeQuant is lowered as follow:
3249 // op1 = quantize(input)
3250 // op2 = dequantize(op1)
3251
3252 RankedTensorType input_type =
3253 input_value.getType().dyn_cast<RankedTensorType>();
3254 if (!input_type) return llvm::None;
3255
3256 // quantized as INT<num_bits>, where num_bits can only be 8, 16
3257 if (num_bits != 8 && num_bits != 16) {
3258 op->emitWarning("FakeQuantOp lowering handles only 8 and 16 for num_bits!");
3259 return llvm::None;
3260 }
3261
3262 // This code originates from
3263 // tensorflow/core/kernels/fake_quant_ops_functor.h.
3264 int32_t qmax = (1 << (num_bits)) - 1;
3265 int32_t qmin = narrow_range ? 1 : 0;
3266
3267 float nudged_min, nudged_max, nudged_scale;
3268 tensorflow_nudge(min, max, qmin, qmax, &nudged_min, &nudged_max,
3269 &nudged_scale);
3270
3271 Value cst_min = getTosaConstTensorSingleF32(rewriter, op, nudged_min);
3272 Value cst_max = getTosaConstTensorSingleF32(rewriter, op, nudged_max);
3273 Value cst_scale = getTosaConstTensorSingleF32(rewriter, op, nudged_scale);
3274 Value cst_inv_scale =
3275 getTosaConstTensorSingleF32(rewriter, op, 1.0f / nudged_scale);
3276 Value cst_half = getTosaConstTensorSingleF32(rewriter, op, 0.5f);
3277
3278 // This code originates from
3279 // tensorflow/core/kernels/fake_quant_ops_functor.h.
3280 auto op1_min_in = CreateOpAndInfer<tosa::MinimumOp>(
3281 rewriter, op->getLoc(), output_type, input_value, cst_max);
3282
3283 auto op2_max_op1 = CreateOpAndInfer<tosa::MaximumOp>(
3284 rewriter, op->getLoc(), output_type, op1_min_in.getResult(), cst_min);
3285
3286 auto op3_sub_op2 = CreateOpAndInfer<tosa::SubOp>(
3287 rewriter, op->getLoc(), output_type, op2_max_op1.getResult(), cst_min);
3288
3289 auto op4_mul_op3 =
3290 CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
3291 op3_sub_op2.getResult(), cst_inv_scale, 0);
3292
3293 auto op5_add_op4 = CreateOpAndInfer<tosa::AddOp>(
3294 rewriter, op->getLoc(), output_type, op4_mul_op3.getResult(), cst_half);
3295
3296 auto op6_floor_op5 = CreateOpAndInfer<tosa::FloorOp>(
3297 rewriter, op->getLoc(), output_type, op5_add_op4.getResult());
3298
3299 auto op7_mul_op6 =
3300 CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
3301 op6_floor_op5.getResult(), cst_scale, 0);
3302
3303 return CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(), output_type,
3304 op7_mul_op6.getResult(), cst_min)
3305 .getResult();
3306 }
3307
convertTFConv2DCommon(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input,Value filter,Value bias,ArrayAttr strides_attr,ArrayAttr dilations_attr,ArrayAttr explicit_padding_attr,StringRef padding_ref,StringRef data_format_ref)3308 llvm::Optional<Value> convertTFConv2DCommon(
3309 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
3310 Value input, Value filter, Value bias, ArrayAttr strides_attr,
3311 ArrayAttr dilations_attr, ArrayAttr explicit_padding_attr,
3312 StringRef padding_ref, StringRef data_format_ref) {
3313 RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
3314 RankedTensorType filter_type = filter.getType().dyn_cast<RankedTensorType>();
3315 // Not a ranked tensor output
3316 if (!input_type) return llvm::None;
3317 if (!filter_type) return llvm::None;
3318
3319 // Transpose [H, W, I, O] to [O, H, W, I]
3320 auto filter_shape = filter_type.getShape();
3321 SmallVector<int64_t, 4> a1_transpose_dims;
3322 a1_transpose_dims.push_back(filter_shape[3]);
3323 a1_transpose_dims.push_back(filter_shape[0]);
3324 a1_transpose_dims.push_back(filter_shape[1]);
3325 a1_transpose_dims.push_back(filter_shape[2]);
3326 llvm::Optional<Value> a1_filter_transpose_perm = getConstTensor<int32_t>(
3327 rewriter, op, /*vec=*/{3, 0, 1, 2}, /*shape=*/{4});
3328
3329 if (!a1_filter_transpose_perm) return llvm::None;
3330
3331 auto a1_filter_transpose_op = CreateOpAndInfer<tosa::TransposeOp>(
3332 rewriter, op->getLoc(),
3333 RankedTensorType::get(a1_transpose_dims, filter_type.getElementType()),
3334 filter, a1_filter_transpose_perm.getValue());
3335
3336 // Only support NHWC now.
3337 if (data_format_ref.str() != "NHWC") {
3338 op->emitWarning("convertTDConv2DCommon only supports NHWC!");
3339 return llvm::None;
3340 }
3341
3342 ArrayAttr stride;
3343 ArrayAttr dilation;
3344 ArrayAttr pad;
3345 {
3346 if (!strides_attr) {
3347 stride = rewriter.getI64ArrayAttr({1, 1});
3348 } else {
3349 // Note: hardcoded to NHWC for now
3350 int64_t stride_h = strides_attr[1].cast<IntegerAttr>().getInt();
3351 int64_t stride_w = strides_attr[2].cast<IntegerAttr>().getInt();
3352 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
3353 }
3354 }
3355 {
3356 if (!dilations_attr) {
3357 dilation = rewriter.getI64ArrayAttr({1, 1});
3358 } else {
3359 // Note: hardcoded to NHWC for now
3360 int64_t dilation_h = dilations_attr[1].cast<IntegerAttr>().getInt();
3361 int64_t dilation_w = dilations_attr[2].cast<IntegerAttr>().getInt();
3362 dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
3363 }
3364 }
3365 {
3366 tensorflow::Padding tf_pad;
3367 if (!GetPaddingFromString(padding_ref.str(), &tf_pad).ok()) {
3368 op->emitWarning("Could not get padding data from padding string term!");
3369 return llvm::None;
3370 }
3371
3372 tensorflow::TensorFormat data_format_tf;
3373 if (!FormatFromString(data_format_ref.str(), &data_format_tf))
3374 return llvm::None;
3375
3376 if (tf_pad == tensorflow::Padding::EXPLICIT) {
3377 pad = getPaddingValuesFromExplicitPadAttr(explicit_padding_attr,
3378 data_format_tf, rewriter);
3379 } else {
3380 if (!getPaddingValuesFromPadType(tf_pad, data_format_tf,
3381 0, // tensorflow::FORMAT_HWIO
3382 input_type, filter_type, stride,
3383 dilation, rewriter, pad))
3384 return llvm::None;
3385 }
3386 }
3387
3388 return CreateOpAndInfer<tosa::Conv2DOp>(
3389 rewriter, op->getLoc(), output_type, input,
3390 a1_filter_transpose_op.getResult(), bias, pad, stride, dilation)
3391 .getResult();
3392 }
3393
3394 // Lowers Gather operators to a sequence of TOSA ops.
convertGatherOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value params_value,Value indices_value,int32_t batch_dims,int32_t axis)3395 llvm::Optional<Value> convertGatherOp(PatternRewriter& rewriter, Operation* op,
3396 Value result_value, Value params_value,
3397 Value indices_value, int32_t batch_dims,
3398 int32_t axis) {
3399 auto result_type = result_value.getType().dyn_cast<ShapedType>();
3400 auto params_type = params_value.getType().dyn_cast<RankedTensorType>();
3401 auto indices_type = indices_value.getType().dyn_cast<RankedTensorType>();
3402
3403 if (!result_type || !params_type || !indices_type) return llvm::None;
3404
3405 // batch_dims indicates the number of batch dimensions in params and
3406 // indices axis indicates the axis at which the gather indexing is
3407 // applied. axis must be >= batch_dims. When axis is equal to
3408 // batch_dims, the right-most batch dimension disappears.
3409 //
3410 // N: number of batches
3411 // Computed as product of params.shape[0:batch_dims-1]
3412 //
3413 // W: number of indices in each batch
3414 // Computed as product of indices.shape[batch_dims:]
3415 //
3416 // K: range of each index
3417 // Computed as params.shape[axis:axis+rank(indices)-1]
3418 //
3419 // C: number of channels for each index
3420 // Computed as: LeftChannels * RightChannels:
3421 // product(params.shape[batch_dims:axis]) * product(params.shape[axis+1:])
3422 //
3423 // The params tensor needs to be transposed, then reshaped to move the
3424 // dimensions into [N, K, C] order.
3425 //
3426 // The dimensions of the input params[] tensor are grouped in the following
3427 // order to begin with:
3428 //
3429 // [Batch, LeftChannels, Indices, RightChannels]
3430 // |-----||------------||-------||-------------|
3431 // N C_l K C_r
3432 //
3433 // Where Batch (N), Indices (K) can be one or more dimensions in size,
3434 // while LeftChannels and RightChannels represent the group of data channels
3435 // (C) to the left and right (C_l, C_r) of the indices; the sum of these two
3436 // is one or more dimensions in size, but either one may be zero depending
3437 // on how axis was specified by the caller.
3438 //
3439 // The resulting tensor will look like:
3440 //
3441 // [Batch, Indices, LeftChannels, RightChannels]
3442 // |-----||-------||---------------------------|
3443 // N K C
3444 //
3445 // The indices tensor simply needs a reshape to flatten all of the
3446 // batch dimensions (N) together and flatten all of the indices (W)
3447 // together.
3448 //
3449 // Then do the tosa.GATHER
3450 //
3451 // output[N,W,C] = tosa.GATHER(values[N,K,C], indices[N,W])
3452 //
3453 // Finally, the resulting tensor will have shape [N, W, C], where C is a
3454 // flattened version of [LeftChannels, RightChannels]. We need to reshape
3455 // to unflatten to:
3456 //
3457 // [N, W, LeftChannels, RightChannels]
3458 //
3459 // and finally transpose back to the output shape
3460 //
3461 // [Batch, LeftChannels, Non-Batch-Indices, RightChannels]
3462
3463 int params_rank = params_type.getShape().size();
3464 int indices_rank = indices_type.getShape().size();
3465
3466 if (!(batch_dims <= indices_rank)) {
3467 op->emitOpError("Batch_dims must be <= indices_rank for a valid gather op");
3468 return llvm::None;
3469 }
3470
3471 if (!(axis >= batch_dims)) {
3472 op->emitOpError("axis must be >= batch_dims for a valid gather op");
3473 return llvm::None;
3474 }
3475
3476 // Sizes for each of these fields.
3477 SmallVector<int64_t> params_batch, params_indices, params_left_channels,
3478 params_right_channels;
3479
3480 // Dimension indices for each of these fields.
3481 SmallVector<int64_t> params_idx_batch, params_idx_indices,
3482 params_idx_left_channels, params_idx_right_channels;
3483
3484 // Read through the params tensor dimensions left-to-right and extract the
3485 // different fields.
3486 for (int i = 0; i < params_rank; i++) {
3487 // When batch_dims == axis, the batch dimension gets replaced.
3488 if (i < batch_dims && i < axis) {
3489 params_batch.push_back(params_type.getShape()[i]);
3490 params_idx_batch.push_back(i);
3491 } else if (i < axis) {
3492 params_left_channels.push_back(params_type.getShape()[i]);
3493 params_idx_left_channels.push_back(i);
3494 } else if (i < (axis + 1)) {
3495 params_indices.push_back(params_type.getShape()[i]);
3496 params_idx_indices.push_back(i);
3497 } else {
3498 params_right_channels.push_back(params_type.getShape()[i]);
3499 params_idx_right_channels.push_back(i);
3500 }
3501 }
3502
3503 // Calculate N, K, W, C
3504 int64_t N = multiply_dims(params_type.getShape().take_front(batch_dims));
3505 int64_t W = multiply_dims(
3506 indices_type.getShape().slice(batch_dims, indices_rank - batch_dims));
3507 int64_t K = params_type.getShape()[axis];
3508
3509 int64_t C = multiply_dims(
3510 params_type.getShape().slice(batch_dims, axis - batch_dims));
3511 C = multiply_dims(
3512 params_type.getShape().slice(axis + 1, params_rank - axis - 1), C);
3513
3514 /////////////////////////////////////////////
3515 // Build up the params transpose operator
3516 SmallVector<int32_t> params_transpose_perm;
3517 SmallVector<int64_t> params_transpose_shape;
3518
3519 // Batch
3520 for (int i = 0; i < params_batch.size(); i++) {
3521 params_transpose_perm.push_back(params_idx_batch[i]);
3522 params_transpose_shape.push_back(params_batch[i]);
3523 }
3524
3525 // Indices
3526 for (int i = 0; i < params_indices.size(); i++) {
3527 params_transpose_perm.push_back(params_idx_indices[i]);
3528 params_transpose_shape.push_back(params_indices[i]);
3529 }
3530
3531 // LeftChannels
3532 for (int i = 0; i < params_left_channels.size(); i++) {
3533 params_transpose_perm.push_back(params_idx_left_channels[i]);
3534 params_transpose_shape.push_back(params_left_channels[i]);
3535 }
3536
3537 // RightChannels
3538 for (int i = 0; i < params_right_channels.size(); i++) {
3539 params_transpose_perm.push_back(params_idx_right_channels[i]);
3540 params_transpose_shape.push_back(params_right_channels[i]);
3541 }
3542
3543 /////////////////////////////////////////////
3544 // Build up the result reshape, in prepration for transpose
3545 // [N, W, C] -> [ Batch, Indices, LeftChannels, RightChannels ]
3546 SmallVector<int64_t> result_reshape_shape;
3547
3548 // Indices
3549 for (int i = 0; i < indices_type.getShape().size(); i++) {
3550 result_reshape_shape.push_back(indices_type.getShape()[i]);
3551 }
3552
3553 // Left channels
3554 for (int i = 0; i < params_left_channels.size(); i++) {
3555 result_reshape_shape.push_back(params_left_channels[i]);
3556 }
3557
3558 // Right channels. But remove the axis dimension.
3559 for (int i = 0; i < params_right_channels.size(); i++) {
3560 result_reshape_shape.push_back(params_right_channels[i]);
3561 }
3562
3563 /////////////////////////////////////////////
3564 // Build up the result transpose operator.
3565 SmallVector<int32_t> result_transpose_perm;
3566
3567 // Batch dimensions
3568 for (int i = 0; i < batch_dims; i++) {
3569 result_transpose_perm.push_back(i);
3570 }
3571
3572 // LeftChannels
3573 for (int i = 0; i < params_left_channels.size(); i++) {
3574 result_transpose_perm.push_back(i + indices_type.getShape().size());
3575 }
3576
3577 // Indices (remainder of dimensions after batch).
3578 for (int i = batch_dims; i < (indices_type.getShape().size()); i++) {
3579 result_transpose_perm.push_back(i);
3580 }
3581
3582 // RightChannels, coming from after both the Indices and LeftChannels.
3583 for (int i = 0; i < params_right_channels.size(); i++) {
3584 result_transpose_perm.push_back(i + indices_type.getShape().size() +
3585 params_left_channels.size());
3586 }
3587
3588 SmallVector<int64_t> tosa_values_shape = {N, K, C};
3589 SmallVector<int64_t> tosa_indices_shape = {N, W};
3590 SmallVector<int64_t> tosa_gather_result_shape = {N, W, C};
3591
3592 llvm::Optional<Value> params_transpose_perm_val = getConstTensor<int32_t>(
3593 rewriter, op, params_transpose_perm,
3594 {static_cast<int64_t>(params_transpose_perm.size())});
3595
3596 llvm::Optional<Value> result_transpose_perm_val = getConstTensor<int32_t>(
3597 rewriter, op, result_transpose_perm,
3598 {static_cast<int64_t>(result_transpose_perm.size())});
3599
3600 if (!params_transpose_perm_val || !result_transpose_perm_val)
3601 return llvm::None;
3602
3603 auto params_transpose_op = CreateOpAndInfer<tosa::TransposeOp>(
3604 rewriter, op->getLoc(),
3605 RankedTensorType::get(params_transpose_shape,
3606 params_type.getElementType()),
3607 params_value, params_transpose_perm_val.getValue());
3608
3609 if (count_dynamic_dims(tosa_values_shape) > 1) {
3610 return (void)rewriter.notifyMatchFailure(
3611 op,
3612 "multiply dynamic shapes when reshaping values down to "
3613 "tosa.gather"),
3614 llvm::None;
3615 }
3616
3617 auto tosa_values_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3618 rewriter, op->getLoc(),
3619 RankedTensorType::get(tosa_values_shape, params_type.getElementType()),
3620 params_transpose_op.getResult(),
3621 rewriter.getI64ArrayAttr(tosa_values_shape));
3622
3623 if (count_dynamic_dims(tosa_indices_shape) > 1) {
3624 return (void)rewriter.notifyMatchFailure(
3625 op,
3626 "multiply dynamic shapes when reshaping indices down to "
3627 "tosa.gather"),
3628 llvm::None;
3629 }
3630
3631 auto tosa_indices_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3632 rewriter, op->getLoc(),
3633 RankedTensorType::get(tosa_indices_shape, indices_type.getElementType()),
3634 indices_value, rewriter.getI64ArrayAttr(tosa_indices_shape));
3635
3636 auto tosa_gather_op = CreateOpAndInfer<tosa::GatherOp>(
3637 rewriter, op->getLoc(),
3638 RankedTensorType::get(tosa_gather_result_shape,
3639 result_type.getElementType()),
3640 tosa_values_reshape_op.getResult(), tosa_indices_reshape_op.getResult());
3641
3642 if (count_dynamic_dims(result_reshape_shape) > 1) {
3643 return (void)rewriter.notifyMatchFailure(
3644 op,
3645 "multiply dynamic shapes when reshaping tosa.gather result up"),
3646 llvm::None;
3647 }
3648
3649 auto tosa_result_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3650 rewriter, op->getLoc(),
3651 RankedTensorType::get(result_reshape_shape, params_type.getElementType()),
3652 tosa_gather_op.getResult(),
3653 rewriter.getI64ArrayAttr(result_reshape_shape));
3654
3655 return CreateOpAndInfer<tosa::TransposeOp>(
3656 rewriter, op->getLoc(), result_type,
3657 tosa_result_reshape_op.getResult(),
3658 result_transpose_perm_val.getValue())
3659 .getResult();
3660 }
3661
3662 // Lowers Gather operators to a sequence of TOSA ops.
convertGatherNdOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value params_value,Value indices_value)3663 llvm::Optional<Value> convertGatherNdOp(PatternRewriter& rewriter,
3664 Operation* op, Value result_value,
3665 Value params_value, Value indices_value)
3666
3667 {
3668 auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
3669 auto params_type = params_value.getType().dyn_cast<RankedTensorType>();
3670 auto indices_type = indices_value.getType().dyn_cast<RankedTensorType>();
3671
3672 if (!result_type || !params_type || !indices_type) return llvm::None;
3673
3674 // N: number of batches
3675 // Always 1 for GatherND
3676 //
3677 // Because TOSA's GATHER operator already uses the symbol 'N' for
3678 // the number of batches, we will use the symbol 'ND' to specify the
3679 // number of dimensions that are sliced from params instead of'N' in
3680 // the TF MLIR documentation.
3681 //
3682 // ND: indices.shape[-1]
3683 //
3684 // W: number of indices in each batch
3685 // Computed as:
3686 // product(indices.shape[0:-1]) (all but the last dimension)
3687 //
3688 // K: range of each index
3689 // Computed as:
3690 // product(params.shape[0:ND-1])
3691 //
3692 // C: number of channels for each index
3693 // Computed as:
3694 // product(params.shape[ND:])
3695 //
3696 // The params tensor needs to be reshaped, but not transposed, to move the
3697 // dimensions into [N, K, C] order.
3698 //
3699 // The dimensions of the input params[] tensor are grouped in the following
3700 // order to begin with:
3701 //
3702 // [ParamIndices, ParamChannels]
3703 // |------------||-------------|
3704 // K C
3705 //
3706 // The reshape simply flattens the params tensor into a 2D [K, C] shape.
3707 //
3708 // Indices needs to be put in the form of [N, W], but a simple flattening
3709 // will not suffice, because the indices need to index into a [W]-shape
3710 // vector instead of the params.shape[0:ND-1] tensor that we had before.
3711 //
3712 // To flatten the coordinates, first reshape indices to a [W, ND] matrix,
3713 // where the matrix now represents W ND-dimensional coordinates into the
3714 // params tensor.
3715 //
3716 // From here, we take each of the ND dimensions and multiply it with
3717 // the size of the next params dimension (or 1 for the last
3718 // dimension), then sum all these together with a reduce_sum
3719 // operator. This is exactly the same mathematics as one would use
3720 // flatten the indices of an N-dimensional row-major array into a
3721 // 1-D array in C.
3722 //
3723 // More precisely, do an element-wise multiply with [params.shape[1
3724 // .. ND], 1] in axis 1, then reduce_sum in axis 1 to flatten to a
3725 // [W]-shaped tensor, then trivially reshape to [N=1, W] to be
3726 // compatible with the GATHER operator's shape.
3727 //
3728 // Then perform the tosa.GATHER() operation.
3729 //
3730 // Now we have result = [N, K, C].
3731 //
3732 // Reshape with a single, simple reshape to the final output shape of:
3733 // [Indices, ParamChannels]
3734 //
3735 // Where, Indices is indices.shape[0:ND-1]
3736
3737 int N = 1, W = 1, K = 1, C = 1, ND = 1;
3738
3739 int params_rank = params_type.getShape().size();
3740 int indices_rank = indices_type.getShape().size();
3741
3742 ND = indices_type.getShape()[indices_rank - 1];
3743
3744 if (ND > params_rank) {
3745 op->emitOpError("Size of last dimension on indices must be <= params rank");
3746 return llvm::None;
3747 }
3748
3749 // Calculate N, K, W, C. (N is always 1)
3750 for (int i = 0; i < (indices_rank - 1); i++) {
3751 W *= indices_type.getShape()[i];
3752 }
3753
3754 for (int i = 0; i < ND; i++) {
3755 K *= params_type.getShape()[i];
3756 }
3757
3758 for (int i = ND; i < params_rank; i++) {
3759 C *= params_type.getShape()[i];
3760 }
3761
3762 SmallVector<int64_t, 3> tosa_values_shape({N, K, C});
3763 SmallVector<int64_t, 2> tosa_indices_shape({N, W});
3764 SmallVector<int64_t, 2> indices_matrix_shape({W, ND});
3765 SmallVector<int64_t, 3> tosa_gather_result_shape({N, W, C});
3766
3767 auto tosa_values_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3768 rewriter, op->getLoc(),
3769 RankedTensorType::get(tosa_values_shape, params_type.getElementType()),
3770 params_value, rewriter.getI64ArrayAttr(tosa_values_shape));
3771
3772 // Flatten the input indices tensor to an [W, ND] matrix.
3773 auto indices_matrix_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3774 rewriter, op->getLoc(),
3775 RankedTensorType::get(indices_matrix_shape,
3776 indices_type.getElementType()),
3777 indices_value, rewriter.getI64ArrayAttr(indices_matrix_shape));
3778
3779 SmallVector<int32_t> flattened_coeff_vec;
3780 for (int i = 1; i < ND; i++) {
3781 flattened_coeff_vec.push_back(params_type.getShape()[i]);
3782 }
3783 flattened_coeff_vec.push_back(1);
3784
3785 for (int i = ND - 1; i > 0; i--) {
3786 flattened_coeff_vec[i - 1] *= flattened_coeff_vec[i];
3787 }
3788
3789 llvm::Optional<Value> flattened_coeff_value = getConstTensor<int32_t>(
3790 rewriter, op, flattened_coeff_vec,
3791 {static_cast<int64_t>(flattened_coeff_vec.size())});
3792
3793 if (!flattened_coeff_value) return llvm::None;
3794
3795 // Multiply the coefficients by the coordinates
3796 auto flattened_indices_mul_op = CreateOpAndInfer<tosa::MulOp>(
3797 rewriter, op->getLoc(),
3798 RankedTensorType::get(indices_matrix_shape,
3799 indices_type.getElementType()),
3800 indices_matrix_reshape_op.getResult(), flattened_coeff_value.getValue(),
3801 0);
3802
3803 // Sum up the products of the coefficients and coordinates
3804 auto flattened_indices_reduce_op = CreateOpAndInfer<tosa::ReduceSumOp>(
3805 rewriter, op->getLoc(),
3806 RankedTensorType::get(tosa_indices_shape, indices_type.getElementType()),
3807 flattened_indices_mul_op.getResult(), rewriter.getI64IntegerAttr(1));
3808
3809 // And reshape to [N, W]
3810 auto tosa_indices_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3811 rewriter, op->getLoc(),
3812 RankedTensorType::get(tosa_indices_shape, indices_type.getElementType()),
3813 flattened_indices_reduce_op.getResult(),
3814 rewriter.getI64ArrayAttr(tosa_indices_shape));
3815
3816 // Now the gather op itself
3817 auto tosa_gather_op = CreateOpAndInfer<tosa::GatherOp>(
3818 rewriter, op->getLoc(),
3819 RankedTensorType::get(tosa_gather_result_shape,
3820 result_type.getElementType()),
3821 tosa_values_reshape_op.getResult(), tosa_indices_reshape_op.getResult());
3822
3823 // Finally, reshape back to the original output shape of [Indices,
3824 // ParamChannels].
3825 return CreateOpAndInfer<tosa::ReshapeOp>(
3826 rewriter, op->getLoc(), result_type, tosa_gather_op.getResult(),
3827 rewriter.getI64ArrayAttr(result_type.getShape()))
3828 .getResult();
3829 }
3830
3831 // Lowers OneHot operator to a sequence of TOSA ops.
convertOneHotOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value indices_value,Value on_value,Value off_value,int32_t depth,int32_t axis)3832 llvm::Optional<Value> convertOneHotOp(PatternRewriter& rewriter, Operation* op,
3833 Value result_value, Value indices_value,
3834 Value on_value, Value off_value,
3835 int32_t depth, int32_t axis) {
3836 auto result_type = result_value.getType().dyn_cast<RankedTensorType>();
3837 auto indices_type = indices_value.getType().dyn_cast<RankedTensorType>();
3838 auto on_value_type = on_value.getType().dyn_cast<RankedTensorType>();
3839 auto off_value_type = off_value.getType().dyn_cast<RankedTensorType>();
3840
3841 if (!result_type || !indices_type || !on_value_type || !off_value_type)
3842 return llvm::None;
3843
3844 // OneHot operator creates a new tensor with shape indices.shape[:axis] +
3845 // [depth] + indices.shape[axis:] For each index in 'indices', it needs to
3846 // be within range of [0, depth - 1] and the [..., k, ...] = on_value (if k
3847 // = index), or [..., k, ...] = off_value (if k != index)
3848 //
3849 // The lowering below assumes depth is always known at compile time.
3850 // TBD for depth resolved in run time.
3851 //
3852 // OneHot can be lowered as TOSA Scatter, where off_value being mapped to
3853 // 'values_in', on_value being mapped to 'input', and indices naturally
3854 // mapped to 'indices'. Also the dimensions of TOSA scatter (N, W, K, C)
3855 // need to be picked.
3856 //
3857 // N: number of elements of input indices
3858 // Computed as:
3859 // product(indices.shape[:])
3860 //
3861 // K: newly added dimension
3862 // K = depth
3863 //
3864 // W, C: dummy dimension now
3865 // W = C = 1
3866 //
3867 // High level description of lowering looks like:
3868 // 1. off_value is reshaped/tiled into [N, K, C]
3869 // 2. on_value is reshaped/tiled into [N, W, C]
3870 // 3. indices is reshaped into [N, W]
3871 // 4. scatter into [N, K, C]
3872 // 5. reshaped into [LeftDims, RightDims, K]
3873 // 6. transpose into [LeftDims, K, RightDims]
3874 // 7. reshaped to result.shape
3875
3876 if (on_value_type.getRank() != 0 || off_value_type.getRank() != 0) {
3877 op->emitOpError("OneHotOp: on_value/off_value needs to be scalar");
3878 return llvm::None;
3879 }
3880
3881 if (axis < -1 || axis > indices_type.getRank()) {
3882 op->emitOpError("OneHotOp: axis out of valie range [-1, indices.rank]");
3883 return llvm::None;
3884 }
3885
3886 // axis = -1 is equivalent to axis = indices.rank
3887 if (axis == -1) {
3888 axis = indices_type.getRank();
3889 }
3890
3891 int N = 1, W = 1, C = 1;
3892 int K = depth;
3893 int left_dim = 1, right_dim = 1;
3894
3895 for (int32_t i = 0; i < indices_type.getRank(); i++) {
3896 int32_t dim = indices_type.getShape()[i];
3897 N *= dim;
3898 if (i >= axis) {
3899 right_dim *= dim;
3900 } else {
3901 left_dim *= dim;
3902 }
3903 }
3904
3905 // Reshape on_value to [1, 1, 1]
3906 auto op1_reshape_on_value = CreateOpAndInfer<tosa::ReshapeOp>(
3907 rewriter, op->getLoc(),
3908 RankedTensorType::get({1, 1, 1}, on_value_type.getElementType()),
3909 on_value, rewriter.getI64ArrayAttr({1, 1, 1}));
3910
3911 // And tile to [N, W, C]
3912 auto op2_tile_op1 = CreateOpAndInfer<tosa::TileOp>(
3913 rewriter, op->getLoc(),
3914 RankedTensorType::get({N, W, C}, on_value_type.getElementType()),
3915 op1_reshape_on_value.getResult(), rewriter.getI64ArrayAttr({N, W, C}));
3916
3917 // Reshape off_value to [1, 1, 1]
3918 auto op3_reshape_off_value = CreateOpAndInfer<tosa::ReshapeOp>(
3919 rewriter, op->getLoc(),
3920 RankedTensorType::get({1, 1, 1}, off_value_type.getElementType()),
3921 off_value, rewriter.getI64ArrayAttr({1, 1, 1}));
3922
3923 // And tile to [N, K, C]
3924 auto op4_tile_op3 = CreateOpAndInfer<tosa::TileOp>(
3925 rewriter, op->getLoc(),
3926 RankedTensorType::get({N, K, C}, on_value_type.getElementType()),
3927 op3_reshape_off_value.getResult(), rewriter.getI64ArrayAttr({N, K, C}));
3928
3929 // Reshape indices to [N, W]
3930 auto op5_reshape_indices = CreateOpAndInfer<tosa::ReshapeOp>(
3931 rewriter, op->getLoc(),
3932 RankedTensorType::get({N, W}, indices_type.getElementType()),
3933 indices_value, rewriter.getI64ArrayAttr({N, W}));
3934
3935 // Scatter to [N, K, C]
3936 auto op6_scatter_op4_op5_op2 = CreateOpAndInfer<tosa::ScatterOp>(
3937 rewriter, op->getLoc(),
3938 RankedTensorType::get({N, K, C}, result_type.getElementType()),
3939 op4_tile_op3.getResult(), op5_reshape_indices.getResult(),
3940 op2_tile_op1.getResult());
3941
3942 // Reshaped to [LeftDims, RightDims, K]. C being squeezed out since it's 1.
3943 auto op7_reshape_op6 = CreateOpAndInfer<tosa::ReshapeOp>(
3944 rewriter, op->getLoc(),
3945 RankedTensorType::get({left_dim, right_dim, K},
3946 result_type.getElementType()),
3947 op6_scatter_op4_op5_op2.getResult(),
3948 rewriter.getI64ArrayAttr({left_dim, right_dim, K}));
3949
3950 // Transposed to [LeftDims, K, RightDims].
3951 llvm::Optional<Value> perm_const =
3952 getConstTensor<int32_t>(rewriter, op, /*vec=*/{0, 2, 1}, /*shape=*/{3});
3953
3954 if (!perm_const) return llvm::None;
3955
3956 auto op8_transpose_op7 = CreateOpAndInfer<tosa::TransposeOp>(
3957 rewriter, op->getLoc(),
3958 RankedTensorType::get({left_dim, K, right_dim},
3959 result_type.getElementType()),
3960 op7_reshape_op6.getResult(), perm_const.getValue());
3961
3962 // Reshaped to result.shape.
3963 return CreateOpAndInfer<tosa::ReshapeOp>(
3964 rewriter, op->getLoc(), result_type, op8_transpose_op7.getResult(),
3965 rewriter.getI64ArrayAttr(result_type.getShape()))
3966 .getResult();
3967 }
3968
3969 }; // namespace tosa
3970 }; // namespace mlir
3971