1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/math_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #include "tensorflow/core/kernels/sparse_matmul_op.h"
21
22 #include <map>
23 #include <memory>
24 #include <vector>
25
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/framework/bfloat16.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/kernels/fill_functor.h"
33 #include "tensorflow/core/lib/core/blocking_counter.h"
34 #include "tensorflow/core/lib/core/threadpool.h"
35 #include "tensorflow/core/lib/gtl/stl_util.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/platform/thread_annotations.h"
40 #include "tensorflow/core/platform/types.h"
41 #ifdef TENSORFLOW_USE_LIBXSMM
42 #include "include/libxsmm_intrinsics_x86.h"
43 #include "include/libxsmm_malloc.h"
44 #include "include/libxsmm_spmdm.h"
45 #endif
46
47 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
48 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
49 #endif
50
51 #define ALWAYS_INLINE EIGEN_ALWAYS_INLINE
52
53 namespace tensorflow {
54 namespace {
55
56 template <typename T>
57 using BasicMatrix = Eigen::Tensor<T, 2, Eigen::RowMajor>;
58
59 template <typename T>
60 using BasicMatrixMap =
61 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>;
62
63 using Matrix = BasicMatrix<float>;
64 using MatrixMap = BasicMatrixMap<float>;
65 using CPUDevice = Eigen::ThreadPoolDevice;
66 using DSizes = Eigen::DSizes<Eigen::DenseIndex, 2>;
67
68 // Two commonly used static dsizes. We use Eigen::type2index to allow as much
69 // compile time optimization as possible.
70 #ifdef EIGEN_HAS_INDEX_LIST
71 inline Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>
dsizes_00()72 dsizes_00() {
73 return Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>();
74 }
75 inline Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>
dsizes_10()76 dsizes_10() {
77 return Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>();
78 }
79 #else
dsizes_00()80 inline DSizes dsizes_00() { return DSizes(0, 0); }
dsizes_10()81 inline DSizes dsizes_10() { return DSizes(1, 0); }
82 #endif
83
84 // Blocksizes
85 // TODO(agarwal): compute these sizes based on cache sizes.
86 const int K = 64;
87 const int M = 64;
88 const int N = 128;
89
90 // This stores a sparse representation of a slice of a matrix with size
91 // (num_rows, num_cols). The slice is represented as a series of blocks of size
92 // (num_rows, b), where b = block_size for all but the last block, which may
93 // have fewer columns.
94 //
95 // num_rows and block_size are assumed to be <= 256. This allows storing
96 // different indices as uint8.
97 //
98 // For each block, we store all the non zero entries in data/data3 vector and
99 // the corresponding coordinates of the element in index/index3 vectors. index3
100 // vector stores index of 3 elements in the same row so that these elements can
101 // share the same row coordinate. Each entry in Index3 corresponds to 3 entries
102 // in data3.
103 //
104 // Note that all the data/indices of all the blocks are stored in the same
105 // vectors respectively. To identify block boundaries, we store the block
106 // offsets using index3_offset/index_offset. If there are n blocks in the slice,
107 // index3_offset and index_offset have n entries. The indices for the ith block
108 // are the values in the following range:
109 // [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for
110 // index_offset.
111 template <typename T>
112 struct SparseSlice {
113 using ConstMatrixMap = BasicMatrixMap<const T>;
114
115 public:
116 // Indices of three elements on the same row.
117 struct Index3 {
118 uint8 m; // row
119 // columns
120 uint8 k1;
121 uint8 k2;
122 uint8 k3;
123 };
124
125 // Index of one element.
126 struct Index {
127 uint8 m;
128 uint8 k;
129 };
130
SparseSlicetensorflow::__anonf0ef8c480111::SparseSlice131 SparseSlice(int nrows, int ncols, int bsize)
132 : num_rows(nrows), num_cols(ncols), block_size(bsize) {
133 DCHECK_LE(nrows, 256);
134 DCHECK_LE(block_size, 256);
135 }
136
137 // Initializes the slice with data starting at mat(0, col_offset) and with
138 // size (num_rows, num_cols).
139 // If Transpose is true, implicitly transposes mat.
140 template <bool Transpose = false>
141 void Initialize(const ConstMatrixMap& mat, int col_offset);
142
143 void Clear();
144
145 // See comments above.
146 std::vector<int> index3_offset;
147 std::vector<Index3> index3;
148 std::vector<T> data3;
149
150 // See comments above. Similar to "index3" except that each element in "index"
151 // corresponds to one element in data.
152 std::vector<int> index_offset;
153 std::vector<Index> index;
154 std::vector<T> data;
155
156 // Number of rows and columns for the slice.
157 const int num_rows;
158 const int num_cols;
159
160 // Block size used to initialize from a matrix.
161 const int block_size;
162 };
163
164 template <typename T>
165 bool IsZero(T v);
166
167 template <>
IsZero(bfloat16 v)168 ALWAYS_INLINE bool IsZero(bfloat16 v) {
169 return v.IsZero();
170 }
171
172 template <>
IsZero(float v)173 ALWAYS_INLINE bool IsZero(float v) {
174 return v == 0.0f;
175 }
176
177 template <typename T>
178 template <bool Transpose>
Initialize(const typename SparseSlice<T>::ConstMatrixMap & mat,int col_offset)179 void SparseSlice<T>::Initialize(
180 const typename SparseSlice<T>::ConstMatrixMap& mat, int col_offset) {
181 const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0);
182 const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1);
183 DCHECK_LE(num_rows, mat_rows);
184 DCHECK_LE(num_cols + col_offset, mat_cols);
185
186 int num_blocks = (num_cols + block_size - 1) / block_size;
187 int mat_size = num_rows * num_cols;
188
189 index3_offset.reserve(num_blocks);
190 data3.reserve(mat_size);
191 index3.reserve(mat_size / 3);
192
193 index_offset.reserve(num_blocks);
194 data.reserve(num_blocks * num_rows * 2);
195 index.reserve(num_blocks * num_rows * 2);
196
197 Index3 idx3;
198 const int stride = Transpose ? mat.dimension(1) : 1;
199
200 for (int i = 0; i < num_blocks; ++i) {
201 int num_block_cols = std::min(block_size, num_cols - block_size * i);
202 for (int row = 0; row < num_rows; ++row) {
203 idx3.m = static_cast<uint8>(row);
204 // Safety note: The following code has a race, since it checks whether
205 // *curr is nonzero and then reads it again on use. However, the result
206 // of the race is only that some of the "nonzeros" in the resulting sparse
207 // representation may actually be zero, which is harmless.
208 const auto* start =
209 Transpose ? &mat(col_offset, row) : &mat(row, col_offset);
210 const auto* curr = start;
211 const auto* end = start + stride * num_block_cols;
212 uint8 k = 0;
213 #define NEXT_ELEM \
214 curr += stride; \
215 ++k;
216 #define EAT_ZEROS \
217 while (curr < end && IsZero<T>(*curr)) { \
218 NEXT_ELEM; \
219 }
220 while (true) {
221 EAT_ZEROS
222 if (curr >= end) break;
223 idx3.k1 = k;
224 const T value1 = *curr;
225 NEXT_ELEM;
226
227 EAT_ZEROS
228 if (curr >= end) {
229 data.push_back(value1);
230 index.push_back({idx3.m, idx3.k1});
231 break;
232 }
233 idx3.k2 = k;
234 const T value2 = *curr;
235 NEXT_ELEM;
236
237 EAT_ZEROS
238 if (curr >= end) {
239 data.push_back(value2);
240 index.push_back({idx3.m, idx3.k2});
241 data.push_back(value1);
242 index.push_back({idx3.m, idx3.k1});
243 break;
244 }
245 idx3.k3 = k;
246 data3.push_back(value1);
247 data3.push_back(value2);
248 data3.push_back(*curr);
249 NEXT_ELEM;
250 index3.push_back(idx3);
251 #undef NEXT_ELEM
252 #undef EAT_ZEROS
253 }
254 }
255 col_offset += block_size;
256 index3_offset.push_back(index3.size());
257 index_offset.push_back(index.size());
258 }
259 DCHECK_EQ(index3_offset.size(), num_blocks);
260 DCHECK_EQ(index_offset.size(), num_blocks);
261 DCHECK_EQ(3 * index3.size(), data3.size());
262 DCHECK_EQ(index.size(), data.size());
263 }
264
265 template <typename T>
Clear()266 void SparseSlice<T>::Clear() {
267 index3_offset.clear();
268 index3.clear();
269 data3.clear();
270 index_offset.clear();
271 index.clear();
272 data.clear();
273 }
274
275 using Packet = Eigen::internal::packet_traits<float>::type;
276 const int kNumOperands = (sizeof(Packet) / sizeof(float));
277 #define LOAD(x) Eigen::internal::pload<Packet>(x);
278 #define EXPAND_BFLOAT_L(x, y) \
279 const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x);
280 #define EXPAND_BFLOAT_U(x, y) \
281 const auto y = Eigen::internal::pexpand_bf16_u<Packet>(x);
282 #define STORE(x, y) Eigen::internal::pstore<float>(x, y);
283 #define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c);
284
ConvertBfloat16ToFloat(const bfloat16 * src)285 ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) {
286 float out = 0;
287 auto tmp = reinterpret_cast<bfloat16*>(&out);
288 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
289 tmp[0] = *src;
290 #else
291 tmp[1] = *src;
292 #endif
293 return out;
294 }
295
ConvertFourBfloat16ToFloat(const bfloat16 * src)296 ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) {
297 return Eigen::internal::pload4bf16<Packet>(
298 reinterpret_cast<const float*>(src));
299 }
300
ConvertTwoBfloat16ToFloat(const bfloat16 * src)301 ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) {
302 return Eigen::internal::pload2bf16<Packet>(
303 reinterpret_cast<const float*>(src));
304 }
305
ScalarMulAdd(const float a,const float ** inp,float ** out)306 ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) {
307 **out += a * **inp;
308 ++*inp;
309 ++*out;
310 }
311
ScalarMulAdd(const float a,const bfloat16 ** inp,float ** out)312 ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp,
313 float** out) {
314 float inp_f = ConvertBfloat16ToFloat(*inp);
315 **out += a * inp_f;
316 ++*inp;
317 ++*out;
318 }
ScalarMulAdd3Way(const float a1,const float a2,const float a3,const bfloat16 ** inp1,const bfloat16 ** inp2,const bfloat16 ** inp3,float ** out)319 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
320 const float a3, const bfloat16** inp1,
321 const bfloat16** inp2,
322 const bfloat16** inp3, float** out) {
323 float inp1_f = ConvertBfloat16ToFloat(*inp1);
324 float inp2_f = ConvertBfloat16ToFloat(*inp2);
325 float inp3_f = ConvertBfloat16ToFloat(*inp3);
326 **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f;
327 ++*out;
328 ++*inp1;
329 ++*inp2;
330 ++*inp3;
331 }
332
ScalarMulAdd3Way(const float a1,const float a2,const float a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)333 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
334 const float a3, const float** inp1,
335 const float** inp2, const float** inp3,
336 float** out) {
337 **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3;
338 ++*out;
339 ++*inp1;
340 ++*inp2;
341 ++*inp3;
342 }
343
LoadSingleScalar(const bfloat16 ** data,Packet * l)344 ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) {
345 auto tmp = ConvertBfloat16ToFloat(*data);
346 *l = Eigen::internal::pset1<Packet>(tmp);
347 ++*data;
348 }
349
LoadTwoScalars(const bfloat16 ** data,Packet * l1,Packet * l2)350 ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1,
351 Packet* l2) {
352 if (kNumOperands >= 2) {
353 auto tmp = ConvertTwoBfloat16ToFloat(*data);
354 *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
355 *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
356 *data += 2;
357 } else {
358 LoadSingleScalar(data, l1);
359 LoadSingleScalar(data, l2);
360 }
361 }
362
LoadFourScalars(const bfloat16 ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4)363 ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1,
364 Packet* l2, Packet* l3, Packet* l4) {
365 if (kNumOperands >= 4) {
366 auto tmp = ConvertFourBfloat16ToFloat(*data);
367 *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
368 *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
369 *l3 = Eigen::internal::pbroadcast_third<Packet>(tmp);
370 *l4 = Eigen::internal::pbroadcast_fourth<Packet>(tmp);
371 *data += 4;
372 } else {
373 LoadTwoScalars(data, l1, l2);
374 LoadTwoScalars(data, l3, l4);
375 }
376 }
377
LoadSingleScalar(const float ** data,Packet * l)378 ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) {
379 *l = Eigen::internal::pload1<Packet>(*data);
380 ++(*data);
381 }
382
LoadTwoScalars(const float ** data,Packet * l1,Packet * l2)383 ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) {
384 LoadSingleScalar(data, l1);
385 LoadSingleScalar(data, l2);
386 }
387
LoadFourScalars(const float ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4)388 ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2,
389 Packet* l3, Packet* l4) {
390 LoadTwoScalars(data, l1, l2);
391 LoadTwoScalars(data, l3, l4);
392 }
393
394 template <typename T>
LoadThreeScalars(const T ** data,Packet * l1,Packet * l2,Packet * l3)395 ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2,
396 Packet* l3) {
397 LoadTwoScalars(data, l1, l2);
398 LoadSingleScalar(data, l3);
399 }
400
401 template <typename T>
LoadSixScalars(const T ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4,Packet * l5,Packet * l6)402 ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2,
403 Packet* l3, Packet* l4, Packet* l5,
404 Packet* l6) {
405 LoadFourScalars(data, l1, l2, l3, l4);
406 LoadTwoScalars(data, l5, l6);
407 }
408
409 // Vectorized version of ScalarMulAdd.
MulAdd(const Packet a,const bfloat16 ** binp,float ** out)410 ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) {
411 auto inp = reinterpret_cast<const float*>(*binp);
412 const auto b = LOAD(inp);
413 EXPAND_BFLOAT_L(b, b_0);
414 EXPAND_BFLOAT_U(b, b_1);
415 *binp += 2 * kNumOperands;
416 auto c1 = LOAD(*out);
417 auto c2 = LOAD(*out + kNumOperands);
418 FMA(a, b_0, c1, c1);
419 FMA(a, b_1, c2, c2);
420 STORE(*out, c1);
421 STORE(*out + kNumOperands, c2);
422 *out += 2 * kNumOperands;
423 }
424
425 // Vectorized version of ScalarMulAdd3Way.
MulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** binp1,const bfloat16 ** binp2,const bfloat16 ** binp3,float ** out)426 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
427 const bfloat16** binp1, const bfloat16** binp2,
428 const bfloat16** binp3, float** out) {
429 auto inp1 = reinterpret_cast<const float*>(*binp1);
430 auto inp2 = reinterpret_cast<const float*>(*binp2);
431 auto inp3 = reinterpret_cast<const float*>(*binp3);
432 auto c1 = LOAD(*out);
433 auto c2 = LOAD(*out + kNumOperands);
434 const auto b1 = LOAD(inp1);
435 EXPAND_BFLOAT_L(b1, b1_0);
436 EXPAND_BFLOAT_U(b1, b1_1);
437 *binp1 += 2 * kNumOperands;
438 const auto b2 = LOAD(inp2);
439 EXPAND_BFLOAT_L(b2, b2_0);
440 EXPAND_BFLOAT_U(b2, b2_1);
441 *binp2 += 2 * kNumOperands;
442 const auto b3 = LOAD(inp3);
443 EXPAND_BFLOAT_L(b3, b3_0);
444 EXPAND_BFLOAT_U(b3, b3_1);
445 *binp3 += 2 * kNumOperands;
446 FMA(a1, b1_0, c1, c1);
447 FMA(a1, b1_1, c2, c2);
448 FMA(a2, b2_0, c1, c1);
449 FMA(a2, b2_1, c2, c2);
450 FMA(a3, b3_0, c1, c1);
451 FMA(a3, b3_1, c2, c2);
452 STORE(*out, c1);
453 STORE(*out + kNumOperands, c2);
454 *out += 2 * kNumOperands;
455 }
456
457 // Unroll MulAdd3Way for two iterations
TwoMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** binp1,const bfloat16 ** binp2,const bfloat16 ** binp3,float ** out)458 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
459 const Packet a3, const bfloat16** binp1,
460 const bfloat16** binp2, const bfloat16** binp3,
461 float** out) {
462 auto inp1 = reinterpret_cast<const float*>(*binp1);
463 auto inp2 = reinterpret_cast<const float*>(*binp2);
464 auto inp3 = reinterpret_cast<const float*>(*binp3);
465 auto c1 = LOAD(*out);
466 auto c2 = LOAD(*out + kNumOperands);
467 const auto b1 = LOAD(inp1);
468 const auto b2 = LOAD(inp2);
469 const auto b3 = LOAD(inp3);
470
471 EXPAND_BFLOAT_L(b1, b1_0);
472 EXPAND_BFLOAT_U(b1, b1_1);
473 EXPAND_BFLOAT_L(b2, b2_0);
474 EXPAND_BFLOAT_U(b2, b2_1);
475 EXPAND_BFLOAT_L(b3, b3_0);
476 EXPAND_BFLOAT_U(b3, b3_1);
477 auto c3 = LOAD(*out + 2 * kNumOperands);
478 auto c4 = LOAD(*out + 3 * kNumOperands);
479 const auto b4 = LOAD(inp1 + kNumOperands);
480 const auto b5 = LOAD(inp2 + kNumOperands);
481 const auto b6 = LOAD(inp3 + kNumOperands);
482
483 EXPAND_BFLOAT_L(b4, b4_0);
484 EXPAND_BFLOAT_U(b4, b4_1);
485 EXPAND_BFLOAT_L(b5, b5_0);
486 EXPAND_BFLOAT_U(b5, b5_1);
487 EXPAND_BFLOAT_L(b6, b6_0);
488 EXPAND_BFLOAT_U(b6, b6_1);
489
490 FMA(a1, b1_0, c1, c1);
491 FMA(a1, b1_1, c2, c2);
492 FMA(a1, b4_0, c3, c3);
493 FMA(a1, b4_1, c4, c4);
494 FMA(a2, b2_0, c1, c1);
495 FMA(a2, b2_1, c2, c2);
496 FMA(a2, b5_0, c3, c3);
497 FMA(a2, b5_1, c4, c4);
498 FMA(a3, b3_0, c1, c1);
499 FMA(a3, b3_1, c2, c2);
500 FMA(a3, b6_0, c3, c3);
501 FMA(a3, b6_1, c4, c4);
502 STORE(*out, c1);
503 STORE(*out + kNumOperands, c2);
504 STORE(*out + 2 * kNumOperands, c3);
505 STORE(*out + 3 * kNumOperands, c4);
506 *out += 4 * kNumOperands;
507 *binp1 += 4 * kNumOperands;
508 *binp2 += 4 * kNumOperands;
509 *binp3 += 4 * kNumOperands;
510 }
511
512 // Apply MulAdd3Way on 128 operands.
MulAdd3Way128(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** inp1,const bfloat16 ** inp2,const bfloat16 ** inp3,float ** out)513 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
514 const Packet a3, const bfloat16** inp1,
515 const bfloat16** inp2, const bfloat16** inp3,
516 float** out) {
517 for (int k = 0; k < 128 / (8 * kNumOperands); ++k) {
518 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
519 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
520 }
521 }
522
523 // Vectorized version of ScalarMulAdd
MulAdd(const Packet a,const float ** inp,float ** out)524 ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) {
525 const auto b = LOAD(*inp);
526 *inp += kNumOperands;
527 auto c = LOAD(*out);
528 FMA(a, b, c, c);
529 STORE(*out, c);
530 *out += kNumOperands;
531 }
532
533 // Vectorized version of ScalarMulAdd3Way
MulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)534 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
535 const float** inp1, const float** inp2,
536 const float** inp3, float** out) {
537 auto c = LOAD(*out);
538 const auto b1 = LOAD(*inp1);
539 *inp1 += kNumOperands;
540 const auto b2 = LOAD(*inp2);
541 *inp2 += kNumOperands;
542 const auto b3 = LOAD(*inp3);
543 *inp3 += kNumOperands;
544 FMA(a1, b1, c, c);
545 FMA(a2, b2, c, c);
546 FMA(a3, b3, c, c);
547 STORE(*out, c);
548 *out += kNumOperands;
549 }
550
551 // Unroll MulAdd3Way for two iterations
TwoMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)552 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
553 const Packet a3, const float** inp1,
554 const float** inp2, const float** inp3,
555 float** out) {
556 auto c1 = LOAD(*out);
557 const auto b1 = LOAD(*inp1);
558 const auto b2 = LOAD(*inp2);
559 const auto b3 = LOAD(*inp3);
560
561 auto c2 = LOAD(*out + kNumOperands);
562 const auto b4 = LOAD(*inp1 + kNumOperands);
563 const auto b5 = LOAD(*inp2 + kNumOperands);
564 const auto b6 = LOAD(*inp3 + kNumOperands);
565
566 FMA(a1, b1, c1, c1);
567 FMA(a1, b4, c2, c2);
568 FMA(a2, b2, c1, c1);
569 FMA(a2, b5, c2, c2);
570 FMA(a3, b3, c1, c1);
571 FMA(a3, b6, c2, c2);
572 STORE(*out, c1);
573 STORE(*out + kNumOperands, c2);
574 *out += 2 * kNumOperands;
575 *inp1 += 2 * kNumOperands;
576 *inp2 += 2 * kNumOperands;
577 *inp3 += 2 * kNumOperands;
578 }
579
580 // Unroll MulAdd3Way for four iterations
FourMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)581 ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2,
582 const Packet a3, const float** inp1,
583 const float** inp2, const float** inp3,
584 float** out) {
585 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
586 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
587 }
588
589 // Apply MulAdd3Way on 128 operands.
MulAdd3Way128(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)590 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
591 const Packet a3, const float** inp1,
592 const float** inp2, const float** inp3,
593 float** out) {
594 if (kNumOperands == 8) {
595 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
596 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
597 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
598 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
599 } else {
600 DCHECK_LE(4 * kNumOperands, 128);
601 for (int i = 0; i < 128 / (4 * kNumOperands); ++i) {
602 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
603 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
604 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
605 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
606 }
607 }
608 }
609 // Computes product of "left_slices" with "num_cols" columns of "right", and
610 // stores the output in *"output".
611 // Note that left_slices is a list of SparseSlices, which are conceptually
612 // assumed to be concatenated along the column dimension. Also each SparseSlice
613 // is encoded as a list of blocks with upto N columns. See SparseSlice for more
614 // details.
615 template <typename TL, typename TR, int Cols>
GEPP(const std::vector<SparseSlice<TL> * > & left_slices,const Eigen::TensorMap<Eigen::Tensor<const TR,2,Eigen::RowMajor>,Eigen::Aligned> & right,const int num_cols,Matrix * output)616 inline void GEPP(
617 const std::vector<SparseSlice<TL>*>& left_slices,
618 const Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
619 Eigen::Aligned>& right,
620 const int num_cols, Matrix* output) {
621 const int cols = (Cols == -1) ? num_cols : Cols;
622 DCHECK_EQ(num_cols, cols);
623 const int right_num_cols = right.dimension(1);
624 const int output_num_cols = output->dimension(1);
625 static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR);
626 const int cols_mod = cols % kNumOperandsR;
627 int k_offset = 0;
628 // Pre-compute pointers for output matrix.
629 float* out_ptrs[M];
630 float* const out_start = &(*output)(0, 0);
631 for (int j = 0; j < M; ++j) {
632 out_ptrs[j] = out_start + output_num_cols * j;
633 }
634 for (const auto* left_slice : left_slices) {
635 const auto& left = *left_slice;
636 const auto* data3 = (!left.data3.empty()) ? &left.data3[0] : nullptr;
637 const auto* data = (!left.data.empty()) ? &left.data[0] : nullptr;
638 const int num_blocks = left.index3_offset.size();
639 int begin3 = 0;
640 int begin = 0;
641 for (int i = 0; i < num_blocks; ++i) {
642 // Pre-compute pointers for right matrix
643 const TR* right_ptrs[K];
644 const auto* const right_start = &right(k_offset, 0);
645 DCHECK_LT(k_offset, right.dimension(0));
646 for (int j = 0; j < K; ++j) {
647 right_ptrs[j] = right_start + right_num_cols * j;
648 }
649
650 const int end3 = left.index3_offset[i];
651 int j = begin3;
652 // Loop unrolled for 2 iterations.
653 for (; j + 1 < end3; j += 2) {
654 Packet l1, l2, l3, nl1, nl2, nl3;
655 LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3);
656 const auto& index = left.index3[j];
657 const auto& nindex = left.index3[j + 1];
658 float* out = out_ptrs[index.m];
659 float* nout = out_ptrs[nindex.m];
660 const auto* r1 = right_ptrs[index.k1];
661 const auto* r2 = right_ptrs[index.k2];
662 const auto* r3 = right_ptrs[index.k3];
663
664 const auto* nr1 = right_ptrs[nindex.k1];
665 const auto* nr2 = right_ptrs[nindex.k2];
666 const auto* nr3 = right_ptrs[nindex.k3];
667 if (cols == 128) {
668 MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
669 MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
670 } else {
671 for (int n = 0; n < cols / kNumOperandsR; ++n) {
672 MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
673 MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
674 }
675
676 const float sl1 = Eigen::internal::pfirst<Packet>(l1);
677 const float sl2 = Eigen::internal::pfirst<Packet>(l2);
678 const float sl3 = Eigen::internal::pfirst<Packet>(l3);
679 const float nsl1 = Eigen::internal::pfirst<Packet>(nl1);
680 const float nsl2 = Eigen::internal::pfirst<Packet>(nl2);
681 const float nsl3 = Eigen::internal::pfirst<Packet>(nl3);
682 for (int k = 0; k < cols_mod; ++k) {
683 ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
684 ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout);
685 }
686 }
687 }
688 if (j < end3) {
689 Packet l1, l2, l3;
690 LoadThreeScalars(&data3, &l1, &l2, &l3);
691
692 const auto& index = left.index3[j];
693 float* out = out_ptrs[index.m];
694 const auto* r1 = right_ptrs[index.k1];
695 const auto* r2 = right_ptrs[index.k2];
696 const auto* r3 = right_ptrs[index.k3];
697 if (cols == 128) {
698 MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
699 } else {
700 for (int n = 0; n < cols / kNumOperandsR; ++n) {
701 MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
702 }
703 const float sl1 = Eigen::internal::pfirst<Packet>(l1);
704 const float sl2 = Eigen::internal::pfirst<Packet>(l2);
705 const float sl3 = Eigen::internal::pfirst<Packet>(l3);
706 for (int k = 0; k < cols_mod; ++k) {
707 ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
708 }
709 }
710 }
711 begin3 = end3;
712 int end = left.index_offset[i];
713 // Loop unrolled for 4 iterations.
714 j = begin;
715 for (; j + 3 < end; j += 4) {
716 Packet l, nl, n2l, n3l;
717 LoadFourScalars(&data, &l, &nl, &n2l, &n3l);
718
719 const auto& index = left.index[j];
720 const auto& nindex = left.index[j + 1];
721 const auto& n2index = left.index[j + 2];
722 const auto& n3index = left.index[j + 3];
723 const auto* r = right_ptrs[index.k];
724 const auto* nr = right_ptrs[nindex.k];
725 const auto* n2r = right_ptrs[n2index.k];
726 const auto* n3r = right_ptrs[n3index.k];
727 float* out = out_ptrs[index.m];
728 float* nout = out_ptrs[nindex.m];
729 float* n2out = out_ptrs[n2index.m];
730 float* n3out = out_ptrs[n3index.m];
731
732 for (int n = 0; n < cols / kNumOperandsR; ++n) {
733 MulAdd(l, &r, &out);
734 MulAdd(nl, &nr, &nout);
735 MulAdd(n2l, &n2r, &n2out);
736 MulAdd(n3l, &n3r, &n3out);
737 }
738
739 const float sl1 = Eigen::internal::pfirst<Packet>(l);
740 const float sl2 = Eigen::internal::pfirst<Packet>(nl);
741 const float sl3 = Eigen::internal::pfirst<Packet>(n2l);
742 const float sl4 = Eigen::internal::pfirst<Packet>(n3l);
743 for (int k = 0; k < cols_mod; ++k) {
744 ScalarMulAdd(sl1, &r, &out);
745 ScalarMulAdd(sl2, &nr, &nout);
746 ScalarMulAdd(sl3, &n2r, &n2out);
747 ScalarMulAdd(sl4, &n3r, &n3out);
748 }
749 }
750 while (j < end) {
751 Packet l;
752 LoadSingleScalar(&data, &l);
753 const auto& index = left.index[j];
754 const auto* r = right_ptrs[index.k];
755 float* out = out_ptrs[index.m];
756 for (int n = 0; n < cols / kNumOperandsR; ++n) {
757 MulAdd(l, &r, &out);
758 }
759 const float sl = Eigen::internal::pfirst<Packet>(l);
760 for (int k = 0; k < cols_mod; ++k) {
761 ScalarMulAdd(sl, &r, &out);
762 }
763 j++;
764 }
765 k_offset += left.block_size;
766 begin = end;
767 }
768 }
769 }
770
771 #undef LOAD
772 #undef EXPAND_BFLOAT_L
773 #undef EXPAND_BFLOAT_U
774 #undef STORE
775 #undef FMA
776
777 } // namespace
778
779 template <typename TL, typename TR>
780 class SparseMatMul {
781 using MatrixL = BasicMatrix<TL>;
782 using MatrixR = BasicMatrix<TR>;
783 using ConstMatrixMapL = BasicMatrixMap<const TL>;
784 using ConstMatrixMapR = BasicMatrixMap<const TR>;
785 using MatrixMapR = BasicMatrixMap<TR>;
786
787 public:
788 // Not used; added to match interface of LibxsmmSparseMatMul
789 struct TensorInfoCache {};
790
791 // Perform matrix multiplication of "left" and "right", and store the result
792 // in *"output".
793 public:
794 static inline void Compute(TensorInfoCache* cache,
795 const ConstMatrixMapL& left,
796 const ConstMatrixMapR& right, bool transpose_left,
797 const DeviceBase::CpuWorkerThreads* thread_pool,
798 bool transpose_output, MatrixMap* output);
799
800 private:
801 // Computes multiplication of left and num_cols columns of right, and stores
802 // the output block in *"output" at offsets "output_row_offset" and
803 // "output_col_offset". If assign is true, assigns the value to that block,
804 // else adds the values to the existing values.
805 static inline void ComputeOutputBlock(
806 const std::vector<SparseSlice<TL>*>& left, const ConstMatrixMapR& right,
807 int num_cols, int output_row_offset, int output_col_offset, bool assign,
808 bool transpose_output, MatrixMap* output);
809
810 // Encodes "mat" using a sparse representation and stores that in
811 // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and
812 // "slice_num_cols", each grid element is converted into a SparseSlice and
813 // stored in mat_slices. "slice_block_size" is used to perform further column
814 // blocking of each slice.
815 static inline std::unique_ptr<BlockingCounter> CreateSparseSlices(
816 const ConstMatrixMapL& mat, bool transpose, int slice_num_rows,
817 int slice_block_size, int slice_num_cols,
818 std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
819 const DeviceBase::CpuWorkerThreads* thread_pool);
820
821 // This function chops "mat" along column dimension into pieces with at most N
822 // columns, and concatenates the pieces one after the other in "buffer". It
823 // returns the list of the pieces in "slices". It returns a BlockingCounter
824 // which should be used to wait for the shuffle operations to complete.
825 static inline std::unique_ptr<BlockingCounter> CreateDenseSlices(
826 const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start,
827 int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
828 MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices);
829
830 // Helper function for CreateDenseSlices to move the data around. It returns a
831 // BlockingCounter which should be used to wait for the shuffle operations to
832 // complete.
833 static inline BlockingCounter* ShuffleMatrix(
834 const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows,
835 int slice_col_start, int slice_num_cols, const int N,
836 const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer);
837
838 // Helper function for CreateDenseSlices to create slices.
839 static inline void SliceMatrix(const MatrixR& mat, const int num_rows,
840 const int num_slices,
841 std::vector<ConstMatrixMapR*>* slices);
842
843 // Heuristics to compute various block sizes.
844 // KR, NR: block sizes for "right". We run blocking iterations that operate on
845 // matrices with at most this size.
846 // KL: grid size along the column dimension used while encoding left.
847 // IB, JB: number of left and right slices to multiply together. This is used
848 // for ordering different ComputeBlockOutput operations inside each blocking
849 // iteration so as to potentially reduce the working set size.
850 static inline void ComputeBlockSizes(const ConstMatrixMapL& left,
851 const ConstMatrixMapR& right,
852 bool transpose_left, int num_threads,
853 int* KR, int* NR, int* KL, int* JB,
854 int* IB);
855
856 TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul);
857 };
858
859 #ifdef TENSORFLOW_USE_LIBXSMM
860 template <typename TL, typename TR>
861 class LibxsmmSparseMatMul {
862 using MatrixL = BasicMatrix<TL>;
863 using MatrixR = BasicMatrix<TR>;
864 using ConstMatrixMapL = BasicMatrixMap<const TL>;
865 using ConstMatrixMapR = BasicMatrixMap<const TR>;
866 using MatrixMapR = BasicMatrixMap<TR>;
867
868 public:
869 // This structure contains a set of libxsmm kernels for sizes that have been
870 // encountered previously by this operator so that libxsmm does not need to
871 // reallocate its scratchpad memory each time (which hurts performance
872 // substantially).
873 struct TensorInfoCache {
874 struct TensorInfoCacheEntry {
875 // Parameters for kernel
876 int M;
877 int K;
878 int N;
879 int max_threads;
880 // libxsmm handle and matrix data
881 libxsmm_spmdm_handle handle;
882 libxsmm_CSR_sparseslice* output_csr;
883 // Chain to non-libxsmm implementation's cache in case that ever becomes
884 // useful (it is an empty struct right now)
885 typename SparseMatMul<TL, TR>::TensorInfoCache
886 non_libxsmm_cache; // Currently not used
887 };
888 // protects entries; invariant: entries is a valid std::multimap
889 tensorflow::mutex lock;
890 // Because there could be multiple matrix multiplies with the same sizes
891 // going on at the same time, we need to allow multiple cache entries for a
892 // given set of parameters. Taking and returning entries is used to make
893 // sure the same cache entry is not used from two threads at a time.
894 std::multimap<std::tuple<int, int, int, int>,
895 std::unique_ptr<TensorInfoCacheEntry>>
896 entries GUARDED_BY(lock);
897
TensorInfoCachetensorflow::LibxsmmSparseMatMul::TensorInfoCache898 TensorInfoCache() : lock(), entries() {}
899 // Look up and remove first entry with these parameters, creating one if
900 // there isn't one
take_cache_entrytensorflow::LibxsmmSparseMatMul::TensorInfoCache901 std::unique_ptr<TensorInfoCacheEntry> take_cache_entry(int M, int K, int N,
902 int max_threads)
903 LOCKS_EXCLUDED(lock) {
904 tensorflow::mutex_lock ml(lock);
905 auto key = std::make_tuple(M, K, N, max_threads);
906 auto it = entries.find(key);
907 if (it != entries.end()) {
908 auto val = std::move(it->second);
909 entries.erase(it);
910 return val;
911 } else {
912 std::unique_ptr<TensorInfoCacheEntry> e{
913 new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}};
914 // setup scoped allocator, which uses cpu_allocator() for this scope
915 const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator;
916 libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr);
917 return e;
918 }
919 }
920 // Add a cache entry with certain parameters
return_cache_entrytensorflow::LibxsmmSparseMatMul::TensorInfoCache921 void return_cache_entry(std::unique_ptr<TensorInfoCacheEntry> e)
922 LOCKS_EXCLUDED(lock) {
923 tensorflow::mutex_lock ml(lock);
924 auto key = std::make_tuple(e->M, e->K, e->N, e->max_threads);
925 entries.insert(std::make_pair(key, std::move(e)));
926 }
~TensorInfoCachetensorflow::LibxsmmSparseMatMul::TensorInfoCache927 ~TensorInfoCache() {
928 tensorflow::mutex_lock ml(lock);
929 for (auto& p : entries) {
930 libxsmm_spmdm_destroy(&p.second->handle);
931 }
932 entries.clear();
933 }
934
935 private:
936 TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache);
937 };
938
939 // Perform matrix multiplication of "left" and "right", and store the result
940 // in *"output".
941 public:
942 static inline void Compute(TensorInfoCache* cache,
943 const ConstMatrixMapL& left,
944 const ConstMatrixMapR& right, bool transpose_left,
945 const DeviceBase::CpuWorkerThreads* thread_pool,
946 bool transpose_output, MatrixMap* output);
947
948 private:
949 TF_DISALLOW_COPY_AND_ASSIGN(LibxsmmSparseMatMul);
950 };
951 #endif
952
953 template <typename TL, typename TR,
954 template <typename TL2, typename TR2> class DoMatMul>
955 class SparseMatMulOp : public OpKernel {
956 using MatrixR = BasicMatrix<TR>;
957 using ConstMatrixMapR = BasicMatrixMap<const TR>;
958
959 public:
SparseMatMulOp(OpKernelConstruction * ctx)960 explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
961 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
962 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
963 OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &a_is_sparse_));
964 OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &b_is_sparse_));
965 }
966
Compute(OpKernelContext * ctx)967 void Compute(OpKernelContext* ctx) override {
968 const Tensor& a = ctx->input(0);
969 const Tensor& b = ctx->input(1);
970 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
971 errors::InvalidArgument("a is not a matrix"));
972 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
973 errors::InvalidArgument("b is not a matrix"));
974
975 const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0);
976 const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1);
977 const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1);
978 const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0);
979
980 OP_REQUIRES(ctx, k == k2,
981 errors::InvalidArgument(
982 "Matrix size incompatible: a: ", a.shape().DebugString(),
983 ", b: ", b.shape().DebugString()));
984 Tensor* output = nullptr;
985 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output));
986
987 if (k == 0) {
988 // If the inner dimension k in the matrix multiplication is zero, we fill
989 // the output with zeros.
990 functor::SetZeroFunctor<CPUDevice, float> f;
991 f(ctx->eigen_device<CPUDevice>(), output->flat<float>());
992 return;
993 }
994
995 auto out = output->matrix<float>();
996
997 std::unique_ptr<Tensor> a_float;
998 std::unique_ptr<Tensor> b_float;
999 if (!a_is_sparse_ && !b_is_sparse_) {
1000 auto left = &a;
1001 auto right = &b;
1002 // TODO(agarwal): multi-thread the conversions from bfloat16 to float.
1003 if (std::is_same<TL, bfloat16>::value) {
1004 a_float.reset(new Tensor(DT_FLOAT, a.shape()));
1005 BFloat16ToFloat(a.flat<bfloat16>().data(),
1006 a_float->flat<float>().data(), a.NumElements());
1007 left = a_float.get();
1008 }
1009 if (std::is_same<TR, bfloat16>::value) {
1010 b_float.reset(new Tensor(DT_FLOAT, b.shape()));
1011 BFloat16ToFloat(b.flat<bfloat16>().data(),
1012 b_float->flat<float>().data(), b.NumElements());
1013 right = b_float.get();
1014 }
1015 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
1016 dim_pair[0].first = transpose_a_ ? 0 : 1;
1017 dim_pair[0].second = transpose_b_ ? 1 : 0;
1018
1019 out.device(ctx->template eigen_device<CPUDevice>()) =
1020 left->matrix<float>().contract(right->matrix<float>(), dim_pair);
1021 return;
1022 }
1023
1024 auto left = &a;
1025 auto right = &b;
1026 bool transpose_output = false;
1027 bool transpose_a = transpose_a_;
1028 bool transpose_b = transpose_b_;
1029 if (!a_is_sparse_) {
1030 // Swap the order of multiplications using the identity:
1031 // A * B = (B' * A')'.
1032 std::swap(left, right);
1033 std::swap(transpose_a, transpose_b);
1034 transpose_a = !transpose_a;
1035 transpose_b = !transpose_b;
1036 transpose_output = !transpose_output;
1037 }
1038
1039 std::unique_ptr<Tensor> right_tr;
1040 if (transpose_b) {
1041 // TODO(agarwal): avoid transposing the matrix here and directly handle
1042 // transpose in CreateDenseSlices.
1043 right_tr.reset(
1044 new Tensor(right->dtype(),
1045 TensorShape({right->dim_size(1), right->dim_size(0)})));
1046
1047 const auto perm = dsizes_10();
1048 if (transpose_output) {
1049 right_tr->matrix<TL>().device(ctx->template eigen_device<CPUDevice>()) =
1050 right->matrix<TL>().shuffle(perm);
1051 } else {
1052 right_tr->matrix<TR>().device(ctx->template eigen_device<CPUDevice>()) =
1053 right->matrix<TR>().shuffle(perm);
1054 }
1055 right = right_tr.get();
1056 }
1057
1058 if (transpose_output) {
1059 DoMatMul<TR, TL>::Compute(&this->cache_tr_, left->matrix<TR>(),
1060 right->matrix<TL>(), transpose_a,
1061 ctx->device()->tensorflow_cpu_worker_threads(),
1062 transpose_output, &out);
1063 } else {
1064 DoMatMul<TL, TR>::Compute(&this->cache_nt_, left->matrix<TL>(),
1065 right->matrix<TR>(), transpose_a,
1066 ctx->device()->tensorflow_cpu_worker_threads(),
1067 transpose_output, &out);
1068 }
1069 }
1070
1071 private:
1072 bool transpose_a_;
1073 bool transpose_b_;
1074 bool a_is_sparse_;
1075 bool b_is_sparse_;
1076
1077 // Cache for non-transposed-output multiply
1078 typename DoMatMul<TL, TR>::TensorInfoCache cache_nt_;
1079 // Cache for transposed-output multiply
1080 typename DoMatMul<TR, TL>::TensorInfoCache cache_tr_;
1081
1082 TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp);
1083 };
1084
1085 template <typename TL, typename TR>
ComputeOutputBlock(const std::vector<SparseSlice<TL> * > & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,int num_cols,int output_row_offset,int output_col_offset,bool assign,bool transpose_output,MatrixMap * output)1086 inline void SparseMatMul<TL, TR>::ComputeOutputBlock(
1087 const std::vector<SparseSlice<TL>*>& left,
1088 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols,
1089 int output_row_offset, int output_col_offset, bool assign,
1090 bool transpose_output, MatrixMap* output) {
1091 const auto perm = dsizes_10();
1092 int num_rows = left[0]->num_rows;
1093 const int rhs_num_cols = right.dimension(1);
1094 DCHECK_LE(num_cols, rhs_num_cols);
1095 Matrix out(num_rows, rhs_num_cols);
1096 out.setZero();
1097 if (num_cols == N) {
1098 GEPP<TL, TR, N>(left, right, num_cols, &out);
1099 } else {
1100 GEPP<TL, TR, -1>(left, right, num_cols, &out);
1101 }
1102 if (!assign) {
1103 const DSizes begin(output_row_offset, output_col_offset);
1104 const DSizes sizes(num_rows, num_cols);
1105 if (transpose_output) {
1106 if (num_cols == rhs_num_cols) {
1107 output->shuffle(perm).slice(begin, sizes) += out;
1108 } else {
1109 const auto zero = dsizes_00();
1110 output->shuffle(perm).slice(begin, sizes) += out.slice(zero, sizes);
1111 }
1112 } else {
1113 if (num_cols == rhs_num_cols) {
1114 output->slice(begin, sizes) += out;
1115 } else {
1116 const auto zero = dsizes_00();
1117 output->slice(begin, sizes) += out.slice(zero, sizes);
1118 }
1119 }
1120 } else {
1121 std::unique_ptr<Matrix> out_tr;
1122 if (transpose_output) {
1123 out_tr.reset(new Matrix(rhs_num_cols, num_rows));
1124 *out_tr = out.shuffle(perm);
1125 std::swap(output_row_offset, output_col_offset);
1126 std::swap(num_rows, num_cols);
1127 }
1128 const Matrix& final_out = transpose_output ? *out_tr : out;
1129 for (int i = 0; i < num_rows; ++i) {
1130 memcpy(&(*output)(output_row_offset + i, output_col_offset),
1131 &final_out(i, 0), num_cols * sizeof(float));
1132 }
1133 }
1134 }
1135
1136 template <typename TL, typename TR>
1137 inline std::unique_ptr<BlockingCounter>
CreateSparseSlices(const typename SparseMatMul<TL,TR>::ConstMatrixMapL & mat,bool transpose,int slice_num_rows,int slice_block_size,int slice_num_cols,std::vector<std::vector<SparseSlice<TL> * >> * mat_slices,const DeviceBase::CpuWorkerThreads * thread_pool)1138 SparseMatMul<TL, TR>::CreateSparseSlices(
1139 const typename SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose,
1140 int slice_num_rows, int slice_block_size, int slice_num_cols,
1141 std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
1142 const DeviceBase::CpuWorkerThreads* thread_pool) {
1143 const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0);
1144 const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1);
1145 const int num_slices_dim0 =
1146 std::max(1, (mat_num_rows + slice_num_rows - 1) / slice_num_rows);
1147 const int num_slices_dim1 =
1148 std::max(1, (mat_num_cols + slice_num_cols - 1) / slice_num_cols);
1149 mat_slices->resize(num_slices_dim0);
1150 BlockingCounter* counter =
1151 new BlockingCounter(num_slices_dim0 * num_slices_dim1);
1152 auto work = [counter, transpose](SparseSlice<TL>* sparse_slice,
1153 SparseMatMul<TL, TR>::ConstMatrixMapL* slice,
1154 int col_offset) {
1155 if (transpose) {
1156 sparse_slice->template Initialize<true>(*slice, col_offset);
1157 } else {
1158 sparse_slice->template Initialize<false>(*slice, col_offset);
1159 }
1160 delete slice;
1161 counter->DecrementCount();
1162 };
1163 for (int i = 0; i < num_slices_dim0; ++i) {
1164 (*mat_slices)[i].resize(num_slices_dim1);
1165 int num_rows =
1166 std::min<int>(slice_num_rows, mat_num_rows - i * slice_num_rows);
1167 for (int j = 0; j < num_slices_dim1; ++j) {
1168 int num_cols =
1169 std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols);
1170 SparseMatMul<TL, TR>::ConstMatrixMapL* slice = nullptr;
1171 if (transpose) {
1172 slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1173 &mat(0, i * slice_num_rows), mat.dimensions());
1174 } else {
1175 DSizes d(num_rows, mat_num_cols);
1176 slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1177 &mat(i * slice_num_rows, 0), d);
1178 }
1179 auto* sparse_slice =
1180 new SparseSlice<TL>(num_rows, num_cols, slice_block_size);
1181 (*mat_slices)[i][j] = sparse_slice;
1182 thread_pool->workers->Schedule(
1183 [=]() { work(sparse_slice, slice, slice_num_cols * j); });
1184 }
1185 }
1186 return std::unique_ptr<BlockingCounter>(counter);
1187 }
1188 #define LOAD(x) Eigen::internal::ploadu<Packet>((x));
1189 #define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x);
1190 #define STORE(x, y) Eigen::internal::pstoreu<float>(x, y);
1191
1192 template <int NUM_ELEM = -1>
CopyAndMayBeInterleaveBfloat16(void * bdst,const void * bsrc,int num_elements)1193 ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc,
1194 int num_elements) {
1195 DCHECK_GE(kNumOperands, 8);
1196 static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16);
1197 const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM;
1198 DCHECK_EQ(num, num_elements);
1199 const float* src = reinterpret_cast<const float*>(bsrc);
1200 float* dst = reinterpret_cast<float*>(bdst);
1201 for (int index = 0; index + kStep <= num; index += kStep) {
1202 auto in = LOAD(src);
1203 auto tmp = INTERLEAVE(in);
1204 STORE(dst, tmp);
1205 src += kNumOperands;
1206 dst += kNumOperands;
1207 }
1208 if (num % kStep != 0) {
1209 memcpy(dst, src, (num % kStep) * sizeof(bfloat16));
1210 }
1211 }
1212
1213 template <typename T>
CopyAndMayBeInterleave(void * dst,const void * src,int num_elements)1214 ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src,
1215 int num_elements) {
1216 if (std::is_same<T, float>::value || kNumOperands < 8) {
1217 memcpy(dst, src, num_elements * sizeof(T));
1218 } else if (std::is_same<T, bfloat16>::value) {
1219 if (num_elements == N) {
1220 CopyAndMayBeInterleaveBfloat16<N>(dst, src, num_elements);
1221 } else {
1222 CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements);
1223 }
1224 } else {
1225 LOG(FATAL) << "Unsupported type";
1226 }
1227 }
1228
1229 #undef LOAD
1230 #undef Interleave
1231 #undef Store
1232
1233 template <typename TL, typename TR>
ShuffleMatrix(const typename SparseMatMul<TL,TR>::ConstMatrixMapR & mat,int slice_row_start,int slice_num_rows,int slice_col_start,int slice_num_cols,const int N,const DeviceBase::CpuWorkerThreads * thread_pool,MatrixR * buffer)1234 inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix(
1235 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat,
1236 int slice_row_start, int slice_num_rows, int slice_col_start,
1237 int slice_num_cols, const int N,
1238 const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) {
1239 DCHECK_EQ(N % 2, 0);
1240 DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N);
1241 int num_threads = std::min(thread_pool->num_threads, 16);
1242 BlockingCounter* counter = new BlockingCounter(num_threads);
1243 DCHECK_EQ(N, buffer->dimension(1));
1244 auto shuffle_work = [&mat, slice_row_start, slice_num_rows, slice_col_start,
1245 slice_num_cols, N, buffer, counter](int s, int e) {
1246 const int row_start = s % slice_num_rows + slice_row_start;
1247 const int col_start = s / slice_num_rows * N + slice_col_start;
1248 auto* out_start = &(*buffer)(s, 0);
1249 const auto* input_start = &mat(row_start, col_start);
1250 const auto* input_end = &mat(slice_row_start + slice_num_rows - 1,
1251 slice_col_start + slice_num_cols - 1);
1252 const int mat_num_cols = mat.dimension(1);
1253 const int row_slice_size = slice_num_rows * mat_num_cols;
1254
1255 const int aligned_end = slice_num_cols / N * slice_num_rows;
1256 const int e1 = std::min(e, aligned_end);
1257 while (s < e1) {
1258 CopyAndMayBeInterleave<TR>(out_start, input_start, N);
1259 out_start += N;
1260 input_start += mat_num_cols;
1261 if (input_start > input_end) {
1262 input_start = input_start - row_slice_size + N;
1263 }
1264 ++s;
1265 }
1266 int s1 = std::max(s, aligned_end);
1267 const int copy_num_cols = slice_num_cols % N;
1268 while (s1 < e) {
1269 CopyAndMayBeInterleave<TR>(out_start, input_start, copy_num_cols);
1270 out_start += N;
1271 input_start += mat_num_cols;
1272 ++s1;
1273 }
1274 if (counter) counter->DecrementCount();
1275 };
1276
1277 int start = 0;
1278 int end = 0;
1279 int num_out_rows = (slice_num_cols + N - 1) / N * slice_num_rows;
1280 DCHECK_LE(num_out_rows, buffer->dimension(0));
1281 for (int i = std::max(1, num_threads); i > 0; --i) {
1282 end = start + num_out_rows / i;
1283 thread_pool->workers->Schedule([=]() { shuffle_work(start, end); });
1284 num_out_rows -= (end - start);
1285 start = end;
1286 }
1287 return counter;
1288 }
1289
1290 template <typename TL, typename TR>
SliceMatrix(const MatrixR & mat,const int num_rows,const int num_slices,std::vector<typename SparseMatMul<TL,TR>::ConstMatrixMapR * > * slices)1291 inline void SparseMatMul<TL, TR>::SliceMatrix(
1292 const MatrixR& mat, const int num_rows, const int num_slices,
1293 std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1294 slices->resize(num_slices);
1295 DSizes d(num_rows, mat.dimension(1));
1296 DCHECK_LE(num_rows * num_slices, mat.dimension(0));
1297 for (int i = 0; i < num_slices; ++i) {
1298 (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d);
1299 }
1300 }
1301
1302 template <typename TL, typename TR>
CreateDenseSlices(const typename SparseMatMul<TL,TR>::ConstMatrixMapR & mat,int row_start,int num_rows,int col_start,int num_cols,const DeviceBase::CpuWorkerThreads * thread_pool,MatrixR * buffer,std::vector<typename SparseMatMul<TL,TR>::ConstMatrixMapR * > * slices)1303 inline std::unique_ptr<BlockingCounter> SparseMatMul<TL, TR>::CreateDenseSlices(
1304 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start,
1305 int num_rows, int col_start, int num_cols,
1306 const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer,
1307 std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1308 std::unique_ptr<BlockingCounter> shuffle_counter(ShuffleMatrix(
1309 mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer));
1310 const int num_slices = (num_cols + N - 1) / N;
1311 SliceMatrix(*buffer, num_rows, num_slices, slices);
1312 return shuffle_counter;
1313 }
1314
1315 template <typename TL, typename TR>
ComputeBlockSizes(const typename SparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,int num_threads,int * KR,int * NR,int * KL,int * JB,int * IB)1316 inline void SparseMatMul<TL, TR>::ComputeBlockSizes(
1317 const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1318 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1319 bool transpose_left, int num_threads, int* KR, int* NR, int* KL, int* JB,
1320 int* IB) {
1321 // Heuristics for calculating block sizes
1322 // Assume two hyperthreads per core.
1323 const int est_num_cores = std::max(1, (num_threads + 1) / 2);
1324 // Use block of rhs with at most 128K floats per core.
1325 const int mem = est_num_cores * 128 * 1024;
1326 *KR = std::min(static_cast<int>(right.dimension(0)), mem / 256);
1327 *NR = right.dimension(1);
1328 if (*KR * *NR > mem) {
1329 // 4096 may be enough to amortize the cost of writes.
1330 *KR = std::min<int>(*KR, 4096);
1331 }
1332 // Use sizes that are multiples of K and 256.
1333 *KR = std::max(1, *KR / K) * K;
1334 *NR = std::max(1, *NR / 256) * 256;
1335 if (*KR * *NR > mem) {
1336 *NR = mem / *KR;
1337 }
1338 *NR = std::max(1, *NR / 256) * 256;
1339
1340 const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1341 const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1342 for (*KL = 1024; *KL > K; *KL /= 2) {
1343 if (*KR % *KL == 0 &&
1344 std::max<int>(1, left_dim0 / 64) * (left_dim1 / *KL) > est_num_cores) {
1345 break;
1346 }
1347 }
1348 DCHECK_EQ(*KL % K, 0);
1349 DCHECK_GE(*KR, *KL);
1350 if (*KR < right.dimension(0)) {
1351 CHECK_EQ(*KR % *KL, 0);
1352 }
1353
1354 *JB = std::max(1, static_cast<int>(sqrt(num_threads) / 2.0));
1355 *IB = 8 * *JB;
1356 DCHECK_EQ(N * sizeof(float) % 64, size_t{0});
1357 }
1358
1359 #ifdef TENSORFLOW_USE_LIBXSMM
1360
1361 template <typename F>
do_on_all_threads(const DeviceBase::CpuWorkerThreads * thread_pool,const F & f)1362 void do_on_all_threads(const DeviceBase::CpuWorkerThreads* thread_pool,
1363 const F& f) {
1364 int num_threads = thread_pool->num_threads;
1365 if (num_threads == 0) {
1366 LOG(FATAL) << "Have 0 threads in thread pool";
1367 } else if (num_threads == 1) {
1368 f(0);
1369 } else {
1370 BlockingCounter counter(num_threads - 1);
1371 for (int i = 1; i < num_threads; ++i) {
1372 thread_pool->workers->Schedule([&, i]() {
1373 f(i);
1374 counter.DecrementCount();
1375 });
1376 }
1377 f(0);
1378 counter.Wait();
1379 }
1380 }
1381
1382 template <typename T>
1383 struct empty_type_wrapper {};
1384
1385 // Copies of interface to libxsmm_spmdm_createSparseSlice_*_notrans_thread to
1386 // allow overloading
wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(empty_type_wrapper<float>,const libxsmm_spmdm_handle * handle,char transA,const float * A,libxsmm_CSR_sparseslice * libxsmm_output_csr_a,int block_id,int tid,int nthreads)1387 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1388 empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1389 const float* A, libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id,
1390 int tid, int nthreads) {
1391 return libxsmm_spmdm_createSparseSlice_fp32_thread(
1392 handle, transA, A, libxsmm_output_csr_a, block_id, tid, nthreads);
1393 }
wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(empty_type_wrapper<bfloat16>,const libxsmm_spmdm_handle * handle,char transA,const bfloat16 * A,libxsmm_CSR_sparseslice * libxsmm_output_csr_a,int block_id,int tid,int nthreads)1394 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1395 empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1396 char transA, const bfloat16* A,
1397 libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid,
1398 int nthreads) {
1399 return libxsmm_spmdm_createSparseSlice_bfloat16_thread(
1400 handle, transA, reinterpret_cast<const uint16*>(A), libxsmm_output_csr_a,
1401 block_id, tid, nthreads);
1402 }
1403
wrapper_libxsmm_spmdm_compute_generic_thread(empty_type_wrapper<bfloat16>,const libxsmm_spmdm_handle * handle,char transA,char transB,const bfloat16 * alpha,libxsmm_CSR_sparseslice * A_sparse,const bfloat16 * B,char transC,const bfloat16 * beta,float * C,int block_id,int tid,int nthreads)1404 void wrapper_libxsmm_spmdm_compute_generic_thread(
1405 empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1406 char transA, char transB, const bfloat16* alpha,
1407 libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
1408 const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
1409 return libxsmm_spmdm_compute_bfloat16_thread(
1410 handle, transA, transB, reinterpret_cast<const uint16*>(alpha), A_sparse,
1411 reinterpret_cast<const uint16*>(B), transC,
1412 reinterpret_cast<const uint16*>(beta), C, block_id, tid, nthreads);
1413 }
wrapper_libxsmm_spmdm_compute_generic_thread(empty_type_wrapper<float>,const libxsmm_spmdm_handle * handle,char transA,char transB,const float * alpha,libxsmm_CSR_sparseslice * A_sparse,const float * B,char transC,const float * beta,float * C,int block_id,int tid,int nthreads)1414 void wrapper_libxsmm_spmdm_compute_generic_thread(
1415 empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1416 char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse,
1417 const float* B, char transC, const float* beta, float* C, int block_id,
1418 int tid, int nthreads) {
1419 return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha,
1420 A_sparse, B, transC, beta, C,
1421 block_id, tid, nthreads);
1422 }
1423
1424 template <typename TL, typename TR>
Compute(typename LibxsmmSparseMatMul<TL,TR>::TensorInfoCache * cache,const typename LibxsmmSparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename LibxsmmSparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,const DeviceBase::CpuWorkerThreads * thread_pool,bool transpose_output,MatrixMap * output)1425 inline void LibxsmmSparseMatMul<TL, TR>::Compute(
1426 typename LibxsmmSparseMatMul<TL, TR>::TensorInfoCache* cache,
1427 const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapL& left,
1428 const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right,
1429 bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1430 bool transpose_output, MatrixMap* output) {
1431 if (false) {
1432 // Not handled by libxsmm currently
1433 SparseMatMul<TL, TR>::Compute(
1434 nullptr /* Assumes no cached data for fallback */, left, right,
1435 transpose_left, thread_pool, transpose_output, output);
1436 return;
1437 }
1438 const int num_threads = thread_pool->num_threads;
1439 const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1440 const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1441 const int right_dim0 = right.dimension(0);
1442 const int right_dim1 = right.dimension(1);
1443 CHECK_EQ(left_dim1, right_dim0);
1444 CHECK_EQ(left_dim0,
1445 (transpose_output ? output->dimension(1) : output->dimension(0)));
1446 CHECK_EQ(right_dim1,
1447 (transpose_output ? output->dimension(0) : output->dimension(1)));
1448 if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) {
1449 // Causes problems in libxsmm
1450 SparseMatMul<TL, TR>::Compute(
1451 nullptr /* Assumes no cached data for fallback */, left, right,
1452 transpose_left, thread_pool, transpose_output, output);
1453 return;
1454 }
1455 auto left_data = left.data();
1456 auto right_data = right.data();
1457 auto output_data = output->data();
1458 // Initialize libxsmm for this matrix; make sure another thread doesn't use
1459 // this handle
1460 auto entry =
1461 cache->take_cache_entry(left_dim0, right_dim0, right_dim1, num_threads);
1462 // Convert the left matrix to compressed sparse row (CSR) format
1463 ptrdiff_t total_num_creation_blocks =
1464 libxsmm_spmdm_get_num_createSparseSlice_blocks(&entry->handle);
1465 std::atomic<int> cur_create_block_number;
1466 cur_create_block_number.store(0);
1467 do_on_all_threads(thread_pool, [&](int i) {
1468 while (true) {
1469 int work_item = cur_create_block_number.fetch_add(1);
1470 if (work_item >= total_num_creation_blocks) break;
1471 wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1472 empty_type_wrapper<TL>{}, &entry->handle,
1473 (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item,
1474 i, num_threads);
1475 }
1476 });
1477 // Do matrix-matrix multiplication
1478 ptrdiff_t total_num_mult_blocks =
1479 libxsmm_spmdm_get_num_compute_blocks(&entry->handle);
1480 std::atomic<int> cur_mult_block_number;
1481 cur_mult_block_number.store(0);
1482 do_on_all_threads(thread_pool, [&](int i) {
1483 while (true) {
1484 int work_item = cur_mult_block_number.fetch_add(1);
1485 if (work_item >= total_num_mult_blocks) break;
1486 const TL alpha(1.0); // Stored in a variable so we can get a pointer
1487 const TL beta(0.0); // Stored in a variable so we can get a pointer
1488 wrapper_libxsmm_spmdm_compute_generic_thread(
1489 empty_type_wrapper<TL>{}, &entry->handle,
1490 (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr,
1491 right_data, (transpose_output ? 'T' : 'N'), &beta, output_data,
1492 work_item, i, num_threads);
1493 }
1494 });
1495 // Put handle + CSR storage back into cache
1496 cache->return_cache_entry(std::move(entry));
1497 }
1498
1499 #endif // TENSORFLOW_USE_LIBXSMM
1500
1501 // Here is an overview of the SparseMatMul code. Note that we assume that the
1502 // left matrix is sparse.
1503 //
1504 // The matrix "left" is divided into a grid with blocksize of (M, KL). Each
1505 // block is encoded as a SparseSlice. These grid elements are stored as
1506 // std::vector<std::vector<SparseSlice>>. Each element of the outer vector
1507 // represents M rows of the left matrix. Lets call these elements l_i and lets
1508 // call each element of the inner vector L_mk.
1509 //
1510 // The matrix "right" is divided into a grid with block size KR * NR. Lets
1511 // denote the blocks on the right as R_kn. Note that we ensure that KL divides
1512 // KR so that for each element R_kn, we don't need to multiply it with any
1513 // partial L_mk blocks.
1514 //
1515 // We then multiply each right side block R_kn with the full "left" matrix and
1516 // update the output. These iterations are run sequentially since R_kn are
1517 // packed into the same underlying temporary buffer.
1518 //
1519 // In each iteration we do the following:
1520 // 1. Create slices r_j of R_kn: We split R_kn into vertical blocks with N
1521 // (=128) columns and then concatenating these slices into a buffer. This is
1522 // done so that each slice r_j of R_kn is stored contiguously in memory. Note
1523 // that if R_kj has dimensions (KR, NR), we create NR / N slices, and the
1524 // buffer has dimensions (KR * NR / N, N) (assuming N divides NR).
1525 // 2. For each (l_i, r_j), we compute the inner product using the GEPP function
1526 // and update the output block o_ij. These calls are further blocked to
1527 // reduce the working set size. In each iteration we take IB elements from
1528 // {l_i} and JB elements from {r_j} and compute the IB * JB inner products.
1529 template <typename TL, typename TR>
Compute(typename SparseMatMul<TL,TR>::TensorInfoCache *,const typename SparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,const DeviceBase::CpuWorkerThreads * thread_pool,bool transpose_output,MatrixMap * output)1530 inline void SparseMatMul<TL, TR>::Compute(
1531 typename SparseMatMul<TL, TR>::TensorInfoCache* /*cache*/,
1532 const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1533 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1534 bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1535 bool transpose_output, MatrixMap* output) {
1536 const int num_threads = thread_pool->num_threads;
1537 int KR, NR, KL, JB, IB;
1538 ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL,
1539 &JB, &IB);
1540 // Slice the left matrix
1541 std::vector<std::vector<SparseSlice<TL>*>> left_slices;
1542 std::unique_ptr<BlockingCounter> sparse_slice_counter =
1543 CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()),
1544 transpose_left, M, K, KL, &left_slices, thread_pool);
1545 const int num_left_slices = left_slices.size();
1546
1547 const int right_dim0 = right.dimension(0);
1548 const int right_dim1 = right.dimension(1);
1549 // Allocate buffer for storing slices of right matrix.
1550 // Note buffer needs enough space to hold at most a KR * NR matrix since that
1551 // is the block size per iteration.
1552 const int buffer_num_rows =
1553 std::min(KR, right_dim0) * (std::min(NR, right_dim1) + N - 1) / N;
1554 MatrixR buffer(buffer_num_rows, N);
1555 std::vector<ConstMatrixMapR*> right_slices;
1556
1557 std::vector<SparseSlice<TL>*> block_left_slices;
1558 std::vector<std::function<void(void)>> tasks;
1559 // Number of blocks based on block sizes of KR * NR.
1560 const int num_k_blocks = (right_dim0 + KR - 1) / KR;
1561 const int num_n_blocks = (right_dim1 + NR - 1) / NR;
1562 std::unique_ptr<BlockingCounter> dense_slice_counter;
1563
1564 for (int nb = 0; nb < num_n_blocks; ++nb) {
1565 const int right_num_cols =
1566 std::min(NR, static_cast<int>(right_dim1 - NR * nb));
1567 for (int kb = 0; kb < num_k_blocks; ++kb) {
1568 const int right_num_rows =
1569 std::min(KR, static_cast<int>(right_dim0 - KR * kb));
1570 dense_slice_counter = CreateDenseSlices(
1571 right, kb * KR, right_num_rows, nb * NR, right_num_cols, thread_pool,
1572 &buffer, &right_slices);
1573 const int num_right_slices = right_slices.size();
1574 tasks.reserve(num_left_slices * num_right_slices);
1575 for (int j_outer = 0; j_outer < num_right_slices; j_outer += JB) {
1576 for (int i_outer = 0; i_outer < num_left_slices; i_outer += IB) {
1577 for (int j_inner = j_outer;
1578 j_inner < std::min(num_right_slices, j_outer + JB); ++j_inner) {
1579 const int num_cols = std::min(N, right_num_cols - N * j_inner);
1580 for (int i_inner = i_outer;
1581 i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) {
1582 block_left_slices.clear();
1583 int begin = kb * KR / KL;
1584 int end = std::min<int>((kb + 1) * KR / KL,
1585 (right.dimension(0) + KL - 1) / KL);
1586 DCHECK_LT(begin, end);
1587 block_left_slices.insert(block_left_slices.begin(),
1588 left_slices[i_inner].begin() + begin,
1589 left_slices[i_inner].begin() + end);
1590 tasks.push_back(std::bind(
1591 &ComputeOutputBlock, block_left_slices,
1592 std::ref(*right_slices[j_inner]), num_cols, M * i_inner,
1593 N * j_inner + nb * NR, kb == 0, transpose_output, output));
1594 }
1595 }
1596 }
1597 }
1598 if (sparse_slice_counter) {
1599 sparse_slice_counter->Wait();
1600 sparse_slice_counter.reset(nullptr);
1601 }
1602 if (dense_slice_counter) {
1603 dense_slice_counter->Wait();
1604 dense_slice_counter.reset(nullptr);
1605 }
1606 BlockingCounter bc(tasks.size());
1607 for (const auto& t : tasks) {
1608 thread_pool->workers->Schedule([&bc, &t]() {
1609 t();
1610 bc.DecrementCount();
1611 });
1612 }
1613 bc.Wait();
1614 tasks.clear();
1615 gtl::STLDeleteElements(&right_slices);
1616 right_slices.clear();
1617 }
1618 }
1619 for (auto& left_slice : left_slices) {
1620 gtl::STLDeleteElements(&left_slice);
1621 }
1622 }
1623
1624 #define REGISTER_SPARSE_MATMUL(TA, TB) \
1625 REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \
1626 .Device(DEVICE_CPU) \
1627 .TypeConstraint<TA>("Ta") \
1628 .TypeConstraint<TB>("Tb"), \
1629 SparseMatMulOp<TA, TB, SparseMatMul>);
1630 #ifdef TENSORFLOW_USE_LIBXSMM
1631 #define REGISTER_SPARSE_MATMUL_LIBXSMM(TA, TB) \
1632 REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \
1633 .Device(DEVICE_CPU) \
1634 .TypeConstraint<TA>("Ta") \
1635 .TypeConstraint<TB>("Tb"), \
1636 SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
1637 #endif
1638
1639 REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
1640
1641 REGISTER_SPARSE_MATMUL(float, bfloat16);
1642
1643 REGISTER_SPARSE_MATMUL(bfloat16, float);
1644
1645 #ifdef TENSORFLOW_USE_LIBXSMM
1646 REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
1647 #else
1648 REGISTER_SPARSE_MATMUL(float, float);
1649 #endif
1650
1651 #undef REGISTER_SPARSE_MATMUL
1652
1653 } // end namespace tensorflow
1654