• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
17 
18 #include <string>
19 #include <unordered_map>
20 #include <utility>
21 
22 #include "tensorflow/core/platform/logging.h"
23 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
24 #include "absl/algorithm/container.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/Module.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
31 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
32 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
33 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
34 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
35 #include "tensorflow/compiler/xla/service/hlo_computation.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
37 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
38 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
42 #include "tensorflow/compiler/xla/service/name_uniquer.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/compiler/xla/window_util.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 
50 namespace xla {
51 
52 using llvm_ir::IrName;
53 using llvm_ir::SetToFirstInsertPoint;
54 
55 namespace gpu {
56 
IrEmitter(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context,bool is_nested)57 IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
58                      IrEmitterContext* ir_emitter_context, bool is_nested)
59     : ir_emitter_context_(ir_emitter_context),
60       module_(ir_emitter_context->llvm_module()),
61       b_(module_->getContext()),
62       bindings_(ir_emitter_context->hlo_module(),
63                 &ir_emitter_context->buffer_assignment(), &b_, module_,
64                 is_nested),
65       hlo_module_config_(hlo_module_config) {
66 }
67 
DefaultAction(HloInstruction * hlo)68 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
69   ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
70   for (const HloInstruction* operand : hlo->operands()) {
71     operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
72       return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_);
73     };
74   }
75   return EmitTargetElementLoop(
76       *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &b_,
77                                   GetNestedComputer())
78                 .MakeElementGenerator(hlo, operand_to_generator));
79 }
80 
HandleConstant(HloInstruction * constant)81 Status IrEmitter::HandleConstant(HloInstruction* constant) {
82   return Status::OK();
83 }
84 
HandleBitcast(HloInstruction * bitcast)85 Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
86   VLOG(2) << "HandleBitcast: " << bitcast->ToString();
87   const HloInstruction* operand = bitcast->operand(0);
88   // Bitcast is a no-op, but we still want to bind it to an llvm::Value
89   // sometimes, e.g., when it's operand is a constant or a bitcast of a
90   // constant.
91   if (bindings_.BoundToIrValue(*operand)) {
92     bindings_.BindHloToIrValue(*bitcast, GetBasePointer(*operand));
93   }
94   return Status::OK();
95 }
96 
HandleAddDependency(HloInstruction * add_dependency)97 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
98   VLOG(2) << "HandleAddDependency: " << add_dependency->ToString();
99   const HloInstruction* operand = add_dependency->operand(0);
100   // Add_Dependency is a no-op, but we still want to bind it to an llvm::Value
101   // sometimes, e.g., when it's operand is a constant or a bitcast of a
102   // constant.
103   if (bindings_.BoundToIrValue(*operand)) {
104     bindings_.BindHloToIrValue(*add_dependency, GetBasePointer(*operand));
105   }
106   return Status::OK();
107 }
108 
HandleGetTupleElement(HloInstruction * get_tuple_element)109 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
110   auto operand = get_tuple_element->operand(0);
111   CHECK(bindings_.BoundToIrValue(*operand));
112   bindings_.BindHloToIrValue(
113       *get_tuple_element,
114       llvm_ir::EmitGetTupleElement(
115           get_tuple_element->shape(), get_tuple_element->tuple_index(),
116           // TODO(b/26344050): tighten the alignment here
117           // based on the real element type.
118           /*alignment=*/1, GetBasePointer(*operand), &b_));
119   return Status::OK();
120 }
121 
HandleSend(HloInstruction *)122 Status IrEmitter::HandleSend(HloInstruction*) {
123   return Unimplemented("Send is not implemented on GPU");
124 }
125 
HandleSendDone(HloInstruction *)126 Status IrEmitter::HandleSendDone(HloInstruction*) {
127   return Unimplemented("Send-Done is not implemented on GPU");
128 }
129 
HandleRecv(HloInstruction *)130 Status IrEmitter::HandleRecv(HloInstruction*) {
131   return Unimplemented("Recv is not implemented on GPU");
132 }
133 
HandleRecvDone(HloInstruction *)134 Status IrEmitter::HandleRecvDone(HloInstruction*) {
135   return Unimplemented("Recv-done is not implemented on GPU");
136 }
137 
HandleScatter(HloInstruction *)138 Status IrEmitter::HandleScatter(HloInstruction*) {
139   return Unimplemented("Scatter is not implemented on GPUs.");
140 }
141 
HandleTuple(HloInstruction * tuple)142 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
143   std::vector<llvm::Value*> base_ptrs;
144   for (const HloInstruction* operand : tuple->operands()) {
145     base_ptrs.push_back(GetBasePointer(*operand));
146   }
147   llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_);
148   return Status::OK();
149 }
150 
EmitCallToNestedComputation(const HloComputation & nested_computation,absl::Span<llvm::Value * const> operands,llvm::Value * output)151 Status IrEmitter::EmitCallToNestedComputation(
152     const HloComputation& nested_computation,
153     absl::Span<llvm::Value* const> operands, llvm::Value* output) {
154   TF_RET_CHECK(nested_computation.num_parameters() > 0);
155   llvm::Function*& emitted_function =
156       computation_to_ir_function_[&nested_computation];
157   if (emitted_function == nullptr) {
158     IrEmitterNested ir_emitter_nested(hlo_module_config_, nested_computation,
159                                       ir_emitter_context_);
160     TF_RETURN_IF_ERROR(
161         nested_computation.root_instruction()->Accept(&ir_emitter_nested));
162     emitted_function = ir_emitter_nested.GetEmittedFunction();
163   }
164 
165   std::vector<llvm::Value*> arguments(operands.begin(), operands.end());
166   arguments.push_back(output);
167   arguments.push_back(bindings_.GetTempBufferBase());
168   Call(emitted_function, arguments);
169 
170   return Status::OK();
171 }
172 
MaybeEmitDirectAtomicOperation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)173 bool IrEmitter::MaybeEmitDirectAtomicOperation(
174     const HloComputation& computation, llvm::Value* output_address,
175     llvm::Value* source_address) {
176   CHECK_EQ(2, computation.num_parameters());
177 
178   if (computation.instruction_count() != 3) {
179     // We special-case only computations with one computing instruction for now.
180     // Such computation has exactly three instructions given it has two
181     // parameters.
182     return false;
183   }
184 
185   HloOpcode root_opcode = computation.root_instruction()->opcode();
186   PrimitiveType element_type =
187       computation.root_instruction()->shape().element_type();
188   bool is_atomic_integral = element_type == S32 || element_type == U32 ||
189                             element_type == S64 || element_type == U64;
190   llvm::Value* source = Load(source_address, "source");
191 
192   // kCopy of RHS -> atomic store.
193   if (root_opcode == HloOpcode::kCopy &&
194       (element_type == F32 || is_atomic_integral) &&
195       computation.root_instruction()->operand(0)->opcode() ==
196           HloOpcode::kParameter &&
197       computation.root_instruction()->operand(0)->parameter_number() == 1) {
198     llvm::StoreInst* store = Store(source, output_address);
199     store->setAtomic(llvm::AtomicOrdering::Unordered);
200     // Derive a minimum alignment from the type. The optimizer can increase it
201     // later.
202     store->setAlignment(ShapeUtil::ByteSizeOfPrimitiveType(element_type));
203     return true;
204   }
205 
206   if (root_opcode == HloOpcode::kAdd) {
207     // NVPTX supports atomicAdd on F32 and integer types.
208     if (element_type == F32) {
209       // F32 + F32
210       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_atomic_load_add_f32,
211                                    {output_address, source},
212                                    {output_address->getType()}, &b_);
213       return true;
214     }
215     if (is_atomic_integral) {
216       // integral + integral
217       AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
218                 llvm::AtomicOrdering::SequentiallyConsistent);
219       return true;
220     }
221   }
222 
223   // NVPTX supports atomicMax and atomicMin only on integer types.
224   if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) {
225     // max(integral, integral)
226     auto opcode = primitive_util::IsSignedIntegralType(element_type)
227                       ? llvm::AtomicRMWInst::Max
228                       : llvm::AtomicRMWInst::UMax;
229     AtomicRMW(opcode, output_address, source,
230               llvm::AtomicOrdering::SequentiallyConsistent);
231     return true;
232   }
233 
234   if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) {
235     // min(integral, integral)
236     auto opcode = primitive_util::IsSignedIntegralType(element_type)
237                       ? llvm::AtomicRMWInst::Min
238                       : llvm::AtomicRMWInst::UMin;
239     AtomicRMW(opcode, output_address, source,
240               llvm::AtomicOrdering::SequentiallyConsistent);
241     return true;
242   }
243 
244   return false;
245 }
246 
247 // Implements atomic binary operations using atomic compare-and-swap
248 // (atomicCAS) as follows:
249 //   1. Reads the value from the memory pointed to by output_address and
250 //     records it as old_output.
251 //   2. Uses old_output as one of the source operand to perform the binary
252 //     operation and stores the result in new_output.
253 //   3. Calls atomicCAS which implements compare-and-swap as an atomic
254 //     operation. In particular, atomicCAS reads the value from the memory
255 //     pointed to by output_address, and compares the value with old_output. If
256 //     the two values equal, new_output is written to the same memory location
257 //     and true is returned to indicate that the atomic operation succeeds.
258 //     Otherwise, the new value read from the memory is returned. In this case,
259 //     the new value is copied to old_output, and steps 2. and 3. are repeated
260 //     until atomicCAS succeeds.
261 //
262 // On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If
263 // the element type of the binary operation is 32 bits or 64 bits, the integer
264 // type of the same size is used for the atomicCAS operation. On the other hand,
265 // if the element type is smaller than 32 bits, int32 is used for the atomicCAS
266 // operation. In this case, atomicCAS reads and writes 32 bit values from
267 // the memory, which is larger than the memory size required by the original
268 // atomic binary operation. We mask off the last two bits of the output_address
269 // and use the result as an address to read the 32 bit values from the memory.
270 // This can avoid out of bound memory accesses if tensor buffers are 4 byte
271 // aligned and have a size of 4N, an assumption that the runtime can guarantee.
272 //
273 // The pseudo code is shown below. Variables *_address are pointers to a memory
274 // region with a size equal to the size of the atomicCAS operation, with the
275 // exception that new_output_address is a pointer to a memory region with a size
276 // equal to the element size of the binary operation.
277 //
278 //   element_size = sizeof(element_type);
279 //   atomic_size = max(32, element_size);
280 //   cas_new_output_address = alloca(atomic_size);
281 //   cas_old_output_address = alloca(atomic_size);
282 //   if (atomic_size != element_size) {
283 //     atomic_address = output_address & ((int64)(-4));
284 //     new_output_address = cas_new_output_address + (output_address & 3);
285 //   } else {
286 //     atomic_address = output_address;
287 //     new_output_address = cas_new_output_address;
288 //   }
289 //
290 //   *cas_old_output_address = *atomic_address;
291 //   do {
292 //     *cas_new_output_address = *cas_old_output_address;
293 //     *new_output_address = operation(*new_output_address, *source_address);
294 //     (*cas_old_output_address, success) =
295 //       atomicCAS(atomic_address, *cas_old_output_address,
296 //       *cas_new_output_address);
297 //   } while (!success);
298 //
EmitAtomicOperationUsingCAS(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)299 Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
300                                               llvm::Value* output_address,
301                                               llvm::Value* source_address) {
302   llvm::PointerType* output_address_type =
303       llvm::dyn_cast<llvm::PointerType>(output_address->getType());
304   CHECK_NE(output_address_type, nullptr);
305 
306   // element_type is the data type for the binary operation.
307   llvm::Type* element_type = output_address_type->getPointerElementType();
308   int element_size = llvm_ir::GetSizeInBits(element_type);
309   llvm::Type* element_address_type = element_type->getPointerTo();
310 
311   int atomic_size = (element_size < 32) ? 32 : element_size;
312   llvm::Type* atomic_type = b_.getIntNTy(atomic_size);
313   llvm::Type* atomic_address_type =
314       atomic_type->getPointerTo(output_address_type->getPointerAddressSpace());
315 
316   // cas_old_output_address and cas_new_output_address point to the scratch
317   // memory where we store the old and new values for the repeated atomicCAS
318   // operations.
319   llvm::Value* cas_old_output_address =
320       Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address");
321   llvm::Value* cas_new_output_address =
322       Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address");
323 
324   // Emit preparation code to the preheader.
325   llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock();
326 
327   llvm::Value* atomic_memory_address;
328   // binop_output_address points to the scratch memory that stores the
329   // result of the binary operation.
330   llvm::Value* binop_output_address;
331   if (element_size < 32) {
332     // Assume the element size is an integer number of bytes.
333     CHECK_EQ((element_size % sizeof(char)), 0);
334     llvm::Type* address_int_type =
335         module_->getDataLayout().getIntPtrType(output_address_type);
336     atomic_memory_address = PtrToInt(output_address, address_int_type);
337     llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3);
338     llvm::Value* offset = And(atomic_memory_address, mask);
339     mask = llvm::ConstantInt::get(address_int_type, -4);
340     atomic_memory_address = And(atomic_memory_address, mask);
341     atomic_memory_address =
342         IntToPtr(atomic_memory_address, atomic_address_type);
343     binop_output_address =
344         Add(PtrToInt(cas_new_output_address, address_int_type), offset);
345     binop_output_address = IntToPtr(binop_output_address, element_address_type);
346   } else {
347     atomic_memory_address = BitCast(output_address, atomic_address_type);
348     binop_output_address =
349         BitCast(cas_new_output_address, element_address_type);
350   }
351 
352   // Use the value from the memory that atomicCAS operates on to initialize
353   // cas_old_output.
354   llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output");
355   Store(cas_old_output, cas_old_output_address);
356 
357   llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock(
358       b_.GetInsertPoint(), "atomic_op_loop_exit");
359   llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create(
360       b_.getContext(), "atomic_op_loop_body", b_.GetInsertBlock()->getParent());
361   b_.SetInsertPoint(loop_body_bb);
362   // Change preheader's successor from loop_exit_bb to loop_body_bb.
363   loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb);
364 
365   // Emit the body of the loop that repeatedly invokes atomicCAS.
366   //
367   // Use cas_old_output to initialize cas_new_output.
368   cas_old_output = Load(cas_old_output_address, "cas_old_output");
369   Store(cas_old_output, cas_new_output_address);
370   // Emits code to calculate new_output = operation(old_output, source);
371   TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
372       computation, {binop_output_address, source_address},
373       binop_output_address));
374 
375   llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output");
376 
377   // Emit code to perform the atomicCAS operation
378   // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output,
379   //                                       cas_new_output);
380   llvm::Value* ret_value =
381       AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output,
382                     llvm::AtomicOrdering::SequentiallyConsistent,
383                     llvm::AtomicOrdering::SequentiallyConsistent);
384 
385   // Extract the memory value returned from atomicCAS and store it as
386   // cas_old_output.
387   Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address);
388   // Extract the success bit returned from atomicCAS and generate a
389   // conditional branch on the success bit.
390   CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb);
391 
392   // Set the insertion point to the exit basic block so that the caller of
393   // this method can continue emitting code to the right place.
394   SetToFirstInsertPoint(loop_exit_bb, &b_);
395   return Status::OK();
396 }
397 
EmitAtomicOperationForNestedComputation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)398 Status IrEmitter::EmitAtomicOperationForNestedComputation(
399     const HloComputation& computation, llvm::Value* output_address,
400     llvm::Value* source_address) {
401   if (computation.num_parameters() != 2) {
402     // TODO(b/30258929): We only accept binary computations so far.
403     return Unimplemented(
404         "We only support atomic functions with exactly two parameters, but "
405         "computation %s has %d.",
406         computation.name(), computation.num_parameters());
407   }
408 
409   if (MaybeEmitDirectAtomicOperation(computation, output_address,
410                                      source_address)) {
411     return Status::OK();
412   }
413 
414   return EmitAtomicOperationUsingCAS(computation, output_address,
415                                      source_address);
416 }
417 
HandleSelect(HloInstruction * select)418 Status IrEmitter::HandleSelect(HloInstruction* select) {
419   auto pred = select->operand(0);
420   TF_RET_CHECK(pred->shape().element_type() == PRED);
421   // We must not call the subclass `DefaultAction` method, lest its
422   // `HandleSelect` call `IrEmitter::HandleSelect` and its `DefaultAction`
423   // assume no handler has already been called.
424   return IrEmitter::DefaultAction(select);
425 }
426 
HandleTupleSelect(HloInstruction * tuple_select)427 Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
428   auto pred = tuple_select->operand(0);
429   auto on_true = tuple_select->operand(1);
430   auto on_false = tuple_select->operand(2);
431   TF_RET_CHECK(pred->shape().element_type() == PRED);
432   TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()));
433   TF_RET_CHECK(tuple_select->shape().IsTuple());
434   llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select),
435                            GetIrArray(*pred, *tuple_select),
436                            GetBasePointer(*on_true), GetBasePointer(*on_false),
437                            &b_);
438   return Status::OK();
439 }
440 
441 namespace {
Real(llvm::Value * x,llvm::IRBuilder<> * b)442 llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* b) {
443   return b->CreateExtractValue(x, {0});
444 }
445 
Imag(llvm::Value * x,llvm::IRBuilder<> * b)446 llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* b) {
447   return b->CreateExtractValue(x, {1});
448 }
449 
MultiplyComplex(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)450 std::pair<llvm::Value*, llvm::Value*> MultiplyComplex(llvm::Value* lhs_value,
451                                                       llvm::Value* rhs_value,
452                                                       llvm::IRBuilder<>* b) {
453   llvm::Value* lhs_real = Real(lhs_value, b);
454   llvm::Value* lhs_imag = Imag(lhs_value, b);
455   llvm::Value* rhs_real = Real(rhs_value, b);
456   llvm::Value* rhs_imag = Imag(rhs_value, b);
457   llvm::Value* real_result1 = b->CreateFMul(lhs_real, rhs_real);
458   llvm::Value* real_result2 = b->CreateFMul(lhs_imag, rhs_imag);
459   llvm::Value* real_result = b->CreateFSub(real_result1, real_result2);
460   llvm::Value* imag_result1 = b->CreateFMul(lhs_real, rhs_imag);
461   llvm::Value* imag_result2 = b->CreateFMul(lhs_imag, rhs_real);
462   llvm::Value* imag_result = b->CreateFAdd(imag_result1, imag_result2);
463   return {real_result, imag_result};
464 }
465 }  // namespace
466 
HandleDot(HloInstruction * dot)467 Status IrEmitter::HandleDot(HloInstruction* dot) {
468   auto lhs_instruction = dot->operand(0);
469   auto rhs_instruction = dot->operand(1);
470   const llvm_ir::IrArray& target_array = GetIrArray(*dot, *dot);
471   const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction, *dot);
472   const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction, *dot);
473 
474   const Shape& lhs_shape = lhs_instruction->shape();
475   const Shape& rhs_shape = rhs_instruction->shape();
476   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
477   CHECK_EQ(dnums.lhs_batch_dimensions_size(),
478            dnums.rhs_batch_dimensions_size());
479 
480   // TODO(b/110211620): Convert to use i32 index_type when it is possible.
481   llvm::Type* index_type = b_.getInt64Ty();
482   llvm_ir::IrArray::Index element_index(index_type);
483   if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) {
484     // If the operands are scalar, don't emit any loops.
485     llvm::Value* lhs_value =
486         lhs_array.EmitReadArrayElement(/*index=*/element_index, &b_);
487     llvm::Value* rhs_value =
488         rhs_array.EmitReadArrayElement(/*index=*/element_index, &b_);
489     llvm::Value* result;
490     if (ShapeUtil::ElementIsComplex(lhs_shape)) {
491       auto value = MultiplyComplex(lhs_value, rhs_value, &b_);
492       result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType());
493       result = InsertValue(result, value.first, {0});
494       result = InsertValue(result, value.second, {1});
495     } else if (ShapeUtil::ElementIsFloating(lhs_shape)) {
496       result = FMul(lhs_value, rhs_value);
497     } else {
498       TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape));
499       result = Mul(lhs_value, rhs_value);
500     }
501     target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_);
502     return Status::OK();
503   }
504 
505   // "Scalar dot non-scalar" or "non-scalar dot scalar" is invalid. See
506   // the semantics of Dot in the XLA documentation for details.
507   TF_RET_CHECK(!ShapeUtil::IsScalar(lhs_shape) &&
508                !ShapeUtil::IsScalar(rhs_shape));
509 
510   const int64 lhs_reduction_dimension = dnums.lhs_contracting_dimensions(0);
511   const int64 rhs_reduction_dimension = dnums.rhs_contracting_dimensions(0);
512 
513   // Check that the batch dims don't cover the reduction dimensions.
514   for (int64 batch_dim : dnums.lhs_batch_dimensions()) {
515     CHECK_NE(lhs_reduction_dimension, batch_dim);
516     CHECK_NE(rhs_reduction_dimension, batch_dim);
517   }
518 
519   // Verify the reduction dimension in the two operands are the same size.
520   TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) ==
521                rhs_shape.dimensions(rhs_reduction_dimension))
522       << "lhs_shape.dimensions(" << lhs_reduction_dimension
523       << ") = " << lhs_shape.dimensions(lhs_reduction_dimension)
524       << ", and rhs_shape.dimensions(" << rhs_reduction_dimension
525       << ") = " << rhs_shape.dimensions(rhs_reduction_dimension);
526 
527   // Create loop nests which loop through the LHS operand dimensions and the RHS
528   // operand dimensions. The reduction dimension of the LHS and RHS are handled
529   // in a separate innermost loop which performs the sum of products.
530   llvm_ir::ForLoopNest loop_nest(IrName(dot), &b_);
531   std::vector<llvm::Value*> lhs_multi_index =
532       loop_nest.EmitOperandArrayLoopNest(
533           lhs_array, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
534   std::vector<llvm::Value*> rhs_multi_index =
535       loop_nest.EmitOperandArrayLoopNest(
536           rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
537 
538   // We don't have to iterate over the batch dimensions in both arrays, simplify
539   // the loop nest of the rhs.
540   for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) {
541     DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i));
542     rhs_multi_index[i] = lhs_multi_index[i];
543   }
544 
545   // Create the reduction loop which does the sum of products reduction.
546   std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
547       /*start_index=*/0,
548       /*end_index=*/lhs_shape.dimensions(lhs_reduction_dimension),
549       /*suffix=*/"reduction");
550 
551   // The final entry in the rhs and lhs indexes is the indvar of the reduction
552   // loop.
553   lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
554   rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
555 
556   // For computing the sum of products we alloca a single location to store the
557   // dot product result as we accumulate it within the reduction loop. After the
558   // reduction loop we load the result and store into the output array.
559   llvm::Type* accum_type = target_array.GetElementLlvmType();
560   llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry(
561       accum_type,       // The pointee type of the alloca instruction.
562       "accum_address",  // The name of the alloca instruction.
563       &b_);
564 
565   // Initialize the accumulator in the preheader to zero.
566   new llvm::StoreInst(
567       llvm::Constant::getNullValue(lhs_array.GetElementLlvmType()),  // init 0
568       accum_address,  // The address.
569       reduction_loop->GetPreheaderBasicBlock()
570           ->getTerminator());  // The instruction this store is inserted before.
571 
572   // Emit the body of the reduction loop:
573   //   accum = *accum_address
574   //   updated_accum = accum + lhs_element * rhs_element
575   //   *accum_address = updated_accum
576   TF_RET_CHECK(!reduction_loop->GetBodyBasicBlock()->empty());
577   b_.SetInsertPoint(
578       &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt());
579   llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_array.GetShape(),
580                                     b_.getInt64Ty());
581   llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_);
582   llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_array.GetShape(),
583                                     b_.getInt64Ty());
584   llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_);
585   llvm::Value* accum = Load(accum_address);
586   llvm::Value* updated_accum;
587   if (ShapeUtil::ElementIsComplex(lhs_shape)) {
588     auto value = MultiplyComplex(lhs_element, rhs_element, &b_);
589     llvm::Value* accum_real = Real(accum, &b_);
590     llvm::Value* real_sum = FAdd(accum_real, value.first);
591     updated_accum = InsertValue(accum, real_sum, {0});
592     llvm::Value* accum_imag = Imag(accum, &b_);
593     llvm::Value* imag_sum = FAdd(accum_imag, value.second);
594     updated_accum = InsertValue(updated_accum, imag_sum, {1});
595   } else if (ShapeUtil::ElementIsFloating(lhs_shape)) {
596     llvm::Value* product = FMul(lhs_element, rhs_element);
597     updated_accum = FAdd(accum, product);
598   } else {
599     TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape));
600     llvm::Value* product = Mul(lhs_element, rhs_element);
601     updated_accum = Add(accum, product);
602   }
603   Store(updated_accum, accum_address);
604 
605   // After the reduction loop exits, store the accumulator into the target
606   // address. The index into the target address is the concatenation of the rhs
607   // and lhs indexes with the reduction dimensions removed. The terms from the
608   // rhs index are the lower dimensions in the index so we add them first.
609   std::vector<llvm::Value*> target_multi_index;
610   for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) {
611     if (dimension != lhs_reduction_dimension) {
612       target_multi_index.push_back(lhs_index[dimension]);
613     }
614   }
615   // Skip over the batch dimensions to not have them in the index twice.
616   for (size_t dimension = dnums.lhs_batch_dimensions_size();
617        dimension < rhs_index.size(); ++dimension) {
618     if (dimension != rhs_reduction_dimension) {
619       target_multi_index.push_back(rhs_index[dimension]);
620     }
621   }
622   SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_);
623   llvm_ir::IrArray::Index target_index(target_multi_index,
624                                        target_array.GetShape(), index_type);
625   target_array.EmitWriteArrayElement(
626       target_index,
627       Load(accum_address),  // The value written to the target array.
628       &b_);
629 
630   // Set the IR builder insert point to the exit basic block of the outer most
631   // loop. This ensures later instructions are inserted after this loop nest.
632   b_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
633 
634   return Status::OK();
635 }
636 
HandleConvolution(HloInstruction * convolution)637 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
638   if (ShapeUtil::IsZeroElementArray(convolution->shape())) {
639     // Emit no code for an empty output.
640     return Status::OK();
641   }
642   // TODO(b/31409998): Support convolution with dilation.
643   return Unimplemented(
644       "Hit a case for convolution that is not implemented on GPU.");
645 }
646 
HandleFft(HloInstruction * fft)647 Status IrEmitter::HandleFft(HloInstruction* fft) {
648   if (ShapeUtil::IsZeroElementArray(fft->shape())) {
649     // Emit no code for an empty output.
650     return Status::OK();
651   }
652   return Unimplemented("Hit a case for fft that is not implemented on GPU.");
653 }
654 
HandleAllReduce(HloInstruction * crs)655 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
656   // TODO(b/33011107): Support cross replica sum on GPU.
657   return Unimplemented("AllReduce is not implemented on GPU.");
658 }
659 
HandleParameter(HloInstruction * parameter)660 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
661   return Status::OK();
662 }
663 
HandleReduce(HloInstruction * reduce)664 Status IrEmitter::HandleReduce(HloInstruction* reduce) {
665   // TODO(b/118332391): Support variadic reduce.
666   if (!reduce->shape().IsArray()) {
667     return Unimplemented("Variadic reduce is not supported on GPU");
668   }
669   auto arg = reduce->operand(0);
670   auto init_value = reduce->operand(1);
671   absl::Span<const int64> dimensions(reduce->dimensions());
672   HloComputation* function = reduce->to_apply();
673   return EmitTargetElementLoop(
674       *reduce,
675       [=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
676         // Initialize an accumulator with init_value.
677         llvm::AllocaInst* accumulator_addr =
678             Alloca(llvm_ir::PrimitiveTypeToIrType(
679                 reduce->shape().element_type(), module_));
680         Store(Load(GetBasePointer(*init_value)), accumulator_addr);
681 
682         // The enclosing loops go over all the target elements. Now we have to
683         // compute the actual target element. For this, we build a new loop nest
684         // to iterate over all the reduction dimensions in the argument.
685         // AddLoopsForShapeOnDimensions will return an Index where induction
686         // Value*s are placed for each dimension in dimensions, and all the rest
687         // are nullptrs.
688         llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
689         std::vector<llvm::Value*> input_multi_index =
690             loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
691                                                "reduction_dim");
692 
693         SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
694 
695         // Build a full index for the input argument, using reduced_dims_index
696         // as the base. In reduced_dims_index only the reduction dimensions are
697         // filled in. We fill in the rest of the dimensions with induction
698         // Value*s taken from 'index' which iterates over the target array.
699         // See the high-level description in the XLA documentation for details.
700         llvm_ir::IrArray::Index::const_iterator it = index.begin();
701 
702         for (auto& i : input_multi_index) {
703           if (i == nullptr) {
704             i = *it++;
705           }
706         }
707         CHECK(index.end() == it);
708 
709         // Apply the reduction function to the loaded value.
710         llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
711                                             b_.getInt64Ty());
712         llvm::Value* input_address =
713             GetIrArray(*arg, *reduce).EmitArrayElementAddress(input_index, &b_);
714         TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
715             *function, {accumulator_addr, input_address}, accumulator_addr));
716 
717         SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
718         return Load(accumulator_addr);
719       });
720 }
721 
HandleFusion(HloInstruction * fusion)722 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
723   // kFusion for library calls should be handled by
724   // IrEmitterUnnested::HandleFusion.
725   CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind());
726   GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
727                                           GetNestedComputer());
728   FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
729                                &elemental_emitter);
730   TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
731 
732   return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator());
733 }
734 
HandleCall(HloInstruction * call)735 Status IrEmitter::HandleCall(HloInstruction* call) {
736   std::vector<llvm::Value*> operand_addresses;
737   for (HloInstruction* operand : call->operands()) {
738     operand_addresses.push_back(GetBasePointer(*operand));
739   }
740   return EmitCallToNestedComputation(*call->to_apply(), operand_addresses,
741                                      GetBasePointer(*call));
742 }
743 
HandleCustomCall(HloInstruction *)744 Status IrEmitter::HandleCustomCall(HloInstruction*) {
745   return Unimplemented("custom-call");
746 }
747 
HandleInfeed(HloInstruction *)748 Status IrEmitter::HandleInfeed(HloInstruction*) {
749   // TODO(b/30467474): Implement infeed on GPU.
750   return Unimplemented("Infeed is not supported on GPU.");
751 }
752 
HandleOutfeed(HloInstruction *)753 Status IrEmitter::HandleOutfeed(HloInstruction*) {
754   // TODO(b/34359662): Implement outfeed on GPU.
755   return Unimplemented("Outfeed is not supported on GPU.");
756 }
757 
HandleBatchNormInference(HloInstruction *)758 Status IrEmitter::HandleBatchNormInference(HloInstruction*) {
759   return Unimplemented(
760       "The GPU backend does not implement BatchNormInference directly.  It "
761       "should be lowered before IR emission to HLO-soup using "
762       "BatchNormRewriter or to a cudnn CustomCall using "
763       "CudnnBatchNormRewriter.");
764 }
765 
HandleBatchNormTraining(HloInstruction *)766 Status IrEmitter::HandleBatchNormTraining(HloInstruction*) {
767   return Unimplemented(
768       "The GPU backend does not implement BatchNormTraining directly.  It "
769       "should be lowered before IR emission to HLO-soup using "
770       "BatchNormRewriter or to a cudnn CustomCall using "
771       "CudnnBatchNormRewriter.");
772 }
773 
HandleBatchNormGrad(HloInstruction *)774 Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
775   return Unimplemented(
776       "The GPU backend does not implement BatchNormGrad directly.  It should "
777       "be lowered before IR emission to HLO-soup (using BatchNormRewriter) or "
778       "to a cudnn CustomCall using CudnnBatchNormRewriter.");
779 }
780 
ComputeNestedElement(const HloComputation & computation,absl::Span<llvm::Value * const> parameter_elements)781 StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
782     const HloComputation& computation,
783     absl::Span<llvm::Value* const> parameter_elements) {
784   llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
785       llvm_ir::PrimitiveTypeToIrType(
786           computation.root_instruction()->shape().element_type(), module_),
787       "return_buffer", &b_);
788   std::vector<llvm::Value*> parameter_buffers;
789   for (llvm::Value* parameter_element : parameter_elements) {
790     parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
791         parameter_element->getType(), "parameter_buffer", &b_));
792     Store(parameter_element, parameter_buffers.back());
793   }
794   TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
795                                                  return_buffer));
796   return Load(return_buffer);
797 }
798 
ConstructIrArrayForOutputs(const HloInstruction & hlo)799 std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
800     const HloInstruction& hlo) {
801   std::vector<llvm_ir::IrArray> output_arrays;
802   if (hlo.shape().IsTuple()) {
803     int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
804     output_arrays.reserve(num_outputs);
805     for (int64 i = 0; i < num_outputs; ++i) {
806       output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
807     }
808   } else {
809     output_arrays.push_back(GetIrArray(hlo, hlo));
810   }
811   return output_arrays;
812 }
813 
814 }  // namespace gpu
815 }  // namespace xla
816