1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include <armnn/IRuntime.hpp> 8 #include <armnn/Optional.hpp> 9 10 #include <armnn/backends/IBackendInternal.hpp> 11 12 #include <backendsCommon/WorkloadFactoryBase.hpp> 13 #include <aclCommon/BaseMemoryManager.hpp> 14 15 #include <arm_compute/core/CL/CLCompileContext.h> 16 17 namespace armnn 18 { 19 20 // ARM Compute OpenCL workload factory. 21 class ClWorkloadFactory : public WorkloadFactoryBase 22 { 23 public: 24 ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager); 25 26 ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager, 27 const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr); 28 29 void AfterWorkloadsCreated() override; 30 31 const BackendId& GetBackendId() const override; 32 33 static bool IsLayerSupported(const Layer& layer, 34 Optional<DataType> dataType, 35 std::string& outReasonIfUnsupported); 36 37 static bool IsLayerSupported(const IConnectableLayer& layer, 38 Optional<DataType> dataType, 39 std::string& outReasonIfUnsupported, 40 const ModelOptions& modelOptions); 41 SupportsSubTensors() const42 bool SupportsSubTensors() const override { return true; } 43 44 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead") 45 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent, 46 TensorShape const& subTensorShape, 47 unsigned int const* subTensorOrigin) const override; 48 49 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead") 50 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 51 const bool IsMemoryManaged = true) const override; 52 53 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead") 54 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 55 DataLayout dataLayout, 56 const bool IsMemoryManaged = true) const override; 57 58 ARMNN_DEPRECATED_MSG("Use CreateElementwiseUnary instead") 59 std::unique_ptr<IWorkload> CreateAbs(const AbsQueueDescriptor& descriptor, 60 const WorkloadInfo& info) const override; 61 62 std::unique_ptr<IWorkload> CreateActivation(const ActivationQueueDescriptor& descriptor, 63 const WorkloadInfo& info) const override; 64 65 std::unique_ptr<IWorkload> CreateAddition(const AdditionQueueDescriptor& descriptor, 66 const WorkloadInfo& info) const override; 67 68 std::unique_ptr<IWorkload> CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor, 69 const WorkloadInfo& info) const override; 70 71 std::unique_ptr<IWorkload> CreateBatchNormalization(const BatchNormalizationQueueDescriptor& descriptor, 72 const WorkloadInfo& info) const override; 73 74 std::unique_ptr<IWorkload> CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor, 75 const WorkloadInfo& info) const override; 76 77 std::unique_ptr<IWorkload> CreateComparison(const ComparisonQueueDescriptor& descriptor, 78 const WorkloadInfo& info) const override; 79 80 std::unique_ptr<IWorkload> CreateConcat(const ConcatQueueDescriptor& descriptor, 81 const WorkloadInfo& info) const override; 82 83 std::unique_ptr<IWorkload> CreateConstant(const ConstantQueueDescriptor& descriptor, 84 const WorkloadInfo& info) const override; 85 86 std::unique_ptr<IWorkload> CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor, 87 const WorkloadInfo& info) const override; 88 89 std::unique_ptr<IWorkload> CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor, 90 const WorkloadInfo& info) const override; 91 92 std::unique_ptr<IWorkload> CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor, 93 const WorkloadInfo& info) const override; 94 95 std::unique_ptr<IWorkload> CreateDebug(const DebugQueueDescriptor& descriptor, 96 const WorkloadInfo& info) const override; 97 98 std::unique_ptr<IWorkload> CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor, 99 const WorkloadInfo& info) const override; 100 101 std::unique_ptr<IWorkload> CreateDepthwiseConvolution2d(const DepthwiseConvolution2dQueueDescriptor& descriptor, 102 const WorkloadInfo& info) const override; 103 104 std::unique_ptr<IWorkload> CreateDequantize(const DequantizeQueueDescriptor& descriptor, 105 const WorkloadInfo& info) const override; 106 107 std::unique_ptr<IWorkload> CreateDetectionPostProcess(const DetectionPostProcessQueueDescriptor& descriptor, 108 const WorkloadInfo& info) const override; 109 110 std::unique_ptr<IWorkload> CreateDivision(const DivisionQueueDescriptor& descriptor, 111 const WorkloadInfo& info) const override; 112 113 std::unique_ptr<IWorkload> CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor, 114 const WorkloadInfo& info) const override; 115 116 ARMNN_DEPRECATED_MSG("Use CreateComparison instead") 117 std::unique_ptr<IWorkload> CreateEqual(const EqualQueueDescriptor& descriptor, 118 const WorkloadInfo& info) const override; 119 120 std::unique_ptr<IWorkload> CreateFill(const FillQueueDescriptor& descriptor, 121 const WorkloadInfo& info) const override; 122 123 std::unique_ptr<IWorkload> CreateFloor(const FloorQueueDescriptor& descriptor, 124 const WorkloadInfo& info) const override; 125 126 std::unique_ptr<IWorkload> CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor, 127 const WorkloadInfo& info) const override; 128 129 std::unique_ptr<IWorkload> CreateGather(const GatherQueueDescriptor& descriptor, 130 const WorkloadInfo& info) const override; 131 132 ARMNN_DEPRECATED_MSG("Use CreateComparison instead") 133 std::unique_ptr<IWorkload> CreateGreater(const GreaterQueueDescriptor& descriptor, 134 const WorkloadInfo& info) const override; 135 136 std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor, 137 const WorkloadInfo& info) const override; 138 139 std::unique_ptr<IWorkload> CreateInstanceNormalization(const InstanceNormalizationQueueDescriptor& descriptor, 140 const WorkloadInfo& info) const override; 141 142 std::unique_ptr<IWorkload> CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, 143 const WorkloadInfo& info) const override; 144 145 std::unique_ptr<IWorkload> CreateLogicalBinary(const LogicalBinaryQueueDescriptor& descriptor, 146 const WorkloadInfo& info) const override; 147 148 std::unique_ptr<IWorkload> CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor, 149 const WorkloadInfo& info) const override; 150 151 std::unique_ptr<IWorkload> CreateLstm(const LstmQueueDescriptor& descriptor, 152 const WorkloadInfo& info) const override; 153 154 std::unique_ptr<IWorkload> CreateMaximum(const MaximumQueueDescriptor& descriptor, 155 const WorkloadInfo& info) const override; 156 157 std::unique_ptr<IWorkload> CreateMean(const MeanQueueDescriptor& descriptor, 158 const WorkloadInfo& Info) const override; 159 160 std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor, 161 const WorkloadInfo& info) const override; 162 163 std::unique_ptr<IWorkload> CreateMemImport(const MemImportQueueDescriptor& descriptor, 164 const WorkloadInfo& info) const override; 165 166 ARMNN_DEPRECATED_MSG("Use CreateConcat instead") 167 std::unique_ptr<IWorkload> CreateMerger(const MergerQueueDescriptor& descriptor, 168 const WorkloadInfo& info) const override; 169 170 std::unique_ptr<IWorkload> CreateMinimum(const MinimumQueueDescriptor& descriptor, 171 const WorkloadInfo& info) const override; 172 173 std::unique_ptr<IWorkload> CreateMultiplication(const MultiplicationQueueDescriptor& descriptor, 174 const WorkloadInfo& info) const override; 175 176 std::unique_ptr<IWorkload> CreateNormalization(const NormalizationQueueDescriptor& descriptor, 177 const WorkloadInfo& info) const override; 178 179 std::unique_ptr<IWorkload> CreateOutput(const OutputQueueDescriptor& descriptor, 180 const WorkloadInfo& info) const override; 181 182 std::unique_ptr<IWorkload> CreatePad(const PadQueueDescriptor& descriptor, 183 const WorkloadInfo& info) const override; 184 185 std::unique_ptr<IWorkload> CreatePermute(const PermuteQueueDescriptor& descriptor, 186 const WorkloadInfo& info) const override; 187 188 std::unique_ptr<IWorkload> CreatePooling2d(const Pooling2dQueueDescriptor& descriptor, 189 const WorkloadInfo& info) const override; 190 191 std::unique_ptr<IWorkload> CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor, 192 const WorkloadInfo& info) const override; 193 194 std::unique_ptr<IWorkload> CreatePrelu(const PreluQueueDescriptor& descriptor, 195 const WorkloadInfo& info) const override; 196 197 std::unique_ptr<IWorkload> CreateQLstm(const QLstmQueueDescriptor& descriptor, 198 const WorkloadInfo& info) const override; 199 200 std::unique_ptr<IWorkload> CreateQuantize(const QuantizeQueueDescriptor& descriptor, 201 const WorkloadInfo& info) const override; 202 203 std::unique_ptr<IWorkload> CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor, 204 const WorkloadInfo& info) const override; 205 206 std::unique_ptr<IWorkload> CreateReshape(const ReshapeQueueDescriptor& descriptor, 207 const WorkloadInfo& info) const override; 208 209 std::unique_ptr<IWorkload> CreateResize(const ResizeQueueDescriptor& descriptor, 210 const WorkloadInfo& info) const override; 211 212 ARMNN_DEPRECATED_MSG("Use CreateResize instead") 213 std::unique_ptr<IWorkload> CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor, 214 const WorkloadInfo& info) const override; 215 216 ARMNN_DEPRECATED_MSG("Use CreateElementwiseUnary instead") 217 std::unique_ptr<IWorkload> CreateRsqrt(const RsqrtQueueDescriptor& descriptor, 218 const WorkloadInfo& info) const override; 219 220 std::unique_ptr<IWorkload> CreateSlice(const SliceQueueDescriptor& descriptor, 221 const WorkloadInfo& info) const override; 222 223 std::unique_ptr<IWorkload> CreateSoftmax(const SoftmaxQueueDescriptor& descriptor, 224 const WorkloadInfo& info) const override; 225 226 std::unique_ptr<IWorkload> CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor, 227 const WorkloadInfo& info) const override; 228 229 std::unique_ptr<IWorkload> CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor, 230 const WorkloadInfo& info) const override; 231 232 std::unique_ptr<IWorkload> CreateSplitter(const SplitterQueueDescriptor& descriptor, 233 const WorkloadInfo& info) const override; 234 235 std::unique_ptr<IWorkload> CreateStack(const StackQueueDescriptor& descriptor, 236 const WorkloadInfo& info) const override; 237 238 std::unique_ptr<IWorkload> CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor, 239 const WorkloadInfo& info) const override; 240 241 std::unique_ptr<IWorkload> CreateSubtraction(const SubtractionQueueDescriptor& descriptor, 242 const WorkloadInfo& info) const override; 243 244 std::unique_ptr<IWorkload> CreateTranspose(const TransposeQueueDescriptor& descriptor, 245 const WorkloadInfo& info) const override; 246 247 std::unique_ptr<IWorkload> CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& descriptor, 248 const WorkloadInfo& info) const override; 249 250 private: 251 template<typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args> 252 static std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, 253 const WorkloadInfo& info, 254 Args&&... args); 255 256 template <typename Workload, typename QueueDescriptorType, typename... Args> 257 static std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, 258 const WorkloadInfo& info, 259 Args&&... args); 260 261 void InitializeCLCompileContext(); 262 263 mutable std::shared_ptr<ClMemoryManager> m_MemoryManager; 264 const IBackendInternal::IBackendSpecificModelContextPtr m_ModelContextPtr; 265 arm_compute::CLCompileContext m_CLCompileContext; 266 }; 267 268 } // namespace armnn 269