• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <boost/test/unit_test.hpp>
6 #include "../OnnxParser.hpp"
7 #include  "ParserPrototxtFixture.hpp"
8 #include <onnx/onnx.pb.h>
9 #include "google/protobuf/stubs/logging.h"
10 
11 
12 using ModelPtr = std::unique_ptr<onnx::ModelProto>;
13 
14 BOOST_AUTO_TEST_SUITE(OnnxParser)
15 
16 struct GetInputsOutputsMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
17 {
GetInputsOutputsMainFixtureGetInputsOutputsMainFixture18     explicit GetInputsOutputsMainFixture()
19     {
20         m_Prototext = R"(
21                    ir_version: 3
22                    producer_name:  "CNTK"
23                    producer_version:  "2.5.1"
24                    domain:  "ai.cntk"
25                    model_version: 1
26                    graph {
27                      name:  "CNTKGraph"
28                      input {
29                         name: "Input"
30                         type {
31                           tensor_type {
32                             elem_type: 1
33                             shape {
34                               dim {
35                                 dim_value: 4
36                               }
37                             }
38                           }
39                         }
40                       }
41                      node {
42                          input: "Input"
43                          output: "Output"
44                          name: "ActivationLayer"
45                          op_type: "Relu"
46                     }
47                       output {
48                           name: "Output"
49                           type {
50                              tensor_type {
51                                elem_type: 1
52                                shape {
53                                    dim {
54                                        dim_value: 4
55                                    }
56                                }
57                             }
58                          }
59                       }
60                     }
61                    opset_import {
62                       version: 7
63                     })";
64         Setup();
65     }
66 };
67 
68 
BOOST_FIXTURE_TEST_CASE(GetInput,GetInputsOutputsMainFixture)69 BOOST_FIXTURE_TEST_CASE(GetInput, GetInputsOutputsMainFixture)
70 {
71     ModelPtr model = armnnOnnxParser::OnnxParser::LoadModelFromString(m_Prototext.c_str());
72     std::vector<std::string> tensors = armnnOnnxParser::OnnxParser::GetInputs(model);
73     BOOST_CHECK_EQUAL(1, tensors.size());
74     BOOST_CHECK_EQUAL("Input", tensors[0]);
75 
76 }
77 
BOOST_FIXTURE_TEST_CASE(GetOutput,GetInputsOutputsMainFixture)78 BOOST_FIXTURE_TEST_CASE(GetOutput, GetInputsOutputsMainFixture)
79 {
80     ModelPtr model = armnnOnnxParser::OnnxParser::LoadModelFromString(m_Prototext.c_str());
81     std::vector<std::string> tensors = armnnOnnxParser::OnnxParser::GetOutputs(model);
82     BOOST_CHECK_EQUAL(1, tensors.size());
83     BOOST_CHECK_EQUAL("Output", tensors[0]);
84 }
85 
86 struct GetEmptyInputsOutputsFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
87 {
GetEmptyInputsOutputsFixtureGetEmptyInputsOutputsFixture88     GetEmptyInputsOutputsFixture()
89     {
90         m_Prototext = R"(
91                    ir_version: 3
92                    producer_name:  "CNTK "
93                    producer_version:  "2.5.1 "
94                    domain:  "ai.cntk "
95                    model_version: 1
96                    graph {
97                      name:  "CNTKGraph "
98                      node {
99                         output:  "Output"
100                         attribute {
101                           name: "value"
102                           t {
103                               dims: 7
104                               data_type: 1
105                               float_data: 0.0
106                               float_data: 1.0
107                               float_data: 2.0
108                               float_data: 3.0
109                               float_data: 4.0
110                               float_data: 5.0
111                               float_data: 6.0
112 
113                           }
114                           type: 1
115                         }
116                         name:  "constantNode"
117                         op_type:  "Constant"
118                       }
119                       output {
120                           name:  "Output"
121                           type {
122                              tensor_type {
123                                elem_type: 1
124                                shape {
125                                  dim {
126                                     dim_value: 7
127                                  }
128                                }
129                              }
130                           }
131                       }
132                    }
133                    opset_import {
134                       version: 7
135                     })";
136         Setup();
137     }
138 };
139 
BOOST_FIXTURE_TEST_CASE(GetEmptyInputs,GetEmptyInputsOutputsFixture)140 BOOST_FIXTURE_TEST_CASE(GetEmptyInputs, GetEmptyInputsOutputsFixture)
141 {
142     ModelPtr model = armnnOnnxParser::OnnxParser::LoadModelFromString(m_Prototext.c_str());
143     std::vector<std::string> tensors = armnnOnnxParser::OnnxParser::GetInputs(model);
144     BOOST_CHECK_EQUAL(0, tensors.size());
145 }
146 
BOOST_AUTO_TEST_CASE(GetInputsNullModel)147 BOOST_AUTO_TEST_CASE(GetInputsNullModel)
148 {
149     BOOST_CHECK_THROW(armnnOnnxParser::OnnxParser::LoadModelFromString(""), armnn::InvalidArgumentException);
150 }
151 
BOOST_AUTO_TEST_CASE(GetOutputsNullModel)152 BOOST_AUTO_TEST_CASE(GetOutputsNullModel)
153 {
154     auto silencer = google::protobuf::LogSilencer(); //get rid of errors from protobuf
155     BOOST_CHECK_THROW(armnnOnnxParser::OnnxParser::LoadModelFromString("nknnk"), armnn::ParseException);
156 }
157 
158 struct GetInputsMultipleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
159 {
GetInputsMultipleFixtureGetInputsMultipleFixture160     GetInputsMultipleFixture() {
161 
162         m_Prototext = R"(
163                    ir_version: 3
164                    producer_name:  "CNTK"
165                    producer_version:  "2.5.1"
166                    domain:  "ai.cntk"
167                    model_version: 1
168                    graph {
169                      name:  "CNTKGraph"
170                      input {
171                         name: "Input0"
172                         type {
173                           tensor_type {
174                             elem_type: 1
175                             shape {
176                               dim {
177                                 dim_value: 1
178                               }
179                               dim {
180                                 dim_value: 1
181                               }
182                               dim {
183                                 dim_value: 1
184                               }
185                               dim {
186                                 dim_value: 4
187                               }
188                             }
189                           }
190                         }
191                       }
192                       input {
193                          name: "Input1"
194                          type {
195                            tensor_type {
196                              elem_type: 1
197                              shape {
198                                  dim {
199                                    dim_value: 4
200                                  }
201                              }
202                            }
203                          }
204                        }
205                        node {
206                             input: "Input0"
207                             input: "Input1"
208                             output: "Output"
209                             name: "addition"
210                             op_type: "Add"
211                             doc_string: ""
212                             domain: ""
213                           }
214                           output {
215                               name: "Output"
216                               type {
217                                  tensor_type {
218                                    elem_type: 1
219                                    shape {
220                                        dim {
221                                            dim_value: 1
222                                        }
223                                        dim {
224                                            dim_value: 1
225                                        }
226                                        dim {
227                                            dim_value: 1
228                                        }
229                                        dim {
230                                            dim_value: 4
231                                        }
232                                    }
233                                 }
234                             }
235                         }
236                     }
237                    opset_import {
238                       version: 7
239                     })";
240         Setup();
241     }
242 };
243 
BOOST_FIXTURE_TEST_CASE(GetInputsMultipleInputs,GetInputsMultipleFixture)244 BOOST_FIXTURE_TEST_CASE(GetInputsMultipleInputs, GetInputsMultipleFixture)
245 {
246     ModelPtr model = armnnOnnxParser::OnnxParser::LoadModelFromString(m_Prototext.c_str());
247     std::vector<std::string> tensors = armnnOnnxParser::OnnxParser::GetInputs(model);
248     BOOST_CHECK_EQUAL(2, tensors.size());
249     BOOST_CHECK_EQUAL("Input0", tensors[0]);
250     BOOST_CHECK_EQUAL("Input1", tensors[1]);
251 }
252 
253 
254 
255 BOOST_AUTO_TEST_SUITE_END()
256