1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "ModelAccuracyChecker.hpp"
6
7 #include <boost/test/unit_test.hpp>
8
9 #include <iostream>
10 #include <string>
11
12 using namespace armnnUtils;
13
14 namespace {
15 struct TestHelper
16 {
GetValidationLabelSet__anonf4b0a63b0111::TestHelper17 const std::map<std::string, std::string> GetValidationLabelSet()
18 {
19 std::map<std::string, std::string> validationLabelSet;
20 validationLabelSet.insert(std::make_pair("val_01.JPEG", "goldfinch"));
21 validationLabelSet.insert(std::make_pair("val_02.JPEG", "magpie"));
22 validationLabelSet.insert(std::make_pair("val_03.JPEG", "brambling"));
23 validationLabelSet.insert(std::make_pair("val_04.JPEG", "robin"));
24 validationLabelSet.insert(std::make_pair("val_05.JPEG", "indigo bird"));
25 validationLabelSet.insert(std::make_pair("val_06.JPEG", "ostrich"));
26 validationLabelSet.insert(std::make_pair("val_07.JPEG", "jay"));
27 validationLabelSet.insert(std::make_pair("val_08.JPEG", "snowbird"));
28 validationLabelSet.insert(std::make_pair("val_09.JPEG", "house finch"));
29 validationLabelSet.insert(std::make_pair("val_09.JPEG", "bulbul"));
30
31 return validationLabelSet;
32 }
GetModelOutputLabels__anonf4b0a63b0111::TestHelper33 const std::vector<armnnUtils::LabelCategoryNames> GetModelOutputLabels()
34 {
35 const std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels =
36 {
37 {"ostrich", "Struthio camelus"},
38 {"brambling", "Fringilla montifringilla"},
39 {"goldfinch", "Carduelis carduelis"},
40 {"house finch", "linnet", "Carpodacus mexicanus"},
41 {"junco", "snowbird"},
42 {"indigo bunting", "indigo finch", "indigo bird", "Passerina cyanea"},
43 {"robin", "American robin", "Turdus migratorius"},
44 {"bulbul"},
45 {"jay"},
46 {"magpie"}
47 };
48 return modelOutputLabels;
49 }
50 };
51 }
52
53 BOOST_AUTO_TEST_SUITE(ModelAccuracyCheckerTest)
54
55 using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
56
BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy,TestHelper)57 BOOST_FIXTURE_TEST_CASE(TestFloat32OutputTensorAccuracy, TestHelper)
58 {
59 ModelAccuracyChecker checker(GetValidationLabelSet(), GetModelOutputLabels());
60
61 // Add image 1 and check accuracy
62 std::vector<float> inferenceOutputVector1 = {0.05f, 0.10f, 0.70f, 0.15f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
63 TContainer inference1Container(inferenceOutputVector1);
64 std::vector<TContainer> outputTensor1;
65 outputTensor1.push_back(inference1Container);
66
67 std::string imageName = "val_01.JPEG";
68 checker.AddImageResult<TContainer>(imageName, outputTensor1);
69
70 // Top 1 Accuracy
71 float totalAccuracy = checker.GetAccuracy(1);
72 BOOST_CHECK(totalAccuracy == 100.0f);
73
74 // Add image 2 and check accuracy
75 std::vector<float> inferenceOutputVector2 = {0.10f, 0.0f, 0.0f, 0.0f, 0.05f, 0.70f, 0.0f, 0.0f, 0.0f, 0.15f};
76 TContainer inference2Container(inferenceOutputVector2);
77 std::vector<TContainer> outputTensor2;
78 outputTensor2.push_back(inference2Container);
79
80 imageName = "val_02.JPEG";
81 checker.AddImageResult<TContainer>(imageName, outputTensor2);
82
83 // Top 1 Accuracy
84 totalAccuracy = checker.GetAccuracy(1);
85 BOOST_CHECK(totalAccuracy == 50.0f);
86
87 // Top 2 Accuracy
88 totalAccuracy = checker.GetAccuracy(2);
89 BOOST_CHECK(totalAccuracy == 100.0f);
90
91 // Add image 3 and check accuracy
92 std::vector<float> inferenceOutputVector3 = {0.0f, 0.10f, 0.0f, 0.0f, 0.05f, 0.70f, 0.0f, 0.0f, 0.0f, 0.15f};
93 TContainer inference3Container(inferenceOutputVector3);
94 std::vector<TContainer> outputTensor3;
95 outputTensor3.push_back(inference3Container);
96
97 imageName = "val_03.JPEG";
98 checker.AddImageResult<TContainer>(imageName, outputTensor3);
99
100 // Top 1 Accuracy
101 totalAccuracy = checker.GetAccuracy(1);
102 BOOST_CHECK(totalAccuracy == 33.3333321f);
103
104 // Top 2 Accuracy
105 totalAccuracy = checker.GetAccuracy(2);
106 BOOST_CHECK(totalAccuracy == 66.6666641f);
107
108 // Top 3 Accuracy
109 totalAccuracy = checker.GetAccuracy(3);
110 BOOST_CHECK(totalAccuracy == 100.0f);
111 }
112
113 BOOST_AUTO_TEST_SUITE_END()
114