• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "InferenceModel.hpp"
8 
9 #include <armnn/ArmNN.hpp>
10 #include <armnn/Logging.hpp>
11 #include <armnn/TypesUtils.hpp>
12 #include <armnn/utility/IgnoreUnused.hpp>
13 
14 #include <cxxopts/cxxopts.hpp>
15 #include <fmt/format.h>
16 
17 
18 namespace armnn
19 {
20 
operator >>(std::istream & in,armnn::Compute & compute)21 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
22 {
23     std::string token;
24     in >> token;
25     compute = armnn::ParseComputeDevice(token.c_str());
26     if (compute == armnn::Compute::Undefined)
27     {
28         in.setstate(std::ios_base::failbit);
29         throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
30     }
31     return in;
32 }
33 
operator >>(std::istream & in,armnn::BackendId & backend)34 inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
35 {
36     std::string token;
37     in >> token;
38     armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
39     if (compute == armnn::Compute::Undefined)
40     {
41         in.setstate(std::ios_base::failbit);
42         throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
43     }
44     backend = compute;
45     return in;
46 }
47 
48 namespace test
49 {
50 
51 class TestFrameworkException : public Exception
52 {
53 public:
54     using Exception::Exception;
55 };
56 
57 struct InferenceTestOptions
58 {
59     unsigned int m_IterationCount;
60     std::string m_InferenceTimesFile;
61     bool m_EnableProfiling;
62     std::string m_DynamicBackendsPath;
63 
InferenceTestOptionsarmnn::test::InferenceTestOptions64     InferenceTestOptions()
65         : m_IterationCount(0)
66         , m_EnableProfiling(0)
67         , m_DynamicBackendsPath()
68     {}
69 };
70 
71 enum class TestCaseResult
72 {
73     /// The test completed without any errors.
74     Ok,
75     /// The test failed (e.g. the prediction didn't match the validation file).
76     /// This will eventually fail the whole program but the remaining test cases will still be run.
77     Failed,
78     /// The test failed with a fatal error. The remaining tests will not be run.
79     Abort
80 };
81 
82 class IInferenceTestCase
83 {
84 public:
~IInferenceTestCase()85     virtual ~IInferenceTestCase() {}
86 
87     virtual void Run() = 0;
88     virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
89 };
90 
91 class IInferenceTestCaseProvider
92 {
93 public:
~IInferenceTestCaseProvider()94     virtual ~IInferenceTestCaseProvider() {}
95 
AddCommandLineOptions(cxxopts::Options & options,std::vector<std::string> & required)96     virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required)
97     {
98         IgnoreUnused(options, required);
99     };
ProcessCommandLineOptions(const InferenceTestOptions & commonOptions)100     virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
101     {
102         IgnoreUnused(commonOptions);
103         return true;
104     };
105     virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
OnInferenceTestFinished()106     virtual bool OnInferenceTestFinished() { return true; };
107 };
108 
109 template <typename TModel>
110 class InferenceModelTestCase : public IInferenceTestCase
111 {
112 public:
113     using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
114 
InferenceModelTestCase(TModel & model,unsigned int testCaseId,const std::vector<TContainer> & inputs,const std::vector<unsigned int> & outputSizes)115     InferenceModelTestCase(TModel& model,
116                            unsigned int testCaseId,
117                            const std::vector<TContainer>& inputs,
118                            const std::vector<unsigned int>& outputSizes)
119         : m_Model(model)
120         , m_TestCaseId(testCaseId)
121         , m_Inputs(std::move(inputs))
122     {
123         // Initialize output vector
124         const size_t numOutputs = outputSizes.size();
125         m_Outputs.reserve(numOutputs);
126 
127         for (size_t i = 0; i < numOutputs; i++)
128         {
129             m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
130         }
131     }
132 
Run()133     virtual void Run() override
134     {
135         m_Model.Run(m_Inputs, m_Outputs);
136     }
137 
138 protected:
GetTestCaseId() const139     unsigned int GetTestCaseId() const { return m_TestCaseId; }
GetOutputs() const140     const std::vector<TContainer>& GetOutputs() const { return m_Outputs; }
141 
142 private:
143     TModel&                 m_Model;
144     unsigned int            m_TestCaseId;
145     std::vector<TContainer> m_Inputs;
146     std::vector<TContainer> m_Outputs;
147 };
148 
149 template <typename TTestCaseDatabase, typename TModel>
150 class ClassifierTestCase : public InferenceModelTestCase<TModel>
151 {
152 public:
153     ClassifierTestCase(int& numInferencesRef,
154         int& numCorrectInferencesRef,
155         const std::vector<unsigned int>& validationPredictions,
156         std::vector<unsigned int>* validationPredictionsOut,
157         TModel& model,
158         unsigned int testCaseId,
159         unsigned int label,
160         std::vector<typename TModel::DataType> modelInput);
161 
162     virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
163 
164 private:
165     unsigned int m_Label;
166     InferenceModelInternal::QuantizationParams m_QuantizationParams;
167 
168     /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
169     /// @{
170     int& m_NumInferencesRef;
171     int& m_NumCorrectInferencesRef;
172     const std::vector<unsigned int>& m_ValidationPredictions;
173     std::vector<unsigned int>* m_ValidationPredictionsOut;
174     /// @}
175 };
176 
177 template <typename TDatabase, typename InferenceModel>
178 class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
179 {
180 public:
181     template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
182     ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
183 
184     virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override;
185     virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
186     virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
187     virtual bool OnInferenceTestFinished() override;
188 
189 private:
190     void ReadPredictions();
191 
192     typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
193     std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
194                                                   typename InferenceModel::CommandLineOptions)> m_ConstructModel;
195     std::unique_ptr<InferenceModel> m_Model;
196 
197     std::string m_DataDir;
198     std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
199     std::unique_ptr<TDatabase> m_Database;
200 
201     int m_NumInferences; // Referenced by test cases.
202     int m_NumCorrectInferences; // Referenced by test cases.
203 
204     std::string m_ValidationFileIn;
205     std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
206 
207     std::string m_ValidationFileOut;
208     std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
209 };
210 
211 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
212     InferenceTestOptions& outParams);
213 
214 bool ValidateDirectory(std::string& dir);
215 
216 bool InferenceTest(const InferenceTestOptions& params,
217     const std::vector<unsigned int>& defaultTestCaseIds,
218     IInferenceTestCaseProvider& testCaseProvider);
219 
220 template<typename TConstructTestCaseProvider>
221 int InferenceTestMain(int argc,
222     char* argv[],
223     const std::vector<unsigned int>& defaultTestCaseIds,
224     TConstructTestCaseProvider constructTestCaseProvider);
225 
226 template<typename TDatabase,
227     typename TParser,
228     typename TConstructDatabaseCallable>
229 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
230     const char* inputBindingName, const char* outputBindingName,
231     const std::vector<unsigned int>& defaultTestCaseIds,
232     TConstructDatabaseCallable constructDatabase,
233     const armnn::TensorShape* inputTensorShape = nullptr);
234 
235 } // namespace test
236 } // namespace armnn
237 
238 #include "InferenceTest.inl"
239