• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnTfParser/ITfParser.hpp"
7 
8 #include "ParserPrototxtFixture.hpp"
9 #include <PrototxtConversions.hpp>
10 
11 #include <boost/test/unit_test.hpp>
12 
13 BOOST_AUTO_TEST_SUITE(TensorflowParser)
14 
15 namespace {
16 // helper for setting the dimensions in prototxt
dimsHelper(const std::vector<int> & dims,std::string & text)17 void dimsHelper(const std::vector<int>& dims, std::string& text){
18     for(unsigned int i = 0; i < dims.size(); ++i) {
19         text.append(R"(dim {
20       size: )");
21         text.append(std::to_string(dims[i]));
22         text.append(R"(
23     })");
24     }
25 }
26 
27 // helper for converting from integer to octal representation
octalHelper(const std::vector<int> & indicesContent,std::string & text)28 void octalHelper(const std::vector<int>& indicesContent, std::string& text){
29     for(unsigned int i = 0; i < indicesContent.size(); ++i) {
30         text.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(indicesContent[i])));
31     }
32 }
33 } // namespace
34 
35 struct GatherFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
36 {
GatherFixtureGatherFixture37     GatherFixture(const armnn::TensorShape& inputShape0,
38                   const armnn::TensorShape& inputShape1,
39                   const std::vector<int>& input1Content,
40                   const std::vector<int>& input0Dims,
41                   const std::vector<int>& input1Dims,
42                   int axis = 0)
43     {
44         m_Prototext = R"(
45 node {
46   name: "input0"
47   op: "Placeholder"
48   attr {
49     key: "dtype"
50     value {
51       type: DT_FLOAT
52     }
53   }
54   attr {
55     key: "shape"
56     value {
57       shape {
58 )";
59         dimsHelper(input0Dims, m_Prototext);
60 
61         m_Prototext.append(R"(
62       }
63     }
64   }
65 }
66 node {
67   name: "input1"
68   op: "Const"
69   attr {
70     key: "dtype"
71     value {
72       type: DT_INT32
73     }
74   }
75   attr {
76     key: "value"
77     value {
78      tensor {
79       dtype: DT_INT32
80         tensor_shape {
81 )");
82         dimsHelper(input1Dims, m_Prototext);
83 
84         m_Prototext.append(R"(
85         }
86         tensor_content: ")");
87         octalHelper(input1Content, m_Prototext);
88         m_Prototext.append(R"("
89       }
90     }
91   }
92 }
93 node {
94   name: "output"
95   op: "Gather"
96   input: "input0"
97   input: "input1"
98   attr {
99     key: "Tindices"
100     value {
101       type: DT_INT32
102     }
103   }
104   attr {
105     key: "Tparams"
106     value {
107       type: DT_FLOAT
108     }
109   }
110   attr {
111     key: "axis"
112     value {
113       i:  )");
114         m_Prototext += std::to_string(axis);
115 
116         m_Prototext.append(R"(
117     }
118   }
119 }
120         )");
121 
122         Setup({ { "input0", inputShape0 },
123                 { "input1", inputShape1 } },
124               { "output" });
125 
126     }
127 };
128 
129 
130 struct GatherFixture1DParams1DIndices : public GatherFixture
131 {
GatherFixture1DParams1DIndicesGatherFixture1DParams1DIndices132     GatherFixture1DParams1DIndices() : GatherFixture(
133             { 4, 1, 1, 1 },
134             { 4, 0, 0, 0 },
135             { 0, 2, 1, 3 },
136             { 4 },
137             { 4 },
138             0) {}
139 };
140 
141 struct GatherFixture1DParamsMultiDimIndices : public GatherFixture
142 {
GatherFixture1DParamsMultiDimIndicesGatherFixture1DParamsMultiDimIndices143     GatherFixture1DParamsMultiDimIndices() : GatherFixture(
144             { 4, 1, 1 },
145             { 2, 2, 1, 1 },
146             { 0, 1, 1, 3 },
147             { 4 },
148             { 2, 2 },
149             0) {}
150 };
151 
152 struct GatherFixtureMultiDimParamMultiDimIndices : public GatherFixture
153 {
GatherFixtureMultiDimParamMultiDimIndicesGatherFixtureMultiDimParamMultiDimIndices154     GatherFixtureMultiDimParamMultiDimIndices() : GatherFixture(
155             { 5, 2, 1 },
156             { 2, 1, 4 },
157             { 1, 3, 0, 2 },
158             { 5, 2 },
159             { 2, 2 },
160             0) {}
161 };
162 
BOOST_FIXTURE_TEST_CASE(ParseGather1DParams1DIndices,GatherFixture1DParams1DIndices)163 BOOST_FIXTURE_TEST_CASE(ParseGather1DParams1DIndices, GatherFixture1DParams1DIndices)
164 {
165     RunTest<4>({ { "input0", { 1, 2, 3, 4 } } },
166 
167                { { "output", { 1, 3, 2, 4 } } });
168 }
169 
BOOST_FIXTURE_TEST_CASE(ParseGather1DParamsMultiDimIndices,GatherFixture1DParamsMultiDimIndices)170 BOOST_FIXTURE_TEST_CASE(ParseGather1DParamsMultiDimIndices, GatherFixture1DParamsMultiDimIndices)
171 {
172     RunTest<4>({ { "input0", { 1, 2, 3, 4 } } },
173 
174                { { "output", { 1, 2, 2, 4 } } });
175 }
176 
BOOST_FIXTURE_TEST_CASE(ParseGatherMultiDimParamMultiDimIndices,GatherFixtureMultiDimParamMultiDimIndices)177 BOOST_FIXTURE_TEST_CASE(ParseGatherMultiDimParamMultiDimIndices, GatherFixtureMultiDimParamMultiDimIndices)
178 {
179     RunTest<4>({ { "input0", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 } } },
180 
181                { { "output", { 3, 4, 7, 8, 1, 2, 5, 6} } });
182 }
183 
184 BOOST_AUTO_TEST_SUITE_END()
185