• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/strings/str_cat.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/Value.h"
26 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"  // from @llvm-project
27 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
28 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
32 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
33 #include "mlir/IR/Value.h"  // from @llvm-project
34 #include "mlir/Pass/Pass.h"  // from @llvm-project
35 #include "tensorflow/compiler/xla/primitive_util.h"
36 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
37 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
38 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
39 #include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
40 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
41 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
42 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
43 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
44 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
45 #include "tensorflow/compiler/xla/service/hlo_module.h"
46 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
47 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
48 #include "tensorflow/compiler/xla/shape_util.h"
49 #include "tensorflow/compiler/xla/status_macros.h"
50 #include "tensorflow/compiler/xla/util.h"
51 #include "tensorflow/compiler/xla/xla_data.pb.h"
52 #include "tensorflow/core/platform/logging.h"
53 
54 namespace xla {
55 
56 using llvm_ir::SetToFirstInsertPoint;
57 
58 namespace cpu {
59 namespace {
60 // Returns true if we should call into multi-threaded Eigen routines.
ShouldUseMultiThreadedEigen(const HloModuleConfig & config)61 bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) {
62   return config.debug_options().xla_cpu_multi_thread_eigen();
63 }
64 
65 // Represents a dot operation.  We use this in lieu of an `HloInstruction`
66 // because we want to be able to create this for the "inner" dot operation in a
67 // batch dot, for which there is no separate HLO instruction.
68 struct DotInfo {
69   Shape lhs_shape;
70   Shape rhs_shape;
71   Shape result_shape;
72   DotDimensionNumbers dim_nums;
73 
74   DotInfo() = default;
75 
DotInfoxla::cpu::__anon162db8010111::DotInfo76   explicit DotInfo(const HloInstruction& instr) {
77     CHECK_EQ(instr.opcode(), HloOpcode::kDot);
78     lhs_shape = instr.operand(0)->shape();
79     rhs_shape = instr.operand(1)->shape();
80     result_shape = instr.shape();
81     dim_nums = instr.dot_dimension_numbers();
82   }
83 };
84 
85 // Dictates how a dot operation is implemented.
86 enum class DotImplementationStrategy {
87   // The dot operation is lowered into LLVM IR that implements a naive nested
88   // loop that computes the result one element at a time.  This is our
89   // "fallback"; we don't really want this to kick in for any non-trival dot
90   // operation.
91   kNaiveLlvmIr,
92 
93   // The dot operation is lowered into LLVM IR that implements a tiled
94   // Matrix*Vector operation.  This strategy also allows fusing in a bias add
95   // into the dot.  The matrix can be row major or column major, both are
96   // supported.
97   kTiledLlvmIrGemv,
98 
99   // The dot operation is lowered into LLVM IR that implements a tiled
100   // Matrix*Matrix operation.  No fusions are supported.  The two inputs
101   // and the output have to be row major.
102   kTiledLlvmIrGemm,
103 
104   // The dot operation is lowered into linalg.matmul op and lowered to LLVM IR.
105   kLinalgMatmul,
106 
107   // The dot operation is lowered into a call into an Eigen routine.  No fusions
108   // are supported today.  The two inputs and the output have to be row major.
109   // However, we do allow transposing either the LHS or the RHS as part of the
110   // GEMM -- we expose this flexibility as flexibility in the contraction
111   // dimensions, but we can also see this as flexibility in the input layouts.
112   kEigen,
113 };
114 
115 // Returns the implementation strategy for a dot with the configuration
116 // `dot_info`.
117 DotImplementationStrategy GetDotImplementationStrategy(
118     const HloModuleConfig& config, const DotInfo& dot_info,
119     const TargetMachineFeatures& target_machine_features);
120 
121 // Helper class for emitting LLVM IR to perform the dot operation.
122 class DotOpEmitter {
123  public:
124   explicit DotOpEmitter(DotInfo dot_info, std::string dot_hlo_name,
125                         const llvm_ir::IrArray& target_array,
126                         const llvm_ir::IrArray& lhs_array,
127                         const llvm_ir::IrArray& rhs_array,
128                         const llvm_ir::IrArray* addend_array,
129                         llvm::Value* executable_run_options_value,
130                         llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
131                         const HloModuleConfig& hlo_module_config,
132                         const TargetMachineFeatures& target_machine_features);
133 
134   // Emits the IR to perform the dot operation.
135   Status Emit();
136 
137   // Emits the IR to perform the batch dot operation.
138   Status EmitBatch();
139 
140  private:
141   // Emits instructions to perform a scalar dot product (a multiply of the
142   // LHS and RHS) and store the results in the target.
143   Status EmitScalarDot();
144 
145   // Emits a call to the CPU runtime to perform the matrix multiply.
146   Status EmitCallToRuntime();
147 
148   // Emits a call to the CPU runtime to perform the batch matrix multiply.
149   Status EmitCallToBatchRuntime();
150 
151   // Represents the dimensions of a matrix-matrix multiply operation.
152   struct MatMultDims {
153     // The number of rows in the LHS.
154     int64_t m;
155 
156     // The number of columns in the LHS, which is also must be equal to the
157     // number of rows in the RHS.
158     int64_t k;
159 
160     // The number of columns on the RHS.
161     int64_t n;
162 
163     // True if the LHS matrix is column major.
164     bool lhs_column_major;
165 
166     // True if the LHS contraction dimension is 1.
167     bool lhs_canonical;
168 
169     // True if the RHS matrix is column major.
170     bool rhs_column_major;
171 
172     // True if the RHS contraction dimension is 0.
173     bool rhs_canonical;
174   };
175 
176   // Get the MatMultDims instance for the dot product this DotOpEmitter
177   // represents.  Precondition: the dot is of rank 2 (and thus its operands are
178   // of rank 2 as well).
179   MatMultDims GetMatMultDims() const;
180 
181   // Get the MatMultDims instance for the dot product this DotOpEmitter
182   // represents.  Precondition: the dot is of rank 3 (and thus its operands are
183   // of rank 3 as well).
184   MatMultDims GetBatchMatMultDims() const;
185 
186   // Lowers the dot operation as a tiled Matrix*Vector loop.
187   void EmitTiledLlvmIrGemv();
188 
189   // Lowers the dot operation as a tiled Matrix*Matrix loop.
190   void EmitTiledLlvmIrGemm();
191 
192   // Lowers the dot operation through MLIR's linalg.matmul.
193   Status EmitLinalgMatmul();
194 
195   // Lowers the dot operation as a naive nested loop that computes the result
196   // one element at a time.
197   void EmitNaiveLlvmIrGemm();
198 
199   // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
200   // registers.
GetGemvTilingFactor() const201   int64_t GetGemvTilingFactor() const {
202     const int64_t kDefaultTilingFactor = 8;
203     return options::LlvmIrGemvTilingFactor(hlo_module_config_)
204         .value_or(kDefaultTilingFactor);
205   }
206 
GetGemmTileSize() const207   std::tuple<int64_t, int64_t, int64_t> GetGemmTileSize() const {
208     // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
209     //
210     // TODO(b/80093688): Tune for other architectures and centralize this
211     // information in one place.
212     const std::tuple<int64_t, int64_t, int64_t> kDefaultTileSize =
213         std::tuple<int64_t, int64_t, int64_t>(11, 9, 1);
214     return options::LlvmIrGemmTileSize(hlo_module_config_)
215         .value_or(kDefaultTileSize);
216   }
217 
GetMlirGemmTileSize() const218   std::array<int64_t, 3> GetMlirGemmTileSize() const {
219     // Tile by 4 x registers x register size. This was picked by running
220     // small matmuls on Haswell and Skylake. There's a lot of room for
221     // improvement here.
222     constexpr int64_t kDefaultTileSizeForM = 4;
223     int64_t elements_per_register =
224         target_machine_features_.vector_register_num_elements(
225             *b_->GetInsertBlock()->getParent(),
226             dot_info_.result_shape.element_type());
227     int64_t num_registers = target_machine_features_.vector_register_count(
228         *b_->GetInsertBlock()->getParent());
229     return {{kDefaultTileSizeForM, num_registers, elements_per_register}};
230   }
231 
232   DotInfo dot_info_;
233   std::string dot_hlo_name_;
234   const llvm_ir::IrArray& target_array_;
235   const llvm_ir::IrArray& lhs_array_;
236   const llvm_ir::IrArray& rhs_array_;
237   const llvm_ir::IrArray* addend_array_;
238   llvm::Value* executable_run_options_value_;
239   llvm::IRBuilder<>* b_;
240   mlir::MLIRContext* mlir_context_;
241   const HloModuleConfig& hlo_module_config_;
242   const TargetMachineFeatures& target_machine_features_;
243 };
244 }  // namespace
245 
DotOpEmitter(DotInfo dot_info,std::string dot_hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)246 DotOpEmitter::DotOpEmitter(
247     DotInfo dot_info, std::string dot_hlo_name,
248     const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
249     const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array,
250     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
251     mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
252     const TargetMachineFeatures& target_machine_features)
253     : dot_info_(std::move(dot_info)),
254       dot_hlo_name_(std::move(dot_hlo_name)),
255       target_array_(target_array),
256       lhs_array_(lhs_array),
257       rhs_array_(rhs_array),
258       addend_array_(addend_array),
259       executable_run_options_value_(executable_run_options_value),
260       b_(b),
261       mlir_context_(mlir_context),
262       hlo_module_config_(hlo_module_config),
263       target_machine_features_(target_machine_features) {}
264 
EmitLinalgMatmul()265 Status DotOpEmitter::EmitLinalgMatmul() {
266   Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape};
267   llvm::Value* operand_ptrs[] = {lhs_array_.GetBasePointer(),
268                                  rhs_array_.GetBasePointer()};
269   llvm::Value* target_ptr = target_array_.GetBasePointer();
270 
271   // Zero out the output buffer.
272   int64_t size_bytes = ShapeUtil::ByteSizeOf(dot_info_.result_shape);
273   b_->CreateMemSet(target_ptr, b_->getInt8(0), /*Size=*/size_bytes,
274                    /*Align=*/llvm::MaybeAlign(1));
275 
276   std::string name =
277       absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_",
278                    dot_info_.lhs_shape.ToString(true), "_",
279                    dot_info_.rhs_shape.ToString(true));
280 
281   return EmitMlirFuncAndCall(
282       mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
283       operand_ptrs, name,
284       [&](mlir::OpBuilder* builder, mlir::func::FuncOp function) {
285         CHECK_EQ(dot_info_.dim_nums.lhs_contracting_dimensions_size(), 1);
286         CHECK_EQ(dot_info_.dim_nums.rhs_contracting_dimensions_size(), 1);
287         mlir::MLIRContext* context = builder->getContext();
288         mlir::Value a = function.getArgument(0), b = function.getArgument(1),
289                     c = function.getArgument(2);
290 
291         llvm::SmallVector<mlir::AffineExpr, 2> b_exprs(
292             dot_info_.lhs_shape.rank());
293         llvm::SmallVector<mlir::AffineExpr, 2> c_exprs(
294             dot_info_.rhs_shape.rank());
295 
296         llvm::SmallVector<mlir::AffineExpr, 2> parallel_exprs;
297         mlir::AffineExpr reduce_expr;
298         for (int i = 0; i != dot_info_.result_shape.rank(); ++i) {
299           parallel_exprs.push_back(mlir::getAffineDimExpr(i, context));
300         }
301         reduce_expr =
302             mlir::getAffineDimExpr(dot_info_.result_shape.rank(), context);
303 
304         // The reduction expr is shared for both inputs.
305         b_exprs[dot_info_.dim_nums.lhs_contracting_dimensions(0)] = reduce_expr;
306         c_exprs[dot_info_.dim_nums.rhs_contracting_dimensions(0)] = reduce_expr;
307 
308         // Fill in the remaining parallel exprs.
309         int par_expr_num = 0;
310         for (auto* v : {&b_exprs, &c_exprs}) {
311           for (auto& e : *v) {
312             if (!e) {
313               e = parallel_exprs[par_expr_num++];
314             }
315           }
316         }
317 
318         llvm::SmallVector<llvm::StringRef, 4> iteratorTypes(
319             parallel_exprs.size(), toString(mlir::IteratorType::Parallel));
320         iteratorTypes.push_back(toString(mlir::IteratorType::Reduction));
321         builder->create<mlir::linalg::GenericOp>(
322             function.getLoc(),
323             /*inputs=*/mlir::ValueRange{b, c},
324             /*outputs=*/mlir::ValueRange{a},
325             /*indexingMaps=*/
326             mlir::AffineMap::inferFromExprList(
327                 {b_exprs, c_exprs, parallel_exprs}),
328             /*iteratorTypes=*/iteratorTypes,
329             [](mlir::OpBuilder& b, mlir::Location loc, mlir::ValueRange args) {
330               mlir::ArithBuilder ab(b, loc);
331               mlir::Value mul = ab.mul(args[0], args[1]);
332               mlir::Value add = ab.add(mul, args[2]);
333               b.create<mlir::linalg::YieldOp>(loc, add);
334             });
335         builder->create<mlir::func::ReturnOp>(function.getLoc());
336 
337         mlir::linalg::LinalgTilingOptions tilingOptions;
338         tilingOptions = tilingOptions.setTileSizes(GetMlirGemmTileSize());
339         // TODO: this has been retired upstream, reevaluate whether this
340         // path really needs it or if it is even relevant anymore.
341         // int64_t alignment =
342         //     target_machine_features_.minimum_alignment_for_allocation(
343         //         ShapeUtil::ByteSizeOf(dot_info_.result_shape));
344         mlir::linalg::CodegenStrategy strategy;
345         strategy
346             .tile(mlir::linalg::GenericOp::getOperationName(), tilingOptions)
347             // TODO: this has been retired upstream, reevaluate whether this
348             // path really needs it or if it is even relevant anymore.
349             // .promote(mlir::linalg::GenericOp::getOperationName(),
350             //          mlir::linalg::LinalgPromotionOptions()
351             //              .setAlignment(alignment)
352             //              .setUseFullTileBuffersByDefault(true)
353             //              .setUseAlloca(true))
354             .vectorize(mlir::linalg::GenericOp::getOperationName())
355             .vectorLowering(
356                 mlir::linalg::LinalgVectorLoweringOptions()
357                     .setVectorTransformsOptions(
358                         mlir::vector::VectorTransformsOptions()
359                             .setVectorTransformsOptions(
360                                 mlir::vector::VectorContractLowering::
361                                     OuterProduct))
362                     .setVectorTransferToSCFOptions(
363                         mlir::VectorTransferToSCFOptions().enableFullUnroll()));
364         // TODO: this should be within a pass and we should be able to create a
365         // nested OpPassManager.
366         // Created a nested OpPassManager, populate the strategy and run.
367         // mlir::OpPassManager dynamicPM("func.func");
368         // strategy.configurePassPipeline(dynamicPM, function.getContext());
369         // Propagate pass failure?
370         // (void)mlir::runPipeline(dynamicPM, function);
371         mlir::PassManager pm(function.getContext(),
372                              function.getOperationName());
373         strategy.configurePassPipeline(pm, function.getContext());
374         // Propagate pass failure?
375         (void)pm.run(function);
376       });
377 }
378 
EmitTiledLlvmIrGemm()379 void DotOpEmitter::EmitTiledLlvmIrGemm() {
380   PrimitiveType primitive_type = dot_info_.result_shape.element_type();
381   MatMultDims mat_mult_dims = GetMatMultDims();
382 
383   llvm::Value* lhs = lhs_array_.GetBasePointer();
384   llvm::Value* rhs = rhs_array_.GetBasePointer();
385   llvm::Value* target = target_array_.GetBasePointer();
386   int64_t m = mat_mult_dims.m;
387   int64_t k = mat_mult_dims.k;
388   int64_t n = mat_mult_dims.n;
389 
390   if (mat_mult_dims.lhs_column_major) {
391     std::swap(lhs, rhs);
392     std::swap(m, n);
393   }
394 
395   int64_t size_bytes =
396       m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
397   b_->CreateMemSet(target, b_->getInt8(0), /*Size=*/size_bytes,
398                    /*Align=*/llvm::MaybeAlign(1));
399 
400   int64_t max_target_vector_width =
401       target_machine_features_.vector_register_num_elements(
402           *b_->GetInsertBlock()->getParent(), primitive_type);
403 
404   int64_t tile_size_m, tile_size_k, tile_size_n_in_vector_width;
405   std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
406       GetGemmTileSize();
407 
408   EmitSmallGemm(
409       /*scalar_type=*/primitive_type,
410       /*m=*/m, /*k=*/k, /*n=*/n,
411       /*max_vectorization_width=*/max_target_vector_width,
412       /*max_vector_count=*/tile_size_n_in_vector_width,
413       /*min_vectorization_width=*/std::min<int64_t>(4, max_target_vector_width),
414       /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs,
415       /*rhs=*/rhs, /*result=*/target, b_, hlo_module_config_);
416 }
417 
EmitTiledLlvmIrGemv()418 void DotOpEmitter::EmitTiledLlvmIrGemv() {
419   PrimitiveType primitive_type = dot_info_.result_shape.element_type();
420 
421   CHECK(primitive_util::IsFloatingPointType(primitive_type) ||
422         primitive_util::IsIntegralType(primitive_type));
423 
424   MatMultDims mat_mult_dims = GetMatMultDims();
425   bool is_column_major_matrix_vector_gemv = false;
426   bool is_row_major_matrix_vector_gemv = false;
427 
428   int64_t m, k;
429   bool swap_operands;
430 
431   if (mat_mult_dims.m == 1) {
432     // Our emitters can only do Matrix*Vector (abbreviated as M*V) but when M=1
433     // we actually want V*M.  We implement V*M as follows (Tr(X) = Transpose of
434     // X):
435     //
436     //   V*M = Tr(Tr(V*M))  // Tr(Tr(X)) == X
437     //       = Tr(Tr(M) * Tr(V))  // Tr(A * B) == Tr(B) * Tr(A)
438     //
439     // Since transposing a vector is physically a no-op, this is really
440     // equivalent to `Tr(M) * V`.  We further implement Tr(M) by pretending that
441     // M is row major if it is actually column major and vice-versa.
442 
443     bool rhs_effectively_column_major = mat_mult_dims.rhs_canonical
444                                             ? mat_mult_dims.rhs_column_major
445                                             : !mat_mult_dims.rhs_column_major;
446 
447     if (rhs_effectively_column_major) {
448       k = mat_mult_dims.k;
449       m = mat_mult_dims.n;
450 
451       // We set is_row_major_matrix_vector_gemv and not
452       // is_column_major_matrix_vector_gemv to implement the Transpose trick
453       // mentioned above.
454       is_row_major_matrix_vector_gemv = true;
455       swap_operands = true;
456     } else {
457       k = mat_mult_dims.k;
458       m = mat_mult_dims.n;
459 
460       // We set is_column_major_matrix_vector_gemv and not
461       // is_row_major_matrix_vector_gemv to implement the Transpose trick
462       // mentioned above.
463       is_column_major_matrix_vector_gemv = true;
464       swap_operands = true;
465     }
466   }
467 
468   if (mat_mult_dims.n == 1) {
469     bool lhs_effectively_column_major = mat_mult_dims.lhs_canonical
470                                             ? mat_mult_dims.lhs_column_major
471                                             : !mat_mult_dims.lhs_column_major;
472 
473     if (lhs_effectively_column_major) {
474       m = mat_mult_dims.m;
475       k = mat_mult_dims.k;
476       is_column_major_matrix_vector_gemv = true;
477       swap_operands = false;
478     } else {
479       m = mat_mult_dims.m;
480       k = mat_mult_dims.k;
481       is_row_major_matrix_vector_gemv = true;
482       swap_operands = false;
483     }
484   }
485 
486   CHECK(is_column_major_matrix_vector_gemv || is_row_major_matrix_vector_gemv);
487 
488   int64_t tiling_factor = GetGemvTilingFactor();
489   CHECK_GT(tiling_factor, 0);
490 
491   llvm::Value* result_op = target_array_.GetBasePointer();
492   llvm::Value* lhs_op =
493       swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
494   llvm::Value* rhs_op =
495       swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
496 
497   const int target_vector_register_element_size =
498       target_machine_features_.vector_register_num_elements(
499           *b_->GetInsertBlock()->getParent(), primitive_type);
500 
501   // We may not always know the vector register size for the target we're
502   // compiling against, in which case target_vector_register_element_size is 0.
503   // In these cases we choose a default LLVM IR register size.
504   const int kUnknownTargetVectorRegisterSize = 4;
505   const int vector_register_element_size =
506       target_vector_register_element_size == 0
507           ? kUnknownTargetVectorRegisterSize
508           : target_vector_register_element_size;
509 
510   if (is_column_major_matrix_vector_gemv) {
511     VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
512             << " and k = " << k;
513     EmitColumnMajorGemv(
514         /*scalar_type=*/primitive_type,
515         /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
516         /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
517         /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
518         /*result=*/result_op, b_, hlo_module_config_);
519   } else {
520     VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
521             << " and k = " << k;
522     EmitRowMajorGemv(
523         /*scalar_type=*/primitive_type,
524         /*tile_rows=*/tiling_factor,
525         /*tile_cols=*/vector_register_element_size,
526         /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
527         /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
528         /*result=*/result_op, b_, hlo_module_config_);
529   }
530 }
531 
Emit()532 Status DotOpEmitter::Emit() {
533   // The dot operation performs a sum of products over dimension 0 of the left
534   // hand side operand and dimension 1 of the right hand side operand.
535   //
536   // Let the shapes of lhs and rhs be defined as below:
537   //
538   //   lhs = [L{n-1} x L{n-2} x ... L{0}]
539   //   rhs = [R{m-1} x R{m-2} x ... R{0}]
540   //
541   // The sum-of-products dimension in the lhs has size L{0} and the dimension in
542   // the rhs has size R{1}. Necessarily, then:
543   //
544   //   L{0} == R{1}
545   //
546   // The output of the operation has the following shape:
547   //
548   //   output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
549   //
550   // To perform the operation we construct a loop nest with one for-loop for
551   // each dimension of the output. Inside this loop nest is another for-loop
552   // which performs the sum-of-products (the reduction loop) before storing
553   // the result in the output buffer.
554 
555   const Shape& lhs_shape = lhs_array_.GetShape();
556   const Shape& rhs_shape = rhs_array_.GetShape();
557 
558   if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
559     // If the operands are scalar, don't emit any loops.
560     TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) &&
561                  ShapeUtil::IsScalar(rhs_shape));
562     return EmitScalarDot();
563   }
564 
565   switch (GetDotImplementationStrategy(hlo_module_config_, dot_info_,
566                                        target_machine_features_)) {
567     case DotImplementationStrategy::kNaiveLlvmIr:
568       EmitNaiveLlvmIrGemm();
569       return OkStatus();
570 
571     case DotImplementationStrategy::kTiledLlvmIrGemv:
572       EmitTiledLlvmIrGemv();
573       return OkStatus();
574 
575     case DotImplementationStrategy::kTiledLlvmIrGemm:
576       EmitTiledLlvmIrGemm();
577       return OkStatus();
578 
579     case DotImplementationStrategy::kLinalgMatmul:
580       return EmitLinalgMatmul();
581 
582     case DotImplementationStrategy::kEigen:
583       return EmitCallToRuntime();
584   }
585 }
586 
EmitBatch()587 Status DotOpEmitter::EmitBatch() {
588   // The dot operation performs a sum of products over dimension 0 of the left
589   // hand side operand and dimension 1 of the right hand side operand.
590   //
591   // Let the shapes of lhs and rhs be defined as below:
592   //
593   //   lhs = [L{n-1} x L{n-2} x ... L{0}]
594   //   rhs = [R{m-1} x R{m-2} x ... R{0}]
595   //
596   // The sum-of-products dimension in the lhs has size L{0} and the dimension in
597   // the rhs has size R{1}. Necessarily, then:
598   //
599   //   L{0} == R{1}
600   //
601   // The output of the operation has the following shape:
602   //
603   //   output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
604   //
605   // To perform the operation we construct a loop nest with one for-loop for
606   // each dimension of the output. Inside this loop nest is another for-loop
607   // which performs the sum-of-products (the reduction loop) before storing
608   // the result in the output buffer.
609 
610   return EmitCallToBatchRuntime();
611 }
612 
EmitNaiveLlvmIrGemm()613 void DotOpEmitter::EmitNaiveLlvmIrGemm() {
614   CHECK_EQ(addend_array_, nullptr);
615 
616   const Shape& lhs_shape = lhs_array_.GetShape();
617   const Shape& rhs_shape = rhs_array_.GetShape();
618   const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
619 
620   // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special
621   // case where the reduction dimension is 0 for both LHS and RHS. This results
622   // in a vector dot product producing a scalar.
623   int64_t lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0);
624   int64_t rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0);
625 
626   // Verify the reduction dimension in the two operands are the same size.
627   CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension),
628            rhs_shape.dimensions(rhs_reduction_dimension));
629 
630   bool lhs_reduction_along_minor_dimension =
631       lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0);
632   bool rhs_reduction_along_minor_dimension =
633       rhs_reduction_dimension == LayoutUtil::Minor(rhs_shape.layout(), 0);
634 
635   // Create loop nests which loop through the LHS operand dimensions and the RHS
636   // operand dimensions. The reduction dimension of the LHS and RHS are handled
637   // in a separate innermost loop which performs the sum of products.
638   llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_);
639   std::vector<llvm::Value*> lhs_multi_index =
640       loop_nest.EmitOperandArrayLoopNest(
641           lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
642   std::vector<llvm::Value*> rhs_multi_index =
643       loop_nest.EmitOperandArrayLoopNest(
644           rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
645 
646   // Create the loop which does the sum of products reduction.
647   //
648   // The prevent_unrolling bit is working around a deficiency in LLVM's loop
649   // vectorization pipeline, wherein in some cases unrolling a loop can prevent
650   // effective vectorization.  Since we know that the IR we generate when
651   // reducing across the minor dimension in both LHS and RHS is vectorized well
652   // by the loop vectorizer, we block unrolling in that case to stop loop unroll
653   // from messing up the vectorization.
654   std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
655       0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction",
656       /*unroll_mode=*/
657       (lhs_reduction_along_minor_dimension &&
658        rhs_reduction_along_minor_dimension)
659           ? xla::llvm_ir::UnrollMode::kNoUnroll
660           : xla::llvm_ir::UnrollMode::kDefaultUnroll);
661 
662   // The final entry in the rhs and lhs indexes is the indvar of the
663   // reduction loop.
664   lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
665   llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_shape,
666                                     b_->getInt64Ty());
667   rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
668   llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_shape,
669                                     b_->getInt64Ty());
670 
671   // For computing the sum of products we alloca a single location to store the
672   // dot product result as we accumulate it within the reduction loop. After the
673   // reduction loop we load the result and store into the output array.
674 
675   // Function entry basic block.
676   // - Emit alloca for accumulator
677   llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent();
678   SetToFirstInsertPoint(&func->getEntryBlock(), b_);
679   llvm::Type* accum_type = target_array_.GetElementLlvmType();
680   llvm::Value* accum_address =
681       b_->CreateAlloca(accum_type, /*ArraySize=*/nullptr, "accum_address");
682 
683   // Preheader basic block of reduction loop:
684   // - Initialize accumulator to zero.
685   llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
686   b_->SetInsertPoint(preheader_bb->getTerminator());
687 
688   b_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address);
689 
690   // Body basic block of reduction loop:
691   // - Load elements from lhs and rhs array.
692   // - Multiply lhs-element and rhs-element.
693   // - Load accumulator and add to product.
694   // - Store sum back into accumulator.
695   SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), b_);
696 
697   llvm::Value* lhs_element = lhs_array_.EmitReadArrayElement(lhs_index, b_);
698   llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, b_);
699 
700   llvm::Value* accum = b_->CreateLoad(accum_type, accum_address);
701   llvm::Value* updated_accum;
702   if (ShapeUtil::ElementIsComplex(lhs_shape)) {
703     auto real = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {0}); };
704     auto imag = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {1}); };
705     llvm::Value* product_real =
706         b_->CreateFSub(b_->CreateFMul(real(lhs_element), real(rhs_element)),
707                        b_->CreateFMul(imag(lhs_element), imag(rhs_element)));
708     llvm::Value* product_imag =
709         b_->CreateFAdd(b_->CreateFMul(real(lhs_element), imag(rhs_element)),
710                        b_->CreateFMul(imag(lhs_element), real(rhs_element)));
711     updated_accum = b_->CreateInsertValue(
712         accum, b_->CreateFAdd(real(accum), product_real), {0});
713     updated_accum = b_->CreateInsertValue(
714         updated_accum, b_->CreateFAdd(imag(accum), product_imag), {1});
715   } else if (ShapeUtil::ElementIsIntegral(lhs_shape)) {
716     llvm::Value* product = b_->CreateMul(lhs_element, rhs_element);
717     updated_accum = b_->CreateAdd(accum, product);
718   } else if (lhs_shape.element_type() == PRED) {
719     llvm::Value* product = b_->CreateAnd(lhs_element, rhs_element);
720     updated_accum = b_->CreateOr(accum, product);
721   } else {
722     llvm::Value* product = b_->CreateFMul(lhs_element, rhs_element);
723     updated_accum = b_->CreateFAdd(accum, product);
724   }
725   b_->CreateStore(updated_accum, accum_address);
726 
727   // Exit basic block of reduction loop.
728   // - Load accumulator value (the result).
729   // - Store into output array.
730   SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), b_);
731 
732   llvm::Value* result = b_->CreateLoad(accum_type, accum_address);
733 
734   // Create index into target address. The target index is the concatenation of
735   // the rhs and lhs indexes with the reduction dimensions removed. The terms
736   // from the rhs index are the lower dimensions in the index so we add them
737   // first.
738   std::vector<llvm::Value*> target_multi_index;
739   for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
740     if (dimension != lhs_reduction_dimension) {
741       target_multi_index.push_back(lhs_index[dimension]);
742     }
743   }
744   for (int dimension = 0; dimension < rhs_index.size(); ++dimension) {
745     if (dimension != rhs_reduction_dimension) {
746       target_multi_index.push_back(rhs_index[dimension]);
747     }
748   }
749 
750   llvm_ir::IrArray::Index target_index(
751       target_multi_index, target_array_.GetShape(), lhs_index.GetType());
752   target_array_.EmitWriteArrayElement(target_index, result, b_);
753 
754   // Set the IR builder insert point to the exit basic block of the outer most
755   // loop.
756   b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
757 }
758 
EmitScalarDot()759 Status DotOpEmitter::EmitScalarDot() {
760   // A scalar dot is just a scalar multiply.
761   llvm::Value* result;
762   // Use the same index_type for all tensor accesses in the same kernel.
763   llvm::Type* index_type = b_->getInt64Ty();
764   llvm_ir::IrArray::Index element_index(index_type);
765   llvm::Value* lhs_value =
766       lhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
767   llvm::Value* rhs_value =
768       rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
769   if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
770     auto get_real = [&](llvm::Value* x) {
771       return b_->CreateExtractValue(x, {0});
772     };
773 
774     auto get_imag = [&](llvm::Value* x) {
775       return b_->CreateExtractValue(x, {1});
776     };
777 
778     llvm::Value* real = b_->CreateFSub(
779         b_->CreateFMul(get_real(lhs_value), get_real(rhs_value)),
780         b_->CreateFMul(get_imag(lhs_value), get_imag(rhs_value)));
781     llvm::Value* imag = b_->CreateFAdd(
782         b_->CreateFMul(get_real(lhs_value), get_imag(rhs_value)),
783         b_->CreateFMul(get_imag(lhs_value), get_real(rhs_value)));
784     result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
785     result = b_->CreateInsertValue(result, real, {0});
786     result = b_->CreateInsertValue(result, imag, {1});
787   } else {
788     result = b_->CreateFMul(lhs_value, rhs_value);
789   }
790   target_array_.EmitWriteArrayElement(/*index=*/element_index, result, b_);
791   return OkStatus();
792 }
793 
EmitCallToRuntime()794 Status DotOpEmitter::EmitCallToRuntime() {
795   // The signature of the Eigen runtime matmul function is:
796   //
797   //   (void)(void* run_options, float* out, float* lhs, float* rhs,
798   //          int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
799   //          int32_t transpose_rhs);
800   // The two transpose_... parameters are actually booleans, but we use int32_t
801   // to avoid target-dependent calling convention details.
802 
803   bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_);
804   bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
805   bool use_acl = hlo_module_config_.debug_options().xla_cpu_use_acl();
806   PrimitiveType type = target_array_.GetShape().element_type();
807   llvm::Function* function = b_->GetInsertBlock()->getParent();
808   llvm::Module* module = function->getParent();
809   llvm::Type* float_type;
810   const char* fn_name;
811   switch (type) {
812     case F16:
813       fn_name = multi_threaded
814                     ? runtime::kEigenMatMulF16SymbolName
815                     : runtime::kEigenSingleThreadedMatMulF16SymbolName;
816       float_type = b_->getHalfTy();
817       break;
818     case F32:
819       fn_name =
820           multi_threaded
821               ? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName
822                              : (use_acl ? runtime::kACLMatMulF32SymbolName
823                                         : runtime::kEigenMatMulF32SymbolName))
824               : (use_mkl_dnn
825                      ? runtime::kMKLSingleThreadedMatMulF32SymbolName
826                      : runtime::kEigenSingleThreadedMatMulF32SymbolName);
827       float_type = b_->getFloatTy();
828       break;
829     case F64:
830       fn_name = multi_threaded
831                     ? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName
832                                    : runtime::kEigenMatMulF64SymbolName)
833                     : (use_mkl_dnn
834                            ? runtime::kMKLSingleThreadedMatMulF64SymbolName
835                            : runtime::kEigenSingleThreadedMatMulF64SymbolName);
836       float_type = b_->getDoubleTy();
837       break;
838     case C64:
839       fn_name = multi_threaded
840                     ? runtime::kEigenMatMulC64SymbolName
841                     : runtime::kEigenSingleThreadedMatMulC64SymbolName;
842       float_type = llvm_ir::PrimitiveTypeToIrType(C64, module);
843       break;
844     case C128:
845       fn_name = multi_threaded
846                     ? runtime::kEigenMatMulC128SymbolName
847                     : runtime::kEigenSingleThreadedMatMulC128SymbolName;
848       float_type = llvm_ir::PrimitiveTypeToIrType(C128, module);
849       break;
850     case S32:
851       fn_name = multi_threaded
852                     ? runtime::kEigenMatMulS32SymbolName
853                     : runtime::kEigenSingleThreadedMatMulS32SymbolName;
854       float_type = b_->getInt32Ty();
855       break;
856     default:
857       return Unimplemented("Invalid type %s for dot operation",
858                            PrimitiveType_Name(type));
859   }
860 
861   llvm::Type* float_ptr_type = float_type->getPointerTo();
862   llvm::Type* int64_type = b_->getInt64Ty();
863   llvm::Type* int32_type = b_->getInt32Ty();
864   llvm::Type* int8_ptr_type = b_->getInt8Ty()->getPointerTo();
865   llvm::FunctionType* matmul_type = llvm::FunctionType::get(
866       b_->getVoidTy(),
867       {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
868        int64_type, int64_type, int64_type, int32_type, int32_type},
869       /*isVarArg=*/false);
870 
871   llvm::FunctionCallee matmul_func =
872       module->getOrInsertFunction(fn_name, matmul_type);
873   if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
874     fn->setCallingConv(llvm::CallingConv::C);
875     fn->setDoesNotThrow();
876     fn->setOnlyAccessesArgMemory();
877   }
878 
879   // The Eigen runtime function expects column-major layout. If the matrices are
880   // row major, then use the following identity to compute the product:
881   //
882   //   (A x B)^T = B^T x A^T
883   //
884   // The connection between this identity and memory layout is that the
885   // transpose operation can also be considered as an operation that changes the
886   // memory layout of a matrix from row-major to column-major or vice versa.
887   //
888   // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
889 
890   MatMultDims mat_mult_dims = GetMatMultDims();
891 
892   CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
893 
894   const llvm_ir::IrArray* lhs = &lhs_array_;
895   const llvm_ir::IrArray* rhs = &rhs_array_;
896   bool transpose_lhs = !mat_mult_dims.lhs_canonical;
897   bool transpose_rhs = !mat_mult_dims.rhs_canonical;
898 
899   if (!mat_mult_dims.lhs_column_major) {
900     std::swap(mat_mult_dims.m, mat_mult_dims.n);
901     std::swap(lhs, rhs);
902     std::swap(transpose_lhs, transpose_rhs);
903   }
904 
905   b_->CreateCall(
906       matmul_func,
907       {b_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
908        b_->CreateBitCast(target_array_.GetBasePointer(), float_ptr_type),
909        b_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
910        b_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
911        b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n),
912        b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs),
913        b_->getInt32(transpose_rhs)});
914   return OkStatus();
915 }
916 
EmitCallToBatchRuntime()917 Status DotOpEmitter::EmitCallToBatchRuntime() {
918   // The signature of the runtime batch matmul function is:
919   //
920   //   (void)(void* run_options, float* out, float* lhs, float* rhs,
921   //          int64_t m, int64_t n, int64_t k, int64_t batch_size, int32_t
922   //          transpose_lhs, int32_t transpose_rhs);
923   // The two transpose_... parameters are actually booleans, but we use int32_t
924   // to avoid target-dependent calling convention details.
925 
926   PrimitiveType type = target_array_.GetShape().element_type();
927   bool use_acl = hlo_module_config_.debug_options().xla_cpu_use_acl();
928   llvm::Function* function = b_->GetInsertBlock()->getParent();
929   llvm::Module* module = function->getParent();
930   llvm::Type* float_type;
931   const char* fn_name;
932   switch (type) {
933     case F32:
934       fn_name = use_acl ? runtime::kACLBatchMatMulF32SymbolName
935                         : runtime::kEigenBatchMatMulF32SymbolName;
936 
937       float_type = b_->getFloatTy();
938       break;
939     default:
940       return Unimplemented("Invalid type %s for dot operation",
941                            PrimitiveType_Name(type));
942   }
943 
944   llvm::Type* float_ptr_type = float_type->getPointerTo();
945   llvm::Type* int64_type = b_->getInt64Ty();
946   llvm::Type* int32_type = b_->getInt32Ty();
947   llvm::Type* int8_ptr_type = b_->getInt8Ty()->getPointerTo();
948   llvm::FunctionType* matmul_type = llvm::FunctionType::get(
949       b_->getVoidTy(),
950       {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
951        int64_type, int64_type, int64_type, int64_type, int32_type, int32_type},
952       /*isVarArg=*/false);
953 
954   llvm::FunctionCallee matmul_func =
955       module->getOrInsertFunction(fn_name, matmul_type);
956   if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
957     fn->setCallingConv(llvm::CallingConv::C);
958     fn->setDoesNotThrow();
959     fn->setOnlyAccessesArgMemory();
960   }
961 
962   // The ACL runtime function expects column-major layout. If the matrices are
963   // row major, then use the following identity to compute the product:
964   //
965   //   (A x B)^T = B^T x A^T
966   //
967   // The connection between this identity and memory layout is that the
968   // transpose operation can also be considered as an operation that changes the
969   // memory layout of a matrix from row-major to column-major or vice versa.
970   //
971   // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
972 
973   MatMultDims mat_mult_dims = GetBatchMatMultDims();
974   CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
975 
976   const llvm_ir::IrArray* lhs = &lhs_array_;
977   const llvm_ir::IrArray* rhs = &rhs_array_;
978   bool transpose_lhs = !mat_mult_dims.lhs_canonical;
979   bool transpose_rhs = !mat_mult_dims.rhs_canonical;
980   const Shape& lhs_shape = lhs_array_.GetShape();
981 
982   if (!mat_mult_dims.lhs_column_major) {
983     std::swap(mat_mult_dims.m, mat_mult_dims.n);
984     std::swap(lhs, rhs);
985     std::swap(transpose_lhs, transpose_rhs);
986   }
987 
988   VLOG(1) << "Batch dot emitted with runtime:" << fn_name;
989 
990   b_->CreateCall(
991       matmul_func,
992       {b_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
993        b_->CreateBitCast(target_array_.GetBasePointer(), float_ptr_type),
994        b_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
995        b_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
996        b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n),
997        b_->getInt64(mat_mult_dims.k), b_->getInt64(lhs_shape.dimensions(0)),
998        b_->getInt32(static_cast<uint32_t>(transpose_lhs)),
999        b_->getInt32(static_cast<uint32_t>(transpose_rhs))});
1000   return Status::OK();
1001 }
1002 
GetMatMultDims() const1003 DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
1004   CHECK_LE(dot_info_.result_shape.dimensions_size(), 2);
1005 
1006   const Shape& lhs_shape = lhs_array_.GetShape();
1007   const Shape& rhs_shape = rhs_array_.GetShape();
1008   const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
1009 
1010   auto is_column_major = [](const Shape& shape) {
1011     return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0;
1012   };
1013 
1014   // Non-contracting dots should never make it here.
1015   CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0);
1016   CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0);
1017 
1018   return {
1019       /*m=*/lhs_shape.rank() <= 1
1020           ? 1LL
1021           : lhs_shape.dimensions(1LL - dim_nums.lhs_contracting_dimensions(0)),
1022       /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
1023       /*n=*/rhs_shape.rank() <= 1
1024           ? 1LL
1025           : rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)),
1026       /*lhs_column_major=*/is_column_major(lhs_shape),
1027       /*lhs_canonical=*/lhs_shape.rank() <= 1 ||
1028           dim_nums.lhs_contracting_dimensions(0) == 1,
1029       /*rhs_column_major=*/is_column_major(rhs_shape),
1030       /*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0};
1031 }
1032 
GetBatchMatMultDims() const1033 DotOpEmitter::MatMultDims DotOpEmitter::GetBatchMatMultDims() const {
1034   CHECK_LE(dot_info_.result_shape.dimensions_size(), 2);
1035 
1036   const Shape& lhs_shape = lhs_array_.GetShape();
1037   const Shape& rhs_shape = rhs_array_.GetShape();
1038   const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
1039 
1040   auto is_column_major = [](const Shape& shape) {
1041     return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0;
1042   };
1043 
1044   // Non-contracting dots should never make it here.
1045   CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0);
1046   CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0);
1047 
1048   return {
1049       /*m=*/lhs_shape.rank() <= 1
1050           ? 1LL
1051           : lhs_shape.dimensions(2LL - dim_nums.lhs_contracting_dimensions(0)),
1052       /*k=*/lhs_shape.dimensions(1LL + dim_nums.lhs_contracting_dimensions(0)),
1053       /*n=*/rhs_shape.rank() <= 1
1054           ? 1LL
1055           : rhs_shape.dimensions(2LL - dim_nums.rhs_contracting_dimensions(0)),
1056       /*lhs_column_major=*/is_column_major(lhs_shape),
1057       /*lhs_canonical=*/lhs_shape.rank() <= 1 ||
1058           dim_nums.lhs_contracting_dimensions(0) == 1,
1059       /*rhs_column_major=*/is_column_major(rhs_shape),
1060       /*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0};
1061 }
1062 
1063 // For vector-matrix dot products, it is always profitable to make the Rhs
1064 // column major.
ProfitableToMakeDotOperandColumnMajor(const HloInstruction & hlo)1065 std::optional<int64_t> ProfitableToMakeDotOperandColumnMajor(
1066     const HloInstruction& hlo) {
1067   if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() <= 1) {
1068     if (hlo.operand(0)->shape().rank() != 1 ||
1069         hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) != 0) {
1070       return {};
1071     }
1072 
1073     // Don't bother if the other operand is tiny, switching to column major
1074     // wouldn't use tiling.
1075     constexpr int kColumnMajorThresholdInBytes = 32;
1076     int64_t lhs_size =
1077         ShapeUtil::ByteSizeOfPrimitiveType(hlo.shape().element_type()) *
1078         ShapeUtil::ElementsIn(hlo.operand(0)->shape());
1079     if (lhs_size < kColumnMajorThresholdInBytes) {
1080       return {};
1081     }
1082 
1083     return 1;
1084   }
1085 
1086   if (hlo.IsOutputFusion()) {
1087     auto* fusion_root =
1088         hlo.fused_instructions_computation()->root_instruction();
1089     if (fusion_root->opcode() != HloOpcode::kAdd) {
1090       return {};
1091     }
1092 
1093     for (auto* fusion_root_op : fusion_root->operands()) {
1094       if (fusion_root_op->opcode() != HloOpcode::kDot) {
1095         continue;
1096       }
1097       if (auto operand_num =
1098               ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) {
1099         auto* operand = fusion_root_op->operand(*operand_num);
1100         if (operand->opcode() == HloOpcode::kParameter &&
1101             operand->user_count() == 1) {
1102           return operand->parameter_number();
1103         }
1104       }
1105     }
1106   }
1107 
1108   return {};
1109 }
1110 
1111 namespace {
1112 // Return whether the given shape is rank 2.
IsRank2(const Shape & shape)1113 bool IsRank2(const Shape& shape) { return shape.rank() == 2; }
1114 
IsSimpleLayout(const Layout & layout)1115 bool IsSimpleLayout(const Layout& layout) {
1116   return layout.tiles().empty() && LayoutUtil::IsDense(layout);
1117 }
1118 
1119 // In a gemm operation where output = lhs * rhs, check whether the given shapes
1120 // are valid for the operation.
AreGemmShapes(const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,const TargetMachineFeatures & target_machine_features)1121 bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
1122                    const Shape& output_shape,
1123                    const TargetMachineFeatures& target_machine_features) {
1124   CHECK(!lhs_shape.has_layout() || IsSimpleLayout(lhs_shape.layout()))
1125       << lhs_shape.DebugString();
1126   CHECK(!rhs_shape.has_layout() || IsSimpleLayout(rhs_shape.layout()))
1127       << rhs_shape.DebugString();
1128   CHECK(!output_shape.has_layout() || IsSimpleLayout(output_shape.layout()))
1129       << output_shape.DebugString();
1130 
1131   switch (output_shape.element_type()) {
1132     case F16:
1133     case F32:
1134     case F64:
1135     case C64:
1136     case C128:
1137     case S32:
1138       return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape);
1139     default:
1140       return false;
1141   }
1142 }
1143 
IsAlignedGemm(const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)1144 bool IsAlignedGemm(const DotInfo& dot_info,
1145                    const TargetMachineFeatures& target_machine_features) {
1146   if (ShapeUtil::IsZeroElementArray(dot_info.lhs_shape) ||
1147       ShapeUtil::IsZeroElementArray(dot_info.rhs_shape)) {
1148     return false;
1149   }
1150 
1151   return AreGemmShapes(dot_info.lhs_shape, dot_info.rhs_shape,
1152                        dot_info.result_shape, target_machine_features);
1153 }
1154 
CanEmitTiledLlvmIrGemm(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)1155 bool CanEmitTiledLlvmIrGemm(
1156     const HloModuleConfig& config, const DotInfo& dot_info,
1157     const TargetMachineFeatures& target_machine_features) {
1158   CHECK(IsAlignedGemm(dot_info, target_machine_features));
1159 
1160   if (ShouldUseMultiThreadedEigen(config)) {
1161     return false;
1162   }
1163 
1164   int m = dot_info.result_shape.dimensions(0);
1165   int k = dot_info.lhs_shape.dimensions(
1166       dot_info.dim_nums.lhs_contracting_dimensions(0));
1167   int n = dot_info.result_shape.dimensions(1);
1168 
1169   if (!options::ForceEnableExperimentalLlvmIrGemm(config)) {
1170     // TODO(sanjoy):  We should make these numbers micro-arch specific.
1171     bool small_gemm =
1172         k <= 128 && ((m <= 32 && n <= 128) || (m <= 128 && n <= 32));
1173     if (!small_gemm) {
1174       return false;
1175     }
1176   }
1177 
1178   bool lhs_canonical = dot_info.dim_nums.lhs_contracting_dimensions(0) == 1;
1179   bool rhs_canonical = dot_info.dim_nums.rhs_contracting_dimensions(0) == 0;
1180 
1181   if (!(lhs_canonical && rhs_canonical)) {
1182     return false;
1183   }
1184 
1185   if (dot_info.result_shape.element_type() == F16 ||
1186       dot_info.result_shape.element_type() == C64 ||
1187       dot_info.result_shape.element_type() == C128) {
1188     // TODO(sanjoy): This is probably easy to fix, but I want to keep the CL
1189     // adding this comment NFC.
1190     return false;
1191   }
1192 
1193   return true;
1194 }
1195 
GetDotImplementationStrategy(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)1196 DotImplementationStrategy GetDotImplementationStrategy(
1197     const HloModuleConfig& config, const DotInfo& dot_info,
1198     const TargetMachineFeatures& target_machine_features) {
1199   PrimitiveType element_type = dot_info.result_shape.element_type();
1200   // Any Matrix-Vector product of floating point or integral type, or
1201   // a transpose-dot fusion of the same can be lowered to a tiled LLVM
1202   // IR implementation.
1203   if ((dot_info.result_shape.dimensions_size() <= 1 ||
1204        (dot_info.result_shape.dimensions_size() == 2 &&
1205         (dot_info.result_shape.dimensions(0) == 1 ||
1206          dot_info.result_shape.dimensions(1) == 1))) &&
1207       (primitive_util::IsFloatingPointType(element_type) ||
1208        primitive_util::IsIntegralType(element_type))) {
1209     return DotImplementationStrategy::kTiledLlvmIrGemv;
1210   }
1211 
1212   // MatMul smaller than 3x3 should use naive nested loop.
1213   if ((dot_info.lhs_shape.dimensions_size() <= 1 ||
1214        (dot_info.lhs_shape.dimensions_size() == 2 &&
1215         (dot_info.lhs_shape.dimensions(0) <= 3 ||
1216          dot_info.lhs_shape.dimensions(1) <= 3))) &&
1217       (dot_info.rhs_shape.dimensions_size() <= 1 ||
1218        (dot_info.rhs_shape.dimensions_size() == 2 &&
1219         (dot_info.rhs_shape.dimensions(0) <= 3 ||
1220          dot_info.rhs_shape.dimensions(1) <= 3))) &&
1221       (primitive_util::IsFloatingPointType(element_type) ||
1222        primitive_util::IsIntegralType(element_type))) {
1223     return DotImplementationStrategy::kNaiveLlvmIr;
1224   }
1225 
1226   if (IsAlignedGemm(dot_info, target_machine_features)) {
1227     if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) {
1228       return DotImplementationStrategy::kTiledLlvmIrGemm;
1229     }
1230     return DotImplementationStrategy::kEigen;
1231   }
1232 
1233   return DotImplementationStrategy::kNaiveLlvmIr;
1234 }
1235 
EmitNonBatchDotOperation(DotInfo dot_info,std::string hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1236 Status EmitNonBatchDotOperation(
1237     DotInfo dot_info, std::string hlo_name,
1238     const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
1239     const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array,
1240     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
1241     mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
1242     const TargetMachineFeatures& target_machine_features) {
1243   PrimitiveType type = target_array.GetShape().element_type();
1244   TF_RET_CHECK(PRED == type || S8 == type || U8 == type || S16 == type ||
1245                U16 == type || S32 == type || U32 == type || S64 == type ||
1246                U64 == type || F16 == type || F32 == type || F64 == type ||
1247                C64 == type || C128 == type);
1248   DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
1249                            target_array, lhs_array, rhs_array, addend_array,
1250                            executable_run_options_value, b, mlir_context,
1251                            hlo_module_config, target_machine_features);
1252   return dot_emitter.Emit();
1253 }
1254 
DropFirstDim(const Shape & shape)1255 Shape DropFirstDim(const Shape& shape) {
1256   absl::Span<int64_t const> array_shape_dims(shape.dimensions());
1257   array_shape_dims.remove_prefix(1);
1258   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1259                                                   array_shape_dims);
1260 }
1261 
CollapseFirstNDims(const Shape & shape,int64_t n)1262 Shape CollapseFirstNDims(const Shape& shape, int64_t n) {
1263   absl::Span<int64_t const> input_shape_dims(shape.dimensions());
1264   int64_t prefix_dim =
1265       std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n,
1266                       1ll, std::multiplies<int64_t>());
1267   DimensionVector result_dims;
1268   result_dims.push_back(prefix_dim);
1269   std::copy(input_shape_dims.begin() + n, input_shape_dims.end(),
1270             std::back_inserter(result_dims));
1271   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1272                                                   result_dims);
1273 }
1274 
CollapseFirstNDims(llvm::IRBuilder<> * b,const llvm_ir::IrArray & array,int64_t n)1275 llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilder<>* b,
1276                                     const llvm_ir::IrArray& array, int64_t n) {
1277   llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
1278   const Shape& shape = array.GetShape();
1279   CHECK(shape.has_layout() &&
1280         LayoutUtil::IsMonotonicWithDim0Major(shape.layout()));
1281   CHECK_GE(shape.dimensions_size(), n);
1282   Shape new_shape = CollapseFirstNDims(shape, n);
1283   llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
1284   llvm::Value* new_value =
1285       b->CreateBitCast(array.GetBasePointer(), new_ir_type->getPointerTo());
1286   return llvm_ir::IrArray(new_value, new_ir_type, std::move(new_shape));
1287 }
1288 
ValidateDotDimensionNumbers(const DotDimensionNumbers & dim_numbers)1289 Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) {
1290   // Checks some invariants that do not hold in general, but DotDecomposer
1291   // should have established for us.  This is just a debugging aid.
1292   TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
1293   std::vector<int64_t> batch_dim_numbers(
1294       dim_numbers.lhs_batch_dimensions_size());
1295   absl::c_iota(batch_dim_numbers, 0);
1296   TF_RET_CHECK(
1297       absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions()));
1298   TF_RET_CHECK(
1299       absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions()));
1300   return OkStatus();
1301 }
1302 
1303 // Slice out the inner array at batch index `batch_index` from `outer_array`.
SliceOutInnerArray(llvm_ir::IrArray outer_array,llvm::Value * batch_index,llvm::IRBuilder<> * b)1304 llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array,
1305                                     llvm::Value* batch_index,
1306                                     llvm::IRBuilder<>* b) {
1307   llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
1308 
1309   Shape inner_shape = DropFirstDim(outer_array.GetShape());
1310   std::vector<llvm::Value*> multidim_index(inner_shape.rank() + 1,
1311                                            b->getInt64(0));
1312   multidim_index[0] = batch_index;
1313   llvm_ir::IrArray::Index slice_index(multidim_index, outer_array.GetShape(),
1314                                       batch_index->getType());
1315   llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b);
1316   llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(inner_shape, module);
1317   llvm::Type* slice_ptr_type = new_ir_type->getPointerTo();
1318   return llvm_ir::IrArray(b->CreateBitCast(slice_ptr, slice_ptr_type),
1319                           new_ir_type, std::move(inner_shape));
1320 }
1321 
PotentiallyImplementedAsEigenMatmul(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features,DotInfo & dot_info)1322 bool PotentiallyImplementedAsEigenMatmul(
1323     const HloInstruction& dot, const llvm_ir::IrArray& target_array,
1324     const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
1325     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
1326     mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
1327     const TargetMachineFeatures& target_machine_features, DotInfo& dot_info) {
1328   int64_t num_batch_dims =
1329       dot.dot_dimension_numbers().lhs_batch_dimensions_size();
1330 
1331   // TODO(kramerb): Remove this limitation.
1332   if (num_batch_dims > 1) return false;
1333 
1334   // First reshape the inputs to make sure we only have one batch dimension.
1335   // This is a no-op bitcast because the operands have to be in row-major layout
1336   // (enforced in CpuLayoutAssignment), and the batch dimensions are the leading
1337   // dimensions (established by DotDecomposer and checked by
1338   // ValidateDotDimensionNumbers above).
1339   llvm_ir::IrArray lhs_array_reshaped =
1340       CollapseFirstNDims(b, lhs_array, num_batch_dims);
1341   llvm_ir::IrArray rhs_array_reshaped =
1342       CollapseFirstNDims(b, rhs_array, num_batch_dims);
1343   llvm_ir::IrArray target_array_reshaped =
1344       CollapseFirstNDims(b, target_array, num_batch_dims);
1345 
1346   DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers();
1347   adjusted_dim_numbers.clear_lhs_batch_dimensions();
1348   adjusted_dim_numbers.clear_rhs_batch_dimensions();
1349 
1350   // Create a DotInfo representing the batch of "inner" dot operations.
1351   dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape());
1352   dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape());
1353   dot_info.result_shape = DropFirstDim(target_array_reshaped.GetShape());
1354   dot_info.dim_nums = dot.dot_dimension_numbers();
1355   dot_info.dim_nums.clear_lhs_batch_dimensions();
1356   dot_info.dim_nums.clear_rhs_batch_dimensions();
1357 
1358   dot_info.dim_nums.set_lhs_contracting_dimensions(
1359       0, dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims);
1360   dot_info.dim_nums.set_rhs_contracting_dimensions(
1361       0, dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims);
1362 
1363   PrimitiveType type = target_array.GetShape().element_type();
1364   if (F32 != type) return false;
1365 
1366   if (ShapeUtil::IsScalar(dot_info.lhs_shape) ||
1367       ShapeUtil::IsScalar(dot_info.rhs_shape)) {
1368     // If the operands are scalar, don't emit any loops.
1369     return false;
1370   }
1371 
1372   DotImplementationStrategy impl_strategy = GetDotImplementationStrategy(
1373       dot.parent()->parent()->config(), dot_info, target_machine_features);
1374 
1375   return impl_strategy == DotImplementationStrategy::kEigen;
1376 }
1377 
EmitBatchDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1378 Status EmitBatchDotOperation(
1379     const HloInstruction& dot, const llvm_ir::IrArray& target_array,
1380     const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
1381     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
1382     mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
1383     const TargetMachineFeatures& target_machine_features) {
1384   TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers()));
1385 
1386   // first check if the batch can be rendered directly by the runtime
1387   // otherwise lower it to a sequence of non-batch dot operations
1388   DotInfo dot_info;
1389   if (ShouldUseMultiThreadedEigen(hlo_module_config) &&
1390       PotentiallyImplementedAsEigenMatmul(
1391           dot, target_array, lhs_array, rhs_array, executable_run_options_value,
1392           b, mlir_context, hlo_module_config, target_machine_features,
1393           dot_info)) {
1394     DotOpEmitter dot_emitter(dot_info, dot.name(), target_array, lhs_array,
1395                              rhs_array, nullptr /*addend_array*/,
1396                              executable_run_options_value, b, mlir_context,
1397                              hlo_module_config, target_machine_features);
1398 
1399     return dot_emitter.EmitBatch();
1400   } else {
1401     // Lower a batch dot into a sequence of non-batch dot operations.
1402 
1403     int64_t num_batch_dims =
1404         dot.dot_dimension_numbers().lhs_batch_dimensions_size();
1405 
1406     // First reshape the inputs to make sure we only have one batch dimension.
1407     // This is a no-op bitcast because the operands have to be in row-major
1408     // layout (enforced in CpuLayoutAssignment), and the batch dimensions are
1409     // the leading dimensions (established by DotDecomposer and checked by
1410     // ValidateDotDimensionNumbers above).
1411     llvm_ir::IrArray lhs_array_reshaped =
1412         CollapseFirstNDims(b, lhs_array, num_batch_dims);
1413     llvm_ir::IrArray rhs_array_reshaped =
1414         CollapseFirstNDims(b, rhs_array, num_batch_dims);
1415     llvm_ir::IrArray target_array_reshaped =
1416         CollapseFirstNDims(b, target_array, num_batch_dims);
1417 
1418     int64_t batch_count = lhs_array_reshaped.GetShape().dimensions(0);
1419 
1420     KernelSupportLibrary ksl(b);
1421 
1422     return ksl.ForWithStatus(
1423         llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count,
1424         /*step=*/1, [&](llvm::Value* indvar) {
1425           DotDimensionNumbers adjusted_dim_numbers =
1426               dot.dot_dimension_numbers();
1427           adjusted_dim_numbers.clear_lhs_batch_dimensions();
1428           adjusted_dim_numbers.clear_rhs_batch_dimensions();
1429 
1430           // Create a DotInfo representing the "inner" non-batch dot operation.
1431           DotInfo dot_info;
1432           dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape());
1433           dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape());
1434           dot_info.result_shape =
1435               DropFirstDim(target_array_reshaped.GetShape());
1436           dot_info.dim_nums = dot.dot_dimension_numbers();
1437           dot_info.dim_nums.clear_lhs_batch_dimensions();
1438           dot_info.dim_nums.clear_rhs_batch_dimensions();
1439 
1440           dot_info.dim_nums.set_lhs_contracting_dimensions(
1441               0,
1442               dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims);
1443           dot_info.dim_nums.set_rhs_contracting_dimensions(
1444               0,
1445               dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims);
1446 
1447           llvm_ir::IrArray lhs_slice =
1448               SliceOutInnerArray(lhs_array_reshaped, /*batch_index=*/indvar, b);
1449           llvm_ir::IrArray rhs_slice =
1450               SliceOutInnerArray(rhs_array_reshaped, /*batch_index=*/indvar, b);
1451           llvm_ir::IrArray target_slice = SliceOutInnerArray(
1452               target_array_reshaped, /*batch_index=*/indvar, b);
1453 
1454           // Emit the inner non-batch dot operation.
1455           return EmitNonBatchDotOperation(
1456               dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr,
1457               executable_run_options_value, b, mlir_context, hlo_module_config,
1458               target_machine_features);
1459         });
1460   }
1461 }
1462 
IsBatchDot(const HloInstruction & instr)1463 bool IsBatchDot(const HloInstruction& instr) {
1464   if (auto* dot_instr = DynCast<HloDotInstruction>(&instr)) {
1465     return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0;
1466   }
1467 
1468   return false;
1469 }
1470 }  // namespace
1471 
DotImplementationCanHandleTranspose(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1472 bool DotImplementationCanHandleTranspose(
1473     const HloInstruction& dot_instr,
1474     const TargetMachineFeatures& target_machine_features) {
1475   DotImplementationStrategy impl_strategy =
1476       GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1477                                    DotInfo(dot_instr), target_machine_features);
1478 
1479   return impl_strategy == DotImplementationStrategy::kNaiveLlvmIr ||
1480          impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemv ||
1481          impl_strategy == DotImplementationStrategy::kEigen;
1482 }
1483 
DotOperandsAndResultMustHaveRowMajorLayout(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1484 bool DotOperandsAndResultMustHaveRowMajorLayout(
1485     const HloInstruction& dot_instr,
1486     const TargetMachineFeatures& target_machine_features) {
1487   // Batched dots require the batch dimensions to be major. DotDecomposer always
1488   // moves batch dimensions to the front of the shape, so force a row-major
1489   // layout.
1490   if (IsBatchDot(dot_instr)) {
1491     return true;
1492   }
1493 
1494   DotImplementationStrategy impl_strategy =
1495       GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1496                                    DotInfo(dot_instr), target_machine_features);
1497 
1498   return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm ||
1499          impl_strategy == DotImplementationStrategy::kEigen;
1500 }
1501 
EmitDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1502 Status EmitDotOperation(const HloInstruction& dot,
1503                         const llvm_ir::IrArray& target_array,
1504                         const llvm_ir::IrArray& lhs_array,
1505                         const llvm_ir::IrArray& rhs_array,
1506                         const llvm_ir::IrArray* addend_array,
1507                         llvm::Value* executable_run_options_value,
1508                         llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
1509                         const HloModuleConfig& hlo_module_config,
1510                         const TargetMachineFeatures& target_machine_features) {
1511   // This routine assumes that the dot operation is not in a parallelized
1512   // enclosing computation.
1513   CHECK(dot.parent()->root_instruction()->outer_dimension_partitions().empty());
1514 
1515   if (IsBatchDot(dot)) {
1516     TF_RET_CHECK(addend_array == nullptr);
1517     return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array,
1518                                  executable_run_options_value, b, mlir_context,
1519                                  hlo_module_config, target_machine_features);
1520   }
1521 
1522   return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array,
1523                                   lhs_array, rhs_array, addend_array,
1524                                   executable_run_options_value, b, mlir_context,
1525                                   hlo_module_config, target_machine_features);
1526 }
1527 }  // namespace cpu
1528 }  // namespace xla
1529