• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
17 
18 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
22 
23 namespace xla {
24 namespace cpu {
25 namespace {
26 
27 using ::int64_t;
28 
29 // Provides tiled access to an in-memory rank 2 array.
30 class MemoryTile {
31  public:
32   // Constructs a MemoryTile that can operate on tiles consisting of
33   // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at
34   // `major_dim_offset` in the major dimension.  The tile size along the minor
35   // dimension is the vector size, and that is implicitly determined by `vsl`.
MemoryTile(VectorSupportLibrary * vsl,llvm::IRBuilder<> * b,llvm::Value * matrix,int64_t matrix_size_along_minor_dim,llvm::Value * major_dim_offset,int64_t tile_size_along_major_dim)36   MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b,
37              llvm::Value* matrix, int64_t matrix_size_along_minor_dim,
38              llvm::Value* major_dim_offset, int64_t tile_size_along_major_dim)
39       : vsl_(vsl), b_(b) {
40     pointers_.reserve(tile_size_along_major_dim);
41     for (int64_t i = 0; i < tile_size_along_major_dim; i++) {
42       llvm::Value* total_offset =
43           b->CreateMul(b->getInt64(matrix_size_along_minor_dim),
44                        b->CreateAdd(b->getInt64(i), major_dim_offset));
45       pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset));
46     }
47   }
48 
49   // Load a tile consisting of `tile_size_along_major_dim` vectors from position
50   // {major: `major_dim_offset`, minor: `minor_dim_offset`}.
51   //
52   // Note: `major_dim_offset` is a parameter to the constructor.
LoadTile(llvm::Value * minor_dim_offset) const53   std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const {
54     std::vector<llvm::Value*> result;
55     result.reserve(pointers_.size());
56     for (const auto& pointer : pointers_) {
57       result.push_back(vsl_->LoadVector(pointer, minor_dim_offset));
58     }
59     return result;
60   }
61 
62   // Stores `tile` to position {major: `major_dim_offset`, minor:
63   // `minor_dim_offset`}.
64   //
65   // Note: `major_dim_offset` is a parameter to the constructor.
StoreTile(absl::Span<llvm::Value * const> tile,llvm::Value * minor_dim_offset) const66   void StoreTile(absl::Span<llvm::Value* const> tile,
67                  llvm::Value* minor_dim_offset) const {
68     CHECK_EQ(tile.size(), pointers_.size());
69     for (int64_t i = 0; i < pointers_.size(); i++) {
70       vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset);
71     }
72   }
73 
74   // Loads a tile of size [`tile_size_along_major_dim`,
75   // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`,
76   // minor: `minor_dim_offset`} and then broadcasts each element into a vector
77   // of size vsl_.vector_size().  The (i,j)'th element of the return value is
78   // the (i,j)'th element in the tile broadcasted into an LLVM vector.
79   //
80   // Note: `major_dim_offset` is a parameter to the constructor.
LoadBroadcastTile(llvm::Value * minor_dim_offset,int64_t tile_size_along_middle_dim) const81   std::vector<std::vector<llvm::Value*>> LoadBroadcastTile(
82       llvm::Value* minor_dim_offset, int64_t tile_size_along_middle_dim) const {
83     std::vector<std::vector<llvm::Value*>> result;
84     result.resize(pointers_.size());
85     for (int64_t i = 0; i < pointers_.size(); i++) {
86       for (int64_t j = 0; j < tile_size_along_middle_dim; j++) {
87         result[i].push_back(vsl_->LoadBroadcast(
88             pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j))));
89       }
90     }
91     return result;
92   }
93 
94  private:
95   VectorSupportLibrary* vsl_;
96   llvm::IRBuilder<>* b_;
97   std::vector<llvm::Value*> pointers_;
98 };
99 
100 // The base class for the classes representing the GEMV emitter configurations.
101 //
102 // The IR emitted (modulo the LLVM values representing the input and output
103 // buffers) by the row major and column major GEMV emitters should be a function
104 // of their configuration.  This is important because their configuration is
105 // used as a key to cache the generated IR.
106 class GemvConfig {
107  public:
108   // Mixin for convenience.
109   template <typename T>
110   struct User {
111    public:
scalar_typexla::cpu::__anon8d179a540111::GemvConfig::User112     PrimitiveType scalar_type() const {
113       return derived().config().scalar_type();
114     }
tile_rowsxla::cpu::__anon8d179a540111::GemvConfig::User115     int64_t tile_rows() const { return derived().config().tile_rows(); }
tile_colsxla::cpu::__anon8d179a540111::GemvConfig::User116     int64_t tile_cols() const { return derived().config().tile_cols(); }
mxla::cpu::__anon8d179a540111::GemvConfig::User117     int64_t m() const { return derived().config().m(); }
kxla::cpu::__anon8d179a540111::GemvConfig::User118     int64_t k() const { return derived().config().k(); }
has_addendxla::cpu::__anon8d179a540111::GemvConfig::User119     int64_t has_addend() const { return derived().config().has_addend(); }
120 
121    private:
derivedxla::cpu::__anon8d179a540111::GemvConfig::User122     const T& derived() const { return *static_cast<const T*>(this); }
123   };
124 
scalar_type() const125   PrimitiveType scalar_type() const { return scalar_type_; }
tile_rows() const126   int64_t tile_rows() const { return tile_rows_; }
tile_cols() const127   int64_t tile_cols() const { return tile_cols_; }
m() const128   int64_t m() const { return m_; }
k() const129   int64_t k() const { return k_; }
has_addend() const130   bool has_addend() const { return has_addend_; }
131 
GetCacheKey() const132   std::string GetCacheKey() const {
133     return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_",
134                         tile_rows(), "_", tile_cols(), "_", m(), "_", k(),
135                         has_addend() ? "_with_addend" : "");
136   }
137 
138  protected:
GemvConfig(std::string name,PrimitiveType scalar_type,int64_t tile_rows,int64_t tile_cols,int64_t m,int64_t k,bool has_addend)139   explicit GemvConfig(std::string name, PrimitiveType scalar_type,
140                       int64_t tile_rows, int64_t tile_cols, int64_t m,
141                       int64_t k, bool has_addend)
142       : name_(std::move(name)),
143         scalar_type_(scalar_type),
144         tile_rows_(tile_rows),
145         tile_cols_(tile_cols),
146         m_(m),
147         k_(k),
148         has_addend_(has_addend) {}
149 
150  private:
151   std::string name_;
152   PrimitiveType scalar_type_;
153   int64_t tile_rows_;
154   int64_t tile_cols_;
155   int64_t m_;
156   int64_t k_;
157   bool has_addend_;
158 };
159 
160 // Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the
161 // layout of the vector does not matter).  This implementation uses a tiling
162 // scheme to improve performance.
163 //
164 // We logically separate the LHS matrix into four segments:
165 //
166 //   +----------------------+---+
167 //   |                      |   |
168 //   |                      |   |
169 //   |         A            | B |
170 //   |                      |   |
171 //   |                      |   |
172 //   |                      |   |
173 //   +----------------------+---+
174 //   |         C            | D |
175 //   +----------------------+---+
176 //
177 // where A is the largest submatrix of the LHS that can be evenly divided into
178 // tiles.  For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
179 //
180 //   +---+---+---+---+       +--+--+--+--+
181 //   |M00|M10|M20|M30|       |V0|V1|V2|V3|
182 //   +---+---+---+---+       +--+--+--+--+
183 //   |M01|M11|M21|M31| and   |V0|V1|V2|V3|
184 //   +---+---+---+---+       +--+--+--+--+
185 //   |M02|M12|M22|M32|       |V0|V1|V2|V3|
186 //   +---+---+---+---+       +--+--+--+--+
187 //   |M03|M13|M23|M33|       |V0|V1|V2|V3|
188 //   +---+---+---+---+       +--+--+--+--+
189 //
190 // (Legend: rows are horizontal and columns are vertical; and each column is one
191 // llvm::Value of a vector type)
192 //
193 // where:
194 //
195 //   a. The left tile is from the column major left matrix.
196 //   b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3]
197 //      vector loaded from the RHS vector.
198 //
199 // As we iterate through the column dimension, we compute the change to the
200 // result vector by an elementwise multiplication between the two tiles above
201 // followed by a reduction along the major dimension:
202 //
203 //                     +-----------------------------------+
204 //                     | M00*V0 + M10*V1 + M20*V2 + M30*V3 |
205 //                     +-----------------------------------+
206 //                     | M01*V0 + M11*V1 + M21*V2 + M31*V3 |
207 // Result[R:R+4] +=    +-----------------------------------+
208 //                     | M02*V0 + M12*V1 + M22*V2 + M32*V3 |
209 //                     +-----------------------------------+
210 //                     | M03*V0 + M13*V1 + M23*V2 + M33*V3 |
211 //                     +-----------------------------------+
212 //
213 // Where R is the starting row for the tile.
214 //
215 // We have an inner epilogue loop to deal with the "C" submatrix and an outer
216 // epilogue loop to deal with the B,D submatrix.
217 //
218 // TODO(sanjoy): We should investigate if using gather loads and scatter stores
219 // can be used here have the same inner loop for both column-major and row-major
220 // matrix-vector products.
221 class ColumnMajorMatrixVectorProductEmitter
222     : public GemvConfig::User<ColumnMajorMatrixVectorProductEmitter> {
223  public:
224   class Config : public GemvConfig {
225    public:
Config(PrimitiveType scalar_type,int64_t tile_rows,int64_t tile_cols,int64_t m,int64_t k,bool has_addend)226     explicit Config(PrimitiveType scalar_type, int64_t tile_rows,
227                     int64_t tile_cols, int64_t m, int64_t k, bool has_addend)
228         : GemvConfig(/*name=*/"col_major_gemv", scalar_type,
229                      /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
230                      /*k=*/k, /*has_addend=*/has_addend) {}
231   };
232 
ColumnMajorMatrixVectorProductEmitter(const Config & config,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b)233   ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
234                                         llvm::Value* rhs, llvm::Value* addend,
235                                         llvm::Value* result,
236                                         llvm::IRBuilder<>* b)
237       : config_(config),
238         lhs_(lhs),
239         rhs_(rhs),
240         addend_(addend),
241         result_(result),
242         b_(b),
243         ksl_(b_),
244         vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") {
245     CHECK(tile_rows() > 0 &&
246           absl::has_single_bit(static_cast<uint64_t>(tile_rows())));
247     CHECK(!has_addend() || addend != nullptr);
248   }
249 
250   void Emit();
251 
config() const252   const Config& config() const { return config_; }
253 
254  private:
255   void EmitOuterLoopBody(llvm::Value* column, int64_t column_count,
256                          bool is_first_column);
257 
GetLhsMemoryTile(llvm::Value * column_start,int64_t column_count)258   MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64_t column_count) {
259     return MemoryTile(&vsl_, b_, /*matrix=*/lhs_,
260                       /*matrix_size_along_minor_dim=*/m(),
261                       /*major_dim_offset=*/column_start,
262                       /*tile_size_along_major_dim=*/column_count);
263   }
264 
265   // Load a tile of values from the RHS.  For the RHS a "tile" is a contiguous
266   // sequence of `count` values, each one broadcasted to the vector width.
LoadRhsTile(llvm::Value * offset,int64_t count)267   std::vector<llvm::Value*> LoadRhsTile(llvm::Value* offset, int64_t count) {
268     llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset);
269     std::vector<llvm::Value*> result;
270     result.reserve(count);
271     for (int64_t i = 0; i < count; i++) {
272       result.push_back(vsl_.LoadBroadcast(base_pointer, i));
273     }
274     return result;
275   }
276 
277   void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile,
278                           const std::vector<llvm::Value*>& rhs_tile,
279                           int64_t columns, bool is_first_column);
280 
281   void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64_t columns,
282                              bool is_first_tiled_column);
283 
284   Config config_;
285   llvm::Value* lhs_;
286   llvm::Value* rhs_;
287   llvm::Value* addend_;
288   llvm::Value* result_;
289   llvm::IRBuilder<>* b_;
290   KernelSupportLibrary ksl_;
291   VectorSupportLibrary vsl_;
292 };
293 
EmitOuterLoopBody(llvm::Value * column,int64_t column_count,bool is_first_column)294 void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody(
295     llvm::Value* column, int64_t column_count, bool is_first_column) {
296   MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column,
297                                                 /*column_count=*/column_count);
298 
299   std::vector<llvm::Value*> rhs_tile =
300       LoadRhsTile(column, /*count=*/column_count);
301   EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile,
302                      /*columns=*/column_count, is_first_column);
303   EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column);
304 }
305 
Emit()306 void ColumnMajorMatrixVectorProductEmitter::Emit() {
307   // See the comment on the class declaration for the algorithm used here.
308   int64_t column_remainder = k() % tile_cols();
309   int64_t column_limit = k() - column_remainder;
310 
311   ksl_.For("dot.outer.tiled",
312            /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(),
313            [&](llvm::Value* column, bool is_first_column) {
314              EmitOuterLoopBody(column, tile_cols(), is_first_column);
315            });
316 
317   if (column_remainder != 0) {
318     EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder,
319                       column_limit == 0);
320   }
321 }
322 
EmitInnerLoopTiled(MemoryTile * lhs_memory_tile,const std::vector<llvm::Value * > & rhs_tile,int64_t columns,bool is_first_column)323 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
324     MemoryTile* lhs_memory_tile, const std::vector<llvm::Value*>& rhs_tile,
325     int64_t columns, bool is_first_column) {
326   int64_t row_limit = m() - (m() % tile_rows());
327 
328   ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
329            /*step=*/tile_rows(), [&](llvm::Value* row) {
330              std::vector<llvm::Value*> lhs_tile =
331                  lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row);
332              llvm::Value* accumulator =
333                  is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row)
334                                             : vsl_.GetZeroVector())
335                                  : vsl_.LoadVector(result_, row);
336              for (int i = 0; i < columns; i++) {
337                accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator);
338              }
339              vsl_.StoreVector(accumulator, result_, row);
340            });
341 }
342 
EmitInnerLoopEpilogue(llvm::Value * current_tile_col,int64_t columns,bool is_first_tiled_column)343 void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
344     llvm::Value* current_tile_col, int64_t columns,
345     bool is_first_tiled_column) {
346   int64_t row_start = m() - (m() % tile_rows());
347   if (row_start == m()) {
348     return;
349   }
350 
351   llvm::Value* columns_llvm = b_->getInt64(columns);
352 
353   // for (col = current_tile_col; col < (columns + current_tile_col); col++)
354   //   for (row = row_start, row < m_; row++) {
355   //     result[row] += lhs[row, col] * rhs[col]
356   //     // Also take into account that if col is 0 then result[row] is not
357   //     // initialized.
358   //   }
359 
360   ksl_.For(
361       "dot.inner.epilg.outer", /*start=*/current_tile_col,
362       /*end=*/b_->CreateAdd(columns_llvm, current_tile_col),
363       /*step=*/1, /*peel_first_iteration=*/false,
364       [&](llvm::Value* col, llvm::Value* is_first_scalar_col) {
365         llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col);
366         llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m()));
367         llvm::Value* lhs_base_pointer =
368             vsl_.ComputeOffsetPointer(lhs_, total_offset);
369         ksl_.For(
370             "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(),
371             /*step=*/1, [&](llvm::Value* scalar_row) {
372               llvm::Value* product = vsl_.Mul(
373                   vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element);
374               llvm::Value* setting_result_first_time = b_->CreateAnd(
375                   is_first_scalar_col, b_->getInt1(is_first_tiled_column));
376               ksl_.If(
377                   setting_result_first_time,
378                   /*true_block_generator=*/
379                   [&]() {
380                     if (addend_) {
381                       vsl_.StoreScalar(
382                           vsl_.Add(vsl_.LoadScalar(addend_, scalar_row),
383                                    product),
384                           result_, scalar_row);
385                     } else {
386                       vsl_.StoreScalar(product, result_, scalar_row);
387                     }
388                   },
389                   /*false_block_generator=*/
390                   [&]() {
391                     vsl_.StoreScalar(
392                         vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product),
393                         result_, scalar_row);
394                   });
395             });
396       });
397 }
398 
399 // Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the
400 // layout of the vector does not matter).  This implementation uses a tiling
401 // scheme to improve performance.
402 //
403 // We logically separate the LHS matrix into four segments:
404 //
405 //   +----------------------+---+
406 //   |                      |   |
407 //   |                      |   |
408 //   |         A            | B |
409 //   |                      |   |
410 //   |                      |   |
411 //   |                      |   |
412 //   +----------------------+---+
413 //   |         C            | D |
414 //   +----------------------+---+
415 //
416 // where A is the largest submatrix of the LHS that can be evenly divided into
417 // tiles.  For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
418 //
419 //   +---+---+---+---+
420 //   |M00|M10|M20|M30|
421 //   +---+---+---+---+       +--+--+--+--+
422 //   |M01|M11|M21|M31| and   |V0|V1|V2|V3|
423 //   +---+---+---+---+       +--+--+--+--+
424 //   |M02|M12|M22|M32|
425 //   +---+---+---+---+
426 //   |M03|M13|M23|M33|
427 //   +---+---+---+---+
428 //
429 // (Legend: rows are horizontal and columns are vertical; and each row is one
430 // llvm::Value of a vector type)
431 //
432 // where:
433 //
434 //   a. The left tile is loaded from the row major left matrix.
435 //   b. The right vector is loaded from the RHS vector.
436 //
437 // We keep 4 vector accumulators accumulating the following four vector
438 // expressions as we iterate over the row dimension:
439 //
440 //   +------+------+------+------+
441 //   |M0I*V0|M1I*V1|M2I*V2|M3I*V3|  for I in [0,4)
442 //   +------+------+------+------+
443 //
444 // In the end we do a horizontal reduction over these 4 vector accumulators to
445 // get 4 values in the result vector.
446 //
447 // We have an inner epilogue loop to deal with the "B" sub-matrix and an outer
448 // epilogue loop to deal with the C,D submatrix.
449 class RowMajorMatrixVectorProductEmitter
450     : public GemvConfig::User<RowMajorMatrixVectorProductEmitter> {
451  public:
452   class Config : public GemvConfig {
453    public:
Config(PrimitiveType scalar_type,int64_t tile_rows,int64_t tile_cols,int64_t m,int64_t k,bool has_addend)454     explicit Config(PrimitiveType scalar_type, int64_t tile_rows,
455                     int64_t tile_cols, int64_t m, int64_t k, bool has_addend)
456         : GemvConfig(/*name=*/"row_major_gemv", scalar_type,
457                      /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
458                      /*k=*/k, /*has_addend=*/has_addend) {}
459   };
460 
RowMajorMatrixVectorProductEmitter(const Config & config,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b)461   RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
462                                      llvm::Value* rhs, llvm::Value* addend,
463                                      llvm::Value* result, llvm::IRBuilder<>* b)
464       : config_(config),
465         lhs_(lhs),
466         rhs_(rhs),
467         addend_(addend),
468         result_(result),
469         b_(b),
470         ksl_(b_),
471         vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") {
472     CHECK(tile_cols() > 0 &&
473           absl::has_single_bit(static_cast<uint64_t>(tile_cols())));
474     CHECK(!has_addend() || addend != nullptr);
475   }
476 
477   void Emit();
478 
config() const479   const Config& config() const { return config_; }
480 
481  private:
GetLhsMemoryTile(llvm::Value * row_start,int64_t row_count)482   MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64_t row_count) {
483     return MemoryTile(&vsl_, b_, /*matrix=*/lhs_,
484                       /*matrix_size_along_minor_dim=*/k(),
485                       /*major_dim_offset=*/row_start,
486                       /*tile_size_along_major_dim=*/row_count);
487   }
488 
489   void EmitOuterLoopBody(llvm::Value* row, int64_t row_count);
490 
491   void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64_t rows,
492                           std::vector<VectorVariable>* vector_accumulators);
493 
494   void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64_t rows,
495                              std::vector<ScalarVariable>* scalar_accumulators);
496 
497   Config config_;
498   llvm::Value* lhs_;
499   llvm::Value* rhs_;
500   llvm::Value* addend_;
501   llvm::Value* result_;
502   llvm::IRBuilder<>* b_;
503   KernelSupportLibrary ksl_;
504   VectorSupportLibrary vsl_;
505 };
506 
EmitOuterLoopBody(llvm::Value * row,int64_t row_count)507 void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row,
508                                                            int64_t row_count) {
509   MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row,
510                                                 /*row_count=*/row_count);
511   std::vector<VectorVariable> vector_accumulators;
512   std::vector<ScalarVariable> scalar_accumulators;
513   vector_accumulators.reserve(row_count);
514   scalar_accumulators.reserve(row_count);
515   for (int64_t i = 0; i < row_count; i++) {
516     vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector());
517     scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar());
518   }
519   EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count,
520                      &vector_accumulators);
521   EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count,
522                         &scalar_accumulators);
523 
524   std::vector<llvm::Value*> accumulator_values;
525   std::transform(
526       vector_accumulators.begin(), vector_accumulators.end(),
527       std::back_inserter(accumulator_values),
528       [](const VectorVariable& vector_var) { return vector_var.Get(); });
529 
530   std::vector<llvm::Value*> horizontal_sums;
531   if (row_count == vsl_.vector_size()) {
532     if (addend_) {
533       horizontal_sums = vsl_.ComputeHorizontalSums(
534           std::move(accumulator_values), vsl_.LoadVector(addend_, row));
535     } else {
536       horizontal_sums =
537           vsl_.ComputeHorizontalSums(std::move(accumulator_values));
538     }
539   } else {
540     horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values));
541   }
542 
543   for (int i = 0; i < row_count; i++) {
544     llvm::Value* result_value =
545         vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get());
546     llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row);
547     if (addend_ && row_count != vsl_.vector_size()) {
548       result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value);
549     }
550     vsl_.StoreScalar(result_value, result_, offset);
551   }
552 }
553 
Emit()554 void RowMajorMatrixVectorProductEmitter::Emit() {
555   // See the comment on the class declaration for the algorithm used here.
556   int64_t row_remainder = m() % tile_rows();
557   int64_t row_limit = m() - row_remainder;
558 
559   ksl_.For("dot.outer.tiled",
560            /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(),
561            [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); });
562 
563   if (row_remainder != 0) {
564     EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder);
565   }
566 }
567 
EmitInnerLoopTiled(MemoryTile * lhs_memory_tile,int64_t rows,std::vector<VectorVariable> * vector_accumulators)568 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
569     MemoryTile* lhs_memory_tile, int64_t rows,
570     std::vector<VectorVariable>* vector_accumulators) {
571   int64_t column_limit = k() - (k() % tile_cols());
572 
573   ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
574            /*step=*/tile_cols(), [&](llvm::Value* col) {
575              std::vector<llvm::Value*> lhs_tile =
576                  lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col);
577              llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
578              for (int i = 0; i < rows; i++) {
579                llvm::Value* old_sum = (*vector_accumulators)[i].Get();
580                (*vector_accumulators)[i].Set(
581                    vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i])));
582              }
583            });
584 }
585 
EmitInnerLoopEpilogue(llvm::Value * current_tile_row,int64_t rows,std::vector<ScalarVariable> * scalar_accumulators)586 void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
587     llvm::Value* current_tile_row, int64_t rows,
588     std::vector<ScalarVariable>* scalar_accumulators) {
589   int64_t column_start = k() - (k() % tile_cols());
590   if (column_start == k()) {
591     return;
592   }
593 
594   for (int r = 0; r < rows; r++) {
595     llvm::Value* total_offset = b_->CreateMul(
596         b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k()));
597     llvm::Value* lhs_base_pointer =
598         vsl_.ComputeOffsetPointer(lhs_, total_offset);
599     ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(),
600              /*step=*/1, [&](llvm::Value* scalar_col) {
601                llvm::Value* product =
602                    vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
603                             vsl_.LoadScalar(rhs_, scalar_col));
604                llvm::Value* old_value = (*scalar_accumulators)[r].Get();
605                (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product));
606              });
607   }
608 }
609 
610 // This class implements a tiled matrix multiplication algorithm, intended for
611 // multiplying small matrices that don't need cache tiling.
612 //
613 // In the future this can be used as the innermost GEBP loop in a GEMM kernel as
614 // described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of
615 // high-performance matrix multiplication." ACM Transactions on Mathematical
616 // Software (TOMS) 34.3 (2008): 12.".
617 //
618 // This only supports canonical dot operations (i.e. where the lhs contraction
619 // dimension is 1 and the rhs contraction dimension is 0) over row major
620 // matrices.
621 class TiledSmallGemmEmitter {
622  public:
623   // Describe the dimensions of the kernel.
624   class Dimensions {
625    public:
Dimensions(int64_t m,int64_t k,int64_t n)626     explicit Dimensions(int64_t m, int64_t k, int64_t n)
627         : m_(m), k_(k), n_(n) {}
628 
m() const629     int64_t m() const { return m_; }
k() const630     int64_t k() const { return k_; }
n() const631     int64_t n() const { return n_; }
632 
ToString() const633     std::string ToString() const {
634       return absl::StrCat(m(), "x", k(), "x", n());
635     }
636 
637    private:
638     const int64_t m_;
639     const int64_t k_;
640     const int64_t n_;
641   };
642 
643   // Represents the configuration of the emitter.  The LLVM IR emitted by the
644   // emitter, modulo the LLVM values holding the input and output buffers, must
645   // be a function of the instance of `Config` passed to it.
646   //
647   // `dims` holds the matrix multiplication dimensions.
648   //
649   // `max_vectorization_width` is the maximum vector width (i.e. the width of
650   // the largest vector register we will use).  This can be larger than the
651   // largest vector register supported by the machine -- LLVM will legalize
652   // these large vector widths into legally sized vectors.
653   //
654   // `max_vector_count` is the maximum number of vectors of size
655   // `max_vectorization_width` that we will attempt to process at once.
656   //
657   // `min_vectorization_width` is the smallest vector width the emitter will use
658   // -- below that it will devolve to using a scalar loop.
659   //
660   // The innermost reduction loop executes the matrix multiply in tiles of size
661   // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`,
662   // <vectorization width>] in the RHS.
663   class Config {
664    public:
Config(PrimitiveType scalar_type,Dimensions dims,int64_t max_vectorization_width,int64_t max_vector_count,int64_t min_vectorization_width,int64_t tile_size_m,int64_t tile_size_k)665     explicit Config(PrimitiveType scalar_type, Dimensions dims,
666                     int64_t max_vectorization_width, int64_t max_vector_count,
667                     int64_t min_vectorization_width, int64_t tile_size_m,
668                     int64_t tile_size_k)
669         : scalar_type_(scalar_type),
670           dims_(dims),
671           max_vectorization_width_(max_vectorization_width),
672           max_vector_count_(max_vector_count),
673           min_vectorization_width_(min_vectorization_width),
674           tile_size_m_(tile_size_m),
675           tile_size_k_(tile_size_k) {}
676 
GetCacheKey() const677     std::string GetCacheKey() const {
678       return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_",
679                           dims().ToString(), "_", max_vectorization_width(),
680                           "_", min_vectorization_width(), "_", tile_size_m(),
681                           "_", tile_size_k());
682     }
683 
scalar_type() const684     PrimitiveType scalar_type() const { return scalar_type_; }
dims() const685     Dimensions dims() const { return dims_; }
max_vectorization_width() const686     int64_t max_vectorization_width() const { return max_vectorization_width_; }
max_vector_count() const687     int64_t max_vector_count() const { return max_vector_count_; }
min_vectorization_width() const688     int64_t min_vectorization_width() const { return min_vectorization_width_; }
689 
tile_size_m() const690     int64_t tile_size_m() const { return tile_size_m_; }
tile_size_k() const691     int64_t tile_size_k() const { return tile_size_k_; }
692 
693    private:
694     PrimitiveType scalar_type_;
695     Dimensions dims_;
696     int64_t max_vectorization_width_;
697     int64_t max_vector_count_;
698     int64_t min_vectorization_width_;
699     int64_t tile_size_m_;
700     int64_t tile_size_k_;
701   };
702 
703   // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies
704   // `lhs` with `rhs` and stores the result in `result`.
TiledSmallGemmEmitter(Config config,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * result,llvm::IRBuilder<> * b)705   explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs,
706                                  llvm::Value* rhs, llvm::Value* result,
707                                  llvm::IRBuilder<>* b)
708       : lhs_(lhs),
709         rhs_(rhs),
710         result_(result),
711         config_(config),
712         b_(b),
713         ksl_(b_) {
714     CHECK(
715         max_vectorization_width() > 0 &&
716         absl::has_single_bit(static_cast<uint64_t>(max_vectorization_width())));
717     CHECK_GT(max_vector_count(), 0);
718     CHECK(
719         min_vectorization_width() > 0 &&
720         absl::has_single_bit(static_cast<uint64_t>(min_vectorization_width())));
721     CHECK_GE(max_vectorization_width(), min_vectorization_width());
722     CHECK_GT(tile_size_k(), 0);
723   }
724 
725   void Emit();
726 
727  private:
728   // The HandleResiduesOnX helpers split the iteration space for dimension X
729   // into a multiple of the tile size on dimension X and an epilogue.  These
730   // helpers ultimately call into `EmitTiledGemm` for emitting the
731   // tiled GEMM kernel.
732 
733   void HandleResiduesOnN();
734   void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start,
735                          llvm::Value* n_end);
736   void HandleResiduesOnM(VectorSupportLibrary* vsl, int64_t tile_size_k,
737                          llvm::Value* k_start, llvm::Value* k_end,
738                          llvm::Value* n_start, llvm::Value* n_end);
739 
740   // This emits a tiled GEMM kernel.  For a detailed description see the comment
741   // on the implementation.
742   void EmitTiledGemm(VectorSupportLibrary* vsl, int64_t tile_size_k,
743                      llvm::Value* k_start, llvm::Value* k_end,
744                      llvm::Value* n_start, llvm::Value* n_end,
745                      int64_t tile_size_m, llvm::Value* m_start,
746                      llvm::Value* m_end);
747 
GetInt64(int64_t value)748   llvm::Value* GetInt64(int64_t value) { return b_->getInt64(value); }
749 
config() const750   Config config() const { return config_; }
dims() const751   Dimensions dims() const { return config().dims(); }
752 
max_vectorization_width() const753   int64_t max_vectorization_width() const {
754     return config().max_vectorization_width();
755   }
max_vector_count() const756   int64_t max_vector_count() const { return config().max_vector_count(); }
min_vectorization_width() const757   int64_t min_vectorization_width() const {
758     return config().min_vectorization_width();
759   }
tile_size_m() const760   int64_t tile_size_m() const { return config().tile_size_m(); }
tile_size_k() const761   int64_t tile_size_k() const { return config().tile_size_k(); }
scalar_type() const762   PrimitiveType scalar_type() const { return config().scalar_type(); }
763 
764   llvm::Value* lhs_;
765   llvm::Value* rhs_;
766   llvm::Value* result_;
767   Config config_;
768 
769   llvm::IRBuilder<>* b_;
770   KernelSupportLibrary ksl_;
771 };
772 
Emit()773 void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); }
774 
HandleResiduesOnN()775 void TiledSmallGemmEmitter::HandleResiduesOnN() {
776   // We can only iterate the `n` dimension for an extent that is divisible by
777   // the vectorization width.  So we emit an outer loop that first processes the
778   // largest extent in `n` that is divisible by max_vectorization_width, then
779   // the largest remaining extent that is divisible by max_vectorization_width /
780   // 2 etc.
781 
782   int64_t current_vectorization_width =
783       max_vector_count() * max_vectorization_width();
784   int64_t current_vector_count = max_vector_count();
785 
786   int64_t n_start = 0;
787   while (n_start != dims().n() &&
788          current_vectorization_width >= min_vectorization_width()) {
789     int64_t n_end = dims().n() - (dims().n() % current_vectorization_width);
790     if (n_start != n_end) {
791       VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_,
792                                "gemm");
793       HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end));
794       n_start = n_end;
795     }
796     if (current_vector_count == 1) {
797       current_vectorization_width /= 2;
798     } else {
799       current_vector_count--;
800       current_vectorization_width =
801           current_vector_count * max_vectorization_width();
802     }
803   }
804 
805   if (n_start != dims().n()) {
806     VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm");
807     ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
808       llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1));
809       HandleResiduesOnK(&vsl, n_i, n_i_next);
810     });
811   }
812 }
813 
HandleResiduesOnK(VectorSupportLibrary * vsl,llvm::Value * n_start,llvm::Value * n_end)814 void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
815                                               llvm::Value* n_start,
816                                               llvm::Value* n_end) {
817   int64_t k_start = 0;
818   int64_t k_end = dims().k() - (dims().k() % tile_size_k());
819   if (k_end != k_start) {
820     HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end),
821                       n_start, n_end);
822     k_start = k_end;
823   }
824 
825   if (k_start != dims().k()) {
826     HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start),
827                       GetInt64(dims().k()), n_start, n_end);
828   }
829 }
830 
HandleResiduesOnM(VectorSupportLibrary * vsl,int64_t tile_size_k,llvm::Value * k_start,llvm::Value * k_end,llvm::Value * n_start,llvm::Value * n_end)831 void TiledSmallGemmEmitter::HandleResiduesOnM(
832     VectorSupportLibrary* vsl, int64_t tile_size_k, llvm::Value* k_start,
833     llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) {
834   const int64_t m_end = dims().m() - dims().m() % tile_size_m();
835   EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(),
836                 GetInt64(0), GetInt64(m_end));
837 
838   if (m_end != dims().m()) {
839     EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end,
840                   dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m()));
841   }
842 }
843 
844 // The loop structure is:
845 //
846 // Iterate over dimension M as m:
847 //   Iterate over dimension N as n:
848 //     Iterate over dimension K as k:
849 //       OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n])
850 //
851 // I.e. a just a tiled version of a "naive" GEMM.
852 //
853 // The tiling scheme is as follows:
854 //
855 // Let the LHS be:
856 //
857 //   +----+----+----+
858 //   | a0 | b0 | c0 | .
859 //   +----+----+----+ .
860 //   | a1 | b1 | c1 | .
861 //   +----+----+----+
862 //     ..     ..
863 //
864 // and the RHS be:
865 //
866 //   +----+----+----+----+
867 //   | p0 | p1 | p2 | p3 | .
868 //   +----+----+----+----+ .
869 //   | q0 | q1 | q2 | q3 | .
870 //   +----+----+----+----+
871 //   | r0 | r1 | r2 | r3 | .
872 //   +----+----+----+----+ .
873 //     ......    ......
874 //
875 // and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted
876 // by `vsl`) be 4.  Then we want to matrix multiply this tile to get a [2,4]
877 // matrix that we can increment the result matrix by.
878 //
879 // First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank
880 // 3 array, L, of dimension [2,3,4]:
881 //
882 //       L[0,_,_]           *      L[1,_,_]
883 //                          *
884 //   +----+----+----+----+  *  +----+----+----+----+
885 //   | a0 | a0 | a0 | a0 |  *  | a1 | a1 | a1 | a1 |
886 //   +----+----+----+----+  *  +----+----+----+----+
887 //   | b0 | b0 | b0 | b0 |  *  | b1 | b1 | b1 | b1 |
888 //   +----+----+----+----+  *  +----+----+----+----+
889 //   | c0 | c0 | c0 | c0 |  *  | c1 | c1 | c1 | c1 |
890 //   +----+----+----+----+  *  +----+----+----+----+
891 //
892 //
893 // Then we FMA L[0,_,_] with the RHS to get the first row of the result and
894 // L[1,_,_] with the RHS to get the second row of the result.  For example,
895 // L[0,_,_] is computed as:
896 //
897 //   +----+----+----+----+   +----+----+----+----+
898 //   | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 |   +
899 //   +----+----+----+----+   +----+----+----+----+
900 //
901 //   +----+----+----+----+   +----+----+----+----+
902 //   | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 |   +
903 //   +----+----+----+----+   +----+----+----+----+
904 //
905 //   +----+----+----+----+   +----+----+----+----+
906 //   | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 |
907 //   +----+----+----+----+   +----+----+----+----+
908 //
909 // to get:
910 //
911 //   +-------------------+-------------------+-------------------+---------
912 //   | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 |  ...
913 //   +-------------------+-------------------+-------------------+---------
EmitTiledGemm(VectorSupportLibrary * vsl,int64_t tile_size_k,llvm::Value * k_start,llvm::Value * k_end,llvm::Value * n_start,llvm::Value * n_end,int64_t tile_size_m,llvm::Value * m_start,llvm::Value * m_end)914 void TiledSmallGemmEmitter::EmitTiledGemm(
915     VectorSupportLibrary* vsl, int64_t tile_size_k, llvm::Value* k_start,
916     llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end,
917     int64_t tile_size_m, llvm::Value* m_start, llvm::Value* m_end) {
918   ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) {
919     MemoryTile result_memory_tile(vsl, b_, /*matrix=*/result_,
920                                   /*matrix_size_along_minor_dim=*/dims().n(),
921                                   /*major_dim_offset=*/m_i,
922                                   /*tile_size_along_major_dim=*/tile_size_m);
923     MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_,
924                                /*matrix_size_along_minor_dim=*/dims().k(),
925                                /*major_dim_offset=*/m_i,
926                                /*tile_size_along_major_dim=*/tile_size_m);
927     ksl_.For(
928         "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
929           TileVariable result_tile_var(vsl, result_memory_tile.LoadTile(n_i));
930           ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) {
931             MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i,
932                                        tile_size_k);
933             std::vector<std::vector<llvm::Value*>> lhs_tile =
934                 lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k);
935             std::vector<llvm::Value*> rhs_tile = rhs_memory_tile.LoadTile(n_i);
936             std::vector<llvm::Value*> result_tile = result_tile_var.Get();
937             for (int64_t r_m_i = 0; r_m_i < tile_size_m; r_m_i++) {
938               for (int64_t r_k_i = 0; r_k_i < tile_size_k; r_k_i++) {
939                 result_tile[r_m_i] =
940                     vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i],
941                                 result_tile[r_m_i]);
942               }
943             }
944             result_tile_var.Set(result_tile);
945           });
946 
947           result_memory_tile.StoreTile(result_tile_var.Get(), n_i);
948         });
949   });
950 }
951 
GetPointerToElementType(llvm::Type * pointer_type)952 llvm::Type* GetPointerToElementType(llvm::Type* pointer_type) {
953   if (pointer_type->isOpaquePointerTy()) return pointer_type;
954 
955   llvm::Type* type = pointer_type->getNonOpaquePointerElementType();
956   while (auto* array_type = llvm::dyn_cast<llvm::ArrayType>(type)) {
957     type = array_type->getElementType();
958   }
959 
960   return type->getPointerTo();
961 }
962 
963 struct GemvBuffersWithCanonicalType {
964   llvm::Value* lhs_canonicalized;
965   llvm::Value* rhs_canonicalized;
966   llvm::Value* addend_canonicalized;
967   llvm::Value* result_canonicalized;
968 };
969 
GetGemvBuffersWithCanonicalType(llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b)970 GemvBuffersWithCanonicalType GetGemvBuffersWithCanonicalType(
971     llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend,
972     llvm::Value* result, llvm::IRBuilder<>* b) {
973   // We characterize a GEMV operation via M and K, since N is implicitly 1.
974   // This means the GEMV that multiplies (say) [5,6] with [6,1] is implemented
975   // by the same GEMV that multiplies [5,6] with [1,6].  However, the
976   // `llvm::Types` for the inputs to the two GEMVs don't match (in a trivial
977   // sense -- the in memory representations are the same) since they're computed
978   // from the `xla::Shape`s.  Since we want to be able to call the same
979   // `llvm::Function` for the two GEMVs we canonicalize the types of the GEMV
980   // inputs here into the same type.
981   GemvBuffersWithCanonicalType buffers_with_canonical_type;
982   llvm::Type* lhs_type = lhs->getType();
983   llvm::Type* rhs_type = rhs->getType();
984   llvm::Type* addend_type = addend ? addend->getType() : nullptr;
985   llvm::Type* result_type = result->getType();
986 
987   buffers_with_canonical_type.lhs_canonicalized =
988       b->CreateBitCast(lhs, GetPointerToElementType(lhs_type));
989   buffers_with_canonical_type.rhs_canonicalized =
990       b->CreateBitCast(rhs, GetPointerToElementType(rhs_type));
991   buffers_with_canonical_type.addend_canonicalized =
992       addend ? b->CreateBitCast(addend, GetPointerToElementType(addend_type))
993              : nullptr;
994   buffers_with_canonical_type.result_canonicalized =
995       b->CreateBitCast(result, GetPointerToElementType(result_type));
996 
997   return buffers_with_canonical_type;
998 }
999 
1000 }  // namespace
1001 
EmitRowMajorGemv(PrimitiveType scalar_type,int64_t tile_rows,int64_t tile_cols,int64_t m,int64_t k,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b,const HloModuleConfig & module_config)1002 void EmitRowMajorGemv(PrimitiveType scalar_type, int64_t tile_rows,
1003                       int64_t tile_cols, int64_t m, int64_t k, llvm::Value* lhs,
1004                       llvm::Value* rhs, llvm::Value* addend,
1005                       llvm::Value* result, llvm::IRBuilder<>* b,
1006                       const HloModuleConfig& module_config) {
1007   RowMajorMatrixVectorProductEmitter::Config config(
1008       /*scalar_type=*/scalar_type,
1009       /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
1010       /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
1011 
1012   GemvBuffersWithCanonicalType canonical_inputs =
1013       GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
1014 
1015   KernelSupportLibrary::EmitAndCallOutlinedKernel(
1016       module_config, b, config.GetCacheKey(),
1017       canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
1018       canonical_inputs.addend_canonicalized,
1019       canonical_inputs.result_canonicalized,
1020       [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs,
1021                                       llvm::Value* addend,
1022                                       llvm::Value* result) {
1023         RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
1024                                                    result, b);
1025         emitter.Emit();
1026       });
1027 }
1028 
EmitColumnMajorGemv(PrimitiveType scalar_type,int64_t tile_rows,int64_t tile_cols,int64_t m,int64_t k,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * addend,llvm::Value * result,llvm::IRBuilder<> * b,const HloModuleConfig & module_config)1029 void EmitColumnMajorGemv(PrimitiveType scalar_type, int64_t tile_rows,
1030                          int64_t tile_cols, int64_t m, int64_t k,
1031                          llvm::Value* lhs, llvm::Value* rhs,
1032                          llvm::Value* addend, llvm::Value* result,
1033                          llvm::IRBuilder<>* b,
1034                          const HloModuleConfig& module_config) {
1035   ColumnMajorMatrixVectorProductEmitter::Config config(
1036       /*scalar_type=*/scalar_type,
1037       /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
1038       /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
1039 
1040   GemvBuffersWithCanonicalType canonical_inputs =
1041       GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
1042 
1043   KernelSupportLibrary::EmitAndCallOutlinedKernel(
1044       module_config, b, config.GetCacheKey(),
1045       canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
1046       canonical_inputs.addend_canonicalized,
1047       canonical_inputs.result_canonicalized,
1048       [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs,
1049                                       llvm::Value* addend,
1050                                       llvm::Value* result) {
1051         ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
1052                                                       result, b);
1053         emitter.Emit();
1054       });
1055 }
1056 
EmitSmallGemm(PrimitiveType scalar_type,int64_t m,int64_t k,int64_t n,int64_t max_vectorization_width,int64_t max_vector_count,int64_t min_vectorization_width,int64_t tile_size_m,int64_t tile_size_k,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * result,llvm::IRBuilder<> * b,const HloModuleConfig & module_config)1057 void EmitSmallGemm(PrimitiveType scalar_type, int64_t m, int64_t k, int64_t n,
1058                    int64_t max_vectorization_width, int64_t max_vector_count,
1059                    int64_t min_vectorization_width, int64_t tile_size_m,
1060                    int64_t tile_size_k, llvm::Value* lhs, llvm::Value* rhs,
1061                    llvm::Value* result, llvm::IRBuilder<>* b,
1062                    const HloModuleConfig& module_config) {
1063   TiledSmallGemmEmitter::Config config(
1064       /*scalar_type=*/scalar_type,
1065       TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
1066       /*max_vectorization_width=*/max_vectorization_width,
1067       /*max_vector_count=*/max_vector_count,
1068       /*min_vectorization_width=*/min_vectorization_width,
1069       /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k);
1070 
1071   KernelSupportLibrary::EmitAndCallOutlinedKernel(
1072       module_config, b, config.GetCacheKey(), lhs, rhs, result,
1073       [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result) {
1074         TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs,
1075                                                  /*rhs=*/rhs,
1076                                                  /*result=*/result, b);
1077         small_gemm_emitter.Emit();
1078       });
1079 }
1080 
1081 }  // namespace cpu
1082 }  // namespace xla
1083