1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ 18 19 #include <unordered_map> 20 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/IR/Value.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 26 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" 27 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 28 #include "tensorflow/compiler/xla/statusor.h" 29 30 namespace xla { 31 32 class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> { 33 public: 34 using HloToElementGeneratorMap = 35 std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>; 36 ElementalIrEmitter(const HloModuleConfig & hlo_module_config,llvm::Module * module,llvm::IRBuilder<> * b)37 ElementalIrEmitter(const HloModuleConfig& hlo_module_config, 38 llvm::Module* module, llvm::IRBuilder<>* b) 39 : b_(b), module_(module), hlo_module_config_(hlo_module_config) {} 40 41 virtual ~ElementalIrEmitter() = default; 42 43 virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op, 44 llvm::Value* operand_value); 45 46 virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op, 47 llvm::Value* lhs_value, 48 llvm::Value* rhs_value); 49 50 // Returns a function to generate an element of the output of `hlo`, given a 51 // map of functions to generate elements of its operands. 52 virtual llvm_ir::ElementGenerator MakeElementGenerator( 53 const HloInstruction* hlo, 54 const HloToElementGeneratorMap& operand_to_generator); 55 b()56 llvm::IRBuilder<>* b() { return b_; } 57 58 // builder() is for IrBuilderMixin. builder()59 llvm::IRBuilder<>* builder() { return b_; } 60 module()61 llvm::Module* module() { return module_; } 62 63 protected: 64 virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op, 65 llvm::Value* operand_value); 66 67 virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(const HloInstruction* op, 68 llvm::Value* operand_value); 69 70 virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(const HloInstruction* op, 71 llvm::Value* operand_value); 72 73 llvm::Value* IsZero(llvm::Value* v); 74 llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, llvm::Value* rhs); 75 llvm::Value* GetZero(llvm::Type* type); 76 llvm::Value* GetOne(llvm::Type* type); 77 llvm::Value* GetIntSMin(llvm::Type* type); 78 llvm::Value* GetMinusOne(llvm::Type* type); 79 80 llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs, 81 bool is_signed); 82 llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs, 83 bool is_signed); 84 85 virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op, 86 llvm::Value* lhs_value, 87 llvm::Value* rhs_value, 88 bool is_signed); 89 90 virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op, 91 llvm::Value* lhs_value, 92 llvm::Value* rhs_value); 93 94 virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op, 95 llvm::Value* lhs_value, 96 llvm::Value* rhs_value); 97 98 virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value, 99 llvm::Value* rhs_value); 100 101 virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value, 102 llvm::Value* rhs_value); 103 104 llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, 105 bool is_signed); 106 107 llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, 108 bool is_signed); 109 110 virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type, 111 llvm::Value* value); 112 113 virtual StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type, 114 llvm::Value* value); 115 116 virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, 117 llvm::Value* lhs, llvm::Value* rhs); 118 119 virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type, 120 llvm::Value* value); 121 122 virtual StatusOr<llvm::Value*> EmitSqrt(PrimitiveType prim_type, 123 llvm::Value* value); 124 125 virtual StatusOr<llvm::Value*> EmitRsqrt(PrimitiveType prim_type, 126 llvm::Value* value); 127 128 virtual StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type, 129 llvm::Value* value); 130 131 virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type, 132 llvm::Value* value); 133 134 virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type, 135 llvm::Value* value); 136 137 virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type, 138 llvm::Value* value); 139 140 virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type, 141 llvm::Value* value); 142 143 virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, 144 llvm::Value* lhs, llvm::Value* rhs); 145 146 virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type, 147 llvm::Value* value); 148 149 virtual StatusOr<llvm::Value*> EmitRoundNearestAfz(PrimitiveType prim_type, 150 llvm::Value* value); 151 152 virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo, 153 llvm::Value* x); 154 155 virtual llvm::Value* EmitExtractReal(llvm::Value* value); 156 virtual llvm::Value* EmitExtractImag(llvm::Value* value); 157 158 // Composes a complex struct. imag may be nullptr for simple cast operations. 159 llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, 160 llvm::Value* imag); 161 162 // Identifier of the thread unique among all threads on the device EmitThreadId()163 virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); } 164 165 StatusOr<llvm::Value*> EmitElementalSelect( 166 const HloInstruction* hlo, 167 const HloToElementGeneratorMap& operand_to_generator, 168 const llvm_ir::IrArray::Index& index); 169 170 StatusOr<llvm::Value*> EmitElementalClamp( 171 const HloInstruction* hlo, 172 const HloToElementGeneratorMap& operand_to_generator, 173 const llvm_ir::IrArray::Index& index); 174 175 StatusOr<llvm::Value*> EmitElementalConcatenate( 176 const HloInstruction* hlo, 177 const HloToElementGeneratorMap& operand_to_generator, 178 const llvm_ir::IrArray::Index& target_index); 179 180 StatusOr<llvm::Value*> EmitElementalDynamicSlice( 181 const HloInstruction* hlo, 182 const HloToElementGeneratorMap& operand_to_generator, 183 const llvm_ir::IrArray::Index& index); 184 185 StatusOr<llvm::Value*> EmitElementalGather( 186 const HloInstruction* hlo, 187 const HloToElementGeneratorMap& operand_to_generator, 188 const llvm_ir::IrArray::Index& index); 189 190 StatusOr<llvm::Value*> EmitElementalDynamicUpdateSlice( 191 const HloInstruction* hlo, 192 const HloToElementGeneratorMap& operand_to_generator, 193 const llvm_ir::IrArray::Index& index); 194 195 StatusOr<llvm::Value*> EmitElementalPad( 196 const HloInstruction* hlo, 197 const HloToElementGeneratorMap& operand_to_generator, 198 const llvm_ir::IrArray::Index& padded_index); 199 200 StatusOr<llvm::Value*> EmitElementalDot( 201 const HloInstruction* hlo, 202 const HloToElementGeneratorMap& operand_to_generator, 203 const llvm_ir::IrArray::Index& dot_result_index); 204 205 llvm::IRBuilder<>* const b_; 206 207 llvm::Module* module_; 208 209 // The HloModuleConfig which gathers all settings and values which affect the 210 // compiled executable outside of the HLO code itself. 211 const HloModuleConfig& hlo_module_config_; 212 213 private: 214 // Computes the complex power function, returns (a + i*b)^(c + i*d). 215 StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op, 216 llvm::Value* a, llvm::Value* b, 217 llvm::Value* c, llvm::Value* d); 218 219 // Returns a ElementGenerator for an RNG HloInstruction using the Philox 220 // random number generation algorithm. 221 llvm_ir::ElementGenerator MakePhiloxRngElementGenerator( 222 const HloInstruction* hlo, 223 const HloToElementGeneratorMap& operand_to_generator); 224 225 // Converts the raw value generated by a random number generation algorithm 226 // to the distribution requested by the RNG HloInstruction. 227 // 228 // Precondition: raw_value has at least as many bits as hlo's element type. 229 StatusOr<llvm::Value*> ConvertValueForDistribution( 230 const HloInstruction* hlo, 231 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, 232 const llvm_ir::IrArray::Index& index, llvm::Value* raw_value); 233 }; 234 235 } // namespace xla 236 237 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ 238