• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2019-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #include "arm_compute/graph.h"
26 #include "support/ToolchainSupport.h"
27 #include "utils/CommonGraphOptions.h"
28 #include "utils/GraphUtils.h"
29 #include "utils/Utils.h"
30 
31 using namespace arm_compute;
32 using namespace arm_compute::utils;
33 using namespace arm_compute::graph::frontend;
34 using namespace arm_compute::graph_utils;
35 
36 /** Example demonstrating how to implement Mnist's network using the Compute Library's graph API */
37 class GraphMnistExample : public Example
38 {
39 public:
GraphMnistExample()40     GraphMnistExample()
41         : cmd_parser(), common_opts(cmd_parser), common_params(), graph(0, "LeNet")
42     {
43     }
do_setup(int argc,char ** argv)44     bool do_setup(int argc, char **argv) override
45     {
46         // Parse arguments
47         cmd_parser.parse(argc, argv);
48         cmd_parser.validate();
49 
50         // Consume common parameters
51         common_params = consume_common_graph_parameters(common_opts);
52 
53         // Return when help menu is requested
54         if(common_params.help)
55         {
56             cmd_parser.print_help(argv[0]);
57             return false;
58         }
59 
60         // Print parameter values
61         std::cout << common_params << std::endl;
62 
63         // Get trainable parameters data path
64         std::string data_path = common_params.data_path;
65 
66         // Add model path to data path
67         if(!data_path.empty() && arm_compute::is_data_type_quantized_asymmetric(common_params.data_type))
68         {
69             data_path += "/cnn_data/mnist_qasymm8_model/";
70         }
71 
72         // Create input descriptor
73         const auto        operation_layout = common_params.data_layout;
74         const TensorShape tensor_shape     = permute_shape(TensorShape(28U, 28U, 1U), DataLayout::NCHW, operation_layout);
75         TensorDescriptor  input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
76 
77         const QuantizationInfo in_quant_info = QuantizationInfo(0.003921568859368563f, 0);
78 
79         const std::vector<std::pair<QuantizationInfo, QuantizationInfo>> conv_quant_info =
80         {
81             { QuantizationInfo(0.004083447158336639f, 138), QuantizationInfo(0.0046257381327450275f, 0) }, // conv0
82             { QuantizationInfo(0.0048590428195893764f, 149), QuantizationInfo(0.03558270260691643f, 0) },  // conv1
83             { QuantizationInfo(0.004008443560451269f, 146), QuantizationInfo(0.09117382764816284f, 0) },   // conv2
84             { QuantizationInfo(0.004344311077147722f, 160), QuantizationInfo(0.5494495034217834f, 167) },  // fc
85         };
86 
87         // Set weights trained layout
88         const DataLayout        weights_layout = DataLayout::NHWC;
89         FullyConnectedLayerInfo fc_info        = FullyConnectedLayerInfo();
90         fc_info.set_weights_trained_layout(weights_layout);
91 
92         graph << common_params.target
93               << common_params.fast_math_hint
94               << InputLayer(input_descriptor.set_quantization_info(in_quant_info),
95                             get_input_accessor(common_params))
96               << ConvolutionLayer(
97                   3U, 3U, 32U,
98                   get_weights_accessor(data_path, "conv2d_weights_quant_FakeQuantWithMinMaxVars.npy", weights_layout),
99                   get_weights_accessor(data_path, "conv2d_Conv2D_bias.npy"),
100                   PadStrideInfo(1U, 1U, 1U, 1U), 1, conv_quant_info.at(0).first, conv_quant_info.at(0).second)
101               .set_name("Conv0")
102 
103               << ConvolutionLayer(
104                   3U, 3U, 32U,
105                   get_weights_accessor(data_path, "conv2d_1_weights_quant_FakeQuantWithMinMaxVars.npy", weights_layout),
106                   get_weights_accessor(data_path, "conv2d_1_Conv2D_bias.npy"),
107                   PadStrideInfo(1U, 1U, 1U, 1U), 1, conv_quant_info.at(1).first, conv_quant_info.at(1).second)
108               .set_name("conv1")
109 
110               << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, operation_layout, PadStrideInfo(2, 2, 0, 0))).set_name("maxpool1")
111 
112               << ConvolutionLayer(
113                   3U, 3U, 32U,
114                   get_weights_accessor(data_path, "conv2d_2_weights_quant_FakeQuantWithMinMaxVars.npy", weights_layout),
115                   get_weights_accessor(data_path, "conv2d_2_Conv2D_bias.npy"),
116                   PadStrideInfo(1U, 1U, 1U, 1U), 1, conv_quant_info.at(2).first, conv_quant_info.at(2).second)
117               .set_name("conv2")
118 
119               << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, operation_layout, PadStrideInfo(2, 2, 0, 0))).set_name("maxpool2")
120 
121               << FullyConnectedLayer(
122                   10U,
123                   get_weights_accessor(data_path, "dense_weights_quant_FakeQuantWithMinMaxVars_transpose.npy", weights_layout),
124                   get_weights_accessor(data_path, "dense_MatMul_bias.npy"),
125                   fc_info, conv_quant_info.at(3).first, conv_quant_info.at(3).second)
126               .set_name("fc")
127 
128               << SoftmaxLayer().set_name("prob");
129 
130         if(arm_compute::is_data_type_quantized_asymmetric(common_params.data_type))
131         {
132             graph << DequantizationLayer().set_name("dequantize");
133         }
134 
135         graph << OutputLayer(get_output_accessor(common_params, 5));
136 
137         // Finalize graph
138         GraphConfig config;
139         config.num_threads = common_params.threads;
140         config.use_tuner   = common_params.enable_tuner;
141         config.tuner_mode  = common_params.tuner_mode;
142         config.tuner_file  = common_params.tuner_file;
143 
144         graph.finalize(common_params.target, config);
145 
146         return true;
147     }
do_run()148     void do_run() override
149     {
150         // Run graph
151         graph.run();
152     }
153 
154 private:
155     CommandLineParser  cmd_parser;
156     CommonGraphOptions common_opts;
157     CommonGraphParams  common_params;
158     Stream             graph;
159 };
160 
161 /** Main program for Mnist Example
162  *
163  * @note To list all the possible arguments execute the binary appended with the --help option
164  *
165  * @param[in] argc Number of arguments
166  * @param[in] argv Arguments
167  */
main(int argc,char ** argv)168 int main(int argc, char **argv)
169 {
170     return arm_compute::utils::run_example<GraphMnistExample>(argc, argv);
171 }
172