1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "TestLayerVisitor.hpp"
8 #include <doctest/doctest.h>
9
10 namespace armnn
11 {
12
CheckLayerBindingId(LayerBindingId visitorId,LayerBindingId id)13 void CheckLayerBindingId(LayerBindingId visitorId, LayerBindingId id)
14 {
15 CHECK_EQ(visitorId, id);
16 }
17
18 // Concrete TestLayerVisitor subclasses for layers taking LayerBindingId argument with overridden VisitLayer methods
19 class TestInputLayerVisitor : public TestLayerVisitor
20 {
21 private:
22 LayerBindingId visitorId;
23
24 public:
TestInputLayerVisitor(LayerBindingId id,const char * name=nullptr)25 explicit TestInputLayerVisitor(LayerBindingId id, const char* name = nullptr)
26 : TestLayerVisitor(name)
27 , visitorId(id)
28 {};
29
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)30 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
31 const armnn::BaseDescriptor& descriptor,
32 const std::vector<armnn::ConstTensor>& constants,
33 const char* name,
34 const armnn::LayerBindingId id = 0) override
35 {
36 armnn::IgnoreUnused(descriptor, constants, id);
37 switch (layer->GetType())
38 {
39 case armnn::LayerType::Input:
40 {
41 CheckLayerPointer(layer);
42 CheckLayerBindingId(visitorId, id);
43 CheckLayerName(name);
44 break;
45 }
46 default:
47 {
48 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
49 }
50 }
51 }
52 };
53
54 class TestOutputLayerVisitor : public TestLayerVisitor
55 {
56 private:
57 LayerBindingId visitorId;
58
59 public:
TestOutputLayerVisitor(LayerBindingId id,const char * name=nullptr)60 explicit TestOutputLayerVisitor(LayerBindingId id, const char* name = nullptr)
61 : TestLayerVisitor(name)
62 , visitorId(id)
63 {};
64
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)65 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
66 const armnn::BaseDescriptor& descriptor,
67 const std::vector<armnn::ConstTensor>& constants,
68 const char* name,
69 const armnn::LayerBindingId id = 0) override
70 {
71 armnn::IgnoreUnused(descriptor, constants, id);
72 switch (layer->GetType())
73 {
74 case armnn::LayerType::Output:
75 {
76 CheckLayerPointer(layer);
77 CheckLayerBindingId(visitorId, id);
78 CheckLayerName(name);
79 break;
80 }
81 default:
82 {
83 m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
84 }
85 }
86 }
87 };
88
89 } //namespace armnn
90