1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ 18 19 #include <unordered_map> 20 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/IR/Value.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 26 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 29 namespace xla { 30 31 class ElementalIrEmitter { 32 public: 33 using HloToElementGeneratorMap = 34 std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>; 35 ElementalIrEmitter(const HloModuleConfig & hlo_module_config,llvm::Module * module,llvm::IRBuilder<> * ir_builder)36 ElementalIrEmitter(const HloModuleConfig& hlo_module_config, 37 llvm::Module* module, llvm::IRBuilder<>* ir_builder) 38 : ir_builder_(ir_builder), 39 module_(module), 40 hlo_module_config_(hlo_module_config) {} 41 42 virtual ~ElementalIrEmitter() = default; 43 44 virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op, 45 llvm::Value* operand_value) const; 46 47 virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op, 48 llvm::Value* lhs_value, 49 llvm::Value* rhs_value) const; 50 51 // Returns a function to generate an element of the output of `hlo`, given a 52 // map of functions to generate elements of its operands. 53 virtual llvm_ir::ElementGenerator MakeElementGenerator( 54 const HloInstruction* hlo, 55 const HloToElementGeneratorMap& operand_to_generator) const; 56 ir_builder()57 llvm::IRBuilder<>* ir_builder() const { return ir_builder_; } module()58 llvm::Module* module() const { return module_; } 59 60 protected: 61 virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp( 62 const HloInstruction* op, llvm::Value* operand_value) const; 63 64 virtual StatusOr<llvm::Value*> EmitFloatUnaryOp( 65 const HloInstruction* op, llvm::Value* operand_value) const; 66 67 virtual StatusOr<llvm::Value*> EmitComplexUnaryOp( 68 const HloInstruction* op, llvm::Value* operand_value) const; 69 70 virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op, 71 llvm::Value* lhs_value, 72 llvm::Value* rhs_value, 73 bool is_signed) const; 74 75 virtual StatusOr<llvm::Value*> EmitFloatBinaryOp( 76 const HloInstruction* op, llvm::Value* lhs_value, 77 llvm::Value* rhs_value) const; 78 79 virtual StatusOr<llvm::Value*> EmitComplexBinaryOp( 80 const HloInstruction* op, llvm::Value* lhs_value, 81 llvm::Value* rhs_value) const; 82 83 virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value, 84 llvm::Value* rhs_value) const; 85 86 virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value, 87 llvm::Value* rhs_value) const; 88 89 llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, 90 bool is_signed) const; 91 92 llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, 93 bool is_signed) const; 94 95 virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type, 96 llvm::Value* value) const; 97 98 virtual StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type, 99 llvm::Value* value) const; 100 101 virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, 102 llvm::Value* lhs, 103 llvm::Value* rhs) const; 104 105 virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type, 106 llvm::Value* value) const; 107 108 virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type, 109 llvm::Value* value) const; 110 111 virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type, 112 llvm::Value* value) const; 113 114 virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type, 115 llvm::Value* value) const; 116 117 virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, 118 llvm::Value* lhs, 119 llvm::Value* rhs) const; 120 121 virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo, 122 llvm::Value* x) const; 123 124 virtual llvm::Value* EmitExtractReal(llvm::Value* value) const; 125 virtual llvm::Value* EmitExtractImag(llvm::Value* value) const; 126 127 // Composes a complex struct. imag may be nullptr for simple cast operations. 128 llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, 129 llvm::Value* imag) const; 130 131 // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and 132 // the target array index, computes the source array index of its 133 // `operand_no`-th operand. 134 // 135 // Precondition: `hlo` is an elementwise op. 136 llvm_ir::IrArray::Index ElementwiseSourceIndex( 137 const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, 138 int64 operand_no) const; 139 140 // Identifier of the thread unique among all threads on the device EmitThreadId()141 virtual llvm::Value* EmitThreadId() const { 142 return ir_builder_->getIntN(128, 0); 143 } 144 145 llvm::IRBuilder<>* const ir_builder_; 146 147 llvm::Module* module_; 148 149 // The HloModuleConfig which gathers all settings and values which affect the 150 // compiled executable outside of the HLO code itself. 151 const HloModuleConfig& hlo_module_config_; 152 153 private: 154 // Returns a ElementGenerator for a RNG HloInstruction. 155 llvm_ir::ElementGenerator MakeRngElementGenerator( 156 const HloInstruction* hlo, 157 const HloToElementGeneratorMap& operand_to_generator) const; 158 }; 159 160 } // namespace xla 161 162 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ 163