• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/TypesUtils.hpp>
9 
10 #include <backendsCommon/Workload.hpp>
11 
12 namespace armnn
13 {
14 
15 template <armnn::DataType DataType>
16 class RefDebugWorkload : public TypedWorkload<DebugQueueDescriptor, DataType>
17 {
18 public:
RefDebugWorkload(const DebugQueueDescriptor & descriptor,const WorkloadInfo & info)19     RefDebugWorkload(const DebugQueueDescriptor& descriptor, const WorkloadInfo& info)
20     : TypedWorkload<DebugQueueDescriptor, DataType>(descriptor, info)
21     , m_Callback(nullptr) {}
22 
GetName()23     static const std::string& GetName()
24     {
25         static const std::string name = std::string("RefDebug") + GetDataTypeName(DataType) + "Workload";
26         return name;
27     }
28 
29     using TypedWorkload<DebugQueueDescriptor, DataType>::m_Data;
30     using TypedWorkload<DebugQueueDescriptor, DataType>::TypedWorkload;
31 
32     void Execute() const override;
33 
34     void RegisterDebugCallback(const DebugCallbackFunction& func) override;
35 
36 private:
37     DebugCallbackFunction m_Callback;
38 };
39 
40 using RefDebugBFloat16Workload   = RefDebugWorkload<DataType::BFloat16>;
41 using RefDebugFloat16Workload   = RefDebugWorkload<DataType::Float16>;
42 using RefDebugFloat32Workload   = RefDebugWorkload<DataType::Float32>;
43 using RefDebugQAsymmU8Workload  = RefDebugWorkload<DataType::QAsymmU8>;
44 using RefDebugQAsymmS8Workload  = RefDebugWorkload<DataType::QAsymmS8>;
45 using RefDebugQSymmS16Workload  = RefDebugWorkload<DataType::QSymmS16>;
46 using RefDebugQSymmS8Workload   = RefDebugWorkload<DataType::QSymmS8>;
47 using RefDebugSigned32Workload  = RefDebugWorkload<DataType::Signed32>;
48 
49 } // namespace armnn
50