1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/Value.h"
26 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" // from @llvm-project
27 #include "mlir/Dialect/StandardOps/Utils/Utils.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/MLIRContext.h" // from @llvm-project
31 #include "mlir/IR/OperationSupport.h" // from @llvm-project
32 #include "mlir/IR/Value.h" // from @llvm-project
33 #include "tensorflow/compiler/xla/primitive_util.h"
34 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
35 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
36 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
37 #include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
38 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
39 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
40 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
41 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
42 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
43 #include "tensorflow/compiler/xla/service/hlo_module.h"
44 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
45 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
46 #include "tensorflow/compiler/xla/shape_util.h"
47 #include "tensorflow/compiler/xla/status_macros.h"
48 #include "tensorflow/compiler/xla/util.h"
49 #include "tensorflow/compiler/xla/xla_data.pb.h"
50 #include "tensorflow/core/platform/logging.h"
51
52 namespace xla {
53
54 using llvm_ir::SetToFirstInsertPoint;
55
56 namespace cpu {
57 namespace {
58 // Returns true if we should call into multi-threaded Eigen routines.
ShouldUseMultiThreadedEigen(const HloModuleConfig & config)59 bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) {
60 return config.debug_options().xla_cpu_multi_thread_eigen();
61 }
62
63 // Represents a dot operation. We use this in lieu of an `HloInstruction`
64 // because we want to be able to create this for the "inner" dot operation in a
65 // batch dot, for which there is no separate HLO instruction.
66 struct DotInfo {
67 Shape lhs_shape;
68 Shape rhs_shape;
69 Shape result_shape;
70 DotDimensionNumbers dim_nums;
71
72 DotInfo() = default;
73
DotInfoxla::cpu::__anon68a1a4880111::DotInfo74 explicit DotInfo(const HloInstruction& instr) {
75 CHECK_EQ(instr.opcode(), HloOpcode::kDot);
76 lhs_shape = instr.operand(0)->shape();
77 rhs_shape = instr.operand(1)->shape();
78 result_shape = instr.shape();
79 dim_nums = instr.dot_dimension_numbers();
80 }
81 };
82
83 // Dictates how a dot operation is implemented.
84 enum class DotImplementationStrategy {
85 // The dot operation is lowered into LLVM IR that implements a naive nested
86 // loop that computes the result one element at a time. This is our
87 // "fallback"; we don't really want this to kick in for any non-trival dot
88 // operation.
89 kNaiveLlvmIr,
90
91 // The dot operation is lowered into LLVM IR that implements a tiled
92 // Matrix*Vector operation. This strategy also allows fusing in a bias add
93 // into the dot. The matrix can be row major or column major, both are
94 // supported.
95 kTiledLlvmIrGemv,
96
97 // The dot operation is lowered into LLVM IR that implements a tiled
98 // Matrix*Matrix operation. No fusions are supported. The two inputs
99 // and the output have to be row major.
100 kTiledLlvmIrGemm,
101
102 // The dot operation is lowered into linalg.matmul op and lowered to LLVM IR.
103 kLinalgMatmul,
104
105 // The dot operation is lowered into a call into an Eigen routine. No fusions
106 // are supported today. The two inputs and the output have to be row major.
107 // However, we do allow transposing either the LHS or the RHS as part of the
108 // GEMM -- we expose this flexibility as flexibility in the contraction
109 // dimensions, but we can also see this as flexibility in the input layouts.
110 kEigen,
111 };
112
113 // Returns the implementation strategy for a dot with the configuration
114 // `dot_info`.
115 DotImplementationStrategy GetDotImplementationStrategy(
116 const HloModuleConfig& config, const DotInfo& dot_info,
117 const TargetMachineFeatures& target_machine_features);
118
119 // Helper class for emitting LLVM IR to perform the dot operation.
120 class DotOpEmitter {
121 public:
122 explicit DotOpEmitter(DotInfo dot_info, string dot_hlo_name,
123 const llvm_ir::IrArray& target_array,
124 const llvm_ir::IrArray& lhs_array,
125 const llvm_ir::IrArray& rhs_array,
126 const llvm_ir::IrArray* addend_array,
127 llvm::Value* executable_run_options_value,
128 llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
129 const HloModuleConfig& hlo_module_config,
130 const TargetMachineFeatures& target_machine_features);
131
132 // Emits the IR to perform the dot operation.
133 Status Emit();
134
135 private:
136 // Emits instructions to perform a scalar dot product (a multiply of the
137 // LHS and RHS) and store the results in the target.
138 Status EmitScalarDot();
139
140 // Emits a call to the CPU runtime to perform the matrix multiply.
141 Status EmitCallToRuntime();
142
143 // Represents the dimensions of a matrix-matrix multiply operation.
144 struct MatMultDims {
145 // The number of rows in the LHS.
146 int64 m;
147
148 // The number of columns in the LHS, which is also must be equal to the
149 // number of rows in the RHS.
150 int64 k;
151
152 // The number of columns on the RHS.
153 int64 n;
154
155 // True if the LHS matrix is column major.
156 bool lhs_column_major;
157
158 // True if the LHS contraction dimension is 1.
159 bool lhs_canonical;
160
161 // True if the RHS matrix is column major.
162 bool rhs_column_major;
163
164 // True if the RHS contraction dimension is 0.
165 bool rhs_canonical;
166 };
167
168 // Get the MatMultDims instance for the dot product this DotOpEmitter
169 // represents. Precondition: the dot is of rank 2 (and thus its operands are
170 // of rank 2 as well).
171 MatMultDims GetMatMultDims() const;
172
173 // Lowers the dot operation as a tiled Matrix*Vector loop.
174 void EmitTiledLlvmIrGemv();
175
176 // Lowers the dot operation as a tiled Matrix*Matrix loop.
177 void EmitTiledLlvmIrGemm();
178
179 // Lowers the dot operation through MLIR's linalg.matmul.
180 Status EmitLinalgMatmul();
181
182 // Lowers the dot operation as a naive nested loop that computes the result
183 // one element at a time.
184 void EmitNaiveLlvmIrGemm();
185
186 // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
187 // registers.
GetGemvTilingFactor() const188 int64 GetGemvTilingFactor() const {
189 const int64_t kDefaultTilingFactor = 8;
190 return options::LlvmIrGemvTilingFactor(hlo_module_config_)
191 .value_or(kDefaultTilingFactor);
192 }
193
GetGemmTileSize() const194 std::tuple<int64, int64, int64> GetGemmTileSize() const {
195 // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
196 //
197 // TODO(b/80093688): Tune for other architectures and centralize this
198 // information in one place.
199 const std::tuple<int64, int64, int64> kDefaultTileSize =
200 std::tuple<int64, int64, int64>(11, 9, 1);
201 return options::LlvmIrGemmTileSize(hlo_module_config_)
202 .value_or(kDefaultTileSize);
203 }
204
GetMlirGemmTileSize() const205 std::array<int64_t, 3> GetMlirGemmTileSize() const {
206 // Tile by 4 x registers x register size. This was picked by running
207 // small matmuls on Haswell and Skylake. There's a lot of room for
208 // improvement here.
209 constexpr int64_t kDefaultTileSizeForM = 4;
210 int64_t elements_per_register =
211 target_machine_features_.vector_register_num_elements(
212 *b_->GetInsertBlock()->getParent(),
213 dot_info_.result_shape.element_type());
214 int64_t num_registers = target_machine_features_.vector_register_count(
215 *b_->GetInsertBlock()->getParent());
216 return {{kDefaultTileSizeForM, num_registers, elements_per_register}};
217 }
218
219 DotInfo dot_info_;
220 string dot_hlo_name_;
221 const llvm_ir::IrArray& target_array_;
222 const llvm_ir::IrArray& lhs_array_;
223 const llvm_ir::IrArray& rhs_array_;
224 const llvm_ir::IrArray* addend_array_;
225 llvm::Value* executable_run_options_value_;
226 llvm::IRBuilder<>* b_;
227 mlir::MLIRContext* mlir_context_;
228 const HloModuleConfig& hlo_module_config_;
229 const TargetMachineFeatures& target_machine_features_;
230 };
231 } // namespace
232
DotOpEmitter(DotInfo dot_info,string dot_hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)233 DotOpEmitter::DotOpEmitter(
234 DotInfo dot_info, string dot_hlo_name, const llvm_ir::IrArray& target_array,
235 const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
236 const llvm_ir::IrArray* addend_array,
237 llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
238 mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
239 const TargetMachineFeatures& target_machine_features)
240 : dot_info_(std::move(dot_info)),
241 dot_hlo_name_(std::move(dot_hlo_name)),
242 target_array_(target_array),
243 lhs_array_(lhs_array),
244 rhs_array_(rhs_array),
245 addend_array_(addend_array),
246 executable_run_options_value_(executable_run_options_value),
247 b_(b),
248 mlir_context_(mlir_context),
249 hlo_module_config_(hlo_module_config),
250 target_machine_features_(target_machine_features) {}
251
EmitLinalgMatmul()252 Status DotOpEmitter::EmitLinalgMatmul() {
253 Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape};
254 llvm::Value* operand_ptrs[] = {lhs_array_.GetBasePointer(),
255 rhs_array_.GetBasePointer()};
256 llvm::Value* target_ptr = target_array_.GetBasePointer();
257
258 // Zero out the output buffer.
259 int64_t size_bytes = ShapeUtil::ByteSizeOf(dot_info_.result_shape);
260 b_->CreateMemSet(target_ptr, b_->getInt8(0), /*Size=*/size_bytes,
261 /*Align=*/llvm::MaybeAlign(1));
262
263 std::string name =
264 absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_",
265 dot_info_.lhs_shape.ToString(true), "_",
266 dot_info_.rhs_shape.ToString(true));
267
268 return EmitMlirFuncAndCall(
269 mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
270 operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) {
271 CHECK_EQ(dot_info_.dim_nums.lhs_contracting_dimensions_size(), 1);
272 CHECK_EQ(dot_info_.dim_nums.rhs_contracting_dimensions_size(), 1);
273 mlir::MLIRContext* context = builder->getContext();
274 mlir::Value a = function.getArgument(0), b = function.getArgument(1),
275 c = function.getArgument(2);
276
277 llvm::SmallVector<mlir::AffineExpr, 2> b_exprs(
278 dot_info_.lhs_shape.rank());
279 llvm::SmallVector<mlir::AffineExpr, 2> c_exprs(
280 dot_info_.rhs_shape.rank());
281
282 llvm::SmallVector<mlir::AffineExpr, 2> parallel_exprs;
283 mlir::AffineExpr reduce_expr;
284 for (int i = 0; i != dot_info_.result_shape.rank(); ++i) {
285 parallel_exprs.push_back(mlir::getAffineDimExpr(i, context));
286 }
287 reduce_expr =
288 mlir::getAffineDimExpr(dot_info_.result_shape.rank(), context);
289
290 // The reduction expr is shared for both inputs.
291 b_exprs[dot_info_.dim_nums.lhs_contracting_dimensions(0)] = reduce_expr;
292 c_exprs[dot_info_.dim_nums.rhs_contracting_dimensions(0)] = reduce_expr;
293
294 // Fill in the remaining parallel exprs.
295 int par_expr_num = 0;
296 for (auto* v : {&b_exprs, &c_exprs}) {
297 for (auto& e : *v) {
298 if (!e) {
299 e = parallel_exprs[par_expr_num++];
300 }
301 }
302 }
303
304 llvm::SmallVector<llvm::StringRef, 4> iteratorTypes(
305 parallel_exprs.size(), toString(mlir::IteratorType::Parallel));
306 iteratorTypes.push_back(toString(mlir::IteratorType::Reduction));
307 builder->create<mlir::linalg::GenericOp>(
308 function.getLoc(),
309 /*inputs=*/mlir::ValueRange{b, c},
310 /*outputs=*/mlir::ValueRange{a},
311 /*indexingMaps=*/
312 mlir::AffineMap::inferFromExprList(
313 {b_exprs, c_exprs, parallel_exprs}),
314 /*iteratorTypes=*/iteratorTypes,
315 [](mlir::OpBuilder& b, mlir::Location loc, mlir::ValueRange args) {
316 mlir::ArithBuilder ab(b, loc);
317 mlir::Value mul = ab.mul(args[0], args[1]);
318 mlir::Value add = ab.add(mul, args[2]);
319 b.create<mlir::linalg::YieldOp>(loc, add);
320 });
321 builder->create<mlir::ReturnOp>(function.getLoc());
322
323 mlir::linalg::LinalgTilingOptions tilingOptions;
324 tilingOptions = tilingOptions.setTileSizes(GetMlirGemmTileSize());
325 int64_t alignment =
326 target_machine_features_.minimum_alignment_for_allocation(
327 ShapeUtil::ByteSizeOf(dot_info_.result_shape));
328 mlir::linalg::CodegenStrategy strategy;
329 strategy.tile<mlir::linalg::GenericOp>(tilingOptions)
330 .promote<mlir::linalg::GenericOp>(
331 mlir::linalg::LinalgPromotionOptions()
332 .setAlignment(alignment)
333 .setUseFullTileBuffersByDefault(true)
334 .setUseAlloca(true))
335 .vectorize<mlir::linalg::GenericOp>()
336 .setVectorTransformsOptions(
337 mlir::vector::VectorTransformsOptions()
338 .setVectorTransformsOptions(
339 mlir::vector::VectorContractLowering::OuterProduct))
340 .setVectorTransferToSCFOptions(
341 mlir::VectorTransferToSCFOptions().setUnroll(true));
342 strategy.transform(function);
343 });
344 }
345
EmitTiledLlvmIrGemm()346 void DotOpEmitter::EmitTiledLlvmIrGemm() {
347 PrimitiveType primitive_type = dot_info_.result_shape.element_type();
348 MatMultDims mat_mult_dims = GetMatMultDims();
349
350 llvm::Value* lhs = lhs_array_.GetBasePointer();
351 llvm::Value* rhs = rhs_array_.GetBasePointer();
352 llvm::Value* target = target_array_.GetBasePointer();
353 int64_t m = mat_mult_dims.m;
354 int64_t k = mat_mult_dims.k;
355 int64_t n = mat_mult_dims.n;
356
357 if (mat_mult_dims.lhs_column_major) {
358 std::swap(lhs, rhs);
359 std::swap(m, n);
360 }
361
362 int64_t size_bytes =
363 m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
364 b_->CreateMemSet(target, b_->getInt8(0), /*Size=*/size_bytes,
365 /*Align=*/llvm::MaybeAlign(1));
366
367 int64_t max_target_vector_width =
368 target_machine_features_.vector_register_num_elements(
369 *b_->GetInsertBlock()->getParent(), primitive_type);
370
371 int64_t tile_size_m, tile_size_k, tile_size_n_in_vector_width;
372 std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
373 GetGemmTileSize();
374
375 EmitSmallGemm(
376 /*scalar_type=*/primitive_type,
377 /*m=*/m, /*k=*/k, /*n=*/n,
378 /*max_vectorization_width=*/max_target_vector_width,
379 /*max_vector_count=*/tile_size_n_in_vector_width,
380 /*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width),
381 /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs,
382 /*rhs=*/rhs, /*result=*/target, b_, hlo_module_config_);
383 }
384
EmitTiledLlvmIrGemv()385 void DotOpEmitter::EmitTiledLlvmIrGemv() {
386 PrimitiveType primitive_type = dot_info_.result_shape.element_type();
387
388 CHECK(primitive_util::IsFloatingPointType(primitive_type) ||
389 primitive_util::IsIntegralType(primitive_type));
390
391 MatMultDims mat_mult_dims = GetMatMultDims();
392 bool is_column_major_matrix_vector_gemv = false;
393 bool is_row_major_matrix_vector_gemv = false;
394
395 int64_t m, k;
396 bool swap_operands;
397
398 if (mat_mult_dims.m == 1) {
399 // Our emitters can only do Matrix*Vector (abbreviated as M*V) but when M=1
400 // we actually want V*M. We implement V*M as follows (Tr(X) = Transpose of
401 // X):
402 //
403 // V*M = Tr(Tr(V*M)) // Tr(Tr(X)) == X
404 // = Tr(Tr(M) * Tr(V)) // Tr(A * B) == Tr(B) * Tr(A)
405 //
406 // Since transposing a vector is physically a no-op, this is really
407 // equivalent to `Tr(M) * V`. We further implement Tr(M) by pretending that
408 // M is row major if it is actually column major and vice-versa.
409
410 bool rhs_effectively_column_major = mat_mult_dims.rhs_canonical
411 ? mat_mult_dims.rhs_column_major
412 : !mat_mult_dims.rhs_column_major;
413
414 if (rhs_effectively_column_major) {
415 k = mat_mult_dims.k;
416 m = mat_mult_dims.n;
417
418 // We set is_row_major_matrix_vector_gemv and not
419 // is_column_major_matrix_vector_gemv to implement the Transpose trick
420 // mentioned above.
421 is_row_major_matrix_vector_gemv = true;
422 swap_operands = true;
423 } else {
424 k = mat_mult_dims.k;
425 m = mat_mult_dims.n;
426
427 // We set is_column_major_matrix_vector_gemv and not
428 // is_row_major_matrix_vector_gemv to implement the Transpose trick
429 // mentioned above.
430 is_column_major_matrix_vector_gemv = true;
431 swap_operands = true;
432 }
433 }
434
435 if (mat_mult_dims.n == 1) {
436 bool lhs_effectively_column_major = mat_mult_dims.lhs_canonical
437 ? mat_mult_dims.lhs_column_major
438 : !mat_mult_dims.lhs_column_major;
439
440 if (lhs_effectively_column_major) {
441 m = mat_mult_dims.m;
442 k = mat_mult_dims.k;
443 is_column_major_matrix_vector_gemv = true;
444 swap_operands = false;
445 } else {
446 m = mat_mult_dims.m;
447 k = mat_mult_dims.k;
448 is_row_major_matrix_vector_gemv = true;
449 swap_operands = false;
450 }
451 }
452
453 CHECK(is_column_major_matrix_vector_gemv || is_row_major_matrix_vector_gemv);
454
455 int64_t tiling_factor = GetGemvTilingFactor();
456 CHECK_GT(tiling_factor, 0);
457
458 llvm::Value* result_op = target_array_.GetBasePointer();
459 llvm::Value* lhs_op =
460 swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
461 llvm::Value* rhs_op =
462 swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
463
464 const int target_vector_register_element_size =
465 target_machine_features_.vector_register_num_elements(
466 *b_->GetInsertBlock()->getParent(), primitive_type);
467
468 // We may not always know the vector register size for the target we're
469 // compiling against, in which case target_vector_register_element_size is 0.
470 // In these cases we choose a default LLVM IR register size.
471 const int kUnknownTargetVectorRegisterSize = 4;
472 const int vector_register_element_size =
473 target_vector_register_element_size == 0
474 ? kUnknownTargetVectorRegisterSize
475 : target_vector_register_element_size;
476
477 if (is_column_major_matrix_vector_gemv) {
478 VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
479 << " and k = " << k;
480 EmitColumnMajorGemv(
481 /*scalar_type=*/primitive_type,
482 /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
483 /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
484 /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
485 /*result=*/result_op, b_, hlo_module_config_);
486 } else {
487 VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
488 << " and k = " << k;
489 EmitRowMajorGemv(
490 /*scalar_type=*/primitive_type,
491 /*tile_rows=*/tiling_factor,
492 /*tile_cols=*/vector_register_element_size,
493 /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
494 /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
495 /*result=*/result_op, b_, hlo_module_config_);
496 }
497 }
498
Emit()499 Status DotOpEmitter::Emit() {
500 // The dot operation performs a sum of products over dimension 0 of the left
501 // hand side operand and dimension 1 of the right hand side operand.
502 //
503 // Let the shapes of lhs and rhs be defined as below:
504 //
505 // lhs = [L{n-1} x L{n-2} x ... L{0}]
506 // rhs = [R{m-1} x R{m-2} x ... R{0}]
507 //
508 // The sum-of-products dimension in the lhs has size L{0} and the dimension in
509 // the rhs has size R{1}. Necessarily, then:
510 //
511 // L{0} == R{1}
512 //
513 // The output of the operation has the following shape:
514 //
515 // output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
516 //
517 // To perform the operation we construct a loop nest with one for-loop for
518 // each dimension of the output. Inside this loop nest is another for-loop
519 // which performs the sum-of-products (the reduction loop) before storing
520 // the result in the output buffer.
521
522 const Shape& lhs_shape = lhs_array_.GetShape();
523 const Shape& rhs_shape = rhs_array_.GetShape();
524
525 if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
526 // If the operands are scalar, don't emit any loops.
527 TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) &&
528 ShapeUtil::IsScalar(rhs_shape));
529 return EmitScalarDot();
530 }
531
532 switch (GetDotImplementationStrategy(hlo_module_config_, dot_info_,
533 target_machine_features_)) {
534 case DotImplementationStrategy::kNaiveLlvmIr:
535 EmitNaiveLlvmIrGemm();
536 return Status::OK();
537
538 case DotImplementationStrategy::kTiledLlvmIrGemv:
539 EmitTiledLlvmIrGemv();
540 return Status::OK();
541
542 case DotImplementationStrategy::kTiledLlvmIrGemm:
543 EmitTiledLlvmIrGemm();
544 return Status::OK();
545
546 case DotImplementationStrategy::kLinalgMatmul:
547 return EmitLinalgMatmul();
548
549 case DotImplementationStrategy::kEigen:
550 return EmitCallToRuntime();
551 }
552 }
553
EmitNaiveLlvmIrGemm()554 void DotOpEmitter::EmitNaiveLlvmIrGemm() {
555 CHECK_EQ(addend_array_, nullptr);
556
557 const Shape& lhs_shape = lhs_array_.GetShape();
558 const Shape& rhs_shape = rhs_array_.GetShape();
559 const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
560
561 // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special
562 // case where the reduction dimension is 0 for both LHS and RHS. This results
563 // in a vector dot product producing a scalar.
564 int64_t lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0);
565 int64_t rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0);
566
567 // Verify the reduction dimension in the two operands are the same size.
568 CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension),
569 rhs_shape.dimensions(rhs_reduction_dimension));
570
571 bool lhs_reduction_along_minor_dimension =
572 lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0);
573 bool rhs_reduction_along_minor_dimension =
574 rhs_reduction_dimension == LayoutUtil::Minor(rhs_shape.layout(), 0);
575
576 // Create loop nests which loop through the LHS operand dimensions and the RHS
577 // operand dimensions. The reduction dimension of the LHS and RHS are handled
578 // in a separate innermost loop which performs the sum of products.
579 llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_);
580 std::vector<llvm::Value*> lhs_multi_index =
581 loop_nest.EmitOperandArrayLoopNest(
582 lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
583 std::vector<llvm::Value*> rhs_multi_index =
584 loop_nest.EmitOperandArrayLoopNest(
585 rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
586
587 // Create the loop which does the sum of products reduction.
588 //
589 // The prevent_unrolling bit is working around a deficiency in LLVM's loop
590 // vectorization pipeline, wherein in some cases unrolling a loop can prevent
591 // effective vectorization. Since we know that the IR we generate when
592 // reducing across the minor dimension in both LHS and RHS is vectorized well
593 // by the loop vectorizer, we block unrolling in that case to stop loop unroll
594 // from messing up the vectorization.
595 std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
596 0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction",
597 /*unroll_mode=*/
598 (lhs_reduction_along_minor_dimension &&
599 rhs_reduction_along_minor_dimension)
600 ? xla::llvm_ir::UnrollMode::kNoUnroll
601 : xla::llvm_ir::UnrollMode::kDefaultUnroll);
602
603 // The final entry in the rhs and lhs indexes is the indvar of the
604 // reduction loop.
605 lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
606 llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_shape,
607 b_->getInt64Ty());
608 rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
609 llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_shape,
610 b_->getInt64Ty());
611
612 // For computing the sum of products we alloca a single location to store the
613 // dot product result as we accumulate it within the reduction loop. After the
614 // reduction loop we load the result and store into the output array.
615
616 // Function entry basic block.
617 // - Emit alloca for accumulator
618 llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent();
619 SetToFirstInsertPoint(&func->getEntryBlock(), b_);
620 llvm::Type* accum_type = target_array_.GetElementLlvmType();
621 llvm::Value* accum_address =
622 b_->CreateAlloca(accum_type, /*ArraySize=*/nullptr, "accum_address");
623
624 // Preheader basic block of reduction loop:
625 // - Initialize accumulator to zero.
626 llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
627 b_->SetInsertPoint(preheader_bb->getTerminator());
628
629 b_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address);
630
631 // Body basic block of reduction loop:
632 // - Load elements from lhs and rhs array.
633 // - Multiply lhs-element and rhs-element.
634 // - Load accumulator and add to product.
635 // - Store sum back into accumulator.
636 SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), b_);
637
638 llvm::Value* lhs_element = lhs_array_.EmitReadArrayElement(lhs_index, b_);
639 llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, b_);
640
641 llvm::Value* accum = b_->CreateLoad(accum_address);
642 llvm::Value* updated_accum;
643 if (ShapeUtil::ElementIsComplex(lhs_shape)) {
644 auto real = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {0}); };
645 auto imag = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {1}); };
646 llvm::Value* product_real =
647 b_->CreateFSub(b_->CreateFMul(real(lhs_element), real(rhs_element)),
648 b_->CreateFMul(imag(lhs_element), imag(rhs_element)));
649 llvm::Value* product_imag =
650 b_->CreateFAdd(b_->CreateFMul(real(lhs_element), imag(rhs_element)),
651 b_->CreateFMul(imag(lhs_element), real(rhs_element)));
652 updated_accum = b_->CreateInsertValue(
653 accum, b_->CreateFAdd(real(accum), product_real), {0});
654 updated_accum = b_->CreateInsertValue(
655 updated_accum, b_->CreateFAdd(imag(accum), product_imag), {1});
656 } else if (ShapeUtil::ElementIsIntegral(lhs_shape)) {
657 llvm::Value* product = b_->CreateMul(lhs_element, rhs_element);
658 updated_accum = b_->CreateAdd(accum, product);
659 } else if (lhs_shape.element_type() == PRED) {
660 llvm::Value* product = b_->CreateAnd(lhs_element, rhs_element);
661 updated_accum = b_->CreateOr(accum, product);
662 } else {
663 llvm::Value* product = b_->CreateFMul(lhs_element, rhs_element);
664 updated_accum = b_->CreateFAdd(accum, product);
665 }
666 b_->CreateStore(updated_accum, accum_address);
667
668 // Exit basic block of reduction loop.
669 // - Load accumulator value (the result).
670 // - Store into output array.
671 SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), b_);
672
673 llvm::Value* result = b_->CreateLoad(accum_address);
674
675 // Create index into target address. The target index is the concatenation of
676 // the rhs and lhs indexes with the reduction dimensions removed. The terms
677 // from the rhs index are the lower dimensions in the index so we add them
678 // first.
679 std::vector<llvm::Value*> target_multi_index;
680 for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
681 if (dimension != lhs_reduction_dimension) {
682 target_multi_index.push_back(lhs_index[dimension]);
683 }
684 }
685 for (int dimension = 0; dimension < rhs_index.size(); ++dimension) {
686 if (dimension != rhs_reduction_dimension) {
687 target_multi_index.push_back(rhs_index[dimension]);
688 }
689 }
690
691 llvm_ir::IrArray::Index target_index(
692 target_multi_index, target_array_.GetShape(), lhs_index.GetType());
693 target_array_.EmitWriteArrayElement(target_index, result, b_);
694
695 // Set the IR builder insert point to the exit basic block of the outer most
696 // loop.
697 b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
698 }
699
EmitScalarDot()700 Status DotOpEmitter::EmitScalarDot() {
701 // A scalar dot is just a scalar multiply.
702 llvm::Value* result;
703 // Use the same index_type for all tensor accesses in the same kernel.
704 llvm::Type* index_type = b_->getInt64Ty();
705 llvm_ir::IrArray::Index element_index(index_type);
706 llvm::Value* lhs_value =
707 lhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
708 llvm::Value* rhs_value =
709 rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
710 if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
711 auto get_real = [&](llvm::Value* x) {
712 return b_->CreateExtractValue(x, {0});
713 };
714
715 auto get_imag = [&](llvm::Value* x) {
716 return b_->CreateExtractValue(x, {1});
717 };
718
719 llvm::Value* real = b_->CreateFSub(
720 b_->CreateFMul(get_real(lhs_value), get_real(rhs_value)),
721 b_->CreateFMul(get_imag(lhs_value), get_imag(rhs_value)));
722 llvm::Value* imag = b_->CreateFAdd(
723 b_->CreateFMul(get_real(lhs_value), get_imag(rhs_value)),
724 b_->CreateFMul(get_imag(lhs_value), get_real(rhs_value)));
725 result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
726 result = b_->CreateInsertValue(result, real, {0});
727 result = b_->CreateInsertValue(result, imag, {1});
728 } else {
729 result = b_->CreateFMul(lhs_value, rhs_value);
730 }
731 target_array_.EmitWriteArrayElement(/*index=*/element_index, result, b_);
732 return Status::OK();
733 }
734
EmitCallToRuntime()735 Status DotOpEmitter::EmitCallToRuntime() {
736 // The signature of the Eigen runtime matmul function is:
737 //
738 // (void)(void* run_options, float* out, float* lhs, float* rhs,
739 // int64 m, int64 n, int64 k, int32 transpose_lhs,
740 // int32 transpose_rhs);
741 // The two transpose_... parameters are actually booleans, but we use int32
742 // to avoid target-dependent calling convention details.
743
744 bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_);
745 bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
746 PrimitiveType type = target_array_.GetShape().element_type();
747 llvm::Function* function = b_->GetInsertBlock()->getParent();
748 llvm::Module* module = function->getParent();
749 llvm::Type* float_type;
750 const char* fn_name;
751 switch (type) {
752 case F16:
753 fn_name = multi_threaded
754 ? runtime::kEigenMatMulF16SymbolName
755 : runtime::kEigenSingleThreadedMatMulF16SymbolName;
756 float_type = b_->getHalfTy();
757 break;
758 case F32:
759 fn_name = multi_threaded
760 ? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName
761 : runtime::kEigenMatMulF32SymbolName)
762 : (use_mkl_dnn
763 ? runtime::kMKLSingleThreadedMatMulF32SymbolName
764 : runtime::kEigenSingleThreadedMatMulF32SymbolName);
765 float_type = b_->getFloatTy();
766 break;
767 case F64:
768 fn_name = multi_threaded
769 ? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName
770 : runtime::kEigenMatMulF64SymbolName)
771 : (use_mkl_dnn
772 ? runtime::kMKLSingleThreadedMatMulF64SymbolName
773 : runtime::kEigenSingleThreadedMatMulF64SymbolName);
774 float_type = b_->getDoubleTy();
775 break;
776 case C64:
777 fn_name = multi_threaded
778 ? runtime::kEigenMatMulC64SymbolName
779 : runtime::kEigenSingleThreadedMatMulC64SymbolName;
780 float_type = llvm_ir::PrimitiveTypeToIrType(C64, module);
781 break;
782 case C128:
783 fn_name = multi_threaded
784 ? runtime::kEigenMatMulC128SymbolName
785 : runtime::kEigenSingleThreadedMatMulC128SymbolName;
786 float_type = llvm_ir::PrimitiveTypeToIrType(C128, module);
787 break;
788 case S32:
789 fn_name = multi_threaded
790 ? runtime::kEigenMatMulS32SymbolName
791 : runtime::kEigenSingleThreadedMatMulS32SymbolName;
792 float_type = b_->getInt32Ty();
793 break;
794 default:
795 return Unimplemented("Invalid type %s for dot operation",
796 PrimitiveType_Name(type));
797 }
798
799 llvm::Type* float_ptr_type = float_type->getPointerTo();
800 llvm::Type* int64_type = b_->getInt64Ty();
801 llvm::Type* int32_type = b_->getInt32Ty();
802 llvm::Type* int8_ptr_type = b_->getInt8Ty()->getPointerTo();
803 llvm::FunctionType* matmul_type = llvm::FunctionType::get(
804 b_->getVoidTy(),
805 {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
806 int64_type, int64_type, int64_type, int32_type, int32_type},
807 /*isVarArg=*/false);
808
809 llvm::FunctionCallee matmul_func =
810 module->getOrInsertFunction(fn_name, matmul_type);
811 if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
812 fn->setCallingConv(llvm::CallingConv::C);
813 fn->setDoesNotThrow();
814 fn->setOnlyAccessesArgMemory();
815 }
816
817 // The Eigen runtime function expects column-major layout. If the matrices are
818 // row major, then use the following identity to compute the product:
819 //
820 // (A x B)^T = B^T x A^T
821 //
822 // The connection between this identity and memory layout is that the
823 // transpose operation can also be considered as an operation that changes the
824 // memory layout of a matrix from row-major to column-major or vice versa.
825 //
826 // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
827
828 MatMultDims mat_mult_dims = GetMatMultDims();
829
830 CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
831
832 const llvm_ir::IrArray* lhs = &lhs_array_;
833 const llvm_ir::IrArray* rhs = &rhs_array_;
834 bool transpose_lhs = !mat_mult_dims.lhs_canonical;
835 bool transpose_rhs = !mat_mult_dims.rhs_canonical;
836
837 if (!mat_mult_dims.lhs_column_major) {
838 std::swap(mat_mult_dims.m, mat_mult_dims.n);
839 std::swap(lhs, rhs);
840 std::swap(transpose_lhs, transpose_rhs);
841 }
842
843 b_->CreateCall(
844 matmul_func,
845 {b_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
846 b_->CreateBitCast(target_array_.GetBasePointer(), float_ptr_type),
847 b_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
848 b_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
849 b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n),
850 b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs),
851 b_->getInt32(transpose_rhs)});
852 return Status::OK();
853 }
854
GetMatMultDims() const855 DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
856 CHECK_LE(dot_info_.result_shape.dimensions_size(), 2);
857
858 const Shape& lhs_shape = lhs_array_.GetShape();
859 const Shape& rhs_shape = rhs_array_.GetShape();
860 const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
861
862 auto is_column_major = [](const Shape& shape) {
863 return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0;
864 };
865
866 // Non-contracting dots should never make it here.
867 CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0);
868 CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0);
869
870 return {
871 /*m=*/lhs_shape.rank() <= 1
872 ? 1LL
873 : lhs_shape.dimensions(1LL - dim_nums.lhs_contracting_dimensions(0)),
874 /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
875 /*n=*/rhs_shape.rank() <= 1
876 ? 1LL
877 : rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)),
878 /*lhs_column_major=*/is_column_major(lhs_shape),
879 /*lhs_canonical=*/lhs_shape.rank() <= 1 ||
880 dim_nums.lhs_contracting_dimensions(0) == 1,
881 /*rhs_column_major=*/is_column_major(rhs_shape),
882 /*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0};
883 }
884
885 // For vector-matrix dot products, it is always profitable to make the Rhs
886 // column major.
ProfitableToMakeDotOperandColumnMajor(const HloInstruction & hlo)887 absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
888 const HloInstruction& hlo) {
889 if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() <= 1) {
890 if (hlo.operand(0)->shape().rank() != 1 ||
891 hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) != 0) {
892 return {};
893 }
894
895 // Don't bother if the other operand is tiny, switching to column major
896 // wouldn't use tiling.
897 constexpr int kColumnMajorThresholdInBytes = 32;
898 int64_t lhs_size =
899 ShapeUtil::ByteSizeOfPrimitiveType(hlo.shape().element_type()) *
900 ShapeUtil::ElementsIn(hlo.operand(0)->shape());
901 if (lhs_size < kColumnMajorThresholdInBytes) {
902 return {};
903 }
904
905 return 1;
906 }
907
908 if (hlo.IsOutputFusion()) {
909 auto* fusion_root =
910 hlo.fused_instructions_computation()->root_instruction();
911 if (fusion_root->opcode() != HloOpcode::kAdd) {
912 return {};
913 }
914
915 for (auto* fusion_root_op : fusion_root->operands()) {
916 if (fusion_root_op->opcode() != HloOpcode::kDot) {
917 continue;
918 }
919 if (auto operand_num =
920 ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) {
921 auto* operand = fusion_root_op->operand(*operand_num);
922 if (operand->opcode() == HloOpcode::kParameter &&
923 operand->user_count() == 1) {
924 return operand->parameter_number();
925 }
926 }
927 }
928 }
929
930 return {};
931 }
932
933 namespace {
934 // Return whether the given shape is rank 2.
IsRank2(const Shape & shape)935 bool IsRank2(const Shape& shape) { return shape.rank() == 2; }
936
IsSimpleLayout(const Layout & layout)937 bool IsSimpleLayout(const Layout& layout) {
938 return layout.tiles().empty() && layout.format() == DENSE;
939 }
940
941 // In a gemm operation where output = lhs * rhs, check whether the given shapes
942 // are valid for the operation.
AreGemmShapes(const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,const TargetMachineFeatures & target_machine_features)943 bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
944 const Shape& output_shape,
945 const TargetMachineFeatures& target_machine_features) {
946 CHECK(!lhs_shape.has_layout() || IsSimpleLayout(lhs_shape.layout()))
947 << lhs_shape.DebugString();
948 CHECK(!rhs_shape.has_layout() || IsSimpleLayout(rhs_shape.layout()))
949 << rhs_shape.DebugString();
950 CHECK(!output_shape.has_layout() || IsSimpleLayout(output_shape.layout()))
951 << output_shape.DebugString();
952
953 switch (output_shape.element_type()) {
954 case F16:
955 case F32:
956 case F64:
957 case C64:
958 case C128:
959 case S32:
960 return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape);
961 default:
962 return false;
963 }
964 }
965
IsAlignedGemm(const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)966 bool IsAlignedGemm(const DotInfo& dot_info,
967 const TargetMachineFeatures& target_machine_features) {
968 if (ShapeUtil::IsZeroElementArray(dot_info.lhs_shape) ||
969 ShapeUtil::IsZeroElementArray(dot_info.rhs_shape)) {
970 return false;
971 }
972
973 return AreGemmShapes(dot_info.lhs_shape, dot_info.rhs_shape,
974 dot_info.result_shape, target_machine_features);
975 }
976
CanEmitTiledLlvmIrGemm(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)977 bool CanEmitTiledLlvmIrGemm(
978 const HloModuleConfig& config, const DotInfo& dot_info,
979 const TargetMachineFeatures& target_machine_features) {
980 CHECK(IsAlignedGemm(dot_info, target_machine_features));
981
982 if (ShouldUseMultiThreadedEigen(config)) {
983 return false;
984 }
985
986 int m = dot_info.result_shape.dimensions(0);
987 int k = dot_info.lhs_shape.dimensions(
988 dot_info.dim_nums.lhs_contracting_dimensions(0));
989 int n = dot_info.result_shape.dimensions(1);
990
991 if (!options::ForceEnableExperimentalLlvmIrGemm(config)) {
992 // TODO(sanjoy): We should make these numbers micro-arch specific.
993 bool small_gemm =
994 k <= 128 && ((m <= 32 && n <= 128) || (m <= 128 && n <= 32));
995 if (!small_gemm) {
996 return false;
997 }
998 }
999
1000 bool lhs_canonical = dot_info.dim_nums.lhs_contracting_dimensions(0) == 1;
1001 bool rhs_canonical = dot_info.dim_nums.rhs_contracting_dimensions(0) == 0;
1002
1003 if (!(lhs_canonical && rhs_canonical)) {
1004 return false;
1005 }
1006
1007 if (dot_info.result_shape.element_type() == F16 ||
1008 dot_info.result_shape.element_type() == C64 ||
1009 dot_info.result_shape.element_type() == C128) {
1010 // TODO(sanjoy): This is probably easy to fix, but I want to keep the CL
1011 // adding this comment NFC.
1012 return false;
1013 }
1014
1015 return true;
1016 }
1017
GetDotImplementationStrategy(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)1018 DotImplementationStrategy GetDotImplementationStrategy(
1019 const HloModuleConfig& config, const DotInfo& dot_info,
1020 const TargetMachineFeatures& target_machine_features) {
1021 PrimitiveType element_type = dot_info.result_shape.element_type();
1022 // Any Matrix-Vector product of floating point or integral type, or
1023 // a transpose-dot fusion of the same can be lowered to a tiled LLVM
1024 // IR implementation.
1025 if ((dot_info.result_shape.dimensions_size() <= 1 ||
1026 (dot_info.result_shape.dimensions_size() == 2 &&
1027 (dot_info.result_shape.dimensions(0) == 1 ||
1028 dot_info.result_shape.dimensions(1) == 1))) &&
1029 (primitive_util::IsFloatingPointType(element_type) ||
1030 primitive_util::IsIntegralType(element_type))) {
1031 return DotImplementationStrategy::kTiledLlvmIrGemv;
1032 }
1033
1034 // MatMul smaller than 3x3 should use naive nested loop.
1035 if ((dot_info.lhs_shape.dimensions_size() <= 1 ||
1036 (dot_info.lhs_shape.dimensions_size() == 2 &&
1037 (dot_info.lhs_shape.dimensions(0) <= 3 ||
1038 dot_info.lhs_shape.dimensions(1) <= 3))) &&
1039 (dot_info.rhs_shape.dimensions_size() <= 1 ||
1040 (dot_info.rhs_shape.dimensions_size() == 2 &&
1041 (dot_info.rhs_shape.dimensions(0) <= 3 ||
1042 dot_info.rhs_shape.dimensions(1) <= 3))) &&
1043 (primitive_util::IsFloatingPointType(element_type) ||
1044 primitive_util::IsIntegralType(element_type))) {
1045 return DotImplementationStrategy::kNaiveLlvmIr;
1046 }
1047
1048 if (IsAlignedGemm(dot_info, target_machine_features)) {
1049 if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) {
1050 return DotImplementationStrategy::kTiledLlvmIrGemm;
1051 }
1052 return DotImplementationStrategy::kEigen;
1053 }
1054
1055 return DotImplementationStrategy::kNaiveLlvmIr;
1056 }
1057
EmitNonBatchDotOperation(DotInfo dot_info,string hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1058 Status EmitNonBatchDotOperation(
1059 DotInfo dot_info, string hlo_name, const llvm_ir::IrArray& target_array,
1060 const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
1061 const llvm_ir::IrArray* addend_array,
1062 llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
1063 mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
1064 const TargetMachineFeatures& target_machine_features) {
1065 PrimitiveType type = target_array.GetShape().element_type();
1066 TF_RET_CHECK(PRED == type || S8 == type || U8 == type || S16 == type ||
1067 U16 == type || S32 == type || U32 == type || S64 == type ||
1068 U64 == type || F16 == type || F32 == type || F64 == type ||
1069 C64 == type || C128 == type);
1070 DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
1071 target_array, lhs_array, rhs_array, addend_array,
1072 executable_run_options_value, b, mlir_context,
1073 hlo_module_config, target_machine_features);
1074 return dot_emitter.Emit();
1075 }
1076
DropFirstDim(const Shape & shape)1077 Shape DropFirstDim(const Shape& shape) {
1078 absl::Span<int64 const> array_shape_dims(shape.dimensions());
1079 array_shape_dims.remove_prefix(1);
1080 return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1081 array_shape_dims);
1082 }
1083
CollapseFirstNDims(const Shape & shape,int64_t n)1084 Shape CollapseFirstNDims(const Shape& shape, int64_t n) {
1085 absl::Span<int64 const> input_shape_dims(shape.dimensions());
1086 int64_t prefix_dim =
1087 std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n,
1088 1ll, std::multiplies<int64>());
1089 DimensionVector result_dims;
1090 result_dims.push_back(prefix_dim);
1091 std::copy(input_shape_dims.begin() + n, input_shape_dims.end(),
1092 std::back_inserter(result_dims));
1093 return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1094 result_dims);
1095 }
1096
CollapseFirstNDims(llvm::IRBuilder<> * b,const llvm_ir::IrArray & array,int64_t n)1097 llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilder<>* b,
1098 const llvm_ir::IrArray& array, int64_t n) {
1099 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
1100 const Shape& shape = array.GetShape();
1101 CHECK(shape.has_layout() &&
1102 LayoutUtil::IsMonotonicWithDim0Major(shape.layout()));
1103 CHECK_GE(shape.dimensions_size(), n);
1104 Shape new_shape = CollapseFirstNDims(shape, n);
1105 llvm::Value* new_value = b->CreateBitCast(
1106 array.GetBasePointer(),
1107 llvm_ir::ShapeToIrType(new_shape, module)->getPointerTo());
1108 return llvm_ir::IrArray(new_value, std::move(new_shape));
1109 }
1110
ValidateDotDimensionNumbers(const DotDimensionNumbers & dim_numbers)1111 Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) {
1112 // Checks some invariants that do not hold in general, but DotDecomposer
1113 // should have established for us. This is just a debugging aid.
1114 TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
1115 std::vector<int64> batch_dim_numbers(dim_numbers.lhs_batch_dimensions_size());
1116 absl::c_iota(batch_dim_numbers, 0);
1117 TF_RET_CHECK(
1118 absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions()));
1119 TF_RET_CHECK(
1120 absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions()));
1121 return Status::OK();
1122 }
1123
1124 // Slice out the inner array at batch index `batch_index` from `outer_array`.
SliceOutInnerArray(llvm_ir::IrArray outer_array,llvm::Value * batch_index,llvm::IRBuilder<> * b)1125 llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array,
1126 llvm::Value* batch_index,
1127 llvm::IRBuilder<>* b) {
1128 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
1129
1130 Shape inner_shape = DropFirstDim(outer_array.GetShape());
1131 std::vector<llvm::Value*> multidim_index(inner_shape.rank() + 1,
1132 b->getInt64(0));
1133 multidim_index[0] = batch_index;
1134 llvm_ir::IrArray::Index slice_index(multidim_index, outer_array.GetShape(),
1135 batch_index->getType());
1136 llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b);
1137 llvm::Type* slice_ptr_type =
1138 llvm_ir::ShapeToIrType(inner_shape, module)->getPointerTo();
1139 return llvm_ir::IrArray(b->CreateBitCast(slice_ptr, slice_ptr_type),
1140 std::move(inner_shape));
1141 }
1142
EmitBatchDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1143 Status EmitBatchDotOperation(
1144 const HloInstruction& dot, const llvm_ir::IrArray& target_array,
1145 const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
1146 llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
1147 mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
1148 const TargetMachineFeatures& target_machine_features) {
1149 TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers()));
1150
1151 // Lower a batch dot into a sequence of non-batch dot operations.
1152
1153 int64_t num_batch_dims =
1154 dot.dot_dimension_numbers().lhs_batch_dimensions_size();
1155
1156 // First reshape the inputs to make sure we only have one batch dimension.
1157 // This is a no-op bitcast because the operands have to be in row-major layout
1158 // (enforced in CpuLayoutAssignment), and the batch dimensions are the leading
1159 // dimensions (established by DotDecomposer and checked by
1160 // ValidateDotDimensionNumbers above).
1161 llvm_ir::IrArray lhs_array_reshaped =
1162 CollapseFirstNDims(b, lhs_array, num_batch_dims);
1163 llvm_ir::IrArray rhs_array_reshaped =
1164 CollapseFirstNDims(b, rhs_array, num_batch_dims);
1165 llvm_ir::IrArray target_array_reshaped =
1166 CollapseFirstNDims(b, target_array, num_batch_dims);
1167
1168 int64_t batch_count = lhs_array_reshaped.GetShape().dimensions(0);
1169
1170 KernelSupportLibrary ksl(b);
1171
1172 return ksl.ForWithStatus(
1173 llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count,
1174 /*step=*/1, [&](llvm::Value* indvar) {
1175 DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers();
1176 adjusted_dim_numbers.clear_lhs_batch_dimensions();
1177 adjusted_dim_numbers.clear_rhs_batch_dimensions();
1178
1179 // Create a DotInfo representing the "inner" non-batch dot operation.
1180 DotInfo dot_info;
1181 dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape());
1182 dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape());
1183 dot_info.result_shape = DropFirstDim(target_array_reshaped.GetShape());
1184 dot_info.dim_nums = dot.dot_dimension_numbers();
1185 dot_info.dim_nums.clear_lhs_batch_dimensions();
1186 dot_info.dim_nums.clear_rhs_batch_dimensions();
1187
1188 dot_info.dim_nums.set_lhs_contracting_dimensions(
1189 0,
1190 dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims);
1191 dot_info.dim_nums.set_rhs_contracting_dimensions(
1192 0,
1193 dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims);
1194
1195 llvm_ir::IrArray lhs_slice =
1196 SliceOutInnerArray(lhs_array_reshaped, /*batch_index=*/indvar, b);
1197 llvm_ir::IrArray rhs_slice =
1198 SliceOutInnerArray(rhs_array_reshaped, /*batch_index=*/indvar, b);
1199 llvm_ir::IrArray target_slice = SliceOutInnerArray(
1200 target_array_reshaped, /*batch_index=*/indvar, b);
1201
1202 // Emit the inner non-batch dot operation.
1203 return EmitNonBatchDotOperation(
1204 dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr,
1205 executable_run_options_value, b, mlir_context, hlo_module_config,
1206 target_machine_features);
1207 });
1208 }
1209
IsBatchDot(const HloInstruction & instr)1210 bool IsBatchDot(const HloInstruction& instr) {
1211 if (auto* dot_instr = DynCast<HloDotInstruction>(&instr)) {
1212 return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0;
1213 }
1214
1215 return false;
1216 }
1217 } // namespace
1218
DotImplementationCanHandleTranspose(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1219 bool DotImplementationCanHandleTranspose(
1220 const HloInstruction& dot_instr,
1221 const TargetMachineFeatures& target_machine_features) {
1222 DotImplementationStrategy impl_strategy =
1223 GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1224 DotInfo(dot_instr), target_machine_features);
1225
1226 return impl_strategy == DotImplementationStrategy::kNaiveLlvmIr ||
1227 impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemv ||
1228 impl_strategy == DotImplementationStrategy::kEigen;
1229 }
1230
DotOperandsAndResultMustHaveRowMajorLayout(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1231 bool DotOperandsAndResultMustHaveRowMajorLayout(
1232 const HloInstruction& dot_instr,
1233 const TargetMachineFeatures& target_machine_features) {
1234 // Batched dots require the batch dimensions to be major. DotDecomposer always
1235 // moves batch dimensions to the front of the shape, so force a row-major
1236 // layout.
1237 if (IsBatchDot(dot_instr)) {
1238 return true;
1239 }
1240
1241 DotImplementationStrategy impl_strategy =
1242 GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1243 DotInfo(dot_instr), target_machine_features);
1244
1245 return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm ||
1246 impl_strategy == DotImplementationStrategy::kEigen;
1247 }
1248
EmitDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,mlir::MLIRContext * mlir_context,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1249 Status EmitDotOperation(const HloInstruction& dot,
1250 const llvm_ir::IrArray& target_array,
1251 const llvm_ir::IrArray& lhs_array,
1252 const llvm_ir::IrArray& rhs_array,
1253 const llvm_ir::IrArray* addend_array,
1254 llvm::Value* executable_run_options_value,
1255 llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
1256 const HloModuleConfig& hlo_module_config,
1257 const TargetMachineFeatures& target_machine_features) {
1258 // This routine assumes that the dot operation is not in a parallelized
1259 // enclosing computation.
1260 CHECK(dot.parent()->root_instruction()->outer_dimension_partitions().empty());
1261
1262 if (IsBatchDot(dot)) {
1263 TF_RET_CHECK(addend_array == nullptr);
1264 return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array,
1265 executable_run_options_value, b, mlir_context,
1266 hlo_module_config, target_machine_features);
1267 }
1268
1269 return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array,
1270 lhs_array, rhs_array, addend_array,
1271 executable_run_options_value, b, mlir_context,
1272 hlo_module_config, target_machine_features);
1273 }
1274 } // namespace cpu
1275 } // namespace xla
1276