1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
18
19 #include "third_party/eigen3/Eigen/Core"
20 #include "tensorflow/core/platform/byte_order.h"
21 #include "tensorflow/core/platform/types.h"
22
23 #if defined(PLATFORM_WINDOWS)
24 #include "tensorflow/core/platform/windows/intrinsics_port.h"
25 #endif
26
27 namespace Eigen {
28 namespace internal {
29
30 // Return the float representation of the bfloat16 value
31 // in the lower 16-bits of input
32 template <typename Packet>
pexpand_bf16_l(const Packet & from)33 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_l(const Packet& from) {
34 tensorflow::uint32 tmp;
35 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
36 tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
37 #else
38 tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
39 #endif
40 return reinterpret_cast<const float&>(tmp);
41 }
42
43 // Return the float representation of the bfloat16 value
44 // in the upper 16-bits of input
45 template <typename Packet>
pexpand_bf16_u(const Packet & from)46 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_u(const Packet& from) {
47 tensorflow::uint32 tmp;
48 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
49 tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
50 #else
51 tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
52 #endif
53 return reinterpret_cast<const float&>(tmp);
54 }
55
56 // Specialization non-scalar version on non-sse.
57 // Enable vectorization on z13 and higher
58 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
59 defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
60 template <typename Packet>
pexpand_bf16_l(const Packet4f & from)61 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
62 float r[4];
63 tensorflow::uint32 p[4];
64 pstoreu(r, from);
65 tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
66 p[0] = (ir[0] << 16) & 0xffff0000;
67 p[1] = ir[0] & 0xffff0000;
68 p[2] = (ir[1] << 16) & 0xffff0000;
69 p[3] = ir[1] & 0xffff0000;
70 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
71 }
72
73 template <typename Packet>
pexpand_bf16_u(const Packet4f & from)74 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
75 float r[4];
76 tensorflow::uint32 p[4];
77 pstoreu(r, from);
78 tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
79 p[0] = (ir[2] << 16) & 0xffff0000;
80 p[1] = ir[2] & 0xffff0000;
81 p[2] = (ir[3] << 16) & 0xffff0000;
82 p[3] = ir[3] & 0xffff0000;
83 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
84 }
85 #endif
86
87 template <typename Packet>
pinterleave4x64(const Packet & from)88 EIGEN_DEVICE_FUNC inline Packet pinterleave4x64(const Packet& from) {
89 return from;
90 }
91
92 template <typename Packet>
pbroadcast_first(const Packet & a)93 EIGEN_DEVICE_FUNC inline Packet pbroadcast_first(const Packet& a) {
94 return a;
95 }
96
97 template <typename Packet>
pbroadcast_second(const Packet & a)98 EIGEN_DEVICE_FUNC inline Packet pbroadcast_second(const Packet& a) {
99 assert(false && "Not applicable to Scalar Values");
100 return a;
101 }
102
103 template <typename Packet>
pbroadcast_third(const Packet & a)104 EIGEN_DEVICE_FUNC inline Packet pbroadcast_third(const Packet& a) {
105 assert(false && "Not applicable to Scalar Values");
106 return a;
107 }
108
109 template <typename Packet>
pbroadcast_fourth(const Packet & a)110 EIGEN_DEVICE_FUNC inline Packet pbroadcast_fourth(const Packet& a) {
111 assert(false && "Not applicable to Scalar Values");
112 return a;
113 }
114
115 template <typename Packet>
pload4bf16(const typename unpacket_traits<Packet>::type * from)116 EIGEN_DEVICE_FUNC inline Packet pload4bf16(
117 const typename unpacket_traits<Packet>::type* from) {
118 assert(false && "Not applicable to Scalar Values");
119 return Packet();
120 }
121
122 template <typename Packet>
pload2bf16(const typename unpacket_traits<Packet>::type * from)123 EIGEN_DEVICE_FUNC inline Packet pload2bf16(
124 const typename unpacket_traits<Packet>::type* from) {
125 assert(false && "Not applicable to Scalar Values");
126 return Packet();
127 }
128
129 // Specialization for pload4bf16 and pload2bf16 for non-sse.
130 // Enable vectorization on z13 and higher.
131 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
132 defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
133 template <>
134 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
135 tensorflow::uint32 p[4];
136 const tensorflow::uint32* ir =
137 reinterpret_cast<const tensorflow::uint32*>(from);
138 p[0] = (ir[0] << 16) & 0xffff0000;
139 p[1] = ir[0] & 0xffff0000;
140 p[2] = (ir[1] << 16) & 0xffff0000;
141 p[3] = ir[1] & 0xffff0000;
142 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
143 }
144
145 template <>
146 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
147 tensorflow::uint32 p[4];
148 const tensorflow::uint32* ir =
149 reinterpret_cast<const tensorflow::uint32*>(from);
150 p[0] = (ir[0] << 16) & 0xffff0000;
151 p[1] = ir[0] & 0xffff0000;
152 p[2] = (ir[0] << 16) & 0xffff0000;
153 p[3] = ir[0] & 0xffff0000;
154 return ploadu<Packet4f>(reinterpret_cast<float*>(p));
155 }
156 #endif
157
158 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
159 // Return a packet with the first value of the input Packet replicated
160 template <>
161 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
162 return vec_splat(a, 0);
163 }
164
165 // Return a packet with the second value of the input Packet replicated
166 template <>
167 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
168 return vec_splat(a, 1);
169 }
170
171 // Return a packet with the third value of the input Packet replicated
172 template <>
173 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
174 return vec_splat(a, 2);
175 }
176
177 // Return a packet with the fourth value of the input Packet replicated
178 template <>
179 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
180 return vec_splat(a, 3);
181 }
182 #endif
183
184 #ifdef EIGEN_VECTORIZE_SSE2
185 // For PacketSize of 4 floats the Packet is not modified
186 template <>
187 EIGEN_STRONG_INLINE Packet4f pinterleave4x64<Packet4f>(const Packet4f& from) {
188 return from;
189 }
190
191 // Return a Packet with 4 floats loaded from 4 bfloat16 values
192 template <>
193 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
194 __m128i zero = _mm_setzero_si128();
195 __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
196 return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
197 }
198
199 // Return a Packet with 2 floats loaded from 2 bfloat16 values
200 template <>
201 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
202 __m128i zero = _mm_setzero_si128();
203 __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
204 return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
205 }
206
207 // Return a Packet with 4 floats expanded from 4 bfloat16 values
208 // in the lower half of the 128-bit lane
209 template <typename Packet>
pexpand_bf16_l(const Packet4f & from)210 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
211 __m128i zero = _mm_setzero_si128();
212 __m128i tmp = _mm_castps_si128(from);
213 return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
214 }
215
216 // Return a Packet with 4 floats expanded from 4 bfloat16 values
217 // in the upper half of the 128-bit lane
218 template <typename Packet>
pexpand_bf16_u(const Packet4f & from)219 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
220 __m128i zero = _mm_setzero_si128();
221 __m128i tmp = _mm_castps_si128(from);
222 return _mm_castsi128_ps(_mm_unpackhi_epi16(zero, tmp));
223 }
224
225 // Return a packet with the first value of the input Packet replicated
226 template <>
227 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
228 return _mm_set1_ps(pfirst<Packet4f>(a));
229 }
230
231 // Return a packet with the second value of the input Packet replicated
232 template <>
233 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
234 return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 1)));
235 }
236
237 // Return a packet with the third value of the input Packet replicated
238 template <>
239 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
240 return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 2)));
241 }
242
243 // Return a packet with the fourth value of the input Packet replicated
244 template <>
245 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
246 return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 3)));
247 }
248
249 #endif
250
251 #ifdef EIGEN_VECTORIZE_AVX512
252 template <>
253 EIGEN_STRONG_INLINE Packet16f
254 pbroadcast_first<Packet16f>(const Packet16f& a_in) {
255 Packet4f a = _mm512_castps512_ps128(a_in);
256 return _mm512_broadcastss_ps(a);
257 }
258 template <>
259 EIGEN_STRONG_INLINE Packet16f
260 pbroadcast_second<Packet16f>(const Packet16f& a_in) {
261 Packet4f a = _mm512_castps512_ps128(a_in);
262 return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(1, 1, 1, 1)));
263 }
264 template <>
265 EIGEN_STRONG_INLINE Packet16f
266 pbroadcast_third<Packet16f>(const Packet16f& a_in) {
267 Packet4f a = _mm512_castps512_ps128(a_in);
268 return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 2, 2, 2)));
269 }
270 template <>
271 EIGEN_STRONG_INLINE Packet16f
272 pbroadcast_fourth<Packet16f>(const Packet16f& a_in) {
273 Packet4f a = _mm512_castps512_ps128(a_in);
274 return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(3, 3, 3, 3)));
275 }
276 template <>
277 EIGEN_STRONG_INLINE Packet8d pbroadcast_first<Packet8d>(const Packet8d& a_in) {
278 Packet2d a = _mm512_castpd512_pd128(a_in);
279 return _mm512_broadcastsd_pd(a);
280 }
281 template <>
282 EIGEN_STRONG_INLINE Packet8d pbroadcast_second<Packet8d>(const Packet8d& a_in) {
283 Packet2d a = _mm_permute_pd(_mm512_castpd512_pd128(a_in), 3);
284 return _mm512_broadcastsd_pd(a);
285 }
286 template <>
287 EIGEN_STRONG_INLINE Packet8d pbroadcast_third<Packet8d>(const Packet8d& a_in) {
288 Packet2d a = _mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1);
289 return _mm512_broadcastsd_pd(a);
290 }
291 template <>
292 EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth<Packet8d>(const Packet8d& a_in) {
293 Packet2d a =
294 _mm_permute_pd(_mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1), 3);
295 return _mm512_broadcastsd_pd(a);
296 }
297 template <>
298 EIGEN_STRONG_INLINE Packet16i
299 pbroadcast_first<Packet16i>(const Packet16i& a_in) {
300 Packet4i a = _mm512_castsi512_si128(a_in);
301 return _mm512_broadcastd_epi32(a);
302 }
303 template <>
304 EIGEN_STRONG_INLINE Packet16i
305 pbroadcast_second<Packet16i>(const Packet16i& a_in) {
306 Packet4i a = _mm512_castsi512_si128(a_in);
307 return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(1, 1, 1, 1)));
308 }
309 template <>
310 EIGEN_STRONG_INLINE Packet16i
311 pbroadcast_third<Packet16i>(const Packet16i& a_in) {
312 Packet4i a = _mm512_castsi512_si128(a_in);
313 return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(2, 2, 2, 2)));
314 }
315 template <>
316 EIGEN_STRONG_INLINE Packet16i
317 pbroadcast_fourth<Packet16i>(const Packet16i& a_in) {
318 Packet4i a = _mm512_castsi512_si128(a_in);
319 return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(3, 3, 3, 3)));
320 }
321 #endif
322
323 #ifdef EIGEN_VECTORIZE_AVX
324 // For a Packet of Size 8 floats(256-bits), swap the 2nd and 3rd quadwords
325 template <>
326 EIGEN_STRONG_INLINE Packet8f pinterleave4x64<Packet8f>(const Packet8f& from) {
327 #ifdef EIGEN_VECTORIZE_AVX2
328 return _mm256_castsi256_ps(_mm256_permute4x64_epi64(_mm256_castps_si256(from),
329 _MM_SHUFFLE(3, 1, 2, 0)));
330 #else
331 auto tmp1 = _mm256_extract_epi32(_mm256_castps_si256(from), 2);
332 auto tmp2 = _mm256_extract_epi32(_mm256_castps_si256(from), 3);
333 auto tmp3 = _mm256_extract_epi32(_mm256_castps_si256(from), 4);
334 auto tmp4 = _mm256_extract_epi32(_mm256_castps_si256(from), 5);
335 auto tmp5 = _mm256_insert_epi32(_mm256_castps_si256(from), tmp1, 4);
336 tmp5 = _mm256_insert_epi32(tmp5, tmp2, 5);
337 tmp5 = _mm256_insert_epi32(tmp5, tmp3, 2);
338 tmp5 = _mm256_insert_epi32(tmp5, tmp4, 3);
339 return _mm256_castsi256_ps(tmp5);
340 #endif
341 }
342 // Return a Packet with 4 floats loaded from 4 bfloat16 values
343 template <>
344 EIGEN_STRONG_INLINE Packet8f pload4bf16<Packet8f>(const float* from) {
345 __m128i zero = _mm_setzero_si128();
346 __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
347 return _mm256_castps128_ps256(
348 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
349 }
350 // Return a Packet with 2 floats loaded from 2 bfloat16 values
351 template <>
352 EIGEN_STRONG_INLINE Packet8f pload2bf16<Packet8f>(const float* from) {
353 __m128i zero = _mm_setzero_si128();
354 __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
355 return _mm256_castps128_ps256(
356 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
357 }
358
359 #ifdef EIGEN_VECTORIZE_AVX512
360 // Return a Packet with 4 floats loaded from 4 bfloat16 values
361 template <>
362 EIGEN_STRONG_INLINE Packet16f pload4bf16<Packet16f>(const float* from) {
363 __m128i zero = _mm_setzero_si128();
364 __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
365 return _mm512_castps128_ps512(
366 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
367 }
368 // Return a Packet with 2 floats loaded from 2 bfloat16 values
369 template <>
370 EIGEN_STRONG_INLINE Packet16f pload2bf16<Packet16f>(const float* from) {
371 __m128i zero = _mm_setzero_si128();
372 __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
373 return _mm512_castps128_ps512(
374 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
375 }
376 #endif
377
378 // For each 128-bit lane convert 4 bfloat to 4 float values from the lower half
379 // of the 128-bit lane
380 template <typename Packet>
pexpand_bf16_l(const Packet8f & from)381 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_l(const Packet8f& from) {
382 #ifdef EIGEN_VECTORIZE_AVX2
383 __m256i zero = _mm256_setzero_si256();
384 __m256i tmp = _mm256_castps_si256(from);
385 return _mm256_castsi256_ps(_mm256_unpacklo_epi16(zero, tmp));
386 #else
387 __m128i zero = _mm_setzero_si128();
388 __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
389 __m128i res_l = _mm_unpacklo_epi16(zero, low);
390 __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
391 __m128i res_h = _mm_unpacklo_epi16(zero, high);
392 __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
393 res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
394 return res;
395 #endif
396 }
397
398 // For each 128-bit lane convert 4 bfloat to 4 float values from the upper half
399 // of the 128-bit lane
400 template <typename Packet>
pexpand_bf16_u(const Packet8f & from)401 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_u(const Packet8f& from) {
402 #ifdef EIGEN_VECTORIZE_AVX2
403 __m256i zero = _mm256_setzero_si256();
404 __m256i tmp = _mm256_castps_si256(from);
405 return _mm256_castsi256_ps(_mm256_unpackhi_epi16(zero, tmp));
406 #else
407 __m128i zero = _mm_setzero_si128();
408 __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
409 __m128i res_l = _mm_unpackhi_epi16(zero, low);
410 __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
411 __m128i res_h = _mm_unpackhi_epi16(zero, high);
412 __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
413 res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
414 return res;
415 #endif
416 }
417
418 // Return a packet with the first value of the input Packet replicated
419 template <>
420 EIGEN_STRONG_INLINE Packet8f pbroadcast_first<Packet8f>(const Packet8f& a) {
421 return _mm256_set1_ps(pfirst<Packet8f>(a));
422 }
423
424 // Return a packet with the second value of the input Packet replicated
425 template <>
426 EIGEN_STRONG_INLINE Packet8f pbroadcast_second<Packet8f>(const Packet8f& a) {
427 return _mm256_set1_ps(
428 _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 1))));
429 }
430
431 // Return a packet with the third value of the input Packet replicated
432 template <>
433 EIGEN_STRONG_INLINE Packet8f pbroadcast_third<Packet8f>(const Packet8f& a) {
434 return _mm256_set1_ps(
435 _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 2))));
436 }
437
438 // Return a packet with the fourth value of the input Packet replicated
439 template <>
440 EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth<Packet8f>(const Packet8f& a) {
441 return _mm256_set1_ps(
442 _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 3))));
443 }
444
445 #endif
446
447 #ifdef EIGEN_VECTORIZE_AVX512
448
449 template <typename Packet>
pexpand_bf16_l(const Packet16f & from)450 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) {
451 return _mm512_castsi512_ps(_mm512_slli_epi32(
452 _mm512_cvtepu16_epi32(_mm512_castsi512_si256(_mm512_castps_si512(from))),
453 16));
454 }
455
456 template <typename Packet>
pexpand_bf16_u(const Packet16f & from)457 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) {
458 Packet16i tmp = _mm512_castps_si512(from);
459 Packet16i tmp2 = _mm512_alignr_epi32(tmp, tmp, 8);
460 return _mm512_castsi512_ps(_mm512_slli_epi32(
461 _mm512_cvtepu16_epi32(_mm512_castsi512_si256(tmp2)), 16));
462 }
463
464 #endif
465 } // namespace internal
466 } // namespace Eigen
467 #endif // TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
468