• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
7$assert REQUANTIZATION == "FP32"
8$assert DATATYPE in ["QC8", "QS8"]
9$assert CHANNEL_TILE % 16 == 0
10$assert CHANNEL_TILE >= 16
11$assert KERNEL_TILE >= 2
12#include <assert.h>
13
14#include <immintrin.h>
15
16#include <xnnpack/dwconv.h>
17#include <xnnpack/unaligned.h>
18
19
20$PARAMS_STRUCT = REQUANTIZATION.lower() + "_avx2"
21$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower()
22void xnn_${DATATYPE.lower()}_dwconv_minmax_${REQUANTIZATION.lower()}_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__avx2_mul16${"_add16" if ADD16 else ""}_vpunpck(
23    size_t channels,
24    size_t output_width,
25    const int8_t** input,
26    const void* weights,
27    int8_t* output,
28    size_t input_stride,
29    size_t output_increment,
30    size_t input_offset,
31    const int8_t* zero,
32    const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
33{
34  assert(channels != 0);
35  assert(output_width != 0);
36
37  do {
38    $for K in range(KERNEL_TILE):
39      const int8_t* i${K} = input[${K}];
40      assert(i${K} != NULL);
41      if XNN_UNPREDICTABLE(i${K} != zero) {
42        i${K} = (const int8_t*) ((uintptr_t) i${K} + input_offset);
43      }
44    input = (const int8_t**) ((uintptr_t) input + input_stride);
45
46    size_t c = channels;
47    const void* w = weights;
48    for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
49      __m256i vacc${ABC[0:8]} = _mm256_loadu_si256((const __m256i*) w);
50      $for C in range(8, CHANNEL_TILE, 8):
51        __m256i vacc${ABC[C:C+8]} = _mm256_loadu_si256((const __m256i*) ((uintptr_t) w + ${C} * sizeof(int32_t)));
52
53      $for C in range(0, CHANNEL_TILE, 16):
54        __m256i vacc${ABC[C:C+4]}${ABC[C+8:C+12]} = _mm256_inserti128_si256(vacc${ABC[C:C+8]}, _mm256_castsi256_si128(vacc${ABC[C+8:C+16]}), 1);
55        __m256i vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_permute2x128_si256(vacc${ABC[C:C+8]}, vacc${ABC[C+8:C+16]}, 0x31);
56
57      $for K in range(KERNEL_TILE):
58
59        $for C in range(0, CHANNEL_TILE, 16):
60          $if C == 0:
61            const __m256i vi${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) i${K}));
62          $else:
63            const __m256i vi${K}x${ABC[C:C+16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) (i${K} + ${C})));
64          const __m256i vk${K}x${ABC[C:C+16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE + C} * sizeof(int8_t))));
65        i${K} += ${CHANNEL_TILE};
66
67        $if ADD16:
68          $for C in range(0, CHANNEL_TILE, 16):
69            $if K == 0:
70              __m256i vacc${ABC[C:C+16]} = _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]});
71            $elif K % 2 == 0 or K + 1 == KERNEL_TILE:
72              vacc${ABC[C:C+16]} = _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]});
73            $else:
74              vacc${ABC[C:C+16]} = _mm256_add_epi16(vacc${ABC[C:C+16]}, _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]}));
75
76          $if K % 2 == 1 or K + 1 == KERNEL_TILE:
77            $for C in range(0, CHANNEL_TILE, 16):
78              $if K == 1:
79                __m256i vsignacc${ABC[C:C+16]} = _mm256_srai_epi16(vacc${ABC[C:C+16]}, 15);
80              $else:
81                vsignacc${ABC[C:C+16]} = _mm256_srai_epi16(vacc${ABC[C:C+16]}, 15);
82              vacc${ABC[C:C+4]}${ABC[C+8:C+12]} = _mm256_add_epi32(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, _mm256_unpacklo_epi16(vacc${ABC[C:C+16]}, vsignacc${ABC[C:C+16]}));
83              vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_add_epi32(vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}, _mm256_unpackhi_epi16(vacc${ABC[C:C+16]}, vsignacc${ABC[C:C+16]}));
84        $else:
85          $for C in range(0, CHANNEL_TILE, 16):
86            const __m256i vprod${K}x${ABC[C:C+16]}lo =  _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]});
87            const __m256i vprod${K}x${ABC[C:C+16]}hi = _mm256_srai_epi16(vprod${K}x${ABC[C:C+16]}lo, 15);
88
89          $for C in range(0, CHANNEL_TILE, 16):
90            vacc${ABC[C:C+4]}${ABC[C+8:C+12]} = _mm256_add_epi32(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, _mm256_unpacklo_epi16(vprod${K}x${ABC[C:C+16]}lo, vprod${K}x${ABC[C:C+16]}hi));
91            vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_add_epi32(vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}, _mm256_unpackhi_epi16(vprod${K}x${ABC[C:C+16]}lo, vprod${K}x${ABC[C:C+16]}hi));
92
93      w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(int8_t));
94
95      $for C in range(0, CHANNEL_TILE, 16):
96        vacc${ABC[C:C+8]} = _mm256_inserti128_si256(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, _mm256_castsi256_si128(vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}), 1);
97        vacc${ABC[C+8:C+16]} = _mm256_permute2x128_si256(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}, 0x31);
98
99      $for C in range(0, CHANNEL_TILE, 8):
100        __m256 vfpacc${ABC[C:C+8]} = _mm256_cvtepi32_ps(vacc${ABC[C:C+8]});
101
102      $if DATATYPE == "QC8":
103        const __m256 vscale${ABC[0:8]} = _mm256_loadu_ps((const float*) w);
104        $for C in range(8, CHANNEL_TILE, 8):
105          const __m256 vscale${ABC[C:C+8]} = _mm256_loadu_ps((const float*) ((uintptr_t) w + ${C} * sizeof(float)));
106        w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(float));
107        $for C in range(0, CHANNEL_TILE, 8):
108          vfpacc${ABC[C:C+8]} = _mm256_mul_ps(vfpacc${ABC[C:C+8]}, vscale${ABC[C:C+8]});
109      $else:
110        const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
111        $for C in range(0, CHANNEL_TILE, 8):
112          vfpacc${ABC[C:C+8]} = _mm256_mul_ps(vfpacc${ABC[C:C+8]}, vscale);
113
114      const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point);
115      $for C in range(0, CHANNEL_TILE, 8):
116        vfpacc${ABC[C:C+8]} = _mm256_min_ps(vfpacc${ABC[C:C+8]}, voutput_max_less_zero_point);
117
118      $for C in range(0, CHANNEL_TILE, 8):
119        vacc${ABC[C:C+8]} = _mm256_cvtps_epi32(vfpacc${ABC[C:C+8]});
120
121      const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->${PARAMS_STRUCT}.output_zero_point);
122      $for C in range(0, CHANNEL_TILE, 16):
123        const __m256i vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_adds_epi16(_mm256_packs_epi32(vacc${ABC[C:C+8]}, vacc${ABC[C+8:C+16]}), voutput_zero_point);
124
125      $for C in range(0, CHANNEL_TILE, 16):
126        __m128i vout${ABC[C:C+16]} = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]}), _mm256_extracti128_si256(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]}, 1)), _MM_SHUFFLE(3, 1, 2, 0));
127
128      const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min);
129      $for C in range(0, CHANNEL_TILE, 16):
130        vout${ABC[C:C+16]} = _mm_max_epi8(vout${ABC[C:C+16]}, voutput_min);
131
132      _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]});
133      $for C in range(16, CHANNEL_TILE, 16):
134        _mm_storeu_si128((__m128i*) (output + ${C}), vout${ABC[C:C+16]});
135      output += ${CHANNEL_TILE};
136    }
137    if XNN_UNLIKELY(c != 0) {
138      $if CHANNEL_TILE > 16:
139        const int8_t* k = (const int8_t*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t));
140      ${"do " if CHANNEL_TILE > 16 else ""}{
141        __m256i vacc${ABC[0:8]} = _mm256_loadu_si256((const __m256i*) w);
142        __m256i vacc${ABC[8:16]} = _mm256_loadu_si256((const __m256i*) ((uintptr_t) w + 8 * sizeof(int32_t)));
143
144        __m256i vacc${ABC[0:4]}${ABC[8:12]} = _mm256_inserti128_si256(vacc${ABC[0:8]}, _mm256_castsi256_si128(vacc${ABC[8:16]}), 1);
145        __m256i vacc${ABC[4:8]}${ABC[12:16]} = _mm256_permute2x128_si256(vacc${ABC[0:8]}, vacc${ABC[8:16]}, 0x31);
146
147        $for K in range(KERNEL_TILE):
148
149          const __m256i vi${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) i${K}));
150          $if CHANNEL_TILE > 16:
151            $if K == 0:
152              const __m256i vk${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) k));
153            $else:
154              const __m256i vk${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) (k + ${K * CHANNEL_TILE})));
155          $else:
156            const __m256i vk${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE} * sizeof(int8_t))));
157          $if CHANNEL_TILE > 16:
158            i${K} += 16;
159
160          const __m256i vprod${K}x${ABC[0:16]}lo =  _mm256_mullo_epi16(vi${K}x${ABC[0:16]}, vk${K}x${ABC[0:16]});
161          const __m256i vprod${K}x${ABC[0:16]}hi = _mm256_srai_epi16(vprod${K}x${ABC[0:16]}lo, 15);
162
163          vacc${ABC[0:4]}${ABC[8:12]} = _mm256_add_epi32(vacc${ABC[0:4]}${ABC[8:12]}, _mm256_unpacklo_epi16(vprod${K}x${ABC[0:16]}lo, vprod${K}x${ABC[0:16]}hi));
164          vacc${ABC[4:8]}${ABC[12:16]} = _mm256_add_epi32(vacc${ABC[4:8]}${ABC[12:16]}, _mm256_unpackhi_epi16(vprod${K}x${ABC[0:16]}lo, vprod${K}x${ABC[0:16]}hi));
165
166        vacc${ABC[0:8]} = _mm256_inserti128_si256(vacc${ABC[0:4]}${ABC[8:12]}, _mm256_castsi256_si128(vacc${ABC[4:8]}${ABC[12:16]}), 1);
167        vacc${ABC[8:16]} = _mm256_permute2x128_si256(vacc${ABC[0:4]}${ABC[8:12]}, vacc${ABC[4:8]}${ABC[12:16]}, 0x31);
168
169        $if CHANNEL_TILE > 16:
170          k += 16;
171
172        __m256 vfpacc${ABC[0:8]} = _mm256_cvtepi32_ps(vacc${ABC[0:8]});
173        __m256 vfpacc${ABC[8:16]} = _mm256_cvtepi32_ps(vacc${ABC[8:16]});
174
175        $if DATATYPE == "QC8":
176          const __m256 vscale${ABC[0:8]} = _mm256_loadu_ps((const float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(int8_t)));
177          const __m256 vscale${ABC[8:16]} = _mm256_loadu_ps((const float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(int8_t) + 8 * sizeof(float)));
178          vfpacc${ABC[0:8]} = _mm256_mul_ps(vfpacc${ABC[0:8]}, vscale${ABC[0:8]});
179          vfpacc${ABC[8:16]} = _mm256_mul_ps(vfpacc${ABC[8:16]}, vscale${ABC[8:16]});
180        $else:
181          const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
182          vfpacc${ABC[0:8]} = _mm256_mul_ps(vfpacc${ABC[0:8]}, vscale);
183          vfpacc${ABC[8:16]} = _mm256_mul_ps(vfpacc${ABC[8:16]}, vscale);
184
185        const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point);
186        vfpacc${ABC[0:8]} = _mm256_min_ps(vfpacc${ABC[0:8]}, voutput_max_less_zero_point);
187        vfpacc${ABC[8:16]} = _mm256_min_ps(vfpacc${ABC[8:16]}, voutput_max_less_zero_point);
188
189        vacc${ABC[0:8]} = _mm256_cvtps_epi32(vfpacc${ABC[0:8]});
190        vacc${ABC[8:16]} = _mm256_cvtps_epi32(vfpacc${ABC[8:16]});
191
192        $if CHANNEL_TILE > 16:
193          w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
194
195        const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_zero_point);
196        __m128i vout${ABC[0:8]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[0:8]}), _mm256_extracti128_si256(vacc${ABC[0:8]}, 1)), voutput_zero_point);
197        __m128i vout${ABC[8:16]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[8:16]}), _mm256_extracti128_si256(vacc${ABC[8:16]}, 1)), voutput_zero_point);
198
199        const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min);
200
201        __m128i vout${ABC[0:16]} = _mm_packs_epi16(vout${ABC[0:8]}, vout${ABC[8:16]});
202        vout${ABC[0:16]} = _mm_max_epi8(vout${ABC[0:16]}, voutput_min);
203
204        $if CHANNEL_TILE > 16:
205          if XNN_LIKELY(c >= 16) {
206            _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]});
207            output += 16;
208            c -= 16;
209          } else {
210            if (c & 8) {
211              _mm_storel_epi64((__m128i*) output, vout${ABC[0:16]});
212              vout${ABC[0:16]} = _mm_unpackhi_epi64(vout${ABC[0:16]}, vout${ABC[0:16]});
213              output += 8;
214            }
215            if (c & 4) {
216              unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:16]}));
217              vout${ABC[0:16]} = _mm_srli_epi64(vout${ABC[0:16]}, 32);
218              output += 4;
219            }
220            if (c & 2) {
221              unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vout${ABC[0:16]}, 0));
222              vout${ABC[0:16]} = _mm_srli_epi32(vout${ABC[0:16]}, 16);
223              output += 2;
224            }
225            if (c & 1) {
226              *output = (int8_t) _mm_extract_epi8(vout${ABC[0:16]}, 0);
227              output += 1;
228            }
229            c = 0;
230          }
231        $else:
232          if (c & 8) {
233            _mm_storel_epi64((__m128i*) output, vout${ABC[0:16]});
234            vout${ABC[0:16]} = _mm_unpackhi_epi64(vout${ABC[0:16]}, vout${ABC[0:16]});
235            output += 8;
236          }
237          if (c & 4) {
238            unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:16]}));
239            vout${ABC[0:16]} = _mm_srli_epi64(vout${ABC[0:16]}, 32);
240            output += 4;
241          }
242          if (c & 2) {
243            unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vout${ABC[0:16]}, 0));
244            vout${ABC[0:16]} = _mm_srli_epi32(vout${ABC[0:16]}, 16);
245            output += 2;
246          }
247          if (c & 1) {
248            *output = (int8_t) _mm_extract_epi8(vout${ABC[0:16]}, 0);
249            output += 1;
250          }
251      }${" while (c != 0);" if CHANNEL_TILE > 16 else ""}
252    }
253
254    output = (int8_t*) ((uintptr_t) output + output_increment);
255  } while (--output_width != 0);
256}
257