• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "TestLayerVisitor.hpp"
8 #include "LayersFwd.hpp"
9 #include <armnn/Descriptors.hpp>
10 #include <armnn/LstmParams.hpp>
11 #include <armnn/QuantizedLstmParams.hpp>
12 #include <armnn/utility/PolymorphicDowncast.hpp>
13 #include <armnn/backends/TensorHandle.hpp>
14 
15 #include <doctest/doctest.h>
16 
17 namespace armnn
18 {
19 
20 class TestConvolution2dLayerVisitor : public TestLayerVisitor
21 {
22 public:
TestConvolution2dLayerVisitor(const Convolution2dDescriptor & convolution2dDescriptor,const char * name=nullptr)23     explicit TestConvolution2dLayerVisitor(const Convolution2dDescriptor& convolution2dDescriptor,
24                                            const char* name = nullptr)
25         : TestLayerVisitor(name)
26         , m_Descriptor(convolution2dDescriptor)
27     {}
28 
~TestConvolution2dLayerVisitor()29     virtual ~TestConvolution2dLayerVisitor() {}
30 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)31     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
32                          const armnn::BaseDescriptor& descriptor,
33                          const std::vector<armnn::ConstTensor>& constants,
34                          const char* name,
35                          const armnn::LayerBindingId id = 0) override
36     {
37         armnn::IgnoreUnused(descriptor, constants, id);
38         switch (layer->GetType())
39         {
40             case armnn::LayerType::Convolution2d:
41             {
42                 CheckLayerPointer(layer);
43                 CheckLayerName(name);
44                 CheckDescriptor(static_cast<const armnn::Convolution2dDescriptor&>(descriptor));
45                 break;
46             }
47             default:
48             {
49                 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
50             }
51         }
52     }
53 
54 protected:
55     void CheckDescriptor(const Convolution2dDescriptor& convolution2dDescriptor);
56 
57 private:
58     Convolution2dDescriptor m_Descriptor;
59 };
60 
61 class TestDepthwiseConvolution2dLayerVisitor : public TestLayerVisitor
62 {
63 public:
TestDepthwiseConvolution2dLayerVisitor(const DepthwiseConvolution2dDescriptor & descriptor,const char * name=nullptr)64     explicit TestDepthwiseConvolution2dLayerVisitor(const DepthwiseConvolution2dDescriptor& descriptor,
65                                                     const char* name = nullptr)
66         : TestLayerVisitor(name)
67         , m_Descriptor(descriptor)
68     {}
69 
~TestDepthwiseConvolution2dLayerVisitor()70     virtual ~TestDepthwiseConvolution2dLayerVisitor() {}
71 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)72     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
73                          const armnn::BaseDescriptor& descriptor,
74                          const std::vector<armnn::ConstTensor>& constants,
75                          const char* name,
76                          const armnn::LayerBindingId id = 0) override
77     {
78         armnn::IgnoreUnused(descriptor, constants, id);
79         switch (layer->GetType())
80         {
81             case armnn::LayerType::DepthwiseConvolution2d:
82             {
83                 CheckLayerPointer(layer);
84                 CheckLayerName(name);
85                 CheckDescriptor(static_cast<const armnn::DepthwiseConvolution2dDescriptor&>(descriptor));
86                 break;
87             }
88             default:
89             {
90                 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
91             }
92         }
93     }
94 
95 protected:
96     void CheckDescriptor(const DepthwiseConvolution2dDescriptor& convolution2dDescriptor);
97 
98 private:
99     DepthwiseConvolution2dDescriptor m_Descriptor;
100 };
101 
102 class TestFullyConnectedLayerVistor : public TestLayerVisitor
103 {
104 public:
TestFullyConnectedLayerVistor(const FullyConnectedDescriptor & descriptor,const char * name=nullptr)105     explicit TestFullyConnectedLayerVistor(const FullyConnectedDescriptor& descriptor,
106                                            const char* name = nullptr)
107         : TestLayerVisitor(name)
108         , m_Descriptor(descriptor)
109     {}
110 
~TestFullyConnectedLayerVistor()111     virtual ~TestFullyConnectedLayerVistor() {}
112 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)113     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
114                          const armnn::BaseDescriptor& descriptor,
115                          const std::vector<armnn::ConstTensor>& constants,
116                          const char* name,
117                          const armnn::LayerBindingId id = 0) override
118     {
119         armnn::IgnoreUnused(descriptor, constants, id);
120         switch (layer->GetType())
121         {
122             case armnn::LayerType::FullyConnected:
123             {
124                 CheckLayerPointer(layer);
125                 CheckLayerName(name);
126                 CheckDescriptor(static_cast<const armnn::FullyConnectedDescriptor&>(descriptor));
127                 break;
128             }
129             default:
130             {
131                 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
132             }
133         }
134     }
135 
136 protected:
137     void CheckDescriptor(const FullyConnectedDescriptor& descriptor);
138 private:
139     FullyConnectedDescriptor m_Descriptor;
140 };
141 
142 class TestBatchNormalizationLayerVisitor : public TestLayerVisitor
143 {
144 public:
TestBatchNormalizationLayerVisitor(const BatchNormalizationDescriptor & descriptor,const ConstTensor & mean,const ConstTensor & variance,const ConstTensor & beta,const ConstTensor & gamma,const char * name=nullptr)145     TestBatchNormalizationLayerVisitor(const BatchNormalizationDescriptor& descriptor,
146                                        const ConstTensor& mean,
147                                        const ConstTensor& variance,
148                                        const ConstTensor& beta,
149                                        const ConstTensor& gamma,
150                                        const char* name = nullptr)
151         : TestLayerVisitor(name)
152         , m_Descriptor(descriptor)
153         , m_Mean(mean)
154         , m_Variance(variance)
155         , m_Beta(beta)
156         , m_Gamma(gamma)
157     {}
158 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)159     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
160                          const armnn::BaseDescriptor& descriptor,
161                          const std::vector<armnn::ConstTensor>& constants,
162                          const char* name,
163                          const armnn::LayerBindingId id = 0) override
164     {
165         armnn::IgnoreUnused(descriptor, constants, id);
166         switch (layer->GetType())
167         {
168             case armnn::LayerType::BatchNormalization:
169             {
170                 CheckLayerPointer(layer);
171                 CheckLayerName(name);
172                 CheckDescriptor(static_cast<const armnn::BatchNormalizationDescriptor&>(descriptor));
173                 CheckConstTensors(m_Mean,     constants[0]);
174                 CheckConstTensors(m_Variance, constants[1]);
175                 CheckConstTensors(m_Beta,     constants[2]);
176                 CheckConstTensors(m_Gamma,    constants[3]);
177                 break;
178             }
179             default:
180             {
181                 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
182             }
183         }
184     }
185 
186 protected:
187     void CheckDescriptor(const BatchNormalizationDescriptor& descriptor);
188 
189 private:
190     BatchNormalizationDescriptor m_Descriptor;
191     ConstTensor m_Mean;
192     ConstTensor m_Variance;
193     ConstTensor m_Beta;
194     ConstTensor m_Gamma;
195 };
196 
197 class TestConstantLayerVisitor : public TestLayerVisitor
198 {
199 public:
TestConstantLayerVisitor(const ConstTensor & input,const char * name=nullptr)200     explicit TestConstantLayerVisitor(const ConstTensor& input,
201                                       const char* name = nullptr)
202         : TestLayerVisitor(name)
203         , m_Input(input)
204     {}
205 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)206     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
207                          const armnn::BaseDescriptor& descriptor,
208                          const std::vector<armnn::ConstTensor>& constants,
209                          const char* name,
210                          const armnn::LayerBindingId id = 0) override
211     {
212         armnn::IgnoreUnused(descriptor, constants, id);
213         switch (layer->GetType())
214         {
215             case armnn::LayerType::Constant:
216             {
217                 CheckLayerPointer(layer);
218                 CheckLayerName(name);
219                 CheckConstTensors(m_Input, constants[0]);
220                 break;
221             }
222             default:
223             {
224                 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
225             }
226         }
227     }
228 
229 private:
230     ConstTensor m_Input;
231 };
232 
233 // Used to supply utility functions to the actual lstm test visitors
234 class LstmVisitor : public TestLayerVisitor
235 {
236 public:
LstmVisitor(const LstmInputParams & params,const char * name=nullptr)237     explicit LstmVisitor(const LstmInputParams& params,
238                          const char* name = nullptr)
239          : TestLayerVisitor(name)
240          , m_InputParams(params) {}
241 
242 protected:
243     template<typename LayerType>
244     void CheckInputParameters(const LayerType* layer, const LstmInputParams& inputParams);
245 
246     LstmInputParams m_InputParams;
247 };
248 
249 template<typename LayerType>
CheckInputParameters(const LayerType * layer,const LstmInputParams & inputParams)250 void LstmVisitor::CheckInputParameters(const LayerType* layer, const LstmInputParams& inputParams)
251 {
252     CheckConstTensorPtrs("OutputGateBias",
253                          inputParams.m_OutputGateBias,
254                          layer->m_BasicParameters.m_OutputGateBias);
255     CheckConstTensorPtrs("InputToForgetWeights",
256                          inputParams.m_InputToForgetWeights,
257                          layer->m_BasicParameters.m_InputToForgetWeights);
258     CheckConstTensorPtrs("InputToCellWeights",
259                          inputParams.m_InputToCellWeights,
260                          layer->m_BasicParameters.m_InputToCellWeights);
261     CheckConstTensorPtrs("InputToOutputWeights",
262                          inputParams.m_InputToOutputWeights,
263                          layer->m_BasicParameters.m_InputToOutputWeights);
264     CheckConstTensorPtrs("RecurrentToForgetWeights",
265                          inputParams.m_RecurrentToForgetWeights,
266                          layer->m_BasicParameters.m_RecurrentToForgetWeights);
267     CheckConstTensorPtrs("RecurrentToCellWeights",
268                          inputParams.m_RecurrentToCellWeights,
269                          layer->m_BasicParameters.m_RecurrentToCellWeights);
270     CheckConstTensorPtrs("RecurrentToOutputWeights",
271                          inputParams.m_RecurrentToOutputWeights,
272                          layer->m_BasicParameters.m_RecurrentToOutputWeights);
273     CheckConstTensorPtrs("ForgetGateBias",
274                          inputParams.m_ForgetGateBias,
275                          layer->m_BasicParameters.m_ForgetGateBias);
276     CheckConstTensorPtrs("CellBias",
277                          inputParams.m_CellBias,
278                          layer->m_BasicParameters.m_CellBias);
279 
280     CheckConstTensorPtrs("InputToInputWeights",
281                          inputParams.m_InputToInputWeights,
282                          layer->m_CifgParameters.m_InputToInputWeights);
283     CheckConstTensorPtrs("RecurrentToInputWeights",
284                          inputParams.m_RecurrentToInputWeights,
285                          layer->m_CifgParameters.m_RecurrentToInputWeights);
286     CheckConstTensorPtrs("InputGateBias",
287                          inputParams.m_InputGateBias,
288                          layer->m_CifgParameters.m_InputGateBias);
289 
290     CheckConstTensorPtrs("ProjectionBias",
291                          inputParams.m_ProjectionBias,
292                          layer->m_ProjectionParameters.m_ProjectionBias);
293     CheckConstTensorPtrs("ProjectionWeights",
294                          inputParams.m_ProjectionWeights,
295                          layer->m_ProjectionParameters.m_ProjectionWeights);
296 
297     CheckConstTensorPtrs("CellToInputWeights",
298                          inputParams.m_CellToInputWeights,
299                          layer->m_PeepholeParameters.m_CellToInputWeights);
300     CheckConstTensorPtrs("CellToForgetWeights",
301                          inputParams.m_CellToForgetWeights,
302                          layer->m_PeepholeParameters.m_CellToForgetWeights);
303     CheckConstTensorPtrs("CellToOutputWeights",
304                          inputParams.m_CellToOutputWeights,
305                          layer->m_PeepholeParameters.m_CellToOutputWeights);
306 
307     CheckConstTensorPtrs("InputLayerNormWeights",
308                          inputParams.m_InputLayerNormWeights,
309                          layer->m_LayerNormParameters.m_InputLayerNormWeights);
310     CheckConstTensorPtrs("ForgetLayerNormWeights",
311                          inputParams.m_ForgetLayerNormWeights,
312                          layer->m_LayerNormParameters.m_ForgetLayerNormWeights);
313     CheckConstTensorPtrs("CellLayerNormWeights",
314                          inputParams.m_CellLayerNormWeights,
315                          layer->m_LayerNormParameters.m_CellLayerNormWeights);
316     CheckConstTensorPtrs("OutputLayerNormWeights",
317                          inputParams.m_OutputLayerNormWeights,
318                          layer->m_LayerNormParameters.m_OutputLayerNormWeights);
319 }
320 
321 class TestLstmLayerVisitor : public LstmVisitor
322 {
323 public:
TestLstmLayerVisitor(const LstmDescriptor & descriptor,const LstmInputParams & params,const char * name=nullptr)324     explicit TestLstmLayerVisitor(const LstmDescriptor& descriptor,
325                                   const LstmInputParams& params,
326                                   const char* name = nullptr)
327         : LstmVisitor(params, name)
328         , m_Descriptor(descriptor)
329     {}
330 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)331     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
332                          const armnn::BaseDescriptor& descriptor,
333                          const std::vector<armnn::ConstTensor>& constants,
334                          const char* name,
335                          const armnn::LayerBindingId id = 0) override
336     {
337         armnn::IgnoreUnused(descriptor, constants, id);
338         switch (layer->GetType())
339         {
340             case armnn::LayerType::Lstm:
341             {
342                 CheckLayerPointer(layer);
343                 CheckLayerName(name);
344                 CheckDescriptor(static_cast<const armnn::LstmDescriptor&>(descriptor));
345                 CheckInputParameters<const LstmLayer>(PolymorphicDowncast<const LstmLayer*>(layer), m_InputParams);
346                 break;
347             }
348             default:
349             {
350                 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
351             }
352         }
353     }
354 
355 protected:
356     void CheckDescriptor(const LstmDescriptor& descriptor);
357 
358 private:
359     LstmDescriptor m_Descriptor;
360 };
361 
362 class TestQLstmLayerVisitor : public LstmVisitor
363 {
364 public:
TestQLstmLayerVisitor(const QLstmDescriptor & descriptor,const LstmInputParams & params,const char * name=nullptr)365     explicit TestQLstmLayerVisitor(const QLstmDescriptor& descriptor,
366                                    const LstmInputParams& params,
367                                    const char* name = nullptr)
368             : LstmVisitor(params, name)
369             , m_Descriptor(descriptor)
370     {}
371 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)372     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
373                          const armnn::BaseDescriptor& descriptor,
374                          const std::vector<armnn::ConstTensor>& constants,
375                          const char* name,
376                          const armnn::LayerBindingId id = 0) override
377     {
378         armnn::IgnoreUnused(descriptor, constants, id);
379         switch (layer->GetType())
380         {
381             case armnn::LayerType::QLstm:
382             {
383                 CheckLayerPointer(layer);
384                 CheckLayerName(name);
385                 CheckDescriptor(static_cast<const armnn::QLstmDescriptor&>(descriptor));
386                 CheckInputParameters<const QLstmLayer>(PolymorphicDowncast<const QLstmLayer*>(layer), m_InputParams);
387                 break;
388             }
389             default:
390             {
391                 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
392             }
393         }
394     }
395 
396 protected:
397     void CheckDescriptor(const QLstmDescriptor& descriptor);
398 
399 private:
400     QLstmDescriptor m_Descriptor;
401 };
402 
403 
404 class TestQuantizedLstmLayerVisitor : public TestLayerVisitor
405 {
406 public:
TestQuantizedLstmLayerVisitor(const QuantizedLstmInputParams & params,const char * name=nullptr)407     explicit TestQuantizedLstmLayerVisitor(const QuantizedLstmInputParams& params,
408                                            const char* name = nullptr)
409         : TestLayerVisitor(name)
410         , m_InputParams(params)
411     {}
412 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)413     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
414                          const armnn::BaseDescriptor& descriptor,
415                          const std::vector<armnn::ConstTensor>& constants,
416                          const char* name,
417                          const armnn::LayerBindingId id = 0) override
418     {
419         armnn::IgnoreUnused(descriptor, constants, id);
420         switch (layer->GetType())
421         {
422             case armnn::LayerType::QuantizedLstm:
423             {
424                 CheckLayerPointer(layer);
425                 CheckLayerName(name);
426                 CheckInputParameters(m_InputParams);
427                 break;
428             }
429             default:
430             {
431                 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
432             }
433         }
434     }
435 
436 protected:
437     void CheckInputParameters(const QuantizedLstmInputParams& params);
438 
439 private:
440     QuantizedLstmInputParams m_InputParams;
441 };
442 
443 
444 } // namespace armnn
445