• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9 
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11 
12 struct MinimumFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
MinimumFixtureMinimumFixture14     MinimumFixture()
15     {
16         m_Prototext = R"(
17             node {
18               name: "input0"
19               op: "Placeholder"
20               attr {
21                 key: "dtype"
22                 value {
23                   type: DT_FLOAT
24                 }
25               }
26               attr {
27                 key: "shape"
28                 value {
29                   shape {
30                   }
31                 }
32               }
33             }
34             node {
35               name: "input1"
36               op: "Placeholder"
37               attr {
38                 key: "dtype"
39                 value {
40                   type: DT_FLOAT
41                 }
42               }
43               attr {
44                 key: "shape"
45                 value {
46                   shape {
47                   }
48                 }
49               }
50             }
51             node {
52               name: "output"
53               op: "Minimum"
54               input: "input0"
55               input: "input1"
56               attr {
57                 key: "T"
58                 value {
59                   type: DT_FLOAT
60                 }
61               }
62             }
63         )";
64     }
65 };
66 
BOOST_FIXTURE_TEST_CASE(ParseMininumUnsupportedBroadcast,MinimumFixture)67 BOOST_FIXTURE_TEST_CASE(ParseMininumUnsupportedBroadcast, MinimumFixture)
68 {
69     BOOST_REQUIRE_THROW(Setup({ { "input0", {2, 3} },
70                                 { "input1", {1, 2, 2, 3} } },
71                               { "output" }),
72                         armnn::ParseException);
73 }
74 
75 struct MinimumFixtureAutoSetup : public MinimumFixture
76 {
MinimumFixtureAutoSetupMinimumFixtureAutoSetup77     MinimumFixtureAutoSetup(const armnn::TensorShape& input0Shape,
78                             const armnn::TensorShape& input1Shape)
79     : MinimumFixture()
80     {
81         Setup({ { "input0", input0Shape },
82                 { "input1", input1Shape } },
83               { "output" });
84     }
85 };
86 
87 struct MinimumFixture4D : public MinimumFixtureAutoSetup
88 {
MinimumFixture4DMinimumFixture4D89     MinimumFixture4D()
90     : MinimumFixtureAutoSetup({1, 2, 2, 3}, {1, 2, 2, 3}) {}
91 };
92 
BOOST_FIXTURE_TEST_CASE(ParseMinimum4D,MinimumFixture4D)93 BOOST_FIXTURE_TEST_CASE(ParseMinimum4D, MinimumFixture4D)
94 {
95     RunTest<4>({ { "input0", { 0.0f,  1.0f,  2.0f,
96                                3.0f,  4.0f,  5.0f,
97                                6.0f,  7.0f,  8.0f,
98                                9.0f, 10.0f, 11.0f } },
99                  { "input1", { 0.0f, 0.0f, 0.0f,
100                                5.0f, 5.0f, 5.0f,
101                                7.0f, 7.0f, 7.0f,
102                                9.0f, 9.0f, 9.0f } } },
103                { { "output", { 0.0f, 0.0f, 0.0f,
104                                3.0f, 4.0f, 5.0f,
105                                6.0f, 7.0f, 7.0f,
106                                9.0f, 9.0f, 9.0f } } });
107 }
108 
109 struct MinimumBroadcastFixture4D : public MinimumFixtureAutoSetup
110 {
MinimumBroadcastFixture4DMinimumBroadcastFixture4D111     MinimumBroadcastFixture4D()
112     : MinimumFixtureAutoSetup({1, 1, 2, 1}, {1, 2, 1, 3}) {}
113 };
114 
BOOST_FIXTURE_TEST_CASE(ParseMinimumBroadcast4D,MinimumBroadcastFixture4D)115 BOOST_FIXTURE_TEST_CASE(ParseMinimumBroadcast4D, MinimumBroadcastFixture4D)
116 {
117     RunTest<4>({ { "input0", { 2.0f,
118                                4.0f } },
119                  { "input1", { 1.0f, 2.0f, 3.0f,
120                                4.0f, 5.0f, 6.0f } } },
121                { { "output", { 1.0f, 2.0f, 2.0f,
122                                1.0f, 2.0f, 3.0f,
123                                2.0f, 2.0f, 2.0f,
124                                4.0f, 4.0f, 4.0f } } });
125 }
126 
127 struct MinimumBroadcastFixture4D1D : public MinimumFixtureAutoSetup
128 {
MinimumBroadcastFixture4D1DMinimumBroadcastFixture4D1D129     MinimumBroadcastFixture4D1D()
130     : MinimumFixtureAutoSetup({1, 2, 2, 3}, {1}) {}
131 };
132 
BOOST_FIXTURE_TEST_CASE(ParseMinimumBroadcast4D1D,MinimumBroadcastFixture4D1D)133 BOOST_FIXTURE_TEST_CASE(ParseMinimumBroadcast4D1D, MinimumBroadcastFixture4D1D)
134 {
135     RunTest<4>({ { "input0", { 0.0f,  1.0f,  2.0f,
136                                3.0f,  4.0f,  5.0f,
137                                6.0f,  7.0f,  8.0f,
138                                9.0f, 10.0f, 11.0f } },
139                  { "input1", { 5.0f } } },
140                { { "output", { 0.0f, 1.0f, 2.0f,
141                                3.0f, 4.0f, 5.0f,
142                                5.0f, 5.0f, 5.0f,
143                                5.0f, 5.0f, 5.0f } } });
144 }
145 
146 struct MinimumBroadcastFixture1D4D : public MinimumFixtureAutoSetup
147 {
MinimumBroadcastFixture1D4DMinimumBroadcastFixture1D4D148     MinimumBroadcastFixture1D4D()
149     : MinimumFixtureAutoSetup({3}, {1, 2, 2, 3}) {}
150 };
151 
BOOST_FIXTURE_TEST_CASE(ParseMinimumBroadcast1D4D,MinimumBroadcastFixture1D4D)152 BOOST_FIXTURE_TEST_CASE(ParseMinimumBroadcast1D4D, MinimumBroadcastFixture1D4D)
153 {
154     RunTest<4>({ { "input0", { 5.0f,  6.0f,  7.0f } },
155                  { "input1", { 0.0f,  1.0f,  2.0f,
156                                3.0f,  4.0f,  5.0f,
157                                6.0f,  7.0f,  8.0f,
158                                9.0f, 10.0f, 11.0f } } },
159                { { "output", { 0.0f, 1.0f, 2.0f,
160                                3.0f, 4.0f, 5.0f,
161                                5.0f, 6.0f, 7.0f,
162                                5.0f, 6.0f, 7.0f } } });
163 }
164 
165 BOOST_AUTO_TEST_SUITE_END()
166