1// 2// Copyright © 2017 Arm Ltd. All rights reserved. 3// SPDX-License-Identifier: MIT 4// 5#include "InferenceTest.hpp" 6 7#include <armnn/utility/Assert.hpp> 8#include <armnn/utility/NumericCast.hpp> 9#include "CxxoptsUtils.hpp" 10 11#include <cxxopts/cxxopts.hpp> 12#include <fmt/format.h> 13 14#include <fstream> 15#include <iostream> 16#include <iomanip> 17#include <array> 18#include <chrono> 19 20using namespace std; 21using namespace std::chrono; 22using namespace armnn::test; 23 24namespace armnn 25{ 26namespace test 27{ 28 29using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>; 30 31template <typename TTestCaseDatabase, typename TModel> 32ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase( 33 int& numInferencesRef, 34 int& numCorrectInferencesRef, 35 const std::vector<unsigned int>& validationPredictions, 36 std::vector<unsigned int>* validationPredictionsOut, 37 TModel& model, 38 unsigned int testCaseId, 39 unsigned int label, 40 std::vector<typename TModel::DataType> modelInput) 41 : InferenceModelTestCase<TModel>( 42 model, testCaseId, std::vector<TContainer>{ modelInput }, { model.GetOutputSize() }) 43 , m_Label(label) 44 , m_QuantizationParams(model.GetQuantizationParams()) 45 , m_NumInferencesRef(numInferencesRef) 46 , m_NumCorrectInferencesRef(numCorrectInferencesRef) 47 , m_ValidationPredictions(validationPredictions) 48 , m_ValidationPredictionsOut(validationPredictionsOut) 49{ 50} 51 52struct ClassifierResultProcessor 53{ 54 using ResultMap = std::map<float,int>; 55 56 ClassifierResultProcessor(float scale, int offset) 57 : m_Scale(scale) 58 , m_Offset(offset) 59 {} 60 61 void operator()(const std::vector<float>& values) 62 { 63 SortPredictions(values, [](float value) 64 { 65 return value; 66 }); 67 } 68 69 void operator()(const std::vector<uint8_t>& values) 70 { 71 auto& scale = m_Scale; 72 auto& offset = m_Offset; 73 SortPredictions(values, [&scale, &offset](uint8_t value) 74 { 75 return armnn::Dequantize(value, scale, offset); 76 }); 77 } 78 79 void operator()(const std::vector<int>& values) 80 { 81 IgnoreUnused(values); 82 ARMNN_ASSERT_MSG(false, "Non-float predictions output not supported."); 83 } 84 85 ResultMap& GetResultMap() { return m_ResultMap; } 86 87private: 88 template<typename Container, typename Delegate> 89 void SortPredictions(const Container& c, Delegate delegate) 90 { 91 int index = 0; 92 for (const auto& value : c) 93 { 94 int classification = index++; 95 // Take the first class with each probability 96 // This avoids strange results when looping over batched results produced 97 // with identical test data. 98 ResultMap::iterator lb = m_ResultMap.lower_bound(value); 99 100 if (lb == m_ResultMap.end() || !m_ResultMap.key_comp()(value, lb->first)) 101 { 102 // If the key is not already in the map, insert it. 103 m_ResultMap.insert(lb, ResultMap::value_type(delegate(value), classification)); 104 } 105 } 106 } 107 108 ResultMap m_ResultMap; 109 110 float m_Scale=0.0f; 111 int m_Offset=0; 112}; 113 114template <typename TTestCaseDatabase, typename TModel> 115TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params) 116{ 117 auto& output = this->GetOutputs()[0]; 118 const auto testCaseId = this->GetTestCaseId(); 119 120 ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second); 121 mapbox::util::apply_visitor(resultProcessor, output); 122 123 ARMNN_LOG(info) << "= Prediction values for test #" << testCaseId; 124 auto it = resultProcessor.GetResultMap().rbegin(); 125 for (int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i) 126 { 127 ARMNN_LOG(info) << "Top(" << (i+1) << ") prediction is " << it->second << 128 " with value: " << (it->first); 129 ++it; 130 } 131 132 unsigned int prediction = 0; 133 mapbox::util::apply_visitor([&](auto&& value) 134 { 135 prediction = armnn::numeric_cast<unsigned int>( 136 std::distance(value.begin(), std::max_element(value.begin(), value.end()))); 137 }, 138 output); 139 140 // If we're just running the defaultTestCaseIds, each one must be classified correctly. 141 if (params.m_IterationCount == 0 && prediction != m_Label) 142 { 143 ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" << 144 " is incorrect (should be " << m_Label << ")"; 145 return TestCaseResult::Failed; 146 } 147 148 // If a validation file was provided as input, it checks that the prediction matches. 149 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId]) 150 { 151 ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" << 152 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")"; 153 return TestCaseResult::Failed; 154 } 155 156 // If a validation file was requested as output, it stores the predictions. 157 if (m_ValidationPredictionsOut) 158 { 159 m_ValidationPredictionsOut->push_back(prediction); 160 } 161 162 // Updates accuracy stats. 163 m_NumInferencesRef++; 164 if (prediction == m_Label) 165 { 166 m_NumCorrectInferencesRef++; 167 } 168 169 return TestCaseResult::Ok; 170} 171 172template <typename TDatabase, typename InferenceModel> 173template <typename TConstructDatabaseCallable, typename TConstructModelCallable> 174ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider( 175 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel) 176 : m_ConstructModel(constructModel) 177 , m_ConstructDatabase(constructDatabase) 178 , m_NumInferences(0) 179 , m_NumCorrectInferences(0) 180{ 181} 182 183template <typename TDatabase, typename InferenceModel> 184void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions( 185 cxxopts::Options& options, std::vector<std::string>& required) 186{ 187 options 188 .allow_unrecognised_options() 189 .add_options() 190 ("validation-file-in", 191 "Reads expected predictions from the given file and confirms they match the actual predictions.", 192 cxxopts::value<std::string>(m_ValidationFileIn)->default_value("")) 193 ("validation-file-out", "Predictions are saved to the given file for later use via --validation-file-in.", 194 cxxopts::value<std::string>(m_ValidationFileOut)->default_value("")) 195 ("d,data-dir", "Path to directory containing test data", cxxopts::value<std::string>(m_DataDir)); 196 197 required.emplace_back("data-dir"); //add to required arguments to check 198 199 InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions, required); 200} 201 202template <typename TDatabase, typename InferenceModel> 203bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions( 204 const InferenceTestOptions& commonOptions) 205{ 206 if (!ValidateDirectory(m_DataDir)) 207 { 208 return false; 209 } 210 211 ReadPredictions(); 212 213 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions); 214 if (!m_Model) 215 { 216 return false; 217 } 218 219 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model)); 220 if (!m_Database) 221 { 222 return false; 223 } 224 225 return true; 226} 227 228template <typename TDatabase, typename InferenceModel> 229std::unique_ptr<IInferenceTestCase> 230ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId) 231{ 232 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId); 233 if (testCaseData == nullptr) 234 { 235 return nullptr; 236 } 237 238 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>( 239 m_NumInferences, 240 m_NumCorrectInferences, 241 m_ValidationPredictions, 242 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut, 243 *m_Model, 244 testCaseId, 245 testCaseData->m_Label, 246 std::move(testCaseData->m_InputImage)); 247} 248 249template <typename TDatabase, typename InferenceModel> 250bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished() 251{ 252 const double accuracy = armnn::numeric_cast<double>(m_NumCorrectInferences) / 253 armnn::numeric_cast<double>(m_NumInferences); 254 ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy; 255 256 // If a validation file was requested as output, the predictions are saved to it. 257 if (!m_ValidationFileOut.empty()) 258 { 259 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out); 260 if (validationFileOut.good()) 261 { 262 for (const unsigned int prediction : m_ValidationPredictionsOut) 263 { 264 validationFileOut << prediction << std::endl; 265 } 266 } 267 else 268 { 269 ARMNN_LOG(error) << "Failed to open output validation file: " << m_ValidationFileOut; 270 return false; 271 } 272 } 273 274 return true; 275} 276 277template <typename TDatabase, typename InferenceModel> 278void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions() 279{ 280 // Reads the expected predictions from the input validation file (if provided). 281 if (!m_ValidationFileIn.empty()) 282 { 283 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in); 284 if (validationFileIn.good()) 285 { 286 while (!validationFileIn.eof()) 287 { 288 unsigned int i; 289 validationFileIn >> i; 290 m_ValidationPredictions.emplace_back(i); 291 } 292 } 293 else 294 { 295 throw armnn::Exception(fmt::format("Failed to open input validation file: {}" 296 , m_ValidationFileIn)); 297 } 298 } 299} 300 301template<typename TConstructTestCaseProvider> 302int InferenceTestMain(int argc, 303 char* argv[], 304 const std::vector<unsigned int>& defaultTestCaseIds, 305 TConstructTestCaseProvider constructTestCaseProvider) 306{ 307 // Configures logging for both the ARMNN library and this test program. 308#ifdef NDEBUG 309 armnn::LogSeverity level = armnn::LogSeverity::Info; 310#else 311 armnn::LogSeverity level = armnn::LogSeverity::Debug; 312#endif 313 armnn::ConfigureLogging(true, true, level); 314 315 try 316 { 317 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider(); 318 if (!testCaseProvider) 319 { 320 return 1; 321 } 322 323 InferenceTestOptions inferenceTestOptions; 324 if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions)) 325 { 326 return 1; 327 } 328 329 const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider); 330 return success ? 0 : 1; 331 } 332 catch (armnn::Exception const& e) 333 { 334 ARMNN_LOG(fatal) << "Armnn Error: " << e.what(); 335 return 1; 336 } 337} 338 339// 340// This function allows us to create a classifier inference test based on: 341// - a model file name 342// - which can be a binary or a text file for protobuf formats 343// - an input tensor name 344// - an output tensor name 345// - a set of test case ids 346// - a callback method which creates an object that can return images 347// called 'Database' in these tests 348// - and an input tensor shape 349// 350template<typename TDatabase, 351 typename TParser, 352 typename TConstructDatabaseCallable> 353int ClassifierInferenceTestMain(int argc, 354 char* argv[], 355 const char* modelFilename, 356 bool isModelBinary, 357 const char* inputBindingName, 358 const char* outputBindingName, 359 const std::vector<unsigned int>& defaultTestCaseIds, 360 TConstructDatabaseCallable constructDatabase, 361 const armnn::TensorShape* inputTensorShape) 362 363{ 364 ARMNN_ASSERT(modelFilename); 365 ARMNN_ASSERT(inputBindingName); 366 ARMNN_ASSERT(outputBindingName); 367 368 return InferenceTestMain(argc, argv, defaultTestCaseIds, 369 [=] 370 () 371 { 372 using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>; 373 using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>; 374 375 return make_unique<TestCaseProvider>(constructDatabase, 376 [&] 377 (const InferenceTestOptions &commonOptions, 378 typename InferenceModel::CommandLineOptions modelOptions) 379 { 380 if (!ValidateDirectory(modelOptions.m_ModelDir)) 381 { 382 return std::unique_ptr<InferenceModel>(); 383 } 384 385 typename InferenceModel::Params modelParams; 386 modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename; 387 modelParams.m_InputBindings = { inputBindingName }; 388 modelParams.m_OutputBindings = { outputBindingName }; 389 390 if (inputTensorShape) 391 { 392 modelParams.m_InputShapes.push_back(*inputTensorShape); 393 } 394 395 modelParams.m_IsModelBinary = isModelBinary; 396 modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds(); 397 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel; 398 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode; 399 400 return std::make_unique<InferenceModel>(modelParams, 401 commonOptions.m_EnableProfiling, 402 commonOptions.m_DynamicBackendsPath); 403 }); 404 }); 405} 406 407} // namespace test 408} // namespace armnn 409