• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9 
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11 
12     struct EqualFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13     {
EqualFixtureEqualFixture14         EqualFixture()
15         {
16             m_Prototext = R"(
17 node {
18   name: "input0"
19   op: "Placeholder"
20   attr {
21     key: "dtype"
22     value {
23       type: DT_FLOAT
24     }
25   }
26   attr {
27     key: "shape"
28     value {
29       shape {
30       }
31     }
32   }
33 }
34 node {
35   name: "input1"
36   op: "Placeholder"
37   attr {
38     key: "dtype"
39     value {
40       type: DT_FLOAT
41     }
42   }
43   attr {
44     key: "shape"
45     value {
46       shape {
47       }
48     }
49   }
50 }
51 node {
52   name: "output"
53   op: "Equal"
54   input: "input0"
55   input: "input1"
56   attr {
57     key: "T"
58     value {
59       type: DT_FLOAT
60     }
61   }
62 }
63         )";
64         }
65     };
66 
BOOST_FIXTURE_TEST_CASE(ParseEqualUnsupportedBroadcast,EqualFixture)67 BOOST_FIXTURE_TEST_CASE(ParseEqualUnsupportedBroadcast, EqualFixture)
68 {
69     BOOST_REQUIRE_THROW(Setup({ { "input0", {2, 3} },
70                                 { "input1", {1, 2, 2, 3} } },
71                               { "output" }),
72                               armnn::ParseException);
73 }
74 
75 struct EqualFixtureAutoSetup : public EqualFixture
76 {
EqualFixtureAutoSetupEqualFixtureAutoSetup77     EqualFixtureAutoSetup(const armnn::TensorShape& input0Shape,
78                           const armnn::TensorShape& input1Shape)
79                 : EqualFixture()
80     {
81          Setup({ { "input0", input0Shape },
82                  { "input1", input1Shape } },
83                { "output" });
84     }
85 };
86 
87 struct EqualTwoByTwo : public EqualFixtureAutoSetup
88 {
EqualTwoByTwoEqualTwoByTwo89     EqualTwoByTwo() : EqualFixtureAutoSetup({2,2}, {2,2}) {}
90 };
91 
BOOST_FIXTURE_TEST_CASE(ParseEqualTwoByTwo,EqualTwoByTwo)92 BOOST_FIXTURE_TEST_CASE(ParseEqualTwoByTwo, EqualTwoByTwo)
93 {
94     RunComparisonTest<2>({ { "input0", { 1.0f, 2.0f, 3.0f, 2.0f } },
95                            { "input1", { 1.0f, 5.0f, 2.0f, 2.0f } } },
96                          { { "output", { 1, 0, 0, 1 } } });
97 }
98 
99 struct EqualBroadcast1DAnd4D : public EqualFixtureAutoSetup
100 {
EqualBroadcast1DAnd4DEqualBroadcast1DAnd4D101     EqualBroadcast1DAnd4D() : EqualFixtureAutoSetup({1}, {1,1,2,2}) {}
102 };
103 
BOOST_FIXTURE_TEST_CASE(ParseEqualBroadcast1DToTwoByTwo,EqualBroadcast1DAnd4D)104 BOOST_FIXTURE_TEST_CASE(ParseEqualBroadcast1DToTwoByTwo, EqualBroadcast1DAnd4D)
105 {
106     RunComparisonTest<4>({ { "input0", { 2.0f } },
107                            { "input1", { 1.0f, 2.0f, 3.0f, 2.0f } } },
108                          { { "output", { 0, 1, 0, 1 } } });
109 }
110 
111 struct EqualBroadcast4DAnd1D : public EqualFixtureAutoSetup
112 {
EqualBroadcast4DAnd1DEqualBroadcast4DAnd1D113     EqualBroadcast4DAnd1D() : EqualFixtureAutoSetup({1,1,2,2}, {1}) {}
114 };
115 
BOOST_FIXTURE_TEST_CASE(ParseEqualBroadcast4DAnd1D,EqualBroadcast4DAnd1D)116 BOOST_FIXTURE_TEST_CASE(ParseEqualBroadcast4DAnd1D, EqualBroadcast4DAnd1D)
117 {
118     RunComparisonTest<4>({ { "input0", { 1.0f, 2.0f, 3.0f, 2.0f } },
119                            { "input1", { 3.0f } } },
120                          { { "output", { 0, 0, 1, 0 } } });
121 }
122 
123 struct EqualMultiDimBroadcast : public EqualFixtureAutoSetup
124 {
EqualMultiDimBroadcastEqualMultiDimBroadcast125     EqualMultiDimBroadcast() : EqualFixtureAutoSetup({1,1,2,1}, {1,2,1,3}) {}
126 };
127 
BOOST_FIXTURE_TEST_CASE(ParseEqualMultiDimBroadcast,EqualMultiDimBroadcast)128 BOOST_FIXTURE_TEST_CASE(ParseEqualMultiDimBroadcast, EqualMultiDimBroadcast)
129 {
130     RunComparisonTest<4>({ { "input0", { 1.0f, 2.0f } },
131                            { "input1", { 1.0f, 2.0f, 3.0f,
132                                          3.0f, 2.0f, 2.0f } } },
133                          { { "output", { 1, 0, 0,
134                                          0, 1, 0,
135                                          0, 0, 0,
136                                          0, 1, 1 } } });
137 }
138 
139 BOOST_AUTO_TEST_SUITE_END()
140