1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/TypesUtils.hpp> 9 10 #include <backendsCommon/Workload.hpp> 11 12 namespace armnn 13 { 14 15 template <armnn::DataType DataType> 16 class RefDebugWorkload : public TypedWorkload<DebugQueueDescriptor, DataType> 17 { 18 public: RefDebugWorkload(const DebugQueueDescriptor & descriptor,const WorkloadInfo & info)19 RefDebugWorkload(const DebugQueueDescriptor& descriptor, const WorkloadInfo& info) 20 : TypedWorkload<DebugQueueDescriptor, DataType>(descriptor, info) 21 , m_Callback(nullptr) {} 22 GetName()23 static const std::string& GetName() 24 { 25 static const std::string name = std::string("RefDebug") + GetDataTypeName(DataType) + "Workload"; 26 return name; 27 } 28 29 using TypedWorkload<DebugQueueDescriptor, DataType>::m_Data; 30 using TypedWorkload<DebugQueueDescriptor, DataType>::TypedWorkload; 31 32 void Execute() const override; 33 34 void RegisterDebugCallback(const DebugCallbackFunction& func) override; 35 36 private: 37 DebugCallbackFunction m_Callback; 38 }; 39 40 using RefDebugBFloat16Workload = RefDebugWorkload<DataType::BFloat16>; 41 using RefDebugFloat16Workload = RefDebugWorkload<DataType::Float16>; 42 using RefDebugFloat32Workload = RefDebugWorkload<DataType::Float32>; 43 using RefDebugQAsymmU8Workload = RefDebugWorkload<DataType::QAsymmU8>; 44 using RefDebugQAsymmS8Workload = RefDebugWorkload<DataType::QAsymmS8>; 45 using RefDebugQSymmS16Workload = RefDebugWorkload<DataType::QSymmS16>; 46 using RefDebugQSymmS8Workload = RefDebugWorkload<DataType::QSymmS8>; 47 using RefDebugSigned32Workload = RefDebugWorkload<DataType::Signed32>; 48 49 } // namespace armnn 50