• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2016-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h"
25 
26 #include "arm_compute/core/TensorInfo.h"
27 #include "src/core/CPP/Validate.h"
28 #include "src/core/NEON/NEAsymm.h"
29 #include "src/core/NEON/NESymm.h"
30 #include "src/core/NEON/wrapper/wrapper.h"
31 #include "src/core/helpers/AutoConfiguration.h"
32 #include "src/core/helpers/WindowHelpers.h"
33 
34 #include <arm_neon.h>
35 
36 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
37 #include <arm_fp16.h> // needed for float16_t
38 #endif                /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
39 
40 namespace arm_compute
41 {
42 namespace
43 {
44 const float       scale255_constant      = 1.f / 255.f;
45 const float32x4_t scale255_constant_f32q = vdupq_n_f32(scale255_constant);
46 const float32x4_t positive_round_f32q    = vdupq_n_f32(0.5f);
47 
validate_arguments(const ITensorInfo * input1,const ITensorInfo * input2,const ITensorInfo * output,float scale,ConvertPolicy overflow_policy,RoundingPolicy rounding_policy)48 inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
49 {
50     ARM_COMPUTE_UNUSED(overflow_policy);
51     ARM_COMPUTE_UNUSED(rounding_policy);
52 
53     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
54     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::S32, DataType::QSYMM16, DataType::F16,
55                                                          DataType::F32);
56     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::S32, DataType::QSYMM16, DataType::F16,
57                                                          DataType::F32);
58     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
59                                                          DataType::S16, DataType::QSYMM16,
60                                                          DataType::S32, DataType::F16, DataType::F32);
61     if(is_data_type_quantized(input1->data_type()) || is_data_type_quantized(input2->data_type()))
62     {
63         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
64         ARM_COMPUTE_RETURN_ERROR_ON_MSG(overflow_policy == ConvertPolicy::WRAP, "ConvertPolicy cannot be WRAP if datatype is quantized");
65     }
66 
67     if(output->total_size() > 0)
68     {
69         const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
70         ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
71         ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
72         // clang-format off
73         ARM_COMPUTE_RETURN_ERROR_ON_MSG(
74             !(input1->data_type() == input2->data_type() && input2->data_type() == output->data_type()) &&
75             !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) &&
76             !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16) &&
77             !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) &&
78             !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) &&
79             !(input1->data_type() == DataType::QSYMM16 && input2->data_type() == DataType::QSYMM16 && output->data_type() == DataType::S32)
80             , "Invalid data type combination");
81         // clang-format on
82         ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::S16 && output->data_type() == DataType::S32 && scale != 1.f, "Unsupported scale for QSYMM16 inputs and S32 output");
83     }
84 
85     if(std::abs(scale - scale255_constant) < 0.00001f)
86     {
87         ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN);
88         ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::S32 && input2->data_type() == DataType::S32 && output->data_type() == DataType::S32,
89                                         "Scale == 1/255 is not supported if input and output are of data type S32");
90     }
91     else
92     {
93         ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_ZERO);
94 
95         int         exponent            = 0;
96         const float normalized_mantissa = std::frexp(scale, &exponent);
97 
98         // Use int scaling if factor is equal to 1/2^n for 0 <= n <= 15
99         // frexp returns 0.5 as mantissa which means that the exponent will be in the range of -1 <= e <= 14
100         // Moreover, it will be negative as we deal with 1/2^n
101         ARM_COMPUTE_RETURN_ERROR_ON_MSG(!((normalized_mantissa == 0.5f) && (-14 <= exponent) && (exponent <= 1)), "Scale value not supported (Should be 1/(2^n) or 1/255");
102     }
103 
104     return Status{};
105 }
106 
107 /* Scales a given vector by 1/255.
108  *
109  * @note This does not work for all cases. e.g. for float of 0.49999999999999994 and large floats.
110  *
111  * @param in Input vector to scale.
112  * @return   Scaled output rounded to nearest (round half up).
113  */
scale255_S32_S32(int32x4_t in)114 inline int32x4_t scale255_S32_S32(int32x4_t in)
115 {
116     // Scale
117     const float32x4_t tmp = vmulq_f32(vcvtq_f32_s32(in), scale255_constant_f32q);
118     // Round to nearest (round half up)
119     // Add +0.5 for all values
120     // Afterwards vcvt rounds toward zero
121     return vcvtq_s32_f32(vaddq_f32(tmp, positive_round_f32q));
122 }
123 
scale255_U16_U16(uint16x8_t in)124 inline uint16x8_t scale255_U16_U16(uint16x8_t in)
125 {
126     const int32x4_t tmp_s1 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(in))));
127     const int32x4_t tmp_s2 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(in))));
128     return vreinterpretq_u16_s16(vcombine_s16(vmovn_s32(tmp_s2), vmovn_s32(tmp_s1)));
129 }
130 
131 template <typename T>
132 inline typename std::enable_if<std::is_same<T, int8_t>::value, int8x16_t>::type
vquantize(float32x4x4_t val,const UniformQuantizationInfo & info)133 vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
134 {
135     return vquantize_signed(val, info);
136 }
137 
138 template <typename T>
139 inline typename std::enable_if<std::is_same<T, uint8_t>::value, uint8x16_t>::type
vquantize(float32x4x4_t val,const UniformQuantizationInfo & info)140 vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
141 {
142     return vquantize(val, info);
143 }
144 
145 template <typename T>
mul_saturate_quantized_8(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,float scale)146 void mul_saturate_quantized_8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
147 {
148     // Create input windows
149     Window win        = window;
150     Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
151     Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
152 
153     // Clear X Dimension on execution window as we handle manually
154     win.set(Window::DimX, Window::Dimension(0, 1, 1));
155 
156     const int  window_step_x         = 16 / sizeof(T);
157     const auto window_start_x        = static_cast<int>(window.x().start());
158     const auto window_end_x          = static_cast<int>(window.x().end());
159     const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
160 
161     const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
162     const UniformQuantizationInfo tmp_qua_info    = { output_qua_info.scale / scale, output_qua_info.offset };
163 
164     if(is_broadcast_across_x)
165     {
166         const bool                    is_broadcast_input_2 = input2_win.x().step() == 0;
167         Window                        broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
168         Window                        non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
169         const ITensor                *broadcast_tensor     = is_broadcast_input_2 ? in2 : in1;
170         const ITensor                *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
171         const UniformQuantizationInfo broadcast_qinfo      = broadcast_tensor->info()->quantization_info().uniform();
172         const UniformQuantizationInfo non_broadcast_qinfo  = non_broadcast_tensor->info()->quantization_info().uniform();
173 
174         // Clear X Dimension on execution window as we handle manually
175         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
176 
177         Iterator broadcast_input(broadcast_tensor, broadcast_win);
178         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
179         Iterator output(out, win);
180 
181         using ExactTagType = typename wrapper::traits::neon_vector<T, window_step_x>::tag_type;
182 
183         execute_window_loop(win, [&](const Coordinates &)
184         {
185             const auto non_broadcast_input_ptr = reinterpret_cast<const T *>(non_broadcast_input.ptr());
186             const auto output_ptr              = reinterpret_cast<T *>(output.ptr());
187 
188             const auto broadcast_value     = *reinterpret_cast<const T *>(broadcast_input.ptr());
189             const auto broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
190 
191             // Compute window_step_x elements per iteration
192             int x = window_start_x;
193             for(; x <= (window_end_x - window_step_x); x += window_step_x)
194             {
195                 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
196 
197                 // Dequantize inputs
198                 const float32x4x4_t in1_f32x4x4 = vdequantize(non_broadcast_v, non_broadcast_qinfo);
199                 const float32x4x4_t in2_f32x4x4 = vdequantize(broadcast_value_vec, broadcast_qinfo);
200 
201                 const float32x4x4_t out_f32x4x4 =
202                 {
203                     vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
204                     vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
205                     vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
206                     vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
207                 };
208 
209                 // Quantize output
210                 const auto result = vquantize<T>(out_f32x4x4, tmp_qua_info);
211                 wrapper::vstore(output_ptr + x, result);
212             }
213 
214             // Compute left-over elements
215             for(; x < window_end_x; ++x)
216             {
217                 // Dequantize inputs
218                 const T     in1     = *(non_broadcast_input_ptr + x);
219                 const float tmp_in1 = Qasymm8QuantizationHelper<T>::dequantize(in1, non_broadcast_qinfo);
220                 const float tmp_in2 = Qasymm8QuantizationHelper<T>::dequantize(broadcast_value, broadcast_qinfo);
221                 const float tmp_f   = tmp_in1 * tmp_in2;
222 
223                 // Quantize output
224                 const auto tmp_qua = Qasymm8QuantizationHelper<T>::quantize(tmp_f, tmp_qua_info);
225                 *(output_ptr + x)  = tmp_qua;
226             }
227         },
228         broadcast_input, non_broadcast_input, output);
229     }
230     else
231     {
232         const UniformQuantizationInfo input1_qua_info = in1->info()->quantization_info().uniform();
233         const UniformQuantizationInfo input2_qua_info = in2->info()->quantization_info().uniform();
234 
235         // Clear X Dimension on execution window as we handle manually
236         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
237         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
238 
239         Iterator input1(in1, input1_win);
240         Iterator input2(in2, input2_win);
241         Iterator output(out, win);
242 
243         execute_window_loop(win, [&](const Coordinates &)
244         {
245             const auto input1_ptr = reinterpret_cast<const T *>(input1.ptr());
246             const auto input2_ptr = reinterpret_cast<const T *>(input2.ptr());
247             const auto output_ptr = reinterpret_cast<T *>(output.ptr());
248 
249             // Compute window_step_x elements per iteration
250             int x = window_start_x;
251             for(; x <= (window_end_x - window_step_x); x += window_step_x)
252             {
253                 const auto input1_q = wrapper::vloadq(input1_ptr + x);
254                 const auto input2_q = wrapper::vloadq(input2_ptr + x);
255 
256                 // Dequantize inputs
257                 const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
258                 const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
259 
260                 const float32x4x4_t out_f32x4x4 =
261                 {
262                     vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
263                     vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
264                     vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
265                     vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
266                 };
267 
268                 // Quantize output
269                 const auto result = vquantize<T>(out_f32x4x4, tmp_qua_info);
270                 wrapper::vstore(output_ptr + x, result);
271             }
272 
273             // Compute left-over elements
274             for(; x < window_end_x; ++x)
275             {
276                 // Dequantize inputs
277                 const T     in1     = *(input1_ptr + x);
278                 const T     in2     = *(input2_ptr + x);
279                 const float tmp_in1 = Qasymm8QuantizationHelper<T>::dequantize(in1, input1_qua_info);
280                 const float tmp_in2 = Qasymm8QuantizationHelper<T>::dequantize(in2, input2_qua_info);
281                 const float tmp_f   = tmp_in1 * tmp_in2;
282 
283                 // Quantize output
284                 const auto tmp_qua = Qasymm8QuantizationHelper<T>::quantize(tmp_f, tmp_qua_info);
285                 *(output_ptr + x)  = tmp_qua;
286             }
287         },
288         input1, input2, output);
289     }
290 }
291 
mul_saturate_QSYMM16_QSYMM16_QSYMM16(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,float scale)292 void mul_saturate_QSYMM16_QSYMM16_QSYMM16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
293 {
294     const UniformQuantizationInfo input1_qua_info = in1->info()->quantization_info().uniform();
295     const UniformQuantizationInfo input2_qua_info = in2->info()->quantization_info().uniform();
296     const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
297 
298     // Create input windows
299     Window win        = window;
300     Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
301     Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
302 
303     // Clear X Dimension on execution window as we handle manually
304     win.set(Window::DimX, Window::Dimension(0, 1, 1));
305     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
306     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
307 
308     Iterator input1(in1, input1_win);
309     Iterator input2(in2, input2_win);
310     Iterator output(out, win);
311 
312     const int  window_step_x  = 16;
313     const auto window_start_x = static_cast<int>(window.x().start());
314     const auto window_end_x   = static_cast<int>(window.x().end());
315 
316     const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
317 
318     execute_window_loop(win, [&](const Coordinates &)
319     {
320         const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
321         const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
322         const auto output_ptr = reinterpret_cast<qsymm16_t *>(output.ptr());
323 
324         // Compute window_step_x elements per iteration
325         int x = window_start_x;
326         for(; x <= (window_end_x - window_step_x); x += window_step_x)
327         {
328             const qsymm16x8x2_t input1_q =
329             {
330                 {
331                     vld1q_s16(input1_ptr + x),
332                     vld1q_s16(input1_ptr + x + 8),
333                 }
334             };
335             const qsymm16x8x2_t input2_q =
336             {
337                 {
338                     vld1q_s16(input2_ptr + x),
339                     vld1q_s16(input2_ptr + x + 8),
340                 }
341             };
342 
343             // Dequantize inputs
344             const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
345             const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
346 
347             const float32x4x4_t out_f32x4x4 =
348             {
349                 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
350                 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
351                 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
352                 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
353             };
354 
355             const qsymm16x8x2_t result = vquantize_qsymm16(out_f32x4x4, tmp_qua_info);
356             vst1q_s16(output_ptr + x, result.val[0]);
357             vst1q_s16(output_ptr + x + 8, result.val[1]);
358         }
359 
360         // Compute left-over elements
361         for(; x < window_end_x; ++x)
362         {
363             // Dequantize inputs
364             float tmp_in1 = static_cast<float>(*(input1_ptr + x)) * input1_qua_info.scale;
365             float tmp_in2 = static_cast<float>(*(input2_ptr + x)) * input2_qua_info.scale;
366             float tmp_f   = tmp_in1 * tmp_in2;
367 
368             // Quantize output, lrintf() has same rounding mode as vcombine_s16
369             int32_t   tmp     = lrintf(tmp_f / tmp_qua_info.scale);
370             qsymm16_t tmp_qua = static_cast<qsymm16_t>(tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
371             *(output_ptr + x) = tmp_qua;
372         }
373     },
374     input1, input2, output);
375 }
376 
mul_QSYMM16_QSYMM16_S32(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,int scale)377 void mul_QSYMM16_QSYMM16_S32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int scale)
378 {
379     ARM_COMPUTE_UNUSED(scale);
380 
381     // Create input windows
382     Window win        = window;
383     Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
384     Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
385 
386     // Clear X Dimension on execution window as we handle manually
387     win.set(Window::DimX, Window::Dimension(0, 1, 1));
388     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
389     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
390 
391     Iterator input1(in1, input1_win);
392     Iterator input2(in2, input2_win);
393     Iterator output(out, win);
394 
395     const int  window_step_x  = 16;
396     const auto window_start_x = static_cast<int>(window.x().start());
397     const auto window_end_x   = static_cast<int>(window.x().end());
398 
399     execute_window_loop(win, [&](const Coordinates &)
400     {
401         const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
402         const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
403         const auto output_ptr = reinterpret_cast<int32_t *>(output.ptr());
404 
405         // Compute window_step_x elements per iteration
406         int x = window_start_x;
407         for(; x <= (window_end_x - window_step_x); x += window_step_x)
408         {
409             const qsymm16x8x2_t input1_q =
410             {
411                 {
412                     vld1q_s16(input1_ptr + x),
413                     vld1q_s16(input1_ptr + x + 8),
414                 }
415             };
416             const qsymm16x8x2_t input2_q =
417             {
418                 {
419                     vld1q_s16(input2_ptr + x),
420                     vld1q_s16(input2_ptr + x + 8),
421                 }
422             };
423 
424             const int32x4x4_t in1_s32 =
425             {
426                 {
427                     vmovl_s16(vget_low_s16(input1_q.val[0])),
428                     vmovl_s16(vget_high_s16(input1_q.val[0])),
429                     vmovl_s16(vget_low_s16(input1_q.val[1])),
430                     vmovl_s16(vget_high_s16(input1_q.val[1])),
431                 }
432             };
433             const int32x4x4_t in2_s32 =
434             {
435                 {
436                     vmovl_s16(vget_low_s16(input2_q.val[0])),
437                     vmovl_s16(vget_high_s16(input2_q.val[0])),
438                     vmovl_s16(vget_low_s16(input2_q.val[1])),
439                     vmovl_s16(vget_high_s16(input2_q.val[1])),
440                 }
441             };
442 
443             const int32x4x4_t result =
444             {
445                 {
446                     vmulq_s32(in1_s32.val[0], in2_s32.val[0]),
447                     vmulq_s32(in1_s32.val[1], in2_s32.val[1]),
448                     vmulq_s32(in1_s32.val[2], in2_s32.val[2]),
449                     vmulq_s32(in1_s32.val[3], in2_s32.val[3]),
450                 }
451             };
452 
453             vst1q_s32(output_ptr + x, result.val[0]);
454             vst1q_s32(output_ptr + x + 4, result.val[1]);
455             vst1q_s32(output_ptr + x + 8, result.val[2]);
456             vst1q_s32(output_ptr + x + 12, result.val[3]);
457         }
458 
459         // Compute left-over elements
460         for(; x < window_end_x; ++x)
461         {
462             int32_t tmp       = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
463             *(output_ptr + x) = tmp;
464         }
465     },
466     input1, input2, output);
467 }
468 
469 template <bool is_scale255, bool is_sat>
mul_U8_U8_U8(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,int n)470 void mul_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
471 {
472     // Create input windows
473     Window win        = window;
474     Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
475     Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
476 
477     // Clear X Dimension on execution window as we handle manually
478     win.set(Window::DimX, Window::Dimension(0, 1, 1));
479     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
480     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
481 
482     Iterator input1(in1, input1_win);
483     Iterator input2(in2, input2_win);
484     Iterator output(out, win);
485 
486     const int  window_step_x  = 16 / sizeof(uint8_t);
487     const auto window_start_x = static_cast<int>(window.x().start());
488     const auto window_end_x   = static_cast<int>(window.x().end());
489 
490     execute_window_loop(win, [&](const Coordinates &)
491     {
492         const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
493         const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
494         const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
495 
496         // Compute window_step_x elements per iteration
497         int x = window_start_x;
498         for(; x <= (window_end_x - window_step_x); x += window_step_x)
499         {
500             const uint8x16_t ta1 = wrapper::vloadq(input1_ptr + x);
501             const uint8x16_t ta2 = wrapper::vloadq(input2_ptr + x);
502 
503             uint16x8_t       tmp1_high = vmovl_u8(vget_high_u8(ta1));
504             const uint16x8_t tmp2_high = vmovl_u8(vget_high_u8(ta2));
505             uint16x8_t       tmp1_low  = vmovl_u8(vget_low_u8(ta1));
506             const uint16x8_t tmp2_low  = vmovl_u8(vget_low_u8(ta2));
507 
508             tmp1_high = vmulq_u16(tmp1_high, tmp2_high);
509             tmp1_low  = vmulq_u16(tmp1_low, tmp2_low);
510 
511             if(is_scale255)
512             {
513                 tmp1_high = scale255_U16_U16(tmp1_high);
514                 tmp1_low  = scale255_U16_U16(tmp1_low);
515             }
516             else
517             {
518                 const int16x8_t vn = vdupq_n_s16(-n);
519 
520                 if(is_sat)
521                 {
522                     tmp1_high = vqshlq_u16(tmp1_high, vn);
523                     tmp1_low  = vqshlq_u16(tmp1_low, vn);
524                 }
525                 else
526                 {
527                     tmp1_high = vshlq_u16(tmp1_high, vn);
528                     tmp1_low  = vshlq_u16(tmp1_low, vn);
529                 }
530             }
531             if(is_sat)
532             {
533                 vst1q_u8(output_ptr, vcombine_u8(vqmovn_u16(tmp1_low), vqmovn_u16(tmp1_high)));
534             }
535             else
536             {
537                 vst1q_u8(output_ptr, vcombine_u8(vmovn_u16(tmp1_low), vmovn_u16(tmp1_high)));
538             }
539         }
540 
541         // Compute left-over elements
542         for(; x < window_end_x; ++x)
543         {
544             uint16_t tmp = static_cast<uint16_t>(*(input1_ptr + x)) * static_cast<uint16_t>(*(input2_ptr + x));
545 
546             if(is_scale255)
547             {
548                 float tmp_f = static_cast<float>(tmp) * scale255_constant;
549                 tmp         = static_cast<uint16_t>(tmp_f + 0.5f);
550             }
551             else
552             {
553                 tmp >>= n;
554             }
555             if(is_sat && tmp > 255)
556             {
557                 tmp = 255;
558             }
559             *(output_ptr + x) = static_cast<uint8_t>(tmp);
560         }
561     },
562     input1, input2, output);
563 }
564 
565 template <bool is_scale255, bool is_sat>
mul_S16_S16_S16_n_loop(const int16x8_t & input1,const int16x8_t & input2,int n)566 inline int16x8_t mul_S16_S16_S16_n_loop(const int16x8_t &input1, const int16x8_t &input2, int n)
567 {
568     int32x4_t       tmp1_high = vmovl_s16(vget_high_s16(input1));
569     const int32x4_t tmp2_high = vmovl_s16(vget_high_s16(input2));
570     int32x4_t       tmp1_low  = vmovl_s16(vget_low_s16(input1));
571     const int32x4_t tmp2_low  = vmovl_s16(vget_low_s16(input2));
572 
573     tmp1_high = vmulq_s32(tmp1_high, tmp2_high);
574     tmp1_low  = vmulq_s32(tmp1_low, tmp2_low);
575 
576     if(is_scale255)
577     {
578         tmp1_high = scale255_S32_S32(tmp1_high);
579         tmp1_low  = scale255_S32_S32(tmp1_low);
580     }
581     else
582     {
583         // Right shift amount
584         const int32x4_t vn = vdupq_n_s32(-n);
585         // Left shift amount
586         const int32x4_t vnl = vdupq_n_s32(n);
587         // Calculate conversion bit
588         const uint32x4_t tmp1_high_u  = vreinterpretq_u32_s32(tmp1_high);
589         const uint32x4_t tmp1_low_u   = vreinterpretq_u32_s32(tmp1_low);
590         const uint32x4_t sign_high    = vshrq_n_u32(tmp1_high_u, 31);
591         const uint32x4_t sign_low     = vshrq_n_u32(tmp1_low_u, 31);
592         const int32x4_t  sign_high_s  = vreinterpretq_s32_u32(sign_high);
593         const int32x4_t  sign_low_s   = vreinterpretq_s32_u32(sign_low);
594         const int32x4_t  convert_high = vsubq_s32(vshlq_s32(sign_high_s, vnl), sign_high_s);
595         const int32x4_t  convert_low  = vsubq_s32(vshlq_s32(sign_low_s, vnl), sign_low_s);
596         if(is_sat)
597         {
598             tmp1_high = vqshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
599             tmp1_low  = vqshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
600         }
601         else
602         {
603             tmp1_high = vshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
604             tmp1_low  = vshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
605         }
606     }
607 
608     if(is_sat)
609     {
610         return vcombine_s16(vqmovn_s32(tmp1_low), vqmovn_s32(tmp1_high));
611     }
612     else
613     {
614         return vcombine_s16(vmovn_s32(tmp1_low), vmovn_s32(tmp1_high));
615     }
616 }
617 
618 template <bool is_scale255, bool is_sat>
mul_S16_S16_S16_n_k(const int16x8x2_t & input1,const int16x8x2_t & input2,int n)619 inline int16x8x2_t mul_S16_S16_S16_n_k(const int16x8x2_t &input1, const int16x8x2_t &input2, int n)
620 {
621     const int16x8x2_t result =
622     {
623         {
624             // First 8 elements
625             mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[0], input2.val[0], n),
626             // Second 8 elements
627             mul_S16_S16_S16_n_loop<is_scale255, is_sat>(input1.val[1], input2.val[1], n)
628         }
629     };
630 
631     return result;
632 }
633 
634 template <bool is_scale255, bool is_sat>
mul_S16_S16_S16(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,int n)635 void mul_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
636 {
637     // Create input windows
638     Window win        = window;
639     Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
640     Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
641 
642     // Clear X Dimension on execution window as we handle manually
643     win.set(Window::DimX, Window::Dimension(0, 1, 1));
644     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
645     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
646 
647     Iterator input1(in1, input1_win);
648     Iterator input2(in2, input2_win);
649     Iterator output(out, win);
650 
651     const int  window_step_x  = 16;
652     const auto window_start_x = static_cast<int>(window.x().start());
653     const auto window_end_x   = static_cast<int>(window.x().end());
654 
655     execute_window_loop(win, [&](const Coordinates &)
656     {
657         const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
658         const auto input2_ptr = reinterpret_cast<const int16_t *>(input2.ptr());
659         const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
660 
661         // Compute window_step_x elements per iteration
662         int x = window_start_x;
663         for(; x <= (window_end_x - window_step_x); x += window_step_x)
664         {
665             const int16x8x2_t ta1 =
666             {
667                 {
668                     vld1q_s16(input1_ptr + x),
669                     vld1q_s16(input1_ptr + x + 8),
670                 }
671             };
672             const int16x8x2_t ta2 =
673             {
674                 {
675                     vld1q_s16(input2_ptr + x),
676                     vld1q_s16(input2_ptr + x + 8),
677                 }
678             };
679             const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
680 
681             vst1q_s16(output_ptr + x, result.val[0]);
682             vst1q_s16(output_ptr + x + 8, result.val[1]);
683         }
684 
685         // Compute left-over elements
686         for(; x < window_end_x; ++x)
687         {
688             int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
689 
690             if(is_scale255)
691             {
692                 float tmp_f = static_cast<float>(tmp) * scale255_constant;
693 
694                 tmp = static_cast<int32_t>(tmp_f + 0.5f);
695             }
696             else
697             {
698                 if(tmp >= 0)
699                 {
700                     tmp >>= n;
701                 }
702                 else
703                 {
704                     uint32_t mask = (1u << n) - 1;
705                     tmp           = (tmp + static_cast<int32_t>(mask)) >> n;
706                 }
707             }
708             if(is_sat)
709             {
710                 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
711             }
712             *(output_ptr + x) = static_cast<int16_t>(tmp);
713         }
714     },
715     input1, input2, output);
716 }
717 
718 template <bool   is_sat>
mul_S32_S32_S32_n_loop(const int32x4_t & input1,const int32x4_t & input2,int n)719 inline int32x4_t mul_S32_S32_S32_n_loop(const int32x4_t &input1, const int32x4_t &input2, int n)
720 {
721     const int32x2_t input1_1 = vget_low_s32(input1);
722     const int32x2_t input2_1 = vget_low_s32(input2);
723     const int32x2_t input1_2 = vget_high_s32(input1);
724     const int32x2_t input2_2 = vget_high_s32(input2);
725 
726     int64x2_t tmp_1 = vmull_s32(input1_1, input2_1);
727     int64x2_t tmp_2 = vmull_s32(input1_2, input2_2);
728 
729     // Apply scaling, conversion and rounding (round to zero)
730     // Right shift amount
731     const int64x2_t vn = vdupq_n_s64(-n);
732     // Left shift amount
733     const int64x2_t vnl = vdupq_n_s64(n);
734     // Calculate conversion bit
735     const uint64x2_t tmp_1_u   = vreinterpretq_u64_s64(tmp_1);
736     const uint64x2_t sign_1    = vshrq_n_u64(tmp_1_u, 63);
737     const int64x2_t  sign_1_s  = vreinterpretq_s64_u64(sign_1);
738     const int64x2_t  convert_1 = vsubq_s64(vshlq_s64(sign_1_s, vnl), sign_1_s);
739 
740     const uint64x2_t tmp_2_u   = vreinterpretq_u64_s64(tmp_2);
741     const uint64x2_t sign_2    = vshrq_n_u64(tmp_2_u, 63);
742     const int64x2_t  sign_2_s  = vreinterpretq_s64_u64(sign_2);
743     const int64x2_t  convert_2 = vsubq_s64(vshlq_s64(sign_2_s, vnl), sign_2_s);
744     if(is_sat)
745     {
746         tmp_1 = vqshlq_s64(vaddq_s64(tmp_1, convert_1), vn);
747         tmp_2 = vqshlq_s64(vaddq_s64(tmp_2, convert_2), vn);
748         return vcombine_s32(vqmovn_s64(tmp_1), vqmovn_s64(tmp_2));
749     }
750     else
751     {
752         tmp_1 = vshlq_s64(vaddq_s64(tmp_1, convert_1), vn);
753         tmp_2 = vshlq_s64(vaddq_s64(tmp_2, convert_2), vn);
754         return vcombine_s32(vmovn_s64(tmp_1), vmovn_s64(tmp_2));
755     }
756 }
757 
758 template <bool     is_sat>
mul_S32_S32_S32_n_k(const int32x4x2_t & input1,const int32x4x2_t & input2,int n)759 inline int32x4x2_t mul_S32_S32_S32_n_k(const int32x4x2_t &input1, const int32x4x2_t &input2, int n)
760 {
761     const int32x4x2_t result =
762     {
763         {
764             // First 4 elements
765             mul_S32_S32_S32_n_loop<is_sat>(input1.val[0], input2.val[0], n),
766             // Second 4 elements
767             mul_S32_S32_S32_n_loop<is_sat>(input1.val[1], input2.val[1], n)
768         }
769     };
770 
771     return result;
772 }
773 
774 template <bool is_sat>
mul_S32_S32_S32(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,int n)775 void mul_S32_S32_S32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
776 {
777     // Create input windows
778     Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
779     Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
780 
781     // Clear X Dimension on execution window as we handle manually
782     Window win = window;
783     win.set(Window::DimX, Window::Dimension(0, 1, 1));
784 
785     const int  window_step_x         = 8;
786     const auto window_start_x        = static_cast<int>(window.x().start());
787     const auto window_end_x          = static_cast<int>(window.x().end());
788     const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
789 
790     if(is_broadcast_across_x)
791     {
792         const bool     is_broadcast_input_2 = input2_win.x().step() == 0;
793         Window         broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
794         Window         non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
795         const ITensor *broadcast_tensor     = is_broadcast_input_2 ? in2 : in1;
796         const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
797 
798         // Clear X Dimension on execution window as we handle manually
799         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
800 
801         Iterator broadcast_input(broadcast_tensor, broadcast_win);
802         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
803         Iterator output(out, win);
804 
805         execute_window_loop(win, [&](const Coordinates &)
806         {
807             const auto non_broadcast_input_ptr = reinterpret_cast<const int32_t *>(non_broadcast_input.ptr());
808             const auto output_ptr              = reinterpret_cast<int32_t *>(output.ptr());
809 
810             const int32_t broadcast_value     = *reinterpret_cast<const int32_t *>(broadcast_input.ptr());
811             const auto    broadcast_value_vec = vdupq_n_s32(broadcast_value);
812 
813             // Compute window_step_x elements per iteration
814             int x = window_start_x;
815             for(; x <= (window_end_x - window_step_x); x += window_step_x)
816             {
817                 const int32x4x2_t broadcast_v =
818                 {
819                     {
820                         broadcast_value_vec,
821                         broadcast_value_vec,
822                     }
823                 };
824                 const int32x4x2_t non_broadcast_v =
825                 {
826                     {
827                         vld1q_s32(non_broadcast_input_ptr + x),
828                         vld1q_s32(non_broadcast_input_ptr + x + 4),
829                     }
830                 };
831                 const int32x4x2_t result = mul_S32_S32_S32_n_k<is_sat>(broadcast_v, non_broadcast_v, n);
832 
833                 vst1q_s32(output_ptr + x, result.val[0]);
834                 vst1q_s32(output_ptr + x + 4, result.val[1]);
835             }
836 
837             // Compute left-over elements
838             for(; x < window_end_x; ++x)
839             {
840                 int64_t tmp = static_cast<int64_t>(broadcast_value) * static_cast<int64_t>(*(non_broadcast_input_ptr + x));
841 
842                 if(tmp >= 0)
843                 {
844                     tmp >>= n;
845                 }
846                 else
847                 {
848                     uint64_t mask = (1u << n) - 1;
849                     tmp           = (tmp + static_cast<int64_t>(mask)) >> n;
850                 }
851                 if(is_sat)
852                 {
853                     tmp = utility::clamp<int64_t, int32_t>(tmp);
854                 }
855                 *(output_ptr + x) = static_cast<int32_t>(tmp);
856             }
857         },
858         broadcast_input, non_broadcast_input, output);
859     }
860     else
861     {
862         // Clear X Dimension on execution window as we handle manually
863         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
864         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
865 
866         Iterator input1(in1, input1_win);
867         Iterator input2(in2, input2_win);
868         Iterator output(out, win);
869 
870         execute_window_loop(win, [&](const Coordinates &)
871         {
872             const auto input1_ptr = reinterpret_cast<const int32_t *>(input1.ptr());
873             const auto input2_ptr = reinterpret_cast<const int32_t *>(input2.ptr());
874             const auto output_ptr = reinterpret_cast<int32_t *>(output.ptr());
875 
876             // Compute window_step_x elements per iteration
877             int x = window_start_x;
878             for(; x <= (window_end_x - window_step_x); x += window_step_x)
879             {
880                 const int32x4x2_t ta1 =
881                 {
882                     {
883                         vld1q_s32(input1_ptr + x),
884                         vld1q_s32(input1_ptr + x + 4),
885                     }
886                 };
887                 const int32x4x2_t ta2 =
888                 {
889                     {
890                         vld1q_s32(input2_ptr + x),
891                         vld1q_s32(input2_ptr + x + 4),
892                     }
893                 };
894                 const int32x4x2_t result = mul_S32_S32_S32_n_k<is_sat>(ta1, ta2, n);
895 
896                 vst1q_s32(output_ptr + x, result.val[0]);
897                 vst1q_s32(output_ptr + x + 4, result.val[1]);
898             }
899 
900             // Compute left-over elements
901             for(; x < window_end_x; ++x)
902             {
903                 int64_t tmp = static_cast<int64_t>(*(input1_ptr + x)) * static_cast<int64_t>(*(input2_ptr + x));
904 
905                 if(tmp >= 0)
906                 {
907                     tmp >>= n;
908                 }
909                 else
910                 {
911                     uint64_t mask = (1u << n) - 1;
912                     tmp           = (tmp + static_cast<int64_t>(mask)) >> n;
913                 }
914                 if(is_sat)
915                 {
916                     tmp = utility::clamp<int64_t, int32_t>(tmp);
917                 }
918                 *(output_ptr + x) = static_cast<int32_t>(tmp);
919             }
920         },
921         input1, input2, output);
922     }
923 }
924 
mul_F32_F32_F32(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,float scale)925 void mul_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
926 {
927     // Create input windows
928     Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
929     Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
930 
931     // Clear X Dimension on execution window as we handle manually
932     Window win = window;
933     win.set(Window::DimX, Window::Dimension(0, 1, 1));
934 
935     constexpr int window_step_x         = 16 / sizeof(float);
936     const auto    window_start_x        = static_cast<int>(window.x().start());
937     const auto    window_end_x          = static_cast<int>(window.x().end());
938     const bool    is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
939 
940     using ExactTagType = typename wrapper::traits::neon_vector<float, window_step_x>::tag_type;
941 
942     if(is_broadcast_across_x)
943     {
944         const bool     is_broadcast_input_2 = input2_win.x().step() == 0;
945         Window         broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
946         Window         non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
947         const ITensor *broadcast_tensor     = is_broadcast_input_2 ? in2 : in1;
948         const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
949 
950         // Clear X Dimension on execution window as we handle manually
951         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
952 
953         Iterator broadcast_input(broadcast_tensor, broadcast_win);
954         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
955         Iterator output(out, win);
956 
957         execute_window_loop(win, [&](const Coordinates &)
958         {
959             const auto non_broadcast_input_ptr = reinterpret_cast<const float *>(non_broadcast_input.ptr());
960             const auto output_ptr              = reinterpret_cast<float *>(output.ptr());
961 
962             const float broadcast_value     = *reinterpret_cast<const float *>(broadcast_input.ptr());
963             const auto  broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
964             const auto  scale_vec           = wrapper::vdup_n(scale, ExactTagType{});
965 
966             // Compute window_step_x elements per iteration
967             int x = window_start_x;
968             for(; x <= (window_end_x - window_step_x); x += window_step_x)
969             {
970                 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
971                 auto       res             = wrapper::vmul(wrapper::vmul(broadcast_value_vec, non_broadcast_v), scale_vec);
972                 wrapper::vstore(output_ptr + x, res);
973             }
974 
975             // Compute left-over elements
976             for(; x < window_end_x; ++x)
977             {
978                 const auto non_broadcast_v = *(non_broadcast_input_ptr + x);
979                 *(output_ptr + x)          = broadcast_value * non_broadcast_v * scale;
980             }
981         },
982         broadcast_input, non_broadcast_input, output);
983     }
984     else
985     {
986         // Clear X Dimension on execution window as we handle manually
987         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
988         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
989 
990         Iterator input1(in1, input1_win);
991         Iterator input2(in2, input2_win);
992         Iterator output(out, win);
993 
994         execute_window_loop(win, [&](const Coordinates &)
995         {
996             const auto input1_ptr = reinterpret_cast<const float *>(input1.ptr());
997             const auto input2_ptr = reinterpret_cast<const float *>(input2.ptr());
998             const auto output_ptr = reinterpret_cast<float *>(output.ptr());
999 
1000             // Compute window_step_x elements per iteration
1001             int x = window_start_x;
1002             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1003             {
1004                 const auto ta1       = wrapper::vloadq(input1_ptr + x);
1005                 const auto ta2       = wrapper::vloadq(input2_ptr + x);
1006                 const auto scale_vec = wrapper::vdup_n(scale, ExactTagType{});
1007                 const auto res       = wrapper::vmul(wrapper::vmul(ta1, ta2), scale_vec);
1008                 wrapper::vstore(output_ptr + x, res);
1009             }
1010 
1011             // Compute left-over elements
1012             for(; x < window_end_x; ++x)
1013             {
1014                 const auto ta1    = *(input1_ptr + x);
1015                 const auto ta2    = *(input2_ptr + x);
1016                 *(output_ptr + x) = ta1 * ta2 * scale;
1017             }
1018         },
1019         input1, input2, output);
1020     }
1021 }
1022 
c_mul_F32_F32_F32_n(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)1023 void c_mul_F32_F32_F32_n(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1024 {
1025     // Create input windows
1026     Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
1027     Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
1028 
1029     // Clear X Dimension on execution window as we handle manually
1030     Window win = window;
1031     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1032 
1033     constexpr int window_step_x         = 8 / sizeof(float);
1034     const auto    window_start_x        = static_cast<int>(window.x().start());
1035     const auto    window_end_x          = static_cast<int>(window.x().end());
1036     const bool    is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
1037 
1038     using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
1039 
1040     if(is_broadcast_across_x)
1041     {
1042         const bool     is_broadcast_input_2 = input2_win.x().step() == 0;
1043         Window         broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
1044         Window         non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
1045         const ITensor *broadcast_tensor     = is_broadcast_input_2 ? in2 : in1;
1046         const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
1047 
1048         // Clear X Dimension on execution window as we handle manually
1049         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1050 
1051         Iterator broadcast_input(broadcast_tensor, broadcast_win);
1052         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1053         Iterator output(out, win);
1054 
1055         execute_window_loop(win, [&](const Coordinates &)
1056         {
1057             const auto non_broadcast_input_ptr = reinterpret_cast<const float *>(non_broadcast_input.ptr());
1058             const auto output_ptr              = reinterpret_cast<float *>(output.ptr());
1059 
1060             const float broadcast_value = *reinterpret_cast<const float *>(broadcast_input.ptr());
1061 
1062             // Compute window_step_x elements per iteration
1063             int x = window_start_x;
1064             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1065             {
1066                 const auto  a = wrapper::vloadq(non_broadcast_input_ptr + 2 * x);
1067                 float32x4_t b = vdupq_n_f32(broadcast_value);
1068 
1069                 const float32x4_t mask  = { -1.0f, 1.0f, -1.0f, 1.0f };
1070                 const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
1071                 const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
1072                 const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{});
1073                 const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{});
1074 
1075                 const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10);
1076                 const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11);
1077 
1078                 float32x4_t res = wrapper::vmul(tmp0, b);
1079                 b               = wrapper::vmul(b, mask);
1080 
1081                 res = wrapper::vmla(res, tmp1, b);
1082                 wrapper::vstore(output_ptr + 2 * x, res);
1083             }
1084 
1085             // Compute left-over elements
1086             for(; x < window_end_x; ++x)
1087             {
1088                 const auto non_broadcast_value0 = *(non_broadcast_input_ptr + 2 * x);
1089                 const auto non_broadcast_value1 = *(non_broadcast_input_ptr + 2 * x + 1);
1090                 auto       res1                 = broadcast_value * (non_broadcast_value0 - non_broadcast_value1);
1091                 auto       res2                 = broadcast_value * (non_broadcast_value1 + non_broadcast_value0);
1092                 *(output_ptr + 2 * x)           = res1;
1093                 *(output_ptr + 2 * x + 1)       = res2;
1094             }
1095         },
1096         broadcast_input, non_broadcast_input, output);
1097     }
1098     else
1099     {
1100         // Clear X Dimension on execution window as we handle manually
1101         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1102         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1103 
1104         Iterator input1(in1, input1_win);
1105         Iterator input2(in2, input2_win);
1106         Iterator output(out, win);
1107 
1108         execute_window_loop(win, [&](const Coordinates &)
1109         {
1110             const auto input1_ptr = reinterpret_cast<const float *>(input1.ptr());
1111             const auto input2_ptr = reinterpret_cast<const float *>(input2.ptr());
1112             const auto output_ptr = reinterpret_cast<float *>(output.ptr());
1113 
1114             // Compute window_step_x elements per iteration
1115             int x = window_start_x;
1116             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1117             {
1118                 const float32x4_t a = wrapper::vloadq(input1_ptr + 2 * x);
1119                 float32x4_t       b = wrapper::vloadq(input2_ptr + 2 * x);
1120 
1121                 const float32x4_t mask  = { -1.0f, 1.0f, -1.0f, 1.0f };
1122                 const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
1123                 const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
1124                 const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{});
1125                 const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{});
1126 
1127                 const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10);
1128                 const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11);
1129 
1130                 float32x4_t res = wrapper::vmul(tmp0, b);
1131 
1132                 b = wrapper::vrev64(b);
1133                 b = wrapper::vmul(b, mask);
1134 
1135                 res = wrapper::vmla(res, tmp1, b);
1136                 wrapper::vstore(output_ptr + 2 * x, res);
1137             }
1138 
1139             // Compute left-over elements
1140             for(; x < window_end_x; ++x)
1141             {
1142                 const auto a0             = *(input1_ptr + 2 * x);
1143                 const auto a1             = *(input1_ptr + 2 * x + 1);
1144                 const auto b0             = *(input2_ptr + 2 * x);
1145                 const auto b1             = *(input2_ptr + 2 * x + 1);
1146                 auto       res1           = a0 * b0 - a1 * b1;
1147                 auto       res2           = a0 * b1 + a1 * b0;
1148                 *(output_ptr + 2 * x)     = res1;
1149                 *(output_ptr + 2 * x + 1) = res2;
1150             }
1151         },
1152         input1, input2, output);
1153     }
1154 }
1155 
1156 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
mul_F16_F16_F16(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,float scale)1157 void mul_F16_F16_F16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale)
1158 {
1159     // Create input windows
1160     Window win        = window;
1161     Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
1162     Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
1163 
1164     // Clear X Dimension on execution window as we handle manually
1165     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1166     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1167     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1168 
1169     Iterator input1(in1, input1_win);
1170     Iterator input2(in2, input2_win);
1171     Iterator output(out, win);
1172 
1173     const int  window_step_x  = 16;
1174     const auto window_start_x = static_cast<int>(window.x().start());
1175     const auto window_end_x   = static_cast<int>(window.x().end());
1176 
1177     execute_window_loop(win, [&](const Coordinates &)
1178     {
1179         const auto input1_ptr = reinterpret_cast<const float16_t *>(input1.ptr());
1180         const auto input2_ptr = reinterpret_cast<const float16_t *>(input2.ptr());
1181         const auto output_ptr = reinterpret_cast<float16_t *>(output.ptr());
1182 
1183         // Compute window_step_x elements per iteration
1184         int x = window_start_x;
1185         for(; x <= (window_end_x - window_step_x); x += window_step_x)
1186         {
1187             const float16x8x2_t ta1 =
1188             {
1189                 {
1190                     vld1q_f16(input1_ptr + x),
1191                     vld1q_f16(input1_ptr + x + 8),
1192                 }
1193             };
1194             const float16x8x2_t ta2 =
1195             {
1196                 {
1197                     vld1q_f16(input2_ptr + x),
1198                     vld1q_f16(input2_ptr + x + 8),
1199                 }
1200             };
1201             const float16x8_t   scale_vec = vdupq_n_f16(scale);
1202             const float16x8x2_t result =
1203             {
1204                 {
1205                     vmulq_f16(vmulq_f16(ta1.val[0], ta2.val[0]), scale_vec),
1206                     vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec),
1207                 }
1208             };
1209             vst1q_f16(output_ptr + x, result.val[0]);
1210             vst1q_f16(output_ptr + x + 8, result.val[1]);
1211         }
1212 
1213         // Compute left-over elements
1214         for(; x < window_end_x; ++x)
1215         {
1216             const auto ta1    = *(input1_ptr + x);
1217             const auto ta2    = *(input2_ptr + x);
1218             *(output_ptr + x) = ta1 * ta2 * scale;
1219         }
1220     },
1221     input1, input2, output);
1222 }
1223 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1224 
1225 template <bool is_scale255, bool is_sat>
mul_U8_U8_S16(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,int n)1226 void mul_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
1227 {
1228     // Create input windows
1229     Window win        = window;
1230     Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
1231     Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
1232 
1233     // Clear X Dimension on execution window as we handle manually
1234     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1235     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1236     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1237 
1238     Iterator input1(in1, input1_win);
1239     Iterator input2(in2, input2_win);
1240     Iterator output(out, win);
1241 
1242     const int  window_step_x  = 16 / sizeof(uint8_t);
1243     const auto window_start_x = static_cast<int>(window.x().start());
1244     const auto window_end_x   = static_cast<int>(window.x().end());
1245 
1246     execute_window_loop(win, [&](const Coordinates &)
1247     {
1248         const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
1249         const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
1250         const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
1251 
1252         // Compute window_step_x elements per iteration
1253         int x = window_start_x;
1254         for(; x <= (window_end_x - window_step_x); x += window_step_x)
1255         {
1256             const uint8x16_t bv = wrapper::vloadq(input2_ptr + x);
1257             const uint8x16_t av = wrapper::vloadq(input1_ptr + x);
1258 
1259             uint16x8_t tmp_low  = vmovl_u8(vget_low_u8(av));
1260             uint16x8_t tmp_high = vmovl_u8(vget_high_u8(av));
1261             tmp_low             = vmulq_u16(tmp_low, vmovl_u8(vget_low_u8(bv)));
1262             tmp_high            = vmulq_u16(tmp_high, vmovl_u8(vget_high_u8(bv)));
1263 
1264             if(is_scale255)
1265             {
1266                 tmp_low  = scale255_U16_U16(tmp_low);
1267                 tmp_high = scale255_U16_U16(tmp_high);
1268             }
1269             else
1270             {
1271                 const int16x8_t vn = vdupq_n_s16(-n);
1272 
1273                 if(is_sat)
1274                 {
1275                     tmp_low  = vqshlq_u16(tmp_low, vn);
1276                     tmp_high = vqshlq_u16(tmp_high, vn);
1277                 }
1278                 else
1279                 {
1280                     tmp_low  = vshlq_u16(tmp_low, vn);
1281                     tmp_high = vshlq_u16(tmp_high, vn);
1282                 }
1283             }
1284 
1285             if(is_sat)
1286             {
1287                 static const uint16x8_t max = vdupq_n_u16(SHRT_MAX);
1288 
1289                 tmp_low  = vminq_u16(tmp_low, max);
1290                 tmp_high = vminq_u16(tmp_high, max);
1291             }
1292 
1293             vst1q_s16(output_ptr + x, vreinterpretq_s16_u16(tmp_low));
1294             vst1q_s16(output_ptr + x + 8, vreinterpretq_s16_u16(tmp_high));
1295         }
1296 
1297         // Compute left-over elements
1298         for(; x < window_end_x; ++x)
1299         {
1300             int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
1301 
1302             if(is_scale255)
1303             {
1304                 float tmp_f = static_cast<float>(tmp) * scale255_constant;
1305                 tmp         = static_cast<int32_t>(tmp_f + 0.5f);
1306             }
1307             else
1308             {
1309                 tmp >>= n;
1310             }
1311 
1312             if(is_sat)
1313             {
1314                 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : tmp;
1315             }
1316 
1317             *(output_ptr + x) = static_cast<int16_t>(tmp);
1318         }
1319     },
1320     input1, input2, output);
1321 }
1322 
1323 template <bool is_scale255, bool is_sat>
mul_S16_U8_S16(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,int n)1324 void mul_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
1325 {
1326     // Create input windows
1327     Window win        = window;
1328     Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
1329     Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
1330 
1331     // Clear X Dimension on execution window as we handle manually
1332     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1333     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1334     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1335 
1336     Iterator input1(in1, input1_win);
1337     Iterator input2(in2, input2_win);
1338     Iterator output(out, win);
1339 
1340     const int  window_step_x  = 16;
1341     const auto window_start_x = static_cast<int>(window.x().start());
1342     const auto window_end_x   = static_cast<int>(window.x().end());
1343 
1344     execute_window_loop(win, [&](const Coordinates &)
1345     {
1346         const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
1347         const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
1348         const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
1349 
1350         // Compute window_step_x elements per iteration
1351         int x = window_start_x;
1352         for(; x <= (window_end_x - window_step_x); x += window_step_x)
1353         {
1354             const int16x8x2_t ta1 =
1355             {
1356                 {
1357                     vld1q_s16(input1_ptr + x),
1358                     vld1q_s16(input1_ptr + x + 8),
1359                 }
1360             };
1361             const uint8x8x2_t ta2u =
1362             {
1363                 {
1364                     vld1_u8(input2_ptr + x),
1365                     vld1_u8(input2_ptr + x + 8),
1366                 }
1367             };
1368             const int16x8x2_t ta2 =
1369             {
1370                 {
1371                     vreinterpretq_s16_u16(vmovl_u8(ta2u.val[0])),
1372                     vreinterpretq_s16_u16(vmovl_u8(ta2u.val[1]))
1373                 }
1374             };
1375 
1376             const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
1377 
1378             vst1q_s16(output_ptr + x, result.val[0]);
1379             vst1q_s16(output_ptr + x + 8, result.val[1]);
1380         }
1381 
1382         // Compute left-over elements
1383         for(; x < window_end_x; ++x)
1384         {
1385             int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
1386 
1387             if(is_scale255)
1388             {
1389                 float tmp_f = static_cast<float>(tmp) * scale255_constant;
1390 
1391                 tmp = static_cast<int32_t>(tmp_f + 0.5f);
1392             }
1393             else
1394             {
1395                 if(tmp >= 0)
1396                 {
1397                     tmp >>= n;
1398                 }
1399                 else
1400                 {
1401                     uint32_t mask = (1u << n) - 1;
1402                     tmp           = (tmp + static_cast<int32_t>(mask)) >> n;
1403                 }
1404             }
1405             if(is_sat)
1406             {
1407                 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
1408             }
1409             *(output_ptr + x) = static_cast<int16_t>(tmp);
1410         }
1411     },
1412     input1, input2, output);
1413 }
1414 
1415 template <bool is_scale255, bool is_sat>
mul_U8_S16_S16(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,int n)1416 void mul_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n)
1417 {
1418     // Simply swap the two input buffers
1419     mul_S16_U8_S16<is_scale255, is_sat>(in2, in1, out, window, n);
1420 }
1421 } // namespace
1422 
NEPixelWiseMultiplicationKernel()1423 NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel()
1424     : _func_float(nullptr), _func_int(nullptr), _func_quantized(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
1425 {
1426 }
1427 
configure(ITensorInfo * input1,ITensorInfo * input2,ITensorInfo * output,float scale,ConvertPolicy overflow_policy,RoundingPolicy rounding_policy)1428 void NEPixelWiseMultiplicationKernel::configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
1429 {
1430     ARM_COMPUTE_UNUSED(rounding_policy);
1431     ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
1432 
1433     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy));
1434 
1435     const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
1436     const TensorShape &out_shape    = broadcast_pair.first;
1437     const ValidRegion &valid_region = broadcast_pair.second;
1438 
1439     // Auto initialize output if not initialized
1440     set_shape_if_empty(*output, out_shape);
1441 
1442     _scale          = scale;
1443     _scale_exponent = 0;
1444     _func_quantized = nullptr;
1445     _func_int       = nullptr;
1446     _func_float     = nullptr;
1447 
1448     bool is_scale_255 = false;
1449     // Check and validate scaling factor
1450     if(std::abs(scale - scale255_constant) < 0.00001f)
1451     {
1452         is_scale_255 = true;
1453     }
1454     else
1455     {
1456         int exponent = 0;
1457 
1458         std::frexp(scale, &exponent);
1459 
1460         // Store the positive exponent. We know that we compute 1/2^n
1461         // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5
1462         _scale_exponent = std::abs(exponent - 1);
1463     }
1464 
1465     const DataType dt_input1 = input1->data_type();
1466     const DataType dt_input2 = input2->data_type();
1467     const DataType dt_output = output->data_type();
1468     const bool     is_sat    = (overflow_policy == ConvertPolicy::SATURATE);
1469 
1470     switch(dt_input1)
1471     {
1472         case DataType::QASYMM8:
1473             if(dt_input2 == DataType::QASYMM8 && dt_output == DataType::QASYMM8)
1474             {
1475                 _func_quantized = &mul_saturate_quantized_8<uint8_t>;
1476             }
1477             break;
1478         case DataType::QASYMM8_SIGNED:
1479             if(dt_input2 == DataType::QASYMM8_SIGNED)
1480             {
1481                 _func_quantized = &mul_saturate_quantized_8<int8_t>;
1482                 ;
1483             }
1484             break;
1485         case DataType::QSYMM16:
1486             if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::QSYMM16)
1487             {
1488                 _func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16;
1489             }
1490             else if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::S32)
1491             {
1492                 _func_int = &mul_QSYMM16_QSYMM16_S32;
1493             }
1494             break;
1495         case DataType::S16:
1496             if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1497             {
1498                 if(is_scale_255)
1499                 {
1500                     _func_int = is_sat ? &mul_S16_U8_S16<true, true> : &mul_S16_U8_S16<true, false>;
1501                 }
1502                 else
1503                 {
1504                     _func_int = is_sat ? &mul_S16_U8_S16<false, true> : &mul_S16_U8_S16<false, false>;
1505                 }
1506             }
1507             if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1508             {
1509                 if(is_scale_255)
1510                 {
1511                     _func_int = is_sat ? &mul_S16_S16_S16<true, true> : &mul_S16_S16_S16<true, false>;
1512                 }
1513                 else
1514                 {
1515                     _func_int = is_sat ? &mul_S16_S16_S16<false, true> : &mul_S16_S16_S16<false, false>;
1516                 }
1517             }
1518             break;
1519         case DataType::S32:
1520             if(DataType::S32 == dt_input2 && DataType::S32 == dt_output)
1521             {
1522                 _func_int = is_sat ? &mul_S32_S32_S32<true> : &mul_S32_S32_S32<false>;
1523             }
1524             break;
1525         case DataType::U8:
1526             if(DataType::U8 == dt_input2 && DataType::U8 == dt_output)
1527             {
1528                 if(is_scale_255)
1529                 {
1530                     _func_int = is_sat ? &mul_U8_U8_U8<true, true> : &mul_U8_U8_U8<true, false>;
1531                 }
1532                 else
1533                 {
1534                     _func_int = is_sat ? &mul_U8_U8_U8<false, true> : &mul_U8_U8_U8<false, false>;
1535                 }
1536             }
1537             else if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1538             {
1539                 if(is_scale_255)
1540                 {
1541                     _func_int = is_sat ? &mul_U8_U8_S16<true, true> : &mul_U8_U8_S16<true, false>;
1542                 }
1543                 else
1544                 {
1545                     _func_int = is_sat ? &mul_U8_U8_S16<false, true> : &mul_U8_U8_S16<false, false>;
1546                 }
1547             }
1548             else if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1549             {
1550                 if(is_scale_255)
1551                 {
1552                     _func_int = is_sat ? &mul_U8_S16_S16<true, true> : &mul_U8_S16_S16<true, false>;
1553                 }
1554                 else
1555                 {
1556                     _func_int = is_sat ? &mul_U8_S16_S16<false, true> : &mul_U8_S16_S16<false, false>;
1557                 }
1558             }
1559             break;
1560 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1561         case DataType::F16:
1562             _func_float = &mul_F16_F16_F16;
1563             break;
1564 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1565         case DataType::F32:
1566             _func_float = &mul_F32_F32_F32;
1567             break;
1568         default:
1569             ARM_COMPUTE_ERROR("You called with the wrong img formats");
1570     }
1571 
1572     // Configure kernel window
1573     Coordinates coord;
1574     coord.set_num_dimensions(output->num_dimensions());
1575     output->set_valid_region(valid_region);
1576     Window win = calculate_max_window(valid_region, Steps());
1577 
1578     INEKernel::configure(win);
1579 }
1580 
validate(const ITensorInfo * input1,const ITensorInfo * input2,const ITensorInfo * output,float scale,ConvertPolicy overflow_policy,RoundingPolicy rounding_policy)1581 Status NEPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy,
1582                                                  RoundingPolicy rounding_policy)
1583 {
1584     ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
1585     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy));
1586 
1587     return Status{};
1588 }
1589 
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)1590 void NEPixelWiseMultiplicationKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
1591 {
1592     ARM_COMPUTE_UNUSED(info);
1593     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1594     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1595 
1596     auto input1 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
1597     auto input2 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
1598     auto output = tensors.get_tensor(TensorType::ACL_DST);
1599 
1600     if(_func_quantized != nullptr)
1601     {
1602         (*_func_quantized)(input1, input2, output, window, _scale);
1603     }
1604     else if(_func_int != nullptr)
1605     {
1606         (*_func_int)(input1, input2, output, window, _scale_exponent);
1607     }
1608     else
1609     {
1610         ARM_COMPUTE_ERROR_ON(_func_float == nullptr);
1611         (*_func_float)(input1, input2, output, window, _scale);
1612     }
1613 }
1614 namespace
1615 {
validate_arguments_complex(const ITensorInfo * input1,const ITensorInfo * input2,const ITensorInfo * output)1616 Status validate_arguments_complex(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1617 {
1618     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 2, DataType::F32);
1619     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 2, DataType::F32);
1620 
1621     const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
1622 
1623     ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
1624 
1625     // Validate in case of configured output
1626     if(output->total_size() > 0)
1627     {
1628         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 2, DataType::F32);
1629         ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
1630     }
1631 
1632     return Status{};
1633 }
1634 } // namespace
1635 
configure(ITensorInfo * input1,ITensorInfo * input2,ITensorInfo * output)1636 void NEComplexPixelWiseMultiplicationKernel::configure(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
1637 {
1638     ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
1639     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_complex(input1, input2, output));
1640 
1641     const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
1642     const TensorShape &out_shape    = broadcast_pair.first;
1643     const ValidRegion &valid_region = broadcast_pair.second;
1644 
1645     // Auto initialize output if not initialized
1646     const TensorInfo out_info(out_shape, input1->num_channels(), input1->data_type());
1647     auto_init_if_empty(*output, out_info);
1648 
1649     // Configure kernel window
1650     Coordinates coord;
1651     coord.set_num_dimensions(output->num_dimensions());
1652     output->set_valid_region(valid_region);
1653     Window win = calculate_max_window(valid_region, Steps());
1654 
1655     INEKernel::configure(win);
1656 }
1657 
validate(const ITensorInfo * input1,const ITensorInfo * input2,const ITensorInfo * output)1658 Status NEComplexPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
1659 {
1660     ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
1661     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_complex(input1, input2, output));
1662 
1663     return Status{};
1664 }
1665 
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)1666 void NEComplexPixelWiseMultiplicationKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
1667 {
1668     ARM_COMPUTE_UNUSED(info);
1669     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1670     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1671 
1672     auto input1 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
1673     auto input2 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
1674     auto output = tensors.get_tensor(TensorType::ACL_DST);
1675 
1676     c_mul_F32_F32_F32_n(input1, input2, output, window);
1677 }
1678 } // namespace arm_compute
1679