• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "WorkloadData.hpp"
8 #include "WorkloadInfo.hpp"
9 
10 #include <armnn/backends/IWorkload.hpp>
11 #include <Profiling.hpp>
12 #include <ProfilingService.hpp>
13 
14 #include <algorithm>
15 
16 namespace armnn
17 {
18 
19 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
20 // in the various workload factories.
21 // There should never be an instantiation of a NullWorkload.
22 class NullWorkload : public IWorkload
23 {
24     NullWorkload()=delete;
25 };
26 
27 template <typename QueueDescriptor>
28 class BaseWorkload : public IWorkload
29 {
30 public:
31 
BaseWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)32     BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
33         : m_Data(descriptor),
34           m_Guid(profiling::ProfilingService::GetNextGuid())
35     {
36         m_Data.Validate(info);
37     }
38 
PostAllocationConfigure()39     void PostAllocationConfigure() override {}
40 
GetData() const41     const QueueDescriptor& GetData() const { return m_Data; }
42 
GetGuid() const43     profiling::ProfilingGuid GetGuid() const final { return m_Guid; }
44 
45 protected:
46     const QueueDescriptor m_Data;
47     const profiling::ProfilingGuid m_Guid;
48 };
49 
50 // TypedWorkload used
51 template <typename QueueDescriptor, armnn::DataType... DataTypes>
52 class TypedWorkload : public BaseWorkload<QueueDescriptor>
53 {
54 public:
55 
TypedWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)56     TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
57         : BaseWorkload<QueueDescriptor>(descriptor, info)
58     {
59         std::vector<armnn::DataType> dataTypes = {DataTypes...};
60         armnn::DataType expectedInputType;
61 
62         if (!info.m_InputTensorInfos.empty())
63         {
64             expectedInputType = info.m_InputTensorInfos.front().GetDataType();
65 
66             if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
67             {
68                 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
69             }
70             ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
71                                          info.m_InputTensorInfos.end(),
72                                          [&](auto it){
73                                              return it.GetDataType() == expectedInputType;
74                                          }),
75                              "Trying to create workload with incorrect type");
76         }
77         armnn::DataType expectedOutputType;
78 
79         if (!info.m_OutputTensorInfos.empty())
80         {
81             expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
82 
83             if (!info.m_InputTensorInfos.empty())
84             {
85                 if (expectedOutputType != expectedInputType)
86                 {
87                     ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
88                 }
89             }
90             else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
91             {
92                 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
93             }
94             ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
95                                          info.m_OutputTensorInfos.end(),
96                                          [&](auto it){
97                                              return it.GetDataType() == expectedOutputType;
98                                          }),
99                              "Trying to create workload with incorrect type");
100         }
101     }
102 };
103 
104 template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
105 class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
106 {
107 public:
108 
MultiTypedWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)109     MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
110         : BaseWorkload<QueueDescriptor>(descriptor, info)
111     {
112         ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
113                                      info.m_InputTensorInfos.end(),
114                                      [&](auto it){
115                                          return it.GetDataType() == InputDataType;
116                                      }),
117                          "Trying to create workload with incorrect type");
118 
119         ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
120                                      info.m_OutputTensorInfos.end(),
121                                      [&](auto it){
122                                          return it.GetDataType() == OutputDataType;
123                                      }),
124                          "Trying to create workload with incorrect type");
125     }
126 };
127 
128 // FirstInputTypedWorkload used to check type of the first input
129 template <typename QueueDescriptor, armnn::DataType DataType>
130 class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
131 {
132 public:
133 
FirstInputTypedWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)134     FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
135         : BaseWorkload<QueueDescriptor>(descriptor, info)
136     {
137         if (!info.m_InputTensorInfos.empty())
138         {
139             ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
140                                  "Trying to create workload with incorrect type");
141         }
142 
143         ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
144                                      info.m_OutputTensorInfos.end(),
145                                      [&](auto it){
146                                          return it.GetDataType() == DataType;
147                                      }),
148                          "Trying to create workload with incorrect type");
149     }
150 };
151 
152 template <typename QueueDescriptor>
153 using FloatWorkload = TypedWorkload<QueueDescriptor,
154                                     armnn::DataType::Float16,
155                                     armnn::DataType::Float32>;
156 
157 template <typename QueueDescriptor>
158 using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
159 
160 template <typename QueueDescriptor>
161 using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QAsymmU8>;
162 
163 template <typename QueueDescriptor>
164 using Int32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Signed32>;
165 
166 template <typename QueueDescriptor>
167 using BooleanWorkload = TypedWorkload<QueueDescriptor, armnn::DataType::Boolean>;
168 
169 template <typename QueueDescriptor>
170 using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
171                                                          armnn::DataType::Float32,
172                                                          armnn::DataType::Boolean>;
173 
174 template <typename QueueDescriptor>
175 using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
176                                                        armnn::DataType::QAsymmU8,
177                                                        armnn::DataType::Boolean>;
178 
179 template <typename QueueDescriptor>
180 using BFloat16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
181                                                      armnn::DataType::BFloat16,
182                                                      armnn::DataType::Float32>;
183 
184 template <typename QueueDescriptor>
185 using Float32ToBFloat16Workload = MultiTypedWorkload<QueueDescriptor,
186                                                      armnn::DataType::Float32,
187                                                      armnn::DataType::BFloat16>;
188 
189 template <typename QueueDescriptor>
190 using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
191                                                     armnn::DataType::Float16,
192                                                     armnn::DataType::Float32>;
193 
194 template <typename QueueDescriptor>
195 using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
196                                                     armnn::DataType::Float32,
197                                                     armnn::DataType::Float16>;
198 
199 template <typename QueueDescriptor>
200 using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
201                                                   armnn::DataType::QAsymmU8,
202                                                   armnn::DataType::Float32>;
203 
204 } //namespace armnn
205