1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
18
19 #include "third_party/eigen3/Eigen/Core"
20 #include "tensorflow/core/platform/byte_order.h"
21 #include "tensorflow/core/platform/types.h"
22
23 #if defined(PLATFORM_WINDOWS)
24 #include "tensorflow/core/platform/windows/cpu_info.h"
25 #include "tensorflow/core/platform/windows/intrinsics_port.h"
26 #endif
27
28 namespace Eigen {
29 namespace internal {
30
31 // Return the float representation of the bfloat16 value
32 // in the lower 16-bits of input
33 template <typename Packet>
pexpand_bf16_l(const Packet & from)34 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_l(const Packet& from) {
35 tensorflow::uint32 tmp;
36 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
37 tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
38 #else
39 tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
40 #endif
41 return reinterpret_cast<const float&>(tmp);
42 }
43
44 // Return the float representation of the bfloat16 value
45 // in the upper 16-bits of input
46 template <typename Packet>
pexpand_bf16_u(const Packet & from)47 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_u(const Packet& from) {
48 tensorflow::uint32 tmp;
49 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
50 tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
51 #else
52 tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
53 #endif
54 return reinterpret_cast<const float&>(tmp);
55 }
56
57 // Specialization non-scalar version on non-sse.
58 // Enable vectorization on z13 and higher
59 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
60 defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
61 template <typename Packet>
pexpand_bf16_l(const Packet4f & from)62 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
63 float r[4];
64 tensorflow::uint32 p[4];
65 pstoreu(r, from);
66 tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
67 p[0] = (ir[0] << 16) & 0xffff0000;
68 p[1] = ir[0] & 0xffff0000;
69 p[2] = (ir[1] << 16) & 0xffff0000;
70 p[3] = ir[1] & 0xffff0000;
71 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
72 }
73
74 template <typename Packet>
pexpand_bf16_u(const Packet4f & from)75 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
76 float r[4];
77 tensorflow::uint32 p[4];
78 pstoreu(r, from);
79 tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
80 p[0] = (ir[2] << 16) & 0xffff0000;
81 p[1] = ir[2] & 0xffff0000;
82 p[2] = (ir[3] << 16) & 0xffff0000;
83 p[3] = ir[3] & 0xffff0000;
84 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
85 }
86 #endif
87
88 template <typename Packet>
pinterleave4x64(const Packet & from)89 EIGEN_DEVICE_FUNC inline Packet pinterleave4x64(const Packet& from) {
90 return from;
91 }
92
93 template <typename Packet>
pbroadcast_first(const Packet & a)94 EIGEN_DEVICE_FUNC inline Packet pbroadcast_first(const Packet& a) {
95 return a;
96 }
97
98 template <typename Packet>
pbroadcast_second(const Packet & a)99 EIGEN_DEVICE_FUNC inline Packet pbroadcast_second(const Packet& a) {
100 assert(false && "Not applicable to Scalar Values");
101 return a;
102 }
103
104 template <typename Packet>
pbroadcast_third(const Packet & a)105 EIGEN_DEVICE_FUNC inline Packet pbroadcast_third(const Packet& a) {
106 assert(false && "Not applicable to Scalar Values");
107 return a;
108 }
109
110 template <typename Packet>
pbroadcast_fourth(const Packet & a)111 EIGEN_DEVICE_FUNC inline Packet pbroadcast_fourth(const Packet& a) {
112 assert(false && "Not applicable to Scalar Values");
113 return a;
114 }
115
116 template <typename Packet>
pload4bf16(const typename unpacket_traits<Packet>::type * from)117 EIGEN_DEVICE_FUNC inline Packet pload4bf16(
118 const typename unpacket_traits<Packet>::type* from) {
119 assert(false && "Not applicable to Scalar Values");
120 return Packet();
121 }
122
123 template <typename Packet>
pload2bf16(const typename unpacket_traits<Packet>::type * from)124 EIGEN_DEVICE_FUNC inline Packet pload2bf16(
125 const typename unpacket_traits<Packet>::type* from) {
126 assert(false && "Not applicable to Scalar Values");
127 return Packet();
128 }
129
130 // Specialization for pload4bf16 and pload2bf16 for non-sse.
131 // Enable vectorization on z13 and higher.
132 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
133 defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
134 template <>
135 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
136 tensorflow::uint32 p[4];
137 const tensorflow::uint32* ir =
138 reinterpret_cast<const tensorflow::uint32*>(from);
139 p[0] = (ir[0] << 16) & 0xffff0000;
140 p[1] = ir[0] & 0xffff0000;
141 p[2] = (ir[1] << 16) & 0xffff0000;
142 p[3] = ir[1] & 0xffff0000;
143 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
144 }
145
146 template <>
147 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
148 tensorflow::uint32 p[4];
149 const tensorflow::uint32* ir =
150 reinterpret_cast<const tensorflow::uint32*>(from);
151 p[0] = (ir[0] << 16) & 0xffff0000;
152 p[1] = ir[0] & 0xffff0000;
153 p[2] = (ir[0] << 16) & 0xffff0000;
154 p[3] = ir[0] & 0xffff0000;
155 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
156 }
157 #endif
158
159 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
160 // Return a packet with the first value of the input Packet replicated
161 template <>
162 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
163 return vec_splat(a, 0);
164 }
165
166 // Return a packet with the second value of the input Packet replicated
167 template <>
168 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
169 return vec_splat(a, 1);
170 }
171
172 // Return a packet with the third value of the input Packet replicated
173 template <>
174 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
175 return vec_splat(a, 2);
176 }
177
178 // Return a packet with the fourth value of the input Packet replicated
179 template <>
180 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
181 return vec_splat(a, 3);
182 }
183 #endif
184
185 #ifdef EIGEN_VECTORIZE_SSE2
186 // For PacketSize of 4 floats the Packet is not modified
187 template <>
188 EIGEN_STRONG_INLINE Packet4f pinterleave4x64<Packet4f>(const Packet4f& from) {
189 return from;
190 }
191
192 // Return a Packet with 4 floats loaded from 4 bfloat16 values
193 template <>
194 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
195 __m128i zero = _mm_setzero_si128();
196 __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
197 return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
198 }
199
200 // Return a Packet with 2 floats loaded from 2 bfloat16 values
201 template <>
202 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
203 __m128i zero = _mm_setzero_si128();
204 __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
205 return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
206 }
207
208 // Return a Packet with 4 floats expanded from 4 bfloat16 values
209 // in the lower half of the 128-bit lane
210 template <typename Packet>
pexpand_bf16_l(const Packet4f & from)211 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
212 __m128i zero = _mm_setzero_si128();
213 __m128i tmp = _mm_castps_si128(from);
214 return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
215 }
216
217 // Return a Packet with 4 floats expanded from 4 bfloat16 values
218 // in the upper half of the 128-bit lane
219 template <typename Packet>
pexpand_bf16_u(const Packet4f & from)220 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
221 __m128i zero = _mm_setzero_si128();
222 __m128i tmp = _mm_castps_si128(from);
223 return _mm_castsi128_ps(_mm_unpackhi_epi16(zero, tmp));
224 }
225
226 // Return a packet with the first value of the input Packet replicated
227 template <>
228 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
229 return _mm_set1_ps(pfirst<Packet4f>(a));
230 }
231
232 // Return a packet with the second value of the input Packet replicated
233 template <>
234 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
235 return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 1)));
236 }
237
238 // Return a packet with the third value of the input Packet replicated
239 template <>
240 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
241 return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 2)));
242 }
243
244 // Return a packet with the fourth value of the input Packet replicated
245 template <>
246 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
247 return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 3)));
248 }
249
250 #endif
251
252 #ifdef EIGEN_VECTORIZE_AVX512
253 template <>
254 EIGEN_STRONG_INLINE Packet16f
255 pbroadcast_first<Packet16f>(const Packet16f& a_in) {
256 Packet4f a = _mm512_castps512_ps128(a_in);
257 return _mm512_broadcastss_ps(a);
258 }
259 template <>
260 EIGEN_STRONG_INLINE Packet16f
261 pbroadcast_second<Packet16f>(const Packet16f& a_in) {
262 Packet4f a = _mm512_castps512_ps128(a_in);
263 return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(1, 1, 1, 1)));
264 }
265 template <>
266 EIGEN_STRONG_INLINE Packet16f
267 pbroadcast_third<Packet16f>(const Packet16f& a_in) {
268 Packet4f a = _mm512_castps512_ps128(a_in);
269 return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 2, 2, 2)));
270 }
271 template <>
272 EIGEN_STRONG_INLINE Packet16f
273 pbroadcast_fourth<Packet16f>(const Packet16f& a_in) {
274 Packet4f a = _mm512_castps512_ps128(a_in);
275 return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(3, 3, 3, 3)));
276 }
277 template <>
278 EIGEN_STRONG_INLINE Packet8d pbroadcast_first<Packet8d>(const Packet8d& a_in) {
279 Packet2d a = _mm512_castpd512_pd128(a_in);
280 return _mm512_broadcastsd_pd(a);
281 }
282 template <>
283 EIGEN_STRONG_INLINE Packet8d pbroadcast_second<Packet8d>(const Packet8d& a_in) {
284 Packet2d a = _mm_permute_pd(_mm512_castpd512_pd128(a_in), 3);
285 return _mm512_broadcastsd_pd(a);
286 }
287 template <>
288 EIGEN_STRONG_INLINE Packet8d pbroadcast_third<Packet8d>(const Packet8d& a_in) {
289 Packet2d a = _mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1);
290 return _mm512_broadcastsd_pd(a);
291 }
292 template <>
293 EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth<Packet8d>(const Packet8d& a_in) {
294 Packet2d a =
295 _mm_permute_pd(_mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1), 3);
296 return _mm512_broadcastsd_pd(a);
297 }
298 template <>
299 EIGEN_STRONG_INLINE Packet16i
300 pbroadcast_first<Packet16i>(const Packet16i& a_in) {
301 Packet4i a = _mm512_castsi512_si128(a_in);
302 return _mm512_broadcastd_epi32(a);
303 }
304 template <>
305 EIGEN_STRONG_INLINE Packet16i
306 pbroadcast_second<Packet16i>(const Packet16i& a_in) {
307 Packet4i a = _mm512_castsi512_si128(a_in);
308 return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(1, 1, 1, 1)));
309 }
310 template <>
311 EIGEN_STRONG_INLINE Packet16i
312 pbroadcast_third<Packet16i>(const Packet16i& a_in) {
313 Packet4i a = _mm512_castsi512_si128(a_in);
314 return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(2, 2, 2, 2)));
315 }
316 template <>
317 EIGEN_STRONG_INLINE Packet16i
318 pbroadcast_fourth<Packet16i>(const Packet16i& a_in) {
319 Packet4i a = _mm512_castsi512_si128(a_in);
320 return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(3, 3, 3, 3)));
321 }
322 #endif
323
324 #ifdef EIGEN_VECTORIZE_AVX
325 // For a Packet of Size 8 floats(256-bits), swap the 2nd and 3rd quadwords
326 template <>
327 EIGEN_STRONG_INLINE Packet8f pinterleave4x64<Packet8f>(const Packet8f& from) {
328 #ifdef EIGEN_VECTORIZE_AVX2
329 return _mm256_castsi256_ps(_mm256_permute4x64_epi64(_mm256_castps_si256(from),
330 _MM_SHUFFLE(3, 1, 2, 0)));
331 #else
332 auto tmp1 = _mm256_extract_epi32(_mm256_castps_si256(from), 2);
333 auto tmp2 = _mm256_extract_epi32(_mm256_castps_si256(from), 3);
334 auto tmp3 = _mm256_extract_epi32(_mm256_castps_si256(from), 4);
335 auto tmp4 = _mm256_extract_epi32(_mm256_castps_si256(from), 5);
336 auto tmp5 = _mm256_insert_epi32(_mm256_castps_si256(from), tmp1, 4);
337 tmp5 = _mm256_insert_epi32(tmp5, tmp2, 5);
338 tmp5 = _mm256_insert_epi32(tmp5, tmp3, 2);
339 tmp5 = _mm256_insert_epi32(tmp5, tmp4, 3);
340 return _mm256_castsi256_ps(tmp5);
341 #endif
342 }
343 // Return a Packet with 4 floats loaded from 4 bfloat16 values
344 template <>
345 EIGEN_STRONG_INLINE Packet8f pload4bf16<Packet8f>(const float* from) {
346 __m128i zero = _mm_setzero_si128();
347 __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
348 return _mm256_castps128_ps256(
349 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
350 }
351 // Return a Packet with 2 floats loaded from 2 bfloat16 values
352 template <>
353 EIGEN_STRONG_INLINE Packet8f pload2bf16<Packet8f>(const float* from) {
354 __m128i zero = _mm_setzero_si128();
355 __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
356 return _mm256_castps128_ps256(
357 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
358 }
359
360 #ifdef EIGEN_VECTORIZE_AVX512
361 // Return a Packet with 4 floats loaded from 4 bfloat16 values
362 template <>
363 EIGEN_STRONG_INLINE Packet16f pload4bf16<Packet16f>(const float* from) {
364 __m128i zero = _mm_setzero_si128();
365 __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
366 return _mm512_castps128_ps512(
367 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
368 }
369 // Return a Packet with 2 floats loaded from 2 bfloat16 values
370 template <>
371 EIGEN_STRONG_INLINE Packet16f pload2bf16<Packet16f>(const float* from) {
372 __m128i zero = _mm_setzero_si128();
373 __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
374 return _mm512_castps128_ps512(
375 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
376 }
377 #endif
378
379 // For each 128-bit lane convert 4 bfloat to 4 float values from the lower half
380 // of the 128-bit lane
381 template <typename Packet>
pexpand_bf16_l(const Packet8f & from)382 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_l(const Packet8f& from) {
383 #ifdef EIGEN_VECTORIZE_AVX2
384 __m256i zero = _mm256_setzero_si256();
385 __m256i tmp = _mm256_castps_si256(from);
386 return _mm256_castsi256_ps(_mm256_unpacklo_epi16(zero, tmp));
387 #else
388 __m128i zero = _mm_setzero_si128();
389 __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
390 __m128i res_l = _mm_unpacklo_epi16(zero, low);
391 __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
392 __m128i res_h = _mm_unpacklo_epi16(zero, high);
393 __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
394 res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
395 return res;
396 #endif
397 }
398
399 // For each 128-bit lane convert 4 bfloat to 4 float values from the upper half
400 // of the 128-bit lane
401 template <typename Packet>
pexpand_bf16_u(const Packet8f & from)402 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_u(const Packet8f& from) {
403 #ifdef EIGEN_VECTORIZE_AVX2
404 __m256i zero = _mm256_setzero_si256();
405 __m256i tmp = _mm256_castps_si256(from);
406 return _mm256_castsi256_ps(_mm256_unpackhi_epi16(zero, tmp));
407 #else
408 __m128i zero = _mm_setzero_si128();
409 __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
410 __m128i res_l = _mm_unpackhi_epi16(zero, low);
411 __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
412 __m128i res_h = _mm_unpackhi_epi16(zero, high);
413 __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
414 res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
415 return res;
416 #endif
417 }
418
419 // Return a packet with the first value of the input Packet replicated
420 template <>
421 EIGEN_STRONG_INLINE Packet8f pbroadcast_first<Packet8f>(const Packet8f& a) {
422 return _mm256_set1_ps(pfirst<Packet8f>(a));
423 }
424
425 // Return a packet with the second value of the input Packet replicated
426 template <>
427 EIGEN_STRONG_INLINE Packet8f pbroadcast_second<Packet8f>(const Packet8f& a) {
428 return _mm256_set1_ps(
429 _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 1))));
430 }
431
432 // Return a packet with the third value of the input Packet replicated
433 template <>
434 EIGEN_STRONG_INLINE Packet8f pbroadcast_third<Packet8f>(const Packet8f& a) {
435 return _mm256_set1_ps(
436 _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 2))));
437 }
438
439 // Return a packet with the fourth value of the input Packet replicated
440 template <>
441 EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth<Packet8f>(const Packet8f& a) {
442 return _mm256_set1_ps(
443 _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 3))));
444 }
445
446 #endif
447
448 #ifdef EIGEN_VECTORIZE_AVX512
449
450 template <typename Packet>
pexpand_bf16_l(const Packet16f & from)451 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) {
452 return _mm512_castsi512_ps(_mm512_slli_epi32(
453 _mm512_cvtepu16_epi32(_mm512_castsi512_si256(_mm512_castps_si512(from))),
454 16));
455 }
456
457 template <typename Packet>
pexpand_bf16_u(const Packet16f & from)458 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) {
459 Packet16i tmp = _mm512_castps_si512(from);
460 Packet16i tmp2 = _mm512_alignr_epi32(tmp, tmp, 8);
461 return _mm512_castsi512_ps(_mm512_slli_epi32(
462 _mm512_cvtepu16_epi32(_mm512_castsi512_si256(tmp2)), 16));
463 }
464
465 #endif
466 } // namespace internal
467 } // namespace Eigen
468 #endif // TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
469