1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ 18 19 #include <unordered_map> 20 #include <vector> 21 22 #include "absl/strings/string_view.h" 23 #include "absl/types/span.h" 24 #include "llvm/IR/IRBuilder.h" 25 #include "llvm/IR/Module.h" 26 #include "llvm/IR/Value.h" 27 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 28 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 29 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 30 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" 31 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 34 namespace xla { 35 36 class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> { 37 public: 38 using HloToElementGeneratorMap = 39 std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>; 40 ElementalIrEmitter(llvm::Module * module,llvm::IRBuilder<> * b)41 ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b) 42 : b_(b), module_(module) {} 43 44 virtual ~ElementalIrEmitter() = default; 45 46 // Returns a function to generate an element of the output of `hlo`, given a 47 // map of functions to generate elements of its operands. 48 llvm_ir::ElementGenerator MakeElementGenerator( 49 const HloInstruction* hlo, 50 const HloToElementGeneratorMap& operand_to_generator); 51 b()52 llvm::IRBuilder<>* b() { return b_; } 53 54 // builder() is for IrBuilderMixin. builder()55 llvm::IRBuilder<>* builder() { return b_; } 56 module()57 llvm::Module* module() { return module_; } 58 59 protected: 60 virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op, 61 llvm::Value* lhs_value, 62 llvm::Value* rhs_value); 63 64 virtual llvm::Value* EmitExtractReal(llvm::Value* value); 65 virtual llvm::Value* EmitExtractImag(llvm::Value* value); 66 67 private: 68 virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op, 69 llvm::Value* operand_value); 70 71 virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op, 72 llvm::Value* lhs_value, 73 llvm::Value* rhs_value); 74 75 virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op, 76 llvm::Value* operand_value); 77 78 virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(const HloInstruction* op, 79 llvm::Value* operand_value); 80 81 virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(const HloInstruction* op, 82 llvm::Value* operand_value); 83 84 llvm::Value* IsZero(llvm::Value* v); 85 llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, llvm::Value* rhs); 86 llvm::Value* GetZero(llvm::Type* type); 87 llvm::Value* GetOne(llvm::Type* type); 88 llvm::Value* GetIntSMin(llvm::Type* type); 89 llvm::Value* GetMinusOne(llvm::Type* type); 90 91 llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs, 92 bool is_signed); 93 llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs, 94 bool is_signed); 95 llvm::Value* EmitIntegerPow(llvm::Value* lhs, llvm::Value* rhs, 96 bool is_signed); 97 98 virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op, 99 llvm::Value* lhs_value, 100 llvm::Value* rhs_value, 101 bool is_signed); 102 103 virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op, 104 llvm::Value* lhs_value, 105 llvm::Value* rhs_value); 106 107 virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value, 108 llvm::Value* rhs_value, 109 absl::string_view name); 110 111 virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value, 112 llvm::Value* rhs_value, 113 absl::string_view name); 114 115 llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, 116 bool is_signed); 117 118 llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, 119 bool is_signed); 120 121 virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, 122 llvm::Value* lhs, llvm::Value* rhs, 123 absl::string_view name); 124 125 virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type, 126 llvm::Value* value); 127 128 virtual StatusOr<llvm::Value*> EmitSqrt(PrimitiveType prim_type, 129 llvm::Value* value); 130 131 virtual StatusOr<llvm::Value*> EmitCbrt(PrimitiveType prim_type, 132 llvm::Value* value); 133 134 virtual StatusOr<llvm::Value*> EmitRsqrt(PrimitiveType prim_type, 135 llvm::Value* value); 136 137 virtual StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type, 138 llvm::Value* value); 139 140 virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type, 141 llvm::Value* value); 142 143 virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type, 144 llvm::Value* value); 145 146 virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type, 147 llvm::Value* value, 148 absl::string_view name); 149 150 virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type, 151 llvm::Value* value); 152 153 virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, 154 llvm::Value* lhs, llvm::Value* rhs, 155 absl::string_view name); 156 157 virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type, 158 llvm::Value* value); 159 160 virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo, 161 llvm::Value* x); 162 163 virtual StatusOr<std::tuple<llvm::Value*, llvm::Value*, llvm::Value*>> 164 EmitComplexAbsHelper(PrimitiveType prim_type, llvm::Value* operand_value, 165 bool return_sqrt); 166 167 virtual StatusOr<llvm::Value*> EmitComplexAbs(PrimitiveType prim_type, 168 llvm::Value* operand_value); 169 170 virtual StatusOr<llvm::Value*> EmitSqrtComplexAbs(PrimitiveType prim_type, 171 llvm::Value* operand_value); 172 virtual StatusOr<llvm::Value*> EmitRsqrtComplexAbs( 173 PrimitiveType prim_type, llvm::Value* operand_value); 174 175 virtual StatusOr<llvm::Value*> EmitComplexSqrt(const HloInstruction* op, 176 PrimitiveType prim_type, 177 llvm::Value* operand_value); 178 179 virtual StatusOr<llvm::Value*> EmitComplexCbrt(const HloInstruction* op, 180 PrimitiveType prim_type, 181 llvm::Value* operand_value); 182 183 virtual StatusOr<llvm::Value*> EmitComplexRsqrt(const HloInstruction* op, 184 PrimitiveType prim_type, 185 llvm::Value* operand_value); 186 187 StatusOr<llvm::Value*> EmitAccumResult( 188 absl::Span<llvm::Value* const> accumulator_addrs, 189 llvm::ArrayRef<llvm::Type*> accumulator_types, bool is_variadic); 190 191 // Composes a complex struct. imag may be nullptr for simple cast operations. 192 llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, 193 llvm::Value* imag); 194 195 // Emit `accumulator + lhs * rhs` for the given primitive type. 196 llvm::Value* EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs, 197 llvm::Value* accumulator, 198 xla::PrimitiveType primitive_type); 199 200 // Identifier of the thread unique among all threads on the device EmitThreadId()201 virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); } 202 203 StatusOr<llvm::Value*> EmitElementalSelect( 204 const HloInstruction* hlo, 205 const HloToElementGeneratorMap& operand_to_generator, 206 const llvm_ir::IrArray::Index& index); 207 208 StatusOr<llvm::Value*> EmitElementalClamp( 209 const HloInstruction* hlo, 210 const HloToElementGeneratorMap& operand_to_generator, 211 const llvm_ir::IrArray::Index& index); 212 213 StatusOr<llvm::Value*> EmitElementalConcatenate( 214 const HloInstruction* hlo, 215 const HloToElementGeneratorMap& operand_to_generator, 216 const llvm_ir::IrArray::Index& target_index); 217 218 StatusOr<llvm::Value*> EmitElementalDynamicSlice( 219 const HloInstruction* hlo, 220 const HloToElementGeneratorMap& operand_to_generator, 221 const llvm_ir::IrArray::Index& index); 222 223 StatusOr<llvm::Value*> EmitElementalGather( 224 const HloInstruction* hlo, 225 const HloToElementGeneratorMap& operand_to_generator, 226 const llvm_ir::IrArray::Index& index); 227 228 StatusOr<llvm::Value*> EmitElementalDynamicUpdateSlice( 229 const HloInstruction* hlo, 230 const HloToElementGeneratorMap& operand_to_generator, 231 const llvm_ir::IrArray::Index& index); 232 233 StatusOr<llvm::Value*> EmitElementalPad( 234 const HloInstruction* hlo, 235 const HloToElementGeneratorMap& operand_to_generator, 236 const llvm_ir::IrArray::Index& padded_index); 237 238 StatusOr<llvm::Value*> EmitElementalDot( 239 const HloInstruction* hlo, 240 const HloToElementGeneratorMap& operand_to_generator, 241 const llvm_ir::IrArray::Index& dot_result_index); 242 243 virtual StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall( 244 const HloComputation& callee, absl::Span<llvm::Value* const> parameters, 245 absl::string_view name) = 0; 246 247 StatusOr<llvm::Value*> EmitElementalMap( 248 const HloMapInstruction* map_instr, 249 absl::Span<llvm::Value* const> elemental_operands); 250 251 StatusOr<llvm::Value*> EmitElementalReduceWindow( 252 const HloReduceWindowInstruction* reduce_window, 253 std::vector<llvm_ir::ElementGenerator> input_generators, 254 std::vector<llvm_ir::ElementGenerator> initial_value_generators, 255 const llvm_ir::IrArray::Index& index); 256 257 StatusOr<llvm::Value*> EmitElementalReduce( 258 const HloReduceInstruction* reduce, 259 std::vector<llvm_ir::ElementGenerator> input_generators, 260 std::vector<llvm_ir::ElementGenerator> initial_value_generators, 261 const llvm_ir::IrArray::Index& index); 262 263 virtual StatusOr<llvm::Value*> EmitConvolution( 264 const HloInstruction* hlo, 265 const HloToElementGeneratorMap& operand_to_generator, 266 const llvm_ir::IrArray::Index& index); 267 268 // Computes the complex power function, returns (a + i*b)^(c + i*d). 269 StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op, 270 llvm::Value* a, llvm::Value* b, 271 llvm::Value* c, llvm::Value* d); 272 273 // Evaluates a polynomial using Horner's method. 274 StatusOr<llvm::Value*> EvaluatePolynomial( 275 llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients); 276 277 virtual bool fast_min_max() = 0; 278 279 llvm::IRBuilder<>* const b_; 280 281 llvm::Module* module_; 282 }; 283 284 } // namespace xla 285 286 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ 287