• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnTfParser/ITfParser.hpp"
7 
8 #include <ParserPrototxtFixture.hpp>
9 #include <PrototxtConversions.hpp>
10 
11 #include <boost/test/unit_test.hpp>
12 
13 BOOST_AUTO_TEST_SUITE(TensorflowParser)
14 
15 struct MeanFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
16 {
MeanFixtureMeanFixture17     explicit MeanFixture(const armnn::TensorShape& inputShape, const armnn::TensorShape& outputShape,
18                          const std::vector<unsigned int>& axis, bool keepDims)
19     {
20         std::string protobufAxisString;
21         std::vector<unsigned int> protobufAxis(axis);
22 
23         // If no axis range is specified, the reduction is applied to
24         // all dimensions of the input tensor
25         if (protobufAxis.size() == 0)
26         {
27             for (unsigned int i = 0; i < inputShape.GetNumDimensions(); ++i)
28             {
29                 protobufAxis.push_back(i);
30             }
31         }
32 
33         for (unsigned int i = 0; i < protobufAxis.size(); ++i)
34         {
35             protobufAxisString.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(protobufAxis[i])));
36         }
37 
38         m_Prototext = R"(node {
39               name: "input"
40               op: "Placeholder"
41               attr {
42               key: "dtype"
43                 value {
44                   type: DT_FLOAT
45                 }
46               }
47               attr {
48                 key: "shape"
49                 value {
50                   shape {
51                   }
52                 }
53               }
54             }
55             node {
56               name: "Const"
57               op: "Const"
58               attr {
59                 key: "dtype"
60                 value {
61                   type: DT_INT32
62                 }
63               }
64               attr {
65                 key: "value"
66                 value { )";
67 
68         if (axis.size() == 1)
69         {
70             m_Prototext.append(R"(      tensor {
71                     dtype: DT_INT32
72                     tensor_shape {
73                     }
74                     int_val: )").append(std::to_string(protobufAxis[0])).append(R"(
75                   } )");
76         }
77         else
78         {
79             m_Prototext.append(R"(      tensor {
80                     dtype: DT_INT32
81                     tensor_shape {
82                       dim {
83                         size: 2
84                       }
85                     }
86                     tensor_content: ")").append(protobufAxisString).append(R"("
87                   } )");
88         }
89 
90         m_Prototext.append(R"(    }
91               }
92             }
93             node {
94               name: "output"
95               op: "Mean"
96               input: "input"
97               input: "Const"
98               attr {
99                 key: "T"
100                 value {
101                   type: DT_FLOAT
102                 }
103               }
104               attr {
105                 key: "Tidx"
106                   value {
107                     type: DT_INT32
108                 }
109              }
110              attr {
111                key: "keep_dims"
112                  value {
113                    b: )").append(keepDims ? "true" : "false").append(R"(
114                }
115              }
116             })");
117 
118         SetupSingleInputSingleOutput(inputShape, outputShape, "input", "output");
119     }
120 };
121 
122 struct MeanNoAxisNoKeepDimsFixture: MeanFixture
123 {
MeanNoAxisNoKeepDimsFixtureMeanNoAxisNoKeepDimsFixture124     MeanNoAxisNoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 1 }, {}, false) {}
125 };
126 
127 struct MeanWithAxis0NoKeepDimsFixture: MeanFixture
128 {
MeanWithAxis0NoKeepDimsFixtureMeanWithAxis0NoKeepDimsFixture129     MeanWithAxis0NoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 3 }, { 0 }, false) {}
130 };
131 
132 struct MeanWithAxis1NoKeepDimsFixture: MeanFixture
133 {
MeanWithAxis1NoKeepDimsFixtureMeanWithAxis1NoKeepDimsFixture134     MeanWithAxis1NoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 2 }, { 1 }, false) {}
135 };
136 
137 struct MeanWithAxis0KeepDimsFixture: MeanFixture
138 {
MeanWithAxis0KeepDimsFixtureMeanWithAxis0KeepDimsFixture139     MeanWithAxis0KeepDimsFixture() : MeanFixture({ 2, 3 }, { 1, 3 }, { 0 }, true) {}
140 };
141 
142 struct MeanWithAxis1KeepDimsFixture: MeanFixture
143 {
MeanWithAxis1KeepDimsFixtureMeanWithAxis1KeepDimsFixture144     MeanWithAxis1KeepDimsFixture() : MeanFixture({ 2, 3 }, { 2, 1 }, { 1 }, true) {}
145 };
146 
147 
BOOST_FIXTURE_TEST_CASE(MeanNoAxisNoKeepDims,MeanNoAxisNoKeepDimsFixture)148 BOOST_FIXTURE_TEST_CASE(MeanNoAxisNoKeepDims, MeanNoAxisNoKeepDimsFixture)
149 {
150     RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
151                { { "output", { 1.5f } } });
152 }
153 
BOOST_FIXTURE_TEST_CASE(MeanWithAxis0NoKeepDims,MeanWithAxis0NoKeepDimsFixture)154 BOOST_FIXTURE_TEST_CASE(MeanWithAxis0NoKeepDims, MeanWithAxis0NoKeepDimsFixture)
155 {
156     RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
157                { { "output", { 1.5f, 1.5f, 1.5f } } });
158 }
159 
BOOST_FIXTURE_TEST_CASE(MeanWithAxis1NoKeepDims,MeanWithAxis1NoKeepDimsFixture)160 BOOST_FIXTURE_TEST_CASE(MeanWithAxis1NoKeepDims, MeanWithAxis1NoKeepDimsFixture)
161 {
162     RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
163                { { "output", { 1.f, 2.f } } });
164 }
165 
BOOST_FIXTURE_TEST_CASE(MeanWithAxis0KeepDims,MeanWithAxis0KeepDimsFixture)166 BOOST_FIXTURE_TEST_CASE(MeanWithAxis0KeepDims, MeanWithAxis0KeepDimsFixture)
167 {
168     RunTest<2>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
169                { { "output", { 1.5f, 1.5f, 1.5f } } });
170 }
171 
BOOST_FIXTURE_TEST_CASE(MeanWithAxis1KeepDims,MeanWithAxis1KeepDimsFixture)172 BOOST_FIXTURE_TEST_CASE(MeanWithAxis1KeepDims, MeanWithAxis1KeepDimsFixture)
173 {
174     RunTest<2>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
175                { { "output", { 1.f, 2.f } } });
176 }
177 
178 BOOST_AUTO_TEST_SUITE_END()
179