• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/Logging.hpp>
8 #include <armnn/Utils.hpp>
9 #include <reference/RefWorkloadFactory.hpp>
10 #include <reference/test/RefWorkloadFactoryHelper.hpp>
11 
12 #include <backendsCommon/test/LayerTests.hpp>
13 #include <backendsCommon/test/WorkloadFactoryHelper.hpp>
14 #include "TensorHelpers.hpp"
15 #include <boost/test/unit_test.hpp>
16 
ConfigureLoggingTest()17 inline void ConfigureLoggingTest()
18 {
19     // Configures logging for both the ARMNN library and this test program.
20     armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
21 }
22 
23 // The following macros require the caller to have defined FactoryType, with one of the following using statements:
24 //
25 //      using FactoryType = armnn::RefWorkloadFactory;
26 //      using FactoryType = armnn::ClWorkloadFactory;
27 //      using FactoryType = armnn::NeonWorkloadFactory;
28 
29 /// Executes BOOST_TEST on CompareTensors() return value so that the predicate_result message is reported.
30 /// If the test reports itself as not supported then the tensors are not compared.
31 /// Additionally this checks that the supportedness reported by the test matches the name of the test.
32 /// Unsupported tests must be 'tagged' by including "UNSUPPORTED" in their name.
33 /// This is useful because it clarifies that the feature being tested is not actually supported
34 /// (a passed test with the name of a feature would imply that feature was supported).
35 /// If support is added for a feature, the test case will fail because the name incorrectly contains UNSUPPORTED.
36 /// If support is removed for a feature, the test case will fail because the name doesn't contain UNSUPPORTED.
37 template <typename T, std::size_t n>
CompareTestResultIfSupported(const std::string & testName,const LayerTestResult<T,n> & testResult)38 void CompareTestResultIfSupported(const std::string& testName, const LayerTestResult<T, n>& testResult)
39 {
40     bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
41     BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.supported,
42         "The test name does not match the supportedness it is reporting");
43     if (testResult.supported)
44     {
45         BOOST_TEST(CompareTensors(testResult.output, testResult.outputExpected, testResult.compareBoolean));
46     }
47 }
48 
49 template <typename T, std::size_t n>
CompareTestResultIfSupported(const std::string & testName,const std::vector<LayerTestResult<T,n>> & testResult)50 void CompareTestResultIfSupported(const std::string& testName, const std::vector<LayerTestResult<T, n>>& testResult)
51 {
52     bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
53     for (unsigned int i = 0; i < testResult.size(); ++i)
54     {
55         BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].supported,
56             "The test name does not match the supportedness it is reporting");
57         if (testResult[i].supported)
58         {
59             BOOST_TEST(CompareTensors(testResult[i].output, testResult[i].outputExpected));
60         }
61     }
62 }
63 
64 template<typename FactoryType, typename TFuncPtr, typename... Args>
RunTestFunction(const char * testName,TFuncPtr testFunction,Args...args)65 void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
66 {
67     std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>();
68     armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
69 
70     auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
71     FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
72 
73     auto testResult = (*testFunction)(workloadFactory, memoryManager, args...);
74     CompareTestResultIfSupported(testName, testResult);
75 
76     armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr);
77 }
78 
79 
80 template<typename FactoryType, typename TFuncPtr, typename... Args>
RunTestFunctionUsingTensorHandleFactory(const char * testName,TFuncPtr testFunction,Args...args)81 void RunTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args)
82 {
83     std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>();
84     armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
85 
86     auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
87     FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
88 
89     auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
90 
91     auto testResult = (*testFunction)(workloadFactory, memoryManager, tensorHandleFactory, args...);
92     CompareTestResultIfSupported(testName, testResult);
93 
94     armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr);
95 }
96 
97 #define ARMNN_SIMPLE_TEST_CASE(TestName, TestFunction) \
98     BOOST_AUTO_TEST_CASE(TestName) \
99     { \
100         TestFunction(); \
101     }
102 
103 #define ARMNN_AUTO_TEST_CASE(TestName, TestFunction, ...) \
104     BOOST_AUTO_TEST_CASE(TestName) \
105     { \
106         RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
107     }
108 
109 #define ARMNN_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
110     BOOST_AUTO_TEST_CASE(TestName) \
111     { \
112         RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
113     }
114 
115 template<typename FactoryType, typename TFuncPtr, typename... Args>
CompareRefTestFunction(const char * testName,TFuncPtr testFunction,Args...args)116 void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
117 {
118     auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
119     FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
120 
121     armnn::RefWorkloadFactory refWorkloadFactory;
122 
123     auto testResult = (*testFunction)(workloadFactory, memoryManager, refWorkloadFactory, args...);
124     CompareTestResultIfSupported(testName, testResult);
125 }
126 
127 template<typename FactoryType, typename TFuncPtr, typename... Args>
CompareRefTestFunctionUsingTensorHandleFactory(const char * testName,TFuncPtr testFunction,Args...args)128 void CompareRefTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args)
129 {
130     auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
131     FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
132 
133     armnn::RefWorkloadFactory refWorkloadFactory;
134     auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
135     auto refTensorHandleFactory =
136         RefWorkloadFactoryHelper::GetTensorHandleFactory(memoryManager);
137 
138     auto testResult = (*testFunction)(
139         workloadFactory, memoryManager, refWorkloadFactory, tensorHandleFactory, refTensorHandleFactory, args...);
140     CompareTestResultIfSupported(testName, testResult);
141 }
142 
143 #define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \
144     BOOST_AUTO_TEST_CASE(TestName) \
145     { \
146         CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
147     }
148 
149 #define ARMNN_COMPARE_REF_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
150     BOOST_AUTO_TEST_CASE(TestName) \
151     { \
152         CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
153     }
154 
155 #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \
156     BOOST_FIXTURE_TEST_CASE(TestName, Fixture) \
157     { \
158         CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
159     }
160 
161 #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE_WITH_THF(TestName, Fixture, TestFunction, ...) \
162     BOOST_FIXTURE_TEST_CASE(TestName, Fixture) \
163     { \
164         CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
165     }
166