• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
17 
18 #include <string>
19 #include <unordered_map>
20 #include <utility>
21 
22 #include "tensorflow/core/platform/logging.h"
23 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
24 #include "absl/algorithm/container.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Module.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
32 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
33 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
34 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
35 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
36 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
37 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
38 #include "tensorflow/compiler/xla/service/hlo_computation.h"
39 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
42 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
43 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
44 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
45 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
46 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
47 #include "tensorflow/compiler/xla/service/name_uniquer.h"
48 #include "tensorflow/compiler/xla/shape_util.h"
49 #include "tensorflow/compiler/xla/status_macros.h"
50 #include "tensorflow/compiler/xla/types.h"
51 #include "tensorflow/compiler/xla/util.h"
52 #include "tensorflow/compiler/xla/window_util.h"
53 #include "tensorflow/core/lib/core/errors.h"
54 
55 // Convenient function to cast the provided llvm::Value* using IRBuilder
56 // to default address space. This is useful in particular for generating
57 // IR for AMDGPU target, as its kernel variables are in address space 5
58 // instead of the default address space.
AddrCastToDefault(llvm::Value * arg,llvm::IRBuilder<> & b)59 static llvm::Value* AddrCastToDefault(llvm::Value* arg, llvm::IRBuilder<>& b) {
60   llvm::Type* arg_type = arg->getType();
61   CHECK(arg_type->isPointerTy());
62   if (arg_type->getPointerAddressSpace() != 0) {
63     llvm::Type* generic_arg_type =
64         arg_type->getPointerElementType()->getPointerTo(0);
65     llvm::Value* addrspacecast_arg =
66         b.CreateAddrSpaceCast(arg, generic_arg_type);
67     return addrspacecast_arg;
68   }
69   return arg;
70 }
71 
72 namespace xla {
73 
74 using llvm_ir::IrName;
75 using llvm_ir::SetToFirstInsertPoint;
76 
77 namespace gpu {
78 
IrEmitter(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context,bool is_nested)79 IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
80                      IrEmitterContext* ir_emitter_context, bool is_nested)
81     : ir_emitter_context_(ir_emitter_context),
82       module_(ir_emitter_context->llvm_module()),
83       b_(module_->getContext()),
84       bindings_(ir_emitter_context->hlo_module(),
85                 &ir_emitter_context->buffer_assignment(), &b_, module_,
86                 is_nested),
87       hlo_module_config_(hlo_module_config) {
88 }
89 
DefaultAction(HloInstruction * hlo)90 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
91   ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
92   for (const HloInstruction* operand : hlo->operands()) {
93     operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
94       return GetIrArray(*operand, *hlo)
95           .EmitReadArrayElement(index, &b_, operand->name());
96     };
97   }
98   return EmitTargetElementLoop(
99       *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &b_,
100                                   GetNestedComputer())
101                 .MakeElementGenerator(hlo, operand_to_generator));
102 }
103 
EmitConstants(const HloComputation & computation,bool lookup_indices)104 Status IrEmitter::EmitConstants(const HloComputation& computation,
105                                 bool lookup_indices) {
106   for (HloInstruction* instr : computation.instructions()) {
107     if (instr->opcode() != HloOpcode::kConstant) {
108       continue;
109     }
110     Literal& literal = *Cast<HloConstantInstruction>(instr)->mutable_literal();
111     const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal);
112     llvm::ArrayType* global_type =
113         llvm::ArrayType::get(b_.getInt8Ty(), literal.size_bytes());
114     llvm::Constant* initializer =
115         should_emit_initializer
116             ? llvm_ir::ConvertLiteralToIrConstant(literal, module_)
117             : llvm::ConstantAggregateZero::get(global_type);
118     if (should_emit_initializer) {
119       VLOG(3) << "Emitted initializer for constant with shape "
120               << ShapeUtil::HumanString(literal.shape());
121     }
122 
123     // These globals will be looked up by name by GpuExecutable so we need to
124     // give them an external linkage.  Not all of their uses are visible in
125     // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that
126     // merely preserves their names (like available_externally), we also need
127     // to ensure that they stick around even if they're "unused".
128     //
129     // We may have to be more clever here in the future if we notice that we're
130     // keeping around too many globals because of their linkage.
131     unsigned global_address_space = llvm_ir::GetGlobalMemoryAddressSpace(
132         *ir_emitter_context_->llvm_module());
133 
134     std::string global_name = llvm_ir::ConstantHloToGlobalName(*instr);
135 
136     llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
137         global_type, /*isConstant=*/should_emit_initializer,
138         llvm::GlobalValue::ExternalLinkage,
139         /*Initializer=*/initializer, global_name,
140         /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
141         /*AddressSpace=*/global_address_space,
142         /*isExternallyInitialized=*/false);
143     global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
144     ir_emitter_context_->llvm_module()->getGlobalList().push_back(
145         global_for_const);
146 
147     GpuExecutable::ConstantInfo info;
148     info.symbol_name = global_name;
149 
150     if (!should_emit_initializer) {
151       auto base = static_cast<const uint8*>(literal.untyped_data());
152       info.content.assign(base, base + literal.size_bytes());
153     }
154     if (lookup_indices) {
155       auto maybe_slice =
156           ir_emitter_context_->buffer_assignment().GetUniqueSlice(instr, {});
157       if (maybe_slice.ok()) {
158         info.allocation_index = maybe_slice.ValueOrDie().index();
159       }
160     }
161     ir_emitter_context_->constants().push_back(std::move(info));
162   }
163   return Status::OK();
164 }
165 
HandleConstant(HloInstruction * constant)166 Status IrEmitter::HandleConstant(HloInstruction* constant) {
167   return Status::OK();
168 }
169 
HandleAddDependency(HloInstruction * add_dependency)170 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
171   VLOG(2) << "HandleAddDependency: " << add_dependency->ToString();
172   const HloInstruction* operand = add_dependency->operand(0);
173   // Add_Dependency is a no-op, but we still want to bind it to an llvm::Value
174   // sometimes, e.g., when it's operand is a constant or a bitcast of a
175   // constant.
176   if (bindings_.BoundToIrValue(*operand)) {
177     bindings_.BindHloToIrValue(*add_dependency, GetBasePointer(*operand));
178   }
179   return Status::OK();
180 }
181 
HandleGetTupleElement(HloInstruction * get_tuple_element)182 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
183   auto operand = get_tuple_element->operand(0);
184   CHECK(bindings_.BoundToIrValue(*operand));
185   bindings_.BindHloToIrValue(
186       *get_tuple_element,
187       llvm_ir::EmitGetTupleElement(
188           get_tuple_element->shape(), get_tuple_element->tuple_index(),
189           // TODO(b/26344050): tighten the alignment here
190           // based on the real element type.
191           /*alignment=*/1, GetBasePointer(*operand), &b_));
192   return Status::OK();
193 }
194 
HandleSend(HloInstruction *)195 Status IrEmitter::HandleSend(HloInstruction*) {
196   return Unimplemented("Send is not implemented on GPU");
197 }
198 
HandleSendDone(HloInstruction *)199 Status IrEmitter::HandleSendDone(HloInstruction*) {
200   return Unimplemented("Send-Done is not implemented on GPU");
201 }
202 
HandleRecv(HloInstruction *)203 Status IrEmitter::HandleRecv(HloInstruction*) {
204   return Unimplemented("Recv is not implemented on GPU");
205 }
206 
HandleRecvDone(HloInstruction *)207 Status IrEmitter::HandleRecvDone(HloInstruction*) {
208   return Unimplemented("Recv-done is not implemented on GPU");
209 }
210 
HandleScatter(HloInstruction *)211 Status IrEmitter::HandleScatter(HloInstruction*) {
212   return Unimplemented("Scatter is not implemented on GPUs.");
213 }
214 
HandleTuple(HloInstruction * tuple)215 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
216   std::vector<llvm::Value*> base_ptrs;
217   for (const HloInstruction* operand : tuple->operands()) {
218     base_ptrs.push_back(GetBasePointer(*operand));
219   }
220   llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_);
221   return Status::OK();
222 }
223 
EmitCallToNestedComputation(const HloComputation & nested_computation,absl::Span<llvm::Value * const> operands,llvm::Value * output)224 Status IrEmitter::EmitCallToNestedComputation(
225     const HloComputation& nested_computation,
226     absl::Span<llvm::Value* const> operands, llvm::Value* output) {
227   TF_RET_CHECK(nested_computation.num_parameters() > 0);
228   llvm::Function*& emitted_function =
229       computation_to_ir_function_[&nested_computation];
230   if (emitted_function == nullptr) {
231     TF_ASSIGN_OR_RETURN(
232         auto ir_emitter_nested,
233         IrEmitterNested::Create(hlo_module_config_, nested_computation,
234                                 ir_emitter_context_));
235     TF_RETURN_IF_ERROR(ir_emitter_nested->CodegenNestedComputation());
236     emitted_function = ir_emitter_nested->GetEmittedFunction();
237   }
238 
239   // Operands are in default address space for non-AMDGPU target.
240   // However for AMDGPU target, addrspacecast alloca variables from
241   // addrspace 5 to addrspace 0 is needed.
242   std::vector<llvm::Value*> arguments;
243   absl::c_transform(
244       operands, std::back_inserter(arguments),
245       [this](llvm::Value* arg) { return AddrCastToDefault(arg, b_); });
246 
247   llvm::Value* casted_output = AddrCastToDefault(output, b_);
248   arguments.push_back(casted_output);
249 
250   Call(emitted_function, arguments);
251 
252   return Status::OK();
253 }
254 
MaybeEmitDirectAtomicOperation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)255 bool IrEmitter::MaybeEmitDirectAtomicOperation(
256     const HloComputation& computation, llvm::Value* output_address,
257     llvm::Value* source_address) {
258   CHECK_EQ(2, computation.num_parameters());
259 
260   HloOpcode root_opcode = computation.root_instruction()->opcode();
261   PrimitiveType element_type =
262       computation.root_instruction()->shape().element_type();
263   bool is_atomic_integral = element_type == S32 || element_type == U32 ||
264                             element_type == S64 || element_type == U64;
265   llvm::Value* source = Load(source_address, "source");
266 
267   // Just passing along RHS -> atomic store.
268   if (computation.instruction_count() == 2 &&
269       root_opcode == HloOpcode::kParameter &&
270       (element_type == F32 || is_atomic_integral) &&
271       computation.root_instruction()->parameter_number() == 1) {
272     llvm::StoreInst* store = Store(source, output_address);
273     store->setAtomic(llvm::AtomicOrdering::Unordered);
274     // Derive a minimum alignment from the type. The optimizer can increase it
275     // later.
276     store->setAlignment(
277         llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(element_type)));
278     return true;
279   }
280 
281   if (computation.instruction_count() != 3) {
282     // We special-case only computations with one computing instruction for now.
283     // Such computation has exactly three instructions given it has two
284     // parameters.
285     return false;
286   }
287 
288   if (root_opcode == HloOpcode::kAdd) {
289     llvm::Triple target_triple = llvm::Triple(module_->getTargetTriple());
290     // NVPTX supports atomicAdd on F32 and integer types.
291     if (target_triple.isNVPTX()) {
292       // "atom.add.f64 requires sm_60 or higher."
293       // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
294       bool f64_atomic_add_supported =
295           ir_emitter_context_->cuda_compute_capability()->cc_major >= 6;
296 
297       bool atomic_add_supported =
298           element_type == F32 ||
299           (f64_atomic_add_supported && element_type == F64);
300       if (atomic_add_supported) {
301         AtomicRMW(llvm::AtomicRMWInst::FAdd, output_address, source,
302                   llvm::AtomicOrdering::SequentiallyConsistent);
303         return true;
304       }
305     }
306 
307     if (is_atomic_integral) {
308       // integral + integral
309       AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
310                 llvm::AtomicOrdering::SequentiallyConsistent);
311       return true;
312     }
313   }
314 
315   // NVPTX supports atomicMax and atomicMin only on integer types.
316   if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) {
317     // max(integral, integral)
318     auto opcode = primitive_util::IsSignedIntegralType(element_type)
319                       ? llvm::AtomicRMWInst::Max
320                       : llvm::AtomicRMWInst::UMax;
321     AtomicRMW(opcode, output_address, source,
322               llvm::AtomicOrdering::SequentiallyConsistent);
323     return true;
324   }
325 
326   if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) {
327     // min(integral, integral)
328     auto opcode = primitive_util::IsSignedIntegralType(element_type)
329                       ? llvm::AtomicRMWInst::Min
330                       : llvm::AtomicRMWInst::UMin;
331     AtomicRMW(opcode, output_address, source,
332               llvm::AtomicOrdering::SequentiallyConsistent);
333     return true;
334   }
335 
336   return false;
337 }
338 
339 // Implements atomic binary operations using atomic compare-and-swap
340 // (atomicCAS) as follows:
341 //   1. Reads the value from the memory pointed to by output_address and
342 //     records it as old_output.
343 //   2. Uses old_output as one of the source operand to perform the binary
344 //     operation and stores the result in new_output.
345 //   3. Calls atomicCAS which implements compare-and-swap as an atomic
346 //     operation. In particular, atomicCAS reads the value from the memory
347 //     pointed to by output_address, and compares the value with old_output. If
348 //     the two values equal, new_output is written to the same memory location
349 //     and true is returned to indicate that the atomic operation succeeds.
350 //     Otherwise, the new value read from the memory is returned. In this case,
351 //     the new value is copied to old_output, and steps 2. and 3. are repeated
352 //     until atomicCAS succeeds.
353 //
354 // On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If
355 // the element type of the binary operation is 32 bits or 64 bits, the integer
356 // type of the same size is used for the atomicCAS operation. On the other hand,
357 // if the element type is smaller than 32 bits, int32 is used for the atomicCAS
358 // operation. In this case, atomicCAS reads and writes 32 bit values from
359 // the memory, which is larger than the memory size required by the original
360 // atomic binary operation. We mask off the last two bits of the output_address
361 // and use the result as an address to read the 32 bit values from the memory.
362 // This can avoid out of bound memory accesses if tensor buffers are 4 byte
363 // aligned and have a size of 4N, an assumption that the runtime can guarantee.
364 //
365 // The pseudo code is shown below. Variables *_address are pointers to a memory
366 // region with a size equal to the size of the atomicCAS operation, with the
367 // exception that new_output_address is a pointer to a memory region with a size
368 // equal to the element size of the binary operation.
369 //
370 //   element_size = sizeof(element_type);
371 //   atomic_size = max(32, element_size);
372 //   cas_new_output_address = alloca(atomic_size);
373 //   cas_old_output_address = alloca(atomic_size);
374 //   if (atomic_size != element_size) {
375 //     atomic_address = output_address & ((int64)(-4));
376 //     new_output_address = cas_new_output_address + (output_address & 3);
377 //   } else {
378 //     atomic_address = output_address;
379 //     new_output_address = cas_new_output_address;
380 //   }
381 //
382 //   *cas_old_output_address = *atomic_address;
383 //   do {
384 //     *cas_new_output_address = *cas_old_output_address;
385 //     *new_output_address = operation(*new_output_address, *source_address);
386 //     (*cas_old_output_address, success) =
387 //       atomicCAS(atomic_address, *cas_old_output_address,
388 //       *cas_new_output_address);
389 //   } while (!success);
390 //
EmitAtomicOperationUsingCAS(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)391 Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
392                                               llvm::Value* output_address,
393                                               llvm::Value* source_address) {
394   llvm::PointerType* output_address_type =
395       llvm::dyn_cast<llvm::PointerType>(output_address->getType());
396   CHECK_NE(output_address_type, nullptr);
397 
398   // element_type is the data type for the binary operation.
399   llvm::Type* element_type = output_address_type->getPointerElementType();
400   int element_size = llvm_ir::GetSizeInBits(element_type);
401 
402   int atomic_size = (element_size < 32) ? 32 : element_size;
403   llvm::Type* atomic_type = b_.getIntNTy(atomic_size);
404   llvm::Type* atomic_address_type =
405       atomic_type->getPointerTo(output_address_type->getPointerAddressSpace());
406 
407   // cas_old_output_address and cas_new_output_address point to the scratch
408   // memory where we store the old and new values for the repeated atomicCAS
409   // operations.
410   llvm::Value* cas_old_output_address = llvm_ir::EmitAllocaAtFunctionEntry(
411       atomic_type, "cas_old_output_address", &b_);
412   llvm::Value* cas_new_output_address = llvm_ir::EmitAllocaAtFunctionEntry(
413       atomic_type, "cas_new_output_address", &b_);
414 
415   // Emit preparation code to the preheader.
416   llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock();
417 
418   llvm::Value* atomic_memory_address;
419   // binop_output_address points to the scratch memory that stores the
420   // result of the binary operation.
421   llvm::Value* binop_output_address;
422   if (element_size < 32) {
423     // Assume the element size is an integer number of bytes.
424     CHECK_EQ((element_size % sizeof(char)), 0);
425     llvm::Type* address_int_type =
426         module_->getDataLayout().getIntPtrType(output_address_type);
427     atomic_memory_address = PtrToInt(output_address, address_int_type);
428     llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3);
429     llvm::Value* offset = And(atomic_memory_address, mask);
430     mask = llvm::ConstantInt::get(address_int_type, -4);
431     atomic_memory_address = And(atomic_memory_address, mask);
432     atomic_memory_address =
433         IntToPtr(atomic_memory_address, atomic_address_type);
434     binop_output_address =
435         Add(PtrToInt(cas_new_output_address, address_int_type), offset);
436     binop_output_address = IntToPtr(
437         binop_output_address,
438         llvm::PointerType::get(
439             element_type,
440             cas_new_output_address->getType()->getPointerAddressSpace()));
441   } else {
442     atomic_memory_address = b_.CreatePointerBitCastOrAddrSpaceCast(
443         output_address, atomic_address_type);
444     binop_output_address = b_.CreatePointerBitCastOrAddrSpaceCast(
445         cas_new_output_address,
446         llvm::PointerType::get(
447             element_type,
448             cas_new_output_address->getType()->getPointerAddressSpace()));
449   }
450 
451   // Use the value from the memory that atomicCAS operates on to initialize
452   // cas_old_output.
453   llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output");
454   Store(cas_old_output, cas_old_output_address);
455 
456   llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock(
457       b_.GetInsertPoint(), "atomic_op_loop_exit");
458   llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create(
459       b_.getContext(), "atomic_op_loop_body", b_.GetInsertBlock()->getParent());
460   b_.SetInsertPoint(loop_body_bb);
461   // Change preheader's successor from loop_exit_bb to loop_body_bb.
462   loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb);
463 
464   // Emit the body of the loop that repeatedly invokes atomicCAS.
465   //
466   // Use cas_old_output to initialize cas_new_output.
467   cas_old_output = Load(cas_old_output_address, "cas_old_output");
468   Store(cas_old_output, cas_new_output_address);
469   // Emits code to calculate new_output = operation(old_output, source);
470   TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
471       computation, {binop_output_address, source_address},
472       binop_output_address));
473 
474   llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output");
475 
476   // Emit code to perform the atomicCAS operation
477   // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output,
478   //                                       cas_new_output);
479   llvm::Value* ret_value =
480       AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output,
481                     llvm::AtomicOrdering::SequentiallyConsistent,
482                     llvm::AtomicOrdering::SequentiallyConsistent);
483 
484   // Extract the memory value returned from atomicCAS and store it as
485   // cas_old_output.
486   Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address);
487   // Extract the success bit returned from atomicCAS and generate a
488   // conditional branch on the success bit.
489   CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb);
490 
491   // Set the insertion point to the exit basic block so that the caller of
492   // this method can continue emitting code to the right place.
493   SetToFirstInsertPoint(loop_exit_bb, &b_);
494   return Status::OK();
495 }
496 
EmitAtomicOperationForNestedComputation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)497 Status IrEmitter::EmitAtomicOperationForNestedComputation(
498     const HloComputation& computation, llvm::Value* output_address,
499     llvm::Value* source_address) {
500   if (computation.num_parameters() != 2) {
501     // TODO(b/30258929): We only accept binary computations so far.
502     return Unimplemented(
503         "We only support atomic functions with exactly two parameters, but "
504         "computation %s has %d.",
505         computation.name(), computation.num_parameters());
506   }
507 
508   if (MaybeEmitDirectAtomicOperation(computation, output_address,
509                                      source_address)) {
510     return Status::OK();
511   }
512 
513   return EmitAtomicOperationUsingCAS(computation, output_address,
514                                      source_address);
515 }
516 
HandleTupleSelect(HloInstruction * tuple_select)517 Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
518   return InternalError(
519       "Dynamic selection of tuples is not supported. Please file a bug against "
520       "XLA/GPU if you need it");
521 }
522 
523 namespace {
Real(llvm::Value * x,llvm::IRBuilder<> * b)524 llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* b) {
525   return b->CreateExtractValue(x, {0});
526 }
527 
Imag(llvm::Value * x,llvm::IRBuilder<> * b)528 llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* b) {
529   return b->CreateExtractValue(x, {1});
530 }
531 
MultiplyComplex(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)532 std::pair<llvm::Value*, llvm::Value*> MultiplyComplex(llvm::Value* lhs_value,
533                                                       llvm::Value* rhs_value,
534                                                       llvm::IRBuilder<>* b) {
535   llvm::Value* lhs_real = Real(lhs_value, b);
536   llvm::Value* lhs_imag = Imag(lhs_value, b);
537   llvm::Value* rhs_real = Real(rhs_value, b);
538   llvm::Value* rhs_imag = Imag(rhs_value, b);
539   llvm::Value* real_result1 = b->CreateFMul(lhs_real, rhs_real);
540   llvm::Value* real_result2 = b->CreateFMul(lhs_imag, rhs_imag);
541   llvm::Value* real_result = b->CreateFSub(real_result1, real_result2);
542   llvm::Value* imag_result1 = b->CreateFMul(lhs_real, rhs_imag);
543   llvm::Value* imag_result2 = b->CreateFMul(lhs_imag, rhs_real);
544   llvm::Value* imag_result = b->CreateFAdd(imag_result1, imag_result2);
545   return {real_result, imag_result};
546 }
547 }  // namespace
548 
HandleConvolution(HloInstruction * convolution)549 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
550   if (ShapeUtil::IsZeroElementArray(convolution->shape())) {
551     // Emit no code for an empty output.
552     return Status::OK();
553   }
554   // TODO(b/31409998): Support convolution with dilation.
555   return Unimplemented(
556       "Hit a case for convolution that is not implemented on GPU.");
557 }
558 
HandleFft(HloInstruction * fft)559 Status IrEmitter::HandleFft(HloInstruction* fft) {
560   if (ShapeUtil::IsZeroElementArray(fft->shape())) {
561     // Emit no code for an empty output.
562     return Status::OK();
563   }
564   return Unimplemented("Hit a case for fft that is not implemented on GPU.");
565 }
566 
HandleAllReduce(HloInstruction * crs)567 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
568   return Unimplemented(
569       "AllReduce cannot be nested inside of fusion, map, etc.");
570 }
571 
HandleParameter(HloInstruction * parameter)572 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
573   return Status::OK();
574 }
575 
HandleFusion(HloInstruction * fusion)576 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
577   // kFusion for library calls should be handled by
578   // IrEmitterUnnested::HandleFusion.
579   CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind());
580   GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
581                                           GetNestedComputer());
582   FusedIrEmitter fused_emitter(&elemental_emitter);
583   BindFusionArguments(fusion, &fused_emitter);
584   TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
585                                           fusion->fused_expression_root()));
586   return EmitTargetElementLoop(*fusion, generator);
587 }
588 
HandleCall(HloInstruction * call)589 Status IrEmitter::HandleCall(HloInstruction* call) {
590   std::vector<llvm::Value*> operand_addresses;
591   for (HloInstruction* operand : call->operands()) {
592     operand_addresses.push_back(GetBasePointer(*operand));
593   }
594   return EmitCallToNestedComputation(*call->to_apply(), operand_addresses,
595                                      GetBasePointer(*call));
596 }
597 
HandleCustomCall(HloInstruction *)598 Status IrEmitter::HandleCustomCall(HloInstruction*) {
599   return Unimplemented("custom-call");
600 }
601 
HandleInfeed(HloInstruction *)602 Status IrEmitter::HandleInfeed(HloInstruction*) {
603   // TODO(b/30467474): Implement infeed on GPU.
604   return Unimplemented("Infeed is not supported on GPU.");
605 }
606 
HandleOutfeed(HloInstruction *)607 Status IrEmitter::HandleOutfeed(HloInstruction*) {
608   // TODO(b/34359662): Implement outfeed on GPU.
609   return Unimplemented("Outfeed is not supported on GPU.");
610 }
611 
HandleBatchNormInference(HloInstruction *)612 Status IrEmitter::HandleBatchNormInference(HloInstruction*) {
613   return Unimplemented(
614       "The GPU backend does not implement BatchNormInference directly.  It "
615       "should be lowered before IR emission to HLO-soup using "
616       "BatchNormRewriter or to a cudnn CustomCall using "
617       "CudnnBatchNormRewriter.");
618 }
619 
HandleBatchNormTraining(HloInstruction *)620 Status IrEmitter::HandleBatchNormTraining(HloInstruction*) {
621   return Unimplemented(
622       "The GPU backend does not implement BatchNormTraining directly.  It "
623       "should be lowered before IR emission to HLO-soup using "
624       "BatchNormRewriter or to a cudnn CustomCall using "
625       "CudnnBatchNormRewriter.");
626 }
627 
HandleBatchNormGrad(HloInstruction *)628 Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
629   return Unimplemented(
630       "The GPU backend does not implement BatchNormGrad directly.  It should "
631       "be lowered before IR emission to HLO-soup (using BatchNormRewriter) or "
632       "to a cudnn CustomCall using CudnnBatchNormRewriter.");
633 }
634 
ComputeNestedElement(const HloComputation & computation,absl::Span<llvm::Value * const> parameter_elements)635 StatusOr<std::vector<llvm::Value*>> IrEmitter::ComputeNestedElement(
636     const HloComputation& computation,
637     absl::Span<llvm::Value* const> parameter_elements) {
638   const Shape& return_shape = computation.root_instruction()->shape();
639   llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
640       llvm_ir::ShapeToIrType(return_shape, module_), "return_buffer", &b_);
641   std::vector<llvm::Value*> parameter_buffers;
642   for (llvm::Value* parameter_element : parameter_elements) {
643     parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
644         parameter_element->getType(), "parameter_buffer", &b_));
645     Store(parameter_element, parameter_buffers.back());
646   }
647 
648   std::vector<llvm::Value*> allocas_for_returned_scalars;
649   if (!return_shape.IsTuple()) {
650     allocas_for_returned_scalars.push_back(return_buffer);
651   } else {
652     allocas_for_returned_scalars =
653         llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
654     llvm_ir::IrArray tuple_array(return_buffer, return_shape);
655 
656     EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
657   }
658 
659   TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
660                                                  return_buffer));
661 
662   std::vector<llvm::Value*> returned_scalars;
663   returned_scalars.reserve(allocas_for_returned_scalars.size());
664   for (llvm::Value* addr : allocas_for_returned_scalars) {
665     returned_scalars.push_back(Load(addr));
666   }
667   return returned_scalars;
668 }
669 
ConstructIrArrayForOutputs(const HloInstruction & hlo)670 std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
671     const HloInstruction& hlo) {
672   std::vector<llvm_ir::IrArray> output_arrays;
673   if (hlo.shape().IsTuple()) {
674     int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
675     output_arrays.reserve(num_outputs);
676     for (int64 i = 0; i < num_outputs; ++i) {
677       output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
678     }
679   } else {
680     output_arrays.push_back(GetIrArray(hlo, hlo));
681   }
682   return output_arrays;
683 }
684 
BindFusionArguments(const HloInstruction * fusion,FusedIrEmitter * fused_emitter)685 void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
686                                     FusedIrEmitter* fused_emitter) {
687   for (int i = 0; i < fusion->operand_count(); i++) {
688     const HloInstruction* operand = fusion->operand(i);
689     fused_emitter->BindGenerator(
690         fusion->fused_parameter(i),
691         [this, operand, fusion](llvm_ir::IrArray::Index index) {
692           return GetIrArray(*operand, *fusion)
693               .EmitReadArrayElement(index, &b_, operand->name());
694         });
695   }
696 }
697 
698 }  // namespace gpu
699 }  // namespace xla
700