• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefLayerSupport.hpp"
7 
8 #include <armnn/TypesUtils.hpp>
9 #include <armnn/Types.hpp>
10 #include <armnn/Descriptors.hpp>
11 #include <armnn/utility/IgnoreUnused.hpp>
12 #include <armnn/utility/NumericCast.hpp>
13 
14 #include <LayerSupportCommon.hpp>
15 #include <backendsCommon/LayerSupportRules.hpp>
16 
17 #include <vector>
18 #include <array>
19 
20 namespace armnn
21 {
22 
23 namespace
24 {
25 
26 template<typename Float32Func, typename Uint8Func, typename ... Params>
IsSupportedForDataTypeRef(Optional<std::string &> reasonIfUnsupported,DataType dataType,Float32Func floatFuncPtr,Uint8Func uint8FuncPtr,Params &&...params)27 bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28                                DataType dataType,
29                                Float32Func floatFuncPtr,
30                                Uint8Func uint8FuncPtr,
31                                Params&&... params)
32 {
33     return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34                                          dataType,
35                                          &FalseFunc<Params...>,
36                                          floatFuncPtr,
37                                          uint8FuncPtr,
38                                          &FalseFunc<Params...>,
39                                          &FalseFunc<Params...>,
40                                          std::forward<Params>(params)...);
41 }
42 
43 } // anonymous namespace
44 
45 namespace
46 {
47 
CreateIncorrectDimensionsErrorMsg(unsigned int expected,unsigned int actual,std::string & layerStr,std::string & tensorName)48 std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49                                               unsigned int actual,
50                                               std::string& layerStr,
51                                               std::string& tensorName)
52 {
53     std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54                            " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55 
56     return errorMsg;
57 }
58 
59 } // anonymous namespace
60 
IsAbsSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const61 bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
62                                      Optional<std::string&> reasonIfUnsupported) const
63 {
64     return IsElementwiseUnarySupported(input,
65                                        output,
66                                        ElementwiseUnaryDescriptor(UnaryOperation::Abs),
67                                        reasonIfUnsupported);
68 }
69 
IsActivationSupported(const TensorInfo & input,const TensorInfo & output,const ActivationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const70 bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
71                                             const TensorInfo& output,
72                                             const ActivationDescriptor& descriptor,
73                                             Optional<std::string&> reasonIfUnsupported) const
74 {
75    bool supported = true;
76 
77     // Define supported types.
78     std::array<DataType,6> supportedTypes = {
79         DataType::BFloat16,
80         DataType::Float32,
81         DataType::Float16,
82         DataType::QAsymmS8,
83         DataType::QAsymmU8,
84         DataType::QSymmS16
85     };
86 
87     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
88                                   "Reference activation: input type not supported.");
89 
90     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
91                                   "Reference activation: output type not supported.");
92 
93     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
94                                   "Reference activation: input and output types mismatched.");
95 
96     supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
97                                   "Reference activation: input and output shapes are of different rank.");
98 
99 
100     struct ActivationFunctionSupported : public Rule
101     {
102         ActivationFunctionSupported(const ActivationDescriptor& desc)
103         {
104             switch(desc.m_Function)
105             {
106                 case ActivationFunction::Abs:
107                 case ActivationFunction::BoundedReLu:
108                 case ActivationFunction::Elu:
109                 case ActivationFunction::HardSwish:
110                 case ActivationFunction::LeakyReLu:
111                 case ActivationFunction::Linear:
112                 case ActivationFunction::ReLu:
113                 case ActivationFunction::Sigmoid:
114                 case ActivationFunction::SoftReLu:
115                 case ActivationFunction::Sqrt:
116                 case ActivationFunction::Square:
117                 case ActivationFunction::TanH:
118                 {
119                     m_Res = true;
120                     break;
121                 }
122                 default:
123                 {
124                     m_Res = false;
125                     break;
126                 }
127             }
128         }
129     };
130 
131     // Function is supported
132     supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
133                                   "Reference activation: function not supported.");
134 
135     return supported;
136 }
137 
IsAdditionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const138 bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
139                                           const TensorInfo& input1,
140                                           const TensorInfo& output,
141                                           Optional<std::string&> reasonIfUnsupported) const
142 {
143     bool supported = true;
144 
145     std::array<DataType,7> supportedTypes = {
146         DataType::BFloat16,
147         DataType::Float32,
148         DataType::Float16,
149         DataType::QAsymmS8,
150         DataType::QAsymmU8,
151         DataType::QSymmS16,
152         DataType::Signed32
153     };
154 
155     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
156                                   "Reference addition: input 0 is not a supported type.");
157 
158     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
159                                   "Reference addition: input 1 is not a supported type.");
160 
161     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
162                                   "Reference addition: output is not a supported type.");
163 
164     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
165                                   "Reference addition: input 0 and Input 1 types are mismatched");
166 
167     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
168                                   "Reference addition: input and output types are mismatched");
169 
170     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
171                                   "Reference addition: shapes are not suitable for implicit broadcast.");
172 
173     return supported;
174 }
175 
IsArgMinMaxSupported(const armnn::TensorInfo & input,const armnn::TensorInfo & output,const armnn::ArgMinMaxDescriptor & descriptor,armnn::Optional<std::string &> reasonIfUnsupported) const176 bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
177                                            const armnn::ArgMinMaxDescriptor &descriptor,
178                                            armnn::Optional<std::string &> reasonIfUnsupported) const
179 {
180     IgnoreUnused(descriptor);
181 
182     std::array<DataType, 7> supportedTypes =
183     {
184         DataType::BFloat16,
185         DataType::Float16,
186         DataType::Float32,
187         DataType::QAsymmS8,
188         DataType::QAsymmU8,
189         DataType::QSymmS16,
190         DataType::Signed32
191     };
192 
193     bool supported = true;
194 
195     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
196                                   "Reference ArgMinMax: input is not a supported type.");
197     supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
198                                   "Reference ArgMinMax: output type not supported");
199 
200     return supported;
201 }
202 
IsBatchNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & mean,const TensorInfo & variance,const TensorInfo & beta,const TensorInfo & gamma,const BatchNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const203 bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
204                                                     const TensorInfo& output,
205                                                     const TensorInfo& mean,
206                                                     const TensorInfo& variance,
207                                                     const TensorInfo& beta,
208                                                     const TensorInfo& gamma,
209                                                     const BatchNormalizationDescriptor& descriptor,
210                                                     Optional<std::string&> reasonIfUnsupported) const
211 {
212     IgnoreUnused(descriptor);
213 
214     std::array<DataType, 6> supportedTypes =
215     {
216         DataType::BFloat16,
217         DataType::Float32,
218         DataType::Float16,
219         DataType::QAsymmS8,
220         DataType::QAsymmU8,
221         DataType::QSymmS16
222     };
223 
224     bool supported = true;
225 
226     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
227                                   "Reference batch normalization: input is not a supported type.");
228 
229     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
230                                   "Reference batch normalization: output is not a supported type.");
231 
232     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
233                                   "Reference batch normalization: input and output types are mismatched");
234 
235     supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
236                                   "Reference batch normalization: mean is not a supported type.");
237 
238     supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
239                                   "Reference batch normalization: variance is not a supported type.");
240 
241     supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
242                                   "Reference batch normalization: beta is not a supported type.");
243 
244     supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
245                                   "Reference batch normalization: gamma is not a supported type.");
246 
247     return supported;
248 }
249 
IsBatchToSpaceNdSupported(const TensorInfo & input,const TensorInfo & output,const BatchToSpaceNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const250 bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
251                                                 const TensorInfo& output,
252                                                 const BatchToSpaceNdDescriptor& descriptor,
253                                                 Optional<std::string&> reasonIfUnsupported) const
254 {
255     IgnoreUnused(descriptor);
256 
257     bool supported = true;
258 
259     std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
260     std::string inputTensorStr = "input";
261     std::string outputTensorStr = "output";
262 
263     // Define supported types.
264     std::array<DataType,6> supportedTypes =
265     {
266         DataType::BFloat16,
267         DataType::Float32,
268         DataType::Float16,
269         DataType::QAsymmS8,
270         DataType::QAsymmU8,
271         DataType::QSymmS16
272     };
273 
274     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
275                                   "Reference BatchToSpaceNd: input type not supported.");
276 
277     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
278                                   "Reference BatchToSpaceNd: output type not supported.");
279 
280     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
281                                   "Reference BatchToSpaceNd: input and output types mismatched.");
282 
283     supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
284                                   reasonIfUnsupported,
285                                   CreateIncorrectDimensionsErrorMsg(4,
286                                                                     output.GetNumDimensions(),
287                                                                     batchToSpaceNdLayerStr,
288                                                                     outputTensorStr).data());
289 
290     supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
291                                   reasonIfUnsupported,
292                                   CreateIncorrectDimensionsErrorMsg(4,
293                                                                     input.GetNumDimensions(),
294                                                                     batchToSpaceNdLayerStr,
295                                                                     inputTensorStr).data());
296 
297     return supported;
298 }
299 
IsComparisonSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const ComparisonDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const300 bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
301                                             const TensorInfo& input1,
302                                             const TensorInfo& output,
303                                             const ComparisonDescriptor& descriptor,
304                                             Optional<std::string&> reasonIfUnsupported) const
305 {
306     IgnoreUnused(descriptor);
307     std::array<DataType, 8> supportedInputTypes =
308     {
309         DataType::Boolean,
310         DataType::BFloat16,
311         DataType::Float32,
312         DataType::Float16,
313         DataType::QAsymmS8,
314         DataType::QAsymmU8,
315         DataType::QSymmS16,
316         DataType::Signed32
317     };
318 
319     bool supported = true;
320     supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
321                                   "Reference comparison: input 0 is not a supported type");
322 
323     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
324                                   "Reference comparison: input 0 and Input 1 types are mismatched");
325 
326     supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
327                                   "Reference comparison: output is not of type Boolean");
328 
329     return supported;
330 }
331 
IsConcatSupported(const std::vector<const TensorInfo * > inputs,const TensorInfo & output,const ConcatDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const332 bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
333                                         const TensorInfo& output,
334                                         const ConcatDescriptor& descriptor,
335                                         Optional<std::string&> reasonIfUnsupported) const
336 {
337     IgnoreUnused(descriptor);
338 
339     bool supported = true;
340     std::array<DataType,6> supportedTypes =
341     {
342         DataType::BFloat16,
343         DataType::Float32,
344         DataType::Float16,
345         DataType::QAsymmS8,
346         DataType::QAsymmU8,
347         DataType::QSymmS16
348     };
349 
350     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
351                                   "Reference concatenation: output type not supported");
352     for (const TensorInfo* input : inputs)
353     {
354         ARMNN_ASSERT(input != nullptr);
355         supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
356             "Reference concatenation: input type not supported");
357 
358         supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
359             "Reference concatenation: input and output types mismatched.");
360     }
361 
362     return supported;
363 }
364 
IsConstantSupported(const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const365 bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
366                                           Optional<std::string&> reasonIfUnsupported) const
367 {
368     std::array<DataType,8> supportedTypes =
369     {
370         DataType::BFloat16,
371         DataType::Float16,
372         DataType::Float32,
373         DataType::QAsymmS8,
374         DataType::QAsymmU8,
375         DataType::QSymmS8,
376         DataType::QSymmS16,
377         DataType::Signed32
378     };
379 
380     return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
381                                   "Reference constant: output is not a supported type.");
382 }
383 
IsConvertBf16ToFp32Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const384 bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
385                                                    const TensorInfo& output,
386                                                    Optional<std::string&> reasonIfUnsupported) const
387 {
388     bool supported = true;
389 
390     supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
391                                   "Reference for ConvertBf16ToFp32 layer: input type not supported");
392 
393     supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
394                                   "Reference for ConvertBf16ToFp32 layer: output type not supported");
395 
396     return supported;
397 }
398 
IsConvertFp16ToFp32Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const399 bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
400                                                    const TensorInfo& output,
401                                                    Optional<std::string&> reasonIfUnsupported) const
402 {
403     return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
404                                           input.GetDataType(),
405                                           &TrueFunc<>,
406                                           &FalseInputFuncF32<>,
407                                           &FalseFuncU8<>,
408                                           &FalseFuncI32<>,
409                                           &FalseFuncU8<>) &&
410             IsSupportedForDataTypeGeneric(reasonIfUnsupported,
411                                           output.GetDataType(),
412                                           &FalseOutputFuncF16<>,
413                                           &TrueFunc<>,
414                                           &FalseFuncU8<>,
415                                           &FalseFuncI32<>,
416                                           &FalseFuncU8<>));
417 }
418 
IsConvertFp32ToBf16Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const419 bool RefLayerSupport::IsConvertFp32ToBf16Supported(const TensorInfo& input,
420                                                    const TensorInfo& output,
421                                                    Optional<std::string&> reasonIfUnsupported) const
422 {
423     bool supported = true;
424 
425     supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
426                                   "Reference for ConvertFp32ToBf16 layer: input type not supported");
427 
428     supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
429                                   "Reference for ConvertFp32ToBf16 layer: output type not supported");
430 
431     return supported;
432 }
433 
IsConvertFp32ToFp16Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const434 bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
435                                                    const TensorInfo& output,
436                                                    Optional<std::string&> reasonIfUnsupported) const
437 {
438     return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
439                                           input.GetDataType(),
440                                           &FalseInputFuncF16<>,
441                                           &TrueFunc<>,
442                                           &FalseFuncU8<>,
443                                           &FalseFuncI32<>,
444                                           &FalseFuncU8<>) &&
445             IsSupportedForDataTypeGeneric(reasonIfUnsupported,
446                                           output.GetDataType(),
447                                           &TrueFunc<>,
448                                           &FalseOutputFuncF32<>,
449                                           &FalseFuncU8<>,
450                                           &FalseFuncI32<>,
451                                           &FalseFuncU8<>));
452 }
453 
IsConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const Convolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const454 bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
455                                                const TensorInfo& output,
456                                                const Convolution2dDescriptor& descriptor,
457                                                const TensorInfo& weights,
458                                                const Optional<TensorInfo>& biases,
459                                                Optional<std::string&> reasonIfUnsupported) const
460 {
461     bool supported = true;
462 
463     // Define supported types.
464     std::array<DataType,7> supportedTypes =
465     {
466         DataType::BFloat16,
467         DataType::Float32,
468         DataType::Float16,
469         DataType::QAsymmS8,
470         DataType::QAsymmU8,
471         DataType::QSymmS8,
472         DataType::QSymmS16
473     };
474 
475     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
476                                   "Reference Convolution2d: input is not a supported type.");
477 
478     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
479                                   "Reference Convolution2d: output is not a supported type.");
480 
481     // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
482     if (input.GetDataType() == DataType::BFloat16)
483     {
484         if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
485         {
486             reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
487             supported = false;
488         }
489     }
490     else
491     {
492         supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
493                                   "Reference Convolution2d: input and output types mismatched.");
494     }
495 
496     const DataType inputType = input.GetDataType();
497     if (IsQuantized8BitType(inputType))
498     {
499         ARMNN_NO_DEPRECATE_WARN_BEGIN
500         std::array<DataType, 4> supportedWeightTypes =
501         {
502             DataType::QAsymmS8,
503             DataType::QAsymmU8,
504             DataType::QSymmS8,
505             DataType::QuantizedSymm8PerAxis // deprecated
506         };
507         ARMNN_NO_DEPRECATE_WARN_END
508 
509         supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
510                                       "Reference Convolution2d: weights type not supported for quantized input.");
511     }
512     else
513     {
514         supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
515                                       "Reference Convolution2d: weights is not a supported type.");
516 
517         supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
518                                       "Reference Convolution2d: input and weights types mismatched.");
519     }
520 
521     if (biases.has_value())
522     {
523         std::array<DataType,4> biasesSupportedTypes =
524         {
525             DataType::BFloat16,
526             DataType::Float32,
527             DataType::Float16,
528             DataType::Signed32
529         };
530 
531         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
532                                       "Reference Convolution2d: biases is not a supported type.");
533     }
534     IgnoreUnused(descriptor);
535 
536     return supported;
537 }
538 
IsDebugSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const539 bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
540                                        const TensorInfo& output,
541                                        Optional<std::string&> reasonIfUnsupported) const
542 {
543     bool supported = true;
544 
545     std::array<DataType, 8> supportedTypes =
546     {
547         DataType::BFloat16,
548         DataType::Float16,
549         DataType::Float32,
550         DataType::QAsymmS8,
551         DataType::QAsymmU8,
552         DataType::QSymmS8,
553         DataType::QSymmS16,
554         DataType::Signed32
555     };
556 
557     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
558                                   "Reference for Debug layer: input type not supported");
559 
560     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
561                                   "Reference for Debug layer: output type not supported");
562 
563     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
564                                   "Reference for Debug layer: input and output types are mismatched");
565 
566     return supported;
567 }
568 
IsDepthToSpaceSupported(const TensorInfo & input,const TensorInfo & output,const DepthToSpaceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const569 bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
570                                               const TensorInfo& output,
571                                               const DepthToSpaceDescriptor& descriptor,
572                                               Optional<std::string&> reasonIfUnsupported) const
573 {
574     IgnoreUnused(descriptor);
575     bool supported = true;
576 
577     std::array<DataType,6> supportedTypes =
578     {
579         DataType::BFloat16,
580         DataType::Float32,
581         DataType::Float16,
582         DataType::QAsymmS8,
583         DataType::QAsymmU8,
584         DataType::QSymmS16
585     };
586 
587     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
588         "Reference DepthToSpace: input type not supported");
589 
590     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
591         "Reference DepthToSpace: output type not supported");
592 
593     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
594         "Reference DepthToSpace: input and output types are mismatched");
595 
596     return supported;
597 }
598 
IsDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const599 bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
600                                                       const TensorInfo& output,
601                                                       const DepthwiseConvolution2dDescriptor& descriptor,
602                                                       const TensorInfo& weights,
603                                                       const Optional<TensorInfo>& biases,
604                                                       Optional<std::string&> reasonIfUnsupported) const
605 {
606     IgnoreUnused(descriptor);
607     bool supported = true;
608 
609     // Define supported types.
610     std::array<DataType,7> supportedTypes =
611     {
612         DataType::BFloat16,
613         DataType::Float32,
614         DataType::Float16,
615         DataType::QAsymmS8,
616         DataType::QAsymmU8,
617         DataType::QSymmS8,
618         DataType::QSymmS16
619     };
620 
621     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
622                                   "Reference DepthwiseConvolution2d: input is not a supported type.");
623 
624     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
625                                   "Reference DepthwiseConvolution2d: output is not a supported type.");
626 
627     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
628                                   "Reference DepthwiseConvolution2d: input and output types mismatched.");
629 
630     const DataType inputType = input.GetDataType();
631     if (IsQuantized8BitType(inputType))
632     {
633         ARMNN_NO_DEPRECATE_WARN_BEGIN
634         std::array<DataType, 4> supportedWeightTypes =
635                 {
636                         DataType::QAsymmS8,
637                         DataType::QAsymmU8,
638                         DataType::QSymmS8,
639                         DataType::QuantizedSymm8PerAxis // deprecated
640                 };
641         ARMNN_NO_DEPRECATE_WARN_END
642 
643         supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
644                                        "Reference DepthwiseConvolution2d: weights type not supported for "
645                                        "quantized input.");
646     }
647     else
648     {
649         supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
650                                       "Reference DepthwiseConvolution2d: weights is not a supported type.");
651 
652         supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
653                                       "Reference DepthwiseConvolution2d: input and weights types mismatched.");
654     }
655 
656     if (biases.has_value())
657     {
658         std::array<DataType,4> biasesSupportedTypes =
659         {
660             DataType::BFloat16,
661             DataType::Float32,
662             DataType::Float16,
663             DataType::Signed32
664         };
665         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
666                                       "Reference DepthwiseConvolution2d: biases is not a supported type.");
667     }
668 
669     return supported;
670 
671 }
672 
IsDequantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const673 bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
674                                             const TensorInfo& output,
675                                             Optional<std::string&> reasonIfUnsupported) const
676 {
677    bool supported = true;
678 
679     std::array<DataType,4> supportedInputTypes = {
680         DataType::QAsymmS8,
681         DataType::QAsymmU8,
682         DataType::QSymmS8,
683         DataType::QSymmS16
684     };
685 
686     supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
687                                   "Reference for Dequantize layer: input type not supported.");
688 
689     supported &= CheckSupportRule( TypeNotPerAxisQuantized(input), reasonIfUnsupported,
690                                     "Reference for Dequantize layer: per-axis quantized input not support .");
691 
692     supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
693                                   "Reference dequantize: per-axis quantized input not support .");
694 
695     std::array<DataType,3> supportedOutputTypes = {
696         DataType::BFloat16,
697         DataType::Float32,
698         DataType::Float16
699     };
700 
701     supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
702                                   "Reference for Dequantize layer: output type not supported.");
703 
704     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
705                                   "Reference for Dequantize layer: input/output shapes have different num total "
706                                   "elements.");
707 
708     return supported;
709 }
710 
IsDetectionPostProcessSupported(const TensorInfo & boxEncodings,const TensorInfo & scores,const TensorInfo & anchors,const TensorInfo & detectionBoxes,const TensorInfo & detectionClasses,const TensorInfo & detectionScores,const TensorInfo & numDetections,const DetectionPostProcessDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const711 bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
712                                                       const TensorInfo& scores,
713                                                       const TensorInfo& anchors,
714                                                       const TensorInfo& detectionBoxes,
715                                                       const TensorInfo& detectionClasses,
716                                                       const TensorInfo& detectionScores,
717                                                       const TensorInfo& numDetections,
718                                                       const DetectionPostProcessDescriptor& descriptor,
719                                                       Optional<std::string&> reasonIfUnsupported) const
720 {
721     IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
722 
723     bool supported = true;
724 
725     std::array<DataType,6> supportedInputTypes =
726     {
727         DataType::BFloat16,
728         DataType::Float32,
729         DataType::Float16,
730         DataType::QAsymmS8,
731         DataType::QAsymmU8,
732         DataType::QSymmS16
733     };
734 
735     supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
736                                   "Reference DetectionPostProcess: input 0 is not a supported type.");
737 
738     supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
739                                   "Reference DetectionPostProcess: input 1 is not a supported type.");
740 
741     return supported;
742 }
743 
IsDilatedDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const744 bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
745                                                              const TensorInfo& output,
746                                                              const DepthwiseConvolution2dDescriptor& descriptor,
747                                                              const TensorInfo& weights,
748                                                              const Optional<TensorInfo>& biases,
749                                                              Optional<std::string&> reasonIfUnsupported) const
750 {
751     return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
752 }
753 
IsDivisionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const754 bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
755                                           const TensorInfo& input1,
756                                           const TensorInfo& output,
757                                           Optional<std::string&> reasonIfUnsupported) const
758 {
759     bool supported = true;
760 
761     std::array<DataType,7> supportedTypes = {
762         DataType::BFloat16,
763         DataType::Float32,
764         DataType::Float16,
765         DataType::QAsymmS8,
766         DataType::QAsymmU8,
767         DataType::QSymmS16,
768         DataType::Signed32
769     };
770 
771     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
772                                   "Reference division: input 0 is not a supported type.");
773 
774     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
775                                   "Reference division: input 1 is not a supported type.");
776 
777     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
778                                   "Reference division: output is not a supported type.");
779 
780     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
781                                   "Reference division: input 0 and Input 1 types are mismatched");
782 
783     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
784                                   "Reference division: input and output types are mismatched");
785 
786     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
787                                   "Reference division: shapes are not suitable for implicit broadcast.");
788 
789     return supported;
790 }
791 
IsElementwiseUnarySupported(const TensorInfo & input,const TensorInfo & output,const ElementwiseUnaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const792 bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
793                                                   const TensorInfo& output,
794                                                   const ElementwiseUnaryDescriptor& descriptor,
795                                                   Optional<std::string&> reasonIfUnsupported) const
796 {
797     IgnoreUnused(descriptor);
798 
799     std::array<DataType, 7> supportedTypes =
800     {
801         DataType::BFloat16,
802         DataType::Float32,
803         DataType::Float16,
804         DataType::QAsymmS8,
805         DataType::QAsymmU8,
806         DataType::QSymmS16,
807         DataType::Signed32
808     };
809 
810     std::array<DataType, 1> logicalSupportedTypes =
811     {
812         DataType::Boolean
813     };
814 
815     bool supported = true;
816 
817     if (descriptor.m_Operation == UnaryOperation::LogicalNot)
818     {
819         supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
820                                       "Reference elementwise unary: input type not supported");
821 
822         supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
823                                       "Reference elementwise unary: output type not supported");
824     }
825     else
826     {
827         supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
828                                       "Reference elementwise unary: input type not supported");
829 
830         supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
831                                       "Reference elementwise unary: output type not supported");
832     }
833 
834     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
835                                   "Reference elementwise unary: input and output types not matching");
836 
837     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
838                                   "Reference elementwise unary: input and output shapes"
839                                   "have different number of total elements");
840 
841     return supported;
842 }
843 
IsEqualSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const844 bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
845                                        const TensorInfo& input1,
846                                        const TensorInfo& output,
847                                        Optional<std::string&> reasonIfUnsupported) const
848 {
849     return IsComparisonSupported(input0,
850                                  input1,
851                                  output,
852                                  ComparisonDescriptor(ComparisonOperation::Equal),
853                                  reasonIfUnsupported);
854 }
855 
IsFakeQuantizationSupported(const TensorInfo & input,const FakeQuantizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const856 bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
857                                                   const FakeQuantizationDescriptor& descriptor,
858                                                   Optional<std::string&> reasonIfUnsupported) const
859 {
860     IgnoreUnused(descriptor);
861     bool supported = true;
862 
863     std::array<DataType,1> supportedTypes =
864     {
865         DataType::Float32
866     };
867 
868     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
869                                   "Reference fake quantization: input type not supported.");
870 
871     return supported;
872 }
873 
IsFillSupported(const TensorInfo & input,const TensorInfo & output,const FillDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const874 bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
875                                       const TensorInfo& output,
876                                       const FillDescriptor& descriptor,
877                                       Optional<std::string&> reasonIfUnsupported) const
878 {
879     IgnoreUnused(descriptor);
880     IgnoreUnused(output);
881 
882     bool supported = true;
883 
884     std::array<DataType,3> supportedTypes =
885     {
886         DataType::Float32,
887         DataType::Float16,
888         DataType::Signed32
889     };
890 
891     supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
892                                   "Reference Fill: input type not supported.");
893 
894     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
895                                   "Reference Fill: output type not supported.");
896     return supported;
897 }
898 
IsFloorSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const899 bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
900                                        const TensorInfo& output,
901                                        Optional<std::string&> reasonIfUnsupported) const
902 {
903     IgnoreUnused(output);
904     bool supported = true;
905 
906     std::array<DataType,3> supportedTypes =
907     {
908         DataType::BFloat16,
909         DataType::Float32,
910         DataType::Float16
911     };
912 
913     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
914                                   "Reference Floor: input type not supported.");
915 
916     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
917                                   "Reference Floor: output type not supported.");
918 
919     return supported;
920 }
921 
IsFullyConnectedSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & weights,const TensorInfo & biases,const FullyConnectedDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const922 bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
923                                                 const TensorInfo& output,
924                                                 const TensorInfo& weights,
925                                                 const TensorInfo& biases,
926                                                 const FullyConnectedDescriptor& descriptor,
927                                                 Optional<std::string&> reasonIfUnsupported) const
928 {
929     bool supported = true;
930 
931     // Define supported types.
932     std::array<DataType,6> supportedTypes =
933     {
934         DataType::BFloat16,
935         DataType::Float32,
936         DataType::Float16,
937         DataType::QAsymmS8,
938         DataType::QAsymmU8,
939         DataType::QSymmS16
940     };
941 
942     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
943                                   "Reference Fully Connected: input type not supported.");
944 
945     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
946                                   "Reference Fully Connected: output type not supported.");
947 
948     supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
949                                   "Reference Fully Connected: weights type not supported.");
950 
951     // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
952     if (input.GetDataType() == DataType::BFloat16)
953     {
954         if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
955         {
956             reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
957             supported = false;
958         }
959     }
960     else
961     {
962         supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
963                                   "Reference Fully Connected: input and output types mismatched.");
964     }
965 
966     supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
967                                   "Reference Fully Connected: weights is not a supported type.");
968 
969     supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
970                                   "Reference Fully Connected: input and weights types mismatched.");
971 
972     if (descriptor.m_BiasEnabled)
973     {
974         // Defined supported types for bias
975         std::array<DataType, 5>
976         supportedBiasTypes =
977         {
978             DataType::BFloat16,
979             DataType::Float32,
980             DataType::Float16,
981             DataType::Signed32,
982             DataType::QAsymmS8
983         };
984 
985         supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
986                                       "Reference Fully Connected: bias type not supported.");
987 
988         supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
989                                       "Reference Fully Connected: bias and weight types mismatch.");
990 
991         supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
992                                       "Reference Fully Connected: bias type inferred from weights is incompatible.");
993 
994         supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
995                                       "Reference Fully Connected: bias must have 1 dimension.");
996 
997     }
998 
999     return supported;
1000 }
1001 
IsGatherSupported(const armnn::TensorInfo & input0,const armnn::TensorInfo & input1,const armnn::TensorInfo & output,const GatherDescriptor & descriptor,armnn::Optional<std::string &> reasonIfUnsupported) const1002 bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1003                                         const armnn::TensorInfo& input1,
1004                                         const armnn::TensorInfo& output,
1005                                         const GatherDescriptor& descriptor,
1006                                         armnn::Optional<std::string&> reasonIfUnsupported) const
1007 {
1008     bool supported = true;
1009     std::array<DataType,7> supportedTypes =
1010     {
1011         DataType::BFloat16,
1012         DataType::Float32,
1013         DataType::Float16,
1014         DataType::QAsymmS8,
1015         DataType::QAsymmU8,
1016         DataType::QSymmS16,
1017         DataType::Signed32
1018     };
1019 
1020     if (descriptor.m_Axis != 0)
1021     {
1022         reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
1023         supported &= false;
1024     }
1025     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1026                                   "Reference Gather: input type not supported");
1027 
1028     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1029                                   "Reference Gather: output type not supported");
1030 
1031     supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1032                                   "Reference Gather: indices (input1) type not supported");
1033 
1034     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1035                                   "Reference Gather: input and output types not matching");
1036 
1037     return supported;
1038 }
1039 
IsGreaterSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1040 bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
1041                                          const TensorInfo& input1,
1042                                          const TensorInfo& output,
1043                                          Optional<std::string&> reasonIfUnsupported) const
1044 {
1045     return IsComparisonSupported(input0,
1046                                  input1,
1047                                  output,
1048                                  ComparisonDescriptor(ComparisonOperation::Greater),
1049                                  reasonIfUnsupported);
1050 }
1051 
IsInputSupported(const TensorInfo &,Optional<std::string &>) const1052 bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1053                                        Optional<std::string&> /*reasonIfUnsupported*/) const
1054 {
1055     return true;
1056 }
1057 
IsInstanceNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const InstanceNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1058 bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1059                                                        const TensorInfo& output,
1060                                                        const InstanceNormalizationDescriptor& descriptor,
1061                                                        Optional<std::string&> reasonIfUnsupported) const
1062 {
1063     IgnoreUnused(descriptor);
1064     // Define supported types
1065     std::array<DataType, 3> supportedTypes =
1066         {
1067             DataType::BFloat16,
1068             DataType::Float32,
1069             DataType::Float16
1070         };
1071 
1072     bool supported = true;
1073 
1074     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1075                                   "Reference Instance Normalization: input type not supported.");
1076 
1077     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1078                                   "Reference Instance Normalization: output type not supported.");
1079 
1080     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1081                                   "Reference Instance Normalization: input and output types mismatched.");
1082 
1083     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1084                                   "Reference Instance Normalization: input and output shapes have different "
1085                                   "num total elements.");
1086 
1087     return supported;
1088 }
1089 
IsL2NormalizationSupported(const TensorInfo & input,const TensorInfo & output,const L2NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1090 bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1091                                                  const TensorInfo& output,
1092                                                  const L2NormalizationDescriptor& descriptor,
1093                                                  Optional<std::string&> reasonIfUnsupported) const
1094 {
1095     IgnoreUnused(descriptor);
1096     // Define supported types
1097     std::array<DataType, 6> supportedTypes =
1098     {
1099         DataType::BFloat16,
1100         DataType::Float32,
1101         DataType::Float16,
1102         DataType::QAsymmS8,
1103         DataType::QAsymmU8,
1104         DataType::QSymmS16
1105     };
1106 
1107     bool supported = true;
1108 
1109     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1110                                   "Reference L2normalization: input type not supported.");
1111 
1112     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1113                                   "Reference L2normalization: output type not supported.");
1114 
1115     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1116                                   "Reference L2normalization: input and output types mismatched.");
1117 
1118     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1119                                   "Reference L2normalization: input and output shapes have different "
1120                                   "num total elements.");
1121 
1122     return supported;
1123 }
1124 
IsLogicalBinarySupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const LogicalBinaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1125 bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1126                                                const TensorInfo& input1,
1127                                                const TensorInfo& output,
1128                                                const LogicalBinaryDescriptor& descriptor,
1129                                                Optional<std::string&> reasonIfUnsupported) const
1130 {
1131     IgnoreUnused(descriptor);
1132 
1133     std::array<DataType, 1> supportedTypes =
1134     {
1135         DataType::Boolean
1136     };
1137 
1138     bool supported = true;
1139     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1140                                   "Reference LogicalBinary: input 0 type not supported");
1141     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1142                                   "Reference LogicalBinary: input 1 type not supported");
1143 
1144     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1145                                   "Reference LogicalBinary: input and output types do not match");
1146 
1147     return supported;
1148 }
1149 
IsLogSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const LogSoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1150 bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1151                                             const TensorInfo& output,
1152                                             const LogSoftmaxDescriptor& descriptor,
1153                                             Optional<std::string&> reasonIfUnsupported) const
1154 {
1155     IgnoreUnused(descriptor);
1156 
1157     std::array<DataType, 3> supportedTypes =
1158     {
1159         DataType::BFloat16,
1160         DataType::Float32,
1161         DataType::Float16
1162     };
1163 
1164     bool supported = true;
1165     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1166                                   "Reference LogSoftmax: input type not supported");
1167 
1168     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1169                                   "Reference LogSoftmax: output type not supported");
1170 
1171     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1172                                   "Reference LogSoftmax: input and output types do not match");
1173 
1174     return supported;
1175 }
1176 
IsLstmSupported(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & scratchBuffer,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const LstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1177 bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1178                                       const TensorInfo& outputStateIn,
1179                                       const TensorInfo& cellStateIn,
1180                                       const TensorInfo& scratchBuffer,
1181                                       const TensorInfo& outputStateOut,
1182                                       const TensorInfo& cellStateOut,
1183                                       const TensorInfo& output,
1184                                       const LstmDescriptor& descriptor,
1185                                       const LstmInputParamsInfo& paramsInfo,
1186                                       Optional<std::string&> reasonIfUnsupported) const
1187 {
1188     IgnoreUnused(descriptor);
1189     IgnoreUnused(paramsInfo);
1190 
1191     bool supported = true;
1192 
1193     std::array<DataType,3> supportedTypes = {
1194         DataType::BFloat16,
1195         DataType::Float32,
1196         DataType::QSymmS16
1197     };
1198 
1199     // check inputs and outputs
1200     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1201                                   "Reference Lstm: input is not a supported type.");
1202     supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1203                                   "Reference Lstm: input and outputStateIn types are mismatched");
1204     supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1205                                   "Reference Lstm: input and cellStateIn types are mismatched");
1206     supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1207                                   "Reference Lstm: input and scratchBuffer types are mismatched");
1208     supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1209                                   "Reference Lstm: input and outputStateOut types are mismatched");
1210     supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1211                                   "Reference Lstm: input and cellStateOut types are mismatched");
1212     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1213                                   "Reference Lstm: input and output types are mismatched");
1214     // check layer parameters
1215     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
1216                                   "Reference Lstm: input and InputToForgetWeights types are mismatched");
1217     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
1218                                   "Reference Lstm: input and InputToCellWeights types are mismatched");
1219     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
1220                                   "Reference Lstm: input and InputToOutputWeights types are mismatched");
1221     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
1222                                   "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
1223     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
1224                                   "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
1225     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
1226                                   "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
1227     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
1228                                   "Reference Lstm: input and ForgetGateBias types are mismatched");
1229     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
1230                                   "Reference Lstm: input and CellBias types are mismatched");
1231     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
1232                                   "Reference Lstm: input and OutputGateBias types are mismatched");
1233     if (!descriptor.m_CifgEnabled)
1234     {
1235         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
1236                                       "Reference Lstm: input and InputToInputWeights types are mismatched");
1237         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
1238                                       reasonIfUnsupported,
1239                                       "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
1240         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
1241                                       "Reference Lstm: input and InputGateBias types are mismatched");
1242         if (descriptor.m_PeepholeEnabled)
1243         {
1244             supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
1245                                           reasonIfUnsupported,
1246                                           "Reference Lstm: input and CellToInputWeights types are mismatched");
1247         }
1248     }
1249     if (descriptor.m_PeepholeEnabled)
1250     {
1251         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
1252                                       "Reference Lstm: input and CellToForgetWeights types are mismatched");
1253         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
1254                                       "Reference Lstm: input and CellToOutputWeights types are mismatched");
1255     }
1256     if (descriptor.m_ProjectionEnabled)
1257     {
1258         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
1259                                       "Reference Lstm: input and mProjectionWeights types are mismatched");
1260         if (paramsInfo.m_ProjectionBias != nullptr)
1261         {
1262             supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
1263                                           "Reference Lstm: input and ProjectionBias types are mismatched");
1264         }
1265     }
1266     if (descriptor.m_LayerNormEnabled)
1267     {
1268         if (!descriptor.m_CifgEnabled)
1269         {
1270             supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
1271                                           reasonIfUnsupported,
1272                                           "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1273         }
1274         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
1275                                       reasonIfUnsupported,
1276                                       "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
1277         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
1278                                       reasonIfUnsupported,
1279                                       "Reference Lstm: input and CellLayerNormWeights types are mismatched");
1280         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
1281                                       reasonIfUnsupported,
1282                                       "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1283     }
1284 
1285     return supported;
1286 }
1287 
IsMaximumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1288 bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1289                                          const TensorInfo& input1,
1290                                          const TensorInfo& output,
1291                                          Optional<std::string&> reasonIfUnsupported) const
1292 {
1293     bool supported = true;
1294 
1295     std::array<DataType,7> supportedTypes = {
1296         DataType::BFloat16,
1297         DataType::Float32,
1298         DataType::Float16,
1299         DataType::QAsymmS8,
1300         DataType::QAsymmU8,
1301         DataType::QSymmS16,
1302         DataType::Signed32
1303     };
1304 
1305     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1306                                   "Reference maximum: input 0 is not a supported type.");
1307 
1308     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1309                                   "Reference maximum: input 1 is not a supported type.");
1310 
1311     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1312                                   "Reference maximum: output is not a supported type.");
1313 
1314     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1315                                   "Reference maximum: input 0 and Input 1 types are mismatched");
1316 
1317     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1318                                   "Reference maximum: input and output types are mismatched");
1319 
1320     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1321                                   "Reference maximum: shapes are not suitable for implicit broadcast.");
1322 
1323     return supported;
1324 }
1325 
IsMeanSupported(const TensorInfo & input,const TensorInfo & output,const MeanDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1326 bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1327                                       const TensorInfo& output,
1328                                       const MeanDescriptor& descriptor,
1329                                       Optional<std::string&> reasonIfUnsupported) const
1330 {
1331     bool supported = true;
1332     std::string meanLayerStr = "Mean";
1333     std::string outputTensorStr = "output";
1334 
1335     std::array<DataType,6> supportedTypes =
1336     {
1337         DataType::BFloat16,
1338         DataType::Float32,
1339         DataType::Float16,
1340         DataType::QAsymmS8,
1341         DataType::QAsymmU8,
1342         DataType::QSymmS16
1343     };
1344 
1345     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1346                                   "Reference Mean: input type not supported.");
1347 
1348     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1349                                   "Reference Mean: input and output types are mismatched");
1350 
1351     if (descriptor.m_KeepDims)
1352     {
1353         supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1354                                       reasonIfUnsupported,
1355                                       CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1356                                                                         output.GetNumDimensions(),
1357                                                                         meanLayerStr, outputTensorStr).data());
1358     }
1359     else if (descriptor.m_Axis.empty())
1360     {
1361         supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1362                                       reasonIfUnsupported,
1363                                       CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1364                                                                         meanLayerStr, outputTensorStr).data());
1365     }
1366     else
1367     {
1368         auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1369 
1370         if (outputDim > 0)
1371         {
1372             supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1373                                           reasonIfUnsupported,
1374                                           CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1375                                                                             meanLayerStr, outputTensorStr).data());
1376         }
1377         else
1378         {
1379             supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1380                                           reasonIfUnsupported,
1381                                           CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1382                                                                             meanLayerStr, outputTensorStr).data());
1383         }
1384     }
1385 
1386     return supported;
1387 }
1388 
IsMergerSupported(const std::vector<const TensorInfo * > inputs,const TensorInfo & output,const MergerDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1389 bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
1390                                         const TensorInfo& output,
1391                                         const MergerDescriptor& descriptor,
1392                                         Optional<std::string&> reasonIfUnsupported) const
1393 {
1394     return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
1395 }
1396 
IsMemCopySupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1397 bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1398                                          const TensorInfo &output,
1399                                          Optional<std::string &> reasonIfUnsupported) const
1400 {
1401     bool supported = true;
1402 
1403     std::array<DataType,7> supportedTypes =
1404     {
1405         DataType::BFloat16,
1406         DataType::Float32,
1407         DataType::Float16,
1408         DataType::QAsymmS8,
1409         DataType::QAsymmU8,
1410         DataType::QSymmS16,
1411         DataType::Boolean
1412     };
1413 
1414     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1415                                   "Reference MemCopy: input type not supported");
1416 
1417     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1418                                   "Reference MemCopy: output type not supported");
1419 
1420     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1421                                   "Reference MemCopy: input and output types are mismatched");
1422 
1423     return supported;
1424 }
1425 
IsMinimumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1426 bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1427                                          const TensorInfo& input1,
1428                                          const TensorInfo& output,
1429                                          Optional<std::string&> reasonIfUnsupported) const
1430 {
1431     bool supported = true;
1432 
1433     std::array<DataType,7> supportedTypes = {
1434         DataType::BFloat16,
1435         DataType::Float32,
1436         DataType::Float16,
1437         DataType::QAsymmS8,
1438         DataType::QAsymmU8,
1439         DataType::QSymmS16,
1440         DataType::Signed32
1441     };
1442 
1443     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1444                                   "Reference minimum: input 0 is not a supported type.");
1445 
1446     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1447                                   "Reference minimum: input 1 is not a supported type.");
1448 
1449     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1450                                   "Reference minimum: output is not a supported type.");
1451 
1452     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1453                                   "Reference minimum: input 0 and Input 1 types are mismatched");
1454 
1455     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1456                                   "Reference minimum: input and output types are mismatched");
1457 
1458     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1459                                   "Reference minimum: shapes are not suitable for implicit broadcast.");
1460 
1461     return supported;
1462 }
1463 
IsMultiplicationSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1464 bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1465                                                 const TensorInfo& input1,
1466                                                 const TensorInfo& output,
1467                                                 Optional<std::string&> reasonIfUnsupported) const
1468 {
1469     bool supported = true;
1470 
1471     std::array<DataType,7> supportedTypes = {
1472         DataType::BFloat16,
1473         DataType::Float32,
1474         DataType::Float16,
1475         DataType::QAsymmS8,
1476         DataType::QAsymmU8,
1477         DataType::QSymmS16,
1478         DataType::Signed32
1479     };
1480 
1481     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1482                                   "Reference multiplication: input 0 is not a supported type.");
1483 
1484     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1485                                   "Reference multiplication: input 1 is not a supported type.");
1486 
1487     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1488                                   "Reference multiplication: output is not a supported type.");
1489 
1490     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1491                                   "Reference multiplication: input 0 and Input 1 types are mismatched");
1492 
1493     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1494                                   "Reference multiplication: input and output types are mismatched");
1495 
1496     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1497                                   "Reference multiplication: shapes are not suitable for implicit broadcast.");
1498 
1499     return supported;
1500 }
1501 
IsNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1502 bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1503                                                const TensorInfo& output,
1504                                                const NormalizationDescriptor& descriptor,
1505                                                Optional<std::string&> reasonIfUnsupported) const
1506 {
1507     IgnoreUnused(descriptor);
1508 
1509     // Define supported types
1510     std::array<DataType, 6> supportedTypes =
1511     {
1512         DataType::BFloat16,
1513         DataType::Float16,
1514         DataType::Float32,
1515         DataType::QAsymmS8,
1516         DataType::QAsymmU8,
1517         DataType::QSymmS16
1518     };
1519 
1520     bool supported = true;
1521 
1522     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1523                                   "Reference normalization: input type not supported.");
1524 
1525     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1526                                   "Reference normalization: output type not supported.");
1527 
1528     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1529                                   "Reference normalization: input and output shapes have different "
1530                                   "num total elements.");
1531 
1532     return supported;
1533 }
1534 
IsOutputSupported(const TensorInfo &,Optional<std::string &>) const1535 bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1536                                         Optional<std::string&> /*reasonIfUnsupported*/) const
1537 {
1538     return true;
1539 }
1540 
IsPadSupported(const TensorInfo & input,const TensorInfo & output,const PadDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1541 bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1542                                      const TensorInfo& output,
1543                                      const PadDescriptor& descriptor,
1544                                      Optional<std::string&> reasonIfUnsupported) const
1545 {
1546     IgnoreUnused(descriptor);
1547     bool supported = true;
1548 
1549     // Define supported output and inputs types.
1550     std::array<DataType,6> supportedTypes =
1551     {
1552         DataType::BFloat16,
1553         DataType::Float32,
1554         DataType::Float16,
1555         DataType::QAsymmS8,
1556         DataType::QAsymmU8,
1557         DataType::QSymmS16
1558     };
1559 
1560     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1561                                   "Reference pad: input is not a supported type.");
1562 
1563     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1564                                   "Reference pad: output is not a supported type.");
1565 
1566     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1567                                   "Reference pad: input and output types are mismatched.");
1568 
1569     return supported;
1570 }
1571 
IsPermuteSupported(const TensorInfo & input,const TensorInfo & output,const PermuteDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1572 bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1573                                          const TensorInfo& output,
1574                                          const PermuteDescriptor& descriptor,
1575                                          Optional<std::string&> reasonIfUnsupported) const
1576 {
1577     IgnoreUnused(descriptor);
1578     bool supported = true;
1579 
1580     // Define supported output and inputs types.
1581     std::array<DataType, 6> supportedTypes =
1582     {
1583         DataType::BFloat16,
1584         DataType::Float32,
1585         DataType::Float16,
1586         DataType::QAsymmS8,
1587         DataType::QAsymmU8,
1588         DataType::QSymmS16
1589     };
1590 
1591     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1592                                   "Reference permute: input is not a supported type.");
1593 
1594     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1595                                   "Reference permute: output is not a supported type.");
1596 
1597     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1598                                   "Reference permute: input and output types are mismatched.");
1599 
1600     return supported;
1601 }
1602 
IsPooling2dSupported(const TensorInfo & input,const TensorInfo & output,const Pooling2dDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1603 bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1604                                            const TensorInfo& output,
1605                                            const Pooling2dDescriptor& descriptor,
1606                                            Optional<std::string&> reasonIfUnsupported) const
1607 {
1608     IgnoreUnused(descriptor);
1609     bool supported = true;
1610 
1611     // Define supported output and inputs types.
1612     std::array<DataType,6> supportedTypes =
1613     {
1614         DataType::BFloat16,
1615         DataType::Float32,
1616         DataType::Float16,
1617         DataType::QAsymmS8,
1618         DataType::QAsymmU8,
1619         DataType::QSymmS16
1620     };
1621 
1622     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1623                                   "Reference poolind2d: input is not a supported type.");
1624 
1625     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1626                                   "Reference poolind2d: output is not a supported type.");
1627 
1628     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1629                                   "Reference poolind2d: input and output types are mismatched.");
1630 
1631     return supported;
1632 }
1633 
IsQLstmSupported(const TensorInfo & input,const TensorInfo & previousOutputIn,const TensorInfo & previousCellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const QLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1634 bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
1635                                        const TensorInfo& previousOutputIn,
1636                                        const TensorInfo& previousCellStateIn,
1637                                        const TensorInfo& outputStateOut,
1638                                        const TensorInfo& cellStateOut,
1639                                        const TensorInfo& output,
1640                                        const QLstmDescriptor& descriptor,
1641                                        const LstmInputParamsInfo& paramsInfo,
1642                                        Optional<std::string&> reasonIfUnsupported) const
1643 {
1644     IgnoreUnused(input);
1645     IgnoreUnused(previousOutputIn);
1646     IgnoreUnused(previousCellStateIn);
1647     IgnoreUnused(outputStateOut);
1648     IgnoreUnused(cellStateOut);
1649     IgnoreUnused(output);
1650     IgnoreUnused(descriptor);
1651     IgnoreUnused(paramsInfo);
1652 
1653     IgnoreUnused(reasonIfUnsupported);
1654 
1655     return true;
1656 }
1657 
IsQuantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1658 bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1659                                           const TensorInfo& output,
1660                                           Optional<std::string&> reasonIfUnsupported) const
1661 {
1662    bool supported = true;
1663 
1664     // Define supported input types.
1665     std::array<DataType,7> supportedInputTypes = {
1666         DataType::BFloat16,
1667         DataType::Float32,
1668         DataType::Float16,
1669         DataType::QAsymmS8,
1670         DataType::QAsymmU8,
1671         DataType::QSymmS8,
1672         DataType::QSymmS16
1673     };
1674 
1675     supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1676                                   "Reference quantize: input type not supported.");
1677 
1678     // Define supported output types.
1679     std::array<DataType,4> supportedOutputTypes = {
1680         DataType::QAsymmS8,
1681         DataType::QAsymmU8,
1682         DataType::QSymmS8,
1683         DataType::QSymmS16
1684     };
1685     supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1686                                   "Reference quantize: output type not supported.");
1687 
1688     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1689                                   "Reference quantize: input and output shapes have different num total elements.");
1690 
1691     return supported;
1692 }
1693 
IsRankSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1694 bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
1695                                       const TensorInfo& output,
1696                                       Optional<std::string&> reasonIfUnsupported) const
1697 {
1698     IgnoreUnused(input);
1699     // Define supported output types.
1700     std::array<DataType,1> supportedOutputTypes =
1701     {
1702         DataType::Signed32,
1703     };
1704 
1705     return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1706            "Reference rank: input type not supported.");
1707 }
1708 
IsReshapeSupported(const TensorInfo & input,const TensorInfo & output,const ReshapeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1709 bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
1710                                          const TensorInfo& output,
1711                                          const ReshapeDescriptor& descriptor,
1712                                          Optional<std::string&> reasonIfUnsupported) const
1713 {
1714     IgnoreUnused(output);
1715     IgnoreUnused(descriptor);
1716     // Define supported output types.
1717     std::array<DataType,8> supportedOutputTypes =
1718     {
1719         DataType::BFloat16,
1720         DataType::Float32,
1721         DataType::Float16,
1722         DataType::Signed32,
1723         DataType::QAsymmS8,
1724         DataType::QAsymmU8,
1725         DataType::QSymmS16,
1726         DataType::Boolean
1727     };
1728 
1729     return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1730         "Reference reshape: input type not supported.");
1731 }
1732 
IsResizeBilinearSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1733 bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
1734                                                 const TensorInfo& output,
1735                                                 Optional<std::string&> reasonIfUnsupported) const
1736 {
1737     bool supported = true;
1738     std::array<DataType,6> supportedTypes =
1739     {
1740         DataType::BFloat16,
1741         DataType::Float32,
1742         DataType::Float16,
1743         DataType::QAsymmS8,
1744         DataType::QAsymmU8,
1745         DataType::QSymmS16
1746     };
1747 
1748     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1749                                   "Reference ResizeBilinear: input type not supported");
1750 
1751     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1752                                   "Reference ResizeBilinear: output type not supported");
1753 
1754     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1755                                   "Reference ResizeBilinear: input and output types not matching");
1756 
1757     return supported;
1758 }
1759 
IsResizeSupported(const TensorInfo & input,const TensorInfo & output,const ResizeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1760 bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1761                                         const TensorInfo& output,
1762                                         const ResizeDescriptor& descriptor,
1763                                         Optional<std::string&> reasonIfUnsupported) const
1764 {
1765     IgnoreUnused(descriptor);
1766     bool supported = true;
1767     std::array<DataType,6> supportedTypes =
1768     {
1769         DataType::BFloat16,
1770         DataType::Float32,
1771         DataType::Float16,
1772         DataType::QAsymmS8,
1773         DataType::QAsymmU8,
1774         DataType::QSymmS16
1775     };
1776 
1777     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1778                                   "Reference Resize: input type not supported");
1779 
1780     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1781                                   "Reference Resize: output type not supported");
1782 
1783     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1784                                   "Reference Resize: input and output types not matching");
1785 
1786     return supported;
1787 }
1788 
IsRsqrtSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1789 bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1790                                        const TensorInfo& output,
1791                                        Optional<std::string&> reasonIfUnsupported) const
1792 {
1793     return IsElementwiseUnarySupported(input,
1794                                        output,
1795                                        ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1796                                        reasonIfUnsupported);
1797 }
1798 
IsSliceSupported(const TensorInfo & input,const TensorInfo & output,const SliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1799 bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1800                                        const TensorInfo& output,
1801                                        const SliceDescriptor& descriptor,
1802                                        Optional<std::string&> reasonIfUnsupported) const
1803 {
1804     IgnoreUnused(descriptor);
1805     bool supported = true;
1806 
1807     std::array<DataType, 5> supportedTypes =
1808     {
1809         DataType::BFloat16,
1810         DataType::Float32,
1811         DataType::QAsymmS8,
1812         DataType::QAsymmU8,
1813         DataType::QSymmS16
1814     };
1815 
1816     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1817                                   "Reference Slice: input type not supported");
1818 
1819     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1820                                   "Reference Slice: output type not supported");
1821 
1822     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1823                                   "Reference Slice: input and output types are mismatched");
1824 
1825     return supported;
1826 }
1827 
IsSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const SoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1828 bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1829                                          const TensorInfo& output,
1830                                          const SoftmaxDescriptor& descriptor,
1831                                          Optional<std::string&> reasonIfUnsupported) const
1832 {
1833     IgnoreUnused(descriptor);
1834     bool supported = true;
1835     std::array<DataType,7> supportedTypes =
1836     {
1837         DataType::BFloat16,
1838         DataType::Float32,
1839         DataType::Float16,
1840         DataType::QSymmS8,
1841         DataType::QAsymmS8,
1842         DataType::QAsymmU8,
1843         DataType::QSymmS16
1844     };
1845 
1846     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1847                                   "Reference Softmax: output type not supported");
1848 
1849     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1850                                   "Reference Softmax: input type not supported");
1851 
1852     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1853                                   "Reference Softmax: input type not supported");
1854 
1855     return supported;
1856 }
1857 
IsSpaceToBatchNdSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToBatchNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1858 bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1859                                                 const TensorInfo& output,
1860                                                 const SpaceToBatchNdDescriptor& descriptor,
1861                                                 Optional<std::string&> reasonIfUnsupported) const
1862 {
1863     IgnoreUnused(descriptor);
1864     bool supported = true;
1865     std::array<DataType,6> supportedTypes =
1866     {
1867         DataType::BFloat16,
1868         DataType::Float32,
1869         DataType::Float16,
1870         DataType::QAsymmS8,
1871         DataType::QAsymmU8,
1872         DataType::QSymmS16
1873     };
1874 
1875     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1876                                   "Reference SpaceToBatchNd: input type not supported");
1877 
1878     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1879                                   "Reference SpaceToBatchNd: output type not supported");
1880 
1881     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1882                                   "Reference SpaceToBatchNd: input and output types are mismatched");
1883 
1884     return supported;
1885 }
1886 
IsSpaceToDepthSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToDepthDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1887 bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
1888                                               const TensorInfo& output,
1889                                               const SpaceToDepthDescriptor& descriptor,
1890                                               Optional<std::string&> reasonIfUnsupported) const
1891 {
1892 
1893     IgnoreUnused(descriptor);
1894     bool supported = true;
1895 
1896     std::array<DataType,6> supportedTypes =
1897     {
1898         DataType::BFloat16,
1899         DataType::Float32,
1900         DataType::Float16,
1901         DataType::QAsymmS8,
1902         DataType::QAsymmU8,
1903         DataType::QSymmS16
1904     };
1905 
1906     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1907         "Reference SpaceToDepth: input type not supported");
1908 
1909     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1910         "Reference SpaceToDepth: output type not supported");
1911 
1912     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1913         "Reference SpaceToDepth: input and output types are mismatched");
1914 
1915     return supported;
1916 }
1917 
IsSplitterSupported(const TensorInfo & input,const ViewsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1918 bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1919                                           const ViewsDescriptor& descriptor,
1920                                           Optional<std::string&> reasonIfUnsupported) const
1921 {
1922     IgnoreUnused(descriptor);
1923     bool supported = true;
1924     std::array<DataType,6> supportedTypes =
1925     {
1926         DataType::BFloat16,
1927         DataType::Float32,
1928         DataType::Float16,
1929         DataType::QAsymmS8,
1930         DataType::QAsymmU8,
1931         DataType::QSymmS16
1932     };
1933 
1934     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1935                                   "Reference splitter: input type not supported");
1936 
1937     return supported;
1938 }
1939 
IsSplitterSupported(const TensorInfo & input,const std::vector<std::reference_wrapper<TensorInfo>> & outputs,const ViewsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1940 bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1941                                           const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1942                                           const ViewsDescriptor& descriptor,
1943                                           Optional<std::string&> reasonIfUnsupported) const
1944 {
1945     IgnoreUnused(descriptor);
1946     bool supported = true;
1947     std::array<DataType,6> supportedTypes =
1948     {
1949         DataType::BFloat16,
1950         DataType::Float32,
1951         DataType::Float16,
1952         DataType::QAsymmS8,
1953         DataType::QAsymmU8,
1954         DataType::QSymmS16
1955     };
1956 
1957     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1958                                   "Reference splitter: output type not supported");
1959     for (const TensorInfo& output : outputs)
1960     {
1961         supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1962                                       "Reference splitter: input type not supported");
1963 
1964         supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1965                                       "Reference splitter: input and output types mismatched.");
1966     }
1967 
1968     return supported;
1969 }
1970 
IsStackSupported(const std::vector<const TensorInfo * > & inputs,const TensorInfo & output,const StackDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1971 bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1972                                        const TensorInfo& output,
1973                                        const StackDescriptor& descriptor,
1974                                        Optional<std::string&> reasonIfUnsupported) const
1975 {
1976     IgnoreUnused(descriptor);
1977 
1978     bool supported = true;
1979     std::array<DataType,6> supportedTypes =
1980     {
1981         DataType::BFloat16,
1982         DataType::Float32,
1983         DataType::Float16,
1984         DataType::QAsymmS8,
1985         DataType::QAsymmU8,
1986         DataType::QSymmS16
1987     };
1988 
1989     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1990                                   "Reference stack: output type not supported");
1991     for (const TensorInfo* input : inputs)
1992     {
1993         ARMNN_ASSERT(input != nullptr);
1994         supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1995             "Reference stack: input type not supported");
1996 
1997         supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1998             "Reference stack: input and output types mismatched.");
1999     }
2000 
2001     return supported;
2002 }
2003 
IsStridedSliceSupported(const TensorInfo & input,const TensorInfo & output,const StridedSliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2004 bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2005                                               const TensorInfo& output,
2006                                               const StridedSliceDescriptor& descriptor,
2007                                               Optional<std::string&> reasonIfUnsupported) const
2008 {
2009     IgnoreUnused(descriptor);
2010     bool supported = true;
2011 
2012     std::array<DataType,5> supportedTypes =
2013     {
2014         DataType::BFloat16,
2015         DataType::Float32,
2016         DataType::QAsymmS8,
2017         DataType::QAsymmU8,
2018         DataType::QSymmS16
2019     };
2020 
2021     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2022                                   "Reference StridedSlice: input type not supported");
2023 
2024     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2025                                   "Reference StridedSlice: output type not supported");
2026 
2027     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2028                                   "Reference StridedSlice: input and output types are mismatched");
2029 
2030     return supported;
2031 }
2032 
IsSubtractionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2033 bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2034                                              const TensorInfo& input1,
2035                                              const TensorInfo& output,
2036                                              Optional<std::string&> reasonIfUnsupported) const
2037 {
2038     bool supported = true;
2039 
2040     std::array<DataType,7> supportedTypes = {
2041         DataType::BFloat16,
2042         DataType::Float32,
2043         DataType::Float16,
2044         DataType::QAsymmS8,
2045         DataType::QAsymmU8,
2046         DataType::QSymmS16,
2047         DataType::Signed32
2048     };
2049 
2050     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2051                                   "Reference subtraction: input 0 is not a supported type.");
2052 
2053     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2054                                   "Reference subtraction: input 1 is not a supported type.");
2055 
2056     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2057                                   "Reference subtraction: output is not a supported type.");
2058 
2059     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2060                                   "Reference subtraction: input 0 and Input 1 types are mismatched");
2061 
2062     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2063                                   "Reference subtraction: input and output types are mismatched");
2064 
2065     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2066                                   "Reference subtraction: shapes are not suitable for implicit broadcast.");
2067 
2068     return supported;
2069 }
2070 
IsPreluSupported(const TensorInfo & input,const TensorInfo & alpha,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const2071 bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2072                                        const TensorInfo& alpha,
2073                                        const TensorInfo& output,
2074                                        Optional<std::string&> reasonIfUnsupported) const
2075 {
2076     bool supported = true;
2077 
2078     std::array<DataType, 6> supportedTypes
2079     {
2080         DataType::BFloat16,
2081         DataType::Float32,
2082         DataType::Float16,
2083         DataType::QAsymmS8,
2084         DataType::QAsymmU8,
2085         DataType::QSymmS16
2086     };
2087 
2088     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2089                                   "PReLU: input is not a supported type.");
2090 
2091     supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2092                                   "PReLU: alpha is not a supported type.");
2093 
2094     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2095                                   "PReLU: output is not a supported type.");
2096 
2097     supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2098                                   "PReLU: input, alpha and output types are mismatched");
2099 
2100     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2101                                   "PReLU: shapes are not suitable for implicit broadcast");
2102 
2103     return supported;
2104 }
2105 
IsTransposeConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const TransposeConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const2106 bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2107                                                         const TensorInfo& output,
2108                                                         const TransposeConvolution2dDescriptor& descriptor,
2109                                                         const TensorInfo& weights,
2110                                                         const Optional<TensorInfo>& biases,
2111                                                         Optional<std::string&> reasonIfUnsupported) const
2112 {
2113     IgnoreUnused(descriptor);
2114     bool supported = true;
2115 
2116     std::array<DataType,7> supportedTypes =
2117     {
2118         DataType::BFloat16,
2119         DataType::Float32,
2120         DataType::Float16,
2121         DataType::QAsymmS8,
2122         DataType::QAsymmU8,
2123         DataType::QSymmS8,
2124         DataType::QSymmS16
2125     };
2126 
2127     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2128                                   "Reference TransposeConvolution2d: input is not a supported type.");
2129 
2130     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2131                                   "Reference TransposeConvolution2d: output is not a supported type.");
2132 
2133     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2134                                   "Reference TransposeConvolution2d: input and output types mismatched.");
2135 
2136 
2137     const DataType inputType = input.GetDataType();
2138     if (IsQuantized8BitType(inputType))
2139     {
2140         ARMNN_NO_DEPRECATE_WARN_BEGIN
2141         std::array<DataType, 4> supportedWeightTypes =
2142         {
2143             DataType::QAsymmS8,
2144             DataType::QAsymmU8,
2145             DataType::QSymmS8,
2146             DataType::QuantizedSymm8PerAxis //Deprecated
2147         };
2148         ARMNN_NO_DEPRECATE_WARN_END
2149 
2150         supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2151                                       "Reference TransposeConvolution2d: weights type not supported for "
2152                                       "quantized input.");
2153     }
2154     else
2155     {
2156         supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2157                                     "Reference TransposeConvolution2d: weights is not a supported type.");
2158 
2159         supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2160                                     "Reference TransposeConvolution2d: input and weights types mismatched.");
2161     }
2162 
2163     if (biases.has_value())
2164     {
2165         std::array<DataType,4> biasesSupportedTypes =
2166         {
2167             DataType::BFloat16,
2168             DataType::Float32,
2169             DataType::Float16,
2170             DataType::Signed32
2171         };
2172         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2173                                       "Reference TransposeConvolution2d: biases is not a supported type.");
2174     }
2175 
2176     return supported;
2177 }
2178 
IsTransposeSupported(const TensorInfo & input,const TensorInfo & output,const TransposeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const2179 bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2180                                            const TensorInfo& output,
2181                                            const TransposeDescriptor& descriptor,
2182                                            Optional<std::string&> reasonIfUnsupported) const
2183 {
2184     IgnoreUnused(descriptor);
2185     bool supported = true;
2186 
2187     // Define supported output and inputs types.
2188     std::array<DataType, 6> supportedTypes =
2189     {
2190         DataType::BFloat16,
2191         DataType::Float32,
2192         DataType::Float16,
2193         DataType::QAsymmS8,
2194         DataType::QAsymmU8,
2195         DataType::QSymmS16
2196     };
2197 
2198     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2199                                   "Reference transpose: input is not a supported type.");
2200 
2201     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2202                                   "Reference transpose: output is not a supported type.");
2203 
2204     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2205                                   "Reference transpose: input and output types are mismatched.");
2206 
2207     return supported;
2208 }
2209 
2210 } // namespace armnn
2211