1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/strings/str_cat.h"
28 #include "llvm/IR/BasicBlock.h"
29 #include "llvm/IR/Constants.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/Intrinsics.h"
32 #include "llvm/Support/MathExtras.h"
33 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
34 #include "tensorflow/compiler/xla/primitive_util.h"
35 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
36 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/statusor.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/compiler/xla/window_util.h"
48 #include "tensorflow/compiler/xla/xla_data.pb.h"
49 #include "tensorflow/core/lib/random/random.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/types.h"
52
53 namespace xla {
54
55 using absl::StrCat;
56 using llvm_ir::IrArray;
57 using llvm_ir::IrName;
58 using llvm_ir::SetToFirstInsertPoint;
59
60 namespace {
61
GlobalRandomValue()62 int64 GlobalRandomValue() {
63 static auto* mu = new tensorflow::mutex();
64 static std::mt19937_64 rng{42};
65 tensorflow::mutex_lock l(*mu);
66 return rng();
67 }
68
EmitReducePrecisionIR(PrimitiveType src_ty,llvm::Value * x,int64 dest_exponent_bits,int64 dest_mantissa_bits,bool quiet_nans,llvm::IRBuilder<> * b)69 StatusOr<llvm::Value*> EmitReducePrecisionIR(
70 PrimitiveType src_ty, llvm::Value* x, int64 dest_exponent_bits,
71 int64 dest_mantissa_bits, bool quiet_nans, llvm::IRBuilder<>* b) {
72 using llvm::APInt;
73
74 if (!primitive_util::IsFloatingPointType(src_ty)) {
75 return Unimplemented(
76 "ReducePrecision cannot accept non-floating-point type %s.",
77 PrimitiveType_Name(src_ty));
78 }
79
80 // Integer and float types for casting and constant generation.
81 llvm::Type* float_type = x->getType();
82 int64 nbits = float_type->getPrimitiveSizeInBits();
83 llvm::IntegerType* int_type = b->getIntNTy(nbits);
84
85 // SignificandWidth includes the implicit extra bit.
86 int src_mantissa_bits = primitive_util::SignificandWidth(src_ty) - 1;
87 int src_exponent_bits = nbits - 1 - src_mantissa_bits;
88
89 // Cast the input value to an integer for bitwise manipulation.
90 llvm::Value* x_as_int = b->CreateBitCast(x, int_type);
91
92 // Clear the sign bit, it does not participate in rounding and we will restore
93 // it later.
94 APInt sign_bit_mask(nbits, 1);
95 sign_bit_mask <<= nbits - 1;
96 llvm::Value* x_abs_bits =
97 b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, ~sign_bit_mask));
98
99 APInt exp_bits_mask(nbits, 1);
100 exp_bits_mask = ((exp_bits_mask << src_exponent_bits) - 1)
101 << src_mantissa_bits;
102 auto x_is_nan = b->CreateICmpUGT(
103 x_abs_bits, llvm::ConstantInt::get(int_type, exp_bits_mask));
104
105 if (dest_mantissa_bits < src_mantissa_bits) {
106 // Last remaining mantissa bit.
107 APInt last_mantissa_bit_mask(nbits, 1);
108 last_mantissa_bit_mask <<= src_mantissa_bits - dest_mantissa_bits;
109
110 // Compute rounding bias for round-to-nearest with ties to even. This is
111 // equal to a base value of 0111... plus one bit if the last remaining
112 // mantissa bit is 1.
113 APInt base_rounding_bias = last_mantissa_bit_mask.lshr(1) - 1;
114 llvm::Value* x_last_mantissa_bit = b->CreateLShr(
115 b->CreateAnd(x_as_int,
116 llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
117 (src_mantissa_bits - dest_mantissa_bits));
118 llvm::Value* x_rounding_bias =
119 b->CreateAdd(x_last_mantissa_bit,
120 llvm::ConstantInt::get(int_type, base_rounding_bias));
121
122 // Add rounding bias, and mask out truncated bits. Note that the case
123 // where adding the rounding bias overflows into the exponent bits is
124 // correct; the non-masked mantissa bits will all be zero, and the
125 // exponent will be incremented by one.
126 APInt truncation_mask = ~(last_mantissa_bit_mask - 1);
127 llvm::Value* x_rounded = b->CreateAdd(x_as_int, x_rounding_bias);
128 x_rounded = b->CreateAnd(x_rounded,
129 llvm::ConstantInt::get(int_type, truncation_mask));
130 if (quiet_nans) {
131 x_as_int = b->CreateSelect(x_is_nan, x_as_int, x_rounded);
132 } else {
133 x_as_int = x_rounded;
134 }
135 }
136
137 if (dest_exponent_bits < src_exponent_bits) {
138 // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
139 // significant bit -- is equal to 1.0f for all exponent sizes. Adding
140 // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
141 // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
142 // exponent (corresponding to 0.0f).
143 //
144 // Thus, the f32 exponent corresponding to the highest non-infinite
145 // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
146 // exponent corresponding to the lowest exponent for a bit size of n is
147 // (2^7-1) - 2^(n-1)-1.
148 //
149 // Note that we have already checked that exponents_bits >= 1.
150 APInt exponent_bias(nbits, 1);
151 exponent_bias = (exponent_bias << (src_exponent_bits - 1)) - 1;
152
153 APInt reduced_exponent_bias(nbits, 1);
154 reduced_exponent_bias =
155 (reduced_exponent_bias << (dest_exponent_bits - 1)) - 1;
156
157 APInt reduced_max_exponent = exponent_bias + reduced_exponent_bias;
158 APInt reduced_min_exponent = exponent_bias - reduced_exponent_bias;
159
160 // Do we overflow or underflow?
161 llvm::Value* x_exponent =
162 b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, exp_bits_mask));
163 llvm::Value* x_overflows = b->CreateICmpUGT(
164 x_exponent, llvm::ConstantInt::get(
165 int_type, reduced_max_exponent << src_mantissa_bits));
166 llvm::Value* x_underflows = b->CreateICmpULE(
167 x_exponent, llvm::ConstantInt::get(
168 int_type, reduced_min_exponent << src_mantissa_bits));
169
170 // Compute appropriately-signed values of zero and infinity.
171 llvm::Value* x_signed_zero =
172 b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, sign_bit_mask));
173 llvm::Value* x_signed_inf = b->CreateOr(
174 x_signed_zero, llvm::ConstantInt::get(int_type, exp_bits_mask));
175
176 // Force to zero or infinity if overflow or underflow. (Note that this
177 // truncates all denormal values to zero, rather than rounding them.)
178 x_as_int = b->CreateSelect(x_overflows, x_signed_inf, x_as_int);
179 x_as_int = b->CreateSelect(x_underflows, x_signed_zero, x_as_int);
180 }
181
182 // Cast the result back to a floating-point type.
183 llvm::Value* result = b->CreateBitCast(x_as_int, float_type);
184
185 // Correct result for NaN inputs.
186 //
187 // The exponent handling will "normalize" NaN values to infinities, which is
188 // undesirable (except in the case with no mantissa bits, in which case it
189 // is mandatory). This logic also handles cases where mantissa-rounding
190 // causes a NaN's mantissa to overflow into the exponent bits, which would
191 // otherwise create an erroneous zero value.
192
193 if (dest_mantissa_bits > 0) {
194 if (quiet_nans) {
195 APInt qnan_mask(nbits, 1);
196 qnan_mask <<= src_mantissa_bits - 1;
197 llvm::Value* x_with_qnan_bit_set =
198 b->CreateOr(x_as_int, llvm::ConstantInt::get(int_type, qnan_mask));
199 x_with_qnan_bit_set = b->CreateBitCast(x_with_qnan_bit_set, float_type);
200 result = b->CreateSelect(x_is_nan, x_with_qnan_bit_set, result);
201 } else {
202 result = b->CreateSelect(x_is_nan, x, result);
203 }
204 } else {
205 result = b->CreateSelect(x_is_nan,
206 llvm::ConstantFP::getInfinity(float_type), result);
207 }
208
209 return result;
210 }
211
EmitF32ToBF16(llvm::Value * f32_value,llvm::IRBuilder<> * b)212 StatusOr<llvm::Value*> EmitF32ToBF16(llvm::Value* f32_value,
213 llvm::IRBuilder<>* b) {
214 TF_ASSIGN_OR_RETURN(
215 auto reduced_precision,
216 EmitReducePrecisionIR(
217 /*src_ty=*/F32, f32_value,
218 /*dest_exponent_bits=*/primitive_util::ExponentWidth(BF16),
219 /*dest_mantissa_bits=*/primitive_util::SignificandWidth(BF16) - 1,
220 /*quiet_nans=*/true, b));
221 auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty());
222 auto shifted = b->CreateLShr(as_int32, 16);
223 auto truncated = b->CreateTrunc(shifted, b->getInt16Ty());
224 return b->CreateBitCast(truncated, b->getInt16Ty());
225 }
226
EmitBF16ToF32(llvm::Value * bf16_value,llvm::IRBuilder<> * b)227 llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, llvm::IRBuilder<>* b) {
228 auto as_int16 = b->CreateBitCast(bf16_value, b->getInt16Ty());
229 auto as_int32 = b->CreateZExt(as_int16, b->getInt32Ty());
230 auto shifted = b->CreateShl(as_int32, 16);
231 return b->CreateBitCast(shifted, b->getFloatTy());
232 }
233
EmitIntegralToFloating(llvm::Value * integer_value,PrimitiveType from_type,PrimitiveType to_type,llvm::Module * module,llvm::IRBuilder<> * b)234 llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
235 PrimitiveType from_type,
236 PrimitiveType to_type, llvm::Module* module,
237 llvm::IRBuilder<>* b) {
238 if (primitive_util::IsSignedIntegralType(from_type)) {
239 return b->CreateSIToFP(integer_value,
240 llvm_ir::PrimitiveTypeToIrType(to_type, module));
241 } else {
242 CHECK(primitive_util::IsUnsignedIntegralType(from_type) ||
243 from_type == PRED);
244 return b->CreateUIToFP(integer_value,
245 llvm_ir::PrimitiveTypeToIrType(to_type, module));
246 }
247 }
248
249 } // namespace
250
EmitUnaryOp(const HloInstruction * op,llvm::Value * operand_value)251 StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
252 const HloInstruction* op, llvm::Value* operand_value) {
253 if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
254 op->operand(0)->shape().element_type() == PRED) {
255 return EmitIntegerUnaryOp(op, operand_value);
256 } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
257 return EmitComplexUnaryOp(op, operand_value);
258 } else {
259 return EmitFloatUnaryOp(op, operand_value);
260 }
261 }
262
EmitIntegerUnaryOp(const HloInstruction * op,llvm::Value * operand_value)263 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
264 const HloInstruction* op, llvm::Value* operand_value) {
265 switch (op->opcode()) {
266 case HloOpcode::kConvert: {
267 PrimitiveType from_type = op->operand(0)->shape().element_type();
268 PrimitiveType to_type = op->shape().element_type();
269 CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED)
270 << from_type;
271 if (from_type == to_type) {
272 return operand_value;
273 }
274 if (to_type == PRED) {
275 return b_->CreateZExt(
276 ICmpNE(operand_value,
277 llvm::ConstantInt::get(operand_value->getType(), 0)),
278 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
279 }
280 if (primitive_util::IsIntegralType(to_type)) {
281 return IntCast(operand_value,
282 llvm_ir::PrimitiveTypeToIrType(to_type, module_),
283 primitive_util::IsSignedIntegralType(from_type));
284 }
285 if (primitive_util::IsFloatingPointType(to_type)) {
286 if (to_type == BF16) {
287 return EmitF32ToBF16(EmitIntegralToFloating(operand_value, from_type,
288 F32, module_, b_),
289 b_);
290 }
291 return EmitIntegralToFloating(operand_value, from_type, to_type,
292 module_, b_);
293 }
294 if (primitive_util::IsComplexType(to_type)) {
295 auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
296 primitive_util::ComplexComponentType(to_type), module_);
297 if (primitive_util::IsSignedIntegralType(from_type)) {
298 return EmitComposeComplex(
299 op, SIToFP(operand_value, to_ir_component_type), nullptr);
300 }
301 if (primitive_util::IsUnsignedIntegralType(from_type) ||
302 from_type == PRED) {
303 return EmitComposeComplex(
304 op, UIToFP(operand_value, to_ir_component_type), nullptr);
305 }
306 }
307 return Unimplemented("conversion from primitive type %s to %s",
308 PrimitiveType_Name(from_type),
309 PrimitiveType_Name(to_type));
310 }
311 case HloOpcode::kBitcastConvert: {
312 PrimitiveType from_type = op->operand(0)->shape().element_type();
313 PrimitiveType to_type = op->shape().element_type();
314 CHECK(primitive_util::IsIntegralType(from_type));
315 if (from_type == to_type) {
316 return operand_value;
317 }
318 if (primitive_util::BitWidth(from_type) ==
319 primitive_util::BitWidth(to_type)) {
320 return BitCast(operand_value,
321 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
322 }
323 return InvalidArgument(
324 "bitcast conversion from primitive type %s to %s with unequal "
325 "bit-widths (%u versus %u) ",
326 PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
327 primitive_util::BitWidth(from_type),
328 primitive_util::BitWidth(to_type));
329 }
330 case HloOpcode::kAbs: {
331 bool is_signed =
332 primitive_util::IsSignedIntegralType(op->shape().element_type());
333 if (is_signed) {
334 auto type =
335 llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
336 auto cmp = ICmpSGE(operand_value, GetZero(type));
337 return Select(cmp, operand_value, Neg(operand_value));
338 } else {
339 return operand_value;
340 }
341 }
342 case HloOpcode::kClz: {
343 auto is_zero_undef = b_->getFalse();
344 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctlz,
345 {operand_value, is_zero_undef},
346 {operand_value->getType()}, b_);
347 }
348 case HloOpcode::kSign: {
349 CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type()))
350 << op->shape().element_type();
351 auto type =
352 llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
353 auto cmp = ICmpEQ(operand_value, GetZero(type));
354 auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1);
355 return Select(cmp, GetZero(type), Or(ashr, 1));
356 }
357 case HloOpcode::kNegate:
358 return Neg(operand_value);
359 case HloOpcode::kNot: {
360 auto type = op->shape().element_type();
361 if (type == PRED) {
362 // It is not sufficient to just call CreateNot() here because a PRED
363 // is represented as an i8 and the truth value is stored only in the
364 // bottom bit.
365 return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())),
366 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
367 } else if (primitive_util::IsIntegralType(type)) {
368 return Not(operand_value);
369 }
370 return Unimplemented("unary op Not is not defined for type '%d'", type);
371 }
372 case HloOpcode::kPopulationCount: {
373 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctpop,
374 {operand_value},
375 {operand_value->getType()}, b_);
376 }
377 default:
378 return Unimplemented("unary integer op '%s'",
379 HloOpcodeString(op->opcode()));
380 }
381 }
382
EmitFloatUnaryOp(const HloInstruction * op,llvm::Value * operand_value)383 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
384 const HloInstruction* op, llvm::Value* operand_value) {
385 switch (op->opcode()) {
386 case HloOpcode::kConvert: {
387 PrimitiveType from_type = op->operand(0)->shape().element_type();
388 PrimitiveType to_type = op->shape().element_type();
389 CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type;
390 if (from_type == to_type) {
391 return operand_value;
392 }
393 if (from_type == BF16) {
394 TF_RET_CHECK(to_type != BF16);
395 operand_value = EmitBF16ToF32(operand_value, b_);
396 from_type = F32;
397 if (from_type == to_type) {
398 return operand_value;
399 }
400 }
401 if (primitive_util::IsComplexType(to_type)) {
402 PrimitiveType to_component_type =
403 primitive_util::ComplexComponentType(to_type);
404 if (from_type == to_component_type) {
405 return EmitComposeComplex(op, operand_value, nullptr);
406 }
407 return EmitComposeComplex(
408 op,
409 FPCast(operand_value,
410 llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
411 nullptr);
412 }
413 if (to_type == BF16) {
414 // Cast to F32 first. Other floating point formats are not supported by
415 // EmitReducePrecisionIR.
416 if (from_type != F32) {
417 operand_value = b_->CreateFPCast(
418 operand_value, llvm_ir::PrimitiveTypeToIrType(F32, module_));
419 }
420 return EmitF32ToBF16(operand_value, b_);
421 }
422 if (to_type == PRED) {
423 return b_->CreateZExt(
424 FCmpUNE(operand_value,
425 llvm::ConstantFP::get(operand_value->getType(), 0.0)),
426 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
427 }
428 auto* to_ir_type = llvm_ir::PrimitiveTypeToIrType(to_type, module_);
429 if (primitive_util::IsFloatingPointType(to_type)) {
430 return FPCast(operand_value, to_ir_type);
431 }
432 auto* from_ir_type = llvm_ir::PrimitiveTypeToIrType(from_type, module_);
433 int to_width = primitive_util::BitWidth(to_type);
434 if (primitive_util::IsSignedIntegralType(to_type)) {
435 int64_t min_int = llvm::minIntN(to_width);
436 int64_t max_int = llvm::maxIntN(to_width);
437 auto zero_int = llvm::ConstantInt::get(to_ir_type, 0);
438 auto min_value_int = llvm::ConstantInt::get(to_ir_type, min_int);
439 auto max_value_int = llvm::ConstantInt::get(to_ir_type, max_int);
440 auto min_value_float = llvm::ConstantFP::get(from_ir_type, min_int);
441 auto max_value_float = llvm::ConstantFP::get(from_ir_type, max_int);
442 auto clamped = FPToSI(operand_value,
443 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
444 // x <= static_cast<float>(INT_MIN) ? INT_MIN : ...
445 clamped = Select(FCmpOLE(operand_value, min_value_float), min_value_int,
446 clamped);
447 // x >= static_cast<float>(INT_MAX) ? INT_MAX : ...
448 clamped = Select(FCmpOGE(operand_value, max_value_float), max_value_int,
449 clamped);
450 // isnan(x) ? 0 : ...
451 clamped =
452 Select(FCmpUNO(operand_value, operand_value), zero_int, clamped);
453 return clamped;
454 }
455 if (primitive_util::IsUnsignedIntegralType(to_type)) {
456 uint64_t min_int = 0;
457 uint64_t max_int = llvm::maxUIntN(to_width);
458 auto min_value_int = llvm::ConstantInt::get(to_ir_type, min_int);
459 auto max_value_int = llvm::ConstantInt::get(to_ir_type, max_int);
460 auto min_value_float = llvm::ConstantFP::get(from_ir_type, min_int);
461 auto max_value_float = llvm::ConstantFP::get(from_ir_type, max_int);
462 auto clamped = FPToUI(operand_value,
463 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
464 // (x <= 0.0 || isnan(x)) ? 0 : ...
465 clamped = Select(FCmpULE(operand_value, min_value_float), min_value_int,
466 clamped);
467 // x >= static_cast<float>(UINT_MAX) ? UINT_MAX : ...
468 clamped = Select(FCmpOGE(operand_value, max_value_float), max_value_int,
469 clamped);
470 return clamped;
471 }
472 return Unimplemented("unhandled conversion operation: %s => %s",
473 PrimitiveType_Name(from_type),
474 PrimitiveType_Name(to_type));
475 }
476 case HloOpcode::kBitcastConvert: {
477 PrimitiveType from_type = op->operand(0)->shape().element_type();
478 PrimitiveType to_type = op->shape().element_type();
479 CHECK(primitive_util::IsFloatingPointType(from_type));
480 if (from_type == to_type) {
481 return operand_value;
482 }
483 if (primitive_util::BitWidth(from_type) ==
484 primitive_util::BitWidth(to_type)) {
485 return BitCast(operand_value,
486 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
487 }
488 return InvalidArgument(
489 "bitcast conversion from primitive type %s to %s with unequal "
490 "bit-widths (%u versus %u) ",
491 PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
492 primitive_util::BitWidth(from_type),
493 primitive_util::BitWidth(to_type));
494 }
495 case HloOpcode::kExp:
496 return EmitExp(op->shape().element_type(), operand_value, "");
497 case HloOpcode::kExpm1:
498 return EmitExpm1(op->shape().element_type(), operand_value);
499 case HloOpcode::kLog:
500 return EmitLog(op->shape().element_type(), operand_value);
501 case HloOpcode::kLog1p:
502 return EmitLog1p(op->shape().element_type(), operand_value);
503 case HloOpcode::kCos:
504 return EmitCos(op->shape().element_type(), operand_value);
505 case HloOpcode::kSin:
506 return EmitSin(op->shape().element_type(), operand_value);
507 case HloOpcode::kTanh:
508 return EmitTanh(op->shape().element_type(), operand_value);
509 case HloOpcode::kSqrt:
510 return EmitSqrt(op->shape().element_type(), operand_value);
511 case HloOpcode::kRsqrt:
512 return EmitRsqrt(op->shape().element_type(), operand_value);
513 case HloOpcode::kCbrt:
514 return EmitCbrt(op->shape().element_type(), operand_value);
515 case HloOpcode::kFloor:
516 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
517 {operand_value},
518 {operand_value->getType()}, b_);
519 case HloOpcode::kCeil:
520 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ceil,
521 {operand_value},
522 {operand_value->getType()}, b_);
523 case HloOpcode::kAbs:
524 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
525 {operand_value},
526 {operand_value->getType()}, b_);
527 case HloOpcode::kRoundNearestAfz:
528 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round,
529 {operand_value},
530 {operand_value->getType()}, b_);
531 case HloOpcode::kSign: {
532 auto type = operand_value->getType();
533 auto zero = llvm::ConstantFP::get(type, 0.0);
534 auto ne0_i1 = FCmpONE(operand_value, zero);
535 auto ne0_float = UIToFP(ne0_i1, type);
536 llvm::Value* result = llvm_ir::EmitCallToIntrinsic(
537 llvm::Intrinsic::copysign, {ne0_float, operand_value},
538 {operand_value->getType()}, b_);
539 auto is_nan = FCmpUNO(operand_value, operand_value);
540 result = Select(is_nan, operand_value, result);
541 return result;
542 }
543 case HloOpcode::kIsFinite: {
544 // abs(x) o!= inf, this works because the comparison returns false if
545 // either operand is NaN.
546 auto type = operand_value->getType();
547 auto abs_value = llvm_ir::EmitCallToIntrinsic(
548 llvm::Intrinsic::fabs, {operand_value}, {type}, b_);
549 auto infinity = llvm::ConstantFP::getInfinity(type);
550 auto not_infinite = FCmpONE(abs_value, infinity);
551 return b_->CreateZExt(not_infinite,
552 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
553 }
554 case HloOpcode::kNegate:
555 return FNeg(operand_value);
556 case HloOpcode::kReal:
557 return operand_value;
558 case HloOpcode::kImag:
559 return llvm::ConstantFP::get(operand_value->getType(), 0.0);
560 default:
561 return Unimplemented("unary floating-point op '%s'",
562 HloOpcodeString(op->opcode()));
563 }
564 }
565
EmitComplexUnaryOp(const HloInstruction * op,llvm::Value * operand_value)566 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
567 const HloInstruction* op, llvm::Value* operand_value) {
568 PrimitiveType input_type = op->operand(0)->shape().element_type();
569 PrimitiveType component_type =
570 primitive_util::IsComplexType(input_type)
571 ? primitive_util::ComplexComponentType(input_type)
572 : input_type;
573 switch (op->opcode()) {
574 case HloOpcode::kLog: {
575 return EmitComplexLog(op, operand_value);
576 }
577 case HloOpcode::kLog1p: {
578 // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
579 // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
580 // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
581 auto a = EmitExtractReal(operand_value);
582 auto b = EmitExtractImag(operand_value);
583 llvm::Type* llvm_ty = a->getType();
584 auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
585 auto two = llvm::ConstantFP::get(llvm_ty, 2.0);
586 auto a_plus_one = FAdd(a, one);
587 auto sum_sq = FAdd(FAdd(FMul(a, a), FMul(two, a)), FMul(b, b));
588 TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog1p(component_type, sum_sq));
589 TF_ASSIGN_OR_RETURN(auto angle,
590 EmitAtan2(component_type, b, a_plus_one, ""));
591 auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
592 return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
593 }
594 case HloOpcode::kConvert: {
595 PrimitiveType from_type = op->operand(0)->shape().element_type();
596 TF_RET_CHECK(primitive_util::IsComplexType(from_type));
597 PrimitiveType to_type = op->shape().element_type();
598 TF_RET_CHECK(primitive_util::IsComplexType(to_type));
599 if (from_type == to_type) {
600 return operand_value;
601 }
602 PrimitiveType to_component_type =
603 primitive_util::ComplexComponentType(to_type);
604 auto to_ir_component_type =
605 llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
606 return EmitComposeComplex(
607 op, FPCast(EmitExtractReal(operand_value), to_ir_component_type),
608 FPCast(EmitExtractImag(operand_value), to_ir_component_type));
609 }
610 case HloOpcode::kExp: {
611 // e^(a+bi) = e^a*(cos(b)+sin(b)i)
612 TF_ASSIGN_OR_RETURN(
613 auto exp_a,
614 EmitExp(component_type, EmitExtractReal(operand_value), ""));
615 TF_ASSIGN_OR_RETURN(
616 auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
617 TF_ASSIGN_OR_RETURN(
618 auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
619 return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b));
620 }
621 case HloOpcode::kExpm1: {
622 // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
623 TF_ASSIGN_OR_RETURN(
624 auto exp_a,
625 EmitExp(component_type, EmitExtractReal(operand_value), ""));
626 TF_ASSIGN_OR_RETURN(
627 auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
628 TF_ASSIGN_OR_RETURN(
629 auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
630 auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0);
631 auto real_result = FSub(FMul(exp_a, cos_b), one);
632 auto imag_result = FMul(exp_a, sin_b);
633 return EmitComposeComplex(op, real_result, imag_result);
634 }
635 case HloOpcode::kCos: {
636 // cos(z) = .5(e^(iz) + e^(-iz))
637 // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
638 // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
639 // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i))
640 // cos(-x) = cos(x) and sin(-x) = -sin(x), so
641 // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
642 // = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
643 auto a = EmitExtractReal(operand_value);
644 auto b = EmitExtractImag(operand_value);
645 auto type = a->getType();
646 TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
647 auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
648 auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
649 TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
650 TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
651 return EmitComposeComplex(op,
652 FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)),
653 FMul(sin_a, FSub(half_exp_neg_b, half_exp_b)));
654 }
655 case HloOpcode::kSin: {
656 // sin(z) = .5i(e^(-iz) - e^(iz))
657 // sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi)))
658 // = .5i(e^(b-ai) - e^(-b+ai))
659 // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
660 // sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i))
661 // = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a)))
662 // cos(-x) = cos(x) and sin(-x) = -sin(x), so
663 // = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
664 // = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
665 auto a = EmitExtractReal(operand_value);
666 auto b = EmitExtractImag(operand_value);
667 auto type = a->getType();
668 TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
669 auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
670 auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
671 TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
672 TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
673 return EmitComposeComplex(op,
674 FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)),
675 FMul(cos_a, FSub(half_exp_b, half_exp_neg_b)));
676 }
677 case HloOpcode::kTanh: {
678 /*
679 tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x))
680 e^(a+bi) = e^a*(cos(b)+sin(b)i)
681 so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) /
682 (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a))
683 cos(b)=cos(-b), sin(-b)=-sin(b)
684 so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) /
685 (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a))
686 =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) /
687 (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a))
688 =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) /
689 (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a))
690 This is a complex division, so we can multiply by denom_conj/denom_conj
691 =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) *
692 (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) /
693 ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
694 =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
695 i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
696 ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
697 =(e^(2a)-e^(-2a) +
698 i*[cos(b)sin(b)(e^(2a)+2+e^(-2a))-cos(b)sin(b)(e^(2a)-2+e^(2a)))]
699 / (cos(b)^2*(e^(2a)+2+e^(-2a)) + sin(b)^2*(e^(2a)-2+e^(2a))
700 =(e^(2a)-e^(-2a) +
701 i*cos(b)sin(b)*[e^(2a)+2+e^(-2a)-e^(2a)+2-e^(-2a)]) /
702 ([cos(b)^2 + sin(b)^2][e^(2a)+e^(-2a)])+2*[cos(b)^2 - sin(b)^2])
703 =(e^(2a)-e^(-2a) + i*cos(b)sin(b)*4) /
704 (e^(2a)+e^(-2a)+2*[cos(b)^2 - sin(b)^2])
705 =(e^(2a)-e^(-2a) + i*[sin(2b)/2]*4) /
706 (e^(2a)+e^(-2a)+2*[cos(2b)])
707 =(e^(2a)-e^(-2a) + i*2*sin(2b)) / (e^(2a) + e^(-2a) + 2*cos(2b))
708 */
709 llvm::Value* a = EmitExtractReal(operand_value);
710 llvm::Value* b = EmitExtractImag(operand_value);
711
712 llvm::Type* type = a->getType();
713
714 llvm::Value* neg_one = llvm::ConstantFP::get(type, -1.F);
715 llvm::Value* two_a = FAdd(a, a);
716 llvm::Value* neg_2a = FMul(neg_one, two_a);
717
718 // When we are calculating the real numerator, e^(2a)-e^(-2a), for small
719 // values of `a`, we will get a ULP of 2^-23 using the exp function. Using
720 // expm1 to calculate e^(2a)-e^(-2a) = [e^(2a)-1] - [e^(-2a)-1] allows our
721 // ULP to be arbitrarily small. For larger values of `a`, calculating the
722 // numerator as Exp(2a)-Exp(-2a) vs Expm1(2a)-Expm1(-2a) return virtually
723 // identical results.
724 TF_ASSIGN_OR_RETURN(llvm::Value * exp_2a_m1,
725 EmitExpm1(component_type, two_a));
726 TF_ASSIGN_OR_RETURN(llvm::Value * exp_neg_2a_m1,
727 EmitExpm1(component_type, neg_2a));
728 llvm::Value* real_numerator = FSub(exp_2a_m1, exp_neg_2a_m1);
729
730 // We can use the identity cos(2b)+1 = cos(b)^2-sin(b)^2+cos(b)^2+sin(b)^2
731 // = 2cos(b)^2. This gives us the ability to be more precise when the
732 // denominator is close to zero.
733 TF_ASSIGN_OR_RETURN(llvm::Value * cos_b, EmitCos(component_type, b));
734 llvm::Value* four = llvm::ConstantFP::get(type, 4.F);
735 llvm::Value* cos_b_sq = FMul(cos_b, cos_b);
736 llvm::Value* two_cos_2b_p2 = FMul(cos_b_sq, four);
737
738 // Similarly we can compute sin(2b) with the formula sin(2b) =
739 // 2*sin(b)*cos(b).
740 TF_ASSIGN_OR_RETURN(llvm::Value * sin_b, EmitSin(component_type, b));
741 llvm::Value* imag_numerator = FMul(four, FMul(cos_b, sin_b));
742
743 // Expm1(x) is about x for small values of x, but exp_sum_m2 is about x^2
744 // for small value of x. As a result, due to floating point precision
745 // issues, x^2 is a better approximation than Expm1(x) + Expm1(x) for
746 // small values of x.
747 llvm::Value* a_sqr = FMul(a, a);
748 llvm::Value* use_approx_cutoff = llvm::ConstantFP::get(type, 1e-8);
749 llvm::Value* use_approx = FCmpOLT(a_sqr, use_approx_cutoff);
750
751 llvm::Value* exp_sum_m2 =
752 Select(use_approx, a_sqr, FAdd(exp_2a_m1, exp_neg_2a_m1));
753 llvm::Value* denom = FAdd(exp_sum_m2, two_cos_2b_p2);
754
755 // As `a` grows toward +inf and -inf, the real numerator will grow towards
756 // +inf and -inf respectively, while the denominator will always grow
757 // towards +inf. The result is real_numerator/denom = NaN, when it should
758 // equal +1 and -1 respectively. Therefore, if our denominator is +inf,
759 // we just hardcode the limits for the real numbers.
760 llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
761 llvm::Value* is_inf = FCmpOEQ(exp_sum_m2, inf);
762 llvm::Value* real_limit = llvm_ir::EmitCallToIntrinsic(
763 llvm::Intrinsic::copysign, {neg_one, a}, {type}, b_);
764
765 llvm::Value* real =
766 Select(is_inf, real_limit, FDiv(real_numerator, denom));
767 llvm::Value* imag = FDiv(imag_numerator, denom);
768
769 // The complex tanh functions have a few corner cases:
770 // 1. (+0, +0) => (+0, +0) - Handled normally
771 // 2. (x, +Inf) => (NaN, NaN) - See below
772 // 3. (x, NaN) => (NaN, NaN) - See below
773 // 4. (+inf, y) => (1, +0) - Handled normally
774 // 5. (+Inf, +Inf) => (1, +/-0) - See below
775 // 6. (+Inf, NaN) => (1, +/-0) - See below
776 // 7. (NaN, +0) => (NaN, +0) - See below
777 // 8. (NaN, y) => (NaN, NaN) - Handled normally
778 // 9. (NaN, NaN) => (NaN, NaN) - Handled normally
779 //
780 // For the cases that aren't handled normally:
781 // 2/3) Part of the calculation we do is that if exp(a) + exp(-a) = +inf,
782 // then we return (+/-1, +/-0). However, this is only true if we
783 // assume that a is infinity or b is finite. In the event that both a
784 // is finite and b is either +/-Inf or NaN, then our normal
785 // calculation would end up returing (+/-1, NaN), as opposed to (NaN,
786 // NaN).
787 // 5/6) We always calculate the imaginary value as sin(2b)/denominator.
788 // When the denominator is infinity, this assures us that the zero is
789 // the correct sign. However if our imaginary input results in
790 // sin(2b) = NaN, we calculate our imaginary result as NaN.
791 // 7) In the event that a is NaN, the denominator will be NaN.
792 // Therefore, the normal calculation gives (NaN, NaN) while we need
793 // (NaN, +0).
794 if (!(b_->getFastMathFlags().noNaNs() &&
795 b_->getFastMathFlags().noInfs())) {
796 llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
797 {a}, {type}, b_);
798 llvm::Value* zero = llvm::ConstantFP::get(type, 0.F);
799 llvm::Value* nan = llvm::ConstantFP::getNaN(type);
800
801 llvm::Value* a_is_inf = FCmpOEQ(abs_a, inf);
802 llvm::Value* b_is_zero = FCmpOEQ(b, zero);
803
804 // imag_numerator = 2sin(2b), so sin(2b) is NaN if and only if
805 // imag_numerator is NaN.
806 llvm::Value* sin_2b_is_nan =
807 b_->CreateFCmpUNO(imag_numerator, imag_numerator);
808
809 llvm::Value* real_is_nan =
810 b_->CreateAnd(sin_2b_is_nan, b_->CreateNot(a_is_inf));
811 llvm::Value* imag_is_zero =
812 b_->CreateOr(b_is_zero, b_->CreateAnd(a_is_inf, sin_2b_is_nan));
813
814 real = Select(real_is_nan, nan, real);
815 imag = Select(imag_is_zero, zero, imag);
816 }
817
818 return EmitComposeComplex(op, real, imag);
819 }
820 case HloOpcode::kAbs: {
821 return EmitComplexAbs(component_type, operand_value);
822 }
823 case HloOpcode::kSign: { // Sign(c) = c / |c|
824 TF_ASSIGN_OR_RETURN(auto cplx_abs,
825 EmitComplexAbs(component_type, operand_value));
826 auto type = cplx_abs->getType();
827 auto zero = llvm::ConstantFP::get(type, 0.0);
828 auto oeq = FCmpOEQ(cplx_abs, zero);
829 return Select(
830 oeq, EmitComposeComplex(op, zero, zero),
831 EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs),
832 FDiv(EmitExtractImag(operand_value), cplx_abs)));
833 }
834 case HloOpcode::kSqrt: {
835 return EmitComplexSqrt(op, component_type, operand_value);
836 }
837 case HloOpcode::kRsqrt: {
838 return EmitComplexRsqrt(op, component_type, operand_value);
839 }
840 case HloOpcode::kCbrt: {
841 return EmitComplexCbrt(op, component_type, operand_value);
842 }
843 case HloOpcode::kNegate:
844 return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
845 FNeg(EmitExtractImag(operand_value)));
846 case HloOpcode::kReal:
847 return EmitExtractReal(operand_value);
848 case HloOpcode::kImag:
849 return EmitExtractImag(operand_value);
850 default:
851 return Unimplemented("unary complex op '%s'",
852 HloOpcodeString(op->opcode()));
853 }
854 }
855
EmitBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)856 StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
857 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
858 PrimitiveType operand_type = op->operand(0)->shape().element_type();
859 if (operand_type == PRED) {
860 return EmitPredBinaryOp(op, lhs_value, rhs_value);
861 } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape())) {
862 return EmitIntegerBinaryOp(
863 op, lhs_value, rhs_value,
864 primitive_util::IsSignedIntegralType(operand_type));
865 } else if (primitive_util::IsComplexType(operand_type)) {
866 return EmitComplexBinaryOp(op, lhs_value, rhs_value);
867 } else {
868 return EmitFloatBinaryOp(op, lhs_value, rhs_value);
869 }
870 }
871
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)872 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
873 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
874 switch (op->opcode()) {
875 case HloOpcode::kComplex:
876 return EmitComposeComplex(op, lhs_value, rhs_value);
877 case HloOpcode::kAdd:
878 return FAdd(lhs_value, rhs_value, op->name());
879 case HloOpcode::kSubtract:
880 return FSub(lhs_value, rhs_value, op->name());
881 case HloOpcode::kMultiply:
882 return FMul(lhs_value, rhs_value, op->name());
883 case HloOpcode::kDivide:
884 return FDiv(lhs_value, rhs_value, op->name());
885 case HloOpcode::kRemainder:
886 return FRem(lhs_value, rhs_value, op->name());
887 // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
888 // comparisons always return false when one of the operands is NaN, whereas
889 // unordered comparisons return true.
890 //
891 // We use ordered comparisons for everything except kNe, where we use an
892 // unordered comparison. This makes x != y equivalent to !(x == y), and
893 // matches C++'s semantics.
894 case HloOpcode::kCompare: {
895 switch (op->comparison_direction()) {
896 case ComparisonDirection::kEq:
897 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
898 rhs_value, b_, op->name());
899 case ComparisonDirection::kNe:
900 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
901 rhs_value, b_, op->name());
902 case ComparisonDirection::kLt:
903 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
904 rhs_value, b_, op->name());
905 case ComparisonDirection::kGt:
906 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
907 rhs_value, b_, op->name());
908 case ComparisonDirection::kLe:
909 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
910 rhs_value, b_, op->name());
911 case ComparisonDirection::kGe:
912 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
913 rhs_value, b_, op->name());
914 }
915 }
916 case HloOpcode::kMaximum:
917 return EmitFloatMax(lhs_value, rhs_value, op->name());
918 case HloOpcode::kMinimum:
919 return EmitFloatMin(lhs_value, rhs_value, op->name());
920 case HloOpcode::kPower:
921 return EmitPow(op->shape().element_type(), lhs_value, rhs_value,
922 op->name());
923 case HloOpcode::kAtan2:
924 return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value,
925 op->name());
926 default:
927 return Unimplemented("binary floating point op '%s'",
928 HloOpcodeString(op->opcode()));
929 }
930 }
931
932 // Using sqrt(a^2 + b^2) can cause overflow errors. Therefore we can use
933 // sqrt(a^2 + b^2) = sqrt(a^2 * (1 + b^2/a^2))
934 // = |a| * sqrt(1 + (b/a)^2)
935 // With the assumption that |a| >= |b|.
936 //
937 // This method returns the min, max, and sqrt term for this calculation. This is
938 // done to prevent potential overflow errors that can occur from multiplying the
939 // max with the sqrt term. (i.e. when calculating the sqrt of the absolute
940 // value, we can take the sqrt of the max and the sqrt term before multiplying
941 // them together.) If return_sqrt is false, it returns 1 + (b/a)^2 instead of
942 // sqrt(1 + (b/a)^2).
943 StatusOr<std::tuple<llvm::Value*, llvm::Value*, llvm::Value*>>
EmitComplexAbsHelper(PrimitiveType prim_type,llvm::Value * operand_value,bool return_sqrt)944 ElementalIrEmitter::EmitComplexAbsHelper(PrimitiveType prim_type,
945 llvm::Value* operand_value,
946 bool return_sqrt) {
947 llvm::Value* real = EmitExtractReal(operand_value);
948 llvm::Value* imag = EmitExtractImag(operand_value);
949 llvm::Value* abs_real = llvm_ir::EmitCallToIntrinsic(
950 llvm::Intrinsic::fabs, {real}, {real->getType()}, b_);
951 llvm::Value* abs_imag = llvm_ir::EmitCallToIntrinsic(
952 llvm::Intrinsic::fabs, {imag}, {imag->getType()}, b_);
953 llvm::Value* max = EmitFloatMax(abs_real, abs_imag, "");
954 llvm::Value* min = EmitFloatMin(abs_real, abs_imag, "");
955
956 llvm::Value* div = FDiv(min, max);
957 llvm::Value* div_sq = FMul(div, div);
958 llvm::Value* one = llvm::ConstantFP::get(max->getType(), 1);
959 llvm::Value* one_p_div_sq = FAdd(one, div_sq);
960 TF_ASSIGN_OR_RETURN(llvm::Value * sqrt, EmitSqrt(prim_type, one_p_div_sq));
961 return std::make_tuple(min, max, return_sqrt ? sqrt : one_p_div_sq);
962 }
963
EmitComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)964 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexAbs(
965 PrimitiveType prim_type, llvm::Value* operand_value) {
966 llvm::Value* min;
967 llvm::Value* max;
968 llvm::Value* sqrt;
969 TF_ASSIGN_OR_RETURN(
970 std::tie(min, max, sqrt),
971 EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
972 llvm::Value* result = FMul(max, sqrt);
973 // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
974 // In such cases, we return `min` instead of `result`.
975 return Select(FCmpUNO(result, result), min, result);
976 }
977
978 // Calculates ComplexAbs in the same way, except using:
979 // sqrt(|a| * sqrt(1 + (b/a)^2)) = sqrt(|a|) * pow(1 + (b/a)^2, .25)
EmitSqrtComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)980 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrtComplexAbs(
981 PrimitiveType prim_type, llvm::Value* operand_value) {
982 llvm::Value* min;
983 llvm::Value* max;
984 llvm::Value* one_p_div_sq;
985 TF_ASSIGN_OR_RETURN(
986 std::tie(min, max, one_p_div_sq),
987 EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/false));
988 TF_ASSIGN_OR_RETURN(llvm::Value * sqrt_max, EmitSqrt(prim_type, max));
989 TF_ASSIGN_OR_RETURN(llvm::Value * pow,
990 EmitPow(prim_type, one_p_div_sq,
991 llvm::ConstantFP::get(max->getType(), .25), ""));
992 llvm::Value* result = FMul(sqrt_max, pow);
993 // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
994 // In such cases, we return `min` instead of `result`.
995 return Select(FCmpUNO(result, result), min, result);
996 }
997
998 // Calculates ComplexAbs in the same way, except using:
999 // rsqrt(|a| * sqrt(1 + (b/a)^2)) = rsqrt(|a|) * rsqrt(sqrt(1 + (b/a)^2))
EmitRsqrtComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)1000 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrtComplexAbs(
1001 PrimitiveType prim_type, llvm::Value* operand_value) {
1002 llvm::Value* min;
1003 llvm::Value* max;
1004 llvm::Value* sqrt;
1005 TF_ASSIGN_OR_RETURN(
1006 std::tie(min, max, sqrt),
1007 EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
1008 TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_max, EmitRsqrt(prim_type, max));
1009 TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_sqrt, EmitRsqrt(prim_type, sqrt));
1010 llvm::Value* result = FMul(rsqrt_max, rsqrt_sqrt);
1011 TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_min, EmitRsqrt(prim_type, min));
1012 // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
1013 // In such cases, we return rsqrt(min) instead of `result`.
1014 return Select(FCmpUNO(result, result), rsqrt_min, result);
1015 }
1016
EmitComplexAdd(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1017 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexAdd(
1018 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1019 return EmitComposeComplex(
1020 op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1021 FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
1022 }
1023
EmitComplexSubtract(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1024 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexSubtract(
1025 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1026 return EmitComposeComplex(
1027 op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1028 FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
1029 }
1030
EmitComplexMultiply(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1031 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexMultiply(
1032 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1033 return EmitComposeComplex(
1034 op,
1035 FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1036 FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))),
1037 FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
1038 FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))));
1039 }
1040
EmitComplexDivide(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1041 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexDivide(
1042 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1043 // Division of complex numbers is implemented here, taking into account
1044 // over/underflow, NaN and Inf values.
1045 auto a_r = EmitExtractReal(lhs_value);
1046 auto a_i = EmitExtractImag(lhs_value);
1047 auto b_r = EmitExtractReal(rhs_value);
1048 auto b_i = EmitExtractImag(rhs_value);
1049 auto type = a_r->getType();
1050
1051 // Smith's algorithm to divide complex numbers. It is just a bit smarter
1052 // way to compute the following formula:
1053 // (a_r + a_i * i) / (b_r + b_i * i)
1054 // = (a_r + a_i * i) (b_r - b_i * i) / ((b_r + b_i * i)(b_r - b_i * i))
1055 // = ((a_r * b_r + a_i * b_i) + (a_i * b_r - a_r * b_i) * i) / ||b||^2
1056 //
1057 // Depending on whether |b_r| < |b_i| we compute either
1058 // b_r_b_i_ratio = b_r / b_i
1059 // b_r_b_i_denom = b_i + b_r * b_r_b_i_ratio
1060 // c_r = (a_r * b_r_b_i_ratio + a_i ) / b_r_b_i_denom
1061 // c_i = (a_i * b_r_b_i_ratio - a_r ) / b_r_b_i_denom
1062 //
1063 // or
1064 //
1065 // b_i_b_r_ratio = b_i / b_r
1066 // b_i_b_r_denom = b_r + b_i * b_i_b_r_ratio
1067 // c_r = (a_r + a_i * b_i_b_r_ratio ) / b_i_b_r_denom
1068 // c_i = (a_i - a_r * b_i_b_r_ratio ) / b_i_b_r_denom
1069 //
1070 // See https://dl.acm.org/citation.cfm?id=368661 for more details.
1071 auto b_r_b_i_ratio = FDiv(b_r, b_i);
1072 auto b_r_b_i_denom = FAdd(b_i, FMul(b_r_b_i_ratio, b_r));
1073 auto b_i_b_r_ratio = FDiv(b_i, b_r);
1074 auto b_i_b_r_denom = FAdd(b_r, FMul(b_i_b_r_ratio, b_i));
1075
1076 auto b_r_abs =
1077 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {b_r}, {type}, b_);
1078 auto b_i_abs =
1079 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {b_i}, {type}, b_);
1080 auto b_r_lt_b_i = FCmpOLT(b_r_abs, b_i_abs);
1081 auto c_r = Select(b_r_lt_b_i,
1082 FDiv(FAdd(FMul(b_r_b_i_ratio, a_r), a_i), b_r_b_i_denom),
1083 FDiv(FAdd(FMul(b_i_b_r_ratio, a_i), a_r), b_i_b_r_denom));
1084 auto c_i = Select(b_r_lt_b_i,
1085 FDiv(FSub(FMul(b_r_b_i_ratio, a_i), a_r), b_r_b_i_denom),
1086 FDiv(FSub(a_i, FMul(b_i_b_r_ratio, a_r)), b_i_b_r_denom));
1087 auto result = EmitComposeComplex(op, c_r, c_i);
1088
1089 // Consider corner cases, if the result is (NaN, NaN).
1090 auto zero = llvm::ConstantFP::get(type, 0.0);
1091 auto one = llvm::ConstantFP::get(type, 1.0);
1092 auto inf = llvm::ConstantFP::getInfinity(type);
1093
1094 // Case 1. Zero denominator.
1095 auto zero_denominator =
1096 And(And(FCmpOEQ(b_r_abs, zero), FCmpOEQ(b_i_abs, zero)),
1097 Or(Not(FCmpUNO(a_r, zero)), Not(FCmpUNO(a_i, zero))));
1098 auto inf_with_sign_of_b_r = llvm_ir::EmitCallToIntrinsic(
1099 llvm::Intrinsic::copysign, {inf, b_r}, {type}, b_);
1100 auto zero_denominator_result = EmitComposeComplex(
1101 op, FMul(inf_with_sign_of_b_r, a_r), FMul(inf_with_sign_of_b_r, a_i));
1102
1103 // Case 2. Infinite numerator, finite denominator.
1104 auto b_r_finite = FCmpONE(b_r_abs, inf);
1105 auto b_i_finite = FCmpONE(b_i_abs, inf);
1106 auto a_r_abs =
1107 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {a_r}, {type}, b_);
1108 auto a_i_abs =
1109 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {a_i}, {type}, b_);
1110 auto a_r_infinite = FCmpOEQ(a_r_abs, inf);
1111 auto a_i_infinite = FCmpOEQ(a_i_abs, inf);
1112 auto inf_num_finite_denom =
1113 And(Or(a_r_infinite, a_i_infinite), And(b_r_finite, b_i_finite));
1114
1115 auto a_r_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1116 llvm::Intrinsic::copysign, {Select(a_r_infinite, one, zero), a_r}, {type},
1117 b_);
1118 auto a_i_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1119 llvm::Intrinsic::copysign, {Select(a_i_infinite, one, zero), a_i}, {type},
1120 b_);
1121 auto inf_num_finite_denom_result = EmitComposeComplex(
1122 op,
1123 FMul(inf,
1124 FAdd(FMul(a_r_inf_with_sign, b_r), FMul(a_i_inf_with_sign, b_i))),
1125 FMul(inf,
1126 FSub(FMul(a_i_inf_with_sign, b_r), FMul(a_r_inf_with_sign, b_i))));
1127
1128 // Case 3. Finite numerator, infinite denominator.
1129 auto a_r_finite = FCmpONE(a_r_abs, inf);
1130 auto a_i_finite = FCmpONE(a_i_abs, inf);
1131 auto b_r_infinite = FCmpOEQ(b_r_abs, inf);
1132 auto b_i_infinite = FCmpOEQ(b_i_abs, inf);
1133 auto finite_num_inf_denom =
1134 And(Or(b_r_infinite, b_i_infinite), And(a_r_finite, a_i_finite));
1135
1136 auto b_r_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1137 llvm::Intrinsic::copysign, {Select(b_r_infinite, one, zero), b_r}, {type},
1138 b_);
1139 auto b_i_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1140 llvm::Intrinsic::copysign, {Select(b_i_infinite, one, zero), b_i}, {type},
1141 b_);
1142 auto finite_num_inf_denom_result = EmitComposeComplex(
1143 op,
1144 FMul(zero,
1145 FAdd(FMul(a_r, b_r_inf_with_sign), FMul(a_i, b_i_inf_with_sign))),
1146 FMul(zero,
1147 FSub(FMul(a_i, b_r_inf_with_sign), FMul(a_r, b_i_inf_with_sign))));
1148
1149 auto c_nan = And(FCmpUNO(c_r, zero), FCmpUNO(c_i, zero));
1150 return Select(c_nan,
1151 Select(zero_denominator, zero_denominator_result,
1152 Select(inf_num_finite_denom, inf_num_finite_denom_result,
1153 Select(finite_num_inf_denom,
1154 finite_num_inf_denom_result, result))),
1155 result);
1156 }
1157
EmitComplexLog(const HloInstruction * op,llvm::Value * operand_value)1158 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexLog(
1159 const HloInstruction* op, llvm::Value* operand_value) {
1160 // log(a+bi) = log(abs(a+bi)) + i*atan2(b,a)
1161 PrimitiveType component_type =
1162 primitive_util::ComplexComponentType(op->shape().element_type());
1163 auto a = EmitExtractReal(operand_value);
1164 auto b = EmitExtractImag(operand_value);
1165 TF_ASSIGN_OR_RETURN(llvm::Value * angle, EmitAtan2(component_type, b, a, ""));
1166 TF_ASSIGN_OR_RETURN(llvm::Value * abs,
1167 EmitComplexAbs(component_type, operand_value));
1168 TF_ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs));
1169 return EmitComposeComplex(op, log_abs, angle);
1170 }
1171
1172 // Using our EmitComplexPower formula, but setting c=0.5 and d=0, we get:
1173 // e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
1174 // = e^[ln(r)*0.5] * [cos(t*0.5) + i*sin(t*0.5)]
1175 // = r^0.5 * [cos(t/2) + i*sin(t/2)]
1176 // = sqrt(r) * [cos(t/2) + i*sin(t/2)]
1177 // where r = |a+bi| and t = atan2(b,a)
1178 // TODO(bixia): See doc for implementation without atan2.
EmitComplexSqrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1179 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexSqrt(
1180 const HloInstruction* op, PrimitiveType prim_type,
1181 llvm::Value* operand_value) {
1182 llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
1183 ->getElementType(0);
1184
1185 TF_ASSIGN_OR_RETURN(llvm::Value * r,
1186 EmitSqrtComplexAbs(prim_type, operand_value));
1187
1188 llvm::Value* a = EmitExtractReal(operand_value);
1189 llvm::Value* b = EmitExtractImag(operand_value);
1190 TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));
1191
1192 llvm::Value* c = llvm::ConstantFP::get(type, 0.5);
1193 llvm::Value* angle = FMul(t, c);
1194 TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
1195 TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
1196
1197 llvm::Value* real_part;
1198 llvm::Value* imag_part;
1199
1200 llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1201
1202 if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
1203 llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1204 llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
1205 llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1206 llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1207 {b}, {b->getType()}, b_);
1208
1209 real_part = Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, inf)), inf,
1210 Select(And(FCmpOEQ(a, neg_inf), FCmpONE(abs_b, inf)),
1211 zero, FMul(r, cos)));
1212
1213 llvm::Value* b_signed_inf = llvm_ir::EmitCallToIntrinsic(
1214 llvm::Intrinsic::copysign, {inf, b}, {b->getType()}, b_);
1215 imag_part =
1216 Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, neg_inf)), b_signed_inf,
1217 Select(FCmpUNO(r, r), nan,
1218 Select(FCmpOEQ(sin, zero), sin, FMul(r, sin))));
1219 } else {
1220 real_part = FMul(r, cos);
1221 imag_part = Select(FCmpOEQ(sin, zero), sin, FMul(r, sin));
1222 }
1223
1224 return Select(FCmpOEQ(r, zero), EmitComposeComplex(op, zero, zero),
1225 EmitComposeComplex(op, real_part, imag_part));
1226 }
1227
1228 // Similar to Sqrt, we can use our EmitComplexPower formula, but set
1229 // c=-0.5 and d=0. We get:
1230 // e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
1231 // = e^[ln(r)*-0.5] * [cos(t*-0.5) + i*sin(t*-0.5)]
1232 // = r^(-0.5) * [cos(-t/2) + i*sin(-t/2)]
1233 // = rsqrt(r) * [cos(-t/2) + i*sin(-t/2)]
1234 // where r = |a+bi| and t = atan2(b,a).
EmitComplexRsqrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1235 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexRsqrt(
1236 const HloInstruction* op, PrimitiveType prim_type,
1237 llvm::Value* operand_value) {
1238 llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
1239 ->getElementType(0);
1240
1241 TF_ASSIGN_OR_RETURN(llvm::Value * r,
1242 EmitRsqrtComplexAbs(prim_type, operand_value));
1243
1244 llvm::Value* a = EmitExtractReal(operand_value);
1245 llvm::Value* b = EmitExtractImag(operand_value);
1246 TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));
1247
1248 llvm::Value* c = llvm::ConstantFP::get(type, -0.5);
1249 llvm::Value* angle = FMul(t, c);
1250 TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
1251 TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
1252
1253 llvm::Value* real_part = FMul(r, cos);
1254 llvm::Value* imag_part = FMul(r, sin);
1255
1256 if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
1257 llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1258 llvm::Value* neg_one = llvm::ConstantFP::get(type, -1);
1259 llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1260 llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1261 // llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
1262 llvm::Value* a_signed_zero = llvm_ir::EmitCallToIntrinsic(
1263 llvm::Intrinsic::copysign, {zero, a}, {a->getType()}, b_);
1264 llvm::Value* b_signed_zero = llvm_ir::EmitCallToIntrinsic(
1265 llvm::Intrinsic::copysign, {zero, b}, {b->getType()}, b_);
1266 llvm::Value* neg_b_signed_zero = FMul(b_signed_zero, neg_one);
1267
1268 llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1269 {a}, {a->getType()}, b_);
1270 llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1271 {b}, {b->getType()}, b_);
1272
1273 llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
1274 real_part = Select(
1275 is_zero_zero, inf,
1276 Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
1277 a_signed_zero, FMul(r, cos)));
1278 imag_part = Select(
1279 is_zero_zero, nan,
1280 Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
1281 neg_b_signed_zero, FMul(r, sin)));
1282 } else {
1283 llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1284 llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1285 llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1286
1287 llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
1288 real_part = Select(is_zero_zero, inf, FMul(r, cos));
1289 imag_part = Select(is_zero_zero, nan, FMul(r, sin));
1290 }
1291
1292 return EmitComposeComplex(op, real_part, imag_part);
1293 }
1294
1295 //
1296 // Using EmitComplexPower with c=1.0/3.0 and d=0
EmitComplexCbrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1297 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexCbrt(
1298 const HloInstruction* op, PrimitiveType prim_type,
1299 llvm::Value* operand_value) {
1300 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1301 auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
1302 auto zero = llvm::ConstantFP::get(type, 0);
1303 llvm::Value* a = EmitExtractReal(operand_value);
1304 llvm::Value* b = EmitExtractImag(operand_value);
1305 return EmitComplexPower(op, a, b, third, zero);
1306 }
1307
1308 // (a+bi)^(c+di) =
1309 // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
1310 // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
EmitComplexPower(const HloInstruction * op,llvm::Value * a,llvm::Value * b,llvm::Value * c,llvm::Value * d)1311 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexPower(
1312 const HloInstruction* op, llvm::Value* a, llvm::Value* b, llvm::Value* c,
1313 llvm::Value* d) {
1314 PrimitiveType component_type =
1315 primitive_util::ComplexComponentType(op->shape().element_type());
1316 auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b));
1317 auto zero = llvm::ConstantFP::get(a->getType(), 0);
1318 auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
1319 auto one = llvm::ConstantFP::get(a->getType(), 1);
1320 auto half_c = FMul(one_half, c);
1321
1322 TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
1323 EmitPow(component_type, aa_p_bb, half_c, ""));
1324
1325 auto neg_d = FNeg(d);
1326 TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a, ""));
1327 auto neg_d_arg_lhs = FMul(neg_d, arg_lhs);
1328 TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
1329 EmitExp(component_type, neg_d_arg_lhs, ""));
1330 auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
1331 TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
1332 auto half_d = FMul(one_half, d);
1333 auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb));
1334 TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
1335 TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
1336 // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
1337 // Branch Cuts for Complex Elementary Functions or Much Ado About
1338 // Nothing's Sign Bit, W. Kahan, Section 10.
1339 return Select(
1340 And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)),
1341 EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero),
1342 EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)));
1343 }
1344
EmitComplexBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1345 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
1346 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1347 switch (op->opcode()) {
1348 case HloOpcode::kAdd:
1349 return EmitComplexAdd(op, lhs_value, rhs_value);
1350 case HloOpcode::kSubtract:
1351 return EmitComplexSubtract(op, lhs_value, rhs_value);
1352 case HloOpcode::kMultiply:
1353 return EmitComplexMultiply(op, lhs_value, rhs_value);
1354 case HloOpcode::kDivide: {
1355 return EmitComplexDivide(op, lhs_value, rhs_value);
1356 }
1357 // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
1358 // comparisons always return false when one of the operands is NaN, whereas
1359 // unordered comparisons return true.
1360 //
1361 // We use ordered comparisons for everything except kNe, where we use an
1362 // unordered comparison. This makes x != y equivalent to !(x == y), and
1363 // matches C++'s semantics.
1364 case HloOpcode::kCompare: {
1365 switch (op->comparison_direction()) {
1366 case ComparisonDirection::kEq:
1367 return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
1368 EmitExtractReal(lhs_value),
1369 EmitExtractReal(rhs_value), b_),
1370 llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
1371 EmitExtractImag(lhs_value),
1372 EmitExtractImag(rhs_value), b_));
1373 case ComparisonDirection::kNe:
1374 return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
1375 EmitExtractReal(lhs_value),
1376 EmitExtractReal(rhs_value), b_),
1377 llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
1378 EmitExtractImag(lhs_value),
1379 EmitExtractImag(rhs_value), b_));
1380 default:
1381 return Unimplemented(
1382 "complex comparison '%s'",
1383 ComparisonDirectionToString(op->comparison_direction()));
1384 }
1385 }
1386 case HloOpcode::kPower: {
1387 auto a = EmitExtractReal(lhs_value);
1388 auto b = EmitExtractImag(lhs_value);
1389 auto c = EmitExtractReal(rhs_value);
1390 auto d = EmitExtractImag(rhs_value);
1391 return EmitComplexPower(op, a, b, c, d);
1392 }
1393 case HloOpcode::kAtan2: {
1394 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
1395 auto y = lhs_value;
1396 auto x = rhs_value;
1397 TF_ASSIGN_OR_RETURN(auto x_squared, EmitComplexMultiply(op, x, x));
1398 TF_ASSIGN_OR_RETURN(auto y_squared, EmitComplexMultiply(op, y, y));
1399 TF_ASSIGN_OR_RETURN(auto x_squared_plus_y_squared,
1400 EmitComplexAdd(op, x_squared, y_squared));
1401 auto component_type =
1402 primitive_util::ComplexComponentType(op->shape().element_type());
1403 TF_ASSIGN_OR_RETURN(
1404 auto sqrt_x_squared_plus_y_squared,
1405 EmitComplexSqrt(op, component_type, x_squared_plus_y_squared));
1406 auto type = llvm_ir::PrimitiveTypeToIrType(component_type, module_);
1407 auto zero = llvm::ConstantFP::get(type, 0.0);
1408 auto one = llvm::ConstantFP::get(type, 1.0);
1409 auto i = EmitComposeComplex(op, zero, one);
1410 TF_ASSIGN_OR_RETURN(auto i_times_y, EmitComplexMultiply(op, i, y));
1411 TF_ASSIGN_OR_RETURN(auto x_plus_iy, EmitComplexAdd(op, x, i_times_y));
1412 TF_ASSIGN_OR_RETURN(
1413 auto div_result,
1414 EmitComplexDivide(op, x_plus_iy, sqrt_x_squared_plus_y_squared));
1415 TF_ASSIGN_OR_RETURN(auto log_result, EmitComplexLog(op, div_result));
1416 auto negative_one = llvm::ConstantFP::get(type, -1.0);
1417 auto negative_i = EmitComposeComplex(op, zero, negative_one);
1418 return EmitComplexMultiply(op, negative_i, log_result);
1419 }
1420 default:
1421 return Unimplemented("binary complex op '%s'",
1422 HloOpcodeString(op->opcode()));
1423 }
1424 }
1425
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value,absl::string_view name)1426 llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
1427 llvm::Value* rhs_value,
1428 absl::string_view name) {
1429 return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max(), name);
1430 }
1431
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value,absl::string_view name)1432 llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
1433 llvm::Value* rhs_value,
1434 absl::string_view name) {
1435 return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max(), name);
1436 }
1437
EmitLog(PrimitiveType prim_type,llvm::Value * value)1438 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
1439 llvm::Value* value) {
1440 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
1441 {value->getType()}, b_);
1442 }
1443
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)1444 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
1445 llvm::Value* value) {
1446 auto x = value;
1447 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1448 auto one = llvm::ConstantFP::get(type, 1.0);
1449 auto negative_half = llvm::ConstantFP::get(type, -0.5);
1450 // When x is large, the naive evaluation of ln(x + 1) is more
1451 // accurate than the Taylor series.
1452 TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
1453 // When x is small, (defined to be less than sqrt(2) / 2), use a rational
1454 // approximation. The approximation below is based on one from the Cephes
1455 // Mathematical Library.
1456 //
1457 // sqrt(2) - 1.
1458 const auto kAntilogarithmIsSmallThreshold = 0.41421356237309504880;
1459
1460 static const std::array<double, 7> kDenominatorCoeffs{
1461 1.,
1462 1.5062909083469192043167E1,
1463 8.3047565967967209469434E1,
1464 2.2176239823732856465394E2,
1465 3.0909872225312059774938E2,
1466 2.1642788614495947685003E2,
1467 6.0118660497603843919306E1,
1468 };
1469
1470 static const std::array<double, 7> kNumeratorCoeffs{
1471 4.5270000862445199635215E-5, 4.9854102823193375972212E-1,
1472 6.5787325942061044846969E0, 2.9911919328553073277375E1,
1473 6.0949667980987787057556E1, 5.7112963590585538103336E1,
1474 2.0039553499201281259648E1,
1475 };
1476
1477 auto x_squared = FMul(x, x);
1478 TF_ASSIGN_OR_RETURN(auto denominator,
1479 EvaluatePolynomial(type, x, kDenominatorCoeffs));
1480 TF_ASSIGN_OR_RETURN(auto numerator,
1481 EvaluatePolynomial(type, x, kNumeratorCoeffs));
1482 auto for_small_x = FDiv(numerator, denominator);
1483 for_small_x = FMul(FMul(x, x_squared), for_small_x);
1484 for_small_x = FAdd(FMul(negative_half, x_squared), for_small_x);
1485 for_small_x = FAdd(x, for_small_x);
1486
1487 auto abs_x =
1488 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1489 auto x_is_small = FCmpOLT(
1490 abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold));
1491 return Select(x_is_small, for_small_x, for_large_x);
1492 }
1493
EmitSqrt(PrimitiveType,llvm::Value * value)1494 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrt(PrimitiveType,
1495 llvm::Value* value) {
1496 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {value},
1497 {value->getType()}, b_);
1498 }
1499
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)1500 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
1501 llvm::Value* value) {
1502 TF_ASSIGN_OR_RETURN(auto sqrt, EmitSqrt(prim_type, value));
1503 return FDiv(llvm::ConstantFP::get(sqrt->getType(), 1.0), sqrt);
1504 }
1505
EmitSin(PrimitiveType prim_type,llvm::Value * value)1506 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
1507 llvm::Value* value) {
1508 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
1509 {value->getType()}, b_);
1510 }
1511
EmitCos(PrimitiveType prim_type,llvm::Value * value)1512 StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
1513 llvm::Value* value) {
1514 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
1515 {value->getType()}, b_);
1516 }
1517
EmitExp(PrimitiveType prim_type,llvm::Value * value,absl::string_view name)1518 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
1519 llvm::Value* value,
1520 absl::string_view name) {
1521 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
1522 {value->getType()}, b_, name);
1523 }
1524
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)1525 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
1526 llvm::Value* value) {
1527 auto x = value;
1528 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1529 auto one = llvm::ConstantFP::get(type, 1.0);
1530 auto half = llvm::ConstantFP::get(type, 0.5);
1531 auto zero = llvm::ConstantFP::get(type, 0.0);
1532
1533 // expm1(x) == tanh(x/2)*(exp(x)+1)
1534 // x/2 can underflow, if it does we approximate expm1 with x.
1535 auto x_over_two = FMul(x, half);
1536 auto x_over_two_is_zero = FCmpOEQ(x_over_two, zero);
1537 auto abs_x =
1538 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {x}, {type}, b_);
1539 // Use a naive exp(x)-1 calculation if |x| is > 0.5
1540 auto x_magnitude_is_large = FCmpOGT(abs_x, half);
1541 TF_ASSIGN_OR_RETURN(auto tanh_of_x_over_two, EmitTanh(prim_type, x_over_two));
1542 TF_ASSIGN_OR_RETURN(auto exp_of_x, EmitExp(prim_type, x, ""));
1543 auto exp_of_x_plus_one = FAdd(exp_of_x, one);
1544 auto exp_of_x_minus_one = FSub(exp_of_x, one);
1545 auto expm1_of_x = FMul(tanh_of_x_over_two, exp_of_x_plus_one);
1546 expm1_of_x = Select(x_magnitude_is_large, exp_of_x_minus_one, expm1_of_x);
1547 expm1_of_x = Select(x_over_two_is_zero, x, expm1_of_x);
1548 return expm1_of_x;
1549 }
1550
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)1551 StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
1552 llvm::Value* lhs,
1553 llvm::Value* rhs,
1554 absl::string_view name) {
1555 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
1556 {lhs->getType()}, b_, name);
1557 }
1558
EmitCbrt(PrimitiveType prim_type,llvm::Value * value)1559 StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
1560 llvm::Value* value) {
1561 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1562 auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
1563 auto abs_value =
1564 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1565 TF_ASSIGN_OR_RETURN(llvm::Value * abs_res,
1566 EmitPow(prim_type, abs_value, third, ""));
1567 auto signed_res = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
1568 {abs_res, value}, {type}, b_);
1569 return signed_res;
1570 }
1571
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value *,absl::string_view)1572 StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(
1573 PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* /*rhs*/,
1574 absl::string_view /*name*/) {
1575 return Unimplemented("atan2");
1576 }
1577
EmitTanh(PrimitiveType prim_type,llvm::Value * value)1578 StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
1579 llvm::Value* value) {
1580 return Unimplemented("tanh");
1581 }
1582
EmitReducePrecision(const HloInstruction * hlo,llvm::Value * x)1583 StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
1584 const HloInstruction* hlo, llvm::Value* x) {
1585 return EmitReducePrecisionIR(
1586 /*src_ty=*/hlo->operand(0)->shape().element_type(), x,
1587 /*dest_exponent_bits=*/hlo->exponent_bits(),
1588 /*dest_mantissa_bits=*/hlo->mantissa_bits(),
1589 /*quiet_nans=*/false, b_);
1590 }
1591
SaturateShiftIfNecessary(llvm::IRBuilder<> * b,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * shift_result,bool saturate_to_sign_bit)1592 static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
1593 llvm::Value* lhs, llvm::Value* rhs,
1594 llvm::Value* shift_result,
1595 bool saturate_to_sign_bit) {
1596 llvm::IntegerType* integer_type =
1597 llvm::cast<llvm::IntegerType>(lhs->getType());
1598 unsigned integer_bitsize = integer_type->getBitWidth();
1599 llvm::ConstantInt* integer_bitsize_constant =
1600 llvm::ConstantInt::get(integer_type, integer_bitsize);
1601 llvm::ConstantInt* zero = llvm::ConstantInt::get(integer_type, 0);
1602 llvm::ConstantInt* minus_one = llvm::ConstantInt::get(integer_type, -1);
1603 llvm::Value* saturated_value;
1604 if (saturate_to_sign_bit) {
1605 saturated_value =
1606 b->CreateSelect(b->CreateICmpSLT(lhs, zero), minus_one, zero);
1607 } else {
1608 saturated_value = zero;
1609 }
1610 llvm::Value* shift_amt_in_range =
1611 b->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk");
1612 return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value);
1613 }
1614
GetOne(llvm::Type * type)1615 llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) {
1616 return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 1);
1617 }
1618
GetZero(llvm::Type * type)1619 llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) {
1620 return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 0);
1621 }
1622
GetIntSMin(llvm::Type * type)1623 llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) {
1624 auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1625 return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue(
1626 integer_type->getBitWidth()));
1627 }
1628
GetMinusOne(llvm::Type * type)1629 llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) {
1630 auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1631 return llvm::ConstantInt::get(
1632 integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth()));
1633 }
1634
IsZero(llvm::Value * v)1635 llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) {
1636 return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0));
1637 }
1638
IsIntMinDivisionOverflow(llvm::Value * lhs,llvm::Value * rhs)1639 llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs,
1640 llvm::Value* rhs) {
1641 return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())),
1642 ICmpEQ(rhs, GetMinusOne(rhs->getType())));
1643 }
1644
EmitIntegerDivide(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1645 llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs,
1646 llvm::Value* rhs,
1647 bool is_signed) {
1648 // Integer division overflow behavior:
1649 //
1650 // X / 0 == -1
1651 // INT_SMIN /s -1 = INT_SMIN
1652
1653 if (!is_signed) {
1654 llvm::Value* udiv_is_unsafe = IsZero(rhs);
1655 llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs);
1656 llvm::Value* safe_div = UDiv(lhs, safe_rhs);
1657 return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div);
1658 }
1659
1660 llvm::Value* has_zero_divisor = IsZero(rhs);
1661 llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1662 llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1663 llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs);
1664 llvm::Value* safe_div = SDiv(lhs, safe_rhs);
1665
1666 return Select(
1667 has_zero_divisor, GetMinusOne(lhs->getType()),
1668 Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div));
1669 }
1670
EmitIntegerRemainder(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1671 llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
1672 llvm::Value* rhs,
1673 bool is_signed) {
1674 // Integer remainder overflow behavior:
1675 //
1676 // X % 0 == X
1677 // INT_SMIN %s -1 = 0
1678
1679 if (!is_signed) {
1680 llvm::Value* urem_is_unsafe = IsZero(rhs);
1681 llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs);
1682 llvm::Value* safe_rem = URem(lhs, safe_rhs);
1683 return Select(urem_is_unsafe, lhs, safe_rem);
1684 }
1685
1686 llvm::Value* has_zero_divisor = IsZero(rhs);
1687 llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1688 llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1689 llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs);
1690 llvm::Value* safe_rem = SRem(lhs, safe_rhs);
1691
1692 return Select(
1693 has_zero_divisor, lhs,
1694 Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem));
1695 }
1696
EmitIntegerPow(llvm::Value * base,llvm::Value * exponent,bool is_signed)1697 llvm::Value* ElementalIrEmitter::EmitIntegerPow(llvm::Value* base,
1698 llvm::Value* exponent,
1699 bool is_signed) {
1700 // Exponentiation by squaring:
1701 // https://en.wikipedia.org/wiki/Exponentiation_by_squaring;
1702 int bits = 6; // Everything else would overflow for any exponent > 1, as 2^64
1703 // is the larget possible exponent for a 64-bit integer, and
1704 // that's 1 << 6.
1705 llvm::Value* accumulator = llvm::ConstantInt::get(base->getType(), 1);
1706 llvm::Value* one = llvm::ConstantInt::get(exponent->getType(), 1);
1707 llvm::Value* zero = llvm::ConstantInt::get(exponent->getType(), 0);
1708 llvm::Value* original_base = base;
1709 llvm::Value* original_exponent = exponent;
1710
1711 // Unroll the loop at compile time.
1712 for (int i = 0; i < bits; i++) {
1713 accumulator =
1714 b_->CreateSelect(b_->CreateICmpEQ(b_->CreateAnd(exponent, one), one),
1715 b_->CreateMul(accumulator, base), accumulator);
1716 base = b_->CreateMul(base, base);
1717 exponent = b_->CreateLShr(exponent, 1);
1718 }
1719 return b_->CreateSelect(
1720 b_->CreateICmpSGE(original_exponent, zero), accumulator,
1721 b_->CreateSelect(b_->CreateICmpEQ(original_base, one), one, zero));
1722 }
1723
EmitPredBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1724 StatusOr<llvm::Value*> ElementalIrEmitter::EmitPredBinaryOp(
1725 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1726 // Per the reference interpreter, pred arithmetic should behave like
1727 // `int8(x) OP int8(y) != 0`. For most permitted ops, we can just emit the
1728 // underlying i8 op to achieve this (e.g. kAnd, kOr, kXor, kMultiply). In the
1729 // case of kAdd, we would need to insert a comparison instruction after the
1730 // addition, but it's both easier and faster to emit a bitwise or instruction
1731 // instead.
1732 //
1733 // For several of these ops, a faster bitwise implementation is available, but
1734 // LLVM is unlikely to be able to see it, since it gets IR that e.g. loads i8s
1735 // from memory, multiplies them, and writes the result back, without any
1736 // indication that the inputs were assumed to be 0 or 1. So, just in case,
1737 // help it out by choosing the faster instruction to begin with.
1738 switch (op->opcode()) {
1739 case HloOpcode::kCompare:
1740 case HloOpcode::kXor:
1741 return EmitIntegerBinaryOp(op, lhs_value, rhs_value, false);
1742
1743 // zext(i1 x) + zext(i1 y) != 0 === or(x, y)
1744 // max(zext(i1 x), zext(i1 y)) != 0 === or(x, y)
1745 case HloOpcode::kAdd:
1746 case HloOpcode::kMaximum:
1747 case HloOpcode::kOr:
1748 return Or(lhs_value, rhs_value);
1749
1750 // zext(i1 x) * zext(i1 y) != 0 === and(x, y)
1751 // min(zext(i1 x), zext(i1 y)) != 0 === and(x, y)
1752 case HloOpcode::kMultiply:
1753 case HloOpcode::kMinimum:
1754 case HloOpcode::kAnd:
1755 return And(lhs_value, rhs_value);
1756
1757 // These opcodes are rejected by shape-inference for PRED elements; calling
1758 // them out here serves more as documentation than a necessary check.
1759 case HloOpcode::kDivide:
1760 case HloOpcode::kRemainder:
1761 case HloOpcode::kPower:
1762 case HloOpcode::kSubtract:
1763 case HloOpcode::kShiftLeft:
1764 case HloOpcode::kShiftRightArithmetic:
1765 case HloOpcode::kShiftRightLogical:
1766 return InternalError("Invalid binary op '%s' for pred",
1767 HloOpcodeString(op->opcode()));
1768
1769 default:
1770 return Unimplemented("binary pred op '%s'",
1771 HloOpcodeString(op->opcode()));
1772 }
1773 }
1774
EmitIntegerBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1775 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
1776 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
1777 bool is_signed) {
1778 switch (op->opcode()) {
1779 // TODO(jingyue): add the "nsw" attribute for signed types.
1780 case HloOpcode::kAdd:
1781 return Add(lhs_value, rhs_value);
1782 case HloOpcode::kSubtract:
1783 return Sub(lhs_value, rhs_value);
1784 case HloOpcode::kMultiply:
1785 return Mul(lhs_value, rhs_value);
1786 case HloOpcode::kDivide:
1787 return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
1788 case HloOpcode::kRemainder:
1789 return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
1790 case HloOpcode::kCompare: {
1791 switch (op->comparison_direction()) {
1792 case ComparisonDirection::kEq:
1793 return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
1794 rhs_value, b_);
1795 case ComparisonDirection::kNe:
1796 return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
1797 rhs_value, b_);
1798 case ComparisonDirection::kLt:
1799 return llvm_ir::EmitComparison(
1800 is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
1801 lhs_value, rhs_value, b_);
1802 case ComparisonDirection::kGt:
1803 return llvm_ir::EmitComparison(
1804 is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
1805 lhs_value, rhs_value, b_);
1806 case ComparisonDirection::kLe:
1807 return llvm_ir::EmitComparison(
1808 is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
1809 lhs_value, rhs_value, b_);
1810 case ComparisonDirection::kGe:
1811 return llvm_ir::EmitComparison(
1812 is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
1813 lhs_value, rhs_value, b_);
1814 }
1815 }
1816 case HloOpcode::kMinimum:
1817 return EmitIntegralMin(lhs_value, rhs_value, is_signed);
1818 case HloOpcode::kMaximum:
1819 return EmitIntegralMax(lhs_value, rhs_value, is_signed);
1820 case HloOpcode::kAnd:
1821 return And(lhs_value, rhs_value);
1822 case HloOpcode::kOr:
1823 return Or(lhs_value, rhs_value);
1824 case HloOpcode::kPower:
1825 return EmitIntegerPow(lhs_value, rhs_value, is_signed);
1826 case HloOpcode::kXor:
1827 return Xor(lhs_value, rhs_value);
1828
1829 // Shifting out bits >= the number of bits in the type being shifted
1830 // produces a poison value in LLVM which is basically "deferred undefined
1831 // behavior" -- doing something observable with such a value precipitates
1832 // UB. We replace the poison value with a constant to avoid this deferred
1833 // UB.
1834 case HloOpcode::kShiftRightArithmetic:
1835 return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1836 AShr(lhs_value, rhs_value),
1837 /*saturate_to_sign_bit=*/true);
1838 case HloOpcode::kShiftLeft:
1839 return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1840 Shl(lhs_value, rhs_value),
1841 /*saturate_to_sign_bit=*/false);
1842 case HloOpcode::kShiftRightLogical:
1843 return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1844 LShr(lhs_value, rhs_value),
1845 /*saturate_to_sign_bit=*/false);
1846 default:
1847 return Unimplemented("binary integer op '%s'",
1848 HloOpcodeString(op->opcode()));
1849 }
1850 }
1851
EmitIntegralMax(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1852 llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
1853 llvm::Value* rhs_value,
1854 bool is_signed) {
1855 return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
1856 : llvm::ICmpInst::ICMP_UGE,
1857 lhs_value, rhs_value),
1858 lhs_value, rhs_value);
1859 }
1860
EmitIntegralMin(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1861 llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
1862 llvm::Value* rhs_value,
1863 bool is_signed) {
1864 return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
1865 : llvm::ICmpInst::ICMP_ULE,
1866 lhs_value, rhs_value),
1867 lhs_value, rhs_value);
1868 }
1869
EmitElementalSelect(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1870 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
1871 const HloInstruction* hlo,
1872 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1873 const llvm_ir::IrArray::Index& index) {
1874 TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
1875 operand_to_generator.at(hlo->operand(0))(index));
1876 TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
1877 operand_to_generator.at(hlo->operand(1))(index));
1878 TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
1879 operand_to_generator.at(hlo->operand(2))(index));
1880 return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value,
1881 on_false_value);
1882 }
1883
EmitElementalClamp(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1884 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
1885 const HloInstruction* hlo,
1886 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1887 const llvm_ir::IrArray::Index& index) {
1888 TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
1889 operand_to_generator.at(hlo->operand(0))(index));
1890 TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
1891 operand_to_generator.at(hlo->operand(1))(index));
1892 TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
1893 operand_to_generator.at(hlo->operand(2))(index));
1894 PrimitiveType prim_type = hlo->shape().element_type();
1895 if (primitive_util::IsFloatingPointType(prim_type)) {
1896 return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value, ""), "");
1897 } else if (primitive_util::IsIntegralType(prim_type)) {
1898 bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
1899 return EmitIntegralMin(
1900 max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed);
1901 } else {
1902 return Unimplemented("Clamp unimplemented for %s",
1903 PrimitiveType_Name(prim_type));
1904 }
1905 }
1906
EmitElementalConcatenate(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & source_index)1907 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
1908 const HloInstruction* hlo,
1909 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1910 const llvm_ir::IrArray::Index& source_index) {
1911 const int64 concat_dim = hlo->dimensions(0);
1912 llvm::BasicBlock* init_block = b_->GetInsertBlock();
1913
1914 llvm::BasicBlock* exit_block;
1915 if (b_->GetInsertPoint() != init_block->end()) {
1916 // Inserting into the middle.
1917 CHECK(init_block->getTerminator());
1918 exit_block =
1919 init_block->splitBasicBlock(b_->GetInsertPoint(), IrName(hlo, "merge"));
1920 init_block->getTerminator()->eraseFromParent();
1921 } else {
1922 // Inserting at the end.
1923 CHECK(!init_block->getTerminator());
1924 exit_block = llvm_ir::CreateBasicBlock(
1925 /*insert_before=*/nullptr, IrName(hlo, "merge"), b_);
1926 }
1927
1928 llvm_ir::SetToFirstInsertPoint(exit_block, b_);
1929 llvm::PHINode* output = b_->CreatePHI(
1930 llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
1931 hlo->operands().size());
1932 auto prior_insert_point = b_->GetInsertPoint();
1933
1934 b_->SetInsertPoint(init_block);
1935
1936 // Assign a unique id for each *different* operand, and count how often each
1937 // operand is used. If all operands are different, the usage count will be 1
1938 // for each operand.
1939 absl::flat_hash_map<const HloInstruction*, int64> to_unique_operand_id;
1940 std::vector<int64> operand_usage_count;
1941 for (const HloInstruction* operand : hlo->operands()) {
1942 if (to_unique_operand_id.contains(operand)) {
1943 ++operand_usage_count[to_unique_operand_id[operand]];
1944 } else {
1945 int64 unique_operand_id = to_unique_operand_id.size();
1946 to_unique_operand_id[operand] = unique_operand_id;
1947 operand_usage_count.push_back(1);
1948 }
1949 }
1950
1951 // To avoid that we emit the same operand more than once, we create one basic
1952 // block for each *different* operand with a PHI node for the different source
1953 // index inputs.
1954 std::vector<llvm::BasicBlock*> emit_operand_blocks(
1955 to_unique_operand_id.size(), nullptr);
1956 std::vector<llvm::PHINode*> source_index_phis(to_unique_operand_id.size(),
1957 nullptr);
1958 for (const HloInstruction* operand : hlo->operands()) {
1959 int64 operand_id = to_unique_operand_id[operand];
1960 if (emit_operand_blocks[operand_id] != nullptr) {
1961 continue;
1962 }
1963
1964 emit_operand_blocks[operand_id] = llvm_ir::CreateBasicBlock(
1965 exit_block, StrCat("concat_index_from_operand_id", operand_id), b_);
1966 auto saved_insert_point = b_->GetInsertPoint();
1967 llvm_ir::SetToFirstInsertPoint(emit_operand_blocks[operand_id], b_);
1968 source_index_phis[operand_id] =
1969 b_->CreatePHI(source_index.GetType(), operand_usage_count[operand_id]);
1970 std::vector<llvm::Value*> operand_multi_index = source_index.multidim();
1971 operand_multi_index[concat_dim] = b_->CreateNSWSub(
1972 operand_multi_index[concat_dim], source_index_phis[operand_id]);
1973
1974 // Create the terminator of the block before calling operand generators,
1975 // because they require non-degenerate basic blocks.
1976 b_->SetInsertPoint(llvm::BranchInst::Create(
1977 exit_block, /*InsertAtEnd=*/emit_operand_blocks[operand_id]));
1978 llvm_ir::IrArray::Index operand_index(operand_multi_index, operand->shape(),
1979 source_index.GetType());
1980
1981 TF_ASSIGN_OR_RETURN(llvm::Value * value,
1982 operand_to_generator.at(operand)(operand_index));
1983 output->addIncoming(value, b_->GetInsertBlock());
1984 b_->SetInsertPoint(init_block, saved_insert_point);
1985 }
1986
1987 // We use bisection to select the input operand.
1988 int64 current_offset = 0;
1989
1990 // Offset for every operand.
1991 std::vector<std::pair<int64, const HloInstruction*>> cases;
1992
1993 cases.reserve(hlo->operand_count());
1994 for (const HloInstruction* operand : hlo->operands()) {
1995 cases.emplace_back(current_offset, operand);
1996 current_offset += operand->shape().dimensions(concat_dim);
1997 }
1998 CHECK_EQ(current_offset, hlo->shape().dimensions(concat_dim));
1999
2000 std::function<llvm::BasicBlock*(
2001 absl::Span<const std::pair<int64, const HloInstruction*>> operands)>
2002 emit_tree = [&](absl::Span<const std::pair<int64, const HloInstruction*>>
2003 operands) {
2004 llvm::IRBuilder<>::InsertPointGuard guard(*b_);
2005 size_t mid = operands.size() / 2;
2006 const std::pair<int64, const HloInstruction*>& pivot = operands[mid];
2007 llvm::BasicBlock* block = llvm_ir::CreateBasicBlock(
2008 exit_block, absl::StrCat("concatenate.pivot.", pivot.first, "."),
2009 b_);
2010 b_->SetInsertPoint(block);
2011
2012 // If there's only one element we're done. The range is contiguous so we
2013 // can just jump to the block for it.
2014 if (operands.size() == 1) {
2015 const std::pair<int64, const HloInstruction*>& operand =
2016 operands.back();
2017 int64 operand_id = to_unique_operand_id[operand.second];
2018
2019 source_index_phis[operand_id]->addIncoming(
2020 source_index.GetConstantWithIndexType(operand.first),
2021 b_->GetInsertBlock());
2022 b_->CreateBr(emit_operand_blocks[operand_id]);
2023 return block;
2024 }
2025
2026 // Take the middle element and recurse.
2027 llvm::Constant* pivot_const = llvm::ConstantInt::get(
2028 source_index[concat_dim]->getType(), pivot.first);
2029 llvm::Value* comp =
2030 b_->CreateICmpULT(source_index[concat_dim], pivot_const);
2031
2032 llvm::BasicBlock* left_block = emit_tree(operands.subspan(0, mid));
2033 llvm::BasicBlock* right_block = emit_tree(operands.subspan(mid));
2034
2035 b_->CreateCondBr(comp, left_block, right_block);
2036 return block;
2037 };
2038
2039 Br(emit_tree(cases));
2040
2041 b_->SetInsertPoint(exit_block, prior_insert_point);
2042 return output;
2043 }
2044
EmitElementalDynamicSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2045 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
2046 const HloInstruction* hlo,
2047 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2048 const llvm_ir::IrArray::Index& index) {
2049 // Emit IR to read dynamic start indices from hlo->operand(1).
2050 const HloInstruction* input_hlo = hlo->operand(0);
2051 const int64 rank = input_hlo->shape().rank();
2052 // Use the same index type for all tensor accesses in the same kernel.
2053 llvm::Type* index_type = index.GetType();
2054 std::vector<llvm::Value*> slice_start_multi_index(rank);
2055 for (int64 i = 0; i < rank; ++i) {
2056 auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2057 return llvm::ConstantInt::get(index_type, c);
2058 };
2059 llvm_ir::IrArray::Index zero_index(index_type);
2060 TF_ASSIGN_OR_RETURN(
2061 llvm::Value * start_index_value,
2062 operand_to_generator.at(hlo->operand(1 + i))(zero_index));
2063
2064 // Clamp the start index so that the sliced portion fits in the operand:
2065 // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
2066 start_index_value = SExtOrTrunc(start_index_value, index_type);
2067 int64 largest_valid_start_index =
2068 input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i);
2069 CHECK_GE(largest_valid_start_index, 0);
2070
2071 bool is_signed = ShapeUtil::ElementIsSigned(hlo->operand(1)->shape());
2072 start_index_value = EmitIntegralMin(
2073 index_typed_const(largest_valid_start_index),
2074 EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
2075 is_signed);
2076
2077 start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
2078 slice_start_multi_index[i] = start_index_value;
2079 }
2080
2081 std::vector<llvm::Value*> input_multi_index(rank);
2082 for (int64 i = 0; i < rank; ++i) {
2083 // Emit IR which computes:
2084 // input_index = start_index + offset_index
2085 input_multi_index[i] = Add(slice_start_multi_index[i], index[i]);
2086 }
2087 llvm_ir::IrArray::Index input_index(input_multi_index, input_hlo->shape(),
2088 index_type);
2089 return operand_to_generator.at(input_hlo)(input_index);
2090 }
2091
EmitElementalGather(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2092 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
2093 const HloInstruction* hlo,
2094 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2095 const llvm_ir::IrArray::Index& index) {
2096 const Shape& operand_shape = hlo->operand(0)->shape();
2097 const Shape& indices_shape = hlo->operand(1)->shape();
2098 const Shape& output_shape = hlo->shape();
2099
2100 const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers();
2101
2102 const llvm_ir::ElementGenerator& operand_generator =
2103 operand_to_generator.at(hlo->operand(0));
2104 const llvm_ir::ElementGenerator& indices_generator =
2105 operand_to_generator.at(hlo->operand(1));
2106
2107 llvm::Type* index_type = index.GetType();
2108 // This is the index into `operand` that holds the element we want to
2109 // generate.
2110 std::vector<llvm::Value*> operand_multi_index;
2111
2112 // First copy in the window indices to operand_index. Also collect a mapping
2113 // from operand dimension to output window dimension. Elided window dimensions
2114 // map to -1.
2115 std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1);
2116 for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0;
2117 i < e; i++) {
2118 if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
2119 operand_multi_index.push_back(index.GetConstantWithIndexType(0));
2120 } else {
2121 int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++);
2122 operand_to_output_dim[i] = output_window_dim;
2123 operand_multi_index.push_back(index[output_window_dim]);
2124 }
2125 }
2126
2127 // This is the index of the index vector in the start_indices tensor.
2128 std::vector<llvm::Value*> gather_index_index_components;
2129 {
2130 for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
2131 if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
2132 gather_index_index_components.push_back(index[i]);
2133 }
2134 }
2135
2136 if (gather_index_index_components.size() !=
2137 indices_shape.dimensions_size()) {
2138 gather_index_index_components.insert(
2139 gather_index_index_components.begin() +
2140 dim_numbers.index_vector_dim(),
2141 nullptr);
2142 }
2143 }
2144
2145 auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
2146 auto index_component_type = index_component->getType();
2147 auto extended_type = index_component_type->getScalarSizeInBits() >=
2148 index_type->getScalarSizeInBits()
2149 ? index_component_type
2150 : index_type;
2151 // Possibly extend the value at the beginning to ensure clamping logic stays
2152 // in bounds.
2153 auto maybe_extended_index =
2154 index_component_type != extended_type
2155 ? b_->CreateSExt(index_component, extended_type)
2156 : index_component;
2157 int64 operand_dim = dim_numbers.start_index_map(dim);
2158 int64 output_dim = operand_to_output_dim[operand_dim];
2159 // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
2160 // This means we set the iteration index to 0, so for the purpose of the
2161 // following calculations we can consider the output dimension size to be 1.
2162 int64 output_dim_size =
2163 output_dim == -1 ? 1 : output_shape.dimensions(output_dim);
2164 int64 largest_valid_start_index =
2165 operand_shape.dimensions(operand_dim) - output_dim_size;
2166 CHECK_GE(largest_valid_start_index, 0);
2167
2168 // Clamp the gather index so that the gather region fits in the operand.
2169 // clamped_index =
2170 // clamp(gather_dim_component_extended, 0, largest_valid_start_index);
2171 bool is_signed = ShapeUtil::ElementIsSigned(indices_shape);
2172 auto clamped_index = EmitIntegralMin(
2173 llvm::ConstantInt::get(extended_type, largest_valid_start_index),
2174 EmitIntegralMax(llvm::ConstantInt::get(extended_type, 0),
2175 maybe_extended_index, is_signed),
2176 is_signed);
2177 // Truncate at the end to the optimized index size
2178 auto maybe_truncated_clamped_index = extended_type != index_type
2179 ? Trunc(clamped_index, index_type)
2180 : clamped_index;
2181
2182 operand_multi_index[operand_dim] =
2183 Add(operand_multi_index[operand_dim], maybe_truncated_clamped_index);
2184 };
2185
2186 if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
2187 IrArray::Index gather_index_index(gather_index_index_components,
2188 indices_shape, index_type);
2189 TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
2190 indices_generator(gather_index_index));
2191 add_to_operand_index(gather_dim_component, 0);
2192 } else {
2193 int64 index_vector_size =
2194 indices_shape.dimensions(dim_numbers.index_vector_dim());
2195 for (int64 i = 0; i < index_vector_size; i++) {
2196 gather_index_index_components[dim_numbers.index_vector_dim()] =
2197 index.GetConstantWithIndexType(i);
2198 IrArray::Index gather_index_index(gather_index_index_components,
2199 indices_shape, index_type);
2200 TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
2201 indices_generator(gather_index_index));
2202 add_to_operand_index(gather_dim_component, i);
2203 }
2204 }
2205 IrArray::Index operand_index(operand_multi_index, operand_shape, index_type);
2206 return operand_generator(operand_index);
2207 }
2208
EmitElementalDynamicUpdateSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2209 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
2210 const HloInstruction* hlo,
2211 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2212 const llvm_ir::IrArray::Index& index) {
2213 const HloInstruction* input_hlo = hlo->operand(0);
2214 const HloInstruction* update_hlo = hlo->operand(1);
2215 const HloInstruction* start_hlo = hlo->operand(2);
2216 // Calculate slice start/end indices.
2217 const int64 rank = input_hlo->shape().rank();
2218 std::vector<llvm::Value*> slice_start_multi_index(rank);
2219 std::vector<llvm::Value*> slice_limit_multi_index(rank);
2220 // Slice intersection gathers (ANDs) conditions on all ranks for which
2221 // 'input' is set to 'update'
2222 llvm::Value* slice_intersection = b_->getTrue();
2223
2224 for (int64 i = 0; i < rank; ++i) {
2225 llvm::Type* index_type = index[0]->getType();
2226 auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2227 return llvm::ConstantInt::get(index_type, c);
2228 };
2229
2230 llvm_ir::IrArray::Index zero_index(index_type);
2231 TF_ASSIGN_OR_RETURN(
2232 llvm::Value * start_index_value,
2233 operand_to_generator.at(hlo->operand(2 + i))(zero_index));
2234
2235 // Clamp the start index so that the update region fits in the operand.
2236 // start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
2237 start_index_value = SExtOrTrunc(start_index_value, index_type);
2238 llvm::Value* update_dim_size =
2239 index_typed_const(update_hlo->shape().dimensions(i));
2240 int64 largest_valid_start_index =
2241 input_hlo->shape().dimensions(i) - update_hlo->shape().dimensions(i);
2242 CHECK_GE(largest_valid_start_index, 0);
2243
2244 bool is_signed = ShapeUtil::ElementIsSigned(start_hlo->shape());
2245 start_index_value = EmitIntegralMin(
2246 index_typed_const(largest_valid_start_index),
2247 EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
2248 is_signed);
2249
2250 start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
2251 slice_start_multi_index[i] = start_index_value;
2252 slice_limit_multi_index[i] =
2253 Add(slice_start_multi_index[i], update_dim_size);
2254
2255 slice_intersection =
2256 And(slice_intersection, ICmpSGE(index[i], slice_start_multi_index[i]),
2257 "slice_intersection");
2258 slice_intersection =
2259 And(slice_intersection, ICmpSLT(index[i], slice_limit_multi_index[i]),
2260 "slice_intersection");
2261 }
2262
2263 // Emit:
2264 // if (slice_intersection) -> return data from 'update'.
2265 // else -> return data from 'input'.
2266 llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2267 llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2268 "ret_value_addr", b_);
2269 llvm_ir::LlvmIfData if_data =
2270 llvm_ir::EmitIfThenElse(slice_intersection, "slice_intersection", b_);
2271
2272 // Handle true BB (return data from 'update')
2273 SetToFirstInsertPoint(if_data.true_block, b_);
2274 // Compute update index for intersection case.
2275 std::vector<llvm::Value*> update_multi_index(rank);
2276 for (int64 i = 0; i < rank; ++i) {
2277 update_multi_index[i] = Sub(index[i], slice_start_multi_index[i]);
2278 }
2279 llvm_ir::IrArray::Index update_index(update_multi_index, update_hlo->shape(),
2280 index.GetType());
2281 TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
2282 operand_to_generator.at(update_hlo)(update_index));
2283 Store(true_value, ret_value_addr);
2284
2285 // Handle false BB (return data from 'input')
2286 SetToFirstInsertPoint(if_data.false_block, b_);
2287 TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
2288 operand_to_generator.at(input_hlo)(index));
2289 Store(false_value, ret_value_addr);
2290
2291 SetToFirstInsertPoint(if_data.after_block, b_);
2292 return Load(ret_value_addr);
2293 }
2294
EmitElementalPad(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & padded_index)2295 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
2296 const HloInstruction* hlo,
2297 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2298 const llvm_ir::IrArray::Index& padded_index) {
2299 std::vector<llvm::Value*> multi_index = padded_index.multidim();
2300 llvm::Value* in_bounds = b_->getTrue();
2301 for (size_t i = 0; i < multi_index.size(); ++i) {
2302 auto index_typed_const = [=](int64 n) {
2303 return padded_index.GetConstantWithIndexType(n);
2304 };
2305 const auto& pad_dim = hlo->padding_config().dimensions(i);
2306 multi_index[i] =
2307 Sub(multi_index[i], index_typed_const(pad_dim.edge_padding_low()));
2308 in_bounds = And(in_bounds, ICmpSGE(multi_index[i], index_typed_const(0)),
2309 "in_bounds");
2310 in_bounds =
2311 And(in_bounds,
2312 ICmpEQ(index_typed_const(0),
2313 URem(multi_index[i],
2314 index_typed_const(pad_dim.interior_padding() + 1))),
2315 "in_bounds");
2316 multi_index[i] =
2317 SDiv(multi_index[i], index_typed_const(pad_dim.interior_padding() + 1));
2318 in_bounds =
2319 And(in_bounds,
2320 ICmpSLT(multi_index[i],
2321 index_typed_const(hlo->operand(0)->shape().dimensions(i))),
2322 "in_bounds");
2323 }
2324
2325 // if (in_bounds) {
2326 // ret_value = operand0[index]; // source
2327 // } else {
2328 // ret_value = *operand1; // padding
2329 // }
2330 llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2331 llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2332 "pad_result_addr", b_);
2333 llvm_ir::LlvmIfData if_data =
2334 llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
2335 SetToFirstInsertPoint(if_data.true_block, b_);
2336 llvm_ir::IrArray::Index index(multi_index, hlo->operand(0)->shape(),
2337 padded_index.GetType());
2338 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2339 operand_to_generator.at(hlo->operand(0))(index));
2340 Store(operand_value, ret_value_addr);
2341
2342 SetToFirstInsertPoint(if_data.false_block, b_);
2343 TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
2344 operand_to_generator.at(hlo->operand(1))(
2345 IrArray::Index(index.GetType())));
2346 Store(padding_value, ret_value_addr);
2347
2348 SetToFirstInsertPoint(if_data.after_block, b_);
2349 // Don't create phi(operand_value, padding_value) here, because invoking
2350 // operand_to_generator may create new basic blocks, making the parent
2351 // of operand_value or padding_value no longer a predecessor of
2352 // if_data.after_block.
2353 return Load(ret_value_addr);
2354 }
2355
EmitElementalDot(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & dot_result_index)2356 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
2357 const HloInstruction* hlo,
2358 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2359 const llvm_ir::IrArray::Index& dot_result_index) {
2360 auto lhs_generator = operand_to_generator.at(hlo->operand(0));
2361 auto rhs_generator = operand_to_generator.at(hlo->operand(1));
2362
2363 const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers();
2364 int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0);
2365 int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0);
2366
2367 int64 contracted_dim_size =
2368 hlo->operand(0)->shape().dimensions(lhs_contracting_dim);
2369 int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
2370 int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
2371
2372 llvm::Type* index_type = dot_result_index.GetType();
2373 auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2374 return llvm::ConstantInt::get(index_type, c);
2375 };
2376
2377 std::unique_ptr<llvm_ir::ForLoop> inner_loop = llvm_ir::ForLoop::EmitForLoop(
2378 IrName(hlo, "inner"), index_typed_const(0),
2379 index_typed_const(contracted_dim_size), index_typed_const(1), b_);
2380
2381 SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), b_);
2382 PrimitiveType primitive_type = hlo->shape().element_type();
2383 llvm::Type* primitive_type_llvm =
2384 llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
2385 llvm::Value* accumulator_alloca =
2386 llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_);
2387 Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca);
2388
2389 SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_);
2390
2391 // This is the inner reduction loop for a dot operation that produces
2392 // one element in the output. If the operands to the dot operation have
2393 // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
2394 // Given an output index [a,b,c,d,e] in the result, we compute:
2395 // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
2396
2397 std::vector<llvm::Value*> lhs_multi_index, rhs_multi_index;
2398 for (int64 i = 0; i < lhs_dims - 1; i++) {
2399 lhs_multi_index.push_back(dot_result_index[i]);
2400 }
2401 lhs_multi_index.insert(lhs_multi_index.begin() + lhs_contracting_dim,
2402 inner_loop->GetIndVarValue());
2403 IrArray::Index lhs_index(lhs_multi_index, hlo->operand(0)->shape(),
2404 index_type);
2405
2406 int64 num_batch_dims = dim_numbers.rhs_batch_dimensions_size();
2407 for (int64 i = 0; i < num_batch_dims; i++) {
2408 rhs_multi_index.push_back(
2409 dot_result_index[dim_numbers.rhs_batch_dimensions(i)]);
2410 }
2411 for (int64 i = 0; i < rhs_dims - 1 - num_batch_dims; i++) {
2412 rhs_multi_index.push_back(dot_result_index[lhs_dims - 1 + i]);
2413 }
2414 rhs_multi_index.insert(rhs_multi_index.begin() + rhs_contracting_dim,
2415 inner_loop->GetIndVarValue());
2416 IrArray::Index rhs_index(rhs_multi_index, hlo->operand(1)->shape(),
2417 index_type);
2418
2419 llvm::Value* current_accumulator = Load(accumulator_alloca);
2420 TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
2421 TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
2422 llvm::Value* next_accumulator =
2423 EmitMulAdd(lhs_value, rhs_value, current_accumulator, primitive_type);
2424 Store(next_accumulator, accumulator_alloca);
2425
2426 SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
2427 return Load(accumulator_alloca);
2428 }
2429
MakeElementGenerator(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator)2430 llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
2431 const HloInstruction* hlo,
2432 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
2433 switch (hlo->opcode()) {
2434 case HloOpcode::kAbs:
2435 case HloOpcode::kRoundNearestAfz:
2436 case HloOpcode::kCeil:
2437 case HloOpcode::kClz:
2438 case HloOpcode::kConvert:
2439 case HloOpcode::kBitcastConvert:
2440 case HloOpcode::kCos:
2441 case HloOpcode::kExp:
2442 case HloOpcode::kExpm1:
2443 case HloOpcode::kFloor:
2444 case HloOpcode::kImag:
2445 case HloOpcode::kIsFinite:
2446 case HloOpcode::kLog:
2447 case HloOpcode::kLog1p:
2448 case HloOpcode::kNegate:
2449 case HloOpcode::kNot:
2450 case HloOpcode::kPopulationCount:
2451 case HloOpcode::kReal:
2452 case HloOpcode::kRsqrt:
2453 case HloOpcode::kSign:
2454 case HloOpcode::kSin:
2455 case HloOpcode::kSqrt:
2456 case HloOpcode::kCbrt:
2457 case HloOpcode::kTanh:
2458 return [this, hlo, &operand_to_generator](
2459 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2460 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2461 operand_to_generator.at(hlo->operand(0))(index));
2462 return EmitUnaryOp(hlo, operand_value);
2463 };
2464 case HloOpcode::kAdd:
2465 case HloOpcode::kAnd:
2466 case HloOpcode::kAtan2:
2467 case HloOpcode::kCompare:
2468 case HloOpcode::kComplex:
2469 case HloOpcode::kDivide:
2470 case HloOpcode::kMaximum:
2471 case HloOpcode::kMinimum:
2472 case HloOpcode::kMultiply:
2473 case HloOpcode::kOr:
2474 case HloOpcode::kXor:
2475 case HloOpcode::kPower:
2476 case HloOpcode::kRemainder:
2477 case HloOpcode::kShiftLeft:
2478 case HloOpcode::kShiftRightArithmetic:
2479 case HloOpcode::kShiftRightLogical:
2480 case HloOpcode::kSubtract:
2481 return [this, hlo, &operand_to_generator](
2482 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2483 const HloInstruction* lhs = hlo->operand(0);
2484 const HloInstruction* rhs = hlo->operand(1);
2485 TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value,
2486 operand_to_generator.at(lhs)(index));
2487 TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value,
2488 operand_to_generator.at(rhs)(index));
2489 return EmitBinaryOp(hlo, lhs_value, rhs_value);
2490 };
2491 case HloOpcode::kSelect:
2492 return [this, hlo, &operand_to_generator](
2493 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2494 return EmitElementalSelect(hlo, operand_to_generator, index);
2495 };
2496 case HloOpcode::kClamp:
2497 return [this, hlo, &operand_to_generator](
2498 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2499 return EmitElementalClamp(hlo, operand_to_generator, index);
2500 };
2501 case HloOpcode::kReducePrecision:
2502 return [this, hlo, &operand_to_generator](
2503 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2504 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2505 operand_to_generator.at(hlo->operand(0))(index));
2506 return EmitReducePrecision(hlo, operand_value);
2507 };
2508 case HloOpcode::kConcatenate:
2509 return [this, hlo, &operand_to_generator](
2510 const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
2511 return EmitElementalConcatenate(hlo, operand_to_generator,
2512 target_index);
2513 };
2514 case HloOpcode::kReverse:
2515 return [this, hlo, &operand_to_generator](
2516 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2517 const HloInstruction* operand = hlo->operand(0);
2518 std::vector<llvm::Value*> source_multi_index = target_index.multidim();
2519 for (int64 dim : hlo->dimensions()) {
2520 source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType(
2521 hlo->shape().dimensions(dim) - 1),
2522 target_index[dim]);
2523 }
2524 llvm_ir::IrArray::Index source_index(
2525 source_multi_index, operand->shape(), target_index.GetType());
2526 return operand_to_generator.at(operand)(source_index);
2527 };
2528 case HloOpcode::kBroadcast:
2529 return [this, hlo, &operand_to_generator](
2530 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2531 const HloInstruction* operand = hlo->operand(0);
2532 // The `dimensions` member of the broadcast instruction maps from
2533 // input dimensions to output dimensions.
2534 return operand_to_generator.at(operand)(
2535 target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(),
2536 hlo->dimensions(), b_));
2537 };
2538 case HloOpcode::kIota:
2539 return [this, hlo](
2540 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2541 auto* iota = Cast<HloIotaInstruction>(hlo);
2542 PrimitiveType element_type = iota->shape().element_type();
2543 IrArray::Index elem_index =
2544 iota->shape().rank() > 1
2545 ? target_index.SourceIndexOfBroadcast(
2546 iota->shape(),
2547 ShapeUtil::MakeShapeWithDescendingLayout(
2548 element_type,
2549 {iota->shape().dimensions(iota->iota_dimension())}),
2550 {iota->iota_dimension()}, b_)
2551 : target_index;
2552 llvm::Value* elem_index_linear = elem_index.linear();
2553 if (elem_index_linear == nullptr) {
2554 std::vector<int64> iota_bound = {
2555 iota->shape().dimensions(iota->iota_dimension())};
2556 elem_index_linear = elem_index.Linearize(iota_bound, b_);
2557 }
2558 Shape component_shape =
2559 ShapeUtil::ElementIsComplex(iota->shape())
2560 ? ShapeUtil::ComplexComponentShape(iota->shape())
2561 : iota->shape();
2562 PrimitiveType component_element_type = component_shape.element_type();
2563 llvm::Value* iota_result;
2564 if (primitive_util::IsIntegralType(component_element_type)) {
2565 iota_result = b_->CreateIntCast(
2566 elem_index_linear,
2567 llvm_ir::PrimitiveTypeToIrType(component_element_type, module_),
2568 /*isSigned=*/false);
2569 } else {
2570 TF_RET_CHECK(
2571 primitive_util::IsFloatingPointType(component_element_type))
2572 << component_element_type;
2573 llvm::Type* float_ir_type;
2574 if (component_element_type == BF16) {
2575 float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_);
2576 } else {
2577 float_ir_type =
2578 llvm_ir::PrimitiveTypeToIrType(component_element_type, module_);
2579 }
2580 llvm::Value* float_val =
2581 b_->CreateUIToFP(elem_index_linear, float_ir_type);
2582 if (component_element_type == BF16) {
2583 TF_ASSIGN_OR_RETURN(iota_result, EmitF32ToBF16(float_val, b_));
2584 } else {
2585 iota_result = float_val;
2586 }
2587 }
2588 if (ShapeUtil::ElementIsComplex(iota->shape())) {
2589 return EmitComposeComplex(iota, iota_result, nullptr);
2590 } else {
2591 return iota_result;
2592 }
2593 };
2594 case HloOpcode::kSlice:
2595 return [this, hlo, &operand_to_generator](
2596 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2597 IrArray::Index sliced_index = index.SourceIndexOfSlice(
2598 /*operand_shape=*/hlo->operand(0)->shape(),
2599 /*starts=*/hlo->slice_starts(),
2600 /*strides=*/hlo->slice_strides(), /*builder=*/b_);
2601 return operand_to_generator.at(hlo->operand(0))(sliced_index);
2602 };
2603 case HloOpcode::kDynamicSlice:
2604 return [this, hlo, &operand_to_generator](
2605 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2606 return EmitElementalDynamicSlice(hlo, operand_to_generator, index);
2607 };
2608
2609 case HloOpcode::kGather:
2610 return [this, hlo, &operand_to_generator](
2611 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2612 return EmitElementalGather(hlo, operand_to_generator, index);
2613 };
2614 case HloOpcode::kDynamicUpdateSlice:
2615 return [this, hlo, &operand_to_generator](
2616 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2617 return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator,
2618 index);
2619 };
2620 case HloOpcode::kBitcast:
2621 CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2622 ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2623 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2624 const HloInstruction* operand = hlo->operand(0);
2625 return operand_to_generator.at(operand)(
2626 GetSourceIndexOfBitcast(index, hlo));
2627 };
2628 case HloOpcode::kReshape:
2629 CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2630 ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2631 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2632 const HloInstruction* operand = hlo->operand(0);
2633 return operand_to_generator.at(operand)(
2634 index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_));
2635 };
2636 case HloOpcode::kCopy:
2637 return [hlo, &operand_to_generator](
2638 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2639 IrArray::Index source_index(target_index.multidim(),
2640 hlo->operand(0)->shape(),
2641 target_index.GetType());
2642 TF_ASSIGN_OR_RETURN(
2643 llvm::Value * operand_value,
2644 operand_to_generator.at(hlo->operand(0))(source_index));
2645 return operand_value;
2646 };
2647 case HloOpcode::kTranspose:
2648 return [this, hlo,
2649 &operand_to_generator](const IrArray::Index& target_index) {
2650 return operand_to_generator.at(hlo->operand(0))(
2651 target_index.SourceIndexOfTranspose(
2652 hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions()));
2653 };
2654 case HloOpcode::kPad:
2655 return [this, hlo, &operand_to_generator](
2656 const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
2657 return EmitElementalPad(hlo, operand_to_generator, padded_index);
2658 };
2659
2660 case HloOpcode::kDot:
2661 return [this, hlo,
2662 &operand_to_generator](const IrArray::Index& dot_result_index)
2663 -> StatusOr<llvm::Value*> {
2664 return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
2665 };
2666 case HloOpcode::kMap:
2667 return [this, hlo, &operand_to_generator](
2668 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2669 std::vector<llvm::Value*> operands;
2670 for (int i = 0; i < hlo->operand_count(); i++) {
2671 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2672 operand_to_generator.at(hlo->operand(i))(index));
2673 operands.push_back(operand_value);
2674 }
2675 return EmitElementalMap(Cast<HloMapInstruction>(hlo), operands);
2676 };
2677 case HloOpcode::kReduceWindow:
2678 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2679 auto reduce_window_instr = Cast<HloReduceWindowInstruction>(hlo);
2680 std::vector<llvm_ir::ElementGenerator> input_generators;
2681 for (const HloInstruction* instr : reduce_window_instr->inputs()) {
2682 input_generators.push_back(operand_to_generator.at(instr));
2683 }
2684
2685 std::vector<llvm_ir::ElementGenerator> initial_value_generators;
2686 for (const HloInstruction* instr : reduce_window_instr->init_values()) {
2687 initial_value_generators.push_back(operand_to_generator.at(instr));
2688 }
2689 return EmitElementalReduceWindow(
2690 Cast<HloReduceWindowInstruction>(hlo), std::move(input_generators),
2691 std::move(initial_value_generators), index);
2692 };
2693 case HloOpcode::kReduce:
2694 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2695 auto reduce_instr = Cast<HloReduceInstruction>(hlo);
2696 std::vector<llvm_ir::ElementGenerator> input_generators;
2697 for (const HloInstruction* instr : reduce_instr->inputs()) {
2698 input_generators.push_back(operand_to_generator.at(instr));
2699 }
2700
2701 std::vector<llvm_ir::ElementGenerator> initial_value_generators;
2702 for (const HloInstruction* instr : reduce_instr->init_values()) {
2703 initial_value_generators.push_back(operand_to_generator.at(instr));
2704 }
2705 return EmitElementalReduce(reduce_instr, std::move(input_generators),
2706 std::move(initial_value_generators), index);
2707 };
2708 case HloOpcode::kConvolution:
2709 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2710 return EmitConvolution(hlo, operand_to_generator, index);
2711 };
2712 default:
2713 return [hlo](const IrArray::Index& index) {
2714 return Unimplemented("Unhandled opcode for elemental IR emission: %s",
2715 HloOpcodeString(hlo->opcode()));
2716 };
2717 }
2718 }
2719
EmitExtractReal(llvm::Value * value)2720 llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) {
2721 return ExtractValue(value, {0});
2722 }
2723
EmitExtractImag(llvm::Value * value)2724 llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) {
2725 return ExtractValue(value, {1});
2726 }
2727
EmitComposeComplex(const HloInstruction * op,llvm::Value * real,llvm::Value * imag)2728 llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
2729 llvm::Value* real,
2730 llvm::Value* imag) {
2731 auto cplx_type =
2732 llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
2733 auto complex =
2734 InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0});
2735 if (imag != nullptr) {
2736 complex = InsertValue(complex, imag, {1});
2737 }
2738 return complex;
2739 }
2740
EmitMulAdd(llvm::Value * lhs,llvm::Value * rhs,llvm::Value * accumulator,xla::PrimitiveType primitive_type)2741 llvm::Value* ElementalIrEmitter::EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs,
2742 llvm::Value* accumulator,
2743 xla::PrimitiveType primitive_type) {
2744 if (primitive_util::IsComplexType(primitive_type)) {
2745 llvm::Value* product_real =
2746 FSub(FMul(EmitExtractReal(lhs), EmitExtractReal(rhs)),
2747 FMul(EmitExtractImag(lhs), EmitExtractImag(rhs)));
2748 llvm::Value* product_imag =
2749 FAdd(FMul(EmitExtractReal(lhs), EmitExtractImag(rhs)),
2750 FMul(EmitExtractImag(lhs), EmitExtractReal(rhs)));
2751 llvm::Value* next_accumulator = InsertValue(
2752 accumulator, FAdd(EmitExtractReal(accumulator), product_real), {0});
2753 return InsertValue(next_accumulator,
2754 FAdd(EmitExtractImag(accumulator), product_imag), {1});
2755 } else if (primitive_util::IsFloatingPointType(primitive_type)) {
2756 return FAdd(accumulator, FPCast(FMul(lhs, rhs), accumulator->getType()));
2757 } else if (primitive_type == PRED) {
2758 return Or(accumulator, And(lhs, rhs));
2759 }
2760 return Add(accumulator, Mul(lhs, rhs));
2761 }
2762
EmitElementalMap(const HloMapInstruction * map_instr,absl::Span<llvm::Value * const> elemental_operands)2763 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
2764 const HloMapInstruction* map_instr,
2765 absl::Span<llvm::Value* const> elemental_operands) {
2766 TF_ASSIGN_OR_RETURN(
2767 std::vector<llvm::Value*> values,
2768 EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands,
2769 llvm_ir::IrName(map_instr)));
2770 CHECK_EQ(values.size(), 1);
2771 return values[0];
2772 }
2773
EmitElementalReduceWindow(const HloReduceWindowInstruction * reduce_window,std::vector<llvm_ir::ElementGenerator> input_generators,std::vector<llvm_ir::ElementGenerator> initial_value_generators,const llvm_ir::IrArray::Index & index)2774 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduceWindow(
2775 const HloReduceWindowInstruction* reduce_window,
2776 std::vector<llvm_ir::ElementGenerator> input_generators,
2777 std::vector<llvm_ir::ElementGenerator> initial_value_generators,
2778 const llvm_ir::IrArray::Index& index) {
2779 // Pseudocode:
2780 // for each index I in output
2781 // value = init_value
2782 // for each index W in window
2783 // for each dimension i from 0 to rank - 1
2784 // (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i]
2785 // if I in bounds of input
2786 // value = function(value, input[I])
2787 // output[O] = value
2788 int64 input_count = reduce_window->input_count();
2789 std::vector<PrimitiveType> operand_element_types;
2790 std::vector<llvm::Type*> accum_types;
2791 std::vector<llvm::Value*> accum_ptrs;
2792 for (int64 operand_index = 0; operand_index < input_count; ++operand_index) {
2793 auto operand = reduce_window->inputs()[operand_index];
2794 PrimitiveType operand_element_type = operand->shape().element_type();
2795 operand_element_types.push_back(operand_element_type);
2796 llvm::Type* llvm_type =
2797 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_);
2798 accum_types.push_back(llvm_type);
2799 llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
2800 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
2801 "reduce_window_accum_ptr", b_);
2802 accum_ptrs.push_back(accum_ptr);
2803 {
2804 auto initial_value_generator = initial_value_generators[operand_index];
2805 TF_ASSIGN_OR_RETURN(
2806 llvm::Value* const init_value,
2807 initial_value_generator(llvm_ir::IrArray::Index(index.GetType())));
2808 Store(init_value, accum_ptr);
2809 }
2810 }
2811
2812 llvm::Type* index_type = index.GetType();
2813 auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2814 return index.GetConstantWithIndexType(c);
2815 };
2816
2817 const Window& window = reduce_window->window();
2818 llvm_ir::ForLoopNest loops(IrName(reduce_window), b_, index_type);
2819 std::vector<int64> window_size;
2820 for (const auto& dim : window.dimensions()) {
2821 window_size.push_back(dim.size());
2822 }
2823 const IrArray::Index window_index = loops.AddLoopsForShape(
2824 ShapeUtil::MakeShape(operand_element_types[0], window_size), "window");
2825 CHECK_EQ(window_index.size(), index.size());
2826
2827 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
2828
2829 std::vector<llvm::Value*> input_multi_index(index.size());
2830 llvm::Value* in_bounds = b_->getInt1(true);
2831 for (size_t i = 0; i < index.size(); ++i) {
2832 llvm::Value* stridden_index =
2833 NSWMul(index[i], index_typed_const(window.dimensions(i).stride()));
2834 input_multi_index[i] = NSWSub(
2835 NSWAdd(
2836 stridden_index,
2837 NSWMul(window_index[i],
2838 index_typed_const(window.dimensions(i).window_dilation()))),
2839 index_typed_const(window.dimensions(i).padding_low()));
2840
2841 // We need to verify that we are not in the dilated base area.
2842 llvm::Value* dilation_condition =
2843 ICmpEQ(SRem(input_multi_index[i],
2844 index_typed_const(window.dimensions(i).base_dilation())),
2845 index_typed_const(0));
2846 in_bounds = And(in_bounds, dilation_condition);
2847
2848 // Apply base dilation to the index.
2849 input_multi_index[i] =
2850 SDiv(input_multi_index[i],
2851 index_typed_const(window.dimensions(i).base_dilation()));
2852
2853 // We must check whether 0 <= input_multi_index[i] < bound, as
2854 // otherwise we are in the pad and so can skip the computation. This
2855 // comparison is equivalent to the unsigned comparison
2856 // input_multi_index[i] < bound, as a negative value wraps to a large
2857 // positive value.
2858 in_bounds =
2859 And(in_bounds,
2860 ICmpULT(input_multi_index[i],
2861 index_typed_const(
2862 reduce_window->inputs()[0]->shape().dimensions(i))));
2863 }
2864
2865 llvm_ir::LlvmIfData if_data =
2866 llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
2867 SetToFirstInsertPoint(if_data.true_block, b_);
2868
2869 // We are not in pad, so do the computation.
2870 std::vector<llvm::Value*> input_values(reduce_window->operand_count());
2871 IrArray::Index input_index(input_multi_index,
2872 reduce_window->inputs()[0]->shape(), index_type);
2873 for (int64 operand_idx = 0; operand_idx < input_count; ++operand_idx) {
2874 TF_ASSIGN_OR_RETURN(llvm::Value * input_value,
2875 input_generators[operand_idx](input_index));
2876 input_values[input_count + operand_idx] = input_value;
2877 input_values[operand_idx] = Load(accum_ptrs[operand_idx]);
2878 }
2879 TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accum_values,
2880 EmitThreadLocalCall(*reduce_window->to_apply(),
2881 input_values, "reducer_function"));
2882
2883 for (int64 operand_idx = 0; operand_idx < accum_values.size();
2884 ++operand_idx) {
2885 Store(accum_values[operand_idx], accum_ptrs[operand_idx]);
2886 }
2887
2888 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
2889 return EmitAccumResult(accum_ptrs, accum_types,
2890 reduce_window->shape().IsTuple());
2891 }
2892
EmitElementalReduce(const HloReduceInstruction * reduce,std::vector<llvm_ir::ElementGenerator> input_generators,std::vector<llvm_ir::ElementGenerator> initial_value_generators,const llvm_ir::IrArray::Index & index)2893 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
2894 const HloReduceInstruction* reduce,
2895 std::vector<llvm_ir::ElementGenerator> input_generators,
2896 std::vector<llvm_ir::ElementGenerator> initial_value_generators,
2897 const llvm_ir::IrArray::Index& index) {
2898 const Shape& out_shape = reduce->shape();
2899 bool is_variadic = !out_shape.IsArray();
2900 int accumulators_count = 1;
2901 if (is_variadic) {
2902 CHECK(out_shape.IsTuple());
2903 accumulators_count = out_shape.tuple_shapes_size();
2904 }
2905
2906 absl::Span<const int64> reduced_dimensions(reduce->dimensions());
2907
2908 std::vector<llvm::Value*> accumulator_addrs;
2909 std::vector<llvm::Type*> accumulator_types;
2910 llvm::Type* index_type = index.GetType();
2911 for (int i = 0; i < accumulators_count; i++) {
2912 const Shape& element_shape =
2913 is_variadic ? out_shape.tuple_shapes(i) : out_shape;
2914 PrimitiveType accumulator_type = element_shape.element_type();
2915 llvm::Type* accumulator_llvm_type =
2916 llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
2917 accumulator_types.push_back(accumulator_llvm_type);
2918
2919 // Initialize an accumulator with init_value.
2920 llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2921 accumulator_llvm_type, "accumulator_" + std::to_string(i), b());
2922 TF_ASSIGN_OR_RETURN(
2923 llvm::Value* const init_value,
2924 initial_value_generators[i](llvm_ir::IrArray::Index(index_type)));
2925 Store(init_value, accumulator_addr);
2926 accumulator_addrs.push_back(accumulator_addr);
2927 }
2928
2929 // The enclosing loops go over all the target elements. Now we have to compute
2930 // the actual target element. For this, we build a new loop nest to iterate
2931 // over all the reduction dimensions in the argument.
2932 // AddLoopsForShapeOnDimensions will return an Index where induction Value*s
2933 // are placed for each dimension in dimensions, and all the rest are nullptrs.
2934 llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), b(), index_type);
2935 const HloInstruction* arg = reduce->operand(0);
2936 std::vector<llvm::Value*> input_multi_index =
2937 loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions,
2938 "reduction_dim");
2939
2940 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b());
2941
2942 // Build a full index for the input argument, using input_multi_index as the
2943 // base. In input_multi_index only the reduction dimensions are filled in. We
2944 // fill in the rest of the dimensions with induction Value*s taken from
2945 // 'index' which iterates over the target array. See the high-level
2946 // description in the XLA documentation for details.
2947 auto it = index.begin();
2948
2949 for (auto& i : input_multi_index) {
2950 if (i == nullptr) {
2951 i = *it++;
2952 }
2953 }
2954 CHECK(index.end() == it);
2955 llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
2956 index_type);
2957
2958 std::vector<llvm::Value*> reduction_operands;
2959 for (llvm::Value* accum : accumulator_addrs) {
2960 llvm::Value* accum_value = Load(accum);
2961 reduction_operands.push_back(accum_value);
2962 }
2963
2964 for (int i = 0; i < accumulators_count; i++) {
2965 TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
2966 input_generators[i](input_index));
2967 reduction_operands.push_back(input_element);
2968 }
2969
2970 TF_ASSIGN_OR_RETURN(
2971 std::vector<llvm::Value*> results,
2972 EmitThreadLocalCall(*reduce->to_apply(), reduction_operands,
2973 "reduce_function"));
2974
2975 CHECK(results.size() == accumulators_count);
2976 for (int i = 0; i < accumulators_count; i++) {
2977 Store(results[i], accumulator_addrs[i]);
2978 }
2979 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b());
2980 return EmitAccumResult(accumulator_addrs, accumulator_types, is_variadic);
2981 }
2982
EmitAccumResult(absl::Span<llvm::Value * const> accumulator_addrs,llvm::ArrayRef<llvm::Type * > accumulator_types,bool is_variadic)2983 StatusOr<llvm::Value*> ElementalIrEmitter::EmitAccumResult(
2984 absl::Span<llvm::Value* const> accumulator_addrs,
2985 llvm::ArrayRef<llvm::Type*> accumulator_types, bool is_variadic) {
2986 TF_RET_CHECK(accumulator_addrs.size() == accumulator_types.size());
2987 if (is_variadic) {
2988 // Emit a structure, as that what the LoopEmitter expects.
2989 llvm::Value* returned_structure = llvm::UndefValue::get(
2990 llvm::StructType::get(b()->getContext(), accumulator_types));
2991 for (int64 i = 0; i < accumulator_addrs.size(); i++) {
2992 llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
2993 returned_structure =
2994 b()->CreateInsertValue(returned_structure, accumulator_value, i);
2995 }
2996 return returned_structure;
2997 } else {
2998 CHECK_EQ(accumulator_addrs.size(), 1);
2999 return Load(accumulator_addrs[0]);
3000 }
3001 }
3002
EmitConvolution(const HloInstruction * convolution,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)3003 StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
3004 const HloInstruction* convolution,
3005 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
3006 const llvm_ir::IrArray::Index& index) {
3007 const HloInstruction* lhs = convolution->operand(0);
3008 const auto& input_generator = operand_to_generator.at(lhs);
3009 const HloInstruction* rhs = convolution->operand(1);
3010 const auto& kernel_generator = operand_to_generator.at(rhs);
3011 const Window& window = convolution->window();
3012
3013 const ConvolutionDimensionNumbers& dnums =
3014 convolution->convolution_dimension_numbers();
3015 int num_spatial_dims = dnums.output_spatial_dimensions_size();
3016 std::vector<llvm::Value*> output_spatial(num_spatial_dims);
3017 for (int i = 0; i < num_spatial_dims; ++i) {
3018 output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
3019 }
3020 llvm::Value* output_feature = index[dnums.output_feature_dimension()];
3021 llvm::Value* batch = index[dnums.output_batch_dimension()];
3022
3023 // We will accumulate the products into this sum to calculate the output entry
3024 // at the given index.
3025 PrimitiveType lhs_element_type = lhs->shape().element_type();
3026 llvm::Type* lhs_llvm_type =
3027 llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
3028 // Upcast the accumulator to F32 from F16 for increased precision.
3029 llvm::Type* accumulator_type =
3030 lhs_element_type == F16 ? b_->getFloatTy() : lhs_llvm_type;
3031 llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
3032 accumulator_type, "convolution_sum_address", b_);
3033 llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type);
3034 Store(constant_zero, sum_address);
3035
3036 llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), b_);
3037 std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
3038 for (int i = 0; i < num_spatial_dims; ++i) {
3039 kernel_spatial[i] =
3040 loops
3041 .AddLoop(
3042 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
3043 absl::StrCat("k", i))
3044 ->GetIndVarValue();
3045 }
3046 llvm::Value* input_feature =
3047 loops
3048 .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()),
3049 "iz")
3050 ->GetIndVarValue();
3051
3052 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
3053
3054 // Calculate the spatial index in the input array, taking striding, dilation
3055 // and padding into account. An index in the padding will be out of the bounds
3056 // of the array.
3057 const auto calculate_input_index = [this](llvm::Value* output_index,
3058 llvm::Value* kernel_index,
3059 const WindowDimension& window_dim) {
3060 llvm::Value* strided_index =
3061 NSWMul(output_index, b_->getInt64(window_dim.stride()));
3062 llvm::Value* dilated_kernel_index =
3063 NSWMul(kernel_index, b_->getInt64(window_dim.window_dilation()));
3064 return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
3065 b_->getInt64(window_dim.padding_low()));
3066 };
3067 std::vector<llvm::Value*> input_spatial(num_spatial_dims);
3068 for (int i = 0; i < num_spatial_dims; ++i) {
3069 input_spatial[i] = calculate_input_index(
3070 output_spatial[i], kernel_spatial[i], window.dimensions(i));
3071 }
3072
3073 // We need to check if 0 <= input dim < bound, as otherwise we are in the
3074 // padding so that we can skip the computation. That is equivalent to input
3075 // dim < bound as an *unsigned* comparison, since a negative value will wrap
3076 // to a large positive value. The input dim is dilated, so we need to dilate
3077 // the bound as well to match.
3078
3079 // Also need to check that the input coordinates are not in one of the
3080 // holes created by base dilation.
3081 const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
3082 llvm::Value* remainder = SRem(input_index, b_->getInt64(base_dilation));
3083 return ICmpEQ(remainder, b_->getInt64(0));
3084 };
3085
3086 llvm::Value* in_bounds_condition = b_->getInt1(true);
3087 for (int i = 0; i < num_spatial_dims; ++i) {
3088 llvm::ConstantInt* input_bound = b_->getInt64(window_util::DilatedBound(
3089 lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
3090 window.dimensions(i).base_dilation()));
3091 llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
3092 llvm::Value* dim_not_in_hole =
3093 not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
3094 llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
3095 in_bounds_condition = And(in_bounds_condition, dim_ok);
3096 }
3097
3098 // Now we need to map the dilated base coordinates back to the actual
3099 // data indices on the lhs.
3100 const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
3101 return SDiv(input_index, b_->getInt64(base_dilation));
3102 };
3103 for (int i = 0; i < num_spatial_dims; ++i) {
3104 input_spatial[i] =
3105 undilate(input_spatial[i], window.dimensions(i).base_dilation());
3106 }
3107
3108 llvm_ir::LlvmIfData if_data =
3109 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", b_);
3110 SetToFirstInsertPoint(if_data.true_block, b_);
3111
3112 // We are not in the padding, so carry out the computation.
3113 int num_dims = num_spatial_dims + 2;
3114 std::vector<llvm::Value*> input_multi_index(num_dims);
3115 for (int i = 0; i < num_spatial_dims; ++i) {
3116 input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
3117 }
3118 input_multi_index[dnums.input_feature_dimension()] = input_feature;
3119 input_multi_index[dnums.input_batch_dimension()] = batch;
3120
3121 std::vector<llvm::Value*> kernel_multi_index(num_dims);
3122 for (int i = 0; i < num_spatial_dims; ++i) {
3123 kernel_multi_index[dnums.kernel_spatial_dimensions(i)] =
3124 window.dimensions(i).window_reversal()
3125 ? NSWSub(b_->getInt64(window.dimensions(i).size() - 1),
3126 kernel_spatial[i])
3127 : kernel_spatial[i];
3128 }
3129
3130 kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature;
3131 kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature;
3132
3133 llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(),
3134 b_->getInt64Ty());
3135 TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
3136 input_generator(input_index));
3137 llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(),
3138 b_->getInt64Ty());
3139 TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
3140 kernel_generator(kernel_index));
3141 llvm::Value* sum = EmitMulAdd(input_value, kernel_value, Load(sum_address),
3142 convolution->shape().element_type());
3143 Store(sum, sum_address);
3144
3145 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
3146 return FPCast(Load(sum_address), lhs_llvm_type);
3147 }
3148
3149 // Evaluate polynomial using Horner's method.
EvaluatePolynomial(llvm::Type * type,llvm::Value * x,absl::Span<const double> coefficients)3150 StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
3151 llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
3152 llvm::Value* poly = llvm::ConstantFP::get(type, 0.0);
3153 for (const double c : coefficients) {
3154 poly = FAdd(FMul(poly, x), llvm::ConstantFP::get(type, c));
3155 }
3156 return poly;
3157 }
3158
3159 } // namespace xla
3160