• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/strings/str_cat.h"
27 #include "llvm/IR/BasicBlock.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
31 #include "tensorflow/compiler/xla/primitive_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
33 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
34 #include "tensorflow/compiler/xla/service/hlo_module.h"
35 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
37 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
38 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/statusor.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/lib/random/random.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/types.h"
48 
49 namespace xla {
50 
51 using absl::StrCat;
52 using llvm_ir::IrArray;
53 using llvm_ir::IrName;
54 using llvm_ir::SetToFirstInsertPoint;
55 
56 namespace {
57 
GlobalRandomValue()58 int64 GlobalRandomValue() {
59   static auto* mu = new tensorflow::mutex();
60   static std::mt19937_64 rng{42};
61   tensorflow::mutex_lock l(*mu);
62   return rng();
63 }
64 
EmitReducePrecisionFloat(llvm::Value * x,int64 exponent_bits,int64 mantissa_bits,llvm::IRBuilder<> * b)65 llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
66                                       int64 mantissa_bits,
67                                       llvm::IRBuilder<>* b) {
68   // Integer and float types for casting and constant generation.
69   llvm::Type* float_type = x->getType();
70   llvm::IntegerType* int_type = b->getInt32Ty();
71 
72   // Cast the input value to an integer for bitwise manipulation.
73   llvm::Value* x_as_int = b->CreateBitCast(x, int_type);
74 
75   if (mantissa_bits < 23) {
76     // Last remaining mantissa bit.
77     const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
78 
79     // Compute rounding bias for round-to-nearest with ties to even.  This is
80     // equal to a base value of 0111... plus one bit if the last remaining
81     // mantissa bit is 1.
82     const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
83     llvm::Value* x_last_mantissa_bit = b->CreateLShr(
84         b->CreateAnd(x_as_int,
85                      llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
86         (23 - mantissa_bits));
87     llvm::Value* x_rounding_bias =
88         b->CreateAdd(x_last_mantissa_bit,
89                      llvm::ConstantInt::get(int_type, base_rounding_bias));
90 
91     // Add rounding bias, and mask out truncated bits.  Note that the case
92     // where adding the rounding bias overflows into the exponent bits is
93     // correct; the non-masked mantissa bits will all be zero, and the
94     // exponent will be incremented by one.
95     const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
96     x_as_int = b->CreateAdd(x_as_int, x_rounding_bias);
97     x_as_int = b->CreateAnd(x_as_int,
98                             llvm::ConstantInt::get(int_type, truncation_mask));
99   }
100 
101   if (exponent_bits < 8) {
102     // Masks for f32 values.
103     const uint32_t f32_sign_bit_mask = 1u << 31;
104     const uint32_t f32_exp_bits_mask = 0xffu << 23;
105 
106     // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
107     // significant bit -- is equal to 1.0f for all exponent sizes.  Adding
108     // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
109     // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
110     // exponent (corresponding to 0.0f).
111     //
112     // Thus, the f32 exponent corresponding to the highest non-infinite
113     // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
114     // exponent corresponding to the lowest exponent for a bit size of n is
115     // (2^7-1) - 2^(n-1)-1.
116     //
117     // Note that we have already checked that exponents_bits >= 1.
118     const uint32_t f32_exponent_bias = (1 << 7) - 1;
119     const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1;
120     const uint32_t reduced_max_exponent =
121         f32_exponent_bias + reduced_exponent_bias;
122     const uint32_t reduced_min_exponent =
123         f32_exponent_bias - reduced_exponent_bias;
124 
125     // Do we overflow or underflow?
126     llvm::Value* x_exponent = b->CreateAnd(
127         x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
128     llvm::Value* x_overflows = b->CreateICmpUGT(
129         x_exponent,
130         llvm::ConstantInt::get(int_type, reduced_max_exponent << 23));
131     llvm::Value* x_underflows = b->CreateICmpULE(
132         x_exponent,
133         llvm::ConstantInt::get(int_type, reduced_min_exponent << 23));
134 
135     // Compute appropriately-signed values of zero and infinity.
136     llvm::Value* x_signed_zero = b->CreateAnd(
137         x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask));
138     llvm::Value* x_signed_inf = b->CreateOr(
139         x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
140 
141     // Force to zero or infinity if overflow or underflow.  (Note that this
142     // truncates all denormal values to zero, rather than rounding them.)
143     x_as_int = b->CreateSelect(x_overflows, x_signed_inf, x_as_int);
144     x_as_int = b->CreateSelect(x_underflows, x_signed_zero, x_as_int);
145   }
146 
147   // Cast the result back to a floating-point type.
148   llvm::Value* result = b->CreateBitCast(x_as_int, float_type);
149 
150   // Correct result for NaN inputs.
151   //
152   // The exponent handling will "normalize" NaN values to infinities, which is
153   // undesirable (except in the case with no mantissa bits, in which case it
154   // is mandatory).  This logic also handles cases where mantissa-rounding
155   // causes a NaN's mantissa to overflow into the exponent bits, which would
156   // otherwise create an erroneous zero value.
157   //
158   // If the fast-math flags are set to assume no NaNs, the comparison is likely
159   // to be optimized away, so there's no point in even emitting it.
160   if (!b->getFastMathFlags().noNaNs()) {
161     llvm::Value* x_is_nan = b->CreateFCmpUNO(x, x);
162 
163     if (mantissa_bits > 0) {
164       result = b->CreateSelect(x_is_nan, x, result);
165     } else {
166       result = b->CreateSelect(
167           x_is_nan, llvm::ConstantFP::getInfinity(float_type), result);
168     }
169   }
170   return result;
171 }
172 
EmitF32ToBF16(llvm::Value * f32_value,llvm::IRBuilder<> * b)173 llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, llvm::IRBuilder<>* b) {
174   auto reduced_precision = EmitReducePrecisionFloat(
175       f32_value,
176       /*exponent_bits=*/primitive_util::kBFloat16ExponentBits,
177       /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b);
178   auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty());
179   auto shifted = b->CreateLShr(as_int32, 16);
180   auto truncated = b->CreateTrunc(shifted, b->getInt16Ty());
181   return b->CreateBitCast(truncated, b->getInt16Ty());
182 }
183 
EmitBF16ToF32(llvm::Value * bf16_value,llvm::IRBuilder<> * b)184 llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, llvm::IRBuilder<>* b) {
185   auto as_int16 = b->CreateBitCast(bf16_value, b->getInt16Ty());
186   auto as_int32 = b->CreateZExt(as_int16, b->getInt32Ty());
187   auto shifted = b->CreateShl(as_int32, 16);
188   return b->CreateBitCast(shifted, b->getFloatTy());
189 }
190 
EmitIntegralToFloating(llvm::Value * integer_value,PrimitiveType from_type,PrimitiveType to_type,llvm::Module * module,llvm::IRBuilder<> * b)191 llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
192                                     PrimitiveType from_type,
193                                     PrimitiveType to_type, llvm::Module* module,
194                                     llvm::IRBuilder<>* b) {
195   if (primitive_util::IsSignedIntegralType(from_type)) {
196     return b->CreateSIToFP(integer_value,
197                            llvm_ir::PrimitiveTypeToIrType(to_type, module));
198   } else {
199     CHECK(primitive_util::IsUnsignedIntegralType(from_type) ||
200           from_type == PRED);
201     return b->CreateUIToFP(integer_value,
202                            llvm_ir::PrimitiveTypeToIrType(to_type, module));
203   }
204 }
205 
206 }  // namespace
207 
EmitUnaryOp(const HloInstruction * op,llvm::Value * operand_value)208 StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
209     const HloInstruction* op, llvm::Value* operand_value) {
210   if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
211       op->operand(0)->shape().element_type() == PRED) {
212     return EmitIntegerUnaryOp(op, operand_value);
213   } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
214     return EmitComplexUnaryOp(op, operand_value);
215   } else {
216     return EmitFloatUnaryOp(op, operand_value);
217   }
218 }
219 
EmitIntegerUnaryOp(const HloInstruction * op,llvm::Value * operand_value)220 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
221     const HloInstruction* op, llvm::Value* operand_value) {
222   switch (op->opcode()) {
223     case HloOpcode::kConvert: {
224       PrimitiveType from_type = op->operand(0)->shape().element_type();
225       PrimitiveType to_type = op->shape().element_type();
226       CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED)
227           << from_type;
228       if (from_type == to_type) {
229         return operand_value;
230       }
231       if (to_type == PRED) {
232         return b_->CreateZExt(
233             ICmpNE(operand_value,
234                    llvm::ConstantInt::get(operand_value->getType(), 0)),
235             llvm_ir::PrimitiveTypeToIrType(PRED, module_));
236       }
237       if (primitive_util::IsIntegralType(to_type)) {
238         return IntCast(operand_value,
239                        llvm_ir::PrimitiveTypeToIrType(to_type, module_),
240                        primitive_util::IsSignedIntegralType(from_type));
241       }
242       if (primitive_util::IsFloatingPointType(to_type)) {
243         if (to_type == BF16) {
244           return EmitF32ToBF16(EmitIntegralToFloating(operand_value, from_type,
245                                                       F32, module_, b_),
246                                b_);
247         }
248         return EmitIntegralToFloating(operand_value, from_type, to_type,
249                                       module_, b_);
250       }
251       if (primitive_util::IsComplexType(to_type)) {
252         auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
253             primitive_util::ComplexComponentType(to_type), module_);
254         if (primitive_util::IsSignedIntegralType(from_type)) {
255           return EmitComposeComplex(
256               op, SIToFP(operand_value, to_ir_component_type), nullptr);
257         }
258         if (primitive_util::IsUnsignedIntegralType(from_type) ||
259             from_type == PRED) {
260           return EmitComposeComplex(
261               op, UIToFP(operand_value, to_ir_component_type), nullptr);
262         }
263       }
264       return Unimplemented("conversion from primitive type %s to %s",
265                            PrimitiveType_Name(from_type),
266                            PrimitiveType_Name(to_type));
267     }
268     case HloOpcode::kBitcastConvert: {
269       PrimitiveType from_type = op->operand(0)->shape().element_type();
270       PrimitiveType to_type = op->shape().element_type();
271       CHECK(primitive_util::IsIntegralType(from_type));
272       if (from_type == to_type) {
273         return operand_value;
274       }
275       if (primitive_util::BitWidth(from_type) ==
276           primitive_util::BitWidth(to_type)) {
277         return BitCast(operand_value,
278                        llvm_ir::PrimitiveTypeToIrType(to_type, module_));
279       }
280       return InvalidArgument(
281           "bitcast conversion from primitive type %s to %s with unequal "
282           "bit-widths (%u versus %u) ",
283           PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
284           primitive_util::BitWidth(from_type),
285           primitive_util::BitWidth(to_type));
286     }
287     case HloOpcode::kAbs: {
288       bool is_signed =
289           primitive_util::IsSignedIntegralType(op->shape().element_type());
290       if (is_signed) {
291         auto type =
292             llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
293         auto cmp = ICmpSGE(operand_value, GetZero(type));
294         return Select(cmp, operand_value, Neg(operand_value));
295       } else {
296         return operand_value;
297       }
298     }
299     case HloOpcode::kClz: {
300       auto is_zero_undef = b_->getFalse();
301       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctlz,
302                                           {operand_value, is_zero_undef},
303                                           {operand_value->getType()}, b_);
304     }
305     case HloOpcode::kSign: {
306       CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type()))
307           << op->shape().element_type();
308       auto type =
309           llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
310       auto cmp = ICmpEQ(operand_value, GetZero(type));
311       auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1);
312       return Select(cmp, GetZero(type), Or(ashr, 1));
313     }
314     case HloOpcode::kNegate:
315       return Neg(operand_value);
316     case HloOpcode::kNot: {
317       auto type = op->shape().element_type();
318       if (type == PRED) {
319         // It is not sufficient to just call CreateNot() here because a PRED
320         // is represented as an i8 and the truth value is stored only in the
321         // bottom bit.
322         return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())),
323                               llvm_ir::PrimitiveTypeToIrType(PRED, module_));
324       } else if (primitive_util::IsIntegralType(type)) {
325         return Not(operand_value);
326       }
327       return Unimplemented("unary op Not is not defined for type '%d'", type);
328     }
329     default:
330       return Unimplemented("unary integer op '%s'",
331                            HloOpcodeString(op->opcode()));
332   }
333 }
334 
EmitFloatUnaryOp(const HloInstruction * op,llvm::Value * operand_value)335 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
336     const HloInstruction* op, llvm::Value* operand_value) {
337   switch (op->opcode()) {
338     case HloOpcode::kConvert: {
339       PrimitiveType from_type = op->operand(0)->shape().element_type();
340       PrimitiveType to_type = op->shape().element_type();
341       CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type;
342       if (from_type == to_type) {
343         return operand_value;
344       }
345       if (primitive_util::IsComplexType(to_type)) {
346         PrimitiveType to_component_type =
347             primitive_util::ComplexComponentType(to_type);
348         if (from_type == to_component_type) {
349           return EmitComposeComplex(op, operand_value, nullptr);
350         }
351         return EmitComposeComplex(
352             op,
353             FPCast(operand_value,
354                    llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
355             nullptr);
356       }
357       if (from_type == BF16) {
358         TF_RET_CHECK(to_type != BF16);
359         operand_value = EmitBF16ToF32(operand_value, b_);
360         from_type = F32;
361         if (from_type == to_type) {
362           return operand_value;
363         }
364       }
365       if (from_type == F32 && to_type == BF16) {
366         return EmitF32ToBF16(operand_value, b_);
367       }
368       if (to_type == PRED) {
369         return b_->CreateZExt(
370             FCmpUNE(operand_value,
371                     llvm::ConstantFP::get(operand_value->getType(), 0.0)),
372             llvm_ir::PrimitiveTypeToIrType(PRED, module_));
373       }
374       if (primitive_util::IsFloatingPointType(to_type)) {
375         return FPCast(operand_value,
376                       llvm_ir::PrimitiveTypeToIrType(to_type, module_));
377       }
378       if (primitive_util::IsSignedIntegralType(to_type)) {
379         return FPToSI(operand_value,
380                       llvm_ir::PrimitiveTypeToIrType(to_type, module_));
381       }
382       if (primitive_util::IsUnsignedIntegralType(to_type)) {
383         return FPToUI(operand_value,
384                       llvm_ir::PrimitiveTypeToIrType(to_type, module_));
385       }
386       return Unimplemented("unhandled conversion operation: %s => %s",
387                            PrimitiveType_Name(from_type),
388                            PrimitiveType_Name(to_type));
389     }
390     case HloOpcode::kBitcastConvert: {
391       PrimitiveType from_type = op->operand(0)->shape().element_type();
392       PrimitiveType to_type = op->shape().element_type();
393       CHECK(primitive_util::IsFloatingPointType(from_type));
394       if (from_type == to_type) {
395         return operand_value;
396       }
397       if (primitive_util::BitWidth(from_type) ==
398           primitive_util::BitWidth(to_type)) {
399         return BitCast(operand_value,
400                        llvm_ir::PrimitiveTypeToIrType(to_type, module_));
401       }
402       return InvalidArgument(
403           "bitcast conversion from primitive type %s to %s with unequal "
404           "bit-widths (%u versus %u) ",
405           PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
406           primitive_util::BitWidth(from_type),
407           primitive_util::BitWidth(to_type));
408     }
409     case HloOpcode::kExp:
410       return EmitExp(op->shape().element_type(), operand_value);
411     case HloOpcode::kExpm1:
412       return EmitExpm1(op->shape().element_type(), operand_value);
413     case HloOpcode::kLog:
414       return EmitLog(op->shape().element_type(), operand_value);
415     case HloOpcode::kLog1p:
416       return EmitLog1p(op->shape().element_type(), operand_value);
417     case HloOpcode::kCos:
418       return EmitCos(op->shape().element_type(), operand_value);
419     case HloOpcode::kSin:
420       return EmitSin(op->shape().element_type(), operand_value);
421     case HloOpcode::kTanh:
422       return EmitTanh(op->shape().element_type(), operand_value);
423     case HloOpcode::kSqrt:
424       return EmitSqrt(op->shape().element_type(), operand_value);
425     case HloOpcode::kRsqrt:
426       return EmitRsqrt(op->shape().element_type(), operand_value);
427     case HloOpcode::kFloor:
428       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
429                                           {operand_value},
430                                           {operand_value->getType()}, b_);
431     case HloOpcode::kCeil:
432       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ceil,
433                                           {operand_value},
434                                           {operand_value->getType()}, b_);
435     case HloOpcode::kAbs:
436       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
437                                           {operand_value},
438                                           {operand_value->getType()}, b_);
439     case HloOpcode::kRoundNearestAfz:
440       return EmitRoundNearestAfz(op->shape().element_type(), operand_value);
441     case HloOpcode::kSign: {
442       auto type = operand_value->getType();
443       auto zero = llvm::ConstantFP::get(type, 0.0);
444       auto ne0_i1 = FCmpONE(operand_value, zero);
445       auto ne0_float = UIToFP(ne0_i1, type);
446       llvm::Value* result = llvm_ir::EmitCallToIntrinsic(
447           llvm::Intrinsic::copysign, {ne0_float, operand_value},
448           {operand_value->getType()}, b_);
449       auto is_nan = FCmpUNO(operand_value, operand_value);
450       result = Select(is_nan, operand_value, result);
451       return result;
452     }
453     case HloOpcode::kIsFinite: {
454       // abs(x) o!= inf, this works because the comparison returns false if
455       // either operand is NaN.
456       auto type = operand_value->getType();
457       auto abs_value = llvm_ir::EmitCallToIntrinsic(
458           llvm::Intrinsic::fabs, {operand_value}, {type}, b_);
459       auto infinity = llvm::ConstantFP::getInfinity(type);
460       auto not_infinite = FCmpONE(abs_value, infinity);
461       return b_->CreateZExt(not_infinite,
462                             llvm_ir::PrimitiveTypeToIrType(PRED, module_));
463     }
464     case HloOpcode::kNegate:
465       return FNeg(operand_value);
466     case HloOpcode::kReal:
467       return operand_value;
468     case HloOpcode::kImag:
469       return llvm::ConstantFP::get(operand_value->getType(), 0.0);
470     default:
471       return Unimplemented("unary floating-point op '%s'",
472                            HloOpcodeString(op->opcode()));
473   }
474 }
475 
EmitComplexUnaryOp(const HloInstruction * op,llvm::Value * operand_value)476 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
477     const HloInstruction* op, llvm::Value* operand_value) {
478   PrimitiveType input_type = op->operand(0)->shape().element_type();
479   PrimitiveType component_type =
480       primitive_util::IsComplexType(input_type)
481           ? primitive_util::ComplexComponentType(input_type)
482           : input_type;
483   switch (op->opcode()) {
484     case HloOpcode::kLog: {
485       // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
486       auto a = EmitExtractReal(operand_value);
487       auto b = EmitExtractImag(operand_value);
488       llvm::Type* llvm_ty = a->getType();
489       auto sum_sq = FAdd(FMul(a, a), FMul(b, b));
490       TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
491       TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a));
492       auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
493       return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
494     }
495     case HloOpcode::kLog1p: {
496       // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
497       auto a = EmitExtractReal(operand_value);
498       auto b = EmitExtractImag(operand_value);
499       llvm::Type* llvm_ty = a->getType();
500       auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
501       auto a_plus_one = FAdd(a, one);
502       auto sum_sq = FAdd(FMul(a_plus_one, a_plus_one), FMul(b, b));
503       TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
504       TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one));
505       auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
506       return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
507     }
508     case HloOpcode::kConvert: {
509       PrimitiveType from_type = op->operand(0)->shape().element_type();
510       TF_RET_CHECK(primitive_util::IsComplexType(from_type));
511       PrimitiveType to_type = op->shape().element_type();
512       TF_RET_CHECK(primitive_util::IsComplexType(to_type));
513       if (from_type == to_type) {
514         return operand_value;
515       }
516       PrimitiveType to_component_type =
517           primitive_util::ComplexComponentType(to_type);
518       auto to_ir_component_type =
519           llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
520       return EmitComposeComplex(
521           op, FPCast(EmitExtractReal(operand_value), to_ir_component_type),
522           FPCast(EmitExtractImag(operand_value), to_ir_component_type));
523     }
524     case HloOpcode::kExp: {
525       // e^(a+bi) = e^a*(cos(b)+sin(b)i)
526       TF_ASSIGN_OR_RETURN(
527           auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value)));
528       TF_ASSIGN_OR_RETURN(
529           auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
530       TF_ASSIGN_OR_RETURN(
531           auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
532       return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b));
533     }
534     case HloOpcode::kExpm1: {
535       // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
536       TF_ASSIGN_OR_RETURN(
537           auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value)));
538       TF_ASSIGN_OR_RETURN(
539           auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
540       TF_ASSIGN_OR_RETURN(
541           auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
542       auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0);
543       auto real_result = FSub(FMul(exp_a, cos_b), one);
544       auto imag_result = FMul(exp_a, sin_b);
545       return EmitComposeComplex(op, real_result, imag_result);
546     }
547     case HloOpcode::kCos: {
548       // cos(z) = .5(e^(iz) + e^(-iz))
549       // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
550       // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
551       // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i))
552       // cos(-x) = cos(x) and sin(-x) = -sin(x), so
553       // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
554       //           = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
555       auto a = EmitExtractReal(operand_value);
556       auto b = EmitExtractImag(operand_value);
557       auto type = a->getType();
558       TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
559       auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
560       auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
561       TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
562       TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
563       return EmitComposeComplex(op,
564                                 FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)),
565                                 FMul(sin_a, FSub(half_exp_neg_b, half_exp_b)));
566     }
567     case HloOpcode::kSin: {
568       // sin(z) = .5i(e^(-iz) - e^(iz))
569       // sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi)))
570       //           = .5i(e^(b-ai) - e^(-b+ai))
571       // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
572       // sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i))
573       //           = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a)))
574       // cos(-x) = cos(x) and sin(-x) = -sin(x), so
575       //           = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
576       //           = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
577       auto a = EmitExtractReal(operand_value);
578       auto b = EmitExtractImag(operand_value);
579       auto type = a->getType();
580       TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
581       auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
582       auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
583       TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
584       TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
585       return EmitComposeComplex(op,
586                                 FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)),
587                                 FMul(cos_a, FSub(half_exp_b, half_exp_neg_b)));
588     }
589     case HloOpcode::kTanh: {
590       /*
591       tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x))
592       e^(a+bi) = e^a*(cos(b)+sin(b)i)
593       so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) /
594               (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a))
595       cos(b)=cos(-b), sin(-b)=-sin(b)
596       so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) /
597               (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a))
598              =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) /
599               (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a))
600              =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) /
601               (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a))
602       This is a complex division, so we can multiply by denom_conj/denom_conj
603              =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) *
604               (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) /
605               ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
606              =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
607                i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
608               ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
609       */
610       auto a = EmitExtractReal(operand_value);
611       auto b = EmitExtractImag(operand_value);
612       TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a));
613       TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b));
614       TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b));
615       auto exp_neg_a = FDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
616       auto exp_2a_minus_exp_neg_2a =
617           FSub(FMul(exp_a, exp_a), FMul(exp_neg_a, exp_neg_a));
618       auto cos_b_sq = FMul(cos_b, cos_b);
619       auto sin_b_sq = FMul(sin_b, sin_b);
620       auto real_num = FAdd(FMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
621                            FMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
622       auto cos_b_sin_b = FMul(cos_b, sin_b);
623       auto exp_a_plus_exp_neg_a = FAdd(exp_a, exp_neg_a);
624       auto exp_a_plus_exp_neg_a_sq =
625           FMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
626       auto exp_a_minus_exp_neg_a = FSub(exp_a, exp_neg_a);
627       auto exp_a_minus_exp_neg_a_sq =
628           FMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
629       auto imag_num = FMul(
630           cos_b_sin_b, FSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq));
631       auto denom = FAdd(FMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
632                         FMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
633       return EmitComposeComplex(op, FDiv(real_num, denom),
634                                 FDiv(imag_num, denom));
635     }
636     case HloOpcode::kAbs: {
637       auto sum_sq = FAdd(
638           FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)),
639           FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value)));
640       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq},
641                                           {sum_sq->getType()}, b_);
642     }
643     case HloOpcode::kSign: {  // Sign(c) = c / |c|
644       auto sum_sq = FAdd(
645           FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)),
646           FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value)));
647       auto cplx_abs = llvm_ir::EmitCallToIntrinsic(
648           llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_);
649       auto type = cplx_abs->getType();
650       auto zero = llvm::ConstantFP::get(type, 0.0);
651       auto oeq = FCmpOEQ(cplx_abs, zero);
652       return Select(
653           oeq, EmitComposeComplex(op, zero, zero),
654           EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs),
655                              FDiv(EmitExtractImag(operand_value), cplx_abs)));
656     }
657     case HloOpcode::kSqrt: {
658       auto a = EmitExtractReal(operand_value);
659       auto b = EmitExtractImag(operand_value);
660       auto c = llvm::ConstantFP::get(a->getType(), 0.5);
661       auto d = llvm::ConstantFP::get(b->getType(), 0.0);
662       return EmitComplexPower(op, a, b, c, d);
663     }
664     case HloOpcode::kRsqrt: {
665       auto a = EmitExtractReal(operand_value);
666       auto b = EmitExtractImag(operand_value);
667       auto c = llvm::ConstantFP::get(a->getType(), -0.5);
668       auto d = llvm::ConstantFP::get(b->getType(), 0.0);
669       return EmitComplexPower(op, a, b, c, d);
670     }
671     case HloOpcode::kNegate:
672       return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
673                                 FNeg(EmitExtractImag(operand_value)));
674     case HloOpcode::kReal:
675       return EmitExtractReal(operand_value);
676     case HloOpcode::kImag:
677       return EmitExtractImag(operand_value);
678     default:
679       return Unimplemented("unary complex op '%s'",
680                            HloOpcodeString(op->opcode()));
681   }
682 }
683 
EmitBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)684 StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
685     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
686   PrimitiveType operand_type = op->operand(0)->shape().element_type();
687   if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
688       operand_type == PRED) {
689     return EmitIntegerBinaryOp(
690         op, lhs_value, rhs_value,
691         primitive_util::IsSignedIntegralType(operand_type));
692   } else if (primitive_util::IsComplexType(operand_type)) {
693     return EmitComplexBinaryOp(op, lhs_value, rhs_value);
694   } else {
695     return EmitFloatBinaryOp(op, lhs_value, rhs_value);
696   }
697 }
698 
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)699 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
700     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
701   switch (op->opcode()) {
702     case HloOpcode::kComplex:
703       return EmitComposeComplex(op, lhs_value, rhs_value);
704     case HloOpcode::kAdd:
705       return FAdd(lhs_value, rhs_value);
706     case HloOpcode::kSubtract:
707       return FSub(lhs_value, rhs_value);
708     case HloOpcode::kMultiply:
709       return FMul(lhs_value, rhs_value);
710     case HloOpcode::kDivide:
711       return FDiv(lhs_value, rhs_value);
712     case HloOpcode::kRemainder:
713       return FRem(lhs_value, rhs_value);
714     // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
715     // comparisons always return false when one of the operands is NaN, whereas
716     // unordered comparisons return true.
717     //
718     // We use ordered comparisons for everything except kNe, where we use an
719     // unordered comparison.  This makes x != y equivalent to !(x == y), and
720     // matches C++'s semantics.
721     case HloOpcode::kCompare: {
722       switch (op->comparison_direction()) {
723         case ComparisonDirection::kEq:
724           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
725                                          rhs_value, b_);
726         case ComparisonDirection::kNe:
727           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
728                                          rhs_value, b_);
729         case ComparisonDirection::kLt:
730           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
731                                          rhs_value, b_);
732         case ComparisonDirection::kGt:
733           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
734                                          rhs_value, b_);
735         case ComparisonDirection::kLe:
736           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
737                                          rhs_value, b_);
738         case ComparisonDirection::kGe:
739           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
740                                          rhs_value, b_);
741       }
742     }
743     case HloOpcode::kMaximum:
744       return EmitFloatMax(lhs_value, rhs_value);
745     case HloOpcode::kMinimum:
746       return EmitFloatMin(lhs_value, rhs_value);
747     case HloOpcode::kPower:
748       return EmitPow(op->shape().element_type(), lhs_value, rhs_value);
749     case HloOpcode::kAtan2:
750       return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value);
751     default:
752       return Unimplemented("binary floating point op '%s'",
753                            HloOpcodeString(op->opcode()));
754   }
755 }
756 
757 // (a+bi)^(c+di) =
758 //    (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
759 //    where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
EmitComplexPower(const HloInstruction * op,llvm::Value * a,llvm::Value * b,llvm::Value * c,llvm::Value * d)760 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexPower(
761     const HloInstruction* op, llvm::Value* a, llvm::Value* b, llvm::Value* c,
762     llvm::Value* d) {
763   PrimitiveType component_type =
764       primitive_util::ComplexComponentType(op->shape().element_type());
765   auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b));
766   auto zero = llvm::ConstantFP::get(a->getType(), 0);
767   auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
768   auto one = llvm::ConstantFP::get(a->getType(), 1);
769   auto half_c = FMul(one_half, c);
770 
771   TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
772                       EmitPow(component_type, aa_p_bb, half_c));
773 
774   auto neg_d = FNeg(d);
775   TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a));
776   auto neg_d_arg_lhs = FMul(neg_d, arg_lhs);
777   TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
778                       EmitExp(component_type, neg_d_arg_lhs));
779   auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
780   TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
781   auto half_d = FMul(one_half, d);
782   auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb));
783   TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
784   TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
785   // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
786   // Branch Cuts for Complex Elementary Functions or Much Ado About
787   // Nothing's Sign Bit, W. Kahan, Section 10.
788   return Select(
789       And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)),
790       EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero),
791       EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)));
792 }
793 
EmitComplexBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)794 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
795     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
796   switch (op->opcode()) {
797     case HloOpcode::kAdd:
798       return EmitComposeComplex(
799           op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
800           FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
801     case HloOpcode::kSubtract:
802       return EmitComposeComplex(
803           op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
804           FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
805     case HloOpcode::kMultiply:
806       return EmitComposeComplex(
807           op,
808           FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
809                FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))),
810           FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
811                FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))));
812     case HloOpcode::kDivide: {
813       // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di))
814       // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2)
815       auto rhs_sum_sq =
816           FAdd(FMul(EmitExtractReal(rhs_value), EmitExtractReal(rhs_value)),
817                FMul(EmitExtractImag(rhs_value), EmitExtractImag(rhs_value)));
818       auto type = rhs_sum_sq->getType();
819       auto zero = llvm::ConstantFP::get(type, 0.0);
820       auto oeq = FCmpOEQ(rhs_sum_sq, zero);
821       auto real_inf_or_nan = FDiv(EmitExtractReal(lhs_value), zero);
822       auto imag_inf_or_nan = FDiv(EmitExtractImag(lhs_value), zero);
823       return Select(
824           oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan),
825           EmitComposeComplex(op,
826                              FDiv(FAdd(FMul(EmitExtractReal(lhs_value),
827                                             EmitExtractReal(rhs_value)),
828                                        FMul(EmitExtractImag(lhs_value),
829                                             EmitExtractImag(rhs_value))),
830                                   rhs_sum_sq),
831                              FDiv(FSub(FMul(EmitExtractImag(lhs_value),
832                                             EmitExtractReal(rhs_value)),
833                                        FMul(EmitExtractReal(lhs_value),
834                                             EmitExtractImag(rhs_value))),
835                                   rhs_sum_sq)));
836     }
837     // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
838     // comparisons always return false when one of the operands is NaN, whereas
839     // unordered comparisons return true.
840     //
841     // We use ordered comparisons for everything except kNe, where we use an
842     // unordered comparison.  This makes x != y equivalent to !(x == y), and
843     // matches C++'s semantics.
844     case HloOpcode::kCompare: {
845       switch (op->comparison_direction()) {
846         case ComparisonDirection::kEq:
847           return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
848                                              EmitExtractReal(lhs_value),
849                                              EmitExtractReal(rhs_value), b_),
850                      llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
851                                              EmitExtractImag(lhs_value),
852                                              EmitExtractImag(rhs_value), b_));
853         case ComparisonDirection::kNe:
854           return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
855                                             EmitExtractReal(lhs_value),
856                                             EmitExtractReal(rhs_value), b_),
857                     llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
858                                             EmitExtractImag(lhs_value),
859                                             EmitExtractImag(rhs_value), b_));
860         default:
861           return Unimplemented(
862               "complex comparison '%s'",
863               ComparisonDirectionToString(op->comparison_direction()));
864       }
865     }
866     case HloOpcode::kPower: {
867       auto a = EmitExtractReal(lhs_value);
868       auto b = EmitExtractImag(lhs_value);
869       auto c = EmitExtractReal(rhs_value);
870       auto d = EmitExtractImag(rhs_value);
871       return EmitComplexPower(op, a, b, c, d);
872     }
873     default:
874       return Unimplemented("binary complex op '%s'",
875                            HloOpcodeString(op->opcode()));
876   }
877 }
878 
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value)879 llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
880                                               llvm::Value* rhs_value) {
881   return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_);
882 }
883 
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value)884 llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
885                                               llvm::Value* rhs_value) {
886   return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_);
887 }
888 
889 // TODO(b/123355973): We have an implementation of erfinv in math.cc.  We
890 // shouldn't have two implementations, especially since this one isn't testable
891 // (it's only observable via a normally-distributed RNG).
EmitErfInv(PrimitiveType prim_type,llvm::Value * x)892 StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
893                                                       llvm::Value* x) {
894   if (prim_type != F16 && prim_type != F32 && prim_type != F64) {
895     return Unimplemented(
896         "Inverse erf is only implemented for element "
897         "types F16, F32 and F64.");
898   }
899 
900   // Upcast half to float.
901   if (prim_type == F16) {
902     x = b_->CreateFPExt(x, b_->getFloatTy());
903   }
904 
905   auto get_float = [&](const double f) {
906     return llvm::ConstantFP::get(x->getType(), f);
907   };
908   auto multiply_add = [&](absl::Span<const double> coefficients,
909                           llvm::Value* w) {
910     llvm::Value* p = get_float(coefficients.front());
911     coefficients.remove_prefix(1);
912     for (float coefficient : coefficients) {
913       p = FAdd(FMul(p, w), get_float(coefficient));
914     }
915     return p;
916   };
917 
918   // Approximation for inverse error function from
919   //   Giles, M., "Approximating the erfinv function".
920   // The approximation has the form (float version):
921   //   w = -log((1-x)*(1+x))
922   //   if ( w < 5 ) {
923   //     w = w - 2.5
924   //     p = sum_{i=1}^n lq[i]*w^i
925   //   } else {
926   //     w = sqrt(w) - 3
927   //     p = sum_{i=1}^n gq[i]*w^i
928   //   }
929   //   return p*x
930   llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
931       module_, llvm::Intrinsic::log, {x->getType()});
932 
933   llvm::Value* w = FNeg(Call(
934       logf_fn, {FMul(FSub(get_float(1.0f), x), FAdd(get_float(1.0f), x))}));
935 
936   llvm::Value* p_addr =
937       llvm_ir::EmitAllocaAtFunctionEntry(x->getType(), "p.addr", b_);
938 
939   if (prim_type == F16 || prim_type == F32) {
940     llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
941         FCmpOLT(w, get_float(5.0f)), "w_less_than_five", b_);
942     // Handle true BB.
943     SetToFirstInsertPoint(if_data.true_block, b_);
944     {
945       llvm::Value* lw = FSub(w, get_float(2.5f));
946       absl::Span<const double> lq{
947           2.81022636e-08f,  3.43273939e-07f, -3.5233877e-06f,
948           -4.39150654e-06f, 0.00021858087f,  -0.00125372503f,
949           -0.00417768164f,  0.246640727f,    1.50140941f};
950       llvm::Value* p = multiply_add(lq, lw);
951       Store(p, p_addr);
952     }
953 
954     // Handle false BB.
955     SetToFirstInsertPoint(if_data.false_block, b_);
956     {
957       llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
958           module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
959 
960       llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.0f));
961       absl::Span<const double> gq{
962           -0.000200214257f, 0.000100950558f, 0.00134934322f,
963           -0.00367342844f,  0.00573950773f,  -0.0076224613f,
964           0.00943887047f,   1.00167406f,     2.83297682f};
965       llvm::Value* p = multiply_add(gq, gw);
966       Store(p, p_addr);
967     }
968 
969     SetToFirstInsertPoint(if_data.after_block, b_);
970   } else {
971     DCHECK(prim_type == F64);
972 
973     llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
974         FCmpOLT(w, get_float(6.25)), "w_less_than_6.25", b_);
975 
976     SetToFirstInsertPoint(if_data.true_block, b_);
977     {
978       llvm::Value* lw = FSub(w, get_float(3.125));
979       absl::Span<const double> c{
980           -3.6444120640178196996e-21, -1.685059138182016589e-19,
981           1.2858480715256400167e-18,  1.115787767802518096e-17,
982           -1.333171662854620906e-16,  2.0972767875968561637e-17,
983           6.6376381343583238325e-15,  -4.0545662729752068639e-14,
984           -8.1519341976054721522e-14, 2.6335093153082322977e-12,
985           -1.2975133253453532498e-11, -5.4154120542946279317e-11,
986           1.051212273321532285e-09,   -4.1126339803469836976e-09,
987           -2.9070369957882005086e-08, 4.2347877827932403518e-07,
988           -1.3654692000834678645e-06, -1.3882523362786468719e-05,
989           0.0001867342080340571352,   -0.00074070253416626697512,
990           -0.0060336708714301490533,  0.24015818242558961693,
991           1.6536545626831027356};
992       llvm::Value* p = multiply_add(c, lw);
993       Store(p, p_addr);
994     }
995 
996     SetToFirstInsertPoint(if_data.false_block, b_);
997     llvm_ir::LlvmIfData if_data_second = llvm_ir::EmitIfThenElse(
998         FCmpOLT(w, get_float(16.0)), "w_less_than_16", b_);
999     SetToFirstInsertPoint(if_data_second.true_block, b_);
1000     {
1001       llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
1002           module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
1003 
1004       llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.25));
1005       absl::Span<const double> t1{
1006           2.2137376921775787049e-09,  9.0756561938885390979e-08,
1007           -2.7517406297064545428e-07, 1.8239629214389227755e-08,
1008           1.5027403968909827627e-06,  -4.013867526981545969e-06,
1009           2.9234449089955446044e-06,  1.2475304481671778723e-05,
1010           -4.7318229009055733981e-05, 6.8284851459573175448e-05,
1011           2.4031110387097893999e-05,  -0.0003550375203628474796,
1012           0.00095328937973738049703,  -0.0016882755560235047313,
1013           0.0024914420961078508066,   -0.0037512085075692412107,
1014           0.005370914553590063617,    1.0052589676941592334,
1015           3.0838856104922207635};
1016       llvm::Value* p = multiply_add(t1, gw);
1017       Store(p, p_addr);
1018     }
1019 
1020     SetToFirstInsertPoint(if_data_second.false_block, b_);
1021     {
1022       llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
1023           module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
1024 
1025       llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(5.0));
1026       absl::Span<const double> t2{
1027           -2.7109920616438573243e-11, -2.5556418169965252055e-10,
1028           1.5076572693500548083e-09,  -3.7894654401267369937e-09,
1029           7.6157012080783393804e-09,  -1.4960026627149240478e-08,
1030           2.9147953450901080826e-08,  -6.7711997758452339498e-08,
1031           2.2900482228026654717e-07,  -9.9298272942317002539e-07,
1032           4.5260625972231537039e-06,  -1.9681778105531670567e-05,
1033           7.5995277030017761139e-05,  -0.00021503011930044477347,
1034           -0.00013871931833623122026, 1.0103004648645343977,
1035           4.8499064014085844221};
1036       llvm::Value* p = multiply_add(t2, gw);
1037       Store(p, p_addr);
1038     }
1039 
1040     SetToFirstInsertPoint(if_data.after_block, b_);
1041   }
1042   llvm::Value* p = Load(p_addr);
1043   x = FMul(p, x);
1044   // Trunc back to half if needed.
1045   if (prim_type == F16) {
1046     x = b_->CreateFPTrunc(x, b_->getHalfTy());
1047   }
1048   return x;
1049 }
1050 
EmitErfcInv(PrimitiveType prim_type,llvm::Value * value)1051 StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type,
1052                                                        llvm::Value* value) {
1053   // Compute erfcinv(value) by calculating erfinv(1.0 - value).
1054   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1055   auto one = llvm::ConstantFP::get(type, 1.0);
1056   return EmitErfInv(prim_type, FSub(one, value));
1057 }
1058 
EmitLog(PrimitiveType prim_type,llvm::Value * value)1059 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
1060                                                    llvm::Value* value) {
1061   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
1062                                       {value->getType()}, b_);
1063 }
1064 
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)1065 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
1066                                                      llvm::Value* value) {
1067   auto x = value;
1068   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1069   auto one = llvm::ConstantFP::get(type, 1.0);
1070   auto negative_half = llvm::ConstantFP::get(type, -0.5);
1071   // When x is large, the naive evaluation of ln(x + 1) is more
1072   // accurate than the Taylor series.
1073   TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
1074   // The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + ….
1075   auto for_small_x = FMul(FAdd(FMul(negative_half, x), one), x);
1076   const auto kAntilogarithmIsSmallThreshold = 1e-4;
1077   auto abs_x =
1078       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1079   auto x_is_small = FCmpOLT(
1080       abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold));
1081   return Select(x_is_small, for_small_x, for_large_x);
1082 }
1083 
EmitSqrt(PrimitiveType prim_type,llvm::Value * value)1084 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
1085                                                     llvm::Value* value) {
1086   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {value},
1087                                       {value->getType()}, b_);
1088 }
1089 
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)1090 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
1091                                                      llvm::Value* value) {
1092   TF_ASSIGN_OR_RETURN(auto sqrt, EmitSqrt(prim_type, value));
1093   return FDiv(llvm::ConstantFP::get(sqrt->getType(), 1.0), sqrt);
1094 }
1095 
EmitSin(PrimitiveType prim_type,llvm::Value * value)1096 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
1097                                                    llvm::Value* value) {
1098   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
1099                                       {value->getType()}, b_);
1100 }
1101 
EmitCos(PrimitiveType prim_type,llvm::Value * value)1102 StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
1103                                                    llvm::Value* value) {
1104   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
1105                                       {value->getType()}, b_);
1106 }
1107 
EmitExp(PrimitiveType prim_type,llvm::Value * value)1108 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
1109                                                    llvm::Value* value) {
1110   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
1111                                       {value->getType()}, b_);
1112 }
1113 
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)1114 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
1115                                                      llvm::Value* value) {
1116   auto x = value;
1117   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1118   auto one = llvm::ConstantFP::get(type, 1.0);
1119   auto half = llvm::ConstantFP::get(type, 0.5);
1120   // When the exponent is large, the naive evaluation of e^(x) - 1 is more
1121   // accurate than the Taylor series.
1122   TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value));
1123   auto for_large_x = FSub(exp_x, one);
1124   // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + ….
1125   // We want exp(x)-1 which is x + x^2/2 + x^3/6 + ….
1126   auto x_squared = FAdd(x, x);
1127   auto x_squared_over_two = FMul(x_squared, half);
1128   auto for_small_x = FAdd(x, x_squared_over_two);
1129   const auto kExponentIsSmallThreshold = 1e-5;
1130   auto abs_x =
1131       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1132   auto x_is_small =
1133       FCmpOLT(abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold));
1134   return Select(x_is_small, for_small_x, for_large_x);
1135 }
1136 
EmitRoundNearestAfz(PrimitiveType,llvm::Value * value)1137 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRoundNearestAfz(
1138     PrimitiveType /*prim_type*/, llvm::Value* value) {
1139   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round, {value},
1140                                       {value->getType()}, b_);
1141 }
1142 
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs)1143 StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
1144                                                    llvm::Value* lhs,
1145                                                    llvm::Value* rhs) {
1146   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
1147                                       {lhs->getType()}, b_);
1148 }
1149 
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs)1150 StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
1151                                                      llvm::Value* lhs,
1152                                                      llvm::Value* rhs) {
1153   return Unimplemented("atan2");
1154 }
1155 
EmitTanh(PrimitiveType prim_type,llvm::Value * value)1156 StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
1157                                                     llvm::Value* value) {
1158   return Unimplemented("tanh");
1159 }
1160 
EmitReducePrecision(const HloInstruction * hlo,llvm::Value * x)1161 StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
1162     const HloInstruction* hlo, llvm::Value* x) {
1163   if (hlo->operand(0)->shape().element_type() != F32) {
1164     return Unimplemented("reduce-precision only implemented for F32");
1165   }
1166   return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(),
1167                                   /*mantissa_bits=*/hlo->mantissa_bits(), b_);
1168 }
1169 
SaturateShiftIfNecessary(llvm::IRBuilder<> * b,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * shift_result,bool saturate_to_sign_bit)1170 static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
1171                                              llvm::Value* lhs, llvm::Value* rhs,
1172                                              llvm::Value* shift_result,
1173                                              bool saturate_to_sign_bit) {
1174   llvm::IntegerType* integer_type =
1175       llvm::cast<llvm::IntegerType>(lhs->getType());
1176   unsigned integer_bitsize = integer_type->getBitWidth();
1177   llvm::ConstantInt* integer_bitsize_constant =
1178       llvm::ConstantInt::get(integer_type, integer_bitsize);
1179   llvm::ConstantInt* zero = llvm::ConstantInt::get(integer_type, 0);
1180   llvm::ConstantInt* minus_one = llvm::ConstantInt::get(integer_type, -1);
1181   llvm::Value* saturated_value;
1182   if (saturate_to_sign_bit) {
1183     saturated_value =
1184         b->CreateSelect(b->CreateICmpSLT(lhs, zero), minus_one, zero);
1185   } else {
1186     saturated_value = zero;
1187   }
1188   llvm::Value* shift_amt_in_range =
1189       b->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk");
1190   return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value);
1191 }
1192 
GetOne(llvm::Type * type)1193 llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) {
1194   return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 1);
1195 }
1196 
GetZero(llvm::Type * type)1197 llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) {
1198   return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 0);
1199 }
1200 
GetIntSMin(llvm::Type * type)1201 llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) {
1202   auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1203   return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue(
1204                                                   integer_type->getBitWidth()));
1205 }
1206 
GetMinusOne(llvm::Type * type)1207 llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) {
1208   auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1209   return llvm::ConstantInt::get(
1210       integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth()));
1211 }
1212 
IsZero(llvm::Value * v)1213 llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) {
1214   return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0));
1215 }
1216 
IsIntMinDivisionOverflow(llvm::Value * lhs,llvm::Value * rhs)1217 llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs,
1218                                                           llvm::Value* rhs) {
1219   return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())),
1220              ICmpEQ(rhs, GetMinusOne(rhs->getType())));
1221 }
1222 
EmitIntegerDivide(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1223 llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs,
1224                                                    llvm::Value* rhs,
1225                                                    bool is_signed) {
1226   // Integer division overflow behavior:
1227   //
1228   // X / 0 == -1
1229   // INT_SMIN /s -1 = INT_SMIN
1230 
1231   if (!is_signed) {
1232     llvm::Value* udiv_is_unsafe = IsZero(rhs);
1233     llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs);
1234     llvm::Value* safe_div = UDiv(lhs, safe_rhs);
1235     return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div);
1236   }
1237 
1238   llvm::Value* has_zero_divisor = IsZero(rhs);
1239   llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1240   llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1241   llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs);
1242   llvm::Value* safe_div = SDiv(lhs, safe_rhs);
1243 
1244   return Select(
1245       has_zero_divisor, GetMinusOne(lhs->getType()),
1246       Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div));
1247 }
1248 
EmitIntegerRemainder(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1249 llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
1250                                                       llvm::Value* rhs,
1251                                                       bool is_signed) {
1252   // Integer remainder overflow behavior:
1253   //
1254   // X % 0 == X
1255   // INT_SMIN %s -1 = 0
1256 
1257   if (!is_signed) {
1258     llvm::Value* urem_is_unsafe = IsZero(rhs);
1259     llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs);
1260     llvm::Value* safe_rem = URem(lhs, safe_rhs);
1261     return Select(urem_is_unsafe, lhs, safe_rem);
1262   }
1263 
1264   llvm::Value* has_zero_divisor = IsZero(rhs);
1265   llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1266   llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1267   llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs);
1268   llvm::Value* safe_rem = SRem(lhs, safe_rhs);
1269 
1270   return Select(
1271       has_zero_divisor, lhs,
1272       Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem));
1273 }
1274 
EmitIntegerBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1275 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
1276     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
1277     bool is_signed) {
1278   switch (op->opcode()) {
1279     // TODO(jingyue): add the "nsw" attribute for signed types.
1280     case HloOpcode::kAdd:
1281       return Add(lhs_value, rhs_value);
1282     case HloOpcode::kSubtract:
1283       return Sub(lhs_value, rhs_value);
1284     case HloOpcode::kMultiply:
1285       return Mul(lhs_value, rhs_value);
1286     case HloOpcode::kDivide:
1287       return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
1288     case HloOpcode::kRemainder:
1289       return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
1290     case HloOpcode::kCompare: {
1291       switch (op->comparison_direction()) {
1292         case ComparisonDirection::kEq:
1293           return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
1294                                          rhs_value, b_);
1295         case ComparisonDirection::kNe:
1296           return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
1297                                          rhs_value, b_);
1298         case ComparisonDirection::kLt:
1299           return llvm_ir::EmitComparison(
1300               is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
1301               lhs_value, rhs_value, b_);
1302         case ComparisonDirection::kGt:
1303           return llvm_ir::EmitComparison(
1304               is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
1305               lhs_value, rhs_value, b_);
1306         case ComparisonDirection::kLe:
1307           return llvm_ir::EmitComparison(
1308               is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
1309               lhs_value, rhs_value, b_);
1310         case ComparisonDirection::kGe:
1311           return llvm_ir::EmitComparison(
1312               is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
1313               lhs_value, rhs_value, b_);
1314       }
1315     }
1316     case HloOpcode::kMinimum:
1317       return EmitIntegralMin(lhs_value, rhs_value, is_signed);
1318     case HloOpcode::kMaximum:
1319       return EmitIntegralMax(lhs_value, rhs_value, is_signed);
1320     case HloOpcode::kAnd:
1321       return And(lhs_value, rhs_value);
1322     case HloOpcode::kOr:
1323       return Or(lhs_value, rhs_value);
1324     case HloOpcode::kXor:
1325       return Xor(lhs_value, rhs_value);
1326 
1327     // Shifting out bits >= the number of bits in the type being shifted
1328     // produces a poison value in LLVM which is basically "deferred undefined
1329     // behavior" -- doing something observable with such a value precipitates
1330     // UB.  We replace the poison value with a constant to avoid this deferred
1331     // UB.
1332     case HloOpcode::kShiftRightArithmetic:
1333       return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1334                                       AShr(lhs_value, rhs_value),
1335                                       /*saturate_to_sign_bit=*/true);
1336     case HloOpcode::kShiftLeft:
1337       return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1338                                       Shl(lhs_value, rhs_value),
1339                                       /*saturate_to_sign_bit=*/false);
1340     case HloOpcode::kShiftRightLogical:
1341       return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1342                                       LShr(lhs_value, rhs_value),
1343                                       /*saturate_to_sign_bit=*/false);
1344     default:
1345       return Unimplemented("binary integer op '%s'",
1346                            HloOpcodeString(op->opcode()));
1347   }
1348 }
1349 
EmitIntegralMax(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1350 llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
1351                                                  llvm::Value* rhs_value,
1352                                                  bool is_signed) {
1353   return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
1354                                          : llvm::ICmpInst::ICMP_UGE,
1355                                lhs_value, rhs_value),
1356                 lhs_value, rhs_value);
1357 }
1358 
EmitIntegralMin(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1359 llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
1360                                                  llvm::Value* rhs_value,
1361                                                  bool is_signed) {
1362   return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
1363                                          : llvm::ICmpInst::ICMP_ULE,
1364                                lhs_value, rhs_value),
1365                 lhs_value, rhs_value);
1366 }
1367 
ConvertValueForDistribution(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index,llvm::Value * raw_value)1368 StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
1369     const HloInstruction* hlo,
1370     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1371     const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) {
1372   TF_ASSIGN_OR_RETURN(llvm::Value * a_or_mean,
1373                       operand_to_generator.at(hlo->operand(0))(index));
1374   TF_ASSIGN_OR_RETURN(llvm::Value * b_or_sigma,
1375                       operand_to_generator.at(hlo->operand(1))(index));
1376   PrimitiveType elem_prim_ty = hlo->shape().element_type();
1377   llvm::Type* elem_ir_ty =
1378       llvm_ir::PrimitiveTypeToIrType(elem_prim_ty, module_);
1379   llvm::Type* raw_value_ty = raw_value->getType();
1380 
1381   // If we're generating a floating-point value, convert the raw integer R (i.e.
1382   // `raw_value`) to a float in the range [0, 1).
1383   //
1384   // The basic approach is to choose a significand and exponent such that the
1385   // significand is uniformly distributed and the exponent is distributed, well,
1386   // exponentially (it's more likely to be close to 0 than far from 0).
1387   //
1388   // An easy way to do this is to say that the significand is the first S bits
1389   // of R, and the exponent is determined by the number of trailing zeroes in R,
1390   // exp = 2^-(cttz(R) + 1).  (+1 because the largest exponent should be -1;
1391   // this way the largest value we can return is 1.999... * 2^-1 = 1-ε.)
1392   //
1393   // This results in a small bias.  Namely, if R has enough trailing zeroes, the
1394   // significand and exponent will "overlap".  As a concrete example, consider
1395   //
1396   //         20 X's                 12 zeroes
1397   //   R = 0bXXXXXXXXXXXXXXXXXXXX000000000000
1398   //
1399   // Here the exponent is 2^-13 because R has 12 trailing zeroes.  The
1400   // significand is made up of the first 23 most-significant bits of R, which we
1401   // observe contain 3 zeroes.  This is biased because any random value with
1402   // exponent 2^-12 will have a significand which ends in `000`.
1403   //
1404   // For f32s, this problem occurs only when there are more than 32-23 = 9
1405   // trailing zeros, which happens with probability 0.5^10 = ~0.1%. Moreover the
1406   // probability of a large bias (i.e. many trailing 0s in the significand) is
1407   // exponentially low.  So we deem this acceptable.
1408   llvm::Value* elem_value = raw_value;
1409   if (elem_ir_ty->isFloatingPointTy()) {
1410     const auto& dest_flt_semantics = elem_ir_ty->getFltSemantics();
1411     const int bits = raw_value_ty->getPrimitiveSizeInBits();
1412     CHECK_GE(bits, llvm::APFloat::semanticsSizeInBits(dest_flt_semantics));
1413 
1414     // Subtract 1 because semanticsPrecision includes the "hidden bit", i.e. the
1415     // implicit "1." at the beginning of the significand.
1416     const int significand_bits =
1417         llvm::APFloat::semanticsPrecision(dest_flt_semantics) - 1;
1418 
1419     llvm::Value* cttz = llvm_ir::EmitCallToIntrinsic(
1420         llvm::Intrinsic::cttz, {raw_value, /*is_zero_undef=*/b_->getFalse()},
1421         {raw_value->getType()}, b_);
1422     llvm::Value* significand = LShr(raw_value, bits - significand_bits);
1423 
1424     // Exponent bias is -127 for f32, meaning that if the exponent is E and the
1425     // significand is S, then the value of the number is 2^(E - 127) * (1.S).
1426     //
1427     // We want cttz == 0 to correspond to 2^-1, so our exponent is computed as
1428     // E = 126 - cttz.
1429     //
1430     // For f64, this is all the same, except the bias is -1023.
1431     //
1432     // In IEEE floating point, the absolute value of the exponent bias equals
1433     // the value of the largest possible exponent.
1434     const int bias = -llvm::APFloat::semanticsMaxExponent(dest_flt_semantics);
1435     llvm::Value* exponent =
1436         Sub(llvm::ConstantInt::get(cttz->getType(), -bias - 1), cttz);
1437 
1438     // Now just slot everything into place!  The `Trunc` is here because
1439     // raw_value may be larger than our float destination.
1440     elem_value =
1441         BitCast(Trunc(Or(Shl(exponent, significand_bits), significand),
1442                       b_->getIntNTy(elem_ir_ty->getPrimitiveSizeInBits())),
1443                 elem_ir_ty);
1444   }
1445 
1446   // Convert the value for the requested distribution.
1447   switch (hlo->random_distribution()) {
1448     case RNG_UNIFORM: {
1449       if (elem_ir_ty->isFloatingPointTy()) {
1450         return FAdd(FMul(FSub(b_or_sigma, a_or_mean), elem_value), a_or_mean);
1451       } else {
1452         // To generate a uniform random value in [a, b) from a raw random sample
1453         // in range [0, 2^N), we let range = b - a and return
1454         // (a + raw_value % range). If range is not a power of 2, raw values
1455         // larger than (2^N - 2^N % range) are biased toward results in
1456         // [a, a + (limit % range)). An unbiased algorithm would need to drop
1457         // raw values and re-sample, but we don't do this because re-sampling in
1458         // an efficient way is complex, and it's not clear that users need it.
1459         // In particular, if one thread in a GPU warp needs to re-sample, we pay
1460         // the same cost as if the whole warp were to re-sample.  So an
1461         // efficient re-sampling implementation on GPU would need to do
1462         // nontrivial work to share entropy between threads in the warp.
1463         auto range = Sub(b_or_sigma, a_or_mean);
1464         return Add(a_or_mean, URem(elem_value, range));
1465       }
1466     }
1467     case RNG_NORMAL: {
1468       TF_ASSIGN_OR_RETURN(
1469           llvm::Value * r,
1470           EmitErfcInv(elem_prim_ty, FMul(llvm::ConstantFP::get(elem_ir_ty, 2.0),
1471                                          elem_value)));
1472       return FAdd(FMul(r, b_or_sigma), a_or_mean);
1473     }
1474     default:
1475       return InvalidArgument(
1476           "unhandled distribution %s",
1477           RandomDistribution_Name(hlo->random_distribution()));
1478   }
1479 }
1480 
1481 namespace {
1482 
1483 // Checks that the primitive type is supported by the elemental IR emitter for
1484 // Philox RNG and returns the number of elements in each 128 bit sample of the
1485 // Philox RNG algorithm.
GetNumberOfElementsPerPhiloxRngSample(PrimitiveType elem_prim_ty)1486 int32 GetNumberOfElementsPerPhiloxRngSample(PrimitiveType elem_prim_ty) {
1487   // Calculate the number of elements, that is the number of random numbers, in
1488   // a 128 bit sample.
1489   switch (elem_prim_ty) {
1490     case U32:
1491     case S32:
1492     case F32:
1493     // The algorithm uses 32 bits to generate values for F16.
1494     case F16:
1495       return 4;
1496     case U64:
1497     case S64:
1498     case F64:
1499       return 2;
1500     default:
1501       // BF16 is converted to F16 by the hlo pass HloElementTypeConverter.
1502       // Other data types are not supported by XLA random operation.
1503       LOG(FATAL) << "Unrecognized primitive type for RNG " << elem_prim_ty;
1504   }
1505   return 0;
1506 }
1507 
1508 // Calculates the four uint32 values for the 128-bit Philox sample.
CalculateSampleValues(llvm::Value * sample_idx,llvm::Value * hlo_random_value,llvm::Value * global_random_number,llvm::Value * rng_state,llvm::IRBuilder<> * b)1509 std::array<llvm::Value*, 4> CalculateSampleValues(
1510     llvm::Value* sample_idx, llvm::Value* hlo_random_value,
1511     llvm::Value* global_random_number, llvm::Value* rng_state,
1512     llvm::IRBuilder<>* b) {
1513   llvm::Type* index_ty = sample_idx->getType();
1514 
1515   std::array<llvm::Value*, 4> counter_values;
1516 
1517   // Use the sample index to initialize counter[0] and counter[1].
1518   unsigned index_ty_size_in_bits = index_ty->getPrimitiveSizeInBits();
1519   CHECK(index_ty_size_in_bits == 32 || index_ty_size_in_bits == 64);
1520   if (index_ty_size_in_bits == 32) {
1521     counter_values[0] = sample_idx;
1522     counter_values[1] = b->getInt32(0);
1523   } else {
1524     std::tie(counter_values[0], counter_values[1]) =
1525         llvm_ir::SplitInt64ToInt32s(b, sample_idx);
1526   }
1527 
1528   // Xor the global state variable with the global random number seed and use
1529   // the result to initialize counter[2] and counter[3].
1530   std::tie(counter_values[2], counter_values[3]) = llvm_ir::SplitInt64ToInt32s(
1531       b, b->CreateXor(rng_state, global_random_number));
1532 
1533   // The algorithm uses a 64 bit key, which is also interpreted as two uint32
1534   // values.
1535   llvm::Value* key_values[2];
1536 
1537   // Use a module random number to initialize the key.
1538   std::tie(key_values[0], key_values[1]) =
1539       llvm_ir::SplitInt64ToInt32s(b, hlo_random_value);
1540 
1541   // Prepare the constants used in the Philox RNG Algorithm.
1542   llvm::Value* philoxW32A = b->getInt32(0x9E3779B9);
1543   llvm::Value* philoxW32B = b->getInt32(0xBB67AE85);
1544   llvm::Value* philoxM4xW32A = b->getInt32(0xD2511F53);
1545   llvm::Value* philoxM4xW32B = b->getInt32(0xCD9E8D57);
1546 
1547   // Compute the 128 bit value for the current sample by repeating the
1548   // single round computation and key raising computation for ten times.
1549   for (int round = 0; round < 10; ++round) {
1550     // A single round of computation of the counter values is as follows:
1551     //  MultiplyHighLow(kPhiloxM4x32A, counter[0], &lo0, &hi0);
1552     //  MultiplyHighLow(kPhiloxM4x32B, counter[2], &lo1, &hi1);
1553     //  counter[0] = hi1 ^ counter[1] ^ key[0];
1554     //  counter[1] = lo1;
1555     //  counter[2] = hi0 ^ counter[3] ^ key[1];
1556     //  counter[3] = lo0;
1557     llvm::Value* lo0;
1558     llvm::Value* hi0;
1559     std::tie(lo0, hi0) =
1560         llvm_ir::UMulLowHigh32(b, philoxM4xW32A, counter_values[0]);
1561     llvm::Value* lo1;
1562     llvm::Value* hi1;
1563     std::tie(lo1, hi1) =
1564         llvm_ir::UMulLowHigh32(b, philoxM4xW32B, counter_values[2]);
1565     counter_values[0] =
1566         b->CreateXor(hi1, b->CreateXor(counter_values[1], key_values[0]));
1567     counter_values[1] = lo1;
1568     counter_values[2] =
1569         b->CreateXor(hi0, b->CreateXor(counter_values[3], key_values[1]));
1570     counter_values[3] = lo0;
1571     key_values[0] = b->CreateAdd(key_values[0], philoxW32A);
1572     key_values[1] = b->CreateAdd(key_values[1], philoxW32B);
1573   }
1574 
1575   return counter_values;
1576 }
1577 
1578 }  // namespace
1579 
1580 // Implements the Philox algorithm to generate random numbers in parallel.
1581 // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
1582 //   http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
1583 //
1584 // The paper presents a few variants of the Philox algorithm, we picked the
1585 // 4x32_10 version of the algorithm for the following reasons:
1586 //   . 4x32 uses 32-bit multiplication which is fast on GPUs.
1587 //   . The authors recommend the 10-round variant, and TensorFlow also uses it.
1588 //
1589 // Precondition: the RNG instruction is not fused.
MakePhiloxRngElementGenerator(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator)1590 llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
1591     const HloInstruction* hlo,
1592     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
1593   VLOG(3) << "Using philox RNG algorithm";
1594   CHECK(!hlo->IsFused());
1595   // A random number generated by the per module random number generator.
1596   // This ensures that each RNG HLO generates a different random sequence.
1597   llvm::Value* hlo_random_value = b_->getInt64(hlo->GetModule()->RandomNew64());
1598   // A value specified by the configuration or generated by a global random
1599   // number generator.
1600   llvm::Value* global_random_number =
1601       b_->getInt64(hlo_module_config_.seed() != 0 ? hlo_module_config_.seed()
1602                                                   : GlobalRandomValue());
1603 
1604   int elems_per_sample =
1605       GetNumberOfElementsPerPhiloxRngSample(hlo->shape().element_type());
1606 
1607   // Allocate stack storage for the 128 bit sample as four int32.
1608   llvm::Type* int32_ty = b_->getInt32Ty();
1609   llvm::Value* sample_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
1610       int32_ty, /*element_count=*/b_->getInt32(4), "sample", b_);
1611 
1612   // Load the global state variable for the Philox RNG algorithm.
1613   llvm::GlobalVariable* rng_state_ptr =
1614       llvm_ir::GetOrCreateVariableForPhiloxRngState(module_, b_);
1615   llvm::Value* rng_state = Load(rng_state_ptr, "rng_state_value");
1616 
1617   // Build and return the elemental IR generator to generate a random value for
1618   // the element corresponding to the current thread.
1619   //
1620   // This elemental IR generator computes one sample with multiple random
1621   // numbers but only returns one random number. As a result, neighboring
1622   // threads may calculate the same sample unnecessarily. However, if the
1623   // kernel containing the RNG hlo is unrolled, LLVM is able to optimize away
1624   // the duplicated computation of the same sample. In particular, if the unroll
1625   // factor is a multiplier of elems_per_sample, LLVM is able to completely
1626   // remove such duplicated computation. If the unroll factor is a non-trivial
1627   // factor of elems_per_sample, LLVM can only partially remove such duplicated
1628   // computation.
1629   return [=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
1630     llvm::Type* index_ty = index.GetType();
1631     // Calculate the linear element index.
1632     llvm::Value* elem_idx = index.linear();
1633     if (elem_idx == nullptr) {
1634       elem_idx = index.Linearize(AsInt64Slice(hlo->shape().dimensions()), b_);
1635     }
1636 
1637     // Calculate the index for the 128 bit sample and the offset of the current
1638     // element within the sample.
1639     llvm::Value* elems_per_sample_value =
1640         llvm::ConstantInt::get(index_ty, elems_per_sample);
1641     llvm::Value* sample_idx = UDiv(elem_idx, elems_per_sample_value);
1642     llvm::Value* elem_offset = URem(elem_idx, elems_per_sample_value);
1643 
1644     std::array<llvm::Value*, 4> counter_values = CalculateSampleValues(
1645         sample_idx, hlo_random_value, global_random_number, rng_state, b_);
1646 
1647     // Store the four counter_values into the sample_address alloca so we can
1648     // load the elem_offset'th one below.
1649     for (int idx = 0; idx < 4; ++idx) {
1650       Store(counter_values[idx],
1651             InBoundsGEP(sample_address, b_->getInt32(idx)));
1652     }
1653 
1654     llvm::Type* int64_ty = b_->getInt64Ty();
1655     CHECK(elems_per_sample == 2 || elems_per_sample == 4);
1656     llvm::Type* raw_value_ty = elems_per_sample == 2 ? int64_ty : int32_ty;
1657     // Retrieve the raw value for the current element from the current sample.
1658     llvm::Value* raw_elem_value = Load(
1659         InBoundsGEP(PointerCast(sample_address, raw_value_ty->getPointerTo()),
1660                     elem_offset),
1661         "raw_elem_value");
1662 
1663     return ConvertValueForDistribution(hlo, operand_to_generator, index,
1664                                        raw_elem_value);
1665   };
1666 }
1667 
EmitElementalSelect(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1668 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
1669     const HloInstruction* hlo,
1670     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1671     const llvm_ir::IrArray::Index& index) {
1672   TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
1673                       operand_to_generator.at(hlo->operand(0))(index));
1674   TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
1675                       operand_to_generator.at(hlo->operand(1))(index));
1676   TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
1677                       operand_to_generator.at(hlo->operand(2))(index));
1678   return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value,
1679                 on_false_value);
1680 }
1681 
EmitElementalClamp(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1682 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
1683     const HloInstruction* hlo,
1684     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1685     const llvm_ir::IrArray::Index& index) {
1686   TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
1687                       operand_to_generator.at(hlo->operand(0))(index));
1688   TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
1689                       operand_to_generator.at(hlo->operand(1))(index));
1690   TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
1691                       operand_to_generator.at(hlo->operand(2))(index));
1692   PrimitiveType prim_type = hlo->shape().element_type();
1693   if (primitive_util::IsFloatingPointType(prim_type)) {
1694     return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
1695   } else if (primitive_util::IsIntegralType(prim_type)) {
1696     bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
1697     return EmitIntegralMin(
1698         max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed);
1699   } else {
1700     return Unimplemented("Clamp unimplemented for %s",
1701                          PrimitiveType_Name(prim_type));
1702   }
1703 }
1704 
EmitElementalConcatenate(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & target_index)1705 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
1706     const HloInstruction* hlo,
1707     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1708     const llvm_ir::IrArray::Index& target_index) {
1709   const int64 concat_dim = hlo->dimensions(0);
1710   auto source_index = target_index;
1711 
1712   llvm::BasicBlock* init_block = b_->GetInsertBlock();
1713 
1714   // A terminator should be present iff we're emitting code
1715   // into the middle (as opposed to the end) of a basic block.
1716   CHECK_EQ(b_->GetInsertPoint() == init_block->end(),
1717            init_block->getTerminator() == nullptr);
1718 
1719   llvm::BasicBlock* exit_block;
1720   if (b_->GetInsertPoint() == init_block->end()) {
1721     exit_block = llvm_ir::CreateBasicBlock(
1722         /*insert_before=*/nullptr, IrName(hlo, "merge"), b_);
1723   } else {
1724     exit_block =
1725         init_block->splitBasicBlock(b_->GetInsertPoint(), IrName(hlo, "merge"));
1726     init_block->getTerminator()->eraseFromParent();
1727   }
1728 
1729   llvm_ir::SetToFirstInsertPoint(exit_block, b_);
1730   llvm::PHINode* output =
1731       PHI(llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
1732           hlo->operands().size());
1733   auto prior_insert_point = b_->GetInsertPoint();
1734 
1735   b_->SetInsertPoint(init_block);
1736 
1737   // Assign a unique id for each *different* operand, and count how often each
1738   // operand is used. If all operands are different, the usage count will be 1
1739   // for each operand.
1740   absl::flat_hash_map<const HloInstruction*, int64> to_unique_operand_id;
1741   std::vector<int64> operand_usage_count;
1742   for (const auto* operand : hlo->operands()) {
1743     if (to_unique_operand_id.contains(operand)) {
1744       ++operand_usage_count[to_unique_operand_id[operand]];
1745     } else {
1746       int64 unique_operand_id = to_unique_operand_id.size();
1747       to_unique_operand_id[operand] = unique_operand_id;
1748       operand_usage_count.push_back(1);
1749     }
1750   }
1751 
1752   // To avoid that we emit the same operand more than once, we create one basic
1753   // block for each *different* operand with a PHI node for the different source
1754   // index inputs.
1755   std::vector<llvm::BasicBlock*> emit_operand_blocks(
1756       to_unique_operand_id.size(), nullptr);
1757   std::vector<llvm::PHINode*> source_index_phis(to_unique_operand_id.size(),
1758                                                 nullptr);
1759   for (const auto* operand : hlo->operands()) {
1760     int64 operand_id = to_unique_operand_id[operand];
1761     if (emit_operand_blocks[operand_id] != nullptr) {
1762       continue;
1763     }
1764 
1765     emit_operand_blocks[operand_id] = llvm_ir::CreateBasicBlock(
1766         exit_block, StrCat("concat_index_from_operand_id", operand_id), b_);
1767     auto saved_insert_point = b_->GetInsertPoint();
1768     llvm_ir::SetToFirstInsertPoint(emit_operand_blocks[operand_id], b_);
1769     source_index_phis[operand_id] =
1770         PHI(source_index.GetType(), operand_usage_count[operand_id]);
1771     std::vector<llvm::Value*> operand_multi_index = source_index.multidim();
1772     operand_multi_index[concat_dim] = source_index_phis[operand_id];
1773 
1774     // Create the terminator of the block before calling operand generators,
1775     // because they require non-degenerate basic blocks.
1776     b_->SetInsertPoint(llvm::BranchInst::Create(
1777         exit_block, /*InsertAtEnd=*/emit_operand_blocks[operand_id]));
1778     llvm_ir::IrArray::Index operand_index(operand_multi_index, operand->shape(),
1779                                           source_index.GetType());
1780     TF_ASSIGN_OR_RETURN(llvm::Value * value,
1781                         operand_to_generator.at(operand)(operand_index));
1782     output->addIncoming(value, b_->GetInsertBlock());
1783     b_->SetInsertPoint(init_block, saved_insert_point);
1784   }
1785 
1786   std::vector<llvm::Value*> source_multi_index = source_index.multidim();
1787   for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
1788        ++operand_idx) {
1789     const HloInstruction* operand = hlo->operand(operand_idx);
1790     auto false_block = llvm_ir::CreateBasicBlock(
1791         exit_block, StrCat("concat_index_not_from_operand", operand_idx), b_);
1792     auto concat_dim_size = source_index.GetConstantWithIndexType(
1793         operand->shape().dimensions(concat_dim));
1794     int64 operand_id = to_unique_operand_id[operand];
1795     source_index_phis[operand_id]->addIncoming(source_multi_index[concat_dim],
1796                                                b_->GetInsertBlock());
1797     CondBr(ICmpULT(source_multi_index[concat_dim], concat_dim_size),
1798            emit_operand_blocks[operand_id], false_block);
1799 
1800     // Subtract the size of the concat dimension of the current operand
1801     // from the source index.
1802     b_->SetInsertPoint(false_block);
1803     source_multi_index[concat_dim] =
1804         Sub(source_multi_index[concat_dim], concat_dim_size);
1805   }
1806 
1807   Unreachable();
1808   b_->SetInsertPoint(exit_block, prior_insert_point);
1809   return output;
1810 }
1811 
EmitElementalDynamicSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1812 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
1813     const HloInstruction* hlo,
1814     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1815     const llvm_ir::IrArray::Index& index) {
1816   // Emit IR to read dynamic start indices from hlo->operand(1).
1817   const HloInstruction* input_hlo = hlo->operand(0);
1818   const int64 rank = input_hlo->shape().rank();
1819   // Use the same index type for all tensor accesses in the same kernel.
1820   llvm::Type* index_type = index.GetType();
1821   std::vector<llvm::Value*> slice_start_multi_index(rank);
1822   for (int64 i = 0; i < rank; ++i) {
1823     auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
1824       return llvm::ConstantInt::get(index_type, c);
1825     };
1826     llvm_ir::IrArray::Index zero_index(index_type);
1827     TF_ASSIGN_OR_RETURN(
1828         llvm::Value * start_index_value,
1829         operand_to_generator.at(hlo->operand(1 + i))(zero_index));
1830 
1831     // Clamp the start index so that the sliced portion fits in the operand:
1832     // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
1833     start_index_value = SExtOrTrunc(start_index_value, index_type);
1834     int64 largest_valid_start_index =
1835         input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i);
1836     CHECK_GE(largest_valid_start_index, 0);
1837 
1838     bool is_signed = ShapeUtil::ElementIsSigned(hlo->operand(1)->shape());
1839     start_index_value = EmitIntegralMin(
1840         index_typed_const(largest_valid_start_index),
1841         EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
1842         is_signed);
1843 
1844     start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
1845     slice_start_multi_index[i] = start_index_value;
1846   }
1847 
1848   std::vector<llvm::Value*> input_multi_index(rank);
1849   for (int64 i = 0; i < rank; ++i) {
1850     // Emit IR which computes:
1851     //   input_index = start_index + offset_index
1852     input_multi_index[i] = Add(slice_start_multi_index[i], index[i]);
1853   }
1854   llvm_ir::IrArray::Index input_index(input_multi_index, input_hlo->shape(),
1855                                       index_type);
1856   return operand_to_generator.at(input_hlo)(input_index);
1857 }
1858 
EmitElementalGather(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1859 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
1860     const HloInstruction* hlo,
1861     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1862     const llvm_ir::IrArray::Index& index) {
1863   const Shape& operand_shape = hlo->operand(0)->shape();
1864   const Shape& indices_shape = hlo->operand(1)->shape();
1865   const Shape& output_shape = hlo->shape();
1866 
1867   const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers();
1868 
1869   const llvm_ir::ElementGenerator& operand_generator =
1870       operand_to_generator.at(hlo->operand(0));
1871   const llvm_ir::ElementGenerator& indices_generator =
1872       operand_to_generator.at(hlo->operand(1));
1873 
1874   llvm::Type* index_type = index.GetType();
1875   // This is the index into `operand` that holds the element we want to
1876   // generate.
1877   std::vector<llvm::Value*> operand_multi_index;
1878 
1879   // First copy in the window indices to operand_index. Also collect a mapping
1880   // from operand dimension to output window dimension. Elided window dimensions
1881   // map to -1.
1882   std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1);
1883   for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0;
1884        i < e; i++) {
1885     if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
1886       operand_multi_index.push_back(index.GetConstantWithIndexType(0));
1887     } else {
1888       int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++);
1889       operand_to_output_dim[i] = output_window_dim;
1890       operand_multi_index.push_back(index[output_window_dim]);
1891     }
1892   }
1893 
1894   // This is the index of the index vector in the start_indices tensor.
1895   std::vector<llvm::Value*> gather_index_index_components;
1896   {
1897     for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
1898       if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
1899         gather_index_index_components.push_back(index[i]);
1900       }
1901     }
1902 
1903     if (gather_index_index_components.size() !=
1904         indices_shape.dimensions_size()) {
1905       gather_index_index_components.insert(
1906           gather_index_index_components.begin() +
1907               dim_numbers.index_vector_dim(),
1908           nullptr);
1909     }
1910   }
1911 
1912   auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
1913     llvm::Value* gather_dim_component_extended =
1914         SExtOrTrunc(index_component, index_type);
1915     int64 operand_dim = dim_numbers.start_index_map(dim);
1916     int64 output_dim = operand_to_output_dim[operand_dim];
1917     // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
1918     // This means we set the iteration index to 0, so for the purpose of the
1919     // following calculations we can consider the output dimension size to be 1.
1920     int64 output_dim_size =
1921         output_dim == -1 ? 1 : output_shape.dimensions(output_dim);
1922     int64 largest_valid_start_index =
1923         operand_shape.dimensions(operand_dim) - output_dim_size;
1924     CHECK_GE(largest_valid_start_index, 0);
1925 
1926     // Clamp the gather index so that the gather region fits in the operand.
1927     // gather_dim_component_extended_inbound =
1928     //     clamp(gather_dim_component_extended, 0, largest_valid_start_index);
1929     bool is_signed = ShapeUtil::ElementIsSigned(indices_shape);
1930     auto gather_dim_component_extended_inbound = EmitIntegralMin(
1931         index.GetConstantWithIndexType(largest_valid_start_index),
1932         EmitIntegralMax(index.GetConstantWithIndexType(0),
1933                         gather_dim_component_extended, is_signed),
1934         is_signed);
1935 
1936     operand_multi_index[operand_dim] =
1937         Add(operand_multi_index[operand_dim],
1938             gather_dim_component_extended_inbound);
1939   };
1940 
1941   if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
1942     IrArray::Index gather_index_index(gather_index_index_components,
1943                                       indices_shape, index_type);
1944     TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
1945                         indices_generator(gather_index_index));
1946     add_to_operand_index(gather_dim_component, 0);
1947   } else {
1948     int64 index_vector_size =
1949         indices_shape.dimensions(dim_numbers.index_vector_dim());
1950     for (int64 i = 0; i < index_vector_size; i++) {
1951       gather_index_index_components[dim_numbers.index_vector_dim()] =
1952           index.GetConstantWithIndexType(i);
1953       IrArray::Index gather_index_index(gather_index_index_components,
1954                                         indices_shape, index_type);
1955       TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
1956                           indices_generator(gather_index_index));
1957       add_to_operand_index(gather_dim_component, i);
1958     }
1959   }
1960   IrArray::Index operand_index(operand_multi_index, operand_shape, index_type);
1961   return operand_generator(operand_index);
1962 }
1963 
EmitElementalDynamicUpdateSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1964 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
1965     const HloInstruction* hlo,
1966     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1967     const llvm_ir::IrArray::Index& index) {
1968   const HloInstruction* input_hlo = hlo->operand(0);
1969   const HloInstruction* update_hlo = hlo->operand(1);
1970   const HloInstruction* start_hlo = hlo->operand(2);
1971   // Calculate slice start/end indices.
1972   const int64 rank = input_hlo->shape().rank();
1973   std::vector<llvm::Value*> slice_start_multi_index(rank);
1974   std::vector<llvm::Value*> slice_limit_multi_index(rank);
1975   // Slice intersection gathers (ANDs) conditions on all ranks for which
1976   // 'input' is set to 'update'
1977   llvm::Value* slice_intersection = b_->getTrue();
1978 
1979   for (int64 i = 0; i < rank; ++i) {
1980     llvm::Type* index_type = index[0]->getType();
1981     auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
1982       return llvm::ConstantInt::get(index_type, c);
1983     };
1984 
1985     llvm_ir::IrArray::Index zero_index(index_type);
1986     TF_ASSIGN_OR_RETURN(
1987         llvm::Value * start_index_value,
1988         operand_to_generator.at(hlo->operand(2 + i))(zero_index));
1989 
1990     // Clamp the start index so that the update region fits in the operand.
1991     // start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
1992     start_index_value = SExtOrTrunc(start_index_value, index_type);
1993     llvm::Value* update_dim_size =
1994         index_typed_const(update_hlo->shape().dimensions(i));
1995     int64 largest_valid_start_index =
1996         input_hlo->shape().dimensions(i) - update_hlo->shape().dimensions(i);
1997     CHECK_GE(largest_valid_start_index, 0);
1998 
1999     bool is_signed = ShapeUtil::ElementIsSigned(start_hlo->shape());
2000     start_index_value = EmitIntegralMin(
2001         index_typed_const(largest_valid_start_index),
2002         EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
2003         is_signed);
2004 
2005     start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
2006     slice_start_multi_index[i] = start_index_value;
2007     slice_limit_multi_index[i] =
2008         Add(slice_start_multi_index[i], update_dim_size);
2009 
2010     slice_intersection =
2011         And(slice_intersection, ICmpSGE(index[i], slice_start_multi_index[i]),
2012             "slice_intersection");
2013     slice_intersection =
2014         And(slice_intersection, ICmpSLT(index[i], slice_limit_multi_index[i]),
2015             "slice_intersection");
2016   }
2017 
2018   // Emit:
2019   // if (slice_intersection) -> return data from 'update'.
2020   // else                    -> return data from 'input'.
2021   llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2022       llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2023       "ret_value_addr", b_);
2024   llvm_ir::LlvmIfData if_data =
2025       llvm_ir::EmitIfThenElse(slice_intersection, "slice_intersection", b_);
2026 
2027   // Handle true BB (return data from 'update')
2028   SetToFirstInsertPoint(if_data.true_block, b_);
2029   // Compute update index for intersection case.
2030   std::vector<llvm::Value*> update_multi_index(rank);
2031   for (int64 i = 0; i < rank; ++i) {
2032     update_multi_index[i] = Sub(index[i], slice_start_multi_index[i]);
2033   }
2034   llvm_ir::IrArray::Index update_index(update_multi_index, update_hlo->shape(),
2035                                        index.GetType());
2036   TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
2037                       operand_to_generator.at(update_hlo)(update_index));
2038   Store(true_value, ret_value_addr);
2039 
2040   // Handle false BB (return data from 'input')
2041   SetToFirstInsertPoint(if_data.false_block, b_);
2042   TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
2043                       operand_to_generator.at(input_hlo)(index));
2044   Store(false_value, ret_value_addr);
2045 
2046   SetToFirstInsertPoint(if_data.after_block, b_);
2047   return Load(ret_value_addr);
2048 }
2049 
EmitElementalPad(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & padded_index)2050 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
2051     const HloInstruction* hlo,
2052     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2053     const llvm_ir::IrArray::Index& padded_index) {
2054   std::vector<llvm::Value*> multi_index = padded_index.multidim();
2055   llvm::Value* in_bounds = b_->getTrue();
2056   for (size_t i = 0; i < multi_index.size(); ++i) {
2057     auto index_typed_const = [=](int64 n) {
2058       return padded_index.GetConstantWithIndexType(n);
2059     };
2060     const auto& pad_dim = hlo->padding_config().dimensions(i);
2061     multi_index[i] =
2062         Sub(multi_index[i], index_typed_const(pad_dim.edge_padding_low()));
2063     in_bounds = And(in_bounds, ICmpSGE(multi_index[i], index_typed_const(0)),
2064                     "in_bounds");
2065     in_bounds =
2066         And(in_bounds,
2067             ICmpEQ(index_typed_const(0),
2068                    URem(multi_index[i],
2069                         index_typed_const(pad_dim.interior_padding() + 1))),
2070             "in_bounds");
2071     multi_index[i] =
2072         SDiv(multi_index[i], index_typed_const(pad_dim.interior_padding() + 1));
2073     in_bounds =
2074         And(in_bounds,
2075             ICmpSLT(multi_index[i],
2076                     index_typed_const(hlo->operand(0)->shape().dimensions(i))),
2077             "in_bounds");
2078   }
2079 
2080   // if (in_bounds) {
2081   //   ret_value = operand0[index];  // source
2082   // } else {
2083   //   ret_value = *operand1;        // padding
2084   // }
2085   llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2086       llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2087       "pad_result_addr", b_);
2088   llvm_ir::LlvmIfData if_data =
2089       llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
2090   SetToFirstInsertPoint(if_data.true_block, b_);
2091   llvm_ir::IrArray::Index index(multi_index, hlo->operand(0)->shape(),
2092                                 padded_index.GetType());
2093   TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2094                       operand_to_generator.at(hlo->operand(0))(index));
2095   Store(operand_value, ret_value_addr);
2096 
2097   SetToFirstInsertPoint(if_data.false_block, b_);
2098   TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
2099                       operand_to_generator.at(hlo->operand(1))(
2100                           IrArray::Index(index.GetType())));
2101   Store(padding_value, ret_value_addr);
2102 
2103   SetToFirstInsertPoint(if_data.after_block, b_);
2104   // Don't create phi(operand_value, padding_value) here, because invoking
2105   // operand_to_generator may create new basic blocks, making the parent
2106   // of operand_value or padding_value no longer a predecessor of
2107   // if_data.after_block.
2108   return Load(ret_value_addr);
2109 }
2110 
EmitElementalDot(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & dot_result_index)2111 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
2112     const HloInstruction* hlo,
2113     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2114     const llvm_ir::IrArray::Index& dot_result_index) {
2115   auto lhs_generator = operand_to_generator.at(hlo->operand(0));
2116   auto rhs_generator = operand_to_generator.at(hlo->operand(1));
2117 
2118   const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers();
2119   int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0);
2120   int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0);
2121 
2122   int64 contracted_dim_size =
2123       hlo->operand(0)->shape().dimensions(lhs_contracting_dim);
2124   int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
2125   int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
2126 
2127   llvm::Type* index_type = dot_result_index[0]->getType();
2128   auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2129     return llvm::ConstantInt::get(index_type, c);
2130   };
2131 
2132   std::unique_ptr<llvm_ir::ForLoop> inner_loop = llvm_ir::ForLoop::EmitForLoop(
2133       IrName(hlo, "inner"), index_typed_const(0),
2134       index_typed_const(contracted_dim_size), index_typed_const(1), b_);
2135 
2136   SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), b_);
2137   PrimitiveType primitive_type = hlo->shape().element_type();
2138   llvm::Type* primitive_type_llvm =
2139       llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
2140   llvm::Value* accumulator_alloca =
2141       llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_);
2142   Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca);
2143 
2144   SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_);
2145 
2146   // This is the inner reduction loop for a dot operation that produces
2147   // one element in the output.  If the operands to the dot operation have
2148   // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
2149   // Given an output index [a,b,c,d,e] in the result, we compute:
2150   //   sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
2151 
2152   std::vector<llvm::Value*> lhs_multi_index, rhs_multi_index;
2153   for (int64 i = 0; i < lhs_dims - 1; i++) {
2154     lhs_multi_index.push_back(dot_result_index[i]);
2155   }
2156   lhs_multi_index.insert(lhs_multi_index.begin() + lhs_contracting_dim,
2157                          inner_loop->GetIndVarValue());
2158   IrArray::Index lhs_index(lhs_multi_index, hlo->operand(0)->shape(),
2159                            index_type);
2160 
2161   int64 num_batch_dims = dim_numbers.rhs_batch_dimensions_size();
2162   for (int64 i = 0; i < num_batch_dims; i++) {
2163     rhs_multi_index.push_back(
2164         dot_result_index[dim_numbers.rhs_batch_dimensions(i)]);
2165   }
2166   for (int64 i = 0; i < rhs_dims - 1 - num_batch_dims; i++) {
2167     rhs_multi_index.push_back(dot_result_index[lhs_dims - 1 + i]);
2168   }
2169   rhs_multi_index.insert(rhs_multi_index.begin() + rhs_contracting_dim,
2170                          inner_loop->GetIndVarValue());
2171   IrArray::Index rhs_index(rhs_multi_index, hlo->operand(1)->shape(),
2172                            index_type);
2173 
2174   llvm::Value* current_accumulator = Load(accumulator_alloca);
2175   TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
2176   TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
2177   llvm::Value* next_accumulator;
2178   if (primitive_util::IsComplexType(primitive_type)) {
2179     llvm::Value* product_real =
2180         FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
2181              FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
2182     llvm::Value* product_imag =
2183         FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
2184              FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)));
2185     next_accumulator = InsertValue(
2186         current_accumulator,
2187         FAdd(EmitExtractReal(current_accumulator), product_real), {0});
2188     next_accumulator = InsertValue(
2189         next_accumulator,
2190         FAdd(EmitExtractImag(current_accumulator), product_imag), {1});
2191   } else if (primitive_util::IsFloatingPointType(primitive_type)) {
2192     next_accumulator = FAdd(current_accumulator, FMul(lhs_value, rhs_value));
2193   } else {
2194     next_accumulator = Add(current_accumulator, Mul(lhs_value, rhs_value));
2195   }
2196   Store(next_accumulator, accumulator_alloca);
2197 
2198   SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
2199   return Load(accumulator_alloca);
2200 }
2201 
MakeElementGenerator(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator)2202 llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
2203     const HloInstruction* hlo,
2204     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
2205   switch (hlo->opcode()) {
2206     case HloOpcode::kAbs:
2207     case HloOpcode::kRoundNearestAfz:
2208     case HloOpcode::kCeil:
2209     case HloOpcode::kClz:
2210     case HloOpcode::kConvert:
2211     case HloOpcode::kBitcastConvert:
2212     case HloOpcode::kCos:
2213     case HloOpcode::kExp:
2214     case HloOpcode::kExpm1:
2215     case HloOpcode::kFloor:
2216     case HloOpcode::kImag:
2217     case HloOpcode::kIsFinite:
2218     case HloOpcode::kLog:
2219     case HloOpcode::kLog1p:
2220     case HloOpcode::kNegate:
2221     case HloOpcode::kNot:
2222     case HloOpcode::kReal:
2223     case HloOpcode::kRsqrt:
2224     case HloOpcode::kSign:
2225     case HloOpcode::kSin:
2226     case HloOpcode::kSqrt:
2227     case HloOpcode::kTanh:
2228       return [this, hlo, &operand_to_generator](
2229                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2230         TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2231                             operand_to_generator.at(hlo->operand(0))(index));
2232         return EmitUnaryOp(hlo, operand_value);
2233       };
2234     case HloOpcode::kAdd:
2235     case HloOpcode::kAnd:
2236     case HloOpcode::kAtan2:
2237     case HloOpcode::kCompare:
2238     case HloOpcode::kComplex:
2239     case HloOpcode::kDivide:
2240     case HloOpcode::kMaximum:
2241     case HloOpcode::kMinimum:
2242     case HloOpcode::kMultiply:
2243     case HloOpcode::kOr:
2244     case HloOpcode::kXor:
2245     case HloOpcode::kPower:
2246     case HloOpcode::kRemainder:
2247     case HloOpcode::kShiftLeft:
2248     case HloOpcode::kShiftRightArithmetic:
2249     case HloOpcode::kShiftRightLogical:
2250     case HloOpcode::kSubtract:
2251       return [this, hlo, &operand_to_generator](
2252                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2253         const HloInstruction* lhs = hlo->operand(0);
2254         const HloInstruction* rhs = hlo->operand(1);
2255         TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value,
2256                             operand_to_generator.at(lhs)(index));
2257         TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value,
2258                             operand_to_generator.at(rhs)(index));
2259         return EmitBinaryOp(hlo, lhs_value, rhs_value);
2260       };
2261     case HloOpcode::kSelect:
2262       return [this, hlo, &operand_to_generator](
2263                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2264         return EmitElementalSelect(hlo, operand_to_generator, index);
2265       };
2266     case HloOpcode::kClamp:
2267       return [this, hlo, &operand_to_generator](
2268                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2269         return EmitElementalClamp(hlo, operand_to_generator, index);
2270       };
2271     case HloOpcode::kReducePrecision:
2272       return [this, hlo, &operand_to_generator](
2273                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2274         TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2275                             operand_to_generator.at(hlo->operand(0))(index));
2276         return EmitReducePrecision(hlo, operand_value);
2277       };
2278     case HloOpcode::kConcatenate:
2279       return [this, hlo, &operand_to_generator](
2280                  const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
2281         return EmitElementalConcatenate(hlo, operand_to_generator,
2282                                         target_index);
2283       };
2284     case HloOpcode::kReverse:
2285       return [this, hlo, &operand_to_generator](
2286                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2287         const HloInstruction* operand = hlo->operand(0);
2288         std::vector<llvm::Value*> source_multi_index = target_index.multidim();
2289         for (int64 dim : hlo->dimensions()) {
2290           source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType(
2291                                             hlo->shape().dimensions(dim) - 1),
2292                                         target_index[dim]);
2293         }
2294         llvm_ir::IrArray::Index source_index(
2295             source_multi_index, operand->shape(), target_index.GetType());
2296         return operand_to_generator.at(operand)(source_index);
2297       };
2298     case HloOpcode::kBroadcast:
2299       return [this, hlo, &operand_to_generator](
2300                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2301         const HloInstruction* operand = hlo->operand(0);
2302         // The `dimensions` member of the broadcast instruction maps from
2303         // input dimensions to output dimensions.
2304         return operand_to_generator.at(operand)(
2305             target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(),
2306                                                 hlo->dimensions(), b_));
2307       };
2308     case HloOpcode::kIota:
2309       return [this, hlo](
2310                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2311         auto* iota = Cast<HloIotaInstruction>(hlo);
2312         PrimitiveType element_type = iota->shape().element_type();
2313         IrArray::Index elem_index =
2314             iota->shape().rank() > 1
2315                 ? target_index.SourceIndexOfBroadcast(
2316                       iota->shape(),
2317                       ShapeUtil::MakeShapeWithDescendingLayout(
2318                           element_type,
2319                           {iota->shape().dimensions(iota->iota_dimension())}),
2320                       {iota->iota_dimension()}, b_)
2321                 : target_index;
2322         llvm::Value* elem_index_linear = elem_index.linear();
2323         if (elem_index_linear == nullptr) {
2324           std::vector<int64> iota_bound = {
2325               iota->shape().dimensions(iota->iota_dimension())};
2326           elem_index_linear = elem_index.Linearize(iota_bound, b_);
2327         }
2328         Shape component_shape =
2329             ShapeUtil::ElementIsComplex(iota->shape())
2330                 ? ShapeUtil::ComplexComponentShape(iota->shape())
2331                 : iota->shape();
2332         PrimitiveType component_element_type = component_shape.element_type();
2333         llvm::Value* iota_result;
2334         if (primitive_util::IsIntegralType(component_element_type) ||
2335             component_element_type == PRED) {
2336           iota_result = b_->CreateIntCast(
2337               elem_index_linear,
2338               llvm_ir::PrimitiveTypeToIrType(component_element_type, module_),
2339               /*isSigned=*/false);
2340         } else {
2341           TF_RET_CHECK(
2342               primitive_util::IsFloatingPointType(component_element_type))
2343               << component_element_type;
2344           llvm::Type* float_ir_type;
2345           if (component_element_type == BF16) {
2346             float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_);
2347           } else {
2348             float_ir_type =
2349                 llvm_ir::PrimitiveTypeToIrType(component_element_type, module_);
2350           }
2351           llvm::Value* float_val =
2352               b_->CreateUIToFP(elem_index_linear, float_ir_type);
2353           if (component_element_type == BF16) {
2354             iota_result = EmitF32ToBF16(float_val, b_);
2355           } else {
2356             iota_result = float_val;
2357           }
2358         }
2359         if (ShapeUtil::ElementIsComplex(iota->shape())) {
2360           return EmitComposeComplex(iota, iota_result, nullptr);
2361         } else {
2362           return iota_result;
2363         }
2364       };
2365     case HloOpcode::kSlice:
2366       return [this, hlo, &operand_to_generator](
2367                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2368         IrArray::Index sliced_index = index.SourceIndexOfSlice(
2369             /*operand_shape=*/hlo->operand(0)->shape(),
2370             /*starts=*/hlo->slice_starts(),
2371             /*strides=*/hlo->slice_strides(), /*builder=*/b_);
2372         return operand_to_generator.at(hlo->operand(0))(sliced_index);
2373       };
2374     case HloOpcode::kDynamicSlice:
2375       return [this, hlo, &operand_to_generator](
2376                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2377         return EmitElementalDynamicSlice(hlo, operand_to_generator, index);
2378       };
2379 
2380     case HloOpcode::kGather:
2381       return [this, hlo, &operand_to_generator](
2382                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2383         return EmitElementalGather(hlo, operand_to_generator, index);
2384       };
2385     case HloOpcode::kDynamicUpdateSlice:
2386       return [this, hlo, &operand_to_generator](
2387                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2388         return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator,
2389                                                index);
2390       };
2391     case HloOpcode::kBitcast:
2392       CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2393                ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2394       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2395         const HloInstruction* operand = hlo->operand(0);
2396         return operand_to_generator.at(operand)(
2397             index.SourceIndexOfBitcast(hlo->shape(), operand->shape(), b_));
2398       };
2399     case HloOpcode::kReshape:
2400       CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2401                ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2402       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2403         const HloInstruction* operand = hlo->operand(0);
2404         return operand_to_generator.at(operand)(
2405             index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_));
2406       };
2407     case HloOpcode::kCopy:
2408       return [hlo, &operand_to_generator](
2409                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2410         IrArray::Index source_index(target_index.multidim(),
2411                                     hlo->operand(0)->shape(),
2412                                     target_index.GetType());
2413         TF_ASSIGN_OR_RETURN(
2414             llvm::Value * operand_value,
2415             operand_to_generator.at(hlo->operand(0))(source_index));
2416         return operand_value;
2417       };
2418     case HloOpcode::kTranspose:
2419       return [this, hlo,
2420               &operand_to_generator](const IrArray::Index& target_index) {
2421         return operand_to_generator.at(hlo->operand(0))(
2422             target_index.SourceIndexOfTranspose(
2423                 hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(), b_));
2424       };
2425     case HloOpcode::kRng:
2426       return MakePhiloxRngElementGenerator(hlo, operand_to_generator);
2427     case HloOpcode::kPad:
2428       return [this, hlo, &operand_to_generator](
2429                  const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
2430         return EmitElementalPad(hlo, operand_to_generator, padded_index);
2431       };
2432 
2433     case HloOpcode::kDot:
2434       return [this, hlo,
2435               &operand_to_generator](const IrArray::Index& dot_result_index)
2436                  -> StatusOr<llvm::Value*> {
2437         return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
2438       };
2439     case HloOpcode::kReplicaId:
2440       return [this, hlo](const IrArray::Index&) -> StatusOr<llvm::Value*> {
2441         if (hlo_module_config_.replica_count() != 1) {
2442           return Unimplemented("Replication is not implemented on CPU/GPU.");
2443         }
2444         llvm::Type* type = llvm_ir::PrimitiveTypeToIrType(
2445             hlo->shape().element_type(), module_);
2446         return llvm::ConstantInt::getNullValue(type);
2447       };
2448     default:
2449       return [hlo](const IrArray::Index& index) {
2450         return Unimplemented("Unhandled opcode for elemental IR emission: %s",
2451                              HloOpcodeString(hlo->opcode()));
2452       };
2453   }
2454 }
2455 
EmitExtractReal(llvm::Value * value)2456 llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) {
2457   return ExtractValue(value, {0});
2458 }
2459 
EmitExtractImag(llvm::Value * value)2460 llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) {
2461   return ExtractValue(value, {1});
2462 }
2463 
EmitComposeComplex(const HloInstruction * op,llvm::Value * real,llvm::Value * imag)2464 llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
2465                                                     llvm::Value* real,
2466                                                     llvm::Value* imag) {
2467   auto cplx_type =
2468       llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
2469   auto complex =
2470       InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0});
2471   if (imag != nullptr) {
2472     complex = InsertValue(complex, imag, {1});
2473   }
2474   return complex;
2475 }
2476 
2477 }  // namespace xla
2478