1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_ 18 19 #include <map> 20 #include <unordered_map> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/types/optional.h" 24 #include "absl/types/span.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/Value.h" 27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 28 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" 29 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 30 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 31 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h" 32 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 33 #include "tensorflow/compiler/xla/statusor.h" 34 #include "tensorflow/compiler/xla/xla_data.pb.h" 35 36 namespace xla { 37 38 // FusedIrEmitter is used to generate code for fusion nodes. 39 // 40 // Unlike IrEmitter and its ilk, which directly create LLVM IR in an LLVM 41 // Module, FusedIrEmitter is better understood as "IR generator generator". 42 // FusedIrEmitter recursively creates a generator (a host function) which the 43 // compiler can invoke at a later time. Invoking the generator emits LLVM IR 44 // that, when run, produces the value at a particular index of the output. 45 // 46 // After building this generator, the compiler creates a loop (or its moral 47 // equivalent, e.g. a GPU kernel) and calls the generator from within the loop. 48 // This generates code that produces each element of the output. 49 // 50 // This class handles both vanilla fusion and multi-output fusion. In the MOF 51 // case, the fusion node ends with a kTuple instruction, and the generator 52 // created produces an LLVM struct with N elements, one for each element of the 53 // arrays in the tuple. It follows that the arrays in the tuple must have the 54 // same length. 55 class FusedIrEmitter : public DfsHloVisitorWithDefault { 56 public: 57 using IndexedGenerator = llvm_ir::ElementGenerator; 58 using NonIndexedGenerator = std::function<StatusOr<llvm::Value*>()>; 59 using GeneratorForOperandIrArrays = 60 std::function<std::vector<llvm_ir::IrArray>()>; 61 FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator,ElementalIrEmitter * elemental_emitter)62 FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator, 63 ElementalIrEmitter* elemental_emitter) 64 : operand_arrays_(), 65 operand_arrays_generator_(std::move(operand_arrays_generator)), 66 tiled_parameter_info_(nullptr), 67 elemental_emitter_(elemental_emitter), 68 b_(elemental_emitter->b()), 69 module_(elemental_emitter->module()) {} 70 71 Status DefaultAction(HloInstruction* hlo) override; 72 73 Status HandleConstant(HloInstruction* constant) override; 74 75 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 76 77 Status HandleParameter(HloInstruction* parameter) override; 78 79 // Emits the ir value for each element in the tuple. 80 Status HandleTuple(HloInstruction* tuple) override; 81 82 Status FinishVisit(HloInstruction* root) override; 83 84 // Returns the generator function for the root of the fused computation. 85 IndexedGenerator GetRootGenerator() const; 86 87 // Returns the generator function for the given instruction. 88 IndexedGenerator GetGenerator(const HloInstruction* instruction) const; 89 SetTiledParameterInfo(const llvm_ir::TiledParameterInfo * info)90 void SetTiledParameterInfo(const llvm_ir::TiledParameterInfo* info) { 91 tiled_parameter_info_ = info; 92 } 93 94 // Evaluates whether fusing 'producer' into 'consumer' might cause exponential 95 // behavior in FusedIrEmitter. We currently can have exponential time/memory 96 // requirements for emitting certain fusion kernels, in which case we don't 97 // want to fuse. 98 // TODO(b/119692968): Remove this once we have fixed our fusion emitter. 99 static bool IsFusedIrEmitterInefficient(const HloInstruction* consumer, 100 const HloInstruction* producer); 101 102 protected: 103 // Returns the IrArrays for the fusion instruction operands. GetIrArrayForFusedParameter(int64 parameter_number)104 llvm_ir::IrArray& GetIrArrayForFusedParameter(int64 parameter_number) { 105 if (!operand_arrays_.has_value()) { 106 operand_arrays_ = operand_arrays_generator_(); 107 } 108 return operand_arrays_.value()[parameter_number]; 109 } 110 GetBasePointerForFusedParameter(int64 parameter_number)111 llvm::Value* GetBasePointerForFusedParameter(int64 parameter_number) { 112 return GetIrArrayForFusedParameter(parameter_number).GetBasePointer(); 113 } 114 115 private: 116 // IrArrays for the fusion instruction operands, whose base addresses are the 117 // base address of the corresponding parameters in the fused computation. 118 absl::optional<std::vector<llvm_ir::IrArray>> operand_arrays_; 119 GeneratorForOperandIrArrays operand_arrays_generator_; 120 121 const llvm_ir::TiledParameterInfo* tiled_parameter_info_; 122 123 ElementalIrEmitter* elemental_emitter_; 124 125 // This member will be set by FinishVisit and used in GetRootGenerator. 126 const HloInstruction* fused_root_ = nullptr; 127 128 // Borrowed 129 llvm::IRBuilder<>* b_; 130 llvm::Module* module_; 131 132 // Map from instructions to functions that generate code for the output 133 // elements. If an instruction is a GetTupleElement instruction, the 134 // instruction produces non-tuple result. 135 std::unordered_map<const HloInstruction*, IndexedGenerator> 136 indexed_generators_; 137 138 // Map from tuple-result-producing GetTupleELement instructions to functions 139 // that generate the base pointers for the output elements. This is used to 140 // support the translation of nested GetTupleElement instructions. 141 std::unordered_map<const HloInstruction*, NonIndexedGenerator> 142 non_indexed_generators_; 143 144 // Cache of generated values, lest we regenerate an element of a node with 145 // multiple outgoing edges 146 absl::flat_hash_map< 147 const HloInstruction*, 148 absl::flat_hash_map<std::vector<llvm::Value*>, llvm::Value*>> 149 generated_value_cache_; 150 }; 151 152 } // namespace xla 153 154 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_ 155