1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "TestNameAndDescriptorLayerVisitor.hpp"
6 #include "Network.hpp"
7
8 #include <armnn/Exceptions.hpp>
9
10 namespace
11 {
12
13 #define TEST_CASE_CHECK_LAYER_VISITOR_NAME_AND_DESCRIPTOR(name) \
14 BOOST_AUTO_TEST_CASE(Check##name##LayerVisitorNameAndDescriptor) \
15 { \
16 const char* layerName = "name##Layer"; \
17 armnn::name##Descriptor descriptor = GetDescriptor<armnn::name##Descriptor>(); \
18 Test##name##LayerVisitor visitor(descriptor, layerName); \
19 armnn::Network net; \
20 armnn::IConnectableLayer *const layer = net.Add##name##Layer(descriptor, layerName); \
21 layer->Accept(visitor); \
22 }
23
24 #define TEST_CASE_CHECK_LAYER_VISITOR_NAME_NULLPTR_AND_DESCRIPTOR(name) \
25 BOOST_AUTO_TEST_CASE(Check##name##LayerVisitorNameNullptrAndDescriptor) \
26 { \
27 armnn::name##Descriptor descriptor = GetDescriptor<armnn::name##Descriptor>(); \
28 Test##name##LayerVisitor visitor(descriptor); \
29 armnn::Network net; \
30 armnn::IConnectableLayer *const layer = net.Add##name##Layer(descriptor); \
31 layer->Accept(visitor); \
32 }
33
34 #define TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(name) \
35 TEST_CASE_CHECK_LAYER_VISITOR_NAME_AND_DESCRIPTOR(name) \
36 TEST_CASE_CHECK_LAYER_VISITOR_NAME_NULLPTR_AND_DESCRIPTOR(name)
37
38 template<typename Descriptor> Descriptor GetDescriptor();
39
40 template<>
GetDescriptor()41 armnn::ActivationDescriptor GetDescriptor<armnn::ActivationDescriptor>()
42 {
43 armnn::ActivationDescriptor descriptor;
44 descriptor.m_Function = armnn::ActivationFunction::Linear;
45 descriptor.m_A = 2.0f;
46 descriptor.m_B = 2.0f;
47
48 return descriptor;
49 }
50
51 template<>
GetDescriptor()52 armnn::ArgMinMaxDescriptor GetDescriptor<armnn::ArgMinMaxDescriptor>()
53 {
54 armnn::ArgMinMaxDescriptor descriptor;
55 descriptor.m_Function = armnn::ArgMinMaxFunction::Max;
56 descriptor.m_Axis = 1;
57
58 return descriptor;
59 }
60
61 template<>
GetDescriptor()62 armnn::BatchToSpaceNdDescriptor GetDescriptor<armnn::BatchToSpaceNdDescriptor>()
63 {
64 return armnn::BatchToSpaceNdDescriptor({ 1, 1 }, {{ 0, 0 }, { 0, 0 }});
65 }
66
67 template<>
GetDescriptor()68 armnn::ComparisonDescriptor GetDescriptor<armnn::ComparisonDescriptor>()
69 {
70 return armnn::ComparisonDescriptor(armnn::ComparisonOperation::GreaterOrEqual);
71 }
72
73 template<>
GetDescriptor()74 armnn::ConcatDescriptor GetDescriptor<armnn::ConcatDescriptor>()
75 {
76 armnn::ConcatDescriptor descriptor(2, 2);
77 for (unsigned int i = 0u; i < descriptor.GetNumViews(); ++i)
78 {
79 for (unsigned int j = 0u; j < descriptor.GetNumDimensions(); ++j)
80 {
81 descriptor.SetViewOriginCoord(i, j, i);
82 }
83 }
84
85 return descriptor;
86 }
87
88 template<>
GetDescriptor()89 armnn::ElementwiseUnaryDescriptor GetDescriptor<armnn::ElementwiseUnaryDescriptor>()
90 {
91 return armnn::ElementwiseUnaryDescriptor(armnn::UnaryOperation::Abs);
92 }
93
94 template<>
GetDescriptor()95 armnn::FillDescriptor GetDescriptor<armnn::FillDescriptor>()
96 {
97 return armnn::FillDescriptor(1);
98 }
99
100 template<>
GetDescriptor()101 armnn::GatherDescriptor GetDescriptor<armnn::GatherDescriptor>()
102 {
103 return armnn::GatherDescriptor();
104 }
105
106 template<>
GetDescriptor()107 armnn::InstanceNormalizationDescriptor GetDescriptor<armnn::InstanceNormalizationDescriptor>()
108 {
109 armnn::InstanceNormalizationDescriptor descriptor;
110 descriptor.m_Gamma = 1.0f;
111 descriptor.m_Beta = 2.0f;
112 descriptor.m_Eps = 0.0001f;
113 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
114
115 return descriptor;
116 }
117
118 template<>
GetDescriptor()119 armnn::L2NormalizationDescriptor GetDescriptor<armnn::L2NormalizationDescriptor>()
120 {
121 armnn::L2NormalizationDescriptor descriptor;
122 descriptor.m_Eps = 0.0001f;
123 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
124
125 return descriptor;
126 }
127
128 template<>
GetDescriptor()129 armnn::LogicalBinaryDescriptor GetDescriptor<armnn::LogicalBinaryDescriptor>()
130 {
131 return armnn::LogicalBinaryDescriptor(armnn::LogicalBinaryOperation::LogicalOr);
132 }
133
134 template<>
GetDescriptor()135 armnn::MeanDescriptor GetDescriptor<armnn::MeanDescriptor>()
136 {
137 return armnn::MeanDescriptor({ 1, 2, }, true);
138 }
139
140 template<>
GetDescriptor()141 armnn::NormalizationDescriptor GetDescriptor<armnn::NormalizationDescriptor>()
142 {
143 armnn::NormalizationDescriptor descriptor;
144 descriptor.m_NormChannelType = armnn::NormalizationAlgorithmChannel::Within;
145 descriptor.m_NormMethodType = armnn::NormalizationAlgorithmMethod::LocalContrast;
146 descriptor.m_NormSize = 1u;
147 descriptor.m_Alpha = 1.0f;
148 descriptor.m_Beta = 1.0f;
149 descriptor.m_K = 1.0f;
150 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
151
152 return descriptor;
153 }
154
155 template<>
GetDescriptor()156 armnn::PadDescriptor GetDescriptor<armnn::PadDescriptor>()
157 {
158 return armnn::PadDescriptor({{ 1, 2 }, { 3, 4 }});
159 }
160
161 template<>
GetDescriptor()162 armnn::PermuteDescriptor GetDescriptor<armnn::PermuteDescriptor>()
163 {
164 return armnn::PermuteDescriptor({ 0, 1, 2, 3 });
165 }
166
167 template<>
GetDescriptor()168 armnn::Pooling2dDescriptor GetDescriptor<armnn::Pooling2dDescriptor>()
169 {
170 armnn::Pooling2dDescriptor descriptor;
171 descriptor.m_PoolType = armnn::PoolingAlgorithm::Max;
172 descriptor.m_PadLeft = 1u;
173 descriptor.m_PadRight = 1u;
174 descriptor.m_PadTop = 1u;
175 descriptor.m_PadBottom = 1u;
176 descriptor.m_PoolWidth = 1u;
177 descriptor.m_PoolHeight = 1u;
178 descriptor.m_StrideX = 1u;
179 descriptor.m_StrideY = 1u;
180 descriptor.m_OutputShapeRounding = armnn::OutputShapeRounding::Ceiling;
181 descriptor.m_PaddingMethod = armnn::PaddingMethod::IgnoreValue;
182 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
183
184 return descriptor;
185 }
186
187 template<>
GetDescriptor()188 armnn::ReshapeDescriptor GetDescriptor<armnn::ReshapeDescriptor>()
189 {
190 return armnn::ReshapeDescriptor({ 1, 2, 3, 4 });
191 }
192
193 template<>
GetDescriptor()194 armnn::ResizeDescriptor GetDescriptor<armnn::ResizeDescriptor>()
195 {
196 armnn::ResizeDescriptor descriptor;
197 descriptor.m_TargetHeight = 1u;
198 descriptor.m_TargetWidth = 1u;
199 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
200
201 return descriptor;
202 }
203
204 template<>
GetDescriptor()205 armnn::SliceDescriptor GetDescriptor<armnn::SliceDescriptor>()
206 {
207 return armnn::SliceDescriptor({ 1, 1 }, { 2, 2 });
208 }
209
210 template<>
GetDescriptor()211 armnn::SoftmaxDescriptor GetDescriptor<armnn::SoftmaxDescriptor>()
212 {
213 armnn::SoftmaxDescriptor descriptor;
214 descriptor.m_Beta = 2.0f;
215 descriptor.m_Axis = -1;
216
217 return descriptor;
218 }
219
220 template<>
GetDescriptor()221 armnn::SpaceToBatchNdDescriptor GetDescriptor<armnn::SpaceToBatchNdDescriptor>()
222 {
223 return armnn::SpaceToBatchNdDescriptor({ 2, 2 }, {{ 1, 1 } , { 1, 1 }});
224 }
225
226 template<>
GetDescriptor()227 armnn::SpaceToDepthDescriptor GetDescriptor<armnn::SpaceToDepthDescriptor>()
228 {
229 return armnn::SpaceToDepthDescriptor(2, armnn::DataLayout::NHWC);
230 }
231
232 template<>
GetDescriptor()233 armnn::SplitterDescriptor GetDescriptor<armnn::SplitterDescriptor>()
234 {
235 armnn::SplitterDescriptor descriptor(2, 2);
236 for (unsigned int i = 0u; i < descriptor.GetNumViews(); ++i)
237 {
238 for (unsigned int j = 0u; j < descriptor.GetNumDimensions(); ++j)
239 {
240 descriptor.SetViewOriginCoord(i, j, i);
241 descriptor.SetViewSize(i, j, 1);
242 }
243 }
244
245 return descriptor;
246 }
247
248 template<>
GetDescriptor()249 armnn::StackDescriptor GetDescriptor<armnn::StackDescriptor>()
250 {
251 return armnn::StackDescriptor(1, 2, { 2, 2 });
252 }
253
254 template<>
GetDescriptor()255 armnn::StridedSliceDescriptor GetDescriptor<armnn::StridedSliceDescriptor>()
256 {
257 armnn::StridedSliceDescriptor descriptor({ 1, 2 }, { 3, 4 }, { 3, 4 });
258 descriptor.m_BeginMask = 1;
259 descriptor.m_EndMask = 1;
260 descriptor.m_ShrinkAxisMask = 1;
261 descriptor.m_EllipsisMask = 1;
262 descriptor.m_NewAxisMask = 1;
263 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
264
265 return descriptor;
266 }
267
268 template<>
GetDescriptor()269 armnn::TransposeDescriptor GetDescriptor<armnn::TransposeDescriptor>()
270 {
271 return armnn::TransposeDescriptor({ 0, 1, 2, 3 });
272 }
273
274 } // anonymous namespace
275
276 BOOST_AUTO_TEST_SUITE(TestNameAndDescriptorLayerVisitor)
277
278 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Activation)
279 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(ArgMinMax)
280 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(DepthToSpace)
281 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(BatchToSpaceNd)
282 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Comparison)
283 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Concat)
284 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(ElementwiseUnary)
285 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Fill)
286 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Gather)
287 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(InstanceNormalization)
288 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(L2Normalization)
289 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(LogicalBinary)
290 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(LogSoftmax)
291 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Mean)
292 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Normalization)
293 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Pad)
294 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Permute)
295 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Pooling2d)
296 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Reshape)
297 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Resize)
298 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Slice)
299 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Softmax)
300 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(SpaceToBatchNd)
301 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(SpaceToDepth)
302 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Splitter)
303 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Stack)
304 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(StridedSlice)
305 TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Transpose)
306
307 BOOST_AUTO_TEST_SUITE_END()
308