• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9 
10 #include <array>
11 
12 BOOST_AUTO_TEST_SUITE(TensorflowParser)
13 
14 struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
15 {
FusedBatchNormFixtureFusedBatchNormFixture16     explicit FusedBatchNormFixture(const std::string& dataLayout)
17     {
18         m_Prototext = "node { \n"
19             "  name: \"graphInput\" \n"
20             "  op: \"Placeholder\" \n"
21             "  attr { \n"
22             "    key: \"dtype\" \n"
23             "    value { \n"
24             "      type: DT_FLOAT \n"
25             "    } \n"
26             "  } \n"
27             "  attr { \n"
28             "    key: \"shape\" \n"
29             "    value { \n"
30             "      shape { \n"
31             "      } \n"
32             "    } \n"
33             "  } \n"
34             "} \n"
35             "node { \n"
36             "  name: \"Const_1\" \n"
37             "  op: \"Const\" \n"
38             "  attr { \n"
39             "    key: \"dtype\" \n"
40             "    value { \n"
41             "      type: DT_FLOAT \n"
42             "    } \n"
43             "  } \n"
44             "  attr { \n"
45             "    key: \"value\" \n"
46             "    value { \n"
47             "      tensor { \n"
48             "        dtype: DT_FLOAT \n"
49             "        tensor_shape { \n"
50             "          dim { \n"
51             "            size: 1 \n"
52             "          } \n"
53             "        } \n"
54             "        float_val: 1.0 \n"
55             "      } \n"
56             "    } \n"
57             "  } \n"
58             "} \n"
59             "node { \n"
60             "  name: \"Const_2\" \n"
61             "  op: \"Const\" \n"
62             "  attr { \n"
63             "    key: \"dtype\" \n"
64             "    value { \n"
65             "      type: DT_FLOAT \n"
66             "    } \n"
67             "  } \n"
68             "  attr { \n"
69             "    key: \"value\" \n"
70             "    value { \n"
71             "      tensor { \n"
72             "        dtype: DT_FLOAT \n"
73             "        tensor_shape { \n"
74             "          dim { \n"
75             "            size: 1 \n"
76             "          } \n"
77             "        } \n"
78             "        float_val: 0.0 \n"
79             "      } \n"
80             "    } \n"
81             "  } \n"
82             "} \n"
83             "node { \n"
84             "  name: \"FusedBatchNormLayer/mean\" \n"
85             "  op: \"Const\" \n"
86             "  attr { \n"
87             "    key: \"dtype\" \n"
88             "    value { \n"
89             "      type: DT_FLOAT \n"
90             "    } \n"
91             "  } \n"
92             "  attr { \n"
93             "    key: \"value\" \n"
94             "    value { \n"
95             "      tensor { \n"
96             "        dtype: DT_FLOAT \n"
97             "        tensor_shape { \n"
98             "          dim { \n"
99             "            size: 1 \n"
100             "          } \n"
101             "        } \n"
102             "        float_val: 5.0 \n"
103             "      } \n"
104             "    } \n"
105             "  } \n"
106             "} \n"
107             "node { \n"
108             "  name: \"FusedBatchNormLayer/variance\" \n"
109             "  op: \"Const\" \n"
110             "  attr { \n"
111             "    key: \"dtype\" \n"
112             "    value { \n"
113             "      type: DT_FLOAT \n"
114             "    } \n"
115             "  } \n"
116             "  attr { \n"
117             "    key: \"value\" \n"
118             "    value { \n"
119             "      tensor { \n"
120             "        dtype: DT_FLOAT \n"
121             "        tensor_shape { \n"
122             "          dim { \n"
123             "            size: 1 \n"
124             "          } \n"
125             "        } \n"
126             "        float_val: 2.0 \n"
127             "      } \n"
128             "    } \n"
129             "  } \n"
130             "} \n"
131             "node { \n"
132             "  name: \"output\" \n"
133             "  op: \"FusedBatchNorm\" \n"
134             "  input: \"graphInput\" \n"
135             "  input: \"Const_1\" \n"
136             "  input: \"Const_2\" \n"
137             "  input: \"FusedBatchNormLayer/mean\" \n"
138             "  input: \"FusedBatchNormLayer/variance\" \n"
139             "  attr { \n"
140             "    key: \"T\" \n"
141             "    value { \n"
142             "      type: DT_FLOAT \n"
143             "    } \n"
144             "  } \n";
145 
146         // NOTE: we only explicitly set data_format when it is not the default NHWC
147         if (dataLayout != "NHWC")
148         {
149             m_Prototext.append("  attr { \n"
150                 "    key: \"data_format\" \n"
151                 "    value { \n"
152                 "      s: \"");
153             m_Prototext.append(dataLayout);
154             m_Prototext.append("\" \n"
155                 "    } \n"
156                 "  } \n");
157         }
158 
159         m_Prototext.append("  attr { \n"
160                            "    key: \"epsilon\" \n"
161                            "    value { \n"
162                            "      f: 0.0010000000475 \n"
163                            "    } \n"
164                            "  } \n"
165                            "  attr { \n"
166                            "    key: \"is_training\" \n"
167                            "    value { \n"
168                            "      b: false \n"
169                            "    } \n"
170                            "  } \n"
171                            "} \n");
172 
173         // Set the input shape according to the data layout
174         std::array<unsigned int, 4> dims;
175         if (dataLayout == "NHWC")
176         {
177             dims = { 1u, 3u, 3u, 1u };
178         }
179         else // dataLayout == "NCHW"
180         {
181             dims = { 1u, 1u, 3u, 3u };
182         }
183 
184         SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "output");
185     }
186 };
187 
188 struct FusedBatchNormNhwcFixture : FusedBatchNormFixture
189 {
FusedBatchNormNhwcFixtureFusedBatchNormNhwcFixture190     FusedBatchNormNhwcFixture() : FusedBatchNormFixture("NHWC"){}
191 };
BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNhwc,FusedBatchNormNhwcFixture)192 BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNhwc, FusedBatchNormNhwcFixture)
193 {
194     RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 },               // Input data.
195                { -2.8277204f, -2.12079024f, -1.4138602f,
196                  -0.7069301f,  0.0f,         0.7069301f,
197                   1.4138602f,  2.12079024f,  2.8277204f }); // Expected output data.
198 }
199 
200 struct FusedBatchNormNchwFixture : FusedBatchNormFixture
201 {
FusedBatchNormNchwFixtureFusedBatchNormNchwFixture202     FusedBatchNormNchwFixture() : FusedBatchNormFixture("NCHW"){}
203 };
BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNchw,FusedBatchNormNchwFixture)204 BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNchw, FusedBatchNormNchwFixture)
205 {
206     RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 },               // Input data.
207                { -2.8277204f, -2.12079024f, -1.4138602f,
208                  -0.7069301f,  0.0f,         0.7069301f,
209                   1.4138602f,  2.12079024f,  2.8277204f }); // Expected output data.
210 }
211 
212 BOOST_AUTO_TEST_SUITE_END()
213