1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ 18 19 #include <functional> 20 #include <map> 21 #include <memory> 22 #include <utility> 23 #include <vector> 24 25 #include "absl/strings/string_view.h" 26 #include "absl/types/span.h" 27 #include "llvm/IR/Function.h" 28 #include "llvm/IR/IRBuilder.h" 29 #include "llvm/IR/Value.h" 30 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 32 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" 33 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" 34 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" 35 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" 36 #include "tensorflow/compiler/xla/service/gpu/thunk.h" 37 #include "tensorflow/compiler/xla/service/hlo_computation.h" 38 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 39 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" 40 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 41 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" 42 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" 43 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 44 #include "tensorflow/compiler/xla/statusor.h" 45 #include "tensorflow/compiler/xla/types.h" 46 #include "tensorflow/compiler/xla/xla_data.pb.h" 47 #include "tensorflow/core/platform/types.h" 48 49 namespace xla { 50 namespace gpu { 51 52 // Abstract base class for translating HLO graphs to LLVM IR for a GPU. 53 // 54 // There are two concrete subclasses of IrEmitter: IrEmitterNested and 55 // IrEmitterUnnested. In the unnested variety, each HLO gets its own kernel 56 // function, whereas in the nested version the whole computation is emitted as 57 // one *non-kernel* function. 58 // 59 // In XLA, kernel functions never call other kernel functions. This means that 60 // if we have a kernel -- e.g. implementing a kReduce HLO -- that wants to use 61 // an HLO computation as a "subroutine" -- e.g. the HLO computation that 62 // specifies how to reduce two elements -- then the subroutine computation must 63 // be emitted using IrEmitterNested. 64 // 65 // Fusion nodes are a special case. A fusion node is emitted using 66 // IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is 67 // not a subclass of gpu::IrEmitter, and in fact is better understood as an IR 68 // generator generator. See comments on that class. 69 class IrEmitter : public DfsHloVisitorWithDefault, 70 public IrBuilderMixin<IrEmitter> { 71 public: 72 using GeneratorForOperandIrArrays = 73 std::function<std::vector<llvm_ir::IrArray>()>; 74 75 IrEmitter(const IrEmitter&) = delete; 76 IrEmitter& operator=(const IrEmitter&) = delete; 77 78 Status DefaultAction(HloInstruction* hlo) override; 79 Status HandleConstant(HloInstruction* constant) override; 80 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 81 Status HandleConvolution(HloInstruction* convolution) override; 82 Status HandleFft(HloInstruction* fft) override; 83 Status HandleAllReduce(HloInstruction* crs) override; 84 Status HandleInfeed(HloInstruction* infeed) override; 85 Status HandleOutfeed(HloInstruction* outfeed) override; 86 Status HandleSend(HloInstruction* send) override; 87 Status HandleSendDone(HloInstruction* send_done) override; 88 Status HandleRecv(HloInstruction* recv) override; 89 Status HandleRecvDone(HloInstruction* recv_done) override; 90 Status HandleParameter(HloInstruction* parameter) override; 91 Status HandleTuple(HloInstruction* tuple) override; 92 Status HandleScatter(HloInstruction* scatter) override; 93 Status HandleTupleSelect(HloInstruction* tuple_select) override; 94 Status HandleFusion(HloInstruction* fusion) override; 95 Status HandleCall(HloInstruction* call) override; 96 Status HandleCustomCall(HloInstruction* custom_call) override; 97 Status HandleBatchNormInference(HloInstruction* batch_norm) override; 98 Status HandleBatchNormTraining(HloInstruction* batch_norm) override; 99 Status HandleBatchNormGrad(HloInstruction* batch_norm) override; 100 Status HandleAddDependency(HloInstruction* add_dependency) override; 101 FinishVisit(HloInstruction * root)102 Status FinishVisit(HloInstruction* root) override { return Status::OK(); } 103 builder()104 llvm::IRBuilder<>* builder() { return &b_; } 105 106 // Emits constants to generated LLVM IR, and also populate related 107 // inforamtion to ir_emitter_context for large-constant initializations. If 108 // `lookup_indices` is true, the allocation index associated with the constant 109 // is also populated. 110 Status EmitConstants(const HloComputation& computation, bool lookup_indices); 111 112 protected: 113 // Constructs an IrEmitter with the given IrEmitter context. 114 // ir_emitter_context is owned by the caller and should outlive the IrEmitter 115 // object. 116 explicit IrEmitter(const HloModuleConfig& hlo_module_config, 117 IrEmitterContext* ir_emitter_context, bool is_nested); 118 119 // Helper for calling HloToIrBindings::GetIrArray. 120 // 121 // Gets the IrArray which contains inst. This array has metadata that makes 122 // it valid only within the IR that implements consumer. If you are 123 // implementing an HLO and want to get its own output buffer, call 124 // GetIrArray(hlo, hlo). 125 llvm_ir::IrArray GetIrArray(const HloInstruction& inst, 126 const HloInstruction& consumer, 127 const ShapeIndex& shape_index = {}) { 128 return bindings_.GetIrArray(inst, consumer, shape_index); 129 } 130 // A convenient helper for calling HloToIrBindings::GetBasePointer. 131 llvm::Value* GetBasePointer(const HloInstruction& inst, 132 ShapeIndexView shape_index = {}) const { 133 return bindings_.GetBasePointer(inst, shape_index); 134 } 135 136 // Generates the IrArray for each output of an hlo instruction and returns 137 // a vector containing such IrArrays. 138 std::vector<llvm_ir::IrArray> ConstructIrArrayForOutputs( 139 const HloInstruction& hlo); 140 141 // Emit a singlethreaded or multithreaded loop that computes every element in 142 // the result of the given HLO instruction. This produces a series of nested 143 // loops (e.g. one for each dimension of the `hlo`'s shape). The body of the 144 // inner-most loop is provided by the body_emitter function. 145 virtual Status EmitTargetElementLoop( 146 const HloInstruction& hlo, 147 const llvm_ir::ElementGenerator& body_emitter) = 0; 148 149 // Emits a call in IR to the given nested computation with the given operands 150 // and output. If no IR function has been previously emitted for the 151 // computation, also emits such a function. 152 Status EmitCallToNestedComputation(const HloComputation& nested_computation, 153 absl::Span<llvm::Value* const> operands, 154 llvm::Value* output); 155 156 // Emits an atomic operation that implements `nested_computation` in the 157 // sequentially consistent memory model. `output_address` and `source_address` 158 // are the arguments of the nested computation. For example, 159 // atomicAdd(output_address, *source_address). 160 Status EmitAtomicOperationForNestedComputation( 161 const HloComputation& nested_computation, llvm::Value* output_address, 162 llvm::Value* source_address); 163 GetNestedComputer()164 GpuElementalIrEmitter::NestedComputer GetNestedComputer() { 165 return std::bind(&IrEmitter::ComputeNestedElement, this, 166 std::placeholders::_1, std::placeholders::_2); 167 } 168 169 IrEmitterContext* ir_emitter_context_; 170 llvm::Module* module_; 171 172 // The following fields track the IR emission state. According to LLVM memory 173 // management rules, their memory is owned by the module. 174 llvm::IRBuilder<> b_; 175 176 // Mapping from HLO to its underlying LLVM value. 177 HloToIrBindings bindings_; 178 179 // Hlo configuration data used during code generation. 180 const HloModuleConfig& hlo_module_config_; 181 182 protected: 183 // Bind all argument IrArrays of `fusion` to `fused_emitter`. 184 void BindFusionArguments(const HloInstruction* fusion, 185 FusedIrEmitter* fused_emitter); 186 187 private: 188 // A helper method for EmitAtomicOperationForNestedComputation. Certain 189 // computations, such as floating-point addition and integer maximization, can 190 // be simply implemented using an LLVM atomic instruction. If "computation" is 191 // one of this kind, emits code to do that and returns true; otherwise, 192 // returns false. 193 bool MaybeEmitDirectAtomicOperation(const HloComputation& computation, 194 llvm::Value* output_address, 195 llvm::Value* source_address); 196 197 // A helper method for EmitAtomicOperationForNestedComputation. It implements 198 // binary atomic operations using atomicCAS with special handling to support 199 // small data types. 200 Status EmitAtomicOperationUsingCAS(const HloComputation& computation, 201 llvm::Value* output_address, 202 llvm::Value* source_address); 203 204 // A helper method for HandleSort(). It adds the inner comparison loop where 205 // we compare elements pointed to by 'keys_index' and 'compare_keys_index'. 206 void EmitCompareLoop(int64 dimension_to_sort, 207 const llvm_ir::IrArray::Index& keys_index, 208 const llvm_ir::IrArray::Index& compare_keys_index, 209 const llvm_ir::IrArray& keys_array); 210 211 StatusOr<std::vector<llvm::Value*>> ComputeNestedElement( 212 const HloComputation& computation, 213 absl::Span<llvm::Value* const> parameter_elements); 214 215 // Emits an atomic operation that implements `nested_computation` in the 216 // sequentially consistent memory model. `output_address` and `source_address` 217 // are the arguments of the nested computation. For example, 218 // atomicAdd(output_address, *source_address). 219 StatusOr<llvm::Function*> EmitAtomicFunctionForNestedComputation( 220 const HloComputation& nested_computation, llvm::Type* element_ir_type); 221 222 // Map nested computations to emitted IR functions. This serves as a cache so 223 // that IrEmitter does not emit multiple functions for the same 224 // HloComputation. 225 std::map<const HloComputation*, llvm::Function*> computation_to_ir_function_; 226 }; 227 228 } // namespace gpu 229 } // namespace xla 230 231 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ 232