1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file contains legalizations common to mapping both TensorFlow and
17 // TensorFlow Lite to TOSA. It operates generically on ops and does not have
18 // a hard reference on either dialect.
19 //
20 // Conversion functions return llvm::None on a legalization failure or a
21 // legalized value on success. Callers must check for presence of an
22 // llvm::Optional value after each call.
23
24 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
25
26 #include <climits>
27 #include <cstddef>
28 #include <cstdint>
29 #include <iterator>
30 #include <numeric>
31
32 #include "llvm/Support/FormatVariadic.h"
33 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
34 #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
35 #include "mlir/IR/Matchers.h" // from @llvm-project
36 #include "mlir/IR/PatternMatch.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
38
39 namespace mlir {
40 namespace tosa {
41
42 // Lowers the Pack operator to TOSA.
convertPackOp(PatternRewriter & rewriter,Operation * op,Value result_value,SmallVector<Value,8> & inputs,int32_t axis)43 llvm::Optional<Value> convertPackOp(PatternRewriter& rewriter, Operation* op,
44 Value result_value,
45 SmallVector<Value, 8>& inputs,
46 int32_t axis) {
47 //////////////////////////////////////////////////
48 // Operator: output = Pack([values], axis) or output = Stack([values], axis)
49 // Lowering:
50 //
51 // This operator is lowered into a series of pairwise tosa.concat()
52 // operators and a reshape
53 // Depending on the inputs, a tranpose operator is also generated:
54 //
55 // Step 1: concatenate the tensors
56 // a1_concat = tosa.concat(input[0], input[1], axis)
57 // for (i = 2; i < len(input); i++)
58 // a1_concat = tosa.concat(a1_concat, input[i], axis)
59 //
60 // Step 2: reshape to N+1 dimensions
61 // a2_reshape = tosa.reshape(a1_concat, new_rank)
62 //
63 // Step 3: Transpose if a new dimension is being added:
64 // if (axis == rank(values[0]):
65 // // perm will be [1, 2, 3, 0]
66 // a3_transpose = tosa.transpose(a2_reshape, perm)
67
68 // Sanity check 1: make sure all input tensors have the same shape
69 // if input[0] has shape [A, B, C], input[1] to input[N-1] should also have
70 // shape[A, B, C]
71 RankedTensorType result_type =
72 result_value.getType().dyn_cast<RankedTensorType>();
73
74 // Check for ranked tensor type.
75 if (!result_type) {
76 op->emitOpError("PackOp: result type not ranked tensor");
77 return llvm::None;
78 }
79
80 // Valid axis in TF is [-rank(input), rank(input))
81 // Valid axis in TOSA is [0, rank(input))
82 // Plus rank(input) once if axis is negative.
83 RankedTensorType input_type =
84 op->getOperand(0).getType().dyn_cast<RankedTensorType>();
85 if (!input_type) {
86 op->emitOpError("PackOp: input type not ranked tensor");
87 return llvm::None;
88 }
89
90 int32_t input_rank = input_type.getShape().size();
91 if (axis < 0) axis += input_rank;
92
93 input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
94 if (!input_type) {
95 op->emitOpError("Input 0 type not ranked tensor.");
96 return llvm::None;
97 }
98 ArrayRef<int64_t> input0_tensor_shape = input_type.getShape();
99 int input_tensor_rank = input0_tensor_shape.size();
100
101 for (int i = 1; i < inputs.size(); i++) {
102 input_type = inputs[0].getType().dyn_cast<RankedTensorType>();
103 if (!input_type) {
104 op->emitOpError(llvm::formatv(
105 "reduce axis {} is not in valid range [-rank(input), rank(input))",
106 i));
107 return llvm::None;
108 }
109 ArrayRef<int64_t> next_tensor_shape = input_type.getShape();
110 if (next_tensor_shape.size() != input_tensor_rank) {
111 op->emitOpError("PackOp: input tensor rank mismatch.");
112 return llvm::None;
113 }
114 for (int d = 0; d < input0_tensor_shape.size(); d++) {
115 if (input0_tensor_shape[d] != next_tensor_shape[d]) {
116 op->emitOpError("PackOp: input tensor shape mismatch.");
117 return llvm::None;
118 }
119 }
120 }
121
122 // If input tensors are rank 0, should reshape them to rank 1 size 1 before
123 // performing concat.
124 if (input_tensor_rank == 0) {
125 SmallVector<int64_t, 8> reshape_rank1_size1_shape{1};
126 RankedTensorType reshape_rank1_size1_type =
127 RankedTensorType::get(ArrayRef<int64_t>(reshape_rank1_size1_shape),
128 result_type.getElementType());
129 ArrayAttr shape_rank1_size1_attr =
130 rewriter.getI64ArrayAttr(reshape_rank1_size1_shape);
131 for (int i = 0; i < inputs.size(); i++) {
132 auto a0_reshape_op = rewriter.create<tosa::ReshapeOp>(
133 op->getLoc(), reshape_rank1_size1_type, inputs[i],
134 shape_rank1_size1_attr);
135 inputs[i] = a0_reshape_op.getResult();
136 }
137 }
138
139 // Sanity check 2: axis can be from [0, rank(input)+1]
140 // Where rank(input)+1 means create a new dimension
141 // Negative values are also allowed up to -(rank(input)+1)
142 // where the axis "wraps around".
143 if (axis < 0) axis += input_rank;
144
145 if (axis > (input_tensor_rank + 1)) {
146 op->emitOpError("PackOp: axis out of valid range.");
147 return llvm::None;
148 }
149
150 // Sanity check 2: if input shape is [A, B, C], output shape should be [N,
151 // A, B, C]
152 // 2.a check output is rank(input) + 1
153 SmallVector<int64_t, 8> output_shape_vals(result_type.getShape().begin(),
154 result_type.getShape().end());
155 if (output_shape_vals.size() != (input_tensor_rank + 1)) {
156 op->emitOpError("PackOp: output tensor rank mismatch.");
157 return llvm::None;
158 }
159 // 2.b check output rank 0 is N
160 if (output_shape_vals[axis] != inputs.size()) {
161 op->emitOpError("PackOp: output tensor shape mismatch.");
162 return llvm::None;
163 }
164 // Most of the cases when PackOp.axis() is within [0, rank(input) - 1].
165 // We can directly concatenate along that axis and perform the reshape.
166 // For example, stack N [A, B, C] input tensor ranks along axis = 1
167 // after concatenation, output will be [A, N * B, C]
168 // and then reshape it into [A, N, B, C]
169 // a special case would be PackOp.axis() equal to rank(input), in which case
170 // we can't directly concatenate along the PackOp.axis(), instead
171 // we concat along axis=0, and reshape into [N, A, B, C]
172 // and then we need an extra transpose to [A, B, C, N].
173 int64_t concat_axis;
174 SmallVector<int32_t, 8> perm;
175 SmallVector<int64_t, 8> reshape_output_shape;
176 if (axis == 0 && input_tensor_rank == 0) {
177 concat_axis = 0;
178 // Don't need reshape and perm, since we inputs are reshaped into rank 1
179 // size 1. Output will be rank 1 size N.
180 } else if (axis == input_tensor_rank) {
181 concat_axis = 0;
182
183 // A special case when stack axis is equal to input tensor rank:
184 // Output shape is [A, B, C, N]
185 // so reshape output will be [N, A, B, C]
186 // and perm will be [1, 2, 3, 0].
187 reshape_output_shape.push_back(output_shape_vals[axis]);
188 for (int d = 0; d < input_tensor_rank; d++) {
189 perm.push_back(d + 1);
190 reshape_output_shape.push_back(output_shape_vals[d]);
191 }
192 perm.push_back(0);
193 } else {
194 // General case, doesn't need perm vector.
195 concat_axis = axis;
196 reshape_output_shape.assign(output_shape_vals.begin(),
197 output_shape_vals.end());
198 }
199 IntegerAttr concat_axis_attr = rewriter.getI64IntegerAttr(concat_axis);
200 ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_output_shape);
201
202 // For each concat output, shape will be different.
203 // If input shape is [A, B, C] and concat_axis = 0, 1st concat output will
204 // be [2 * A, B, C].
205 int orig_input_dim_on_axis;
206 SmallVector<int64_t, 4> concat_output_shape;
207 if (input_tensor_rank == 0) {
208 concat_output_shape.push_back(1);
209 orig_input_dim_on_axis = 1;
210 } else {
211 for (int i = 0; i < input_tensor_rank; i++) {
212 concat_output_shape.push_back(input0_tensor_shape[i]);
213 }
214 orig_input_dim_on_axis = input0_tensor_shape[concat_axis];
215 }
216
217 concat_output_shape[concat_axis] = orig_input_dim_on_axis * 2;
218 RankedTensorType concat_type = RankedTensorType::get(
219 ArrayRef<int64_t>(concat_output_shape), result_type.getElementType());
220 auto a1_concat_op = rewriter.create<tosa::ConcatOp>(
221 op->getLoc(), concat_type, inputs[0], inputs[1], concat_axis_attr);
222
223 // K-th concat output will be [(k+1) * A, B, C], last output will be [N * A,
224 // B, C].
225 for (int i = 2; i < inputs.size(); i++) {
226 concat_output_shape[concat_axis] = orig_input_dim_on_axis * (i + 1);
227 concat_type = RankedTensorType::get(ArrayRef<int64_t>(concat_output_shape),
228 result_type.getElementType());
229 a1_concat_op = rewriter.create<tosa::ConcatOp>(op->getLoc(), concat_type,
230 a1_concat_op.getResult(),
231 inputs[i], concat_axis_attr);
232 }
233
234 // Doesn't need reshape or transpose if input tensor is rank 0, since inputs
235 // are reshaped beforehand.
236 if (input_tensor_rank == 0) return a1_concat_op.getResult();
237
238 // Reshape [N * A, B, C] to [N, A, B, C].
239 RankedTensorType reshape_output_type = RankedTensorType::get(
240 ArrayRef<int64_t>(reshape_output_shape), result_type.getElementType());
241
242 auto a2_reshape_op = rewriter.create<tosa::ReshapeOp>(
243 op->getLoc(), reshape_output_type, a1_concat_op.getResult(), shape_attr);
244
245 // If axis is equal to input tensor rank, then we need extra transpose
246 // [N, A, B, C] to [A, B, C, N]
247 if (axis == input_tensor_rank) {
248 Value a3_transpose_perm =
249 get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, perm);
250
251 return rewriter
252 .create<tosa::TransposeOp>(op->getLoc(), result_type,
253 a2_reshape_op.getResult(), a3_transpose_perm)
254 .getResult();
255 }
256
257 return a2_reshape_op.getResult();
258 }
259
260 // Lowers the Unpack operator to TOSA
convertUnpackOp(PatternRewriter & rewriter,Operation * op,Value input_value,int32_t axis)261 llvm::Optional<ValueRange> convertUnpackOp(PatternRewriter& rewriter,
262 Operation* op, Value input_value,
263 int32_t axis) {
264 RankedTensorType input_type =
265 input_value.getType().dyn_cast<RankedTensorType>();
266 if (!input_type) return llvm::None;
267
268 auto input_shape = input_type.getShape();
269 int64_t input_rank = input_shape.size();
270
271 SmallVector<Value, 4> results_vec;
272
273 // Negative axis allowed as long as it's within [-input_rank, input_rank).
274 if (axis < 0) axis += input_rank;
275
276 assert(axis >= 0 && axis < input_shape.size());
277
278 // A list of the output types for each slice op
279 SmallVector<Type, 4> outs_type_vec;
280
281 // Step 1: transpose 'axis' to leftmost dimension.
282 Value transposed_input_value;
283 if (axis != 0) {
284 SmallVector<int32_t, 8> perm_vec;
285 SmallVector<int64_t, 2> a1_transpose_shape(input_rank);
286
287 perm_vec.push_back(axis);
288 for (int i = 0; i < input_rank; i++) {
289 if (i == axis) continue;
290 perm_vec.push_back(i);
291 }
292
293 Value a1_transpose_perm =
294 get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, perm_vec);
295
296 for (int i = 0; i < input_rank; i++) {
297 a1_transpose_shape[i] = input_shape[perm_vec[i]];
298 }
299
300 auto a1_transpose_op = rewriter.create<tosa::TransposeOp>(
301 op->getLoc(),
302 RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_shape),
303 input_type.getElementType()),
304 input_value, a1_transpose_perm);
305
306 transposed_input_value = a1_transpose_op.getResult();
307 } else {
308 // Do nothing if axis is already at leftmost dimension.
309 transposed_input_value = input_value;
310 }
311
312 // Step 2: slice [N, A, B, C] into N [A, B, C].
313 RankedTensorType transposed_input_type =
314 transposed_input_value.getType().dyn_cast<RankedTensorType>();
315 if (!transposed_input_type) return llvm::None;
316
317 auto transposed_input_shape = transposed_input_type.getShape();
318 int64_t transposed_input_rank = transposed_input_shape.size();
319
320 for (int i = 0; i < transposed_input_shape[0]; i++) {
321 SmallVector<int64_t, 4> begin_vals, size_vals, shape_vals;
322
323 for (int j = 0; j < transposed_input_rank; j++) {
324 if (j == 0) {
325 begin_vals.push_back(i);
326 size_vals.push_back(1);
327 } else {
328 begin_vals.push_back(0);
329 size_vals.push_back(transposed_input_shape[j]);
330 shape_vals.push_back(transposed_input_shape[j]);
331 }
332 }
333
334 ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
335 ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
336
337 auto a2_slice_op = rewriter.create<tosa::SliceOp>(
338 op->getLoc(),
339 RankedTensorType::get(ArrayRef<int64_t>(size_vals),
340 transposed_input_type.getElementType()),
341 transposed_input_value, begin, size);
342
343 auto a3_reshape_op = rewriter.create<tosa::ReshapeOp>(
344 op->getLoc(),
345 RankedTensorType::get(ArrayRef<int64_t>(shape_vals),
346 transposed_input_type.getElementType()),
347 a2_slice_op.getResult(), rewriter.getI64ArrayAttr(shape_vals));
348
349 outs_type_vec.push_back(RankedTensorType::get(
350 ArrayRef<int64_t>(shape_vals), transposed_input_type.getElementType()));
351
352 results_vec.push_back(a3_reshape_op.getResult());
353 }
354
355 // Combine the sequence of tosa.slice() ops into a list
356 // using the IdentityN operator.
357 return rewriter
358 .create<tosa::IdentityNOp>(op->getLoc(), ArrayRef<Type>(outs_type_vec),
359 results_vec)
360 .getResults();
361 }
362
363 // Lowers the Select operator to TOSA.
convertSelectOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value condition_value,Value x_value,Value y_value)364 llvm::Optional<Value> convertSelectOp(PatternRewriter& rewriter, Operation* op,
365 Value result_value, Value condition_value,
366 Value x_value, Value y_value) {
367 RankedTensorType result_type =
368 result_value.getType().dyn_cast<RankedTensorType>();
369 RankedTensorType condition_type =
370 condition_value.getType().dyn_cast<RankedTensorType>();
371 RankedTensorType x_type = x_value.getType().dyn_cast<RankedTensorType>();
372 RankedTensorType y_type = y_value.getType().dyn_cast<RankedTensorType>();
373
374 if (!result_type || !condition_type || !x_type || !y_type) {
375 op->emitOpError("Select: failed ranked tensor type check");
376 return llvm::None;
377 }
378
379 // First check whether we need to reshape the condition to match
380 // the same rank as the then/else clauses.
381 if (result_type.getRank() == condition_type.getRank()) {
382 // Nothing to reshape.
383 return rewriter
384 .create<tosa::SelectOp>(op->getLoc(), result_type, condition_value,
385 x_value, y_value)
386 .getResult();
387 }
388
389 // Need to reshape the condition.
390 SmallVector<int64_t, 8> new_cond_dims(
391 result_type.getRank() - condition_type.getRank(), 1);
392
393 for (int i = 0; i < condition_type.getRank(); i++) {
394 new_cond_dims.push_back(condition_type.getShape()[i]);
395 }
396
397 auto reshape_op = rewriter.create<tosa::ReshapeOp>(
398 op->getLoc(),
399 RankedTensorType::get(ArrayRef<int64_t>(new_cond_dims),
400 condition_type.getElementType()),
401 condition_value, rewriter.getI64ArrayAttr(new_cond_dims));
402
403 return rewriter
404 .create<tosa::SelectOp>(op->getLoc(), result_type, reshape_op, x_value,
405 y_value)
406 .getResult();
407 }
408
409 // Lowers the ZerosLike operator to TOSA by creating a constant
410 // of the desired type and shape.
convertZerosLikeOp(PatternRewriter & rewriter,Operation * op,Value result,Value input)411 llvm::Optional<Value> convertZerosLikeOp(PatternRewriter& rewriter,
412 Operation* op, Value result,
413 Value input) {
414 RankedTensorType result_type = result.getType().dyn_cast<RankedTensorType>();
415 if (!result_type) {
416 op->emitOpError("Zeroslike: result not ranked tensor type");
417 return llvm::None;
418 }
419
420 RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
421 if (!input_type) {
422 op->emitOpError("Zeroslike: input not ranked tensor type");
423 return llvm::None;
424 }
425
426 auto input_shape = input_type.getShape();
427
428 ShapedType zero_type =
429 RankedTensorType::get(input_shape, input_type.getElementType());
430 Attribute zero_attr = rewriter.getZeroAttr(zero_type);
431
432 return rewriter
433 .create<tosa::ConstOp>(op->getLoc(), zero_type,
434 zero_attr.cast<ElementsAttr>())
435 .getResult();
436 }
437
438 // Lowers the Mul operator to TOSA. For quantized types, this requires
439 // inserting rescale operators before and after the operation.
convertMultiplyOp(PatternRewriter & rewriter,Operation * op,Value output_val,Value input_lhs_val,Value input_rhs_val)440 llvm::Optional<Value> convertMultiplyOp(PatternRewriter& rewriter,
441 Operation* op, Value output_val,
442 Value input_lhs_val,
443 Value input_rhs_val) {
444 RankedTensorType input_lhs_type =
445 input_lhs_val.getType().dyn_cast<RankedTensorType>();
446 RankedTensorType input_rhs_type =
447 input_rhs_val.getType().dyn_cast<RankedTensorType>();
448 RankedTensorType output_type =
449 output_val.getType().dyn_cast<RankedTensorType>();
450 // Not a ranked tensor output
451 if (!input_lhs_type || !input_rhs_type || !output_type) return llvm::None;
452
453 bool input_lhs_is_qtype =
454 input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
455 bool input_rhs_is_qtype =
456 input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
457 bool output_is_qtype =
458 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
459
460 if (input_lhs_is_qtype != output_is_qtype ||
461 input_rhs_is_qtype != output_is_qtype) {
462 op->emitOpError(
463 "ConvertMultiplyOp: input/output tensor should "
464 "be all quantized or all floating-point");
465 return llvm::None;
466 }
467
468 Value output;
469 if (output_is_qtype) {
470 RankedTensorType rescale_type =
471 RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
472 auto input_lhs_qtype = input_lhs_type.getElementType()
473 .cast<mlir::quant::UniformQuantizedType>();
474 auto input_rhs_qtype = input_rhs_type.getElementType()
475 .cast<mlir::quant::UniformQuantizedType>();
476 auto output_qtype =
477 output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
478 double in_lhs_scale = input_lhs_qtype.getScale();
479 double in_rhs_scale = input_rhs_qtype.getScale();
480 double output_scale = output_qtype.getScale();
481
482 double output_rescale_scale = in_lhs_scale * in_rhs_scale / output_scale;
483
484 Value op1_rescale_lhs = buildRescaleToInt32(
485 rewriter, op, input_lhs_val, 1.0f, input_lhs_qtype.getZeroPoint());
486 Value op2_rescale_rhs = buildRescaleToInt32(
487 rewriter, op, input_rhs_val, 1.0f, input_rhs_qtype.getZeroPoint());
488 auto op3_mul_op1_op2 = rewriter.create<tosa::MulOp>(
489 op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs, 0);
490 return buildRescaleFromInt32(
491 rewriter, op, output_type, op3_mul_op1_op2.getResult(),
492 output_rescale_scale, output_qtype.getZeroPoint());
493 }
494
495 return rewriter
496 .create<tosa::MulOp>(op->getLoc(), output_type, input_lhs_val,
497 input_rhs_val, 0)
498 .getResult();
499 }
500
501 // Lowers the SquaredDifference operator to TOSA.
convertSquaredDifferenceOp(PatternRewriter & rewriter,Operation * op,Value result,Value x,Value y)502 llvm::Optional<Value> convertSquaredDifferenceOp(PatternRewriter& rewriter,
503 Operation* op, Value result,
504 Value x, Value y) {
505 // Squared-difference is (x-y)*(x-y).
506 // This lowering calculates the difference and multiplies.
507 RankedTensorType result_type = result.getType().dyn_cast<RankedTensorType>();
508 if (!result_type) {
509 op->emitOpError("SquaredDifference: result not ranked tensor type");
510 return llvm::None;
511 }
512
513 RankedTensorType x_type = x.getType().dyn_cast<RankedTensorType>();
514 RankedTensorType y_type = y.getType().dyn_cast<RankedTensorType>();
515 if (!x_type || !y_type) {
516 op->emitOpError("SquaredDifference: inputs not ranked tensor type");
517 return llvm::None;
518 }
519
520 auto sub_op = rewriter.create<tosa::SubOp>(op->getLoc(), result_type, x, y);
521 return rewriter
522 .create<tosa::MulOp>(op->getLoc(), result_type, sub_op.getResult(),
523 sub_op.getResult(), 0)
524 .getResult();
525 }
526
527 // Lowers the Round operator to TOSA.
convertRoundOp(PatternRewriter & rewriter,Operation * op,Value result,Value input)528 llvm::Optional<Value> convertRoundOp(PatternRewriter& rewriter, Operation* op,
529 Value result, Value input) {
530 // Implements banker's rounding by calculating floor(input + 0.5).
531 RankedTensorType result_type = result.getType().dyn_cast<RankedTensorType>();
532 if (!result_type) {
533 op->emitOpError("Round: result not ranked tensor type");
534 return llvm::None;
535 }
536
537 RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
538 if (!input_type) {
539 op->emitOpError("Round: input not ranked tensor type");
540 return llvm::None;
541 }
542
543 auto add_op = rewriter.create<tosa::AddOp>(
544 op->getLoc(), result_type, input,
545 getTosaConstTensorSingleF32(rewriter, op, 0.5));
546
547 return rewriter
548 .create<tosa::FloorOp>(op->getLoc(), result_type, add_op.getResult())
549 .getResult();
550 }
551
552 // Lowers ConcatV2 to TOSA.
convertConcatV2Op(PatternRewriter & rewriter,Operation * op,Value result_value,SmallVector<Value,8> & values,int32_t axis)553 llvm::Optional<Value> convertConcatV2Op(PatternRewriter& rewriter,
554 Operation* op, Value result_value,
555 SmallVector<Value, 8>& values,
556 int32_t axis) {
557 // ConcatV2 becomes a series of TOSA Concat operators that take pairs of
558 // tensors as arguments. Rank-0 tensors are reshaped to Rank-1,
559 // shape (1,) tensors.
560 RankedTensorType result_type =
561 result_value.getType().dyn_cast<RankedTensorType>();
562 if (!result_type) {
563 op->emitOpError("ConcatV2Op: result type not ranked tensor.");
564 return llvm::None;
565 }
566
567 // Valid axis in TF is [-rank(input), rank(input)).
568 // Valid axis in TOSA is [0, rank(input)).
569 // Plus rank(input) once if axis is negative.
570 RankedTensorType input_type =
571 op->getOperand(0).getType().dyn_cast<RankedTensorType>();
572 if (!input_type) {
573 op->emitOpError("ConcatV2Op: input type not ranked tensor.");
574 return llvm::None;
575 }
576
577 auto input_rank = input_type.getShape().size();
578
579 if (axis < 0) axis += input_rank;
580
581 assert(values.size() >= 2);
582
583 if (!values[0].getType().dyn_cast<RankedTensorType>() ||
584 !values[1].getType().dyn_cast<RankedTensorType>()) {
585 op->emitOpError("ConcatV2Op: value type not ranked tensor.");
586 return llvm::None;
587 }
588
589 Value lhs_val = values[0];
590 Value rhs_val = values[1];
591 RankedTensorType lhs_type = lhs_val.getType().cast<RankedTensorType>();
592 RankedTensorType rhs_type = rhs_val.getType().cast<RankedTensorType>();
593 ArrayRef<int64_t> lhs_tensor_shape = lhs_type.getShape();
594 ArrayRef<int64_t> rhs_tensor_shape = rhs_type.getShape();
595 int input_tensor_rank = lhs_tensor_shape.size();
596
597 // For each concat output, shape will be different.
598 // If input tensors are rank 0, should reshape them to rank 1 size 1 before
599 // performing concat. If not, most dimensions should have same size as input
600 // except the concat'd axis.
601 //
602 // If input is [A0, B, C] and [A1, B, C] and axis = 0
603 // this concat output will be [A0 + A1, B, C].
604 SmallVector<int64_t, 4> concat_result_shape;
605 if (input_tensor_rank == 0) {
606 if (axis != 0) {
607 op->emitOpError("ConcatV2Op: axis invalid.");
608 return llvm::None;
609 }
610 SmallVector<int64_t, 8> reshape_rank1_size1_shape{1};
611 RankedTensorType reshape_rank1_size1_type =
612 RankedTensorType::get(ArrayRef<int64_t>(reshape_rank1_size1_shape),
613 result_type.getElementType());
614 ArrayAttr shape_rank1_size1_attr =
615 rewriter.getI64ArrayAttr(reshape_rank1_size1_shape);
616 for (int i = 0; i < values.size(); i++) {
617 auto a0_reshape_op = rewriter.create<tosa::ReshapeOp>(
618 op->getLoc(), reshape_rank1_size1_type, values[i],
619 shape_rank1_size1_attr);
620 values[i] = a0_reshape_op.getResult();
621 }
622 concat_result_shape.push_back(2);
623 } else {
624 if (axis < 0 || axis >= input_tensor_rank) {
625 op->emitOpError("ConcatV2Op: axis invalid.");
626 return llvm::None;
627 }
628 for (int i = 0; i < input_tensor_rank; i++) {
629 concat_result_shape.push_back(lhs_tensor_shape[i]);
630 }
631 concat_result_shape[axis] = lhs_tensor_shape[axis] + rhs_tensor_shape[axis];
632 }
633
634 RankedTensorType concat_type = RankedTensorType::get(
635 ArrayRef<int64_t>(concat_result_shape), result_type.getElementType());
636
637 mlir::quant::UniformQuantizedType lhs_quant_type =
638 lhs_type.getElementType()
639 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
640 mlir::quant::UniformQuantizedType rhs_quant_type =
641 rhs_type.getElementType()
642 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
643 mlir::quant::UniformQuantizedType result_quant_type =
644 result_type.getElementType()
645 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
646
647 double lhs_scale, rhs_scale, result_scale;
648 int32_t lhs_zeropoint, rhs_zeropoint, result_zeropoint;
649
650 // tfl.concat currently allows different scales for each input tensor, which
651 // TFlite team will fix in:
652 // https://github.com/tensorflow/tensorflow/issues/39658
653 //
654 // For backward compatibility, we still need to support this artifact by
655 // scaling inputs to let them have the same scales.
656 if (result_quant_type && lhs_quant_type && rhs_quant_type) {
657 lhs_scale = static_cast<double>(lhs_quant_type.getScale());
658 lhs_zeropoint = lhs_quant_type.getZeroPoint();
659 rhs_scale = static_cast<double>(rhs_quant_type.getScale());
660 rhs_zeropoint = rhs_quant_type.getZeroPoint();
661 result_scale = static_cast<double>(result_quant_type.getScale());
662 result_zeropoint = result_quant_type.getZeroPoint();
663
664 // Rescale input if scale is not equal to output tensor scale.
665 if (lhs_scale != result_scale) {
666 RankedTensorType rescale_type =
667 RankedTensorType::get(lhs_type.getShape(), result_quant_type);
668
669 Value rescale_op = buildRescale(rewriter, op, rescale_type, lhs_val,
670 lhs_scale / result_scale, lhs_zeropoint,
671 result_zeropoint);
672
673 lhs_val = rescale_op;
674 }
675 if (rhs_scale != result_scale) {
676 RankedTensorType rescale_type =
677 RankedTensorType::get(rhs_type.getShape(), result_quant_type);
678
679 Value rescale_op = buildRescale(rewriter, op, rescale_type, rhs_val,
680 rhs_scale / result_scale, rhs_zeropoint,
681 result_zeropoint);
682
683 rhs_val = rescale_op;
684 }
685 }
686
687 auto concat_op = rewriter.create<tosa::ConcatOp>(
688 op->getLoc(), concat_type, lhs_val, rhs_val,
689 rewriter.getI64IntegerAttr(axis));
690 for (int i = 2; i < values.size(); i++) {
691 rhs_val = values[i];
692 rhs_type = rhs_val.getType().dyn_cast<RankedTensorType>();
693 rhs_tensor_shape = rhs_type.getShape();
694 rhs_quant_type = rhs_type.getElementType()
695 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
696
697 if (input_tensor_rank == 0) {
698 concat_result_shape[axis] = concat_result_shape[axis] + 1;
699 } else {
700 concat_result_shape[axis] =
701 concat_result_shape[axis] + rhs_tensor_shape[axis];
702 }
703 concat_type = RankedTensorType::get(ArrayRef<int64_t>(concat_result_shape),
704 result_type.getElementType());
705
706 if (rhs_quant_type && result_quant_type) {
707 rhs_scale = static_cast<float>(rhs_quant_type.getScale());
708 rhs_zeropoint = rhs_quant_type.getZeroPoint();
709
710 if (rhs_scale != result_scale) {
711 RankedTensorType rescale_type =
712 RankedTensorType::get(rhs_type.getShape(), result_quant_type);
713
714 Value rescale_op = buildRescale(rewriter, op, rescale_type, rhs_val,
715 rhs_scale / result_scale, rhs_zeropoint,
716 result_zeropoint);
717
718 rhs_val = rescale_op;
719 }
720 }
721
722 concat_op = rewriter.create<tosa::ConcatOp>(
723 op->getLoc(), concat_type, concat_op.getResult(), rhs_val,
724 rewriter.getI64IntegerAttr(axis));
725 }
726
727 return concat_op.getResult();
728 }
729
730 // Lowers SpaceToBatchND to TOSA.
convertSpaceToBatchNDOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value block_shape_value,Value paddings_value)731 llvm::Optional<Value> convertSpaceToBatchNDOp(PatternRewriter& rewriter,
732 Operation* op, Value result_value,
733 Value input_value,
734 Value block_shape_value,
735 Value paddings_value) {
736 /////////////////////////////////////////////////
737 // Operator: output = SpaceToBatchND(input, block_shape, paddings)
738 // Lowering:
739 //
740 // SpaceToBatch input tensors are broken into three pieces:
741 // (a) batch dimension (N in NHWC)
742 // (b) input being transformed to batch dimension (typically H, W in NHWC)
743 // (c) remainder of input (typically C in NHWC)
744 //
745 // Step 0. Generate padding constant for the first reshape.
746 // No padding on the batch dimension
747 // The input paddings array is addressed as [input_rank][2]
748 // No padding on the remaining dimensions
749 //
750 // a0_pad_const = tosa.const(input=Tensor<input_rank, 2>)
751 //
752 // Step 1. Pad the input tensor
753 //
754 // a1_pad_input_op = tosa.pad(input=input, shape=a0_pad_const_op)
755 //
756 // Step 2. Reshape the padded structure of shape padded_shape to
757 // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
758 // padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
759 // remaining_shape
760 //
761 // block_rank = M (number of elements in block_shape)
762 // New rank: input_rank + block_rank
763 //
764 // a2_reshape_a1_op = tosa.reshape(input=a1_pad_input_op, shape=a2_shape)
765 //
766 // Step 3. Transpose dimensions to:
767 // block-shape +
768 // [batch] +
769 // [padded_shape[1] / block_shape[0],
770 // ...
771 // [padded_shape[M] / block_shape[M-1]] +
772 // remaining_shape
773 //
774 // a3_transpose_a2_op = tosa.tranpose(input=a2_reshape_a1_op,
775 // perms=a3_perm)
776 //
777 // Step 4. Reshape the transposed tensor to flatten block_shape stuff
778 // into the batch dimension with the following shape:
779 // [ batch * prod(block_shape)] +
780 // [ padded_shape[1] / block_shape[0],
781 // ...,
782 // padded_shape[M] / block_shape[M-1]] +
783 // remaining_shape
784 //
785 // a4_reshape_a3_op = tosa.reshape(input=a3_tranpose_a2_op,
786 // shape=a3_shape)
787 //
788
789 RankedTensorType result_type =
790 result_value.getType().dyn_cast<RankedTensorType>();
791 RankedTensorType input_type =
792 input_value.getType().dyn_cast<RankedTensorType>();
793 RankedTensorType block_shape_type =
794 block_shape_value.getType().dyn_cast<RankedTensorType>();
795 RankedTensorType paddings_type =
796 paddings_value.getType().dyn_cast<RankedTensorType>();
797
798 // Not a ranked tensor output.
799 if (!result_type) {
800 op->emitOpError("SpaceToBatchND: result type not ranked tensor");
801 return llvm::None;
802 }
803 if (!input_type) {
804 op->emitOpError("SpaceToBatchND: input type not ranked tensor");
805 return llvm::None;
806 }
807 if (!block_shape_type) {
808 op->emitOpError("SpaceToBatchND: block shape type not ranked tensor");
809 return llvm::None;
810 }
811 if (!paddings_type) {
812 op->emitOpError("SpaceToBatchND: paddings type not ranked tensor");
813 return llvm::None;
814 }
815
816 // Follow implementation in
817 // tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
818
819 // So, to figure out the spatial_shape, remove the batch dimension and
820 // then use the next block_rank dimensions. The remaining dimensions are
821 // remaining_shape.
822
823 auto block_shape = block_shape_type.getShape();
824 auto input_shape = input_type.getShape();
825
826 int block_rank = block_shape[0];
827 int batch_size = input_shape[0];
828 int input_rank = input_type.getRank();
829 int remaining_shape_rank = input_rank - block_rank - 1;
830 int block_num_elems = 1;
831 int padding_sum = 0;
832
833 ElementsAttr block_shape_elems;
834 ElementsAttr paddings_elems;
835
836 if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems)))
837 return llvm::None;
838
839 if (!matchPattern(paddings_value, m_Constant(&paddings_elems)))
840 return llvm::None;
841
842 SmallVector<int32_t, 2> a0_pad_const(2 * (input_rank));
843 SmallVector<int64_t, 2> padded_shape(input_rank);
844
845 // 1. Pad based on paddings operand. No padding on the batch dimension.
846 // The a0_pad_const array is addressed as [input_rank][2], but
847 // it is flattened to a 1D array because LLVM appears to only accept 1D.
848 //
849 // padded_shape[] is the shape of the padded output of step a1.
850 // The name is retained for consistency with the TF reference code.
851 padded_shape[0] = input_shape[0];
852
853 // Batch dimension padding
854 a0_pad_const[0] = 0;
855 a0_pad_const[1] = 0;
856
857 // This iterator seems to be the only reliable way to get
858 // int values out of a multi-dimensional ElementsAttr.
859 int idx = 0;
860
861 for (auto i : paddings_elems.getValues<IntegerAttr>()) {
862 a0_pad_const[idx + 2] = i.getInt();
863 padding_sum += i.getInt();
864 idx++;
865 }
866
867 // Insert padding on the spatial shape dimensions
868 for (int i = 0; i < block_rank; i++) {
869 int32_t lo_pad = a0_pad_const[2 * (i + 1) + 0];
870 int32_t hi_pad = a0_pad_const[2 * (i + 1) + 1];
871
872 padded_shape[i + 1] = input_shape[i + 1] + lo_pad + hi_pad;
873 }
874
875 // No padding on the remaining_shape dimensions
876 for (int i = 0; i < remaining_shape_rank; i++) {
877 a0_pad_const[2 * (i + block_rank + 1) + 0] = 0;
878 a0_pad_const[2 * (i + block_rank + 1) + 1] = 0;
879 padded_shape[i + block_rank + 1] = input_shape[i + block_rank + 1];
880 }
881
882 RankedTensorType a0_pad_const_attr_type =
883 RankedTensorType::get({(input_rank), 2}, rewriter.getIntegerType(32));
884
885 // Create a const op to generate the tensor type for the input padding array
886 auto a0_pad_const_op = rewriter.create<tosa::ConstOp>(
887 op->getLoc(), a0_pad_const_attr_type,
888 DenseElementsAttr::get(a0_pad_const_attr_type,
889 llvm::makeArrayRef<int32_t>(a0_pad_const)));
890
891 auto a1_pad_input_op = rewriter.create<tosa::PadOp>(
892 op->getLoc(),
893 RankedTensorType::get(ArrayRef<int64_t>(padded_shape),
894 result_type.getElementType()),
895 input_value, a0_pad_const_op.getResult());
896
897 // 2. Reshape the padded structure of shape padded_shape to
898 // [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
899 // padded_shape[M] / block_shape[M-1], block_shape[M-1]] +
900 // remaining_shape
901
902 // block_rank = M (number of elements in block_shape)
903 // New rank: input_rank + block_rank
904 SmallVector<int64_t, 2> a2_shape(1 + block_rank * 2 + remaining_shape_rank);
905
906 // First dimension is batch.
907 a2_shape[0] = input_type.getShape()[0];
908 for (int i = 0; i < block_rank; i++) {
909 int32_t block_shape_val =
910 rewriter
911 .getI32IntegerAttr(
912 block_shape_elems.getValue<IntegerAttr>(i).getInt())
913 .getInt();
914 a2_shape[1 + i * 2 + 0] = padded_shape[1 + i] / block_shape_val;
915 a2_shape[1 + i * 2 + 1] = block_shape_val;
916 block_num_elems *= block_shape_val;
917 }
918
919 // Copy in the remaining block shape.
920 for (int i = 0; i < remaining_shape_rank; i++) {
921 a2_shape[1 + block_rank * 2 + i] = input_shape[1 + block_rank + i];
922 }
923
924 auto a2_reshape_a1_op = rewriter.create<tosa::ReshapeOp>(
925 op->getLoc(),
926 RankedTensorType::get(ArrayRef<int64_t>(a2_shape),
927 result_type.getElementType()),
928 a1_pad_input_op.getResult(), rewriter.getI64ArrayAttr(a2_shape));
929
930 // 3. Transpose dimensions to:
931 // block-shape +
932 // [batch] +
933 // [padded_shape[1] / block_shape[0],
934 // ...
935 // [padded_shape[M] / block_shape[M-1]] +
936 // remaining_shape
937 int32_t a2_reshape_a1_rank =
938 a2_reshape_a1_op.getResult().getType().cast<RankedTensorType>().getRank();
939 SmallVector<int32_t, 8> a3_perm(a2_reshape_a1_rank);
940 SmallVector<int64_t, 2> a3_transpose_shape(a2_reshape_a1_rank);
941
942 for (int i = 0; i < block_rank; i++) {
943 a3_perm[i] = 1 + 2 * i + 1;
944 a3_perm[block_rank + 1 + i] = 1 + 2 * i;
945 }
946 a3_perm[block_rank] = 0;
947 for (int i = 1 + block_rank * 2; i < a2_reshape_a1_rank; i++) {
948 a3_perm[i] = i;
949 }
950
951 for (int i = 0; i < a3_transpose_shape.size(); i++) {
952 a3_transpose_shape[i] = a2_shape[a3_perm[i]];
953 }
954
955 Value a3_transpose_const =
956 get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, a3_perm);
957
958 auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
959 op->getLoc(),
960 RankedTensorType::get(ArrayRef<int64_t>(a3_transpose_shape),
961 result_type.getElementType()),
962 a2_reshape_a1_op.getResult(), a3_transpose_const);
963
964 // 4. Reshape the transposed tensor to flatten block_shape
965 // into the batch dimension with the following shape:
966 // [ batch * prod(block_shape)] +
967 // [ padded_shape[1] / block_shape[0],
968 // ...,
969 // padded_shape[M] / block_shape[M-1]] +
970 // remaining_shape
971 SmallVector<int64_t, 2> a4_reshape_shape(input_rank);
972
973 // Batch
974 a4_reshape_shape[0] = batch_size * block_num_elems;
975
976 // padded shape / block_shape.
977 for (int i = 0; i < block_rank; i++) {
978 int32_t block_shape_val =
979 rewriter
980 .getI32IntegerAttr(
981 block_shape_elems.getValue<IntegerAttr>(i).getInt())
982 .getInt();
983 a4_reshape_shape[i + 1] = padded_shape[i + 1] / block_shape_val;
984 }
985
986 // Copy in remainder shape.
987 for (int i = 0; i < remaining_shape_rank; i++) {
988 a4_reshape_shape[1 + block_rank + i] = input_shape[1 + block_rank + i];
989 }
990
991 return rewriter
992 .create<tosa::ReshapeOp>(op->getLoc(), result_type,
993 a3_transpose_a2_op.getResult(),
994 rewriter.getI64ArrayAttr(a4_reshape_shape))
995 .getResult();
996 }
997
998 // Lowers BatchToSpaceND to TOSA.
convertBatchToSpaceNDOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value block_shape_value,Value crops_value)999 llvm::Optional<Value> convertBatchToSpaceNDOp(PatternRewriter& rewriter,
1000 Operation* op, Value result_value,
1001 Value input_value,
1002 Value block_shape_value,
1003 Value crops_value) {
1004 /////////////////////////////////////////////////
1005 // Operator: output = BatchToSpaceND(input, block_shape, clips)
1006 // Lowering:
1007 //
1008 // BatchToSpace input tensors are broken into three pieces:
1009 // (a) batch dimension (N in NHWC)
1010 // (b) input being transformed from batch dimension (typically H, W in
1011 // NHWC)
1012 // (c) remainder of input (typically C in NHWC)
1013 //
1014 // Step 1. Reshape input to:
1015 // [block_shape[0],
1016 // ...
1017 // [block_shape[M-1],
1018 // [batch / prod(block_shape)]
1019 // [input_shape[1],
1020 // ...
1021 // [input_shape[N-1]
1022 //
1023 // a1_reshape_input_op = tosa.reshape(input=input, shape=a1_shape)
1024 //
1025 // Step 2. Permute to shape
1026 // [ batch / prod(block_shape) ],
1027 // [ input_shape[1] ], [ block_shape[1] ]
1028 // ...
1029 // [ input_shape[M] ], [ block_shape[M-1]
1030 // + remaining_input_shapes input_shape[M .. N-1]
1031 //
1032 // a2_transpose_a1 = tosa.transpose(input=a1_reshape_input_op,
1033 // shape=a2_shape)
1034 //
1035 // Step 3. Reshape to:
1036 // [ batch / prod(block_shape) ],
1037 // [input_shape[1] * block_shape[0] ],
1038 // ..
1039 // [input_shape[M * block_shape[M-1],
1040 // + remaining input shapes [input_shape[M+1.. N-1]]
1041 //
1042 // a3_reshape_a2 = tosa.reshape(input=a2_transpose_a1, shape=a3_shape)
1043 //
1044 // Step 4. Crop the start/end dimensions according to crops of the
1045 // a3_reshape_a2 shape
1046 //
1047 // a4_slice_a3 = tosa.slice(input=a3_reshape_a2, start=a4_start,
1048 // size=a4_size)
1049
1050 RankedTensorType result_type =
1051 result_value.getType().dyn_cast<RankedTensorType>();
1052 RankedTensorType input_type =
1053 input_value.getType().dyn_cast<RankedTensorType>();
1054 RankedTensorType block_shape_type =
1055 block_shape_value.getType().dyn_cast<RankedTensorType>();
1056 RankedTensorType crops_type =
1057 crops_value.getType().dyn_cast<RankedTensorType>();
1058
1059 if (!result_type) {
1060 op->emitOpError("BatchToSpaceND: result type not ranked tensor");
1061 return llvm::None;
1062 }
1063 if (!input_type) {
1064 op->emitOpError("BatchToSpaceND: input type not ranked tensor");
1065 return llvm::None;
1066 }
1067 if (!block_shape_type) {
1068 op->emitOpError("BatchToSpaceND: block shape type not ranked tensor");
1069 return llvm::None;
1070 }
1071 if (!crops_type) {
1072 op->emitOpError("BatchToSpaceND: crops type not ranked tensor");
1073 return llvm::None;
1074 }
1075
1076 // Another 4-step process
1077 int block_rank = block_shape_type.getShape()[0];
1078 int input_rank = input_type.getRank();
1079 int crops_dims = crops_type.getShape()[0];
1080 int remaining_shape_rank = input_rank - block_rank - 1;
1081 auto input_shape = input_type.getShape();
1082
1083 ElementsAttr block_shape_elems;
1084 ElementsAttr crops_elems;
1085
1086 if (!matchPattern(block_shape_value, m_Constant(&block_shape_elems))) {
1087 op->emitOpError("BatchToSpaceND: block_shape not a constant");
1088 return llvm::None;
1089 }
1090
1091 if (!matchPattern(crops_value, m_Constant(&crops_elems))) {
1092 op->emitOpError("BatchToSpaceND: crops not a constant");
1093 return llvm::None;
1094 }
1095
1096 SmallVector<int64_t, 4> block_shape(block_rank);
1097 SmallVector<std::pair<int64_t, int64_t>, 4> crops(crops_dims);
1098
1099 // Extract values for block_shape and crops now.
1100 int block_num_elems = 1;
1101 for (int i = 0; i < block_rank; i++) {
1102 int block_shape_val =
1103 rewriter
1104 .getI32IntegerAttr(
1105 block_shape_elems.getValue<IntegerAttr>(i).getInt())
1106 .getInt();
1107 block_num_elems *= block_shape_val;
1108 block_shape[i] = block_shape_val;
1109 }
1110
1111 // This iterator seems to be the only reliable way to get
1112 // int values out of a multi-dimensional ElementsAttr
1113 SmallVector<int32_t, 2> crops_const(2 * (crops_dims));
1114 int idx = 0;
1115 for (auto i : crops_elems.getValues<IntegerAttr>()) {
1116 crops_const[idx++] = i.getInt();
1117 }
1118
1119 for (int i = 0; i < crops_dims; i++) {
1120 int crops_lo = crops_const[i * crops_dims + 0];
1121 int crops_hi = crops_const[i * crops_dims + 1];
1122 crops[i] = std::make_pair(crops_lo, crops_hi);
1123 }
1124
1125 // Step 1. Reshape input to:
1126 // [block_shape[0],
1127 // ...
1128 // [block_shape[M-1],
1129 // [batch / prod(block_shape)]
1130 // [input_shape[1],
1131 // ...
1132 // [input_shape[N-1]
1133 SmallVector<int64_t, 2> a1_shape(block_rank + input_rank);
1134
1135 for (int i = 0; i < block_rank; i++) a1_shape[i] = block_shape[i];
1136
1137 a1_shape[block_rank] = input_shape[0] / block_num_elems;
1138
1139 for (int i = 0; i < input_rank - 1; i++)
1140 a1_shape[i + block_rank + 1] = input_shape[i + 1];
1141
1142 auto a1_reshape_input_op = rewriter.create<tosa::ReshapeOp>(
1143 op->getLoc(),
1144 RankedTensorType::get(ArrayRef<int64_t>(a1_shape),
1145 result_type.getElementType()),
1146 input_value, rewriter.getI64ArrayAttr(a1_shape));
1147
1148 // 2. Permute to shape
1149 // [ batch / prod(block_shape) ],
1150 // [ input_shape[1] ], [ block_shape[0] ]
1151 // ...
1152 // [ input_shape[M] ], [ block_shape[M-1]
1153 // + remaining_input_shapes input_shape[M+1 .. N-1]
1154
1155 // 2a. calculate the permutation
1156 SmallVector<int32_t, 8> a2_perm(block_rank + input_rank);
1157 SmallVector<int64_t, 2> a2_transpose_shape(block_rank + input_rank);
1158
1159 a2_perm[0] = block_rank;
1160 for (int i = 0; i < block_rank; i++) {
1161 a2_perm[1 + i * 2 + 0] = block_rank + 1 + i;
1162 a2_perm[1 + i * 2 + 1] = i;
1163 }
1164
1165 for (int i = 0; i < remaining_shape_rank; i++) {
1166 a2_perm[1 + 2 * block_rank + i] = 1 + 2 * block_rank + i;
1167 }
1168
1169 // 2b. calculate the a2_permuted shape
1170 for (int i = 0; i < (block_rank + input_rank); i++) {
1171 a2_transpose_shape[i] = a1_shape[a2_perm[i]];
1172 }
1173
1174 Value a2_transpose_perm =
1175 get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, a2_perm);
1176 auto a2_transpose_a1_op = rewriter.create<tosa::TransposeOp>(
1177 op->getLoc(),
1178 RankedTensorType::get(ArrayRef<int64_t>(a2_transpose_shape),
1179 result_type.getElementType()),
1180 a1_reshape_input_op.getResult(), a2_transpose_perm);
1181
1182 // Step 3. Reshape to:
1183 // [ batch / prod(block_shape) ],
1184 // [input_shape[1] * block_shape[0] ],
1185 // ..
1186 // [input_shape[M * block_shape[M-1],
1187 // + remaining input shapes [input_shape[M+1.. N-1]]
1188 SmallVector<int64_t, 2> a4_shape(input_rank);
1189
1190 a4_shape[0] = input_shape[0] / block_num_elems;
1191 for (int i = 0; i < block_rank; i++) {
1192 a4_shape[1 + i] = input_shape[i + 1] * block_shape[i];
1193 }
1194 for (int i = 0; i < remaining_shape_rank; i++) {
1195 a4_shape[1 + block_rank + i] = input_shape[block_rank + 1 + i];
1196 }
1197
1198 auto a3_reshape_a2 = rewriter.create<tosa::ReshapeOp>(
1199 op->getLoc(),
1200 RankedTensorType::get(ArrayRef<int64_t>(a4_shape),
1201 result_type.getElementType()),
1202 a2_transpose_a1_op.getResult(), rewriter.getI64ArrayAttr(a4_shape));
1203
1204 // 4. Crop the start/end dimensions on 'spatial dimension' according to
1205 // crops
1206 // Use a slice operator to do the cropping.
1207 //
1208 // Calculate a beginning point and a size:
1209 // - Begin is the origin, offset by the lo crop amount in each dimension
1210 // - Size is the reshaped tensor size, minus the quantity (lo + hi) for each
1211 // dimension
1212 SmallVector<int64_t, 4> a4_begin_vals(input_rank), a4_size_vals(input_rank);
1213
1214 for (int i = 0; i < input_rank; i++) {
1215 // Batch dimension and remaining dimensions.
1216 if (i == 0 || i > crops_dims) {
1217 a4_begin_vals[i] = 0;
1218 a4_size_vals[i] = result_type.getShape()[i];
1219 } else {
1220 // Spatial dimension.
1221 assert(i - 1 >= 0 && i - 1 < crops_dims);
1222 a4_begin_vals[i] = crops[i - 1].first;
1223 a4_size_vals[i] = a4_shape[i] - crops[i - 1].first - crops[i - 1].second;
1224 }
1225 }
1226
1227 return rewriter
1228 .create<tosa::SliceOp>(
1229 op->getLoc(),
1230 RankedTensorType::get(ArrayRef<int64_t>(a4_size_vals),
1231 result_type.getElementType()),
1232 a3_reshape_a2.getResult(), rewriter.getI64ArrayAttr(a4_begin_vals),
1233 rewriter.getI64ArrayAttr(a4_size_vals))
1234 .getResult();
1235 }
1236
1237 // Lowers ExpandDims to TOSA.
convertExpandDimsOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value dim_value)1238 llvm::Optional<Value> convertExpandDimsOp(PatternRewriter& rewriter,
1239 Operation* op, Value result_value,
1240 Value input_value, Value dim_value) {
1241 // Lowers to a reshape op with 1's inserted in the appropriate dimensions.
1242 RankedTensorType output_type =
1243 result_value.getType().dyn_cast<RankedTensorType>();
1244 // Not a ranked tensor output
1245 if (!output_type) {
1246 op->emitOpError("ExpandDims: output type not ranked tensor");
1247 return llvm::None;
1248 }
1249
1250 RankedTensorType input_type =
1251 input_value.getType().dyn_cast<RankedTensorType>();
1252 if (!input_type) {
1253 op->emitOpError("ExpandDims: input type not ranked tensor");
1254 return llvm::None;
1255 }
1256
1257 auto input_shape = input_type.getShape();
1258
1259 ElementsAttr dim_elem;
1260 if (!matchPattern(dim_value, m_Constant(&dim_elem))) return llvm::None;
1261
1262 assert(dim_elem.getType().getRank() == 0 && "expected scalar tensor");
1263 int32_t dim = dim_elem.getValue<IntegerAttr>({}).getInt();
1264
1265 SmallVector<int64_t, 4> reshape_dims;
1266 if (dim < 0 || dim >= input_shape.size()) { // add dim at end of tensor
1267 dim = input_shape.size();
1268 for (int i = 0; i < input_shape.size(); i++) {
1269 reshape_dims.emplace_back(input_shape[i]);
1270 }
1271 reshape_dims.emplace_back(1);
1272 } else {
1273 for (int i = 0; i < input_shape.size(); i++) {
1274 if (i == dim) {
1275 reshape_dims.emplace_back(1);
1276 }
1277 reshape_dims.emplace_back(input_shape[i]);
1278 }
1279 }
1280
1281 ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
1282
1283 return rewriter
1284 .create<tosa::ReshapeOp>(op->getLoc(), output_type, input_value,
1285 shape_attr)
1286 .getResult();
1287 }
1288
1289 // Lowers Squeeze to TOSA.
convertSqueezeOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,SmallVector<int32_t,8> & squeeze_dims)1290 llvm::Optional<Value> convertSqueezeOp(PatternRewriter& rewriter, Operation* op,
1291 Value result_value, Value input_value,
1292 SmallVector<int32_t, 8>& squeeze_dims) {
1293 // Lowers to a reshape op where dimensions in squeeze_dims with size=1
1294 // are removed.
1295 RankedTensorType output_type =
1296 result_value.getType().dyn_cast<RankedTensorType>();
1297 // Not a ranked tensor output
1298 if (!output_type) {
1299 op->emitOpError("Squeeze: output type not ranked tensor");
1300 return llvm::None;
1301 }
1302
1303 RankedTensorType input_type =
1304 input_value.getType().dyn_cast<RankedTensorType>();
1305 if (!input_type) {
1306 op->emitOpError("Squeeze: input type not ranked tensor");
1307 return llvm::None;
1308 }
1309
1310 auto input_shape = input_type.getShape();
1311
1312 SmallVector<int64_t, 8> reshape_dims;
1313
1314 if (squeeze_dims.empty()) { // remove all 1-dims
1315 for (int i = 0; i < input_shape.size(); i++) {
1316 if (input_shape[i] != 1) {
1317 reshape_dims.emplace_back(input_shape[i]);
1318 }
1319 }
1320 } else {
1321 // Remove only specified dims.
1322 // First sort the array so they can be picked off in sequence.
1323 std::sort(squeeze_dims.begin(), squeeze_dims.end(),
1324 [](const int32_t& a, const int32_t& b) { return a < b; });
1325
1326 int pos = 0;
1327 auto dim = squeeze_dims[pos];
1328 for (int i = 0; i < input_shape.size(); i++) {
1329 if (i == dim) {
1330 pos = pos + 1;
1331 if (pos < squeeze_dims.size())
1332 dim = squeeze_dims[pos];
1333 else
1334 dim = -1; // Invalid
1335 } else {
1336 reshape_dims.emplace_back(input_shape[i]);
1337 }
1338 }
1339 }
1340
1341 ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims);
1342
1343 return rewriter
1344 .create<tosa::ReshapeOp>(op->getLoc(), output_type, input_value,
1345 shape_attr)
1346 .getResult();
1347 }
1348
1349 // Lowers ELU to a sequence of TOSA ops.
convertEluOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value features_value)1350 llvm::Optional<Value> convertEluOp(PatternRewriter& rewriter, Operation* op,
1351 Value result_value, Value features_value) {
1352 // Lowers Elu using the following formula:
1353 // elu(x) = x < 0 ? (exp(x) - 1) : x
1354 // one = const({1});
1355 // zero = const({0});
1356 // one_bcast = reshape(one, [1, ..., rank(x) - 1])
1357 // zero_bcast = reshape(zero, [1, ..., rank(x) - 1])
1358 // a1 = exp(x);
1359 // a2 = sub(a1, one_bcast)
1360 // a3 = ge(x, zero_bcast)
1361 // a4 = select(a3, x, a2)
1362 RankedTensorType output_type =
1363 result_value.getType().dyn_cast<RankedTensorType>();
1364 // Not a ranked tensor output
1365 if (!output_type) {
1366 op->emitOpError("Elu: output type not ranked tensor");
1367 return llvm::None;
1368 }
1369
1370 int32_t input_rank = output_type.getShape().size();
1371 SmallVector<int64_t, 4> bcast_shape(input_rank, 1);
1372
1373 // Can't directly create size=1, rank=rank(input) tensor because
1374 // it will be optimized out. Instead, create rank0 tensor and reshape later.
1375 Value one_const_op = getTosaConstTensorSingleF32(rewriter, op, 1.0);
1376
1377 Value zero_const_op = getTosaConstTensorSingleF32(rewriter, op, 0.0);
1378
1379 auto a1_exp_in_op =
1380 rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, features_value);
1381
1382 auto a2_sub_a1_one_op = rewriter.create<tosa::SubOp>(
1383 op->getLoc(), output_type, a1_exp_in_op.getResult(), one_const_op);
1384
1385 auto a3_ge_in_zero_op = rewriter.create<tosa::GreaterEqualOp>(
1386 op->getLoc(),
1387 RankedTensorType::get(output_type.getShape(), rewriter.getIntegerType(1)),
1388 features_value, zero_const_op);
1389
1390 return rewriter
1391 .create<tosa::SelectOp>(op->getLoc(), output_type,
1392 a3_ge_in_zero_op.getResult(), features_value,
1393 a2_sub_a1_one_op.getResult())
1394 .getResult();
1395 }
1396
1397 // Lowers Softmax to a sequence of TOSA ops.
convertSoftmaxOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value logits_value)1398 llvm::Optional<Value> convertSoftmaxOp(PatternRewriter& rewriter, Operation* op,
1399 Value result_value, Value logits_value) {
1400 // softmax = exp(logits) / reduce_sum(exp(logits), -1)
1401 //
1402 // or equivalently multiply exp(-max(logits)) to both numerator and
1403 // denominator we get:
1404 //
1405 // softmax = exp(logits - max(logits)) / reduce_sum(exp(logits -
1406 // max(logits)), -1)
1407 //
1408 // We'll use first version for direct fp lowering, and second version for
1409 // quantized lowering since second one we can restrict input to exp() be
1410 // negative, and thus LUT can always be within [0.0, 1.0].
1411 RankedTensorType output_type =
1412 result_value.getType().dyn_cast<RankedTensorType>();
1413 RankedTensorType input_type =
1414 logits_value.getType().dyn_cast<RankedTensorType>();
1415
1416 // Not a ranked tensor input/output
1417 if (!output_type || !input_type) {
1418 op->emitOpError("Softmax: input and result not ranked tensors");
1419 return llvm::None;
1420 }
1421
1422 // reduce_sum on last dimension
1423 int32_t input_rank = input_type.getShape().size();
1424 ArrayRef<int64_t> logits_shape = output_type.getShape();
1425
1426 if (input_type.getElementType().isa<mlir::quant::QuantizedType>() &&
1427 output_type.getElementType().isa<mlir::quant::QuantizedType>()) {
1428 SmallVector<int64_t, 4> rsum_shape_v(input_type.getShape().begin(),
1429 input_type.getShape().end() - 1);
1430 rsum_shape_v.push_back(1);
1431 ArrayRef<int64_t> rsum_shape(rsum_shape_v);
1432 // The if condition already checks if these are UQTs
1433 mlir::quant::UniformQuantizedType in_quant_type =
1434 input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
1435 mlir::quant::UniformQuantizedType out_quant_type =
1436 output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
1437
1438 auto int16_element_qtype = mlir::quant::UniformQuantizedType::get(
1439 true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
1440 -32768, 32767);
1441 RankedTensorType int16_logits_type =
1442 RankedTensorType::get(logits_shape, int16_element_qtype);
1443 RankedTensorType int32_logits_type =
1444 RankedTensorType::get(logits_shape, rewriter.getIntegerType(32));
1445 RankedTensorType int16_rsum_type =
1446 RankedTensorType::get(rsum_shape, int16_element_qtype);
1447 RankedTensorType int32_rsum_type =
1448 RankedTensorType::get(rsum_shape, rewriter.getIntegerType(32));
1449
1450 // Step 1. get x - max(x)
1451 Value op1_rescale_in =
1452 buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f,
1453 in_quant_type.getZeroPoint(), 0);
1454
1455 auto op2_reducemax_op1 = rewriter.create<tosa::ReduceMaxOp>(
1456 op->getLoc(), int32_rsum_type, op1_rescale_in,
1457 rewriter.getI64IntegerAttr(input_rank - 1));
1458
1459 auto op3_sub_op1_op2 = rewriter.create<tosa::SubOp>(
1460 op->getLoc(), int32_logits_type, op1_rescale_in,
1461 op2_reducemax_op1.getResult());
1462
1463 // Table input range from -16.0 to 16.0, input below -16.0 treated as
1464 // exp(-16.0), which is 0 in 0.16
1465 const double exp_sample_grain = 1.0 / 16.0;
1466 auto exp_func = [exp_sample_grain](int32_t x) -> int32_t {
1467 double v = static_cast<double>(x) * exp_sample_grain;
1468 v = v < 0.0 ? std::exp(v) : 1.0;
1469 return std::lround(32768.0 * v);
1470 };
1471
1472 Value exp_table_const = getTosa1DConstTensorTable(rewriter, op, exp_func);
1473
1474 // Step 2. rescale input
1475 Value op4_rescale_op3 = buildRescale(
1476 rewriter, op, int16_logits_type, op3_sub_op1_op2.getResult(),
1477 in_quant_type.getScale() * 128.0 / exp_sample_grain, 0, 0);
1478
1479 // Step 3. get exp() result
1480 // Since we already make sure input x < 0 in step 1,
1481 // we can utilize full output 0.16 range.
1482
1483 // Output is 0.23
1484 auto op5_table_op4 = rewriter.create<tosa::TableOp>(
1485 op->getLoc(), int32_logits_type, op4_rescale_op3, exp_table_const);
1486
1487 // Right shift 3 bits. output 0.20
1488 auto op6_rshift_op5 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1489 op->getLoc(), int32_logits_type, op5_table_op4.getResult(),
1490 getTosaConstTensorSingleI32(rewriter, op, 3), true);
1491
1492 // Step 4. get sum(exp()). output 12.20
1493 auto op7_reducesum_op6 = rewriter.create<tosa::ReduceSumOp>(
1494 op->getLoc(), int32_rsum_type, op6_rshift_op5.getResult(),
1495 rewriter.getI64IntegerAttr(input_rank - 1));
1496
1497 // Step 5. calculate reciprocal(sum(exp()))
1498 auto op8_clz_op7 = rewriter.create<tosa::ClzOp>(
1499 op->getLoc(), int32_rsum_type, op7_reducesum_op6.getResult());
1500
1501 // rshift amount of reciprocal(sum(exp()))
1502 // 12 from the integer bits of 12.20 accumulator
1503 // 30 from output of multiply 0.15 x 0.15
1504 // -8 to keep additional 8 bits before output rescaling
1505 auto op9_sub_op8 = rewriter.create<tosa::SubOp>(
1506 op->getLoc(), int32_rsum_type,
1507 getTosaConstTensorSingleI32(rewriter, op, 12 + 30 - 8),
1508 op8_clz_op7.getResult());
1509
1510 // Left shift to get 1.31 format
1511 auto op10_lshift_op7_op8 = rewriter.create<tosa::LogicalLeftShiftOp>(
1512 op->getLoc(), int32_rsum_type, op7_reducesum_op6.getResult(),
1513 op8_clz_op7.getResult());
1514
1515 // Subtract (1 << 31) to make 0 <= x <= 1
1516 auto op11_sub_op10 = rewriter.create<tosa::SubOp>(
1517 op->getLoc(), int32_rsum_type, op10_lshift_op7_op8.getResult(),
1518 getTosaConstTensorSingleI32(rewriter, op, (1u << 31)));
1519
1520 // Right shift 16 bits to get 16 bits index
1521 auto op12_rshift_op11 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1522 op->getLoc(), int32_rsum_type, op11_sub_op10.getResult(),
1523 getTosaConstTensorSingleI32(rewriter, op, 16), true);
1524
1525 // cast to 16 bits to index TABLE op
1526 auto op13_cast_op12 = rewriter.create<tosa::CastOp>(
1527 op->getLoc(), int16_rsum_type, op12_rshift_op11.getResult());
1528
1529 // Generate table for 1 / (1 + x), for 0 <= x <= 1
1530 const double one_over_one_plus_x_sample_grain = 1.0 / 256.0;
1531 auto one_over_one_plus_x_func =
1532 [one_over_one_plus_x_sample_grain](int32_t x) -> int32_t {
1533 double v = static_cast<double>(x) * one_over_one_plus_x_sample_grain;
1534 v = v < 0 ? 1.0 : 1.0 / (1.0 + v);
1535 return std::lround(32768.0 * v);
1536 };
1537
1538 Value one_over_one_plus_x_table_const =
1539 getTosa1DConstTensorTable(rewriter, op, one_over_one_plus_x_func);
1540
1541 auto op14_table_op13 = rewriter.create<tosa::TableOp>(
1542 op->getLoc(), int32_rsum_type, op13_cast_op12.getResult(),
1543 one_over_one_plus_x_table_const);
1544
1545 // Rescale sum(exp(x)) from 0.23 back to 0.16
1546 Value op15_rescale_op14 = buildRescale(rewriter, op, int32_rsum_type,
1547 op14_table_op13, 1.0 / 128.0, 0, 0);
1548
1549 // Rescale exp(x) from 0.23 back to 0.16
1550 Value op16_rescale_op5 =
1551 buildRescale(rewriter, op, int32_logits_type, op5_table_op4.getResult(),
1552 1.0 / 128.0, 0, 0);
1553
1554 // Step 6. apply the scales we just get explicitly in i32 space
1555 // lhs: 0.16, rhs: 0.16, output: 0.32
1556 auto op17_mul_op15_op16 =
1557 rewriter.create<tosa::MulOp>(op->getLoc(), int32_logits_type,
1558 op15_rescale_op14, op16_rescale_op5, 0);
1559
1560 // Apply right shift from clz
1561 auto op18_rshift_op17_op9 = rewriter.create<tosa::ArithmeticRightShiftOp>(
1562 op->getLoc(), int32_logits_type, op17_mul_op15_op16.getResult(),
1563 op9_sub_op8.getResult(), true);
1564
1565 // Step 7. output scaling, extra 1.0 / 256.0 since we keep extra 8 bits
1566 // in op9_sub_op8
1567 return buildRescale(rewriter, op, output_type,
1568 op18_rshift_op17_op9.getResult(),
1569 1.0 / (out_quant_type.getScale() * 256.0), 0,
1570 out_quant_type.getZeroPoint());
1571
1572 } else {
1573 SmallVector<int64_t, 4> rsum_shape_v(input_type.getShape().begin(),
1574 input_type.getShape().end());
1575 rsum_shape_v[input_rank - 1] = 1;
1576 ArrayRef<int64_t> rsum_shape(rsum_shape_v);
1577
1578 // Floating-point loewring is more direct:
1579 //
1580 // op1 = exp(logits)
1581 // op2 = reduce_sum(op1, -1)
1582 // op3 = reciprocal(op2)
1583 // op4 = mul(op1, op3)
1584 auto op1_exp_in =
1585 rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, logits_value);
1586 RankedTensorType rsum_type =
1587 RankedTensorType::get(rsum_shape, output_type.getElementType());
1588
1589 // Keep dims so we don't need to reshape later
1590 auto op2_reducesum_op1 = rewriter.create<tosa::ReduceSumOp>(
1591 op->getLoc(), rsum_type, op1_exp_in.getResult(),
1592 rewriter.getI64IntegerAttr(input_rank - 1));
1593 auto op3_reciprocal_op2 = rewriter.create<tosa::ReciprocalOp>(
1594 op->getLoc(), rsum_type, op2_reducesum_op1.getResult());
1595
1596 return rewriter
1597 .create<tosa::MulOp>(op->getLoc(), output_type, op1_exp_in.getResult(),
1598 op3_reciprocal_op2.getResult(), 0)
1599 .getResult();
1600 }
1601 }
1602
1603 // Lowers LogSoftmax to a sequence of TOSA ops.
convertLogSoftmaxOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value logits_value)1604 llvm::Optional<Value> convertLogSoftmaxOp(PatternRewriter& rewriter,
1605 Operation* op, Value result_value,
1606 Value logits_value) {
1607 // log_softmax = log(exp(logits) / reduce_sum(exp(logits), -1))
1608 // op1 = exp(logits)
1609 // op2 = reduce_sum(op1, -1)
1610 // op3 = reciprocal(op2)
1611 // op4 = mul(op1, op3)
1612 // op5 = log(op4)
1613
1614 RankedTensorType output_type =
1615 result_value.getType().dyn_cast<RankedTensorType>();
1616 // Not a ranked tensor output
1617 if (!output_type) {
1618 op->emitOpError("LogSoftmax: output type not ranked tensor.");
1619 return llvm::None;
1620 }
1621
1622 RankedTensorType input_type =
1623 op->getOperand(0).getType().dyn_cast<RankedTensorType>();
1624 if (!input_type) {
1625 op->emitOpError("LogSoftmax: input type not ranked tensor.");
1626 return llvm::None;
1627 }
1628
1629 mlir::quant::UniformQuantizedType in_quant_type =
1630 input_type.getElementType()
1631 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
1632 mlir::quant::UniformQuantizedType out_quant_type =
1633 output_type.getElementType()
1634 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
1635 if (in_quant_type || out_quant_type) {
1636 op->emitOpError("Quantized log_softmax lowering not implemented yet");
1637 return llvm::None;
1638 }
1639
1640 auto op1_exp_in =
1641 rewriter.create<tosa::ExpOp>(op->getLoc(), output_type, logits_value);
1642
1643 // reduce_sum on last dimension
1644 int32_t input_rank = input_type.getShape().size();
1645 SmallVector<int64_t, 4> rsum_shape(output_type.getShape().begin(),
1646 output_type.getShape().end());
1647 rsum_shape[input_rank - 1] = 1;
1648 RankedTensorType rsum_type = RankedTensorType::get(
1649 ArrayRef<int64_t>(rsum_shape), output_type.getElementType());
1650 // Keep dims so we don't need to reshape later
1651 auto op2_reducesum_op1 = rewriter.create<tosa::ReduceSumOp>(
1652 op->getLoc(), rsum_type, op1_exp_in.getResult(),
1653 rewriter.getI64IntegerAttr(input_rank - 1));
1654 auto op3_reciprocal_op2 = rewriter.create<tosa::ReciprocalOp>(
1655 op->getLoc(), rsum_type, op2_reducesum_op1.getResult());
1656
1657 auto op4_mul_op1_op3 = rewriter.create<tosa::MulOp>(
1658 op->getLoc(), output_type, op1_exp_in.getResult(),
1659 op3_reciprocal_op2.getResult(), 0);
1660
1661 return rewriter
1662 .create<tosa::LogOp>(op->getLoc(), output_type,
1663 op4_mul_op1_op3.getResult())
1664 .getResult();
1665 }
1666
1667 // Lowers SpaceToDepth to a sequence of TOSA ops. Supports NHWC.
convertSpaceToDepthOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,IntegerAttr block_size_attr,StringAttr data_format)1668 llvm::Optional<Value> convertSpaceToDepthOp(PatternRewriter& rewriter,
1669 Operation* op, Value result_value,
1670 Value input_value,
1671 IntegerAttr block_size_attr,
1672 StringAttr data_format) {
1673 // NHWC lowering version:
1674 // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1]//b, b, orig_shape[2]//b,
1675 // b, orig_shape[3]])
1676 // a3 = tf.transpose(a2, [0, 1, 3, 2, 4, 5])
1677 // a4 = tf.reshape(a3, [orig_shape[0], orig_shape[1]//b, orig_shape[2]//b,
1678 // orig_shape[3]*b*b])
1679 // return a4
1680 RankedTensorType output_type =
1681 result_value.getType().dyn_cast<RankedTensorType>();
1682
1683 // Not a ranked tensor output.
1684 if (!output_type) {
1685 op->emitOpError("SpaceToDepth: output type not ranked tensor.");
1686 return llvm::None;
1687 }
1688
1689 RankedTensorType input_type =
1690 input_value.getType().dyn_cast<RankedTensorType>();
1691 if (!input_type) {
1692 op->emitOpError("SpaceToDepth: input type not ranked tensor.");
1693 return llvm::None;
1694 }
1695
1696 if (input_type.getRank() != 4) {
1697 op->emitOpError("SpaceToDepth: input rank not 4.");
1698 return llvm::None;
1699 }
1700
1701 auto input_shape = input_type.getShape();
1702
1703 if (!block_size_attr) { // This is a required parameter
1704 op->emitOpError("SpaceToDepth: block size attribute not set.");
1705 return llvm::None;
1706 }
1707
1708 SmallVector<int64_t, 2> block_size;
1709 block_size.assign(2, block_size_attr.getInt());
1710
1711 if (!data_format) data_format = rewriter.getStringAttr("NHWC");
1712
1713 if (data_format.getValue().str() != "NHWC") {
1714 op->emitOpError("SpaceToDepth: data format not NHWC.");
1715 return llvm::None;
1716 }
1717
1718 assert(block_size[0] * block_size[1] != 0);
1719
1720 SmallVector<int64_t, 4> a_reshape_dims;
1721 a_reshape_dims.push_back(input_shape[0]);
1722 a_reshape_dims.push_back(input_shape[1] / block_size[0]);
1723 a_reshape_dims.push_back(block_size[0]);
1724 a_reshape_dims.push_back(input_shape[2] / block_size[1]);
1725 a_reshape_dims.push_back(block_size[1]);
1726 a_reshape_dims.push_back(input_shape[3]);
1727
1728 RankedTensorType a_reshape_output_type = RankedTensorType::get(
1729 ArrayRef<int64_t>(a_reshape_dims), output_type.getElementType());
1730 auto a2_reshape_a_op = rewriter.create<tosa::ReshapeOp>(
1731 op->getLoc(), a_reshape_output_type, input_value,
1732 rewriter.getI64ArrayAttr(a_reshape_dims));
1733
1734 Value a3_transpose_perm = get1DConstTensor<tosa::ConstOp, int32_t>(
1735 rewriter, op, {0, 1, 3, 2, 4, 5});
1736
1737 auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
1738 op->getLoc(), a_reshape_output_type, a2_reshape_a_op.getResult(),
1739 a3_transpose_perm);
1740
1741 SmallVector<int64_t, 4> a3_reshape_dims;
1742 a3_reshape_dims.push_back(input_shape[0]);
1743 a3_reshape_dims.push_back(input_shape[1] / block_size[0]);
1744 a3_reshape_dims.push_back(input_shape[2] / block_size[1]);
1745 a3_reshape_dims.push_back(input_shape[3] * block_size[0] * block_size[1]);
1746
1747 RankedTensorType a3_reshape_output_type = RankedTensorType::get(
1748 ArrayRef<int64_t>(a3_reshape_dims), output_type.getElementType());
1749 return rewriter
1750 .create<tosa::ReshapeOp>(op->getLoc(), a3_reshape_output_type,
1751 a3_transpose_a2_op.getResult(),
1752 rewriter.getI64ArrayAttr(a3_reshape_dims))
1753 .getResult();
1754 }
1755
1756 // Lowers DepthToSpace to a sequence of TOSA ops. Supports NHWC.
convertDepthToSpaceOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,IntegerAttr block_size_attr,StringAttr data_format)1757 llvm::Optional<Value> convertDepthToSpaceOp(PatternRewriter& rewriter,
1758 Operation* op, Value result_value,
1759 Value input_value,
1760 IntegerAttr block_size_attr,
1761 StringAttr data_format) {
1762 // NHWC version
1763 // a2 = tf.reshape(a, [orig_shape[0], orig_shape[1], orig_shape[2], b, b,
1764 // orig_shape[3] // (b*b)])
1765 // a3 = tf.transpose(a2, [0, 1, 3, 2, 4, 5])
1766 // a4 = tf.reshape(a3, [orig_shape[0], orig_shape[1] * b, orig_shape[2] * b,
1767 // orig_shape[3] // (b*b)])
1768 // return a4
1769
1770 RankedTensorType output_type =
1771 result_value.getType().dyn_cast<RankedTensorType>();
1772
1773 // Not a ranked tensor output
1774 if (!output_type) {
1775 op->emitOpError("DepthToSpace: output type not ranked tensor.");
1776 return llvm::None;
1777 }
1778
1779 RankedTensorType input_type =
1780 input_value.getType().dyn_cast<RankedTensorType>();
1781 if (!input_type) {
1782 op->emitOpError("DepthToSpace: input type not ranked tensor.");
1783 return llvm::None;
1784 }
1785
1786 if (input_type.getRank() != 4) return llvm::None;
1787 auto input_shape = input_type.getShape();
1788
1789 if (!block_size_attr) { // This is a required parameter
1790 op->emitOpError("DepthToSpace: block size attribute not set.");
1791 return llvm::None;
1792 }
1793
1794 SmallVector<int64_t, 2> block_size;
1795 block_size.assign(2, block_size_attr.getInt());
1796
1797 if (!data_format) data_format = rewriter.getStringAttr("NHWC");
1798 if (data_format.getValue().str() != "NHWC") {
1799 op->emitOpError("DepthToSpace: data format not NHWC.");
1800 return llvm::None;
1801 }
1802
1803 assert(block_size[0] * block_size[1] != 0);
1804
1805 SmallVector<int64_t, 4> a_reshape_dims;
1806 a_reshape_dims.push_back(input_shape[0]);
1807 a_reshape_dims.push_back(input_shape[1]);
1808 a_reshape_dims.push_back(input_shape[2]);
1809 a_reshape_dims.push_back(block_size[0]);
1810 a_reshape_dims.push_back(block_size[1]);
1811 a_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
1812
1813 RankedTensorType a_reshape_output_type = RankedTensorType::get(
1814 ArrayRef<int64_t>(a_reshape_dims), output_type.getElementType());
1815 auto a2_reshape_a_op = rewriter.create<tosa::ReshapeOp>(
1816 op->getLoc(), a_reshape_output_type, input_value,
1817 rewriter.getI64ArrayAttr(a_reshape_dims));
1818
1819 Value a3_transpose_perm = get1DConstTensor<tosa::ConstOp, int32_t>(
1820 rewriter, op, {0, 1, 3, 2, 4, 5});
1821
1822 auto a3_transpose_a2_op = rewriter.create<tosa::TransposeOp>(
1823 op->getLoc(), a_reshape_output_type, a2_reshape_a_op.getResult(),
1824 a3_transpose_perm);
1825
1826 SmallVector<int64_t, 4> a3_reshape_dims;
1827 a3_reshape_dims.push_back(input_shape[0]);
1828 a3_reshape_dims.push_back(input_shape[1] * block_size[0]);
1829 a3_reshape_dims.push_back(input_shape[2] * block_size[1]);
1830 a3_reshape_dims.push_back(input_shape[3] / (block_size[0] * block_size[1]));
1831
1832 RankedTensorType a3_reshape_output_type = RankedTensorType::get(
1833 ArrayRef<int64_t>(a3_reshape_dims), output_type.getElementType());
1834 return rewriter
1835 .create<tosa::ReshapeOp>(op->getLoc(), a3_reshape_output_type,
1836 a3_transpose_a2_op.getResult(),
1837 rewriter.getI64ArrayAttr(a3_reshape_dims))
1838 .getResult();
1839 }
1840
1841 // Lowers Split to a sequence of TOSA ops.
convertSplitOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,int32_t num_split,int32_t axis)1842 llvm::Optional<ValueRange> convertSplitOp(PatternRewriter& rewriter,
1843 Operation* op, Value result_value,
1844 Value input_value, int32_t num_split,
1845 int32_t axis) {
1846 // This lowering creates num_split slice ops and ties them together
1847 // with IdentityN to get from an array of Operations to a single Operation
1848 // with a list of result tensors.
1849 RankedTensorType result_type =
1850 result_value.getType().dyn_cast<RankedTensorType>();
1851 // Not a ranked tensor output
1852 if (!result_type) {
1853 op->emitOpError("Split: output type not ranked tensor.");
1854 return llvm::None;
1855 }
1856
1857 RankedTensorType input_type =
1858 input_value.getType().dyn_cast<RankedTensorType>();
1859 if (!input_type) {
1860 op->emitOpError("Split: input type not ranked tensor.");
1861 return llvm::None;
1862 }
1863
1864 auto input_shape = input_type.getShape();
1865
1866 SmallVector<Value, 4> results_vec;
1867
1868 assert(axis > 0 && axis < input_shape.size());
1869 assert((input_shape[axis] % num_split) == 0);
1870 assert(num_split > 0);
1871
1872 int64_t slice_size = input_shape[axis] / num_split;
1873
1874 SmallVector<Type, 4>
1875 outs_type_vec; // A list of the output types for each slice op
1876
1877 for (int i = 0; i < num_split; i++) {
1878 // Each slice has a different begining point.
1879 // The slice size is actually the same each op.
1880 SmallVector<int64_t, 4> begin_vals, size_vals;
1881
1882 for (int j = 0; j < input_shape.size(); j++) {
1883 if (j == axis) {
1884 begin_vals.push_back(slice_size * i);
1885 size_vals.push_back(slice_size);
1886 } else {
1887 begin_vals.push_back(0);
1888 size_vals.push_back(input_shape[j]);
1889 }
1890 }
1891
1892 ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
1893 ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
1894
1895 outs_type_vec.push_back(RankedTensorType::get(
1896 ArrayRef<int64_t>(size_vals), result_type.getElementType()));
1897
1898 auto slice_op = rewriter.create<tosa::SliceOp>(
1899 op->getLoc(),
1900 RankedTensorType::get(ArrayRef<int64_t>(size_vals),
1901 result_type.getElementType()),
1902 input_value, begin, size);
1903
1904 results_vec.push_back(slice_op.getResult());
1905 }
1906
1907 // Combine the sequence of tosa.slice() ops into a list
1908 // using the IdentityN operator
1909 return rewriter
1910 .create<tosa::IdentityNOp>(op->getLoc(), ArrayRef<Type>(outs_type_vec),
1911 results_vec)
1912 .getResults();
1913 }
1914
1915 // Lowers SplitV to a sequence of TOSA ops.
convertSplitVOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,SmallVector<int32_t,4> & size_split,int32_t axis)1916 llvm::Optional<ValueRange> convertSplitVOp(PatternRewriter& rewriter,
1917 Operation* op, Value result_value,
1918 Value input_value,
1919 SmallVector<int32_t, 4>& size_split,
1920 int32_t axis) {
1921 // This lowering creates num_split slice ops and ties them together
1922 // with IdentityN to get from an array of Operations to a single Operation
1923 // with a list of result tensors.
1924 RankedTensorType result_type =
1925 result_value.getType().dyn_cast<RankedTensorType>();
1926 // Not a ranked tensor output
1927 if (!result_type) {
1928 op->emitOpError("SplitV: output type not ranked tensor.");
1929 return llvm::None;
1930 }
1931
1932 RankedTensorType input_type =
1933 input_value.getType().dyn_cast<RankedTensorType>();
1934 if (!input_type) {
1935 op->emitOpError("SplitV: input type not ranked tensor.");
1936 return llvm::None;
1937 }
1938
1939 auto input_shape = input_type.getShape();
1940
1941 SmallVector<Value, 4> results_vec;
1942
1943 assert(axis > 0 && axis < input_shape.size());
1944 int32_t size_split_sum = 0;
1945 for (int i = 0; i < size_split.size(); i++) {
1946 size_split_sum += size_split[i];
1947 }
1948
1949 // The split sizes must sum up to the size of the axis being split
1950 assert(size_split_sum == input_shape[axis]);
1951
1952 // Create num_split slice ops:
1953 SmallVector<Type, 4>
1954 outs_type_vec; // A list of the output types for each slice op
1955
1956 int32_t curr_split_start = 0;
1957 for (int i = 0; i < size_split.size(); i++) {
1958 // Each slice has a different begining point.
1959 // The slice size is different for each op.
1960 SmallVector<int64_t, 4> begin_vals, size_vals;
1961
1962 for (int j = 0; j < input_shape.size(); j++) {
1963 if (j == axis) {
1964 begin_vals.push_back(curr_split_start);
1965 size_vals.push_back(size_split[i]);
1966 } else {
1967 begin_vals.push_back(0);
1968 size_vals.push_back(input_shape[j]);
1969 }
1970 }
1971
1972 ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
1973 ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
1974
1975 outs_type_vec.push_back(RankedTensorType::get(
1976 ArrayRef<int64_t>(size_vals), result_type.getElementType()));
1977
1978 auto slice_op = rewriter.create<tosa::SliceOp>(
1979 op->getLoc(),
1980 RankedTensorType::get(ArrayRef<int64_t>(size_vals),
1981 result_type.getElementType()),
1982 input_value, begin, size);
1983
1984 results_vec.push_back(slice_op.getResult());
1985
1986 // Next start position
1987 curr_split_start += size_split[i];
1988 }
1989
1990 // Combine the sequence of tosa.slice() ops into a list
1991 // using the IdentityN operator
1992 return rewriter
1993 .create<tosa::IdentityNOp>(op->getLoc(), ArrayRef<Type>(outs_type_vec),
1994 results_vec)
1995 .getResults();
1996 }
1997
1998 // Lowers StridedSlice to a sequence of TOSA ops.
convertStridedSliceOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value input_value,Value begin_value,Value end_value,Value strides_value,int32_t begin_mask,int32_t end_mask,int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask)1999 llvm::Optional<Value> convertStridedSliceOp(
2000 PatternRewriter& rewriter, Operation* op, Value result_value,
2001 Value input_value, Value begin_value, Value end_value, Value strides_value,
2002 int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
2003 int32_t new_axis_mask, int32_t shrink_axis_mask) {
2004 // The mask arguments are bitmasks where bit [i] applies to
2005 // dimension [i] of the input tensor.
2006 //
2007 // The rough algorithm for lowering strided slice is as follows:
2008 //
2009 // 0. Process begin/end masks, since they are basically syntactic sugar
2010 // on top of the begin_value/end_value arrays
2011 //
2012 // 1. Slice1: Ignoring stride, slice the interesting range from the input
2013 // tensor
2014 //
2015 // 2. Reshape2: Reshape the tensor from (1) such that each dimension with
2016 // stride is split into two dimensions of size_i/stride_i, stride_i. A naive
2017 // implementation doubles the input tensor rank, but only dimensions being
2018 // strided actually need to be doubled.
2019 //
2020 // 3. Slice3: Slice the tensor from (2) such that we select index [0] from
2021 // each of the stride_i dimensions in (2)
2022 //
2023 // 4. Reshape4: Reshape the tensor to eliminate the stride_i dimensions, add
2024 // any dimensions in new_axis_mask and remove any dimensions in the
2025 // shrink_axis_mask
2026
2027 // Limitations:
2028 // This implementation only supports ellipsis_mask=0 for now
2029 // This implementation does not support reverse stride yet. Will need
2030 // to insert tosa.Reverse operators for this.
2031 assert(ellipsis_mask == 0);
2032
2033 RankedTensorType input_type =
2034 input_value.getType().dyn_cast<RankedTensorType>();
2035 RankedTensorType result_type =
2036 result_value.getType().dyn_cast<RankedTensorType>();
2037
2038 if (!result_type) {
2039 op->emitOpError("StridedSlice: output type not ranked tensor.");
2040 return llvm::None;
2041 }
2042
2043 if (!input_type) {
2044 op->emitOpError("StridedSlice: input type not ranked tensor.");
2045 return llvm::None;
2046 }
2047
2048 int32_t input_rank = input_type.getRank();
2049 auto input_shape = input_type.getShape();
2050
2051 // Extract the begin/end/stride tensors
2052 SmallVector<int32_t, 4> begin, end, strides;
2053
2054 if (getVectorFromValue32(begin_value, begin) != input_rank) {
2055 op->emitOpError("StridedSlice: begin doesn't match input_rank.");
2056 return llvm::None;
2057 }
2058 if (getVectorFromValue32(end_value, end) != input_rank) {
2059 op->emitOpError("StridedSlice: end doesn't match input_rank.");
2060 return llvm::None;
2061 }
2062 if (getVectorFromValue32(strides_value, strides) != input_rank) {
2063 op->emitOpError("StridedSlice: strides doesn't match input_rank.");
2064 return llvm::None;
2065 }
2066
2067 SmallVector<int64_t, 2> a1_begin(input_rank), a1_size(input_rank);
2068 SmallVector<int64_t, 2> a2_shape(input_rank * 2);
2069 SmallVector<int64_t, 2> a3_begin(input_rank * 2), a3_size(input_rank * 2);
2070 SmallVector<int64_t, 2> a4_shape;
2071
2072 // Step 0: Process the begin/end masks and build the begin/sizes for the
2073 // first slice
2074 int residual = 1;
2075 (void)residual;
2076 for (int i = 0; i < input_rank; i++) {
2077 if (begin_mask & (1 << i)) begin[i] = 0;
2078
2079 if (end_mask & (1 << i)) end[i] = input_shape[i];
2080
2081 // Wrap around index if begin and end is negative
2082 if (begin[i] < 0) begin[i] += input_shape[i];
2083
2084 if (end[i] < 0) end[i] += input_shape[i];
2085
2086 // TODO: support reverse stride
2087 a1_begin[i] = begin[i];
2088 a1_size[i] = end[i] - begin[i];
2089
2090 a2_shape[i * 2 + 0] = a1_size[i] / strides[i];
2091 a2_shape[i * 2 + 1] = strides[i];
2092
2093 a3_begin[i * 2 + 0] = 0;
2094 a3_begin[i * 2 + 1] = 0;
2095
2096 if (shrink_axis_mask & (1 << i)) {
2097 a3_size[i * 2 + 0] = 1;
2098 } else {
2099 a3_size[i * 2 + 0] = a1_size[i] / strides[i];
2100 }
2101 a3_size[i * 2 + 1] = 1;
2102
2103 if (!(shrink_axis_mask & (1 << i))) {
2104 if (new_axis_mask & (1 << i)) a4_shape.push_back(1);
2105 a4_shape.push_back((a1_size[i] / strides[i]));
2106 }
2107 }
2108
2109 // Make sure we didn't lose any dimensions from the shrink_axis_mask
2110 assert(residual == 1);
2111
2112 // Step 1: Slice the input array
2113 auto a1_slice_op = rewriter.create<tosa::SliceOp>(
2114 op->getLoc(),
2115 RankedTensorType::get(ArrayRef<int64_t>(a1_size),
2116 input_type.getElementType()),
2117 input_value, rewriter.getI64ArrayAttr(a1_begin),
2118 rewriter.getI64ArrayAttr(a1_size));
2119
2120 // Step 2: reshape the sliced array
2121 auto a2_reshape_op = rewriter.create<tosa::ReshapeOp>(
2122 op->getLoc(),
2123 RankedTensorType::get(ArrayRef<int64_t>(a2_shape),
2124 input_type.getElementType()),
2125 a1_slice_op.getResult(), rewriter.getI64ArrayAttr(a2_shape));
2126
2127 // Step 3: take a slice along the strides
2128 auto a3_slice_op = rewriter.create<tosa::SliceOp>(
2129 op->getLoc(),
2130 RankedTensorType::get(ArrayRef<int64_t>(a3_size),
2131 input_type.getElementType()),
2132 a2_reshape_op.getResult(), rewriter.getI64ArrayAttr(a3_begin),
2133 rewriter.getI64ArrayAttr(a3_size));
2134
2135 // Step 4: reshape the now-strided tensor
2136 return rewriter
2137 .create<tosa::ReshapeOp>(
2138 op->getLoc(),
2139 RankedTensorType::get(ArrayRef<int64_t>(a4_shape),
2140 input_type.getElementType()),
2141 a3_slice_op.getResult(), rewriter.getI64ArrayAttr(a4_shape))
2142 .getResult();
2143 }
2144
2145 // Lowers FloorDiv to a sequence of TOSA operators.
convertFloorDivOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value lhs_value,Value rhs_value)2146 llvm::Optional<Value> convertFloorDivOp(PatternRewriter& rewriter,
2147 Operation* op, Value result_value,
2148 Value lhs_value, Value rhs_value) {
2149 // FloorDiv lowering:
2150 // floor(1/rhs * lhs)
2151 //
2152 // a1 = reciprocal(rhs);
2153 // a2 = mul(lhs, a1);
2154 // a3 = floor(a2);
2155 // return a3;
2156 RankedTensorType output_type =
2157 result_value.getType().dyn_cast<RankedTensorType>();
2158 // Not a ranked tensor output
2159 if (!output_type) return llvm::None;
2160
2161 auto a1_reciprocal_rhs_op =
2162 rewriter.create<tosa::ReciprocalOp>(op->getLoc(), output_type, rhs_value);
2163 auto a2_mul_lhs_a1_op =
2164 rewriter.create<tosa::MulOp>(op->getLoc(), output_type, lhs_value,
2165 a1_reciprocal_rhs_op.getResult(), 0);
2166 return rewriter
2167 .create<tosa::FloorOp>(op->getLoc(), output_type,
2168 a2_mul_lhs_a1_op.getResult())
2169 .getResult();
2170 }
2171
2172 // Lowers FloorMod to a sequence of TOSA operators.
convertFloorModOp(PatternRewriter & rewriter,Operation * op,Value result_value,Value lhs_value,Value rhs_value)2173 llvm::Optional<Value> convertFloorModOp(PatternRewriter& rewriter,
2174 Operation* op, Value result_value,
2175 Value lhs_value, Value rhs_value) {
2176 // FloorMod lowering:
2177 // (1/rhs * lhs) - floor(1/rhs * lhs)
2178 // a1 = reciprocal(rhs);
2179 // a2 = mul(lhs, a1);
2180 // a3 = floor(a2);
2181 // a4 = sub(a2, a3);
2182 // return a4;
2183
2184 RankedTensorType output_type =
2185 result_value.getType().dyn_cast<RankedTensorType>();
2186 // Not a ranked tensor output
2187 if (!output_type) return llvm::None;
2188
2189 auto a1_reciprocal_rhs_op =
2190 rewriter.create<tosa::ReciprocalOp>(op->getLoc(), output_type, rhs_value);
2191 auto a2_mul_lhs_a1_op =
2192 rewriter.create<tosa::MulOp>(op->getLoc(), output_type, lhs_value,
2193 a1_reciprocal_rhs_op.getResult(), 0);
2194 auto a3_floor_a2_op = rewriter.create<tosa::FloorOp>(
2195 op->getLoc(), output_type, a2_mul_lhs_a1_op.getResult());
2196 return rewriter
2197 .create<tosa::SubOp>(op->getLoc(), output_type,
2198 a2_mul_lhs_a1_op.getResult(),
2199 a3_floor_a2_op.getResult())
2200 .getResult();
2201 }
2202
2203 // Lowers FusedActivation to a sequence of TOSA ops.
convertFusedActivation(PatternRewriter & rewriter,Operation * op,Value input_value,StringAttr fused_activation_fn)2204 llvm::Optional<Value> convertFusedActivation(PatternRewriter& rewriter,
2205 Operation* op, Value input_value,
2206 StringAttr fused_activation_fn) {
2207 RankedTensorType input_type =
2208 input_value.getType().dyn_cast<RankedTensorType>();
2209 if (!input_type) return llvm::None;
2210
2211 bool input_is_qtype =
2212 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2213
2214 if (input_is_qtype) {
2215 auto input_qtype =
2216 input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2217
2218 if (fused_activation_fn.getValue() == "TANH") {
2219 // TODO: implement with TABLE
2220 op->emitWarning("Quantized TANH lowering TBD!");
2221 return llvm::None;
2222 } else {
2223 RankedTensorType rescale_type = RankedTensorType::get(
2224 input_type.getShape(), rewriter.getIntegerType(32));
2225
2226 Value op1_rescale_in = buildRescaleToInt32(
2227 rewriter, op, input_value, 1.0f, input_qtype.getZeroPoint());
2228
2229 Value op2_relu_op1;
2230 if (fused_activation_fn.getValue() == "NONE") {
2231 return input_value;
2232 } else if (fused_activation_fn.getValue() == "RELU") {
2233 auto relu_op = rewriter.create<tosa::ReluNOp>(
2234 op->getLoc(), rescale_type, op1_rescale_in,
2235 rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
2236 rewriter.getF32FloatAttr(0));
2237
2238 op2_relu_op1 = relu_op.getResult();
2239
2240 } else if (fused_activation_fn.getValue() == "RELU6") {
2241 int64_t rescaled_6 = std::llround(6.0f / input_qtype.getScale()) +
2242 input_qtype.getZeroPoint();
2243
2244 auto relu_op = rewriter.create<tosa::ReluNOp>(
2245 op->getLoc(), rescale_type, op1_rescale_in,
2246 rewriter.getI64IntegerAttr(rescaled_6),
2247 rewriter.getF32FloatAttr(0.0f));
2248
2249 op2_relu_op1 = relu_op.getResult();
2250
2251 } else if (fused_activation_fn.getValue() == "RELU_N1_TO_1") {
2252 int64_t rescaled_n1 = std::llround(-1.0f / input_qtype.getScale()) +
2253 input_qtype.getZeroPoint();
2254 int64_t rescaled_1 = std::llround(1.0f / input_qtype.getScale()) +
2255 input_qtype.getZeroPoint();
2256
2257 auto relu_op = rewriter.create<tosa::ClampOp>(
2258 op->getLoc(), rescale_type, op1_rescale_in,
2259 rewriter.getI64IntegerAttr(rescaled_n1),
2260 rewriter.getI64IntegerAttr(rescaled_1),
2261 rewriter.getF32FloatAttr(0.0f), rewriter.getF32FloatAttr(0.0f));
2262
2263 op2_relu_op1 = relu_op.getResult();
2264 } else {
2265 return llvm::None;
2266 }
2267
2268 return buildRescaleFromInt32(rewriter, op, input_type, op2_relu_op1, 1.0f,
2269 input_qtype.getZeroPoint());
2270 }
2271 } else {
2272 if (fused_activation_fn.getValue() == "NONE") {
2273 return input_value;
2274 } else if (fused_activation_fn.getValue() == "RELU") {
2275 return rewriter
2276 .create<tosa::ReluNOp>(
2277 op->getLoc(), input_type, input_value,
2278 rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
2279 rewriter.getF32FloatAttr(std::numeric_limits<float>::max()))
2280 .getResult();
2281 } else if (fused_activation_fn.getValue() == "RELU6") {
2282 return rewriter
2283 .create<tosa::ReluNOp>(op->getLoc(), input_type, input_value,
2284 rewriter.getI64IntegerAttr(6),
2285 rewriter.getF32FloatAttr(6.0))
2286 .getResult();
2287 } else if (fused_activation_fn.getValue() == "RELU_N1_TO_1") {
2288 return rewriter
2289 .create<tosa::ClampOp>(
2290 op->getLoc(), input_type, input_value,
2291 rewriter.getI64IntegerAttr(-1), rewriter.getI64IntegerAttr(1),
2292 rewriter.getF32FloatAttr(-1.0), rewriter.getF32FloatAttr(1.0))
2293 .getResult();
2294 } else if (fused_activation_fn.getValue() == "TANH") {
2295 return rewriter
2296 .create<tosa::TanhOp>(op->getLoc(), input_type, input_value)
2297 .getResult();
2298 } else {
2299 // Unsupported activation type. Bail out.
2300 return llvm::None;
2301 }
2302 }
2303
2304 return llvm::None;
2305 }
2306
2307 // Common function for lowering reduce operations to TOSA ops.
2308 template <typename T>
convertReduceOpCommon(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims,Type reduce_element_type,bool is_quantized,double input_scale,int64_t input_zp,double output_scale,int64_t output_zp)2309 llvm::Optional<Value> convertReduceOpCommon(
2310 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2311 Value input_value, ElementsAttr axes_elems, bool keep_dims,
2312 Type reduce_element_type, bool is_quantized, double input_scale,
2313 int64_t input_zp, double output_scale, int64_t output_zp) {
2314 RankedTensorType input_type =
2315 input_value.getType().dyn_cast<RankedTensorType>();
2316 if (!input_type) return llvm::None;
2317
2318 ArrayRef<int64_t> input_shape = input_type.getShape();
2319 ArrayRef<int64_t> output_shape = output_type.getShape();
2320 auto input_rank = input_shape.size();
2321 Value val = input_value;
2322
2323 if (axes_elems.getNumElements() == 0) {
2324 // No axes means return the original tensor.
2325 auto identity_op =
2326 rewriter.create<tosa::IdentityOp>(op->getLoc(), output_type, val);
2327 val = identity_op.getResult();
2328 } else {
2329 // Reduce along each axis
2330 SmallVector<int64_t, 4> shape_vec(input_shape.begin(), input_shape.end());
2331
2332 if (is_quantized) {
2333 val = buildRescaleToInt32(rewriter, op, val, input_scale, input_zp);
2334 }
2335
2336 for (int i = 0; i < axes_elems.getNumElements(); i++) {
2337 int64_t axis_val = axes_elems.getValue<IntegerAttr>(i).getInt();
2338 if (axis_val < 0) axis_val += input_rank;
2339 auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
2340
2341 shape_vec[axis_val] = 1;
2342 RankedTensorType reduce_type = RankedTensorType::get(
2343 llvm::makeArrayRef<int64_t>(shape_vec), reduce_element_type);
2344
2345 auto reduce_op =
2346 rewriter.create<T>(op->getLoc(), reduce_type, val, axis_attr);
2347
2348 val = reduce_op.getResult();
2349 }
2350
2351 if (is_quantized) {
2352 RankedTensorType output_rescale_type = RankedTensorType::get(
2353 llvm::makeArrayRef<int64_t>(shape_vec), output_type.getElementType());
2354 val = buildRescaleFromInt32(rewriter, op, output_rescale_type, val,
2355 output_scale, output_zp);
2356 }
2357
2358 // Optionally squeeze out the reduced axes.
2359 if (!keep_dims) {
2360 auto reshape_op = rewriter.create<tosa::ReshapeOp>(
2361 op->getLoc(), output_type, val,
2362 rewriter.getI64ArrayAttr(output_shape));
2363 val = reshape_op.getResult();
2364 }
2365 }
2366
2367 return val;
2368 }
2369
2370 // Lowers ReduceAll to a sequence of TOSA ops.
convertReduceAllOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2371 llvm::Optional<Value> convertReduceAllOp(
2372 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2373 Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2374 RankedTensorType input_type =
2375 input_value.getType().dyn_cast<RankedTensorType>();
2376 if (!input_type) return llvm::None;
2377
2378 return convertReduceOpCommon<tosa::ReduceAllOp>(
2379 rewriter, op, output_type, input_value, axes_elems, keep_dims,
2380 output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2381 }
2382
2383 // Lowers ReduceAny to a sequence of TOSA ops.
convertReduceAnyOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2384 llvm::Optional<Value> convertReduceAnyOp(
2385 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2386 Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2387 RankedTensorType input_type =
2388 input_value.getType().dyn_cast<RankedTensorType>();
2389 if (!input_type) return llvm::None;
2390
2391 return convertReduceOpCommon<tosa::ReduceAnyOp>(
2392 rewriter, op, output_type, input_value, axes_elems, keep_dims,
2393 output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2394 }
2395
2396 // Lowers ReduceMin to a sequence of TOSA ops.
convertReduceMinOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2397 llvm::Optional<Value> convertReduceMinOp(
2398 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2399 Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2400 RankedTensorType input_type =
2401 input_value.getType().dyn_cast<RankedTensorType>();
2402 if (!input_type) return llvm::None;
2403
2404 return convertReduceOpCommon<tosa::ReduceMinOp>(
2405 rewriter, op, output_type, input_value, axes_elems, keep_dims,
2406 output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2407 }
2408
2409 // Lowers ReduceMax to a sequence of TOSA ops.
convertReduceMaxOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2410 llvm::Optional<Value> convertReduceMaxOp(
2411 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2412 Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2413 RankedTensorType input_type =
2414 input_value.getType().dyn_cast<RankedTensorType>();
2415 if (!input_type) return llvm::None;
2416
2417 return convertReduceOpCommon<tosa::ReduceMaxOp>(
2418 rewriter, op, output_type, input_value, axes_elems, keep_dims,
2419 output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2420 }
2421
2422 // Lowers ReduceProd to a sequence of TOSA ops.
convertReduceProdOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2423 llvm::Optional<Value> convertReduceProdOp(
2424 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2425 Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2426 RankedTensorType input_type =
2427 input_value.getType().dyn_cast<RankedTensorType>();
2428 if (!input_type) return llvm::None;
2429
2430 bool input_is_qtype =
2431 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2432 bool output_is_qtype =
2433 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2434
2435 if (input_is_qtype || output_is_qtype) {
2436 op->emitOpError(
2437 "ConvertReduceProdOp: input/output tensor should "
2438 "be all floating-point.");
2439 return llvm::None;
2440 }
2441
2442 return convertReduceOpCommon<tosa::ReduceProdOp>(
2443 rewriter, op, output_type, input_value, axes_elems, keep_dims,
2444 output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
2445 }
2446
2447 // Lowers ReduceSum to a sequence of TOSA ops.
convertReduceSumOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2448 llvm::Optional<Value> convertReduceSumOp(
2449 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2450 Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2451 RankedTensorType input_type =
2452 input_value.getType().dyn_cast<RankedTensorType>();
2453 if (!input_type) return llvm::None;
2454
2455 bool input_is_qtype =
2456 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2457 bool output_is_qtype =
2458 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2459
2460 if (input_is_qtype != output_is_qtype) {
2461 op->emitOpError(
2462 "ConvertReduceSumOp: input/output tensor should "
2463 "be all quantized or all floating-point.");
2464 return llvm::None;
2465 }
2466
2467 double input_scale = 1.0f;
2468 double output_scale = 1.0f;
2469 int64_t input_zp = 0;
2470 int64_t output_zp = 0;
2471 Type reduce_element_type = input_type.getElementType();
2472
2473 if (input_is_qtype) {
2474 auto input_qtype =
2475 input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2476 auto output_qtype =
2477 output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2478
2479 int32_t input_shift = 20;
2480
2481 input_scale =
2482 static_cast<double>(1 << input_shift) * input_qtype.getScale();
2483 output_scale =
2484 1.0 / (output_qtype.getScale() * static_cast<double>(1 << input_shift));
2485
2486 input_zp = input_qtype.getZeroPoint();
2487 output_zp = output_qtype.getZeroPoint();
2488 reduce_element_type = rewriter.getI32Type();
2489 }
2490
2491 return convertReduceOpCommon<tosa::ReduceSumOp>(
2492 rewriter, op, output_type, input_value, axes_elems, keep_dims,
2493 reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
2494 output_zp);
2495 }
2496
2497 // Lowers ReduceMean to a sequence of TOSA ops.
convertReduceMeanOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,ElementsAttr axes_elems,bool keep_dims)2498 llvm::Optional<Value> convertReduceMeanOp(
2499 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2500 Value input_value, ElementsAttr axes_elems, bool keep_dims) {
2501 // reduce_mean is lowered as followed:
2502 // op1 = reduce_sum(input)
2503 // op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
2504
2505 RankedTensorType input_type =
2506 input_value.getType().dyn_cast<RankedTensorType>();
2507 if (!input_type) return llvm::None;
2508
2509 bool input_is_qtype =
2510 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2511 bool output_is_qtype =
2512 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2513
2514 if (input_is_qtype != output_is_qtype) {
2515 op->emitOpError(
2516 "ConvertReduceSumOp: input/output tensor should "
2517 "be all quantized or all floating-point.");
2518 return llvm::None;
2519 }
2520
2521 // Only supports float type mean() if it's non-quantized
2522 if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) {
2523 op->emitWarning(
2524 "Failed convertReduceMean: input unquantized type but output element "
2525 "not FloatType!");
2526 return llvm::None;
2527 }
2528
2529 int64_t input_rank = input_type.getRank();
2530 int64_t num_elems_on_reduced_axis = 1;
2531 for (int i = 0; i < axes_elems.getNumElements(); i++) {
2532 int64_t axis_val = axes_elems.getValue<IntegerAttr>(i).getInt();
2533 if (axis_val < 0) axis_val += input_rank;
2534 num_elems_on_reduced_axis *= input_type.getShape()[axis_val];
2535 }
2536 double div_scale = 1.0 / static_cast<double>(num_elems_on_reduced_axis);
2537
2538 double input_scale = 1.0f;
2539 double output_scale = 1.0f;
2540 int64_t input_zp = 0;
2541 int64_t output_zp = 0;
2542 Type reduce_element_type = input_type.getElementType();
2543
2544 if (input_is_qtype) {
2545 auto input_qtype =
2546 input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2547 auto output_qtype =
2548 output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
2549
2550 int32_t input_shift = 20;
2551
2552 input_scale =
2553 static_cast<double>(1 << input_shift) * input_qtype.getScale();
2554 output_scale = div_scale / (output_qtype.getScale() *
2555 static_cast<double>(1 << input_shift));
2556
2557 input_zp = input_qtype.getZeroPoint();
2558 output_zp = output_qtype.getZeroPoint();
2559 reduce_element_type = rewriter.getI32Type();
2560 }
2561
2562 auto val = convertReduceOpCommon<tosa::ReduceSumOp>(
2563 rewriter, op, output_type, input_value, axes_elems, keep_dims,
2564 reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
2565 output_zp);
2566
2567 if (!val.hasValue()) return llvm::None;
2568
2569 if (!input_is_qtype) {
2570 Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
2571 return rewriter
2572 .create<tosa::MulOp>(op->getLoc(), output_type, val.getValue(),
2573 div_const, 0)
2574 .getResult();
2575 }
2576
2577 return val;
2578 }
2579
2580 // Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize.
convertResizeOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,StringRef mode)2581 llvm::Optional<Value> convertResizeOp(PatternRewriter& rewriter, Operation* op,
2582 RankedTensorType output_type,
2583 Value input_value, StringRef mode) {
2584 RankedTensorType input_type =
2585 input_value.getType().dyn_cast<RankedTensorType>();
2586 if (!input_type) return llvm::None;
2587
2588 auto input_shape = input_type.getShape();
2589 auto output_shape = output_type.getShape();
2590
2591 size_t input_height = input_shape[1];
2592 size_t input_width = input_shape[2];
2593 size_t output_height = output_shape[1];
2594 size_t output_width = output_shape[2];
2595
2596 bool input_is_qtype =
2597 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2598 bool output_is_qtype =
2599 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2600
2601 if (input_is_qtype != output_is_qtype) {
2602 op->emitOpError(
2603 "ConvertResizeOp: input/output tensor should "
2604 "be all quantized or all floating-point.");
2605 return llvm::None;
2606 }
2607
2608 if (!input_is_qtype) {
2609 // TODO: support float type
2610 op->emitOpError("ConvertResizeOp: floating-point type not supported yet ");
2611 return llvm::None;
2612 }
2613
2614 int32_t shift = 11; // Set default shift to maximum allowed
2615
2616 double frac_y =
2617 static_cast<double>(output_height) / static_cast<double>(input_height);
2618 double frac_x =
2619 static_cast<double>(output_width) / static_cast<double>(input_width);
2620 int32_t stride_y = std::lround(frac_y * static_cast<double>(1 << shift));
2621 int32_t stride_x = std::lround(frac_x * static_cast<double>(1 << shift));
2622
2623 // Stride is int16
2624 while (stride_y >= 32768 || stride_x >= 32768) {
2625 shift--;
2626 stride_y = std::lround(frac_y * static_cast<double>(1 << shift));
2627 stride_x = std::lround(frac_x * static_cast<double>(1 << shift));
2628 }
2629
2630 ArrayAttr output_size =
2631 rewriter.getI64ArrayAttr({static_cast<int64_t>(output_height),
2632 static_cast<int64_t>(output_width)});
2633 ArrayAttr stride = rewriter.getI64ArrayAttr({stride_y, stride_x});
2634 ArrayAttr offset = rewriter.getI64ArrayAttr({0, 0});
2635 IntegerAttr shift_attr = rewriter.getI32IntegerAttr(shift);
2636 StringAttr resize_mode = rewriter.getStringAttr(mode.str());
2637
2638 return rewriter
2639 .create<tosa::ResizeOp>(op->getLoc(), output_type, input_value,
2640 output_size, stride, offset, shift_attr,
2641 resize_mode)
2642 .getResult();
2643 }
2644
2645 // Lowers Quantize to a sequence of TOSA quantization ops.
convertQuantizeOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,double scale,int64_t zeropoint)2646 llvm::Optional<Value> convertQuantizeOp(PatternRewriter& rewriter,
2647 Operation* op,
2648 RankedTensorType output_type,
2649 Value input_value, double scale,
2650 int64_t zeropoint) {
2651 RankedTensorType input_type =
2652 input_value.getType().dyn_cast<RankedTensorType>();
2653 if (!input_type) return llvm::None;
2654
2655 auto output_shape = output_type.getShape();
2656 auto output_element_type = output_type.getElementType();
2657
2658 // output element type could only be quantized integer
2659 if (!output_element_type.isa<mlir::quant::QuantizedType>()) {
2660 op->emitWarning(
2661 "Lowering quantizeOp but output element type not quantized!");
2662 return llvm::None;
2663 }
2664
2665 RankedTensorType output_fp_type =
2666 RankedTensorType::get(output_shape, rewriter.getF32Type());
2667
2668 Value zp_val =
2669 getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(zeropoint));
2670
2671 auto op1_mul_in = rewriter.create<tosa::MulOp>(
2672 op->getLoc(), output_fp_type, input_value,
2673 getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale)), 0);
2674
2675 auto op2_add_op1 = rewriter.create<tosa::AddOp>(
2676 op->getLoc(), output_fp_type, op1_mul_in.getResult(), zp_val);
2677
2678 // TOSA doesn't support CAST FLOAT->AINT8, need to CAST to INT32
2679 // followed by a RESCALE
2680 RankedTensorType output_int32_type =
2681 RankedTensorType::get(output_shape, rewriter.getI32Type());
2682
2683 auto op3_cast_op2 = rewriter.create<tosa::CastOp>(
2684 op->getLoc(), output_int32_type, op2_add_op1.getResult());
2685
2686 return buildRescale(rewriter, op, output_type, op3_cast_op2.getResult(), 1.0,
2687 0, 0);
2688 }
2689
2690 // Lowers Dequantize to a sequence of TOSA dequantization ops.
convertDequantizeOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,double scale,int64_t zeropoint)2691 llvm::Optional<Value> convertDequantizeOp(PatternRewriter& rewriter,
2692 Operation* op,
2693 RankedTensorType output_type,
2694 Value input_value, double scale,
2695 int64_t zeropoint) {
2696 RankedTensorType input_type =
2697 input_value.getType().dyn_cast<RankedTensorType>();
2698 if (!input_type) return llvm::None;
2699
2700 // input element type could only be quantized integer
2701 if (!input_type.getElementType().isa<mlir::quant::QuantizedType>())
2702 return llvm::None;
2703
2704 auto output_shape = output_type.getShape();
2705
2706 RankedTensorType output_int32_type =
2707 RankedTensorType::get(output_shape, rewriter.getI32Type());
2708
2709 Value zp_val =
2710 getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(zeropoint));
2711
2712 // TOSA doesn't support CAST AINT8 -> FLOAT, need to RESCALE to INT32
2713 // followed by a CAST
2714 Value op1_rescale_in =
2715 buildRescale(rewriter, op, output_int32_type, input_value, 1.0, 0, 0);
2716
2717 auto op2_cast_op1 =
2718 rewriter.create<tosa::CastOp>(op->getLoc(), output_type, op1_rescale_in);
2719
2720 auto op3_sub_op2 = rewriter.create<tosa::SubOp>(
2721 op->getLoc(), output_type, op2_cast_op1.getResult(), zp_val);
2722
2723 return rewriter
2724 .create<tosa::MulOp>(
2725 op->getLoc(), output_type, op3_sub_op2.getResult(),
2726 getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(scale)),
2727 0)
2728 .getResult();
2729 }
2730
2731 // Lowers FakeQuant to a sequence of TOSA quantization ops.
convertFakeQuantOp(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input_value,double min,double max,int64_t num_bits,bool narrow_range)2732 llvm::Optional<Value> convertFakeQuantOp(PatternRewriter& rewriter,
2733 Operation* op,
2734 RankedTensorType output_type,
2735 Value input_value, double min,
2736 double max, int64_t num_bits,
2737 bool narrow_range) {
2738 // FakeQuant is lowered as follow:
2739 // op1 = quantize(input)
2740 // op2 = dequantize(op1)
2741
2742 RankedTensorType input_type =
2743 input_value.getType().dyn_cast<RankedTensorType>();
2744 if (!input_type) return llvm::None;
2745
2746 // quantized as INT<num_bits>, where num_bits can only be 8, 16
2747 if (num_bits != 8 && num_bits != 16) {
2748 op->emitWarning("FakeQuantOp lowering handles only 8 and 16 for num_bits!");
2749 return llvm::None;
2750 }
2751
2752 auto output_shape = output_type.getShape();
2753
2754 int64_t qmax = (1L << (num_bits - 1)) - 1;
2755 int64_t qmin = -(1L << (num_bits - 1));
2756 if (narrow_range) {
2757 qmin += 1;
2758 }
2759
2760 auto int_element_qtype = mlir::quant::UniformQuantizedType::get(
2761 true, rewriter.getIntegerType(num_bits), rewriter.getF32Type(), 1.0f, 0,
2762 qmin, qmax);
2763 RankedTensorType output_int_type =
2764 RankedTensorType::get(output_shape, int_element_qtype);
2765
2766 double scale = (max - min) / static_cast<double>(qmax - qmin);
2767 int64_t zeropoint = std::llround((-min) / scale + static_cast<double>(qmin));
2768
2769 // Quantize: round(x / scale + zeropoint)
2770 auto quantized_val = convertQuantizeOp(rewriter, op, output_int_type,
2771 input_value, 1.0 / scale, zeropoint);
2772
2773 if (!quantized_val.hasValue()) return llvm::None;
2774
2775 // Dequantize: ((float)x - zeropoint) * scale
2776 return convertDequantizeOp(rewriter, op, output_type,
2777 quantized_val.getValue(), scale, zeropoint);
2778 }
2779
convertTFConv2DCommon(PatternRewriter & rewriter,Operation * op,RankedTensorType output_type,Value input,Value filter,Value bias,ArrayAttr strides_attr,ArrayAttr dilations_attr,ArrayAttr explicit_padding_attr,StringRef padding_ref,StringRef data_format_ref)2780 llvm::Optional<Value> convertTFConv2DCommon(
2781 PatternRewriter& rewriter, Operation* op, RankedTensorType output_type,
2782 Value input, Value filter, Value bias, ArrayAttr strides_attr,
2783 ArrayAttr dilations_attr, ArrayAttr explicit_padding_attr,
2784 StringRef padding_ref, StringRef data_format_ref) {
2785 RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
2786 RankedTensorType filter_type = filter.getType().dyn_cast<RankedTensorType>();
2787 // Not a ranked tensor output
2788 if (!input_type) return llvm::None;
2789 if (!filter_type) return llvm::None;
2790
2791 // Transpose [H, W, I, O] to [O, H, W, I]
2792 auto filter_shape = filter_type.getShape();
2793 SmallVector<int64_t, 4> a1_transpose_dims;
2794 a1_transpose_dims.push_back(filter_shape[3]);
2795 a1_transpose_dims.push_back(filter_shape[0]);
2796 a1_transpose_dims.push_back(filter_shape[1]);
2797 a1_transpose_dims.push_back(filter_shape[2]);
2798 Value a1_filter_transpose_perm =
2799 get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, {3, 0, 1, 2});
2800 auto a1_filter_transpose_op = rewriter.create<tosa::TransposeOp>(
2801 op->getLoc(),
2802 RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_dims),
2803 filter_type.getElementType()),
2804 filter, a1_filter_transpose_perm);
2805
2806 // Only support NHWC now.
2807 if (data_format_ref.str() != "NHWC") {
2808 op->emitWarning("convertTDConv2DCommon only supports NHWC!");
2809 return llvm::None;
2810 }
2811
2812 ArrayAttr stride;
2813 ArrayAttr dilation;
2814 ArrayAttr pad;
2815 {
2816 if (!strides_attr) {
2817 stride = rewriter.getI64ArrayAttr({1, 1});
2818 } else {
2819 // Note: hardcoded to NHWC for now
2820 int64_t stride_h = strides_attr[1].cast<IntegerAttr>().getInt();
2821 int64_t stride_w = strides_attr[2].cast<IntegerAttr>().getInt();
2822 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
2823 }
2824 }
2825 {
2826 if (!dilations_attr) {
2827 dilation = rewriter.getI64ArrayAttr({1, 1});
2828 } else {
2829 // Note: hardcoded to NHWC for now
2830 int64_t dilation_h = dilations_attr[1].cast<IntegerAttr>().getInt();
2831 int64_t dilation_w = dilations_attr[2].cast<IntegerAttr>().getInt();
2832 dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
2833 }
2834 }
2835 {
2836 tensorflow::Padding tf_pad;
2837 if (!GetPaddingFromString(padding_ref.str(), &tf_pad).ok()) {
2838 op->emitWarning("Could not get padding data from padding string term!");
2839 return llvm::None;
2840 }
2841
2842 tensorflow::TensorFormat data_format_tf;
2843 if (!FormatFromString(data_format_ref.str(), &data_format_tf))
2844 return llvm::None;
2845
2846 if (tf_pad == tensorflow::Padding::EXPLICIT) {
2847 pad = getPaddingValuesFromExplicitPadAttr(explicit_padding_attr,
2848 data_format_tf, rewriter);
2849 } else {
2850 if (!getPaddingValuesFromPadType(tf_pad, data_format_tf,
2851 0, // tensorflow::FORMAT_HWIO
2852 input_type, filter_type, stride,
2853 dilation, rewriter, pad))
2854 return llvm::None;
2855 }
2856 }
2857
2858 return rewriter
2859 .create<tosa::Conv2DOp>(op->getLoc(), output_type, input,
2860 a1_filter_transpose_op.getResult(), bias, pad,
2861 stride, dilation)
2862 .getResult();
2863 }
2864
2865 }; // namespace tosa
2866 }; // namespace mlir
2867