• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/****************************************************************************
2* Copyright (C) 2017 Intel Corporation.   All Rights Reserved.
3*
4* Permission is hereby granted, free of charge, to any person obtaining a
5* copy of this software and associated documentation files (the "Software"),
6* to deal in the Software without restriction, including without limitation
7* the rights to use, copy, modify, merge, publish, distribute, sublicense,
8* and/or sell copies of the Software, and to permit persons to whom the
9* Software is furnished to do so, subject to the following conditions:
10*
11* The above copyright notice and this permission notice (including the next
12* paragraph) shall be included in all copies or substantial portions of the
13* Software.
14*
15* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21* IN THE SOFTWARE.
22****************************************************************************/
23#if !defined(__SIMD_LIB_AVX512_HPP__)
24#error Do not include this file directly, use "simdlib.hpp" instead.
25#endif
26
27//============================================================================
28// SIMD128 AVX (512) implementation
29//
30// Since this implementation inherits from the AVX (2) implementation,
31// the only operations below ones that replace AVX (2) operations.
32// These use native AVX512 instructions with masking to enable a larger
33// register set.
34//============================================================================
35
36private:
37    static SIMDINLINE __m512  __conv(Float r) { return _mm512_castps128_ps512(r.v); }
38    static SIMDINLINE __m512d __conv(Double r) { return _mm512_castpd128_pd512(r.v); }
39    static SIMDINLINE __m512i __conv(Integer r) { return _mm512_castsi128_si512(r.v); }
40    static SIMDINLINE Float   __conv(__m512 r) { return _mm512_castps512_ps128(r); }
41    static SIMDINLINE Double  __conv(__m512d r) { return _mm512_castpd512_pd128(r); }
42    static SIMDINLINE Integer __conv(__m512i r) { return _mm512_castsi512_si128(r); }
43public:
44
45#define SIMD_WRAPPER_1_(op, intrin, mask)  \
46    static SIMDINLINE Float SIMDCALL op(Float a)   \
47    {\
48        return __conv(_mm512_maskz_##intrin((mask), __conv(a)));\
49    }
50#define SIMD_WRAPPER_1(op)  SIMD_WRAPPER_1_(op, op, __mmask16(0xf))
51
52#define SIMD_WRAPPER_1I_(op, intrin, mask)  \
53    template<int ImmT> \
54    static SIMDINLINE Float SIMDCALL op(Float a)   \
55    {\
56        return __conv(_mm512_maskz_##intrin((mask), __conv(a), ImmT));\
57    }
58#define SIMD_WRAPPER_1I(op)  SIMD_WRAPPER_1I_(op, op, __mmask16(0xf))
59
60#define SIMD_WRAPPER_2_(op, intrin, mask)  \
61    static SIMDINLINE Float SIMDCALL op(Float a, Float b)   \
62    {\
63        return __conv(_mm512_maskz_##intrin((mask), __conv(a), __conv(b)));\
64    }
65#define SIMD_WRAPPER_2(op)  SIMD_WRAPPER_2_(op, op, __mmask16(0xf))
66
67#define SIMD_WRAPPER_2I(op)  \
68    template<int ImmT>\
69    static SIMDINLINE Float SIMDCALL op(Float a, Float b)   \
70    {\
71        return __conv(_mm512_maskz_##op(0xf, __conv(a), __conv(b), ImmT));\
72    }
73
74#define SIMD_WRAPPER_3_(op, intrin, mask)  \
75    static SIMDINLINE Float SIMDCALL op(Float a, Float b, Float c)   \
76    {\
77        return __conv(_mm512_maskz_##intrin((mask), __conv(a), __conv(b), __conv(c)));\
78    }
79#define SIMD_WRAPPER_3(op)  SIMD_WRAPPER_3_(op, op, __mmask16(0xf))
80
81#define SIMD_DWRAPPER_2I(op)  \
82    template<int ImmT>\
83    static SIMDINLINE Double SIMDCALL op(Double a, Double b)   \
84    {\
85        return __conv(_mm512_maskz_##op(0x3, __conv(a), __conv(b), ImmT));\
86    }
87
88#define SIMD_IWRAPPER_1_(op, intrin, mask)  \
89    static SIMDINLINE Integer SIMDCALL op(Integer a)   \
90    {\
91        return __conv(_mm512_maskz_##intrin((mask), __conv(a)));\
92    }
93#define SIMD_IWRAPPER_1_32(op)  SIMD_IWRAPPER_1_(op, op, __mmask16(0xf))
94
95#define SIMD_IWRAPPER_1I_(op, intrin, mask)  \
96    template<int ImmT> \
97    static SIMDINLINE Integer SIMDCALL op(Integer a)   \
98    {\
99        return __conv(_mm512_maskz_##intrin((mask), __conv(a), ImmT));\
100    }
101#define SIMD_IWRAPPER_1I_32(op)  SIMD_IWRAPPER_1I_(op, op, __mmask16(0xf))
102
103#define SIMD_IWRAPPER_2_(op, intrin, mask)  \
104    static SIMDINLINE Integer SIMDCALL op(Integer a, Integer b)   \
105    {\
106        return __conv(_mm512_maskz_##intrin((mask), __conv(a), __conv(b)));\
107    }
108#define SIMD_IWRAPPER_2_32(op)  SIMD_IWRAPPER_2_(op, op, __mmask16(0xf))
109
110#define SIMD_IWRAPPER_2I(op)  \
111    template<int ImmT>\
112    static SIMDINLINE Integer SIMDCALL op(Integer a, Integer b)   \
113    {\
114        return __conv(_mm512_maskz_##op(0xf, __conv(a), __conv(b), ImmT));\
115    }
116
117//-----------------------------------------------------------------------
118// Single precision floating point arithmetic operations
119//-----------------------------------------------------------------------
120SIMD_WRAPPER_2(add_ps);     // return a + b
121SIMD_WRAPPER_2(div_ps);     // return a / b
122SIMD_WRAPPER_3(fmadd_ps);   // return (a * b) + c
123SIMD_WRAPPER_3(fmsub_ps);   // return (a * b) - c
124SIMD_WRAPPER_2(max_ps);     // return (a > b) ? a : b
125SIMD_WRAPPER_2(min_ps);     // return (a < b) ? a : b
126SIMD_WRAPPER_2(mul_ps);     // return a * b
127SIMD_WRAPPER_1_(rcp_ps, rcp14_ps, __mmask16(0xf));     // return 1.0f / a
128SIMD_WRAPPER_1_(rsqrt_ps, rsqrt14_ps, __mmask16(0xf));   // return 1.0f / sqrt(a)
129SIMD_WRAPPER_2(sub_ps);     // return a - b
130
131//-----------------------------------------------------------------------
132// Integer (various width) arithmetic operations
133//-----------------------------------------------------------------------
134SIMD_IWRAPPER_1_32(abs_epi32);  // return absolute_value(a) (int32)
135SIMD_IWRAPPER_2_32(add_epi32);  // return a + b (int32)
136SIMD_IWRAPPER_2_32(max_epi32);  // return (a > b) ? a : b (int32)
137SIMD_IWRAPPER_2_32(max_epu32);  // return (a > b) ? a : b (uint32)
138SIMD_IWRAPPER_2_32(min_epi32);  // return (a < b) ? a : b (int32)
139SIMD_IWRAPPER_2_32(min_epu32);  // return (a < b) ? a : b (uint32)
140SIMD_IWRAPPER_2_32(mul_epi32);  // return a * b (int32)
141
142// SIMD_IWRAPPER_2_8(add_epi8);    // return a + b (int8)
143// SIMD_IWRAPPER_2_8(adds_epu8);   // return ((a + b) > 0xff) ? 0xff : (a + b) (uint8)
144
145// return (a * b) & 0xFFFFFFFF
146//
147// Multiply the packed 32-bit integers in a and b, producing intermediate 64-bit integers,
148// and store the low 32 bits of the intermediate integers in dst.
149SIMD_IWRAPPER_2_32(mullo_epi32);
150SIMD_IWRAPPER_2_32(sub_epi32);  // return a - b (int32)
151
152// SIMD_IWRAPPER_2_64(sub_epi64);  // return a - b (int64)
153// SIMD_IWRAPPER_2_8(subs_epu8);   // return (b > a) ? 0 : (a - b) (uint8)
154
155//-----------------------------------------------------------------------
156// Logical operations
157//-----------------------------------------------------------------------
158SIMD_IWRAPPER_2_(and_si,    and_epi32, __mmask16(0xf));    // return a & b       (int)
159SIMD_IWRAPPER_2_(andnot_si, andnot_epi32, __mmask16(0xf)); // return (~a) & b    (int)
160SIMD_IWRAPPER_2_(or_si,     or_epi32, __mmask16(0xf));     // return a | b       (int)
161SIMD_IWRAPPER_2_(xor_si,    xor_epi32, __mmask16(0xf));    // return a ^ b       (int)
162
163
164//-----------------------------------------------------------------------
165// Shift operations
166//-----------------------------------------------------------------------
167SIMD_IWRAPPER_1I_32(slli_epi32);               // return a << ImmT
168SIMD_IWRAPPER_2_32(sllv_epi32);                // return a << b      (uint32)
169SIMD_IWRAPPER_1I_32(srai_epi32);               // return a >> ImmT   (int32)
170SIMD_IWRAPPER_1I_32(srli_epi32);               // return a >> ImmT   (uint32)
171SIMD_IWRAPPER_2_32(srlv_epi32);                // return a >> b      (uint32)
172
173// use AVX2 version
174//SIMD_IWRAPPER_1I_(srli_si, srli_si256);     // return a >> (ImmT*8) (uint)
175
176//-----------------------------------------------------------------------
177// Conversion operations (Use AVX2 versions)
178//-----------------------------------------------------------------------
179// SIMD_IWRAPPER_1L(cvtepu8_epi16, 0xffff);    // return (int16)a    (uint8 --> int16)
180// SIMD_IWRAPPER_1L(cvtepu8_epi32, 0xff);      // return (int32)a    (uint8 --> int32)
181// SIMD_IWRAPPER_1L(cvtepu16_epi32, 0xff);     // return (int32)a    (uint16 --> int32)
182// SIMD_IWRAPPER_1L(cvtepu16_epi64, 0xf);      // return (int64)a    (uint16 --> int64)
183// SIMD_IWRAPPER_1L(cvtepu32_epi64, 0xf);      // return (int64)a    (uint32 --> int64)
184
185//-----------------------------------------------------------------------
186// Comparison operations (Use AVX2 versions
187//-----------------------------------------------------------------------
188//SIMD_IWRAPPER_2_CMP(cmpeq_epi8);    // return a == b (int8)
189//SIMD_IWRAPPER_2_CMP(cmpeq_epi16);   // return a == b (int16)
190//SIMD_IWRAPPER_2_CMP(cmpeq_epi32);   // return a == b (int32)
191//SIMD_IWRAPPER_2_CMP(cmpeq_epi64);   // return a == b (int64)
192//SIMD_IWRAPPER_2_CMP(cmpgt_epi8,);   // return a > b (int8)
193//SIMD_IWRAPPER_2_CMP(cmpgt_epi16);   // return a > b (int16)
194//SIMD_IWRAPPER_2_CMP(cmpgt_epi32);   // return a > b (int32)
195//SIMD_IWRAPPER_2_CMP(cmpgt_epi64);   // return a > b (int64)
196//
197//static SIMDINLINE Integer SIMDCALL cmplt_epi32(Integer a, Integer b)   // return a < b (int32)
198//{
199//    return cmpgt_epi32(b, a);
200//}
201
202//-----------------------------------------------------------------------
203// Blend / shuffle / permute operations
204//-----------------------------------------------------------------------
205// SIMD_IWRAPPER_2_8(packs_epi16);     // int16 --> int8    See documentation for _mm256_packs_epi16 and _mm512_packs_epi16
206// SIMD_IWRAPPER_2_16(packs_epi32);    // int32 --> int16   See documentation for _mm256_packs_epi32 and _mm512_packs_epi32
207// SIMD_IWRAPPER_2_8(packus_epi16);    // uint16 --> uint8  See documentation for _mm256_packus_epi16 and _mm512_packus_epi16
208// SIMD_IWRAPPER_2_16(packus_epi32);   // uint32 --> uint16 See documentation for _mm256_packus_epi32 and _mm512_packus_epi32
209// SIMD_IWRAPPER_2_(permute_epi32, permutevar8x32_epi32);
210
211//static SIMDINLINE Float SIMDCALL permute_ps(Float a, Integer swiz)    // return a[swiz[i]] for each 32-bit lane i (float)
212//{
213//    return _mm256_permutevar8x32_ps(a, swiz);
214//}
215
216SIMD_IWRAPPER_1I_32(shuffle_epi32);
217//template<int ImmT>
218//static SIMDINLINE Integer SIMDCALL shuffle_epi64(Integer a, Integer b)
219//{
220//    return castpd_si(shuffle_pd<ImmT>(castsi_pd(a), castsi_pd(b)));
221//}
222//SIMD_IWRAPPER_2(shuffle_epi8);
223SIMD_IWRAPPER_2_32(unpackhi_epi32);
224SIMD_IWRAPPER_2_32(unpacklo_epi32);
225
226// SIMD_IWRAPPER_2_16(unpackhi_epi16);
227// SIMD_IWRAPPER_2_64(unpackhi_epi64);
228// SIMD_IWRAPPER_2_8(unpackhi_epi8);
229// SIMD_IWRAPPER_2_16(unpacklo_epi16);
230// SIMD_IWRAPPER_2_64(unpacklo_epi64);
231// SIMD_IWRAPPER_2_8(unpacklo_epi8);
232
233//-----------------------------------------------------------------------
234// Load / store operations
235//-----------------------------------------------------------------------
236static SIMDINLINE Float SIMDCALL load_ps(float const *p)   // return *p    (loads SIMD width elements from memory)
237{
238    return __conv(_mm512_maskz_loadu_ps(__mmask16(0xf), p));
239}
240
241static SIMDINLINE Integer SIMDCALL load_si(Integer const *p)  // return *p
242{
243    return __conv(_mm512_maskz_loadu_epi32(__mmask16(0xf), p));
244}
245
246static SIMDINLINE Float SIMDCALL loadu_ps(float const *p)  // return *p    (same as load_ps but allows for unaligned mem)
247{
248    return __conv(_mm512_maskz_loadu_ps(__mmask16(0xf), p));
249}
250
251static SIMDINLINE Integer SIMDCALL loadu_si(Integer const *p) // return *p    (same as load_si but allows for unaligned mem)
252{
253    return __conv(_mm512_maskz_loadu_epi32(__mmask16(0xf), p));
254}
255
256template<ScaleFactor ScaleT>
257static SIMDINLINE Float SIMDCALL i32gather_ps(float const* p, Integer idx) // return *(float*)(((int8*)p) + (idx * ScaleT))
258{
259    return __conv(_mm512_mask_i32gather_ps(
260                    _mm512_setzero_ps(),
261                    __mmask16(0xf),
262                    __conv(idx),
263                    p,
264                    static_cast<int>(ScaleT)));
265}
266
267// for each element: (mask & (1 << 31)) ? (i32gather_ps<ScaleT>(p, idx), mask = 0) : old
268template<ScaleFactor ScaleT>
269static SIMDINLINE Float SIMDCALL mask_i32gather_ps(Float old, float const* p, Integer idx, Float mask)
270{
271    __mmask16 m = 0xf;
272    m = _mm512_mask_test_epi32_mask(m, _mm512_castps_si512(__conv(mask)),
273                                _mm512_set1_epi32(0x80000000));
274    return __conv(_mm512_mask_i32gather_ps(
275                    __conv(old),
276                    m,
277                    __conv(idx),
278                    p,
279                    static_cast<int>(ScaleT)));
280}
281
282// static SIMDINLINE uint32_t SIMDCALL movemask_epi8(Integer a)
283// {
284//     __mmask64 m = 0xffffull;
285//     return static_cast<uint32_t>(
286//         _mm512_mask_test_epi8_mask(m, __conv(a), _mm512_set1_epi8(0x80)));
287// }
288
289static SIMDINLINE void SIMDCALL maskstore_ps(float *p, Integer mask, Float src)
290{
291    __mmask16 m = 0xf;
292    m = _mm512_mask_test_epi32_mask(m, __conv(mask), _mm512_set1_epi32(0x80000000));
293    _mm512_mask_storeu_ps(p, m, __conv(src));
294}
295
296static SIMDINLINE void SIMDCALL store_ps(float *p, Float a)    // *p = a   (stores all elements contiguously in memory)
297{
298    _mm512_mask_storeu_ps(p, __mmask16(0xf), __conv(a));
299}
300
301static SIMDINLINE void SIMDCALL store_si(Integer *p, Integer a)   // *p = a
302{
303    _mm512_mask_storeu_epi32(p, __mmask16(0xf), __conv(a));
304}
305
306static SIMDINLINE Float SIMDCALL vmask_ps(int32_t mask)
307{
308    return castsi_ps(__conv(_mm512_maskz_set1_epi32(__mmask16(mask & 0xf), -1)));
309}
310
311//=======================================================================
312// Legacy interface (available only in SIMD256 width)
313//=======================================================================
314
315#undef SIMD_WRAPPER_1_
316#undef SIMD_WRAPPER_1
317#undef SIMD_WRAPPER_1I_
318#undef SIMD_WRAPPER_1I
319#undef SIMD_WRAPPER_2_
320#undef SIMD_WRAPPER_2
321#undef SIMD_WRAPPER_2I
322#undef SIMD_WRAPPER_3_
323#undef SIMD_WRAPPER_3
324#undef SIMD_DWRAPPER_1_
325#undef SIMD_DWRAPPER_1
326#undef SIMD_DWRAPPER_1I_
327#undef SIMD_DWRAPPER_1I
328#undef SIMD_DWRAPPER_2_
329#undef SIMD_DWRAPPER_2
330#undef SIMD_DWRAPPER_2I
331#undef SIMD_IWRAPPER_1_
332#undef SIMD_IWRAPPER_1_8
333#undef SIMD_IWRAPPER_1_16
334#undef SIMD_IWRAPPER_1_32
335#undef SIMD_IWRAPPER_1_64
336#undef SIMD_IWRAPPER_1I_
337#undef SIMD_IWRAPPER_1I_8
338#undef SIMD_IWRAPPER_1I_16
339#undef SIMD_IWRAPPER_1I_32
340#undef SIMD_IWRAPPER_1I_64
341#undef SIMD_IWRAPPER_2_
342#undef SIMD_IWRAPPER_2_8
343#undef SIMD_IWRAPPER_2_16
344#undef SIMD_IWRAPPER_2_32
345#undef SIMD_IWRAPPER_2_64
346#undef SIMD_IWRAPPER_2I
347//#undef SIMD_IWRAPPER_2I_8
348//#undef SIMD_IWRAPPER_2I_16
349//#undef SIMD_IWRAPPER_2I_32
350//#undef SIMD_IWRAPPER_2I_64
351