• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/Helpers.h"
25 #include "arm_compute/core/Window.h"
26 #include "src/core/NEON/NEAsymm.h"
27 #include "src/core/NEON/NEMath.h"
28 #include "src/core/NEON/wrapper/wrapper.h"
29 
30 #include <arm_neon.h>
31 namespace
32 {
clamp_v4f32(float32x4_t block,float32x4_t quant_min_vec,float32x4_t quant_max_vec)33 inline float32x4_t clamp_v4f32(float32x4_t block, float32x4_t quant_min_vec, float32x4_t quant_max_vec)
34 {
35     return vminq_f32(vmaxq_f32(block, quant_min_vec), quant_max_vec);
36 }
fuse_words_f32(float32x4_t fb1,float32x4_t fb2)37 inline uint16x8_t fuse_words_f32(float32x4_t fb1, float32x4_t fb2)
38 {
39     return vcombine_u16(vmovn_u32(vcvtq_u32_f32(fb1)), vmovn_u32(vcvtq_u32_f32(fb2)));
40 }
fuse_shorts_u16(uint16x8_t sb1,uint16x8_t sb2)41 inline uint8x16_t fuse_shorts_u16(uint16x8_t sb1, uint16x8_t sb2)
42 {
43     return vcombine_u8(vmovn_u16(sb1), vmovn_u16(sb2));
44 }
45 } // namespace
46 
47 namespace arm_compute
48 {
49 namespace cpu
50 {
neon_qasymm8_meanstddevnorm(ITensor * input,ITensor * output,float epsilon,const Window & window)51 void neon_qasymm8_meanstddevnorm(ITensor *input, ITensor *output, float epsilon, const Window &window)
52 {
53     Window win = window;
54     win.set(Window::DimX, Window::Dimension(0, 1, 1));
55 
56     const int window_step_x  = 16;
57     const int window_start_x = static_cast<int>(window.x().start());
58     const int window_end_x   = static_cast<int>(window.x().end());
59 
60     const UniformQuantizationInfo qi_out        = output->info()->quantization_info().uniform();
61     const float                   output_scale  = qi_out.scale;
62     const int                     output_offset = qi_out.offset;
63 
64     Iterator input_itr(input, win);
65     Iterator output_itr(output, win);
66 
67     const float       output_inv_scale = 1.0f / output_scale;
68     const float32x4_t quant_max_vec    = vdupq_n_f32(255.0f);
69     const float32x4_t quant_min_vec    = vdupq_n_f32(0.0f);
70 
71     execute_window_loop(
72         win, [&](const Coordinates &)
73     {
74         int  x       = window_start_x;
75         auto in_ptr  = reinterpret_cast<const uint8_t *>(input_itr.ptr());
76         auto out_ptr = reinterpret_cast<uint8_t *>(output_itr.ptr());
77 
78         uint32x4_t sum_vec    = vdupq_n_u32(0);
79         uint32x4_t sum_sq_vec = vdupq_n_u32(0);
80 
81         for(; x <= (window_end_x - window_step_x); x += window_step_x)
82         {
83             const uint8x16_t data         = vld1q_u8(in_ptr + x);
84             sum_vec                       = vaddq_u32(sum_vec, vpaddlq_u16(vpaddlq_u8(data)));
85             const uint16x8_t squares_low  = vmull_u8(vget_low_u8(data), vget_low_u8(data));
86             const uint16x8_t squares_high = vmull_u8(vget_high_u8(data), vget_high_u8(data));
87             sum_sq_vec                    = vaddq_u32(sum_sq_vec, vaddq_u32(vpaddlq_u16(squares_low), vpaddlq_u16(squares_high)));
88         }
89 
90 #ifdef __aarch64__
91         sum_vec         = vpaddq_u32(sum_vec, sum_vec);
92         sum_vec         = vpaddq_u32(sum_vec, sum_vec);
93         uint32_t sum    = vgetq_lane_u32(sum_vec, 0);
94         sum_sq_vec      = vpaddq_u32(sum_sq_vec, sum_sq_vec);
95         sum_sq_vec      = vpaddq_u32(sum_sq_vec, sum_sq_vec);
96         uint32_t sum_sq = vgetq_lane_u32(sum_sq_vec, 0);
97 #elif __arm__ // #ifdef __aarch64__
98         uint32_t sum =  vgetq_lane_u32(sum_vec, 0) +
99                         vgetq_lane_u32(sum_vec, 1) +
100                         vgetq_lane_u32(sum_vec, 2) +
101                         vgetq_lane_u32(sum_vec, 3);
102 
103         uint32_t sum_sq =   vgetq_lane_u32(sum_sq_vec, 0) +
104                             vgetq_lane_u32(sum_sq_vec, 1) +
105                             vgetq_lane_u32(sum_sq_vec, 2) +
106                             vgetq_lane_u32(sum_sq_vec, 3);
107 #endif        // #ifdef __aarch64__
108         for(; x < window_end_x; ++x)
109         {
110             auto data = static_cast<uint32_t>(*(in_ptr + x));
111             sum += data;
112             sum_sq += (data * data);
113         }
114 
115         const float       mean      = (static_cast<float>(sum) / static_cast<float>(input->info()->dimension(0)));
116         const float       var       = (static_cast<float>(sum_sq) / static_cast<float>(input->info()->dimension(0))) - (mean * mean);
117         const float       stdev_inv = 1.0f / sqrtf(var + epsilon);
118         const float32x4_t v_scale   = vdupq_n_f32(stdev_inv * output_inv_scale);
119         const float32x4_t v_offset  = vdupq_n_f32(-mean * stdev_inv * output_inv_scale + output_offset);
120         for(x = window_start_x; x <= (window_end_x - window_step_x); x += window_step_x)
121         {
122             const uint8x16_t data = vld1q_u8(in_ptr + x);
123             float32x4_t      db1  = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(data)))));
124             float32x4_t      db2  = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(data)))));
125             float32x4_t      db3  = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(data)))));
126             float32x4_t      db4  = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(data)))));
127             db1                   = clamp_v4f32(vaddq_f32(vmulq_f32(db1, v_scale), v_offset), quant_min_vec, quant_max_vec);
128             db2                   = clamp_v4f32(vaddq_f32(vmulq_f32(db2, v_scale), v_offset), quant_min_vec, quant_max_vec);
129             db3                   = clamp_v4f32(vaddq_f32(vmulq_f32(db3, v_scale), v_offset), quant_min_vec, quant_max_vec);
130             db4                   = clamp_v4f32(vaddq_f32(vmulq_f32(db4, v_scale), v_offset), quant_min_vec, quant_max_vec);
131             const uint8x16_t out  = fuse_shorts_u16(fuse_words_f32(db1, db2), fuse_words_f32(db3, db4));
132             vst1q_u8(out_ptr + x, out);
133         }
134 
135         for(; x < window_end_x; ++x)
136         {
137             auto          data = static_cast<float32_t>(*(in_ptr + x));
138             const uint8_t res  = data * (stdev_inv * output_inv_scale) + (-mean * stdev_inv * output_inv_scale + output_offset);
139             *(out_ptr + x)     = res;
140         }
141     },
142     input_itr, output_itr);
143 }
144 } // namespace cpu
145 } // namespace arm_compute
146