1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 #include <array>
11
12 BOOST_AUTO_TEST_SUITE(TensorflowParser)
13
14 struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
15 {
FusedBatchNormFixtureFusedBatchNormFixture16 explicit FusedBatchNormFixture(const std::string& dataLayout)
17 {
18 m_Prototext = "node { \n"
19 " name: \"graphInput\" \n"
20 " op: \"Placeholder\" \n"
21 " attr { \n"
22 " key: \"dtype\" \n"
23 " value { \n"
24 " type: DT_FLOAT \n"
25 " } \n"
26 " } \n"
27 " attr { \n"
28 " key: \"shape\" \n"
29 " value { \n"
30 " shape { \n"
31 " } \n"
32 " } \n"
33 " } \n"
34 "} \n"
35 "node { \n"
36 " name: \"Const_1\" \n"
37 " op: \"Const\" \n"
38 " attr { \n"
39 " key: \"dtype\" \n"
40 " value { \n"
41 " type: DT_FLOAT \n"
42 " } \n"
43 " } \n"
44 " attr { \n"
45 " key: \"value\" \n"
46 " value { \n"
47 " tensor { \n"
48 " dtype: DT_FLOAT \n"
49 " tensor_shape { \n"
50 " dim { \n"
51 " size: 1 \n"
52 " } \n"
53 " } \n"
54 " float_val: 1.0 \n"
55 " } \n"
56 " } \n"
57 " } \n"
58 "} \n"
59 "node { \n"
60 " name: \"Const_2\" \n"
61 " op: \"Const\" \n"
62 " attr { \n"
63 " key: \"dtype\" \n"
64 " value { \n"
65 " type: DT_FLOAT \n"
66 " } \n"
67 " } \n"
68 " attr { \n"
69 " key: \"value\" \n"
70 " value { \n"
71 " tensor { \n"
72 " dtype: DT_FLOAT \n"
73 " tensor_shape { \n"
74 " dim { \n"
75 " size: 1 \n"
76 " } \n"
77 " } \n"
78 " float_val: 0.0 \n"
79 " } \n"
80 " } \n"
81 " } \n"
82 "} \n"
83 "node { \n"
84 " name: \"FusedBatchNormLayer/mean\" \n"
85 " op: \"Const\" \n"
86 " attr { \n"
87 " key: \"dtype\" \n"
88 " value { \n"
89 " type: DT_FLOAT \n"
90 " } \n"
91 " } \n"
92 " attr { \n"
93 " key: \"value\" \n"
94 " value { \n"
95 " tensor { \n"
96 " dtype: DT_FLOAT \n"
97 " tensor_shape { \n"
98 " dim { \n"
99 " size: 1 \n"
100 " } \n"
101 " } \n"
102 " float_val: 5.0 \n"
103 " } \n"
104 " } \n"
105 " } \n"
106 "} \n"
107 "node { \n"
108 " name: \"FusedBatchNormLayer/variance\" \n"
109 " op: \"Const\" \n"
110 " attr { \n"
111 " key: \"dtype\" \n"
112 " value { \n"
113 " type: DT_FLOAT \n"
114 " } \n"
115 " } \n"
116 " attr { \n"
117 " key: \"value\" \n"
118 " value { \n"
119 " tensor { \n"
120 " dtype: DT_FLOAT \n"
121 " tensor_shape { \n"
122 " dim { \n"
123 " size: 1 \n"
124 " } \n"
125 " } \n"
126 " float_val: 2.0 \n"
127 " } \n"
128 " } \n"
129 " } \n"
130 "} \n"
131 "node { \n"
132 " name: \"output\" \n"
133 " op: \"FusedBatchNorm\" \n"
134 " input: \"graphInput\" \n"
135 " input: \"Const_1\" \n"
136 " input: \"Const_2\" \n"
137 " input: \"FusedBatchNormLayer/mean\" \n"
138 " input: \"FusedBatchNormLayer/variance\" \n"
139 " attr { \n"
140 " key: \"T\" \n"
141 " value { \n"
142 " type: DT_FLOAT \n"
143 " } \n"
144 " } \n";
145
146 // NOTE: we only explicitly set data_format when it is not the default NHWC
147 if (dataLayout != "NHWC")
148 {
149 m_Prototext.append(" attr { \n"
150 " key: \"data_format\" \n"
151 " value { \n"
152 " s: \"");
153 m_Prototext.append(dataLayout);
154 m_Prototext.append("\" \n"
155 " } \n"
156 " } \n");
157 }
158
159 m_Prototext.append(" attr { \n"
160 " key: \"epsilon\" \n"
161 " value { \n"
162 " f: 0.0010000000475 \n"
163 " } \n"
164 " } \n"
165 " attr { \n"
166 " key: \"is_training\" \n"
167 " value { \n"
168 " b: false \n"
169 " } \n"
170 " } \n"
171 "} \n");
172
173 // Set the input shape according to the data layout
174 std::array<unsigned int, 4> dims;
175 if (dataLayout == "NHWC")
176 {
177 dims = { 1u, 3u, 3u, 1u };
178 }
179 else // dataLayout == "NCHW"
180 {
181 dims = { 1u, 1u, 3u, 3u };
182 }
183
184 SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "output");
185 }
186 };
187
188 struct FusedBatchNormNhwcFixture : FusedBatchNormFixture
189 {
FusedBatchNormNhwcFixtureFusedBatchNormNhwcFixture190 FusedBatchNormNhwcFixture() : FusedBatchNormFixture("NHWC"){}
191 };
BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNhwc,FusedBatchNormNhwcFixture)192 BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNhwc, FusedBatchNormNhwcFixture)
193 {
194 RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // Input data.
195 { -2.8277204f, -2.12079024f, -1.4138602f,
196 -0.7069301f, 0.0f, 0.7069301f,
197 1.4138602f, 2.12079024f, 2.8277204f }); // Expected output data.
198 }
199
200 struct FusedBatchNormNchwFixture : FusedBatchNormFixture
201 {
FusedBatchNormNchwFixtureFusedBatchNormNchwFixture202 FusedBatchNormNchwFixture() : FusedBatchNormFixture("NCHW"){}
203 };
BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNchw,FusedBatchNormNchwFixture)204 BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNchw, FusedBatchNormNchwFixture)
205 {
206 RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // Input data.
207 { -2.8277204f, -2.12079024f, -1.4138602f,
208 -0.7069301f, 0.0f, 0.7069301f,
209 1.4138602f, 2.12079024f, 2.8277204f }); // Expected output data.
210 }
211
212 BOOST_AUTO_TEST_SUITE_END()
213