• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/StrategyBase.hpp>
8 #include <armnn/Descriptors.hpp>
9 #include <armnn/backends/TensorHandle.hpp>
10 
11 namespace armnn
12 {
13 // Abstract base class with do nothing implementations for all layers
14 class TestLayerVisitor : public StrategyBase<NoThrowStrategy>
15 {
16 protected:
~TestLayerVisitor()17     virtual ~TestLayerVisitor() {}
18 
19     void CheckLayerName(const char* name);
20 
21     void CheckLayerPointer(const IConnectableLayer* layer);
22 
23     void CheckConstTensors(const ConstTensor& expected,
24                            const ConstTensor& actual);
25     void CheckConstTensors(const ConstTensor& expected,
26                            const ConstTensorHandle& actual);
27 
28     void CheckConstTensorPtrs(const std::string& name,
29                               const ConstTensor* expected,
30                               const ConstTensor* actual);
31     void CheckConstTensorPtrs(const std::string& name,
32                               const ConstTensor* expected,
33                               const std::shared_ptr<ConstTensorHandle> actual);
34 
35     void CheckOptionalConstTensors(const Optional<ConstTensor>& expected, const Optional<ConstTensor>& actual);
36 
37 private:
38     const char* m_LayerName;
39 
40 public:
TestLayerVisitor(const char * name)41     explicit TestLayerVisitor(const char* name) : m_LayerName(name)
42     {
43         if (name == nullptr)
44         {
45             m_LayerName = "";
46         }
47     }
48 };
49 
50 } //namespace armnn
51