• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/core/NEON/kernels/NEFuseBatchNormalizationKernel.h"
25 
26 #include "arm_compute/core/Helpers.h"
27 #include "arm_compute/core/ITensor.h"
28 #include "arm_compute/core/TensorInfo.h"
29 #include "arm_compute/core/Utils.h"
30 #include "arm_compute/core/Validate.h"
31 #include "arm_compute/core/Window.h"
32 #include "src/core/CPP/Validate.h"
33 #include "src/core/NEON/wrapper/wrapper.h"
34 #include "src/core/helpers/AutoConfiguration.h"
35 #include "src/core/helpers/WindowHelpers.h"
36 
37 #include <map>
38 
39 namespace arm_compute
40 {
41 namespace
42 {
validate_arguments(const ITensorInfo * input_weights,const ITensorInfo * bn_mean,const ITensorInfo * bn_var,const ITensorInfo * fused_weights,const ITensorInfo * fused_bias,const ITensorInfo * input_bias,const ITensorInfo * bn_beta,const ITensorInfo * bn_gamma,float epsilon,FuseBatchNormalizationType fbn_type)43 Status validate_arguments(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var,
44                           const ITensorInfo *fused_weights, const ITensorInfo *fused_bias,
45                           const ITensorInfo *input_bias, const ITensorInfo *bn_beta, const ITensorInfo *bn_gamma,
46                           float epsilon, FuseBatchNormalizationType fbn_type)
47 {
48     ARM_COMPUTE_UNUSED(epsilon);
49     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input_weights, bn_mean, bn_var);
50     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input_weights);
51     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_weights, 1, DataType::F16, DataType::F32);
52     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_var);
53     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_mean, bn_var);
54     ARM_COMPUTE_RETURN_ERROR_ON(input_bias == nullptr && fused_bias == nullptr);
55     ARM_COMPUTE_RETURN_ERROR_ON(bn_mean->num_dimensions() > 1);
56 
57     if(fbn_type == FuseBatchNormalizationType::CONVOLUTION)
58     {
59         ARM_COMPUTE_RETURN_ERROR_ON(input_weights->dimension(3) != bn_mean->dimension(0));
60     }
61     else
62     {
63         const size_t channel_idx = get_data_layout_dimension_index(input_weights->data_layout(), DataLayoutDimension::CHANNEL);
64         ARM_COMPUTE_RETURN_ERROR_ON(input_weights->dimension(channel_idx) != bn_mean->dimension(0));
65     }
66     // Validate bias
67     if(input_bias != nullptr)
68     {
69         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, input_bias);
70         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, input_bias);
71     }
72     // Validate beta
73     if(bn_beta != nullptr)
74     {
75         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_beta);
76         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_beta);
77     }
78     // Validate gamma
79     if(bn_gamma != nullptr)
80     {
81         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_gamma);
82         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_gamma);
83     }
84 
85     // Validate output weights
86     if(fused_weights != nullptr && fused_weights->total_size() != 0)
87     {
88         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_weights, fused_weights);
89         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input_weights, fused_weights);
90         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, fused_weights);
91     }
92     // Validate output bias
93     if(fused_bias != nullptr && fused_bias->total_size() != 0)
94     {
95         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, fused_bias);
96         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, fused_bias);
97     }
98 
99     return Status{};
100 }
101 
102 template <typename VectorType>
fused_batch_normalization_conv(const ITensor * conv_weights,const ITensor * conv_bias,ITensor * fused_weights,ITensor * fused_bias,const ITensor * bn_mean,const ITensor * bn_var,const ITensor * bn_beta,const ITensor * bn_gamma,float epsilon,const Window & window)103 void fused_batch_normalization_conv(const ITensor *conv_weights, const ITensor *conv_bias, ITensor *fused_weights, ITensor *fused_bias,
104                                     const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
105 {
106     using ScalarType   = typename VectorType::scalar_type;
107     const int size     = 16 / conv_weights->info()->element_size();
108     using ExactTagType = typename VectorType::tag_type;
109 
110     const bool run_in_place_weights = (fused_weights == nullptr) || (fused_weights == conv_weights);
111     const bool run_in_place_bias    = (fused_bias == nullptr) || (conv_bias != nullptr && fused_bias == conv_bias);
112 
113     // Set build options
114     Window win = window;
115     win.set(Window::DimX, Window::Dimension(0, 1, 1));
116 
117     const int  window_step_x  = size;
118     const auto window_start_x = static_cast<int>(window.x().start());
119     const auto window_end_x   = static_cast<int>(window.x().end());
120 
121     Iterator conv_w_in(conv_weights, win);
122     Iterator conv_w_out(run_in_place_weights ? conv_weights : fused_weights, win);
123 
124     const auto conv_bias_in  = (conv_bias != nullptr ? reinterpret_cast<ScalarType *>(conv_bias->ptr_to_element(Coordinates(0, 0))) : nullptr);
125     auto       conv_bias_out = (run_in_place_bias ? conv_bias_in : reinterpret_cast<ScalarType *>(fused_bias->ptr_to_element(Coordinates(0, 0))));
126 
127     const auto input_mean  = reinterpret_cast<const ScalarType *>(bn_mean->ptr_to_element(Coordinates(0, 0)));
128     const auto input_var   = reinterpret_cast<const ScalarType *>(bn_var->ptr_to_element(Coordinates(0, 0)));
129     const auto input_gamma = (bn_gamma != nullptr) ? reinterpret_cast<const ScalarType *>(bn_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
130     const auto input_beta  = (bn_beta != nullptr) ? reinterpret_cast<const ScalarType *>(bn_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
131 
132     auto       mean_vec    = wrapper::vdup_n(ScalarType(0), ExactTagType{});
133     auto       var_vec     = wrapper::vdup_n(ScalarType(0), ExactTagType{});
134     auto       gamma_vec   = wrapper::vdup_n(ScalarType(1), ExactTagType{});
135     auto       beta_vec    = wrapper::vdup_n(ScalarType(0), ExactTagType{});
136     auto       rvar_vec    = wrapper::vdup_n(ScalarType(0), ExactTagType{});
137     const auto epsilon_vec = wrapper::vdup_n(ScalarType(epsilon), ExactTagType{});
138 
139     auto mean                = ScalarType(0.0);
140     auto var                 = ScalarType(0.0);
141     auto gamma               = ScalarType(1.0);
142     auto beta                = ScalarType(0.0);
143     auto conv_bias_in_scalar = ScalarType(0.0);
144     execute_window_loop(win, [&](const Coordinates & id)
145     {
146         var = input_var[id[3]];
147         if(input_gamma != nullptr)
148         {
149             gamma = input_gamma[id[3]];
150         }
151 
152         if((id[0] == 0) && (id[1] == 0) && (id[2] == 0))
153         {
154             if(input_beta != nullptr)
155             {
156                 beta     = input_beta[id[3]];
157                 beta_vec = wrapper::vdup_n(beta, ExactTagType{});
158             }
159 
160             // Construct vectors
161             mean     = input_mean[id[3]];
162             mean_vec = wrapper::vdup_n(mean, ExactTagType{});
163 
164             if(conv_bias_in != nullptr)
165             {
166                 conv_bias_in_scalar = conv_bias_in[id[3]];
167             }
168             auto conv_bias_tmp_scalar = (conv_bias_in_scalar - mean) / std::sqrt(var + ScalarType(epsilon));
169             conv_bias_out[id[3]]      = (conv_bias_tmp_scalar * gamma) + beta;
170         }
171 
172         int  x              = window_start_x;
173         auto conv_w_in_ptr  = reinterpret_cast<const ScalarType *>(conv_w_in.ptr());
174         auto conv_w_out_ptr = reinterpret_cast<ScalarType *>(conv_w_out.ptr());
175         var_vec             = wrapper::vdup_n(var, ExactTagType{});
176         gamma_vec           = wrapper::vdup_n(gamma, ExactTagType{});
177         rvar_vec            = wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec));
178 
179         for(; x <= (window_end_x - window_step_x); x += window_step_x)
180         {
181             auto wn = wrapper::vloadq(conv_w_in_ptr + x);
182             wn      = wrapper::vmul(wn, rvar_vec);
183             wn      = wrapper::vmul(wn, gamma_vec);
184 
185             // Store results
186             wrapper::vstore(conv_w_out_ptr + x, wn);
187         }
188 
189         // Compute left-over elements
190         for(; x < window_end_x; ++x)
191         {
192             *(conv_w_out_ptr + x) = *(conv_w_in_ptr + x) / std::sqrt(var + ScalarType(epsilon)) * gamma;
193         }
194     },
195     conv_w_in, conv_w_out);
196 }
197 
198 template <typename VectorType>
fused_batch_normalization_dwc_nhwc(const ITensor * dwc_weights,const ITensor * dwc_bias,ITensor * fused_weights,ITensor * fused_bias,const ITensor * bn_mean,const ITensor * bn_var,const ITensor * bn_beta,const ITensor * bn_gamma,float epsilon,const Window & window)199 void fused_batch_normalization_dwc_nhwc(const ITensor *dwc_weights, const ITensor *dwc_bias, ITensor *fused_weights, ITensor *fused_bias,
200                                         const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
201 {
202     using ScalarType   = typename VectorType::scalar_type;
203     const int size     = 16 / dwc_weights->info()->element_size();
204     using ExactTagType = typename VectorType::tag_type;
205 
206     const bool run_in_place_weights = (fused_weights == nullptr) || (fused_weights == dwc_weights);
207     const bool run_in_place_bias    = (fused_bias == nullptr) || (dwc_bias != nullptr && fused_bias == dwc_bias);
208 
209     // Set build options
210     Window win = window;
211     win.set(Window::DimX, Window::Dimension(0, 1, 1));
212 
213     const int  window_step_x  = size;
214     const auto window_start_x = static_cast<int>(window.x().start());
215     const auto window_end_x   = static_cast<int>(window.x().end());
216 
217     Iterator dwc_w_in(dwc_weights, win);
218     Iterator dwc_w_out(run_in_place_weights ? dwc_weights : fused_weights, win);
219 
220     const auto dwc_bias_in  = (dwc_bias != nullptr ? reinterpret_cast<ScalarType *>(dwc_bias->ptr_to_element(Coordinates(0, 0))) : nullptr);
221     auto       dwc_bias_out = (run_in_place_bias ? dwc_bias_in : reinterpret_cast<ScalarType *>(fused_bias->ptr_to_element(Coordinates(0, 0))));
222 
223     const auto input_mean  = reinterpret_cast<const ScalarType *>(bn_mean->ptr_to_element(Coordinates(0, 0)));
224     const auto input_var   = reinterpret_cast<const ScalarType *>(bn_var->ptr_to_element(Coordinates(0, 0)));
225     const auto input_gamma = (bn_gamma != nullptr) ? reinterpret_cast<const ScalarType *>(bn_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
226     const auto input_beta  = (bn_beta != nullptr) ? reinterpret_cast<const ScalarType *>(bn_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
227 
228     auto       mean_vec     = wrapper::vdup_n(ScalarType(0), ExactTagType{});
229     auto       var_vec      = wrapper::vdup_n(ScalarType(0), ExactTagType{});
230     auto       gamma_vec    = wrapper::vdup_n(ScalarType(1), ExactTagType{});
231     auto       beta_vec     = wrapper::vdup_n(ScalarType(0), ExactTagType{});
232     auto       rvar_vec     = wrapper::vdup_n(ScalarType(0), ExactTagType{});
233     auto       dwc_bias_vec = wrapper::vdup_n(ScalarType(0), ExactTagType{});
234     const auto epsilon_vec  = wrapper::vdup_n(ScalarType(epsilon), ExactTagType{});
235 
236     auto gamma              = ScalarType(1.0);
237     auto beta               = ScalarType(0.0);
238     auto dwc_bias_in_scalar = ScalarType(0);
239 
240     execute_window_loop(win, [&](const Coordinates & id)
241     {
242         int x = window_start_x;
243         for(; x <= (window_end_x - window_step_x); x += window_step_x)
244         {
245             var_vec = wrapper::vloadq(input_var + x);
246             if(input_gamma != nullptr)
247             {
248                 gamma_vec = wrapper::vloadq(input_gamma + x);
249             }
250 
251             if((id[2] == 0) && (id[1] == 0))
252             {
253                 mean_vec = wrapper::vloadq(input_mean + x);
254 
255                 // Construct vectors
256                 if(input_beta != nullptr)
257                 {
258                     beta_vec = wrapper::vloadq(input_beta + x);
259                 }
260 
261                 if(dwc_bias_in != nullptr)
262                 {
263                     dwc_bias_vec = wrapper::vloadq(dwc_bias_in + x);
264                 }
265 
266                 auto dwc_bias_tmp_vec = wrapper::vmul(wrapper::vsub(dwc_bias_vec, mean_vec), wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec)));
267                 dwc_bias_tmp_vec      = wrapper::vadd(wrapper::vmul(dwc_bias_tmp_vec, gamma_vec), beta_vec);
268                 wrapper::vstore(dwc_bias_out + x, dwc_bias_tmp_vec);
269             }
270 
271             auto dwc_w_in_ptr  = reinterpret_cast<const ScalarType *>(dwc_w_in.ptr());
272             auto dwc_w_out_ptr = reinterpret_cast<ScalarType *>(dwc_w_out.ptr());
273 
274             auto wn  = wrapper::vloadq(dwc_w_in_ptr + x);
275             rvar_vec = wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec));
276             wn       = wrapper::vmul(wn, rvar_vec);
277             wn       = wrapper::vmul(wn, gamma_vec);
278 
279             // Store results
280             wrapper::vstore(dwc_w_out_ptr + x, wn);
281         }
282 
283         // Compute left-over elements
284         for(; x < window_end_x; ++x)
285         {
286             auto var = input_var[x];
287             if(input_gamma != nullptr)
288             {
289                 gamma = input_gamma[x];
290             }
291 
292             if(id[2] == 0 && id[1] == 0)
293             {
294                 auto mean = input_mean[x];
295                 if(input_beta != nullptr)
296                 {
297                     beta = input_beta[x];
298                 }
299                 if(dwc_bias_in != nullptr)
300                 {
301                     dwc_bias_in_scalar = dwc_bias_in[x];
302                 }
303 
304                 auto dwc_bias_tmp_scalar = (dwc_bias_in_scalar - mean) / std::sqrt(var + ScalarType(epsilon));
305                 dwc_bias_out[x]          = (dwc_bias_tmp_scalar * gamma) + beta;
306             }
307 
308             const auto dwc_w_in_ptr  = reinterpret_cast<const ScalarType *>(dwc_w_in.ptr());
309             auto       dwc_w_out_ptr = reinterpret_cast<ScalarType *>(dwc_w_out.ptr());
310 
311             *(dwc_w_out_ptr + x) = *(dwc_w_in_ptr + x) / std::sqrt(var + ScalarType(epsilon)) * gamma;
312         }
313     },
314     dwc_w_in, dwc_w_out);
315 }
316 
317 template <typename VectorType>
fused_batch_normalization_dwc_nchw(const ITensor * dwc_weights,const ITensor * dwc_bias,ITensor * fused_weights,ITensor * fused_bias,const ITensor * bn_mean,const ITensor * bn_var,const ITensor * bn_beta,const ITensor * bn_gamma,float epsilon,const Window & window)318 void fused_batch_normalization_dwc_nchw(const ITensor *dwc_weights, const ITensor *dwc_bias, ITensor *fused_weights, ITensor *fused_bias,
319                                         const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
320 {
321     using ScalarType   = typename VectorType::scalar_type;
322     const int size     = 16 / dwc_weights->info()->element_size();
323     using ExactTagType = typename VectorType::tag_type;
324 
325     const bool run_in_place_weights = (fused_weights == nullptr) || (fused_weights == dwc_weights);
326     const bool run_in_place_bias    = (fused_bias == nullptr) || (dwc_bias != nullptr && fused_bias == dwc_bias);
327 
328     // Set build options
329     Window win = window;
330     win.set(Window::DimX, Window::Dimension(0, 1, 1));
331 
332     const int  window_step_x  = size;
333     const auto window_start_x = static_cast<int>(window.x().start());
334     const auto window_end_x   = static_cast<int>(window.x().end());
335 
336     Iterator dwc_w_in(dwc_weights, win);
337     Iterator dwc_w_out(run_in_place_weights ? dwc_weights : fused_weights, win);
338 
339     const auto dwc_bias_in  = (dwc_bias != nullptr ? reinterpret_cast<ScalarType *>(dwc_bias->ptr_to_element(Coordinates(0, 0))) : nullptr);
340     auto       dwc_bias_out = (run_in_place_bias ? dwc_bias_in : reinterpret_cast<ScalarType *>(fused_bias->ptr_to_element(Coordinates(0, 0))));
341 
342     const auto input_mean  = reinterpret_cast<const ScalarType *>(bn_mean->ptr_to_element(Coordinates(0, 0)));
343     const auto input_var   = reinterpret_cast<const ScalarType *>(bn_var->ptr_to_element(Coordinates(0, 0)));
344     const auto input_gamma = (bn_gamma != nullptr) ? reinterpret_cast<const ScalarType *>(bn_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
345     const auto input_beta  = (bn_beta != nullptr) ? reinterpret_cast<const ScalarType *>(bn_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
346 
347     auto       mean_vec    = wrapper::vdup_n(ScalarType(0), ExactTagType{});
348     auto       var_vec     = wrapper::vdup_n(ScalarType(0), ExactTagType{});
349     auto       gamma_vec   = wrapper::vdup_n(ScalarType(1), ExactTagType{});
350     auto       beta_vec    = wrapper::vdup_n(ScalarType(0), ExactTagType{});
351     auto       rvar_vec    = wrapper::vdup_n(ScalarType(0), ExactTagType{});
352     const auto epsilon_vec = wrapper::vdup_n(ScalarType(epsilon), ExactTagType{});
353 
354     auto mean               = ScalarType(0.0);
355     auto var                = ScalarType(0.0);
356     auto gamma              = ScalarType(1.0);
357     auto beta               = ScalarType(0.0);
358     auto dwc_bias_in_scalar = ScalarType(0.0);
359     execute_window_loop(win, [&](const Coordinates & id)
360     {
361         var = input_var[id[2]];
362         if(input_gamma != nullptr)
363         {
364             gamma = input_gamma[id[2]];
365         }
366 
367         if(id[1] == 0)
368         {
369             mean = input_mean[id[2]];
370 
371             // Construct vectors
372             mean_vec = wrapper::vdup_n(mean, ExactTagType{});
373             if(input_beta != nullptr)
374             {
375                 beta     = input_beta[id[2]];
376                 beta_vec = wrapper::vdup_n(beta, ExactTagType{});
377             }
378 
379             if(dwc_bias_in != nullptr)
380             {
381                 dwc_bias_in_scalar = dwc_bias_in[id[2]];
382             }
383 
384             auto dwc_bias_tmp_scalar = (dwc_bias_in_scalar - mean) / std::sqrt(var + ScalarType(epsilon));
385             dwc_bias_out[id[2]]      = (dwc_bias_tmp_scalar * gamma) + beta;
386         }
387 
388         int  x             = window_start_x;
389         auto dwc_w_in_ptr  = reinterpret_cast<const ScalarType *>(dwc_w_in.ptr());
390         auto dwc_w_out_ptr = reinterpret_cast<ScalarType *>(dwc_w_out.ptr());
391         var_vec            = wrapper::vdup_n(var, ExactTagType{});
392         gamma_vec          = wrapper::vdup_n(gamma, ExactTagType{});
393         rvar_vec           = wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec));
394 
395         for(; x <= (window_end_x - window_step_x); x += window_step_x)
396         {
397             auto wn = wrapper::vloadq(dwc_w_in_ptr + x);
398             wn      = wrapper::vmul(wn, rvar_vec);
399             wn      = wrapper::vmul(wn, gamma_vec);
400 
401             // Store results
402             wrapper::vstore(dwc_w_out_ptr + x, wn);
403         }
404 
405         // Compute left-over elements
406         for(; x < window_end_x; ++x)
407         {
408             *(dwc_w_out_ptr + x) = *(dwc_w_in_ptr + x) / std::sqrt(var + ScalarType(epsilon)) * gamma;
409         }
410     },
411     dwc_w_in, dwc_w_out);
412 }
413 
414 } // namespace
415 
NEFuseBatchNormalizationKernel()416 NEFuseBatchNormalizationKernel::NEFuseBatchNormalizationKernel()
417     : _input_weights(nullptr), _input_bias(nullptr), _bn_mean(nullptr), _bn_var(nullptr), _bn_gamma(nullptr), _bn_beta(nullptr), _fused_weights(nullptr), _fused_bias(nullptr), _epsilon(),
418       _run_in_place_weights(false), _run_in_place_bias(false), _func(nullptr)
419 {
420 }
421 
configure(const ITensor * input_weights,const ITensor * bn_mean,const ITensor * bn_var,ITensor * fused_weights,ITensor * fused_bias,const ITensor * input_bias,const ITensor * bn_beta,const ITensor * bn_gamma,float epsilon,FuseBatchNormalizationType fbn_type)422 void NEFuseBatchNormalizationKernel::configure(const ITensor *input_weights, const ITensor *bn_mean, const ITensor *bn_var,
423                                                ITensor *fused_weights, ITensor *fused_bias,
424                                                const ITensor *input_bias, const ITensor *bn_beta, const ITensor *bn_gamma,
425                                                float epsilon, FuseBatchNormalizationType fbn_type)
426 {
427     ARM_COMPUTE_ERROR_ON_NULLPTR(input_weights, bn_mean, bn_var);
428 
429     _input_weights = input_weights;
430     _input_bias    = input_bias;
431     _bn_mean       = bn_mean;
432     _bn_var        = bn_var;
433     _bn_beta       = bn_beta;
434     _bn_gamma      = bn_gamma;
435     _fused_weights = fused_weights;
436     _fused_bias    = fused_bias;
437     _epsilon       = epsilon;
438 
439     _run_in_place_weights = (fused_weights == nullptr) || (fused_weights == input_weights);
440     _run_in_place_bias    = (fused_bias == nullptr) || (input_bias != nullptr && fused_bias == input_bias);
441 
442     // Auto initialize outputs
443     if(_fused_weights != nullptr)
444     {
445         // Output tensor auto initialization if not yet initialized
446         auto_init_if_empty(*_fused_weights->info(), *_input_weights->info()->clone());
447         fused_weights->info()->set_valid_region(input_weights->info()->valid_region());
448     }
449     if(_fused_bias != nullptr)
450     {
451         // Output tensor auto initialization if not yet initialized
452         auto_init_if_empty(*_fused_bias->info(), *_bn_mean->info()->clone());
453         _fused_bias->info()->set_valid_region(bn_mean->info()->valid_region());
454     }
455 
456     // Validate arguments
457     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input_weights->info(), bn_mean->info(), bn_var->info(),
458                                                   (fused_weights != nullptr) ? fused_weights->info() : nullptr,
459                                                   (fused_bias != nullptr) ? fused_bias->info() : nullptr,
460                                                   (input_bias != nullptr) ? input_bias->info() : nullptr,
461                                                   (bn_beta != nullptr) ? bn_beta->info() : nullptr,
462                                                   (bn_gamma != nullptr) ? bn_gamma->info() : nullptr,
463                                                   epsilon, fbn_type));
464 
465     // Configure kernel window
466     Window win = calculate_max_window(*input_weights->info());
467     INEKernel::configure(win);
468 
469     // Configure function
470     static std::map<std::string, FuseBatchNormFunction *> map_function =
471     {
472         { "fused_batch_normalization_conv_NHWC_F32", &fused_batch_normalization_conv<wrapper::traits::neon_vector<float, 4>> },
473         { "fused_batch_normalization_conv_NCHW_F32", &fused_batch_normalization_conv<wrapper::traits::neon_vector<float, 4>> },
474         { "fused_batch_normalization_dwc_NHWC_F32", &fused_batch_normalization_dwc_nhwc<wrapper::traits::neon_vector<float, 4>> },
475         { "fused_batch_normalization_dwc_NCHW_F32", &fused_batch_normalization_dwc_nchw<wrapper::traits::neon_vector<float, 4>> },
476 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
477         { "fused_batch_normalization_conv_NHWC_F16", &fused_batch_normalization_conv<wrapper::traits::neon_vector<float16_t, 8>> },
478         { "fused_batch_normalization_conv_NCHW_F16", &fused_batch_normalization_conv<wrapper::traits::neon_vector<float16_t, 8>> },
479         { "fused_batch_normalization_dwc_NHWC_F16", &fused_batch_normalization_dwc_nhwc<wrapper::traits::neon_vector<float16_t, 8>> },
480         { "fused_batch_normalization_dwc_NCHW_F16", &fused_batch_normalization_dwc_nchw<wrapper::traits::neon_vector<float16_t, 8>> },
481 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
482     };
483 
484     std::string function_to_call("fused_batch_normalization_");
485     function_to_call += fbn_type == FuseBatchNormalizationType::CONVOLUTION ? "conv_" : "dwc_";
486     function_to_call += string_from_data_layout(_input_weights->info()->data_layout());
487     function_to_call += "_";
488     function_to_call += string_from_data_type(_input_weights->info()->data_type());
489 
490     auto it = map_function.find(function_to_call);
491 
492     if(it != map_function.end())
493     {
494         _func = it->second;
495     }
496 }
497 
validate(const ITensorInfo * input_weights,const ITensorInfo * bn_mean,const ITensorInfo * bn_var,const ITensorInfo * fused_weights,const ITensorInfo * fused_bias,const ITensorInfo * input_bias,const ITensorInfo * bn_beta,const ITensorInfo * bn_gamma,float epsilon,FuseBatchNormalizationType fbn_type)498 Status NEFuseBatchNormalizationKernel::validate(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var,
499                                                 const ITensorInfo *fused_weights, const ITensorInfo *fused_bias,
500                                                 const ITensorInfo *input_bias, const ITensorInfo *bn_beta, const ITensorInfo *bn_gamma,
501                                                 float epsilon, FuseBatchNormalizationType fbn_type)
502 {
503     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input_weights, bn_mean, bn_var, fused_weights, fused_bias, input_bias, bn_beta, bn_gamma, epsilon, fbn_type));
504     return Status{};
505 }
506 
run(const Window & window,const ThreadInfo & info)507 void NEFuseBatchNormalizationKernel::run(const Window &window, const ThreadInfo &info)
508 {
509     ARM_COMPUTE_UNUSED(info);
510     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
511     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
512     (*_func)(_input_weights, _input_bias, _fused_weights, _fused_bias, _bn_mean, _bn_var, _bn_beta, _bn_gamma, _epsilon, window);
513 }
514 } // namespace arm_compute
515