1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "armnnTfParser/ITfParser.hpp"
7
8 #include <ParserPrototxtFixture.hpp>
9 #include <PrototxtConversions.hpp>
10
11 #include <boost/test/unit_test.hpp>
12
13 BOOST_AUTO_TEST_SUITE(TensorflowParser)
14
15 struct MeanFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
16 {
MeanFixtureMeanFixture17 explicit MeanFixture(const armnn::TensorShape& inputShape, const armnn::TensorShape& outputShape,
18 const std::vector<unsigned int>& axis, bool keepDims)
19 {
20 std::string protobufAxisString;
21 std::vector<unsigned int> protobufAxis(axis);
22
23 // If no axis range is specified, the reduction is applied to
24 // all dimensions of the input tensor
25 if (protobufAxis.size() == 0)
26 {
27 for (unsigned int i = 0; i < inputShape.GetNumDimensions(); ++i)
28 {
29 protobufAxis.push_back(i);
30 }
31 }
32
33 for (unsigned int i = 0; i < protobufAxis.size(); ++i)
34 {
35 protobufAxisString.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(protobufAxis[i])));
36 }
37
38 m_Prototext = R"(node {
39 name: "input"
40 op: "Placeholder"
41 attr {
42 key: "dtype"
43 value {
44 type: DT_FLOAT
45 }
46 }
47 attr {
48 key: "shape"
49 value {
50 shape {
51 }
52 }
53 }
54 }
55 node {
56 name: "Const"
57 op: "Const"
58 attr {
59 key: "dtype"
60 value {
61 type: DT_INT32
62 }
63 }
64 attr {
65 key: "value"
66 value { )";
67
68 if (axis.size() == 1)
69 {
70 m_Prototext.append(R"( tensor {
71 dtype: DT_INT32
72 tensor_shape {
73 }
74 int_val: )").append(std::to_string(protobufAxis[0])).append(R"(
75 } )");
76 }
77 else
78 {
79 m_Prototext.append(R"( tensor {
80 dtype: DT_INT32
81 tensor_shape {
82 dim {
83 size: 2
84 }
85 }
86 tensor_content: ")").append(protobufAxisString).append(R"("
87 } )");
88 }
89
90 m_Prototext.append(R"( }
91 }
92 }
93 node {
94 name: "output"
95 op: "Mean"
96 input: "input"
97 input: "Const"
98 attr {
99 key: "T"
100 value {
101 type: DT_FLOAT
102 }
103 }
104 attr {
105 key: "Tidx"
106 value {
107 type: DT_INT32
108 }
109 }
110 attr {
111 key: "keep_dims"
112 value {
113 b: )").append(keepDims ? "true" : "false").append(R"(
114 }
115 }
116 })");
117
118 SetupSingleInputSingleOutput(inputShape, outputShape, "input", "output");
119 }
120 };
121
122 struct MeanNoAxisNoKeepDimsFixture: MeanFixture
123 {
MeanNoAxisNoKeepDimsFixtureMeanNoAxisNoKeepDimsFixture124 MeanNoAxisNoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 1 }, {}, false) {}
125 };
126
127 struct MeanWithAxis0NoKeepDimsFixture: MeanFixture
128 {
MeanWithAxis0NoKeepDimsFixtureMeanWithAxis0NoKeepDimsFixture129 MeanWithAxis0NoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 3 }, { 0 }, false) {}
130 };
131
132 struct MeanWithAxis1NoKeepDimsFixture: MeanFixture
133 {
MeanWithAxis1NoKeepDimsFixtureMeanWithAxis1NoKeepDimsFixture134 MeanWithAxis1NoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 2 }, { 1 }, false) {}
135 };
136
137 struct MeanWithAxis0KeepDimsFixture: MeanFixture
138 {
MeanWithAxis0KeepDimsFixtureMeanWithAxis0KeepDimsFixture139 MeanWithAxis0KeepDimsFixture() : MeanFixture({ 2, 3 }, { 1, 3 }, { 0 }, true) {}
140 };
141
142 struct MeanWithAxis1KeepDimsFixture: MeanFixture
143 {
MeanWithAxis1KeepDimsFixtureMeanWithAxis1KeepDimsFixture144 MeanWithAxis1KeepDimsFixture() : MeanFixture({ 2, 3 }, { 2, 1 }, { 1 }, true) {}
145 };
146
147
BOOST_FIXTURE_TEST_CASE(MeanNoAxisNoKeepDims,MeanNoAxisNoKeepDimsFixture)148 BOOST_FIXTURE_TEST_CASE(MeanNoAxisNoKeepDims, MeanNoAxisNoKeepDimsFixture)
149 {
150 RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
151 { { "output", { 1.5f } } });
152 }
153
BOOST_FIXTURE_TEST_CASE(MeanWithAxis0NoKeepDims,MeanWithAxis0NoKeepDimsFixture)154 BOOST_FIXTURE_TEST_CASE(MeanWithAxis0NoKeepDims, MeanWithAxis0NoKeepDimsFixture)
155 {
156 RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
157 { { "output", { 1.5f, 1.5f, 1.5f } } });
158 }
159
BOOST_FIXTURE_TEST_CASE(MeanWithAxis1NoKeepDims,MeanWithAxis1NoKeepDimsFixture)160 BOOST_FIXTURE_TEST_CASE(MeanWithAxis1NoKeepDims, MeanWithAxis1NoKeepDimsFixture)
161 {
162 RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
163 { { "output", { 1.f, 2.f } } });
164 }
165
BOOST_FIXTURE_TEST_CASE(MeanWithAxis0KeepDims,MeanWithAxis0KeepDimsFixture)166 BOOST_FIXTURE_TEST_CASE(MeanWithAxis0KeepDims, MeanWithAxis0KeepDimsFixture)
167 {
168 RunTest<2>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
169 { { "output", { 1.5f, 1.5f, 1.5f } } });
170 }
171
BOOST_FIXTURE_TEST_CASE(MeanWithAxis1KeepDims,MeanWithAxis1KeepDimsFixture)172 BOOST_FIXTURE_TEST_CASE(MeanWithAxis1KeepDims, MeanWithAxis1KeepDimsFixture)
173 {
174 RunTest<2>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
175 { { "output", { 1.f, 2.f } } });
176 }
177
178 BOOST_AUTO_TEST_SUITE_END()
179