• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 #include "armnnOnnxParser/IOnnxParser.hpp"
8 #include  "ParserPrototxtFixture.hpp"
9 
10 BOOST_AUTO_TEST_SUITE(OnnxParser)
11 
12 struct ReshapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13 {
ReshapeMainFixtureReshapeMainFixture14     ReshapeMainFixture(const std::string& dataType)
15     {
16         m_Prototext = R"(
17                    ir_version: 3
18                    producer_name:  "CNTK"
19                    producer_version:  "2.5.1"
20                    domain:  "ai.cntk"
21                    model_version: 1
22                    graph {
23                      name:  "CNTKGraph"
24                      input {
25                         name: "Input"
26                         type {
27                           tensor_type {
28                             elem_type: )" + dataType + R"(
29                             shape {
30                               dim {
31                                 dim_value: 4
32                               }
33                             }
34                           }
35                         }
36                       }
37                       input {
38                          name: "Shape"
39                          type {
40                            tensor_type {
41                              elem_type: 7
42                              shape {
43                                dim {
44                                  dim_value: 2
45                                }
46                              }
47                            }
48                          }
49                        }
50                      node {
51                          input: "Input"
52                          input: "Shape"
53                          output: "Output"
54                          name: "reshape"
55                          op_type: "Reshape"
56 
57                       }
58                       initializer {
59                         dims: 2
60                         data_type: 7
61                         int64_data: 2
62                         int64_data: 2
63                         name: "Shape"
64                      }
65                       output {
66                           name: "Output"
67                           type {
68                              tensor_type {
69                                elem_type: 1
70                                shape {
71                                    dim {
72                                        dim_value: 2
73                                    }
74                                    dim {
75                                        dim_value: 2
76                                    }
77                                }
78                             }
79                           }
80                        }
81                     }
82                    opset_import {
83                       version: 7
84                     })";
85     }
86 };
87 
88 struct ReshapeRank4Fixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
89 {
ReshapeRank4FixtureReshapeRank4Fixture90     ReshapeRank4Fixture(const std::string& dataType)
91     {
92         m_Prototext = R"(
93                    ir_version: 3
94                    producer_name:  "CNTK"
95                    producer_version:  "2.5.1"
96                    domain:  "ai.cntk"
97                    model_version: 1
98                    graph {
99                      name:  "CNTKGraph"
100                      input {
101                         name: "Input"
102                         type {
103                           tensor_type {
104                             elem_type: )" + dataType + R"(
105                             shape {
106                               dim {
107                                 dim_value: 2
108                               }
109                               dim {
110                                 dim_value: 2
111                               }
112                               dim {
113                                 dim_value: 3
114                               }
115                               dim {
116                                 dim_value: 3
117                               }
118                             }
119                           }
120                         }
121                       }
122                       input {
123                          name: "Shape"
124                          type {
125                            tensor_type {
126                              elem_type: 7
127                              shape {
128                                dim {
129                                  dim_value: 2
130                                }
131                              }
132                            }
133                          }
134                        }
135                      node {
136                          input: "Input"
137                          input: "Shape"
138                          output: "Output"
139                          name: "reshape"
140                          op_type: "Reshape"
141 
142                       }
143                       initializer {
144                         dims: 2
145                         data_type: 7
146                         int64_data: 2
147                         int64_data: 2
148                         name: "Shape"
149                      }
150                       output {
151                           name: "Output"
152                           type {
153                              tensor_type {
154                                elem_type: 1
155                                shape {
156                                    dim {
157                                        dim_value: 6
158                                    }
159                                    dim {
160                                        dim_value: 6
161                                    }
162                                }
163                             }
164                           }
165                        }
166                     }
167                    opset_import {
168                       version: 7
169                     })";
170     }
171 };
172 
173 struct ReshapeValidFixture : ReshapeMainFixture
174 {
ReshapeValidFixtureReshapeValidFixture175     ReshapeValidFixture() : ReshapeMainFixture("1") {
176         Setup();
177     }
178 };
179 
180 struct ReshapeValidRank4Fixture : ReshapeRank4Fixture
181 {
ReshapeValidRank4FixtureReshapeValidRank4Fixture182     ReshapeValidRank4Fixture() : ReshapeRank4Fixture("1") {
183         Setup();
184     }
185 };
186 
187 struct ReshapeInvalidFixture : ReshapeMainFixture
188 {
ReshapeInvalidFixtureReshapeInvalidFixture189     ReshapeInvalidFixture() : ReshapeMainFixture("10") { }
190 };
191 
BOOST_FIXTURE_TEST_CASE(ValidReshapeTest,ReshapeValidFixture)192 BOOST_FIXTURE_TEST_CASE(ValidReshapeTest, ReshapeValidFixture)
193 {
194     RunTest<2>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f }}}, {{"Output", { 0.0f, 1.0f, 2.0f, 3.0f }}});
195 }
196 
BOOST_FIXTURE_TEST_CASE(ValidRank4ReshapeTest,ReshapeValidRank4Fixture)197 BOOST_FIXTURE_TEST_CASE(ValidRank4ReshapeTest, ReshapeValidRank4Fixture)
198 {
199     RunTest<2>(
200         {{"Input",
201                    {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
202                     1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
203                     1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}},
204         {{"Output",
205                     {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
206                      1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
207                      1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}});
208 }
209 
BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeReshape,ReshapeInvalidFixture)210 BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeReshape, ReshapeInvalidFixture)
211 {
212    BOOST_CHECK_THROW(Setup(), armnn::ParseException);
213 }
214 
215 BOOST_AUTO_TEST_SUITE_END()
216