• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "TestLayerVisitor.hpp"
8 #include <doctest/doctest.h>
9 
10 namespace armnn
11 {
12 
CheckLayerBindingId(LayerBindingId visitorId,LayerBindingId id)13 void CheckLayerBindingId(LayerBindingId visitorId, LayerBindingId id)
14 {
15     CHECK_EQ(visitorId, id);
16 }
17 
18 // Concrete TestLayerVisitor subclasses for layers taking LayerBindingId argument with overridden VisitLayer methods
19 class TestInputLayerVisitor : public TestLayerVisitor
20 {
21 private:
22     LayerBindingId visitorId;
23 
24 public:
TestInputLayerVisitor(LayerBindingId id,const char * name=nullptr)25     explicit TestInputLayerVisitor(LayerBindingId id, const char* name = nullptr)
26         : TestLayerVisitor(name)
27         , visitorId(id)
28     {};
29 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)30     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
31                          const armnn::BaseDescriptor& descriptor,
32                          const std::vector<armnn::ConstTensor>& constants,
33                          const char* name,
34                          const armnn::LayerBindingId id = 0) override
35     {
36         armnn::IgnoreUnused(descriptor, constants, id);
37         switch (layer->GetType())
38         {
39             case armnn::LayerType::Input:
40             {
41                 CheckLayerPointer(layer);
42                 CheckLayerBindingId(visitorId, id);
43                 CheckLayerName(name);
44                 break;
45             }
46             default:
47             {
48                 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
49             }
50         }
51     }
52 };
53 
54 class TestOutputLayerVisitor : public TestLayerVisitor
55 {
56 private:
57     LayerBindingId visitorId;
58 
59 public:
TestOutputLayerVisitor(LayerBindingId id,const char * name=nullptr)60     explicit TestOutputLayerVisitor(LayerBindingId id, const char* name = nullptr)
61         : TestLayerVisitor(name)
62         , visitorId(id)
63     {};
64 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)65     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
66                          const armnn::BaseDescriptor& descriptor,
67                          const std::vector<armnn::ConstTensor>& constants,
68                          const char* name,
69                          const armnn::LayerBindingId id = 0) override
70     {
71         armnn::IgnoreUnused(descriptor, constants, id);
72         switch (layer->GetType())
73         {
74             case armnn::LayerType::Output:
75             {
76                 CheckLayerPointer(layer);
77                 CheckLayerBindingId(visitorId, id);
78                 CheckLayerName(name);
79                 break;
80             }
81             default:
82             {
83                 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
84             }
85         }
86     }
87 };
88 
89 } //namespace armnn
90