• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
18 
19 #include <functional>
20 #include <map>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/strings/string_view.h"
26 #include "absl/types/span.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/Value.h"
30 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
32 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
33 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
34 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
35 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
36 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
37 #include "tensorflow/compiler/xla/service/hlo_computation.h"
38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
42 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
43 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
44 #include "tensorflow/compiler/xla/statusor.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/xla_data.pb.h"
47 #include "tensorflow/core/platform/types.h"
48 
49 namespace xla {
50 namespace gpu {
51 
52 // Abstract base class for translating HLO graphs to LLVM IR for a GPU.
53 //
54 // There are two concrete subclasses of IrEmitter: IrEmitterNested and
55 // IrEmitterUnnested.  In the unnested variety, each HLO gets its own kernel
56 // function, whereas in the nested version the whole computation is emitted as
57 // one *non-kernel* function.
58 //
59 // In XLA, kernel functions never call other kernel functions.  This means that
60 // if we have a kernel -- e.g. implementing a kReduce HLO -- that wants to use
61 // an HLO computation as a "subroutine" -- e.g. the HLO computation that
62 // specifies how to reduce two elements -- then the subroutine computation must
63 // be emitted using IrEmitterNested.
64 //
65 // Fusion nodes are a special case.  A fusion node is emitted using
66 // IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is
67 // not a subclass of gpu::IrEmitter, and in fact is better understood as an IR
68 // generator generator.  See comments on that class.
69 class IrEmitter : public DfsHloVisitorWithDefault,
70                   public IrBuilderMixin<IrEmitter> {
71  public:
72   using GeneratorForOperandIrArrays =
73       std::function<std::vector<llvm_ir::IrArray>()>;
74 
75   IrEmitter(const IrEmitter&) = delete;
76   IrEmitter& operator=(const IrEmitter&) = delete;
77 
78   Status DefaultAction(HloInstruction* hlo) override;
79   Status HandleConstant(HloInstruction* constant) override;
80   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
81   Status HandleConvolution(HloInstruction* convolution) override;
82   Status HandleFft(HloInstruction* fft) override;
83   Status HandleAllReduce(HloInstruction* crs) override;
84   Status HandleInfeed(HloInstruction* infeed) override;
85   Status HandleOutfeed(HloInstruction* outfeed) override;
86   Status HandleSend(HloInstruction* send) override;
87   Status HandleSendDone(HloInstruction* send_done) override;
88   Status HandleRecv(HloInstruction* recv) override;
89   Status HandleRecvDone(HloInstruction* recv_done) override;
90   Status HandleParameter(HloInstruction* parameter) override;
91   Status HandleTuple(HloInstruction* tuple) override;
92   Status HandleScatter(HloInstruction* scatter) override;
93   Status HandleTupleSelect(HloInstruction* tuple_select) override;
94   Status HandleFusion(HloInstruction* fusion) override;
95   Status HandleCall(HloInstruction* call) override;
96   Status HandleCustomCall(HloInstruction* custom_call) override;
97   Status HandleBatchNormInference(HloInstruction* batch_norm) override;
98   Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
99   Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
100   Status HandleAddDependency(HloInstruction* add_dependency) override;
101 
FinishVisit(HloInstruction * root)102   Status FinishVisit(HloInstruction* root) override { return Status::OK(); }
103 
builder()104   llvm::IRBuilder<>* builder() { return &b_; }
105 
106   // Emits constants to generated LLVM IR, and also populate related
107   // inforamtion to ir_emitter_context for large-constant initializations. If
108   // `lookup_indices` is true, the allocation index associated with the constant
109   // is also populated.
110   Status EmitConstants(const HloComputation& computation, bool lookup_indices);
111 
112  protected:
113   // Constructs an IrEmitter with the given IrEmitter context.
114   // ir_emitter_context is owned by the caller and should outlive the IrEmitter
115   // object.
116   explicit IrEmitter(const HloModuleConfig& hlo_module_config,
117                      IrEmitterContext* ir_emitter_context, bool is_nested);
118 
119   // Helper for calling HloToIrBindings::GetIrArray.
120   //
121   // Gets the IrArray which contains inst.  This array has metadata that makes
122   // it valid only within the IR that implements consumer.  If you are
123   // implementing an HLO and want to get its own output buffer, call
124   // GetIrArray(hlo, hlo).
125   llvm_ir::IrArray GetIrArray(const HloInstruction& inst,
126                               const HloInstruction& consumer,
127                               const ShapeIndex& shape_index = {}) {
128     return bindings_.GetIrArray(inst, consumer, shape_index);
129   }
130   // A convenient helper for calling HloToIrBindings::GetBasePointer.
131   llvm::Value* GetBasePointer(const HloInstruction& inst,
132                               ShapeIndexView shape_index = {}) const {
133     return bindings_.GetBasePointer(inst, shape_index);
134   }
135 
136   // Generates the IrArray for each output of an hlo instruction and returns
137   // a vector containing such IrArrays.
138   std::vector<llvm_ir::IrArray> ConstructIrArrayForOutputs(
139       const HloInstruction& hlo);
140 
141   // Emit a singlethreaded or multithreaded loop that computes every element in
142   // the result of the given HLO instruction. This produces a series of nested
143   // loops (e.g. one for each dimension of the `hlo`'s shape). The body of the
144   // inner-most loop is provided by the body_emitter function.
145   virtual Status EmitTargetElementLoop(
146       const HloInstruction& hlo,
147       const llvm_ir::ElementGenerator& body_emitter) = 0;
148 
149   // Emits a call in IR to the given nested computation with the given operands
150   // and output. If no IR function has been previously emitted for the
151   // computation, also emits such a function.
152   Status EmitCallToNestedComputation(const HloComputation& nested_computation,
153                                      absl::Span<llvm::Value* const> operands,
154                                      llvm::Value* output);
155 
156   // Emits an atomic operation that implements `nested_computation` in the
157   // sequentially consistent memory model. `output_address` and `source_address`
158   // are the arguments of the nested computation. For example,
159   // atomicAdd(output_address, *source_address).
160   Status EmitAtomicOperationForNestedComputation(
161       const HloComputation& nested_computation, llvm::Value* output_address,
162       llvm::Value* source_address);
163 
GetNestedComputer()164   GpuElementalIrEmitter::NestedComputer GetNestedComputer() {
165     return std::bind(&IrEmitter::ComputeNestedElement, this,
166                      std::placeholders::_1, std::placeholders::_2);
167   }
168 
169   IrEmitterContext* ir_emitter_context_;
170   llvm::Module* module_;
171 
172   // The following fields track the IR emission state. According to LLVM memory
173   // management rules, their memory is owned by the module.
174   llvm::IRBuilder<> b_;
175 
176   // Mapping from HLO to its underlying LLVM value.
177   HloToIrBindings bindings_;
178 
179   // Hlo configuration data used during code generation.
180   const HloModuleConfig& hlo_module_config_;
181 
182  protected:
183   // Bind all argument IrArrays of `fusion` to `fused_emitter`.
184   void BindFusionArguments(const HloInstruction* fusion,
185                            FusedIrEmitter* fused_emitter);
186 
187  private:
188   // A helper method for EmitAtomicOperationForNestedComputation. Certain
189   // computations, such as floating-point addition and integer maximization, can
190   // be simply implemented using an LLVM atomic instruction. If "computation" is
191   // one of this kind, emits code to do that and returns true; otherwise,
192   // returns false.
193   bool MaybeEmitDirectAtomicOperation(const HloComputation& computation,
194                                       llvm::Value* output_address,
195                                       llvm::Value* source_address);
196 
197   // A helper method for EmitAtomicOperationForNestedComputation. It implements
198   // binary atomic operations using atomicCAS with special handling to support
199   // small data types.
200   Status EmitAtomicOperationUsingCAS(const HloComputation& computation,
201                                      llvm::Value* output_address,
202                                      llvm::Value* source_address);
203 
204   // A helper method for HandleSort(). It adds the inner comparison loop where
205   // we compare elements pointed to by 'keys_index' and 'compare_keys_index'.
206   void EmitCompareLoop(int64 dimension_to_sort,
207                        const llvm_ir::IrArray::Index& keys_index,
208                        const llvm_ir::IrArray::Index& compare_keys_index,
209                        const llvm_ir::IrArray& keys_array);
210 
211   StatusOr<std::vector<llvm::Value*>> ComputeNestedElement(
212       const HloComputation& computation,
213       absl::Span<llvm::Value* const> parameter_elements);
214 
215   // Emits an atomic operation that implements `nested_computation` in the
216   // sequentially consistent memory model. `output_address` and `source_address`
217   // are the arguments of the nested computation. For example,
218   // atomicAdd(output_address, *source_address).
219   StatusOr<llvm::Function*> EmitAtomicFunctionForNestedComputation(
220       const HloComputation& nested_computation, llvm::Type* element_ir_type);
221 
222   // Map nested computations to emitted IR functions. This serves as a cache so
223   // that IrEmitter does not emit multiple functions for the same
224   // HloComputation.
225   std::map<const HloComputation*, llvm::Function*> computation_to_ir_function_;
226 };
227 
228 }  // namespace gpu
229 }  // namespace xla
230 
231 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
232