• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "ModelAccuracyChecker.hpp"
6 
7 #include <boost/test/unit_test.hpp>
8 
9 #include <iostream>
10 #include <string>
11 
12 using namespace armnnUtils;
13 
14 namespace {
15 struct TestHelper
16 {
GetValidationLabelSet__anonf4b0a63b0111::TestHelper17     const std::map<std::string, std::string> GetValidationLabelSet()
18     {
19         std::map<std::string, std::string> validationLabelSet;
20         validationLabelSet.insert(std::make_pair("val_01.JPEG", "goldfinch"));
21         validationLabelSet.insert(std::make_pair("val_02.JPEG", "magpie"));
22         validationLabelSet.insert(std::make_pair("val_03.JPEG", "brambling"));
23         validationLabelSet.insert(std::make_pair("val_04.JPEG", "robin"));
24         validationLabelSet.insert(std::make_pair("val_05.JPEG", "indigo bird"));
25         validationLabelSet.insert(std::make_pair("val_06.JPEG", "ostrich"));
26         validationLabelSet.insert(std::make_pair("val_07.JPEG", "jay"));
27         validationLabelSet.insert(std::make_pair("val_08.JPEG", "snowbird"));
28         validationLabelSet.insert(std::make_pair("val_09.JPEG", "house finch"));
29         validationLabelSet.insert(std::make_pair("val_09.JPEG", "bulbul"));
30 
31         return validationLabelSet;
32     }
GetModelOutputLabels__anonf4b0a63b0111::TestHelper33     const std::vector<armnnUtils::LabelCategoryNames> GetModelOutputLabels()
34     {
35         const std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels =
36         {
37             {"ostrich", "Struthio camelus"},
38             {"brambling", "Fringilla montifringilla"},
39             {"goldfinch", "Carduelis carduelis"},
40             {"house finch", "linnet", "Carpodacus mexicanus"},
41             {"junco", "snowbird"},
42             {"indigo bunting", "indigo finch", "indigo bird", "Passerina cyanea"},
43             {"robin", "American robin", "Turdus migratorius"},
44             {"bulbul"},
45             {"jay"},
46             {"magpie"}
47         };
48         return modelOutputLabels;
49     }
50 };
51 }
52 
53 BOOST_AUTO_TEST_SUITE(ModelAccuracyCheckerTest)
54 
55 using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
56 
BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy,TestHelper)57 BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper)
58 {
59     ModelAccuracyChecker checker(GetValidationLabelSet(), GetModelOutputLabels());
60 
61     // Add image 1 and check accuracy
62     std::vector<float> inferenceOutputVector1 = {0.05f, 0.10f, 0.70f, 0.15f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
63     TContainer inference1Container(inferenceOutputVector1);
64     std::vector<TContainer> outputTensor1;
65     outputTensor1.push_back(inference1Container);
66 
67     std::string imageName = "val_01.JPEG";
68     checker.AddImageResult<TContainer>(imageName, outputTensor1);
69 
70     // Top 1 Accuracy
71     float totalAccuracy = checker.GetAccuracy(1);
72     BOOST_CHECK(totalAccuracy == 100.0f);
73 
74     // Add image 2 and check accuracy
75     std::vector<float> inferenceOutputVector2 = {0.10f, 0.0f, 0.0f, 0.0f, 0.05f, 0.70f, 0.0f, 0.0f, 0.0f, 0.15f};
76     TContainer inference2Container(inferenceOutputVector2);
77     std::vector<TContainer> outputTensor2;
78     outputTensor2.push_back(inference2Container);
79 
80     imageName = "val_02.JPEG";
81     checker.AddImageResult<TContainer>(imageName, outputTensor2);
82 
83     // Top 1 Accuracy
84     totalAccuracy = checker.GetAccuracy(1);
85     BOOST_CHECK(totalAccuracy == 50.0f);
86 
87     // Top 2 Accuracy
88     totalAccuracy = checker.GetAccuracy(2);
89     BOOST_CHECK(totalAccuracy == 100.0f);
90 
91     // Add image 3 and check accuracy
92     std::vector<float> inferenceOutputVector3 = {0.0f, 0.10f, 0.0f, 0.0f, 0.05f, 0.70f, 0.0f, 0.0f, 0.0f, 0.15f};
93     TContainer inference3Container(inferenceOutputVector3);
94     std::vector<TContainer> outputTensor3;
95     outputTensor3.push_back(inference3Container);
96 
97     imageName = "val_03.JPEG";
98     checker.AddImageResult<TContainer>(imageName, outputTensor3);
99 
100     // Top 1 Accuracy
101     totalAccuracy = checker.GetAccuracy(1);
102     BOOST_CHECK(totalAccuracy == 33.3333321f);
103 
104     // Top 2 Accuracy
105     totalAccuracy = checker.GetAccuracy(2);
106     BOOST_CHECK(totalAccuracy == 66.6666641f);
107 
108     // Top 3 Accuracy
109     totalAccuracy = checker.GetAccuracy(3);
110     BOOST_CHECK(totalAccuracy == 100.0f);
111 }
112 
113 BOOST_AUTO_TEST_SUITE_END()
114