1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
17
18 #include <string>
19 #include <unordered_map>
20 #include <utility>
21
22 #include "tensorflow/core/platform/logging.h"
23 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
24 #include "absl/algorithm/container.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Module.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
32 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
33 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
34 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
35 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
36 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
37 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
38 #include "tensorflow/compiler/xla/service/hlo_computation.h"
39 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
42 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
43 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
44 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
45 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
46 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
47 #include "tensorflow/compiler/xla/service/name_uniquer.h"
48 #include "tensorflow/compiler/xla/shape_util.h"
49 #include "tensorflow/compiler/xla/status_macros.h"
50 #include "tensorflow/compiler/xla/types.h"
51 #include "tensorflow/compiler/xla/util.h"
52 #include "tensorflow/compiler/xla/window_util.h"
53 #include "tensorflow/core/lib/core/errors.h"
54
55 // Convenient function to cast the provided llvm::Value* using IRBuilder
56 // to default address space. This is useful in particular for generating
57 // IR for AMDGPU target, as its kernel variables are in address space 5
58 // instead of the default address space.
AddrCastToDefault(llvm::Value * arg,llvm::IRBuilder<> & b)59 static llvm::Value* AddrCastToDefault(llvm::Value* arg, llvm::IRBuilder<>& b) {
60 llvm::Type* arg_type = arg->getType();
61 CHECK(arg_type->isPointerTy());
62 if (arg_type->getPointerAddressSpace() != 0) {
63 llvm::Type* generic_arg_type =
64 arg_type->getPointerElementType()->getPointerTo(0);
65 llvm::Value* addrspacecast_arg =
66 b.CreateAddrSpaceCast(arg, generic_arg_type);
67 return addrspacecast_arg;
68 }
69 return arg;
70 }
71
72 namespace xla {
73
74 using llvm_ir::IrName;
75 using llvm_ir::SetToFirstInsertPoint;
76
77 namespace gpu {
78
IrEmitter(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context,bool is_nested)79 IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
80 IrEmitterContext* ir_emitter_context, bool is_nested)
81 : ir_emitter_context_(ir_emitter_context),
82 module_(ir_emitter_context->llvm_module()),
83 b_(module_->getContext()),
84 bindings_(ir_emitter_context->hlo_module(),
85 &ir_emitter_context->buffer_assignment(), &b_, module_,
86 is_nested),
87 hlo_module_config_(hlo_module_config) {
88 }
89
DefaultAction(HloInstruction * hlo)90 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
91 ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
92 for (const HloInstruction* operand : hlo->operands()) {
93 operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
94 return GetIrArray(*operand, *hlo)
95 .EmitReadArrayElement(index, &b_, operand->name());
96 };
97 }
98 return EmitTargetElementLoop(
99 *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &b_,
100 GetNestedComputer())
101 .MakeElementGenerator(hlo, operand_to_generator));
102 }
103
EmitConstants(const HloComputation & computation,bool lookup_indices)104 Status IrEmitter::EmitConstants(const HloComputation& computation,
105 bool lookup_indices) {
106 for (HloInstruction* instr : computation.instructions()) {
107 if (instr->opcode() != HloOpcode::kConstant) {
108 continue;
109 }
110 Literal& literal = *Cast<HloConstantInstruction>(instr)->mutable_literal();
111 const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal);
112 llvm::ArrayType* global_type =
113 llvm::ArrayType::get(b_.getInt8Ty(), literal.size_bytes());
114 llvm::Constant* initializer =
115 should_emit_initializer
116 ? llvm_ir::ConvertLiteralToIrConstant(literal, module_)
117 : llvm::ConstantAggregateZero::get(global_type);
118 if (should_emit_initializer) {
119 VLOG(3) << "Emitted initializer for constant with shape "
120 << ShapeUtil::HumanString(literal.shape());
121 }
122
123 // These globals will be looked up by name by GpuExecutable so we need to
124 // give them an external linkage. Not all of their uses are visible in
125 // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that
126 // merely preserves their names (like available_externally), we also need
127 // to ensure that they stick around even if they're "unused".
128 //
129 // We may have to be more clever here in the future if we notice that we're
130 // keeping around too many globals because of their linkage.
131 unsigned global_address_space = llvm_ir::GetGlobalMemoryAddressSpace(
132 *ir_emitter_context_->llvm_module());
133
134 std::string global_name = llvm_ir::ConstantHloToGlobalName(*instr);
135
136 llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
137 global_type, /*isConstant=*/should_emit_initializer,
138 llvm::GlobalValue::ExternalLinkage,
139 /*Initializer=*/initializer, global_name,
140 /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
141 /*AddressSpace=*/global_address_space,
142 /*isExternallyInitialized=*/false);
143 global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
144 ir_emitter_context_->llvm_module()->getGlobalList().push_back(
145 global_for_const);
146
147 GpuExecutable::ConstantInfo info;
148 info.symbol_name = global_name;
149
150 if (!should_emit_initializer) {
151 auto base = static_cast<const uint8*>(literal.untyped_data());
152 info.content.assign(base, base + literal.size_bytes());
153 }
154 if (lookup_indices) {
155 auto maybe_slice =
156 ir_emitter_context_->buffer_assignment().GetUniqueSlice(instr, {});
157 if (maybe_slice.ok()) {
158 info.allocation_index = maybe_slice.ValueOrDie().index();
159 }
160 }
161 ir_emitter_context_->constants().push_back(std::move(info));
162 }
163 return Status::OK();
164 }
165
HandleConstant(HloInstruction * constant)166 Status IrEmitter::HandleConstant(HloInstruction* constant) {
167 return Status::OK();
168 }
169
HandleAddDependency(HloInstruction * add_dependency)170 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
171 VLOG(2) << "HandleAddDependency: " << add_dependency->ToString();
172 const HloInstruction* operand = add_dependency->operand(0);
173 // Add_Dependency is a no-op, but we still want to bind it to an llvm::Value
174 // sometimes, e.g., when it's operand is a constant or a bitcast of a
175 // constant.
176 if (bindings_.BoundToIrValue(*operand)) {
177 bindings_.BindHloToIrValue(*add_dependency, GetBasePointer(*operand));
178 }
179 return Status::OK();
180 }
181
HandleGetTupleElement(HloInstruction * get_tuple_element)182 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
183 auto operand = get_tuple_element->operand(0);
184 CHECK(bindings_.BoundToIrValue(*operand));
185 bindings_.BindHloToIrValue(
186 *get_tuple_element,
187 llvm_ir::EmitGetTupleElement(
188 get_tuple_element->shape(), get_tuple_element->tuple_index(),
189 // TODO(b/26344050): tighten the alignment here
190 // based on the real element type.
191 /*alignment=*/1, GetBasePointer(*operand), &b_));
192 return Status::OK();
193 }
194
HandleSend(HloInstruction *)195 Status IrEmitter::HandleSend(HloInstruction*) {
196 return Unimplemented("Send is not implemented on GPU");
197 }
198
HandleSendDone(HloInstruction *)199 Status IrEmitter::HandleSendDone(HloInstruction*) {
200 return Unimplemented("Send-Done is not implemented on GPU");
201 }
202
HandleRecv(HloInstruction *)203 Status IrEmitter::HandleRecv(HloInstruction*) {
204 return Unimplemented("Recv is not implemented on GPU");
205 }
206
HandleRecvDone(HloInstruction *)207 Status IrEmitter::HandleRecvDone(HloInstruction*) {
208 return Unimplemented("Recv-done is not implemented on GPU");
209 }
210
HandleScatter(HloInstruction *)211 Status IrEmitter::HandleScatter(HloInstruction*) {
212 return Unimplemented("Scatter is not implemented on GPUs.");
213 }
214
HandleTuple(HloInstruction * tuple)215 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
216 std::vector<llvm::Value*> base_ptrs;
217 for (const HloInstruction* operand : tuple->operands()) {
218 base_ptrs.push_back(GetBasePointer(*operand));
219 }
220 llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_);
221 return Status::OK();
222 }
223
EmitCallToNestedComputation(const HloComputation & nested_computation,absl::Span<llvm::Value * const> operands,llvm::Value * output)224 Status IrEmitter::EmitCallToNestedComputation(
225 const HloComputation& nested_computation,
226 absl::Span<llvm::Value* const> operands, llvm::Value* output) {
227 TF_RET_CHECK(nested_computation.num_parameters() > 0);
228 llvm::Function*& emitted_function =
229 computation_to_ir_function_[&nested_computation];
230 if (emitted_function == nullptr) {
231 TF_ASSIGN_OR_RETURN(
232 auto ir_emitter_nested,
233 IrEmitterNested::Create(hlo_module_config_, nested_computation,
234 ir_emitter_context_));
235 TF_RETURN_IF_ERROR(ir_emitter_nested->CodegenNestedComputation());
236 emitted_function = ir_emitter_nested->GetEmittedFunction();
237 }
238
239 // Operands are in default address space for non-AMDGPU target.
240 // However for AMDGPU target, addrspacecast alloca variables from
241 // addrspace 5 to addrspace 0 is needed.
242 std::vector<llvm::Value*> arguments;
243 absl::c_transform(
244 operands, std::back_inserter(arguments),
245 [this](llvm::Value* arg) { return AddrCastToDefault(arg, b_); });
246
247 llvm::Value* casted_output = AddrCastToDefault(output, b_);
248 arguments.push_back(casted_output);
249
250 Call(emitted_function, arguments);
251
252 return Status::OK();
253 }
254
MaybeEmitDirectAtomicOperation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)255 bool IrEmitter::MaybeEmitDirectAtomicOperation(
256 const HloComputation& computation, llvm::Value* output_address,
257 llvm::Value* source_address) {
258 CHECK_EQ(2, computation.num_parameters());
259
260 HloOpcode root_opcode = computation.root_instruction()->opcode();
261 PrimitiveType element_type =
262 computation.root_instruction()->shape().element_type();
263 bool is_atomic_integral = element_type == S32 || element_type == U32 ||
264 element_type == S64 || element_type == U64;
265 llvm::Value* source = Load(source_address, "source");
266
267 // Just passing along RHS -> atomic store.
268 if (computation.instruction_count() == 2 &&
269 root_opcode == HloOpcode::kParameter &&
270 (element_type == F32 || is_atomic_integral) &&
271 computation.root_instruction()->parameter_number() == 1) {
272 llvm::StoreInst* store = Store(source, output_address);
273 store->setAtomic(llvm::AtomicOrdering::Unordered);
274 // Derive a minimum alignment from the type. The optimizer can increase it
275 // later.
276 store->setAlignment(
277 llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(element_type)));
278 return true;
279 }
280
281 if (computation.instruction_count() != 3) {
282 // We special-case only computations with one computing instruction for now.
283 // Such computation has exactly three instructions given it has two
284 // parameters.
285 return false;
286 }
287
288 if (root_opcode == HloOpcode::kAdd) {
289 llvm::Triple target_triple = llvm::Triple(module_->getTargetTriple());
290 // NVPTX supports atomicAdd on F32 and integer types.
291 if (target_triple.isNVPTX()) {
292 // "atom.add.f64 requires sm_60 or higher."
293 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
294 bool f64_atomic_add_supported =
295 ir_emitter_context_->cuda_compute_capability()->cc_major >= 6;
296
297 bool atomic_add_supported =
298 element_type == F32 ||
299 (f64_atomic_add_supported && element_type == F64);
300 if (atomic_add_supported) {
301 AtomicRMW(llvm::AtomicRMWInst::FAdd, output_address, source,
302 llvm::AtomicOrdering::SequentiallyConsistent);
303 return true;
304 }
305 }
306
307 if (is_atomic_integral) {
308 // integral + integral
309 AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
310 llvm::AtomicOrdering::SequentiallyConsistent);
311 return true;
312 }
313 }
314
315 // NVPTX supports atomicMax and atomicMin only on integer types.
316 if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) {
317 // max(integral, integral)
318 auto opcode = primitive_util::IsSignedIntegralType(element_type)
319 ? llvm::AtomicRMWInst::Max
320 : llvm::AtomicRMWInst::UMax;
321 AtomicRMW(opcode, output_address, source,
322 llvm::AtomicOrdering::SequentiallyConsistent);
323 return true;
324 }
325
326 if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) {
327 // min(integral, integral)
328 auto opcode = primitive_util::IsSignedIntegralType(element_type)
329 ? llvm::AtomicRMWInst::Min
330 : llvm::AtomicRMWInst::UMin;
331 AtomicRMW(opcode, output_address, source,
332 llvm::AtomicOrdering::SequentiallyConsistent);
333 return true;
334 }
335
336 return false;
337 }
338
339 // Implements atomic binary operations using atomic compare-and-swap
340 // (atomicCAS) as follows:
341 // 1. Reads the value from the memory pointed to by output_address and
342 // records it as old_output.
343 // 2. Uses old_output as one of the source operand to perform the binary
344 // operation and stores the result in new_output.
345 // 3. Calls atomicCAS which implements compare-and-swap as an atomic
346 // operation. In particular, atomicCAS reads the value from the memory
347 // pointed to by output_address, and compares the value with old_output. If
348 // the two values equal, new_output is written to the same memory location
349 // and true is returned to indicate that the atomic operation succeeds.
350 // Otherwise, the new value read from the memory is returned. In this case,
351 // the new value is copied to old_output, and steps 2. and 3. are repeated
352 // until atomicCAS succeeds.
353 //
354 // On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If
355 // the element type of the binary operation is 32 bits or 64 bits, the integer
356 // type of the same size is used for the atomicCAS operation. On the other hand,
357 // if the element type is smaller than 32 bits, int32 is used for the atomicCAS
358 // operation. In this case, atomicCAS reads and writes 32 bit values from
359 // the memory, which is larger than the memory size required by the original
360 // atomic binary operation. We mask off the last two bits of the output_address
361 // and use the result as an address to read the 32 bit values from the memory.
362 // This can avoid out of bound memory accesses if tensor buffers are 4 byte
363 // aligned and have a size of 4N, an assumption that the runtime can guarantee.
364 //
365 // The pseudo code is shown below. Variables *_address are pointers to a memory
366 // region with a size equal to the size of the atomicCAS operation, with the
367 // exception that new_output_address is a pointer to a memory region with a size
368 // equal to the element size of the binary operation.
369 //
370 // element_size = sizeof(element_type);
371 // atomic_size = max(32, element_size);
372 // cas_new_output_address = alloca(atomic_size);
373 // cas_old_output_address = alloca(atomic_size);
374 // if (atomic_size != element_size) {
375 // atomic_address = output_address & ((int64)(-4));
376 // new_output_address = cas_new_output_address + (output_address & 3);
377 // } else {
378 // atomic_address = output_address;
379 // new_output_address = cas_new_output_address;
380 // }
381 //
382 // *cas_old_output_address = *atomic_address;
383 // do {
384 // *cas_new_output_address = *cas_old_output_address;
385 // *new_output_address = operation(*new_output_address, *source_address);
386 // (*cas_old_output_address, success) =
387 // atomicCAS(atomic_address, *cas_old_output_address,
388 // *cas_new_output_address);
389 // } while (!success);
390 //
EmitAtomicOperationUsingCAS(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)391 Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
392 llvm::Value* output_address,
393 llvm::Value* source_address) {
394 llvm::PointerType* output_address_type =
395 llvm::dyn_cast<llvm::PointerType>(output_address->getType());
396 CHECK_NE(output_address_type, nullptr);
397
398 // element_type is the data type for the binary operation.
399 llvm::Type* element_type = output_address_type->getPointerElementType();
400 int element_size = llvm_ir::GetSizeInBits(element_type);
401
402 int atomic_size = (element_size < 32) ? 32 : element_size;
403 llvm::Type* atomic_type = b_.getIntNTy(atomic_size);
404 llvm::Type* atomic_address_type =
405 atomic_type->getPointerTo(output_address_type->getPointerAddressSpace());
406
407 // cas_old_output_address and cas_new_output_address point to the scratch
408 // memory where we store the old and new values for the repeated atomicCAS
409 // operations.
410 llvm::Value* cas_old_output_address = llvm_ir::EmitAllocaAtFunctionEntry(
411 atomic_type, "cas_old_output_address", &b_);
412 llvm::Value* cas_new_output_address = llvm_ir::EmitAllocaAtFunctionEntry(
413 atomic_type, "cas_new_output_address", &b_);
414
415 // Emit preparation code to the preheader.
416 llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock();
417
418 llvm::Value* atomic_memory_address;
419 // binop_output_address points to the scratch memory that stores the
420 // result of the binary operation.
421 llvm::Value* binop_output_address;
422 if (element_size < 32) {
423 // Assume the element size is an integer number of bytes.
424 CHECK_EQ((element_size % sizeof(char)), 0);
425 llvm::Type* address_int_type =
426 module_->getDataLayout().getIntPtrType(output_address_type);
427 atomic_memory_address = PtrToInt(output_address, address_int_type);
428 llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3);
429 llvm::Value* offset = And(atomic_memory_address, mask);
430 mask = llvm::ConstantInt::get(address_int_type, -4);
431 atomic_memory_address = And(atomic_memory_address, mask);
432 atomic_memory_address =
433 IntToPtr(atomic_memory_address, atomic_address_type);
434 binop_output_address =
435 Add(PtrToInt(cas_new_output_address, address_int_type), offset);
436 binop_output_address = IntToPtr(
437 binop_output_address,
438 llvm::PointerType::get(
439 element_type,
440 cas_new_output_address->getType()->getPointerAddressSpace()));
441 } else {
442 atomic_memory_address = b_.CreatePointerBitCastOrAddrSpaceCast(
443 output_address, atomic_address_type);
444 binop_output_address = b_.CreatePointerBitCastOrAddrSpaceCast(
445 cas_new_output_address,
446 llvm::PointerType::get(
447 element_type,
448 cas_new_output_address->getType()->getPointerAddressSpace()));
449 }
450
451 // Use the value from the memory that atomicCAS operates on to initialize
452 // cas_old_output.
453 llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output");
454 Store(cas_old_output, cas_old_output_address);
455
456 llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock(
457 b_.GetInsertPoint(), "atomic_op_loop_exit");
458 llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create(
459 b_.getContext(), "atomic_op_loop_body", b_.GetInsertBlock()->getParent());
460 b_.SetInsertPoint(loop_body_bb);
461 // Change preheader's successor from loop_exit_bb to loop_body_bb.
462 loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb);
463
464 // Emit the body of the loop that repeatedly invokes atomicCAS.
465 //
466 // Use cas_old_output to initialize cas_new_output.
467 cas_old_output = Load(cas_old_output_address, "cas_old_output");
468 Store(cas_old_output, cas_new_output_address);
469 // Emits code to calculate new_output = operation(old_output, source);
470 TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
471 computation, {binop_output_address, source_address},
472 binop_output_address));
473
474 llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output");
475
476 // Emit code to perform the atomicCAS operation
477 // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output,
478 // cas_new_output);
479 llvm::Value* ret_value =
480 AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output,
481 llvm::AtomicOrdering::SequentiallyConsistent,
482 llvm::AtomicOrdering::SequentiallyConsistent);
483
484 // Extract the memory value returned from atomicCAS and store it as
485 // cas_old_output.
486 Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address);
487 // Extract the success bit returned from atomicCAS and generate a
488 // conditional branch on the success bit.
489 CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb);
490
491 // Set the insertion point to the exit basic block so that the caller of
492 // this method can continue emitting code to the right place.
493 SetToFirstInsertPoint(loop_exit_bb, &b_);
494 return Status::OK();
495 }
496
EmitAtomicOperationForNestedComputation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)497 Status IrEmitter::EmitAtomicOperationForNestedComputation(
498 const HloComputation& computation, llvm::Value* output_address,
499 llvm::Value* source_address) {
500 if (computation.num_parameters() != 2) {
501 // TODO(b/30258929): We only accept binary computations so far.
502 return Unimplemented(
503 "We only support atomic functions with exactly two parameters, but "
504 "computation %s has %d.",
505 computation.name(), computation.num_parameters());
506 }
507
508 if (MaybeEmitDirectAtomicOperation(computation, output_address,
509 source_address)) {
510 return Status::OK();
511 }
512
513 return EmitAtomicOperationUsingCAS(computation, output_address,
514 source_address);
515 }
516
HandleTupleSelect(HloInstruction * tuple_select)517 Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
518 return InternalError(
519 "Dynamic selection of tuples is not supported. Please file a bug against "
520 "XLA/GPU if you need it");
521 }
522
523 namespace {
Real(llvm::Value * x,llvm::IRBuilder<> * b)524 llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* b) {
525 return b->CreateExtractValue(x, {0});
526 }
527
Imag(llvm::Value * x,llvm::IRBuilder<> * b)528 llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* b) {
529 return b->CreateExtractValue(x, {1});
530 }
531
MultiplyComplex(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)532 std::pair<llvm::Value*, llvm::Value*> MultiplyComplex(llvm::Value* lhs_value,
533 llvm::Value* rhs_value,
534 llvm::IRBuilder<>* b) {
535 llvm::Value* lhs_real = Real(lhs_value, b);
536 llvm::Value* lhs_imag = Imag(lhs_value, b);
537 llvm::Value* rhs_real = Real(rhs_value, b);
538 llvm::Value* rhs_imag = Imag(rhs_value, b);
539 llvm::Value* real_result1 = b->CreateFMul(lhs_real, rhs_real);
540 llvm::Value* real_result2 = b->CreateFMul(lhs_imag, rhs_imag);
541 llvm::Value* real_result = b->CreateFSub(real_result1, real_result2);
542 llvm::Value* imag_result1 = b->CreateFMul(lhs_real, rhs_imag);
543 llvm::Value* imag_result2 = b->CreateFMul(lhs_imag, rhs_real);
544 llvm::Value* imag_result = b->CreateFAdd(imag_result1, imag_result2);
545 return {real_result, imag_result};
546 }
547 } // namespace
548
HandleConvolution(HloInstruction * convolution)549 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
550 if (ShapeUtil::IsZeroElementArray(convolution->shape())) {
551 // Emit no code for an empty output.
552 return Status::OK();
553 }
554 // TODO(b/31409998): Support convolution with dilation.
555 return Unimplemented(
556 "Hit a case for convolution that is not implemented on GPU.");
557 }
558
HandleFft(HloInstruction * fft)559 Status IrEmitter::HandleFft(HloInstruction* fft) {
560 if (ShapeUtil::IsZeroElementArray(fft->shape())) {
561 // Emit no code for an empty output.
562 return Status::OK();
563 }
564 return Unimplemented("Hit a case for fft that is not implemented on GPU.");
565 }
566
HandleAllReduce(HloInstruction * crs)567 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
568 return Unimplemented(
569 "AllReduce cannot be nested inside of fusion, map, etc.");
570 }
571
HandleParameter(HloInstruction * parameter)572 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
573 return Status::OK();
574 }
575
HandleFusion(HloInstruction * fusion)576 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
577 // kFusion for library calls should be handled by
578 // IrEmitterUnnested::HandleFusion.
579 CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind());
580 GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
581 GetNestedComputer());
582 FusedIrEmitter fused_emitter(&elemental_emitter);
583 BindFusionArguments(fusion, &fused_emitter);
584 TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
585 fusion->fused_expression_root()));
586 return EmitTargetElementLoop(*fusion, generator);
587 }
588
HandleCall(HloInstruction * call)589 Status IrEmitter::HandleCall(HloInstruction* call) {
590 std::vector<llvm::Value*> operand_addresses;
591 for (HloInstruction* operand : call->operands()) {
592 operand_addresses.push_back(GetBasePointer(*operand));
593 }
594 return EmitCallToNestedComputation(*call->to_apply(), operand_addresses,
595 GetBasePointer(*call));
596 }
597
HandleCustomCall(HloInstruction *)598 Status IrEmitter::HandleCustomCall(HloInstruction*) {
599 return Unimplemented("custom-call");
600 }
601
HandleInfeed(HloInstruction *)602 Status IrEmitter::HandleInfeed(HloInstruction*) {
603 // TODO(b/30467474): Implement infeed on GPU.
604 return Unimplemented("Infeed is not supported on GPU.");
605 }
606
HandleOutfeed(HloInstruction *)607 Status IrEmitter::HandleOutfeed(HloInstruction*) {
608 // TODO(b/34359662): Implement outfeed on GPU.
609 return Unimplemented("Outfeed is not supported on GPU.");
610 }
611
HandleBatchNormInference(HloInstruction *)612 Status IrEmitter::HandleBatchNormInference(HloInstruction*) {
613 return Unimplemented(
614 "The GPU backend does not implement BatchNormInference directly. It "
615 "should be lowered before IR emission to HLO-soup using "
616 "BatchNormRewriter or to a cudnn CustomCall using "
617 "CudnnBatchNormRewriter.");
618 }
619
HandleBatchNormTraining(HloInstruction *)620 Status IrEmitter::HandleBatchNormTraining(HloInstruction*) {
621 return Unimplemented(
622 "The GPU backend does not implement BatchNormTraining directly. It "
623 "should be lowered before IR emission to HLO-soup using "
624 "BatchNormRewriter or to a cudnn CustomCall using "
625 "CudnnBatchNormRewriter.");
626 }
627
HandleBatchNormGrad(HloInstruction *)628 Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
629 return Unimplemented(
630 "The GPU backend does not implement BatchNormGrad directly. It should "
631 "be lowered before IR emission to HLO-soup (using BatchNormRewriter) or "
632 "to a cudnn CustomCall using CudnnBatchNormRewriter.");
633 }
634
ComputeNestedElement(const HloComputation & computation,absl::Span<llvm::Value * const> parameter_elements)635 StatusOr<std::vector<llvm::Value*>> IrEmitter::ComputeNestedElement(
636 const HloComputation& computation,
637 absl::Span<llvm::Value* const> parameter_elements) {
638 const Shape& return_shape = computation.root_instruction()->shape();
639 llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
640 llvm_ir::ShapeToIrType(return_shape, module_), "return_buffer", &b_);
641 std::vector<llvm::Value*> parameter_buffers;
642 for (llvm::Value* parameter_element : parameter_elements) {
643 parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
644 parameter_element->getType(), "parameter_buffer", &b_));
645 Store(parameter_element, parameter_buffers.back());
646 }
647
648 std::vector<llvm::Value*> allocas_for_returned_scalars;
649 if (!return_shape.IsTuple()) {
650 allocas_for_returned_scalars.push_back(return_buffer);
651 } else {
652 allocas_for_returned_scalars =
653 llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
654 llvm_ir::IrArray tuple_array(return_buffer, return_shape);
655
656 EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
657 }
658
659 TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
660 return_buffer));
661
662 std::vector<llvm::Value*> returned_scalars;
663 returned_scalars.reserve(allocas_for_returned_scalars.size());
664 for (llvm::Value* addr : allocas_for_returned_scalars) {
665 returned_scalars.push_back(Load(addr));
666 }
667 return returned_scalars;
668 }
669
ConstructIrArrayForOutputs(const HloInstruction & hlo)670 std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
671 const HloInstruction& hlo) {
672 std::vector<llvm_ir::IrArray> output_arrays;
673 if (hlo.shape().IsTuple()) {
674 int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
675 output_arrays.reserve(num_outputs);
676 for (int64 i = 0; i < num_outputs; ++i) {
677 output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
678 }
679 } else {
680 output_arrays.push_back(GetIrArray(hlo, hlo));
681 }
682 return output_arrays;
683 }
684
BindFusionArguments(const HloInstruction * fusion,FusedIrEmitter * fused_emitter)685 void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
686 FusedIrEmitter* fused_emitter) {
687 for (int i = 0; i < fusion->operand_count(); i++) {
688 const HloInstruction* operand = fusion->operand(i);
689 fused_emitter->BindGenerator(
690 fusion->fused_parameter(i),
691 [this, operand, fusion](llvm_ir::IrArray::Index index) {
692 return GetIrArray(*operand, *fusion)
693 .EmitReadArrayElement(index, &b_, operand->name());
694 });
695 }
696 }
697
698 } // namespace gpu
699 } // namespace xla
700