• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/strings/str_cat.h"
28 #include "llvm/IR/BasicBlock.h"
29 #include "llvm/IR/Constants.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/Intrinsics.h"
32 #include "llvm/Support/MathExtras.h"
33 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
34 #include "tensorflow/compiler/xla/primitive_util.h"
35 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
36 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/statusor.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/compiler/xla/window_util.h"
48 #include "tensorflow/compiler/xla/xla_data.pb.h"
49 #include "tensorflow/core/lib/random/random.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/types.h"
52 
53 namespace xla {
54 
55 using absl::StrCat;
56 using llvm_ir::IrArray;
57 using llvm_ir::IrName;
58 using llvm_ir::SetToFirstInsertPoint;
59 
60 namespace {
61 
GlobalRandomValue()62 int64 GlobalRandomValue() {
63   static auto* mu = new tensorflow::mutex();
64   static std::mt19937_64 rng{42};
65   tensorflow::mutex_lock l(*mu);
66   return rng();
67 }
68 
EmitReducePrecisionIR(PrimitiveType src_ty,llvm::Value * x,int64 dest_exponent_bits,int64 dest_mantissa_bits,bool quiet_nans,llvm::IRBuilder<> * b)69 StatusOr<llvm::Value*> EmitReducePrecisionIR(
70     PrimitiveType src_ty, llvm::Value* x, int64 dest_exponent_bits,
71     int64 dest_mantissa_bits, bool quiet_nans, llvm::IRBuilder<>* b) {
72   using llvm::APInt;
73 
74   if (!primitive_util::IsFloatingPointType(src_ty)) {
75     return Unimplemented(
76         "ReducePrecision cannot accept non-floating-point type %s.",
77         PrimitiveType_Name(src_ty));
78   }
79 
80   // Integer and float types for casting and constant generation.
81   llvm::Type* float_type = x->getType();
82   int64 nbits = float_type->getPrimitiveSizeInBits();
83   llvm::IntegerType* int_type = b->getIntNTy(nbits);
84 
85   // SignificandWidth includes the implicit extra bit.
86   int src_mantissa_bits = primitive_util::SignificandWidth(src_ty) - 1;
87   int src_exponent_bits = nbits - 1 - src_mantissa_bits;
88 
89   // Cast the input value to an integer for bitwise manipulation.
90   llvm::Value* x_as_int = b->CreateBitCast(x, int_type);
91 
92   // Clear the sign bit, it does not participate in rounding and we will restore
93   // it later.
94   APInt sign_bit_mask(nbits, 1);
95   sign_bit_mask <<= nbits - 1;
96   llvm::Value* x_abs_bits =
97       b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, ~sign_bit_mask));
98 
99   APInt exp_bits_mask(nbits, 1);
100   exp_bits_mask = ((exp_bits_mask << src_exponent_bits) - 1)
101                   << src_mantissa_bits;
102   auto x_is_nan = b->CreateICmpUGT(
103       x_abs_bits, llvm::ConstantInt::get(int_type, exp_bits_mask));
104 
105   if (dest_mantissa_bits < src_mantissa_bits) {
106     // Last remaining mantissa bit.
107     APInt last_mantissa_bit_mask(nbits, 1);
108     last_mantissa_bit_mask <<= src_mantissa_bits - dest_mantissa_bits;
109 
110     // Compute rounding bias for round-to-nearest with ties to even.  This is
111     // equal to a base value of 0111... plus one bit if the last remaining
112     // mantissa bit is 1.
113     APInt base_rounding_bias = last_mantissa_bit_mask.lshr(1) - 1;
114     llvm::Value* x_last_mantissa_bit = b->CreateLShr(
115         b->CreateAnd(x_as_int,
116                      llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
117         (src_mantissa_bits - dest_mantissa_bits));
118     llvm::Value* x_rounding_bias =
119         b->CreateAdd(x_last_mantissa_bit,
120                      llvm::ConstantInt::get(int_type, base_rounding_bias));
121 
122     // Add rounding bias, and mask out truncated bits.  Note that the case
123     // where adding the rounding bias overflows into the exponent bits is
124     // correct; the non-masked mantissa bits will all be zero, and the
125     // exponent will be incremented by one.
126     APInt truncation_mask = ~(last_mantissa_bit_mask - 1);
127     llvm::Value* x_rounded = b->CreateAdd(x_as_int, x_rounding_bias);
128     x_rounded = b->CreateAnd(x_rounded,
129                              llvm::ConstantInt::get(int_type, truncation_mask));
130     if (quiet_nans) {
131       x_as_int = b->CreateSelect(x_is_nan, x_as_int, x_rounded);
132     } else {
133       x_as_int = x_rounded;
134     }
135   }
136 
137   if (dest_exponent_bits < src_exponent_bits) {
138     // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
139     // significant bit -- is equal to 1.0f for all exponent sizes.  Adding
140     // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
141     // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
142     // exponent (corresponding to 0.0f).
143     //
144     // Thus, the f32 exponent corresponding to the highest non-infinite
145     // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
146     // exponent corresponding to the lowest exponent for a bit size of n is
147     // (2^7-1) - 2^(n-1)-1.
148     //
149     // Note that we have already checked that exponents_bits >= 1.
150     APInt exponent_bias(nbits, 1);
151     exponent_bias = (exponent_bias << (src_exponent_bits - 1)) - 1;
152 
153     APInt reduced_exponent_bias(nbits, 1);
154     reduced_exponent_bias =
155         (reduced_exponent_bias << (dest_exponent_bits - 1)) - 1;
156 
157     APInt reduced_max_exponent = exponent_bias + reduced_exponent_bias;
158     APInt reduced_min_exponent = exponent_bias - reduced_exponent_bias;
159 
160     // Do we overflow or underflow?
161     llvm::Value* x_exponent =
162         b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, exp_bits_mask));
163     llvm::Value* x_overflows = b->CreateICmpUGT(
164         x_exponent, llvm::ConstantInt::get(
165                         int_type, reduced_max_exponent << src_mantissa_bits));
166     llvm::Value* x_underflows = b->CreateICmpULE(
167         x_exponent, llvm::ConstantInt::get(
168                         int_type, reduced_min_exponent << src_mantissa_bits));
169 
170     // Compute appropriately-signed values of zero and infinity.
171     llvm::Value* x_signed_zero =
172         b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, sign_bit_mask));
173     llvm::Value* x_signed_inf = b->CreateOr(
174         x_signed_zero, llvm::ConstantInt::get(int_type, exp_bits_mask));
175 
176     // Force to zero or infinity if overflow or underflow.  (Note that this
177     // truncates all denormal values to zero, rather than rounding them.)
178     x_as_int = b->CreateSelect(x_overflows, x_signed_inf, x_as_int);
179     x_as_int = b->CreateSelect(x_underflows, x_signed_zero, x_as_int);
180   }
181 
182   // Cast the result back to a floating-point type.
183   llvm::Value* result = b->CreateBitCast(x_as_int, float_type);
184 
185   // Correct result for NaN inputs.
186   //
187   // The exponent handling will "normalize" NaN values to infinities, which is
188   // undesirable (except in the case with no mantissa bits, in which case it
189   // is mandatory).  This logic also handles cases where mantissa-rounding
190   // causes a NaN's mantissa to overflow into the exponent bits, which would
191   // otherwise create an erroneous zero value.
192 
193   if (dest_mantissa_bits > 0) {
194     if (quiet_nans) {
195       APInt qnan_mask(nbits, 1);
196       qnan_mask <<= src_mantissa_bits - 1;
197       llvm::Value* x_with_qnan_bit_set =
198           b->CreateOr(x_as_int, llvm::ConstantInt::get(int_type, qnan_mask));
199       x_with_qnan_bit_set = b->CreateBitCast(x_with_qnan_bit_set, float_type);
200       result = b->CreateSelect(x_is_nan, x_with_qnan_bit_set, result);
201     } else {
202       result = b->CreateSelect(x_is_nan, x, result);
203     }
204   } else {
205     result = b->CreateSelect(x_is_nan,
206                              llvm::ConstantFP::getInfinity(float_type), result);
207   }
208 
209   return result;
210 }
211 
EmitF32ToBF16(llvm::Value * f32_value,llvm::IRBuilder<> * b)212 StatusOr<llvm::Value*> EmitF32ToBF16(llvm::Value* f32_value,
213                                      llvm::IRBuilder<>* b) {
214   TF_ASSIGN_OR_RETURN(
215       auto reduced_precision,
216       EmitReducePrecisionIR(
217           /*src_ty=*/F32, f32_value,
218           /*dest_exponent_bits=*/primitive_util::ExponentWidth(BF16),
219           /*dest_mantissa_bits=*/primitive_util::SignificandWidth(BF16) - 1,
220           /*quiet_nans=*/true, b));
221   auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty());
222   auto shifted = b->CreateLShr(as_int32, 16);
223   auto truncated = b->CreateTrunc(shifted, b->getInt16Ty());
224   return b->CreateBitCast(truncated, b->getInt16Ty());
225 }
226 
EmitBF16ToF32(llvm::Value * bf16_value,llvm::IRBuilder<> * b)227 llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, llvm::IRBuilder<>* b) {
228   auto as_int16 = b->CreateBitCast(bf16_value, b->getInt16Ty());
229   auto as_int32 = b->CreateZExt(as_int16, b->getInt32Ty());
230   auto shifted = b->CreateShl(as_int32, 16);
231   return b->CreateBitCast(shifted, b->getFloatTy());
232 }
233 
EmitIntegralToFloating(llvm::Value * integer_value,PrimitiveType from_type,PrimitiveType to_type,llvm::Module * module,llvm::IRBuilder<> * b)234 llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
235                                     PrimitiveType from_type,
236                                     PrimitiveType to_type, llvm::Module* module,
237                                     llvm::IRBuilder<>* b) {
238   if (primitive_util::IsSignedIntegralType(from_type)) {
239     return b->CreateSIToFP(integer_value,
240                            llvm_ir::PrimitiveTypeToIrType(to_type, module));
241   } else {
242     CHECK(primitive_util::IsUnsignedIntegralType(from_type) ||
243           from_type == PRED);
244     return b->CreateUIToFP(integer_value,
245                            llvm_ir::PrimitiveTypeToIrType(to_type, module));
246   }
247 }
248 
249 }  // namespace
250 
EmitUnaryOp(const HloInstruction * op,llvm::Value * operand_value)251 StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
252     const HloInstruction* op, llvm::Value* operand_value) {
253   if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
254       op->operand(0)->shape().element_type() == PRED) {
255     return EmitIntegerUnaryOp(op, operand_value);
256   } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
257     return EmitComplexUnaryOp(op, operand_value);
258   } else {
259     return EmitFloatUnaryOp(op, operand_value);
260   }
261 }
262 
EmitIntegerUnaryOp(const HloInstruction * op,llvm::Value * operand_value)263 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
264     const HloInstruction* op, llvm::Value* operand_value) {
265   switch (op->opcode()) {
266     case HloOpcode::kConvert: {
267       PrimitiveType from_type = op->operand(0)->shape().element_type();
268       PrimitiveType to_type = op->shape().element_type();
269       CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED)
270           << from_type;
271       if (from_type == to_type) {
272         return operand_value;
273       }
274       if (to_type == PRED) {
275         return b_->CreateZExt(
276             ICmpNE(operand_value,
277                    llvm::ConstantInt::get(operand_value->getType(), 0)),
278             llvm_ir::PrimitiveTypeToIrType(PRED, module_));
279       }
280       if (primitive_util::IsIntegralType(to_type)) {
281         return IntCast(operand_value,
282                        llvm_ir::PrimitiveTypeToIrType(to_type, module_),
283                        primitive_util::IsSignedIntegralType(from_type));
284       }
285       if (primitive_util::IsFloatingPointType(to_type)) {
286         if (to_type == BF16) {
287           return EmitF32ToBF16(EmitIntegralToFloating(operand_value, from_type,
288                                                       F32, module_, b_),
289                                b_);
290         }
291         return EmitIntegralToFloating(operand_value, from_type, to_type,
292                                       module_, b_);
293       }
294       if (primitive_util::IsComplexType(to_type)) {
295         auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
296             primitive_util::ComplexComponentType(to_type), module_);
297         if (primitive_util::IsSignedIntegralType(from_type)) {
298           return EmitComposeComplex(
299               op, SIToFP(operand_value, to_ir_component_type), nullptr);
300         }
301         if (primitive_util::IsUnsignedIntegralType(from_type) ||
302             from_type == PRED) {
303           return EmitComposeComplex(
304               op, UIToFP(operand_value, to_ir_component_type), nullptr);
305         }
306       }
307       return Unimplemented("conversion from primitive type %s to %s",
308                            PrimitiveType_Name(from_type),
309                            PrimitiveType_Name(to_type));
310     }
311     case HloOpcode::kBitcastConvert: {
312       PrimitiveType from_type = op->operand(0)->shape().element_type();
313       PrimitiveType to_type = op->shape().element_type();
314       CHECK(primitive_util::IsIntegralType(from_type));
315       if (from_type == to_type) {
316         return operand_value;
317       }
318       if (primitive_util::BitWidth(from_type) ==
319           primitive_util::BitWidth(to_type)) {
320         return BitCast(operand_value,
321                        llvm_ir::PrimitiveTypeToIrType(to_type, module_));
322       }
323       return InvalidArgument(
324           "bitcast conversion from primitive type %s to %s with unequal "
325           "bit-widths (%u versus %u) ",
326           PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
327           primitive_util::BitWidth(from_type),
328           primitive_util::BitWidth(to_type));
329     }
330     case HloOpcode::kAbs: {
331       bool is_signed =
332           primitive_util::IsSignedIntegralType(op->shape().element_type());
333       if (is_signed) {
334         auto type =
335             llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
336         auto cmp = ICmpSGE(operand_value, GetZero(type));
337         return Select(cmp, operand_value, Neg(operand_value));
338       } else {
339         return operand_value;
340       }
341     }
342     case HloOpcode::kClz: {
343       auto is_zero_undef = b_->getFalse();
344       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctlz,
345                                           {operand_value, is_zero_undef},
346                                           {operand_value->getType()}, b_);
347     }
348     case HloOpcode::kSign: {
349       CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type()))
350           << op->shape().element_type();
351       auto type =
352           llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
353       auto cmp = ICmpEQ(operand_value, GetZero(type));
354       auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1);
355       return Select(cmp, GetZero(type), Or(ashr, 1));
356     }
357     case HloOpcode::kNegate:
358       return Neg(operand_value);
359     case HloOpcode::kNot: {
360       auto type = op->shape().element_type();
361       if (type == PRED) {
362         // It is not sufficient to just call CreateNot() here because a PRED
363         // is represented as an i8 and the truth value is stored only in the
364         // bottom bit.
365         return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())),
366                               llvm_ir::PrimitiveTypeToIrType(PRED, module_));
367       } else if (primitive_util::IsIntegralType(type)) {
368         return Not(operand_value);
369       }
370       return Unimplemented("unary op Not is not defined for type '%d'", type);
371     }
372     case HloOpcode::kPopulationCount: {
373       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctpop,
374                                           {operand_value},
375                                           {operand_value->getType()}, b_);
376     }
377     default:
378       return Unimplemented("unary integer op '%s'",
379                            HloOpcodeString(op->opcode()));
380   }
381 }
382 
EmitFloatUnaryOp(const HloInstruction * op,llvm::Value * operand_value)383 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
384     const HloInstruction* op, llvm::Value* operand_value) {
385   switch (op->opcode()) {
386     case HloOpcode::kConvert: {
387       PrimitiveType from_type = op->operand(0)->shape().element_type();
388       PrimitiveType to_type = op->shape().element_type();
389       CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type;
390       if (from_type == to_type) {
391         return operand_value;
392       }
393       if (from_type == BF16) {
394         TF_RET_CHECK(to_type != BF16);
395         operand_value = EmitBF16ToF32(operand_value, b_);
396         from_type = F32;
397         if (from_type == to_type) {
398           return operand_value;
399         }
400       }
401       if (primitive_util::IsComplexType(to_type)) {
402         PrimitiveType to_component_type =
403             primitive_util::ComplexComponentType(to_type);
404         if (from_type == to_component_type) {
405           return EmitComposeComplex(op, operand_value, nullptr);
406         }
407         return EmitComposeComplex(
408             op,
409             FPCast(operand_value,
410                    llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
411             nullptr);
412       }
413       if (to_type == BF16) {
414         // Cast to F32 first. Other floating point formats are not supported by
415         // EmitReducePrecisionIR.
416         if (from_type != F32) {
417           operand_value = b_->CreateFPCast(
418               operand_value, llvm_ir::PrimitiveTypeToIrType(F32, module_));
419         }
420         return EmitF32ToBF16(operand_value, b_);
421       }
422       if (to_type == PRED) {
423         return b_->CreateZExt(
424             FCmpUNE(operand_value,
425                     llvm::ConstantFP::get(operand_value->getType(), 0.0)),
426             llvm_ir::PrimitiveTypeToIrType(PRED, module_));
427       }
428       auto* to_ir_type = llvm_ir::PrimitiveTypeToIrType(to_type, module_);
429       if (primitive_util::IsFloatingPointType(to_type)) {
430         return FPCast(operand_value, to_ir_type);
431       }
432       auto* from_ir_type = llvm_ir::PrimitiveTypeToIrType(from_type, module_);
433       int to_width = primitive_util::BitWidth(to_type);
434       if (primitive_util::IsSignedIntegralType(to_type)) {
435         int64_t min_int = llvm::minIntN(to_width);
436         int64_t max_int = llvm::maxIntN(to_width);
437         auto zero_int = llvm::ConstantInt::get(to_ir_type, 0);
438         auto min_value_int = llvm::ConstantInt::get(to_ir_type, min_int);
439         auto max_value_int = llvm::ConstantInt::get(to_ir_type, max_int);
440         auto min_value_float = llvm::ConstantFP::get(from_ir_type, min_int);
441         auto max_value_float = llvm::ConstantFP::get(from_ir_type, max_int);
442         auto clamped = FPToSI(operand_value,
443                               llvm_ir::PrimitiveTypeToIrType(to_type, module_));
444         // x <= static_cast<float>(INT_MIN) ? INT_MIN : ...
445         clamped = Select(FCmpOLE(operand_value, min_value_float), min_value_int,
446                          clamped);
447         // x >= static_cast<float>(INT_MAX) ? INT_MAX : ...
448         clamped = Select(FCmpOGE(operand_value, max_value_float), max_value_int,
449                          clamped);
450         // isnan(x) ? 0 : ...
451         clamped =
452             Select(FCmpUNO(operand_value, operand_value), zero_int, clamped);
453         return clamped;
454       }
455       if (primitive_util::IsUnsignedIntegralType(to_type)) {
456         uint64_t min_int = 0;
457         uint64_t max_int = llvm::maxUIntN(to_width);
458         auto min_value_int = llvm::ConstantInt::get(to_ir_type, min_int);
459         auto max_value_int = llvm::ConstantInt::get(to_ir_type, max_int);
460         auto min_value_float = llvm::ConstantFP::get(from_ir_type, min_int);
461         auto max_value_float = llvm::ConstantFP::get(from_ir_type, max_int);
462         auto clamped = FPToUI(operand_value,
463                               llvm_ir::PrimitiveTypeToIrType(to_type, module_));
464         // (x <= 0.0 || isnan(x)) ? 0 : ...
465         clamped = Select(FCmpULE(operand_value, min_value_float), min_value_int,
466                          clamped);
467         // x >= static_cast<float>(UINT_MAX) ? UINT_MAX : ...
468         clamped = Select(FCmpOGE(operand_value, max_value_float), max_value_int,
469                          clamped);
470         return clamped;
471       }
472       return Unimplemented("unhandled conversion operation: %s => %s",
473                            PrimitiveType_Name(from_type),
474                            PrimitiveType_Name(to_type));
475     }
476     case HloOpcode::kBitcastConvert: {
477       PrimitiveType from_type = op->operand(0)->shape().element_type();
478       PrimitiveType to_type = op->shape().element_type();
479       CHECK(primitive_util::IsFloatingPointType(from_type));
480       if (from_type == to_type) {
481         return operand_value;
482       }
483       if (primitive_util::BitWidth(from_type) ==
484           primitive_util::BitWidth(to_type)) {
485         return BitCast(operand_value,
486                        llvm_ir::PrimitiveTypeToIrType(to_type, module_));
487       }
488       return InvalidArgument(
489           "bitcast conversion from primitive type %s to %s with unequal "
490           "bit-widths (%u versus %u) ",
491           PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
492           primitive_util::BitWidth(from_type),
493           primitive_util::BitWidth(to_type));
494     }
495     case HloOpcode::kExp:
496       return EmitExp(op->shape().element_type(), operand_value, "");
497     case HloOpcode::kExpm1:
498       return EmitExpm1(op->shape().element_type(), operand_value);
499     case HloOpcode::kLog:
500       return EmitLog(op->shape().element_type(), operand_value);
501     case HloOpcode::kLog1p:
502       return EmitLog1p(op->shape().element_type(), operand_value);
503     case HloOpcode::kCos:
504       return EmitCos(op->shape().element_type(), operand_value);
505     case HloOpcode::kSin:
506       return EmitSin(op->shape().element_type(), operand_value);
507     case HloOpcode::kTanh:
508       return EmitTanh(op->shape().element_type(), operand_value);
509     case HloOpcode::kSqrt:
510       return EmitSqrt(op->shape().element_type(), operand_value);
511     case HloOpcode::kRsqrt:
512       return EmitRsqrt(op->shape().element_type(), operand_value);
513     case HloOpcode::kCbrt:
514       return EmitCbrt(op->shape().element_type(), operand_value);
515     case HloOpcode::kFloor:
516       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
517                                           {operand_value},
518                                           {operand_value->getType()}, b_);
519     case HloOpcode::kCeil:
520       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ceil,
521                                           {operand_value},
522                                           {operand_value->getType()}, b_);
523     case HloOpcode::kAbs:
524       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
525                                           {operand_value},
526                                           {operand_value->getType()}, b_);
527     case HloOpcode::kRoundNearestAfz:
528       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round,
529                                           {operand_value},
530                                           {operand_value->getType()}, b_);
531     case HloOpcode::kSign: {
532       auto type = operand_value->getType();
533       auto zero = llvm::ConstantFP::get(type, 0.0);
534       auto ne0_i1 = FCmpONE(operand_value, zero);
535       auto ne0_float = UIToFP(ne0_i1, type);
536       llvm::Value* result = llvm_ir::EmitCallToIntrinsic(
537           llvm::Intrinsic::copysign, {ne0_float, operand_value},
538           {operand_value->getType()}, b_);
539       auto is_nan = FCmpUNO(operand_value, operand_value);
540       result = Select(is_nan, operand_value, result);
541       return result;
542     }
543     case HloOpcode::kIsFinite: {
544       // abs(x) o!= inf, this works because the comparison returns false if
545       // either operand is NaN.
546       auto type = operand_value->getType();
547       auto abs_value = llvm_ir::EmitCallToIntrinsic(
548           llvm::Intrinsic::fabs, {operand_value}, {type}, b_);
549       auto infinity = llvm::ConstantFP::getInfinity(type);
550       auto not_infinite = FCmpONE(abs_value, infinity);
551       return b_->CreateZExt(not_infinite,
552                             llvm_ir::PrimitiveTypeToIrType(PRED, module_));
553     }
554     case HloOpcode::kNegate:
555       return FNeg(operand_value);
556     case HloOpcode::kReal:
557       return operand_value;
558     case HloOpcode::kImag:
559       return llvm::ConstantFP::get(operand_value->getType(), 0.0);
560     default:
561       return Unimplemented("unary floating-point op '%s'",
562                            HloOpcodeString(op->opcode()));
563   }
564 }
565 
EmitComplexUnaryOp(const HloInstruction * op,llvm::Value * operand_value)566 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
567     const HloInstruction* op, llvm::Value* operand_value) {
568   PrimitiveType input_type = op->operand(0)->shape().element_type();
569   PrimitiveType component_type =
570       primitive_util::IsComplexType(input_type)
571           ? primitive_util::ComplexComponentType(input_type)
572           : input_type;
573   switch (op->opcode()) {
574     case HloOpcode::kLog: {
575       return EmitComplexLog(op, operand_value);
576     }
577     case HloOpcode::kLog1p: {
578       // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
579       // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
580       // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
581       auto a = EmitExtractReal(operand_value);
582       auto b = EmitExtractImag(operand_value);
583       llvm::Type* llvm_ty = a->getType();
584       auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
585       auto two = llvm::ConstantFP::get(llvm_ty, 2.0);
586       auto a_plus_one = FAdd(a, one);
587       auto sum_sq = FAdd(FAdd(FMul(a, a), FMul(two, a)), FMul(b, b));
588       TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog1p(component_type, sum_sq));
589       TF_ASSIGN_OR_RETURN(auto angle,
590                           EmitAtan2(component_type, b, a_plus_one, ""));
591       auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
592       return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
593     }
594     case HloOpcode::kConvert: {
595       PrimitiveType from_type = op->operand(0)->shape().element_type();
596       TF_RET_CHECK(primitive_util::IsComplexType(from_type));
597       PrimitiveType to_type = op->shape().element_type();
598       TF_RET_CHECK(primitive_util::IsComplexType(to_type));
599       if (from_type == to_type) {
600         return operand_value;
601       }
602       PrimitiveType to_component_type =
603           primitive_util::ComplexComponentType(to_type);
604       auto to_ir_component_type =
605           llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
606       return EmitComposeComplex(
607           op, FPCast(EmitExtractReal(operand_value), to_ir_component_type),
608           FPCast(EmitExtractImag(operand_value), to_ir_component_type));
609     }
610     case HloOpcode::kExp: {
611       // e^(a+bi) = e^a*(cos(b)+sin(b)i)
612       TF_ASSIGN_OR_RETURN(
613           auto exp_a,
614           EmitExp(component_type, EmitExtractReal(operand_value), ""));
615       TF_ASSIGN_OR_RETURN(
616           auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
617       TF_ASSIGN_OR_RETURN(
618           auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
619       return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b));
620     }
621     case HloOpcode::kExpm1: {
622       // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
623       TF_ASSIGN_OR_RETURN(
624           auto exp_a,
625           EmitExp(component_type, EmitExtractReal(operand_value), ""));
626       TF_ASSIGN_OR_RETURN(
627           auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
628       TF_ASSIGN_OR_RETURN(
629           auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
630       auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0);
631       auto real_result = FSub(FMul(exp_a, cos_b), one);
632       auto imag_result = FMul(exp_a, sin_b);
633       return EmitComposeComplex(op, real_result, imag_result);
634     }
635     case HloOpcode::kCos: {
636       // cos(z) = .5(e^(iz) + e^(-iz))
637       // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
638       // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
639       // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i))
640       // cos(-x) = cos(x) and sin(-x) = -sin(x), so
641       // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
642       //           = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
643       auto a = EmitExtractReal(operand_value);
644       auto b = EmitExtractImag(operand_value);
645       auto type = a->getType();
646       TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
647       auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
648       auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
649       TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
650       TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
651       return EmitComposeComplex(op,
652                                 FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)),
653                                 FMul(sin_a, FSub(half_exp_neg_b, half_exp_b)));
654     }
655     case HloOpcode::kSin: {
656       // sin(z) = .5i(e^(-iz) - e^(iz))
657       // sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi)))
658       //           = .5i(e^(b-ai) - e^(-b+ai))
659       // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
660       // sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i))
661       //           = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a)))
662       // cos(-x) = cos(x) and sin(-x) = -sin(x), so
663       //           = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
664       //           = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
665       auto a = EmitExtractReal(operand_value);
666       auto b = EmitExtractImag(operand_value);
667       auto type = a->getType();
668       TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
669       auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
670       auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
671       TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
672       TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
673       return EmitComposeComplex(op,
674                                 FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)),
675                                 FMul(cos_a, FSub(half_exp_b, half_exp_neg_b)));
676     }
677     case HloOpcode::kTanh: {
678       /*
679       tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x))
680       e^(a+bi) = e^a*(cos(b)+sin(b)i)
681       so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) /
682               (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a))
683       cos(b)=cos(-b), sin(-b)=-sin(b)
684       so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) /
685               (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a))
686              =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) /
687               (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a))
688              =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) /
689               (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a))
690       This is a complex division, so we can multiply by denom_conj/denom_conj
691              =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) *
692               (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) /
693               ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
694              =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
695                i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
696               ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
697              =(e^(2a)-e^(-2a) +
698                i*[cos(b)sin(b)(e^(2a)+2+e^(-2a))-cos(b)sin(b)(e^(2a)-2+e^(2a)))]
699                / (cos(b)^2*(e^(2a)+2+e^(-2a)) + sin(b)^2*(e^(2a)-2+e^(2a))
700              =(e^(2a)-e^(-2a) +
701                i*cos(b)sin(b)*[e^(2a)+2+e^(-2a)-e^(2a)+2-e^(-2a)]) /
702                ([cos(b)^2 + sin(b)^2][e^(2a)+e^(-2a)])+2*[cos(b)^2 - sin(b)^2])
703              =(e^(2a)-e^(-2a) + i*cos(b)sin(b)*4) /
704               (e^(2a)+e^(-2a)+2*[cos(b)^2 - sin(b)^2])
705              =(e^(2a)-e^(-2a) + i*[sin(2b)/2]*4) /
706               (e^(2a)+e^(-2a)+2*[cos(2b)])
707              =(e^(2a)-e^(-2a) + i*2*sin(2b)) / (e^(2a) + e^(-2a) + 2*cos(2b))
708       */
709       llvm::Value* a = EmitExtractReal(operand_value);
710       llvm::Value* b = EmitExtractImag(operand_value);
711 
712       llvm::Type* type = a->getType();
713 
714       llvm::Value* neg_one = llvm::ConstantFP::get(type, -1.F);
715       llvm::Value* two_a = FAdd(a, a);
716       llvm::Value* neg_2a = FMul(neg_one, two_a);
717 
718       // When we are calculating the real numerator, e^(2a)-e^(-2a), for small
719       // values of `a`, we will get a ULP of 2^-23 using the exp function. Using
720       // expm1 to calculate e^(2a)-e^(-2a) = [e^(2a)-1] - [e^(-2a)-1] allows our
721       // ULP to be arbitrarily small. For larger values of `a`, calculating the
722       // numerator as Exp(2a)-Exp(-2a) vs Expm1(2a)-Expm1(-2a) return virtually
723       // identical results.
724       TF_ASSIGN_OR_RETURN(llvm::Value * exp_2a_m1,
725                           EmitExpm1(component_type, two_a));
726       TF_ASSIGN_OR_RETURN(llvm::Value * exp_neg_2a_m1,
727                           EmitExpm1(component_type, neg_2a));
728       llvm::Value* real_numerator = FSub(exp_2a_m1, exp_neg_2a_m1);
729 
730       // We can use the identity cos(2b)+1 = cos(b)^2-sin(b)^2+cos(b)^2+sin(b)^2
731       // = 2cos(b)^2. This gives us the ability to be more precise when the
732       // denominator is close to zero.
733       TF_ASSIGN_OR_RETURN(llvm::Value * cos_b, EmitCos(component_type, b));
734       llvm::Value* four = llvm::ConstantFP::get(type, 4.F);
735       llvm::Value* cos_b_sq = FMul(cos_b, cos_b);
736       llvm::Value* two_cos_2b_p2 = FMul(cos_b_sq, four);
737 
738       // Similarly we can compute sin(2b) with the formula sin(2b) =
739       // 2*sin(b)*cos(b).
740       TF_ASSIGN_OR_RETURN(llvm::Value * sin_b, EmitSin(component_type, b));
741       llvm::Value* imag_numerator = FMul(four, FMul(cos_b, sin_b));
742 
743       // Expm1(x) is about x for small values of x, but exp_sum_m2 is about x^2
744       // for small value of x. As a result, due to floating point precision
745       // issues, x^2 is a better approximation than Expm1(x) + Expm1(x) for
746       // small values of x.
747       llvm::Value* a_sqr = FMul(a, a);
748       llvm::Value* use_approx_cutoff = llvm::ConstantFP::get(type, 1e-8);
749       llvm::Value* use_approx = FCmpOLT(a_sqr, use_approx_cutoff);
750 
751       llvm::Value* exp_sum_m2 =
752           Select(use_approx, a_sqr, FAdd(exp_2a_m1, exp_neg_2a_m1));
753       llvm::Value* denom = FAdd(exp_sum_m2, two_cos_2b_p2);
754 
755       // As `a` grows toward +inf and -inf, the real numerator will grow towards
756       // +inf and -inf respectively, while the denominator will always grow
757       // towards +inf. The result is real_numerator/denom = NaN, when it should
758       // equal +1 and -1 respectively. Therefore, if our denominator is +inf,
759       // we just hardcode the limits for the real numbers.
760       llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
761       llvm::Value* is_inf = FCmpOEQ(exp_sum_m2, inf);
762       llvm::Value* real_limit = llvm_ir::EmitCallToIntrinsic(
763           llvm::Intrinsic::copysign, {neg_one, a}, {type}, b_);
764 
765       llvm::Value* real =
766           Select(is_inf, real_limit, FDiv(real_numerator, denom));
767       llvm::Value* imag = FDiv(imag_numerator, denom);
768 
769       // The complex tanh functions have a few corner cases:
770       // 1. (+0, +0) => (+0, +0)        - Handled normally
771       // 2. (x, +Inf) => (NaN, NaN)     - See below
772       // 3. (x, NaN) => (NaN, NaN)      - See below
773       // 4. (+inf, y) => (1, +0)        - Handled normally
774       // 5. (+Inf, +Inf) => (1, +/-0)   - See below
775       // 6. (+Inf, NaN) => (1, +/-0)    - See below
776       // 7. (NaN, +0) => (NaN, +0)      - See below
777       // 8. (NaN, y) => (NaN, NaN)      - Handled normally
778       // 9. (NaN, NaN) => (NaN, NaN)    - Handled normally
779       //
780       // For the cases that aren't handled normally:
781       // 2/3) Part of the calculation we do is that if exp(a) + exp(-a) = +inf,
782       //      then we return (+/-1, +/-0). However, this is only true if we
783       //      assume that a is infinity or b is finite. In the event that both a
784       //      is finite and b is either +/-Inf or NaN, then our normal
785       //      calculation would end up returing (+/-1, NaN), as opposed to (NaN,
786       //      NaN).
787       // 5/6) We always calculate the imaginary value as sin(2b)/denominator.
788       //      When the denominator is infinity, this assures us that the zero is
789       //      the correct sign. However if our imaginary input results in
790       //      sin(2b) = NaN, we calculate our imaginary result as NaN.
791       // 7)   In the event that a is NaN, the denominator will be NaN.
792       //      Therefore, the normal calculation gives (NaN, NaN) while we need
793       //      (NaN, +0).
794       if (!(b_->getFastMathFlags().noNaNs() &&
795             b_->getFastMathFlags().noInfs())) {
796         llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
797                                                           {a}, {type}, b_);
798         llvm::Value* zero = llvm::ConstantFP::get(type, 0.F);
799         llvm::Value* nan = llvm::ConstantFP::getNaN(type);
800 
801         llvm::Value* a_is_inf = FCmpOEQ(abs_a, inf);
802         llvm::Value* b_is_zero = FCmpOEQ(b, zero);
803 
804         // imag_numerator = 2sin(2b), so sin(2b) is NaN if and only if
805         // imag_numerator is NaN.
806         llvm::Value* sin_2b_is_nan =
807             b_->CreateFCmpUNO(imag_numerator, imag_numerator);
808 
809         llvm::Value* real_is_nan =
810             b_->CreateAnd(sin_2b_is_nan, b_->CreateNot(a_is_inf));
811         llvm::Value* imag_is_zero =
812             b_->CreateOr(b_is_zero, b_->CreateAnd(a_is_inf, sin_2b_is_nan));
813 
814         real = Select(real_is_nan, nan, real);
815         imag = Select(imag_is_zero, zero, imag);
816       }
817 
818       return EmitComposeComplex(op, real, imag);
819     }
820     case HloOpcode::kAbs: {
821       return EmitComplexAbs(component_type, operand_value);
822     }
823     case HloOpcode::kSign: {  // Sign(c) = c / |c|
824       TF_ASSIGN_OR_RETURN(auto cplx_abs,
825                           EmitComplexAbs(component_type, operand_value));
826       auto type = cplx_abs->getType();
827       auto zero = llvm::ConstantFP::get(type, 0.0);
828       auto oeq = FCmpOEQ(cplx_abs, zero);
829       return Select(
830           oeq, EmitComposeComplex(op, zero, zero),
831           EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs),
832                              FDiv(EmitExtractImag(operand_value), cplx_abs)));
833     }
834     case HloOpcode::kSqrt: {
835       return EmitComplexSqrt(op, component_type, operand_value);
836     }
837     case HloOpcode::kRsqrt: {
838       return EmitComplexRsqrt(op, component_type, operand_value);
839     }
840     case HloOpcode::kCbrt: {
841       return EmitComplexCbrt(op, component_type, operand_value);
842     }
843     case HloOpcode::kNegate:
844       return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
845                                 FNeg(EmitExtractImag(operand_value)));
846     case HloOpcode::kReal:
847       return EmitExtractReal(operand_value);
848     case HloOpcode::kImag:
849       return EmitExtractImag(operand_value);
850     default:
851       return Unimplemented("unary complex op '%s'",
852                            HloOpcodeString(op->opcode()));
853   }
854 }
855 
EmitBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)856 StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
857     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
858   PrimitiveType operand_type = op->operand(0)->shape().element_type();
859   if (operand_type == PRED) {
860     return EmitPredBinaryOp(op, lhs_value, rhs_value);
861   } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape())) {
862     return EmitIntegerBinaryOp(
863         op, lhs_value, rhs_value,
864         primitive_util::IsSignedIntegralType(operand_type));
865   } else if (primitive_util::IsComplexType(operand_type)) {
866     return EmitComplexBinaryOp(op, lhs_value, rhs_value);
867   } else {
868     return EmitFloatBinaryOp(op, lhs_value, rhs_value);
869   }
870 }
871 
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)872 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
873     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
874   switch (op->opcode()) {
875     case HloOpcode::kComplex:
876       return EmitComposeComplex(op, lhs_value, rhs_value);
877     case HloOpcode::kAdd:
878       return FAdd(lhs_value, rhs_value, op->name());
879     case HloOpcode::kSubtract:
880       return FSub(lhs_value, rhs_value, op->name());
881     case HloOpcode::kMultiply:
882       return FMul(lhs_value, rhs_value, op->name());
883     case HloOpcode::kDivide:
884       return FDiv(lhs_value, rhs_value, op->name());
885     case HloOpcode::kRemainder:
886       return FRem(lhs_value, rhs_value, op->name());
887     // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
888     // comparisons always return false when one of the operands is NaN, whereas
889     // unordered comparisons return true.
890     //
891     // We use ordered comparisons for everything except kNe, where we use an
892     // unordered comparison.  This makes x != y equivalent to !(x == y), and
893     // matches C++'s semantics.
894     case HloOpcode::kCompare: {
895       switch (op->comparison_direction()) {
896         case ComparisonDirection::kEq:
897           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
898                                          rhs_value, b_, op->name());
899         case ComparisonDirection::kNe:
900           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
901                                          rhs_value, b_, op->name());
902         case ComparisonDirection::kLt:
903           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
904                                          rhs_value, b_, op->name());
905         case ComparisonDirection::kGt:
906           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
907                                          rhs_value, b_, op->name());
908         case ComparisonDirection::kLe:
909           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
910                                          rhs_value, b_, op->name());
911         case ComparisonDirection::kGe:
912           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
913                                          rhs_value, b_, op->name());
914       }
915     }
916     case HloOpcode::kMaximum:
917       return EmitFloatMax(lhs_value, rhs_value, op->name());
918     case HloOpcode::kMinimum:
919       return EmitFloatMin(lhs_value, rhs_value, op->name());
920     case HloOpcode::kPower:
921       return EmitPow(op->shape().element_type(), lhs_value, rhs_value,
922                      op->name());
923     case HloOpcode::kAtan2:
924       return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value,
925                        op->name());
926     default:
927       return Unimplemented("binary floating point op '%s'",
928                            HloOpcodeString(op->opcode()));
929   }
930 }
931 
932 // Using sqrt(a^2 + b^2) can cause overflow errors. Therefore we can use
933 // sqrt(a^2 + b^2) = sqrt(a^2 * (1 + b^2/a^2))
934 //                 = |a| * sqrt(1 + (b/a)^2)
935 // With the assumption that |a| >= |b|.
936 //
937 // This method returns the min, max, and sqrt term for this calculation. This is
938 // done to prevent potential overflow errors that can occur from multiplying the
939 // max with the sqrt term. (i.e. when calculating the sqrt of the absolute
940 // value, we can take the sqrt of the max and the sqrt term before multiplying
941 // them together.) If return_sqrt is false, it returns 1 + (b/a)^2 instead of
942 // sqrt(1 + (b/a)^2).
943 StatusOr<std::tuple<llvm::Value*, llvm::Value*, llvm::Value*>>
EmitComplexAbsHelper(PrimitiveType prim_type,llvm::Value * operand_value,bool return_sqrt)944 ElementalIrEmitter::EmitComplexAbsHelper(PrimitiveType prim_type,
945                                          llvm::Value* operand_value,
946                                          bool return_sqrt) {
947   llvm::Value* real = EmitExtractReal(operand_value);
948   llvm::Value* imag = EmitExtractImag(operand_value);
949   llvm::Value* abs_real = llvm_ir::EmitCallToIntrinsic(
950       llvm::Intrinsic::fabs, {real}, {real->getType()}, b_);
951   llvm::Value* abs_imag = llvm_ir::EmitCallToIntrinsic(
952       llvm::Intrinsic::fabs, {imag}, {imag->getType()}, b_);
953   llvm::Value* max = EmitFloatMax(abs_real, abs_imag, "");
954   llvm::Value* min = EmitFloatMin(abs_real, abs_imag, "");
955 
956   llvm::Value* div = FDiv(min, max);
957   llvm::Value* div_sq = FMul(div, div);
958   llvm::Value* one = llvm::ConstantFP::get(max->getType(), 1);
959   llvm::Value* one_p_div_sq = FAdd(one, div_sq);
960   TF_ASSIGN_OR_RETURN(llvm::Value * sqrt, EmitSqrt(prim_type, one_p_div_sq));
961   return std::make_tuple(min, max, return_sqrt ? sqrt : one_p_div_sq);
962 }
963 
EmitComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)964 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexAbs(
965     PrimitiveType prim_type, llvm::Value* operand_value) {
966   llvm::Value* min;
967   llvm::Value* max;
968   llvm::Value* sqrt;
969   TF_ASSIGN_OR_RETURN(
970       std::tie(min, max, sqrt),
971       EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
972   llvm::Value* result = FMul(max, sqrt);
973   // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
974   // In such cases, we return `min` instead of `result`.
975   return Select(FCmpUNO(result, result), min, result);
976 }
977 
978 // Calculates ComplexAbs in the same way, except using:
979 // sqrt(|a| * sqrt(1 + (b/a)^2)) = sqrt(|a|) * pow(1 + (b/a)^2, .25)
EmitSqrtComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)980 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrtComplexAbs(
981     PrimitiveType prim_type, llvm::Value* operand_value) {
982   llvm::Value* min;
983   llvm::Value* max;
984   llvm::Value* one_p_div_sq;
985   TF_ASSIGN_OR_RETURN(
986       std::tie(min, max, one_p_div_sq),
987       EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/false));
988   TF_ASSIGN_OR_RETURN(llvm::Value * sqrt_max, EmitSqrt(prim_type, max));
989   TF_ASSIGN_OR_RETURN(llvm::Value * pow,
990                       EmitPow(prim_type, one_p_div_sq,
991                               llvm::ConstantFP::get(max->getType(), .25), ""));
992   llvm::Value* result = FMul(sqrt_max, pow);
993   // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
994   // In such cases, we return `min` instead of `result`.
995   return Select(FCmpUNO(result, result), min, result);
996 }
997 
998 // Calculates ComplexAbs in the same way, except using:
999 // rsqrt(|a| * sqrt(1 + (b/a)^2)) = rsqrt(|a|) * rsqrt(sqrt(1 + (b/a)^2))
EmitRsqrtComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)1000 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrtComplexAbs(
1001     PrimitiveType prim_type, llvm::Value* operand_value) {
1002   llvm::Value* min;
1003   llvm::Value* max;
1004   llvm::Value* sqrt;
1005   TF_ASSIGN_OR_RETURN(
1006       std::tie(min, max, sqrt),
1007       EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
1008   TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_max, EmitRsqrt(prim_type, max));
1009   TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_sqrt, EmitRsqrt(prim_type, sqrt));
1010   llvm::Value* result = FMul(rsqrt_max, rsqrt_sqrt);
1011   TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_min, EmitRsqrt(prim_type, min));
1012   // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
1013   // In such cases, we return rsqrt(min) instead of `result`.
1014   return Select(FCmpUNO(result, result), rsqrt_min, result);
1015 }
1016 
EmitComplexAdd(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1017 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexAdd(
1018     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1019   return EmitComposeComplex(
1020       op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1021       FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
1022 }
1023 
EmitComplexSubtract(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1024 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexSubtract(
1025     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1026   return EmitComposeComplex(
1027       op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1028       FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
1029 }
1030 
EmitComplexMultiply(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1031 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexMultiply(
1032     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1033   return EmitComposeComplex(
1034       op,
1035       FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1036            FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))),
1037       FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
1038            FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))));
1039 }
1040 
EmitComplexDivide(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1041 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexDivide(
1042     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1043   // Division of complex numbers is implemented here, taking into account
1044   // over/underflow, NaN and Inf values.
1045   auto a_r = EmitExtractReal(lhs_value);
1046   auto a_i = EmitExtractImag(lhs_value);
1047   auto b_r = EmitExtractReal(rhs_value);
1048   auto b_i = EmitExtractImag(rhs_value);
1049   auto type = a_r->getType();
1050 
1051   // Smith's algorithm to divide complex numbers. It is just a bit smarter
1052   // way to compute the following formula:
1053   //  (a_r + a_i * i) / (b_r + b_i * i)
1054   //    = (a_r + a_i * i) (b_r - b_i * i) / ((b_r + b_i * i)(b_r - b_i * i))
1055   //    = ((a_r * b_r + a_i * b_i) + (a_i * b_r - a_r * b_i) * i) / ||b||^2
1056   //
1057   // Depending on whether |b_r| < |b_i| we compute either
1058   //   b_r_b_i_ratio = b_r / b_i
1059   //   b_r_b_i_denom = b_i + b_r * b_r_b_i_ratio
1060   //   c_r = (a_r * b_r_b_i_ratio + a_i ) / b_r_b_i_denom
1061   //   c_i = (a_i * b_r_b_i_ratio - a_r ) / b_r_b_i_denom
1062   //
1063   // or
1064   //
1065   //   b_i_b_r_ratio = b_i / b_r
1066   //   b_i_b_r_denom = b_r + b_i * b_i_b_r_ratio
1067   //   c_r = (a_r + a_i * b_i_b_r_ratio ) / b_i_b_r_denom
1068   //   c_i = (a_i - a_r * b_i_b_r_ratio ) / b_i_b_r_denom
1069   //
1070   // See https://dl.acm.org/citation.cfm?id=368661 for more details.
1071   auto b_r_b_i_ratio = FDiv(b_r, b_i);
1072   auto b_r_b_i_denom = FAdd(b_i, FMul(b_r_b_i_ratio, b_r));
1073   auto b_i_b_r_ratio = FDiv(b_i, b_r);
1074   auto b_i_b_r_denom = FAdd(b_r, FMul(b_i_b_r_ratio, b_i));
1075 
1076   auto b_r_abs =
1077       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {b_r}, {type}, b_);
1078   auto b_i_abs =
1079       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {b_i}, {type}, b_);
1080   auto b_r_lt_b_i = FCmpOLT(b_r_abs, b_i_abs);
1081   auto c_r = Select(b_r_lt_b_i,
1082                     FDiv(FAdd(FMul(b_r_b_i_ratio, a_r), a_i), b_r_b_i_denom),
1083                     FDiv(FAdd(FMul(b_i_b_r_ratio, a_i), a_r), b_i_b_r_denom));
1084   auto c_i = Select(b_r_lt_b_i,
1085                     FDiv(FSub(FMul(b_r_b_i_ratio, a_i), a_r), b_r_b_i_denom),
1086                     FDiv(FSub(a_i, FMul(b_i_b_r_ratio, a_r)), b_i_b_r_denom));
1087   auto result = EmitComposeComplex(op, c_r, c_i);
1088 
1089   // Consider corner cases, if the result is (NaN, NaN).
1090   auto zero = llvm::ConstantFP::get(type, 0.0);
1091   auto one = llvm::ConstantFP::get(type, 1.0);
1092   auto inf = llvm::ConstantFP::getInfinity(type);
1093 
1094   // Case 1. Zero denominator.
1095   auto zero_denominator =
1096       And(And(FCmpOEQ(b_r_abs, zero), FCmpOEQ(b_i_abs, zero)),
1097           Or(Not(FCmpUNO(a_r, zero)), Not(FCmpUNO(a_i, zero))));
1098   auto inf_with_sign_of_b_r = llvm_ir::EmitCallToIntrinsic(
1099       llvm::Intrinsic::copysign, {inf, b_r}, {type}, b_);
1100   auto zero_denominator_result = EmitComposeComplex(
1101       op, FMul(inf_with_sign_of_b_r, a_r), FMul(inf_with_sign_of_b_r, a_i));
1102 
1103   // Case 2. Infinite numerator, finite denominator.
1104   auto b_r_finite = FCmpONE(b_r_abs, inf);
1105   auto b_i_finite = FCmpONE(b_i_abs, inf);
1106   auto a_r_abs =
1107       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {a_r}, {type}, b_);
1108   auto a_i_abs =
1109       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {a_i}, {type}, b_);
1110   auto a_r_infinite = FCmpOEQ(a_r_abs, inf);
1111   auto a_i_infinite = FCmpOEQ(a_i_abs, inf);
1112   auto inf_num_finite_denom =
1113       And(Or(a_r_infinite, a_i_infinite), And(b_r_finite, b_i_finite));
1114 
1115   auto a_r_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1116       llvm::Intrinsic::copysign, {Select(a_r_infinite, one, zero), a_r}, {type},
1117       b_);
1118   auto a_i_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1119       llvm::Intrinsic::copysign, {Select(a_i_infinite, one, zero), a_i}, {type},
1120       b_);
1121   auto inf_num_finite_denom_result = EmitComposeComplex(
1122       op,
1123       FMul(inf,
1124            FAdd(FMul(a_r_inf_with_sign, b_r), FMul(a_i_inf_with_sign, b_i))),
1125       FMul(inf,
1126            FSub(FMul(a_i_inf_with_sign, b_r), FMul(a_r_inf_with_sign, b_i))));
1127 
1128   // Case 3. Finite numerator, infinite denominator.
1129   auto a_r_finite = FCmpONE(a_r_abs, inf);
1130   auto a_i_finite = FCmpONE(a_i_abs, inf);
1131   auto b_r_infinite = FCmpOEQ(b_r_abs, inf);
1132   auto b_i_infinite = FCmpOEQ(b_i_abs, inf);
1133   auto finite_num_inf_denom =
1134       And(Or(b_r_infinite, b_i_infinite), And(a_r_finite, a_i_finite));
1135 
1136   auto b_r_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1137       llvm::Intrinsic::copysign, {Select(b_r_infinite, one, zero), b_r}, {type},
1138       b_);
1139   auto b_i_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1140       llvm::Intrinsic::copysign, {Select(b_i_infinite, one, zero), b_i}, {type},
1141       b_);
1142   auto finite_num_inf_denom_result = EmitComposeComplex(
1143       op,
1144       FMul(zero,
1145            FAdd(FMul(a_r, b_r_inf_with_sign), FMul(a_i, b_i_inf_with_sign))),
1146       FMul(zero,
1147            FSub(FMul(a_i, b_r_inf_with_sign), FMul(a_r, b_i_inf_with_sign))));
1148 
1149   auto c_nan = And(FCmpUNO(c_r, zero), FCmpUNO(c_i, zero));
1150   return Select(c_nan,
1151                 Select(zero_denominator, zero_denominator_result,
1152                        Select(inf_num_finite_denom, inf_num_finite_denom_result,
1153                               Select(finite_num_inf_denom,
1154                                      finite_num_inf_denom_result, result))),
1155                 result);
1156 }
1157 
EmitComplexLog(const HloInstruction * op,llvm::Value * operand_value)1158 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexLog(
1159     const HloInstruction* op, llvm::Value* operand_value) {
1160   // log(a+bi) = log(abs(a+bi)) + i*atan2(b,a)
1161   PrimitiveType component_type =
1162       primitive_util::ComplexComponentType(op->shape().element_type());
1163   auto a = EmitExtractReal(operand_value);
1164   auto b = EmitExtractImag(operand_value);
1165   TF_ASSIGN_OR_RETURN(llvm::Value * angle, EmitAtan2(component_type, b, a, ""));
1166   TF_ASSIGN_OR_RETURN(llvm::Value * abs,
1167                       EmitComplexAbs(component_type, operand_value));
1168   TF_ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs));
1169   return EmitComposeComplex(op, log_abs, angle);
1170 }
1171 
1172 // Using our EmitComplexPower formula, but setting c=0.5 and d=0, we get:
1173 //   e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
1174 // = e^[ln(r)*0.5] * [cos(t*0.5) + i*sin(t*0.5)]
1175 // = r^0.5 * [cos(t/2) + i*sin(t/2)]
1176 // = sqrt(r) * [cos(t/2) + i*sin(t/2)]
1177 // where r = |a+bi| and t = atan2(b,a)
1178 // TODO(bixia): See doc for implementation without atan2.
EmitComplexSqrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1179 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexSqrt(
1180     const HloInstruction* op, PrimitiveType prim_type,
1181     llvm::Value* operand_value) {
1182   llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
1183                          ->getElementType(0);
1184 
1185   TF_ASSIGN_OR_RETURN(llvm::Value * r,
1186                       EmitSqrtComplexAbs(prim_type, operand_value));
1187 
1188   llvm::Value* a = EmitExtractReal(operand_value);
1189   llvm::Value* b = EmitExtractImag(operand_value);
1190   TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));
1191 
1192   llvm::Value* c = llvm::ConstantFP::get(type, 0.5);
1193   llvm::Value* angle = FMul(t, c);
1194   TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
1195   TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
1196 
1197   llvm::Value* real_part;
1198   llvm::Value* imag_part;
1199 
1200   llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1201 
1202   if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
1203     llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1204     llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
1205     llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1206     llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1207                                                       {b}, {b->getType()}, b_);
1208 
1209     real_part = Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, inf)), inf,
1210                        Select(And(FCmpOEQ(a, neg_inf), FCmpONE(abs_b, inf)),
1211                               zero, FMul(r, cos)));
1212 
1213     llvm::Value* b_signed_inf = llvm_ir::EmitCallToIntrinsic(
1214         llvm::Intrinsic::copysign, {inf, b}, {b->getType()}, b_);
1215     imag_part =
1216         Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, neg_inf)), b_signed_inf,
1217                Select(FCmpUNO(r, r), nan,
1218                       Select(FCmpOEQ(sin, zero), sin, FMul(r, sin))));
1219   } else {
1220     real_part = FMul(r, cos);
1221     imag_part = Select(FCmpOEQ(sin, zero), sin, FMul(r, sin));
1222   }
1223 
1224   return Select(FCmpOEQ(r, zero), EmitComposeComplex(op, zero, zero),
1225                 EmitComposeComplex(op, real_part, imag_part));
1226 }
1227 
1228 // Similar to Sqrt, we can use our EmitComplexPower formula, but set
1229 // c=-0.5 and d=0. We get:
1230 //   e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
1231 // = e^[ln(r)*-0.5] * [cos(t*-0.5) + i*sin(t*-0.5)]
1232 // = r^(-0.5) * [cos(-t/2) + i*sin(-t/2)]
1233 // = rsqrt(r) * [cos(-t/2) + i*sin(-t/2)]
1234 // where r = |a+bi| and t = atan2(b,a).
EmitComplexRsqrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1235 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexRsqrt(
1236     const HloInstruction* op, PrimitiveType prim_type,
1237     llvm::Value* operand_value) {
1238   llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
1239                          ->getElementType(0);
1240 
1241   TF_ASSIGN_OR_RETURN(llvm::Value * r,
1242                       EmitRsqrtComplexAbs(prim_type, operand_value));
1243 
1244   llvm::Value* a = EmitExtractReal(operand_value);
1245   llvm::Value* b = EmitExtractImag(operand_value);
1246   TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));
1247 
1248   llvm::Value* c = llvm::ConstantFP::get(type, -0.5);
1249   llvm::Value* angle = FMul(t, c);
1250   TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
1251   TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
1252 
1253   llvm::Value* real_part = FMul(r, cos);
1254   llvm::Value* imag_part = FMul(r, sin);
1255 
1256   if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
1257     llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1258     llvm::Value* neg_one = llvm::ConstantFP::get(type, -1);
1259     llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1260     llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1261     // llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
1262     llvm::Value* a_signed_zero = llvm_ir::EmitCallToIntrinsic(
1263         llvm::Intrinsic::copysign, {zero, a}, {a->getType()}, b_);
1264     llvm::Value* b_signed_zero = llvm_ir::EmitCallToIntrinsic(
1265         llvm::Intrinsic::copysign, {zero, b}, {b->getType()}, b_);
1266     llvm::Value* neg_b_signed_zero = FMul(b_signed_zero, neg_one);
1267 
1268     llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1269                                                       {a}, {a->getType()}, b_);
1270     llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1271                                                       {b}, {b->getType()}, b_);
1272 
1273     llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
1274     real_part = Select(
1275         is_zero_zero, inf,
1276         Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
1277                a_signed_zero, FMul(r, cos)));
1278     imag_part = Select(
1279         is_zero_zero, nan,
1280         Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
1281                neg_b_signed_zero, FMul(r, sin)));
1282   } else {
1283     llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1284     llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1285     llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1286 
1287     llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
1288     real_part = Select(is_zero_zero, inf, FMul(r, cos));
1289     imag_part = Select(is_zero_zero, nan, FMul(r, sin));
1290   }
1291 
1292   return EmitComposeComplex(op, real_part, imag_part);
1293 }
1294 
1295 //
1296 // Using EmitComplexPower with c=1.0/3.0 and d=0
EmitComplexCbrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1297 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexCbrt(
1298     const HloInstruction* op, PrimitiveType prim_type,
1299     llvm::Value* operand_value) {
1300   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1301   auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
1302   auto zero = llvm::ConstantFP::get(type, 0);
1303   llvm::Value* a = EmitExtractReal(operand_value);
1304   llvm::Value* b = EmitExtractImag(operand_value);
1305   return EmitComplexPower(op, a, b, third, zero);
1306 }
1307 
1308 // (a+bi)^(c+di) =
1309 //    (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
1310 //    where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
EmitComplexPower(const HloInstruction * op,llvm::Value * a,llvm::Value * b,llvm::Value * c,llvm::Value * d)1311 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexPower(
1312     const HloInstruction* op, llvm::Value* a, llvm::Value* b, llvm::Value* c,
1313     llvm::Value* d) {
1314   PrimitiveType component_type =
1315       primitive_util::ComplexComponentType(op->shape().element_type());
1316   auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b));
1317   auto zero = llvm::ConstantFP::get(a->getType(), 0);
1318   auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
1319   auto one = llvm::ConstantFP::get(a->getType(), 1);
1320   auto half_c = FMul(one_half, c);
1321 
1322   TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
1323                       EmitPow(component_type, aa_p_bb, half_c, ""));
1324 
1325   auto neg_d = FNeg(d);
1326   TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a, ""));
1327   auto neg_d_arg_lhs = FMul(neg_d, arg_lhs);
1328   TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
1329                       EmitExp(component_type, neg_d_arg_lhs, ""));
1330   auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
1331   TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
1332   auto half_d = FMul(one_half, d);
1333   auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb));
1334   TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
1335   TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
1336   // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
1337   // Branch Cuts for Complex Elementary Functions or Much Ado About
1338   // Nothing's Sign Bit, W. Kahan, Section 10.
1339   return Select(
1340       And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)),
1341       EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero),
1342       EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)));
1343 }
1344 
EmitComplexBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1345 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
1346     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1347   switch (op->opcode()) {
1348     case HloOpcode::kAdd:
1349       return EmitComplexAdd(op, lhs_value, rhs_value);
1350     case HloOpcode::kSubtract:
1351       return EmitComplexSubtract(op, lhs_value, rhs_value);
1352     case HloOpcode::kMultiply:
1353       return EmitComplexMultiply(op, lhs_value, rhs_value);
1354     case HloOpcode::kDivide: {
1355       return EmitComplexDivide(op, lhs_value, rhs_value);
1356     }
1357     // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
1358     // comparisons always return false when one of the operands is NaN, whereas
1359     // unordered comparisons return true.
1360     //
1361     // We use ordered comparisons for everything except kNe, where we use an
1362     // unordered comparison.  This makes x != y equivalent to !(x == y), and
1363     // matches C++'s semantics.
1364     case HloOpcode::kCompare: {
1365       switch (op->comparison_direction()) {
1366         case ComparisonDirection::kEq:
1367           return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
1368                                              EmitExtractReal(lhs_value),
1369                                              EmitExtractReal(rhs_value), b_),
1370                      llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
1371                                              EmitExtractImag(lhs_value),
1372                                              EmitExtractImag(rhs_value), b_));
1373         case ComparisonDirection::kNe:
1374           return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
1375                                             EmitExtractReal(lhs_value),
1376                                             EmitExtractReal(rhs_value), b_),
1377                     llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
1378                                             EmitExtractImag(lhs_value),
1379                                             EmitExtractImag(rhs_value), b_));
1380         default:
1381           return Unimplemented(
1382               "complex comparison '%s'",
1383               ComparisonDirectionToString(op->comparison_direction()));
1384       }
1385     }
1386     case HloOpcode::kPower: {
1387       auto a = EmitExtractReal(lhs_value);
1388       auto b = EmitExtractImag(lhs_value);
1389       auto c = EmitExtractReal(rhs_value);
1390       auto d = EmitExtractImag(rhs_value);
1391       return EmitComplexPower(op, a, b, c, d);
1392     }
1393     case HloOpcode::kAtan2: {
1394       // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
1395       auto y = lhs_value;
1396       auto x = rhs_value;
1397       TF_ASSIGN_OR_RETURN(auto x_squared, EmitComplexMultiply(op, x, x));
1398       TF_ASSIGN_OR_RETURN(auto y_squared, EmitComplexMultiply(op, y, y));
1399       TF_ASSIGN_OR_RETURN(auto x_squared_plus_y_squared,
1400                           EmitComplexAdd(op, x_squared, y_squared));
1401       auto component_type =
1402           primitive_util::ComplexComponentType(op->shape().element_type());
1403       TF_ASSIGN_OR_RETURN(
1404           auto sqrt_x_squared_plus_y_squared,
1405           EmitComplexSqrt(op, component_type, x_squared_plus_y_squared));
1406       auto type = llvm_ir::PrimitiveTypeToIrType(component_type, module_);
1407       auto zero = llvm::ConstantFP::get(type, 0.0);
1408       auto one = llvm::ConstantFP::get(type, 1.0);
1409       auto i = EmitComposeComplex(op, zero, one);
1410       TF_ASSIGN_OR_RETURN(auto i_times_y, EmitComplexMultiply(op, i, y));
1411       TF_ASSIGN_OR_RETURN(auto x_plus_iy, EmitComplexAdd(op, x, i_times_y));
1412       TF_ASSIGN_OR_RETURN(
1413           auto div_result,
1414           EmitComplexDivide(op, x_plus_iy, sqrt_x_squared_plus_y_squared));
1415       TF_ASSIGN_OR_RETURN(auto log_result, EmitComplexLog(op, div_result));
1416       auto negative_one = llvm::ConstantFP::get(type, -1.0);
1417       auto negative_i = EmitComposeComplex(op, zero, negative_one);
1418       return EmitComplexMultiply(op, negative_i, log_result);
1419     }
1420     default:
1421       return Unimplemented("binary complex op '%s'",
1422                            HloOpcodeString(op->opcode()));
1423   }
1424 }
1425 
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value,absl::string_view name)1426 llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
1427                                               llvm::Value* rhs_value,
1428                                               absl::string_view name) {
1429   return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max(), name);
1430 }
1431 
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value,absl::string_view name)1432 llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
1433                                               llvm::Value* rhs_value,
1434                                               absl::string_view name) {
1435   return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max(), name);
1436 }
1437 
EmitLog(PrimitiveType prim_type,llvm::Value * value)1438 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
1439                                                    llvm::Value* value) {
1440   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
1441                                       {value->getType()}, b_);
1442 }
1443 
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)1444 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
1445                                                      llvm::Value* value) {
1446   auto x = value;
1447   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1448   auto one = llvm::ConstantFP::get(type, 1.0);
1449   auto negative_half = llvm::ConstantFP::get(type, -0.5);
1450   // When x is large, the naive evaluation of ln(x + 1) is more
1451   // accurate than the Taylor series.
1452   TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
1453   // When x is small, (defined to be less than sqrt(2) / 2), use a rational
1454   // approximation. The approximation below is based on one from the Cephes
1455   // Mathematical Library.
1456   //
1457   // sqrt(2) - 1.
1458   const auto kAntilogarithmIsSmallThreshold = 0.41421356237309504880;
1459 
1460   static const std::array<double, 7> kDenominatorCoeffs{
1461       1.,
1462       1.5062909083469192043167E1,
1463       8.3047565967967209469434E1,
1464       2.2176239823732856465394E2,
1465       3.0909872225312059774938E2,
1466       2.1642788614495947685003E2,
1467       6.0118660497603843919306E1,
1468   };
1469 
1470   static const std::array<double, 7> kNumeratorCoeffs{
1471       4.5270000862445199635215E-5, 4.9854102823193375972212E-1,
1472       6.5787325942061044846969E0,  2.9911919328553073277375E1,
1473       6.0949667980987787057556E1,  5.7112963590585538103336E1,
1474       2.0039553499201281259648E1,
1475   };
1476 
1477   auto x_squared = FMul(x, x);
1478   TF_ASSIGN_OR_RETURN(auto denominator,
1479                       EvaluatePolynomial(type, x, kDenominatorCoeffs));
1480   TF_ASSIGN_OR_RETURN(auto numerator,
1481                       EvaluatePolynomial(type, x, kNumeratorCoeffs));
1482   auto for_small_x = FDiv(numerator, denominator);
1483   for_small_x = FMul(FMul(x, x_squared), for_small_x);
1484   for_small_x = FAdd(FMul(negative_half, x_squared), for_small_x);
1485   for_small_x = FAdd(x, for_small_x);
1486 
1487   auto abs_x =
1488       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1489   auto x_is_small = FCmpOLT(
1490       abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold));
1491   return Select(x_is_small, for_small_x, for_large_x);
1492 }
1493 
EmitSqrt(PrimitiveType,llvm::Value * value)1494 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrt(PrimitiveType,
1495                                                     llvm::Value* value) {
1496   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {value},
1497                                       {value->getType()}, b_);
1498 }
1499 
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)1500 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
1501                                                      llvm::Value* value) {
1502   TF_ASSIGN_OR_RETURN(auto sqrt, EmitSqrt(prim_type, value));
1503   return FDiv(llvm::ConstantFP::get(sqrt->getType(), 1.0), sqrt);
1504 }
1505 
EmitSin(PrimitiveType prim_type,llvm::Value * value)1506 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
1507                                                    llvm::Value* value) {
1508   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
1509                                       {value->getType()}, b_);
1510 }
1511 
EmitCos(PrimitiveType prim_type,llvm::Value * value)1512 StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
1513                                                    llvm::Value* value) {
1514   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
1515                                       {value->getType()}, b_);
1516 }
1517 
EmitExp(PrimitiveType prim_type,llvm::Value * value,absl::string_view name)1518 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
1519                                                    llvm::Value* value,
1520                                                    absl::string_view name) {
1521   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
1522                                       {value->getType()}, b_, name);
1523 }
1524 
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)1525 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
1526                                                      llvm::Value* value) {
1527   auto x = value;
1528   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1529   auto one = llvm::ConstantFP::get(type, 1.0);
1530   auto half = llvm::ConstantFP::get(type, 0.5);
1531   auto zero = llvm::ConstantFP::get(type, 0.0);
1532 
1533   // expm1(x) == tanh(x/2)*(exp(x)+1)
1534   // x/2 can underflow, if it does we approximate expm1 with x.
1535   auto x_over_two = FMul(x, half);
1536   auto x_over_two_is_zero = FCmpOEQ(x_over_two, zero);
1537   auto abs_x =
1538       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {x}, {type}, b_);
1539   // Use a naive exp(x)-1 calculation if |x| is > 0.5
1540   auto x_magnitude_is_large = FCmpOGT(abs_x, half);
1541   TF_ASSIGN_OR_RETURN(auto tanh_of_x_over_two, EmitTanh(prim_type, x_over_two));
1542   TF_ASSIGN_OR_RETURN(auto exp_of_x, EmitExp(prim_type, x, ""));
1543   auto exp_of_x_plus_one = FAdd(exp_of_x, one);
1544   auto exp_of_x_minus_one = FSub(exp_of_x, one);
1545   auto expm1_of_x = FMul(tanh_of_x_over_two, exp_of_x_plus_one);
1546   expm1_of_x = Select(x_magnitude_is_large, exp_of_x_minus_one, expm1_of_x);
1547   expm1_of_x = Select(x_over_two_is_zero, x, expm1_of_x);
1548   return expm1_of_x;
1549 }
1550 
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)1551 StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
1552                                                    llvm::Value* lhs,
1553                                                    llvm::Value* rhs,
1554                                                    absl::string_view name) {
1555   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
1556                                       {lhs->getType()}, b_, name);
1557 }
1558 
EmitCbrt(PrimitiveType prim_type,llvm::Value * value)1559 StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
1560                                                     llvm::Value* value) {
1561   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1562   auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
1563   auto abs_value =
1564       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1565   TF_ASSIGN_OR_RETURN(llvm::Value * abs_res,
1566                       EmitPow(prim_type, abs_value, third, ""));
1567   auto signed_res = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
1568                                                  {abs_res, value}, {type}, b_);
1569   return signed_res;
1570 }
1571 
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value *,absl::string_view)1572 StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(
1573     PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* /*rhs*/,
1574     absl::string_view /*name*/) {
1575   return Unimplemented("atan2");
1576 }
1577 
EmitTanh(PrimitiveType prim_type,llvm::Value * value)1578 StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
1579                                                     llvm::Value* value) {
1580   return Unimplemented("tanh");
1581 }
1582 
EmitReducePrecision(const HloInstruction * hlo,llvm::Value * x)1583 StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
1584     const HloInstruction* hlo, llvm::Value* x) {
1585   return EmitReducePrecisionIR(
1586       /*src_ty=*/hlo->operand(0)->shape().element_type(), x,
1587       /*dest_exponent_bits=*/hlo->exponent_bits(),
1588       /*dest_mantissa_bits=*/hlo->mantissa_bits(),
1589       /*quiet_nans=*/false, b_);
1590 }
1591 
SaturateShiftIfNecessary(llvm::IRBuilder<> * b,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * shift_result,bool saturate_to_sign_bit)1592 static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
1593                                              llvm::Value* lhs, llvm::Value* rhs,
1594                                              llvm::Value* shift_result,
1595                                              bool saturate_to_sign_bit) {
1596   llvm::IntegerType* integer_type =
1597       llvm::cast<llvm::IntegerType>(lhs->getType());
1598   unsigned integer_bitsize = integer_type->getBitWidth();
1599   llvm::ConstantInt* integer_bitsize_constant =
1600       llvm::ConstantInt::get(integer_type, integer_bitsize);
1601   llvm::ConstantInt* zero = llvm::ConstantInt::get(integer_type, 0);
1602   llvm::ConstantInt* minus_one = llvm::ConstantInt::get(integer_type, -1);
1603   llvm::Value* saturated_value;
1604   if (saturate_to_sign_bit) {
1605     saturated_value =
1606         b->CreateSelect(b->CreateICmpSLT(lhs, zero), minus_one, zero);
1607   } else {
1608     saturated_value = zero;
1609   }
1610   llvm::Value* shift_amt_in_range =
1611       b->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk");
1612   return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value);
1613 }
1614 
GetOne(llvm::Type * type)1615 llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) {
1616   return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 1);
1617 }
1618 
GetZero(llvm::Type * type)1619 llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) {
1620   return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 0);
1621 }
1622 
GetIntSMin(llvm::Type * type)1623 llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) {
1624   auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1625   return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue(
1626                                                   integer_type->getBitWidth()));
1627 }
1628 
GetMinusOne(llvm::Type * type)1629 llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) {
1630   auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1631   return llvm::ConstantInt::get(
1632       integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth()));
1633 }
1634 
IsZero(llvm::Value * v)1635 llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) {
1636   return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0));
1637 }
1638 
IsIntMinDivisionOverflow(llvm::Value * lhs,llvm::Value * rhs)1639 llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs,
1640                                                           llvm::Value* rhs) {
1641   return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())),
1642              ICmpEQ(rhs, GetMinusOne(rhs->getType())));
1643 }
1644 
EmitIntegerDivide(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1645 llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs,
1646                                                    llvm::Value* rhs,
1647                                                    bool is_signed) {
1648   // Integer division overflow behavior:
1649   //
1650   // X / 0 == -1
1651   // INT_SMIN /s -1 = INT_SMIN
1652 
1653   if (!is_signed) {
1654     llvm::Value* udiv_is_unsafe = IsZero(rhs);
1655     llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs);
1656     llvm::Value* safe_div = UDiv(lhs, safe_rhs);
1657     return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div);
1658   }
1659 
1660   llvm::Value* has_zero_divisor = IsZero(rhs);
1661   llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1662   llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1663   llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs);
1664   llvm::Value* safe_div = SDiv(lhs, safe_rhs);
1665 
1666   return Select(
1667       has_zero_divisor, GetMinusOne(lhs->getType()),
1668       Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div));
1669 }
1670 
EmitIntegerRemainder(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1671 llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
1672                                                       llvm::Value* rhs,
1673                                                       bool is_signed) {
1674   // Integer remainder overflow behavior:
1675   //
1676   // X % 0 == X
1677   // INT_SMIN %s -1 = 0
1678 
1679   if (!is_signed) {
1680     llvm::Value* urem_is_unsafe = IsZero(rhs);
1681     llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs);
1682     llvm::Value* safe_rem = URem(lhs, safe_rhs);
1683     return Select(urem_is_unsafe, lhs, safe_rem);
1684   }
1685 
1686   llvm::Value* has_zero_divisor = IsZero(rhs);
1687   llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1688   llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1689   llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs);
1690   llvm::Value* safe_rem = SRem(lhs, safe_rhs);
1691 
1692   return Select(
1693       has_zero_divisor, lhs,
1694       Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem));
1695 }
1696 
EmitIntegerPow(llvm::Value * base,llvm::Value * exponent,bool is_signed)1697 llvm::Value* ElementalIrEmitter::EmitIntegerPow(llvm::Value* base,
1698                                                 llvm::Value* exponent,
1699                                                 bool is_signed) {
1700   // Exponentiation by squaring:
1701   // https://en.wikipedia.org/wiki/Exponentiation_by_squaring;
1702   int bits = 6;  // Everything else would overflow for any exponent > 1, as 2^64
1703                  // is the larget possible exponent for a 64-bit integer, and
1704                  // that's 1 << 6.
1705   llvm::Value* accumulator = llvm::ConstantInt::get(base->getType(), 1);
1706   llvm::Value* one = llvm::ConstantInt::get(exponent->getType(), 1);
1707   llvm::Value* zero = llvm::ConstantInt::get(exponent->getType(), 0);
1708   llvm::Value* original_base = base;
1709   llvm::Value* original_exponent = exponent;
1710 
1711   // Unroll the loop at compile time.
1712   for (int i = 0; i < bits; i++) {
1713     accumulator =
1714         b_->CreateSelect(b_->CreateICmpEQ(b_->CreateAnd(exponent, one), one),
1715                          b_->CreateMul(accumulator, base), accumulator);
1716     base = b_->CreateMul(base, base);
1717     exponent = b_->CreateLShr(exponent, 1);
1718   }
1719   return b_->CreateSelect(
1720       b_->CreateICmpSGE(original_exponent, zero), accumulator,
1721       b_->CreateSelect(b_->CreateICmpEQ(original_base, one), one, zero));
1722 }
1723 
EmitPredBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1724 StatusOr<llvm::Value*> ElementalIrEmitter::EmitPredBinaryOp(
1725     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1726   // Per the reference interpreter, pred arithmetic should behave like
1727   // `int8(x) OP int8(y) != 0`.  For most permitted ops, we can just emit the
1728   // underlying i8 op to achieve this (e.g. kAnd, kOr, kXor, kMultiply).  In the
1729   // case of kAdd, we would need to insert a comparison instruction after the
1730   // addition, but it's both easier and faster to emit a bitwise or instruction
1731   // instead.
1732   //
1733   // For several of these ops, a faster bitwise implementation is available, but
1734   // LLVM is unlikely to be able to see it, since it gets IR that e.g. loads i8s
1735   // from memory, multiplies them, and writes the result back, without any
1736   // indication that the inputs were assumed to be 0 or 1.  So, just in case,
1737   // help it out by choosing the faster instruction to begin with.
1738   switch (op->opcode()) {
1739     case HloOpcode::kCompare:
1740     case HloOpcode::kXor:
1741       return EmitIntegerBinaryOp(op, lhs_value, rhs_value, false);
1742 
1743     // zext(i1 x) + zext(i1 y) != 0 === or(x, y)
1744     // max(zext(i1 x), zext(i1 y)) != 0 === or(x, y)
1745     case HloOpcode::kAdd:
1746     case HloOpcode::kMaximum:
1747     case HloOpcode::kOr:
1748       return Or(lhs_value, rhs_value);
1749 
1750     // zext(i1 x) * zext(i1 y) != 0 === and(x, y)
1751     // min(zext(i1 x), zext(i1 y)) != 0 === and(x, y)
1752     case HloOpcode::kMultiply:
1753     case HloOpcode::kMinimum:
1754     case HloOpcode::kAnd:
1755       return And(lhs_value, rhs_value);
1756 
1757     // These opcodes are rejected by shape-inference for PRED elements; calling
1758     // them out here serves more as documentation than a necessary check.
1759     case HloOpcode::kDivide:
1760     case HloOpcode::kRemainder:
1761     case HloOpcode::kPower:
1762     case HloOpcode::kSubtract:
1763     case HloOpcode::kShiftLeft:
1764     case HloOpcode::kShiftRightArithmetic:
1765     case HloOpcode::kShiftRightLogical:
1766       return InternalError("Invalid binary op '%s' for pred",
1767                            HloOpcodeString(op->opcode()));
1768 
1769     default:
1770       return Unimplemented("binary pred op '%s'",
1771                            HloOpcodeString(op->opcode()));
1772   }
1773 }
1774 
EmitIntegerBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1775 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
1776     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
1777     bool is_signed) {
1778   switch (op->opcode()) {
1779     // TODO(jingyue): add the "nsw" attribute for signed types.
1780     case HloOpcode::kAdd:
1781       return Add(lhs_value, rhs_value);
1782     case HloOpcode::kSubtract:
1783       return Sub(lhs_value, rhs_value);
1784     case HloOpcode::kMultiply:
1785       return Mul(lhs_value, rhs_value);
1786     case HloOpcode::kDivide:
1787       return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
1788     case HloOpcode::kRemainder:
1789       return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
1790     case HloOpcode::kCompare: {
1791       switch (op->comparison_direction()) {
1792         case ComparisonDirection::kEq:
1793           return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
1794                                          rhs_value, b_);
1795         case ComparisonDirection::kNe:
1796           return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
1797                                          rhs_value, b_);
1798         case ComparisonDirection::kLt:
1799           return llvm_ir::EmitComparison(
1800               is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
1801               lhs_value, rhs_value, b_);
1802         case ComparisonDirection::kGt:
1803           return llvm_ir::EmitComparison(
1804               is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
1805               lhs_value, rhs_value, b_);
1806         case ComparisonDirection::kLe:
1807           return llvm_ir::EmitComparison(
1808               is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
1809               lhs_value, rhs_value, b_);
1810         case ComparisonDirection::kGe:
1811           return llvm_ir::EmitComparison(
1812               is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
1813               lhs_value, rhs_value, b_);
1814       }
1815     }
1816     case HloOpcode::kMinimum:
1817       return EmitIntegralMin(lhs_value, rhs_value, is_signed);
1818     case HloOpcode::kMaximum:
1819       return EmitIntegralMax(lhs_value, rhs_value, is_signed);
1820     case HloOpcode::kAnd:
1821       return And(lhs_value, rhs_value);
1822     case HloOpcode::kOr:
1823       return Or(lhs_value, rhs_value);
1824     case HloOpcode::kPower:
1825       return EmitIntegerPow(lhs_value, rhs_value, is_signed);
1826     case HloOpcode::kXor:
1827       return Xor(lhs_value, rhs_value);
1828 
1829     // Shifting out bits >= the number of bits in the type being shifted
1830     // produces a poison value in LLVM which is basically "deferred undefined
1831     // behavior" -- doing something observable with such a value precipitates
1832     // UB.  We replace the poison value with a constant to avoid this deferred
1833     // UB.
1834     case HloOpcode::kShiftRightArithmetic:
1835       return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1836                                       AShr(lhs_value, rhs_value),
1837                                       /*saturate_to_sign_bit=*/true);
1838     case HloOpcode::kShiftLeft:
1839       return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1840                                       Shl(lhs_value, rhs_value),
1841                                       /*saturate_to_sign_bit=*/false);
1842     case HloOpcode::kShiftRightLogical:
1843       return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1844                                       LShr(lhs_value, rhs_value),
1845                                       /*saturate_to_sign_bit=*/false);
1846     default:
1847       return Unimplemented("binary integer op '%s'",
1848                            HloOpcodeString(op->opcode()));
1849   }
1850 }
1851 
EmitIntegralMax(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1852 llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
1853                                                  llvm::Value* rhs_value,
1854                                                  bool is_signed) {
1855   return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
1856                                          : llvm::ICmpInst::ICMP_UGE,
1857                                lhs_value, rhs_value),
1858                 lhs_value, rhs_value);
1859 }
1860 
EmitIntegralMin(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1861 llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
1862                                                  llvm::Value* rhs_value,
1863                                                  bool is_signed) {
1864   return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
1865                                          : llvm::ICmpInst::ICMP_ULE,
1866                                lhs_value, rhs_value),
1867                 lhs_value, rhs_value);
1868 }
1869 
EmitElementalSelect(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1870 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
1871     const HloInstruction* hlo,
1872     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1873     const llvm_ir::IrArray::Index& index) {
1874   TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
1875                       operand_to_generator.at(hlo->operand(0))(index));
1876   TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
1877                       operand_to_generator.at(hlo->operand(1))(index));
1878   TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
1879                       operand_to_generator.at(hlo->operand(2))(index));
1880   return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value,
1881                 on_false_value);
1882 }
1883 
EmitElementalClamp(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1884 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
1885     const HloInstruction* hlo,
1886     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1887     const llvm_ir::IrArray::Index& index) {
1888   TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
1889                       operand_to_generator.at(hlo->operand(0))(index));
1890   TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
1891                       operand_to_generator.at(hlo->operand(1))(index));
1892   TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
1893                       operand_to_generator.at(hlo->operand(2))(index));
1894   PrimitiveType prim_type = hlo->shape().element_type();
1895   if (primitive_util::IsFloatingPointType(prim_type)) {
1896     return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value, ""), "");
1897   } else if (primitive_util::IsIntegralType(prim_type)) {
1898     bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
1899     return EmitIntegralMin(
1900         max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed);
1901   } else {
1902     return Unimplemented("Clamp unimplemented for %s",
1903                          PrimitiveType_Name(prim_type));
1904   }
1905 }
1906 
EmitElementalConcatenate(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & source_index)1907 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
1908     const HloInstruction* hlo,
1909     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1910     const llvm_ir::IrArray::Index& source_index) {
1911   const int64 concat_dim = hlo->dimensions(0);
1912   llvm::BasicBlock* init_block = b_->GetInsertBlock();
1913 
1914   llvm::BasicBlock* exit_block;
1915   if (b_->GetInsertPoint() != init_block->end()) {
1916     // Inserting into the middle.
1917     CHECK(init_block->getTerminator());
1918     exit_block =
1919         init_block->splitBasicBlock(b_->GetInsertPoint(), IrName(hlo, "merge"));
1920     init_block->getTerminator()->eraseFromParent();
1921   } else {
1922     // Inserting at the end.
1923     CHECK(!init_block->getTerminator());
1924     exit_block = llvm_ir::CreateBasicBlock(
1925         /*insert_before=*/nullptr, IrName(hlo, "merge"), b_);
1926   }
1927 
1928   llvm_ir::SetToFirstInsertPoint(exit_block, b_);
1929   llvm::PHINode* output = b_->CreatePHI(
1930       llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
1931       hlo->operands().size());
1932   auto prior_insert_point = b_->GetInsertPoint();
1933 
1934   b_->SetInsertPoint(init_block);
1935 
1936   // Assign a unique id for each *different* operand, and count how often each
1937   // operand is used. If all operands are different, the usage count will be 1
1938   // for each operand.
1939   absl::flat_hash_map<const HloInstruction*, int64> to_unique_operand_id;
1940   std::vector<int64> operand_usage_count;
1941   for (const HloInstruction* operand : hlo->operands()) {
1942     if (to_unique_operand_id.contains(operand)) {
1943       ++operand_usage_count[to_unique_operand_id[operand]];
1944     } else {
1945       int64 unique_operand_id = to_unique_operand_id.size();
1946       to_unique_operand_id[operand] = unique_operand_id;
1947       operand_usage_count.push_back(1);
1948     }
1949   }
1950 
1951   // To avoid that we emit the same operand more than once, we create one basic
1952   // block for each *different* operand with a PHI node for the different source
1953   // index inputs.
1954   std::vector<llvm::BasicBlock*> emit_operand_blocks(
1955       to_unique_operand_id.size(), nullptr);
1956   std::vector<llvm::PHINode*> source_index_phis(to_unique_operand_id.size(),
1957                                                 nullptr);
1958   for (const HloInstruction* operand : hlo->operands()) {
1959     int64 operand_id = to_unique_operand_id[operand];
1960     if (emit_operand_blocks[operand_id] != nullptr) {
1961       continue;
1962     }
1963 
1964     emit_operand_blocks[operand_id] = llvm_ir::CreateBasicBlock(
1965         exit_block, StrCat("concat_index_from_operand_id", operand_id), b_);
1966     auto saved_insert_point = b_->GetInsertPoint();
1967     llvm_ir::SetToFirstInsertPoint(emit_operand_blocks[operand_id], b_);
1968     source_index_phis[operand_id] =
1969         b_->CreatePHI(source_index.GetType(), operand_usage_count[operand_id]);
1970     std::vector<llvm::Value*> operand_multi_index = source_index.multidim();
1971     operand_multi_index[concat_dim] = b_->CreateNSWSub(
1972         operand_multi_index[concat_dim], source_index_phis[operand_id]);
1973 
1974     // Create the terminator of the block before calling operand generators,
1975     // because they require non-degenerate basic blocks.
1976     b_->SetInsertPoint(llvm::BranchInst::Create(
1977         exit_block, /*InsertAtEnd=*/emit_operand_blocks[operand_id]));
1978     llvm_ir::IrArray::Index operand_index(operand_multi_index, operand->shape(),
1979                                           source_index.GetType());
1980 
1981     TF_ASSIGN_OR_RETURN(llvm::Value * value,
1982                         operand_to_generator.at(operand)(operand_index));
1983     output->addIncoming(value, b_->GetInsertBlock());
1984     b_->SetInsertPoint(init_block, saved_insert_point);
1985   }
1986 
1987   // We use bisection to select the input operand.
1988   int64 current_offset = 0;
1989 
1990   // Offset for every operand.
1991   std::vector<std::pair<int64, const HloInstruction*>> cases;
1992 
1993   cases.reserve(hlo->operand_count());
1994   for (const HloInstruction* operand : hlo->operands()) {
1995     cases.emplace_back(current_offset, operand);
1996     current_offset += operand->shape().dimensions(concat_dim);
1997   }
1998   CHECK_EQ(current_offset, hlo->shape().dimensions(concat_dim));
1999 
2000   std::function<llvm::BasicBlock*(
2001       absl::Span<const std::pair<int64, const HloInstruction*>> operands)>
2002       emit_tree = [&](absl::Span<const std::pair<int64, const HloInstruction*>>
2003                           operands) {
2004         llvm::IRBuilder<>::InsertPointGuard guard(*b_);
2005         size_t mid = operands.size() / 2;
2006         const std::pair<int64, const HloInstruction*>& pivot = operands[mid];
2007         llvm::BasicBlock* block = llvm_ir::CreateBasicBlock(
2008             exit_block, absl::StrCat("concatenate.pivot.", pivot.first, "."),
2009             b_);
2010         b_->SetInsertPoint(block);
2011 
2012         // If there's only one element we're done. The range is contiguous so we
2013         // can just jump to the block for it.
2014         if (operands.size() == 1) {
2015           const std::pair<int64, const HloInstruction*>& operand =
2016               operands.back();
2017           int64 operand_id = to_unique_operand_id[operand.second];
2018 
2019           source_index_phis[operand_id]->addIncoming(
2020               source_index.GetConstantWithIndexType(operand.first),
2021               b_->GetInsertBlock());
2022           b_->CreateBr(emit_operand_blocks[operand_id]);
2023           return block;
2024         }
2025 
2026         // Take the middle element and recurse.
2027         llvm::Constant* pivot_const = llvm::ConstantInt::get(
2028             source_index[concat_dim]->getType(), pivot.first);
2029         llvm::Value* comp =
2030             b_->CreateICmpULT(source_index[concat_dim], pivot_const);
2031 
2032         llvm::BasicBlock* left_block = emit_tree(operands.subspan(0, mid));
2033         llvm::BasicBlock* right_block = emit_tree(operands.subspan(mid));
2034 
2035         b_->CreateCondBr(comp, left_block, right_block);
2036         return block;
2037       };
2038 
2039   Br(emit_tree(cases));
2040 
2041   b_->SetInsertPoint(exit_block, prior_insert_point);
2042   return output;
2043 }
2044 
EmitElementalDynamicSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2045 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
2046     const HloInstruction* hlo,
2047     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2048     const llvm_ir::IrArray::Index& index) {
2049   // Emit IR to read dynamic start indices from hlo->operand(1).
2050   const HloInstruction* input_hlo = hlo->operand(0);
2051   const int64 rank = input_hlo->shape().rank();
2052   // Use the same index type for all tensor accesses in the same kernel.
2053   llvm::Type* index_type = index.GetType();
2054   std::vector<llvm::Value*> slice_start_multi_index(rank);
2055   for (int64 i = 0; i < rank; ++i) {
2056     auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2057       return llvm::ConstantInt::get(index_type, c);
2058     };
2059     llvm_ir::IrArray::Index zero_index(index_type);
2060     TF_ASSIGN_OR_RETURN(
2061         llvm::Value * start_index_value,
2062         operand_to_generator.at(hlo->operand(1 + i))(zero_index));
2063 
2064     // Clamp the start index so that the sliced portion fits in the operand:
2065     // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
2066     start_index_value = SExtOrTrunc(start_index_value, index_type);
2067     int64 largest_valid_start_index =
2068         input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i);
2069     CHECK_GE(largest_valid_start_index, 0);
2070 
2071     bool is_signed = ShapeUtil::ElementIsSigned(hlo->operand(1)->shape());
2072     start_index_value = EmitIntegralMin(
2073         index_typed_const(largest_valid_start_index),
2074         EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
2075         is_signed);
2076 
2077     start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
2078     slice_start_multi_index[i] = start_index_value;
2079   }
2080 
2081   std::vector<llvm::Value*> input_multi_index(rank);
2082   for (int64 i = 0; i < rank; ++i) {
2083     // Emit IR which computes:
2084     //   input_index = start_index + offset_index
2085     input_multi_index[i] = Add(slice_start_multi_index[i], index[i]);
2086   }
2087   llvm_ir::IrArray::Index input_index(input_multi_index, input_hlo->shape(),
2088                                       index_type);
2089   return operand_to_generator.at(input_hlo)(input_index);
2090 }
2091 
EmitElementalGather(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2092 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
2093     const HloInstruction* hlo,
2094     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2095     const llvm_ir::IrArray::Index& index) {
2096   const Shape& operand_shape = hlo->operand(0)->shape();
2097   const Shape& indices_shape = hlo->operand(1)->shape();
2098   const Shape& output_shape = hlo->shape();
2099 
2100   const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers();
2101 
2102   const llvm_ir::ElementGenerator& operand_generator =
2103       operand_to_generator.at(hlo->operand(0));
2104   const llvm_ir::ElementGenerator& indices_generator =
2105       operand_to_generator.at(hlo->operand(1));
2106 
2107   llvm::Type* index_type = index.GetType();
2108   // This is the index into `operand` that holds the element we want to
2109   // generate.
2110   std::vector<llvm::Value*> operand_multi_index;
2111 
2112   // First copy in the window indices to operand_index. Also collect a mapping
2113   // from operand dimension to output window dimension. Elided window dimensions
2114   // map to -1.
2115   std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1);
2116   for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0;
2117        i < e; i++) {
2118     if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
2119       operand_multi_index.push_back(index.GetConstantWithIndexType(0));
2120     } else {
2121       int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++);
2122       operand_to_output_dim[i] = output_window_dim;
2123       operand_multi_index.push_back(index[output_window_dim]);
2124     }
2125   }
2126 
2127   // This is the index of the index vector in the start_indices tensor.
2128   std::vector<llvm::Value*> gather_index_index_components;
2129   {
2130     for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
2131       if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
2132         gather_index_index_components.push_back(index[i]);
2133       }
2134     }
2135 
2136     if (gather_index_index_components.size() !=
2137         indices_shape.dimensions_size()) {
2138       gather_index_index_components.insert(
2139           gather_index_index_components.begin() +
2140               dim_numbers.index_vector_dim(),
2141           nullptr);
2142     }
2143   }
2144 
2145   auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
2146     auto index_component_type = index_component->getType();
2147     auto extended_type = index_component_type->getScalarSizeInBits() >=
2148                                  index_type->getScalarSizeInBits()
2149                              ? index_component_type
2150                              : index_type;
2151     // Possibly extend the value at the beginning to ensure clamping logic stays
2152     // in bounds.
2153     auto maybe_extended_index =
2154         index_component_type != extended_type
2155             ? b_->CreateSExt(index_component, extended_type)
2156             : index_component;
2157     int64 operand_dim = dim_numbers.start_index_map(dim);
2158     int64 output_dim = operand_to_output_dim[operand_dim];
2159     // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
2160     // This means we set the iteration index to 0, so for the purpose of the
2161     // following calculations we can consider the output dimension size to be 1.
2162     int64 output_dim_size =
2163         output_dim == -1 ? 1 : output_shape.dimensions(output_dim);
2164     int64 largest_valid_start_index =
2165         operand_shape.dimensions(operand_dim) - output_dim_size;
2166     CHECK_GE(largest_valid_start_index, 0);
2167 
2168     // Clamp the gather index so that the gather region fits in the operand.
2169     // clamped_index =
2170     //     clamp(gather_dim_component_extended, 0, largest_valid_start_index);
2171     bool is_signed = ShapeUtil::ElementIsSigned(indices_shape);
2172     auto clamped_index = EmitIntegralMin(
2173         llvm::ConstantInt::get(extended_type, largest_valid_start_index),
2174         EmitIntegralMax(llvm::ConstantInt::get(extended_type, 0),
2175                         maybe_extended_index, is_signed),
2176         is_signed);
2177     // Truncate at the end to the optimized index size
2178     auto maybe_truncated_clamped_index = extended_type != index_type
2179                                              ? Trunc(clamped_index, index_type)
2180                                              : clamped_index;
2181 
2182     operand_multi_index[operand_dim] =
2183         Add(operand_multi_index[operand_dim], maybe_truncated_clamped_index);
2184   };
2185 
2186   if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
2187     IrArray::Index gather_index_index(gather_index_index_components,
2188                                       indices_shape, index_type);
2189     TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
2190                         indices_generator(gather_index_index));
2191     add_to_operand_index(gather_dim_component, 0);
2192   } else {
2193     int64 index_vector_size =
2194         indices_shape.dimensions(dim_numbers.index_vector_dim());
2195     for (int64 i = 0; i < index_vector_size; i++) {
2196       gather_index_index_components[dim_numbers.index_vector_dim()] =
2197           index.GetConstantWithIndexType(i);
2198       IrArray::Index gather_index_index(gather_index_index_components,
2199                                         indices_shape, index_type);
2200       TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
2201                           indices_generator(gather_index_index));
2202       add_to_operand_index(gather_dim_component, i);
2203     }
2204   }
2205   IrArray::Index operand_index(operand_multi_index, operand_shape, index_type);
2206   return operand_generator(operand_index);
2207 }
2208 
EmitElementalDynamicUpdateSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2209 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
2210     const HloInstruction* hlo,
2211     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2212     const llvm_ir::IrArray::Index& index) {
2213   const HloInstruction* input_hlo = hlo->operand(0);
2214   const HloInstruction* update_hlo = hlo->operand(1);
2215   const HloInstruction* start_hlo = hlo->operand(2);
2216   // Calculate slice start/end indices.
2217   const int64 rank = input_hlo->shape().rank();
2218   std::vector<llvm::Value*> slice_start_multi_index(rank);
2219   std::vector<llvm::Value*> slice_limit_multi_index(rank);
2220   // Slice intersection gathers (ANDs) conditions on all ranks for which
2221   // 'input' is set to 'update'
2222   llvm::Value* slice_intersection = b_->getTrue();
2223 
2224   for (int64 i = 0; i < rank; ++i) {
2225     llvm::Type* index_type = index[0]->getType();
2226     auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2227       return llvm::ConstantInt::get(index_type, c);
2228     };
2229 
2230     llvm_ir::IrArray::Index zero_index(index_type);
2231     TF_ASSIGN_OR_RETURN(
2232         llvm::Value * start_index_value,
2233         operand_to_generator.at(hlo->operand(2 + i))(zero_index));
2234 
2235     // Clamp the start index so that the update region fits in the operand.
2236     // start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
2237     start_index_value = SExtOrTrunc(start_index_value, index_type);
2238     llvm::Value* update_dim_size =
2239         index_typed_const(update_hlo->shape().dimensions(i));
2240     int64 largest_valid_start_index =
2241         input_hlo->shape().dimensions(i) - update_hlo->shape().dimensions(i);
2242     CHECK_GE(largest_valid_start_index, 0);
2243 
2244     bool is_signed = ShapeUtil::ElementIsSigned(start_hlo->shape());
2245     start_index_value = EmitIntegralMin(
2246         index_typed_const(largest_valid_start_index),
2247         EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
2248         is_signed);
2249 
2250     start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
2251     slice_start_multi_index[i] = start_index_value;
2252     slice_limit_multi_index[i] =
2253         Add(slice_start_multi_index[i], update_dim_size);
2254 
2255     slice_intersection =
2256         And(slice_intersection, ICmpSGE(index[i], slice_start_multi_index[i]),
2257             "slice_intersection");
2258     slice_intersection =
2259         And(slice_intersection, ICmpSLT(index[i], slice_limit_multi_index[i]),
2260             "slice_intersection");
2261   }
2262 
2263   // Emit:
2264   // if (slice_intersection) -> return data from 'update'.
2265   // else                    -> return data from 'input'.
2266   llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2267       llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2268       "ret_value_addr", b_);
2269   llvm_ir::LlvmIfData if_data =
2270       llvm_ir::EmitIfThenElse(slice_intersection, "slice_intersection", b_);
2271 
2272   // Handle true BB (return data from 'update')
2273   SetToFirstInsertPoint(if_data.true_block, b_);
2274   // Compute update index for intersection case.
2275   std::vector<llvm::Value*> update_multi_index(rank);
2276   for (int64 i = 0; i < rank; ++i) {
2277     update_multi_index[i] = Sub(index[i], slice_start_multi_index[i]);
2278   }
2279   llvm_ir::IrArray::Index update_index(update_multi_index, update_hlo->shape(),
2280                                        index.GetType());
2281   TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
2282                       operand_to_generator.at(update_hlo)(update_index));
2283   Store(true_value, ret_value_addr);
2284 
2285   // Handle false BB (return data from 'input')
2286   SetToFirstInsertPoint(if_data.false_block, b_);
2287   TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
2288                       operand_to_generator.at(input_hlo)(index));
2289   Store(false_value, ret_value_addr);
2290 
2291   SetToFirstInsertPoint(if_data.after_block, b_);
2292   return Load(ret_value_addr);
2293 }
2294 
EmitElementalPad(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & padded_index)2295 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
2296     const HloInstruction* hlo,
2297     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2298     const llvm_ir::IrArray::Index& padded_index) {
2299   std::vector<llvm::Value*> multi_index = padded_index.multidim();
2300   llvm::Value* in_bounds = b_->getTrue();
2301   for (size_t i = 0; i < multi_index.size(); ++i) {
2302     auto index_typed_const = [=](int64 n) {
2303       return padded_index.GetConstantWithIndexType(n);
2304     };
2305     const auto& pad_dim = hlo->padding_config().dimensions(i);
2306     multi_index[i] =
2307         Sub(multi_index[i], index_typed_const(pad_dim.edge_padding_low()));
2308     in_bounds = And(in_bounds, ICmpSGE(multi_index[i], index_typed_const(0)),
2309                     "in_bounds");
2310     in_bounds =
2311         And(in_bounds,
2312             ICmpEQ(index_typed_const(0),
2313                    URem(multi_index[i],
2314                         index_typed_const(pad_dim.interior_padding() + 1))),
2315             "in_bounds");
2316     multi_index[i] =
2317         SDiv(multi_index[i], index_typed_const(pad_dim.interior_padding() + 1));
2318     in_bounds =
2319         And(in_bounds,
2320             ICmpSLT(multi_index[i],
2321                     index_typed_const(hlo->operand(0)->shape().dimensions(i))),
2322             "in_bounds");
2323   }
2324 
2325   // if (in_bounds) {
2326   //   ret_value = operand0[index];  // source
2327   // } else {
2328   //   ret_value = *operand1;        // padding
2329   // }
2330   llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2331       llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2332       "pad_result_addr", b_);
2333   llvm_ir::LlvmIfData if_data =
2334       llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
2335   SetToFirstInsertPoint(if_data.true_block, b_);
2336   llvm_ir::IrArray::Index index(multi_index, hlo->operand(0)->shape(),
2337                                 padded_index.GetType());
2338   TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2339                       operand_to_generator.at(hlo->operand(0))(index));
2340   Store(operand_value, ret_value_addr);
2341 
2342   SetToFirstInsertPoint(if_data.false_block, b_);
2343   TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
2344                       operand_to_generator.at(hlo->operand(1))(
2345                           IrArray::Index(index.GetType())));
2346   Store(padding_value, ret_value_addr);
2347 
2348   SetToFirstInsertPoint(if_data.after_block, b_);
2349   // Don't create phi(operand_value, padding_value) here, because invoking
2350   // operand_to_generator may create new basic blocks, making the parent
2351   // of operand_value or padding_value no longer a predecessor of
2352   // if_data.after_block.
2353   return Load(ret_value_addr);
2354 }
2355 
EmitElementalDot(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & dot_result_index)2356 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
2357     const HloInstruction* hlo,
2358     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2359     const llvm_ir::IrArray::Index& dot_result_index) {
2360   auto lhs_generator = operand_to_generator.at(hlo->operand(0));
2361   auto rhs_generator = operand_to_generator.at(hlo->operand(1));
2362 
2363   const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers();
2364   int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0);
2365   int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0);
2366 
2367   int64 contracted_dim_size =
2368       hlo->operand(0)->shape().dimensions(lhs_contracting_dim);
2369   int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
2370   int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
2371 
2372   llvm::Type* index_type = dot_result_index.GetType();
2373   auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2374     return llvm::ConstantInt::get(index_type, c);
2375   };
2376 
2377   std::unique_ptr<llvm_ir::ForLoop> inner_loop = llvm_ir::ForLoop::EmitForLoop(
2378       IrName(hlo, "inner"), index_typed_const(0),
2379       index_typed_const(contracted_dim_size), index_typed_const(1), b_);
2380 
2381   SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), b_);
2382   PrimitiveType primitive_type = hlo->shape().element_type();
2383   llvm::Type* primitive_type_llvm =
2384       llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
2385   llvm::Value* accumulator_alloca =
2386       llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_);
2387   Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca);
2388 
2389   SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_);
2390 
2391   // This is the inner reduction loop for a dot operation that produces
2392   // one element in the output.  If the operands to the dot operation have
2393   // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
2394   // Given an output index [a,b,c,d,e] in the result, we compute:
2395   //   sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
2396 
2397   std::vector<llvm::Value*> lhs_multi_index, rhs_multi_index;
2398   for (int64 i = 0; i < lhs_dims - 1; i++) {
2399     lhs_multi_index.push_back(dot_result_index[i]);
2400   }
2401   lhs_multi_index.insert(lhs_multi_index.begin() + lhs_contracting_dim,
2402                          inner_loop->GetIndVarValue());
2403   IrArray::Index lhs_index(lhs_multi_index, hlo->operand(0)->shape(),
2404                            index_type);
2405 
2406   int64 num_batch_dims = dim_numbers.rhs_batch_dimensions_size();
2407   for (int64 i = 0; i < num_batch_dims; i++) {
2408     rhs_multi_index.push_back(
2409         dot_result_index[dim_numbers.rhs_batch_dimensions(i)]);
2410   }
2411   for (int64 i = 0; i < rhs_dims - 1 - num_batch_dims; i++) {
2412     rhs_multi_index.push_back(dot_result_index[lhs_dims - 1 + i]);
2413   }
2414   rhs_multi_index.insert(rhs_multi_index.begin() + rhs_contracting_dim,
2415                          inner_loop->GetIndVarValue());
2416   IrArray::Index rhs_index(rhs_multi_index, hlo->operand(1)->shape(),
2417                            index_type);
2418 
2419   llvm::Value* current_accumulator = Load(accumulator_alloca);
2420   TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
2421   TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
2422   llvm::Value* next_accumulator =
2423       EmitMulAdd(lhs_value, rhs_value, current_accumulator, primitive_type);
2424   Store(next_accumulator, accumulator_alloca);
2425 
2426   SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
2427   return Load(accumulator_alloca);
2428 }
2429 
MakeElementGenerator(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator)2430 llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
2431     const HloInstruction* hlo,
2432     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
2433   switch (hlo->opcode()) {
2434     case HloOpcode::kAbs:
2435     case HloOpcode::kRoundNearestAfz:
2436     case HloOpcode::kCeil:
2437     case HloOpcode::kClz:
2438     case HloOpcode::kConvert:
2439     case HloOpcode::kBitcastConvert:
2440     case HloOpcode::kCos:
2441     case HloOpcode::kExp:
2442     case HloOpcode::kExpm1:
2443     case HloOpcode::kFloor:
2444     case HloOpcode::kImag:
2445     case HloOpcode::kIsFinite:
2446     case HloOpcode::kLog:
2447     case HloOpcode::kLog1p:
2448     case HloOpcode::kNegate:
2449     case HloOpcode::kNot:
2450     case HloOpcode::kPopulationCount:
2451     case HloOpcode::kReal:
2452     case HloOpcode::kRsqrt:
2453     case HloOpcode::kSign:
2454     case HloOpcode::kSin:
2455     case HloOpcode::kSqrt:
2456     case HloOpcode::kCbrt:
2457     case HloOpcode::kTanh:
2458       return [this, hlo, &operand_to_generator](
2459                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2460         TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2461                             operand_to_generator.at(hlo->operand(0))(index));
2462         return EmitUnaryOp(hlo, operand_value);
2463       };
2464     case HloOpcode::kAdd:
2465     case HloOpcode::kAnd:
2466     case HloOpcode::kAtan2:
2467     case HloOpcode::kCompare:
2468     case HloOpcode::kComplex:
2469     case HloOpcode::kDivide:
2470     case HloOpcode::kMaximum:
2471     case HloOpcode::kMinimum:
2472     case HloOpcode::kMultiply:
2473     case HloOpcode::kOr:
2474     case HloOpcode::kXor:
2475     case HloOpcode::kPower:
2476     case HloOpcode::kRemainder:
2477     case HloOpcode::kShiftLeft:
2478     case HloOpcode::kShiftRightArithmetic:
2479     case HloOpcode::kShiftRightLogical:
2480     case HloOpcode::kSubtract:
2481       return [this, hlo, &operand_to_generator](
2482                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2483         const HloInstruction* lhs = hlo->operand(0);
2484         const HloInstruction* rhs = hlo->operand(1);
2485         TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value,
2486                             operand_to_generator.at(lhs)(index));
2487         TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value,
2488                             operand_to_generator.at(rhs)(index));
2489         return EmitBinaryOp(hlo, lhs_value, rhs_value);
2490       };
2491     case HloOpcode::kSelect:
2492       return [this, hlo, &operand_to_generator](
2493                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2494         return EmitElementalSelect(hlo, operand_to_generator, index);
2495       };
2496     case HloOpcode::kClamp:
2497       return [this, hlo, &operand_to_generator](
2498                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2499         return EmitElementalClamp(hlo, operand_to_generator, index);
2500       };
2501     case HloOpcode::kReducePrecision:
2502       return [this, hlo, &operand_to_generator](
2503                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2504         TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2505                             operand_to_generator.at(hlo->operand(0))(index));
2506         return EmitReducePrecision(hlo, operand_value);
2507       };
2508     case HloOpcode::kConcatenate:
2509       return [this, hlo, &operand_to_generator](
2510                  const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
2511         return EmitElementalConcatenate(hlo, operand_to_generator,
2512                                         target_index);
2513       };
2514     case HloOpcode::kReverse:
2515       return [this, hlo, &operand_to_generator](
2516                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2517         const HloInstruction* operand = hlo->operand(0);
2518         std::vector<llvm::Value*> source_multi_index = target_index.multidim();
2519         for (int64 dim : hlo->dimensions()) {
2520           source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType(
2521                                             hlo->shape().dimensions(dim) - 1),
2522                                         target_index[dim]);
2523         }
2524         llvm_ir::IrArray::Index source_index(
2525             source_multi_index, operand->shape(), target_index.GetType());
2526         return operand_to_generator.at(operand)(source_index);
2527       };
2528     case HloOpcode::kBroadcast:
2529       return [this, hlo, &operand_to_generator](
2530                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2531         const HloInstruction* operand = hlo->operand(0);
2532         // The `dimensions` member of the broadcast instruction maps from
2533         // input dimensions to output dimensions.
2534         return operand_to_generator.at(operand)(
2535             target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(),
2536                                                 hlo->dimensions(), b_));
2537       };
2538     case HloOpcode::kIota:
2539       return [this, hlo](
2540                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2541         auto* iota = Cast<HloIotaInstruction>(hlo);
2542         PrimitiveType element_type = iota->shape().element_type();
2543         IrArray::Index elem_index =
2544             iota->shape().rank() > 1
2545                 ? target_index.SourceIndexOfBroadcast(
2546                       iota->shape(),
2547                       ShapeUtil::MakeShapeWithDescendingLayout(
2548                           element_type,
2549                           {iota->shape().dimensions(iota->iota_dimension())}),
2550                       {iota->iota_dimension()}, b_)
2551                 : target_index;
2552         llvm::Value* elem_index_linear = elem_index.linear();
2553         if (elem_index_linear == nullptr) {
2554           std::vector<int64> iota_bound = {
2555               iota->shape().dimensions(iota->iota_dimension())};
2556           elem_index_linear = elem_index.Linearize(iota_bound, b_);
2557         }
2558         Shape component_shape =
2559             ShapeUtil::ElementIsComplex(iota->shape())
2560                 ? ShapeUtil::ComplexComponentShape(iota->shape())
2561                 : iota->shape();
2562         PrimitiveType component_element_type = component_shape.element_type();
2563         llvm::Value* iota_result;
2564         if (primitive_util::IsIntegralType(component_element_type)) {
2565           iota_result = b_->CreateIntCast(
2566               elem_index_linear,
2567               llvm_ir::PrimitiveTypeToIrType(component_element_type, module_),
2568               /*isSigned=*/false);
2569         } else {
2570           TF_RET_CHECK(
2571               primitive_util::IsFloatingPointType(component_element_type))
2572               << component_element_type;
2573           llvm::Type* float_ir_type;
2574           if (component_element_type == BF16) {
2575             float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_);
2576           } else {
2577             float_ir_type =
2578                 llvm_ir::PrimitiveTypeToIrType(component_element_type, module_);
2579           }
2580           llvm::Value* float_val =
2581               b_->CreateUIToFP(elem_index_linear, float_ir_type);
2582           if (component_element_type == BF16) {
2583             TF_ASSIGN_OR_RETURN(iota_result, EmitF32ToBF16(float_val, b_));
2584           } else {
2585             iota_result = float_val;
2586           }
2587         }
2588         if (ShapeUtil::ElementIsComplex(iota->shape())) {
2589           return EmitComposeComplex(iota, iota_result, nullptr);
2590         } else {
2591           return iota_result;
2592         }
2593       };
2594     case HloOpcode::kSlice:
2595       return [this, hlo, &operand_to_generator](
2596                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2597         IrArray::Index sliced_index = index.SourceIndexOfSlice(
2598             /*operand_shape=*/hlo->operand(0)->shape(),
2599             /*starts=*/hlo->slice_starts(),
2600             /*strides=*/hlo->slice_strides(), /*builder=*/b_);
2601         return operand_to_generator.at(hlo->operand(0))(sliced_index);
2602       };
2603     case HloOpcode::kDynamicSlice:
2604       return [this, hlo, &operand_to_generator](
2605                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2606         return EmitElementalDynamicSlice(hlo, operand_to_generator, index);
2607       };
2608 
2609     case HloOpcode::kGather:
2610       return [this, hlo, &operand_to_generator](
2611                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2612         return EmitElementalGather(hlo, operand_to_generator, index);
2613       };
2614     case HloOpcode::kDynamicUpdateSlice:
2615       return [this, hlo, &operand_to_generator](
2616                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2617         return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator,
2618                                                index);
2619       };
2620     case HloOpcode::kBitcast:
2621       CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2622                ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2623       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2624         const HloInstruction* operand = hlo->operand(0);
2625         return operand_to_generator.at(operand)(
2626             GetSourceIndexOfBitcast(index, hlo));
2627       };
2628     case HloOpcode::kReshape:
2629       CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2630                ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2631       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2632         const HloInstruction* operand = hlo->operand(0);
2633         return operand_to_generator.at(operand)(
2634             index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_));
2635       };
2636     case HloOpcode::kCopy:
2637       return [hlo, &operand_to_generator](
2638                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2639         IrArray::Index source_index(target_index.multidim(),
2640                                     hlo->operand(0)->shape(),
2641                                     target_index.GetType());
2642         TF_ASSIGN_OR_RETURN(
2643             llvm::Value * operand_value,
2644             operand_to_generator.at(hlo->operand(0))(source_index));
2645         return operand_value;
2646       };
2647     case HloOpcode::kTranspose:
2648       return [this, hlo,
2649               &operand_to_generator](const IrArray::Index& target_index) {
2650         return operand_to_generator.at(hlo->operand(0))(
2651             target_index.SourceIndexOfTranspose(
2652                 hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions()));
2653       };
2654     case HloOpcode::kPad:
2655       return [this, hlo, &operand_to_generator](
2656                  const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
2657         return EmitElementalPad(hlo, operand_to_generator, padded_index);
2658       };
2659 
2660     case HloOpcode::kDot:
2661       return [this, hlo,
2662               &operand_to_generator](const IrArray::Index& dot_result_index)
2663                  -> StatusOr<llvm::Value*> {
2664         return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
2665       };
2666     case HloOpcode::kMap:
2667       return [this, hlo, &operand_to_generator](
2668                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2669         std::vector<llvm::Value*> operands;
2670         for (int i = 0; i < hlo->operand_count(); i++) {
2671           TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2672                               operand_to_generator.at(hlo->operand(i))(index));
2673           operands.push_back(operand_value);
2674         }
2675         return EmitElementalMap(Cast<HloMapInstruction>(hlo), operands);
2676       };
2677     case HloOpcode::kReduceWindow:
2678       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2679         auto reduce_window_instr = Cast<HloReduceWindowInstruction>(hlo);
2680         std::vector<llvm_ir::ElementGenerator> input_generators;
2681         for (const HloInstruction* instr : reduce_window_instr->inputs()) {
2682           input_generators.push_back(operand_to_generator.at(instr));
2683         }
2684 
2685         std::vector<llvm_ir::ElementGenerator> initial_value_generators;
2686         for (const HloInstruction* instr : reduce_window_instr->init_values()) {
2687           initial_value_generators.push_back(operand_to_generator.at(instr));
2688         }
2689         return EmitElementalReduceWindow(
2690             Cast<HloReduceWindowInstruction>(hlo), std::move(input_generators),
2691             std::move(initial_value_generators), index);
2692       };
2693     case HloOpcode::kReduce:
2694       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2695         auto reduce_instr = Cast<HloReduceInstruction>(hlo);
2696         std::vector<llvm_ir::ElementGenerator> input_generators;
2697         for (const HloInstruction* instr : reduce_instr->inputs()) {
2698           input_generators.push_back(operand_to_generator.at(instr));
2699         }
2700 
2701         std::vector<llvm_ir::ElementGenerator> initial_value_generators;
2702         for (const HloInstruction* instr : reduce_instr->init_values()) {
2703           initial_value_generators.push_back(operand_to_generator.at(instr));
2704         }
2705         return EmitElementalReduce(reduce_instr, std::move(input_generators),
2706                                    std::move(initial_value_generators), index);
2707       };
2708     case HloOpcode::kConvolution:
2709       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2710         return EmitConvolution(hlo, operand_to_generator, index);
2711       };
2712     default:
2713       return [hlo](const IrArray::Index& index) {
2714         return Unimplemented("Unhandled opcode for elemental IR emission: %s",
2715                              HloOpcodeString(hlo->opcode()));
2716       };
2717   }
2718 }
2719 
EmitExtractReal(llvm::Value * value)2720 llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) {
2721   return ExtractValue(value, {0});
2722 }
2723 
EmitExtractImag(llvm::Value * value)2724 llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) {
2725   return ExtractValue(value, {1});
2726 }
2727 
EmitComposeComplex(const HloInstruction * op,llvm::Value * real,llvm::Value * imag)2728 llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
2729                                                     llvm::Value* real,
2730                                                     llvm::Value* imag) {
2731   auto cplx_type =
2732       llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
2733   auto complex =
2734       InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0});
2735   if (imag != nullptr) {
2736     complex = InsertValue(complex, imag, {1});
2737   }
2738   return complex;
2739 }
2740 
EmitMulAdd(llvm::Value * lhs,llvm::Value * rhs,llvm::Value * accumulator,xla::PrimitiveType primitive_type)2741 llvm::Value* ElementalIrEmitter::EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs,
2742                                             llvm::Value* accumulator,
2743                                             xla::PrimitiveType primitive_type) {
2744   if (primitive_util::IsComplexType(primitive_type)) {
2745     llvm::Value* product_real =
2746         FSub(FMul(EmitExtractReal(lhs), EmitExtractReal(rhs)),
2747              FMul(EmitExtractImag(lhs), EmitExtractImag(rhs)));
2748     llvm::Value* product_imag =
2749         FAdd(FMul(EmitExtractReal(lhs), EmitExtractImag(rhs)),
2750              FMul(EmitExtractImag(lhs), EmitExtractReal(rhs)));
2751     llvm::Value* next_accumulator = InsertValue(
2752         accumulator, FAdd(EmitExtractReal(accumulator), product_real), {0});
2753     return InsertValue(next_accumulator,
2754                        FAdd(EmitExtractImag(accumulator), product_imag), {1});
2755   } else if (primitive_util::IsFloatingPointType(primitive_type)) {
2756     return FAdd(accumulator, FPCast(FMul(lhs, rhs), accumulator->getType()));
2757   } else if (primitive_type == PRED) {
2758     return Or(accumulator, And(lhs, rhs));
2759   }
2760   return Add(accumulator, Mul(lhs, rhs));
2761 }
2762 
EmitElementalMap(const HloMapInstruction * map_instr,absl::Span<llvm::Value * const> elemental_operands)2763 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
2764     const HloMapInstruction* map_instr,
2765     absl::Span<llvm::Value* const> elemental_operands) {
2766   TF_ASSIGN_OR_RETURN(
2767       std::vector<llvm::Value*> values,
2768       EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands,
2769                           llvm_ir::IrName(map_instr)));
2770   CHECK_EQ(values.size(), 1);
2771   return values[0];
2772 }
2773 
EmitElementalReduceWindow(const HloReduceWindowInstruction * reduce_window,std::vector<llvm_ir::ElementGenerator> input_generators,std::vector<llvm_ir::ElementGenerator> initial_value_generators,const llvm_ir::IrArray::Index & index)2774 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduceWindow(
2775     const HloReduceWindowInstruction* reduce_window,
2776     std::vector<llvm_ir::ElementGenerator> input_generators,
2777     std::vector<llvm_ir::ElementGenerator> initial_value_generators,
2778     const llvm_ir::IrArray::Index& index) {
2779   // Pseudocode:
2780   // for each index I in output
2781   //   value = init_value
2782   //   for each index W in window
2783   //     for each dimension i from 0 to rank - 1
2784   //       (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i]
2785   //     if I in bounds of input
2786   //       value = function(value, input[I])
2787   //     output[O] = value
2788   int64 input_count = reduce_window->input_count();
2789   std::vector<PrimitiveType> operand_element_types;
2790   std::vector<llvm::Type*> accum_types;
2791   std::vector<llvm::Value*> accum_ptrs;
2792   for (int64 operand_index = 0; operand_index < input_count; ++operand_index) {
2793     auto operand = reduce_window->inputs()[operand_index];
2794     PrimitiveType operand_element_type = operand->shape().element_type();
2795     operand_element_types.push_back(operand_element_type);
2796     llvm::Type* llvm_type =
2797         llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_);
2798     accum_types.push_back(llvm_type);
2799     llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
2800         llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
2801         "reduce_window_accum_ptr", b_);
2802     accum_ptrs.push_back(accum_ptr);
2803     {
2804       auto initial_value_generator = initial_value_generators[operand_index];
2805       TF_ASSIGN_OR_RETURN(
2806           llvm::Value* const init_value,
2807           initial_value_generator(llvm_ir::IrArray::Index(index.GetType())));
2808       Store(init_value, accum_ptr);
2809     }
2810   }
2811 
2812   llvm::Type* index_type = index.GetType();
2813   auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2814     return index.GetConstantWithIndexType(c);
2815   };
2816 
2817   const Window& window = reduce_window->window();
2818   llvm_ir::ForLoopNest loops(IrName(reduce_window), b_, index_type);
2819   std::vector<int64> window_size;
2820   for (const auto& dim : window.dimensions()) {
2821     window_size.push_back(dim.size());
2822   }
2823   const IrArray::Index window_index = loops.AddLoopsForShape(
2824       ShapeUtil::MakeShape(operand_element_types[0], window_size), "window");
2825   CHECK_EQ(window_index.size(), index.size());
2826 
2827   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
2828 
2829   std::vector<llvm::Value*> input_multi_index(index.size());
2830   llvm::Value* in_bounds = b_->getInt1(true);
2831   for (size_t i = 0; i < index.size(); ++i) {
2832     llvm::Value* stridden_index =
2833         NSWMul(index[i], index_typed_const(window.dimensions(i).stride()));
2834     input_multi_index[i] = NSWSub(
2835         NSWAdd(
2836             stridden_index,
2837             NSWMul(window_index[i],
2838                    index_typed_const(window.dimensions(i).window_dilation()))),
2839         index_typed_const(window.dimensions(i).padding_low()));
2840 
2841     // We need to verify that we are not in the dilated base area.
2842     llvm::Value* dilation_condition =
2843         ICmpEQ(SRem(input_multi_index[i],
2844                     index_typed_const(window.dimensions(i).base_dilation())),
2845                index_typed_const(0));
2846     in_bounds = And(in_bounds, dilation_condition);
2847 
2848     // Apply base dilation to the index.
2849     input_multi_index[i] =
2850         SDiv(input_multi_index[i],
2851              index_typed_const(window.dimensions(i).base_dilation()));
2852 
2853     // We must check whether 0 <= input_multi_index[i] < bound, as
2854     // otherwise we are in the pad and so can skip the computation. This
2855     // comparison is equivalent to the unsigned comparison
2856     // input_multi_index[i] < bound, as a negative value wraps to a large
2857     // positive value.
2858     in_bounds =
2859         And(in_bounds,
2860             ICmpULT(input_multi_index[i],
2861                     index_typed_const(
2862                         reduce_window->inputs()[0]->shape().dimensions(i))));
2863   }
2864 
2865   llvm_ir::LlvmIfData if_data =
2866       llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
2867   SetToFirstInsertPoint(if_data.true_block, b_);
2868 
2869   // We are not in pad, so do the computation.
2870   std::vector<llvm::Value*> input_values(reduce_window->operand_count());
2871   IrArray::Index input_index(input_multi_index,
2872                              reduce_window->inputs()[0]->shape(), index_type);
2873   for (int64 operand_idx = 0; operand_idx < input_count; ++operand_idx) {
2874     TF_ASSIGN_OR_RETURN(llvm::Value * input_value,
2875                         input_generators[operand_idx](input_index));
2876     input_values[input_count + operand_idx] = input_value;
2877     input_values[operand_idx] = Load(accum_ptrs[operand_idx]);
2878   }
2879   TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accum_values,
2880                       EmitThreadLocalCall(*reduce_window->to_apply(),
2881                                           input_values, "reducer_function"));
2882 
2883   for (int64 operand_idx = 0; operand_idx < accum_values.size();
2884        ++operand_idx) {
2885     Store(accum_values[operand_idx], accum_ptrs[operand_idx]);
2886   }
2887 
2888   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
2889   return EmitAccumResult(accum_ptrs, accum_types,
2890                          reduce_window->shape().IsTuple());
2891 }
2892 
EmitElementalReduce(const HloReduceInstruction * reduce,std::vector<llvm_ir::ElementGenerator> input_generators,std::vector<llvm_ir::ElementGenerator> initial_value_generators,const llvm_ir::IrArray::Index & index)2893 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
2894     const HloReduceInstruction* reduce,
2895     std::vector<llvm_ir::ElementGenerator> input_generators,
2896     std::vector<llvm_ir::ElementGenerator> initial_value_generators,
2897     const llvm_ir::IrArray::Index& index) {
2898   const Shape& out_shape = reduce->shape();
2899   bool is_variadic = !out_shape.IsArray();
2900   int accumulators_count = 1;
2901   if (is_variadic) {
2902     CHECK(out_shape.IsTuple());
2903     accumulators_count = out_shape.tuple_shapes_size();
2904   }
2905 
2906   absl::Span<const int64> reduced_dimensions(reduce->dimensions());
2907 
2908   std::vector<llvm::Value*> accumulator_addrs;
2909   std::vector<llvm::Type*> accumulator_types;
2910   llvm::Type* index_type = index.GetType();
2911   for (int i = 0; i < accumulators_count; i++) {
2912     const Shape& element_shape =
2913         is_variadic ? out_shape.tuple_shapes(i) : out_shape;
2914     PrimitiveType accumulator_type = element_shape.element_type();
2915     llvm::Type* accumulator_llvm_type =
2916         llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
2917     accumulator_types.push_back(accumulator_llvm_type);
2918 
2919     // Initialize an accumulator with init_value.
2920     llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2921         accumulator_llvm_type, "accumulator_" + std::to_string(i), b());
2922     TF_ASSIGN_OR_RETURN(
2923         llvm::Value* const init_value,
2924         initial_value_generators[i](llvm_ir::IrArray::Index(index_type)));
2925     Store(init_value, accumulator_addr);
2926     accumulator_addrs.push_back(accumulator_addr);
2927   }
2928 
2929   // The enclosing loops go over all the target elements. Now we have to compute
2930   // the actual target element. For this, we build a new loop nest to iterate
2931   // over all the reduction dimensions in the argument.
2932   // AddLoopsForShapeOnDimensions will return an Index where induction Value*s
2933   // are placed for each dimension in dimensions, and all the rest are nullptrs.
2934   llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), b(), index_type);
2935   const HloInstruction* arg = reduce->operand(0);
2936   std::vector<llvm::Value*> input_multi_index =
2937       loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions,
2938                                          "reduction_dim");
2939 
2940   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b());
2941 
2942   // Build a full index for the input argument, using input_multi_index as the
2943   // base. In input_multi_index only the reduction dimensions are filled in. We
2944   // fill in the rest of the dimensions with induction Value*s taken from
2945   // 'index' which iterates over the target array.  See the high-level
2946   // description in the XLA documentation for details.
2947   auto it = index.begin();
2948 
2949   for (auto& i : input_multi_index) {
2950     if (i == nullptr) {
2951       i = *it++;
2952     }
2953   }
2954   CHECK(index.end() == it);
2955   llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
2956                                       index_type);
2957 
2958   std::vector<llvm::Value*> reduction_operands;
2959   for (llvm::Value* accum : accumulator_addrs) {
2960     llvm::Value* accum_value = Load(accum);
2961     reduction_operands.push_back(accum_value);
2962   }
2963 
2964   for (int i = 0; i < accumulators_count; i++) {
2965     TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
2966                         input_generators[i](input_index));
2967     reduction_operands.push_back(input_element);
2968   }
2969 
2970   TF_ASSIGN_OR_RETURN(
2971       std::vector<llvm::Value*> results,
2972       EmitThreadLocalCall(*reduce->to_apply(), reduction_operands,
2973                           "reduce_function"));
2974 
2975   CHECK(results.size() == accumulators_count);
2976   for (int i = 0; i < accumulators_count; i++) {
2977     Store(results[i], accumulator_addrs[i]);
2978   }
2979   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b());
2980   return EmitAccumResult(accumulator_addrs, accumulator_types, is_variadic);
2981 }
2982 
EmitAccumResult(absl::Span<llvm::Value * const> accumulator_addrs,llvm::ArrayRef<llvm::Type * > accumulator_types,bool is_variadic)2983 StatusOr<llvm::Value*> ElementalIrEmitter::EmitAccumResult(
2984     absl::Span<llvm::Value* const> accumulator_addrs,
2985     llvm::ArrayRef<llvm::Type*> accumulator_types, bool is_variadic) {
2986   TF_RET_CHECK(accumulator_addrs.size() == accumulator_types.size());
2987   if (is_variadic) {
2988     // Emit a structure, as that what the LoopEmitter expects.
2989     llvm::Value* returned_structure = llvm::UndefValue::get(
2990         llvm::StructType::get(b()->getContext(), accumulator_types));
2991     for (int64 i = 0; i < accumulator_addrs.size(); i++) {
2992       llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
2993       returned_structure =
2994           b()->CreateInsertValue(returned_structure, accumulator_value, i);
2995     }
2996     return returned_structure;
2997   } else {
2998     CHECK_EQ(accumulator_addrs.size(), 1);
2999     return Load(accumulator_addrs[0]);
3000   }
3001 }
3002 
EmitConvolution(const HloInstruction * convolution,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)3003 StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
3004     const HloInstruction* convolution,
3005     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
3006     const llvm_ir::IrArray::Index& index) {
3007   const HloInstruction* lhs = convolution->operand(0);
3008   const auto& input_generator = operand_to_generator.at(lhs);
3009   const HloInstruction* rhs = convolution->operand(1);
3010   const auto& kernel_generator = operand_to_generator.at(rhs);
3011   const Window& window = convolution->window();
3012 
3013   const ConvolutionDimensionNumbers& dnums =
3014       convolution->convolution_dimension_numbers();
3015   int num_spatial_dims = dnums.output_spatial_dimensions_size();
3016   std::vector<llvm::Value*> output_spatial(num_spatial_dims);
3017   for (int i = 0; i < num_spatial_dims; ++i) {
3018     output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
3019   }
3020   llvm::Value* output_feature = index[dnums.output_feature_dimension()];
3021   llvm::Value* batch = index[dnums.output_batch_dimension()];
3022 
3023   // We will accumulate the products into this sum to calculate the output entry
3024   // at the given index.
3025   PrimitiveType lhs_element_type = lhs->shape().element_type();
3026   llvm::Type* lhs_llvm_type =
3027       llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
3028   // Upcast the accumulator to F32 from F16 for increased precision.
3029   llvm::Type* accumulator_type =
3030       lhs_element_type == F16 ? b_->getFloatTy() : lhs_llvm_type;
3031   llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
3032       accumulator_type, "convolution_sum_address", b_);
3033   llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type);
3034   Store(constant_zero, sum_address);
3035 
3036   llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), b_);
3037   std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
3038   for (int i = 0; i < num_spatial_dims; ++i) {
3039     kernel_spatial[i] =
3040         loops
3041             .AddLoop(
3042                 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
3043                 absl::StrCat("k", i))
3044             ->GetIndVarValue();
3045   }
3046   llvm::Value* input_feature =
3047       loops
3048           .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()),
3049                    "iz")
3050           ->GetIndVarValue();
3051 
3052   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
3053 
3054   // Calculate the spatial index in the input array, taking striding, dilation
3055   // and padding into account. An index in the padding will be out of the bounds
3056   // of the array.
3057   const auto calculate_input_index = [this](llvm::Value* output_index,
3058                                             llvm::Value* kernel_index,
3059                                             const WindowDimension& window_dim) {
3060     llvm::Value* strided_index =
3061         NSWMul(output_index, b_->getInt64(window_dim.stride()));
3062     llvm::Value* dilated_kernel_index =
3063         NSWMul(kernel_index, b_->getInt64(window_dim.window_dilation()));
3064     return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
3065                   b_->getInt64(window_dim.padding_low()));
3066   };
3067   std::vector<llvm::Value*> input_spatial(num_spatial_dims);
3068   for (int i = 0; i < num_spatial_dims; ++i) {
3069     input_spatial[i] = calculate_input_index(
3070         output_spatial[i], kernel_spatial[i], window.dimensions(i));
3071   }
3072 
3073   // We need to check if 0 <= input dim < bound, as otherwise we are in the
3074   // padding so that we can skip the computation. That is equivalent to input
3075   // dim < bound as an *unsigned* comparison, since a negative value will wrap
3076   // to a large positive value. The input dim is dilated, so we need to dilate
3077   // the bound as well to match.
3078 
3079   // Also need to check that the input coordinates are not in one of the
3080   // holes created by base dilation.
3081   const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
3082     llvm::Value* remainder = SRem(input_index, b_->getInt64(base_dilation));
3083     return ICmpEQ(remainder, b_->getInt64(0));
3084   };
3085 
3086   llvm::Value* in_bounds_condition = b_->getInt1(true);
3087   for (int i = 0; i < num_spatial_dims; ++i) {
3088     llvm::ConstantInt* input_bound = b_->getInt64(window_util::DilatedBound(
3089         lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
3090         window.dimensions(i).base_dilation()));
3091     llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
3092     llvm::Value* dim_not_in_hole =
3093         not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
3094     llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
3095     in_bounds_condition = And(in_bounds_condition, dim_ok);
3096   }
3097 
3098   // Now we need to map the dilated base coordinates back to the actual
3099   // data indices on the lhs.
3100   const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
3101     return SDiv(input_index, b_->getInt64(base_dilation));
3102   };
3103   for (int i = 0; i < num_spatial_dims; ++i) {
3104     input_spatial[i] =
3105         undilate(input_spatial[i], window.dimensions(i).base_dilation());
3106   }
3107 
3108   llvm_ir::LlvmIfData if_data =
3109       llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", b_);
3110   SetToFirstInsertPoint(if_data.true_block, b_);
3111 
3112   // We are not in the padding, so carry out the computation.
3113   int num_dims = num_spatial_dims + 2;
3114   std::vector<llvm::Value*> input_multi_index(num_dims);
3115   for (int i = 0; i < num_spatial_dims; ++i) {
3116     input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
3117   }
3118   input_multi_index[dnums.input_feature_dimension()] = input_feature;
3119   input_multi_index[dnums.input_batch_dimension()] = batch;
3120 
3121   std::vector<llvm::Value*> kernel_multi_index(num_dims);
3122   for (int i = 0; i < num_spatial_dims; ++i) {
3123     kernel_multi_index[dnums.kernel_spatial_dimensions(i)] =
3124         window.dimensions(i).window_reversal()
3125             ? NSWSub(b_->getInt64(window.dimensions(i).size() - 1),
3126                      kernel_spatial[i])
3127             : kernel_spatial[i];
3128   }
3129 
3130   kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature;
3131   kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature;
3132 
3133   llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(),
3134                                       b_->getInt64Ty());
3135   TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
3136                       input_generator(input_index));
3137   llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(),
3138                                        b_->getInt64Ty());
3139   TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
3140                       kernel_generator(kernel_index));
3141   llvm::Value* sum = EmitMulAdd(input_value, kernel_value, Load(sum_address),
3142                                 convolution->shape().element_type());
3143   Store(sum, sum_address);
3144 
3145   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
3146   return FPCast(Load(sum_address), lhs_llvm_type);
3147 }
3148 
3149 // Evaluate polynomial using Horner's method.
EvaluatePolynomial(llvm::Type * type,llvm::Value * x,absl::Span<const double> coefficients)3150 StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
3151     llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
3152   llvm::Value* poly = llvm::ConstantFP::get(type, 0.0);
3153   for (const double c : coefficients) {
3154     poly = FAdd(FMul(poly, x), llvm::ConstantFP::get(type, c));
3155   }
3156   return poly;
3157 }
3158 
3159 }  // namespace xla
3160