1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "TestLayerVisitor.hpp"
8 #include "LayersFwd.hpp"
9 #include <armnn/Descriptors.hpp>
10 #include <armnn/LstmParams.hpp>
11 #include <armnn/QuantizedLstmParams.hpp>
12 #include <armnn/utility/PolymorphicDowncast.hpp>
13 #include <armnn/backends/TensorHandle.hpp>
14
15 #include <doctest/doctest.h>
16
17 namespace armnn
18 {
19
20 class TestConvolution2dLayerVisitor : public TestLayerVisitor
21 {
22 public:
TestConvolution2dLayerVisitor(const Convolution2dDescriptor & convolution2dDescriptor,const char * name=nullptr)23 explicit TestConvolution2dLayerVisitor(const Convolution2dDescriptor& convolution2dDescriptor,
24 const char* name = nullptr)
25 : TestLayerVisitor(name)
26 , m_Descriptor(convolution2dDescriptor)
27 {}
28
~TestConvolution2dLayerVisitor()29 virtual ~TestConvolution2dLayerVisitor() {}
30
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)31 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
32 const armnn::BaseDescriptor& descriptor,
33 const std::vector<armnn::ConstTensor>& constants,
34 const char* name,
35 const armnn::LayerBindingId id = 0) override
36 {
37 armnn::IgnoreUnused(descriptor, constants, id);
38 switch (layer->GetType())
39 {
40 case armnn::LayerType::Convolution2d:
41 {
42 CheckLayerPointer(layer);
43 CheckLayerName(name);
44 CheckDescriptor(static_cast<const armnn::Convolution2dDescriptor&>(descriptor));
45 break;
46 }
47 default:
48 {
49 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
50 }
51 }
52 }
53
54 protected:
55 void CheckDescriptor(const Convolution2dDescriptor& convolution2dDescriptor);
56
57 private:
58 Convolution2dDescriptor m_Descriptor;
59 };
60
61 class TestDepthwiseConvolution2dLayerVisitor : public TestLayerVisitor
62 {
63 public:
TestDepthwiseConvolution2dLayerVisitor(const DepthwiseConvolution2dDescriptor & descriptor,const char * name=nullptr)64 explicit TestDepthwiseConvolution2dLayerVisitor(const DepthwiseConvolution2dDescriptor& descriptor,
65 const char* name = nullptr)
66 : TestLayerVisitor(name)
67 , m_Descriptor(descriptor)
68 {}
69
~TestDepthwiseConvolution2dLayerVisitor()70 virtual ~TestDepthwiseConvolution2dLayerVisitor() {}
71
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)72 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
73 const armnn::BaseDescriptor& descriptor,
74 const std::vector<armnn::ConstTensor>& constants,
75 const char* name,
76 const armnn::LayerBindingId id = 0) override
77 {
78 armnn::IgnoreUnused(descriptor, constants, id);
79 switch (layer->GetType())
80 {
81 case armnn::LayerType::DepthwiseConvolution2d:
82 {
83 CheckLayerPointer(layer);
84 CheckLayerName(name);
85 CheckDescriptor(static_cast<const armnn::DepthwiseConvolution2dDescriptor&>(descriptor));
86 break;
87 }
88 default:
89 {
90 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
91 }
92 }
93 }
94
95 protected:
96 void CheckDescriptor(const DepthwiseConvolution2dDescriptor& convolution2dDescriptor);
97
98 private:
99 DepthwiseConvolution2dDescriptor m_Descriptor;
100 };
101
102 class TestFullyConnectedLayerVistor : public TestLayerVisitor
103 {
104 public:
TestFullyConnectedLayerVistor(const FullyConnectedDescriptor & descriptor,const char * name=nullptr)105 explicit TestFullyConnectedLayerVistor(const FullyConnectedDescriptor& descriptor,
106 const char* name = nullptr)
107 : TestLayerVisitor(name)
108 , m_Descriptor(descriptor)
109 {}
110
~TestFullyConnectedLayerVistor()111 virtual ~TestFullyConnectedLayerVistor() {}
112
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)113 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
114 const armnn::BaseDescriptor& descriptor,
115 const std::vector<armnn::ConstTensor>& constants,
116 const char* name,
117 const armnn::LayerBindingId id = 0) override
118 {
119 armnn::IgnoreUnused(descriptor, constants, id);
120 switch (layer->GetType())
121 {
122 case armnn::LayerType::FullyConnected:
123 {
124 CheckLayerPointer(layer);
125 CheckLayerName(name);
126 CheckDescriptor(static_cast<const armnn::FullyConnectedDescriptor&>(descriptor));
127 break;
128 }
129 default:
130 {
131 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
132 }
133 }
134 }
135
136 protected:
137 void CheckDescriptor(const FullyConnectedDescriptor& descriptor);
138 private:
139 FullyConnectedDescriptor m_Descriptor;
140 };
141
142 class TestBatchNormalizationLayerVisitor : public TestLayerVisitor
143 {
144 public:
TestBatchNormalizationLayerVisitor(const BatchNormalizationDescriptor & descriptor,const ConstTensor & mean,const ConstTensor & variance,const ConstTensor & beta,const ConstTensor & gamma,const char * name=nullptr)145 TestBatchNormalizationLayerVisitor(const BatchNormalizationDescriptor& descriptor,
146 const ConstTensor& mean,
147 const ConstTensor& variance,
148 const ConstTensor& beta,
149 const ConstTensor& gamma,
150 const char* name = nullptr)
151 : TestLayerVisitor(name)
152 , m_Descriptor(descriptor)
153 , m_Mean(mean)
154 , m_Variance(variance)
155 , m_Beta(beta)
156 , m_Gamma(gamma)
157 {}
158
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)159 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
160 const armnn::BaseDescriptor& descriptor,
161 const std::vector<armnn::ConstTensor>& constants,
162 const char* name,
163 const armnn::LayerBindingId id = 0) override
164 {
165 armnn::IgnoreUnused(descriptor, constants, id);
166 switch (layer->GetType())
167 {
168 case armnn::LayerType::BatchNormalization:
169 {
170 CheckLayerPointer(layer);
171 CheckLayerName(name);
172 CheckDescriptor(static_cast<const armnn::BatchNormalizationDescriptor&>(descriptor));
173 CheckConstTensors(m_Mean, constants[0]);
174 CheckConstTensors(m_Variance, constants[1]);
175 CheckConstTensors(m_Beta, constants[2]);
176 CheckConstTensors(m_Gamma, constants[3]);
177 break;
178 }
179 default:
180 {
181 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
182 }
183 }
184 }
185
186 protected:
187 void CheckDescriptor(const BatchNormalizationDescriptor& descriptor);
188
189 private:
190 BatchNormalizationDescriptor m_Descriptor;
191 ConstTensor m_Mean;
192 ConstTensor m_Variance;
193 ConstTensor m_Beta;
194 ConstTensor m_Gamma;
195 };
196
197 class TestConstantLayerVisitor : public TestLayerVisitor
198 {
199 public:
TestConstantLayerVisitor(const ConstTensor & input,const char * name=nullptr)200 explicit TestConstantLayerVisitor(const ConstTensor& input,
201 const char* name = nullptr)
202 : TestLayerVisitor(name)
203 , m_Input(input)
204 {}
205
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)206 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
207 const armnn::BaseDescriptor& descriptor,
208 const std::vector<armnn::ConstTensor>& constants,
209 const char* name,
210 const armnn::LayerBindingId id = 0) override
211 {
212 armnn::IgnoreUnused(descriptor, constants, id);
213 switch (layer->GetType())
214 {
215 case armnn::LayerType::Constant:
216 {
217 CheckLayerPointer(layer);
218 CheckLayerName(name);
219 CheckConstTensors(m_Input, constants[0]);
220 break;
221 }
222 default:
223 {
224 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
225 }
226 }
227 }
228
229 private:
230 ConstTensor m_Input;
231 };
232
233 // Used to supply utility functions to the actual lstm test visitors
234 class LstmVisitor : public TestLayerVisitor
235 {
236 public:
LstmVisitor(const LstmInputParams & params,const char * name=nullptr)237 explicit LstmVisitor(const LstmInputParams& params,
238 const char* name = nullptr)
239 : TestLayerVisitor(name)
240 , m_InputParams(params) {}
241
242 protected:
243 template<typename LayerType>
244 void CheckInputParameters(const LayerType* layer, const LstmInputParams& inputParams);
245
246 LstmInputParams m_InputParams;
247 };
248
249 template<typename LayerType>
CheckInputParameters(const LayerType * layer,const LstmInputParams & inputParams)250 void LstmVisitor::CheckInputParameters(const LayerType* layer, const LstmInputParams& inputParams)
251 {
252 CheckConstTensorPtrs("OutputGateBias",
253 inputParams.m_OutputGateBias,
254 layer->m_BasicParameters.m_OutputGateBias);
255 CheckConstTensorPtrs("InputToForgetWeights",
256 inputParams.m_InputToForgetWeights,
257 layer->m_BasicParameters.m_InputToForgetWeights);
258 CheckConstTensorPtrs("InputToCellWeights",
259 inputParams.m_InputToCellWeights,
260 layer->m_BasicParameters.m_InputToCellWeights);
261 CheckConstTensorPtrs("InputToOutputWeights",
262 inputParams.m_InputToOutputWeights,
263 layer->m_BasicParameters.m_InputToOutputWeights);
264 CheckConstTensorPtrs("RecurrentToForgetWeights",
265 inputParams.m_RecurrentToForgetWeights,
266 layer->m_BasicParameters.m_RecurrentToForgetWeights);
267 CheckConstTensorPtrs("RecurrentToCellWeights",
268 inputParams.m_RecurrentToCellWeights,
269 layer->m_BasicParameters.m_RecurrentToCellWeights);
270 CheckConstTensorPtrs("RecurrentToOutputWeights",
271 inputParams.m_RecurrentToOutputWeights,
272 layer->m_BasicParameters.m_RecurrentToOutputWeights);
273 CheckConstTensorPtrs("ForgetGateBias",
274 inputParams.m_ForgetGateBias,
275 layer->m_BasicParameters.m_ForgetGateBias);
276 CheckConstTensorPtrs("CellBias",
277 inputParams.m_CellBias,
278 layer->m_BasicParameters.m_CellBias);
279
280 CheckConstTensorPtrs("InputToInputWeights",
281 inputParams.m_InputToInputWeights,
282 layer->m_CifgParameters.m_InputToInputWeights);
283 CheckConstTensorPtrs("RecurrentToInputWeights",
284 inputParams.m_RecurrentToInputWeights,
285 layer->m_CifgParameters.m_RecurrentToInputWeights);
286 CheckConstTensorPtrs("InputGateBias",
287 inputParams.m_InputGateBias,
288 layer->m_CifgParameters.m_InputGateBias);
289
290 CheckConstTensorPtrs("ProjectionBias",
291 inputParams.m_ProjectionBias,
292 layer->m_ProjectionParameters.m_ProjectionBias);
293 CheckConstTensorPtrs("ProjectionWeights",
294 inputParams.m_ProjectionWeights,
295 layer->m_ProjectionParameters.m_ProjectionWeights);
296
297 CheckConstTensorPtrs("CellToInputWeights",
298 inputParams.m_CellToInputWeights,
299 layer->m_PeepholeParameters.m_CellToInputWeights);
300 CheckConstTensorPtrs("CellToForgetWeights",
301 inputParams.m_CellToForgetWeights,
302 layer->m_PeepholeParameters.m_CellToForgetWeights);
303 CheckConstTensorPtrs("CellToOutputWeights",
304 inputParams.m_CellToOutputWeights,
305 layer->m_PeepholeParameters.m_CellToOutputWeights);
306
307 CheckConstTensorPtrs("InputLayerNormWeights",
308 inputParams.m_InputLayerNormWeights,
309 layer->m_LayerNormParameters.m_InputLayerNormWeights);
310 CheckConstTensorPtrs("ForgetLayerNormWeights",
311 inputParams.m_ForgetLayerNormWeights,
312 layer->m_LayerNormParameters.m_ForgetLayerNormWeights);
313 CheckConstTensorPtrs("CellLayerNormWeights",
314 inputParams.m_CellLayerNormWeights,
315 layer->m_LayerNormParameters.m_CellLayerNormWeights);
316 CheckConstTensorPtrs("OutputLayerNormWeights",
317 inputParams.m_OutputLayerNormWeights,
318 layer->m_LayerNormParameters.m_OutputLayerNormWeights);
319 }
320
321 class TestLstmLayerVisitor : public LstmVisitor
322 {
323 public:
TestLstmLayerVisitor(const LstmDescriptor & descriptor,const LstmInputParams & params,const char * name=nullptr)324 explicit TestLstmLayerVisitor(const LstmDescriptor& descriptor,
325 const LstmInputParams& params,
326 const char* name = nullptr)
327 : LstmVisitor(params, name)
328 , m_Descriptor(descriptor)
329 {}
330
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)331 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
332 const armnn::BaseDescriptor& descriptor,
333 const std::vector<armnn::ConstTensor>& constants,
334 const char* name,
335 const armnn::LayerBindingId id = 0) override
336 {
337 armnn::IgnoreUnused(descriptor, constants, id);
338 switch (layer->GetType())
339 {
340 case armnn::LayerType::Lstm:
341 {
342 CheckLayerPointer(layer);
343 CheckLayerName(name);
344 CheckDescriptor(static_cast<const armnn::LstmDescriptor&>(descriptor));
345 CheckInputParameters<const LstmLayer>(PolymorphicDowncast<const LstmLayer*>(layer), m_InputParams);
346 break;
347 }
348 default:
349 {
350 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
351 }
352 }
353 }
354
355 protected:
356 void CheckDescriptor(const LstmDescriptor& descriptor);
357
358 private:
359 LstmDescriptor m_Descriptor;
360 };
361
362 class TestQLstmLayerVisitor : public LstmVisitor
363 {
364 public:
TestQLstmLayerVisitor(const QLstmDescriptor & descriptor,const LstmInputParams & params,const char * name=nullptr)365 explicit TestQLstmLayerVisitor(const QLstmDescriptor& descriptor,
366 const LstmInputParams& params,
367 const char* name = nullptr)
368 : LstmVisitor(params, name)
369 , m_Descriptor(descriptor)
370 {}
371
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)372 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
373 const armnn::BaseDescriptor& descriptor,
374 const std::vector<armnn::ConstTensor>& constants,
375 const char* name,
376 const armnn::LayerBindingId id = 0) override
377 {
378 armnn::IgnoreUnused(descriptor, constants, id);
379 switch (layer->GetType())
380 {
381 case armnn::LayerType::QLstm:
382 {
383 CheckLayerPointer(layer);
384 CheckLayerName(name);
385 CheckDescriptor(static_cast<const armnn::QLstmDescriptor&>(descriptor));
386 CheckInputParameters<const QLstmLayer>(PolymorphicDowncast<const QLstmLayer*>(layer), m_InputParams);
387 break;
388 }
389 default:
390 {
391 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
392 }
393 }
394 }
395
396 protected:
397 void CheckDescriptor(const QLstmDescriptor& descriptor);
398
399 private:
400 QLstmDescriptor m_Descriptor;
401 };
402
403
404 class TestQuantizedLstmLayerVisitor : public TestLayerVisitor
405 {
406 public:
TestQuantizedLstmLayerVisitor(const QuantizedLstmInputParams & params,const char * name=nullptr)407 explicit TestQuantizedLstmLayerVisitor(const QuantizedLstmInputParams& params,
408 const char* name = nullptr)
409 : TestLayerVisitor(name)
410 , m_InputParams(params)
411 {}
412
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)413 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
414 const armnn::BaseDescriptor& descriptor,
415 const std::vector<armnn::ConstTensor>& constants,
416 const char* name,
417 const armnn::LayerBindingId id = 0) override
418 {
419 armnn::IgnoreUnused(descriptor, constants, id);
420 switch (layer->GetType())
421 {
422 case armnn::LayerType::QuantizedLstm:
423 {
424 CheckLayerPointer(layer);
425 CheckLayerName(name);
426 CheckInputParameters(m_InputParams);
427 break;
428 }
429 default:
430 {
431 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
432 }
433 }
434 }
435
436 protected:
437 void CheckInputParameters(const QuantizedLstmInputParams& params);
438
439 private:
440 QuantizedLstmInputParams m_InputParams;
441 };
442
443
444 } // namespace armnn
445