• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <armnn/utility/Assert.hpp>
7 #include <boost/test/unit_test.hpp>
8 
9 #include "armnnTfParser/ITfParser.hpp"
10 #include "ParserPrototxtFixture.hpp"
11 
12 #include <map>
13 #include <string>
14 
15 
16 BOOST_AUTO_TEST_SUITE(TensorflowParser)
17 
18 struct AddNFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
19 {
AddNFixtureAddNFixture20     AddNFixture(const std::vector<armnn::TensorShape> inputShapes, unsigned int numberOfInputs)
21     {
22         ARMNN_ASSERT(inputShapes.size() == numberOfInputs);
23         m_Prototext = "";
24         for (unsigned int i = 0; i < numberOfInputs; i++)
25         {
26             m_Prototext.append("node { \n");
27             m_Prototext.append("  name: \"input").append(std::to_string(i)).append("\"\n");
28             m_Prototext += R"(  op: "Placeholder"
29   attr {
30     key: "dtype"
31     value {
32       type: DT_FLOAT
33     }
34   }
35   attr {
36     key: "shape"
37     value {
38       shape {
39       }
40     }
41   }
42 }
43 )";
44         }
45         m_Prototext += R"(node {
46   name:  "output"
47   op: "AddN"
48 )";
49         for (unsigned int i = 0; i < numberOfInputs; i++)
50         {
51             m_Prototext.append("  input: \"input").append(std::to_string(i)).append("\"\n");
52         }
53         m_Prototext += R"(  attr {
54     key: "N"
55     value {
56 )";
57         m_Prototext.append("      i: ").append(std::to_string(numberOfInputs)).append("\n");
58         m_Prototext += R"(    }
59   }
60   attr {
61     key: "T"
62     value {
63       type: DT_FLOAT
64     }
65   }
66 })";
67 
68         std::map<std::string, armnn::TensorShape> inputs;
69         for (unsigned int i = 0; i < numberOfInputs; i++)
70         {
71             std::string name("input");
72             name.append(std::to_string(i));
73             inputs.emplace(std::make_pair(name, inputShapes[i]));
74         }
75         Setup(inputs, {"output"});
76     }
77 
78 };
79 
80 // try with 2, 3, 5 and 8 inputs
81 struct FiveTwoDimInputsFixture : AddNFixture
82 {
FiveTwoDimInputsFixtureFiveTwoDimInputsFixture83     FiveTwoDimInputsFixture() : AddNFixture({ { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 } }, 5) {}
84 };
85 
86 
BOOST_FIXTURE_TEST_CASE(FiveTwoDimInputs,FiveTwoDimInputsFixture)87 BOOST_FIXTURE_TEST_CASE(FiveTwoDimInputs, FiveTwoDimInputsFixture)
88 {
89     RunTest<2>({ { "input0", { 1.0, 2.0, 3.0, 4.0 } },
90                  { "input1", { 1.0, 5.0, 2.0, 2.0 } },
91                  { "input2", { 1.0, 1.0, 2.0, 2.0 } },
92                  { "input3", { 3.0, 7.0, 1.0, 2.0 } },
93                  { "input4", { 8.0, 0.0, -2.0, -3.0 } } },
94                { { "output", { 14.0, 15.0, 6.0, 7.0 } } });
95 }
96 
97 struct TwoTwoDimInputsFixture : AddNFixture
98 {
TwoTwoDimInputsFixtureTwoTwoDimInputsFixture99     TwoTwoDimInputsFixture() : AddNFixture({ { 2, 2 }, { 2, 2 } }, 2) {}
100 };
101 
BOOST_FIXTURE_TEST_CASE(TwoTwoDimInputs,TwoTwoDimInputsFixture)102 BOOST_FIXTURE_TEST_CASE(TwoTwoDimInputs, TwoTwoDimInputsFixture)
103 {
104     RunTest<2>({ { "input0", { 1.0, 2.0, 3.0, 4.0 } },
105                  { "input1", { 1.0, 5.0, 2.0, 2.0 } } },
106                { { "output", { 2.0, 7.0, 5.0, 6.0 } } });
107 }
108 
109 struct ThreeTwoDimInputsFixture : AddNFixture
110 {
ThreeTwoDimInputsFixtureThreeTwoDimInputsFixture111     ThreeTwoDimInputsFixture() : AddNFixture({ { 2, 2 }, { 2, 2 }, { 2, 2 } }, 3) {}
112 };
113 
BOOST_FIXTURE_TEST_CASE(ThreeTwoDimInputs,ThreeTwoDimInputsFixture)114 BOOST_FIXTURE_TEST_CASE(ThreeTwoDimInputs, ThreeTwoDimInputsFixture)
115 {
116     RunTest<2>({ { "input0", { 1.0, 2.0, 3.0, 4.0 } },
117                  { "input1", { 1.0, 5.0, 2.0, 2.0 } },
118                  { "input2", { 1.0, 1.0, 2.0, 2.0 } } },
119                { { "output", { 3.0, 8.0, 7.0, 8.0 } } });
120 }
121 
122 struct EightTwoDimInputsFixture : AddNFixture
123 {
EightTwoDimInputsFixtureEightTwoDimInputsFixture124     EightTwoDimInputsFixture() : AddNFixture({ { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 },
125                                                { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 } }, 8) {}
126 };
127 
BOOST_FIXTURE_TEST_CASE(EightTwoDimInputs,EightTwoDimInputsFixture)128 BOOST_FIXTURE_TEST_CASE(EightTwoDimInputs, EightTwoDimInputsFixture)
129 {
130     RunTest<2>({ { "input0", { 1.0, 2.0, 3.0, 4.0 } },
131                  { "input1", { 1.0, 5.0, 2.0, 2.0 } },
132                  { "input2", { 1.0, 1.0, 2.0, 2.0 } },
133                  { "input3", { 3.0, 7.0, 1.0, 2.0 } },
134                  { "input4", { 8.0, 0.0, -2.0, -3.0 } },
135                  { "input5", {-3.0, 2.0, -1.0, -5.0 } },
136                  { "input6", { 1.0, 6.0, 2.0, 2.0 } },
137                  { "input7", {-19.0, 7.0, 1.0, -10.0 } } },
138                { { "output", {-7.0, 30.0, 8.0, -6.0 } } });
139 }
140 
141 struct ThreeInputBroadcast1D4D4DInputsFixture : AddNFixture
142 {
ThreeInputBroadcast1D4D4DInputsFixtureThreeInputBroadcast1D4D4DInputsFixture143     ThreeInputBroadcast1D4D4DInputsFixture() : AddNFixture({ { 1 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, 3) {}
144 };
145 
BOOST_FIXTURE_TEST_CASE(ThreeInputBroadcast1D4D4DInputs,ThreeInputBroadcast1D4D4DInputsFixture)146 BOOST_FIXTURE_TEST_CASE(ThreeInputBroadcast1D4D4DInputs, ThreeInputBroadcast1D4D4DInputsFixture)
147 {
148     RunTest<4>({ { "input0", { 1.0 } },
149                  { "input1", { 1.0, 5.0, 2.0, 2.0 } },
150                  { "input2", { 1.0, 1.0, 2.0, 2.0 } } },
151                { { "output", { 3.0, 7.0, 5.0, 5.0 } } });
152 }
153 
154 struct ThreeInputBroadcast4D1D4DInputsFixture : AddNFixture
155 {
ThreeInputBroadcast4D1D4DInputsFixtureThreeInputBroadcast4D1D4DInputsFixture156     ThreeInputBroadcast4D1D4DInputsFixture() : AddNFixture({ { 1, 1, 2, 2 }, { 1 }, { 1, 1, 2, 2 } }, 3) {}
157 };
158 
BOOST_FIXTURE_TEST_CASE(ThreeInputBroadcast4D1D4DInputs,ThreeInputBroadcast4D1D4DInputsFixture)159 BOOST_FIXTURE_TEST_CASE(ThreeInputBroadcast4D1D4DInputs, ThreeInputBroadcast4D1D4DInputsFixture)
160 {
161     RunTest<4>({ { "input0", { 1.0, 3.0, 9.0, 4.0 } },
162                  { "input1", {-2.0 } },
163                  { "input2", { 1.0, 1.0, 2.0, 2.0 } } },
164                { { "output", { 0.0, 2.0, 9.0, 4.0 } } });
165 }
166 
167 struct ThreeInputBroadcast4D4D1DInputsFixture : AddNFixture
168 {
ThreeInputBroadcast4D4D1DInputsFixtureThreeInputBroadcast4D4D1DInputsFixture169     ThreeInputBroadcast4D4D1DInputsFixture() : AddNFixture({ { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1 } }, 3) {}
170 };
171 
BOOST_FIXTURE_TEST_CASE(ThreeInputBroadcast4D4D1DInputs,ThreeInputBroadcast4D4D1DInputsFixture)172 BOOST_FIXTURE_TEST_CASE(ThreeInputBroadcast4D4D1DInputs, ThreeInputBroadcast4D4D1DInputsFixture)
173 {
174     RunTest<4>({ { "input0", { 1.0, 5.0, 2.0, 2.0 } },
175                  { "input1", { 1.0, 1.0, 2.0, 2.0 } },
176                  { "input2", { 1.0 } } },
177                { { "output", { 3.0, 7.0, 5.0, 5.0 } } });
178 }
179 
180 BOOST_AUTO_TEST_SUITE_END()
181