• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
7$assert CHANNEL_TILE % 8 == 0
8$assert CHANNEL_TILE >= 8
9$assert KERNEL_TILE >= 2
10#include <assert.h>
11
12#include <immintrin.h>
13
14#include <xnnpack/dwconv.h>
15
16
17void xnn_qs8_dwconv_minmax_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__avx2_mul32(
18    size_t channels,
19    size_t output_width,
20    const int8_t** input,
21    const void* weights,
22    int8_t* output,
23    size_t input_stride,
24    size_t output_increment,
25    size_t input_offset,
26    const int8_t* zero,
27    const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
28{
29  assert(channels != 0);
30  assert(output_width != 0);
31
32  do {
33    $for K in range(KERNEL_TILE):
34      const int8_t* i${K} = input[${K}];
35      assert(i${K} != NULL);
36      if XNN_UNPREDICTABLE(i${K} != zero) {
37        i${K} = (const int8_t*) ((uintptr_t) i${K} + input_offset);
38      }
39    input = (const int8_t**) ((uintptr_t) input + input_stride);
40
41    size_t c = channels;
42    const void* w = weights;
43    for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
44      __m256i vacc${ABC[0:8]} = _mm256_loadu_si256((const __m256i*) w);
45      $for C in range(8, CHANNEL_TILE, 8):
46        __m256i vacc${ABC[C:C+8]} = _mm256_loadu_si256((const __m256i*) ((uintptr_t) w + ${C} * sizeof(int32_t)));
47
48      $for K in range(KERNEL_TILE):
49
50        $for C in range(0, CHANNEL_TILE, 8):
51          $if C == 0:
52            const __m256i vi${K}x${ABC[0:8]} = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i${K}));
53          $else:
54            const __m256i vi${K}x${ABC[C:C+8]} = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i${K} + ${C})));
55          const __m256i vk${K}x${ABC[C:C+8]} = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE + C} * sizeof(int8_t))));
56        i${K} += ${CHANNEL_TILE};
57
58        $for C in range(0, CHANNEL_TILE, 8):
59          vacc${ABC[C:C+8]} = _mm256_add_epi32(vacc${ABC[C:C+8]}, _mm256_mullo_epi32(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]}));
60
61      w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(int8_t));
62
63      const __m256i vmultiplier = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.multiplier));
64      const __m256i vrounding = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.rounding));
65
66      $for C in range(0, CHANNEL_TILE, 8):
67        const __m256i vacc${ABC[C+1:C+8:2]} = _mm256_shuffle_epi32(vacc${ABC[C:C+8]}, _MM_SHUFFLE(3, 3, 1, 1));
68
69      $for C in range(0, CHANNEL_TILE, 8):
70        const __m256i vprod${ABC[C:C+8:2]} = _mm256_add_epi64(_mm256_mul_epi32(vacc${ABC[C:C+8]}, vmultiplier), vrounding);
71        const __m256i vprod${ABC[C+1:C+8:2]} = _mm256_add_epi64(_mm256_mul_epi32(vacc${ABC[C+1:C+8:2]}, vmultiplier), vrounding);
72
73      $for C in range(0, CHANNEL_TILE, 8):
74        const __m256i vq31prod${ABC[C:C+8:2]} = _mm256_srli_epi64(vprod${ABC[C:C+8:2]}, 31);
75        const __m256i vq31prod${ABC[C+1:C+8:2]} = _mm256_add_epi64(vprod${ABC[C+1:C+8:2]}, vprod${ABC[C+1:C+8:2]});
76
77      $for C in range(0, CHANNEL_TILE, 8):
78        const __m256i vq31prod${ABC[C:C+8]} = _mm256_blend_epi16(vq31prod${ABC[C:C+8:2]}, vq31prod${ABC[C+1:C+8:2]}, 0xCC);
79
80      const __m256i vremainder_mask = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.remainder_mask));
81      $for C in range(0, CHANNEL_TILE, 8):
82        const __m256i vrem${ABC[C:C+8]} =
83          _mm256_add_epi32(_mm256_and_si256(vq31prod${ABC[C:C+8]}, vremainder_mask), _mm256_cmpgt_epi32(_mm256_setzero_si256(), vq31prod${ABC[C:C+8]}));
84
85      const __m256i vremainder_threshold = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.remainder_threshold));
86      const __m128i vshift = _mm_load_si128((const __m128i*) params->sse2.shift);
87      $for C in range(0, CHANNEL_TILE, 8):
88        vacc${ABC[C:C+8]} =
89          _mm256_sub_epi32(_mm256_sra_epi32(vq31prod${ABC[C:C+8]}, vshift), _mm256_cmpgt_epi32(vrem${ABC[C:C+8]}, vremainder_threshold));
90
91      $if CHANNEL_TILE > 8:
92        const __m256i voutput_zero_point = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.output_zero_point));
93      $else:
94        const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse2.output_zero_point);
95      $for C in range(0, CHANNEL_TILE, 16):
96        $if C + 8 < CHANNEL_TILE:
97          __m256i vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_adds_epi16(_mm256_packs_epi32(vacc${ABC[C:C+8]}, vacc${ABC[C+8:C+16]}), voutput_zero_point);
98        $elif CHANNEL_TILE > 8:
99          __m128i vout${ABC[C:C+8]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[C:C+8]}), _mm256_extracti128_si256(vacc${ABC[C:C+8]}, 1)), _mm256_castsi256_si128(voutput_zero_point));
100        $else:
101          __m128i vout${ABC[C:C+8]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[C:C+8]}), _mm256_extracti128_si256(vacc${ABC[C:C+8]}, 1)), voutput_zero_point);
102
103      $if CHANNEL_TILE > 8:
104        const __m256i voutput_min = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.output_min));
105        const __m256i voutput_max = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.output_max));
106      $else:
107        const __m128i voutput_min = _mm_load_si128((const __m128i*) params->sse2.output_min);
108        const __m128i voutput_max = _mm_load_si128((const __m128i*) params->sse2.output_max);
109      $for C in range(0, CHANNEL_TILE, 16):
110        $if C + 8 < CHANNEL_TILE:
111          vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_min_epi16(_mm256_max_epi16(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]}, voutput_min), voutput_max);
112        $elif CHANNEL_TILE > 8:
113          vout${ABC[C:C+8]} = _mm_min_epi16(_mm_max_epi16(vout${ABC[C:C+8]}, _mm256_castsi256_si128(voutput_min)), _mm256_castsi256_si128(voutput_max));
114        $else:
115          vout${ABC[C:C+8]} = _mm_min_epi16(_mm_max_epi16(vout${ABC[C:C+8]}, voutput_min), voutput_max);
116
117      $for C in range(0, CHANNEL_TILE, 16):
118        $if C + 8 < CHANNEL_TILE:
119          __m128i vout${ABC[C:C+16]} = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]}), _mm256_extracti128_si256(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]}, 1)), _MM_SHUFFLE(3, 1, 2, 0));
120        $else:
121          __m128i vout${ABC[C:C+8]}${ABC[C:C+8]} = _mm_packs_epi16(vout${ABC[C:C+8]}, vout${ABC[C:C+8]});
122
123      $if CHANNEL_TILE > 8:
124        _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]});
125      $else:
126        _mm_storel_epi64((__m128i*) output, vout${ABC[0:8]}${ABC[C:C+8]});
127      $for C in range(16, CHANNEL_TILE, 16):
128        $if C + 8 < CHANNEL_TILE:
129          _mm_storeu_si128((__m128i*) (output + ${C}), vout${ABC[C:C+16]});
130        $else:
131          _mm_storel_epi64((__m128i*) (output + ${C}), vout${ABC[C:C+8]}${ABC[C:C+8]});
132      output += ${CHANNEL_TILE};
133    }
134    if XNN_UNLIKELY(c != 0) {
135      $if CHANNEL_TILE > 8:
136        const int8_t* k = (const int8_t*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t));
137      ${"do " if CHANNEL_TILE > 8 else ""}{
138        __m256i vacc${ABC[0:8]} = _mm256_loadu_si256((const __m256i*) w);
139
140        $for K in range(KERNEL_TILE):
141
142          const __m256i vi${K}x${ABC[0:8]} = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i${K}));
143          $if CHANNEL_TILE > 8:
144            $if K == 0:
145              const __m256i vk${K}x${ABC[0:8]} = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) k));
146            $else:
147              const __m256i vk${K}x${ABC[0:8]} = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + ${K * CHANNEL_TILE})));
148          $else:
149            const __m256i vk${K}x${ABC[0:8]} = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE} * sizeof(int8_t))));
150          $if CHANNEL_TILE > 8:
151            i${K} += 8;
152
153          vacc${ABC[0:8]} = _mm256_add_epi32(vacc${ABC[0:8]}, _mm256_mullo_epi32(vi${K}x${ABC[0:8]}, vk${K}x${ABC[0:8]}));
154
155        $if CHANNEL_TILE > 8:
156          w = (const void*) ((uintptr_t) w + 8 * sizeof(int32_t));
157          k += 8;
158
159        const __m256i vmultiplier = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.multiplier));
160        const __m256i vrounding = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.rounding));
161
162        const __m256i vacc${ABC[1:8:2]} = _mm256_shuffle_epi32(vacc${ABC[0:8]}, _MM_SHUFFLE(3, 3, 1, 1));
163
164        const __m256i vprod${ABC[0:8:2]} = _mm256_add_epi64(_mm256_mul_epi32(vacc${ABC[0:8]}, vmultiplier), vrounding);
165        const __m256i vprod${ABC[1:8:2]} = _mm256_add_epi64(_mm256_mul_epi32(vacc${ABC[1:8:2]}, vmultiplier), vrounding);
166
167        const __m256i vq31prod${ABC[0:8:2]} = _mm256_srli_epi64(vprod${ABC[0:8:2]}, 31);
168        const __m256i vq31prod${ABC[1:8:2]} = _mm256_add_epi64(vprod${ABC[1:8:2]}, vprod${ABC[1:8:2]});
169
170        const __m256i vq31prod${ABC[0:8]} = _mm256_blend_epi16(vq31prod${ABC[0:8:2]}, vq31prod${ABC[1:8:2]}, 0xCC);
171
172        const __m256i vremainder_mask = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.remainder_mask));
173        const __m256i vrem${ABC[0:8]} =
174          _mm256_add_epi32(_mm256_and_si256(vq31prod${ABC[0:8]}, vremainder_mask), _mm256_cmpgt_epi32(_mm256_setzero_si256(), vq31prod${ABC[0:8]}));
175
176        const __m256i vremainder_threshold = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) params->sse2.remainder_threshold));
177        const __m128i vshift = _mm_load_si128((const __m128i*) params->sse2.shift);
178        vacc${ABC[0:8]} =
179          _mm256_sub_epi32(_mm256_sra_epi32(vq31prod${ABC[0:8]}, vshift), _mm256_cmpgt_epi32(vrem${ABC[0:8]}, vremainder_threshold));
180
181        const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse2.output_zero_point);
182        __m128i vout${ABC[0:8]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[0:8]}), _mm256_extracti128_si256(vacc${ABC[0:8]}, 1)), voutput_zero_point);
183
184        const __m128i voutput_min = _mm_load_si128((const __m128i*) params->sse2.output_min);
185        const __m128i voutput_max = _mm_load_si128((const __m128i*) params->sse2.output_max);
186        vout${ABC[0:8]} = _mm_min_epi16(_mm_max_epi16(vout${ABC[0:8]}, voutput_min), voutput_max);
187
188        __m128i vout${ABC[0:8]}${ABC[0:8]} = _mm_packs_epi16(vout${ABC[0:8]}, vout${ABC[0:8]});
189
190        $if CHANNEL_TILE > 8:
191          if XNN_LIKELY(c >= 8) {
192            _mm_storel_epi64((__m128i*) output, vout${ABC[0:8]}${ABC[0:8]});
193            output += 8;
194            c -= 8;
195          } else {
196            if (c & 4) {
197              *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]});
198              vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi64(vout${ABC[0:8]}${ABC[0:8]}, 32);
199              output += 4;
200            }
201            if (c & 2) {
202              *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout${ABC[0:8]}${ABC[0:8]}, 0);
203              vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi32(vout${ABC[0:8]}${ABC[0:8]}, 16);
204              output += 2;
205            }
206            if (c & 1) {
207              *output = (int8_t) _mm_extract_epi8(vout${ABC[0:8]}${ABC[0:8]}, 0);
208              output += 1;
209            }
210            c = 0;
211          }
212        $else:
213          if (c & 4) {
214            *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]});
215            vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi64(vout${ABC[0:8]}${ABC[0:8]}, 32);
216            output += 4;
217          }
218          if (c & 2) {
219            *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout${ABC[0:8]}${ABC[0:8]}, 0);
220            vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi32(vout${ABC[0:8]}${ABC[0:8]}, 16);
221            output += 2;
222          }
223          if (c & 1) {
224            *output = (int8_t) _mm_extract_epi8(vout${ABC[0:8]}${ABC[0:8]}, 0);
225            output += 1;
226          }
227      }${" while (c != 0);" if CHANNEL_TILE > 8 else ""}
228    }
229
230    output = (int8_t*) ((uintptr_t) output + output_increment);
231  } while (--output_width != 0);
232}
233