• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_
18 
19 #include <map>
20 #include <unordered_map>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/types/optional.h"
24 #include "absl/types/span.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Value.h"
27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
28 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h"
32 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 
36 namespace xla {
37 
38 // FusedIrEmitter is used to generate code for fusion nodes.
39 //
40 // Unlike IrEmitter and its ilk, which directly create LLVM IR in an LLVM
41 // Module, FusedIrEmitter is better understood as "IR generator generator".
42 // FusedIrEmitter recursively creates a generator (a host function) which the
43 // compiler can invoke at a later time.  Invoking the generator emits LLVM IR
44 // that, when run, produces the value at a particular index of the output.
45 //
46 // After building this generator, the compiler creates a loop (or its moral
47 // equivalent, e.g. a GPU kernel) and calls the generator from within the loop.
48 // This generates code that produces each element of the output.
49 //
50 // This class handles both vanilla fusion and multi-output fusion.  In the MOF
51 // case, the fusion node ends with a kTuple instruction, and the generator
52 // created produces an LLVM struct with N elements, one for each element of the
53 // arrays in the tuple.  It follows that the arrays in the tuple must have the
54 // same length.
55 class FusedIrEmitter : public DfsHloVisitorWithDefault {
56  public:
57   using IndexedGenerator = llvm_ir::ElementGenerator;
58   using NonIndexedGenerator = std::function<StatusOr<llvm::Value*>()>;
59   using GeneratorForOperandIrArrays =
60       std::function<std::vector<llvm_ir::IrArray>()>;
61 
FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator,ElementalIrEmitter * elemental_emitter)62   FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator,
63                  ElementalIrEmitter* elemental_emitter)
64       : operand_arrays_(),
65         operand_arrays_generator_(std::move(operand_arrays_generator)),
66         tiled_parameter_info_(nullptr),
67         elemental_emitter_(elemental_emitter),
68         b_(elemental_emitter->b()),
69         module_(elemental_emitter->module()) {}
70 
71   Status DefaultAction(HloInstruction* hlo) override;
72 
73   Status HandleConstant(HloInstruction* constant) override;
74 
75   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
76 
77   Status HandleParameter(HloInstruction* parameter) override;
78 
79   // Emits the ir value for each element in the tuple.
80   Status HandleTuple(HloInstruction* tuple) override;
81 
82   Status FinishVisit(HloInstruction* root) override;
83 
84   // Returns the generator function for the root of the fused computation.
85   IndexedGenerator GetRootGenerator() const;
86 
87   // Returns the generator function for the given instruction.
88   IndexedGenerator GetGenerator(const HloInstruction* instruction) const;
89 
SetTiledParameterInfo(const llvm_ir::TiledParameterInfo * info)90   void SetTiledParameterInfo(const llvm_ir::TiledParameterInfo* info) {
91     tiled_parameter_info_ = info;
92   }
93 
94   // Evaluates whether fusing 'producer' into 'consumer' might cause exponential
95   // behavior in FusedIrEmitter. We currently can have exponential time/memory
96   // requirements for emitting certain fusion kernels, in which case we don't
97   // want to fuse.
98   // TODO(b/119692968): Remove this once we have fixed our fusion emitter.
99   static bool IsFusedIrEmitterInefficient(const HloInstruction* consumer,
100                                           const HloInstruction* producer);
101 
102  protected:
103   // Returns the IrArrays for the fusion instruction operands.
GetIrArrayForFusedParameter(int64 parameter_number)104   llvm_ir::IrArray& GetIrArrayForFusedParameter(int64 parameter_number) {
105     if (!operand_arrays_.has_value()) {
106       operand_arrays_ = operand_arrays_generator_();
107     }
108     return operand_arrays_.value()[parameter_number];
109   }
110 
GetBasePointerForFusedParameter(int64 parameter_number)111   llvm::Value* GetBasePointerForFusedParameter(int64 parameter_number) {
112     return GetIrArrayForFusedParameter(parameter_number).GetBasePointer();
113   }
114 
115  private:
116   // IrArrays for the fusion instruction operands, whose base addresses are the
117   // base address of the corresponding parameters in the fused computation.
118   absl::optional<std::vector<llvm_ir::IrArray>> operand_arrays_;
119   GeneratorForOperandIrArrays operand_arrays_generator_;
120 
121   const llvm_ir::TiledParameterInfo* tiled_parameter_info_;
122 
123   ElementalIrEmitter* elemental_emitter_;
124 
125   // This member will be set by FinishVisit and used in GetRootGenerator.
126   const HloInstruction* fused_root_ = nullptr;
127 
128   // Borrowed
129   llvm::IRBuilder<>* b_;
130   llvm::Module* module_;
131 
132   // Map from instructions to functions that generate code for the output
133   // elements. If an instruction is a GetTupleElement instruction, the
134   // instruction produces non-tuple result.
135   std::unordered_map<const HloInstruction*, IndexedGenerator>
136       indexed_generators_;
137 
138   // Map from tuple-result-producing GetTupleELement instructions to functions
139   // that generate the base pointers for the output elements. This is used to
140   // support the translation of nested GetTupleElement instructions.
141   std::unordered_map<const HloInstruction*, NonIndexedGenerator>
142       non_indexed_generators_;
143 
144   // Cache of generated values, lest we regenerate an element of a node with
145   // multiple outgoing edges
146   absl::flat_hash_map<
147       const HloInstruction*,
148       absl::flat_hash_map<std::vector<llvm::Value*>, llvm::Value*>>
149       generated_value_cache_;
150 };
151 
152 }  // namespace xla
153 
154 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_
155