• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
17 
18 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
22 
23 namespace xla {
24 namespace cpu {
25 namespace {
26 
27 using tensorflow::int64;
28 
29 // Provides tiled access to an in-memory rank 2 array.
30 class MemoryTile {
31  public:
32   // Constructs a MemoryTile that can operate on tiles consisting of
33   // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at
34   // `major_dim_offset` in the major dimension.  The tile size along the minor
35   // dimension is the vector size, and that is implicitly determined by `vsl`.
MemoryTile(VectorSupportLibrary * vsl,llvm::IRBuilder<> * b,llvm::Value * matrix,int64 matrix_size_along_minor_dim,llvm::Value * major_dim_offset,int64 tile_size_along_major_dim)36   MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b,
37              llvm::Value* matrix, int64 matrix_size_along_minor_dim,
38              llvm::Value* major_dim_offset, int64 tile_size_along_major_dim)
39       : vsl_(vsl), b_(b) {
40     pointers_.reserve(tile_size_along_major_dim);
41     for (int64 i = 0; i < tile_size_along_major_dim; i++) {
42       llvm::Value* total_offset =
43           b->CreateMul(b->getInt64(matrix_size_along_minor_dim),
44                        b->CreateAdd(b->getInt64(i), major_dim_offset));
45       pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset));
46     }
47   }
48 
49   // Load a tile consisting of `tile_size_along_major_dim` vectors from position
50   // {major: `major_dim_offset`, minor: `minor_dim_offset`}.
51   //
52   // Note: `major_dim_offset` is a parameter to the constructor.
LoadTile(llvm::Value * minor_dim_offset) const53   std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const {
54     std::vector<llvm::Value*> result;
55     result.reserve(pointers_.size());
56     for (const auto& pointer : pointers_) {
57       result.push_back(vsl_->LoadVector(pointer, minor_dim_offset));
58     }
59     return result;
60   }
61 
62   // Stores `tile` to position {major: `major_dim_offset`, minor:
63   // `minor_dim_offset`}.
64   //
65   // Note: `major_dim_offset` is a parameter to the constructor.
StoreTile(absl::Span<llvm::Value * const> tile,llvm::Value * minor_dim_offset) const66   void StoreTile(absl::Span<llvm::Value* const> tile,
67                  llvm::Value* minor_dim_offset) const {
68     CHECK_EQ(tile.size(), pointers_.size());
69     for (int64 i = 0; i < pointers_.size(); i++) {
70       vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset);
71     }
72   }
73 
74   // Loads a tile of size [`tile_size_along_major_dim`,
75   // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`,
76   // minor: `minor_dim_offset`} and then broadcasts each element into a vector
77   // of size vsl_.vector_size().  The (i,j)'th element of the return value is
78   // the (i,j)'th element in the tile broadcasted into an LLVM vector.
79   //
80   // Note: `major_dim_offset` is a parameter to the constructor.
LoadBroadcastTile(llvm::Value * minor_dim_offset,int64 tile_size_along_middle_dim) const81   std::vector<std::vector<llvm::Value*>> LoadBroadcastTile(
82       llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const {
83     std::vector<std::vector<llvm::Value*>> result;
84     result.resize(pointers_.size());
85     for (int64 i = 0; i < pointers_.size(); i++) {
86       for (int64 j = 0; j < tile_size_along_middle_dim; j++) {
87         result[i].push_back(vsl_->LoadBroadcast(
88             pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j))));
89       }
90     }
91     return result;
92   }
93 
94  private:
95   VectorSupportLibrary* vsl_;
96   llvm::IRBuilder<>* b_;
97   std::vector<llvm::Value*> pointers_;
98 };
99 
100 // The base class for the classes representing the GEMV emitter configurations.
101 //
102 // The IR emitted (modulo the LLVM values representing the input and output
103 // buffers) by the row major and column major GEMV emitters should be a function
104 // of their configuration.  This is important because their configuration is
105 // used as a key to cache the generated IR.
106 class GemvConfig {
107  public:
108   // Mixin for convenience.
109   template <typename T>
110   struct User {
111    public:
scalar_typexla::cpu::__anon016ac5180111::GemvConfig::User112     PrimitiveType scalar_type() const {
113       return derived().config().scalar_type();
114     }
tile_rowsxla::cpu::__anon016ac5180111::GemvConfig::User115     int64 tile_rows() const { return derived().config().tile_rows(); }
tile_colsxla::cpu::__anon016ac5180111::GemvConfig::User116     int64 tile_cols() const { return derived().config().tile_cols(); }
mxla::cpu::__anon016ac5180111::GemvConfig::User117     int64 m() const { return derived().config().m(); }
kxla::cpu::__anon016ac5180111::GemvConfig::User118     int64 k() const { return derived().config().k(); }
has_addendxla::cpu::__anon016ac5180111::GemvConfig::User119     int64 has_addend() const { return derived().config().has_addend(); }
120 
121    private:
derivedxla::cpu::__anon016ac5180111::GemvConfig::User122     const T& derived() const { return *static_cast<const T*>(this); }
123   };
124 
scalar_type() const125   PrimitiveType scalar_type() const { return scalar_type_; }
tile_rows() const126   int64 tile_rows() const { return tile_rows_; }
tile_cols() const127   int64 tile_cols() const { return tile_cols_; }
m() const128   int64 m() const { return m_; }
k() const129   int64 k() const { return k_; }
has_addend() const130   bool has_addend() const { return has_addend_; }
131 
GetCacheKey() const132   string GetCacheKey() const {
133     return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_",
134                         tile_rows(), "_", tile_cols(), "_", m(), "_", k(),
135                         has_addend() ? "_with_addend" : "");
136   }
137 
138  protected:
GemvConfig(string name,PrimitiveType scalar_type,int64 tile_rows,int64 tile_cols,int64 m,int64 k,bool has_addend)139   explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows,
140                       int64 tile_cols, int64 m, int64 k, bool has_addend)
141       : name_(std::move(name)),
142         scalar_type_(scalar_type),
143         tile_rows_(tile_rows),
144         tile_cols_(tile_cols),
145         m_(m),
146         k_(k),
147         has_addend_(has_addend) {}
148 
149  private:
150   string name_;
151   PrimitiveType scalar_type_;
152   int64 tile_rows_;
153   int64 tile_cols_;
154   int64 m_;
155   int64 k_;
156   bool has_addend_;
157 };
158 
159 // Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the
160 // layout of the vector does not matter).  This implementation uses a tiling
161 // scheme to improve performance.
162 //
163 // We logically separate the LHS matrix into four segments:
164 //
165 //   +----------------------+---+
166 //   |                      |   |
167 //   |                      |   |
168 //   |         A            | B |
169 //   |                      |   |
170 //   |                      |   |
171 //   |                      |   |
172 //   +----------------------+---+
173 //   |         C            | D |
174 //   +----------------------+---+
175 //
176 // where A is the largest submatrix of the LHS that can be evenly dividied into
177 // tiles.  For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
178 //
179 //   +---+---+---+---+       +--+--+--+--+
180 //   |M00|M10|M20|M30|       |V0|V1|V2|V3|
181 //   +---+---+---+---+       +--+--+--+--+
182 //   |M01|M11|M21|M31| and   |V0|V1|V2|V3|
183 //   +---+---+---+---+       +--+--+--+--+
184 //   |M02|M12|M22|M32|       |V0|V1|V2|V3|
185 //   +---+---+---+---+       +--+--+--+--+
186 //   |M03|M13|M23|M33|       |V0|V1|V2|V3|
187 //   +---+---+---+---+       +--+--+--+--+
188 //
189 // (Legend: rows are horizontal and columns are vertical; and each column is one
190 // llvm::Value of a vector type)
191 //
192 // where:
193 //
194 //   a. The left tile is from the column major left matrix.
195 //   b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3]
196 //      vector loaded from the RHS vector.
197 //
198 // As we iterate through the column dimension, we compute the change to the
199 // result vector by an elementwise multiplication between the two tiles above
200 // followed by a reduction along the major dimension:
201 //
202 //                     +-----------------------------------+
203 //                     | M00*V0 + M10*V1 + M20*V2 + M30*V3 |
204 //                     +-----------------------------------+
205 //                     | M01*V0 + M11*V1 + M21*V2 + M31*V3 |
206 // Result[R:R+4] +=    +-----------------------------------+
207 //                     | M02*V0 + M12*V1 + M22*V2 + M32*V3 |
208 //                     +-----------------------------------+
209 //                     | M03*V0 + M13*V1 + M23*V2 + M33*V3 |
210 //                     +-----------------------------------+
211 //
212 // Where R is the starting row for the tile.
213 //
214 // We have an inner epilogue loop to deal with the "C" submatrix and an outer
215 // epilogue loop to deal with the B,D submarix.
216 //
217 // TODO(sanjoy): We should investigate if using gather loads and scatter stores
218 // can be used here have the same inner loop for both column-major and row-major
219 // matrix-vector products.
220 class ColumnMajorMatrixVectorProductEmitter
221     : public GemvConfig::User<ColumnMajorMatrixVectorProductEmitter> {
222  public:
223   class Config : public GemvConfig {
224    public:
Config(PrimitiveType scalar_type,int64 tile_rows,int64 tile_cols,int64 m,int64 k,bool has_addend)225     explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols,
226                     int64 m, int64 k, bool has_addend)
227         : GemvConfig(/*name=*/"col_major_gemv", scalar_type,
228                      /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
229                      /*k=*/k, /*has_addend=*/has_addend) {}
230   };
231 
ColumnMajorMatrixVectorProductEmitter(const Config & config,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b)232   ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
233                                         llvm::Value* rhs, llvm::Value* addend,
234                                         llvm::Value* result,
235                                         llvm::IRBuilder<>* b)
236       : config_(config),
237         lhs_(lhs),
238         rhs_(rhs),
239         addend_(addend),
240         result_(result),
241         b_(b),
242         ksl_(b_),
243         vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") {
244     CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows())));
245     CHECK(!has_addend() || addend != nullptr);
246   }
247 
248   void Emit();
249 
config() const250   const Config& config() const { return config_; }
251 
252  private:
253   void EmitOuterLoopBody(llvm::Value* column, int64 column_count,
254                          bool is_first_column);
255 
GetLhsMemoryTile(llvm::Value * column_start,int64 column_count)256   MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) {
257     return MemoryTile(&vsl_, b_, /*matrix=*/lhs_,
258                       /*matrix_size_along_minor_dim=*/m(),
259                       /*major_dim_offset=*/column_start,
260                       /*tile_size_along_major_dim=*/column_count);
261   }
262 
263   // Load a tile of values from the RHS.  For the RHS a "tile" is a contiguous
264   // sequence of `count` values, each one broadcasted to the vector width.
LoadRhsTile(llvm::Value * offset,int64 count)265   std::vector<llvm::Value*> LoadRhsTile(llvm::Value* offset, int64 count) {
266     llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset);
267     std::vector<llvm::Value*> result;
268     result.reserve(count);
269     for (int64 i = 0; i < count; i++) {
270       result.push_back(vsl_.LoadBroadcast(base_pointer, i));
271     }
272     return result;
273   }
274 
275   void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile,
276                           const std::vector<llvm::Value*>& rhs_tile,
277                           int64 columns, bool is_first_column);
278 
279   void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns,
280                              bool is_first_tiled_column);
281 
282   Config config_;
283   llvm::Value* lhs_;
284   llvm::Value* rhs_;
285   llvm::Value* addend_;
286   llvm::Value* result_;
287   llvm::IRBuilder<>* b_;
288   KernelSupportLibrary ksl_;
289   VectorSupportLibrary vsl_;
290 };
291 
EmitOuterLoopBody(llvm::Value * column,int64 column_count,bool is_first_column)292 void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody(
293     llvm::Value* column, int64 column_count, bool is_first_column) {
294   MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column,
295                                                 /*column_count=*/column_count);
296 
297   std::vector<llvm::Value*> rhs_tile =
298       LoadRhsTile(column, /*count=*/column_count);
299   EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile,
300                      /*columns=*/column_count, is_first_column);
301   EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column);
302 }
303 
Emit()304 void ColumnMajorMatrixVectorProductEmitter::Emit() {
305   // See the comment on the class declaration for the algorithm used here.
306   int64 column_remainder = k() % tile_cols();
307   int64 column_limit = k() - column_remainder;
308 
309   ksl_.For("dot.outer.tiled",
310            /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(),
311            [&](llvm::Value* column, bool is_first_column) {
312              EmitOuterLoopBody(column, tile_cols(), is_first_column);
313            });
314 
315   if (column_remainder != 0) {
316     EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder,
317                       column_limit == 0);
318   }
319 }
320 
EmitInnerLoopTiled(MemoryTile * lhs_memory_tile,const std::vector<llvm::Value * > & rhs_tile,int64 columns,bool is_first_column)321 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
322     MemoryTile* lhs_memory_tile, const std::vector<llvm::Value*>& rhs_tile,
323     int64 columns, bool is_first_column) {
324   int64 row_limit = m() - (m() % tile_rows());
325 
326   ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
327            /*step=*/tile_rows(), [&](llvm::Value* row) {
328              std::vector<llvm::Value*> lhs_tile =
329                  lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row);
330              llvm::Value* accumulator =
331                  is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row)
332                                             : vsl_.GetZeroVector())
333                                  : vsl_.LoadVector(result_, row);
334              for (int i = 0; i < columns; i++) {
335                accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator);
336              }
337              vsl_.StoreVector(accumulator, result_, row);
338            });
339 }
340 
EmitInnerLoopEpilogue(llvm::Value * current_tile_col,int64 columns,bool is_first_tiled_column)341 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
342     llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) {
343   int64 row_start = m() - (m() % tile_rows());
344   if (row_start == m()) {
345     return;
346   }
347 
348   llvm::Value* columns_llvm = b_->getInt64(columns);
349 
350   // for (col = current_tile_col; col < (columns + current_tile_col); col++)
351   //   for (row = row_start, row < m_; row++) {
352   //     result[row] += lhs[row, col] * rhs[col]
353   //     // Also take into account that if col is 0 then result[row] is not
354   //     // initialized.
355   //   }
356 
357   ksl_.For(
358       "dot.inner.epilg.outer", /*start=*/current_tile_col,
359       /*end=*/b_->CreateAdd(columns_llvm, current_tile_col),
360       /*step=*/1, /*peel_first_iteration=*/false,
361       [&](llvm::Value* col, llvm::Value* is_first_scalar_col) {
362         llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col);
363         llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m()));
364         llvm::Value* lhs_base_pointer =
365             vsl_.ComputeOffsetPointer(lhs_, total_offset);
366         ksl_.For(
367             "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(),
368             /*step=*/1, [&](llvm::Value* scalar_row) {
369               llvm::Value* product = vsl_.Mul(
370                   vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element);
371               llvm::Value* setting_result_first_time = b_->CreateAnd(
372                   is_first_scalar_col, b_->getInt1(is_first_tiled_column));
373               ksl_.If(
374                   setting_result_first_time,
375                   /*true_block_generator=*/
376                   [&]() {
377                     if (addend_) {
378                       vsl_.StoreScalar(
379                           vsl_.Add(vsl_.LoadScalar(addend_, scalar_row),
380                                    product),
381                           result_, scalar_row);
382                     } else {
383                       vsl_.StoreScalar(product, result_, scalar_row);
384                     }
385                   },
386                   /*false_block_generator=*/
387                   [&]() {
388                     vsl_.StoreScalar(
389                         vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product),
390                         result_, scalar_row);
391                   });
392             });
393       });
394 }
395 
396 // Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the
397 // layout of the vector does not matter).  This implementation uses a tiling
398 // scheme to improve performance.
399 //
400 // We logically separate the LHS matrix into four segments:
401 //
402 //   +----------------------+---+
403 //   |                      |   |
404 //   |                      |   |
405 //   |         A            | B |
406 //   |                      |   |
407 //   |                      |   |
408 //   |                      |   |
409 //   +----------------------+---+
410 //   |         C            | D |
411 //   +----------------------+---+
412 //
413 // where A is the largest submatrix of the LHS that can be evenly dividied into
414 // tiles.  For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
415 //
416 //   +---+---+---+---+
417 //   |M00|M10|M20|M30|
418 //   +---+---+---+---+       +--+--+--+--+
419 //   |M01|M11|M21|M31| and   |V0|V1|V2|V3|
420 //   +---+---+---+---+       +--+--+--+--+
421 //   |M02|M12|M22|M32|
422 //   +---+---+---+---+
423 //   |M03|M13|M23|M33|
424 //   +---+---+---+---+
425 //
426 // (Legend: rows are horizontal and columns are vertical; and each row is one
427 // llvm::Value of a vector type)
428 //
429 // where:
430 //
431 //   a. The left tile is loaded from the row major left matrix.
432 //   b. The right vector is loaded from the RHS vector.
433 //
434 // We keep 4 vector accumulators accumulating the following four vector
435 // expressions as we iterate over the row dimension:
436 //
437 //   +------+------+------+------+
438 //   |M0I*V0|M1I*V1|M2I*V2|M3I*V3|  for I in [0,4)
439 //   +------+------+------+------+
440 //
441 // In the end we do a horizontal reduction over these 4 vector accumulators to
442 // get 4 values in the result vector.
443 //
444 // We have an inner epilogue loop to deal with the "B" sub-matrix and an outer
445 // epilogue loop to deal with the C,D submatrix.
446 class RowMajorMatrixVectorProductEmitter
447     : public GemvConfig::User<RowMajorMatrixVectorProductEmitter> {
448  public:
449   class Config : public GemvConfig {
450    public:
Config(PrimitiveType scalar_type,int64 tile_rows,int64 tile_cols,int64 m,int64 k,bool has_addend)451     explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols,
452                     int64 m, int64 k, bool has_addend)
453         : GemvConfig(/*name=*/"row_major_gemv", scalar_type,
454                      /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
455                      /*k=*/k, /*has_addend=*/has_addend) {}
456   };
457 
RowMajorMatrixVectorProductEmitter(const Config & config,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b)458   RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
459                                      llvm::Value* rhs, llvm::Value* addend,
460                                      llvm::Value* result, llvm::IRBuilder<>* b)
461       : config_(config),
462         lhs_(lhs),
463         rhs_(rhs),
464         addend_(addend),
465         result_(result),
466         b_(b),
467         ksl_(b_),
468         vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") {
469     CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols())));
470     CHECK(!has_addend() || addend != nullptr);
471   }
472 
473   void Emit();
474 
config() const475   const Config& config() const { return config_; }
476 
477  private:
GetLhsMemoryTile(llvm::Value * row_start,int64 row_count)478   MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) {
479     return MemoryTile(&vsl_, b_, /*matrix=*/lhs_,
480                       /*matrix_size_along_minor_dim=*/k(),
481                       /*major_dim_offset=*/row_start,
482                       /*tile_size_along_major_dim=*/row_count);
483   }
484 
485   void EmitOuterLoopBody(llvm::Value* row, int64 row_count);
486 
487   void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows,
488                           std::vector<VectorVariable>* vector_accumulators);
489 
490   void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows,
491                              std::vector<ScalarVariable>* scalar_accumulators);
492 
493   Config config_;
494   llvm::Value* lhs_;
495   llvm::Value* rhs_;
496   llvm::Value* addend_;
497   llvm::Value* result_;
498   llvm::IRBuilder<>* b_;
499   KernelSupportLibrary ksl_;
500   VectorSupportLibrary vsl_;
501 };
502 
EmitOuterLoopBody(llvm::Value * row,int64 row_count)503 void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row,
504                                                            int64 row_count) {
505   MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row,
506                                                 /*row_count=*/row_count);
507   std::vector<VectorVariable> vector_accumulators;
508   std::vector<ScalarVariable> scalar_accumulators;
509   for (int i = 0; i < row_count; i++) {
510     vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector());
511     scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar());
512   }
513   EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count,
514                      &vector_accumulators);
515   EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count,
516                         &scalar_accumulators);
517 
518   std::vector<llvm::Value*> accumulator_values;
519   std::transform(
520       vector_accumulators.begin(), vector_accumulators.end(),
521       std::back_inserter(accumulator_values),
522       [](const VectorVariable& vector_var) { return vector_var.Get(); });
523 
524   std::vector<llvm::Value*> horizontal_sums;
525   if (row_count == vsl_.vector_size()) {
526     if (addend_) {
527       horizontal_sums = vsl_.ComputeHorizontalSums(
528           std::move(accumulator_values), vsl_.LoadVector(addend_, row));
529     } else {
530       horizontal_sums =
531           vsl_.ComputeHorizontalSums(std::move(accumulator_values));
532     }
533   } else {
534     horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values));
535   }
536 
537   for (int i = 0; i < row_count; i++) {
538     llvm::Value* result_value =
539         vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get());
540     llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row);
541     if (addend_ && row_count != vsl_.vector_size()) {
542       result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value);
543     }
544     vsl_.StoreScalar(result_value, result_, offset);
545   }
546 }
547 
Emit()548 void RowMajorMatrixVectorProductEmitter::Emit() {
549   // See the comment on the class declaration for the algorithm used here.
550   int64 row_remainder = m() % tile_rows();
551   int64 row_limit = m() - row_remainder;
552 
553   ksl_.For("dot.outer.tiled",
554            /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(),
555            [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); });
556 
557   if (row_remainder != 0) {
558     EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder);
559   }
560 }
561 
EmitInnerLoopTiled(MemoryTile * lhs_memory_tile,int64 rows,std::vector<VectorVariable> * vector_accumulators)562 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
563     MemoryTile* lhs_memory_tile, int64 rows,
564     std::vector<VectorVariable>* vector_accumulators) {
565   int64 column_limit = k() - (k() % tile_cols());
566 
567   ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
568            /*step=*/tile_cols(), [&](llvm::Value* col) {
569              std::vector<llvm::Value*> lhs_tile =
570                  lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col);
571              llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
572              for (int i = 0; i < rows; i++) {
573                llvm::Value* old_sum = (*vector_accumulators)[i].Get();
574                (*vector_accumulators)[i].Set(
575                    vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i])));
576              }
577            });
578 }
579 
EmitInnerLoopEpilogue(llvm::Value * current_tile_row,int64 rows,std::vector<ScalarVariable> * scalar_accumulators)580 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
581     llvm::Value* current_tile_row, int64 rows,
582     std::vector<ScalarVariable>* scalar_accumulators) {
583   int64 column_start = k() - (k() % tile_cols());
584   if (column_start == k()) {
585     return;
586   }
587 
588   for (int r = 0; r < rows; r++) {
589     llvm::Value* total_offset = b_->CreateMul(
590         b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k()));
591     llvm::Value* lhs_base_pointer =
592         vsl_.ComputeOffsetPointer(lhs_, total_offset);
593     ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(),
594              /*step=*/1, [&](llvm::Value* scalar_col) {
595                llvm::Value* product =
596                    vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
597                             vsl_.LoadScalar(rhs_, scalar_col));
598                llvm::Value* old_value = (*scalar_accumulators)[r].Get();
599                (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product));
600              });
601   }
602 }
603 
604 // This class implements a tiled matrix multiplication algorithm, intended for
605 // multiplying small matrices that don't need cache tiling.
606 //
607 // In the future this can be used as the innermost GEBP loop in a GEMM kernel as
608 // described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of
609 // high-performance matrix multiplication." ACM Transactions on Mathematical
610 // Software (TOMS) 34.3 (2008): 12.".
611 //
612 // This only supports canonical dot operations (i.e. where the lhs contraction
613 // dimension is 1 and the rhs contraction dimension is 0) over row major
614 // matrices.
615 class TiledSmallGemmEmitter {
616  public:
617   // Describe the dimensions of the kernel.
618   class Dimensions {
619    public:
Dimensions(int64 m,int64 k,int64 n)620     explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {}
621 
m() const622     int64 m() const { return m_; }
k() const623     int64 k() const { return k_; }
n() const624     int64 n() const { return n_; }
625 
ToString() const626     string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); }
627 
628    private:
629     const int64 m_;
630     const int64 k_;
631     const int64 n_;
632   };
633 
634   // Represents the configuration of the emitter.  The LLVM IR emitted by the
635   // emitter, modulo the LLVM values holding the input and output buffers, must
636   // be a function of the instance of `Config` passed to it.
637   //
638   // `dims` holds the matrix multiplication dimensions.
639   //
640   // `max_vectorization_width` is the maximum vector width (i.e. the width of
641   // the largest vector register we will use).  This can be larger than the
642   // largest vector register supported by the machine -- LLVM will legalize
643   // these large vector widths into legally sized vectors.
644   //
645   // `max_vector_count` is the maximum number of vectors of size
646   // `max_vectorization_width` that we will attempt to process at once.
647   //
648   // `min_vectorization_width` is the smallest vector width the emitter will use
649   // -- below that it will devolve to using a scalar loop.
650   //
651   // The innermost reduction loop executes the matrix multiply in tiles of size
652   // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`,
653   // <vectorization width>] in the RHS.
654   class Config {
655    public:
Config(PrimitiveType scalar_type,Dimensions dims,int64 max_vectorization_width,int64 max_vector_count,int64 min_vectorization_width,int64 tile_size_m,int64 tile_size_k)656     explicit Config(PrimitiveType scalar_type, Dimensions dims,
657                     int64 max_vectorization_width, int64 max_vector_count,
658                     int64 min_vectorization_width, int64 tile_size_m,
659                     int64 tile_size_k)
660         : scalar_type_(scalar_type),
661           dims_(dims),
662           max_vectorization_width_(max_vectorization_width),
663           max_vector_count_(max_vector_count),
664           min_vectorization_width_(min_vectorization_width),
665           tile_size_m_(tile_size_m),
666           tile_size_k_(tile_size_k) {}
667 
GetCacheKey() const668     string GetCacheKey() const {
669       return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_",
670                           dims().ToString(), "_", max_vectorization_width(),
671                           "_", min_vectorization_width(), "_", tile_size_m(),
672                           "_", tile_size_k());
673     }
674 
scalar_type() const675     PrimitiveType scalar_type() const { return scalar_type_; }
dims() const676     Dimensions dims() const { return dims_; }
max_vectorization_width() const677     int64 max_vectorization_width() const { return max_vectorization_width_; }
max_vector_count() const678     int64 max_vector_count() const { return max_vector_count_; }
min_vectorization_width() const679     int64 min_vectorization_width() const { return min_vectorization_width_; }
680 
tile_size_m() const681     int64 tile_size_m() const { return tile_size_m_; }
tile_size_k() const682     int64 tile_size_k() const { return tile_size_k_; }
683 
684    private:
685     PrimitiveType scalar_type_;
686     Dimensions dims_;
687     int64 max_vectorization_width_;
688     int64 max_vector_count_;
689     int64 min_vectorization_width_;
690     int64 tile_size_m_;
691     int64 tile_size_k_;
692   };
693 
694   // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies
695   // `lhs` with `rhs` and stores the result in `result`.
TiledSmallGemmEmitter(Config config,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * result,llvm::IRBuilder<> * b)696   explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs,
697                                  llvm::Value* rhs, llvm::Value* result,
698                                  llvm::IRBuilder<>* b)
699       : lhs_(lhs),
700         rhs_(rhs),
701         result_(result),
702         config_(config),
703         b_(b),
704         ksl_(b_) {
705     CHECK(max_vectorization_width() > 0 &&
706           IsPowerOfTwo(static_cast<uint64>(max_vectorization_width())));
707     CHECK_GT(max_vector_count(), 0);
708     CHECK(min_vectorization_width() > 0 &&
709           IsPowerOfTwo(static_cast<uint64>(min_vectorization_width())));
710     CHECK_GE(max_vectorization_width(), min_vectorization_width());
711     CHECK_GT(tile_size_k(), 0);
712   }
713 
714   void Emit();
715 
716  private:
717   // The HandleResiduesOnX helpers split the iteration space for dimension X
718   // into a multiple of the tile size on dimension X and an epilogue.  These
719   // helpers ultimately call into `EmitTiledGemm` for emitting the
720   // tiled GEMM kernel.
721 
722   void HandleResiduesOnN();
723   void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start,
724                          llvm::Value* n_end);
725   void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k,
726                          llvm::Value* k_start, llvm::Value* k_end,
727                          llvm::Value* n_start, llvm::Value* n_end);
728 
729   // This emits a tiled GEMM kernel.  For a detailed description see the comment
730   // on the implementation.
731   void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k,
732                      llvm::Value* k_start, llvm::Value* k_end,
733                      llvm::Value* n_start, llvm::Value* n_end,
734                      int64 tile_size_m, llvm::Value* m_start,
735                      llvm::Value* m_end);
736 
GetInt64(int64 value)737   llvm::Value* GetInt64(int64 value) { return b_->getInt64(value); }
738 
config() const739   Config config() const { return config_; }
dims() const740   Dimensions dims() const { return config().dims(); }
741 
max_vectorization_width() const742   int64 max_vectorization_width() const {
743     return config().max_vectorization_width();
744   }
max_vector_count() const745   int64 max_vector_count() const { return config().max_vector_count(); }
min_vectorization_width() const746   int64 min_vectorization_width() const {
747     return config().min_vectorization_width();
748   }
tile_size_m() const749   int64 tile_size_m() const { return config().tile_size_m(); }
tile_size_k() const750   int64 tile_size_k() const { return config().tile_size_k(); }
scalar_type() const751   PrimitiveType scalar_type() const { return config().scalar_type(); }
752 
753   llvm::Value* lhs_;
754   llvm::Value* rhs_;
755   llvm::Value* result_;
756   Config config_;
757 
758   llvm::IRBuilder<>* b_;
759   KernelSupportLibrary ksl_;
760 };
761 
Emit()762 void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); }
763 
HandleResiduesOnN()764 void TiledSmallGemmEmitter::HandleResiduesOnN() {
765   // We can only iterate the `n` dimension for an extent that is divisible by
766   // the vectorization width.  So we emit an outer loop that first processes the
767   // largest extent in `n` that is divisible by max_vectorization_width, then
768   // the largest remaining extent that is divisible by max_vectorization_width /
769   // 2 etc.
770 
771   int64 current_vectorization_width =
772       max_vector_count() * max_vectorization_width();
773   int64 current_vector_count = max_vector_count();
774 
775   int64 n_start = 0;
776   while (n_start != dims().n() &&
777          current_vectorization_width >= min_vectorization_width()) {
778     int64 n_end = dims().n() - (dims().n() % current_vectorization_width);
779     if (n_start != n_end) {
780       VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_,
781                                "gemm");
782       HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end));
783       n_start = n_end;
784     }
785     if (current_vector_count == 1) {
786       current_vectorization_width /= 2;
787     } else {
788       current_vector_count--;
789       current_vectorization_width =
790           current_vector_count * max_vectorization_width();
791     }
792   }
793 
794   if (n_start != dims().n()) {
795     VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm");
796     ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
797       llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1));
798       HandleResiduesOnK(&vsl, n_i, n_i_next);
799     });
800   }
801 }
802 
HandleResiduesOnK(VectorSupportLibrary * vsl,llvm::Value * n_start,llvm::Value * n_end)803 void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
804                                               llvm::Value* n_start,
805                                               llvm::Value* n_end) {
806   int64 k_start = 0;
807   int64 k_end = dims().k() - (dims().k() % tile_size_k());
808   if (k_end != k_start) {
809     HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end),
810                       n_start, n_end);
811     k_start = k_end;
812   }
813 
814   if (k_start != dims().k()) {
815     HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start),
816                       GetInt64(dims().k()), n_start, n_end);
817   }
818 }
819 
HandleResiduesOnM(VectorSupportLibrary * vsl,int64 tile_size_k,llvm::Value * k_start,llvm::Value * k_end,llvm::Value * n_start,llvm::Value * n_end)820 void TiledSmallGemmEmitter::HandleResiduesOnM(
821     VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
822     llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) {
823   const int64 m_end = dims().m() - dims().m() % tile_size_m();
824   EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(),
825                 GetInt64(0), GetInt64(m_end));
826 
827   if (m_end != dims().m()) {
828     EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end,
829                   dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m()));
830   }
831 }
832 
833 // The loop structure is:
834 //
835 // Iterate over dimension M as m:
836 //   Iterate over dimension N as n:
837 //     Iterate over dimension K as k:
838 //       OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n])
839 //
840 // I.e. a just a tiled version of a "naive" GEMM.
841 //
842 // The tiling scheme is as follows:
843 //
844 // Let the LHS be:
845 //
846 //   +----+----+----+
847 //   | a0 | b0 | c0 | .
848 //   +----+----+----+ .
849 //   | a1 | b1 | c1 | .
850 //   +----+----+----+
851 //     ..     ..
852 //
853 // and the RHS be:
854 //
855 //   +----+----+----+----+
856 //   | p0 | p1 | p2 | p3 | .
857 //   +----+----+----+----+ .
858 //   | q0 | q1 | q2 | q3 | .
859 //   +----+----+----+----+
860 //   | r0 | r1 | r2 | r3 | .
861 //   +----+----+----+----+ .
862 //     ......    ......
863 //
864 // and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted
865 // by `vsl`) be 4.  Then we want to matrix multiply this tile to get a [2,4]
866 // matrix that we can increment the result matrix by.
867 //
868 // First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank
869 // 3 array, L, of dimension [2,3,4]:
870 //
871 //       L[0,_,_]           *      L[1,_,_]
872 //                          *
873 //   +----+----+----+----+  *  +----+----+----+----+
874 //   | a0 | a0 | a0 | a0 |  *  | a1 | a1 | a1 | a1 |
875 //   +----+----+----+----+  *  +----+----+----+----+
876 //   | b0 | b0 | b0 | b0 |  *  | b1 | b1 | b1 | b1 |
877 //   +----+----+----+----+  *  +----+----+----+----+
878 //   | c0 | c0 | c0 | c0 |  *  | c1 | c1 | c1 | c1 |
879 //   +----+----+----+----+  *  +----+----+----+----+
880 //
881 //
882 // Then we FMA L[0,_,_] with the RHS to get the first row of the result and
883 // L[1,_,_] with the RHS to get the second row of the result.  For example,
884 // L[0,_,_] is computed as:
885 //
886 //   +----+----+----+----+   +----+----+----+----+
887 //   | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 |   +
888 //   +----+----+----+----+   +----+----+----+----+
889 //
890 //   +----+----+----+----+   +----+----+----+----+
891 //   | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 |   +
892 //   +----+----+----+----+   +----+----+----+----+
893 //
894 //   +----+----+----+----+   +----+----+----+----+
895 //   | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 |
896 //   +----+----+----+----+   +----+----+----+----+
897 //
898 // to get:
899 //
900 //   +-------------------+-------------------+-------------------+---------
901 //   | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 |  ...
902 //   +-------------------+-------------------+-------------------+---------
EmitTiledGemm(VectorSupportLibrary * vsl,int64 tile_size_k,llvm::Value * k_start,llvm::Value * k_end,llvm::Value * n_start,llvm::Value * n_end,int64 tile_size_m,llvm::Value * m_start,llvm::Value * m_end)903 void TiledSmallGemmEmitter::EmitTiledGemm(
904     VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
905     llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end,
906     int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) {
907   ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) {
908     MemoryTile result_memory_tile(vsl, b_, /*matrix=*/result_,
909                                   /*matrix_size_along_minor_dim=*/dims().n(),
910                                   /*major_dim_offset=*/m_i,
911                                   /*tile_size_along_major_dim=*/tile_size_m);
912     MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_,
913                                /*matrix_size_along_minor_dim=*/dims().k(),
914                                /*major_dim_offset=*/m_i,
915                                /*tile_size_along_major_dim=*/tile_size_m);
916     ksl_.For(
917         "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
918           TileVariable result_tile_var(vsl, result_memory_tile.LoadTile(n_i));
919           ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) {
920             MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i,
921                                        tile_size_k);
922             std::vector<std::vector<llvm::Value*>> lhs_tile =
923                 lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k);
924             std::vector<llvm::Value*> rhs_tile = rhs_memory_tile.LoadTile(n_i);
925             std::vector<llvm::Value*> result_tile = result_tile_var.Get();
926             for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) {
927               for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) {
928                 result_tile[r_m_i] =
929                     vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i],
930                                 result_tile[r_m_i]);
931               }
932             }
933             result_tile_var.Set(result_tile);
934           });
935 
936           result_memory_tile.StoreTile(result_tile_var.Get(), n_i);
937         });
938   });
939 }
940 
GetPointerToElementType(llvm::Type * pointer_type)941 llvm::Type* GetPointerToElementType(llvm::Type* pointer_type) {
942   llvm::Type* type =
943       llvm::cast<llvm::PointerType>(pointer_type)->getElementType();
944   while (auto* array_type = llvm::dyn_cast<llvm::ArrayType>(type)) {
945     type = array_type->getElementType();
946   }
947 
948   return type->getPointerTo();
949 }
950 
951 struct GemvBuffersWithCanonicalType {
952   llvm::Value* lhs_canonicalized;
953   llvm::Value* rhs_canonicalized;
954   llvm::Value* addend_canonicalized;
955   llvm::Value* result_canonicalized;
956 };
957 
GetGemvBuffersWithCanonicalType(llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b)958 GemvBuffersWithCanonicalType GetGemvBuffersWithCanonicalType(
959     llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend,
960     llvm::Value* result, llvm::IRBuilder<>* b) {
961   // We characterize a GEMV operation via M and K, since N is implicitly 1.
962   // This means the GEMV that multiplies (say) [5,6] with [6,1] is implemented
963   // by the same GEMV that multiplies [5,6] with [1,6].  However, the
964   // `llvm::Types` for the inputs to the two GEMVs don't match (in a trivial
965   // sense -- the in memory representations are the same) since they're computed
966   // from the `xla::Shape`s.  Since we want to be able to call the same
967   // `llvm::Function` for the two GEMVs we canonicalize the types of the GEMV
968   // inputs here into the same type.
969   GemvBuffersWithCanonicalType buffers_with_canonical_type;
970   llvm::Type* lhs_type = lhs->getType();
971   llvm::Type* rhs_type = rhs->getType();
972   llvm::Type* addend_type = addend ? addend->getType() : nullptr;
973   llvm::Type* result_type = result->getType();
974 
975   buffers_with_canonical_type.lhs_canonicalized =
976       b->CreateBitCast(lhs, GetPointerToElementType(lhs_type));
977   buffers_with_canonical_type.rhs_canonicalized =
978       b->CreateBitCast(rhs, GetPointerToElementType(rhs_type));
979   buffers_with_canonical_type.addend_canonicalized =
980       addend ? b->CreateBitCast(addend, GetPointerToElementType(addend_type))
981              : nullptr;
982   buffers_with_canonical_type.result_canonicalized =
983       b->CreateBitCast(result, GetPointerToElementType(result_type));
984 
985   return buffers_with_canonical_type;
986 }
987 
988 }  // namespace
989 
EmitRowMajorGemv(PrimitiveType scalar_type,int64 tile_rows,int64 tile_cols,int64 m,int64 k,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b,const HloModuleConfig & module_config)990 void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
991                       int64 tile_cols, int64 m, int64 k, llvm::Value* lhs,
992                       llvm::Value* rhs, llvm::Value* addend,
993                       llvm::Value* result, llvm::IRBuilder<>* b,
994                       const HloModuleConfig& module_config) {
995   RowMajorMatrixVectorProductEmitter::Config config(
996       /*scalar_type=*/scalar_type,
997       /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
998       /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
999 
1000   GemvBuffersWithCanonicalType canonical_inputs =
1001       GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
1002 
1003   KernelSupportLibrary::EmitAndCallOutlinedKernel(
1004       module_config, b, config.GetCacheKey(),
1005       canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
1006       canonical_inputs.addend_canonicalized,
1007       canonical_inputs.result_canonicalized,
1008       [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs,
1009                                       llvm::Value* addend,
1010                                       llvm::Value* result) {
1011         RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
1012                                                    result, b);
1013         emitter.Emit();
1014       });
1015 }
1016 
EmitColumnMajorGemv(PrimitiveType scalar_type,int64 tile_rows,int64 tile_cols,int64 m,int64 k,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b,const HloModuleConfig & module_config)1017 void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
1018                          int64 tile_cols, int64 m, int64 k, llvm::Value* lhs,
1019                          llvm::Value* rhs, llvm::Value* addend,
1020                          llvm::Value* result, llvm::IRBuilder<>* b,
1021                          const HloModuleConfig& module_config) {
1022   ColumnMajorMatrixVectorProductEmitter::Config config(
1023       /*scalar_type=*/scalar_type,
1024       /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
1025       /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
1026 
1027   GemvBuffersWithCanonicalType canonical_inputs =
1028       GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
1029 
1030   KernelSupportLibrary::EmitAndCallOutlinedKernel(
1031       module_config, b, config.GetCacheKey(),
1032       canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
1033       canonical_inputs.addend_canonicalized,
1034       canonical_inputs.result_canonicalized,
1035       [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs,
1036                                       llvm::Value* addend,
1037                                       llvm::Value* result) {
1038         ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
1039                                                       result, b);
1040         emitter.Emit();
1041       });
1042 }
1043 
EmitSmallGemm(PrimitiveType scalar_type,int64 m,int64 k,int64 n,int64 max_vectorization_width,int64 max_vector_count,int64 min_vectorization_width,int64 tile_size_m,int64 tile_size_k,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * result,llvm::IRBuilder<> * b,const HloModuleConfig & module_config)1044 void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n,
1045                    int64 max_vectorization_width, int64 max_vector_count,
1046                    int64 min_vectorization_width, int64 tile_size_m,
1047                    int64 tile_size_k, llvm::Value* lhs, llvm::Value* rhs,
1048                    llvm::Value* result, llvm::IRBuilder<>* b,
1049                    const HloModuleConfig& module_config) {
1050   TiledSmallGemmEmitter::Config config(
1051       /*scalar_type=*/scalar_type,
1052       TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
1053       /*max_vectorization_width=*/max_vectorization_width,
1054       /*max_vector_count=*/max_vector_count,
1055       /*min_vectorization_width=*/min_vectorization_width,
1056       /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k);
1057 
1058   KernelSupportLibrary::EmitAndCallOutlinedKernel(
1059       module_config, b, config.GetCacheKey(), lhs, rhs, result,
1060       [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result) {
1061         TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs,
1062                                                  /*rhs=*/rhs,
1063                                                  /*result=*/result, b);
1064         small_gemm_emitter.Emit();
1065       });
1066 }
1067 
1068 }  // namespace cpu
1069 }  // namespace xla
1070