1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include <armnn/Deprecated.hpp> 8 #include <armnn/DescriptorsFwd.hpp> 9 #include <armnn/LstmParams.hpp> 10 #include <armnn/Optional.hpp> 11 #include <armnn/QuantizedLstmParams.hpp> 12 13 #include <cctype> 14 #include <functional> 15 #include <memory> 16 #include <vector> 17 18 namespace armnn 19 { 20 21 class TensorInfo; 22 23 class ILayerSupport 24 { 25 protected: ILayerSupport()26 ILayerSupport() {} ~ILayerSupport()27 virtual ~ILayerSupport() {} 28 29 public: 30 ARMNN_DEPRECATED_MSG("Use IsElementwiseUnarySupported instead") 31 virtual bool IsAbsSupported(const TensorInfo& input, 32 const TensorInfo& output, 33 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 34 35 virtual bool IsActivationSupported(const TensorInfo& input, 36 const TensorInfo& output, 37 const ActivationDescriptor& descriptor, 38 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 39 40 virtual bool IsAdditionSupported(const TensorInfo& input0, 41 const TensorInfo& input1, 42 const TensorInfo& output, 43 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 44 45 virtual bool IsArgMinMaxSupported(const TensorInfo& input, 46 const TensorInfo& output, 47 const ArgMinMaxDescriptor& descriptor, 48 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 49 50 virtual bool IsBatchNormalizationSupported(const TensorInfo& input, 51 const TensorInfo& output, 52 const TensorInfo& mean, 53 const TensorInfo& var, 54 const TensorInfo& beta, 55 const TensorInfo& gamma, 56 const BatchNormalizationDescriptor& descriptor, 57 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 58 59 virtual bool IsBatchToSpaceNdSupported(const TensorInfo& input, 60 const TensorInfo& output, 61 const BatchToSpaceNdDescriptor& descriptor, 62 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 63 64 virtual bool IsComparisonSupported(const TensorInfo& input0, 65 const TensorInfo& input1, 66 const TensorInfo& output, 67 const ComparisonDescriptor& descriptor, 68 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 69 70 virtual bool IsConcatSupported(const std::vector<const TensorInfo*> inputs, 71 const TensorInfo& output, 72 const OriginsDescriptor& descriptor, 73 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 74 75 virtual bool IsConstantSupported(const TensorInfo& output, 76 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 77 78 virtual bool IsConvertBf16ToFp32Supported(const TensorInfo& input, 79 const TensorInfo& output, 80 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 81 82 virtual bool IsConvertFp32ToBf16Supported(const TensorInfo& input, 83 const TensorInfo& output, 84 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 85 86 virtual bool IsConvertFp16ToFp32Supported(const TensorInfo& input, 87 const TensorInfo& output, 88 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 89 90 virtual bool IsConvertFp32ToFp16Supported(const TensorInfo& input, 91 const TensorInfo& output, 92 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 93 94 virtual bool IsConvolution2dSupported(const TensorInfo& input, 95 const TensorInfo& output, 96 const Convolution2dDescriptor& descriptor, 97 const TensorInfo& weights, 98 const Optional<TensorInfo>& biases, 99 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 100 101 virtual bool IsDebugSupported(const TensorInfo& input, 102 const TensorInfo& output, 103 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 104 105 virtual bool IsDepthToSpaceSupported(const TensorInfo& input, 106 const TensorInfo& output, 107 const DepthToSpaceDescriptor& descriptor, 108 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 109 110 virtual bool IsDepthwiseConvolutionSupported( 111 const TensorInfo& input, 112 const TensorInfo& output, 113 const DepthwiseConvolution2dDescriptor& descriptor, 114 const TensorInfo& weights, 115 const Optional<TensorInfo>& biases, 116 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 117 118 virtual bool IsDequantizeSupported(const TensorInfo& input, 119 const TensorInfo& output, 120 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 121 122 virtual bool IsDetectionPostProcessSupported(const TensorInfo& boxEncodings, 123 const TensorInfo& scores, 124 const TensorInfo& anchors, 125 const TensorInfo& detectionBoxes, 126 const TensorInfo& detectionClasses, 127 const TensorInfo& detectionScores, 128 const TensorInfo& numDetections, 129 const DetectionPostProcessDescriptor& descriptor, 130 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const =0; 131 132 virtual bool IsDilatedDepthwiseConvolutionSupported( 133 const TensorInfo& input, 134 const TensorInfo& output, 135 const DepthwiseConvolution2dDescriptor& descriptor, 136 const TensorInfo& weights, 137 const Optional<TensorInfo>& biases, 138 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 139 140 virtual bool IsDivisionSupported(const TensorInfo& input0, 141 const TensorInfo& input1, 142 const TensorInfo& output, 143 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 144 145 virtual bool IsElementwiseUnarySupported(const TensorInfo& input, 146 const TensorInfo& output, 147 const ElementwiseUnaryDescriptor& descriptor, 148 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 149 150 ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead") 151 virtual bool IsEqualSupported(const TensorInfo& input0, 152 const TensorInfo& input1, 153 const TensorInfo& output, 154 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 155 156 virtual bool IsFakeQuantizationSupported(const TensorInfo& input, 157 const FakeQuantizationDescriptor& descriptor, 158 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 159 160 virtual bool IsFillSupported(const TensorInfo& input, 161 const TensorInfo& output, 162 const FillDescriptor& descriptor, 163 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 164 165 virtual bool IsFloorSupported(const TensorInfo& input, 166 const TensorInfo& output, 167 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 168 169 virtual bool IsFullyConnectedSupported(const TensorInfo& input, 170 const TensorInfo& output, 171 const TensorInfo& weights, 172 const TensorInfo& biases, 173 const FullyConnectedDescriptor& descriptor, 174 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 175 176 ARMNN_DEPRECATED_MSG("Use IsGatherSupported with descriptor instead") 177 virtual bool IsGatherSupported(const TensorInfo& input0, 178 const TensorInfo& input1, 179 const TensorInfo& output, 180 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 181 182 virtual bool IsGatherSupported(const TensorInfo& input0, 183 const TensorInfo& input1, 184 const TensorInfo& output, 185 const GatherDescriptor& descriptor, 186 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 187 188 ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead") 189 virtual bool IsGreaterSupported(const TensorInfo& input0, 190 const TensorInfo& input1, 191 const TensorInfo& ouput, 192 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 193 194 virtual bool IsInputSupported(const TensorInfo& input, 195 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 196 197 virtual bool IsInstanceNormalizationSupported( 198 const TensorInfo& input, 199 const TensorInfo& output, 200 const InstanceNormalizationDescriptor& descriptor, 201 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 202 203 virtual bool IsL2NormalizationSupported(const TensorInfo& input, 204 const TensorInfo& output, 205 const L2NormalizationDescriptor& descriptor, 206 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 207 208 virtual bool IsLogicalBinarySupported(const TensorInfo& input0, 209 const TensorInfo& input1, 210 const TensorInfo& output, 211 const LogicalBinaryDescriptor& descriptor, 212 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 213 214 virtual bool IsLogicalUnarySupported(const TensorInfo& input, 215 const TensorInfo& output, 216 const ElementwiseUnaryDescriptor& descriptor, 217 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 218 219 virtual bool IsLogSoftmaxSupported(const TensorInfo& input, 220 const TensorInfo& output, 221 const LogSoftmaxDescriptor& descriptor, 222 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 223 224 virtual bool IsLstmSupported(const TensorInfo& input, 225 const TensorInfo& outputStateIn, 226 const TensorInfo& cellStateIn, 227 const TensorInfo& scratchBuffer, 228 const TensorInfo& outputStateOut, 229 const TensorInfo& cellStateOut, 230 const TensorInfo& output, 231 const LstmDescriptor& descriptor, 232 const LstmInputParamsInfo& paramsInfo, 233 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 234 235 virtual bool IsMaximumSupported(const TensorInfo& input0, 236 const TensorInfo& input1, 237 const TensorInfo& output, 238 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 239 240 virtual bool IsMeanSupported(const TensorInfo& input, 241 const TensorInfo& output, 242 const MeanDescriptor& descriptor, 243 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 244 245 virtual bool IsMemCopySupported(const TensorInfo& input, 246 const TensorInfo& output, 247 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 248 249 virtual bool IsMemImportSupported(const TensorInfo& input, 250 const TensorInfo& output, 251 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 252 253 virtual bool IsMergeSupported(const TensorInfo& input0, 254 const TensorInfo& input1, 255 const TensorInfo& output, 256 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 257 258 ARMNN_DEPRECATED_MSG("Use IsConcatSupported instead") 259 virtual bool IsMergerSupported(const std::vector<const TensorInfo*> inputs, 260 const TensorInfo& output, 261 const OriginsDescriptor& descriptor, 262 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 263 264 virtual bool IsMinimumSupported(const TensorInfo& input0, 265 const TensorInfo& input1, 266 const TensorInfo& ouput, 267 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 268 269 virtual bool IsMultiplicationSupported(const TensorInfo& input0, 270 const TensorInfo& input1, 271 const TensorInfo& output, 272 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 273 274 virtual bool IsNormalizationSupported(const TensorInfo& input, 275 const TensorInfo& output, 276 const NormalizationDescriptor& descriptor, 277 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 278 279 virtual bool IsOutputSupported(const TensorInfo& output, 280 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 281 282 virtual bool IsPadSupported(const TensorInfo& input, 283 const TensorInfo& output, 284 const PadDescriptor& descriptor, 285 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 286 287 virtual bool IsPermuteSupported(const TensorInfo& input, 288 const TensorInfo& output, 289 const PermuteDescriptor& descriptor, 290 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 291 292 virtual bool IsPooling2dSupported(const TensorInfo& input, 293 const TensorInfo& output, 294 const Pooling2dDescriptor& descriptor, 295 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 296 297 virtual bool IsPreCompiledSupported(const TensorInfo& input, 298 const PreCompiledDescriptor& descriptor, 299 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 300 301 virtual bool IsPreluSupported(const TensorInfo& input, 302 const TensorInfo& alpha, 303 const TensorInfo& output, 304 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 305 306 virtual bool IsQuantizeSupported(const TensorInfo& input, 307 const TensorInfo& output, 308 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 309 310 virtual bool IsQLstmSupported(const TensorInfo& input, 311 const TensorInfo& previousOutputIn, 312 const TensorInfo& previousCellStateIn, 313 const TensorInfo& outputStateOut, 314 const TensorInfo& cellStateOut, 315 const TensorInfo& output, 316 const QLstmDescriptor& descriptor, 317 const LstmInputParamsInfo& paramsInfo, 318 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 319 320 virtual bool IsQuantizedLstmSupported(const TensorInfo& input, 321 const TensorInfo& previousCellStateIn, 322 const TensorInfo& previousOutputIn, 323 const TensorInfo& cellStateOut, 324 const TensorInfo& output, 325 const QuantizedLstmInputParamsInfo& paramsInfo, 326 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 327 328 virtual bool IsRankSupported(const TensorInfo& input, 329 const TensorInfo& output, 330 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 331 332 virtual bool IsReshapeSupported(const TensorInfo& input, 333 const TensorInfo& output, 334 const ReshapeDescriptor& descriptor, 335 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 336 337 ARMNN_DEPRECATED_MSG("Use IsResizeSupported instead") 338 virtual bool IsResizeBilinearSupported(const TensorInfo& input, 339 const TensorInfo& output, 340 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 341 342 virtual bool IsResizeSupported(const TensorInfo& input, 343 const TensorInfo& output, 344 const ResizeDescriptor& descriptor, 345 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 346 347 ARMNN_DEPRECATED_MSG("Use IsElementwiseUnarySupported instead") 348 virtual bool IsRsqrtSupported(const TensorInfo& input, 349 const TensorInfo& output, 350 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 351 352 virtual bool IsSliceSupported(const TensorInfo& input, 353 const TensorInfo& output, 354 const SliceDescriptor& descriptor, 355 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 356 357 virtual bool IsSoftmaxSupported(const TensorInfo& input, 358 const TensorInfo& output, 359 const SoftmaxDescriptor& descriptor, 360 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 361 362 virtual bool IsSpaceToBatchNdSupported(const TensorInfo& input, 363 const TensorInfo& output, 364 const SpaceToBatchNdDescriptor& descriptor, 365 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 366 367 virtual bool IsSpaceToDepthSupported(const TensorInfo& input, 368 const TensorInfo& output, 369 const SpaceToDepthDescriptor& descriptor, 370 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 371 372 ARMNN_DEPRECATED_MSG("Use IsSplitterSupported with outputs instead") 373 virtual bool IsSplitterSupported(const TensorInfo& input, 374 const ViewsDescriptor& descriptor, 375 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 376 377 virtual bool IsSplitterSupported(const TensorInfo& input, 378 const std::vector<std::reference_wrapper<TensorInfo>>& outputs, 379 const ViewsDescriptor& descriptor, 380 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 381 382 virtual bool IsStackSupported(const std::vector<const TensorInfo*>& inputs, 383 const TensorInfo& output, 384 const StackDescriptor& descriptor, 385 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 386 387 virtual bool IsStandInSupported(const std::vector<const TensorInfo*>& inputs, 388 const std::vector<const TensorInfo*>& outputs, 389 const StandInDescriptor& descriptor, 390 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 391 392 393 virtual bool IsStridedSliceSupported(const TensorInfo& input, 394 const TensorInfo& output, 395 const StridedSliceDescriptor& descriptor, 396 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 397 398 virtual bool IsSubtractionSupported(const TensorInfo& input0, 399 const TensorInfo& input1, 400 const TensorInfo& output, 401 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 402 403 virtual bool IsSwitchSupported(const TensorInfo& input0, 404 const TensorInfo& input1, 405 const TensorInfo& output0, 406 const TensorInfo& output1, 407 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 408 409 virtual bool IsTransposeConvolution2dSupported( 410 const TensorInfo& input, 411 const TensorInfo& output, 412 const TransposeConvolution2dDescriptor& descriptor, 413 const TensorInfo& weights, 414 const Optional<TensorInfo>& biases, 415 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 416 417 virtual bool IsTransposeSupported(const TensorInfo& input, 418 const TensorInfo& output, 419 const TransposeDescriptor& descriptor, 420 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; 421 422 }; // class ILayerSupport 423 424 using ILayerSupportSharedPtr = std::shared_ptr<ILayerSupport>; 425 426 } // namespace armnn 427