1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ 18 19 #include <functional> 20 #include <map> 21 #include <memory> 22 #include <utility> 23 #include <vector> 24 25 #include "absl/strings/string_view.h" 26 #include "absl/types/span.h" 27 #include "llvm/IR/Function.h" 28 #include "llvm/IR/IRBuilder.h" 29 #include "llvm/IR/Value.h" 30 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 32 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" 33 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" 34 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" 35 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" 36 #include "tensorflow/compiler/xla/service/gpu/thunk.h" 37 #include "tensorflow/compiler/xla/service/hlo_computation.h" 38 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 39 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 40 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" 41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" 42 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 43 #include "tensorflow/compiler/xla/statusor.h" 44 #include "tensorflow/compiler/xla/types.h" 45 #include "tensorflow/compiler/xla/xla_data.pb.h" 46 #include "tensorflow/core/platform/types.h" 47 48 namespace xla { 49 namespace gpu { 50 51 // Abstract base class for translating HLO graphs to LLVM IR for a GPU. 52 // 53 // There are two concrete subclasses of IrEmitter: IrEmitterNested and 54 // IrEmitterUnnested. In the unnested variety, each HLO gets its own kernel 55 // function, whereas in the nested version the whole computation is emitted as 56 // one *non-kernel* function. 57 // 58 // In XLA, kernel functions never call other kernel functions. This means that 59 // if we have a kernel -- e.g. implementing a kReduce HLO -- that wants to use 60 // an HLO computation as a "subroutine" -- e.g. the HLO computation that 61 // specifies how to reduce two elements -- then the subroutine computation must 62 // be emitted using IrEmitterNested. 63 // 64 // Fusion nodes are a special case. A fusion node is emitted using 65 // IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is 66 // not a subclass of gpu::IrEmitter, and in fact is better understood as an IR 67 // generator generator. See comments on that class. 68 class IrEmitter : public DfsHloVisitorWithDefault, 69 public IrBuilderMixin<IrEmitter> { 70 public: 71 using GeneratorForOperandIrArrays = 72 std::function<std::vector<llvm_ir::IrArray>()>; 73 74 IrEmitter(const IrEmitter&) = delete; 75 IrEmitter& operator=(const IrEmitter&) = delete; 76 77 Status DefaultAction(HloInstruction* hlo) override; 78 Status HandleConstant(HloInstruction* constant) override; 79 Status HandleBitcast(HloInstruction* bitcast) override; 80 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 81 Status HandleDot(HloInstruction* dot) override; 82 Status HandleConvolution(HloInstruction* convolution) override; 83 Status HandleFft(HloInstruction* fft) override; 84 Status HandleAllReduce(HloInstruction* crs) override; 85 Status HandleInfeed(HloInstruction* infeed) override; 86 Status HandleOutfeed(HloInstruction* outfeed) override; 87 Status HandleSend(HloInstruction* send) override; 88 Status HandleSendDone(HloInstruction* send_done) override; 89 Status HandleRecv(HloInstruction* recv) override; 90 Status HandleRecvDone(HloInstruction* recv_done) override; 91 Status HandleParameter(HloInstruction* parameter) override; 92 Status HandleReduce(HloInstruction* reduce) override; 93 Status HandleTuple(HloInstruction* tuple) override; 94 Status HandleScatter(HloInstruction* scatter) override; 95 Status HandleSelect(HloInstruction* select) override; 96 Status HandleTupleSelect(HloInstruction* tuple_select) override; 97 Status HandleFusion(HloInstruction* fusion) override; 98 Status HandleCall(HloInstruction* call) override; 99 Status HandleCustomCall(HloInstruction* custom_call) override; 100 Status HandleBatchNormInference(HloInstruction* batch_norm) override; 101 Status HandleBatchNormTraining(HloInstruction* batch_norm) override; 102 Status HandleBatchNormGrad(HloInstruction* batch_norm) override; 103 Status HandleAddDependency(HloInstruction* add_dependency) override; 104 FinishVisit(HloInstruction * root)105 Status FinishVisit(HloInstruction* root) override { return Status::OK(); } 106 builder()107 llvm::IRBuilder<>* builder() { return &b_; } 108 109 protected: 110 // Constructs an IrEmitter with the given IrEmitter context. 111 // ir_emitter_context is owned by the caller and should outlive the IrEmitter 112 // object. 113 explicit IrEmitter(const HloModuleConfig& hlo_module_config, 114 IrEmitterContext* ir_emitter_context, bool is_nested); 115 116 // Helper for calling HloToIrBindings::GetIrArray. 117 // 118 // Gets the IrArray which contains inst. This array has metadata that makes 119 // it valid only within the IR that implements consumer. If you are 120 // implementing an HLO and want to get its own output buffer, call 121 // GetIrArray(hlo, hlo). 122 llvm_ir::IrArray GetIrArray(const HloInstruction& inst, 123 const HloInstruction& consumer, 124 const ShapeIndex& shape_index = {}) { 125 return bindings_.GetIrArray(inst, consumer, shape_index); 126 } 127 // A convenient helper for calling HloToIrBindings::GetBasePointer. GetBasePointer(const HloInstruction & inst)128 llvm::Value* GetBasePointer(const HloInstruction& inst) const { 129 return bindings_.GetBasePointer(inst); 130 } 131 132 // Generates the IrArray for each output of an hlo instruction and returns 133 // a vector containing such IrArrays. 134 std::vector<llvm_ir::IrArray> ConstructIrArrayForOutputs( 135 const HloInstruction& hlo); 136 137 // A convenient helper for calling BufferAssignment::GetUniqueSlice. 138 BufferAllocation::Slice GetAllocationSlice( 139 const HloInstruction& hlo, const ShapeIndex& index = {}) const { 140 return ir_emitter_context_->buffer_assignment() 141 .GetUniqueSlice(&hlo, index) 142 .ConsumeValueOrDie(); 143 } 144 145 // Emit a singlethreaded or multithreaded loop that computes every element in 146 // the result of the given HLO instruction. This produces a series of nested 147 // loops (e.g. one for each dimension of the `hlo`'s shape). The body of the 148 // inner-most loop is provided by the body_emitter function. 149 virtual Status EmitTargetElementLoop( 150 const HloInstruction& hlo, 151 const llvm_ir::ElementGenerator& body_emitter) = 0; 152 153 // Emits a call in IR to the given nested computation with the given operands 154 // and output. If no IR function has been previously emitted for the 155 // computation, also emits such a function. 156 Status EmitCallToNestedComputation(const HloComputation& nested_computation, 157 absl::Span<llvm::Value* const> operands, 158 llvm::Value* output); 159 160 // Emits an atomic operation that implements `nested_computation` in the 161 // sequentially consistent memory model. `output_address` and `source_address` 162 // are the arguments of the nested computation. For example, 163 // atomicAdd(output_address, *source_address). 164 Status EmitAtomicOperationForNestedComputation( 165 const HloComputation& nested_computation, llvm::Value* output_address, 166 llvm::Value* source_address); 167 GetNestedComputer()168 GpuElementalIrEmitter::NestedComputer GetNestedComputer() { 169 return std::bind(&IrEmitter::ComputeNestedElement, this, 170 std::placeholders::_1, std::placeholders::_2); 171 } 172 173 IrEmitterContext* ir_emitter_context_; 174 llvm::Module* module_; 175 176 // The following fields track the IR emission state. According to LLVM memory 177 // management rules, their memory is owned by the module. 178 llvm::IRBuilder<> b_; 179 180 // Mapping from HLO to its underlying LLVM value. 181 HloToIrBindings bindings_; 182 183 // Hlo configuration data used during code generation. 184 const HloModuleConfig& hlo_module_config_; 185 186 protected: GetGeneratorForOperandIrArrays(HloInstruction * fusion)187 GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays( 188 HloInstruction* fusion) { 189 return [=]() { 190 std::vector<llvm_ir::IrArray> ir_arrays; 191 ir_arrays.reserve(fusion->operand_count()); 192 absl::c_transform(fusion->operands(), std::back_inserter(ir_arrays), 193 [&](const HloInstruction* operand) { 194 return GetIrArray(*operand, *fusion); 195 }); 196 return ir_arrays; 197 }; 198 } 199 200 private: 201 // A helper method for EmitAtomicOperationForNestedComputation. Certain 202 // computations, such as floating-point addition and integer maximization, can 203 // be simply implemented using an LLVM atomic instruction. If "computation" is 204 // one of this kind, emits code to do that and returns true; otherwise, 205 // returns false. 206 bool MaybeEmitDirectAtomicOperation(const HloComputation& computation, 207 llvm::Value* output_address, 208 llvm::Value* source_address); 209 210 // A helper method for EmitAtomicOperationForNestedComputation. It implements 211 // binary atomic operations using atomicCAS with special handling to support 212 // small data types. 213 Status EmitAtomicOperationUsingCAS(const HloComputation& computation, 214 llvm::Value* output_address, 215 llvm::Value* source_address); 216 217 // A helper method for HandleSort(). It adds the inner comparison loop where 218 // we compare elements pointed to by 'keys_index' and 'compare_keys_index'. 219 void EmitCompareLoop(int64 dimension_to_sort, 220 const llvm_ir::IrArray::Index& keys_index, 221 const llvm_ir::IrArray::Index& compare_keys_index, 222 const llvm_ir::IrArray& keys_array); 223 224 StatusOr<llvm::Value*> ComputeNestedElement( 225 const HloComputation& computation, 226 absl::Span<llvm::Value* const> parameter_elements); 227 228 // Emits an atomic operation that implements `nested_computation` in the 229 // sequentially consistent memory model. `output_address` and `source_address` 230 // are the arguments of the nested computation. For example, 231 // atomicAdd(output_address, *source_address). 232 StatusOr<llvm::Function*> EmitAtomicFunctionForNestedComputation( 233 const HloComputation& nested_computation, llvm::Type* element_ir_type); 234 235 // Map nested computations to emitted IR functions. This serves as a cache so 236 // that IrEmitter does not emit multiple functions for the same 237 // HloComputation. 238 std::map<const HloComputation*, llvm::Function*> computation_to_ir_function_; 239 }; 240 241 } // namespace gpu 242 } // namespace xla 243 244 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ 245