• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "CommonGraphOptions.h"
25 
26 #include "arm_compute/core/Utils.h"
27 #include "arm_compute/graph/TypeLoader.h"
28 #include "arm_compute/graph/TypePrinter.h"
29 
30 #include "support/StringSupport.h"
31 
32 #include <map>
33 
34 using namespace arm_compute::graph;
35 
36 namespace
37 {
parse_validation_range(const std::string & validation_range)38 std::pair<unsigned int, unsigned int> parse_validation_range(const std::string &validation_range)
39 {
40     std::pair<unsigned int /* start */, unsigned int /* end */> range = { 0, std::numeric_limits<unsigned int>::max() };
41     if(!validation_range.empty())
42     {
43         std::string       str;
44         std::stringstream stream(validation_range);
45 
46         // Get first value
47         std::getline(stream, str, ',');
48         if(stream.fail())
49         {
50             return range;
51         }
52         else
53         {
54             range.first = arm_compute::support::cpp11::stoi(str);
55         }
56 
57         // Get second value
58         std::getline(stream, str);
59         if(stream.fail())
60         {
61             range.second = range.first;
62             return range;
63         }
64         else
65         {
66             range.second = arm_compute::support::cpp11::stoi(str);
67         }
68     }
69     return range;
70 }
71 } // namespace
72 
73 namespace arm_compute
74 {
75 namespace utils
76 {
operator <<(::std::ostream & os,const CommonGraphParams & common_params)77 ::std::ostream &operator<<(::std::ostream &os, const CommonGraphParams &common_params)
78 {
79     std::string false_str = std::string("false");
80     std::string true_str  = std::string("true");
81 
82     os << "Threads : " << common_params.threads << std::endl;
83     os << "Target : " << common_params.target << std::endl;
84     os << "Data type : " << common_params.data_type << std::endl;
85     os << "Data layout : " << common_params.data_layout << std::endl;
86     os << "Tuner enabled? : " << (common_params.enable_tuner ? true_str : false_str) << std::endl;
87     os << "Cache enabled? : " << (common_params.enable_cl_cache ? true_str : false_str) << std::endl;
88     os << "Tuner mode : " << common_params.tuner_mode << std::endl;
89     os << "Tuner file : " << common_params.tuner_file << std::endl;
90     os << "Fast math enabled? : " << (common_params.fast_math_hint == FastMathHint::Enabled ? true_str : false_str) << std::endl;
91     if(!common_params.data_path.empty())
92     {
93         os << "Data path : " << common_params.data_path << std::endl;
94     }
95     if(!common_params.image.empty())
96     {
97         os << "Image file : " << common_params.image << std::endl;
98     }
99     if(!common_params.labels.empty())
100     {
101         os << "Labels file : " << common_params.labels << std::endl;
102     }
103     if(!common_params.validation_file.empty())
104     {
105         os << "Validation range : " << common_params.validation_range_start << "-" << common_params.validation_range_end << std::endl;
106         os << "Validation file : " << common_params.validation_file << std::endl;
107         if(!common_params.validation_path.empty())
108         {
109             os << "Validation path : " << common_params.validation_path << std::endl;
110         }
111     }
112 
113     return os;
114 }
115 
CommonGraphOptions(CommandLineParser & parser)116 CommonGraphOptions::CommonGraphOptions(CommandLineParser &parser)
117     : help(parser.add_option<ToggleOption>("help")),
118       threads(parser.add_option<SimpleOption<int>>("threads", 1)),
119       target(),
120       data_type(),
121       data_layout(),
122       enable_tuner(parser.add_option<ToggleOption>("enable-tuner")),
123       enable_cl_cache(parser.add_option<ToggleOption>("enable-cl-cache")),
124       tuner_mode(),
125       fast_math_hint(parser.add_option<ToggleOption>("fast-math")),
126       data_path(parser.add_option<SimpleOption<std::string>>("data")),
127       image(parser.add_option<SimpleOption<std::string>>("image")),
128       labels(parser.add_option<SimpleOption<std::string>>("labels")),
129       validation_file(parser.add_option<SimpleOption<std::string>>("validation-file")),
130       validation_path(parser.add_option<SimpleOption<std::string>>("validation-path")),
131       validation_range(parser.add_option<SimpleOption<std::string>>("validation-range")),
132       tuner_file(parser.add_option<SimpleOption<std::string>>("tuner-file"))
133 {
134     std::set<arm_compute::graph::Target> supported_targets
135     {
136         Target::NEON,
137         Target::CL,
138         Target::GC,
139     };
140 
141     std::set<arm_compute::DataType> supported_data_types
142     {
143         DataType::F16,
144         DataType::F32,
145         DataType::QASYMM8,
146     };
147 
148     std::set<DataLayout> supported_data_layouts
149     {
150         DataLayout::NHWC,
151         DataLayout::NCHW,
152     };
153 
154     const std::set<CLTunerMode> supported_tuner_modes
155     {
156         CLTunerMode::EXHAUSTIVE,
157         CLTunerMode::NORMAL,
158         CLTunerMode::RAPID
159     };
160 
161     target      = parser.add_option<EnumOption<Target>>("target", supported_targets, Target::NEON);
162     data_type   = parser.add_option<EnumOption<DataType>>("type", supported_data_types, DataType::F32);
163     data_layout = parser.add_option<EnumOption<DataLayout>>("layout", supported_data_layouts);
164     tuner_mode  = parser.add_option<EnumOption<CLTunerMode>>("tuner-mode", supported_tuner_modes, CLTunerMode::NORMAL);
165 
166     help->set_help("Show this help message");
167     threads->set_help("Number of threads to use");
168     target->set_help("Target to execute on");
169     data_type->set_help("Data type to use");
170     data_layout->set_help("Data layout to use");
171     enable_tuner->set_help("Enable OpenCL dynamic tuner");
172     enable_cl_cache->set_help("Enable OpenCL program caches");
173     tuner_mode->set_help(
174         "Configures the time taken by the tuner to tune. "
175         "Exhaustive: slowest but produces the most performant LWS configuration. "
176         "Normal: slow but produces the LWS configurations on par with Exhaustive most of the time. "
177         "Rapid: fast but produces less performant LWS configurations");
178     fast_math_hint->set_help("Enable fast math");
179     data_path->set_help("Path where graph parameters reside");
180     image->set_help("Input image for the graph");
181     labels->set_help("File containing the output labels");
182     validation_file->set_help("File used to validate the graph");
183     validation_path->set_help("Path to the validation data");
184     validation_range->set_help("Range of the images to validate for (Format : start,end)");
185     tuner_file->set_help("File to load/save CLTuner values");
186 }
187 
consume_common_graph_parameters(CommonGraphOptions & options)188 CommonGraphParams consume_common_graph_parameters(CommonGraphOptions &options)
189 {
190     FastMathHint fast_math_hint_value = options.fast_math_hint->value() ? FastMathHint::Enabled : FastMathHint::Disabled;
191     auto         validation_range     = parse_validation_range(options.validation_range->value());
192 
193     CommonGraphParams common_params;
194     common_params.help      = options.help->is_set() ? options.help->value() : false;
195     common_params.threads   = options.threads->value();
196     common_params.target    = options.target->value();
197     common_params.data_type = options.data_type->value();
198     if(options.data_layout->is_set())
199     {
200         common_params.data_layout = options.data_layout->value();
201     }
202     common_params.enable_tuner           = options.enable_tuner->is_set() ? options.enable_tuner->value() : false;
203     common_params.enable_cl_cache        = common_params.target == arm_compute::graph::Target::CL ? (options.enable_cl_cache->is_set() ? options.enable_cl_cache->value() : true) : false;
204     common_params.tuner_mode             = options.tuner_mode->value();
205     common_params.fast_math_hint         = options.fast_math_hint->is_set() ? fast_math_hint_value : FastMathHint::Disabled;
206     common_params.data_path              = options.data_path->value();
207     common_params.image                  = options.image->value();
208     common_params.labels                 = options.labels->value();
209     common_params.validation_file        = options.validation_file->value();
210     common_params.validation_path        = options.validation_path->value();
211     common_params.validation_range_start = validation_range.first;
212     common_params.validation_range_end   = validation_range.second;
213     common_params.tuner_file             = options.tuner_file->value();
214 
215     return common_params;
216 }
217 } // namespace utils
218 } // namespace arm_compute
219