• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 
8 #include "../QuantizationDataSet.hpp"
9 
10 #include <armnn/Optional.hpp>
11 #include <Filesystem.hpp>
12 #include <iostream>
13 #include <fstream>
14 #include <vector>
15 #include <map>
16 
17 
18 using namespace armnnQuantizer;
19 
20 struct CsvTestHelper {
21 
CsvTestHelperCsvTestHelper22     CsvTestHelper()
23     {
24         BOOST_TEST_MESSAGE("setup fixture");
25     }
26 
~CsvTestHelperCsvTestHelper27     ~CsvTestHelper()
28     {
29         BOOST_TEST_MESSAGE("teardown fixture");
30         TearDown();
31     }
32 
CreateTempCsvFileCsvTestHelper33     std::string CreateTempCsvFile(std::map<int, std::vector<float>> csvData)
34     {
35         fs::path fileDir = fs::temp_directory_path();
36         fs::path p = armnnUtils::Filesystem::NamedTempFile("Armnn-QuantizationCreateTempCsvFileTest-TempFile.csv");
37 
38         fs::path tensorInput1{fileDir / "input_0_0.raw"};
39         fs::path tensorInput2{fileDir / "input_1_0.raw"};
40         fs::path tensorInput3{fileDir / "input_2_0.raw"};
41 
42         try
43         {
44             std::ofstream ofs{p};
45 
46             std::ofstream ofs1{tensorInput1};
47             std::ofstream ofs2{tensorInput2};
48             std::ofstream ofs3{tensorInput3};
49 
50 
51             for(auto entry : csvData.at(0))
52             {
53                 ofs1 << entry << " ";
54             }
55             for(auto entry : csvData.at(1))
56             {
57                 ofs2 << entry << " ";
58             }
59             for(auto entry : csvData.at(2))
60             {
61                 ofs3 << entry << " ";
62             }
63 
64             ofs << "0, 0, " << tensorInput1.c_str() << std::endl;
65             ofs << "2, 0, " << tensorInput3.c_str() << std::endl;
66             ofs << "1, 0, " << tensorInput2.c_str() << std::endl;
67 
68             ofs.close();
69             ofs1.close();
70             ofs2.close();
71             ofs3.close();
72         }
73         catch (std::exception &e)
74         {
75             std::cerr << "Unable to write to file at location [" << p.c_str() << "] : " << e.what() << std::endl;
76             BOOST_TEST(false);
77         }
78 
79         m_CsvFile = p;
80         return p.string();
81     }
82 
TearDownCsvTestHelper83     void TearDown()
84     {
85        RemoveCsvFile();
86     }
87 
RemoveCsvFileCsvTestHelper88     void RemoveCsvFile()
89     {
90         if (m_CsvFile)
91         {
92             try
93             {
94                 fs::remove(m_CsvFile.value());
95             }
96             catch (std::exception &e)
97             {
98                 std::cerr << "Unable to delete file [" << m_CsvFile.value() << "] : " << e.what() << std::endl;
99                 BOOST_TEST(false);
100             }
101         }
102     }
103 
104     armnn::Optional<fs::path> m_CsvFile;
105 };
106 
107 
108 BOOST_AUTO_TEST_SUITE(QuantizationDataSetTests)
109 
BOOST_FIXTURE_TEST_CASE(CheckDataSet,CsvTestHelper)110 BOOST_FIXTURE_TEST_CASE(CheckDataSet, CsvTestHelper)
111 {
112 
113     std::map<int, std::vector<float>> csvData;
114     csvData.insert(std::pair<int, std::vector<float>>(0, { 0.111111f, 0.222222f, 0.333333f }));
115     csvData.insert(std::pair<int, std::vector<float>>(1, { 0.444444f, 0.555555f, 0.666666f }));
116     csvData.insert(std::pair<int, std::vector<float>>(2, { 0.777777f, 0.888888f, 0.999999f }));
117 
118     std::string myCsvFile = CsvTestHelper::CreateTempCsvFile(csvData);
119     QuantizationDataSet dataSet(myCsvFile);
120     BOOST_TEST(!dataSet.IsEmpty());
121 
122     int csvRow = 0;
123     for(armnnQuantizer::QuantizationInput input : dataSet)
124     {
125         BOOST_TEST(input.GetPassId() == csvRow);
126 
127         BOOST_TEST(input.GetLayerBindingIds().size() == 1);
128         BOOST_TEST(input.GetLayerBindingIds()[0] == 0);
129         BOOST_TEST(input.GetDataForEntry(0).size() == 3);
130 
131         // Check that QuantizationInput data for binding ID 0 corresponds to float values
132         // used for populating the CSV file using by QuantizationDataSet
133         BOOST_TEST(input.GetDataForEntry(0).at(0) == csvData.at(csvRow).at(0));
134         BOOST_TEST(input.GetDataForEntry(0).at(1) == csvData.at(csvRow).at(1));
135         BOOST_TEST(input.GetDataForEntry(0).at(2) == csvData.at(csvRow).at(2));
136         ++csvRow;
137     }
138 }
139 
140 BOOST_AUTO_TEST_SUITE_END();