1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
17
18 #include <string>
19 #include <unordered_map>
20 #include <utility>
21
22 #include "tensorflow/core/platform/logging.h"
23 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
24 #include "absl/algorithm/container.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Module.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
32 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
33 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
34 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
35 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
36 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
37 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
38 #include "tensorflow/compiler/xla/service/hlo_computation.h"
39 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
42 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
43 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
44 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
45 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
46 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
47 #include "tensorflow/compiler/xla/service/name_uniquer.h"
48 #include "tensorflow/compiler/xla/shape_util.h"
49 #include "tensorflow/compiler/xla/status_macros.h"
50 #include "tensorflow/compiler/xla/types.h"
51 #include "tensorflow/compiler/xla/util.h"
52 #include "tensorflow/compiler/xla/window_util.h"
53 #include "tensorflow/core/lib/core/errors.h"
54
55 // Convenient function to cast the provided llvm::Value* using IRBuilder
56 // to default address space. This is useful in particular for generating
57 // IR for AMDGPU target, as its kernel variables are in address space 5
58 // instead of the default address space.
AddrCastToDefault(llvm::Value * arg,llvm::IRBuilder<> & b)59 static llvm::Value* AddrCastToDefault(llvm::Value* arg, llvm::IRBuilder<>& b) {
60 llvm::Type* arg_type = arg->getType();
61 CHECK(arg_type->isPointerTy());
62 if (arg_type->getPointerAddressSpace() != 0) {
63 llvm::Type* generic_arg_type =
64 arg_type->getPointerElementType()->getPointerTo(0);
65 llvm::Value* addrspacecast_arg =
66 b.CreateAddrSpaceCast(arg, generic_arg_type);
67 return addrspacecast_arg;
68 }
69 return arg;
70 }
71
72 namespace xla {
73
74 using llvm_ir::IrName;
75 using llvm_ir::SetToFirstInsertPoint;
76
77 namespace gpu {
78
IrEmitter(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context,bool is_nested)79 IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
80 IrEmitterContext* ir_emitter_context, bool is_nested)
81 : ir_emitter_context_(ir_emitter_context),
82 module_(ir_emitter_context->llvm_module()),
83 b_(module_->getContext()),
84 bindings_(&b_, module_, is_nested),
85 hlo_module_config_(hlo_module_config) {}
86
DefaultAction(HloInstruction * hlo)87 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
88 ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
89 for (const HloInstruction* operand : hlo->operands()) {
90 operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
91 return GetIrArray(*operand, *hlo)
92 .EmitReadArrayElement(index, &b_, operand->name());
93 };
94 }
95 return EmitTargetElementLoop(
96 *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &b_,
97 GetNestedComputer())
98 .MakeElementGenerator(hlo, operand_to_generator));
99 }
100
EmitConstants(const HloComputation & computation)101 Status IrEmitter::EmitConstants(const HloComputation& computation) {
102 for (HloInstruction* instr : computation.instructions()) {
103 if (instr->opcode() != HloOpcode::kConstant) {
104 continue;
105 }
106 Literal& literal = *Cast<HloConstantInstruction>(instr)->mutable_literal();
107 const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal);
108 llvm::ArrayType* global_type =
109 llvm::ArrayType::get(b_.getInt8Ty(), literal.size_bytes());
110 llvm::Constant* initializer =
111 should_emit_initializer
112 ? llvm_ir::ConvertLiteralToIrConstant(literal, module_)
113 : llvm::ConstantAggregateZero::get(global_type);
114 if (should_emit_initializer) {
115 VLOG(3) << "Emitted initializer for constant with shape "
116 << ShapeUtil::HumanString(literal.shape());
117 }
118
119 // These globals will be looked up by name by GpuExecutable so we need to
120 // give them an external linkage. Not all of their uses are visible in
121 // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that
122 // merely preserves their names (like available_externally), we also need
123 // to ensure that they stick around even if they're "unused".
124 //
125 // We may have to be more clever here in the future if we notice that we're
126 // keeping around too many globals because of their linkage.
127 std::string global_name = llvm_ir::ConstantHloToGlobalName(*instr);
128
129 llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
130 global_type, /*isConstant=*/should_emit_initializer,
131 llvm::GlobalValue::ExternalLinkage,
132 /*Initializer=*/initializer, global_name,
133 /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
134 /*AddressSpace=*/0,
135 /*isExternallyInitialized=*/false);
136 global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
137 ir_emitter_context_->llvm_module()->getGlobalList().push_back(
138 global_for_const);
139
140 GpuExecutable::ConstantInfo info;
141 info.symbol_name = global_name;
142
143 if (!should_emit_initializer) {
144 auto base = static_cast<const uint8*>(literal.untyped_data());
145 info.content.assign(base, base + literal.size_bytes());
146 }
147 ir_emitter_context_->constants().push_back(std::move(info));
148 }
149 return Status::OK();
150 }
151
HandleConstant(HloInstruction * constant)152 Status IrEmitter::HandleConstant(HloInstruction* constant) {
153 return Status::OK();
154 }
155
HandleAddDependency(HloInstruction * add_dependency)156 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
157 VLOG(2) << "HandleAddDependency: " << add_dependency->ToString();
158 const HloInstruction* operand = add_dependency->operand(0);
159 // Add_Dependency is a no-op, but we still want to bind it to an llvm::Value
160 // sometimes, e.g., when it's operand is a constant or a bitcast of a
161 // constant.
162 if (bindings_.BoundToIrValue(*operand)) {
163 bindings_.BindHloToIrValue(*add_dependency, GetBasePointer(*operand));
164 }
165 return Status::OK();
166 }
167
HandleGetTupleElement(HloInstruction * get_tuple_element)168 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
169 auto operand = get_tuple_element->operand(0);
170 CHECK(bindings_.BoundToIrValue(*operand));
171 bindings_.BindHloToIrValue(
172 *get_tuple_element,
173 llvm_ir::EmitGetTupleElement(
174 get_tuple_element->shape(), get_tuple_element->tuple_index(),
175 // TODO(b/26344050): tighten the alignment here
176 // based on the real element type.
177 /*alignment=*/1, GetBasePointer(*operand), &b_));
178 return Status::OK();
179 }
180
HandleSend(HloInstruction *)181 Status IrEmitter::HandleSend(HloInstruction*) {
182 return Unimplemented("Send is not implemented on GPU");
183 }
184
HandleSendDone(HloInstruction *)185 Status IrEmitter::HandleSendDone(HloInstruction*) {
186 return Unimplemented("Send-Done is not implemented on GPU");
187 }
188
HandleRecv(HloInstruction *)189 Status IrEmitter::HandleRecv(HloInstruction*) {
190 return Unimplemented("Recv is not implemented on GPU");
191 }
192
HandleRecvDone(HloInstruction *)193 Status IrEmitter::HandleRecvDone(HloInstruction*) {
194 return Unimplemented("Recv-done is not implemented on GPU");
195 }
196
HandleScatter(HloInstruction *)197 Status IrEmitter::HandleScatter(HloInstruction*) {
198 return Unimplemented("Scatter is not implemented on GPUs.");
199 }
200
HandleTuple(HloInstruction * tuple)201 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
202 std::vector<llvm::Value*> base_ptrs;
203 for (const HloInstruction* operand : tuple->operands()) {
204 base_ptrs.push_back(GetBasePointer(*operand));
205 }
206 llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_);
207 return Status::OK();
208 }
209
EmitCallToNestedComputation(const HloComputation & nested_computation,absl::Span<llvm::Value * const> operands,llvm::Value * output)210 Status IrEmitter::EmitCallToNestedComputation(
211 const HloComputation& nested_computation,
212 absl::Span<llvm::Value* const> operands, llvm::Value* output) {
213 TF_RET_CHECK(nested_computation.num_parameters() > 0);
214 llvm::Function*& emitted_function =
215 computation_to_ir_function_[&nested_computation];
216 if (emitted_function == nullptr) {
217 TF_ASSIGN_OR_RETURN(
218 auto ir_emitter_nested,
219 IrEmitterNested::Create(hlo_module_config_, nested_computation,
220 ir_emitter_context_));
221 TF_RETURN_IF_ERROR(ir_emitter_nested->CodegenNestedComputation());
222 emitted_function = ir_emitter_nested->GetEmittedFunction();
223 }
224
225 // Operands are in default address space for non-AMDGPU target.
226 // However for AMDGPU target, addrspacecast alloca variables from
227 // addrspace 5 to addrspace 0 is needed.
228 std::vector<llvm::Value*> arguments;
229 absl::c_transform(
230 operands, std::back_inserter(arguments),
231 [this](llvm::Value* arg) { return AddrCastToDefault(arg, b_); });
232
233 llvm::Value* casted_output = AddrCastToDefault(output, b_);
234 arguments.push_back(casted_output);
235
236 Call(emitted_function, arguments);
237
238 return Status::OK();
239 }
240
MaybeEmitDirectAtomicOperation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)241 bool IrEmitter::MaybeEmitDirectAtomicOperation(
242 const HloComputation& computation, llvm::Value* output_address,
243 llvm::Value* source_address) {
244 CHECK_EQ(2, computation.num_parameters());
245
246 HloOpcode root_opcode = computation.root_instruction()->opcode();
247 PrimitiveType element_type =
248 computation.root_instruction()->shape().element_type();
249 bool is_atomic_integral = element_type == S32 || element_type == U32 ||
250 element_type == S64 || element_type == U64;
251 llvm::Value* source = Load(source_address, "source");
252
253 // Just passing along RHS -> atomic store.
254 if (computation.instruction_count() == 2 &&
255 root_opcode == HloOpcode::kParameter &&
256 (element_type == F32 || is_atomic_integral) &&
257 computation.root_instruction()->parameter_number() == 1) {
258 llvm::StoreInst* store = Store(source, output_address);
259 store->setAtomic(llvm::AtomicOrdering::Unordered);
260 // Derive a minimum alignment from the type. The optimizer can increase it
261 // later.
262 store->setAlignment(
263 llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(element_type)));
264 return true;
265 }
266
267 if (computation.instruction_count() != 3) {
268 // We special-case only computations with one computing instruction for now.
269 // Such computation has exactly three instructions given it has two
270 // parameters.
271 return false;
272 }
273
274 if (root_opcode == HloOpcode::kAdd) {
275 llvm::Triple target_triple = llvm::Triple(module_->getTargetTriple());
276 // NVPTX supports atomicAdd on F32 and integer types.
277 if (target_triple.isNVPTX()) {
278 // "atom.add.f64 requires sm_60 or higher."
279 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
280 bool f64_atomic_add_supported =
281 ir_emitter_context_->cuda_compute_capability().IsAtLeast(6);
282 bool atomic_add_supported =
283 element_type == F32 ||
284 (f64_atomic_add_supported && element_type == F64);
285 if (atomic_add_supported) {
286 AtomicRMW(llvm::AtomicRMWInst::FAdd, output_address, source,
287 llvm::MaybeAlign(),
288 llvm::AtomicOrdering::SequentiallyConsistent);
289 return true;
290 }
291 }
292
293 if (is_atomic_integral) {
294 // integral + integral
295 AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
296 llvm::MaybeAlign(),
297 llvm::AtomicOrdering::SequentiallyConsistent);
298 return true;
299 }
300 }
301
302 // NVPTX supports atomicMax and atomicMin only on integer types.
303 if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) {
304 // max(integral, integral)
305 auto opcode = primitive_util::IsSignedIntegralType(element_type)
306 ? llvm::AtomicRMWInst::Max
307 : llvm::AtomicRMWInst::UMax;
308 AtomicRMW(opcode, output_address, source, llvm::MaybeAlign(),
309 llvm::AtomicOrdering::SequentiallyConsistent);
310 return true;
311 }
312
313 if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) {
314 // min(integral, integral)
315 auto opcode = primitive_util::IsSignedIntegralType(element_type)
316 ? llvm::AtomicRMWInst::Min
317 : llvm::AtomicRMWInst::UMin;
318 AtomicRMW(opcode, output_address, source, llvm::MaybeAlign(),
319 llvm::AtomicOrdering::SequentiallyConsistent);
320 return true;
321 }
322
323 return false;
324 }
325
326 // Implements atomic binary operations using atomic compare-and-swap
327 // (atomicCAS) as follows:
328 // 1. Reads the value from the memory pointed to by output_address and
329 // records it as old_output.
330 // 2. Uses old_output as one of the source operand to perform the binary
331 // operation and stores the result in new_output.
332 // 3. Calls atomicCAS which implements compare-and-swap as an atomic
333 // operation. In particular, atomicCAS reads the value from the memory
334 // pointed to by output_address, and compares the value with old_output. If
335 // the two values equal, new_output is written to the same memory location
336 // and true is returned to indicate that the atomic operation succeeds.
337 // Otherwise, the new value read from the memory is returned. In this case,
338 // the new value is copied to old_output, and steps 2. and 3. are repeated
339 // until atomicCAS succeeds.
340 //
341 // On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If
342 // the element type of the binary operation is 32 bits or 64 bits, the integer
343 // type of the same size is used for the atomicCAS operation. On the other hand,
344 // if the element type is smaller than 32 bits, int32 is used for the atomicCAS
345 // operation. In this case, atomicCAS reads and writes 32 bit values from
346 // the memory, which is larger than the memory size required by the original
347 // atomic binary operation. We mask off the last two bits of the output_address
348 // and use the result as an address to read the 32 bit values from the memory.
349 // This can avoid out of bound memory accesses if tensor buffers are 4 byte
350 // aligned and have a size of 4N, an assumption that the runtime can guarantee.
351 //
352 // The pseudo code is shown below. Variables *_address are pointers to a memory
353 // region with a size equal to the size of the atomicCAS operation, with the
354 // exception that new_output_address is a pointer to a memory region with a size
355 // equal to the element size of the binary operation.
356 //
357 // element_size = sizeof(element_type);
358 // atomic_size = max(32, element_size);
359 // cas_new_output_address = alloca(atomic_size);
360 // cas_old_output_address = alloca(atomic_size);
361 // if (atomic_size != element_size) {
362 // atomic_address = output_address & ((int64)(-4));
363 // new_output_address = cas_new_output_address + (output_address & 3);
364 // } else {
365 // atomic_address = output_address;
366 // new_output_address = cas_new_output_address;
367 // }
368 //
369 // *cas_old_output_address = *atomic_address;
370 // do {
371 // *cas_new_output_address = *cas_old_output_address;
372 // *new_output_address = operation(*new_output_address, *source_address);
373 // (*cas_old_output_address, success) =
374 // atomicCAS(atomic_address, *cas_old_output_address,
375 // *cas_new_output_address);
376 // } while (!success);
377 //
EmitAtomicOperationUsingCAS(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)378 Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
379 llvm::Value* output_address,
380 llvm::Value* source_address) {
381 llvm::PointerType* output_address_type =
382 llvm::dyn_cast<llvm::PointerType>(output_address->getType());
383 CHECK_NE(output_address_type, nullptr);
384
385 // element_type is the data type for the binary operation.
386 llvm::Type* element_type = output_address_type->getPointerElementType();
387 int element_size = llvm_ir::GetSizeInBits(element_type);
388
389 int atomic_size = (element_size < 32) ? 32 : element_size;
390 llvm::Type* atomic_type = b_.getIntNTy(atomic_size);
391 llvm::Type* atomic_address_type =
392 atomic_type->getPointerTo(output_address_type->getPointerAddressSpace());
393
394 // cas_old_output_address and cas_new_output_address point to the scratch
395 // memory where we store the old and new values for the repeated atomicCAS
396 // operations.
397 llvm::Value* cas_old_output_address = llvm_ir::EmitAllocaAtFunctionEntry(
398 atomic_type, "cas_old_output_address", &b_);
399 llvm::Value* cas_new_output_address = llvm_ir::EmitAllocaAtFunctionEntry(
400 atomic_type, "cas_new_output_address", &b_);
401
402 // Emit preparation code to the preheader.
403 llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock();
404
405 llvm::Value* atomic_memory_address;
406 // binop_output_address points to the scratch memory that stores the
407 // result of the binary operation.
408 llvm::Value* binop_output_address;
409 if (element_size < 32) {
410 // Assume the element size is an integer number of bytes.
411 CHECK_EQ((element_size % sizeof(char)), 0);
412 llvm::Type* address_int_type =
413 module_->getDataLayout().getIntPtrType(output_address_type);
414 atomic_memory_address = PtrToInt(output_address, address_int_type);
415 llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3);
416 llvm::Value* offset = And(atomic_memory_address, mask);
417 mask = llvm::ConstantInt::get(address_int_type, -4);
418 atomic_memory_address = And(atomic_memory_address, mask);
419 atomic_memory_address =
420 IntToPtr(atomic_memory_address, atomic_address_type);
421 binop_output_address =
422 Add(PtrToInt(cas_new_output_address, address_int_type), offset);
423 binop_output_address = IntToPtr(
424 binop_output_address,
425 llvm::PointerType::get(
426 element_type,
427 cas_new_output_address->getType()->getPointerAddressSpace()));
428 } else {
429 atomic_memory_address = b_.CreatePointerBitCastOrAddrSpaceCast(
430 output_address, atomic_address_type);
431 binop_output_address = b_.CreatePointerBitCastOrAddrSpaceCast(
432 cas_new_output_address,
433 llvm::PointerType::get(
434 element_type,
435 cas_new_output_address->getType()->getPointerAddressSpace()));
436 }
437
438 // Use the value from the memory that atomicCAS operates on to initialize
439 // cas_old_output.
440 llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output");
441 Store(cas_old_output, cas_old_output_address);
442
443 llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock(
444 b_.GetInsertPoint(), "atomic_op_loop_exit");
445 llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create(
446 b_.getContext(), "atomic_op_loop_body", b_.GetInsertBlock()->getParent());
447 b_.SetInsertPoint(loop_body_bb);
448 // Change preheader's successor from loop_exit_bb to loop_body_bb.
449 loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb);
450
451 // Emit the body of the loop that repeatedly invokes atomicCAS.
452 //
453 // Use cas_old_output to initialize cas_new_output.
454 cas_old_output = Load(cas_old_output_address, "cas_old_output");
455 Store(cas_old_output, cas_new_output_address);
456 // Emits code to calculate new_output = operation(old_output, source);
457 TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
458 computation, {binop_output_address, source_address},
459 binop_output_address));
460
461 llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output");
462
463 // If cas_new_output == cas_old_output, we're not asking for anything to
464 // change, so we're done here!
465 llvm::Value* old_eq_new = ICmpEQ(cas_old_output, cas_new_output);
466 llvm::BasicBlock* loop_cas_bb = llvm::BasicBlock::Create(
467 b_.getContext(), "atomic_op_loop_cas", b_.GetInsertBlock()->getParent());
468 CondBr(old_eq_new, loop_exit_bb, loop_cas_bb);
469 b_.SetInsertPoint(loop_cas_bb);
470
471 // Emit code to perform the atomicCAS operation
472 // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output,
473 // cas_new_output);
474 llvm::Value* ret_value = AtomicCmpXchg(
475 atomic_memory_address, cas_old_output, cas_new_output, llvm::MaybeAlign(),
476 llvm::AtomicOrdering::SequentiallyConsistent,
477 llvm::AtomicOrdering::SequentiallyConsistent);
478
479 // Extract the memory value returned from atomicCAS and store it as
480 // cas_old_output.
481 Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address);
482 // Extract the success bit returned from atomicCAS and generate a
483 // conditional branch on the success bit.
484 CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb);
485
486 // Set the insertion point to the exit basic block so that the caller of
487 // this method can continue emitting code to the right place.
488 SetToFirstInsertPoint(loop_exit_bb, &b_);
489 return Status::OK();
490 }
491
EmitAtomicOperationForNestedComputation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)492 Status IrEmitter::EmitAtomicOperationForNestedComputation(
493 const HloComputation& computation, llvm::Value* output_address,
494 llvm::Value* source_address) {
495 if (computation.num_parameters() != 2) {
496 // TODO(b/30258929): We only accept binary computations so far.
497 return Unimplemented(
498 "We only support atomic functions with exactly two parameters, but "
499 "computation %s has %d.",
500 computation.name(), computation.num_parameters());
501 }
502
503 if (MaybeEmitDirectAtomicOperation(computation, output_address,
504 source_address)) {
505 return Status::OK();
506 }
507
508 return EmitAtomicOperationUsingCAS(computation, output_address,
509 source_address);
510 }
511
HandleTupleSelect(HloInstruction * tuple_select)512 Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
513 return InternalError(
514 "Dynamic selection of tuples is not supported. Please file a bug against "
515 "XLA/GPU if you need it");
516 }
517
518 namespace {
Real(llvm::Value * x,llvm::IRBuilder<> * b)519 llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* b) {
520 return b->CreateExtractValue(x, {0});
521 }
522
Imag(llvm::Value * x,llvm::IRBuilder<> * b)523 llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* b) {
524 return b->CreateExtractValue(x, {1});
525 }
526
MultiplyComplex(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)527 std::pair<llvm::Value*, llvm::Value*> MultiplyComplex(llvm::Value* lhs_value,
528 llvm::Value* rhs_value,
529 llvm::IRBuilder<>* b) {
530 llvm::Value* lhs_real = Real(lhs_value, b);
531 llvm::Value* lhs_imag = Imag(lhs_value, b);
532 llvm::Value* rhs_real = Real(rhs_value, b);
533 llvm::Value* rhs_imag = Imag(rhs_value, b);
534 llvm::Value* real_result1 = b->CreateFMul(lhs_real, rhs_real);
535 llvm::Value* real_result2 = b->CreateFMul(lhs_imag, rhs_imag);
536 llvm::Value* real_result = b->CreateFSub(real_result1, real_result2);
537 llvm::Value* imag_result1 = b->CreateFMul(lhs_real, rhs_imag);
538 llvm::Value* imag_result2 = b->CreateFMul(lhs_imag, rhs_real);
539 llvm::Value* imag_result = b->CreateFAdd(imag_result1, imag_result2);
540 return {real_result, imag_result};
541 }
542 } // namespace
543
HandleConvolution(HloInstruction * convolution)544 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
545 if (ShapeUtil::IsZeroElementArray(convolution->shape())) {
546 // Emit no code for an empty output.
547 return Status::OK();
548 }
549 // TODO(b/31409998): Support convolution with dilation.
550 return Unimplemented(
551 "Hit a case for convolution that is not implemented on GPU.");
552 }
553
HandleFft(HloInstruction * fft)554 Status IrEmitter::HandleFft(HloInstruction* fft) {
555 if (ShapeUtil::IsZeroElementArray(fft->shape())) {
556 // Emit no code for an empty output.
557 return Status::OK();
558 }
559 return Unimplemented("Hit a case for fft that is not implemented on GPU.");
560 }
561
HandleAllReduce(HloInstruction * crs)562 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
563 return Unimplemented(
564 "AllReduce cannot be nested inside of fusion, map, etc.");
565 }
566
HandleParameter(HloInstruction * parameter)567 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
568 return Status::OK();
569 }
570
HandleFusion(HloInstruction * fusion)571 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
572 // kFusion for library calls should be handled by
573 // IrEmitterUnnested::HandleFusion.
574 CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind());
575 GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
576 GetNestedComputer());
577 FusedIrEmitter fused_emitter(&elemental_emitter);
578 BindFusionArguments(fusion, &fused_emitter);
579 TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
580 fusion->fused_expression_root()));
581 return EmitTargetElementLoop(*fusion, generator);
582 }
583
HandleCall(HloInstruction * call)584 Status IrEmitter::HandleCall(HloInstruction* call) {
585 std::vector<llvm::Value*> operand_addresses;
586 for (HloInstruction* operand : call->operands()) {
587 operand_addresses.push_back(GetBasePointer(*operand));
588 }
589 return EmitCallToNestedComputation(*call->to_apply(), operand_addresses,
590 GetBasePointer(*call));
591 }
592
HandleCustomCall(HloInstruction *)593 Status IrEmitter::HandleCustomCall(HloInstruction*) {
594 return Unimplemented("custom-call");
595 }
596
HandleInfeed(HloInstruction *)597 Status IrEmitter::HandleInfeed(HloInstruction*) {
598 // TODO(b/30467474): Implement infeed on GPU.
599 return Unimplemented("Infeed is not supported on GPU.");
600 }
601
HandleOutfeed(HloInstruction *)602 Status IrEmitter::HandleOutfeed(HloInstruction*) {
603 // TODO(b/34359662): Implement outfeed on GPU.
604 return Unimplemented("Outfeed is not supported on GPU.");
605 }
606
HandleBatchNormInference(HloInstruction *)607 Status IrEmitter::HandleBatchNormInference(HloInstruction*) {
608 return Unimplemented(
609 "The GPU backend does not implement BatchNormInference directly. It "
610 "should be lowered before IR emission to HLO-soup using "
611 "BatchNormRewriter or to a cudnn CustomCall using "
612 "CudnnBatchNormRewriter.");
613 }
614
HandleBatchNormTraining(HloInstruction *)615 Status IrEmitter::HandleBatchNormTraining(HloInstruction*) {
616 return Unimplemented(
617 "The GPU backend does not implement BatchNormTraining directly. It "
618 "should be lowered before IR emission to HLO-soup using "
619 "BatchNormRewriter or to a cudnn CustomCall using "
620 "CudnnBatchNormRewriter.");
621 }
622
HandleBatchNormGrad(HloInstruction *)623 Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
624 return Unimplemented(
625 "The GPU backend does not implement BatchNormGrad directly. It should "
626 "be lowered before IR emission to HLO-soup (using BatchNormRewriter) or "
627 "to a cudnn CustomCall using CudnnBatchNormRewriter.");
628 }
629
ComputeNestedElement(const HloComputation & computation,absl::Span<llvm::Value * const> parameter_elements)630 StatusOr<std::vector<llvm::Value*>> IrEmitter::ComputeNestedElement(
631 const HloComputation& computation,
632 absl::Span<llvm::Value* const> parameter_elements) {
633 const Shape& return_shape = computation.root_instruction()->shape();
634 llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
635 llvm_ir::ShapeToIrType(return_shape, module_), "return_buffer", &b_);
636 std::vector<llvm::Value*> parameter_buffers;
637 for (llvm::Value* parameter_element : parameter_elements) {
638 parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
639 parameter_element->getType(), "parameter_buffer", &b_));
640 Store(parameter_element, parameter_buffers.back());
641 }
642
643 std::vector<llvm::Value*> allocas_for_returned_scalars;
644 if (!return_shape.IsTuple()) {
645 allocas_for_returned_scalars.push_back(return_buffer);
646 } else {
647 allocas_for_returned_scalars =
648 llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
649 llvm_ir::IrArray tuple_array(return_buffer, return_shape);
650
651 EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
652 }
653
654 TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
655 return_buffer));
656
657 std::vector<llvm::Value*> returned_scalars;
658 returned_scalars.reserve(allocas_for_returned_scalars.size());
659 for (llvm::Value* addr : allocas_for_returned_scalars) {
660 returned_scalars.push_back(Load(addr));
661 }
662 return returned_scalars;
663 }
664
ConstructIrArrayForOutputs(const HloInstruction & hlo)665 std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
666 const HloInstruction& hlo) {
667 std::vector<llvm_ir::IrArray> output_arrays;
668 if (hlo.shape().IsTuple()) {
669 int64_t num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
670 output_arrays.reserve(num_outputs);
671 for (int64_t i = 0; i < num_outputs; ++i) {
672 output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
673 }
674 } else {
675 output_arrays.push_back(GetIrArray(hlo, hlo));
676 }
677 return output_arrays;
678 }
679
BindFusionArguments(const HloInstruction * fusion,FusedIrEmitter * fused_emitter)680 void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
681 FusedIrEmitter* fused_emitter) {
682 for (int i = 0; i < fusion->operand_count(); i++) {
683 const HloInstruction* operand = fusion->operand(i);
684 fused_emitter->BindGenerator(
685 fusion->fused_parameter(i),
686 [this, operand, fusion](llvm_ir::IrArray::Index index) {
687 return GetIrArray(*operand, *fusion)
688 .EmitReadArrayElement(index, &b_, operand->name());
689 });
690 }
691 }
692
693 } // namespace gpu
694 } // namespace xla
695