1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnOnnxParser/IOnnxParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(OnnxParser)
11
12 struct ReshapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13 {
ReshapeMainFixtureReshapeMainFixture14 ReshapeMainFixture(const std::string& dataType)
15 {
16 m_Prototext = R"(
17 ir_version: 3
18 producer_name: "CNTK"
19 producer_version: "2.5.1"
20 domain: "ai.cntk"
21 model_version: 1
22 graph {
23 name: "CNTKGraph"
24 input {
25 name: "Input"
26 type {
27 tensor_type {
28 elem_type: )" + dataType + R"(
29 shape {
30 dim {
31 dim_value: 4
32 }
33 }
34 }
35 }
36 }
37 input {
38 name: "Shape"
39 type {
40 tensor_type {
41 elem_type: 7
42 shape {
43 dim {
44 dim_value: 2
45 }
46 }
47 }
48 }
49 }
50 node {
51 input: "Input"
52 input: "Shape"
53 output: "Output"
54 name: "reshape"
55 op_type: "Reshape"
56
57 }
58 initializer {
59 dims: 2
60 data_type: 7
61 int64_data: 2
62 int64_data: 2
63 name: "Shape"
64 }
65 output {
66 name: "Output"
67 type {
68 tensor_type {
69 elem_type: 1
70 shape {
71 dim {
72 dim_value: 2
73 }
74 dim {
75 dim_value: 2
76 }
77 }
78 }
79 }
80 }
81 }
82 opset_import {
83 version: 7
84 })";
85 }
86 };
87
88 struct ReshapeRank4Fixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
89 {
ReshapeRank4FixtureReshapeRank4Fixture90 ReshapeRank4Fixture(const std::string& dataType)
91 {
92 m_Prototext = R"(
93 ir_version: 3
94 producer_name: "CNTK"
95 producer_version: "2.5.1"
96 domain: "ai.cntk"
97 model_version: 1
98 graph {
99 name: "CNTKGraph"
100 input {
101 name: "Input"
102 type {
103 tensor_type {
104 elem_type: )" + dataType + R"(
105 shape {
106 dim {
107 dim_value: 2
108 }
109 dim {
110 dim_value: 2
111 }
112 dim {
113 dim_value: 3
114 }
115 dim {
116 dim_value: 3
117 }
118 }
119 }
120 }
121 }
122 input {
123 name: "Shape"
124 type {
125 tensor_type {
126 elem_type: 7
127 shape {
128 dim {
129 dim_value: 2
130 }
131 }
132 }
133 }
134 }
135 node {
136 input: "Input"
137 input: "Shape"
138 output: "Output"
139 name: "reshape"
140 op_type: "Reshape"
141
142 }
143 initializer {
144 dims: 2
145 data_type: 7
146 int64_data: 2
147 int64_data: 2
148 name: "Shape"
149 }
150 output {
151 name: "Output"
152 type {
153 tensor_type {
154 elem_type: 1
155 shape {
156 dim {
157 dim_value: 6
158 }
159 dim {
160 dim_value: 6
161 }
162 }
163 }
164 }
165 }
166 }
167 opset_import {
168 version: 7
169 })";
170 }
171 };
172
173 struct ReshapeValidFixture : ReshapeMainFixture
174 {
ReshapeValidFixtureReshapeValidFixture175 ReshapeValidFixture() : ReshapeMainFixture("1") {
176 Setup();
177 }
178 };
179
180 struct ReshapeValidRank4Fixture : ReshapeRank4Fixture
181 {
ReshapeValidRank4FixtureReshapeValidRank4Fixture182 ReshapeValidRank4Fixture() : ReshapeRank4Fixture("1") {
183 Setup();
184 }
185 };
186
187 struct ReshapeInvalidFixture : ReshapeMainFixture
188 {
ReshapeInvalidFixtureReshapeInvalidFixture189 ReshapeInvalidFixture() : ReshapeMainFixture("10") { }
190 };
191
BOOST_FIXTURE_TEST_CASE(ValidReshapeTest,ReshapeValidFixture)192 BOOST_FIXTURE_TEST_CASE(ValidReshapeTest, ReshapeValidFixture)
193 {
194 RunTest<2>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f }}}, {{"Output", { 0.0f, 1.0f, 2.0f, 3.0f }}});
195 }
196
BOOST_FIXTURE_TEST_CASE(ValidRank4ReshapeTest,ReshapeValidRank4Fixture)197 BOOST_FIXTURE_TEST_CASE(ValidRank4ReshapeTest, ReshapeValidRank4Fixture)
198 {
199 RunTest<2>(
200 {{"Input",
201 {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
202 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
203 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}},
204 {{"Output",
205 {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
206 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
207 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}});
208 }
209
BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeReshape,ReshapeInvalidFixture)210 BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeReshape, ReshapeInvalidFixture)
211 {
212 BOOST_CHECK_THROW(Setup(), armnn::ParseException);
213 }
214
215 BOOST_AUTO_TEST_SUITE_END()
216