1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "armnnTfParser/ITfParser.hpp"
7 #include "ParserPrototxtFixture.hpp"
8
9 #include <PrototxtConversions.hpp>
10
11 #include <boost/test/unit_test.hpp>
12
13 BOOST_AUTO_TEST_SUITE(TensorflowParser)
14
15 struct StackFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
16 {
StackFixtureStackFixture17 explicit StackFixture(const armnn::TensorShape& inputShape0,
18 const armnn::TensorShape& inputShape1,
19 int axis = 0)
20 {
21 m_Prototext = R"(
22 node {
23 name: "input0"
24 op: "Placeholder"
25 attr {
26 key: "dtype"
27 value {
28 type: DT_FLOAT
29 }
30 }
31 attr {
32 key: "shape"
33 value {
34 shape {
35 }
36 }
37 }
38 }
39 node {
40 name: "input1"
41 op: "Placeholder"
42 attr {
43 key: "dtype"
44 value {
45 type: DT_FLOAT
46 }
47 }
48 attr {
49 key: "shape"
50 value {
51 shape {
52 }
53 }
54 }
55 }
56 node {
57 name: "output"
58 op: "Stack"
59 input: "input0"
60 input: "input1"
61 attr {
62 key: "axis"
63 value {
64 i: )";
65 m_Prototext += std::to_string(axis);
66 m_Prototext += R"(
67 }
68 }
69 })";
70
71 Setup({{"input0", inputShape0 },
72 {"input1", inputShape1 }}, {"output"});
73 }
74 };
75
76 struct Stack3DFixture : StackFixture
77 {
Stack3DFixtureStack3DFixture78 Stack3DFixture() : StackFixture({ 3, 2, 3 }, { 3, 2, 3 }, 3 ) {}
79 };
80
BOOST_FIXTURE_TEST_CASE(Stack3D,Stack3DFixture)81 BOOST_FIXTURE_TEST_CASE(Stack3D, Stack3DFixture)
82 {
83
84 RunTest<4>({ { "input0", { 1, 2, 3,
85 4, 5, 6,
86
87 7, 8, 9,
88 10, 11, 12,
89
90 13, 14, 15,
91 16, 17, 18 } },
92 { "input1", { 19, 20, 21,
93 22, 23, 24,
94
95 25, 26, 27,
96 28, 29, 30,
97
98 31, 32, 33,
99 34, 35, 36 } } },
100 { { "output", { 1, 19,
101 2, 20,
102 3, 21,
103
104 4, 22,
105 5, 23,
106 6, 24,
107
108 7, 25,
109 8, 26,
110 9, 27,
111
112 10, 28,
113 11, 29,
114 12, 30,
115
116 13, 31,
117 14, 32,
118 15, 33,
119
120 16, 34,
121 17, 35,
122 18, 36 } } });
123 }
124
125 struct Stack3DNegativeAxisFixture : StackFixture
126 {
Stack3DNegativeAxisFixtureStack3DNegativeAxisFixture127 Stack3DNegativeAxisFixture() : StackFixture({ 3, 2, 3 }, { 3, 2, 3 }, -1 ) {}
128 };
129
BOOST_FIXTURE_TEST_CASE(Stack3DNegativeAxis,Stack3DNegativeAxisFixture)130 BOOST_FIXTURE_TEST_CASE(Stack3DNegativeAxis, Stack3DNegativeAxisFixture)
131 {
132
133 RunTest<4>({ { "input0", { 1, 2, 3,
134 4, 5, 6,
135
136 7, 8, 9,
137 10, 11, 12,
138
139 13, 14, 15,
140 16, 17, 18 } },
141 { "input1", { 19, 20, 21,
142 22, 23, 24,
143
144 25, 26, 27,
145 28, 29, 30,
146
147 31, 32, 33,
148 34, 35, 36 } } },
149 { { "output", { 1, 19,
150 2, 20,
151 3, 21,
152
153 4, 22,
154 5, 23,
155 6, 24,
156
157 7, 25,
158 8, 26,
159 9, 27,
160
161 10, 28,
162 11, 29,
163 12, 30,
164
165 13, 31,
166 14, 32,
167 15, 33,
168
169 16, 34,
170 17, 35,
171 18, 36 } } });
172 }
173
174 BOOST_AUTO_TEST_SUITE_END()
175