1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
17
18 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
22
23 namespace xla {
24 namespace cpu {
25 namespace {
26
27 using tensorflow::int64;
28
29 // Provides tiled access to an in-memory rank 2 array.
30 class MemoryTile {
31 public:
32 // Constructs a MemoryTile that can operate on tiles consisting of
33 // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at
34 // `major_dim_offset` in the major dimension. The tile size along the minor
35 // dimension is the vector size, and that is implicitly determined by `vsl`.
MemoryTile(VectorSupportLibrary * vsl,llvm::IRBuilder<> * b,llvm::Value * matrix,int64 matrix_size_along_minor_dim,llvm::Value * major_dim_offset,int64 tile_size_along_major_dim)36 MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b,
37 llvm::Value* matrix, int64 matrix_size_along_minor_dim,
38 llvm::Value* major_dim_offset, int64 tile_size_along_major_dim)
39 : vsl_(vsl), b_(b) {
40 pointers_.reserve(tile_size_along_major_dim);
41 for (int64 i = 0; i < tile_size_along_major_dim; i++) {
42 llvm::Value* total_offset =
43 b->CreateMul(b->getInt64(matrix_size_along_minor_dim),
44 b->CreateAdd(b->getInt64(i), major_dim_offset));
45 pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset));
46 }
47 }
48
49 // Load a tile consisting of `tile_size_along_major_dim` vectors from position
50 // {major: `major_dim_offset`, minor: `minor_dim_offset`}.
51 //
52 // Note: `major_dim_offset` is a parameter to the constructor.
LoadTile(llvm::Value * minor_dim_offset) const53 std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const {
54 std::vector<llvm::Value*> result;
55 result.reserve(pointers_.size());
56 for (const auto& pointer : pointers_) {
57 result.push_back(vsl_->LoadVector(pointer, minor_dim_offset));
58 }
59 return result;
60 }
61
62 // Stores `tile` to position {major: `major_dim_offset`, minor:
63 // `minor_dim_offset`}.
64 //
65 // Note: `major_dim_offset` is a parameter to the constructor.
StoreTile(absl::Span<llvm::Value * const> tile,llvm::Value * minor_dim_offset) const66 void StoreTile(absl::Span<llvm::Value* const> tile,
67 llvm::Value* minor_dim_offset) const {
68 CHECK_EQ(tile.size(), pointers_.size());
69 for (int64 i = 0; i < pointers_.size(); i++) {
70 vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset);
71 }
72 }
73
74 // Loads a tile of size [`tile_size_along_major_dim`,
75 // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`,
76 // minor: `minor_dim_offset`} and then broadcasts each element into a vector
77 // of size vsl_.vector_size(). The (i,j)'th element of the return value is
78 // the (i,j)'th element in the tile broadcasted into an LLVM vector.
79 //
80 // Note: `major_dim_offset` is a parameter to the constructor.
LoadBroadcastTile(llvm::Value * minor_dim_offset,int64 tile_size_along_middle_dim) const81 std::vector<std::vector<llvm::Value*>> LoadBroadcastTile(
82 llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const {
83 std::vector<std::vector<llvm::Value*>> result;
84 result.resize(pointers_.size());
85 for (int64 i = 0; i < pointers_.size(); i++) {
86 for (int64 j = 0; j < tile_size_along_middle_dim; j++) {
87 result[i].push_back(vsl_->LoadBroadcast(
88 pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j))));
89 }
90 }
91 return result;
92 }
93
94 private:
95 VectorSupportLibrary* vsl_;
96 llvm::IRBuilder<>* b_;
97 std::vector<llvm::Value*> pointers_;
98 };
99
100 // The base class for the classes representing the GEMV emitter configurations.
101 //
102 // The IR emitted (modulo the LLVM values representing the input and output
103 // buffers) by the row major and column major GEMV emitters should be a function
104 // of their configuration. This is important because their configuration is
105 // used as a key to cache the generated IR.
106 class GemvConfig {
107 public:
108 // Mixin for convenience.
109 template <typename T>
110 struct User {
111 public:
scalar_typexla::cpu::__anon016ac5180111::GemvConfig::User112 PrimitiveType scalar_type() const {
113 return derived().config().scalar_type();
114 }
tile_rowsxla::cpu::__anon016ac5180111::GemvConfig::User115 int64 tile_rows() const { return derived().config().tile_rows(); }
tile_colsxla::cpu::__anon016ac5180111::GemvConfig::User116 int64 tile_cols() const { return derived().config().tile_cols(); }
mxla::cpu::__anon016ac5180111::GemvConfig::User117 int64 m() const { return derived().config().m(); }
kxla::cpu::__anon016ac5180111::GemvConfig::User118 int64 k() const { return derived().config().k(); }
has_addendxla::cpu::__anon016ac5180111::GemvConfig::User119 int64 has_addend() const { return derived().config().has_addend(); }
120
121 private:
derivedxla::cpu::__anon016ac5180111::GemvConfig::User122 const T& derived() const { return *static_cast<const T*>(this); }
123 };
124
scalar_type() const125 PrimitiveType scalar_type() const { return scalar_type_; }
tile_rows() const126 int64 tile_rows() const { return tile_rows_; }
tile_cols() const127 int64 tile_cols() const { return tile_cols_; }
m() const128 int64 m() const { return m_; }
k() const129 int64 k() const { return k_; }
has_addend() const130 bool has_addend() const { return has_addend_; }
131
GetCacheKey() const132 string GetCacheKey() const {
133 return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_",
134 tile_rows(), "_", tile_cols(), "_", m(), "_", k(),
135 has_addend() ? "_with_addend" : "");
136 }
137
138 protected:
GemvConfig(string name,PrimitiveType scalar_type,int64 tile_rows,int64 tile_cols,int64 m,int64 k,bool has_addend)139 explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows,
140 int64 tile_cols, int64 m, int64 k, bool has_addend)
141 : name_(std::move(name)),
142 scalar_type_(scalar_type),
143 tile_rows_(tile_rows),
144 tile_cols_(tile_cols),
145 m_(m),
146 k_(k),
147 has_addend_(has_addend) {}
148
149 private:
150 string name_;
151 PrimitiveType scalar_type_;
152 int64 tile_rows_;
153 int64 tile_cols_;
154 int64 m_;
155 int64 k_;
156 bool has_addend_;
157 };
158
159 // Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the
160 // layout of the vector does not matter). This implementation uses a tiling
161 // scheme to improve performance.
162 //
163 // We logically separate the LHS matrix into four segments:
164 //
165 // +----------------------+---+
166 // | | |
167 // | | |
168 // | A | B |
169 // | | |
170 // | | |
171 // | | |
172 // +----------------------+---+
173 // | C | D |
174 // +----------------------+---+
175 //
176 // where A is the largest submatrix of the LHS that can be evenly dividied into
177 // tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
178 //
179 // +---+---+---+---+ +--+--+--+--+
180 // |M00|M10|M20|M30| |V0|V1|V2|V3|
181 // +---+---+---+---+ +--+--+--+--+
182 // |M01|M11|M21|M31| and |V0|V1|V2|V3|
183 // +---+---+---+---+ +--+--+--+--+
184 // |M02|M12|M22|M32| |V0|V1|V2|V3|
185 // +---+---+---+---+ +--+--+--+--+
186 // |M03|M13|M23|M33| |V0|V1|V2|V3|
187 // +---+---+---+---+ +--+--+--+--+
188 //
189 // (Legend: rows are horizontal and columns are vertical; and each column is one
190 // llvm::Value of a vector type)
191 //
192 // where:
193 //
194 // a. The left tile is from the column major left matrix.
195 // b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3]
196 // vector loaded from the RHS vector.
197 //
198 // As we iterate through the column dimension, we compute the change to the
199 // result vector by an elementwise multiplication between the two tiles above
200 // followed by a reduction along the major dimension:
201 //
202 // +-----------------------------------+
203 // | M00*V0 + M10*V1 + M20*V2 + M30*V3 |
204 // +-----------------------------------+
205 // | M01*V0 + M11*V1 + M21*V2 + M31*V3 |
206 // Result[R:R+4] += +-----------------------------------+
207 // | M02*V0 + M12*V1 + M22*V2 + M32*V3 |
208 // +-----------------------------------+
209 // | M03*V0 + M13*V1 + M23*V2 + M33*V3 |
210 // +-----------------------------------+
211 //
212 // Where R is the starting row for the tile.
213 //
214 // We have an inner epilogue loop to deal with the "C" submatrix and an outer
215 // epilogue loop to deal with the B,D submarix.
216 //
217 // TODO(sanjoy): We should investigate if using gather loads and scatter stores
218 // can be used here have the same inner loop for both column-major and row-major
219 // matrix-vector products.
220 class ColumnMajorMatrixVectorProductEmitter
221 : public GemvConfig::User<ColumnMajorMatrixVectorProductEmitter> {
222 public:
223 class Config : public GemvConfig {
224 public:
Config(PrimitiveType scalar_type,int64 tile_rows,int64 tile_cols,int64 m,int64 k,bool has_addend)225 explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols,
226 int64 m, int64 k, bool has_addend)
227 : GemvConfig(/*name=*/"col_major_gemv", scalar_type,
228 /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
229 /*k=*/k, /*has_addend=*/has_addend) {}
230 };
231
ColumnMajorMatrixVectorProductEmitter(const Config & config,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b)232 ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
233 llvm::Value* rhs, llvm::Value* addend,
234 llvm::Value* result,
235 llvm::IRBuilder<>* b)
236 : config_(config),
237 lhs_(lhs),
238 rhs_(rhs),
239 addend_(addend),
240 result_(result),
241 b_(b),
242 ksl_(b_),
243 vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") {
244 CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows())));
245 CHECK(!has_addend() || addend != nullptr);
246 }
247
248 void Emit();
249
config() const250 const Config& config() const { return config_; }
251
252 private:
253 void EmitOuterLoopBody(llvm::Value* column, int64 column_count,
254 bool is_first_column);
255
GetLhsMemoryTile(llvm::Value * column_start,int64 column_count)256 MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) {
257 return MemoryTile(&vsl_, b_, /*matrix=*/lhs_,
258 /*matrix_size_along_minor_dim=*/m(),
259 /*major_dim_offset=*/column_start,
260 /*tile_size_along_major_dim=*/column_count);
261 }
262
263 // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous
264 // sequence of `count` values, each one broadcasted to the vector width.
LoadRhsTile(llvm::Value * offset,int64 count)265 std::vector<llvm::Value*> LoadRhsTile(llvm::Value* offset, int64 count) {
266 llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset);
267 std::vector<llvm::Value*> result;
268 result.reserve(count);
269 for (int64 i = 0; i < count; i++) {
270 result.push_back(vsl_.LoadBroadcast(base_pointer, i));
271 }
272 return result;
273 }
274
275 void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile,
276 const std::vector<llvm::Value*>& rhs_tile,
277 int64 columns, bool is_first_column);
278
279 void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns,
280 bool is_first_tiled_column);
281
282 Config config_;
283 llvm::Value* lhs_;
284 llvm::Value* rhs_;
285 llvm::Value* addend_;
286 llvm::Value* result_;
287 llvm::IRBuilder<>* b_;
288 KernelSupportLibrary ksl_;
289 VectorSupportLibrary vsl_;
290 };
291
EmitOuterLoopBody(llvm::Value * column,int64 column_count,bool is_first_column)292 void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody(
293 llvm::Value* column, int64 column_count, bool is_first_column) {
294 MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column,
295 /*column_count=*/column_count);
296
297 std::vector<llvm::Value*> rhs_tile =
298 LoadRhsTile(column, /*count=*/column_count);
299 EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile,
300 /*columns=*/column_count, is_first_column);
301 EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column);
302 }
303
Emit()304 void ColumnMajorMatrixVectorProductEmitter::Emit() {
305 // See the comment on the class declaration for the algorithm used here.
306 int64 column_remainder = k() % tile_cols();
307 int64 column_limit = k() - column_remainder;
308
309 ksl_.For("dot.outer.tiled",
310 /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(),
311 [&](llvm::Value* column, bool is_first_column) {
312 EmitOuterLoopBody(column, tile_cols(), is_first_column);
313 });
314
315 if (column_remainder != 0) {
316 EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder,
317 column_limit == 0);
318 }
319 }
320
EmitInnerLoopTiled(MemoryTile * lhs_memory_tile,const std::vector<llvm::Value * > & rhs_tile,int64 columns,bool is_first_column)321 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
322 MemoryTile* lhs_memory_tile, const std::vector<llvm::Value*>& rhs_tile,
323 int64 columns, bool is_first_column) {
324 int64 row_limit = m() - (m() % tile_rows());
325
326 ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
327 /*step=*/tile_rows(), [&](llvm::Value* row) {
328 std::vector<llvm::Value*> lhs_tile =
329 lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row);
330 llvm::Value* accumulator =
331 is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row)
332 : vsl_.GetZeroVector())
333 : vsl_.LoadVector(result_, row);
334 for (int i = 0; i < columns; i++) {
335 accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator);
336 }
337 vsl_.StoreVector(accumulator, result_, row);
338 });
339 }
340
EmitInnerLoopEpilogue(llvm::Value * current_tile_col,int64 columns,bool is_first_tiled_column)341 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
342 llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) {
343 int64 row_start = m() - (m() % tile_rows());
344 if (row_start == m()) {
345 return;
346 }
347
348 llvm::Value* columns_llvm = b_->getInt64(columns);
349
350 // for (col = current_tile_col; col < (columns + current_tile_col); col++)
351 // for (row = row_start, row < m_; row++) {
352 // result[row] += lhs[row, col] * rhs[col]
353 // // Also take into account that if col is 0 then result[row] is not
354 // // initialized.
355 // }
356
357 ksl_.For(
358 "dot.inner.epilg.outer", /*start=*/current_tile_col,
359 /*end=*/b_->CreateAdd(columns_llvm, current_tile_col),
360 /*step=*/1, /*peel_first_iteration=*/false,
361 [&](llvm::Value* col, llvm::Value* is_first_scalar_col) {
362 llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col);
363 llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m()));
364 llvm::Value* lhs_base_pointer =
365 vsl_.ComputeOffsetPointer(lhs_, total_offset);
366 ksl_.For(
367 "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(),
368 /*step=*/1, [&](llvm::Value* scalar_row) {
369 llvm::Value* product = vsl_.Mul(
370 vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element);
371 llvm::Value* setting_result_first_time = b_->CreateAnd(
372 is_first_scalar_col, b_->getInt1(is_first_tiled_column));
373 ksl_.If(
374 setting_result_first_time,
375 /*true_block_generator=*/
376 [&]() {
377 if (addend_) {
378 vsl_.StoreScalar(
379 vsl_.Add(vsl_.LoadScalar(addend_, scalar_row),
380 product),
381 result_, scalar_row);
382 } else {
383 vsl_.StoreScalar(product, result_, scalar_row);
384 }
385 },
386 /*false_block_generator=*/
387 [&]() {
388 vsl_.StoreScalar(
389 vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product),
390 result_, scalar_row);
391 });
392 });
393 });
394 }
395
396 // Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the
397 // layout of the vector does not matter). This implementation uses a tiling
398 // scheme to improve performance.
399 //
400 // We logically separate the LHS matrix into four segments:
401 //
402 // +----------------------+---+
403 // | | |
404 // | | |
405 // | A | B |
406 // | | |
407 // | | |
408 // | | |
409 // +----------------------+---+
410 // | C | D |
411 // +----------------------+---+
412 //
413 // where A is the largest submatrix of the LHS that can be evenly dividied into
414 // tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
415 //
416 // +---+---+---+---+
417 // |M00|M10|M20|M30|
418 // +---+---+---+---+ +--+--+--+--+
419 // |M01|M11|M21|M31| and |V0|V1|V2|V3|
420 // +---+---+---+---+ +--+--+--+--+
421 // |M02|M12|M22|M32|
422 // +---+---+---+---+
423 // |M03|M13|M23|M33|
424 // +---+---+---+---+
425 //
426 // (Legend: rows are horizontal and columns are vertical; and each row is one
427 // llvm::Value of a vector type)
428 //
429 // where:
430 //
431 // a. The left tile is loaded from the row major left matrix.
432 // b. The right vector is loaded from the RHS vector.
433 //
434 // We keep 4 vector accumulators accumulating the following four vector
435 // expressions as we iterate over the row dimension:
436 //
437 // +------+------+------+------+
438 // |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4)
439 // +------+------+------+------+
440 //
441 // In the end we do a horizontal reduction over these 4 vector accumulators to
442 // get 4 values in the result vector.
443 //
444 // We have an inner epilogue loop to deal with the "B" sub-matrix and an outer
445 // epilogue loop to deal with the C,D submatrix.
446 class RowMajorMatrixVectorProductEmitter
447 : public GemvConfig::User<RowMajorMatrixVectorProductEmitter> {
448 public:
449 class Config : public GemvConfig {
450 public:
Config(PrimitiveType scalar_type,int64 tile_rows,int64 tile_cols,int64 m,int64 k,bool has_addend)451 explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols,
452 int64 m, int64 k, bool has_addend)
453 : GemvConfig(/*name=*/"row_major_gemv", scalar_type,
454 /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
455 /*k=*/k, /*has_addend=*/has_addend) {}
456 };
457
RowMajorMatrixVectorProductEmitter(const Config & config,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b)458 RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
459 llvm::Value* rhs, llvm::Value* addend,
460 llvm::Value* result, llvm::IRBuilder<>* b)
461 : config_(config),
462 lhs_(lhs),
463 rhs_(rhs),
464 addend_(addend),
465 result_(result),
466 b_(b),
467 ksl_(b_),
468 vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") {
469 CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols())));
470 CHECK(!has_addend() || addend != nullptr);
471 }
472
473 void Emit();
474
config() const475 const Config& config() const { return config_; }
476
477 private:
GetLhsMemoryTile(llvm::Value * row_start,int64 row_count)478 MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) {
479 return MemoryTile(&vsl_, b_, /*matrix=*/lhs_,
480 /*matrix_size_along_minor_dim=*/k(),
481 /*major_dim_offset=*/row_start,
482 /*tile_size_along_major_dim=*/row_count);
483 }
484
485 void EmitOuterLoopBody(llvm::Value* row, int64 row_count);
486
487 void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows,
488 std::vector<VectorVariable>* vector_accumulators);
489
490 void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows,
491 std::vector<ScalarVariable>* scalar_accumulators);
492
493 Config config_;
494 llvm::Value* lhs_;
495 llvm::Value* rhs_;
496 llvm::Value* addend_;
497 llvm::Value* result_;
498 llvm::IRBuilder<>* b_;
499 KernelSupportLibrary ksl_;
500 VectorSupportLibrary vsl_;
501 };
502
EmitOuterLoopBody(llvm::Value * row,int64 row_count)503 void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row,
504 int64 row_count) {
505 MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row,
506 /*row_count=*/row_count);
507 std::vector<VectorVariable> vector_accumulators;
508 std::vector<ScalarVariable> scalar_accumulators;
509 for (int i = 0; i < row_count; i++) {
510 vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector());
511 scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar());
512 }
513 EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count,
514 &vector_accumulators);
515 EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count,
516 &scalar_accumulators);
517
518 std::vector<llvm::Value*> accumulator_values;
519 std::transform(
520 vector_accumulators.begin(), vector_accumulators.end(),
521 std::back_inserter(accumulator_values),
522 [](const VectorVariable& vector_var) { return vector_var.Get(); });
523
524 std::vector<llvm::Value*> horizontal_sums;
525 if (row_count == vsl_.vector_size()) {
526 if (addend_) {
527 horizontal_sums = vsl_.ComputeHorizontalSums(
528 std::move(accumulator_values), vsl_.LoadVector(addend_, row));
529 } else {
530 horizontal_sums =
531 vsl_.ComputeHorizontalSums(std::move(accumulator_values));
532 }
533 } else {
534 horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values));
535 }
536
537 for (int i = 0; i < row_count; i++) {
538 llvm::Value* result_value =
539 vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get());
540 llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row);
541 if (addend_ && row_count != vsl_.vector_size()) {
542 result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value);
543 }
544 vsl_.StoreScalar(result_value, result_, offset);
545 }
546 }
547
Emit()548 void RowMajorMatrixVectorProductEmitter::Emit() {
549 // See the comment on the class declaration for the algorithm used here.
550 int64 row_remainder = m() % tile_rows();
551 int64 row_limit = m() - row_remainder;
552
553 ksl_.For("dot.outer.tiled",
554 /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(),
555 [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); });
556
557 if (row_remainder != 0) {
558 EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder);
559 }
560 }
561
EmitInnerLoopTiled(MemoryTile * lhs_memory_tile,int64 rows,std::vector<VectorVariable> * vector_accumulators)562 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
563 MemoryTile* lhs_memory_tile, int64 rows,
564 std::vector<VectorVariable>* vector_accumulators) {
565 int64 column_limit = k() - (k() % tile_cols());
566
567 ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
568 /*step=*/tile_cols(), [&](llvm::Value* col) {
569 std::vector<llvm::Value*> lhs_tile =
570 lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col);
571 llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
572 for (int i = 0; i < rows; i++) {
573 llvm::Value* old_sum = (*vector_accumulators)[i].Get();
574 (*vector_accumulators)[i].Set(
575 vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i])));
576 }
577 });
578 }
579
EmitInnerLoopEpilogue(llvm::Value * current_tile_row,int64 rows,std::vector<ScalarVariable> * scalar_accumulators)580 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
581 llvm::Value* current_tile_row, int64 rows,
582 std::vector<ScalarVariable>* scalar_accumulators) {
583 int64 column_start = k() - (k() % tile_cols());
584 if (column_start == k()) {
585 return;
586 }
587
588 for (int r = 0; r < rows; r++) {
589 llvm::Value* total_offset = b_->CreateMul(
590 b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k()));
591 llvm::Value* lhs_base_pointer =
592 vsl_.ComputeOffsetPointer(lhs_, total_offset);
593 ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(),
594 /*step=*/1, [&](llvm::Value* scalar_col) {
595 llvm::Value* product =
596 vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
597 vsl_.LoadScalar(rhs_, scalar_col));
598 llvm::Value* old_value = (*scalar_accumulators)[r].Get();
599 (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product));
600 });
601 }
602 }
603
604 // This class implements a tiled matrix multiplication algorithm, intended for
605 // multiplying small matrices that don't need cache tiling.
606 //
607 // In the future this can be used as the innermost GEBP loop in a GEMM kernel as
608 // described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of
609 // high-performance matrix multiplication." ACM Transactions on Mathematical
610 // Software (TOMS) 34.3 (2008): 12.".
611 //
612 // This only supports canonical dot operations (i.e. where the lhs contraction
613 // dimension is 1 and the rhs contraction dimension is 0) over row major
614 // matrices.
615 class TiledSmallGemmEmitter {
616 public:
617 // Describe the dimensions of the kernel.
618 class Dimensions {
619 public:
Dimensions(int64 m,int64 k,int64 n)620 explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {}
621
m() const622 int64 m() const { return m_; }
k() const623 int64 k() const { return k_; }
n() const624 int64 n() const { return n_; }
625
ToString() const626 string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); }
627
628 private:
629 const int64 m_;
630 const int64 k_;
631 const int64 n_;
632 };
633
634 // Represents the configuration of the emitter. The LLVM IR emitted by the
635 // emitter, modulo the LLVM values holding the input and output buffers, must
636 // be a function of the instance of `Config` passed to it.
637 //
638 // `dims` holds the matrix multiplication dimensions.
639 //
640 // `max_vectorization_width` is the maximum vector width (i.e. the width of
641 // the largest vector register we will use). This can be larger than the
642 // largest vector register supported by the machine -- LLVM will legalize
643 // these large vector widths into legally sized vectors.
644 //
645 // `max_vector_count` is the maximum number of vectors of size
646 // `max_vectorization_width` that we will attempt to process at once.
647 //
648 // `min_vectorization_width` is the smallest vector width the emitter will use
649 // -- below that it will devolve to using a scalar loop.
650 //
651 // The innermost reduction loop executes the matrix multiply in tiles of size
652 // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`,
653 // <vectorization width>] in the RHS.
654 class Config {
655 public:
Config(PrimitiveType scalar_type,Dimensions dims,int64 max_vectorization_width,int64 max_vector_count,int64 min_vectorization_width,int64 tile_size_m,int64 tile_size_k)656 explicit Config(PrimitiveType scalar_type, Dimensions dims,
657 int64 max_vectorization_width, int64 max_vector_count,
658 int64 min_vectorization_width, int64 tile_size_m,
659 int64 tile_size_k)
660 : scalar_type_(scalar_type),
661 dims_(dims),
662 max_vectorization_width_(max_vectorization_width),
663 max_vector_count_(max_vector_count),
664 min_vectorization_width_(min_vectorization_width),
665 tile_size_m_(tile_size_m),
666 tile_size_k_(tile_size_k) {}
667
GetCacheKey() const668 string GetCacheKey() const {
669 return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_",
670 dims().ToString(), "_", max_vectorization_width(),
671 "_", min_vectorization_width(), "_", tile_size_m(),
672 "_", tile_size_k());
673 }
674
scalar_type() const675 PrimitiveType scalar_type() const { return scalar_type_; }
dims() const676 Dimensions dims() const { return dims_; }
max_vectorization_width() const677 int64 max_vectorization_width() const { return max_vectorization_width_; }
max_vector_count() const678 int64 max_vector_count() const { return max_vector_count_; }
min_vectorization_width() const679 int64 min_vectorization_width() const { return min_vectorization_width_; }
680
tile_size_m() const681 int64 tile_size_m() const { return tile_size_m_; }
tile_size_k() const682 int64 tile_size_k() const { return tile_size_k_; }
683
684 private:
685 PrimitiveType scalar_type_;
686 Dimensions dims_;
687 int64 max_vectorization_width_;
688 int64 max_vector_count_;
689 int64 min_vectorization_width_;
690 int64 tile_size_m_;
691 int64 tile_size_k_;
692 };
693
694 // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies
695 // `lhs` with `rhs` and stores the result in `result`.
TiledSmallGemmEmitter(Config config,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * result,llvm::IRBuilder<> * b)696 explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs,
697 llvm::Value* rhs, llvm::Value* result,
698 llvm::IRBuilder<>* b)
699 : lhs_(lhs),
700 rhs_(rhs),
701 result_(result),
702 config_(config),
703 b_(b),
704 ksl_(b_) {
705 CHECK(max_vectorization_width() > 0 &&
706 IsPowerOfTwo(static_cast<uint64>(max_vectorization_width())));
707 CHECK_GT(max_vector_count(), 0);
708 CHECK(min_vectorization_width() > 0 &&
709 IsPowerOfTwo(static_cast<uint64>(min_vectorization_width())));
710 CHECK_GE(max_vectorization_width(), min_vectorization_width());
711 CHECK_GT(tile_size_k(), 0);
712 }
713
714 void Emit();
715
716 private:
717 // The HandleResiduesOnX helpers split the iteration space for dimension X
718 // into a multiple of the tile size on dimension X and an epilogue. These
719 // helpers ultimately call into `EmitTiledGemm` for emitting the
720 // tiled GEMM kernel.
721
722 void HandleResiduesOnN();
723 void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start,
724 llvm::Value* n_end);
725 void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k,
726 llvm::Value* k_start, llvm::Value* k_end,
727 llvm::Value* n_start, llvm::Value* n_end);
728
729 // This emits a tiled GEMM kernel. For a detailed description see the comment
730 // on the implementation.
731 void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k,
732 llvm::Value* k_start, llvm::Value* k_end,
733 llvm::Value* n_start, llvm::Value* n_end,
734 int64 tile_size_m, llvm::Value* m_start,
735 llvm::Value* m_end);
736
GetInt64(int64 value)737 llvm::Value* GetInt64(int64 value) { return b_->getInt64(value); }
738
config() const739 Config config() const { return config_; }
dims() const740 Dimensions dims() const { return config().dims(); }
741
max_vectorization_width() const742 int64 max_vectorization_width() const {
743 return config().max_vectorization_width();
744 }
max_vector_count() const745 int64 max_vector_count() const { return config().max_vector_count(); }
min_vectorization_width() const746 int64 min_vectorization_width() const {
747 return config().min_vectorization_width();
748 }
tile_size_m() const749 int64 tile_size_m() const { return config().tile_size_m(); }
tile_size_k() const750 int64 tile_size_k() const { return config().tile_size_k(); }
scalar_type() const751 PrimitiveType scalar_type() const { return config().scalar_type(); }
752
753 llvm::Value* lhs_;
754 llvm::Value* rhs_;
755 llvm::Value* result_;
756 Config config_;
757
758 llvm::IRBuilder<>* b_;
759 KernelSupportLibrary ksl_;
760 };
761
Emit()762 void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); }
763
HandleResiduesOnN()764 void TiledSmallGemmEmitter::HandleResiduesOnN() {
765 // We can only iterate the `n` dimension for an extent that is divisible by
766 // the vectorization width. So we emit an outer loop that first processes the
767 // largest extent in `n` that is divisible by max_vectorization_width, then
768 // the largest remaining extent that is divisible by max_vectorization_width /
769 // 2 etc.
770
771 int64 current_vectorization_width =
772 max_vector_count() * max_vectorization_width();
773 int64 current_vector_count = max_vector_count();
774
775 int64 n_start = 0;
776 while (n_start != dims().n() &&
777 current_vectorization_width >= min_vectorization_width()) {
778 int64 n_end = dims().n() - (dims().n() % current_vectorization_width);
779 if (n_start != n_end) {
780 VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_,
781 "gemm");
782 HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end));
783 n_start = n_end;
784 }
785 if (current_vector_count == 1) {
786 current_vectorization_width /= 2;
787 } else {
788 current_vector_count--;
789 current_vectorization_width =
790 current_vector_count * max_vectorization_width();
791 }
792 }
793
794 if (n_start != dims().n()) {
795 VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm");
796 ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
797 llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1));
798 HandleResiduesOnK(&vsl, n_i, n_i_next);
799 });
800 }
801 }
802
HandleResiduesOnK(VectorSupportLibrary * vsl,llvm::Value * n_start,llvm::Value * n_end)803 void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
804 llvm::Value* n_start,
805 llvm::Value* n_end) {
806 int64 k_start = 0;
807 int64 k_end = dims().k() - (dims().k() % tile_size_k());
808 if (k_end != k_start) {
809 HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end),
810 n_start, n_end);
811 k_start = k_end;
812 }
813
814 if (k_start != dims().k()) {
815 HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start),
816 GetInt64(dims().k()), n_start, n_end);
817 }
818 }
819
HandleResiduesOnM(VectorSupportLibrary * vsl,int64 tile_size_k,llvm::Value * k_start,llvm::Value * k_end,llvm::Value * n_start,llvm::Value * n_end)820 void TiledSmallGemmEmitter::HandleResiduesOnM(
821 VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
822 llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) {
823 const int64 m_end = dims().m() - dims().m() % tile_size_m();
824 EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(),
825 GetInt64(0), GetInt64(m_end));
826
827 if (m_end != dims().m()) {
828 EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end,
829 dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m()));
830 }
831 }
832
833 // The loop structure is:
834 //
835 // Iterate over dimension M as m:
836 // Iterate over dimension N as n:
837 // Iterate over dimension K as k:
838 // OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n])
839 //
840 // I.e. a just a tiled version of a "naive" GEMM.
841 //
842 // The tiling scheme is as follows:
843 //
844 // Let the LHS be:
845 //
846 // +----+----+----+
847 // | a0 | b0 | c0 | .
848 // +----+----+----+ .
849 // | a1 | b1 | c1 | .
850 // +----+----+----+
851 // .. ..
852 //
853 // and the RHS be:
854 //
855 // +----+----+----+----+
856 // | p0 | p1 | p2 | p3 | .
857 // +----+----+----+----+ .
858 // | q0 | q1 | q2 | q3 | .
859 // +----+----+----+----+
860 // | r0 | r1 | r2 | r3 | .
861 // +----+----+----+----+ .
862 // ...... ......
863 //
864 // and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted
865 // by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4]
866 // matrix that we can increment the result matrix by.
867 //
868 // First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank
869 // 3 array, L, of dimension [2,3,4]:
870 //
871 // L[0,_,_] * L[1,_,_]
872 // *
873 // +----+----+----+----+ * +----+----+----+----+
874 // | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 |
875 // +----+----+----+----+ * +----+----+----+----+
876 // | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 |
877 // +----+----+----+----+ * +----+----+----+----+
878 // | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 |
879 // +----+----+----+----+ * +----+----+----+----+
880 //
881 //
882 // Then we FMA L[0,_,_] with the RHS to get the first row of the result and
883 // L[1,_,_] with the RHS to get the second row of the result. For example,
884 // L[0,_,_] is computed as:
885 //
886 // +----+----+----+----+ +----+----+----+----+
887 // | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | +
888 // +----+----+----+----+ +----+----+----+----+
889 //
890 // +----+----+----+----+ +----+----+----+----+
891 // | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | +
892 // +----+----+----+----+ +----+----+----+----+
893 //
894 // +----+----+----+----+ +----+----+----+----+
895 // | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 |
896 // +----+----+----+----+ +----+----+----+----+
897 //
898 // to get:
899 //
900 // +-------------------+-------------------+-------------------+---------
901 // | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ...
902 // +-------------------+-------------------+-------------------+---------
EmitTiledGemm(VectorSupportLibrary * vsl,int64 tile_size_k,llvm::Value * k_start,llvm::Value * k_end,llvm::Value * n_start,llvm::Value * n_end,int64 tile_size_m,llvm::Value * m_start,llvm::Value * m_end)903 void TiledSmallGemmEmitter::EmitTiledGemm(
904 VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
905 llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end,
906 int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) {
907 ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) {
908 MemoryTile result_memory_tile(vsl, b_, /*matrix=*/result_,
909 /*matrix_size_along_minor_dim=*/dims().n(),
910 /*major_dim_offset=*/m_i,
911 /*tile_size_along_major_dim=*/tile_size_m);
912 MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_,
913 /*matrix_size_along_minor_dim=*/dims().k(),
914 /*major_dim_offset=*/m_i,
915 /*tile_size_along_major_dim=*/tile_size_m);
916 ksl_.For(
917 "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
918 TileVariable result_tile_var(vsl, result_memory_tile.LoadTile(n_i));
919 ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) {
920 MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i,
921 tile_size_k);
922 std::vector<std::vector<llvm::Value*>> lhs_tile =
923 lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k);
924 std::vector<llvm::Value*> rhs_tile = rhs_memory_tile.LoadTile(n_i);
925 std::vector<llvm::Value*> result_tile = result_tile_var.Get();
926 for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) {
927 for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) {
928 result_tile[r_m_i] =
929 vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i],
930 result_tile[r_m_i]);
931 }
932 }
933 result_tile_var.Set(result_tile);
934 });
935
936 result_memory_tile.StoreTile(result_tile_var.Get(), n_i);
937 });
938 });
939 }
940
GetPointerToElementType(llvm::Type * pointer_type)941 llvm::Type* GetPointerToElementType(llvm::Type* pointer_type) {
942 llvm::Type* type =
943 llvm::cast<llvm::PointerType>(pointer_type)->getElementType();
944 while (auto* array_type = llvm::dyn_cast<llvm::ArrayType>(type)) {
945 type = array_type->getElementType();
946 }
947
948 return type->getPointerTo();
949 }
950
951 struct GemvBuffersWithCanonicalType {
952 llvm::Value* lhs_canonicalized;
953 llvm::Value* rhs_canonicalized;
954 llvm::Value* addend_canonicalized;
955 llvm::Value* result_canonicalized;
956 };
957
GetGemvBuffersWithCanonicalType(llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b)958 GemvBuffersWithCanonicalType GetGemvBuffersWithCanonicalType(
959 llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend,
960 llvm::Value* result, llvm::IRBuilder<>* b) {
961 // We characterize a GEMV operation via M and K, since N is implicitly 1.
962 // This means the GEMV that multiplies (say) [5,6] with [6,1] is implemented
963 // by the same GEMV that multiplies [5,6] with [1,6]. However, the
964 // `llvm::Types` for the inputs to the two GEMVs don't match (in a trivial
965 // sense -- the in memory representations are the same) since they're computed
966 // from the `xla::Shape`s. Since we want to be able to call the same
967 // `llvm::Function` for the two GEMVs we canonicalize the types of the GEMV
968 // inputs here into the same type.
969 GemvBuffersWithCanonicalType buffers_with_canonical_type;
970 llvm::Type* lhs_type = lhs->getType();
971 llvm::Type* rhs_type = rhs->getType();
972 llvm::Type* addend_type = addend ? addend->getType() : nullptr;
973 llvm::Type* result_type = result->getType();
974
975 buffers_with_canonical_type.lhs_canonicalized =
976 b->CreateBitCast(lhs, GetPointerToElementType(lhs_type));
977 buffers_with_canonical_type.rhs_canonicalized =
978 b->CreateBitCast(rhs, GetPointerToElementType(rhs_type));
979 buffers_with_canonical_type.addend_canonicalized =
980 addend ? b->CreateBitCast(addend, GetPointerToElementType(addend_type))
981 : nullptr;
982 buffers_with_canonical_type.result_canonicalized =
983 b->CreateBitCast(result, GetPointerToElementType(result_type));
984
985 return buffers_with_canonical_type;
986 }
987
988 } // namespace
989
EmitRowMajorGemv(PrimitiveType scalar_type,int64 tile_rows,int64 tile_cols,int64 m,int64 k,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b,const HloModuleConfig & module_config)990 void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
991 int64 tile_cols, int64 m, int64 k, llvm::Value* lhs,
992 llvm::Value* rhs, llvm::Value* addend,
993 llvm::Value* result, llvm::IRBuilder<>* b,
994 const HloModuleConfig& module_config) {
995 RowMajorMatrixVectorProductEmitter::Config config(
996 /*scalar_type=*/scalar_type,
997 /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
998 /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
999
1000 GemvBuffersWithCanonicalType canonical_inputs =
1001 GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
1002
1003 KernelSupportLibrary::EmitAndCallOutlinedKernel(
1004 module_config, b, config.GetCacheKey(),
1005 canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
1006 canonical_inputs.addend_canonicalized,
1007 canonical_inputs.result_canonicalized,
1008 [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs,
1009 llvm::Value* addend,
1010 llvm::Value* result) {
1011 RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
1012 result, b);
1013 emitter.Emit();
1014 });
1015 }
1016
EmitColumnMajorGemv(PrimitiveType scalar_type,int64 tile_rows,int64 tile_cols,int64 m,int64 k,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b,const HloModuleConfig & module_config)1017 void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
1018 int64 tile_cols, int64 m, int64 k, llvm::Value* lhs,
1019 llvm::Value* rhs, llvm::Value* addend,
1020 llvm::Value* result, llvm::IRBuilder<>* b,
1021 const HloModuleConfig& module_config) {
1022 ColumnMajorMatrixVectorProductEmitter::Config config(
1023 /*scalar_type=*/scalar_type,
1024 /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
1025 /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
1026
1027 GemvBuffersWithCanonicalType canonical_inputs =
1028 GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
1029
1030 KernelSupportLibrary::EmitAndCallOutlinedKernel(
1031 module_config, b, config.GetCacheKey(),
1032 canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
1033 canonical_inputs.addend_canonicalized,
1034 canonical_inputs.result_canonicalized,
1035 [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs,
1036 llvm::Value* addend,
1037 llvm::Value* result) {
1038 ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
1039 result, b);
1040 emitter.Emit();
1041 });
1042 }
1043
EmitSmallGemm(PrimitiveType scalar_type,int64 m,int64 k,int64 n,int64 max_vectorization_width,int64 max_vector_count,int64 min_vectorization_width,int64 tile_size_m,int64 tile_size_k,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * result,llvm::IRBuilder<> * b,const HloModuleConfig & module_config)1044 void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n,
1045 int64 max_vectorization_width, int64 max_vector_count,
1046 int64 min_vectorization_width, int64 tile_size_m,
1047 int64 tile_size_k, llvm::Value* lhs, llvm::Value* rhs,
1048 llvm::Value* result, llvm::IRBuilder<>* b,
1049 const HloModuleConfig& module_config) {
1050 TiledSmallGemmEmitter::Config config(
1051 /*scalar_type=*/scalar_type,
1052 TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
1053 /*max_vectorization_width=*/max_vectorization_width,
1054 /*max_vector_count=*/max_vector_count,
1055 /*min_vectorization_width=*/min_vectorization_width,
1056 /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k);
1057
1058 KernelSupportLibrary::EmitAndCallOutlinedKernel(
1059 module_config, b, config.GetCacheKey(), lhs, rhs, result,
1060 [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result) {
1061 TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs,
1062 /*rhs=*/rhs,
1063 /*result=*/result, b);
1064 small_gemm_emitter.Emit();
1065 });
1066 }
1067
1068 } // namespace cpu
1069 } // namespace xla
1070