• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/strings/str_cat.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/Value.h"
26 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
27 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
28 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
29 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
30 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
31 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
32 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
35 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/util.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/platform/logging.h"
41 
42 namespace xla {
43 
44 using llvm_ir::SetToFirstInsertPoint;
45 
46 namespace cpu {
47 namespace {
48 // Returns true if we should call into multi-threaded Eigen routines.
ShouldUseMultiThreadedEigen(const HloModuleConfig & config)49 bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) {
50   return config.debug_options().xla_cpu_multi_thread_eigen();
51 }
52 
53 // Represents a dot operation.  We use this in lieu of an `HloInstruction`
54 // because we want to be able to create this for the "inner" dot operation in a
55 // batch dot, for which there is no separate HLO instruction.
56 struct DotInfo {
57   Shape lhs_shape;
58   Shape rhs_shape;
59   Shape result_shape;
60   DotDimensionNumbers dim_nums;
61 
62   DotInfo() = default;
63 
DotInfoxla::cpu::__anond13f39450111::DotInfo64   explicit DotInfo(const HloInstruction& instr) {
65     CHECK_EQ(instr.opcode(), HloOpcode::kDot);
66     lhs_shape = instr.operand(0)->shape();
67     rhs_shape = instr.operand(1)->shape();
68     result_shape = instr.shape();
69     dim_nums = instr.dot_dimension_numbers();
70   }
71 };
72 
73 // Dictates how a dot operation is implemented.
74 enum class DotImplementationStrategy {
75   // The dot operation is lowered into LLVM IR that implements a naive nested
76   // loop that computes the result one element at a time.  This is our
77   // "fallback"; we don't really want this to kick in for any non-trival dot
78   // operation.
79   kNaiveLlvmIr,
80 
81   // The dot operation is lowered into LLVM IR that implements a tiled
82   // Matrix*Vector operation.  This strategy also allows fusing in a bias add
83   // into the dot.  The matrix can be row major or column major, both are
84   // supported.
85   kTiledLlvmIrGemv,
86 
87   // The dot operation is lowered into LLVM IR that implemetns a tiled
88   // Matrix*Matrix operation.  No fusions are supported.  The two inputs
89   // and the output have to be row major.
90   kTiledLlvmIrGemm,
91 
92   // The dot operation is lowered into a call into an Eigen routine.  No fusions
93   // are supported today.  The two inputs and the output have to be row major.
94   // However, we do allow transposing either the LHS or the RHS as part of the
95   // GEMM -- we expose this flexibility as flexibility in the contraction
96   // dimensions, but we can also see this as flexibility in the input layouts.
97   kEigen,
98 };
99 
100 // Returns the implementation strategy for a dot with the configuration
101 // `dot_info`.
102 DotImplementationStrategy GetDotImplementationStrategy(
103     const HloModuleConfig& config, const DotInfo& dot_info,
104     const TargetMachineFeatures& target_machine_features);
105 
106 // Helper class for emitting LLVM IR to perform the dot operation.
107 class DotOpEmitter {
108  public:
109   explicit DotOpEmitter(DotInfo dot_info, string dot_hlo_name,
110                         const llvm_ir::IrArray& target_array,
111                         const llvm_ir::IrArray& lhs_array,
112                         const llvm_ir::IrArray& rhs_array,
113                         const llvm_ir::IrArray* addend_array,
114                         llvm::Value* executable_run_options_value,
115                         llvm::IRBuilder<>* b,
116                         const HloModuleConfig& hlo_module_config,
117                         const TargetMachineFeatures& target_machine_features);
118 
119   // Emits the IR to perform the dot operation.
120   Status Emit();
121 
122  private:
123   // Emits instructions to perform a scalar dot product (a multiply of the
124   // LHS and RHS) and store the results in the target.
125   Status EmitScalarDot();
126 
127   // Emits a call to the CPU runtime to perform the matrix multiply.
128   Status EmitCallToRuntime();
129 
130   // Represents the dimensions of a matrix-matrix multiply operation.
131   struct MatMultDims {
132     // The number of rows in the LHS.
133     int64 m;
134 
135     // The number of columns in the LHS, which is also must be equal to the
136     // number of rows in the RHS.
137     int64 k;
138 
139     // The number of columns on the RHS.
140     int64 n;
141 
142     // True if the LHS matrix is column major.
143     bool lhs_column_major;
144 
145     // True if the LHS contraction dimension is not 1.
146     bool lhs_non_canonical;
147 
148     // True if the RHS matrix is column major.
149     bool rhs_column_major;
150 
151     // True if the RHS contraction dimension is not 0.
152     bool rhs_non_canonical;
153 
154     // True if the result matrix is column major.
155     bool target_column_major;
156   };
157 
158   // Get the MatMultDims instance for the dot product this DotOpEmitter
159   // represents.  Precondition: the dot is of rank 2 (and thus its operands are
160   // of rank 2 as well).
161   MatMultDims GetMatMultDims() const;
162 
163   // Lowers the dot operation as a tiled Matrix*Vector loop.
164   void EmitTiledLlvmIrGemv();
165 
166   // Lowers the dot operation as a tiled Matrix*Matrix loop.
167   void EmitTiledLlvmIrGemm();
168 
169   // Lowers the dot operation as a naive nested loop that computes the result
170   // one element at a time.
171   void EmitNaiveLlvmIrGemm();
172 
173   // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
174   // registers.
GetGemvTilingFactor() const175   int64 GetGemvTilingFactor() const {
176     const int64 kDefaultTilingFactor = 8;
177     return options::LlvmIrGemvTilingFactor(hlo_module_config_)
178         .value_or(kDefaultTilingFactor);
179   }
180 
GetGemmTileSize() const181   std::tuple<int64, int64, int64> GetGemmTileSize() const {
182     // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
183     //
184     // TODO(b/80093688): Tune for other architectures and centralize this
185     // information in one place.
186     const std::tuple<int64, int64, int64> kDefaultTileSize =
187         std::tuple<int64, int64, int64>(11, 9, 1);
188     return options::LlvmIrGemmTileSize(hlo_module_config_)
189         .value_or(kDefaultTileSize);
190   }
191 
192   DotInfo dot_info_;
193   string dot_hlo_name_;
194   const llvm_ir::IrArray& target_array_;
195   const llvm_ir::IrArray& lhs_array_;
196   const llvm_ir::IrArray& rhs_array_;
197   const llvm_ir::IrArray* addend_array_;
198   llvm::Value* executable_run_options_value_;
199   llvm::IRBuilder<>* b_;
200   const HloModuleConfig& hlo_module_config_;
201   const TargetMachineFeatures& target_machine_features_;
202 };
203 }  // namespace
204 
DotOpEmitter(DotInfo dot_info,string dot_hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)205 DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name,
206                            const llvm_ir::IrArray& target_array,
207                            const llvm_ir::IrArray& lhs_array,
208                            const llvm_ir::IrArray& rhs_array,
209                            const llvm_ir::IrArray* addend_array,
210                            llvm::Value* executable_run_options_value,
211                            llvm::IRBuilder<>* b,
212                            const HloModuleConfig& hlo_module_config,
213                            const TargetMachineFeatures& target_machine_features)
214     : dot_info_(std::move(dot_info)),
215       dot_hlo_name_(std::move(dot_hlo_name)),
216       target_array_(target_array),
217       lhs_array_(lhs_array),
218       rhs_array_(rhs_array),
219       addend_array_(addend_array),
220       executable_run_options_value_(executable_run_options_value),
221       b_(b),
222       hlo_module_config_(hlo_module_config),
223       target_machine_features_(target_machine_features) {}
224 
EmitTiledLlvmIrGemm()225 void DotOpEmitter::EmitTiledLlvmIrGemm() {
226   PrimitiveType primitive_type = dot_info_.result_shape.element_type();
227   MatMultDims mat_mult_dims = GetMatMultDims();
228 
229   llvm::Value* lhs = lhs_array_.GetBasePointer();
230   llvm::Value* rhs = rhs_array_.GetBasePointer();
231   llvm::Value* target = target_array_.GetBasePointer();
232   int64 m = mat_mult_dims.m;
233   int64 k = mat_mult_dims.k;
234   int64 n = mat_mult_dims.n;
235 
236   if (mat_mult_dims.lhs_column_major) {
237     std::swap(lhs, rhs);
238     std::swap(m, n);
239   }
240 
241   int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
242   b_->CreateMemSet(target, b_->getInt8(0), /*Size=*/size_bytes,
243                    /*Align=*/1);
244 
245   int64 max_target_vector_width =
246       target_machine_features_.vector_register_num_elements(
247           *b_->GetInsertBlock()->getParent(), primitive_type);
248 
249   int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width;
250   std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
251       GetGemmTileSize();
252 
253   EmitSmallGemm(
254       /*scalar_type=*/primitive_type,
255       /*m=*/m, /*k=*/k, /*n=*/n,
256       /*max_vectorization_width=*/max_target_vector_width,
257       /*max_vector_count=*/tile_size_n_in_vector_width,
258       /*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width),
259       /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs,
260       /*rhs=*/rhs, /*result=*/target, b_, hlo_module_config_);
261 }
262 
EmitTiledLlvmIrGemv()263 void DotOpEmitter::EmitTiledLlvmIrGemv() {
264   PrimitiveType primitive_type = dot_info_.result_shape.element_type();
265 
266   CHECK(primitive_util::IsFloatingPointType(primitive_type) ||
267         primitive_util::IsIntegralType(primitive_type));
268 
269   MatMultDims mat_mult_dims = GetMatMultDims();
270   bool is_column_major_matrix_vector = false;
271   bool is_row_major_matrix_vector = false;
272 
273   int64 m, k;
274   bool swap_operands;
275 
276   if (mat_mult_dims.m == 1) {
277     bool rhs_effectively_row_major =
278         mat_mult_dims.rhs_non_canonical ^ !mat_mult_dims.rhs_column_major;
279     if (rhs_effectively_row_major) {
280       k = mat_mult_dims.k;
281       m = mat_mult_dims.n;
282       is_column_major_matrix_vector = true;
283       swap_operands = true;
284     } else {
285       k = mat_mult_dims.k;
286       m = mat_mult_dims.n;
287       is_row_major_matrix_vector = true;
288       swap_operands = true;
289     }
290   }
291 
292   if (mat_mult_dims.n == 1) {
293     bool lhs_effectively_column_major =
294         mat_mult_dims.lhs_non_canonical ^ mat_mult_dims.lhs_column_major;
295     if (lhs_effectively_column_major) {
296       m = mat_mult_dims.m;
297       k = mat_mult_dims.k;
298       is_column_major_matrix_vector = true;
299       swap_operands = false;
300     } else {
301       m = mat_mult_dims.m;
302       k = mat_mult_dims.k;
303       is_row_major_matrix_vector = true;
304       swap_operands = false;
305     }
306   }
307 
308   CHECK(is_column_major_matrix_vector || is_row_major_matrix_vector);
309 
310   int64 tiling_factor = GetGemvTilingFactor();
311   CHECK_GT(tiling_factor, 0);
312 
313   llvm::Value* result_op = target_array_.GetBasePointer();
314   llvm::Value* lhs_op =
315       swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
316   llvm::Value* rhs_op =
317       swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
318 
319   const int target_vector_register_element_size =
320       target_machine_features_.vector_register_num_elements(
321           *b_->GetInsertBlock()->getParent(), primitive_type);
322 
323   // We may not always know the vector register size for the target we're
324   // compiling against, in which case target_vector_register_element_size is 0.
325   // In these cases we choose a default LLVM IR register size.
326   const int kUnknownTargetVectorRegisterSize = 4;
327   const int vector_register_element_size =
328       target_vector_register_element_size == 0
329           ? kUnknownTargetVectorRegisterSize
330           : target_vector_register_element_size;
331 
332   if (is_column_major_matrix_vector) {
333     VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
334             << " and k = " << k;
335     EmitColumnMajorGemv(
336         /*scalar_type=*/primitive_type,
337         /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
338         /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
339         /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
340         /*result=*/result_op, b_, hlo_module_config_);
341   } else {
342     VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
343             << " and k = " << k;
344     EmitRowMajorGemv(
345         /*scalar_type=*/primitive_type,
346         /*tile_rows=*/tiling_factor,
347         /*tile_cols=*/vector_register_element_size,
348         /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
349         /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
350         /*result=*/result_op, b_, hlo_module_config_);
351   }
352 }
353 
Emit()354 Status DotOpEmitter::Emit() {
355   // The dot operation performs a sum of products over dimension 0 of the left
356   // hand side operand and dimension 1 of the right hand side operand.
357   //
358   // Let the shapes of lhs and rhs be defined as below:
359   //
360   //   lhs = [L{n-1} x L{n-2} x ... L{0}]
361   //   rhs = [R{m-1} x R{m-2} x ... R{0}]
362   //
363   // The sum-of-products dimension in the lhs has size L{0} and the dimension in
364   // the rhs has size R{1}. Necessarily, then:
365   //
366   //   L{0} == R{1}
367   //
368   // The output of the operation has the following shape:
369   //
370   //   output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
371   //
372   // To perform the operation we construct a loop nest with one for-loop for
373   // each dimension of the output. Inside this loop nest is another for-loop
374   // which performs the sum-of-products (the reduction loop) before storing
375   // the result in the output buffer.
376 
377   const Shape& lhs_shape = lhs_array_.GetShape();
378   const Shape& rhs_shape = rhs_array_.GetShape();
379 
380   if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
381     // If the operands are scalar, don't emit any loops.
382     TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) &&
383                  ShapeUtil::IsScalar(rhs_shape));
384     return EmitScalarDot();
385   }
386 
387   switch (GetDotImplementationStrategy(hlo_module_config_, dot_info_,
388                                        target_machine_features_)) {
389     case DotImplementationStrategy::kNaiveLlvmIr:
390       EmitNaiveLlvmIrGemm();
391       return Status::OK();
392 
393     case DotImplementationStrategy::kTiledLlvmIrGemv:
394       EmitTiledLlvmIrGemv();
395       return Status::OK();
396 
397     case DotImplementationStrategy::kTiledLlvmIrGemm:
398       EmitTiledLlvmIrGemm();
399       return Status::OK();
400 
401     case DotImplementationStrategy::kEigen:
402       return EmitCallToRuntime();
403   }
404 }
405 
EmitNaiveLlvmIrGemm()406 void DotOpEmitter::EmitNaiveLlvmIrGemm() {
407   CHECK_EQ(addend_array_, nullptr);
408 
409   const Shape& lhs_shape = lhs_array_.GetShape();
410   const Shape& rhs_shape = rhs_array_.GetShape();
411   const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
412 
413   // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special
414   // case where the reduction dimension is 0 for both LHS and RHS. This results
415   // in a vector dot product producing a scalar.
416   int64 lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0);
417   int64 rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0);
418 
419   // Verify the reduction dimension in the two operands are the same size.
420   CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension),
421            rhs_shape.dimensions(rhs_reduction_dimension));
422 
423   bool lhs_reduction_along_minor_dimension =
424       lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0);
425   bool rhs_reduction_along_minor_dimension =
426       rhs_reduction_dimension == LayoutUtil::Minor(rhs_shape.layout(), 0);
427 
428   // Create loop nests which loop through the LHS operand dimensions and the RHS
429   // operand dimensions. The reduction dimension of the LHS and RHS are handled
430   // in a separate innermost loop which performs the sum of products.
431   llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_);
432   std::vector<llvm::Value*> lhs_multi_index =
433       loop_nest.EmitOperandArrayLoopNest(
434           lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
435   std::vector<llvm::Value*> rhs_multi_index =
436       loop_nest.EmitOperandArrayLoopNest(
437           rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
438 
439   // Create the loop which does the sum of products reduction.
440   //
441   // The prevent_unrolling bit is working around a deficiency in LLVM's loop
442   // vectorization pipeline, wherein in some cases unrolling a loop can prevent
443   // effective vectorization.  Since we know that the IR we generate when
444   // reducing across the minor dimension in both LHS and RHS is vectorized well
445   // by the loop vectorizer, we block unrolling in that case to stop loop unroll
446   // from messing up the vectorization.
447   std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
448       0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction",
449       /*unroll_mode=*/
450       (lhs_reduction_along_minor_dimension &&
451        rhs_reduction_along_minor_dimension)
452           ? xla::llvm_ir::UnrollMode::kNoUnroll
453           : xla::llvm_ir::UnrollMode::kDefaultUnroll);
454 
455   // The final entry in the rhs and lhs indexes is the indvar of the
456   // reduction loop.
457   lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
458   llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_shape,
459                                     b_->getInt64Ty());
460   rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
461   llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_shape,
462                                     b_->getInt64Ty());
463 
464   // For computing the sum of products we alloca a single location to store the
465   // dot product result as we accumulate it within the reduction loop. After the
466   // reduction loop we load the result and store into the output array.
467 
468   // Function entry basic block.
469   // - Emit alloca for accumulator
470   llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent();
471   SetToFirstInsertPoint(&func->getEntryBlock(), b_);
472   llvm::Type* accum_type = target_array_.GetElementLlvmType();
473   llvm::Value* accum_address =
474       b_->CreateAlloca(accum_type, /*ArraySize=*/nullptr, "accum_address");
475 
476   // Preheader basic block of reduction loop:
477   // - Initialize accumulator to zero.
478   llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
479   b_->SetInsertPoint(preheader_bb->getTerminator());
480 
481   b_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address);
482 
483   // Body basic block of reduction loop:
484   // - Load elements from lhs and rhs array.
485   // - Multiply lhs-element and rhs-element.
486   // - Load accumulator and add to product.
487   // - Store sum back into accumulator.
488   SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), b_);
489 
490   llvm::Value* lhs_element = lhs_array_.EmitReadArrayElement(lhs_index, b_);
491   llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, b_);
492 
493   llvm::Value* accum = b_->CreateLoad(accum_address);
494   llvm::Value* updated_accum;
495   if (ShapeUtil::ElementIsComplex(lhs_shape)) {
496     auto real = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {0}); };
497     auto imag = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {1}); };
498     llvm::Value* product_real =
499         b_->CreateFSub(b_->CreateFMul(real(lhs_element), real(rhs_element)),
500                        b_->CreateFMul(imag(lhs_element), imag(rhs_element)));
501     llvm::Value* product_imag =
502         b_->CreateFAdd(b_->CreateFMul(real(lhs_element), imag(rhs_element)),
503                        b_->CreateFMul(imag(lhs_element), real(rhs_element)));
504     updated_accum = b_->CreateInsertValue(
505         accum, b_->CreateFAdd(real(accum), product_real), {0});
506     updated_accum = b_->CreateInsertValue(
507         updated_accum, b_->CreateFAdd(imag(accum), product_imag), {1});
508   } else {
509     llvm::Value* product = b_->CreateFMul(lhs_element, rhs_element);
510     updated_accum = b_->CreateFAdd(accum, product);
511   }
512   b_->CreateStore(updated_accum, accum_address);
513 
514   // Exit basic block of reduction loop.
515   // - Load accumulator value (the result).
516   // - Store into output array.
517   SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), b_);
518 
519   llvm::Value* result = b_->CreateLoad(accum_address);
520 
521   // Create index into target address. The target index is the concatenation of
522   // the rhs and lhs indexes with the reduction dimensions removed. The terms
523   // from the rhs index are the lower dimensions in the index so we add them
524   // first.
525   std::vector<llvm::Value*> target_multi_index;
526   for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
527     if (dimension != lhs_reduction_dimension) {
528       target_multi_index.push_back(lhs_index[dimension]);
529     }
530   }
531   for (int dimension = 0; dimension < rhs_index.size(); ++dimension) {
532     if (dimension != rhs_reduction_dimension) {
533       target_multi_index.push_back(rhs_index[dimension]);
534     }
535   }
536 
537   llvm_ir::IrArray::Index target_index(
538       target_multi_index, target_array_.GetShape(), lhs_index.GetType());
539   target_array_.EmitWriteArrayElement(target_index, result, b_);
540 
541   // Set the IR builder insert point to the exit basic block of the outer most
542   // loop.
543   b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
544 }
545 
EmitScalarDot()546 Status DotOpEmitter::EmitScalarDot() {
547   // A scalar dot is just a scalar multiply.
548   llvm::Value* result;
549   // Use the same index_type for all tensor accesses in the same kernel.
550   llvm::Type* index_type = b_->getInt64Ty();
551   llvm_ir::IrArray::Index element_index(index_type);
552   llvm::Value* lhs_value =
553       lhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
554   llvm::Value* rhs_value =
555       rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
556   if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
557     auto get_real = [&](llvm::Value* x) {
558       return b_->CreateExtractValue(x, {0});
559     };
560 
561     auto get_imag = [&](llvm::Value* x) {
562       return b_->CreateExtractValue(x, {1});
563     };
564 
565     llvm::Value* real = b_->CreateFSub(
566         b_->CreateFMul(get_real(lhs_value), get_real(rhs_value)),
567         b_->CreateFMul(get_imag(lhs_value), get_imag(rhs_value)));
568     llvm::Value* imag = b_->CreateFAdd(
569         b_->CreateFMul(get_real(lhs_value), get_imag(rhs_value)),
570         b_->CreateFMul(get_imag(lhs_value), get_real(rhs_value)));
571     result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
572     result = b_->CreateInsertValue(result, real, {0});
573     result = b_->CreateInsertValue(result, imag, {1});
574   } else {
575     result = b_->CreateFMul(lhs_value, rhs_value);
576   }
577   target_array_.EmitWriteArrayElement(/*index=*/element_index, result, b_);
578   return Status::OK();
579 }
580 
EmitCallToRuntime()581 Status DotOpEmitter::EmitCallToRuntime() {
582   // The signature of the Eigen runtime matmul function is:
583   //
584   //   (void)(void* run_options, float* out, float* lhs, float* rhs,
585   //          int64 m, int64 n, int64 k, int32 transpose_lhs,
586   //          int32 transpose_rhs);
587   // The two transpose_... parameters are actually booleans, but we use int32
588   // to avoid target-dependent calling convention details.
589 
590   bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_);
591   bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
592   PrimitiveType type = target_array_.GetShape().element_type();
593   llvm::Type* float_type;
594   const char* fn_name;
595   switch (type) {
596     case F16:
597       fn_name = multi_threaded
598                     ? runtime::kEigenMatMulF16SymbolName
599                     : runtime::kEigenSingleThreadedMatMulF16SymbolName;
600       float_type = b_->getHalfTy();
601       break;
602     case F32:
603       fn_name = multi_threaded
604                     ? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName
605                                    : runtime::kEigenMatMulF32SymbolName)
606                     : (use_mkl_dnn
607                            ? runtime::kMKLSingleThreadedMatMulF32SymbolName
608                            : runtime::kEigenSingleThreadedMatMulF32SymbolName);
609       float_type = b_->getFloatTy();
610       break;
611     case F64:
612       fn_name = multi_threaded
613                     ? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName
614                                    : runtime::kEigenMatMulF64SymbolName)
615                     : (use_mkl_dnn
616                            ? runtime::kMKLSingleThreadedMatMulF64SymbolName
617                            : runtime::kEigenSingleThreadedMatMulF64SymbolName);
618       float_type = b_->getDoubleTy();
619       break;
620     default:
621       return Unimplemented("Invalid type %s for dot operation",
622                            PrimitiveType_Name(type));
623   }
624 
625   llvm::Type* float_ptr_type = float_type->getPointerTo();
626   llvm::Type* int64_type = b_->getInt64Ty();
627   llvm::Type* int32_type = b_->getInt32Ty();
628   llvm::Type* int8_ptr_type = b_->getInt8Ty()->getPointerTo();
629   llvm::FunctionType* matmul_type = llvm::FunctionType::get(
630       b_->getVoidTy(),
631       {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
632        int64_type, int64_type, int64_type, int32_type, int32_type},
633       /*isVarArg=*/false);
634 
635   llvm::Function* function = b_->GetInsertBlock()->getParent();
636   llvm::Module* module = function->getParent();
637 
638   llvm::FunctionCallee matmul_func =
639       module->getOrInsertFunction(fn_name, matmul_type);
640   if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
641     fn->setCallingConv(llvm::CallingConv::C);
642     fn->setDoesNotThrow();
643     fn->setOnlyAccessesArgMemory();
644   }
645 
646   // The Eigen runtime function expects column-major layout. If the matrices are
647   // row major, then use the following identity to compute the product:
648   //
649   //   (A x B)^T = B^T x A^T
650   //
651   // The connection between this identity and memory layout is that the
652   // transpose operation can also be considered as an operation that changes the
653   // memory layout of a matrix from row-major to column-major or vice versa.
654   //
655   // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
656 
657   MatMultDims mat_mult_dims = GetMatMultDims();
658 
659   CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
660 
661   const llvm_ir::IrArray* lhs = &lhs_array_;
662   const llvm_ir::IrArray* rhs = &rhs_array_;
663   bool transpose_lhs = mat_mult_dims.lhs_non_canonical;
664   bool transpose_rhs = mat_mult_dims.rhs_non_canonical;
665 
666   if (!mat_mult_dims.lhs_column_major) {
667     std::swap(mat_mult_dims.m, mat_mult_dims.n);
668     std::swap(lhs, rhs);
669     std::swap(transpose_lhs, transpose_rhs);
670   }
671 
672   b_->CreateCall(
673       matmul_func,
674       {b_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
675        b_->CreateBitCast(target_array_.GetBasePointer(), float_ptr_type),
676        b_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
677        b_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
678        b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n),
679        b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs),
680        b_->getInt32(transpose_rhs)});
681   return Status::OK();
682 }
683 
GetMatMultDims() const684 DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
685   CHECK_EQ(dot_info_.result_shape.dimensions_size(), 2);
686 
687   const Shape& lhs_shape = lhs_array_.GetShape();
688   const Shape& rhs_shape = rhs_array_.GetShape();
689   const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
690 
691   return {
692       /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)),
693       /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
694       /*n=*/rhs_shape.dimensions(1 - dim_nums.rhs_contracting_dimensions(0)),
695       /*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0,
696       /*lhs_non_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 0,
697       /*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0,
698       /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1,
699       /*target_column_major=*/
700       LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0};
701 }
702 
703 // For vector-matrix dot products, it is always profitable to make the Rhs
704 // column major.
ProfitableToMakeDotOperandColumnMajor(const HloInstruction & hlo)705 absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
706     const HloInstruction& hlo) {
707   if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 &&
708       hlo.shape().dimensions(0) == 1) {
709     if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) {
710       return 1;
711     }
712     return {};
713   }
714 
715   if (hlo.opcode() == HloOpcode::kFusion &&
716       hlo.fusion_kind() == HloInstruction::FusionKind::kOutput) {
717     auto* fusion_root =
718         hlo.fused_instructions_computation()->root_instruction();
719     if (fusion_root->opcode() != HloOpcode::kAdd) {
720       return {};
721     }
722 
723     for (auto* fusion_root_op : fusion_root->operands()) {
724       if (fusion_root_op->opcode() != HloOpcode::kDot) {
725         continue;
726       }
727       if (auto operand_num =
728               ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) {
729         auto* operand = fusion_root_op->operand(*operand_num);
730         if (operand->opcode() == HloOpcode::kParameter &&
731             operand->user_count() == 1) {
732           return operand->parameter_number();
733         }
734       }
735     }
736   }
737 
738   return {};
739 }
740 
741 namespace {
742 // Return whether the given shape is rank 2.
IsRank2(const Shape & shape)743 bool IsRank2(const Shape& shape) { return shape.rank() == 2; }
744 
IsSimpleLayout(const Layout & layout)745 bool IsSimpleLayout(const Layout& layout) {
746   return layout.tiles().empty() && layout.format() == DENSE;
747 }
748 
749 // In a gemm operation where output = lhs * rhs, check whether the given shapes
750 // are valid for the operation.
AreGemmShapes(const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,const TargetMachineFeatures & target_machine_features)751 bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
752                    const Shape& output_shape,
753                    const TargetMachineFeatures& target_machine_features) {
754   CHECK(!lhs_shape.has_layout() || IsSimpleLayout(lhs_shape.layout()))
755       << lhs_shape.DebugString();
756   CHECK(!rhs_shape.has_layout() || IsSimpleLayout(rhs_shape.layout()))
757       << rhs_shape.DebugString();
758   CHECK(!output_shape.has_layout() || IsSimpleLayout(output_shape.layout()))
759       << output_shape.DebugString();
760 
761   switch (output_shape.element_type()) {
762     case F64:
763     case F32:
764     case F16:
765       return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape);
766     default:
767       return false;
768   }
769 }
770 
IsAlignedGemm(const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)771 bool IsAlignedGemm(const DotInfo& dot_info,
772                    const TargetMachineFeatures& target_machine_features) {
773   if (ShapeUtil::IsZeroElementArray(dot_info.lhs_shape) ||
774       ShapeUtil::IsZeroElementArray(dot_info.rhs_shape)) {
775     return false;
776   }
777 
778   return AreGemmShapes(dot_info.lhs_shape, dot_info.rhs_shape,
779                        dot_info.result_shape, target_machine_features);
780 }
781 
CanEmitTiledLlvmIrGemm(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)782 bool CanEmitTiledLlvmIrGemm(
783     const HloModuleConfig& config, const DotInfo& dot_info,
784     const TargetMachineFeatures& target_machine_features) {
785   CHECK(IsAlignedGemm(dot_info, target_machine_features));
786 
787   if (ShouldUseMultiThreadedEigen(config)) {
788     return false;
789   }
790 
791   int m = dot_info.result_shape.dimensions(0);
792   int k = dot_info.lhs_shape.dimensions(
793       dot_info.dim_nums.lhs_contracting_dimensions(0));
794   int n = dot_info.result_shape.dimensions(1);
795 
796   if (!options::ForceEnableExperimentalLlvmIrGemm(config)) {
797     // TODO(sanjoy):  We should make these numbers micro-arch specific.
798     bool small_gemm =
799         k <= 128 && ((m <= 32 && n <= 128) || (m <= 128 && n <= 32));
800     if (!small_gemm) {
801       return false;
802     }
803   }
804 
805   bool lhs_non_canonical = dot_info.dim_nums.lhs_contracting_dimensions(0) == 0;
806   bool rhs_non_canonical = dot_info.dim_nums.rhs_contracting_dimensions(0) == 1;
807 
808   if (lhs_non_canonical || rhs_non_canonical) {
809     return false;
810   }
811 
812   if (dot_info.result_shape.element_type() == F16) {
813     // TODO(sanjoy): This is probably easy to fix, but I want to keep the CL
814     // adding this comment NFC.
815     return false;
816   }
817 
818   return true;
819 }
820 
GetDotImplementationStrategy(const HloModuleConfig & config,const DotInfo & dot_info,const TargetMachineFeatures & target_machine_features)821 DotImplementationStrategy GetDotImplementationStrategy(
822     const HloModuleConfig& config, const DotInfo& dot_info,
823     const TargetMachineFeatures& target_machine_features) {
824   PrimitiveType element_type = dot_info.result_shape.element_type();
825   // Any Matrix-Vector product of floating point or integral type, or
826   // a transpose-dot fusion of the same can be lowered to a tiled LLVM
827   // IR implementation.
828   if (dot_info.result_shape.dimensions_size() == 2 &&
829       (dot_info.result_shape.dimensions(0) == 1 ||
830        dot_info.result_shape.dimensions(1) == 1) &&
831       (primitive_util::IsFloatingPointType(element_type) ||
832        primitive_util::IsIntegralType(element_type))) {
833     return DotImplementationStrategy::kTiledLlvmIrGemv;
834   }
835 
836   if (IsAlignedGemm(dot_info, target_machine_features)) {
837     return CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)
838                ? DotImplementationStrategy::kTiledLlvmIrGemm
839                : DotImplementationStrategy::kEigen;
840   }
841 
842   return DotImplementationStrategy::kNaiveLlvmIr;
843 }
844 
EmitNonBatchDotOperation(DotInfo dot_info,string hlo_name,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)845 Status EmitNonBatchDotOperation(
846     DotInfo dot_info, string hlo_name, const llvm_ir::IrArray& target_array,
847     const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
848     const llvm_ir::IrArray* addend_array,
849     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
850     const HloModuleConfig& hlo_module_config,
851     const TargetMachineFeatures& target_machine_features) {
852   PrimitiveType type = target_array.GetShape().element_type();
853   TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type ||
854                C128 == type);
855   DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
856                            target_array, lhs_array, rhs_array, addend_array,
857                            executable_run_options_value, b, hlo_module_config,
858                            target_machine_features);
859   return dot_emitter.Emit();
860 }
861 
DropFirstDim(const Shape & shape)862 Shape DropFirstDim(const Shape& shape) {
863   absl::Span<int64 const> array_shape_dims(shape.dimensions());
864   array_shape_dims.remove_prefix(1);
865   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
866                                                   array_shape_dims);
867 }
868 
CollapseFirstNDims(const Shape & shape,int64 n)869 Shape CollapseFirstNDims(const Shape& shape, int64 n) {
870   absl::Span<int64 const> input_shape_dims(shape.dimensions());
871   int64 prefix_dim =
872       std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n,
873                       1ll, std::multiplies<int64>());
874   DimensionVector result_dims;
875   result_dims.push_back(prefix_dim);
876   std::copy(input_shape_dims.begin() + n, input_shape_dims.end(),
877             std::back_inserter(result_dims));
878   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
879                                                   result_dims);
880 }
881 
CollapseFirstNDims(llvm::IRBuilder<> * b,const llvm_ir::IrArray & array,int64 n)882 llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilder<>* b,
883                                     const llvm_ir::IrArray& array, int64 n) {
884   llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
885   const Shape& shape = array.GetShape();
886   CHECK(shape.has_layout() &&
887         LayoutUtil::IsMonotonicWithDim0Major(shape.layout()));
888   CHECK_GE(shape.dimensions_size(), n);
889   Shape new_shape = CollapseFirstNDims(shape, n);
890   llvm::Value* new_value = b->CreateBitCast(
891       array.GetBasePointer(),
892       llvm_ir::ShapeToIrType(new_shape, module)->getPointerTo());
893   return llvm_ir::IrArray(new_value, std::move(new_shape));
894 }
895 
ValidateDotDimensionNumbers(const DotDimensionNumbers & dim_numbers)896 Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) {
897   // Checks some invariants that do not hold in general, but DotDecomposer
898   // should have established for us.  This is just a debugging aid.
899   TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
900   std::vector<int64> batch_dim_numbers(dim_numbers.lhs_batch_dimensions_size());
901   absl::c_iota(batch_dim_numbers, 0);
902   TF_RET_CHECK(
903       absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions()));
904   TF_RET_CHECK(
905       absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions()));
906   return Status::OK();
907 }
908 
909 // Slice out the inner array at batch index `batch_index` from `outer_array`.
SliceOutInnerArray(llvm_ir::IrArray outer_array,llvm::Value * batch_index,llvm::IRBuilder<> * b)910 llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array,
911                                     llvm::Value* batch_index,
912                                     llvm::IRBuilder<>* b) {
913   llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
914 
915   Shape inner_shape = DropFirstDim(outer_array.GetShape());
916   std::vector<llvm::Value*> multidim_index(inner_shape.rank() + 1,
917                                            b->getInt64(0));
918   multidim_index[0] = batch_index;
919   llvm_ir::IrArray::Index slice_index(multidim_index, outer_array.GetShape(),
920                                       batch_index->getType());
921   llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b);
922   llvm::Type* slice_ptr_type =
923       llvm_ir::ShapeToIrType(inner_shape, module)->getPointerTo();
924   return llvm_ir::IrArray(b->CreateBitCast(slice_ptr, slice_ptr_type),
925                           std::move(inner_shape));
926 }
927 
EmitBatchDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)928 Status EmitBatchDotOperation(
929     const HloInstruction& dot, const llvm_ir::IrArray& target_array,
930     const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
931     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
932     const HloModuleConfig& hlo_module_config,
933     const TargetMachineFeatures& target_machine_features) {
934   TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers()));
935 
936   // Lower a batch dot into a sequence of non-batch dot operations.
937 
938   int64 num_batch_dims =
939       dot.dot_dimension_numbers().lhs_batch_dimensions_size();
940 
941   // First reshape the inputs to make sure we only have one batch dimension.
942   // This is a no-op bitcast because the operands have to be in row-major layout
943   // (enforced in CpuLayoutAssignment), and the batch dimensions are the leading
944   // dimensions (established by DotDecomposer and checked by
945   // ValidateDotDimensionNumbers above).
946   llvm_ir::IrArray lhs_array_reshaped =
947       CollapseFirstNDims(b, lhs_array, num_batch_dims);
948   llvm_ir::IrArray rhs_array_reshaped =
949       CollapseFirstNDims(b, rhs_array, num_batch_dims);
950   llvm_ir::IrArray target_array_reshaped =
951       CollapseFirstNDims(b, target_array, num_batch_dims);
952 
953   int64 batch_count = lhs_array_reshaped.GetShape().dimensions(0);
954 
955   KernelSupportLibrary ksl(b);
956 
957   return ksl.ForWithStatus(
958       llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count,
959       /*step=*/1, [&](llvm::Value* indvar) {
960         DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers();
961         adjusted_dim_numbers.clear_lhs_batch_dimensions();
962         adjusted_dim_numbers.clear_rhs_batch_dimensions();
963 
964         // Create a DotInfo representing the "inner" non-batch dot operation.
965         DotInfo dot_info;
966         dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape());
967         dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape());
968         dot_info.result_shape = DropFirstDim(target_array_reshaped.GetShape());
969         dot_info.dim_nums = dot.dot_dimension_numbers();
970         dot_info.dim_nums.clear_lhs_batch_dimensions();
971         dot_info.dim_nums.clear_rhs_batch_dimensions();
972 
973         dot_info.dim_nums.set_lhs_contracting_dimensions(
974             0,
975             dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims);
976         dot_info.dim_nums.set_rhs_contracting_dimensions(
977             0,
978             dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims);
979 
980         llvm_ir::IrArray lhs_slice =
981             SliceOutInnerArray(lhs_array_reshaped, /*batch_index=*/indvar, b);
982         llvm_ir::IrArray rhs_slice =
983             SliceOutInnerArray(rhs_array_reshaped, /*batch_index=*/indvar, b);
984         llvm_ir::IrArray target_slice = SliceOutInnerArray(
985             target_array_reshaped, /*batch_index=*/indvar, b);
986 
987         // Emit the inner non-batch dot operation.
988         return EmitNonBatchDotOperation(
989             dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr,
990             executable_run_options_value, b, hlo_module_config,
991             target_machine_features);
992       });
993 }
994 
IsBatchDot(const HloInstruction & instr)995 bool IsBatchDot(const HloInstruction& instr) {
996   if (auto* dot_instr = DynCast<HloDotInstruction>(&instr)) {
997     return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0;
998   }
999 
1000   return false;
1001 }
1002 }  // namespace
1003 
DotImplementationCanHandleTranspose(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1004 bool DotImplementationCanHandleTranspose(
1005     const HloInstruction& dot_instr,
1006     const TargetMachineFeatures& target_machine_features) {
1007   DotImplementationStrategy impl_strategy =
1008       GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1009                                    DotInfo(dot_instr), target_machine_features);
1010 
1011   // TODO(sanjoy): This is not quite right, it should be `impl_strategy ==
1012   // kEigen || impl_strategy == kTiledLlvmIrGemv || impl_strategy ==
1013   // kNaiveLlvmIr` but I'll fix this in a later CL in the interest of keeping
1014   // the CL adding this comment NFC.
1015   return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm ||
1016          impl_strategy == DotImplementationStrategy::kEigen;
1017 }
1018 
DotOperandsAndResultMustHaveRowMajorLayout(const HloInstruction & dot_instr,const TargetMachineFeatures & target_machine_features)1019 bool DotOperandsAndResultMustHaveRowMajorLayout(
1020     const HloInstruction& dot_instr,
1021     const TargetMachineFeatures& target_machine_features) {
1022   DotImplementationStrategy impl_strategy =
1023       GetDotImplementationStrategy(dot_instr.parent()->parent()->config(),
1024                                    DotInfo(dot_instr), target_machine_features);
1025 
1026   return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm ||
1027          impl_strategy == DotImplementationStrategy::kEigen;
1028 }
1029 
EmitDotOperation(const HloInstruction & dot,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & lhs_array,const llvm_ir::IrArray & rhs_array,const llvm_ir::IrArray * addend_array,llvm::Value * executable_run_options_value,llvm::IRBuilder<> * b,const HloModuleConfig & hlo_module_config,const TargetMachineFeatures & target_machine_features)1030 Status EmitDotOperation(const HloInstruction& dot,
1031                         const llvm_ir::IrArray& target_array,
1032                         const llvm_ir::IrArray& lhs_array,
1033                         const llvm_ir::IrArray& rhs_array,
1034                         const llvm_ir::IrArray* addend_array,
1035                         llvm::Value* executable_run_options_value,
1036                         llvm::IRBuilder<>* b,
1037                         const HloModuleConfig& hlo_module_config,
1038                         const TargetMachineFeatures& target_machine_features) {
1039   // This routine assumes that the dot operation is not in a parallelized
1040   // enclosing computation.
1041   CHECK(dot.parent()->root_instruction()->outer_dimension_partitions().empty());
1042 
1043   if (IsBatchDot(dot)) {
1044     TF_RET_CHECK(addend_array == nullptr);
1045     return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array,
1046                                  executable_run_options_value, b,
1047                                  hlo_module_config, target_machine_features);
1048   }
1049 
1050   return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array,
1051                                   lhs_array, rhs_array, addend_array,
1052                                   executable_run_options_value, b,
1053                                   hlo_module_config, target_machine_features);
1054 }
1055 }  // namespace cpu
1056 }  // namespace xla
1057