• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
18 
19 #include <functional>
20 #include <map>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/strings/string_view.h"
26 #include "absl/types/span.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/Value.h"
30 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
32 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
33 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
34 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
35 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
36 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
37 #include "tensorflow/compiler/xla/service/hlo_computation.h"
38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
42 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
43 #include "tensorflow/compiler/xla/statusor.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/xla_data.pb.h"
46 #include "tensorflow/core/platform/types.h"
47 
48 namespace xla {
49 namespace gpu {
50 
51 // Abstract base class for translating HLO graphs to LLVM IR for a GPU.
52 //
53 // There are two concrete subclasses of IrEmitter: IrEmitterNested and
54 // IrEmitterUnnested.  In the unnested variety, each HLO gets its own kernel
55 // function, whereas in the nested version the whole computation is emitted as
56 // one *non-kernel* function.
57 //
58 // In XLA, kernel functions never call other kernel functions.  This means that
59 // if we have a kernel -- e.g. implementing a kReduce HLO -- that wants to use
60 // an HLO computation as a "subroutine" -- e.g. the HLO computation that
61 // specifies how to reduce two elements -- then the subroutine computation must
62 // be emitted using IrEmitterNested.
63 //
64 // Fusion nodes are a special case.  A fusion node is emitted using
65 // IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is
66 // not a subclass of gpu::IrEmitter, and in fact is better understood as an IR
67 // generator generator.  See comments on that class.
68 class IrEmitter : public DfsHloVisitorWithDefault,
69                   public IrBuilderMixin<IrEmitter> {
70  public:
71   using GeneratorForOperandIrArrays =
72       std::function<std::vector<llvm_ir::IrArray>()>;
73 
74   IrEmitter(const IrEmitter&) = delete;
75   IrEmitter& operator=(const IrEmitter&) = delete;
76 
77   Status DefaultAction(HloInstruction* hlo) override;
78   Status HandleConstant(HloInstruction* constant) override;
79   Status HandleBitcast(HloInstruction* bitcast) override;
80   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
81   Status HandleDot(HloInstruction* dot) override;
82   Status HandleConvolution(HloInstruction* convolution) override;
83   Status HandleFft(HloInstruction* fft) override;
84   Status HandleAllReduce(HloInstruction* crs) override;
85   Status HandleInfeed(HloInstruction* infeed) override;
86   Status HandleOutfeed(HloInstruction* outfeed) override;
87   Status HandleSend(HloInstruction* send) override;
88   Status HandleSendDone(HloInstruction* send_done) override;
89   Status HandleRecv(HloInstruction* recv) override;
90   Status HandleRecvDone(HloInstruction* recv_done) override;
91   Status HandleParameter(HloInstruction* parameter) override;
92   Status HandleReduce(HloInstruction* reduce) override;
93   Status HandleTuple(HloInstruction* tuple) override;
94   Status HandleScatter(HloInstruction* scatter) override;
95   Status HandleSelect(HloInstruction* select) override;
96   Status HandleTupleSelect(HloInstruction* tuple_select) override;
97   Status HandleFusion(HloInstruction* fusion) override;
98   Status HandleCall(HloInstruction* call) override;
99   Status HandleCustomCall(HloInstruction* custom_call) override;
100   Status HandleBatchNormInference(HloInstruction* batch_norm) override;
101   Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
102   Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
103   Status HandleAddDependency(HloInstruction* add_dependency) override;
104 
FinishVisit(HloInstruction * root)105   Status FinishVisit(HloInstruction* root) override { return Status::OK(); }
106 
builder()107   llvm::IRBuilder<>* builder() { return &b_; }
108 
109  protected:
110   // Constructs an IrEmitter with the given IrEmitter context.
111   // ir_emitter_context is owned by the caller and should outlive the IrEmitter
112   // object.
113   explicit IrEmitter(const HloModuleConfig& hlo_module_config,
114                      IrEmitterContext* ir_emitter_context, bool is_nested);
115 
116   // Helper for calling HloToIrBindings::GetIrArray.
117   //
118   // Gets the IrArray which contains inst.  This array has metadata that makes
119   // it valid only within the IR that implements consumer.  If you are
120   // implementing an HLO and want to get its own output buffer, call
121   // GetIrArray(hlo, hlo).
122   llvm_ir::IrArray GetIrArray(const HloInstruction& inst,
123                               const HloInstruction& consumer,
124                               const ShapeIndex& shape_index = {}) {
125     return bindings_.GetIrArray(inst, consumer, shape_index);
126   }
127   // A convenient helper for calling HloToIrBindings::GetBasePointer.
GetBasePointer(const HloInstruction & inst)128   llvm::Value* GetBasePointer(const HloInstruction& inst) const {
129     return bindings_.GetBasePointer(inst);
130   }
131 
132   // Generates the IrArray for each output of an hlo instruction and returns
133   // a vector containing such IrArrays.
134   std::vector<llvm_ir::IrArray> ConstructIrArrayForOutputs(
135       const HloInstruction& hlo);
136 
137   // A convenient helper for calling BufferAssignment::GetUniqueSlice.
138   BufferAllocation::Slice GetAllocationSlice(
139       const HloInstruction& hlo, const ShapeIndex& index = {}) const {
140     return ir_emitter_context_->buffer_assignment()
141         .GetUniqueSlice(&hlo, index)
142         .ConsumeValueOrDie();
143   }
144 
145   // Emit a singlethreaded or multithreaded loop that computes every element in
146   // the result of the given HLO instruction. This produces a series of nested
147   // loops (e.g. one for each dimension of the `hlo`'s shape). The body of the
148   // inner-most loop is provided by the body_emitter function.
149   virtual Status EmitTargetElementLoop(
150       const HloInstruction& hlo,
151       const llvm_ir::ElementGenerator& body_emitter) = 0;
152 
153   // Emits a call in IR to the given nested computation with the given operands
154   // and output. If no IR function has been previously emitted for the
155   // computation, also emits such a function.
156   Status EmitCallToNestedComputation(const HloComputation& nested_computation,
157                                      absl::Span<llvm::Value* const> operands,
158                                      llvm::Value* output);
159 
160   // Emits an atomic operation that implements `nested_computation` in the
161   // sequentially consistent memory model. `output_address` and `source_address`
162   // are the arguments of the nested computation. For example,
163   // atomicAdd(output_address, *source_address).
164   Status EmitAtomicOperationForNestedComputation(
165       const HloComputation& nested_computation, llvm::Value* output_address,
166       llvm::Value* source_address);
167 
GetNestedComputer()168   GpuElementalIrEmitter::NestedComputer GetNestedComputer() {
169     return std::bind(&IrEmitter::ComputeNestedElement, this,
170                      std::placeholders::_1, std::placeholders::_2);
171   }
172 
173   IrEmitterContext* ir_emitter_context_;
174   llvm::Module* module_;
175 
176   // The following fields track the IR emission state. According to LLVM memory
177   // management rules, their memory is owned by the module.
178   llvm::IRBuilder<> b_;
179 
180   // Mapping from HLO to its underlying LLVM value.
181   HloToIrBindings bindings_;
182 
183   // Hlo configuration data used during code generation.
184   const HloModuleConfig& hlo_module_config_;
185 
186  protected:
GetGeneratorForOperandIrArrays(HloInstruction * fusion)187   GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays(
188       HloInstruction* fusion) {
189     return [=]() {
190       std::vector<llvm_ir::IrArray> ir_arrays;
191       ir_arrays.reserve(fusion->operand_count());
192       absl::c_transform(fusion->operands(), std::back_inserter(ir_arrays),
193                         [&](const HloInstruction* operand) {
194                           return GetIrArray(*operand, *fusion);
195                         });
196       return ir_arrays;
197     };
198   }
199 
200  private:
201   // A helper method for EmitAtomicOperationForNestedComputation. Certain
202   // computations, such as floating-point addition and integer maximization, can
203   // be simply implemented using an LLVM atomic instruction. If "computation" is
204   // one of this kind, emits code to do that and returns true; otherwise,
205   // returns false.
206   bool MaybeEmitDirectAtomicOperation(const HloComputation& computation,
207                                       llvm::Value* output_address,
208                                       llvm::Value* source_address);
209 
210   // A helper method for EmitAtomicOperationForNestedComputation. It implements
211   // binary atomic operations using atomicCAS with special handling to support
212   // small data types.
213   Status EmitAtomicOperationUsingCAS(const HloComputation& computation,
214                                      llvm::Value* output_address,
215                                      llvm::Value* source_address);
216 
217   // A helper method for HandleSort(). It adds the inner comparison loop where
218   // we compare elements pointed to by 'keys_index' and 'compare_keys_index'.
219   void EmitCompareLoop(int64 dimension_to_sort,
220                        const llvm_ir::IrArray::Index& keys_index,
221                        const llvm_ir::IrArray::Index& compare_keys_index,
222                        const llvm_ir::IrArray& keys_array);
223 
224   StatusOr<llvm::Value*> ComputeNestedElement(
225       const HloComputation& computation,
226       absl::Span<llvm::Value* const> parameter_elements);
227 
228   // Emits an atomic operation that implements `nested_computation` in the
229   // sequentially consistent memory model. `output_address` and `source_address`
230   // are the arguments of the nested computation. For example,
231   // atomicAdd(output_address, *source_address).
232   StatusOr<llvm::Function*> EmitAtomicFunctionForNestedComputation(
233       const HloComputation& nested_computation, llvm::Type* element_ir_type);
234 
235   // Map nested computations to emitted IR functions. This serves as a cache so
236   // that IrEmitter does not emit multiple functions for the same
237   // HloComputation.
238   std::map<const HloComputation*, llvm::Function*> computation_to_ir_function_;
239 };
240 
241 }  // namespace gpu
242 }  // namespace xla
243 
244 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
245