1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/Value.h"
26 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" // from @llvm-project
27 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
28 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/MLIRContext.h" // from @llvm-project
32 #include "mlir/IR/OperationSupport.h" // from @llvm-project
33 #include "mlir/IR/Value.h" // from @llvm-project
34 #include "mlir/Pass/Pass.h" // from @llvm-project
35 #include "tensorflow/compiler/xla/primitive_util.h"
36 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
37 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
38 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
39 #include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
40 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
41 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
42 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
43 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
44 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
45 #include "tensorflow/compiler/xla/service/hlo_module.h"
46 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
47 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
48 #include "tensorflow/compiler/xla/shape_util.h"
49 #include "tensorflow/compiler/xla/status_macros.h"
50 #include "tensorflow/compiler/xla/util.h"
51 #include "tensorflow/compiler/xla/xla_data.pb.h"
52 #include "tensorflow/core/platform/logging.h"
53
54 namespace xla {
55
56 using llvm_ir::SetToFirstInsertPoint;
57
58 namespace cpu {
59 namespace {
60 // Returns true if we should call into multi-threaded Eigen routines.
ShouldUseMultiThreadedEigen(const HloModuleConfig & config)61 bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) {
62 return config.debug_options().xla_cpu_multi_thread_eigen();
63 }
64
65 // Represents a dot operation. We use this in lieu of an `HloInstruction`
66 // because we want to be able to create this for the "inner" dot operation in a
67 // batch dot, for which there is no separate HLO instruction.
68 struct DotInfo {
69 Shape lhs_shape;
70 Shape rhs_shape;
71 Shape result_shape;
72 DotDimensionNumbers dim_nums;
73
74 DotInfo() = default;
75
DotInfoxla::cpu::__anon162db8010111::DotInfo76 explicit DotInfo(const HloInstruction& instr) {
77 CHECK_EQ(instr.opcode(), HloOpcode::kDot);
78 lhs_shape = instr.operand(0)->shape();
79 rhs_shape = instr.operand(1)->shape();
80 result_shape = instr.shape();
81 dim_nums = instr.dot_dimension_numbers();
82 }
83 };
84
85 // Dictates how a dot operation is implemented.
86 enum class DotImplementationStrategy {
87 // The dot operation is lowered into LLVM IR that implements a naive nested
88 // loop that computes the result one element at a time. This is our
89 // "fallback"; we don't really want this to kick in for any non-trival dot
90 // operation.
91 kNaiveLlvmIr,
92
93 // The dot operation is lowered into LLVM IR that implements a tiled
94 // Matrix*Vector operation. This strategy also allows fusing in a bias add
95 // into the dot. The matrix can be row major or column major, both are
96 // supported.
97 kTiledLlvmIrGemv,
98
99 // The dot operation is lowered into LLVM IR that implements a tiled
100 // Matrix*Matrix operation. No fusions are supported. The two inputs
101 // and the output have to be row major.
102 kTiledLlvmIrGemm,
103
104 // The dot operation is lowered into linalg.matmul op and lowered to LLVM IR.
105 kLinalgMatmul,
106
107 // The dot operation is lowered into a call into an Eigen routine. No fusions
108 // are supported today. The two inputs and the output have to be row major.
109 // However, we do allow transposing either the LHS or the RHS as part of the
110 // GEMM -- we expose this flexibility as flexibility in the contraction
111 // dimensions, but we can also see this as flexibility in the input layouts.
112 kEigen,
113 };
114
115 // Returns the implementation strategy for a dot with the configuration
116 // `dot_info`.
117 DotImplementationStrategy GetDotImplementationStrategy(
118 const HloModuleConfig& config, const DotInfo& dot_info,
119 const TargetMachineFeatures& target_machine_features);
120
121 // Helper class for emitting LLVM IR to perform the dot operation.
122 class DotOpEmitter {
123 public:
124 explicit DotOpEmitter(DotInfo dot_info, std::string dot_hlo_name,
125 const llvm_ir::IrArray& target_array,
126 const llvm_ir::IrArray& lhs_array,
127 const llvm_ir::IrArray& rhs_array,
128 const llvm_ir::IrArray* addend_array,
129 llvm::Value* executable_run_options_value,
130 llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
131 const HloModuleConfig& hlo_module_config,
132 const TargetMachineFeatures& target_machine_features);
133
134 // Emits the IR to perform the dot operation.
135 Status Emit();
136
137 // Emits the IR to perform the batch dot operation.
138 Status EmitBatch();
139
140 private:
141 // Emits instructions to perform a scalar dot product (a multiply of the
142 // LHS and RHS) and store the results in the target.
143 Status EmitScalarDot();
144
145 // Emits a call to the CPU runtime to perform the matrix multiply.
146 Status EmitCallToRuntime();
147
148 // Emits a call to the CPU runtime to perform the batch matrix multiply.
149 Status EmitCallToBatchRuntime();
150
151 // Represents the dimensions of a matrix-matrix multiply operation.
152 struct MatMultDims {
153 // The number of rows in the LHS.
154 int64_t m;
155
156 // The number of columns in the LHS, which is also must be equal to the
157 // number of rows in the RHS.
158 int64_t k;
159
160 // The number of columns on the RHS.
161 int64_t n;
162
163 // True if the LHS matrix is column major.
164 bool lhs_column_major;
165
166 // True if the LHS contraction dimension is 1.
167 bool lhs_canonical;
168
169 // True if the RHS matrix is column major.
170 bool rhs_column_major;
171
172 // True if the RHS contraction dimension is 0.
173 bool rhs_canonical;
174 };
175
176 // Get the MatMultDims instance for the dot product this DotOpEmitter
177 // represents. Precondition: the dot is of rank 2 (and thus its operands are
178 // of rank 2 as well).
179 MatMultDims GetMatMultDims() const;
180
181 // Get the MatMultDims instance for the dot product this DotOpEmitter
182 // represents. Precondition: the dot is of rank 3 (and thus its operands are
183 // of rank 3 as well).
184 MatMultDims GetBatchMatMultDims() const;
185
186 // Lowers the dot operation as a tiled Matrix*Vector loop.
187 void EmitTiledLlvmIrGemv();
188
189 // Lowers the dot operation as a tiled Matrix*Matrix loop.
190 void EmitTiledLlvmIrGemm();
191
192 // Lowers the dot operation through MLIR's linalg.matmul.
193 Status EmitLinalgMatmul();
194
195 // Lowers the dot operation as a naive nested loop that computes the result
196 // one element at a time.
197 void EmitNaiveLlvmIrGemm();
198
199 // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
200 // registers.
GetGemvTilingFactor() const201 int64_t GetGemvTilingFactor() const {
202 const int64_t kDefaultTilingFactor = 8;
203 return options::LlvmIrGemvTilingFactor(hlo_module_config_)
204 .value_or(kDefaultTilingFactor);
205 }
206
GetGemmTileSize() const207 std::tuple<int64_t, int64_t, int64_t> GetGemmTileSize() const {
208 // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
209 //
210 // TODO(b/80093688): Tune for other architectures and centralize this
211 // information in one place.
212 const std::tuple<int64_t, int64_t, int64_t> kDefaultTileSize =
213 std::tuple<int64_t, int64_t, int64_t>(11, 9, 1);
214 return options::LlvmIrGemmTileSize(hlo_module_config_)
215 .value_or(kDefaultTileSize);
216 }
217
GetMlirGemmTileSize() const218 std::array<int64_t, 3> GetMlirGemmTileSize() const {
219 // Tile by 4 x registers x register size. This was picked by running
220 // small matmuls on Haswell and Skylake. There's a lot of room for
221 // improvement here.
222 constexpr int64_t kDefaultTileSizeForM = 4;
223 int64_t elements_per_register =
224 target_machine_features_.vector_register_num_elements(
225 *b_->GetInsertBlock()->getParent(),
226 dot_info_.result_shape.element_type());
227 int64_t num_registers = target_machine_features_.vector_register_count(
228 *b_->GetInsertBlock()->getParent());
229 return {{kDefaultTileSizeForM, num_registers, elements_per_register}};
230 }
231
232 DotInfo dot_info_;
233 std::string dot_hlo_name_;
234 const llvm_ir::IrArray& target_array_;
235 const llvm_ir::IrArray& lhs_array_;
236 const llvm_ir::IrArray& rhs_array_;
237 const llvm_ir::IrArray* addend_array_;
238 llvm::Value* executable_run_options_value_;
239 llvm::IRBuilder<>* b_;
240 mlir::MLIRContext* mlir_context_;
241 const HloModuleConfig& hlo_module_config_;
242 const TargetMachineFeatures& target_machine_features_;
243 };
244 } // namespace
245
DotOpEmitter(DotInfo dot_info,std::string dot_hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)246 DotOpEmitter::DotOpEmitter(
247 DotInfo dot_info, std::string dot_hlo_name,
248 const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
249 const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array,
250 llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
251 mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
252 const TargetMachineFeatures& target_machine_features)
253 : dot_info_(std::move(dot_info)),
254 dot_hlo_name_(std::move(dot_hlo_name)),
255 target_array_(target_array),
256 lhs_array_(lhs_array),
257 rhs_array_(rhs_array),
258 addend_array_(addend_array),
259 executable_run_options_value_(executable_run_options_value),
260 b_(b),
261 mlir_context_(mlir_context),
262 hlo_module_config_(hlo_module_config),
263 target_machine_features_(target_machine_features) {}
264
EmitLinalgMatmul()265 Status DotOpEmitter::EmitLinalgMatmul() {
266 Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape};
267 llvm::Value* operand_ptrs[] = {lhs_array_.GetBasePointer(),
268 rhs_array_.GetBasePointer()};
269 llvm::Value* target_ptr = target_array_.GetBasePointer();
270
271 // Zero out the output buffer.
272 int64_t size_bytes = ShapeUtil::ByteSizeOf(dot_info_.result_shape);
273 b_->CreateMemSet(target_ptr, b_->getInt8(0), /*Size=*/size_bytes,
274 /*Align=*/llvm::MaybeAlign(1));
275
276 std::string name =
277 absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_",
278 dot_info_.lhs_shape.ToString(true), "_",
279 dot_info_.rhs_shape.ToString(true));
280
281 return EmitMlirFuncAndCall(
282 mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
283 operand_ptrs, name,
284 [&](mlir::OpBuilder* builder, mlir::func::FuncOp function) {
285 CHECK_EQ(dot_info_.dim_nums.lhs_contracting_dimensions_size(), 1);
286 CHECK_EQ(dot_info_.dim_nums.rhs_contracting_dimensions_size(), 1);
287 mlir::MLIRContext* context = builder->getContext();
288 mlir::Value a = function.getArgument(0), b = function.getArgument(1),
289 c = function.getArgument(2);
290
291 llvm::SmallVector<mlir::AffineExpr, 2> b_exprs(
292 dot_info_.lhs_shape.rank());
293 llvm::SmallVector<mlir::AffineExpr, 2> c_exprs(
294 dot_info_.rhs_shape.rank());
295
296 llvm::SmallVector<mlir::AffineExpr, 2> parallel_exprs;
297 mlir::AffineExpr reduce_expr;
298 for (int i = 0; i != dot_info_.result_shape.rank(); ++i) {
299 parallel_exprs.push_back(mlir::getAffineDimExpr(i, context));
300 }
301 reduce_expr =
302 mlir::getAffineDimExpr(dot_info_.result_shape.rank(), context);
303
304 // The reduction expr is shared for both inputs.
305 b_exprs[dot_info_.dim_nums.lhs_contracting_dimensions(0)] = reduce_expr;
306 c_exprs[dot_info_.dim_nums.rhs_contracting_dimensions(0)] = reduce_expr;
307
308 // Fill in the remaining parallel exprs.
309 int par_expr_num = 0;
310 for (auto* v : {&b_exprs, &c_exprs}) {
311 for (auto& e : *v) {
312 if (!e) {
313 e = parallel_exprs[par_expr_num++];
314 }
315 }
316 }
317
318 llvm::SmallVector<llvm::StringRef, 4> iteratorTypes(
319 parallel_exprs.size(), toString(mlir::IteratorType::Parallel));
320 iteratorTypes.push_back(toString(mlir::IteratorType::Reduction));
321 builder->create<mlir::linalg::GenericOp>(
322 function.getLoc(),
323 /*inputs=*/mlir::ValueRange{b, c},
324 /*outputs=*/mlir::ValueRange{a},
325 /*indexingMaps=*/
326 mlir::AffineMap::inferFromExprList(
327 {b_exprs, c_exprs, parallel_exprs}),
328 /*iteratorTypes=*/iteratorTypes,
329 [](mlir::OpBuilder& b, mlir::Location loc, mlir::ValueRange args) {
330 mlir::ArithBuilder ab(b, loc);
331 mlir::Value mul = ab.mul(args[0], args[1]);
332 mlir::Value add = ab.add(mul, args[2]);
333 b.create<mlir::linalg::YieldOp>(loc, add);
334 });
335 builder->create<mlir::func::ReturnOp>(function.getLoc());
336
337 mlir::linalg::LinalgTilingOptions tilingOptions;
338 tilingOptions = tilingOptions.setTileSizes(GetMlirGemmTileSize());
339 // TODO: this has been retired upstream, reevaluate whether this
340 // path really needs it or if it is even relevant anymore.
341 // int64_t alignment =
342 // target_machine_features_.minimum_alignment_for_allocation(
343 // ShapeUtil::ByteSizeOf(dot_info_.result_shape));
344 mlir::linalg::CodegenStrategy strategy;
345 strategy
346 .tile(mlir::linalg::GenericOp::getOperationName(), tilingOptions)
347 // TODO: this has been retired upstream, reevaluate whether this
348 // path really needs it or if it is even relevant anymore.
349 // .promote(mlir::linalg::GenericOp::getOperationName(),
350 // mlir::linalg::LinalgPromotionOptions()
351 // .setAlignment(alignment)
352 // .setUseFullTileBuffersByDefault(true)
353 // .setUseAlloca(true))
354 .vectorize(mlir::linalg::GenericOp::getOperationName())
355 .vectorLowering(
356 mlir::linalg::LinalgVectorLoweringOptions()
357 .setVectorTransformsOptions(
358 mlir::vector::VectorTransformsOptions()
359 .setVectorTransformsOptions(
360 mlir::vector::VectorContractLowering::
361 OuterProduct))
362 .setVectorTransferToSCFOptions(
363 mlir::VectorTransferToSCFOptions().enableFullUnroll()));
364 // TODO: this should be within a pass and we should be able to create a
365 // nested OpPassManager.
366 // Created a nested OpPassManager, populate the strategy and run.
367 // mlir::OpPassManager dynamicPM("func.func");
368 // strategy.configurePassPipeline(dynamicPM, function.getContext());
369 // Propagate pass failure?
370 // (void)mlir::runPipeline(dynamicPM, function);
371 mlir::PassManager pm(function.getContext(),
372 function.getOperationName());
373 strategy.configurePassPipeline(pm, function.getContext());
374 // Propagate pass failure?
375 (void)pm.run(function);
376 });
377 }
378
EmitTiledLlvmIrGemm()379 void DotOpEmitter::EmitTiledLlvmIrGemm() {
380 PrimitiveType primitive_type = dot_info_.result_shape.element_type();
381 MatMultDims mat_mult_dims = GetMatMultDims();
382
383 llvm::Value* lhs = lhs_array_.GetBasePointer();
384 llvm::Value* rhs = rhs_array_.GetBasePointer();
385 llvm::Value* target = target_array_.GetBasePointer();
386 int64_t m = mat_mult_dims.m;
387 int64_t k = mat_mult_dims.k;
388 int64_t n = mat_mult_dims.n;
389
390 if (mat_mult_dims.lhs_column_major) {
391 std::swap(lhs, rhs);
392 std::swap(m, n);
393 }
394
395 int64_t size_bytes =
396 m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
397 b_->CreateMemSet(target, b_->getInt8(0), /*Size=*/size_bytes,
398 /*Align=*/llvm::MaybeAlign(1));
399
400 int64_t max_target_vector_width =
401 target_machine_features_.vector_register_num_elements(
402 *b_->GetInsertBlock()->getParent(), primitive_type);
403
404 int64_t tile_size_m, tile_size_k, tile_size_n_in_vector_width;
405 std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
406 GetGemmTileSize();
407
408 EmitSmallGemm(
409 /*scalar_type=*/primitive_type,
410 /*m=*/m, /*k=*/k, /*n=*/n,
411 /*max_vectorization_width=*/max_target_vector_width,
412 /*max_vector_count=*/tile_size_n_in_vector_width,
413 /*min_vectorization_width=*/std::min<int64_t>(4, max_target_vector_width),
414 /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs,
415 /*rhs=*/rhs, /*result=*/target, b_, hlo_module_config_);
416 }
417
EmitTiledLlvmIrGemv()418 void DotOpEmitter::EmitTiledLlvmIrGemv() {
419 PrimitiveType primitive_type = dot_info_.result_shape.element_type();
420
421 CHECK(primitive_util::IsFloatingPointType(primitive_type) ||
422 primitive_util::IsIntegralType(primitive_type));
423
424 MatMultDims mat_mult_dims = GetMatMultDims();
425 bool is_column_major_matrix_vector_gemv = false;
426 bool is_row_major_matrix_vector_gemv = false;
427
428 int64_t m, k;
429 bool swap_operands;
430
431 if (mat_mult_dims.m == 1) {
432 // Our emitters can only do Matrix*Vector (abbreviated as M*V) but when M=1
433 // we actually want V*M. We implement V*M as follows (Tr(X) = Transpose of
434 // X):
435 //
436 // V*M = Tr(Tr(V*M)) // Tr(Tr(X)) == X
437 // = Tr(Tr(M) * Tr(V)) // Tr(A * B) == Tr(B) * Tr(A)
438 //
439 // Since transposing a vector is physically a no-op, this is really
440 // equivalent to `Tr(M) * V`. We further implement Tr(M) by pretending that
441 // M is row major if it is actually column major and vice-versa.
442
443 bool rhs_effectively_column_major = mat_mult_dims.rhs_canonical
444 ? mat_mult_dims.rhs_column_major
445 : !mat_mult_dims.rhs_column_major;
446
447 if (rhs_effectively_column_major) {
448 k = mat_mult_dims.k;
449 m = mat_mult_dims.n;
450
451 // We set is_row_major_matrix_vector_gemv and not
452 // is_column_major_matrix_vector_gemv to implement the Transpose trick
453 // mentioned above.
454 is_row_major_matrix_vector_gemv = true;
455 swap_operands = true;
456 } else {
457 k = mat_mult_dims.k;
458 m = mat_mult_dims.n;
459
460 // We set is_column_major_matrix_vector_gemv and not
461 // is_row_major_matrix_vector_gemv to implement the Transpose trick
462 // mentioned above.
463 is_column_major_matrix_vector_gemv = true;
464 swap_operands = true;
465 }
466 }
467
468 if (mat_mult_dims.n == 1) {
469 bool lhs_effectively_column_major = mat_mult_dims.lhs_canonical
470 ? mat_mult_dims.lhs_column_major
471 : !mat_mult_dims.lhs_column_major;
472
473 if (lhs_effectively_column_major) {
474 m = mat_mult_dims.m;
475 k = mat_mult_dims.k;
476 is_column_major_matrix_vector_gemv = true;
477 swap_operands = false;
478 } else {
479 m = mat_mult_dims.m;
480 k = mat_mult_dims.k;
481 is_row_major_matrix_vector_gemv = true;
482 swap_operands = false;
483 }
484 }
485
486 CHECK(is_column_major_matrix_vector_gemv || is_row_major_matrix_vector_gemv);
487
488 int64_t tiling_factor = GetGemvTilingFactor();
489 CHECK_GT(tiling_factor, 0);
490
491 llvm::Value* result_op = target_array_.GetBasePointer();
492 llvm::Value* lhs_op =
493 swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
494 llvm::Value* rhs_op =
495 swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
496
497 const int target_vector_register_element_size =
498 target_machine_features_.vector_register_num_elements(
499 *b_->GetInsertBlock()->getParent(), primitive_type);
500
501 // We may not always know the vector register size for the target we're
502 // compiling against, in which case target_vector_register_element_size is 0.
503 // In these cases we choose a default LLVM IR register size.
504 const int kUnknownTargetVectorRegisterSize = 4;
505 const int vector_register_element_size =
506 target_vector_register_element_size == 0
507 ? kUnknownTargetVectorRegisterSize
508 : target_vector_register_element_size;
509
510 if (is_column_major_matrix_vector_gemv) {
511 VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
512 << " and k = " << k;
513 EmitColumnMajorGemv(
514 /*scalar_type=*/primitive_type,
515 /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
516 /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
517 /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
518 /*result=*/result_op, b_, hlo_module_config_);
519 } else {
520 VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
521 << " and k = " << k;
522 EmitRowMajorGemv(
523 /*scalar_type=*/primitive_type,
524 /*tile_rows=*/tiling_factor,
525 /*tile_cols=*/vector_register_element_size,
526 /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
527 /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
528 /*result=*/result_op, b_, hlo_module_config_);
529 }
530 }
531
Emit()532 Status DotOpEmitter::Emit() {
533 // The dot operation performs a sum of products over dimension 0 of the left
534 // hand side operand and dimension 1 of the right hand side operand.
535 //
536 // Let the shapes of lhs and rhs be defined as below:
537 //
538 // lhs = [L{n-1} x L{n-2} x ... L{0}]
539 // rhs = [R{m-1} x R{m-2} x ... R{0}]
540 //
541 // The sum-of-products dimension in the lhs has size L{0} and the dimension in
542 // the rhs has size R{1}. Necessarily, then:
543 //
544 // L{0} == R{1}
545 //
546 // The output of the operation has the following shape:
547 //
548 // output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
549 //
550 // To perform the operation we construct a loop nest with one for-loop for
551 // each dimension of the output. Inside this loop nest is another for-loop
552 // which performs the sum-of-products (the reduction loop) before storing
553 // the result in the output buffer.
554
555 const Shape& lhs_shape = lhs_array_.GetShape();
556 const Shape& rhs_shape = rhs_array_.GetShape();
557
558 if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
559 // If the operands are scalar, don't emit any loops.
560 TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) &&
561 ShapeUtil::IsScalar(rhs_shape));
562 return EmitScalarDot();
563 }
564
565 switch (GetDotImplementationStrategy(hlo_module_config_, dot_info_,
566 target_machine_features_)) {
567 case DotImplementationStrategy::kNaiveLlvmIr:
568 EmitNaiveLlvmIrGemm();
569 return OkStatus();
570
571 case DotImplementationStrategy::kTiledLlvmIrGemv:
572 EmitTiledLlvmIrGemv();
573 return OkStatus();
574
575 case DotImplementationStrategy::kTiledLlvmIrGemm:
576 EmitTiledLlvmIrGemm();
577 return OkStatus();
578
579 case DotImplementationStrategy::kLinalgMatmul:
580 return EmitLinalgMatmul();
581
582 case DotImplementationStrategy::kEigen:
583 return EmitCallToRuntime();
584 }
585 }
586
EmitBatch()587 Status DotOpEmitter::EmitBatch() {
588 // The dot operation performs a sum of products over dimension 0 of the left
589 // hand side operand and dimension 1 of the right hand side operand.
590 //
591 // Let the shapes of lhs and rhs be defined as below:
592 //
593 // lhs = [L{n-1} x L{n-2} x ... L{0}]
594 // rhs = [R{m-1} x R{m-2} x ... R{0}]
595 //
596 // The sum-of-products dimension in the lhs has size L{0} and the dimension in
597 // the rhs has size R{1}. Necessarily, then:
598 //
599 // L{0} == R{1}
600 //
601 // The output of the operation has the following shape:
602 //
603 // output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
604 //
605 // To perform the operation we construct a loop nest with one for-loop for
606 // each dimension of the output. Inside this loop nest is another for-loop
607 // which performs the sum-of-products (the reduction loop) before storing
608 // the result in the output buffer.
609
610 return EmitCallToBatchRuntime();
611 }
612
EmitNaiveLlvmIrGemm()613 void DotOpEmitter::EmitNaiveLlvmIrGemm() {
614 CHECK_EQ(addend_array_, nullptr);
615
616 const Shape& lhs_shape = lhs_array_.GetShape();
617 const Shape& rhs_shape = rhs_array_.GetShape();
618 const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
619
620 // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special
621 // case where the reduction dimension is 0 for both LHS and RHS. This results
622 // in a vector dot product producing a scalar.
623 int64_t lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0);
624 int64_t rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0);
625
626 // Verify the reduction dimension in the two operands are the same size.
627 CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension),
628 rhs_shape.dimensions(rhs_reduction_dimension));
629
630 bool lhs_reduction_along_minor_dimension =
631 lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0);
632 bool rhs_reduction_along_minor_dimension =
633 rhs_reduction_dimension == LayoutUtil::Minor(rhs_shape.layout(), 0);
634
635 // Create loop nests which loop through the LHS operand dimensions and the RHS
636 // operand dimensions. The reduction dimension of the LHS and RHS are handled
637 // in a separate innermost loop which performs the sum of products.
638 llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_);
639 std::vector<llvm::Value*> lhs_multi_index =
640 loop_nest.EmitOperandArrayLoopNest(
641 lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
642 std::vector<llvm::Value*> rhs_multi_index =
643 loop_nest.EmitOperandArrayLoopNest(
644 rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
645
646 // Create the loop which does the sum of products reduction.
647 //
648 // The prevent_unrolling bit is working around a deficiency in LLVM's loop
649 // vectorization pipeline, wherein in some cases unrolling a loop can prevent
650 // effective vectorization. Since we know that the IR we generate when
651 // reducing across the minor dimension in both LHS and RHS is vectorized well
652 // by the loop vectorizer, we block unrolling in that case to stop loop unroll
653 // from messing up the vectorization.
654 std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
655 0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction",
656 /*unroll_mode=*/
657 (lhs_reduction_along_minor_dimension &&
658 rhs_reduction_along_minor_dimension)
659 ? xla::llvm_ir::UnrollMode::kNoUnroll
660 : xla::llvm_ir::UnrollMode::kDefaultUnroll);
661
662 // The final entry in the rhs and lhs indexes is the indvar of the
663 // reduction loop.
664 lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
665 llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_shape,
666 b_->getInt64Ty());
667 rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
668 llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_shape,
669 b_->getInt64Ty());
670
671 // For computing the sum of products we alloca a single location to store the
672 // dot product result as we accumulate it within the reduction loop. After the
673 // reduction loop we load the result and store into the output array.
674
675 // Function entry basic block.
676 // - Emit alloca for accumulator
677 llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent();
678 SetToFirstInsertPoint(&func->getEntryBlock(), b_);
679 llvm::Type* accum_type = target_array_.GetElementLlvmType();
680 llvm::Value* accum_address =
681 b_->CreateAlloca(accum_type, /*ArraySize=*/nullptr, "accum_address");
682
683 // Preheader basic block of reduction loop:
684 // - Initialize accumulator to zero.
685 llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
686 b_->SetInsertPoint(preheader_bb->getTerminator());
687
688 b_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address);
689
690 // Body basic block of reduction loop:
691 // - Load elements from lhs and rhs array.
692 // - Multiply lhs-element and rhs-element.
693 // - Load accumulator and add to product.
694 // - Store sum back into accumulator.
695 SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), b_);
696
697 llvm::Value* lhs_element = lhs_array_.EmitReadArrayElement(lhs_index, b_);
698 llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, b_);
699
700 llvm::Value* accum = b_->CreateLoad(accum_type, accum_address);
701 llvm::Value* updated_accum;
702 if (ShapeUtil::ElementIsComplex(lhs_shape)) {
703 auto real = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {0}); };
704 auto imag = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {1}); };
705 llvm::Value* product_real =
706 b_->CreateFSub(b_->CreateFMul(real(lhs_element), real(rhs_element)),
707 b_->CreateFMul(imag(lhs_element), imag(rhs_element)));
708 llvm::Value* product_imag =
709 b_->CreateFAdd(b_->CreateFMul(real(lhs_element), imag(rhs_element)),
710 b_->CreateFMul(imag(lhs_element), real(rhs_element)));
711 updated_accum = b_->CreateInsertValue(
712 accum, b_->CreateFAdd(real(accum), product_real), {0});
713 updated_accum = b_->CreateInsertValue(
714 updated_accum, b_->CreateFAdd(imag(accum), product_imag), {1});
715 } else if (ShapeUtil::ElementIsIntegral(lhs_shape)) {
716 llvm::Value* product = b_->CreateMul(lhs_element, rhs_element);
717 updated_accum = b_->CreateAdd(accum, product);
718 } else if (lhs_shape.element_type() == PRED) {
719 llvm::Value* product = b_->CreateAnd(lhs_element, rhs_element);
720 updated_accum = b_->CreateOr(accum, product);
721 } else {
722 llvm::Value* product = b_->CreateFMul(lhs_element, rhs_element);
723 updated_accum = b_->CreateFAdd(accum, product);
724 }
725 b_->CreateStore(updated_accum, accum_address);
726
727 // Exit basic block of reduction loop.
728 // - Load accumulator value (the result).
729 // - Store into output array.
730 SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), b_);
731
732 llvm::Value* result = b_->CreateLoad(accum_type, accum_address);
733
734 // Create index into target address. The target index is the concatenation of
735 // the rhs and lhs indexes with the reduction dimensions removed. The terms
736 // from the rhs index are the lower dimensions in the index so we add them
737 // first.
738 std::vector<llvm::Value*> target_multi_index;
739 for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
740 if (dimension != lhs_reduction_dimension) {
741 target_multi_index.push_back(lhs_index[dimension]);
742 }
743 }
744 for (int dimension = 0; dimension < rhs_index.size(); ++dimension) {
745 if (dimension != rhs_reduction_dimension) {
746 target_multi_index.push_back(rhs_index[dimension]);
747 }
748 }
749
750 llvm_ir::IrArray::Index target_index(
751 target_multi_index, target_array_.GetShape(), lhs_index.GetType());
752 target_array_.EmitWriteArrayElement(target_index, result, b_);
753
754 // Set the IR builder insert point to the exit basic block of the outer most
755 // loop.
756 b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
757 }
758
EmitScalarDot()759 Status DotOpEmitter::EmitScalarDot() {
760 // A scalar dot is just a scalar multiply.
761 llvm::Value* result;
762 // Use the same index_type for all tensor accesses in the same kernel.
763 llvm::Type* index_type = b_->getInt64Ty();
764 llvm_ir::IrArray::Index element_index(index_type);
765 llvm::Value* lhs_value =
766 lhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
767 llvm::Value* rhs_value =
768 rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
769 if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
770 auto get_real = [&](llvm::Value* x) {
771 return b_->CreateExtractValue(x, {0});
772 };
773
774 auto get_imag = [&](llvm::Value* x) {
775 return b_->CreateExtractValue(x, {1});
776 };
777
778 llvm::Value* real = b_->CreateFSub(
779 b_->CreateFMul(get_real(lhs_value), get_real(rhs_value)),
780 b_->CreateFMul(get_imag(lhs_value), get_imag(rhs_value)));
781 llvm::Value* imag = b_->CreateFAdd(
782 b_->CreateFMul(get_real(lhs_value), get_imag(rhs_value)),
783 b_->CreateFMul(get_imag(lhs_value), get_real(rhs_value)));
784 result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
785 result = b_->CreateInsertValue(result, real, {0});
786 result = b_->CreateInsertValue(result, imag, {1});
787 } else {
788 result = b_->CreateFMul(lhs_value, rhs_value);
789 }
790 target_array_.EmitWriteArrayElement(/*index=*/element_index, result, b_);
791 return OkStatus();
792 }
793
EmitCallToRuntime()794 Status DotOpEmitter::EmitCallToRuntime() {
795 // The signature of the Eigen runtime matmul function is:
796 //
797 // (void)(void* run_options, float* out, float* lhs, float* rhs,
798 // int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
799 // int32_t transpose_rhs);
800 // The two transpose_... parameters are actually booleans, but we use int32_t
801 // to avoid target-dependent calling convention details.
802
803 bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_);
804 bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
805 bool use_acl = hlo_module_config_.debug_options().xla_cpu_use_acl();
806 PrimitiveType type = target_array_.GetShape().element_type();
807 llvm::Function* function = b_->GetInsertBlock()->getParent();
808 llvm::Module* module = function->getParent();
809 llvm::Type* float_type;
810 const char* fn_name;
811 switch (type) {
812 case F16:
813 fn_name = multi_threaded
814 ? runtime::kEigenMatMulF16SymbolName
815 : runtime::kEigenSingleThreadedMatMulF16SymbolName;
816 float_type = b_->getHalfTy();
817 break;
818 case F32:
819 fn_name =
820 multi_threaded
821 ? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName
822 : (use_acl ? runtime::kACLMatMulF32SymbolName
823 : runtime::kEigenMatMulF32SymbolName))
824 : (use_mkl_dnn
825 ? runtime::kMKLSingleThreadedMatMulF32SymbolName
826 : runtime::kEigenSingleThreadedMatMulF32SymbolName);
827 float_type = b_->getFloatTy();
828 break;
829 case F64:
830 fn_name = multi_threaded
831 ? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName
832 : runtime::kEigenMatMulF64SymbolName)
833 : (use_mkl_dnn
834 ? runtime::kMKLSingleThreadedMatMulF64SymbolName
835 : runtime::kEigenSingleThreadedMatMulF64SymbolName);
836 float_type = b_->getDoubleTy();
837 break;
838 case C64:
839 fn_name = multi_threaded
840 ? runtime::kEigenMatMulC64SymbolName
841 : runtime::kEigenSingleThreadedMatMulC64SymbolName;
842 float_type = llvm_ir::PrimitiveTypeToIrType(C64, module);
843 break;
844 case C128:
845 fn_name = multi_threaded
846 ? runtime::kEigenMatMulC128SymbolName
847 : runtime::kEigenSingleThreadedMatMulC128SymbolName;
848 float_type = llvm_ir::PrimitiveTypeToIrType(C128, module);
849 break;
850 case S32:
851 fn_name = multi_threaded
852 ? runtime::kEigenMatMulS32SymbolName
853 : runtime::kEigenSingleThreadedMatMulS32SymbolName;
854 float_type = b_->getInt32Ty();
855 break;
856 default:
857 return Unimplemented("Invalid type %s for dot operation",
858 PrimitiveType_Name(type));
859 }
860
861 llvm::Type* float_ptr_type = float_type->getPointerTo();
862 llvm::Type* int64_type = b_->getInt64Ty();
863 llvm::Type* int32_type = b_->getInt32Ty();
864 llvm::Type* int8_ptr_type = b_->getInt8Ty()->getPointerTo();
865 llvm::FunctionType* matmul_type = llvm::FunctionType::get(
866 b_->getVoidTy(),
867 {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
868 int64_type, int64_type, int64_type, int32_type, int32_type},
869 /*isVarArg=*/false);
870
871 llvm::FunctionCallee matmul_func =
872 module->getOrInsertFunction(fn_name, matmul_type);
873 if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
874 fn->setCallingConv(llvm::CallingConv::C);
875 fn->setDoesNotThrow();
876 fn->setOnlyAccessesArgMemory();
877 }
878
879 // The Eigen runtime function expects column-major layout. If the matrices are
880 // row major, then use the following identity to compute the product:
881 //
882 // (A x B)^T = B^T x A^T
883 //
884 // The connection between this identity and memory layout is that the
885 // transpose operation can also be considered as an operation that changes the
886 // memory layout of a matrix from row-major to column-major or vice versa.
887 //
888 // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
889
890 MatMultDims mat_mult_dims = GetMatMultDims();
891
892 CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
893
894 const llvm_ir::IrArray* lhs = &lhs_array_;
895 const llvm_ir::IrArray* rhs = &rhs_array_;
896 bool transpose_lhs = !mat_mult_dims.lhs_canonical;
897 bool transpose_rhs = !mat_mult_dims.rhs_canonical;
898
899 if (!mat_mult_dims.lhs_column_major) {
900 std::swap(mat_mult_dims.m, mat_mult_dims.n);
901 std::swap(lhs, rhs);
902 std::swap(transpose_lhs, transpose_rhs);
903 }
904
905 b_->CreateCall(
906 matmul_func,
907 {b_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
908 b_->CreateBitCast(target_array_.GetBasePointer(), float_ptr_type),
909 b_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
910 b_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
911 b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n),
912 b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs),
913 b_->getInt32(transpose_rhs)});
914 return OkStatus();
915 }
916
EmitCallToBatchRuntime()917 Status DotOpEmitter::EmitCallToBatchRuntime() {
918 // The signature of the runtime batch matmul function is:
919 //
920 // (void)(void* run_options, float* out, float* lhs, float* rhs,
921 // int64_t m, int64_t n, int64_t k, int64_t batch_size, int32_t
922 // transpose_lhs, int32_t transpose_rhs);
923 // The two transpose_... parameters are actually booleans, but we use int32_t
924 // to avoid target-dependent calling convention details.
925
926 PrimitiveType type = target_array_.GetShape().element_type();
927 bool use_acl = hlo_module_config_.debug_options().xla_cpu_use_acl();
928 llvm::Function* function = b_->GetInsertBlock()->getParent();
929 llvm::Module* module = function->getParent();
930 llvm::Type* float_type;
931 const char* fn_name;
932 switch (type) {
933 case F32:
934 fn_name = use_acl ? runtime::kACLBatchMatMulF32SymbolName
935 : runtime::kEigenBatchMatMulF32SymbolName;
936
937 float_type = b_->getFloatTy();
938 break;
939 default:
940 return Unimplemented("Invalid type %s for dot operation",
941 PrimitiveType_Name(type));
942 }
943
944 llvm::Type* float_ptr_type = float_type->getPointerTo();
945 llvm::Type* int64_type = b_->getInt64Ty();
946 llvm::Type* int32_type = b_->getInt32Ty();
947 llvm::Type* int8_ptr_type = b_->getInt8Ty()->getPointerTo();
948 llvm::FunctionType* matmul_type = llvm::FunctionType::get(
949 b_->getVoidTy(),
950 {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
951 int64_type, int64_type, int64_type, int64_type, int32_type, int32_type},
952 /*isVarArg=*/false);
953
954 llvm::FunctionCallee matmul_func =
955 module->getOrInsertFunction(fn_name, matmul_type);
956 if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
957 fn->setCallingConv(llvm::CallingConv::C);
958 fn->setDoesNotThrow();
959 fn->setOnlyAccessesArgMemory();
960 }
961
962 // The ACL runtime function expects column-major layout. If the matrices are
963 // row major, then use the following identity to compute the product:
964 //
965 // (A x B)^T = B^T x A^T
966 //
967 // The connection between this identity and memory layout is that the
968 // transpose operation can also be considered as an operation that changes the
969 // memory layout of a matrix from row-major to column-major or vice versa.
970 //
971 // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
972
973 MatMultDims mat_mult_dims = GetBatchMatMultDims();
974 CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
975
976 const llvm_ir::IrArray* lhs = &lhs_array_;
977 const llvm_ir::IrArray* rhs = &rhs_array_;
978 bool transpose_lhs = !mat_mult_dims.lhs_canonical;
979 bool transpose_rhs = !mat_mult_dims.rhs_canonical;
980 const Shape& lhs_shape = lhs_array_.GetShape();
981
982 if (!mat_mult_dims.lhs_column_major) {
983 std::swap(mat_mult_dims.m, mat_mult_dims.n);
984 std::swap(lhs, rhs);
985 std::swap(transpose_lhs, transpose_rhs);
986 }
987
988 VLOG(1) << "Batch dot emitted with runtime:" << fn_name;
989
990 b_->CreateCall(
991 matmul_func,
992 {b_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
993 b_->CreateBitCast(target_array_.GetBasePointer(), float_ptr_type),
994 b_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
995 b_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
996 b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n),
997 b_->getInt64(mat_mult_dims.k), b_->getInt64(lhs_shape.dimensions(0)),
998 b_->getInt32(static_cast<uint32_t>(transpose_lhs)),
999 b_->getInt32(static_cast<uint32_t>(transpose_rhs))});
1000 return Status::OK();
1001 }
1002
GetMatMultDims() const1003 DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
1004 CHECK_LE(dot_info_.result_shape.dimensions_size(), 2);
1005
1006 const Shape& lhs_shape = lhs_array_.GetShape();
1007 const Shape& rhs_shape = rhs_array_.GetShape();
1008 const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
1009
1010 auto is_column_major = [](const Shape& shape) {
1011 return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0;
1012 };
1013
1014 // Non-contracting dots should never make it here.
1015 CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0);
1016 CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0);
1017
1018 return {
1019 /*m=*/lhs_shape.rank() <= 1
1020 ? 1LL
1021 : lhs_shape.dimensions(1LL - dim_nums.lhs_contracting_dimensions(0)),
1022 /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
1023 /*n=*/rhs_shape.rank() <= 1
1024 ? 1LL
1025 : rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)),
1026 /*lhs_column_major=*/is_column_major(lhs_shape),
1027 /*lhs_canonical=*/lhs_shape.rank() <= 1 ||
1028 dim_nums.lhs_contracting_dimensions(0) == 1,
1029 /*rhs_column_major=*/is_column_major(rhs_shape),
1030 /*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0};
1031 }
1032
GetBatchMatMultDims() const1033 DotOpEmitter::MatMultDims DotOpEmitter::GetBatchMatMultDims() const {
1034 CHECK_LE(dot_info_.result_shape.dimensions_size(), 2);
1035
1036 const Shape& lhs_shape = lhs_array_.GetShape();
1037 const Shape& rhs_shape = rhs_array_.GetShape();
1038 const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
1039
1040 auto is_column_major = [](const Shape& shape) {
1041 return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0;
1042 };
1043
1044 // Non-contracting dots should never make it here.
1045 CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0);
1046 CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0);
1047
1048 return {
1049 /*m=*/lhs_shape.rank() <= 1
1050 ? 1LL
1051 : lhs_shape.dimensions(2LL - dim_nums.lhs_contracting_dimensions(0)),
1052 /*k=*/lhs_shape.dimensions(1LL + dim_nums.lhs_contracting_dimensions(0)),
1053 /*n=*/rhs_shape.rank() <= 1
1054 ? 1LL
1055 : rhs_shape.dimensions(2LL - dim_nums.rhs_contracting_dimensions(0)),
1056 /*lhs_column_major=*/is_column_major(lhs_shape),
1057 /*lhs_canonical=*/lhs_shape.rank() <= 1 ||
1058 dim_nums.lhs_contracting_dimensions(0) == 1,
1059 /*rhs_column_major=*/is_column_major(rhs_shape),
1060 /*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0};
1061 }
1062
1063 // For vector-matrix dot products, it is always profitable to make the Rhs
1064 // column major.
ProfitableToMakeDotOperandColumnMajor(const HloInstruction & hlo)1065 std::optional<int64_t> ProfitableToMakeDotOperandColumnMajor(
1066 const HloInstruction& hlo) {
1067 if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() <= 1) {
1068 if (hlo.operand(0)->shape().rank() != 1 ||
1069 hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) != 0) {
1070 return {};
1071 }
1072
1073 // Don't bother if the other operand is tiny, switching to column major
1074 // wouldn't use tiling.
1075 constexpr int kColumnMajorThresholdInBytes = 32;
1076 int64_t lhs_size =
1077 ShapeUtil::ByteSizeOfPrimitiveType(hlo.shape().element_type()) *
1078 ShapeUtil::ElementsIn(hlo.operand(0)->shape());
1079 if (lhs_size < kColumnMajorThresholdInBytes) {
1080 return {};
1081 }
1082
1083 return 1;
1084 }
1085
1086 if (hlo.IsOutputFusion()) {
1087 auto* fusion_root =
1088 hlo.fused_instructions_computation()->root_instruction();
1089 if (fusion_root->opcode() != HloOpcode::kAdd) {
1090 return {};
1091 }
1092
1093 for (auto* fusion_root_op : fusion_root->operands()) {
1094 if (fusion_root_op->opcode() != HloOpcode::kDot) {
1095 continue;
1096 }
1097 if (auto operand_num =
1098 ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) {
1099 auto* operand = fusion_root_op->operand(*operand_num);
1100 if (operand->opcode() == HloOpcode::kParameter &&
1101 operand->user_count() == 1) {
1102 return operand->parameter_number();
1103 }
1104 }
1105 }
1106 }
1107
1108 return {};
1109 }
1110
1111 namespace {
1112 // Return whether the given shape is rank 2.
IsRank2(const Shape & shape)1113 bool IsRank2(const Shape& shape) { return shape.rank() == 2; }
1114
IsSimpleLayout(const Layout & layout)1115 bool IsSimpleLayout(const Layout& layout) {
1116 return layout.tiles().empty() && LayoutUtil::IsDense(layout);
1117 }
1118
1119 // In a gemm operation where output = lhs * rhs, check whether the given shapes
1120 // are valid for the operation.
AreGemmShapes(const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,const TargetMachineFeatures & target_machine_features)1121 bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
1122 const Shape& output_shape,
1123 const TargetMachineFeatures& target_machine_features) {
1124 CHECK(!lhs_shape.has_layout() || IsSimpleLayout(lhs_shape.layout()))
1125 << lhs_shape.DebugString();
1126 CHECK(!rhs_shape.has_layout() || IsSimpleLayout(rhs_shape.layout()))
1127 << rhs_shape.DebugString();
1128 CHECK(!output_shape.has_layout() || IsSimpleLayout(output_shape.layout()))
1129 << output_shape.DebugString();
1130
1131 switch (output_shape.element_type()) {
1132 case F16:
1133 case F32:
1134 case F64:
1135 case C64:
1136 case C128:
1137 case S32:
1138 return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape);
1139 default:
1140 return false;
1141 }
1142 }
1143
IsAlignedGemm(const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)1144 bool IsAlignedGemm(const DotInfo& dot_info,
1145 const TargetMachineFeatures& target_machine_features) {
1146 if (ShapeUtil::IsZeroElementArray(dot_info.lhs_shape) ||
1147 ShapeUtil::IsZeroElementArray(dot_info.rhs_shape)) {
1148 return false;
1149 }
1150
1151 return AreGemmShapes(dot_info.lhs_shape, dot_info.rhs_shape,
1152 dot_info.result_shape, target_machine_features);
1153 }
1154
CanEmitTiledLlvmIrGemm(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)1155 bool CanEmitTiledLlvmIrGemm(
1156 const HloModuleConfig& config, const DotInfo& dot_info,
1157 const TargetMachineFeatures& target_machine_features) {
1158 CHECK(IsAlignedGemm(dot_info, target_machine_features));
1159
1160 if (ShouldUseMultiThreadedEigen(config)) {
1161 return false;
1162 }
1163
1164 int m = dot_info.result_shape.dimensions(0);
1165 int k = dot_info.lhs_shape.dimensions(
1166 dot_info.dim_nums.lhs_contracting_dimensions(0));
1167 int n = dot_info.result_shape.dimensions(1);
1168
1169 if (!options::ForceEnableExperimentalLlvmIrGemm(config)) {
1170 // TODO(sanjoy): We should make these numbers micro-arch specific.
1171 bool small_gemm =
1172 k <= 128 && ((m <= 32 && n <= 128) || (m <= 128 && n <= 32));
1173 if (!small_gemm) {
1174 return false;
1175 }
1176 }
1177
1178 bool lhs_canonical = dot_info.dim_nums.lhs_contracting_dimensions(0) == 1;
1179 bool rhs_canonical = dot_info.dim_nums.rhs_contracting_dimensions(0) == 0;
1180
1181 if (!(lhs_canonical && rhs_canonical)) {
1182 return false;
1183 }
1184
1185 if (dot_info.result_shape.element_type() == F16 ||
1186 dot_info.result_shape.element_type() == C64 ||
1187 dot_info.result_shape.element_type() == C128) {
1188 // TODO(sanjoy): This is probably easy to fix, but I want to keep the CL
1189 // adding this comment NFC.
1190 return false;
1191 }
1192
1193 return true;
1194 }
1195
GetDotImplementationStrategy(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)1196 DotImplementationStrategy GetDotImplementationStrategy(
1197 const HloModuleConfig& config, const DotInfo& dot_info,
1198 const TargetMachineFeatures& target_machine_features) {
1199 PrimitiveType element_type = dot_info.result_shape.element_type();
1200 // Any Matrix-Vector product of floating point or integral type, or
1201 // a transpose-dot fusion of the same can be lowered to a tiled LLVM
1202 // IR implementation.
1203 if ((dot_info.result_shape.dimensions_size() <= 1 ||
1204 (dot_info.result_shape.dimensions_size() == 2 &&
1205 (dot_info.result_shape.dimensions(0) == 1 ||
1206 dot_info.result_shape.dimensions(1) == 1))) &&
1207 (primitive_util::IsFloatingPointType(element_type) ||
1208 primitive_util::IsIntegralType(element_type))) {
1209 return DotImplementationStrategy::kTiledLlvmIrGemv;
1210 }
1211
1212 // MatMul smaller than 3x3 should use naive nested loop.
1213 if ((dot_info.lhs_shape.dimensions_size() <= 1 ||
1214 (dot_info.lhs_shape.dimensions_size() == 2 &&
1215 (dot_info.lhs_shape.dimensions(0) <= 3 ||
1216 dot_info.lhs_shape.dimensions(1) <= 3))) &&
1217 (dot_info.rhs_shape.dimensions_size() <= 1 ||
1218 (dot_info.rhs_shape.dimensions_size() == 2 &&
1219 (dot_info.rhs_shape.dimensions(0) <= 3 ||
1220 dot_info.rhs_shape.dimensions(1) <= 3))) &&
1221 (primitive_util::IsFloatingPointType(element_type) ||
1222 primitive_util::IsIntegralType(element_type))) {
1223 return DotImplementationStrategy::kNaiveLlvmIr;
1224 }
1225
1226 if (IsAlignedGemm(dot_info, target_machine_features)) {
1227 if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) {
1228 return DotImplementationStrategy::kTiledLlvmIrGemm;
1229 }
1230 return DotImplementationStrategy::kEigen;
1231 }
1232
1233 return DotImplementationStrategy::kNaiveLlvmIr;
1234 }
1235
EmitNonBatchDotOperation(DotInfo dot_info,std::string hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1236 Status EmitNonBatchDotOperation(
1237 DotInfo dot_info, std::string hlo_name,
1238 const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
1239 const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array,
1240 llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
1241 mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
1242 const TargetMachineFeatures& target_machine_features) {
1243 PrimitiveType type = target_array.GetShape().element_type();
1244 TF_RET_CHECK(PRED == type || S8 == type || U8 == type || S16 == type ||
1245 U16 == type || S32 == type || U32 == type || S64 == type ||
1246 U64 == type || F16 == type || F32 == type || F64 == type ||
1247 C64 == type || C128 == type);
1248 DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
1249 target_array, lhs_array, rhs_array, addend_array,
1250 executable_run_options_value, b, mlir_context,
1251 hlo_module_config, target_machine_features);
1252 return dot_emitter.Emit();
1253 }
1254
DropFirstDim(const Shape & shape)1255 Shape DropFirstDim(const Shape& shape) {
1256 absl::Span<int64_t const> array_shape_dims(shape.dimensions());
1257 array_shape_dims.remove_prefix(1);
1258 return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1259 array_shape_dims);
1260 }
1261
CollapseFirstNDims(const Shape & shape,int64_t n)1262 Shape CollapseFirstNDims(const Shape& shape, int64_t n) {
1263 absl::Span<int64_t const> input_shape_dims(shape.dimensions());
1264 int64_t prefix_dim =
1265 std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n,
1266 1ll, std::multiplies<int64_t>());
1267 DimensionVector result_dims;
1268 result_dims.push_back(prefix_dim);
1269 std::copy(input_shape_dims.begin() + n, input_shape_dims.end(),
1270 std::back_inserter(result_dims));
1271 return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1272 result_dims);
1273 }
1274
CollapseFirstNDims(llvm::IRBuilder<> * b,const llvm_ir::IrArray & array,int64_t n)1275 llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilder<>* b,
1276 const llvm_ir::IrArray& array, int64_t n) {
1277 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
1278 const Shape& shape = array.GetShape();
1279 CHECK(shape.has_layout() &&
1280 LayoutUtil::IsMonotonicWithDim0Major(shape.layout()));
1281 CHECK_GE(shape.dimensions_size(), n);
1282 Shape new_shape = CollapseFirstNDims(shape, n);
1283 llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
1284 llvm::Value* new_value =
1285 b->CreateBitCast(array.GetBasePointer(), new_ir_type->getPointerTo());
1286 return llvm_ir::IrArray(new_value, new_ir_type, std::move(new_shape));
1287 }
1288
ValidateDotDimensionNumbers(const DotDimensionNumbers & dim_numbers)1289 Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) {
1290 // Checks some invariants that do not hold in general, but DotDecomposer
1291 // should have established for us. This is just a debugging aid.
1292 TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
1293 std::vector<int64_t> batch_dim_numbers(
1294 dim_numbers.lhs_batch_dimensions_size());
1295 absl::c_iota(batch_dim_numbers, 0);
1296 TF_RET_CHECK(
1297 absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions()));
1298 TF_RET_CHECK(
1299 absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions()));
1300 return OkStatus();
1301 }
1302
1303 // Slice out the inner array at batch index `batch_index` from `outer_array`.
SliceOutInnerArray(llvm_ir::IrArray outer_array,llvm::Value * batch_index,llvm::IRBuilder<> * b)1304 llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array,
1305 llvm::Value* batch_index,
1306 llvm::IRBuilder<>* b) {
1307 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
1308
1309 Shape inner_shape = DropFirstDim(outer_array.GetShape());
1310 std::vector<llvm::Value*> multidim_index(inner_shape.rank() + 1,
1311 b->getInt64(0));
1312 multidim_index[0] = batch_index;
1313 llvm_ir::IrArray::Index slice_index(multidim_index, outer_array.GetShape(),
1314 batch_index->getType());
1315 llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b);
1316 llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(inner_shape, module);
1317 llvm::Type* slice_ptr_type = new_ir_type->getPointerTo();
1318 return llvm_ir::IrArray(b->CreateBitCast(slice_ptr, slice_ptr_type),
1319 new_ir_type, std::move(inner_shape));
1320 }
1321
PotentiallyImplementedAsEigenMatmul(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features,DotInfo & dot_info)1322 bool PotentiallyImplementedAsEigenMatmul(
1323 const HloInstruction& dot, const llvm_ir::IrArray& target_array,
1324 const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
1325 llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
1326 mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
1327 const TargetMachineFeatures& target_machine_features, DotInfo& dot_info) {
1328 int64_t num_batch_dims =
1329 dot.dot_dimension_numbers().lhs_batch_dimensions_size();
1330
1331 // TODO(kramerb): Remove this limitation.
1332 if (num_batch_dims > 1) return false;
1333
1334 // First reshape the inputs to make sure we only have one batch dimension.
1335 // This is a no-op bitcast because the operands have to be in row-major layout
1336 // (enforced in CpuLayoutAssignment), and the batch dimensions are the leading
1337 // dimensions (established by DotDecomposer and checked by
1338 // ValidateDotDimensionNumbers above).
1339 llvm_ir::IrArray lhs_array_reshaped =
1340 CollapseFirstNDims(b, lhs_array, num_batch_dims);
1341 llvm_ir::IrArray rhs_array_reshaped =
1342 CollapseFirstNDims(b, rhs_array, num_batch_dims);
1343 llvm_ir::IrArray target_array_reshaped =
1344 CollapseFirstNDims(b, target_array, num_batch_dims);
1345
1346 DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers();
1347 adjusted_dim_numbers.clear_lhs_batch_dimensions();
1348 adjusted_dim_numbers.clear_rhs_batch_dimensions();
1349
1350 // Create a DotInfo representing the batch of "inner" dot operations.
1351 dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape());
1352 dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape());
1353 dot_info.result_shape = DropFirstDim(target_array_reshaped.GetShape());
1354 dot_info.dim_nums = dot.dot_dimension_numbers();
1355 dot_info.dim_nums.clear_lhs_batch_dimensions();
1356 dot_info.dim_nums.clear_rhs_batch_dimensions();
1357
1358 dot_info.dim_nums.set_lhs_contracting_dimensions(
1359 0, dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims);
1360 dot_info.dim_nums.set_rhs_contracting_dimensions(
1361 0, dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims);
1362
1363 PrimitiveType type = target_array.GetShape().element_type();
1364 if (F32 != type) return false;
1365
1366 if (ShapeUtil::IsScalar(dot_info.lhs_shape) ||
1367 ShapeUtil::IsScalar(dot_info.rhs_shape)) {
1368 // If the operands are scalar, don't emit any loops.
1369 return false;
1370 }
1371
1372 DotImplementationStrategy impl_strategy = GetDotImplementationStrategy(
1373 dot.parent()->parent()->config(), dot_info, target_machine_features);
1374
1375 return impl_strategy == DotImplementationStrategy::kEigen;
1376 }
1377
EmitBatchDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1378 Status EmitBatchDotOperation(
1379 const HloInstruction& dot, const llvm_ir::IrArray& target_array,
1380 const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
1381 llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
1382 mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
1383 const TargetMachineFeatures& target_machine_features) {
1384 TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers()));
1385
1386 // first check if the batch can be rendered directly by the runtime
1387 // otherwise lower it to a sequence of non-batch dot operations
1388 DotInfo dot_info;
1389 if (ShouldUseMultiThreadedEigen(hlo_module_config) &&
1390 PotentiallyImplementedAsEigenMatmul(
1391 dot, target_array, lhs_array, rhs_array, executable_run_options_value,
1392 b, mlir_context, hlo_module_config, target_machine_features,
1393 dot_info)) {
1394 DotOpEmitter dot_emitter(dot_info, dot.name(), target_array, lhs_array,
1395 rhs_array, nullptr /*addend_array*/,
1396 executable_run_options_value, b, mlir_context,
1397 hlo_module_config, target_machine_features);
1398
1399 return dot_emitter.EmitBatch();
1400 } else {
1401 // Lower a batch dot into a sequence of non-batch dot operations.
1402
1403 int64_t num_batch_dims =
1404 dot.dot_dimension_numbers().lhs_batch_dimensions_size();
1405
1406 // First reshape the inputs to make sure we only have one batch dimension.
1407 // This is a no-op bitcast because the operands have to be in row-major
1408 // layout (enforced in CpuLayoutAssignment), and the batch dimensions are
1409 // the leading dimensions (established by DotDecomposer and checked by
1410 // ValidateDotDimensionNumbers above).
1411 llvm_ir::IrArray lhs_array_reshaped =
1412 CollapseFirstNDims(b, lhs_array, num_batch_dims);
1413 llvm_ir::IrArray rhs_array_reshaped =
1414 CollapseFirstNDims(b, rhs_array, num_batch_dims);
1415 llvm_ir::IrArray target_array_reshaped =
1416 CollapseFirstNDims(b, target_array, num_batch_dims);
1417
1418 int64_t batch_count = lhs_array_reshaped.GetShape().dimensions(0);
1419
1420 KernelSupportLibrary ksl(b);
1421
1422 return ksl.ForWithStatus(
1423 llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count,
1424 /*step=*/1, [&](llvm::Value* indvar) {
1425 DotDimensionNumbers adjusted_dim_numbers =
1426 dot.dot_dimension_numbers();
1427 adjusted_dim_numbers.clear_lhs_batch_dimensions();
1428 adjusted_dim_numbers.clear_rhs_batch_dimensions();
1429
1430 // Create a DotInfo representing the "inner" non-batch dot operation.
1431 DotInfo dot_info;
1432 dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape());
1433 dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape());
1434 dot_info.result_shape =
1435 DropFirstDim(target_array_reshaped.GetShape());
1436 dot_info.dim_nums = dot.dot_dimension_numbers();
1437 dot_info.dim_nums.clear_lhs_batch_dimensions();
1438 dot_info.dim_nums.clear_rhs_batch_dimensions();
1439
1440 dot_info.dim_nums.set_lhs_contracting_dimensions(
1441 0,
1442 dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims);
1443 dot_info.dim_nums.set_rhs_contracting_dimensions(
1444 0,
1445 dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims);
1446
1447 llvm_ir::IrArray lhs_slice =
1448 SliceOutInnerArray(lhs_array_reshaped, /*batch_index=*/indvar, b);
1449 llvm_ir::IrArray rhs_slice =
1450 SliceOutInnerArray(rhs_array_reshaped, /*batch_index=*/indvar, b);
1451 llvm_ir::IrArray target_slice = SliceOutInnerArray(
1452 target_array_reshaped, /*batch_index=*/indvar, b);
1453
1454 // Emit the inner non-batch dot operation.
1455 return EmitNonBatchDotOperation(
1456 dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr,
1457 executable_run_options_value, b, mlir_context, hlo_module_config,
1458 target_machine_features);
1459 });
1460 }
1461 }
1462
IsBatchDot(const HloInstruction & instr)1463 bool IsBatchDot(const HloInstruction& instr) {
1464 if (auto* dot_instr = DynCast<HloDotInstruction>(&instr)) {
1465 return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0;
1466 }
1467
1468 return false;
1469 }
1470 } // namespace
1471
DotImplementationCanHandleTranspose(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1472 bool DotImplementationCanHandleTranspose(
1473 const HloInstruction& dot_instr,
1474 const TargetMachineFeatures& target_machine_features) {
1475 DotImplementationStrategy impl_strategy =
1476 GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1477 DotInfo(dot_instr), target_machine_features);
1478
1479 return impl_strategy == DotImplementationStrategy::kNaiveLlvmIr ||
1480 impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemv ||
1481 impl_strategy == DotImplementationStrategy::kEigen;
1482 }
1483
DotOperandsAndResultMustHaveRowMajorLayout(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1484 bool DotOperandsAndResultMustHaveRowMajorLayout(
1485 const HloInstruction& dot_instr,
1486 const TargetMachineFeatures& target_machine_features) {
1487 // Batched dots require the batch dimensions to be major. DotDecomposer always
1488 // moves batch dimensions to the front of the shape, so force a row-major
1489 // layout.
1490 if (IsBatchDot(dot_instr)) {
1491 return true;
1492 }
1493
1494 DotImplementationStrategy impl_strategy =
1495 GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1496 DotInfo(dot_instr), target_machine_features);
1497
1498 return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm ||
1499 impl_strategy == DotImplementationStrategy::kEigen;
1500 }
1501
EmitDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1502 Status EmitDotOperation(const HloInstruction& dot,
1503 const llvm_ir::IrArray& target_array,
1504 const llvm_ir::IrArray& lhs_array,
1505 const llvm_ir::IrArray& rhs_array,
1506 const llvm_ir::IrArray* addend_array,
1507 llvm::Value* executable_run_options_value,
1508 llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
1509 const HloModuleConfig& hlo_module_config,
1510 const TargetMachineFeatures& target_machine_features) {
1511 // This routine assumes that the dot operation is not in a parallelized
1512 // enclosing computation.
1513 CHECK(dot.parent()->root_instruction()->outer_dimension_partitions().empty());
1514
1515 if (IsBatchDot(dot)) {
1516 TF_RET_CHECK(addend_array == nullptr);
1517 return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array,
1518 executable_run_options_value, b, mlir_context,
1519 hlo_module_config, target_machine_features);
1520 }
1521
1522 return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array,
1523 lhs_array, rhs_array, addend_array,
1524 executable_run_options_value, b, mlir_context,
1525 hlo_module_config, target_machine_features);
1526 }
1527 } // namespace cpu
1528 } // namespace xla
1529