• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TEST_POOLING_LAYER_FIXTURE
25 #define ARM_COMPUTE_TEST_POOLING_LAYER_FIXTURE
26 
27 #include "arm_compute/core/TensorShape.h"
28 #include "arm_compute/core/Types.h"
29 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
30 #include "arm_compute/runtime/Tensor.h"
31 #include "tests/AssetsLibrary.h"
32 #include "tests/Globals.h"
33 #include "tests/IAccessor.h"
34 #include "tests/framework/Asserts.h"
35 #include "tests/framework/Fixture.h"
36 #include "tests/validation/reference/PoolingLayer.h"
37 #include <random>
38 namespace arm_compute
39 {
40 namespace test
41 {
42 namespace validation
43 {
44 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
45 class PoolingLayerValidationGenericFixture : public framework::Fixture
46 {
47 public:
48     template <typename...>
49     void setup(TensorShape shape, PoolingLayerInfo pool_info, DataType data_type, DataLayout data_layout, bool indices = false)
50     {
51         std::mt19937                    gen(library->seed());
52         std::uniform_int_distribution<> offset_dis(0, 20);
53         const float                     scale     = data_type == DataType::QASYMM8_SIGNED ? 1.f / 127.f : 1.f / 255.f;
54         const int                       scale_in  = data_type == DataType::QASYMM8_SIGNED ? -offset_dis(gen) : offset_dis(gen);
55         const int                       scale_out = data_type == DataType::QASYMM8_SIGNED ? -offset_dis(gen) : offset_dis(gen);
56         const QuantizationInfo          input_qinfo(scale, scale_in);
57         const QuantizationInfo          output_qinfo(scale, scale_out);
58 
59         _pool_info = pool_info;
60         _target    = compute_target(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo, indices);
61         _reference = compute_reference(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo, indices);
62     }
63 
64 protected:
65     template <typename U>
fill(U && tensor)66     void fill(U &&tensor)
67     {
68         if(!is_data_type_quantized(tensor.data_type()))
69         {
70             std::uniform_real_distribution<> distribution(-1.f, 1.f);
71             library->fill(tensor, distribution, 0);
72         }
73         else // data type is quantized_asymmetric
74         {
75             library->fill_tensor_uniform(tensor, 0);
76         }
77     }
78 
compute_target(TensorShape shape,PoolingLayerInfo info,DataType data_type,DataLayout data_layout,QuantizationInfo input_qinfo,QuantizationInfo output_qinfo,bool indices)79     TensorType compute_target(TensorShape shape, PoolingLayerInfo info,
80                               DataType data_type, DataLayout data_layout,
81                               QuantizationInfo input_qinfo, QuantizationInfo output_qinfo,
82                               bool indices)
83     {
84         // Change shape in case of NHWC.
85         if(data_layout == DataLayout::NHWC)
86         {
87             permute(shape, PermutationVector(2U, 0U, 1U));
88         }
89         // Create tensors
90         TensorType        src       = create_tensor<TensorType>(shape, data_type, 1, input_qinfo, data_layout);
91         const TensorShape dst_shape = misc::shape_calculator::compute_pool_shape(*(src.info()), info);
92         TensorType        dst       = create_tensor<TensorType>(dst_shape, data_type, 1, output_qinfo, data_layout);
93         _target_indices             = create_tensor<TensorType>(dst_shape, DataType::U32, 1, output_qinfo, data_layout);
94 
95         // Create and configure function
96         FunctionType pool_layer;
97         pool_layer.configure(&src, &dst, info, (indices) ? &_target_indices : nullptr);
98 
99         ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
100         ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
101         ARM_COMPUTE_EXPECT(_target_indices.info()->is_resizable(), framework::LogLevel::ERRORS);
102 
103         // Allocate tensors
104         src.allocator()->allocate();
105         dst.allocator()->allocate();
106         _target_indices.allocator()->allocate();
107 
108         ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS);
109         ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
110         ARM_COMPUTE_EXPECT(!_target_indices.info()->is_resizable(), framework::LogLevel::ERRORS);
111 
112         // Fill tensors
113         fill(AccessorType(src));
114 
115         // Compute function
116         pool_layer.run();
117 
118         return dst;
119     }
120 
compute_reference(TensorShape shape,PoolingLayerInfo info,DataType data_type,DataLayout data_layout,QuantizationInfo input_qinfo,QuantizationInfo output_qinfo,bool indices)121     SimpleTensor<T> compute_reference(TensorShape shape, PoolingLayerInfo info, DataType data_type, DataLayout data_layout,
122                                       QuantizationInfo input_qinfo, QuantizationInfo output_qinfo, bool indices)
123     {
124         // Create reference
125         SimpleTensor<T> src(shape, data_type, 1, input_qinfo);
126         // Fill reference
127         fill(src);
128         return reference::pooling_layer<T>(src, info, output_qinfo, indices ? &_ref_indices : nullptr, data_layout);
129     }
130 
131     TensorType             _target{};
132     SimpleTensor<T>        _reference{};
133     PoolingLayerInfo       _pool_info{};
134     TensorType             _target_indices{};
135     SimpleTensor<uint32_t> _ref_indices{};
136 };
137 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
138 class PoolingLayerIndicesValidationFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
139 {
140 public:
141     template <typename...>
setup(TensorShape shape,PoolingType pool_type,Size2D pool_size,PadStrideInfo pad_stride_info,bool exclude_padding,DataType data_type,DataLayout data_layout)142     void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout)
143     {
144         PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding),
145                                                                                                data_type, data_layout, true);
146     }
147 };
148 
149 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
150 class PoolingLayerValidationFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
151 {
152 public:
153     template <typename...>
setup(TensorShape shape,PoolingType pool_type,Size2D pool_size,PadStrideInfo pad_stride_info,bool exclude_padding,DataType data_type,DataLayout data_layout)154     void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout)
155     {
156         PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding),
157                                                                                                data_type, data_layout);
158     }
159 };
160 
161 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
162 class PoolingLayerValidationMixedPrecisionFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
163 {
164 public:
165     template <typename...>
166     void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout, bool fp_mixed_precision = false)
167     {
168         PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding, fp_mixed_precision),
169                                                                                                data_type, data_layout);
170     }
171 };
172 
173 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
174 class PoolingLayerValidationQuantizedFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
175 {
176 public:
177     template <typename...>
178     void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout = DataLayout::NCHW)
179     {
180         PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding),
181                                                                                                data_type, data_layout);
182     }
183 };
184 
185 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
186 class SpecialPoolingLayerValidationFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
187 {
188 public:
189     template <typename...>
setup(TensorShape src_shape,PoolingLayerInfo pool_info,DataType data_type)190     void setup(TensorShape src_shape, PoolingLayerInfo pool_info, DataType data_type)
191     {
192         PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, pool_info, data_type, DataLayout::NCHW);
193     }
194 };
195 
196 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
197 class GlobalPoolingLayerValidationFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
198 {
199 public:
200     template <typename...>
201     void setup(TensorShape shape, PoolingType pool_type, DataType data_type, DataLayout data_layout = DataLayout::NCHW)
202     {
203         PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, data_layout), data_type, data_layout);
204     }
205 };
206 
207 } // namespace validation
208 } // namespace test
209 } // namespace arm_compute
210 #endif /* ARM_COMPUTE_TEST_POOLING_LAYER_FIXTURE */
211