• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // output_neon.h: optimized NEON specializations of the templates in output.h.
16 
17 #ifndef GEMMLOWP_INTERNAL_OUTPUT_NEON_H_
18 #define GEMMLOWP_INTERNAL_OUTPUT_NEON_H_
19 
20 #include "output.h"
21 
22 #include <arm_neon.h>
23 
24 namespace gemmlowp {
25 
26 template <>
27 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
28                                  RegBufferInt32<4>> {
29   typedef RegBufferInt32<4> InputType;
30   typedef RegBufferUint8<4> OutputType;
31 
32   typedef OutputStageSaturatingCastToUint8 OutputStage;
33 
34   OutputStageEvalBufferImpl(const OutputStage&) {}
35 
36   OutputType Eval(InputType input) const {
37     OutputType output;
38     int16x4_t res_16 = vqmovn_s32(input.reg[0]);
39     uint8x8_t res_8 = vqmovun_s16(vcombine_s16(res_16, res_16));
40     output.reg[0] = vget_lane_u32(vreinterpret_u32_u8(res_8), 0);
41     return output;
42   }
43 };
44 
45 template <>
46 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
47                                  RegBufferInt32<8>> {
48   typedef RegBufferInt32<8> InputType;
49   typedef RegBufferUint8<8> OutputType;
50 
51   typedef OutputStageSaturatingCastToUint8 OutputStage;
52 
53   OutputStageEvalBufferImpl(const OutputStage&) {}
54 
55   OutputType Eval(InputType input) const {
56     OutputType output;
57     int16x8_t res_16 =
58         vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
59     output.reg[0] = vqmovun_s16(res_16);
60     return output;
61   }
62 };
63 
64 template <>
65 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
66                                  RegBufferInt32<16>> {
67   typedef RegBufferInt32<16> InputType;
68   typedef RegBufferUint8<16> OutputType;
69 
70   typedef OutputStageSaturatingCastToUint8 OutputStage;
71 
72   OutputStageEvalBufferImpl(const OutputStage&) {}
73 
74   OutputType Eval(InputType input) const {
75     OutputType output;
76     int16x8_t res_16_0 =
77         vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
78     int16x8_t res_16_1 =
79         vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
80     output.reg[0] = vqmovun_s16(res_16_0);
81     output.reg[1] = vqmovun_s16(res_16_1);
82     return output;
83   }
84 };
85 
86 template <>
87 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
88                                  RegBufferInt32<32>> {
89   typedef RegBufferInt32<32> InputType;
90   typedef RegBufferUint8<32> OutputType;
91 
92   typedef OutputStageSaturatingCastToUint8 OutputStage;
93 
94   OutputStageEvalBufferImpl(const OutputStage&) {}
95 
96   OutputType Eval(InputType input) const {
97     OutputType output;
98     int16x8_t res_16[4];
99     for (int i = 0; i < 4; i++) {
100       res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]),
101                                vqmovn_s32(input.reg[2 * i + 1]));
102     }
103     for (int i = 0; i < 4; i++) {
104       output.reg[i] = vqmovun_s16(res_16[i]);
105     }
106     return output;
107   }
108 };
109 
110 template <>
111 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
112                                  RegBufferInt32<4>> {
113   typedef RegBufferInt32<4> InputType;
114   typedef RegBufferInt16<4> OutputType;
115 
116   typedef OutputStageSaturatingCastToInt16 OutputStage;
117 
118   OutputStageEvalBufferImpl(const OutputStage&) {}
119 
120   OutputType Eval(InputType input) const {
121     OutputType output;
122     output.reg[0] = vqmovn_s32(input.reg[0]);
123     return output;
124   }
125 };
126 
127 template <>
128 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
129                                  RegBufferInt32<8>> {
130   typedef RegBufferInt32<8> InputType;
131   typedef RegBufferInt16<8> OutputType;
132 
133   typedef OutputStageSaturatingCastToInt16 OutputStage;
134 
135   OutputStageEvalBufferImpl(const OutputStage&) {}
136 
137   OutputType Eval(InputType input) const {
138     OutputType output;
139     output.reg[0] =
140         vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
141     return output;
142   }
143 };
144 
145 template <>
146 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
147                                  RegBufferInt32<16>> {
148   typedef RegBufferInt32<16> InputType;
149   typedef RegBufferInt16<16> OutputType;
150 
151   typedef OutputStageSaturatingCastToInt16 OutputStage;
152 
153   OutputStageEvalBufferImpl(const OutputStage&) {}
154 
155   OutputType Eval(InputType input) const {
156     OutputType output;
157     output.reg[0] =
158         vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
159     output.reg[1] =
160         vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
161     return output;
162   }
163 };
164 
165 template <>
166 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
167                                  RegBufferInt32<32>> {
168   typedef RegBufferInt32<32> InputType;
169   typedef RegBufferInt16<32> OutputType;
170 
171   typedef OutputStageSaturatingCastToInt16 OutputStage;
172 
173   OutputStageEvalBufferImpl(const OutputStage&) {}
174 
175   OutputType Eval(InputType input) const {
176     OutputType output;
177     output.reg[0] =
178         vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
179     output.reg[1] =
180         vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
181     output.reg[2] =
182         vcombine_s16(vqmovn_s32(input.reg[4]), vqmovn_s32(input.reg[5]));
183     output.reg[3] =
184         vcombine_s16(vqmovn_s32(input.reg[6]), vqmovn_s32(input.reg[7]));
185     return output;
186   }
187 };
188 
189 template <typename DstType>
190 struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
191   static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
192                   int col) {
193     if (DstType::kOrder == MapOrder::ColMajor) {
194       StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
195       StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
196     } else {
197       vst1q_lane_s32(dst->data(row + 0, col), src.buf.reg[0], 0);
198       vst1q_lane_s32(dst->data(row + 1, col), src.buf.reg[0], 1);
199       vst1q_lane_s32(dst->data(row + 2, col), src.buf.reg[0], 2);
200       vst1q_lane_s32(dst->data(row + 3, col), src.buf.reg[0], 3);
201       vst1q_lane_s32(dst->data(row + 4, col), src.buf.reg[1], 0);
202       vst1q_lane_s32(dst->data(row + 5, col), src.buf.reg[1], 1);
203       vst1q_lane_s32(dst->data(row + 6, col), src.buf.reg[1], 2);
204       vst1q_lane_s32(dst->data(row + 7, col), src.buf.reg[1], 3);
205     }
206   }
207 };
208 
209 template <typename DstType>
210 struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> {
211   static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row,
212                   int col) {
213     if (DstType::kOrder == MapOrder::ColMajor) {
214       StoreInt16x4(dst->data(row, col), src.buf.reg[0]);
215     } else {
216       vst1_lane_s16(dst->data(row + 0, col), src.buf.reg[0], 0);
217       vst1_lane_s16(dst->data(row + 1, col), src.buf.reg[0], 1);
218       vst1_lane_s16(dst->data(row + 2, col), src.buf.reg[0], 2);
219       vst1_lane_s16(dst->data(row + 3, col), src.buf.reg[0], 3);
220     }
221   }
222 };
223 
224 template <typename DstType>
225 struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> {
226   static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row,
227                   int col) {
228     if (DstType::kOrder == MapOrder::ColMajor) {
229       StoreInt16x8(dst->data(row, col), src.buf.reg[0]);
230     } else {
231       vst1q_lane_s16(dst->data(row + 0, col), src.buf.reg[0], 0);
232       vst1q_lane_s16(dst->data(row + 1, col), src.buf.reg[0], 1);
233       vst1q_lane_s16(dst->data(row + 2, col), src.buf.reg[0], 2);
234       vst1q_lane_s16(dst->data(row + 3, col), src.buf.reg[0], 3);
235       vst1q_lane_s16(dst->data(row + 4, col), src.buf.reg[0], 4);
236       vst1q_lane_s16(dst->data(row + 5, col), src.buf.reg[0], 5);
237       vst1q_lane_s16(dst->data(row + 6, col), src.buf.reg[0], 6);
238       vst1q_lane_s16(dst->data(row + 7, col), src.buf.reg[0], 7);
239     }
240   }
241 };
242 
243 inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
244   const int32x4x2_t t0 = vtrnq_s32(src.buf.reg[0], src.buf.reg[1]);
245   const int32x4x2_t t1 = vtrnq_s32(src.buf.reg[2], src.buf.reg[3]);
246   RegBlockInt32<4, 4> result;
247   result.buf.reg[0] =
248       vcombine_s32(vget_low_s32(t0.val[0]), vget_low_s32(t1.val[0]));
249   result.buf.reg[1] =
250       vcombine_s32(vget_low_s32(t0.val[1]), vget_low_s32(t1.val[1]));
251   result.buf.reg[2] =
252       vcombine_s32(vget_high_s32(t0.val[0]), vget_high_s32(t1.val[0]));
253   result.buf.reg[3] =
254       vcombine_s32(vget_high_s32(t0.val[1]), vget_high_s32(t1.val[1]));
255   return result;
256 }
257 
258 template <typename DstType>
259 struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
260   static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row,
261                   int col) {
262     const auto& block =
263         DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src);
264     std::int32_t* dst_ptr = dst->data(row, col);
265     int stride = dst->stride();
266     for (int i = 0; i < 4; i++) {
267       vst1q_s32(dst_ptr + i * stride, block.buf.reg[i]);
268     }
269   }
270 };
271 
272 template <typename DstType>
273 struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> {
274   static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row,
275                   int col) {
276     if (DstType::kOrder == MapOrder::ColMajor) {
277       vst1_s16(dst->data(row, col + 0), vget_low_s16(src.buf.reg[0]));
278       vst1_s16(dst->data(row, col + 1), vget_high_s16(src.buf.reg[0]));
279       vst1_s16(dst->data(row, col + 2), vget_low_s16(src.buf.reg[1]));
280       vst1_s16(dst->data(row, col + 3), vget_high_s16(src.buf.reg[1]));
281     } else {
282       const int16x4x2_t t0 =
283           vtrn_s16(vget_low_s16(src.buf.reg[0]), vget_high_s16(src.buf.reg[0]));
284       const int16x4x2_t t1 =
285           vtrn_s16(vget_low_s16(src.buf.reg[1]), vget_high_s16(src.buf.reg[1]));
286       const int32x4x2_t t =
287           vtrnq_s32(vreinterpretq_s32_s16(vcombine_s16(t0.val[0], t0.val[1])),
288                     vreinterpretq_s32_s16(vcombine_s16(t1.val[0], t1.val[1])));
289       vst1_s16(dst->data(row + 0, col),
290                vget_low_s16(vreinterpretq_s16_s32(t.val[0])));
291       vst1_s16(dst->data(row + 1, col),
292                vget_high_s16(vreinterpretq_s16_s32(t.val[0])));
293       vst1_s16(dst->data(row + 2, col),
294                vget_low_s16(vreinterpretq_s16_s32(t.val[1])));
295       vst1_s16(dst->data(row + 3, col),
296                vget_high_s16(vreinterpretq_s16_s32(t.val[1])));
297     }
298   }
299 };
300 
301 template <typename DstType>
302 struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
303   static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
304                   int col) {
305     std::int32_t* dst_ptr = dst->data(row, col);
306     if (DstType::kOrder == MapOrder::ColMajor) {
307       int col_stride = dst->cols_stride();
308       for (int i = 0; i < 4; i++) {
309         vst1q_s32(dst_ptr + i * col_stride + 0, src.buf.reg[2 * i + 0]);
310         vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]);
311       }
312     } else {
313       int row_stride = dst->rows_stride();
314       RegBlockInt32<4, 4> top;
315       top.buf.reg[0] = src.buf.reg[0];
316       top.buf.reg[1] = src.buf.reg[2];
317       top.buf.reg[2] = src.buf.reg[4];
318       top.buf.reg[3] = src.buf.reg[6];
319       const auto transpose_top = Transpose(top);
320       for (int i = 0; i < 4; i++) {
321         vst1q_s32(dst_ptr + i * row_stride, transpose_top.buf.reg[i]);
322       }
323       RegBlockInt32<4, 4> bottom;
324       bottom.buf.reg[0] = src.buf.reg[1];
325       bottom.buf.reg[1] = src.buf.reg[3];
326       bottom.buf.reg[2] = src.buf.reg[5];
327       bottom.buf.reg[3] = src.buf.reg[7];
328       const auto transpose_bottom = Transpose(bottom);
329       for (int i = 0; i < 4; i++) {
330         vst1q_s32(dst_ptr + (i + 4) * row_stride, transpose_bottom.buf.reg[i]);
331       }
332     }
333   }
334 };
335 
336 template <typename DstType>
337 struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> {
338   static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row,
339                   int col) {
340     if (DstType::kOrder == MapOrder::ColMajor) {
341       vst1q_s16(dst->data(row, col + 0), src.buf.reg[0]);
342       vst1q_s16(dst->data(row, col + 1), src.buf.reg[1]);
343       vst1q_s16(dst->data(row, col + 2), src.buf.reg[2]);
344       vst1q_s16(dst->data(row, col + 3), src.buf.reg[3]);
345     } else {
346       const int16x8x2_t t0 = vtrnq_s16(src.buf.reg[0], src.buf.reg[1]);
347       const int16x8x2_t t1 = vtrnq_s16(src.buf.reg[2], src.buf.reg[3]);
348       const int32x4x2_t u0 = vtrnq_s32(vreinterpretq_s32_s16(t0.val[0]),
349                                        vreinterpretq_s32_s16(t1.val[0]));
350       const int32x4x2_t u1 = vtrnq_s32(vreinterpretq_s32_s16(t0.val[1]),
351                                        vreinterpretq_s32_s16(t1.val[1]));
352       vst1_s16(dst->data(row + 0, col),
353                vget_low_s16(vreinterpretq_s16_s32(u0.val[0])));
354       vst1_s16(dst->data(row + 1, col),
355                vget_low_s16(vreinterpretq_s16_s32(u1.val[0])));
356       vst1_s16(dst->data(row + 2, col),
357                vget_low_s16(vreinterpretq_s16_s32(u0.val[1])));
358       vst1_s16(dst->data(row + 3, col),
359                vget_low_s16(vreinterpretq_s16_s32(u1.val[1])));
360       vst1_s16(dst->data(row + 4, col),
361                vget_high_s16(vreinterpretq_s16_s32(u0.val[0])));
362       vst1_s16(dst->data(row + 5, col),
363                vget_high_s16(vreinterpretq_s16_s32(u1.val[0])));
364       vst1_s16(dst->data(row + 6, col),
365                vget_high_s16(vreinterpretq_s16_s32(u0.val[1])));
366       vst1_s16(dst->data(row + 7, col),
367                vget_high_s16(vreinterpretq_s16_s32(u1.val[1])));
368     }
369   }
370 };
371 
372 template <typename DstType>
373 struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
374   static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
375                   int col) {
376     std::int32_t* dst_ptr = dst->data(row, col);
377     if (DstType::kOrder == MapOrder::ColMajor) {
378       int col_stride = dst->cols_stride();
379       for (int i = 0; i < 8; i++) {
380         vst1q_s32(dst_ptr + i * col_stride, src.buf.reg[2 * i]);
381         vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]);
382       }
383     } else {
384       int row_stride = dst->rows_stride();
385       RegBlockInt32<4, 4> top_left;
386       top_left.buf.reg[0] = src.buf.reg[0];
387       top_left.buf.reg[1] = src.buf.reg[2];
388       top_left.buf.reg[2] = src.buf.reg[4];
389       top_left.buf.reg[3] = src.buf.reg[6];
390       const auto transpose_top_left = Transpose(top_left);
391       for (int i = 0; i < 4; i++) {
392         vst1q_s32(dst_ptr + i * row_stride, transpose_top_left.buf.reg[i]);
393       }
394       RegBlockInt32<4, 4> bottom_left;
395       bottom_left.buf.reg[0] = src.buf.reg[1];
396       bottom_left.buf.reg[1] = src.buf.reg[3];
397       bottom_left.buf.reg[2] = src.buf.reg[5];
398       bottom_left.buf.reg[3] = src.buf.reg[7];
399       const auto transpose_bottom_left = Transpose(bottom_left);
400       for (int i = 0; i < 4; i++) {
401         vst1q_s32(dst_ptr + (i + 4) * row_stride,
402                   transpose_bottom_left.buf.reg[i]);
403       }
404       RegBlockInt32<4, 4> top_right;
405       top_right.buf.reg[0] = src.buf.reg[8];
406       top_right.buf.reg[1] = src.buf.reg[10];
407       top_right.buf.reg[2] = src.buf.reg[12];
408       top_right.buf.reg[3] = src.buf.reg[14];
409       const auto transpose_top_right = Transpose(top_right);
410       for (int i = 0; i < 4; i++) {
411         vst1q_s32(dst_ptr + i * row_stride + 4, transpose_top_right.buf.reg[i]);
412       }
413       RegBlockInt32<4, 4> bottom_right;
414       bottom_right.buf.reg[0] = src.buf.reg[9];
415       bottom_right.buf.reg[1] = src.buf.reg[11];
416       bottom_right.buf.reg[2] = src.buf.reg[13];
417       bottom_right.buf.reg[3] = src.buf.reg[15];
418       const auto transpose_bottom_right = Transpose(bottom_right);
419       for (int i = 0; i < 4; i++) {
420         vst1q_s32(dst_ptr + (i + 4) * row_stride + 4,
421                   transpose_bottom_right.buf.reg[i]);
422       }
423     }
424   }
425 };
426 
427 template <typename DstType>
428 struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
429   static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
430                   int col) {
431     std::int32_t* dst_ptr = dst->data(row, col);
432     if (DstType::kOrder == MapOrder::ColMajor) {
433       vst1q_s32(dst_ptr, src.buf.reg[0]);
434     } else {
435       int row_stride = dst->rows_stride();
436       vst1q_lane_s32(dst_ptr + 0 * row_stride, src.buf.reg[0], 0);
437       vst1q_lane_s32(dst_ptr + 1 * row_stride, src.buf.reg[0], 1);
438       vst1q_lane_s32(dst_ptr + 2 * row_stride, src.buf.reg[0], 2);
439       vst1q_lane_s32(dst_ptr + 3 * row_stride, src.buf.reg[0], 3);
440     }
441   }
442 };
443 
444 template <typename DstType>
445 struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
446   static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
447                   int col) {
448     std::int32_t* dst_ptr = dst->data(row, col);
449     if (DstType::kOrder == MapOrder::RowMajor) {
450       vst1q_s32(dst_ptr, src.buf.reg[0]);
451     } else {
452       int col_stride = dst->cols_stride();
453       vst1q_lane_s32(dst_ptr + 0 * col_stride, src.buf.reg[0], 0);
454       vst1q_lane_s32(dst_ptr + 1 * col_stride, src.buf.reg[0], 1);
455       vst1q_lane_s32(dst_ptr + 2 * col_stride, src.buf.reg[0], 2);
456       vst1q_lane_s32(dst_ptr + 3 * col_stride, src.buf.reg[0], 3);
457     }
458   }
459 };
460 
461 template <typename DstType>
462 struct StoreFinalOutputImpl<RegBlockInt16<1, 4>, DstType> {
463   static void Run(const RegBlockInt16<1, 4>& src, DstType* dst, int row,
464                   int col) {
465     std::int16_t* dst_ptr = dst->data(row, col);
466     if (DstType::kOrder == MapOrder::RowMajor) {
467       vst1_s16(dst_ptr, src.buf.reg[0]);
468     } else {
469       int col_stride = dst->cols_stride();
470       vst1_lane_s16(dst_ptr + 0 * col_stride, src.buf.reg[0], 0);
471       vst1_lane_s16(dst_ptr + 1 * col_stride, src.buf.reg[0], 1);
472       vst1_lane_s16(dst_ptr + 2 * col_stride, src.buf.reg[0], 2);
473       vst1_lane_s16(dst_ptr + 3 * col_stride, src.buf.reg[0], 3);
474     }
475   }
476 };
477 
478 template <typename DstType>
479 struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
480   static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
481                   int col) {
482     const std::uint32_t src_reg = src.buf.reg[0];
483     for (int i = 0; i < 4; i++) {
484       *dst->data(row + i, col) = (src_reg >> (8 * i));
485     }
486   }
487 };
488 
489 template <typename DstType>
490 struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> {
491   static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row,
492                   int col) {
493     for (int i = 0; i < 4; i++) {
494       *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
495     }
496   }
497 };
498 
499 template <typename DstType>
500 struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> {
501   static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row,
502                   int col) {
503     std::uint8_t* dst_ptr = dst->data(row, col);
504     if (DstType::kOrder == MapOrder::ColMajor) {
505       vst1_u8(dst_ptr, src.buf.reg[0]);
506     } else {
507       const int row_stride = dst->rows_stride();
508       vst1_lane_u8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0);
509       vst1_lane_u8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1);
510       vst1_lane_u8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2);
511       vst1_lane_u8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3);
512       vst1_lane_u8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4);
513       vst1_lane_u8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5);
514       vst1_lane_u8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6);
515       vst1_lane_u8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7);
516     }
517   }
518 };
519 
520 template <typename DstType>
521 struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> {
522   static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row,
523                   int col) {
524     std::uint8_t* dst_ptr = dst->data(row, col);
525     const int row_stride = dst->rows_stride();
526     const int col_stride = dst->cols_stride();
527     for (int i = 0; i < 2; i++) {
528       vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride,
529                    src.buf.reg[i], 0);
530       vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride,
531                    src.buf.reg[i], 1);
532       vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride,
533                    src.buf.reg[i], 2);
534       vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride,
535                    src.buf.reg[i], 3);
536       vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride,
537                    src.buf.reg[i], 4);
538       vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride,
539                    src.buf.reg[i], 5);
540       vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride,
541                    src.buf.reg[i], 6);
542       vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride,
543                    src.buf.reg[i], 7);
544     }
545   }
546 };
547 
548 template <typename DstType>
549 struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
550   static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
551                   int col) {
552     std::uint8_t* dst_ptr = dst->data(row, col);
553     if (DstType::kOrder == MapOrder::ColMajor) {
554       int col_stride = dst->cols_stride();
555       for (int i = 0; i < 4; i++) {
556         vst1_u8(dst_ptr + i * col_stride, src.buf.reg[i]);
557       }
558     } else {
559       for (int i = 0; i < 4; i++) {
560         int row_stride = dst->rows_stride();
561         std::uint8_t* col_ptr = dst_ptr + i;
562         vst1_lane_u8(col_ptr + 0 * row_stride, src.buf.reg[i], 0);
563         vst1_lane_u8(col_ptr + 1 * row_stride, src.buf.reg[i], 1);
564         vst1_lane_u8(col_ptr + 2 * row_stride, src.buf.reg[i], 2);
565         vst1_lane_u8(col_ptr + 3 * row_stride, src.buf.reg[i], 3);
566         vst1_lane_u8(col_ptr + 4 * row_stride, src.buf.reg[i], 4);
567         vst1_lane_u8(col_ptr + 5 * row_stride, src.buf.reg[i], 5);
568         vst1_lane_u8(col_ptr + 6 * row_stride, src.buf.reg[i], 6);
569         vst1_lane_u8(col_ptr + 7 * row_stride, src.buf.reg[i], 7);
570       }
571     }
572   }
573 };
574 
575 inline RegBlockUint8<8, 8> Transpose(const RegBlockUint8<8, 8>& src) {
576   uint8x8x2_t a[4];
577   a[0] = vtrn_u8(src.buf.reg[0], src.buf.reg[1]);
578   a[1] = vtrn_u8(src.buf.reg[2], src.buf.reg[3]);
579   a[2] = vtrn_u8(src.buf.reg[4], src.buf.reg[5]);
580   a[3] = vtrn_u8(src.buf.reg[6], src.buf.reg[7]);
581   uint16x4x2_t b[4];
582   b[0] = vtrn_u16(vreinterpret_u16_u8(a[0].val[0]),
583                   vreinterpret_u16_u8(a[1].val[0]));
584   b[1] = vtrn_u16(vreinterpret_u16_u8(a[0].val[1]),
585                   vreinterpret_u16_u8(a[1].val[1]));
586   b[2] = vtrn_u16(vreinterpret_u16_u8(a[2].val[0]),
587                   vreinterpret_u16_u8(a[3].val[0]));
588   b[3] = vtrn_u16(vreinterpret_u16_u8(a[2].val[1]),
589                   vreinterpret_u16_u8(a[3].val[1]));
590   uint32x2x2_t c[4];
591   c[0] = vtrn_u32(vreinterpret_u32_u16(b[0].val[0]),
592                   vreinterpret_u32_u16(b[2].val[0]));
593   c[1] = vtrn_u32(vreinterpret_u32_u16(b[1].val[0]),
594                   vreinterpret_u32_u16(b[3].val[0]));
595   c[2] = vtrn_u32(vreinterpret_u32_u16(b[0].val[1]),
596                   vreinterpret_u32_u16(b[2].val[1]));
597   c[3] = vtrn_u32(vreinterpret_u32_u16(b[1].val[1]),
598                   vreinterpret_u32_u16(b[3].val[1]));
599   RegBlockUint8<8, 8> result;
600   result.buf.reg[0] = vreinterpret_u8_u32(c[0].val[0]);
601   result.buf.reg[1] = vreinterpret_u8_u32(c[1].val[0]);
602   result.buf.reg[2] = vreinterpret_u8_u32(c[2].val[0]);
603   result.buf.reg[3] = vreinterpret_u8_u32(c[3].val[0]);
604   result.buf.reg[4] = vreinterpret_u8_u32(c[0].val[1]);
605   result.buf.reg[5] = vreinterpret_u8_u32(c[1].val[1]);
606   result.buf.reg[6] = vreinterpret_u8_u32(c[2].val[1]);
607   result.buf.reg[7] = vreinterpret_u8_u32(c[3].val[1]);
608   return result;
609 }
610 
611 template <typename DstType>
612 struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
613   static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
614                   int col) {
615     const auto& block =
616         DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src);
617     std::uint8_t* dst_ptr = dst->data(row, col);
618     int stride = dst->stride();
619     for (int i = 0; i < 8; i++) {
620       vst1_u8(dst_ptr + i * stride, block.buf.reg[i]);
621     }
622   }
623 };
624 
625 template <typename DstType>
626 struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
627   static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
628                   int col) {
629     if (DstType::kOrder == MapOrder::ColMajor) {
630       vst1q_s16(dst->data(row, col + 0), src.buf.reg[0]);
631       vst1q_s16(dst->data(row, col + 1), src.buf.reg[1]);
632       vst1q_s16(dst->data(row, col + 2), src.buf.reg[2]);
633       vst1q_s16(dst->data(row, col + 3), src.buf.reg[3]);
634       vst1q_s16(dst->data(row, col + 4), src.buf.reg[4]);
635       vst1q_s16(dst->data(row, col + 5), src.buf.reg[5]);
636       vst1q_s16(dst->data(row, col + 6), src.buf.reg[6]);
637       vst1q_s16(dst->data(row, col + 7), src.buf.reg[7]);
638     } else {
639       int16x8x2_t a[4];
640       a[0] = vtrnq_s16(src.buf.reg[0], src.buf.reg[1]);
641       a[1] = vtrnq_s16(src.buf.reg[2], src.buf.reg[3]);
642       a[2] = vtrnq_s16(src.buf.reg[4], src.buf.reg[5]);
643       a[3] = vtrnq_s16(src.buf.reg[6], src.buf.reg[7]);
644       int32x4x2_t b[4];
645       b[0] = vtrnq_s32(vreinterpretq_s32_s16(a[0].val[0]),
646                        vreinterpretq_s32_s16(a[1].val[0]));
647       b[1] = vtrnq_s32(vreinterpretq_s32_s16(a[0].val[1]),
648                        vreinterpretq_s32_s16(a[1].val[1]));
649       b[2] = vtrnq_s32(vreinterpretq_s32_s16(a[2].val[0]),
650                        vreinterpretq_s32_s16(a[3].val[0]));
651       b[3] = vtrnq_s32(vreinterpretq_s32_s16(a[2].val[1]),
652                        vreinterpretq_s32_s16(a[3].val[1]));
653       vst1_s16(dst->data(row + 0, col + 0),
654                vget_low_s16(vreinterpretq_s16_s32(b[0].val[0])));
655       vst1_s16(dst->data(row + 0, col + 4),
656                vget_low_s16(vreinterpretq_s16_s32(b[2].val[0])));
657       vst1_s16(dst->data(row + 1, col + 0),
658                vget_low_s16(vreinterpretq_s16_s32(b[1].val[0])));
659       vst1_s16(dst->data(row + 1, col + 4),
660                vget_low_s16(vreinterpretq_s16_s32(b[3].val[0])));
661       vst1_s16(dst->data(row + 2, col + 0),
662                vget_low_s16(vreinterpretq_s16_s32(b[0].val[1])));
663       vst1_s16(dst->data(row + 2, col + 4),
664                vget_low_s16(vreinterpretq_s16_s32(b[2].val[1])));
665       vst1_s16(dst->data(row + 3, col + 0),
666                vget_low_s16(vreinterpretq_s16_s32(b[1].val[1])));
667       vst1_s16(dst->data(row + 3, col + 4),
668                vget_low_s16(vreinterpretq_s16_s32(b[3].val[1])));
669       vst1_s16(dst->data(row + 4, col + 0),
670                vget_high_s16(vreinterpretq_s16_s32(b[0].val[0])));
671       vst1_s16(dst->data(row + 4, col + 4),
672                vget_high_s16(vreinterpretq_s16_s32(b[2].val[0])));
673       vst1_s16(dst->data(row + 5, col + 0),
674                vget_high_s16(vreinterpretq_s16_s32(b[1].val[0])));
675       vst1_s16(dst->data(row + 5, col + 4),
676                vget_high_s16(vreinterpretq_s16_s32(b[3].val[0])));
677       vst1_s16(dst->data(row + 6, col + 0),
678                vget_high_s16(vreinterpretq_s16_s32(b[0].val[1])));
679       vst1_s16(dst->data(row + 6, col + 4),
680                vget_high_s16(vreinterpretq_s16_s32(b[2].val[1])));
681       vst1_s16(dst->data(row + 7, col + 0),
682                vget_high_s16(vreinterpretq_s16_s32(b[1].val[1])));
683       vst1_s16(dst->data(row + 7, col + 4),
684                vget_high_s16(vreinterpretq_s16_s32(b[3].val[1])));
685     }
686   }
687 };
688 
689 }  // namespace gemmlowp
690 
691 #endif  // GEMMLOWP_INTERNAL_OUTPUT_NEON_H_
692