1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
17
18 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
22
23 namespace xla {
24 namespace cpu {
25 namespace {
26
27 using ::int64_t;
28
29 // Provides tiled access to an in-memory rank 2 array.
30 class MemoryTile {
31 public:
32 // Constructs a MemoryTile that can operate on tiles consisting of
33 // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at
34 // `major_dim_offset` in the major dimension. The tile size along the minor
35 // dimension is the vector size, and that is implicitly determined by `vsl`.
MemoryTile(VectorSupportLibrary * vsl,llvm::IRBuilder<> * b,llvm::Value * matrix,int64_t matrix_size_along_minor_dim,llvm::Value * major_dim_offset,int64_t tile_size_along_major_dim)36 MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b,
37 llvm::Value* matrix, int64_t matrix_size_along_minor_dim,
38 llvm::Value* major_dim_offset, int64_t tile_size_along_major_dim)
39 : vsl_(vsl), b_(b) {
40 pointers_.reserve(tile_size_along_major_dim);
41 for (int64_t i = 0; i < tile_size_along_major_dim; i++) {
42 llvm::Value* total_offset =
43 b->CreateMul(b->getInt64(matrix_size_along_minor_dim),
44 b->CreateAdd(b->getInt64(i), major_dim_offset));
45 pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset));
46 }
47 }
48
49 // Load a tile consisting of `tile_size_along_major_dim` vectors from position
50 // {major: `major_dim_offset`, minor: `minor_dim_offset`}.
51 //
52 // Note: `major_dim_offset` is a parameter to the constructor.
LoadTile(llvm::Value * minor_dim_offset) const53 std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const {
54 std::vector<llvm::Value*> result;
55 result.reserve(pointers_.size());
56 for (const auto& pointer : pointers_) {
57 result.push_back(vsl_->LoadVector(pointer, minor_dim_offset));
58 }
59 return result;
60 }
61
62 // Stores `tile` to position {major: `major_dim_offset`, minor:
63 // `minor_dim_offset`}.
64 //
65 // Note: `major_dim_offset` is a parameter to the constructor.
StoreTile(absl::Span<llvm::Value * const> tile,llvm::Value * minor_dim_offset) const66 void StoreTile(absl::Span<llvm::Value* const> tile,
67 llvm::Value* minor_dim_offset) const {
68 CHECK_EQ(tile.size(), pointers_.size());
69 for (int64_t i = 0; i < pointers_.size(); i++) {
70 vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset);
71 }
72 }
73
74 // Loads a tile of size [`tile_size_along_major_dim`,
75 // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`,
76 // minor: `minor_dim_offset`} and then broadcasts each element into a vector
77 // of size vsl_.vector_size(). The (i,j)'th element of the return value is
78 // the (i,j)'th element in the tile broadcasted into an LLVM vector.
79 //
80 // Note: `major_dim_offset` is a parameter to the constructor.
LoadBroadcastTile(llvm::Value * minor_dim_offset,int64_t tile_size_along_middle_dim) const81 std::vector<std::vector<llvm::Value*>> LoadBroadcastTile(
82 llvm::Value* minor_dim_offset, int64_t tile_size_along_middle_dim) const {
83 std::vector<std::vector<llvm::Value*>> result;
84 result.resize(pointers_.size());
85 for (int64_t i = 0; i < pointers_.size(); i++) {
86 for (int64_t j = 0; j < tile_size_along_middle_dim; j++) {
87 result[i].push_back(vsl_->LoadBroadcast(
88 pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j))));
89 }
90 }
91 return result;
92 }
93
94 private:
95 VectorSupportLibrary* vsl_;
96 llvm::IRBuilder<>* b_;
97 std::vector<llvm::Value*> pointers_;
98 };
99
100 // The base class for the classes representing the GEMV emitter configurations.
101 //
102 // The IR emitted (modulo the LLVM values representing the input and output
103 // buffers) by the row major and column major GEMV emitters should be a function
104 // of their configuration. This is important because their configuration is
105 // used as a key to cache the generated IR.
106 class GemvConfig {
107 public:
108 // Mixin for convenience.
109 template <typename T>
110 struct User {
111 public:
scalar_typexla::cpu::__anon8d179a540111::GemvConfig::User112 PrimitiveType scalar_type() const {
113 return derived().config().scalar_type();
114 }
tile_rowsxla::cpu::__anon8d179a540111::GemvConfig::User115 int64_t tile_rows() const { return derived().config().tile_rows(); }
tile_colsxla::cpu::__anon8d179a540111::GemvConfig::User116 int64_t tile_cols() const { return derived().config().tile_cols(); }
mxla::cpu::__anon8d179a540111::GemvConfig::User117 int64_t m() const { return derived().config().m(); }
kxla::cpu::__anon8d179a540111::GemvConfig::User118 int64_t k() const { return derived().config().k(); }
has_addendxla::cpu::__anon8d179a540111::GemvConfig::User119 int64_t has_addend() const { return derived().config().has_addend(); }
120
121 private:
derivedxla::cpu::__anon8d179a540111::GemvConfig::User122 const T& derived() const { return *static_cast<const T*>(this); }
123 };
124
scalar_type() const125 PrimitiveType scalar_type() const { return scalar_type_; }
tile_rows() const126 int64_t tile_rows() const { return tile_rows_; }
tile_cols() const127 int64_t tile_cols() const { return tile_cols_; }
m() const128 int64_t m() const { return m_; }
k() const129 int64_t k() const { return k_; }
has_addend() const130 bool has_addend() const { return has_addend_; }
131
GetCacheKey() const132 std::string GetCacheKey() const {
133 return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_",
134 tile_rows(), "_", tile_cols(), "_", m(), "_", k(),
135 has_addend() ? "_with_addend" : "");
136 }
137
138 protected:
GemvConfig(std::string name,PrimitiveType scalar_type,int64_t tile_rows,int64_t tile_cols,int64_t m,int64_t k,bool has_addend)139 explicit GemvConfig(std::string name, PrimitiveType scalar_type,
140 int64_t tile_rows, int64_t tile_cols, int64_t m,
141 int64_t k, bool has_addend)
142 : name_(std::move(name)),
143 scalar_type_(scalar_type),
144 tile_rows_(tile_rows),
145 tile_cols_(tile_cols),
146 m_(m),
147 k_(k),
148 has_addend_(has_addend) {}
149
150 private:
151 std::string name_;
152 PrimitiveType scalar_type_;
153 int64_t tile_rows_;
154 int64_t tile_cols_;
155 int64_t m_;
156 int64_t k_;
157 bool has_addend_;
158 };
159
160 // Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the
161 // layout of the vector does not matter). This implementation uses a tiling
162 // scheme to improve performance.
163 //
164 // We logically separate the LHS matrix into four segments:
165 //
166 // +----------------------+---+
167 // | | |
168 // | | |
169 // | A | B |
170 // | | |
171 // | | |
172 // | | |
173 // +----------------------+---+
174 // | C | D |
175 // +----------------------+---+
176 //
177 // where A is the largest submatrix of the LHS that can be evenly divided into
178 // tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
179 //
180 // +---+---+---+---+ +--+--+--+--+
181 // |M00|M10|M20|M30| |V0|V1|V2|V3|
182 // +---+---+---+---+ +--+--+--+--+
183 // |M01|M11|M21|M31| and |V0|V1|V2|V3|
184 // +---+---+---+---+ +--+--+--+--+
185 // |M02|M12|M22|M32| |V0|V1|V2|V3|
186 // +---+---+---+---+ +--+--+--+--+
187 // |M03|M13|M23|M33| |V0|V1|V2|V3|
188 // +---+---+---+---+ +--+--+--+--+
189 //
190 // (Legend: rows are horizontal and columns are vertical; and each column is one
191 // llvm::Value of a vector type)
192 //
193 // where:
194 //
195 // a. The left tile is from the column major left matrix.
196 // b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3]
197 // vector loaded from the RHS vector.
198 //
199 // As we iterate through the column dimension, we compute the change to the
200 // result vector by an elementwise multiplication between the two tiles above
201 // followed by a reduction along the major dimension:
202 //
203 // +-----------------------------------+
204 // | M00*V0 + M10*V1 + M20*V2 + M30*V3 |
205 // +-----------------------------------+
206 // | M01*V0 + M11*V1 + M21*V2 + M31*V3 |
207 // Result[R:R+4] += +-----------------------------------+
208 // | M02*V0 + M12*V1 + M22*V2 + M32*V3 |
209 // +-----------------------------------+
210 // | M03*V0 + M13*V1 + M23*V2 + M33*V3 |
211 // +-----------------------------------+
212 //
213 // Where R is the starting row for the tile.
214 //
215 // We have an inner epilogue loop to deal with the "C" submatrix and an outer
216 // epilogue loop to deal with the B,D submatrix.
217 //
218 // TODO(sanjoy): We should investigate if using gather loads and scatter stores
219 // can be used here have the same inner loop for both column-major and row-major
220 // matrix-vector products.
221 class ColumnMajorMatrixVectorProductEmitter
222 : public GemvConfig::User<ColumnMajorMatrixVectorProductEmitter> {
223 public:
224 class Config : public GemvConfig {
225 public:
Config(PrimitiveType scalar_type,int64_t tile_rows,int64_t tile_cols,int64_t m,int64_t k,bool has_addend)226 explicit Config(PrimitiveType scalar_type, int64_t tile_rows,
227 int64_t tile_cols, int64_t m, int64_t k, bool has_addend)
228 : GemvConfig(/*name=*/"col_major_gemv", scalar_type,
229 /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
230 /*k=*/k, /*has_addend=*/has_addend) {}
231 };
232
ColumnMajorMatrixVectorProductEmitter(const Config & config,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b)233 ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
234 llvm::Value* rhs, llvm::Value* addend,
235 llvm::Value* result,
236 llvm::IRBuilder<>* b)
237 : config_(config),
238 lhs_(lhs),
239 rhs_(rhs),
240 addend_(addend),
241 result_(result),
242 b_(b),
243 ksl_(b_),
244 vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") {
245 CHECK(tile_rows() > 0 &&
246 absl::has_single_bit(static_cast<uint64_t>(tile_rows())));
247 CHECK(!has_addend() || addend != nullptr);
248 }
249
250 void Emit();
251
config() const252 const Config& config() const { return config_; }
253
254 private:
255 void EmitOuterLoopBody(llvm::Value* column, int64_t column_count,
256 bool is_first_column);
257
GetLhsMemoryTile(llvm::Value * column_start,int64_t column_count)258 MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64_t column_count) {
259 return MemoryTile(&vsl_, b_, /*matrix=*/lhs_,
260 /*matrix_size_along_minor_dim=*/m(),
261 /*major_dim_offset=*/column_start,
262 /*tile_size_along_major_dim=*/column_count);
263 }
264
265 // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous
266 // sequence of `count` values, each one broadcasted to the vector width.
LoadRhsTile(llvm::Value * offset,int64_t count)267 std::vector<llvm::Value*> LoadRhsTile(llvm::Value* offset, int64_t count) {
268 llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset);
269 std::vector<llvm::Value*> result;
270 result.reserve(count);
271 for (int64_t i = 0; i < count; i++) {
272 result.push_back(vsl_.LoadBroadcast(base_pointer, i));
273 }
274 return result;
275 }
276
277 void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile,
278 const std::vector<llvm::Value*>& rhs_tile,
279 int64_t columns, bool is_first_column);
280
281 void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64_t columns,
282 bool is_first_tiled_column);
283
284 Config config_;
285 llvm::Value* lhs_;
286 llvm::Value* rhs_;
287 llvm::Value* addend_;
288 llvm::Value* result_;
289 llvm::IRBuilder<>* b_;
290 KernelSupportLibrary ksl_;
291 VectorSupportLibrary vsl_;
292 };
293
EmitOuterLoopBody(llvm::Value * column,int64_t column_count,bool is_first_column)294 void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody(
295 llvm::Value* column, int64_t column_count, bool is_first_column) {
296 MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column,
297 /*column_count=*/column_count);
298
299 std::vector<llvm::Value*> rhs_tile =
300 LoadRhsTile(column, /*count=*/column_count);
301 EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile,
302 /*columns=*/column_count, is_first_column);
303 EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column);
304 }
305
Emit()306 void ColumnMajorMatrixVectorProductEmitter::Emit() {
307 // See the comment on the class declaration for the algorithm used here.
308 int64_t column_remainder = k() % tile_cols();
309 int64_t column_limit = k() - column_remainder;
310
311 ksl_.For("dot.outer.tiled",
312 /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(),
313 [&](llvm::Value* column, bool is_first_column) {
314 EmitOuterLoopBody(column, tile_cols(), is_first_column);
315 });
316
317 if (column_remainder != 0) {
318 EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder,
319 column_limit == 0);
320 }
321 }
322
EmitInnerLoopTiled(MemoryTile * lhs_memory_tile,const std::vector<llvm::Value * > & rhs_tile,int64_t columns,bool is_first_column)323 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
324 MemoryTile* lhs_memory_tile, const std::vector<llvm::Value*>& rhs_tile,
325 int64_t columns, bool is_first_column) {
326 int64_t row_limit = m() - (m() % tile_rows());
327
328 ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
329 /*step=*/tile_rows(), [&](llvm::Value* row) {
330 std::vector<llvm::Value*> lhs_tile =
331 lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row);
332 llvm::Value* accumulator =
333 is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row)
334 : vsl_.GetZeroVector())
335 : vsl_.LoadVector(result_, row);
336 for (int i = 0; i < columns; i++) {
337 accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator);
338 }
339 vsl_.StoreVector(accumulator, result_, row);
340 });
341 }
342
EmitInnerLoopEpilogue(llvm::Value * current_tile_col,int64_t columns,bool is_first_tiled_column)343 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
344 llvm::Value* current_tile_col, int64_t columns,
345 bool is_first_tiled_column) {
346 int64_t row_start = m() - (m() % tile_rows());
347 if (row_start == m()) {
348 return;
349 }
350
351 llvm::Value* columns_llvm = b_->getInt64(columns);
352
353 // for (col = current_tile_col; col < (columns + current_tile_col); col++)
354 // for (row = row_start, row < m_; row++) {
355 // result[row] += lhs[row, col] * rhs[col]
356 // // Also take into account that if col is 0 then result[row] is not
357 // // initialized.
358 // }
359
360 ksl_.For(
361 "dot.inner.epilg.outer", /*start=*/current_tile_col,
362 /*end=*/b_->CreateAdd(columns_llvm, current_tile_col),
363 /*step=*/1, /*peel_first_iteration=*/false,
364 [&](llvm::Value* col, llvm::Value* is_first_scalar_col) {
365 llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col);
366 llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m()));
367 llvm::Value* lhs_base_pointer =
368 vsl_.ComputeOffsetPointer(lhs_, total_offset);
369 ksl_.For(
370 "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(),
371 /*step=*/1, [&](llvm::Value* scalar_row) {
372 llvm::Value* product = vsl_.Mul(
373 vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element);
374 llvm::Value* setting_result_first_time = b_->CreateAnd(
375 is_first_scalar_col, b_->getInt1(is_first_tiled_column));
376 ksl_.If(
377 setting_result_first_time,
378 /*true_block_generator=*/
379 [&]() {
380 if (addend_) {
381 vsl_.StoreScalar(
382 vsl_.Add(vsl_.LoadScalar(addend_, scalar_row),
383 product),
384 result_, scalar_row);
385 } else {
386 vsl_.StoreScalar(product, result_, scalar_row);
387 }
388 },
389 /*false_block_generator=*/
390 [&]() {
391 vsl_.StoreScalar(
392 vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product),
393 result_, scalar_row);
394 });
395 });
396 });
397 }
398
399 // Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the
400 // layout of the vector does not matter). This implementation uses a tiling
401 // scheme to improve performance.
402 //
403 // We logically separate the LHS matrix into four segments:
404 //
405 // +----------------------+---+
406 // | | |
407 // | | |
408 // | A | B |
409 // | | |
410 // | | |
411 // | | |
412 // +----------------------+---+
413 // | C | D |
414 // +----------------------+---+
415 //
416 // where A is the largest submatrix of the LHS that can be evenly divided into
417 // tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
418 //
419 // +---+---+---+---+
420 // |M00|M10|M20|M30|
421 // +---+---+---+---+ +--+--+--+--+
422 // |M01|M11|M21|M31| and |V0|V1|V2|V3|
423 // +---+---+---+---+ +--+--+--+--+
424 // |M02|M12|M22|M32|
425 // +---+---+---+---+
426 // |M03|M13|M23|M33|
427 // +---+---+---+---+
428 //
429 // (Legend: rows are horizontal and columns are vertical; and each row is one
430 // llvm::Value of a vector type)
431 //
432 // where:
433 //
434 // a. The left tile is loaded from the row major left matrix.
435 // b. The right vector is loaded from the RHS vector.
436 //
437 // We keep 4 vector accumulators accumulating the following four vector
438 // expressions as we iterate over the row dimension:
439 //
440 // +------+------+------+------+
441 // |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4)
442 // +------+------+------+------+
443 //
444 // In the end we do a horizontal reduction over these 4 vector accumulators to
445 // get 4 values in the result vector.
446 //
447 // We have an inner epilogue loop to deal with the "B" sub-matrix and an outer
448 // epilogue loop to deal with the C,D submatrix.
449 class RowMajorMatrixVectorProductEmitter
450 : public GemvConfig::User<RowMajorMatrixVectorProductEmitter> {
451 public:
452 class Config : public GemvConfig {
453 public:
Config(PrimitiveType scalar_type,int64_t tile_rows,int64_t tile_cols,int64_t m,int64_t k,bool has_addend)454 explicit Config(PrimitiveType scalar_type, int64_t tile_rows,
455 int64_t tile_cols, int64_t m, int64_t k, bool has_addend)
456 : GemvConfig(/*name=*/"row_major_gemv", scalar_type,
457 /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
458 /*k=*/k, /*has_addend=*/has_addend) {}
459 };
460
RowMajorMatrixVectorProductEmitter(const Config & config,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b)461 RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
462 llvm::Value* rhs, llvm::Value* addend,
463 llvm::Value* result, llvm::IRBuilder<>* b)
464 : config_(config),
465 lhs_(lhs),
466 rhs_(rhs),
467 addend_(addend),
468 result_(result),
469 b_(b),
470 ksl_(b_),
471 vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") {
472 CHECK(tile_cols() > 0 &&
473 absl::has_single_bit(static_cast<uint64_t>(tile_cols())));
474 CHECK(!has_addend() || addend != nullptr);
475 }
476
477 void Emit();
478
config() const479 const Config& config() const { return config_; }
480
481 private:
GetLhsMemoryTile(llvm::Value * row_start,int64_t row_count)482 MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64_t row_count) {
483 return MemoryTile(&vsl_, b_, /*matrix=*/lhs_,
484 /*matrix_size_along_minor_dim=*/k(),
485 /*major_dim_offset=*/row_start,
486 /*tile_size_along_major_dim=*/row_count);
487 }
488
489 void EmitOuterLoopBody(llvm::Value* row, int64_t row_count);
490
491 void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64_t rows,
492 std::vector<VectorVariable>* vector_accumulators);
493
494 void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64_t rows,
495 std::vector<ScalarVariable>* scalar_accumulators);
496
497 Config config_;
498 llvm::Value* lhs_;
499 llvm::Value* rhs_;
500 llvm::Value* addend_;
501 llvm::Value* result_;
502 llvm::IRBuilder<>* b_;
503 KernelSupportLibrary ksl_;
504 VectorSupportLibrary vsl_;
505 };
506
EmitOuterLoopBody(llvm::Value * row,int64_t row_count)507 void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row,
508 int64_t row_count) {
509 MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row,
510 /*row_count=*/row_count);
511 std::vector<VectorVariable> vector_accumulators;
512 std::vector<ScalarVariable> scalar_accumulators;
513 vector_accumulators.reserve(row_count);
514 scalar_accumulators.reserve(row_count);
515 for (int64_t i = 0; i < row_count; i++) {
516 vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector());
517 scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar());
518 }
519 EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count,
520 &vector_accumulators);
521 EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count,
522 &scalar_accumulators);
523
524 std::vector<llvm::Value*> accumulator_values;
525 std::transform(
526 vector_accumulators.begin(), vector_accumulators.end(),
527 std::back_inserter(accumulator_values),
528 [](const VectorVariable& vector_var) { return vector_var.Get(); });
529
530 std::vector<llvm::Value*> horizontal_sums;
531 if (row_count == vsl_.vector_size()) {
532 if (addend_) {
533 horizontal_sums = vsl_.ComputeHorizontalSums(
534 std::move(accumulator_values), vsl_.LoadVector(addend_, row));
535 } else {
536 horizontal_sums =
537 vsl_.ComputeHorizontalSums(std::move(accumulator_values));
538 }
539 } else {
540 horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values));
541 }
542
543 for (int i = 0; i < row_count; i++) {
544 llvm::Value* result_value =
545 vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get());
546 llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row);
547 if (addend_ && row_count != vsl_.vector_size()) {
548 result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value);
549 }
550 vsl_.StoreScalar(result_value, result_, offset);
551 }
552 }
553
Emit()554 void RowMajorMatrixVectorProductEmitter::Emit() {
555 // See the comment on the class declaration for the algorithm used here.
556 int64_t row_remainder = m() % tile_rows();
557 int64_t row_limit = m() - row_remainder;
558
559 ksl_.For("dot.outer.tiled",
560 /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(),
561 [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); });
562
563 if (row_remainder != 0) {
564 EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder);
565 }
566 }
567
EmitInnerLoopTiled(MemoryTile * lhs_memory_tile,int64_t rows,std::vector<VectorVariable> * vector_accumulators)568 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
569 MemoryTile* lhs_memory_tile, int64_t rows,
570 std::vector<VectorVariable>* vector_accumulators) {
571 int64_t column_limit = k() - (k() % tile_cols());
572
573 ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
574 /*step=*/tile_cols(), [&](llvm::Value* col) {
575 std::vector<llvm::Value*> lhs_tile =
576 lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col);
577 llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
578 for (int i = 0; i < rows; i++) {
579 llvm::Value* old_sum = (*vector_accumulators)[i].Get();
580 (*vector_accumulators)[i].Set(
581 vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i])));
582 }
583 });
584 }
585
EmitInnerLoopEpilogue(llvm::Value * current_tile_row,int64_t rows,std::vector<ScalarVariable> * scalar_accumulators)586 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
587 llvm::Value* current_tile_row, int64_t rows,
588 std::vector<ScalarVariable>* scalar_accumulators) {
589 int64_t column_start = k() - (k() % tile_cols());
590 if (column_start == k()) {
591 return;
592 }
593
594 for (int r = 0; r < rows; r++) {
595 llvm::Value* total_offset = b_->CreateMul(
596 b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k()));
597 llvm::Value* lhs_base_pointer =
598 vsl_.ComputeOffsetPointer(lhs_, total_offset);
599 ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(),
600 /*step=*/1, [&](llvm::Value* scalar_col) {
601 llvm::Value* product =
602 vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
603 vsl_.LoadScalar(rhs_, scalar_col));
604 llvm::Value* old_value = (*scalar_accumulators)[r].Get();
605 (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product));
606 });
607 }
608 }
609
610 // This class implements a tiled matrix multiplication algorithm, intended for
611 // multiplying small matrices that don't need cache tiling.
612 //
613 // In the future this can be used as the innermost GEBP loop in a GEMM kernel as
614 // described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of
615 // high-performance matrix multiplication." ACM Transactions on Mathematical
616 // Software (TOMS) 34.3 (2008): 12.".
617 //
618 // This only supports canonical dot operations (i.e. where the lhs contraction
619 // dimension is 1 and the rhs contraction dimension is 0) over row major
620 // matrices.
621 class TiledSmallGemmEmitter {
622 public:
623 // Describe the dimensions of the kernel.
624 class Dimensions {
625 public:
Dimensions(int64_t m,int64_t k,int64_t n)626 explicit Dimensions(int64_t m, int64_t k, int64_t n)
627 : m_(m), k_(k), n_(n) {}
628
m() const629 int64_t m() const { return m_; }
k() const630 int64_t k() const { return k_; }
n() const631 int64_t n() const { return n_; }
632
ToString() const633 std::string ToString() const {
634 return absl::StrCat(m(), "x", k(), "x", n());
635 }
636
637 private:
638 const int64_t m_;
639 const int64_t k_;
640 const int64_t n_;
641 };
642
643 // Represents the configuration of the emitter. The LLVM IR emitted by the
644 // emitter, modulo the LLVM values holding the input and output buffers, must
645 // be a function of the instance of `Config` passed to it.
646 //
647 // `dims` holds the matrix multiplication dimensions.
648 //
649 // `max_vectorization_width` is the maximum vector width (i.e. the width of
650 // the largest vector register we will use). This can be larger than the
651 // largest vector register supported by the machine -- LLVM will legalize
652 // these large vector widths into legally sized vectors.
653 //
654 // `max_vector_count` is the maximum number of vectors of size
655 // `max_vectorization_width` that we will attempt to process at once.
656 //
657 // `min_vectorization_width` is the smallest vector width the emitter will use
658 // -- below that it will devolve to using a scalar loop.
659 //
660 // The innermost reduction loop executes the matrix multiply in tiles of size
661 // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`,
662 // <vectorization width>] in the RHS.
663 class Config {
664 public:
Config(PrimitiveType scalar_type,Dimensions dims,int64_t max_vectorization_width,int64_t max_vector_count,int64_t min_vectorization_width,int64_t tile_size_m,int64_t tile_size_k)665 explicit Config(PrimitiveType scalar_type, Dimensions dims,
666 int64_t max_vectorization_width, int64_t max_vector_count,
667 int64_t min_vectorization_width, int64_t tile_size_m,
668 int64_t tile_size_k)
669 : scalar_type_(scalar_type),
670 dims_(dims),
671 max_vectorization_width_(max_vectorization_width),
672 max_vector_count_(max_vector_count),
673 min_vectorization_width_(min_vectorization_width),
674 tile_size_m_(tile_size_m),
675 tile_size_k_(tile_size_k) {}
676
GetCacheKey() const677 std::string GetCacheKey() const {
678 return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_",
679 dims().ToString(), "_", max_vectorization_width(),
680 "_", min_vectorization_width(), "_", tile_size_m(),
681 "_", tile_size_k());
682 }
683
scalar_type() const684 PrimitiveType scalar_type() const { return scalar_type_; }
dims() const685 Dimensions dims() const { return dims_; }
max_vectorization_width() const686 int64_t max_vectorization_width() const { return max_vectorization_width_; }
max_vector_count() const687 int64_t max_vector_count() const { return max_vector_count_; }
min_vectorization_width() const688 int64_t min_vectorization_width() const { return min_vectorization_width_; }
689
tile_size_m() const690 int64_t tile_size_m() const { return tile_size_m_; }
tile_size_k() const691 int64_t tile_size_k() const { return tile_size_k_; }
692
693 private:
694 PrimitiveType scalar_type_;
695 Dimensions dims_;
696 int64_t max_vectorization_width_;
697 int64_t max_vector_count_;
698 int64_t min_vectorization_width_;
699 int64_t tile_size_m_;
700 int64_t tile_size_k_;
701 };
702
703 // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies
704 // `lhs` with `rhs` and stores the result in `result`.
TiledSmallGemmEmitter(Config config,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * result,llvm::IRBuilder<> * b)705 explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs,
706 llvm::Value* rhs, llvm::Value* result,
707 llvm::IRBuilder<>* b)
708 : lhs_(lhs),
709 rhs_(rhs),
710 result_(result),
711 config_(config),
712 b_(b),
713 ksl_(b_) {
714 CHECK(
715 max_vectorization_width() > 0 &&
716 absl::has_single_bit(static_cast<uint64_t>(max_vectorization_width())));
717 CHECK_GT(max_vector_count(), 0);
718 CHECK(
719 min_vectorization_width() > 0 &&
720 absl::has_single_bit(static_cast<uint64_t>(min_vectorization_width())));
721 CHECK_GE(max_vectorization_width(), min_vectorization_width());
722 CHECK_GT(tile_size_k(), 0);
723 }
724
725 void Emit();
726
727 private:
728 // The HandleResiduesOnX helpers split the iteration space for dimension X
729 // into a multiple of the tile size on dimension X and an epilogue. These
730 // helpers ultimately call into `EmitTiledGemm` for emitting the
731 // tiled GEMM kernel.
732
733 void HandleResiduesOnN();
734 void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start,
735 llvm::Value* n_end);
736 void HandleResiduesOnM(VectorSupportLibrary* vsl, int64_t tile_size_k,
737 llvm::Value* k_start, llvm::Value* k_end,
738 llvm::Value* n_start, llvm::Value* n_end);
739
740 // This emits a tiled GEMM kernel. For a detailed description see the comment
741 // on the implementation.
742 void EmitTiledGemm(VectorSupportLibrary* vsl, int64_t tile_size_k,
743 llvm::Value* k_start, llvm::Value* k_end,
744 llvm::Value* n_start, llvm::Value* n_end,
745 int64_t tile_size_m, llvm::Value* m_start,
746 llvm::Value* m_end);
747
GetInt64(int64_t value)748 llvm::Value* GetInt64(int64_t value) { return b_->getInt64(value); }
749
config() const750 Config config() const { return config_; }
dims() const751 Dimensions dims() const { return config().dims(); }
752
max_vectorization_width() const753 int64_t max_vectorization_width() const {
754 return config().max_vectorization_width();
755 }
max_vector_count() const756 int64_t max_vector_count() const { return config().max_vector_count(); }
min_vectorization_width() const757 int64_t min_vectorization_width() const {
758 return config().min_vectorization_width();
759 }
tile_size_m() const760 int64_t tile_size_m() const { return config().tile_size_m(); }
tile_size_k() const761 int64_t tile_size_k() const { return config().tile_size_k(); }
scalar_type() const762 PrimitiveType scalar_type() const { return config().scalar_type(); }
763
764 llvm::Value* lhs_;
765 llvm::Value* rhs_;
766 llvm::Value* result_;
767 Config config_;
768
769 llvm::IRBuilder<>* b_;
770 KernelSupportLibrary ksl_;
771 };
772
Emit()773 void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); }
774
HandleResiduesOnN()775 void TiledSmallGemmEmitter::HandleResiduesOnN() {
776 // We can only iterate the `n` dimension for an extent that is divisible by
777 // the vectorization width. So we emit an outer loop that first processes the
778 // largest extent in `n` that is divisible by max_vectorization_width, then
779 // the largest remaining extent that is divisible by max_vectorization_width /
780 // 2 etc.
781
782 int64_t current_vectorization_width =
783 max_vector_count() * max_vectorization_width();
784 int64_t current_vector_count = max_vector_count();
785
786 int64_t n_start = 0;
787 while (n_start != dims().n() &&
788 current_vectorization_width >= min_vectorization_width()) {
789 int64_t n_end = dims().n() - (dims().n() % current_vectorization_width);
790 if (n_start != n_end) {
791 VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_,
792 "gemm");
793 HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end));
794 n_start = n_end;
795 }
796 if (current_vector_count == 1) {
797 current_vectorization_width /= 2;
798 } else {
799 current_vector_count--;
800 current_vectorization_width =
801 current_vector_count * max_vectorization_width();
802 }
803 }
804
805 if (n_start != dims().n()) {
806 VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm");
807 ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
808 llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1));
809 HandleResiduesOnK(&vsl, n_i, n_i_next);
810 });
811 }
812 }
813
HandleResiduesOnK(VectorSupportLibrary * vsl,llvm::Value * n_start,llvm::Value * n_end)814 void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
815 llvm::Value* n_start,
816 llvm::Value* n_end) {
817 int64_t k_start = 0;
818 int64_t k_end = dims().k() - (dims().k() % tile_size_k());
819 if (k_end != k_start) {
820 HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end),
821 n_start, n_end);
822 k_start = k_end;
823 }
824
825 if (k_start != dims().k()) {
826 HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start),
827 GetInt64(dims().k()), n_start, n_end);
828 }
829 }
830
HandleResiduesOnM(VectorSupportLibrary * vsl,int64_t tile_size_k,llvm::Value * k_start,llvm::Value * k_end,llvm::Value * n_start,llvm::Value * n_end)831 void TiledSmallGemmEmitter::HandleResiduesOnM(
832 VectorSupportLibrary* vsl, int64_t tile_size_k, llvm::Value* k_start,
833 llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) {
834 const int64_t m_end = dims().m() - dims().m() % tile_size_m();
835 EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(),
836 GetInt64(0), GetInt64(m_end));
837
838 if (m_end != dims().m()) {
839 EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end,
840 dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m()));
841 }
842 }
843
844 // The loop structure is:
845 //
846 // Iterate over dimension M as m:
847 // Iterate over dimension N as n:
848 // Iterate over dimension K as k:
849 // OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n])
850 //
851 // I.e. a just a tiled version of a "naive" GEMM.
852 //
853 // The tiling scheme is as follows:
854 //
855 // Let the LHS be:
856 //
857 // +----+----+----+
858 // | a0 | b0 | c0 | .
859 // +----+----+----+ .
860 // | a1 | b1 | c1 | .
861 // +----+----+----+
862 // .. ..
863 //
864 // and the RHS be:
865 //
866 // +----+----+----+----+
867 // | p0 | p1 | p2 | p3 | .
868 // +----+----+----+----+ .
869 // | q0 | q1 | q2 | q3 | .
870 // +----+----+----+----+
871 // | r0 | r1 | r2 | r3 | .
872 // +----+----+----+----+ .
873 // ...... ......
874 //
875 // and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted
876 // by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4]
877 // matrix that we can increment the result matrix by.
878 //
879 // First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank
880 // 3 array, L, of dimension [2,3,4]:
881 //
882 // L[0,_,_] * L[1,_,_]
883 // *
884 // +----+----+----+----+ * +----+----+----+----+
885 // | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 |
886 // +----+----+----+----+ * +----+----+----+----+
887 // | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 |
888 // +----+----+----+----+ * +----+----+----+----+
889 // | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 |
890 // +----+----+----+----+ * +----+----+----+----+
891 //
892 //
893 // Then we FMA L[0,_,_] with the RHS to get the first row of the result and
894 // L[1,_,_] with the RHS to get the second row of the result. For example,
895 // L[0,_,_] is computed as:
896 //
897 // +----+----+----+----+ +----+----+----+----+
898 // | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | +
899 // +----+----+----+----+ +----+----+----+----+
900 //
901 // +----+----+----+----+ +----+----+----+----+
902 // | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | +
903 // +----+----+----+----+ +----+----+----+----+
904 //
905 // +----+----+----+----+ +----+----+----+----+
906 // | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 |
907 // +----+----+----+----+ +----+----+----+----+
908 //
909 // to get:
910 //
911 // +-------------------+-------------------+-------------------+---------
912 // | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ...
913 // +-------------------+-------------------+-------------------+---------
EmitTiledGemm(VectorSupportLibrary * vsl,int64_t tile_size_k,llvm::Value * k_start,llvm::Value * k_end,llvm::Value * n_start,llvm::Value * n_end,int64_t tile_size_m,llvm::Value * m_start,llvm::Value * m_end)914 void TiledSmallGemmEmitter::EmitTiledGemm(
915 VectorSupportLibrary* vsl, int64_t tile_size_k, llvm::Value* k_start,
916 llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end,
917 int64_t tile_size_m, llvm::Value* m_start, llvm::Value* m_end) {
918 ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) {
919 MemoryTile result_memory_tile(vsl, b_, /*matrix=*/result_,
920 /*matrix_size_along_minor_dim=*/dims().n(),
921 /*major_dim_offset=*/m_i,
922 /*tile_size_along_major_dim=*/tile_size_m);
923 MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_,
924 /*matrix_size_along_minor_dim=*/dims().k(),
925 /*major_dim_offset=*/m_i,
926 /*tile_size_along_major_dim=*/tile_size_m);
927 ksl_.For(
928 "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
929 TileVariable result_tile_var(vsl, result_memory_tile.LoadTile(n_i));
930 ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) {
931 MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i,
932 tile_size_k);
933 std::vector<std::vector<llvm::Value*>> lhs_tile =
934 lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k);
935 std::vector<llvm::Value*> rhs_tile = rhs_memory_tile.LoadTile(n_i);
936 std::vector<llvm::Value*> result_tile = result_tile_var.Get();
937 for (int64_t r_m_i = 0; r_m_i < tile_size_m; r_m_i++) {
938 for (int64_t r_k_i = 0; r_k_i < tile_size_k; r_k_i++) {
939 result_tile[r_m_i] =
940 vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i],
941 result_tile[r_m_i]);
942 }
943 }
944 result_tile_var.Set(result_tile);
945 });
946
947 result_memory_tile.StoreTile(result_tile_var.Get(), n_i);
948 });
949 });
950 }
951
GetPointerToElementType(llvm::Type * pointer_type)952 llvm::Type* GetPointerToElementType(llvm::Type* pointer_type) {
953 if (pointer_type->isOpaquePointerTy()) return pointer_type;
954
955 llvm::Type* type = pointer_type->getNonOpaquePointerElementType();
956 while (auto* array_type = llvm::dyn_cast<llvm::ArrayType>(type)) {
957 type = array_type->getElementType();
958 }
959
960 return type->getPointerTo();
961 }
962
963 struct GemvBuffersWithCanonicalType {
964 llvm::Value* lhs_canonicalized;
965 llvm::Value* rhs_canonicalized;
966 llvm::Value* addend_canonicalized;
967 llvm::Value* result_canonicalized;
968 };
969
GetGemvBuffersWithCanonicalType(llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b)970 GemvBuffersWithCanonicalType GetGemvBuffersWithCanonicalType(
971 llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend,
972 llvm::Value* result, llvm::IRBuilder<>* b) {
973 // We characterize a GEMV operation via M and K, since N is implicitly 1.
974 // This means the GEMV that multiplies (say) [5,6] with [6,1] is implemented
975 // by the same GEMV that multiplies [5,6] with [1,6]. However, the
976 // `llvm::Types` for the inputs to the two GEMVs don't match (in a trivial
977 // sense -- the in memory representations are the same) since they're computed
978 // from the `xla::Shape`s. Since we want to be able to call the same
979 // `llvm::Function` for the two GEMVs we canonicalize the types of the GEMV
980 // inputs here into the same type.
981 GemvBuffersWithCanonicalType buffers_with_canonical_type;
982 llvm::Type* lhs_type = lhs->getType();
983 llvm::Type* rhs_type = rhs->getType();
984 llvm::Type* addend_type = addend ? addend->getType() : nullptr;
985 llvm::Type* result_type = result->getType();
986
987 buffers_with_canonical_type.lhs_canonicalized =
988 b->CreateBitCast(lhs, GetPointerToElementType(lhs_type));
989 buffers_with_canonical_type.rhs_canonicalized =
990 b->CreateBitCast(rhs, GetPointerToElementType(rhs_type));
991 buffers_with_canonical_type.addend_canonicalized =
992 addend ? b->CreateBitCast(addend, GetPointerToElementType(addend_type))
993 : nullptr;
994 buffers_with_canonical_type.result_canonicalized =
995 b->CreateBitCast(result, GetPointerToElementType(result_type));
996
997 return buffers_with_canonical_type;
998 }
999
1000 } // namespace
1001
EmitRowMajorGemv(PrimitiveType scalar_type,int64_t tile_rows,int64_t tile_cols,int64_t m,int64_t k,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b,const HloModuleConfig & module_config)1002 void EmitRowMajorGemv(PrimitiveType scalar_type, int64_t tile_rows,
1003 int64_t tile_cols, int64_t m, int64_t k, llvm::Value* lhs,
1004 llvm::Value* rhs, llvm::Value* addend,
1005 llvm::Value* result, llvm::IRBuilder<>* b,
1006 const HloModuleConfig& module_config) {
1007 RowMajorMatrixVectorProductEmitter::Config config(
1008 /*scalar_type=*/scalar_type,
1009 /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
1010 /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
1011
1012 GemvBuffersWithCanonicalType canonical_inputs =
1013 GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
1014
1015 KernelSupportLibrary::EmitAndCallOutlinedKernel(
1016 module_config, b, config.GetCacheKey(),
1017 canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
1018 canonical_inputs.addend_canonicalized,
1019 canonical_inputs.result_canonicalized,
1020 [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs,
1021 llvm::Value* addend,
1022 llvm::Value* result) {
1023 RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
1024 result, b);
1025 emitter.Emit();
1026 });
1027 }
1028
EmitColumnMajorGemv(PrimitiveType scalar_type,int64_t tile_rows,int64_t tile_cols,int64_t m,int64_t k,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b,const HloModuleConfig & module_config)1029 void EmitColumnMajorGemv(PrimitiveType scalar_type, int64_t tile_rows,
1030 int64_t tile_cols, int64_t m, int64_t k,
1031 llvm::Value* lhs, llvm::Value* rhs,
1032 llvm::Value* addend, llvm::Value* result,
1033 llvm::IRBuilder<>* b,
1034 const HloModuleConfig& module_config) {
1035 ColumnMajorMatrixVectorProductEmitter::Config config(
1036 /*scalar_type=*/scalar_type,
1037 /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
1038 /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
1039
1040 GemvBuffersWithCanonicalType canonical_inputs =
1041 GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
1042
1043 KernelSupportLibrary::EmitAndCallOutlinedKernel(
1044 module_config, b, config.GetCacheKey(),
1045 canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
1046 canonical_inputs.addend_canonicalized,
1047 canonical_inputs.result_canonicalized,
1048 [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs,
1049 llvm::Value* addend,
1050 llvm::Value* result) {
1051 ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
1052 result, b);
1053 emitter.Emit();
1054 });
1055 }
1056
EmitSmallGemm(PrimitiveType scalar_type,int64_t m,int64_t k,int64_t n,int64_t max_vectorization_width,int64_t max_vector_count,int64_t min_vectorization_width,int64_t tile_size_m,int64_t tile_size_k,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * result,llvm::IRBuilder<> * b,const HloModuleConfig & module_config)1057 void EmitSmallGemm(PrimitiveType scalar_type, int64_t m, int64_t k, int64_t n,
1058 int64_t max_vectorization_width, int64_t max_vector_count,
1059 int64_t min_vectorization_width, int64_t tile_size_m,
1060 int64_t tile_size_k, llvm::Value* lhs, llvm::Value* rhs,
1061 llvm::Value* result, llvm::IRBuilder<>* b,
1062 const HloModuleConfig& module_config) {
1063 TiledSmallGemmEmitter::Config config(
1064 /*scalar_type=*/scalar_type,
1065 TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
1066 /*max_vectorization_width=*/max_vectorization_width,
1067 /*max_vector_count=*/max_vector_count,
1068 /*min_vectorization_width=*/min_vectorization_width,
1069 /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k);
1070
1071 KernelSupportLibrary::EmitAndCallOutlinedKernel(
1072 module_config, b, config.GetCacheKey(), lhs, rhs, result,
1073 [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result) {
1074 TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs,
1075 /*rhs=*/rhs,
1076 /*result=*/result, b);
1077 small_gemm_emitter.Emit();
1078 });
1079 }
1080
1081 } // namespace cpu
1082 } // namespace xla
1083