• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE % 8 == 0
7$assert KERNEL_TILE >= 2
8$assert ACCUMULATORS >= 1
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10#include <assert.h>
11
12#include <immintrin.h>
13
14#include <xnnpack/dwconv.h>
15#include <xnnpack/intrinsics-polyfill.h>
16
17
18void xnn_f16_dwconv_minmax_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__fma3${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}(
19    size_t channels,
20    size_t output_width,
21    const void** input,
22    const void* weights,
23    void* output,
24    size_t input_stride,
25    size_t output_increment,
26    size_t input_offset,
27    const void* zero,
28    const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
29{
30  assert(channels != 0);
31  assert(output_width != 0);
32
33  const __m256 vmax = _mm256_load_ps(params->avx.max);
34  const __m256 vmin = _mm256_load_ps(params->avx.min);
35
36  uint16_t* o = (uint16_t*) output;
37  do {
38    $for K in range(KERNEL_TILE):
39      const uint16_t* i${K} = input[${K}];
40      assert(i${K} != NULL);
41      if XNN_UNPREDICTABLE(i${K} != zero) {
42        i${K} = (const uint16_t*) ((uintptr_t) i${K} + input_offset);
43      }
44    input = (const void**) ((uintptr_t) input + input_stride);
45
46    size_t c = channels;
47    const uint16_t* w = weights;
48    for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
49      __m256 vacc${ABC[0:8]}p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
50      $for C in range(8, CHANNEL_TILE, 8):
51        __m256 vacc${ABC[C:C+8]}p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + ${C})));
52
53      $for K in range(KERNEL_TILE):
54
55        const __m256 vi${K}x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${K}));
56        $for C in range(8, CHANNEL_TILE, 8):
57          const __m256 vi${K}x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i${K} + ${C})));
58        i${K} += ${CHANNEL_TILE};
59
60        $for C in range(0, CHANNEL_TILE, 8):
61          const __m256 vk${K}x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + ${(K + 1) * CHANNEL_TILE + C})));
62        $for C in range(0, CHANNEL_TILE, 8):
63          $if 1 <= K < ACCUMULATORS:
64            __m256 vacc${ABC[C:C+8]}p${K} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]}), _MM_FROUND_NO_EXC));
65          $else:
66            vacc${ABC[C:C+8]}p${K % ACCUMULATORS} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]}, vacc${ABC[C:C+8]}p${K % ACCUMULATORS}), _MM_FROUND_NO_EXC));
67
68      w += ${(KERNEL_TILE + 1) * CHANNEL_TILE};
69
70      $if ACCUMULATORS > 1:
71        // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
72        $ACC_SLICE = 1
73        $while ACC_SLICE < ACCUMULATORS:
74          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
75            $if A + ACC_SLICE < ACCUMULATORS:
76              $for C in range(0, CHANNEL_TILE, 8):
77                vacc${ABC[C:C+8]}p${A} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vacc${ABC[C:C+8]}p${A}, vacc${ABC[C:C+8]}p${A + ACC_SLICE}), _MM_FROUND_NO_EXC));
78          $ACC_SLICE *= 2
79
80      $for C in range(0, CHANNEL_TILE, 8):
81        __m256 vacc${ABC[C:C+8]} = _mm256_max_ps(vacc${ABC[C:C+8]}p0, vmin);
82      $for C in range(0, CHANNEL_TILE, 8):
83        vacc${ABC[C:C+8]} = _mm256_min_ps(vacc${ABC[C:C+8]}, vmax);
84
85      _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc${ABC[0:8]}, _MM_FROUND_NO_EXC));
86      $for C in range(8, CHANNEL_TILE, 8):
87        _mm_storeu_si128((__m128i*) (o + ${C}), _mm256_cvtps_ph(vacc${ABC[C:C+8]}, _MM_FROUND_NO_EXC));
88      o += ${CHANNEL_TILE};
89    }
90    $if CHANNEL_TILE > 8:
91      for (; c >= 8; c -= 8) {
92        __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
93        $for K in range(KERNEL_TILE):
94
95          const __m256 vi${K}x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${K}));
96          i${K} += 8;
97
98          const __m256 vk${K}x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + ${(K + 1) * CHANNEL_TILE})));
99          $if 1 <= K < ACCUMULATORS:
100            __m256 vacc01234567p${K} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vi${K}x01234567, vk${K}x01234567), _MM_FROUND_NO_EXC));
101          $else:
102            vacc01234567p${K % ACCUMULATORS} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi${K}x01234567, vk${K}x01234567, vacc01234567p${K % ACCUMULATORS}), _MM_FROUND_NO_EXC));
103
104        w += 8;
105
106        $if ACCUMULATORS > 1:
107          // Add up all accumulators to vacc${ABC[0:8]}p0
108          $ACC_SLICE = 1
109          $while ACC_SLICE < ACCUMULATORS:
110            $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
111              $if A + ACC_SLICE < ACCUMULATORS:
112                vacc01234567p${A} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vacc01234567p${A}, vacc01234567p${A + ACC_SLICE}), _MM_FROUND_NO_EXC));
113            $ACC_SLICE *= 2
114
115        __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
116        vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
117
118        _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC));
119        o += 8;
120      }
121    if XNN_UNLIKELY(c != 0) {
122      assert(c >= 1);
123      assert(c <= 7);
124
125      __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
126      $for K in range(KERNEL_TILE):
127
128        const __m256 vi${K}x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${K}));
129
130        const __m256 vk${K}x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + ${(K + 1) * CHANNEL_TILE})));
131        $if 1 <= K < ACCUMULATORS:
132          __m256 vacc01234567p${K} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vi${K}x01234567, vk${K}x01234567), _MM_FROUND_NO_EXC));
133        $else:
134          vacc01234567p${K % ACCUMULATORS} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi${K}x01234567, vk${K}x01234567, vacc01234567p${K % ACCUMULATORS}), _MM_FROUND_NO_EXC));
135
136      $if ACCUMULATORS > 1:
137        // Add up all accumulators to vacc${ABC[0:8]}p0
138        $ACC_SLICE = 1
139        $while ACC_SLICE < ACCUMULATORS:
140          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
141            $if A + ACC_SLICE < ACCUMULATORS:
142              vacc01234567p${A} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vacc01234567p${A}, vacc01234567p${A + ACC_SLICE}), _MM_FROUND_NO_EXC));
143          $ACC_SLICE *= 2
144
145      __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
146      vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
147
148      __m128i vh01234567 = _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC);
149      if (c & 4) {
150        _mm_storel_epi64((__m128i*) o, vh01234567);
151        vh01234567 = _mm_unpackhi_epi64(vh01234567, vh01234567);
152        o += 4;
153      }
154      if (c & 2) {
155        _mm_storeu_si32(o, vh01234567);
156        vh01234567 = _mm_srli_epi64(vh01234567, 32);
157        o += 2;
158      }
159      if (c & 1) {
160        *o = (uint16_t) _mm_extract_epi16(vh01234567, 0);
161        o += 1;
162      }
163    }
164
165    o = (uint16_t*) ((uintptr_t) o + output_increment);
166  } while (--output_width != 0);
167}
168