1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "RefLayerSupport.hpp"
7
8 #include <armnn/TypesUtils.hpp>
9 #include <armnn/Types.hpp>
10 #include <armnn/Descriptors.hpp>
11 #include <armnn/utility/IgnoreUnused.hpp>
12 #include <armnn/utility/NumericCast.hpp>
13
14 #include <LayerSupportCommon.hpp>
15 #include <backendsCommon/LayerSupportRules.hpp>
16
17 #include <vector>
18 #include <array>
19
20 namespace armnn
21 {
22
23 namespace
24 {
25
26 template<typename Float32Func, typename Uint8Func, typename ... Params>
IsSupportedForDataTypeRef(Optional<std::string &> reasonIfUnsupported,DataType dataType,Float32Func floatFuncPtr,Uint8Func uint8FuncPtr,Params &&...params)27 bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32 {
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
38 &FalseFunc<Params...>,
39 &FalseFunc<Params...>,
40 std::forward<Params>(params)...);
41 }
42
43 } // anonymous namespace
44
45 namespace
46 {
47
CreateIncorrectDimensionsErrorMsg(unsigned int expected,unsigned int actual,std::string & layerStr,std::string & tensorName)48 std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52 {
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57 }
58
59 } // anonymous namespace
60
IsAbsSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const61 bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
62 Optional<std::string&> reasonIfUnsupported) const
63 {
64 return IsElementwiseUnarySupported(input,
65 output,
66 ElementwiseUnaryDescriptor(UnaryOperation::Abs),
67 reasonIfUnsupported);
68 }
69
IsActivationSupported(const TensorInfo & input,const TensorInfo & output,const ActivationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const70 bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
71 const TensorInfo& output,
72 const ActivationDescriptor& descriptor,
73 Optional<std::string&> reasonIfUnsupported) const
74 {
75 bool supported = true;
76
77 // Define supported types.
78 std::array<DataType,6> supportedTypes = {
79 DataType::BFloat16,
80 DataType::Float32,
81 DataType::Float16,
82 DataType::QAsymmS8,
83 DataType::QAsymmU8,
84 DataType::QSymmS16
85 };
86
87 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
88 "Reference activation: input type not supported.");
89
90 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
91 "Reference activation: output type not supported.");
92
93 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
94 "Reference activation: input and output types mismatched.");
95
96 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
97 "Reference activation: input and output shapes are of different rank.");
98
99
100 struct ActivationFunctionSupported : public Rule
101 {
102 ActivationFunctionSupported(const ActivationDescriptor& desc)
103 {
104 switch(desc.m_Function)
105 {
106 case ActivationFunction::Abs:
107 case ActivationFunction::BoundedReLu:
108 case ActivationFunction::Elu:
109 case ActivationFunction::HardSwish:
110 case ActivationFunction::LeakyReLu:
111 case ActivationFunction::Linear:
112 case ActivationFunction::ReLu:
113 case ActivationFunction::Sigmoid:
114 case ActivationFunction::SoftReLu:
115 case ActivationFunction::Sqrt:
116 case ActivationFunction::Square:
117 case ActivationFunction::TanH:
118 {
119 m_Res = true;
120 break;
121 }
122 default:
123 {
124 m_Res = false;
125 break;
126 }
127 }
128 }
129 };
130
131 // Function is supported
132 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
133 "Reference activation: function not supported.");
134
135 return supported;
136 }
137
IsAdditionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const138 bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
139 const TensorInfo& input1,
140 const TensorInfo& output,
141 Optional<std::string&> reasonIfUnsupported) const
142 {
143 bool supported = true;
144
145 std::array<DataType,7> supportedTypes = {
146 DataType::BFloat16,
147 DataType::Float32,
148 DataType::Float16,
149 DataType::QAsymmS8,
150 DataType::QAsymmU8,
151 DataType::QSymmS16,
152 DataType::Signed32
153 };
154
155 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
156 "Reference addition: input 0 is not a supported type.");
157
158 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
159 "Reference addition: input 1 is not a supported type.");
160
161 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
162 "Reference addition: output is not a supported type.");
163
164 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
165 "Reference addition: input 0 and Input 1 types are mismatched");
166
167 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
168 "Reference addition: input and output types are mismatched");
169
170 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
171 "Reference addition: shapes are not suitable for implicit broadcast.");
172
173 return supported;
174 }
175
IsArgMinMaxSupported(const armnn::TensorInfo & input,const armnn::TensorInfo & output,const armnn::ArgMinMaxDescriptor & descriptor,armnn::Optional<std::string &> reasonIfUnsupported) const176 bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
177 const armnn::ArgMinMaxDescriptor &descriptor,
178 armnn::Optional<std::string &> reasonIfUnsupported) const
179 {
180 IgnoreUnused(descriptor);
181
182 std::array<DataType, 7> supportedTypes =
183 {
184 DataType::BFloat16,
185 DataType::Float16,
186 DataType::Float32,
187 DataType::QAsymmS8,
188 DataType::QAsymmU8,
189 DataType::QSymmS16,
190 DataType::Signed32
191 };
192
193 bool supported = true;
194
195 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
196 "Reference ArgMinMax: input is not a supported type.");
197 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
198 "Reference ArgMinMax: output type not supported");
199
200 return supported;
201 }
202
IsBatchNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & mean,const TensorInfo & variance,const TensorInfo & beta,const TensorInfo & gamma,const BatchNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const203 bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
204 const TensorInfo& output,
205 const TensorInfo& mean,
206 const TensorInfo& variance,
207 const TensorInfo& beta,
208 const TensorInfo& gamma,
209 const BatchNormalizationDescriptor& descriptor,
210 Optional<std::string&> reasonIfUnsupported) const
211 {
212 IgnoreUnused(descriptor);
213
214 std::array<DataType, 6> supportedTypes =
215 {
216 DataType::BFloat16,
217 DataType::Float32,
218 DataType::Float16,
219 DataType::QAsymmS8,
220 DataType::QAsymmU8,
221 DataType::QSymmS16
222 };
223
224 bool supported = true;
225
226 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
227 "Reference batch normalization: input is not a supported type.");
228
229 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
230 "Reference batch normalization: output is not a supported type.");
231
232 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
233 "Reference batch normalization: input and output types are mismatched");
234
235 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
236 "Reference batch normalization: mean is not a supported type.");
237
238 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
239 "Reference batch normalization: variance is not a supported type.");
240
241 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
242 "Reference batch normalization: beta is not a supported type.");
243
244 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
245 "Reference batch normalization: gamma is not a supported type.");
246
247 return supported;
248 }
249
IsBatchToSpaceNdSupported(const TensorInfo & input,const TensorInfo & output,const BatchToSpaceNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const250 bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
251 const TensorInfo& output,
252 const BatchToSpaceNdDescriptor& descriptor,
253 Optional<std::string&> reasonIfUnsupported) const
254 {
255 IgnoreUnused(descriptor);
256
257 bool supported = true;
258
259 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
260 std::string inputTensorStr = "input";
261 std::string outputTensorStr = "output";
262
263 // Define supported types.
264 std::array<DataType,6> supportedTypes =
265 {
266 DataType::BFloat16,
267 DataType::Float32,
268 DataType::Float16,
269 DataType::QAsymmS8,
270 DataType::QAsymmU8,
271 DataType::QSymmS16
272 };
273
274 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
275 "Reference BatchToSpaceNd: input type not supported.");
276
277 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
278 "Reference BatchToSpaceNd: output type not supported.");
279
280 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
281 "Reference BatchToSpaceNd: input and output types mismatched.");
282
283 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
284 reasonIfUnsupported,
285 CreateIncorrectDimensionsErrorMsg(4,
286 output.GetNumDimensions(),
287 batchToSpaceNdLayerStr,
288 outputTensorStr).data());
289
290 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
291 reasonIfUnsupported,
292 CreateIncorrectDimensionsErrorMsg(4,
293 input.GetNumDimensions(),
294 batchToSpaceNdLayerStr,
295 inputTensorStr).data());
296
297 return supported;
298 }
299
IsComparisonSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const ComparisonDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const300 bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
301 const TensorInfo& input1,
302 const TensorInfo& output,
303 const ComparisonDescriptor& descriptor,
304 Optional<std::string&> reasonIfUnsupported) const
305 {
306 IgnoreUnused(descriptor);
307 std::array<DataType, 8> supportedInputTypes =
308 {
309 DataType::Boolean,
310 DataType::BFloat16,
311 DataType::Float32,
312 DataType::Float16,
313 DataType::QAsymmS8,
314 DataType::QAsymmU8,
315 DataType::QSymmS16,
316 DataType::Signed32
317 };
318
319 bool supported = true;
320 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
321 "Reference comparison: input 0 is not a supported type");
322
323 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
324 "Reference comparison: input 0 and Input 1 types are mismatched");
325
326 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
327 "Reference comparison: output is not of type Boolean");
328
329 return supported;
330 }
331
IsConcatSupported(const std::vector<const TensorInfo * > inputs,const TensorInfo & output,const ConcatDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const332 bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
333 const TensorInfo& output,
334 const ConcatDescriptor& descriptor,
335 Optional<std::string&> reasonIfUnsupported) const
336 {
337 IgnoreUnused(descriptor);
338
339 bool supported = true;
340 std::array<DataType,6> supportedTypes =
341 {
342 DataType::BFloat16,
343 DataType::Float32,
344 DataType::Float16,
345 DataType::QAsymmS8,
346 DataType::QAsymmU8,
347 DataType::QSymmS16
348 };
349
350 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
351 "Reference concatenation: output type not supported");
352 for (const TensorInfo* input : inputs)
353 {
354 ARMNN_ASSERT(input != nullptr);
355 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
356 "Reference concatenation: input type not supported");
357
358 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
359 "Reference concatenation: input and output types mismatched.");
360 }
361
362 return supported;
363 }
364
IsConstantSupported(const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const365 bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
366 Optional<std::string&> reasonIfUnsupported) const
367 {
368 std::array<DataType,8> supportedTypes =
369 {
370 DataType::BFloat16,
371 DataType::Float16,
372 DataType::Float32,
373 DataType::QAsymmS8,
374 DataType::QAsymmU8,
375 DataType::QSymmS8,
376 DataType::QSymmS16,
377 DataType::Signed32
378 };
379
380 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
381 "Reference constant: output is not a supported type.");
382 }
383
IsConvertBf16ToFp32Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const384 bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
385 const TensorInfo& output,
386 Optional<std::string&> reasonIfUnsupported) const
387 {
388 bool supported = true;
389
390 supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
391 "Reference for ConvertBf16ToFp32 layer: input type not supported");
392
393 supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
394 "Reference for ConvertBf16ToFp32 layer: output type not supported");
395
396 return supported;
397 }
398
IsConvertFp16ToFp32Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const399 bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
400 const TensorInfo& output,
401 Optional<std::string&> reasonIfUnsupported) const
402 {
403 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
404 input.GetDataType(),
405 &TrueFunc<>,
406 &FalseInputFuncF32<>,
407 &FalseFuncU8<>,
408 &FalseFuncI32<>,
409 &FalseFuncU8<>) &&
410 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
411 output.GetDataType(),
412 &FalseOutputFuncF16<>,
413 &TrueFunc<>,
414 &FalseFuncU8<>,
415 &FalseFuncI32<>,
416 &FalseFuncU8<>));
417 }
418
IsConvertFp32ToBf16Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const419 bool RefLayerSupport::IsConvertFp32ToBf16Supported(const TensorInfo& input,
420 const TensorInfo& output,
421 Optional<std::string&> reasonIfUnsupported) const
422 {
423 bool supported = true;
424
425 supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
426 "Reference for ConvertFp32ToBf16 layer: input type not supported");
427
428 supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
429 "Reference for ConvertFp32ToBf16 layer: output type not supported");
430
431 return supported;
432 }
433
IsConvertFp32ToFp16Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const434 bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
435 const TensorInfo& output,
436 Optional<std::string&> reasonIfUnsupported) const
437 {
438 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
439 input.GetDataType(),
440 &FalseInputFuncF16<>,
441 &TrueFunc<>,
442 &FalseFuncU8<>,
443 &FalseFuncI32<>,
444 &FalseFuncU8<>) &&
445 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
446 output.GetDataType(),
447 &TrueFunc<>,
448 &FalseOutputFuncF32<>,
449 &FalseFuncU8<>,
450 &FalseFuncI32<>,
451 &FalseFuncU8<>));
452 }
453
IsConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const Convolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const454 bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
455 const TensorInfo& output,
456 const Convolution2dDescriptor& descriptor,
457 const TensorInfo& weights,
458 const Optional<TensorInfo>& biases,
459 Optional<std::string&> reasonIfUnsupported) const
460 {
461 bool supported = true;
462
463 // Define supported types.
464 std::array<DataType,7> supportedTypes =
465 {
466 DataType::BFloat16,
467 DataType::Float32,
468 DataType::Float16,
469 DataType::QAsymmS8,
470 DataType::QAsymmU8,
471 DataType::QSymmS8,
472 DataType::QSymmS16
473 };
474
475 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
476 "Reference Convolution2d: input is not a supported type.");
477
478 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
479 "Reference Convolution2d: output is not a supported type.");
480
481 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
482 if (input.GetDataType() == DataType::BFloat16)
483 {
484 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
485 {
486 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
487 supported = false;
488 }
489 }
490 else
491 {
492 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
493 "Reference Convolution2d: input and output types mismatched.");
494 }
495
496 const DataType inputType = input.GetDataType();
497 if (IsQuantized8BitType(inputType))
498 {
499 ARMNN_NO_DEPRECATE_WARN_BEGIN
500 std::array<DataType, 4> supportedWeightTypes =
501 {
502 DataType::QAsymmS8,
503 DataType::QAsymmU8,
504 DataType::QSymmS8,
505 DataType::QuantizedSymm8PerAxis // deprecated
506 };
507 ARMNN_NO_DEPRECATE_WARN_END
508
509 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
510 "Reference Convolution2d: weights type not supported for quantized input.");
511 }
512 else
513 {
514 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
515 "Reference Convolution2d: weights is not a supported type.");
516
517 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
518 "Reference Convolution2d: input and weights types mismatched.");
519 }
520
521 if (biases.has_value())
522 {
523 std::array<DataType,4> biasesSupportedTypes =
524 {
525 DataType::BFloat16,
526 DataType::Float32,
527 DataType::Float16,
528 DataType::Signed32
529 };
530
531 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
532 "Reference Convolution2d: biases is not a supported type.");
533 }
534 IgnoreUnused(descriptor);
535
536 return supported;
537 }
538
IsDebugSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const539 bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
540 const TensorInfo& output,
541 Optional<std::string&> reasonIfUnsupported) const
542 {
543 bool supported = true;
544
545 std::array<DataType, 8> supportedTypes =
546 {
547 DataType::BFloat16,
548 DataType::Float16,
549 DataType::Float32,
550 DataType::QAsymmS8,
551 DataType::QAsymmU8,
552 DataType::QSymmS8,
553 DataType::QSymmS16,
554 DataType::Signed32
555 };
556
557 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
558 "Reference for Debug layer: input type not supported");
559
560 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
561 "Reference for Debug layer: output type not supported");
562
563 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
564 "Reference for Debug layer: input and output types are mismatched");
565
566 return supported;
567 }
568
IsDepthToSpaceSupported(const TensorInfo & input,const TensorInfo & output,const DepthToSpaceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const569 bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
570 const TensorInfo& output,
571 const DepthToSpaceDescriptor& descriptor,
572 Optional<std::string&> reasonIfUnsupported) const
573 {
574 IgnoreUnused(descriptor);
575 bool supported = true;
576
577 std::array<DataType,6> supportedTypes =
578 {
579 DataType::BFloat16,
580 DataType::Float32,
581 DataType::Float16,
582 DataType::QAsymmS8,
583 DataType::QAsymmU8,
584 DataType::QSymmS16
585 };
586
587 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
588 "Reference DepthToSpace: input type not supported");
589
590 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
591 "Reference DepthToSpace: output type not supported");
592
593 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
594 "Reference DepthToSpace: input and output types are mismatched");
595
596 return supported;
597 }
598
IsDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const599 bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
600 const TensorInfo& output,
601 const DepthwiseConvolution2dDescriptor& descriptor,
602 const TensorInfo& weights,
603 const Optional<TensorInfo>& biases,
604 Optional<std::string&> reasonIfUnsupported) const
605 {
606 IgnoreUnused(descriptor);
607 bool supported = true;
608
609 // Define supported types.
610 std::array<DataType,7> supportedTypes =
611 {
612 DataType::BFloat16,
613 DataType::Float32,
614 DataType::Float16,
615 DataType::QAsymmS8,
616 DataType::QAsymmU8,
617 DataType::QSymmS8,
618 DataType::QSymmS16
619 };
620
621 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
622 "Reference DepthwiseConvolution2d: input is not a supported type.");
623
624 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
625 "Reference DepthwiseConvolution2d: output is not a supported type.");
626
627 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
628 "Reference DepthwiseConvolution2d: input and output types mismatched.");
629
630 const DataType inputType = input.GetDataType();
631 if (IsQuantized8BitType(inputType))
632 {
633 ARMNN_NO_DEPRECATE_WARN_BEGIN
634 std::array<DataType, 4> supportedWeightTypes =
635 {
636 DataType::QAsymmS8,
637 DataType::QAsymmU8,
638 DataType::QSymmS8,
639 DataType::QuantizedSymm8PerAxis // deprecated
640 };
641 ARMNN_NO_DEPRECATE_WARN_END
642
643 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
644 "Reference DepthwiseConvolution2d: weights type not supported for "
645 "quantized input.");
646 }
647 else
648 {
649 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
650 "Reference DepthwiseConvolution2d: weights is not a supported type.");
651
652 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
653 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
654 }
655
656 if (biases.has_value())
657 {
658 std::array<DataType,4> biasesSupportedTypes =
659 {
660 DataType::BFloat16,
661 DataType::Float32,
662 DataType::Float16,
663 DataType::Signed32
664 };
665 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
666 "Reference DepthwiseConvolution2d: biases is not a supported type.");
667 }
668
669 return supported;
670
671 }
672
IsDequantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const673 bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
674 const TensorInfo& output,
675 Optional<std::string&> reasonIfUnsupported) const
676 {
677 bool supported = true;
678
679 std::array<DataType,4> supportedInputTypes = {
680 DataType::QAsymmS8,
681 DataType::QAsymmU8,
682 DataType::QSymmS8,
683 DataType::QSymmS16
684 };
685
686 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
687 "Reference for Dequantize layer: input type not supported.");
688
689 supported &= CheckSupportRule( TypeNotPerAxisQuantized(input), reasonIfUnsupported,
690 "Reference for Dequantize layer: per-axis quantized input not support .");
691
692 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
693 "Reference dequantize: per-axis quantized input not support .");
694
695 std::array<DataType,3> supportedOutputTypes = {
696 DataType::BFloat16,
697 DataType::Float32,
698 DataType::Float16
699 };
700
701 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
702 "Reference for Dequantize layer: output type not supported.");
703
704 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
705 "Reference for Dequantize layer: input/output shapes have different num total "
706 "elements.");
707
708 return supported;
709 }
710
IsDetectionPostProcessSupported(const TensorInfo & boxEncodings,const TensorInfo & scores,const TensorInfo & anchors,const TensorInfo & detectionBoxes,const TensorInfo & detectionClasses,const TensorInfo & detectionScores,const TensorInfo & numDetections,const DetectionPostProcessDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const711 bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
712 const TensorInfo& scores,
713 const TensorInfo& anchors,
714 const TensorInfo& detectionBoxes,
715 const TensorInfo& detectionClasses,
716 const TensorInfo& detectionScores,
717 const TensorInfo& numDetections,
718 const DetectionPostProcessDescriptor& descriptor,
719 Optional<std::string&> reasonIfUnsupported) const
720 {
721 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
722
723 bool supported = true;
724
725 std::array<DataType,6> supportedInputTypes =
726 {
727 DataType::BFloat16,
728 DataType::Float32,
729 DataType::Float16,
730 DataType::QAsymmS8,
731 DataType::QAsymmU8,
732 DataType::QSymmS16
733 };
734
735 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
736 "Reference DetectionPostProcess: input 0 is not a supported type.");
737
738 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
739 "Reference DetectionPostProcess: input 1 is not a supported type.");
740
741 return supported;
742 }
743
IsDilatedDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const744 bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
745 const TensorInfo& output,
746 const DepthwiseConvolution2dDescriptor& descriptor,
747 const TensorInfo& weights,
748 const Optional<TensorInfo>& biases,
749 Optional<std::string&> reasonIfUnsupported) const
750 {
751 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
752 }
753
IsDivisionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const754 bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
755 const TensorInfo& input1,
756 const TensorInfo& output,
757 Optional<std::string&> reasonIfUnsupported) const
758 {
759 bool supported = true;
760
761 std::array<DataType,7> supportedTypes = {
762 DataType::BFloat16,
763 DataType::Float32,
764 DataType::Float16,
765 DataType::QAsymmS8,
766 DataType::QAsymmU8,
767 DataType::QSymmS16,
768 DataType::Signed32
769 };
770
771 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
772 "Reference division: input 0 is not a supported type.");
773
774 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
775 "Reference division: input 1 is not a supported type.");
776
777 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
778 "Reference division: output is not a supported type.");
779
780 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
781 "Reference division: input 0 and Input 1 types are mismatched");
782
783 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
784 "Reference division: input and output types are mismatched");
785
786 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
787 "Reference division: shapes are not suitable for implicit broadcast.");
788
789 return supported;
790 }
791
IsElementwiseUnarySupported(const TensorInfo & input,const TensorInfo & output,const ElementwiseUnaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const792 bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
793 const TensorInfo& output,
794 const ElementwiseUnaryDescriptor& descriptor,
795 Optional<std::string&> reasonIfUnsupported) const
796 {
797 IgnoreUnused(descriptor);
798
799 std::array<DataType, 7> supportedTypes =
800 {
801 DataType::BFloat16,
802 DataType::Float32,
803 DataType::Float16,
804 DataType::QAsymmS8,
805 DataType::QAsymmU8,
806 DataType::QSymmS16,
807 DataType::Signed32
808 };
809
810 std::array<DataType, 1> logicalSupportedTypes =
811 {
812 DataType::Boolean
813 };
814
815 bool supported = true;
816
817 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
818 {
819 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
820 "Reference elementwise unary: input type not supported");
821
822 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
823 "Reference elementwise unary: output type not supported");
824 }
825 else
826 {
827 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
828 "Reference elementwise unary: input type not supported");
829
830 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
831 "Reference elementwise unary: output type not supported");
832 }
833
834 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
835 "Reference elementwise unary: input and output types not matching");
836
837 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
838 "Reference elementwise unary: input and output shapes"
839 "have different number of total elements");
840
841 return supported;
842 }
843
IsEqualSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const844 bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
845 const TensorInfo& input1,
846 const TensorInfo& output,
847 Optional<std::string&> reasonIfUnsupported) const
848 {
849 return IsComparisonSupported(input0,
850 input1,
851 output,
852 ComparisonDescriptor(ComparisonOperation::Equal),
853 reasonIfUnsupported);
854 }
855
IsFakeQuantizationSupported(const TensorInfo & input,const FakeQuantizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const856 bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
857 const FakeQuantizationDescriptor& descriptor,
858 Optional<std::string&> reasonIfUnsupported) const
859 {
860 IgnoreUnused(descriptor);
861 bool supported = true;
862
863 std::array<DataType,1> supportedTypes =
864 {
865 DataType::Float32
866 };
867
868 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
869 "Reference fake quantization: input type not supported.");
870
871 return supported;
872 }
873
IsFillSupported(const TensorInfo & input,const TensorInfo & output,const FillDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const874 bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
875 const TensorInfo& output,
876 const FillDescriptor& descriptor,
877 Optional<std::string&> reasonIfUnsupported) const
878 {
879 IgnoreUnused(descriptor);
880 IgnoreUnused(output);
881
882 bool supported = true;
883
884 std::array<DataType,3> supportedTypes =
885 {
886 DataType::Float32,
887 DataType::Float16,
888 DataType::Signed32
889 };
890
891 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
892 "Reference Fill: input type not supported.");
893
894 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
895 "Reference Fill: output type not supported.");
896 return supported;
897 }
898
IsFloorSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const899 bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
900 const TensorInfo& output,
901 Optional<std::string&> reasonIfUnsupported) const
902 {
903 IgnoreUnused(output);
904 bool supported = true;
905
906 std::array<DataType,3> supportedTypes =
907 {
908 DataType::BFloat16,
909 DataType::Float32,
910 DataType::Float16
911 };
912
913 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
914 "Reference Floor: input type not supported.");
915
916 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
917 "Reference Floor: output type not supported.");
918
919 return supported;
920 }
921
IsFullyConnectedSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & weights,const TensorInfo & biases,const FullyConnectedDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const922 bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
923 const TensorInfo& output,
924 const TensorInfo& weights,
925 const TensorInfo& biases,
926 const FullyConnectedDescriptor& descriptor,
927 Optional<std::string&> reasonIfUnsupported) const
928 {
929 bool supported = true;
930
931 // Define supported types.
932 std::array<DataType,6> supportedTypes =
933 {
934 DataType::BFloat16,
935 DataType::Float32,
936 DataType::Float16,
937 DataType::QAsymmS8,
938 DataType::QAsymmU8,
939 DataType::QSymmS16
940 };
941
942 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
943 "Reference Fully Connected: input type not supported.");
944
945 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
946 "Reference Fully Connected: output type not supported.");
947
948 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
949 "Reference Fully Connected: weights type not supported.");
950
951 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
952 if (input.GetDataType() == DataType::BFloat16)
953 {
954 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
955 {
956 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
957 supported = false;
958 }
959 }
960 else
961 {
962 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
963 "Reference Fully Connected: input and output types mismatched.");
964 }
965
966 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
967 "Reference Fully Connected: weights is not a supported type.");
968
969 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
970 "Reference Fully Connected: input and weights types mismatched.");
971
972 if (descriptor.m_BiasEnabled)
973 {
974 // Defined supported types for bias
975 std::array<DataType, 5>
976 supportedBiasTypes =
977 {
978 DataType::BFloat16,
979 DataType::Float32,
980 DataType::Float16,
981 DataType::Signed32,
982 DataType::QAsymmS8
983 };
984
985 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
986 "Reference Fully Connected: bias type not supported.");
987
988 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
989 "Reference Fully Connected: bias and weight types mismatch.");
990
991 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
992 "Reference Fully Connected: bias type inferred from weights is incompatible.");
993
994 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
995 "Reference Fully Connected: bias must have 1 dimension.");
996
997 }
998
999 return supported;
1000 }
1001
IsGatherSupported(const armnn::TensorInfo & input0,const armnn::TensorInfo & input1,const armnn::TensorInfo & output,const GatherDescriptor & descriptor,armnn::Optional<std::string &> reasonIfUnsupported) const1002 bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1003 const armnn::TensorInfo& input1,
1004 const armnn::TensorInfo& output,
1005 const GatherDescriptor& descriptor,
1006 armnn::Optional<std::string&> reasonIfUnsupported) const
1007 {
1008 bool supported = true;
1009 std::array<DataType,7> supportedTypes =
1010 {
1011 DataType::BFloat16,
1012 DataType::Float32,
1013 DataType::Float16,
1014 DataType::QAsymmS8,
1015 DataType::QAsymmU8,
1016 DataType::QSymmS16,
1017 DataType::Signed32
1018 };
1019
1020 if (descriptor.m_Axis != 0)
1021 {
1022 reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
1023 supported &= false;
1024 }
1025 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1026 "Reference Gather: input type not supported");
1027
1028 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1029 "Reference Gather: output type not supported");
1030
1031 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1032 "Reference Gather: indices (input1) type not supported");
1033
1034 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1035 "Reference Gather: input and output types not matching");
1036
1037 return supported;
1038 }
1039
IsGreaterSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1040 bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
1041 const TensorInfo& input1,
1042 const TensorInfo& output,
1043 Optional<std::string&> reasonIfUnsupported) const
1044 {
1045 return IsComparisonSupported(input0,
1046 input1,
1047 output,
1048 ComparisonDescriptor(ComparisonOperation::Greater),
1049 reasonIfUnsupported);
1050 }
1051
IsInputSupported(const TensorInfo &,Optional<std::string &>) const1052 bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1053 Optional<std::string&> /*reasonIfUnsupported*/) const
1054 {
1055 return true;
1056 }
1057
IsInstanceNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const InstanceNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1058 bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1059 const TensorInfo& output,
1060 const InstanceNormalizationDescriptor& descriptor,
1061 Optional<std::string&> reasonIfUnsupported) const
1062 {
1063 IgnoreUnused(descriptor);
1064 // Define supported types
1065 std::array<DataType, 3> supportedTypes =
1066 {
1067 DataType::BFloat16,
1068 DataType::Float32,
1069 DataType::Float16
1070 };
1071
1072 bool supported = true;
1073
1074 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1075 "Reference Instance Normalization: input type not supported.");
1076
1077 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1078 "Reference Instance Normalization: output type not supported.");
1079
1080 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1081 "Reference Instance Normalization: input and output types mismatched.");
1082
1083 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1084 "Reference Instance Normalization: input and output shapes have different "
1085 "num total elements.");
1086
1087 return supported;
1088 }
1089
IsL2NormalizationSupported(const TensorInfo & input,const TensorInfo & output,const L2NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1090 bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1091 const TensorInfo& output,
1092 const L2NormalizationDescriptor& descriptor,
1093 Optional<std::string&> reasonIfUnsupported) const
1094 {
1095 IgnoreUnused(descriptor);
1096 // Define supported types
1097 std::array<DataType, 6> supportedTypes =
1098 {
1099 DataType::BFloat16,
1100 DataType::Float32,
1101 DataType::Float16,
1102 DataType::QAsymmS8,
1103 DataType::QAsymmU8,
1104 DataType::QSymmS16
1105 };
1106
1107 bool supported = true;
1108
1109 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1110 "Reference L2normalization: input type not supported.");
1111
1112 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1113 "Reference L2normalization: output type not supported.");
1114
1115 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1116 "Reference L2normalization: input and output types mismatched.");
1117
1118 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1119 "Reference L2normalization: input and output shapes have different "
1120 "num total elements.");
1121
1122 return supported;
1123 }
1124
IsLogicalBinarySupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const LogicalBinaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1125 bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1126 const TensorInfo& input1,
1127 const TensorInfo& output,
1128 const LogicalBinaryDescriptor& descriptor,
1129 Optional<std::string&> reasonIfUnsupported) const
1130 {
1131 IgnoreUnused(descriptor);
1132
1133 std::array<DataType, 1> supportedTypes =
1134 {
1135 DataType::Boolean
1136 };
1137
1138 bool supported = true;
1139 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1140 "Reference LogicalBinary: input 0 type not supported");
1141 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1142 "Reference LogicalBinary: input 1 type not supported");
1143
1144 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1145 "Reference LogicalBinary: input and output types do not match");
1146
1147 return supported;
1148 }
1149
IsLogSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const LogSoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1150 bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1151 const TensorInfo& output,
1152 const LogSoftmaxDescriptor& descriptor,
1153 Optional<std::string&> reasonIfUnsupported) const
1154 {
1155 IgnoreUnused(descriptor);
1156
1157 std::array<DataType, 3> supportedTypes =
1158 {
1159 DataType::BFloat16,
1160 DataType::Float32,
1161 DataType::Float16
1162 };
1163
1164 bool supported = true;
1165 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1166 "Reference LogSoftmax: input type not supported");
1167
1168 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1169 "Reference LogSoftmax: output type not supported");
1170
1171 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1172 "Reference LogSoftmax: input and output types do not match");
1173
1174 return supported;
1175 }
1176
IsLstmSupported(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & scratchBuffer,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const LstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1177 bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1178 const TensorInfo& outputStateIn,
1179 const TensorInfo& cellStateIn,
1180 const TensorInfo& scratchBuffer,
1181 const TensorInfo& outputStateOut,
1182 const TensorInfo& cellStateOut,
1183 const TensorInfo& output,
1184 const LstmDescriptor& descriptor,
1185 const LstmInputParamsInfo& paramsInfo,
1186 Optional<std::string&> reasonIfUnsupported) const
1187 {
1188 IgnoreUnused(descriptor);
1189 IgnoreUnused(paramsInfo);
1190
1191 bool supported = true;
1192
1193 std::array<DataType,3> supportedTypes = {
1194 DataType::BFloat16,
1195 DataType::Float32,
1196 DataType::QSymmS16
1197 };
1198
1199 // check inputs and outputs
1200 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1201 "Reference Lstm: input is not a supported type.");
1202 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1203 "Reference Lstm: input and outputStateIn types are mismatched");
1204 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1205 "Reference Lstm: input and cellStateIn types are mismatched");
1206 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1207 "Reference Lstm: input and scratchBuffer types are mismatched");
1208 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1209 "Reference Lstm: input and outputStateOut types are mismatched");
1210 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1211 "Reference Lstm: input and cellStateOut types are mismatched");
1212 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1213 "Reference Lstm: input and output types are mismatched");
1214 // check layer parameters
1215 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
1216 "Reference Lstm: input and InputToForgetWeights types are mismatched");
1217 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
1218 "Reference Lstm: input and InputToCellWeights types are mismatched");
1219 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
1220 "Reference Lstm: input and InputToOutputWeights types are mismatched");
1221 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
1222 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
1223 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
1224 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
1225 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
1226 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
1227 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
1228 "Reference Lstm: input and ForgetGateBias types are mismatched");
1229 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
1230 "Reference Lstm: input and CellBias types are mismatched");
1231 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
1232 "Reference Lstm: input and OutputGateBias types are mismatched");
1233 if (!descriptor.m_CifgEnabled)
1234 {
1235 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
1236 "Reference Lstm: input and InputToInputWeights types are mismatched");
1237 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
1238 reasonIfUnsupported,
1239 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
1240 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
1241 "Reference Lstm: input and InputGateBias types are mismatched");
1242 if (descriptor.m_PeepholeEnabled)
1243 {
1244 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
1245 reasonIfUnsupported,
1246 "Reference Lstm: input and CellToInputWeights types are mismatched");
1247 }
1248 }
1249 if (descriptor.m_PeepholeEnabled)
1250 {
1251 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
1252 "Reference Lstm: input and CellToForgetWeights types are mismatched");
1253 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
1254 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1255 }
1256 if (descriptor.m_ProjectionEnabled)
1257 {
1258 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
1259 "Reference Lstm: input and mProjectionWeights types are mismatched");
1260 if (paramsInfo.m_ProjectionBias != nullptr)
1261 {
1262 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
1263 "Reference Lstm: input and ProjectionBias types are mismatched");
1264 }
1265 }
1266 if (descriptor.m_LayerNormEnabled)
1267 {
1268 if (!descriptor.m_CifgEnabled)
1269 {
1270 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
1271 reasonIfUnsupported,
1272 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1273 }
1274 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
1275 reasonIfUnsupported,
1276 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
1277 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
1278 reasonIfUnsupported,
1279 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
1280 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
1281 reasonIfUnsupported,
1282 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1283 }
1284
1285 return supported;
1286 }
1287
IsMaximumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1288 bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1289 const TensorInfo& input1,
1290 const TensorInfo& output,
1291 Optional<std::string&> reasonIfUnsupported) const
1292 {
1293 bool supported = true;
1294
1295 std::array<DataType,7> supportedTypes = {
1296 DataType::BFloat16,
1297 DataType::Float32,
1298 DataType::Float16,
1299 DataType::QAsymmS8,
1300 DataType::QAsymmU8,
1301 DataType::QSymmS16,
1302 DataType::Signed32
1303 };
1304
1305 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1306 "Reference maximum: input 0 is not a supported type.");
1307
1308 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1309 "Reference maximum: input 1 is not a supported type.");
1310
1311 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1312 "Reference maximum: output is not a supported type.");
1313
1314 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1315 "Reference maximum: input 0 and Input 1 types are mismatched");
1316
1317 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1318 "Reference maximum: input and output types are mismatched");
1319
1320 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1321 "Reference maximum: shapes are not suitable for implicit broadcast.");
1322
1323 return supported;
1324 }
1325
IsMeanSupported(const TensorInfo & input,const TensorInfo & output,const MeanDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1326 bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1327 const TensorInfo& output,
1328 const MeanDescriptor& descriptor,
1329 Optional<std::string&> reasonIfUnsupported) const
1330 {
1331 bool supported = true;
1332 std::string meanLayerStr = "Mean";
1333 std::string outputTensorStr = "output";
1334
1335 std::array<DataType,6> supportedTypes =
1336 {
1337 DataType::BFloat16,
1338 DataType::Float32,
1339 DataType::Float16,
1340 DataType::QAsymmS8,
1341 DataType::QAsymmU8,
1342 DataType::QSymmS16
1343 };
1344
1345 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1346 "Reference Mean: input type not supported.");
1347
1348 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1349 "Reference Mean: input and output types are mismatched");
1350
1351 if (descriptor.m_KeepDims)
1352 {
1353 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1354 reasonIfUnsupported,
1355 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1356 output.GetNumDimensions(),
1357 meanLayerStr, outputTensorStr).data());
1358 }
1359 else if (descriptor.m_Axis.empty())
1360 {
1361 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1362 reasonIfUnsupported,
1363 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1364 meanLayerStr, outputTensorStr).data());
1365 }
1366 else
1367 {
1368 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1369
1370 if (outputDim > 0)
1371 {
1372 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1373 reasonIfUnsupported,
1374 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1375 meanLayerStr, outputTensorStr).data());
1376 }
1377 else
1378 {
1379 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1380 reasonIfUnsupported,
1381 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1382 meanLayerStr, outputTensorStr).data());
1383 }
1384 }
1385
1386 return supported;
1387 }
1388
IsMergerSupported(const std::vector<const TensorInfo * > inputs,const TensorInfo & output,const MergerDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1389 bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
1390 const TensorInfo& output,
1391 const MergerDescriptor& descriptor,
1392 Optional<std::string&> reasonIfUnsupported) const
1393 {
1394 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
1395 }
1396
IsMemCopySupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1397 bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1398 const TensorInfo &output,
1399 Optional<std::string &> reasonIfUnsupported) const
1400 {
1401 bool supported = true;
1402
1403 std::array<DataType,7> supportedTypes =
1404 {
1405 DataType::BFloat16,
1406 DataType::Float32,
1407 DataType::Float16,
1408 DataType::QAsymmS8,
1409 DataType::QAsymmU8,
1410 DataType::QSymmS16,
1411 DataType::Boolean
1412 };
1413
1414 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1415 "Reference MemCopy: input type not supported");
1416
1417 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1418 "Reference MemCopy: output type not supported");
1419
1420 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1421 "Reference MemCopy: input and output types are mismatched");
1422
1423 return supported;
1424 }
1425
IsMinimumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1426 bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1427 const TensorInfo& input1,
1428 const TensorInfo& output,
1429 Optional<std::string&> reasonIfUnsupported) const
1430 {
1431 bool supported = true;
1432
1433 std::array<DataType,7> supportedTypes = {
1434 DataType::BFloat16,
1435 DataType::Float32,
1436 DataType::Float16,
1437 DataType::QAsymmS8,
1438 DataType::QAsymmU8,
1439 DataType::QSymmS16,
1440 DataType::Signed32
1441 };
1442
1443 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1444 "Reference minimum: input 0 is not a supported type.");
1445
1446 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1447 "Reference minimum: input 1 is not a supported type.");
1448
1449 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1450 "Reference minimum: output is not a supported type.");
1451
1452 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1453 "Reference minimum: input 0 and Input 1 types are mismatched");
1454
1455 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1456 "Reference minimum: input and output types are mismatched");
1457
1458 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1459 "Reference minimum: shapes are not suitable for implicit broadcast.");
1460
1461 return supported;
1462 }
1463
IsMultiplicationSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1464 bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1465 const TensorInfo& input1,
1466 const TensorInfo& output,
1467 Optional<std::string&> reasonIfUnsupported) const
1468 {
1469 bool supported = true;
1470
1471 std::array<DataType,7> supportedTypes = {
1472 DataType::BFloat16,
1473 DataType::Float32,
1474 DataType::Float16,
1475 DataType::QAsymmS8,
1476 DataType::QAsymmU8,
1477 DataType::QSymmS16,
1478 DataType::Signed32
1479 };
1480
1481 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1482 "Reference multiplication: input 0 is not a supported type.");
1483
1484 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1485 "Reference multiplication: input 1 is not a supported type.");
1486
1487 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1488 "Reference multiplication: output is not a supported type.");
1489
1490 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1491 "Reference multiplication: input 0 and Input 1 types are mismatched");
1492
1493 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1494 "Reference multiplication: input and output types are mismatched");
1495
1496 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1497 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1498
1499 return supported;
1500 }
1501
IsNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1502 bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1503 const TensorInfo& output,
1504 const NormalizationDescriptor& descriptor,
1505 Optional<std::string&> reasonIfUnsupported) const
1506 {
1507 IgnoreUnused(descriptor);
1508
1509 // Define supported types
1510 std::array<DataType, 6> supportedTypes =
1511 {
1512 DataType::BFloat16,
1513 DataType::Float16,
1514 DataType::Float32,
1515 DataType::QAsymmS8,
1516 DataType::QAsymmU8,
1517 DataType::QSymmS16
1518 };
1519
1520 bool supported = true;
1521
1522 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1523 "Reference normalization: input type not supported.");
1524
1525 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1526 "Reference normalization: output type not supported.");
1527
1528 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1529 "Reference normalization: input and output shapes have different "
1530 "num total elements.");
1531
1532 return supported;
1533 }
1534
IsOutputSupported(const TensorInfo &,Optional<std::string &>) const1535 bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1536 Optional<std::string&> /*reasonIfUnsupported*/) const
1537 {
1538 return true;
1539 }
1540
IsPadSupported(const TensorInfo & input,const TensorInfo & output,const PadDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1541 bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1542 const TensorInfo& output,
1543 const PadDescriptor& descriptor,
1544 Optional<std::string&> reasonIfUnsupported) const
1545 {
1546 IgnoreUnused(descriptor);
1547 bool supported = true;
1548
1549 // Define supported output and inputs types.
1550 std::array<DataType,6> supportedTypes =
1551 {
1552 DataType::BFloat16,
1553 DataType::Float32,
1554 DataType::Float16,
1555 DataType::QAsymmS8,
1556 DataType::QAsymmU8,
1557 DataType::QSymmS16
1558 };
1559
1560 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1561 "Reference pad: input is not a supported type.");
1562
1563 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1564 "Reference pad: output is not a supported type.");
1565
1566 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1567 "Reference pad: input and output types are mismatched.");
1568
1569 return supported;
1570 }
1571
IsPermuteSupported(const TensorInfo & input,const TensorInfo & output,const PermuteDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1572 bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1573 const TensorInfo& output,
1574 const PermuteDescriptor& descriptor,
1575 Optional<std::string&> reasonIfUnsupported) const
1576 {
1577 IgnoreUnused(descriptor);
1578 bool supported = true;
1579
1580 // Define supported output and inputs types.
1581 std::array<DataType, 6> supportedTypes =
1582 {
1583 DataType::BFloat16,
1584 DataType::Float32,
1585 DataType::Float16,
1586 DataType::QAsymmS8,
1587 DataType::QAsymmU8,
1588 DataType::QSymmS16
1589 };
1590
1591 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1592 "Reference permute: input is not a supported type.");
1593
1594 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1595 "Reference permute: output is not a supported type.");
1596
1597 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1598 "Reference permute: input and output types are mismatched.");
1599
1600 return supported;
1601 }
1602
IsPooling2dSupported(const TensorInfo & input,const TensorInfo & output,const Pooling2dDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1603 bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1604 const TensorInfo& output,
1605 const Pooling2dDescriptor& descriptor,
1606 Optional<std::string&> reasonIfUnsupported) const
1607 {
1608 IgnoreUnused(descriptor);
1609 bool supported = true;
1610
1611 // Define supported output and inputs types.
1612 std::array<DataType,6> supportedTypes =
1613 {
1614 DataType::BFloat16,
1615 DataType::Float32,
1616 DataType::Float16,
1617 DataType::QAsymmS8,
1618 DataType::QAsymmU8,
1619 DataType::QSymmS16
1620 };
1621
1622 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1623 "Reference poolind2d: input is not a supported type.");
1624
1625 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1626 "Reference poolind2d: output is not a supported type.");
1627
1628 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1629 "Reference poolind2d: input and output types are mismatched.");
1630
1631 return supported;
1632 }
1633
IsQLstmSupported(const TensorInfo & input,const TensorInfo & previousOutputIn,const TensorInfo & previousCellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const QLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1634 bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
1635 const TensorInfo& previousOutputIn,
1636 const TensorInfo& previousCellStateIn,
1637 const TensorInfo& outputStateOut,
1638 const TensorInfo& cellStateOut,
1639 const TensorInfo& output,
1640 const QLstmDescriptor& descriptor,
1641 const LstmInputParamsInfo& paramsInfo,
1642 Optional<std::string&> reasonIfUnsupported) const
1643 {
1644 IgnoreUnused(input);
1645 IgnoreUnused(previousOutputIn);
1646 IgnoreUnused(previousCellStateIn);
1647 IgnoreUnused(outputStateOut);
1648 IgnoreUnused(cellStateOut);
1649 IgnoreUnused(output);
1650 IgnoreUnused(descriptor);
1651 IgnoreUnused(paramsInfo);
1652
1653 IgnoreUnused(reasonIfUnsupported);
1654
1655 return true;
1656 }
1657
IsQuantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1658 bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1659 const TensorInfo& output,
1660 Optional<std::string&> reasonIfUnsupported) const
1661 {
1662 bool supported = true;
1663
1664 // Define supported input types.
1665 std::array<DataType,7> supportedInputTypes = {
1666 DataType::BFloat16,
1667 DataType::Float32,
1668 DataType::Float16,
1669 DataType::QAsymmS8,
1670 DataType::QAsymmU8,
1671 DataType::QSymmS8,
1672 DataType::QSymmS16
1673 };
1674
1675 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1676 "Reference quantize: input type not supported.");
1677
1678 // Define supported output types.
1679 std::array<DataType,4> supportedOutputTypes = {
1680 DataType::QAsymmS8,
1681 DataType::QAsymmU8,
1682 DataType::QSymmS8,
1683 DataType::QSymmS16
1684 };
1685 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1686 "Reference quantize: output type not supported.");
1687
1688 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1689 "Reference quantize: input and output shapes have different num total elements.");
1690
1691 return supported;
1692 }
1693
IsRankSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1694 bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
1695 const TensorInfo& output,
1696 Optional<std::string&> reasonIfUnsupported) const
1697 {
1698 IgnoreUnused(input);
1699 // Define supported output types.
1700 std::array<DataType,1> supportedOutputTypes =
1701 {
1702 DataType::Signed32,
1703 };
1704
1705 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1706 "Reference rank: input type not supported.");
1707 }
1708
IsReshapeSupported(const TensorInfo & input,const TensorInfo & output,const ReshapeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1709 bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
1710 const TensorInfo& output,
1711 const ReshapeDescriptor& descriptor,
1712 Optional<std::string&> reasonIfUnsupported) const
1713 {
1714 IgnoreUnused(output);
1715 IgnoreUnused(descriptor);
1716 // Define supported output types.
1717 std::array<DataType,8> supportedOutputTypes =
1718 {
1719 DataType::BFloat16,
1720 DataType::Float32,
1721 DataType::Float16,
1722 DataType::Signed32,
1723 DataType::QAsymmS8,
1724 DataType::QAsymmU8,
1725 DataType::QSymmS16,
1726 DataType::Boolean
1727 };
1728
1729 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1730 "Reference reshape: input type not supported.");
1731 }
1732
IsResizeBilinearSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1733 bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
1734 const TensorInfo& output,
1735 Optional<std::string&> reasonIfUnsupported) const
1736 {
1737 bool supported = true;
1738 std::array<DataType,6> supportedTypes =
1739 {
1740 DataType::BFloat16,
1741 DataType::Float32,
1742 DataType::Float16,
1743 DataType::QAsymmS8,
1744 DataType::QAsymmU8,
1745 DataType::QSymmS16
1746 };
1747
1748 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1749 "Reference ResizeBilinear: input type not supported");
1750
1751 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1752 "Reference ResizeBilinear: output type not supported");
1753
1754 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1755 "Reference ResizeBilinear: input and output types not matching");
1756
1757 return supported;
1758 }
1759
IsResizeSupported(const TensorInfo & input,const TensorInfo & output,const ResizeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1760 bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1761 const TensorInfo& output,
1762 const ResizeDescriptor& descriptor,
1763 Optional<std::string&> reasonIfUnsupported) const
1764 {
1765 IgnoreUnused(descriptor);
1766 bool supported = true;
1767 std::array<DataType,6> supportedTypes =
1768 {
1769 DataType::BFloat16,
1770 DataType::Float32,
1771 DataType::Float16,
1772 DataType::QAsymmS8,
1773 DataType::QAsymmU8,
1774 DataType::QSymmS16
1775 };
1776
1777 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1778 "Reference Resize: input type not supported");
1779
1780 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1781 "Reference Resize: output type not supported");
1782
1783 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1784 "Reference Resize: input and output types not matching");
1785
1786 return supported;
1787 }
1788
IsRsqrtSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1789 bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1790 const TensorInfo& output,
1791 Optional<std::string&> reasonIfUnsupported) const
1792 {
1793 return IsElementwiseUnarySupported(input,
1794 output,
1795 ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1796 reasonIfUnsupported);
1797 }
1798
IsSliceSupported(const TensorInfo & input,const TensorInfo & output,const SliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1799 bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1800 const TensorInfo& output,
1801 const SliceDescriptor& descriptor,
1802 Optional<std::string&> reasonIfUnsupported) const
1803 {
1804 IgnoreUnused(descriptor);
1805 bool supported = true;
1806
1807 std::array<DataType, 5> supportedTypes =
1808 {
1809 DataType::BFloat16,
1810 DataType::Float32,
1811 DataType::QAsymmS8,
1812 DataType::QAsymmU8,
1813 DataType::QSymmS16
1814 };
1815
1816 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1817 "Reference Slice: input type not supported");
1818
1819 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1820 "Reference Slice: output type not supported");
1821
1822 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1823 "Reference Slice: input and output types are mismatched");
1824
1825 return supported;
1826 }
1827
IsSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const SoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1828 bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1829 const TensorInfo& output,
1830 const SoftmaxDescriptor& descriptor,
1831 Optional<std::string&> reasonIfUnsupported) const
1832 {
1833 IgnoreUnused(descriptor);
1834 bool supported = true;
1835 std::array<DataType,7> supportedTypes =
1836 {
1837 DataType::BFloat16,
1838 DataType::Float32,
1839 DataType::Float16,
1840 DataType::QSymmS8,
1841 DataType::QAsymmS8,
1842 DataType::QAsymmU8,
1843 DataType::QSymmS16
1844 };
1845
1846 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1847 "Reference Softmax: output type not supported");
1848
1849 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1850 "Reference Softmax: input type not supported");
1851
1852 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1853 "Reference Softmax: input type not supported");
1854
1855 return supported;
1856 }
1857
IsSpaceToBatchNdSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToBatchNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1858 bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1859 const TensorInfo& output,
1860 const SpaceToBatchNdDescriptor& descriptor,
1861 Optional<std::string&> reasonIfUnsupported) const
1862 {
1863 IgnoreUnused(descriptor);
1864 bool supported = true;
1865 std::array<DataType,6> supportedTypes =
1866 {
1867 DataType::BFloat16,
1868 DataType::Float32,
1869 DataType::Float16,
1870 DataType::QAsymmS8,
1871 DataType::QAsymmU8,
1872 DataType::QSymmS16
1873 };
1874
1875 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1876 "Reference SpaceToBatchNd: input type not supported");
1877
1878 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1879 "Reference SpaceToBatchNd: output type not supported");
1880
1881 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1882 "Reference SpaceToBatchNd: input and output types are mismatched");
1883
1884 return supported;
1885 }
1886
IsSpaceToDepthSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToDepthDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1887 bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
1888 const TensorInfo& output,
1889 const SpaceToDepthDescriptor& descriptor,
1890 Optional<std::string&> reasonIfUnsupported) const
1891 {
1892
1893 IgnoreUnused(descriptor);
1894 bool supported = true;
1895
1896 std::array<DataType,6> supportedTypes =
1897 {
1898 DataType::BFloat16,
1899 DataType::Float32,
1900 DataType::Float16,
1901 DataType::QAsymmS8,
1902 DataType::QAsymmU8,
1903 DataType::QSymmS16
1904 };
1905
1906 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1907 "Reference SpaceToDepth: input type not supported");
1908
1909 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1910 "Reference SpaceToDepth: output type not supported");
1911
1912 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1913 "Reference SpaceToDepth: input and output types are mismatched");
1914
1915 return supported;
1916 }
1917
IsSplitterSupported(const TensorInfo & input,const ViewsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1918 bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1919 const ViewsDescriptor& descriptor,
1920 Optional<std::string&> reasonIfUnsupported) const
1921 {
1922 IgnoreUnused(descriptor);
1923 bool supported = true;
1924 std::array<DataType,6> supportedTypes =
1925 {
1926 DataType::BFloat16,
1927 DataType::Float32,
1928 DataType::Float16,
1929 DataType::QAsymmS8,
1930 DataType::QAsymmU8,
1931 DataType::QSymmS16
1932 };
1933
1934 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1935 "Reference splitter: input type not supported");
1936
1937 return supported;
1938 }
1939
IsSplitterSupported(const TensorInfo & input,const std::vector<std::reference_wrapper<TensorInfo>> & outputs,const ViewsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1940 bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1941 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1942 const ViewsDescriptor& descriptor,
1943 Optional<std::string&> reasonIfUnsupported) const
1944 {
1945 IgnoreUnused(descriptor);
1946 bool supported = true;
1947 std::array<DataType,6> supportedTypes =
1948 {
1949 DataType::BFloat16,
1950 DataType::Float32,
1951 DataType::Float16,
1952 DataType::QAsymmS8,
1953 DataType::QAsymmU8,
1954 DataType::QSymmS16
1955 };
1956
1957 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1958 "Reference splitter: output type not supported");
1959 for (const TensorInfo& output : outputs)
1960 {
1961 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1962 "Reference splitter: input type not supported");
1963
1964 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1965 "Reference splitter: input and output types mismatched.");
1966 }
1967
1968 return supported;
1969 }
1970
IsStackSupported(const std::vector<const TensorInfo * > & inputs,const TensorInfo & output,const StackDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1971 bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1972 const TensorInfo& output,
1973 const StackDescriptor& descriptor,
1974 Optional<std::string&> reasonIfUnsupported) const
1975 {
1976 IgnoreUnused(descriptor);
1977
1978 bool supported = true;
1979 std::array<DataType,6> supportedTypes =
1980 {
1981 DataType::BFloat16,
1982 DataType::Float32,
1983 DataType::Float16,
1984 DataType::QAsymmS8,
1985 DataType::QAsymmU8,
1986 DataType::QSymmS16
1987 };
1988
1989 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1990 "Reference stack: output type not supported");
1991 for (const TensorInfo* input : inputs)
1992 {
1993 ARMNN_ASSERT(input != nullptr);
1994 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1995 "Reference stack: input type not supported");
1996
1997 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1998 "Reference stack: input and output types mismatched.");
1999 }
2000
2001 return supported;
2002 }
2003
IsStridedSliceSupported(const TensorInfo & input,const TensorInfo & output,const StridedSliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2004 bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2005 const TensorInfo& output,
2006 const StridedSliceDescriptor& descriptor,
2007 Optional<std::string&> reasonIfUnsupported) const
2008 {
2009 IgnoreUnused(descriptor);
2010 bool supported = true;
2011
2012 std::array<DataType,5> supportedTypes =
2013 {
2014 DataType::BFloat16,
2015 DataType::Float32,
2016 DataType::QAsymmS8,
2017 DataType::QAsymmU8,
2018 DataType::QSymmS16
2019 };
2020
2021 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2022 "Reference StridedSlice: input type not supported");
2023
2024 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2025 "Reference StridedSlice: output type not supported");
2026
2027 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2028 "Reference StridedSlice: input and output types are mismatched");
2029
2030 return supported;
2031 }
2032
IsSubtractionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2033 bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2034 const TensorInfo& input1,
2035 const TensorInfo& output,
2036 Optional<std::string&> reasonIfUnsupported) const
2037 {
2038 bool supported = true;
2039
2040 std::array<DataType,7> supportedTypes = {
2041 DataType::BFloat16,
2042 DataType::Float32,
2043 DataType::Float16,
2044 DataType::QAsymmS8,
2045 DataType::QAsymmU8,
2046 DataType::QSymmS16,
2047 DataType::Signed32
2048 };
2049
2050 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2051 "Reference subtraction: input 0 is not a supported type.");
2052
2053 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2054 "Reference subtraction: input 1 is not a supported type.");
2055
2056 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2057 "Reference subtraction: output is not a supported type.");
2058
2059 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2060 "Reference subtraction: input 0 and Input 1 types are mismatched");
2061
2062 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2063 "Reference subtraction: input and output types are mismatched");
2064
2065 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2066 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2067
2068 return supported;
2069 }
2070
IsPreluSupported(const TensorInfo & input,const TensorInfo & alpha,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2071 bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2072 const TensorInfo& alpha,
2073 const TensorInfo& output,
2074 Optional<std::string&> reasonIfUnsupported) const
2075 {
2076 bool supported = true;
2077
2078 std::array<DataType, 6> supportedTypes
2079 {
2080 DataType::BFloat16,
2081 DataType::Float32,
2082 DataType::Float16,
2083 DataType::QAsymmS8,
2084 DataType::QAsymmU8,
2085 DataType::QSymmS16
2086 };
2087
2088 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2089 "PReLU: input is not a supported type.");
2090
2091 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2092 "PReLU: alpha is not a supported type.");
2093
2094 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2095 "PReLU: output is not a supported type.");
2096
2097 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2098 "PReLU: input, alpha and output types are mismatched");
2099
2100 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2101 "PReLU: shapes are not suitable for implicit broadcast");
2102
2103 return supported;
2104 }
2105
IsTransposeConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const TransposeConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const2106 bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2107 const TensorInfo& output,
2108 const TransposeConvolution2dDescriptor& descriptor,
2109 const TensorInfo& weights,
2110 const Optional<TensorInfo>& biases,
2111 Optional<std::string&> reasonIfUnsupported) const
2112 {
2113 IgnoreUnused(descriptor);
2114 bool supported = true;
2115
2116 std::array<DataType,7> supportedTypes =
2117 {
2118 DataType::BFloat16,
2119 DataType::Float32,
2120 DataType::Float16,
2121 DataType::QAsymmS8,
2122 DataType::QAsymmU8,
2123 DataType::QSymmS8,
2124 DataType::QSymmS16
2125 };
2126
2127 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2128 "Reference TransposeConvolution2d: input is not a supported type.");
2129
2130 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2131 "Reference TransposeConvolution2d: output is not a supported type.");
2132
2133 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2134 "Reference TransposeConvolution2d: input and output types mismatched.");
2135
2136
2137 const DataType inputType = input.GetDataType();
2138 if (IsQuantized8BitType(inputType))
2139 {
2140 ARMNN_NO_DEPRECATE_WARN_BEGIN
2141 std::array<DataType, 4> supportedWeightTypes =
2142 {
2143 DataType::QAsymmS8,
2144 DataType::QAsymmU8,
2145 DataType::QSymmS8,
2146 DataType::QuantizedSymm8PerAxis //Deprecated
2147 };
2148 ARMNN_NO_DEPRECATE_WARN_END
2149
2150 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2151 "Reference TransposeConvolution2d: weights type not supported for "
2152 "quantized input.");
2153 }
2154 else
2155 {
2156 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2157 "Reference TransposeConvolution2d: weights is not a supported type.");
2158
2159 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2160 "Reference TransposeConvolution2d: input and weights types mismatched.");
2161 }
2162
2163 if (biases.has_value())
2164 {
2165 std::array<DataType,4> biasesSupportedTypes =
2166 {
2167 DataType::BFloat16,
2168 DataType::Float32,
2169 DataType::Float16,
2170 DataType::Signed32
2171 };
2172 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2173 "Reference TransposeConvolution2d: biases is not a supported type.");
2174 }
2175
2176 return supported;
2177 }
2178
IsTransposeSupported(const TensorInfo & input,const TensorInfo & output,const TransposeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2179 bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2180 const TensorInfo& output,
2181 const TransposeDescriptor& descriptor,
2182 Optional<std::string&> reasonIfUnsupported) const
2183 {
2184 IgnoreUnused(descriptor);
2185 bool supported = true;
2186
2187 // Define supported output and inputs types.
2188 std::array<DataType, 6> supportedTypes =
2189 {
2190 DataType::BFloat16,
2191 DataType::Float32,
2192 DataType::Float16,
2193 DataType::QAsymmS8,
2194 DataType::QAsymmU8,
2195 DataType::QSymmS16
2196 };
2197
2198 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2199 "Reference transpose: input is not a supported type.");
2200
2201 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2202 "Reference transpose: output is not a supported type.");
2203
2204 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2205 "Reference transpose: input and output types are mismatched.");
2206
2207 return supported;
2208 }
2209
2210 } // namespace armnn
2211