• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2019-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TEST_FUSEBATCHNORMALIZATION_FIXTURE
25 #define ARM_COMPUTE_TEST_FUSEBATCHNORMALIZATION_FIXTURE
26 
27 #include "arm_compute/core/TensorShape.h"
28 #include "arm_compute/core/Types.h"
29 #include "tests/AssetsLibrary.h"
30 #include "tests/Globals.h"
31 #include "tests/IAccessor.h"
32 #include "tests/framework/Asserts.h"
33 #include "tests/framework/Fixture.h"
34 #include "tests/validation/Helpers.h"
35 #include "tests/validation/reference/FuseBatchNormalization.h"
36 
37 #include <tuple>
38 #include <utility>
39 
40 namespace arm_compute
41 {
42 namespace test
43 {
44 namespace validation
45 {
46 template <typename TensorType, typename AccessorType, typename FunctionType, int dims_weights, typename T>
47 class FuseBatchNormalizationFixture : public framework::Fixture
48 {
49 public:
50     template <typename...>
setup(TensorShape shape_w,DataType data_type,DataLayout data_layout,bool in_place,bool with_bias,bool with_gamma,bool with_beta)51     void setup(TensorShape shape_w, DataType data_type, DataLayout data_layout, bool in_place, bool with_bias, bool with_gamma, bool with_beta)
52     {
53         std::tie(_target_w, _target_b)       = compute_target(shape_w, data_type, data_layout, in_place, with_bias, with_gamma, with_beta);
54         std::tie(_reference_w, _reference_b) = compute_reference(shape_w, data_type, with_bias, with_gamma, with_beta);
55     }
56 
57 protected:
58     template <typename U>
fill(U && tensor,int i,float min,float max)59     void fill(U &&tensor, int i, float min, float max)
60     {
61         library->fill_tensor_uniform(tensor, i, min, max);
62     }
63 
compute_target(TensorShape shape_w,DataType data_type,DataLayout data_layout,bool in_place,bool with_bias,bool with_gamma,bool with_beta)64     std::pair<TensorType, TensorType> compute_target(TensorShape shape_w, DataType data_type, DataLayout data_layout, bool in_place, bool with_bias, bool with_gamma, bool with_beta)
65     {
66         const TensorShape shape_v(shape_w[dims_weights - 1]);
67 
68         if(data_layout == DataLayout::NHWC)
69         {
70             permute(shape_w, PermutationVector(2U, 0U, 1U));
71         }
72 
73         const bool in_place_w = in_place;
74         const bool in_place_b = with_bias ? in_place : false;
75 
76         // Create tensors
77         TensorType w       = create_tensor<TensorType>(shape_w, data_type, 1, QuantizationInfo(), data_layout);
78         TensorType b       = create_tensor<TensorType>(shape_v, data_type);
79         TensorType mean    = create_tensor<TensorType>(shape_v, data_type);
80         TensorType var     = create_tensor<TensorType>(shape_v, data_type);
81         TensorType w_fused = create_tensor<TensorType>(shape_w, data_type, 1, QuantizationInfo(), data_layout);
82         TensorType b_fused = create_tensor<TensorType>(shape_v, data_type);
83         TensorType beta    = create_tensor<TensorType>(shape_v, data_type);
84         TensorType gamma   = create_tensor<TensorType>(shape_v, data_type);
85 
86         auto b_to_use       = with_bias ? &b : nullptr;
87         auto gamma_to_use   = with_gamma ? &gamma : nullptr;
88         auto beta_to_use    = with_beta ? &beta : nullptr;
89         auto w_fused_to_use = in_place_w ? nullptr : &w_fused;
90         auto b_fused_to_use = in_place_b ? nullptr : &b_fused;
91 
92         const FuseBatchNormalizationType fuse_bn_type = dims_weights == 3 ?
93                                                         FuseBatchNormalizationType::DEPTHWISECONVOLUTION :
94                                                         FuseBatchNormalizationType::CONVOLUTION;
95         // Create and configure function
96         FunctionType fuse_batch_normalization;
97         fuse_batch_normalization.configure(&w, &mean, &var, w_fused_to_use, b_fused_to_use, b_to_use, beta_to_use, gamma_to_use, _epsilon, fuse_bn_type);
98 
99         ARM_COMPUTE_ASSERT(w.info()->is_resizable());
100         ARM_COMPUTE_ASSERT(b.info()->is_resizable());
101         ARM_COMPUTE_ASSERT(mean.info()->is_resizable());
102         ARM_COMPUTE_ASSERT(var.info()->is_resizable());
103         ARM_COMPUTE_ASSERT(w_fused.info()->is_resizable());
104         ARM_COMPUTE_ASSERT(b_fused.info()->is_resizable());
105         ARM_COMPUTE_ASSERT(beta.info()->is_resizable());
106         ARM_COMPUTE_ASSERT(gamma.info()->is_resizable());
107 
108         // Allocate tensors
109         w.allocator()->allocate();
110         b.allocator()->allocate();
111         mean.allocator()->allocate();
112         var.allocator()->allocate();
113         w_fused.allocator()->allocate();
114         b_fused.allocator()->allocate();
115         beta.allocator()->allocate();
116         gamma.allocator()->allocate();
117 
118         ARM_COMPUTE_ASSERT(!w.info()->is_resizable());
119         ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
120         ARM_COMPUTE_ASSERT(!mean.info()->is_resizable());
121         ARM_COMPUTE_ASSERT(!var.info()->is_resizable());
122         ARM_COMPUTE_ASSERT(!w_fused.info()->is_resizable());
123         ARM_COMPUTE_ASSERT(!b_fused.info()->is_resizable());
124         ARM_COMPUTE_ASSERT(!beta.info()->is_resizable());
125         ARM_COMPUTE_ASSERT(!gamma.info()->is_resizable());
126 
127         // Fill tensors
128         fill(AccessorType(w), 0U, -1.0f, 1.0f);
129         fill(AccessorType(b), 1U, -1.0f, 1.0f);
130         fill(AccessorType(mean), 2U, -1.0f, 1.0f);
131         fill(AccessorType(var), 3U, 0.0f, 1.0f);
132         fill(AccessorType(beta), 4U, -1.0f, 1.0f);
133         fill(AccessorType(gamma), 5U, -1.0f, 1.0f);
134 
135         // Compute function
136         fuse_batch_normalization.run();
137 
138         return std::make_pair(std::move(in_place_w ? w : w_fused), std::move(in_place_b ? b : b_fused));
139     }
140 
compute_reference(TensorShape shape_w,DataType data_type,bool with_bias,bool with_gamma,bool with_beta)141     std::pair<SimpleTensor<T>, SimpleTensor<T>> compute_reference(TensorShape shape_w, DataType data_type, bool with_bias, bool with_gamma, bool with_beta)
142     {
143         const TensorShape shape_v(shape_w[dims_weights - 1]);
144 
145         SimpleTensor<T> w{ shape_w, data_type };
146         SimpleTensor<T> b{ shape_v, data_type };
147         SimpleTensor<T> mean{ shape_v, data_type };
148         SimpleTensor<T> var{ shape_v, data_type };
149         SimpleTensor<T> w_fused{ shape_w, data_type };
150         SimpleTensor<T> b_fused{ shape_v, data_type };
151         SimpleTensor<T> beta{ shape_v, data_type };
152         SimpleTensor<T> gamma{ shape_v, data_type };
153 
154         // Fill reference tensor
155         fill(w, 0U, -1.0f, 1.0f);
156         fill(b, 1U, -1.0f, 1.0f);
157         fill(mean, 2U, -1.0f, 1.0f);
158         fill(var, 3U, 0.0f, 1.0f);
159         fill(beta, 4U, -1.0f, 1.0f);
160         fill(gamma, 5U, -1.0f, 1.0f);
161 
162         if(!with_bias)
163         {
164             // Fill with zeros
165             fill(b, 0U, 0.0f, 0.0f);
166         }
167 
168         if(!with_gamma)
169         {
170             // Fill with ones
171             fill(gamma, 0U, 1.0f, 1.0f);
172         }
173 
174         if(!with_beta)
175         {
176             // Fill with zeros
177             fill(beta, 0U, 0.0f, 0.0f);
178         }
179 
180         switch(dims_weights)
181         {
182             case 3:
183                 // Weights for depth wise convolution layer
184                 reference::fuse_batch_normalization_dwc_layer(w, mean, var, w_fused, b_fused, b, beta, gamma, _epsilon);
185                 break;
186             case 4:
187                 // Weights for convolution layer
188                 reference::fuse_batch_normalization_conv_layer(w, mean, var, w_fused, b_fused, b, beta, gamma, _epsilon);
189                 break;
190             default:
191                 ARM_COMPUTE_ERROR("Not supported number of dimensions for the input weights tensor");
192         }
193 
194         return std::make_pair(std::move(w_fused), std::move(b_fused));
195     }
196 
197     const float     _epsilon{ 0.0001f };
198     TensorType      _target_w{};
199     TensorType      _target_b{};
200     SimpleTensor<T> _reference_w{};
201     SimpleTensor<T> _reference_b{};
202 };
203 } // namespace validation
204 } // namespace test
205 } // namespace arm_compute
206 #endif /* ARM_COMPUTE_TEST_FUSEBATCHNORMALIZATION_FIXTURE */
207