1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include <armnn/Optional.hpp> 8 #include <armnn/backends/IBackendInternal.hpp> 9 10 #include <backendsCommon/WorkloadFactoryBase.hpp> 11 #include <aclCommon/BaseMemoryManager.hpp> 12 #include <armnn/utility/IgnoreUnused.hpp> 13 14 namespace armnn 15 { 16 17 // Neon workload factory. 18 class NeonWorkloadFactory : public WorkloadFactoryBase 19 { 20 public: 21 NeonWorkloadFactory(const std::shared_ptr<NeonMemoryManager>& memoryManager); 22 23 NeonWorkloadFactory(const std::shared_ptr<NeonMemoryManager>& memoryManager, 24 const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr); 25 26 const BackendId& GetBackendId() const override; 27 28 static bool IsLayerSupported(const Layer& layer, 29 Optional<DataType> dataType, 30 std::string& outReasonIfUnsupported); 31 32 static bool IsLayerSupported(const IConnectableLayer& layer, 33 Optional<DataType> dataType, 34 std::string& outReasonIfUnsupported, 35 const ModelOptions& modelOptions); 36 SupportsSubTensors() const37 bool SupportsSubTensors() const override { return true; } 38 39 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead") 40 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent, 41 TensorShape const& subTensorShape, 42 unsigned int const* subTensorOrigin) const override; 43 44 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead") 45 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 46 const bool IsMemoryManaged = true) const override; 47 48 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead") 49 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 50 DataLayout dataLayout, 51 const bool IsMemoryManaged = true) const override; 52 53 ARMNN_DEPRECATED_MSG("Use CreateElementwiseUnary instead") 54 std::unique_ptr<IWorkload> CreateAbs(const AbsQueueDescriptor& descriptor, 55 const WorkloadInfo& info) const override; 56 57 std::unique_ptr<IWorkload> CreateActivation(const ActivationQueueDescriptor& descriptor, 58 const WorkloadInfo& info) const override; 59 60 std::unique_ptr<IWorkload> CreateAddition(const AdditionQueueDescriptor& descriptor, 61 const WorkloadInfo& info) const override; 62 63 std::unique_ptr<IWorkload> CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor, 64 const WorkloadInfo& info) const override; 65 66 std::unique_ptr<IWorkload> CreateBatchNormalization(const BatchNormalizationQueueDescriptor& descriptor, 67 const WorkloadInfo& info) const override; 68 69 std::unique_ptr<IWorkload> CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor, 70 const WorkloadInfo& Info) const override; 71 72 std::unique_ptr<IWorkload> CreateComparison(const ComparisonQueueDescriptor& descriptor, 73 const WorkloadInfo& Info) const override; 74 75 std::unique_ptr<IWorkload> CreateConcat(const ConcatQueueDescriptor& descriptor, 76 const WorkloadInfo& info) const override; 77 78 std::unique_ptr<IWorkload> CreateConstant(const ConstantQueueDescriptor& descriptor, 79 const WorkloadInfo& info) const override; 80 81 std::unique_ptr<IWorkload> CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& descriptor, 82 const WorkloadInfo& info) const override; 83 84 std::unique_ptr<IWorkload> CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor, 85 const WorkloadInfo& info) const override; 86 87 std::unique_ptr<IWorkload> CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& descriptor, 88 const WorkloadInfo& info) const override; 89 90 std::unique_ptr<IWorkload> CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor, 91 const WorkloadInfo& info) const override; 92 93 std::unique_ptr<IWorkload> CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor, 94 const WorkloadInfo& info) const override; 95 96 std::unique_ptr<IWorkload> CreateDebug(const DebugQueueDescriptor& descriptor, 97 const WorkloadInfo& info) const override; 98 99 std::unique_ptr<IWorkload> CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor, 100 const WorkloadInfo& info) const override; 101 102 std::unique_ptr<IWorkload> CreateDepthwiseConvolution2d(const DepthwiseConvolution2dQueueDescriptor& descriptor, 103 const WorkloadInfo& info) const override; 104 105 std::unique_ptr<IWorkload> CreateDequantize(const DequantizeQueueDescriptor& descriptor, 106 const WorkloadInfo& info) const override; 107 108 std::unique_ptr<IWorkload> CreateDetectionPostProcess(const DetectionPostProcessQueueDescriptor& descriptor, 109 const WorkloadInfo& info) const override; 110 111 std::unique_ptr<IWorkload> CreateDivision(const DivisionQueueDescriptor& descriptor, 112 const WorkloadInfo& info) const override; 113 114 std::unique_ptr<IWorkload> CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor, 115 const WorkloadInfo& Info) const override; 116 117 ARMNN_DEPRECATED_MSG("Use CreateComparison instead") 118 std::unique_ptr<IWorkload> CreateEqual(const EqualQueueDescriptor& descriptor, 119 const WorkloadInfo& info) const override; 120 121 std::unique_ptr<IWorkload> CreateFill(const FillQueueDescriptor& descriptor, 122 const WorkloadInfo& info) const override; 123 124 std::unique_ptr<IWorkload> CreateFloor(const FloorQueueDescriptor& descriptor, 125 const WorkloadInfo& info) const override; 126 127 std::unique_ptr<IWorkload> CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor, 128 const WorkloadInfo& info) const override; 129 130 std::unique_ptr<IWorkload> CreateGather(const GatherQueueDescriptor& descriptor, 131 const WorkloadInfo& info) const override; 132 133 ARMNN_DEPRECATED_MSG("Use CreateComparison instead") 134 std::unique_ptr<IWorkload> CreateGreater(const GreaterQueueDescriptor& descriptor, 135 const WorkloadInfo& info) const override; 136 137 std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor, 138 const WorkloadInfo& info) const override; 139 140 std::unique_ptr<IWorkload> CreateInstanceNormalization(const InstanceNormalizationQueueDescriptor& descriptor, 141 const WorkloadInfo& info) const override; 142 143 std::unique_ptr<IWorkload> CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, 144 const WorkloadInfo& info) const override; 145 146 std::unique_ptr<IWorkload> CreateLogicalBinary(const LogicalBinaryQueueDescriptor& descriptor, 147 const WorkloadInfo& info) const override; 148 149 std::unique_ptr<IWorkload> CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor, 150 const WorkloadInfo& info) const override; 151 152 std::unique_ptr<IWorkload> CreateLstm(const LstmQueueDescriptor& descriptor, 153 const WorkloadInfo& info) const override; 154 155 std::unique_ptr<IWorkload> CreateMaximum(const MaximumQueueDescriptor& descriptor, 156 const WorkloadInfo& info) const override; 157 158 std::unique_ptr<IWorkload> CreateMean(const MeanQueueDescriptor& descriptor, 159 const WorkloadInfo& Info) const override; 160 161 std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor, 162 const WorkloadInfo& info) const override; 163 164 std::unique_ptr<IWorkload> CreateMemImport(const MemImportQueueDescriptor& descriptor, 165 const WorkloadInfo& info) const override; 166 167 ARMNN_DEPRECATED_MSG("Use CreateConcat instead") 168 std::unique_ptr<IWorkload> CreateMerger(const MergerQueueDescriptor& descriptor, 169 const WorkloadInfo& info) const override; 170 171 std::unique_ptr<IWorkload> CreateMinimum(const MinimumQueueDescriptor& descriptor, 172 const WorkloadInfo& info) const override; 173 174 std::unique_ptr<IWorkload> CreateMultiplication(const MultiplicationQueueDescriptor& descriptor, 175 const WorkloadInfo& info) const override; 176 177 std::unique_ptr<IWorkload> CreateNormalization(const NormalizationQueueDescriptor& descriptor, 178 const WorkloadInfo& info) const override; 179 180 std::unique_ptr<IWorkload> CreateOutput(const OutputQueueDescriptor& descriptor, 181 const WorkloadInfo& info) const override; 182 183 std::unique_ptr<IWorkload> CreatePad(const PadQueueDescriptor& descriptor, 184 const WorkloadInfo& info) const override; 185 186 std::unique_ptr<IWorkload> CreatePermute(const PermuteQueueDescriptor& descriptor, 187 const WorkloadInfo& info) const override; 188 189 std::unique_ptr<IWorkload> CreatePooling2d(const Pooling2dQueueDescriptor& descriptor, 190 const WorkloadInfo& info) const override; 191 192 std::unique_ptr<IWorkload> CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor, 193 const WorkloadInfo& info) const override; 194 195 std::unique_ptr<IWorkload> CreatePrelu(const PreluQueueDescriptor& descriptor, 196 const WorkloadInfo& info) const override; 197 198 std::unique_ptr<IWorkload> CreateQLstm(const QLstmQueueDescriptor& descriptor, 199 const WorkloadInfo& info) const override; 200 201 std::unique_ptr<IWorkload> CreateQuantize(const QuantizeQueueDescriptor& descriptor, 202 const WorkloadInfo& info) const override; 203 204 std::unique_ptr<IWorkload> CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor, 205 const WorkloadInfo& info) const override; 206 207 std::unique_ptr<IWorkload> CreateReshape(const ReshapeQueueDescriptor& descriptor, 208 const WorkloadInfo& info) const override; 209 210 std::unique_ptr<IWorkload> CreateResize(const ResizeQueueDescriptor& descriptor, 211 const WorkloadInfo& info) const override; 212 213 ARMNN_DEPRECATED_MSG("Use CreateResize instead") 214 std::unique_ptr<IWorkload> CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor, 215 const WorkloadInfo& info) const override; 216 217 ARMNN_DEPRECATED_MSG("Use CreateElementwiseUnary instead") 218 std::unique_ptr<IWorkload> CreateRsqrt(const RsqrtQueueDescriptor& descriptor, 219 const WorkloadInfo& info) const override; 220 221 std::unique_ptr<IWorkload> CreateSlice(const SliceQueueDescriptor& descriptor, 222 const WorkloadInfo& info) const override; 223 224 std::unique_ptr<IWorkload> CreateSoftmax(const SoftmaxQueueDescriptor& descriptor, 225 const WorkloadInfo& info) const override; 226 227 std::unique_ptr<IWorkload> CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor, 228 const WorkloadInfo& info) const override; 229 230 std::unique_ptr<IWorkload> CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor, 231 const WorkloadInfo& info) const override; 232 233 std::unique_ptr<IWorkload> CreateSplitter(const SplitterQueueDescriptor& descriptor, 234 const WorkloadInfo& info) const override; 235 236 std::unique_ptr<IWorkload> CreateStack(const StackQueueDescriptor& descriptor, 237 const WorkloadInfo& info) const override; 238 239 std::unique_ptr<IWorkload> CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor, 240 const WorkloadInfo& info) const override; 241 242 std::unique_ptr<IWorkload> CreateSubtraction(const SubtractionQueueDescriptor& descriptor, 243 const WorkloadInfo& info) const override; 244 245 std::unique_ptr<IWorkload> CreateTranspose(const TransposeQueueDescriptor& descriptor, 246 const WorkloadInfo& info) const override; 247 248 std::unique_ptr<IWorkload> CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& descriptor, 249 const WorkloadInfo& info) const override; 250 251 private: 252 mutable std::shared_ptr<NeonMemoryManager> m_MemoryManager; 253 const IBackendInternal::IBackendSpecificModelContextPtr m_ModelContextPtr; 254 }; 255 256 } // namespace armnn 257