• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/strings/str_cat.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/Value.h"
26 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"  // from @llvm-project
27 #include "mlir/Dialect/StandardOps/Utils/Utils.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
31 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
32 #include "mlir/IR/Value.h"  // from @llvm-project
33 #include "tensorflow/compiler/xla/primitive_util.h"
34 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
35 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
36 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
37 #include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
38 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
39 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
40 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
41 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
42 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
43 #include "tensorflow/compiler/xla/service/hlo_module.h"
44 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
45 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
46 #include "tensorflow/compiler/xla/shape_util.h"
47 #include "tensorflow/compiler/xla/status_macros.h"
48 #include "tensorflow/compiler/xla/util.h"
49 #include "tensorflow/compiler/xla/xla_data.pb.h"
50 #include "tensorflow/core/platform/logging.h"
51 
52 namespace xla {
53 
54 using llvm_ir::SetToFirstInsertPoint;
55 
56 namespace cpu {
57 namespace {
58 // Returns true if we should call into multi-threaded Eigen routines.
ShouldUseMultiThreadedEigen(const HloModuleConfig & config)59 bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) {
60   return config.debug_options().xla_cpu_multi_thread_eigen();
61 }
62 
63 // Represents a dot operation.  We use this in lieu of an `HloInstruction`
64 // because we want to be able to create this for the "inner" dot operation in a
65 // batch dot, for which there is no separate HLO instruction.
66 struct DotInfo {
67   Shape lhs_shape;
68   Shape rhs_shape;
69   Shape result_shape;
70   DotDimensionNumbers dim_nums;
71 
72   DotInfo() = default;
73 
DotInfoxla::cpu::__anon68a1a4880111::DotInfo74   explicit DotInfo(const HloInstruction& instr) {
75     CHECK_EQ(instr.opcode(), HloOpcode::kDot);
76     lhs_shape = instr.operand(0)->shape();
77     rhs_shape = instr.operand(1)->shape();
78     result_shape = instr.shape();
79     dim_nums = instr.dot_dimension_numbers();
80   }
81 };
82 
83 // Dictates how a dot operation is implemented.
84 enum class DotImplementationStrategy {
85   // The dot operation is lowered into LLVM IR that implements a naive nested
86   // loop that computes the result one element at a time.  This is our
87   // "fallback"; we don't really want this to kick in for any non-trival dot
88   // operation.
89   kNaiveLlvmIr,
90 
91   // The dot operation is lowered into LLVM IR that implements a tiled
92   // Matrix*Vector operation.  This strategy also allows fusing in a bias add
93   // into the dot.  The matrix can be row major or column major, both are
94   // supported.
95   kTiledLlvmIrGemv,
96 
97   // The dot operation is lowered into LLVM IR that implements a tiled
98   // Matrix*Matrix operation.  No fusions are supported.  The two inputs
99   // and the output have to be row major.
100   kTiledLlvmIrGemm,
101 
102   // The dot operation is lowered into linalg.matmul op and lowered to LLVM IR.
103   kLinalgMatmul,
104 
105   // The dot operation is lowered into a call into an Eigen routine.  No fusions
106   // are supported today.  The two inputs and the output have to be row major.
107   // However, we do allow transposing either the LHS or the RHS as part of the
108   // GEMM -- we expose this flexibility as flexibility in the contraction
109   // dimensions, but we can also see this as flexibility in the input layouts.
110   kEigen,
111 };
112 
113 // Returns the implementation strategy for a dot with the configuration
114 // `dot_info`.
115 DotImplementationStrategy GetDotImplementationStrategy(
116     const HloModuleConfig& config, const DotInfo& dot_info,
117     const TargetMachineFeatures& target_machine_features);
118 
119 // Helper class for emitting LLVM IR to perform the dot operation.
120 class DotOpEmitter {
121  public:
122   explicit DotOpEmitter(DotInfo dot_info, string dot_hlo_name,
123                         const llvm_ir::IrArray& target_array,
124                         const llvm_ir::IrArray& lhs_array,
125                         const llvm_ir::IrArray& rhs_array,
126                         const llvm_ir::IrArray* addend_array,
127                         llvm::Value* executable_run_options_value,
128                         llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
129                         const HloModuleConfig& hlo_module_config,
130                         const TargetMachineFeatures& target_machine_features);
131 
132   // Emits the IR to perform the dot operation.
133   Status Emit();
134 
135  private:
136   // Emits instructions to perform a scalar dot product (a multiply of the
137   // LHS and RHS) and store the results in the target.
138   Status EmitScalarDot();
139 
140   // Emits a call to the CPU runtime to perform the matrix multiply.
141   Status EmitCallToRuntime();
142 
143   // Represents the dimensions of a matrix-matrix multiply operation.
144   struct MatMultDims {
145     // The number of rows in the LHS.
146     int64 m;
147 
148     // The number of columns in the LHS, which is also must be equal to the
149     // number of rows in the RHS.
150     int64 k;
151 
152     // The number of columns on the RHS.
153     int64 n;
154 
155     // True if the LHS matrix is column major.
156     bool lhs_column_major;
157 
158     // True if the LHS contraction dimension is 1.
159     bool lhs_canonical;
160 
161     // True if the RHS matrix is column major.
162     bool rhs_column_major;
163 
164     // True if the RHS contraction dimension is 0.
165     bool rhs_canonical;
166   };
167 
168   // Get the MatMultDims instance for the dot product this DotOpEmitter
169   // represents.  Precondition: the dot is of rank 2 (and thus its operands are
170   // of rank 2 as well).
171   MatMultDims GetMatMultDims() const;
172 
173   // Lowers the dot operation as a tiled Matrix*Vector loop.
174   void EmitTiledLlvmIrGemv();
175 
176   // Lowers the dot operation as a tiled Matrix*Matrix loop.
177   void EmitTiledLlvmIrGemm();
178 
179   // Lowers the dot operation through MLIR's linalg.matmul.
180   Status EmitLinalgMatmul();
181 
182   // Lowers the dot operation as a naive nested loop that computes the result
183   // one element at a time.
184   void EmitNaiveLlvmIrGemm();
185 
186   // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
187   // registers.
GetGemvTilingFactor() const188   int64 GetGemvTilingFactor() const {
189     const int64_t kDefaultTilingFactor = 8;
190     return options::LlvmIrGemvTilingFactor(hlo_module_config_)
191         .value_or(kDefaultTilingFactor);
192   }
193 
GetGemmTileSize() const194   std::tuple<int64, int64, int64> GetGemmTileSize() const {
195     // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
196     //
197     // TODO(b/80093688): Tune for other architectures and centralize this
198     // information in one place.
199     const std::tuple<int64, int64, int64> kDefaultTileSize =
200         std::tuple<int64, int64, int64>(11, 9, 1);
201     return options::LlvmIrGemmTileSize(hlo_module_config_)
202         .value_or(kDefaultTileSize);
203   }
204 
GetMlirGemmTileSize() const205   std::array<int64_t, 3> GetMlirGemmTileSize() const {
206     // Tile by 4 x registers x register size. This was picked by running
207     // small matmuls on Haswell and Skylake. There's a lot of room for
208     // improvement here.
209     constexpr int64_t kDefaultTileSizeForM = 4;
210     int64_t elements_per_register =
211         target_machine_features_.vector_register_num_elements(
212             *b_->GetInsertBlock()->getParent(),
213             dot_info_.result_shape.element_type());
214     int64_t num_registers = target_machine_features_.vector_register_count(
215         *b_->GetInsertBlock()->getParent());
216     return {{kDefaultTileSizeForM, num_registers, elements_per_register}};
217   }
218 
219   DotInfo dot_info_;
220   string dot_hlo_name_;
221   const llvm_ir::IrArray& target_array_;
222   const llvm_ir::IrArray& lhs_array_;
223   const llvm_ir::IrArray& rhs_array_;
224   const llvm_ir::IrArray* addend_array_;
225   llvm::Value* executable_run_options_value_;
226   llvm::IRBuilder<>* b_;
227   mlir::MLIRContext* mlir_context_;
228   const HloModuleConfig& hlo_module_config_;
229   const TargetMachineFeatures& target_machine_features_;
230 };
231 }  // namespace
232 
DotOpEmitter(DotInfo dot_info,string dot_hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)233 DotOpEmitter::DotOpEmitter(
234     DotInfo dot_info, string dot_hlo_name, const llvm_ir::IrArray& target_array,
235     const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
236     const llvm_ir::IrArray* addend_array,
237     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
238     mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
239     const TargetMachineFeatures& target_machine_features)
240     : dot_info_(std::move(dot_info)),
241       dot_hlo_name_(std::move(dot_hlo_name)),
242       target_array_(target_array),
243       lhs_array_(lhs_array),
244       rhs_array_(rhs_array),
245       addend_array_(addend_array),
246       executable_run_options_value_(executable_run_options_value),
247       b_(b),
248       mlir_context_(mlir_context),
249       hlo_module_config_(hlo_module_config),
250       target_machine_features_(target_machine_features) {}
251 
EmitLinalgMatmul()252 Status DotOpEmitter::EmitLinalgMatmul() {
253   Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape};
254   llvm::Value* operand_ptrs[] = {lhs_array_.GetBasePointer(),
255                                  rhs_array_.GetBasePointer()};
256   llvm::Value* target_ptr = target_array_.GetBasePointer();
257 
258   // Zero out the output buffer.
259   int64_t size_bytes = ShapeUtil::ByteSizeOf(dot_info_.result_shape);
260   b_->CreateMemSet(target_ptr, b_->getInt8(0), /*Size=*/size_bytes,
261                    /*Align=*/llvm::MaybeAlign(1));
262 
263   std::string name =
264       absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_",
265                    dot_info_.lhs_shape.ToString(true), "_",
266                    dot_info_.rhs_shape.ToString(true));
267 
268   return EmitMlirFuncAndCall(
269       mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
270       operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) {
271         CHECK_EQ(dot_info_.dim_nums.lhs_contracting_dimensions_size(), 1);
272         CHECK_EQ(dot_info_.dim_nums.rhs_contracting_dimensions_size(), 1);
273         mlir::MLIRContext* context = builder->getContext();
274         mlir::Value a = function.getArgument(0), b = function.getArgument(1),
275                     c = function.getArgument(2);
276 
277         llvm::SmallVector<mlir::AffineExpr, 2> b_exprs(
278             dot_info_.lhs_shape.rank());
279         llvm::SmallVector<mlir::AffineExpr, 2> c_exprs(
280             dot_info_.rhs_shape.rank());
281 
282         llvm::SmallVector<mlir::AffineExpr, 2> parallel_exprs;
283         mlir::AffineExpr reduce_expr;
284         for (int i = 0; i != dot_info_.result_shape.rank(); ++i) {
285           parallel_exprs.push_back(mlir::getAffineDimExpr(i, context));
286         }
287         reduce_expr =
288             mlir::getAffineDimExpr(dot_info_.result_shape.rank(), context);
289 
290         // The reduction expr is shared for both inputs.
291         b_exprs[dot_info_.dim_nums.lhs_contracting_dimensions(0)] = reduce_expr;
292         c_exprs[dot_info_.dim_nums.rhs_contracting_dimensions(0)] = reduce_expr;
293 
294         // Fill in the remaining parallel exprs.
295         int par_expr_num = 0;
296         for (auto* v : {&b_exprs, &c_exprs}) {
297           for (auto& e : *v) {
298             if (!e) {
299               e = parallel_exprs[par_expr_num++];
300             }
301           }
302         }
303 
304         llvm::SmallVector<llvm::StringRef, 4> iteratorTypes(
305             parallel_exprs.size(), toString(mlir::IteratorType::Parallel));
306         iteratorTypes.push_back(toString(mlir::IteratorType::Reduction));
307         builder->create<mlir::linalg::GenericOp>(
308             function.getLoc(),
309             /*inputs=*/mlir::ValueRange{b, c},
310             /*outputs=*/mlir::ValueRange{a},
311             /*indexingMaps=*/
312             mlir::AffineMap::inferFromExprList(
313                 {b_exprs, c_exprs, parallel_exprs}),
314             /*iteratorTypes=*/iteratorTypes,
315             [](mlir::OpBuilder& b, mlir::Location loc, mlir::ValueRange args) {
316               mlir::ArithBuilder ab(b, loc);
317               mlir::Value mul = ab.mul(args[0], args[1]);
318               mlir::Value add = ab.add(mul, args[2]);
319               b.create<mlir::linalg::YieldOp>(loc, add);
320             });
321         builder->create<mlir::ReturnOp>(function.getLoc());
322 
323         mlir::linalg::LinalgTilingOptions tilingOptions;
324         tilingOptions = tilingOptions.setTileSizes(GetMlirGemmTileSize());
325         int64_t alignment =
326             target_machine_features_.minimum_alignment_for_allocation(
327                 ShapeUtil::ByteSizeOf(dot_info_.result_shape));
328         mlir::linalg::CodegenStrategy strategy;
329         strategy.tile<mlir::linalg::GenericOp>(tilingOptions)
330             .promote<mlir::linalg::GenericOp>(
331                 mlir::linalg::LinalgPromotionOptions()
332                     .setAlignment(alignment)
333                     .setUseFullTileBuffersByDefault(true)
334                     .setUseAlloca(true))
335             .vectorize<mlir::linalg::GenericOp>()
336             .setVectorTransformsOptions(
337                 mlir::vector::VectorTransformsOptions()
338                     .setVectorTransformsOptions(
339                         mlir::vector::VectorContractLowering::OuterProduct))
340             .setVectorTransferToSCFOptions(
341                 mlir::VectorTransferToSCFOptions().setUnroll(true));
342         strategy.transform(function);
343       });
344 }
345 
EmitTiledLlvmIrGemm()346 void DotOpEmitter::EmitTiledLlvmIrGemm() {
347   PrimitiveType primitive_type = dot_info_.result_shape.element_type();
348   MatMultDims mat_mult_dims = GetMatMultDims();
349 
350   llvm::Value* lhs = lhs_array_.GetBasePointer();
351   llvm::Value* rhs = rhs_array_.GetBasePointer();
352   llvm::Value* target = target_array_.GetBasePointer();
353   int64_t m = mat_mult_dims.m;
354   int64_t k = mat_mult_dims.k;
355   int64_t n = mat_mult_dims.n;
356 
357   if (mat_mult_dims.lhs_column_major) {
358     std::swap(lhs, rhs);
359     std::swap(m, n);
360   }
361 
362   int64_t size_bytes =
363       m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
364   b_->CreateMemSet(target, b_->getInt8(0), /*Size=*/size_bytes,
365                    /*Align=*/llvm::MaybeAlign(1));
366 
367   int64_t max_target_vector_width =
368       target_machine_features_.vector_register_num_elements(
369           *b_->GetInsertBlock()->getParent(), primitive_type);
370 
371   int64_t tile_size_m, tile_size_k, tile_size_n_in_vector_width;
372   std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
373       GetGemmTileSize();
374 
375   EmitSmallGemm(
376       /*scalar_type=*/primitive_type,
377       /*m=*/m, /*k=*/k, /*n=*/n,
378       /*max_vectorization_width=*/max_target_vector_width,
379       /*max_vector_count=*/tile_size_n_in_vector_width,
380       /*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width),
381       /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs,
382       /*rhs=*/rhs, /*result=*/target, b_, hlo_module_config_);
383 }
384 
EmitTiledLlvmIrGemv()385 void DotOpEmitter::EmitTiledLlvmIrGemv() {
386   PrimitiveType primitive_type = dot_info_.result_shape.element_type();
387 
388   CHECK(primitive_util::IsFloatingPointType(primitive_type) ||
389         primitive_util::IsIntegralType(primitive_type));
390 
391   MatMultDims mat_mult_dims = GetMatMultDims();
392   bool is_column_major_matrix_vector_gemv = false;
393   bool is_row_major_matrix_vector_gemv = false;
394 
395   int64_t m, k;
396   bool swap_operands;
397 
398   if (mat_mult_dims.m == 1) {
399     // Our emitters can only do Matrix*Vector (abbreviated as M*V) but when M=1
400     // we actually want V*M.  We implement V*M as follows (Tr(X) = Transpose of
401     // X):
402     //
403     //   V*M = Tr(Tr(V*M))  // Tr(Tr(X)) == X
404     //       = Tr(Tr(M) * Tr(V))  // Tr(A * B) == Tr(B) * Tr(A)
405     //
406     // Since transposing a vector is physically a no-op, this is really
407     // equivalent to `Tr(M) * V`.  We further implement Tr(M) by pretending that
408     // M is row major if it is actually column major and vice-versa.
409 
410     bool rhs_effectively_column_major = mat_mult_dims.rhs_canonical
411                                             ? mat_mult_dims.rhs_column_major
412                                             : !mat_mult_dims.rhs_column_major;
413 
414     if (rhs_effectively_column_major) {
415       k = mat_mult_dims.k;
416       m = mat_mult_dims.n;
417 
418       // We set is_row_major_matrix_vector_gemv and not
419       // is_column_major_matrix_vector_gemv to implement the Transpose trick
420       // mentioned above.
421       is_row_major_matrix_vector_gemv = true;
422       swap_operands = true;
423     } else {
424       k = mat_mult_dims.k;
425       m = mat_mult_dims.n;
426 
427       // We set is_column_major_matrix_vector_gemv and not
428       // is_row_major_matrix_vector_gemv to implement the Transpose trick
429       // mentioned above.
430       is_column_major_matrix_vector_gemv = true;
431       swap_operands = true;
432     }
433   }
434 
435   if (mat_mult_dims.n == 1) {
436     bool lhs_effectively_column_major = mat_mult_dims.lhs_canonical
437                                             ? mat_mult_dims.lhs_column_major
438                                             : !mat_mult_dims.lhs_column_major;
439 
440     if (lhs_effectively_column_major) {
441       m = mat_mult_dims.m;
442       k = mat_mult_dims.k;
443       is_column_major_matrix_vector_gemv = true;
444       swap_operands = false;
445     } else {
446       m = mat_mult_dims.m;
447       k = mat_mult_dims.k;
448       is_row_major_matrix_vector_gemv = true;
449       swap_operands = false;
450     }
451   }
452 
453   CHECK(is_column_major_matrix_vector_gemv || is_row_major_matrix_vector_gemv);
454 
455   int64_t tiling_factor = GetGemvTilingFactor();
456   CHECK_GT(tiling_factor, 0);
457 
458   llvm::Value* result_op = target_array_.GetBasePointer();
459   llvm::Value* lhs_op =
460       swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
461   llvm::Value* rhs_op =
462       swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
463 
464   const int target_vector_register_element_size =
465       target_machine_features_.vector_register_num_elements(
466           *b_->GetInsertBlock()->getParent(), primitive_type);
467 
468   // We may not always know the vector register size for the target we're
469   // compiling against, in which case target_vector_register_element_size is 0.
470   // In these cases we choose a default LLVM IR register size.
471   const int kUnknownTargetVectorRegisterSize = 4;
472   const int vector_register_element_size =
473       target_vector_register_element_size == 0
474           ? kUnknownTargetVectorRegisterSize
475           : target_vector_register_element_size;
476 
477   if (is_column_major_matrix_vector_gemv) {
478     VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
479             << " and k = " << k;
480     EmitColumnMajorGemv(
481         /*scalar_type=*/primitive_type,
482         /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
483         /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
484         /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
485         /*result=*/result_op, b_, hlo_module_config_);
486   } else {
487     VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
488             << " and k = " << k;
489     EmitRowMajorGemv(
490         /*scalar_type=*/primitive_type,
491         /*tile_rows=*/tiling_factor,
492         /*tile_cols=*/vector_register_element_size,
493         /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
494         /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
495         /*result=*/result_op, b_, hlo_module_config_);
496   }
497 }
498 
Emit()499 Status DotOpEmitter::Emit() {
500   // The dot operation performs a sum of products over dimension 0 of the left
501   // hand side operand and dimension 1 of the right hand side operand.
502   //
503   // Let the shapes of lhs and rhs be defined as below:
504   //
505   //   lhs = [L{n-1} x L{n-2} x ... L{0}]
506   //   rhs = [R{m-1} x R{m-2} x ... R{0}]
507   //
508   // The sum-of-products dimension in the lhs has size L{0} and the dimension in
509   // the rhs has size R{1}. Necessarily, then:
510   //
511   //   L{0} == R{1}
512   //
513   // The output of the operation has the following shape:
514   //
515   //   output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
516   //
517   // To perform the operation we construct a loop nest with one for-loop for
518   // each dimension of the output. Inside this loop nest is another for-loop
519   // which performs the sum-of-products (the reduction loop) before storing
520   // the result in the output buffer.
521 
522   const Shape& lhs_shape = lhs_array_.GetShape();
523   const Shape& rhs_shape = rhs_array_.GetShape();
524 
525   if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
526     // If the operands are scalar, don't emit any loops.
527     TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) &&
528                  ShapeUtil::IsScalar(rhs_shape));
529     return EmitScalarDot();
530   }
531 
532   switch (GetDotImplementationStrategy(hlo_module_config_, dot_info_,
533                                        target_machine_features_)) {
534     case DotImplementationStrategy::kNaiveLlvmIr:
535       EmitNaiveLlvmIrGemm();
536       return Status::OK();
537 
538     case DotImplementationStrategy::kTiledLlvmIrGemv:
539       EmitTiledLlvmIrGemv();
540       return Status::OK();
541 
542     case DotImplementationStrategy::kTiledLlvmIrGemm:
543       EmitTiledLlvmIrGemm();
544       return Status::OK();
545 
546     case DotImplementationStrategy::kLinalgMatmul:
547       return EmitLinalgMatmul();
548 
549     case DotImplementationStrategy::kEigen:
550       return EmitCallToRuntime();
551   }
552 }
553 
EmitNaiveLlvmIrGemm()554 void DotOpEmitter::EmitNaiveLlvmIrGemm() {
555   CHECK_EQ(addend_array_, nullptr);
556 
557   const Shape& lhs_shape = lhs_array_.GetShape();
558   const Shape& rhs_shape = rhs_array_.GetShape();
559   const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
560 
561   // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special
562   // case where the reduction dimension is 0 for both LHS and RHS. This results
563   // in a vector dot product producing a scalar.
564   int64_t lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0);
565   int64_t rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0);
566 
567   // Verify the reduction dimension in the two operands are the same size.
568   CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension),
569            rhs_shape.dimensions(rhs_reduction_dimension));
570 
571   bool lhs_reduction_along_minor_dimension =
572       lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0);
573   bool rhs_reduction_along_minor_dimension =
574       rhs_reduction_dimension == LayoutUtil::Minor(rhs_shape.layout(), 0);
575 
576   // Create loop nests which loop through the LHS operand dimensions and the RHS
577   // operand dimensions. The reduction dimension of the LHS and RHS are handled
578   // in a separate innermost loop which performs the sum of products.
579   llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_);
580   std::vector<llvm::Value*> lhs_multi_index =
581       loop_nest.EmitOperandArrayLoopNest(
582           lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
583   std::vector<llvm::Value*> rhs_multi_index =
584       loop_nest.EmitOperandArrayLoopNest(
585           rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
586 
587   // Create the loop which does the sum of products reduction.
588   //
589   // The prevent_unrolling bit is working around a deficiency in LLVM's loop
590   // vectorization pipeline, wherein in some cases unrolling a loop can prevent
591   // effective vectorization.  Since we know that the IR we generate when
592   // reducing across the minor dimension in both LHS and RHS is vectorized well
593   // by the loop vectorizer, we block unrolling in that case to stop loop unroll
594   // from messing up the vectorization.
595   std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
596       0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction",
597       /*unroll_mode=*/
598       (lhs_reduction_along_minor_dimension &&
599        rhs_reduction_along_minor_dimension)
600           ? xla::llvm_ir::UnrollMode::kNoUnroll
601           : xla::llvm_ir::UnrollMode::kDefaultUnroll);
602 
603   // The final entry in the rhs and lhs indexes is the indvar of the
604   // reduction loop.
605   lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
606   llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_shape,
607                                     b_->getInt64Ty());
608   rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
609   llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_shape,
610                                     b_->getInt64Ty());
611 
612   // For computing the sum of products we alloca a single location to store the
613   // dot product result as we accumulate it within the reduction loop. After the
614   // reduction loop we load the result and store into the output array.
615 
616   // Function entry basic block.
617   // - Emit alloca for accumulator
618   llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent();
619   SetToFirstInsertPoint(&func->getEntryBlock(), b_);
620   llvm::Type* accum_type = target_array_.GetElementLlvmType();
621   llvm::Value* accum_address =
622       b_->CreateAlloca(accum_type, /*ArraySize=*/nullptr, "accum_address");
623 
624   // Preheader basic block of reduction loop:
625   // - Initialize accumulator to zero.
626   llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
627   b_->SetInsertPoint(preheader_bb->getTerminator());
628 
629   b_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address);
630 
631   // Body basic block of reduction loop:
632   // - Load elements from lhs and rhs array.
633   // - Multiply lhs-element and rhs-element.
634   // - Load accumulator and add to product.
635   // - Store sum back into accumulator.
636   SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), b_);
637 
638   llvm::Value* lhs_element = lhs_array_.EmitReadArrayElement(lhs_index, b_);
639   llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, b_);
640 
641   llvm::Value* accum = b_->CreateLoad(accum_address);
642   llvm::Value* updated_accum;
643   if (ShapeUtil::ElementIsComplex(lhs_shape)) {
644     auto real = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {0}); };
645     auto imag = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {1}); };
646     llvm::Value* product_real =
647         b_->CreateFSub(b_->CreateFMul(real(lhs_element), real(rhs_element)),
648                        b_->CreateFMul(imag(lhs_element), imag(rhs_element)));
649     llvm::Value* product_imag =
650         b_->CreateFAdd(b_->CreateFMul(real(lhs_element), imag(rhs_element)),
651                        b_->CreateFMul(imag(lhs_element), real(rhs_element)));
652     updated_accum = b_->CreateInsertValue(
653         accum, b_->CreateFAdd(real(accum), product_real), {0});
654     updated_accum = b_->CreateInsertValue(
655         updated_accum, b_->CreateFAdd(imag(accum), product_imag), {1});
656   } else if (ShapeUtil::ElementIsIntegral(lhs_shape)) {
657     llvm::Value* product = b_->CreateMul(lhs_element, rhs_element);
658     updated_accum = b_->CreateAdd(accum, product);
659   } else if (lhs_shape.element_type() == PRED) {
660     llvm::Value* product = b_->CreateAnd(lhs_element, rhs_element);
661     updated_accum = b_->CreateOr(accum, product);
662   } else {
663     llvm::Value* product = b_->CreateFMul(lhs_element, rhs_element);
664     updated_accum = b_->CreateFAdd(accum, product);
665   }
666   b_->CreateStore(updated_accum, accum_address);
667 
668   // Exit basic block of reduction loop.
669   // - Load accumulator value (the result).
670   // - Store into output array.
671   SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), b_);
672 
673   llvm::Value* result = b_->CreateLoad(accum_address);
674 
675   // Create index into target address. The target index is the concatenation of
676   // the rhs and lhs indexes with the reduction dimensions removed. The terms
677   // from the rhs index are the lower dimensions in the index so we add them
678   // first.
679   std::vector<llvm::Value*> target_multi_index;
680   for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
681     if (dimension != lhs_reduction_dimension) {
682       target_multi_index.push_back(lhs_index[dimension]);
683     }
684   }
685   for (int dimension = 0; dimension < rhs_index.size(); ++dimension) {
686     if (dimension != rhs_reduction_dimension) {
687       target_multi_index.push_back(rhs_index[dimension]);
688     }
689   }
690 
691   llvm_ir::IrArray::Index target_index(
692       target_multi_index, target_array_.GetShape(), lhs_index.GetType());
693   target_array_.EmitWriteArrayElement(target_index, result, b_);
694 
695   // Set the IR builder insert point to the exit basic block of the outer most
696   // loop.
697   b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
698 }
699 
EmitScalarDot()700 Status DotOpEmitter::EmitScalarDot() {
701   // A scalar dot is just a scalar multiply.
702   llvm::Value* result;
703   // Use the same index_type for all tensor accesses in the same kernel.
704   llvm::Type* index_type = b_->getInt64Ty();
705   llvm_ir::IrArray::Index element_index(index_type);
706   llvm::Value* lhs_value =
707       lhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
708   llvm::Value* rhs_value =
709       rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
710   if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
711     auto get_real = [&](llvm::Value* x) {
712       return b_->CreateExtractValue(x, {0});
713     };
714 
715     auto get_imag = [&](llvm::Value* x) {
716       return b_->CreateExtractValue(x, {1});
717     };
718 
719     llvm::Value* real = b_->CreateFSub(
720         b_->CreateFMul(get_real(lhs_value), get_real(rhs_value)),
721         b_->CreateFMul(get_imag(lhs_value), get_imag(rhs_value)));
722     llvm::Value* imag = b_->CreateFAdd(
723         b_->CreateFMul(get_real(lhs_value), get_imag(rhs_value)),
724         b_->CreateFMul(get_imag(lhs_value), get_real(rhs_value)));
725     result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
726     result = b_->CreateInsertValue(result, real, {0});
727     result = b_->CreateInsertValue(result, imag, {1});
728   } else {
729     result = b_->CreateFMul(lhs_value, rhs_value);
730   }
731   target_array_.EmitWriteArrayElement(/*index=*/element_index, result, b_);
732   return Status::OK();
733 }
734 
EmitCallToRuntime()735 Status DotOpEmitter::EmitCallToRuntime() {
736   // The signature of the Eigen runtime matmul function is:
737   //
738   //   (void)(void* run_options, float* out, float* lhs, float* rhs,
739   //          int64 m, int64 n, int64 k, int32 transpose_lhs,
740   //          int32 transpose_rhs);
741   // The two transpose_... parameters are actually booleans, but we use int32
742   // to avoid target-dependent calling convention details.
743 
744   bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_);
745   bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
746   PrimitiveType type = target_array_.GetShape().element_type();
747   llvm::Function* function = b_->GetInsertBlock()->getParent();
748   llvm::Module* module = function->getParent();
749   llvm::Type* float_type;
750   const char* fn_name;
751   switch (type) {
752     case F16:
753       fn_name = multi_threaded
754                     ? runtime::kEigenMatMulF16SymbolName
755                     : runtime::kEigenSingleThreadedMatMulF16SymbolName;
756       float_type = b_->getHalfTy();
757       break;
758     case F32:
759       fn_name = multi_threaded
760                     ? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName
761                                    : runtime::kEigenMatMulF32SymbolName)
762                     : (use_mkl_dnn
763                            ? runtime::kMKLSingleThreadedMatMulF32SymbolName
764                            : runtime::kEigenSingleThreadedMatMulF32SymbolName);
765       float_type = b_->getFloatTy();
766       break;
767     case F64:
768       fn_name = multi_threaded
769                     ? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName
770                                    : runtime::kEigenMatMulF64SymbolName)
771                     : (use_mkl_dnn
772                            ? runtime::kMKLSingleThreadedMatMulF64SymbolName
773                            : runtime::kEigenSingleThreadedMatMulF64SymbolName);
774       float_type = b_->getDoubleTy();
775       break;
776     case C64:
777       fn_name = multi_threaded
778                     ? runtime::kEigenMatMulC64SymbolName
779                     : runtime::kEigenSingleThreadedMatMulC64SymbolName;
780       float_type = llvm_ir::PrimitiveTypeToIrType(C64, module);
781       break;
782     case C128:
783       fn_name = multi_threaded
784                     ? runtime::kEigenMatMulC128SymbolName
785                     : runtime::kEigenSingleThreadedMatMulC128SymbolName;
786       float_type = llvm_ir::PrimitiveTypeToIrType(C128, module);
787       break;
788     case S32:
789       fn_name = multi_threaded
790                     ? runtime::kEigenMatMulS32SymbolName
791                     : runtime::kEigenSingleThreadedMatMulS32SymbolName;
792       float_type = b_->getInt32Ty();
793       break;
794     default:
795       return Unimplemented("Invalid type %s for dot operation",
796                            PrimitiveType_Name(type));
797   }
798 
799   llvm::Type* float_ptr_type = float_type->getPointerTo();
800   llvm::Type* int64_type = b_->getInt64Ty();
801   llvm::Type* int32_type = b_->getInt32Ty();
802   llvm::Type* int8_ptr_type = b_->getInt8Ty()->getPointerTo();
803   llvm::FunctionType* matmul_type = llvm::FunctionType::get(
804       b_->getVoidTy(),
805       {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
806        int64_type, int64_type, int64_type, int32_type, int32_type},
807       /*isVarArg=*/false);
808 
809   llvm::FunctionCallee matmul_func =
810       module->getOrInsertFunction(fn_name, matmul_type);
811   if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
812     fn->setCallingConv(llvm::CallingConv::C);
813     fn->setDoesNotThrow();
814     fn->setOnlyAccessesArgMemory();
815   }
816 
817   // The Eigen runtime function expects column-major layout. If the matrices are
818   // row major, then use the following identity to compute the product:
819   //
820   //   (A x B)^T = B^T x A^T
821   //
822   // The connection between this identity and memory layout is that the
823   // transpose operation can also be considered as an operation that changes the
824   // memory layout of a matrix from row-major to column-major or vice versa.
825   //
826   // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
827 
828   MatMultDims mat_mult_dims = GetMatMultDims();
829 
830   CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
831 
832   const llvm_ir::IrArray* lhs = &lhs_array_;
833   const llvm_ir::IrArray* rhs = &rhs_array_;
834   bool transpose_lhs = !mat_mult_dims.lhs_canonical;
835   bool transpose_rhs = !mat_mult_dims.rhs_canonical;
836 
837   if (!mat_mult_dims.lhs_column_major) {
838     std::swap(mat_mult_dims.m, mat_mult_dims.n);
839     std::swap(lhs, rhs);
840     std::swap(transpose_lhs, transpose_rhs);
841   }
842 
843   b_->CreateCall(
844       matmul_func,
845       {b_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
846        b_->CreateBitCast(target_array_.GetBasePointer(), float_ptr_type),
847        b_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
848        b_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
849        b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n),
850        b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs),
851        b_->getInt32(transpose_rhs)});
852   return Status::OK();
853 }
854 
GetMatMultDims() const855 DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
856   CHECK_LE(dot_info_.result_shape.dimensions_size(), 2);
857 
858   const Shape& lhs_shape = lhs_array_.GetShape();
859   const Shape& rhs_shape = rhs_array_.GetShape();
860   const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
861 
862   auto is_column_major = [](const Shape& shape) {
863     return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0;
864   };
865 
866   // Non-contracting dots should never make it here.
867   CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0);
868   CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0);
869 
870   return {
871       /*m=*/lhs_shape.rank() <= 1
872           ? 1LL
873           : lhs_shape.dimensions(1LL - dim_nums.lhs_contracting_dimensions(0)),
874       /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
875       /*n=*/rhs_shape.rank() <= 1
876           ? 1LL
877           : rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)),
878       /*lhs_column_major=*/is_column_major(lhs_shape),
879       /*lhs_canonical=*/lhs_shape.rank() <= 1 ||
880           dim_nums.lhs_contracting_dimensions(0) == 1,
881       /*rhs_column_major=*/is_column_major(rhs_shape),
882       /*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0};
883 }
884 
885 // For vector-matrix dot products, it is always profitable to make the Rhs
886 // column major.
ProfitableToMakeDotOperandColumnMajor(const HloInstruction & hlo)887 absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
888     const HloInstruction& hlo) {
889   if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() <= 1) {
890     if (hlo.operand(0)->shape().rank() != 1 ||
891         hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) != 0) {
892       return {};
893     }
894 
895     // Don't bother if the other operand is tiny, switching to column major
896     // wouldn't use tiling.
897     constexpr int kColumnMajorThresholdInBytes = 32;
898     int64_t lhs_size =
899         ShapeUtil::ByteSizeOfPrimitiveType(hlo.shape().element_type()) *
900         ShapeUtil::ElementsIn(hlo.operand(0)->shape());
901     if (lhs_size < kColumnMajorThresholdInBytes) {
902       return {};
903     }
904 
905     return 1;
906   }
907 
908   if (hlo.IsOutputFusion()) {
909     auto* fusion_root =
910         hlo.fused_instructions_computation()->root_instruction();
911     if (fusion_root->opcode() != HloOpcode::kAdd) {
912       return {};
913     }
914 
915     for (auto* fusion_root_op : fusion_root->operands()) {
916       if (fusion_root_op->opcode() != HloOpcode::kDot) {
917         continue;
918       }
919       if (auto operand_num =
920               ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) {
921         auto* operand = fusion_root_op->operand(*operand_num);
922         if (operand->opcode() == HloOpcode::kParameter &&
923             operand->user_count() == 1) {
924           return operand->parameter_number();
925         }
926       }
927     }
928   }
929 
930   return {};
931 }
932 
933 namespace {
934 // Return whether the given shape is rank 2.
IsRank2(const Shape & shape)935 bool IsRank2(const Shape& shape) { return shape.rank() == 2; }
936 
IsSimpleLayout(const Layout & layout)937 bool IsSimpleLayout(const Layout& layout) {
938   return layout.tiles().empty() && layout.format() == DENSE;
939 }
940 
941 // In a gemm operation where output = lhs * rhs, check whether the given shapes
942 // are valid for the operation.
AreGemmShapes(const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,const TargetMachineFeatures & target_machine_features)943 bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
944                    const Shape& output_shape,
945                    const TargetMachineFeatures& target_machine_features) {
946   CHECK(!lhs_shape.has_layout() || IsSimpleLayout(lhs_shape.layout()))
947       << lhs_shape.DebugString();
948   CHECK(!rhs_shape.has_layout() || IsSimpleLayout(rhs_shape.layout()))
949       << rhs_shape.DebugString();
950   CHECK(!output_shape.has_layout() || IsSimpleLayout(output_shape.layout()))
951       << output_shape.DebugString();
952 
953   switch (output_shape.element_type()) {
954     case F16:
955     case F32:
956     case F64:
957     case C64:
958     case C128:
959     case S32:
960       return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape);
961     default:
962       return false;
963   }
964 }
965 
IsAlignedGemm(const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)966 bool IsAlignedGemm(const DotInfo& dot_info,
967                    const TargetMachineFeatures& target_machine_features) {
968   if (ShapeUtil::IsZeroElementArray(dot_info.lhs_shape) ||
969       ShapeUtil::IsZeroElementArray(dot_info.rhs_shape)) {
970     return false;
971   }
972 
973   return AreGemmShapes(dot_info.lhs_shape, dot_info.rhs_shape,
974                        dot_info.result_shape, target_machine_features);
975 }
976 
CanEmitTiledLlvmIrGemm(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)977 bool CanEmitTiledLlvmIrGemm(
978     const HloModuleConfig& config, const DotInfo& dot_info,
979     const TargetMachineFeatures& target_machine_features) {
980   CHECK(IsAlignedGemm(dot_info, target_machine_features));
981 
982   if (ShouldUseMultiThreadedEigen(config)) {
983     return false;
984   }
985 
986   int m = dot_info.result_shape.dimensions(0);
987   int k = dot_info.lhs_shape.dimensions(
988       dot_info.dim_nums.lhs_contracting_dimensions(0));
989   int n = dot_info.result_shape.dimensions(1);
990 
991   if (!options::ForceEnableExperimentalLlvmIrGemm(config)) {
992     // TODO(sanjoy):  We should make these numbers micro-arch specific.
993     bool small_gemm =
994         k <= 128 && ((m <= 32 && n <= 128) || (m <= 128 && n <= 32));
995     if (!small_gemm) {
996       return false;
997     }
998   }
999 
1000   bool lhs_canonical = dot_info.dim_nums.lhs_contracting_dimensions(0) == 1;
1001   bool rhs_canonical = dot_info.dim_nums.rhs_contracting_dimensions(0) == 0;
1002 
1003   if (!(lhs_canonical && rhs_canonical)) {
1004     return false;
1005   }
1006 
1007   if (dot_info.result_shape.element_type() == F16 ||
1008       dot_info.result_shape.element_type() == C64 ||
1009       dot_info.result_shape.element_type() == C128) {
1010     // TODO(sanjoy): This is probably easy to fix, but I want to keep the CL
1011     // adding this comment NFC.
1012     return false;
1013   }
1014 
1015   return true;
1016 }
1017 
GetDotImplementationStrategy(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)1018 DotImplementationStrategy GetDotImplementationStrategy(
1019     const HloModuleConfig& config, const DotInfo& dot_info,
1020     const TargetMachineFeatures& target_machine_features) {
1021   PrimitiveType element_type = dot_info.result_shape.element_type();
1022   // Any Matrix-Vector product of floating point or integral type, or
1023   // a transpose-dot fusion of the same can be lowered to a tiled LLVM
1024   // IR implementation.
1025   if ((dot_info.result_shape.dimensions_size() <= 1 ||
1026        (dot_info.result_shape.dimensions_size() == 2 &&
1027         (dot_info.result_shape.dimensions(0) == 1 ||
1028          dot_info.result_shape.dimensions(1) == 1))) &&
1029       (primitive_util::IsFloatingPointType(element_type) ||
1030        primitive_util::IsIntegralType(element_type))) {
1031     return DotImplementationStrategy::kTiledLlvmIrGemv;
1032   }
1033 
1034   // MatMul smaller than 3x3 should use naive nested loop.
1035   if ((dot_info.lhs_shape.dimensions_size() <= 1 ||
1036        (dot_info.lhs_shape.dimensions_size() == 2 &&
1037         (dot_info.lhs_shape.dimensions(0) <= 3 ||
1038          dot_info.lhs_shape.dimensions(1) <= 3))) &&
1039       (dot_info.rhs_shape.dimensions_size() <= 1 ||
1040        (dot_info.rhs_shape.dimensions_size() == 2 &&
1041         (dot_info.rhs_shape.dimensions(0) <= 3 ||
1042          dot_info.rhs_shape.dimensions(1) <= 3))) &&
1043       (primitive_util::IsFloatingPointType(element_type) ||
1044        primitive_util::IsIntegralType(element_type))) {
1045     return DotImplementationStrategy::kNaiveLlvmIr;
1046   }
1047 
1048   if (IsAlignedGemm(dot_info, target_machine_features)) {
1049     if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) {
1050       return DotImplementationStrategy::kTiledLlvmIrGemm;
1051     }
1052     return DotImplementationStrategy::kEigen;
1053   }
1054 
1055   return DotImplementationStrategy::kNaiveLlvmIr;
1056 }
1057 
EmitNonBatchDotOperation(DotInfo dot_info,string hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1058 Status EmitNonBatchDotOperation(
1059     DotInfo dot_info, string hlo_name, const llvm_ir::IrArray& target_array,
1060     const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
1061     const llvm_ir::IrArray* addend_array,
1062     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
1063     mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
1064     const TargetMachineFeatures& target_machine_features) {
1065   PrimitiveType type = target_array.GetShape().element_type();
1066   TF_RET_CHECK(PRED == type || S8 == type || U8 == type || S16 == type ||
1067                U16 == type || S32 == type || U32 == type || S64 == type ||
1068                U64 == type || F16 == type || F32 == type || F64 == type ||
1069                C64 == type || C128 == type);
1070   DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
1071                            target_array, lhs_array, rhs_array, addend_array,
1072                            executable_run_options_value, b, mlir_context,
1073                            hlo_module_config, target_machine_features);
1074   return dot_emitter.Emit();
1075 }
1076 
DropFirstDim(const Shape & shape)1077 Shape DropFirstDim(const Shape& shape) {
1078   absl::Span<int64 const> array_shape_dims(shape.dimensions());
1079   array_shape_dims.remove_prefix(1);
1080   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1081                                                   array_shape_dims);
1082 }
1083 
CollapseFirstNDims(const Shape & shape,int64_t n)1084 Shape CollapseFirstNDims(const Shape& shape, int64_t n) {
1085   absl::Span<int64 const> input_shape_dims(shape.dimensions());
1086   int64_t prefix_dim =
1087       std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n,
1088                       1ll, std::multiplies<int64>());
1089   DimensionVector result_dims;
1090   result_dims.push_back(prefix_dim);
1091   std::copy(input_shape_dims.begin() + n, input_shape_dims.end(),
1092             std::back_inserter(result_dims));
1093   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1094                                                   result_dims);
1095 }
1096 
CollapseFirstNDims(llvm::IRBuilder<> * b,const llvm_ir::IrArray & array,int64_t n)1097 llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilder<>* b,
1098                                     const llvm_ir::IrArray& array, int64_t n) {
1099   llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
1100   const Shape& shape = array.GetShape();
1101   CHECK(shape.has_layout() &&
1102         LayoutUtil::IsMonotonicWithDim0Major(shape.layout()));
1103   CHECK_GE(shape.dimensions_size(), n);
1104   Shape new_shape = CollapseFirstNDims(shape, n);
1105   llvm::Value* new_value = b->CreateBitCast(
1106       array.GetBasePointer(),
1107       llvm_ir::ShapeToIrType(new_shape, module)->getPointerTo());
1108   return llvm_ir::IrArray(new_value, std::move(new_shape));
1109 }
1110 
ValidateDotDimensionNumbers(const DotDimensionNumbers & dim_numbers)1111 Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) {
1112   // Checks some invariants that do not hold in general, but DotDecomposer
1113   // should have established for us.  This is just a debugging aid.
1114   TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
1115   std::vector<int64> batch_dim_numbers(dim_numbers.lhs_batch_dimensions_size());
1116   absl::c_iota(batch_dim_numbers, 0);
1117   TF_RET_CHECK(
1118       absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions()));
1119   TF_RET_CHECK(
1120       absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions()));
1121   return Status::OK();
1122 }
1123 
1124 // Slice out the inner array at batch index `batch_index` from `outer_array`.
SliceOutInnerArray(llvm_ir::IrArray outer_array,llvm::Value * batch_index,llvm::IRBuilder<> * b)1125 llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array,
1126                                     llvm::Value* batch_index,
1127                                     llvm::IRBuilder<>* b) {
1128   llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
1129 
1130   Shape inner_shape = DropFirstDim(outer_array.GetShape());
1131   std::vector<llvm::Value*> multidim_index(inner_shape.rank() + 1,
1132                                            b->getInt64(0));
1133   multidim_index[0] = batch_index;
1134   llvm_ir::IrArray::Index slice_index(multidim_index, outer_array.GetShape(),
1135                                       batch_index->getType());
1136   llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b);
1137   llvm::Type* slice_ptr_type =
1138       llvm_ir::ShapeToIrType(inner_shape, module)->getPointerTo();
1139   return llvm_ir::IrArray(b->CreateBitCast(slice_ptr, slice_ptr_type),
1140                           std::move(inner_shape));
1141 }
1142 
EmitBatchDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1143 Status EmitBatchDotOperation(
1144     const HloInstruction& dot, const llvm_ir::IrArray& target_array,
1145     const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
1146     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
1147     mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
1148     const TargetMachineFeatures& target_machine_features) {
1149   TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers()));
1150 
1151   // Lower a batch dot into a sequence of non-batch dot operations.
1152 
1153   int64_t num_batch_dims =
1154       dot.dot_dimension_numbers().lhs_batch_dimensions_size();
1155 
1156   // First reshape the inputs to make sure we only have one batch dimension.
1157   // This is a no-op bitcast because the operands have to be in row-major layout
1158   // (enforced in CpuLayoutAssignment), and the batch dimensions are the leading
1159   // dimensions (established by DotDecomposer and checked by
1160   // ValidateDotDimensionNumbers above).
1161   llvm_ir::IrArray lhs_array_reshaped =
1162       CollapseFirstNDims(b, lhs_array, num_batch_dims);
1163   llvm_ir::IrArray rhs_array_reshaped =
1164       CollapseFirstNDims(b, rhs_array, num_batch_dims);
1165   llvm_ir::IrArray target_array_reshaped =
1166       CollapseFirstNDims(b, target_array, num_batch_dims);
1167 
1168   int64_t batch_count = lhs_array_reshaped.GetShape().dimensions(0);
1169 
1170   KernelSupportLibrary ksl(b);
1171 
1172   return ksl.ForWithStatus(
1173       llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count,
1174       /*step=*/1, [&](llvm::Value* indvar) {
1175         DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers();
1176         adjusted_dim_numbers.clear_lhs_batch_dimensions();
1177         adjusted_dim_numbers.clear_rhs_batch_dimensions();
1178 
1179         // Create a DotInfo representing the "inner" non-batch dot operation.
1180         DotInfo dot_info;
1181         dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape());
1182         dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape());
1183         dot_info.result_shape = DropFirstDim(target_array_reshaped.GetShape());
1184         dot_info.dim_nums = dot.dot_dimension_numbers();
1185         dot_info.dim_nums.clear_lhs_batch_dimensions();
1186         dot_info.dim_nums.clear_rhs_batch_dimensions();
1187 
1188         dot_info.dim_nums.set_lhs_contracting_dimensions(
1189             0,
1190             dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims);
1191         dot_info.dim_nums.set_rhs_contracting_dimensions(
1192             0,
1193             dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims);
1194 
1195         llvm_ir::IrArray lhs_slice =
1196             SliceOutInnerArray(lhs_array_reshaped, /*batch_index=*/indvar, b);
1197         llvm_ir::IrArray rhs_slice =
1198             SliceOutInnerArray(rhs_array_reshaped, /*batch_index=*/indvar, b);
1199         llvm_ir::IrArray target_slice = SliceOutInnerArray(
1200             target_array_reshaped, /*batch_index=*/indvar, b);
1201 
1202         // Emit the inner non-batch dot operation.
1203         return EmitNonBatchDotOperation(
1204             dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr,
1205             executable_run_options_value, b, mlir_context, hlo_module_config,
1206             target_machine_features);
1207       });
1208 }
1209 
IsBatchDot(const HloInstruction & instr)1210 bool IsBatchDot(const HloInstruction& instr) {
1211   if (auto* dot_instr = DynCast<HloDotInstruction>(&instr)) {
1212     return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0;
1213   }
1214 
1215   return false;
1216 }
1217 }  // namespace
1218 
DotImplementationCanHandleTranspose(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1219 bool DotImplementationCanHandleTranspose(
1220     const HloInstruction& dot_instr,
1221     const TargetMachineFeatures& target_machine_features) {
1222   DotImplementationStrategy impl_strategy =
1223       GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1224                                    DotInfo(dot_instr), target_machine_features);
1225 
1226   return impl_strategy == DotImplementationStrategy::kNaiveLlvmIr ||
1227          impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemv ||
1228          impl_strategy == DotImplementationStrategy::kEigen;
1229 }
1230 
DotOperandsAndResultMustHaveRowMajorLayout(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1231 bool DotOperandsAndResultMustHaveRowMajorLayout(
1232     const HloInstruction& dot_instr,
1233     const TargetMachineFeatures& target_machine_features) {
1234   // Batched dots require the batch dimensions to be major. DotDecomposer always
1235   // moves batch dimensions to the front of the shape, so force a row-major
1236   // layout.
1237   if (IsBatchDot(dot_instr)) {
1238     return true;
1239   }
1240 
1241   DotImplementationStrategy impl_strategy =
1242       GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1243                                    DotInfo(dot_instr), target_machine_features);
1244 
1245   return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm ||
1246          impl_strategy == DotImplementationStrategy::kEigen;
1247 }
1248 
EmitDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1249 Status EmitDotOperation(const HloInstruction& dot,
1250                         const llvm_ir::IrArray& target_array,
1251                         const llvm_ir::IrArray& lhs_array,
1252                         const llvm_ir::IrArray& rhs_array,
1253                         const llvm_ir::IrArray* addend_array,
1254                         llvm::Value* executable_run_options_value,
1255                         llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
1256                         const HloModuleConfig& hlo_module_config,
1257                         const TargetMachineFeatures& target_machine_features) {
1258   // This routine assumes that the dot operation is not in a parallelized
1259   // enclosing computation.
1260   CHECK(dot.parent()->root_instruction()->outer_dimension_partitions().empty());
1261 
1262   if (IsBatchDot(dot)) {
1263     TF_RET_CHECK(addend_array == nullptr);
1264     return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array,
1265                                  executable_run_options_value, b, mlir_context,
1266                                  hlo_module_config, target_machine_features);
1267   }
1268 
1269   return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array,
1270                                   lhs_array, rhs_array, addend_array,
1271                                   executable_run_options_value, b, mlir_context,
1272                                   hlo_module_config, target_machine_features);
1273 }
1274 }  // namespace cpu
1275 }  // namespace xla
1276