• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
17 
18 #include <string>
19 #include <unordered_map>
20 #include <utility>
21 
22 #include "tensorflow/core/platform/logging.h"
23 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
24 #include "absl/algorithm/container.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Module.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
32 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
33 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
34 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
35 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
36 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
37 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
38 #include "tensorflow/compiler/xla/service/hlo_computation.h"
39 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
42 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
43 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
44 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
45 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
46 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
47 #include "tensorflow/compiler/xla/service/name_uniquer.h"
48 #include "tensorflow/compiler/xla/shape_util.h"
49 #include "tensorflow/compiler/xla/status_macros.h"
50 #include "tensorflow/compiler/xla/types.h"
51 #include "tensorflow/compiler/xla/util.h"
52 #include "tensorflow/compiler/xla/window_util.h"
53 #include "tensorflow/core/lib/core/errors.h"
54 
55 // Convenient function to cast the provided llvm::Value* using IRBuilder
56 // to default address space. This is useful in particular for generating
57 // IR for AMDGPU target, as its kernel variables are in address space 5
58 // instead of the default address space.
AddrCastToDefault(llvm::Value * arg,llvm::IRBuilder<> & b)59 static llvm::Value* AddrCastToDefault(llvm::Value* arg, llvm::IRBuilder<>& b) {
60   llvm::Type* arg_type = arg->getType();
61   CHECK(arg_type->isPointerTy());
62   if (arg_type->getPointerAddressSpace() != 0) {
63     llvm::Type* generic_arg_type =
64         arg_type->getPointerElementType()->getPointerTo(0);
65     llvm::Value* addrspacecast_arg =
66         b.CreateAddrSpaceCast(arg, generic_arg_type);
67     return addrspacecast_arg;
68   }
69   return arg;
70 }
71 
72 namespace xla {
73 
74 using llvm_ir::IrName;
75 using llvm_ir::SetToFirstInsertPoint;
76 
77 namespace gpu {
78 
IrEmitter(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context,bool is_nested)79 IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
80                      IrEmitterContext* ir_emitter_context, bool is_nested)
81     : ir_emitter_context_(ir_emitter_context),
82       module_(ir_emitter_context->llvm_module()),
83       b_(module_->getContext()),
84       bindings_(&b_, module_, is_nested),
85       hlo_module_config_(hlo_module_config) {}
86 
DefaultAction(HloInstruction * hlo)87 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
88   ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
89   for (const HloInstruction* operand : hlo->operands()) {
90     operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
91       return GetIrArray(*operand, *hlo)
92           .EmitReadArrayElement(index, &b_, operand->name());
93     };
94   }
95   return EmitTargetElementLoop(
96       *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &b_,
97                                   GetNestedComputer())
98                 .MakeElementGenerator(hlo, operand_to_generator));
99 }
100 
EmitConstants(const HloComputation & computation)101 Status IrEmitter::EmitConstants(const HloComputation& computation) {
102   for (HloInstruction* instr : computation.instructions()) {
103     if (instr->opcode() != HloOpcode::kConstant) {
104       continue;
105     }
106     Literal& literal = *Cast<HloConstantInstruction>(instr)->mutable_literal();
107     const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal);
108     llvm::ArrayType* global_type =
109         llvm::ArrayType::get(b_.getInt8Ty(), literal.size_bytes());
110     llvm::Constant* initializer =
111         should_emit_initializer
112             ? llvm_ir::ConvertLiteralToIrConstant(literal, module_)
113             : llvm::ConstantAggregateZero::get(global_type);
114     if (should_emit_initializer) {
115       VLOG(3) << "Emitted initializer for constant with shape "
116               << ShapeUtil::HumanString(literal.shape());
117     }
118 
119     // These globals will be looked up by name by GpuExecutable so we need to
120     // give them an external linkage.  Not all of their uses are visible in
121     // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that
122     // merely preserves their names (like available_externally), we also need
123     // to ensure that they stick around even if they're "unused".
124     //
125     // We may have to be more clever here in the future if we notice that we're
126     // keeping around too many globals because of their linkage.
127     std::string global_name = llvm_ir::ConstantHloToGlobalName(*instr);
128 
129     llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
130         global_type, /*isConstant=*/should_emit_initializer,
131         llvm::GlobalValue::ExternalLinkage,
132         /*Initializer=*/initializer, global_name,
133         /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
134         /*AddressSpace=*/0,
135         /*isExternallyInitialized=*/false);
136     global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
137     ir_emitter_context_->llvm_module()->getGlobalList().push_back(
138         global_for_const);
139 
140     GpuExecutable::ConstantInfo info;
141     info.symbol_name = global_name;
142 
143     if (!should_emit_initializer) {
144       auto base = static_cast<const uint8*>(literal.untyped_data());
145       info.content.assign(base, base + literal.size_bytes());
146     }
147     ir_emitter_context_->constants().push_back(std::move(info));
148   }
149   return Status::OK();
150 }
151 
HandleConstant(HloInstruction * constant)152 Status IrEmitter::HandleConstant(HloInstruction* constant) {
153   return Status::OK();
154 }
155 
HandleAddDependency(HloInstruction * add_dependency)156 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
157   VLOG(2) << "HandleAddDependency: " << add_dependency->ToString();
158   const HloInstruction* operand = add_dependency->operand(0);
159   // Add_Dependency is a no-op, but we still want to bind it to an llvm::Value
160   // sometimes, e.g., when it's operand is a constant or a bitcast of a
161   // constant.
162   if (bindings_.BoundToIrValue(*operand)) {
163     bindings_.BindHloToIrValue(*add_dependency, GetBasePointer(*operand));
164   }
165   return Status::OK();
166 }
167 
HandleGetTupleElement(HloInstruction * get_tuple_element)168 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
169   auto operand = get_tuple_element->operand(0);
170   CHECK(bindings_.BoundToIrValue(*operand));
171   bindings_.BindHloToIrValue(
172       *get_tuple_element,
173       llvm_ir::EmitGetTupleElement(
174           get_tuple_element->shape(), get_tuple_element->tuple_index(),
175           // TODO(b/26344050): tighten the alignment here
176           // based on the real element type.
177           /*alignment=*/1, GetBasePointer(*operand), &b_));
178   return Status::OK();
179 }
180 
HandleSend(HloInstruction *)181 Status IrEmitter::HandleSend(HloInstruction*) {
182   return Unimplemented("Send is not implemented on GPU");
183 }
184 
HandleSendDone(HloInstruction *)185 Status IrEmitter::HandleSendDone(HloInstruction*) {
186   return Unimplemented("Send-Done is not implemented on GPU");
187 }
188 
HandleRecv(HloInstruction *)189 Status IrEmitter::HandleRecv(HloInstruction*) {
190   return Unimplemented("Recv is not implemented on GPU");
191 }
192 
HandleRecvDone(HloInstruction *)193 Status IrEmitter::HandleRecvDone(HloInstruction*) {
194   return Unimplemented("Recv-done is not implemented on GPU");
195 }
196 
HandleScatter(HloInstruction *)197 Status IrEmitter::HandleScatter(HloInstruction*) {
198   return Unimplemented("Scatter is not implemented on GPUs.");
199 }
200 
HandleTuple(HloInstruction * tuple)201 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
202   std::vector<llvm::Value*> base_ptrs;
203   for (const HloInstruction* operand : tuple->operands()) {
204     base_ptrs.push_back(GetBasePointer(*operand));
205   }
206   llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_);
207   return Status::OK();
208 }
209 
EmitCallToNestedComputation(const HloComputation & nested_computation,absl::Span<llvm::Value * const> operands,llvm::Value * output)210 Status IrEmitter::EmitCallToNestedComputation(
211     const HloComputation& nested_computation,
212     absl::Span<llvm::Value* const> operands, llvm::Value* output) {
213   TF_RET_CHECK(nested_computation.num_parameters() > 0);
214   llvm::Function*& emitted_function =
215       computation_to_ir_function_[&nested_computation];
216   if (emitted_function == nullptr) {
217     TF_ASSIGN_OR_RETURN(
218         auto ir_emitter_nested,
219         IrEmitterNested::Create(hlo_module_config_, nested_computation,
220                                 ir_emitter_context_));
221     TF_RETURN_IF_ERROR(ir_emitter_nested->CodegenNestedComputation());
222     emitted_function = ir_emitter_nested->GetEmittedFunction();
223   }
224 
225   // Operands are in default address space for non-AMDGPU target.
226   // However for AMDGPU target, addrspacecast alloca variables from
227   // addrspace 5 to addrspace 0 is needed.
228   std::vector<llvm::Value*> arguments;
229   absl::c_transform(
230       operands, std::back_inserter(arguments),
231       [this](llvm::Value* arg) { return AddrCastToDefault(arg, b_); });
232 
233   llvm::Value* casted_output = AddrCastToDefault(output, b_);
234   arguments.push_back(casted_output);
235 
236   Call(emitted_function, arguments);
237 
238   return Status::OK();
239 }
240 
MaybeEmitDirectAtomicOperation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)241 bool IrEmitter::MaybeEmitDirectAtomicOperation(
242     const HloComputation& computation, llvm::Value* output_address,
243     llvm::Value* source_address) {
244   CHECK_EQ(2, computation.num_parameters());
245 
246   HloOpcode root_opcode = computation.root_instruction()->opcode();
247   PrimitiveType element_type =
248       computation.root_instruction()->shape().element_type();
249   bool is_atomic_integral = element_type == S32 || element_type == U32 ||
250                             element_type == S64 || element_type == U64;
251   llvm::Value* source = Load(source_address, "source");
252 
253   // Just passing along RHS -> atomic store.
254   if (computation.instruction_count() == 2 &&
255       root_opcode == HloOpcode::kParameter &&
256       (element_type == F32 || is_atomic_integral) &&
257       computation.root_instruction()->parameter_number() == 1) {
258     llvm::StoreInst* store = Store(source, output_address);
259     store->setAtomic(llvm::AtomicOrdering::Unordered);
260     // Derive a minimum alignment from the type. The optimizer can increase it
261     // later.
262     store->setAlignment(
263         llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(element_type)));
264     return true;
265   }
266 
267   if (computation.instruction_count() != 3) {
268     // We special-case only computations with one computing instruction for now.
269     // Such computation has exactly three instructions given it has two
270     // parameters.
271     return false;
272   }
273 
274   if (root_opcode == HloOpcode::kAdd) {
275     llvm::Triple target_triple = llvm::Triple(module_->getTargetTriple());
276     // NVPTX supports atomicAdd on F32 and integer types.
277     if (target_triple.isNVPTX()) {
278       // "atom.add.f64 requires sm_60 or higher."
279       // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
280       bool f64_atomic_add_supported =
281           ir_emitter_context_->cuda_compute_capability().IsAtLeast(6);
282       bool atomic_add_supported =
283           element_type == F32 ||
284           (f64_atomic_add_supported && element_type == F64);
285       if (atomic_add_supported) {
286         AtomicRMW(llvm::AtomicRMWInst::FAdd, output_address, source,
287                   llvm::MaybeAlign(),
288                   llvm::AtomicOrdering::SequentiallyConsistent);
289         return true;
290       }
291     }
292 
293     if (is_atomic_integral) {
294       // integral + integral
295       AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
296                 llvm::MaybeAlign(),
297                 llvm::AtomicOrdering::SequentiallyConsistent);
298       return true;
299     }
300   }
301 
302   // NVPTX supports atomicMax and atomicMin only on integer types.
303   if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) {
304     // max(integral, integral)
305     auto opcode = primitive_util::IsSignedIntegralType(element_type)
306                       ? llvm::AtomicRMWInst::Max
307                       : llvm::AtomicRMWInst::UMax;
308     AtomicRMW(opcode, output_address, source, llvm::MaybeAlign(),
309               llvm::AtomicOrdering::SequentiallyConsistent);
310     return true;
311   }
312 
313   if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) {
314     // min(integral, integral)
315     auto opcode = primitive_util::IsSignedIntegralType(element_type)
316                       ? llvm::AtomicRMWInst::Min
317                       : llvm::AtomicRMWInst::UMin;
318     AtomicRMW(opcode, output_address, source, llvm::MaybeAlign(),
319               llvm::AtomicOrdering::SequentiallyConsistent);
320     return true;
321   }
322 
323   return false;
324 }
325 
326 // Implements atomic binary operations using atomic compare-and-swap
327 // (atomicCAS) as follows:
328 //   1. Reads the value from the memory pointed to by output_address and
329 //     records it as old_output.
330 //   2. Uses old_output as one of the source operand to perform the binary
331 //     operation and stores the result in new_output.
332 //   3. Calls atomicCAS which implements compare-and-swap as an atomic
333 //     operation. In particular, atomicCAS reads the value from the memory
334 //     pointed to by output_address, and compares the value with old_output. If
335 //     the two values equal, new_output is written to the same memory location
336 //     and true is returned to indicate that the atomic operation succeeds.
337 //     Otherwise, the new value read from the memory is returned. In this case,
338 //     the new value is copied to old_output, and steps 2. and 3. are repeated
339 //     until atomicCAS succeeds.
340 //
341 // On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If
342 // the element type of the binary operation is 32 bits or 64 bits, the integer
343 // type of the same size is used for the atomicCAS operation. On the other hand,
344 // if the element type is smaller than 32 bits, int32 is used for the atomicCAS
345 // operation. In this case, atomicCAS reads and writes 32 bit values from
346 // the memory, which is larger than the memory size required by the original
347 // atomic binary operation. We mask off the last two bits of the output_address
348 // and use the result as an address to read the 32 bit values from the memory.
349 // This can avoid out of bound memory accesses if tensor buffers are 4 byte
350 // aligned and have a size of 4N, an assumption that the runtime can guarantee.
351 //
352 // The pseudo code is shown below. Variables *_address are pointers to a memory
353 // region with a size equal to the size of the atomicCAS operation, with the
354 // exception that new_output_address is a pointer to a memory region with a size
355 // equal to the element size of the binary operation.
356 //
357 //   element_size = sizeof(element_type);
358 //   atomic_size = max(32, element_size);
359 //   cas_new_output_address = alloca(atomic_size);
360 //   cas_old_output_address = alloca(atomic_size);
361 //   if (atomic_size != element_size) {
362 //     atomic_address = output_address & ((int64)(-4));
363 //     new_output_address = cas_new_output_address + (output_address & 3);
364 //   } else {
365 //     atomic_address = output_address;
366 //     new_output_address = cas_new_output_address;
367 //   }
368 //
369 //   *cas_old_output_address = *atomic_address;
370 //   do {
371 //     *cas_new_output_address = *cas_old_output_address;
372 //     *new_output_address = operation(*new_output_address, *source_address);
373 //     (*cas_old_output_address, success) =
374 //       atomicCAS(atomic_address, *cas_old_output_address,
375 //       *cas_new_output_address);
376 //   } while (!success);
377 //
EmitAtomicOperationUsingCAS(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)378 Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
379                                               llvm::Value* output_address,
380                                               llvm::Value* source_address) {
381   llvm::PointerType* output_address_type =
382       llvm::dyn_cast<llvm::PointerType>(output_address->getType());
383   CHECK_NE(output_address_type, nullptr);
384 
385   // element_type is the data type for the binary operation.
386   llvm::Type* element_type = output_address_type->getPointerElementType();
387   int element_size = llvm_ir::GetSizeInBits(element_type);
388 
389   int atomic_size = (element_size < 32) ? 32 : element_size;
390   llvm::Type* atomic_type = b_.getIntNTy(atomic_size);
391   llvm::Type* atomic_address_type =
392       atomic_type->getPointerTo(output_address_type->getPointerAddressSpace());
393 
394   // cas_old_output_address and cas_new_output_address point to the scratch
395   // memory where we store the old and new values for the repeated atomicCAS
396   // operations.
397   llvm::Value* cas_old_output_address = llvm_ir::EmitAllocaAtFunctionEntry(
398       atomic_type, "cas_old_output_address", &b_);
399   llvm::Value* cas_new_output_address = llvm_ir::EmitAllocaAtFunctionEntry(
400       atomic_type, "cas_new_output_address", &b_);
401 
402   // Emit preparation code to the preheader.
403   llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock();
404 
405   llvm::Value* atomic_memory_address;
406   // binop_output_address points to the scratch memory that stores the
407   // result of the binary operation.
408   llvm::Value* binop_output_address;
409   if (element_size < 32) {
410     // Assume the element size is an integer number of bytes.
411     CHECK_EQ((element_size % sizeof(char)), 0);
412     llvm::Type* address_int_type =
413         module_->getDataLayout().getIntPtrType(output_address_type);
414     atomic_memory_address = PtrToInt(output_address, address_int_type);
415     llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3);
416     llvm::Value* offset = And(atomic_memory_address, mask);
417     mask = llvm::ConstantInt::get(address_int_type, -4);
418     atomic_memory_address = And(atomic_memory_address, mask);
419     atomic_memory_address =
420         IntToPtr(atomic_memory_address, atomic_address_type);
421     binop_output_address =
422         Add(PtrToInt(cas_new_output_address, address_int_type), offset);
423     binop_output_address = IntToPtr(
424         binop_output_address,
425         llvm::PointerType::get(
426             element_type,
427             cas_new_output_address->getType()->getPointerAddressSpace()));
428   } else {
429     atomic_memory_address = b_.CreatePointerBitCastOrAddrSpaceCast(
430         output_address, atomic_address_type);
431     binop_output_address = b_.CreatePointerBitCastOrAddrSpaceCast(
432         cas_new_output_address,
433         llvm::PointerType::get(
434             element_type,
435             cas_new_output_address->getType()->getPointerAddressSpace()));
436   }
437 
438   // Use the value from the memory that atomicCAS operates on to initialize
439   // cas_old_output.
440   llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output");
441   Store(cas_old_output, cas_old_output_address);
442 
443   llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock(
444       b_.GetInsertPoint(), "atomic_op_loop_exit");
445   llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create(
446       b_.getContext(), "atomic_op_loop_body", b_.GetInsertBlock()->getParent());
447   b_.SetInsertPoint(loop_body_bb);
448   // Change preheader's successor from loop_exit_bb to loop_body_bb.
449   loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb);
450 
451   // Emit the body of the loop that repeatedly invokes atomicCAS.
452   //
453   // Use cas_old_output to initialize cas_new_output.
454   cas_old_output = Load(cas_old_output_address, "cas_old_output");
455   Store(cas_old_output, cas_new_output_address);
456   // Emits code to calculate new_output = operation(old_output, source);
457   TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
458       computation, {binop_output_address, source_address},
459       binop_output_address));
460 
461   llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output");
462 
463   // If cas_new_output == cas_old_output, we're not asking for anything to
464   // change, so we're done here!
465   llvm::Value* old_eq_new = ICmpEQ(cas_old_output, cas_new_output);
466   llvm::BasicBlock* loop_cas_bb = llvm::BasicBlock::Create(
467       b_.getContext(), "atomic_op_loop_cas", b_.GetInsertBlock()->getParent());
468   CondBr(old_eq_new, loop_exit_bb, loop_cas_bb);
469   b_.SetInsertPoint(loop_cas_bb);
470 
471   // Emit code to perform the atomicCAS operation
472   // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output,
473   //                                       cas_new_output);
474   llvm::Value* ret_value = AtomicCmpXchg(
475       atomic_memory_address, cas_old_output, cas_new_output, llvm::MaybeAlign(),
476       llvm::AtomicOrdering::SequentiallyConsistent,
477       llvm::AtomicOrdering::SequentiallyConsistent);
478 
479   // Extract the memory value returned from atomicCAS and store it as
480   // cas_old_output.
481   Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address);
482   // Extract the success bit returned from atomicCAS and generate a
483   // conditional branch on the success bit.
484   CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb);
485 
486   // Set the insertion point to the exit basic block so that the caller of
487   // this method can continue emitting code to the right place.
488   SetToFirstInsertPoint(loop_exit_bb, &b_);
489   return Status::OK();
490 }
491 
EmitAtomicOperationForNestedComputation(const HloComputation & computation,llvm::Value * output_address,llvm::Value * source_address)492 Status IrEmitter::EmitAtomicOperationForNestedComputation(
493     const HloComputation& computation, llvm::Value* output_address,
494     llvm::Value* source_address) {
495   if (computation.num_parameters() != 2) {
496     // TODO(b/30258929): We only accept binary computations so far.
497     return Unimplemented(
498         "We only support atomic functions with exactly two parameters, but "
499         "computation %s has %d.",
500         computation.name(), computation.num_parameters());
501   }
502 
503   if (MaybeEmitDirectAtomicOperation(computation, output_address,
504                                      source_address)) {
505     return Status::OK();
506   }
507 
508   return EmitAtomicOperationUsingCAS(computation, output_address,
509                                      source_address);
510 }
511 
HandleTupleSelect(HloInstruction * tuple_select)512 Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
513   return InternalError(
514       "Dynamic selection of tuples is not supported. Please file a bug against "
515       "XLA/GPU if you need it");
516 }
517 
518 namespace {
Real(llvm::Value * x,llvm::IRBuilder<> * b)519 llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* b) {
520   return b->CreateExtractValue(x, {0});
521 }
522 
Imag(llvm::Value * x,llvm::IRBuilder<> * b)523 llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* b) {
524   return b->CreateExtractValue(x, {1});
525 }
526 
MultiplyComplex(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)527 std::pair<llvm::Value*, llvm::Value*> MultiplyComplex(llvm::Value* lhs_value,
528                                                       llvm::Value* rhs_value,
529                                                       llvm::IRBuilder<>* b) {
530   llvm::Value* lhs_real = Real(lhs_value, b);
531   llvm::Value* lhs_imag = Imag(lhs_value, b);
532   llvm::Value* rhs_real = Real(rhs_value, b);
533   llvm::Value* rhs_imag = Imag(rhs_value, b);
534   llvm::Value* real_result1 = b->CreateFMul(lhs_real, rhs_real);
535   llvm::Value* real_result2 = b->CreateFMul(lhs_imag, rhs_imag);
536   llvm::Value* real_result = b->CreateFSub(real_result1, real_result2);
537   llvm::Value* imag_result1 = b->CreateFMul(lhs_real, rhs_imag);
538   llvm::Value* imag_result2 = b->CreateFMul(lhs_imag, rhs_real);
539   llvm::Value* imag_result = b->CreateFAdd(imag_result1, imag_result2);
540   return {real_result, imag_result};
541 }
542 }  // namespace
543 
HandleConvolution(HloInstruction * convolution)544 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
545   if (ShapeUtil::IsZeroElementArray(convolution->shape())) {
546     // Emit no code for an empty output.
547     return Status::OK();
548   }
549   // TODO(b/31409998): Support convolution with dilation.
550   return Unimplemented(
551       "Hit a case for convolution that is not implemented on GPU.");
552 }
553 
HandleFft(HloInstruction * fft)554 Status IrEmitter::HandleFft(HloInstruction* fft) {
555   if (ShapeUtil::IsZeroElementArray(fft->shape())) {
556     // Emit no code for an empty output.
557     return Status::OK();
558   }
559   return Unimplemented("Hit a case for fft that is not implemented on GPU.");
560 }
561 
HandleAllReduce(HloInstruction * crs)562 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
563   return Unimplemented(
564       "AllReduce cannot be nested inside of fusion, map, etc.");
565 }
566 
HandleParameter(HloInstruction * parameter)567 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
568   return Status::OK();
569 }
570 
HandleFusion(HloInstruction * fusion)571 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
572   // kFusion for library calls should be handled by
573   // IrEmitterUnnested::HandleFusion.
574   CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind());
575   GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
576                                           GetNestedComputer());
577   FusedIrEmitter fused_emitter(&elemental_emitter);
578   BindFusionArguments(fusion, &fused_emitter);
579   TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
580                                           fusion->fused_expression_root()));
581   return EmitTargetElementLoop(*fusion, generator);
582 }
583 
HandleCall(HloInstruction * call)584 Status IrEmitter::HandleCall(HloInstruction* call) {
585   std::vector<llvm::Value*> operand_addresses;
586   for (HloInstruction* operand : call->operands()) {
587     operand_addresses.push_back(GetBasePointer(*operand));
588   }
589   return EmitCallToNestedComputation(*call->to_apply(), operand_addresses,
590                                      GetBasePointer(*call));
591 }
592 
HandleCustomCall(HloInstruction *)593 Status IrEmitter::HandleCustomCall(HloInstruction*) {
594   return Unimplemented("custom-call");
595 }
596 
HandleInfeed(HloInstruction *)597 Status IrEmitter::HandleInfeed(HloInstruction*) {
598   // TODO(b/30467474): Implement infeed on GPU.
599   return Unimplemented("Infeed is not supported on GPU.");
600 }
601 
HandleOutfeed(HloInstruction *)602 Status IrEmitter::HandleOutfeed(HloInstruction*) {
603   // TODO(b/34359662): Implement outfeed on GPU.
604   return Unimplemented("Outfeed is not supported on GPU.");
605 }
606 
HandleBatchNormInference(HloInstruction *)607 Status IrEmitter::HandleBatchNormInference(HloInstruction*) {
608   return Unimplemented(
609       "The GPU backend does not implement BatchNormInference directly.  It "
610       "should be lowered before IR emission to HLO-soup using "
611       "BatchNormRewriter or to a cudnn CustomCall using "
612       "CudnnBatchNormRewriter.");
613 }
614 
HandleBatchNormTraining(HloInstruction *)615 Status IrEmitter::HandleBatchNormTraining(HloInstruction*) {
616   return Unimplemented(
617       "The GPU backend does not implement BatchNormTraining directly.  It "
618       "should be lowered before IR emission to HLO-soup using "
619       "BatchNormRewriter or to a cudnn CustomCall using "
620       "CudnnBatchNormRewriter.");
621 }
622 
HandleBatchNormGrad(HloInstruction *)623 Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
624   return Unimplemented(
625       "The GPU backend does not implement BatchNormGrad directly.  It should "
626       "be lowered before IR emission to HLO-soup (using BatchNormRewriter) or "
627       "to a cudnn CustomCall using CudnnBatchNormRewriter.");
628 }
629 
ComputeNestedElement(const HloComputation & computation,absl::Span<llvm::Value * const> parameter_elements)630 StatusOr<std::vector<llvm::Value*>> IrEmitter::ComputeNestedElement(
631     const HloComputation& computation,
632     absl::Span<llvm::Value* const> parameter_elements) {
633   const Shape& return_shape = computation.root_instruction()->shape();
634   llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
635       llvm_ir::ShapeToIrType(return_shape, module_), "return_buffer", &b_);
636   std::vector<llvm::Value*> parameter_buffers;
637   for (llvm::Value* parameter_element : parameter_elements) {
638     parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
639         parameter_element->getType(), "parameter_buffer", &b_));
640     Store(parameter_element, parameter_buffers.back());
641   }
642 
643   std::vector<llvm::Value*> allocas_for_returned_scalars;
644   if (!return_shape.IsTuple()) {
645     allocas_for_returned_scalars.push_back(return_buffer);
646   } else {
647     allocas_for_returned_scalars =
648         llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
649     llvm_ir::IrArray tuple_array(return_buffer, return_shape);
650 
651     EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
652   }
653 
654   TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
655                                                  return_buffer));
656 
657   std::vector<llvm::Value*> returned_scalars;
658   returned_scalars.reserve(allocas_for_returned_scalars.size());
659   for (llvm::Value* addr : allocas_for_returned_scalars) {
660     returned_scalars.push_back(Load(addr));
661   }
662   return returned_scalars;
663 }
664 
ConstructIrArrayForOutputs(const HloInstruction & hlo)665 std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
666     const HloInstruction& hlo) {
667   std::vector<llvm_ir::IrArray> output_arrays;
668   if (hlo.shape().IsTuple()) {
669     int64_t num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
670     output_arrays.reserve(num_outputs);
671     for (int64_t i = 0; i < num_outputs; ++i) {
672       output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
673     }
674   } else {
675     output_arrays.push_back(GetIrArray(hlo, hlo));
676   }
677   return output_arrays;
678 }
679 
BindFusionArguments(const HloInstruction * fusion,FusedIrEmitter * fused_emitter)680 void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
681                                     FusedIrEmitter* fused_emitter) {
682   for (int i = 0; i < fusion->operand_count(); i++) {
683     const HloInstruction* operand = fusion->operand(i);
684     fused_emitter->BindGenerator(
685         fusion->fused_parameter(i),
686         [this, operand, fusion](llvm_ir::IrArray::Index index) {
687           return GetIrArray(*operand, *fusion)
688               .EmitReadArrayElement(index, &b_, operand->name());
689         });
690   }
691 }
692 
693 }  // namespace gpu
694 }  // namespace xla
695