• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <arm_neon.h>
10 
11 #include <xnnpack/zip.h>
12 
13 
xnn_x8_zip_xm_ukernel__neon(size_t n,size_t m,const uint8_t * input,uint8_t * output)14 void xnn_x8_zip_xm_ukernel__neon(
15     size_t n,
16     size_t m,
17     const uint8_t* input,
18     uint8_t* output)
19 {
20   const uint8_t* w = input;
21   const size_t input_increment = n * 3;
22   const size_t output_increment = 4 - m * n;
23   const uint8_t* last_input = w + n * (m - 1);
24   uint8_t* last_output = (uint8_t*) ((uintptr_t) output + (m - 4));
25 
26   if (n >= 8) {
27     for (size_t i = 0; i < m; i += 4) {
28       size_t k = n;
29       w = (const uint8_t*) ((uintptr_t) w + input_increment);
30       if (w >= last_input) {
31         w = last_input;
32       }
33       const uint8_t* z = (const uint8_t*) ((uintptr_t) w - n);
34       const uint8_t* y = (const uint8_t*) ((uintptr_t) z - n);
35       const uint8_t* x = (const uint8_t*) ((uintptr_t) y - n);
36       while (k >= 8) {
37         const uint8x8_t vx = vld1_u8(x); x += 8;
38         const uint8x8_t vy = vld1_u8(y); y += 8;
39         const uint8x8_t vz = vld1_u8(z); z += 8;
40         const uint8x8_t vw = vld1_u8(w); w += 8;
41 
42         const uint8x8x2_t vxy = vzip_u8(vx, vy);
43         const uint8x8x2_t vzw = vzip_u8(vz, vw);
44         const uint16x4x2_t vxyzw_lo = vzip_u16(vreinterpret_u16_u8(vxy.val[0]), vreinterpret_u16_u8(vzw.val[0]));
45         const uint16x4x2_t vxyzw_hi = vzip_u16(vreinterpret_u16_u8(vxy.val[1]), vreinterpret_u16_u8(vzw.val[1]));
46 
47         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[0]), 0);
48         output = (uint8_t*) ((uintptr_t) output + m);
49 
50         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[0]), 1);
51         output = (uint8_t*) ((uintptr_t) output + m);
52 
53         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[1]), 0);
54         output = (uint8_t*) ((uintptr_t) output + m);
55 
56         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[1]), 1);
57         output = (uint8_t*) ((uintptr_t) output + m);
58 
59         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[0]), 0);
60         output = (uint8_t*) ((uintptr_t) output + m);
61 
62         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[0]), 1);
63         output = (uint8_t*) ((uintptr_t) output + m);
64 
65         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[1]), 0);
66         output = (uint8_t*) ((uintptr_t) output + m);
67 
68         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[1]), 1);
69         output = (uint8_t*) ((uintptr_t) output + m);
70 
71         k -= 8;
72       }
73       if (k != 0) {
74         const size_t address_increment = k - 8;
75         x = (const uint8_t*) ((uintptr_t) x + address_increment);
76         y = (const uint8_t*) ((uintptr_t) y + address_increment);
77         z = (const uint8_t*) ((uintptr_t) z + address_increment);
78         w = (const uint8_t*) ((uintptr_t) w + address_increment);
79         const int64x1_t vshift = vmov_n_s64(8 * address_increment);
80 
81         const uint64x1_t vx = vshl_u64(vreinterpret_u64_u8(vld1_u8(x)), vshift);
82         const uint64x1_t vy = vshl_u64(vreinterpret_u64_u8(vld1_u8(y)), vshift);
83         const uint64x1_t vz = vshl_u64(vreinterpret_u64_u8(vld1_u8(z)), vshift);
84         const uint64x1_t vw = vshl_u64(vreinterpret_u64_u8(vld1_u8(w)), vshift); w += 8;
85         const uint8x8x2_t vxy = vzip_u8(vreinterpret_u8_u64(vx), vreinterpret_u8_u64(vy));
86         const uint8x8x2_t vzw = vzip_u8(vreinterpret_u8_u64(vz), vreinterpret_u8_u64(vw));
87         const uint16x4x2_t vxyzw_lo = vzip_u16(vreinterpret_u16_u8(vxy.val[0]), vreinterpret_u16_u8(vzw.val[0]));
88         const uint16x4x2_t vxyzw_hi = vzip_u16(vreinterpret_u16_u8(vxy.val[1]), vreinterpret_u16_u8(vzw.val[1]));
89 
90         uint32x2_t vxyzw0 = vreinterpret_u32_u16(vxyzw_lo.val[0]);
91         uint32x2_t vxyzw1 = vreinterpret_u32_u16(vxyzw_lo.val[1]);
92         uint32x2_t vxyzw2 = vreinterpret_u32_u16(vxyzw_hi.val[0]);
93         uint32x2_t vxyzw3 = vreinterpret_u32_u16(vxyzw_hi.val[1]);
94 
95         if (k & 4) {
96           vst1_lane_u32((void*) output, vxyzw0, 0);
97           output = (uint8_t*) ((uintptr_t) output + m);
98 
99           vst1_lane_u32((void*) output, vxyzw0, 1);
100           output = (uint8_t*) ((uintptr_t) output + m);
101 
102           vst1_lane_u32((void*) output, vxyzw1, 0);
103           output = (uint8_t*) ((uintptr_t) output + m);
104 
105           vst1_lane_u32((void*) output, vxyzw1, 1);
106           output = (uint8_t*) ((uintptr_t) output + m);
107 
108           vxyzw0 = vxyzw2;
109           vxyzw1 = vxyzw3;
110         }
111 
112         if (k & 2) {
113           vst1_lane_u32((void*) output, vxyzw0, 0);
114           output = (uint8_t*) ((uintptr_t) output + m);
115 
116           vst1_lane_u32((void*) output, vxyzw0, 1);
117           output = (uint8_t*) ((uintptr_t) output + m);
118 
119           vxyzw0 = vxyzw1;
120         }
121         if (k & 1) {
122           vst1_lane_u32((void*) output, vxyzw0, 0);
123           output = (uint8_t*) ((uintptr_t) output + m);
124         }
125       }
126       output = (uint8_t*) ((uintptr_t) output + output_increment);
127       if (output > last_output) {
128         output = last_output;
129       }
130     }
131   } else {
132     const uint8_t* i = input;
133     uint8_t* o = output;
134     size_t k = n;
135     do {
136       size_t l = m;
137       const uint8_t* ii = i++;
138       do {
139         *o++ = *ii;
140         ii += n;
141       } while (--l != 0);
142     } while (--k != 0);
143   }
144 }
145