1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <armnn/utility/Assert.hpp>
7 #include <boost/test/unit_test.hpp>
8
9 #include "armnnTfParser/ITfParser.hpp"
10 #include "ParserPrototxtFixture.hpp"
11
12 #include <map>
13 #include <string>
14
15
16 BOOST_AUTO_TEST_SUITE(TensorflowParser)
17
18 struct AddNFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
19 {
AddNFixtureAddNFixture20 AddNFixture(const std::vector<armnn::TensorShape> inputShapes, unsigned int numberOfInputs)
21 {
22 ARMNN_ASSERT(inputShapes.size() == numberOfInputs);
23 m_Prototext = "";
24 for (unsigned int i = 0; i < numberOfInputs; i++)
25 {
26 m_Prototext.append("node { \n");
27 m_Prototext.append(" name: \"input").append(std::to_string(i)).append("\"\n");
28 m_Prototext += R"( op: "Placeholder"
29 attr {
30 key: "dtype"
31 value {
32 type: DT_FLOAT
33 }
34 }
35 attr {
36 key: "shape"
37 value {
38 shape {
39 }
40 }
41 }
42 }
43 )";
44 }
45 m_Prototext += R"(node {
46 name: "output"
47 op: "AddN"
48 )";
49 for (unsigned int i = 0; i < numberOfInputs; i++)
50 {
51 m_Prototext.append(" input: \"input").append(std::to_string(i)).append("\"\n");
52 }
53 m_Prototext += R"( attr {
54 key: "N"
55 value {
56 )";
57 m_Prototext.append(" i: ").append(std::to_string(numberOfInputs)).append("\n");
58 m_Prototext += R"( }
59 }
60 attr {
61 key: "T"
62 value {
63 type: DT_FLOAT
64 }
65 }
66 })";
67
68 std::map<std::string, armnn::TensorShape> inputs;
69 for (unsigned int i = 0; i < numberOfInputs; i++)
70 {
71 std::string name("input");
72 name.append(std::to_string(i));
73 inputs.emplace(std::make_pair(name, inputShapes[i]));
74 }
75 Setup(inputs, {"output"});
76 }
77
78 };
79
80 // try with 2, 3, 5 and 8 inputs
81 struct FiveTwoDimInputsFixture : AddNFixture
82 {
FiveTwoDimInputsFixtureFiveTwoDimInputsFixture83 FiveTwoDimInputsFixture() : AddNFixture({ { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 } }, 5) {}
84 };
85
86
BOOST_FIXTURE_TEST_CASE(FiveTwoDimInputs,FiveTwoDimInputsFixture)87 BOOST_FIXTURE_TEST_CASE(FiveTwoDimInputs, FiveTwoDimInputsFixture)
88 {
89 RunTest<2>({ { "input0", { 1.0, 2.0, 3.0, 4.0 } },
90 { "input1", { 1.0, 5.0, 2.0, 2.0 } },
91 { "input2", { 1.0, 1.0, 2.0, 2.0 } },
92 { "input3", { 3.0, 7.0, 1.0, 2.0 } },
93 { "input4", { 8.0, 0.0, -2.0, -3.0 } } },
94 { { "output", { 14.0, 15.0, 6.0, 7.0 } } });
95 }
96
97 struct TwoTwoDimInputsFixture : AddNFixture
98 {
TwoTwoDimInputsFixtureTwoTwoDimInputsFixture99 TwoTwoDimInputsFixture() : AddNFixture({ { 2, 2 }, { 2, 2 } }, 2) {}
100 };
101
BOOST_FIXTURE_TEST_CASE(TwoTwoDimInputs,TwoTwoDimInputsFixture)102 BOOST_FIXTURE_TEST_CASE(TwoTwoDimInputs, TwoTwoDimInputsFixture)
103 {
104 RunTest<2>({ { "input0", { 1.0, 2.0, 3.0, 4.0 } },
105 { "input1", { 1.0, 5.0, 2.0, 2.0 } } },
106 { { "output", { 2.0, 7.0, 5.0, 6.0 } } });
107 }
108
109 struct ThreeTwoDimInputsFixture : AddNFixture
110 {
ThreeTwoDimInputsFixtureThreeTwoDimInputsFixture111 ThreeTwoDimInputsFixture() : AddNFixture({ { 2, 2 }, { 2, 2 }, { 2, 2 } }, 3) {}
112 };
113
BOOST_FIXTURE_TEST_CASE(ThreeTwoDimInputs,ThreeTwoDimInputsFixture)114 BOOST_FIXTURE_TEST_CASE(ThreeTwoDimInputs, ThreeTwoDimInputsFixture)
115 {
116 RunTest<2>({ { "input0", { 1.0, 2.0, 3.0, 4.0 } },
117 { "input1", { 1.0, 5.0, 2.0, 2.0 } },
118 { "input2", { 1.0, 1.0, 2.0, 2.0 } } },
119 { { "output", { 3.0, 8.0, 7.0, 8.0 } } });
120 }
121
122 struct EightTwoDimInputsFixture : AddNFixture
123 {
EightTwoDimInputsFixtureEightTwoDimInputsFixture124 EightTwoDimInputsFixture() : AddNFixture({ { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 },
125 { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 } }, 8) {}
126 };
127
BOOST_FIXTURE_TEST_CASE(EightTwoDimInputs,EightTwoDimInputsFixture)128 BOOST_FIXTURE_TEST_CASE(EightTwoDimInputs, EightTwoDimInputsFixture)
129 {
130 RunTest<2>({ { "input0", { 1.0, 2.0, 3.0, 4.0 } },
131 { "input1", { 1.0, 5.0, 2.0, 2.0 } },
132 { "input2", { 1.0, 1.0, 2.0, 2.0 } },
133 { "input3", { 3.0, 7.0, 1.0, 2.0 } },
134 { "input4", { 8.0, 0.0, -2.0, -3.0 } },
135 { "input5", {-3.0, 2.0, -1.0, -5.0 } },
136 { "input6", { 1.0, 6.0, 2.0, 2.0 } },
137 { "input7", {-19.0, 7.0, 1.0, -10.0 } } },
138 { { "output", {-7.0, 30.0, 8.0, -6.0 } } });
139 }
140
141 struct ThreeInputBroadcast1D4D4DInputsFixture : AddNFixture
142 {
ThreeInputBroadcast1D4D4DInputsFixtureThreeInputBroadcast1D4D4DInputsFixture143 ThreeInputBroadcast1D4D4DInputsFixture() : AddNFixture({ { 1 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, 3) {}
144 };
145
BOOST_FIXTURE_TEST_CASE(ThreeInputBroadcast1D4D4DInputs,ThreeInputBroadcast1D4D4DInputsFixture)146 BOOST_FIXTURE_TEST_CASE(ThreeInputBroadcast1D4D4DInputs, ThreeInputBroadcast1D4D4DInputsFixture)
147 {
148 RunTest<4>({ { "input0", { 1.0 } },
149 { "input1", { 1.0, 5.0, 2.0, 2.0 } },
150 { "input2", { 1.0, 1.0, 2.0, 2.0 } } },
151 { { "output", { 3.0, 7.0, 5.0, 5.0 } } });
152 }
153
154 struct ThreeInputBroadcast4D1D4DInputsFixture : AddNFixture
155 {
ThreeInputBroadcast4D1D4DInputsFixtureThreeInputBroadcast4D1D4DInputsFixture156 ThreeInputBroadcast4D1D4DInputsFixture() : AddNFixture({ { 1, 1, 2, 2 }, { 1 }, { 1, 1, 2, 2 } }, 3) {}
157 };
158
BOOST_FIXTURE_TEST_CASE(ThreeInputBroadcast4D1D4DInputs,ThreeInputBroadcast4D1D4DInputsFixture)159 BOOST_FIXTURE_TEST_CASE(ThreeInputBroadcast4D1D4DInputs, ThreeInputBroadcast4D1D4DInputsFixture)
160 {
161 RunTest<4>({ { "input0", { 1.0, 3.0, 9.0, 4.0 } },
162 { "input1", {-2.0 } },
163 { "input2", { 1.0, 1.0, 2.0, 2.0 } } },
164 { { "output", { 0.0, 2.0, 9.0, 4.0 } } });
165 }
166
167 struct ThreeInputBroadcast4D4D1DInputsFixture : AddNFixture
168 {
ThreeInputBroadcast4D4D1DInputsFixtureThreeInputBroadcast4D4D1DInputsFixture169 ThreeInputBroadcast4D4D1DInputsFixture() : AddNFixture({ { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1 } }, 3) {}
170 };
171
BOOST_FIXTURE_TEST_CASE(ThreeInputBroadcast4D4D1DInputs,ThreeInputBroadcast4D4D1DInputsFixture)172 BOOST_FIXTURE_TEST_CASE(ThreeInputBroadcast4D4D1DInputs, ThreeInputBroadcast4D4D1DInputsFixture)
173 {
174 RunTest<4>({ { "input0", { 1.0, 5.0, 2.0, 2.0 } },
175 { "input1", { 1.0, 1.0, 2.0, 2.0 } },
176 { "input2", { 1.0 } } },
177 { { "output", { 3.0, 7.0, 5.0, 5.0 } } });
178 }
179
180 BOOST_AUTO_TEST_SUITE_END()
181