1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <vector>
22
23 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/strings/str_cat.h"
27 #include "llvm/IR/BasicBlock.h"
28 #include "llvm/IR/Constants.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/Intrinsics.h"
31 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
32 #include "tensorflow/compiler/xla/primitive_util.h"
33 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
34 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
35 #include "tensorflow/compiler/xla/service/hlo_module.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
38 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/statusor.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/compiler/xla/window_util.h"
46 #include "tensorflow/compiler/xla/xla_data.pb.h"
47 #include "tensorflow/core/lib/random/random.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/types.h"
50
51 namespace xla {
52
53 using absl::StrCat;
54 using llvm_ir::IrArray;
55 using llvm_ir::IrName;
56 using llvm_ir::SetToFirstInsertPoint;
57
58 namespace {
59
GlobalRandomValue()60 int64 GlobalRandomValue() {
61 static auto* mu = new tensorflow::mutex();
62 static std::mt19937_64 rng{42};
63 tensorflow::mutex_lock l(*mu);
64 return rng();
65 }
66
EmitReducePrecisionIR(PrimitiveType src_ty,llvm::Value * x,int64 dest_exponent_bits,int64 dest_mantissa_bits,llvm::IRBuilder<> * b)67 StatusOr<llvm::Value*> EmitReducePrecisionIR(PrimitiveType src_ty,
68 llvm::Value* x,
69 int64 dest_exponent_bits,
70 int64 dest_mantissa_bits,
71 llvm::IRBuilder<>* b) {
72 using llvm::APInt;
73
74 if (!primitive_util::IsFloatingPointType(src_ty)) {
75 return Unimplemented(
76 "ReducePrecision cannot accept non-floating-point type %s.",
77 PrimitiveType_Name(src_ty));
78 }
79
80 // Integer and float types for casting and constant generation.
81 llvm::Type* float_type = x->getType();
82 int64 nbits = float_type->getPrimitiveSizeInBits();
83 llvm::IntegerType* int_type = b->getIntNTy(nbits);
84
85 // SignificandWidth includes the implicit extra bit.
86 int src_mantissa_bits = primitive_util::SignificandWidth(src_ty) - 1;
87 int src_exponent_bits = nbits - 1 - src_mantissa_bits;
88
89 // Cast the input value to an integer for bitwise manipulation.
90 llvm::Value* x_as_int = b->CreateBitCast(x, int_type);
91
92 if (dest_mantissa_bits < src_mantissa_bits) {
93 // Last remaining mantissa bit.
94 APInt last_mantissa_bit_mask(nbits, 1);
95 last_mantissa_bit_mask <<= src_mantissa_bits - dest_mantissa_bits;
96
97 // Compute rounding bias for round-to-nearest with ties to even. This is
98 // equal to a base value of 0111... plus one bit if the last remaining
99 // mantissa bit is 1.
100 APInt base_rounding_bias = last_mantissa_bit_mask.lshr(1) - 1;
101 llvm::Value* x_last_mantissa_bit = b->CreateLShr(
102 b->CreateAnd(x_as_int,
103 llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
104 (src_mantissa_bits - dest_mantissa_bits));
105 llvm::Value* x_rounding_bias =
106 b->CreateAdd(x_last_mantissa_bit,
107 llvm::ConstantInt::get(int_type, base_rounding_bias));
108
109 // Add rounding bias, and mask out truncated bits. Note that the case
110 // where adding the rounding bias overflows into the exponent bits is
111 // correct; the non-masked mantissa bits will all be zero, and the
112 // exponent will be incremented by one.
113 APInt truncation_mask = ~(last_mantissa_bit_mask - 1);
114 x_as_int = b->CreateAdd(x_as_int, x_rounding_bias);
115 x_as_int = b->CreateAnd(x_as_int,
116 llvm::ConstantInt::get(int_type, truncation_mask));
117 }
118
119 if (dest_exponent_bits < src_exponent_bits) {
120 APInt sign_bit_mask(nbits, 1);
121 sign_bit_mask <<= nbits - 1;
122
123 APInt exp_bits_mask(nbits, 1);
124 exp_bits_mask = ((exp_bits_mask << src_exponent_bits) - 1)
125 << src_mantissa_bits;
126
127 // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
128 // significant bit -- is equal to 1.0f for all exponent sizes. Adding
129 // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
130 // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
131 // exponent (corresponding to 0.0f).
132 //
133 // Thus, the f32 exponent corresponding to the highest non-infinite
134 // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
135 // exponent corresponding to the lowest exponent for a bit size of n is
136 // (2^7-1) - 2^(n-1)-1.
137 //
138 // Note that we have already checked that exponents_bits >= 1.
139 APInt exponent_bias(nbits, 1);
140 exponent_bias = (exponent_bias << (src_exponent_bits - 1)) - 1;
141
142 APInt reduced_exponent_bias(nbits, 1);
143 reduced_exponent_bias =
144 (reduced_exponent_bias << (dest_exponent_bits - 1)) - 1;
145
146 APInt reduced_max_exponent = exponent_bias + reduced_exponent_bias;
147 APInt reduced_min_exponent = exponent_bias - reduced_exponent_bias;
148
149 // Do we overflow or underflow?
150 llvm::Value* x_exponent =
151 b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, exp_bits_mask));
152 llvm::Value* x_overflows = b->CreateICmpUGT(
153 x_exponent, llvm::ConstantInt::get(
154 int_type, reduced_max_exponent << src_mantissa_bits));
155 llvm::Value* x_underflows = b->CreateICmpULE(
156 x_exponent, llvm::ConstantInt::get(
157 int_type, reduced_min_exponent << src_mantissa_bits));
158
159 // Compute appropriately-signed values of zero and infinity.
160 llvm::Value* x_signed_zero =
161 b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, sign_bit_mask));
162 llvm::Value* x_signed_inf = b->CreateOr(
163 x_signed_zero, llvm::ConstantInt::get(int_type, exp_bits_mask));
164
165 // Force to zero or infinity if overflow or underflow. (Note that this
166 // truncates all denormal values to zero, rather than rounding them.)
167 x_as_int = b->CreateSelect(x_overflows, x_signed_inf, x_as_int);
168 x_as_int = b->CreateSelect(x_underflows, x_signed_zero, x_as_int);
169 }
170
171 // Cast the result back to a floating-point type.
172 llvm::Value* result = b->CreateBitCast(x_as_int, float_type);
173
174 // Correct result for NaN inputs.
175 //
176 // The exponent handling will "normalize" NaN values to infinities, which is
177 // undesirable (except in the case with no mantissa bits, in which case it
178 // is mandatory). This logic also handles cases where mantissa-rounding
179 // causes a NaN's mantissa to overflow into the exponent bits, which would
180 // otherwise create an erroneous zero value.
181 //
182 // If the fast-math flags are set to assume no NaNs, the comparison is likely
183 // to be optimized away, so there's no point in even emitting it.
184 if (!b->getFastMathFlags().noNaNs()) {
185 llvm::Value* x_is_nan = b->CreateFCmpUNO(x, x);
186
187 if (dest_mantissa_bits > 0) {
188 result = b->CreateSelect(x_is_nan, x, result);
189 } else {
190 result = b->CreateSelect(
191 x_is_nan, llvm::ConstantFP::getInfinity(float_type), result);
192 }
193 }
194 return result;
195 }
196
EmitF32ToBF16(llvm::Value * f32_value,llvm::IRBuilder<> * b)197 StatusOr<llvm::Value*> EmitF32ToBF16(llvm::Value* f32_value,
198 llvm::IRBuilder<>* b) {
199 TF_ASSIGN_OR_RETURN(
200 auto reduced_precision,
201 EmitReducePrecisionIR(
202 /*src_ty=*/F32, f32_value,
203 /*dest_exponent_bits=*/primitive_util::ExponentWidth(BF16),
204 /*dest_mantissa_bits=*/primitive_util::SignificandWidth(BF16) - 1,
205 b));
206 auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty());
207 auto shifted = b->CreateLShr(as_int32, 16);
208 auto truncated = b->CreateTrunc(shifted, b->getInt16Ty());
209 return b->CreateBitCast(truncated, b->getInt16Ty());
210 }
211
EmitBF16ToF32(llvm::Value * bf16_value,llvm::IRBuilder<> * b)212 llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, llvm::IRBuilder<>* b) {
213 auto as_int16 = b->CreateBitCast(bf16_value, b->getInt16Ty());
214 auto as_int32 = b->CreateZExt(as_int16, b->getInt32Ty());
215 auto shifted = b->CreateShl(as_int32, 16);
216 return b->CreateBitCast(shifted, b->getFloatTy());
217 }
218
EmitIntegralToFloating(llvm::Value * integer_value,PrimitiveType from_type,PrimitiveType to_type,llvm::Module * module,llvm::IRBuilder<> * b)219 llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
220 PrimitiveType from_type,
221 PrimitiveType to_type, llvm::Module* module,
222 llvm::IRBuilder<>* b) {
223 if (primitive_util::IsSignedIntegralType(from_type)) {
224 return b->CreateSIToFP(integer_value,
225 llvm_ir::PrimitiveTypeToIrType(to_type, module));
226 } else {
227 CHECK(primitive_util::IsUnsignedIntegralType(from_type) ||
228 from_type == PRED);
229 return b->CreateUIToFP(integer_value,
230 llvm_ir::PrimitiveTypeToIrType(to_type, module));
231 }
232 }
233
234 } // namespace
235
EmitUnaryOp(const HloInstruction * op,llvm::Value * operand_value)236 StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
237 const HloInstruction* op, llvm::Value* operand_value) {
238 if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
239 op->operand(0)->shape().element_type() == PRED) {
240 return EmitIntegerUnaryOp(op, operand_value);
241 } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
242 return EmitComplexUnaryOp(op, operand_value);
243 } else {
244 return EmitFloatUnaryOp(op, operand_value);
245 }
246 }
247
EmitIntegerUnaryOp(const HloInstruction * op,llvm::Value * operand_value)248 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
249 const HloInstruction* op, llvm::Value* operand_value) {
250 switch (op->opcode()) {
251 case HloOpcode::kConvert: {
252 PrimitiveType from_type = op->operand(0)->shape().element_type();
253 PrimitiveType to_type = op->shape().element_type();
254 CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED)
255 << from_type;
256 if (from_type == to_type) {
257 return operand_value;
258 }
259 if (to_type == PRED) {
260 return b_->CreateZExt(
261 ICmpNE(operand_value,
262 llvm::ConstantInt::get(operand_value->getType(), 0)),
263 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
264 }
265 if (primitive_util::IsIntegralType(to_type)) {
266 return IntCast(operand_value,
267 llvm_ir::PrimitiveTypeToIrType(to_type, module_),
268 primitive_util::IsSignedIntegralType(from_type));
269 }
270 if (primitive_util::IsFloatingPointType(to_type)) {
271 if (to_type == BF16) {
272 return EmitF32ToBF16(EmitIntegralToFloating(operand_value, from_type,
273 F32, module_, b_),
274 b_);
275 }
276 return EmitIntegralToFloating(operand_value, from_type, to_type,
277 module_, b_);
278 }
279 if (primitive_util::IsComplexType(to_type)) {
280 auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
281 primitive_util::ComplexComponentType(to_type), module_);
282 if (primitive_util::IsSignedIntegralType(from_type)) {
283 return EmitComposeComplex(
284 op, SIToFP(operand_value, to_ir_component_type), nullptr);
285 }
286 if (primitive_util::IsUnsignedIntegralType(from_type) ||
287 from_type == PRED) {
288 return EmitComposeComplex(
289 op, UIToFP(operand_value, to_ir_component_type), nullptr);
290 }
291 }
292 return Unimplemented("conversion from primitive type %s to %s",
293 PrimitiveType_Name(from_type),
294 PrimitiveType_Name(to_type));
295 }
296 case HloOpcode::kBitcastConvert: {
297 PrimitiveType from_type = op->operand(0)->shape().element_type();
298 PrimitiveType to_type = op->shape().element_type();
299 CHECK(primitive_util::IsIntegralType(from_type));
300 if (from_type == to_type) {
301 return operand_value;
302 }
303 if (primitive_util::BitWidth(from_type) ==
304 primitive_util::BitWidth(to_type)) {
305 return BitCast(operand_value,
306 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
307 }
308 return InvalidArgument(
309 "bitcast conversion from primitive type %s to %s with unequal "
310 "bit-widths (%u versus %u) ",
311 PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
312 primitive_util::BitWidth(from_type),
313 primitive_util::BitWidth(to_type));
314 }
315 case HloOpcode::kAbs: {
316 bool is_signed =
317 primitive_util::IsSignedIntegralType(op->shape().element_type());
318 if (is_signed) {
319 auto type =
320 llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
321 auto cmp = ICmpSGE(operand_value, GetZero(type));
322 return Select(cmp, operand_value, Neg(operand_value));
323 } else {
324 return operand_value;
325 }
326 }
327 case HloOpcode::kClz: {
328 auto is_zero_undef = b_->getFalse();
329 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctlz,
330 {operand_value, is_zero_undef},
331 {operand_value->getType()}, b_);
332 }
333 case HloOpcode::kSign: {
334 CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type()))
335 << op->shape().element_type();
336 auto type =
337 llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
338 auto cmp = ICmpEQ(operand_value, GetZero(type));
339 auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1);
340 return Select(cmp, GetZero(type), Or(ashr, 1));
341 }
342 case HloOpcode::kNegate:
343 return Neg(operand_value);
344 case HloOpcode::kNot: {
345 auto type = op->shape().element_type();
346 if (type == PRED) {
347 // It is not sufficient to just call CreateNot() here because a PRED
348 // is represented as an i8 and the truth value is stored only in the
349 // bottom bit.
350 return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())),
351 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
352 } else if (primitive_util::IsIntegralType(type)) {
353 return Not(operand_value);
354 }
355 return Unimplemented("unary op Not is not defined for type '%d'", type);
356 }
357 case HloOpcode::kPopulationCount: {
358 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctpop,
359 {operand_value},
360 {operand_value->getType()}, b_);
361 }
362 default:
363 return Unimplemented("unary integer op '%s'",
364 HloOpcodeString(op->opcode()));
365 }
366 }
367
EmitFloatUnaryOp(const HloInstruction * op,llvm::Value * operand_value)368 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
369 const HloInstruction* op, llvm::Value* operand_value) {
370 switch (op->opcode()) {
371 case HloOpcode::kConvert: {
372 PrimitiveType from_type = op->operand(0)->shape().element_type();
373 PrimitiveType to_type = op->shape().element_type();
374 CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type;
375 if (from_type == to_type) {
376 return operand_value;
377 }
378 if (from_type == BF16) {
379 TF_RET_CHECK(to_type != BF16);
380 operand_value = EmitBF16ToF32(operand_value, b_);
381 from_type = F32;
382 if (from_type == to_type) {
383 return operand_value;
384 }
385 }
386 if (primitive_util::IsComplexType(to_type)) {
387 PrimitiveType to_component_type =
388 primitive_util::ComplexComponentType(to_type);
389 if (from_type == to_component_type) {
390 return EmitComposeComplex(op, operand_value, nullptr);
391 }
392 return EmitComposeComplex(
393 op,
394 FPCast(operand_value,
395 llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
396 nullptr);
397 }
398 if (to_type == BF16) {
399 // Cast to F32 first. Other floating point formats are not supported by
400 // EmitReducePrecisionIR.
401 if (from_type != F32) {
402 operand_value = b_->CreateFPCast(
403 operand_value, llvm_ir::PrimitiveTypeToIrType(F32, module_));
404 }
405 return EmitF32ToBF16(operand_value, b_);
406 }
407 if (to_type == PRED) {
408 return b_->CreateZExt(
409 FCmpUNE(operand_value,
410 llvm::ConstantFP::get(operand_value->getType(), 0.0)),
411 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
412 }
413 if (primitive_util::IsFloatingPointType(to_type)) {
414 return FPCast(operand_value,
415 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
416 }
417 if (primitive_util::IsSignedIntegralType(to_type)) {
418 return FPToSI(operand_value,
419 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
420 }
421 if (primitive_util::IsUnsignedIntegralType(to_type)) {
422 return FPToUI(operand_value,
423 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
424 }
425 return Unimplemented("unhandled conversion operation: %s => %s",
426 PrimitiveType_Name(from_type),
427 PrimitiveType_Name(to_type));
428 }
429 case HloOpcode::kBitcastConvert: {
430 PrimitiveType from_type = op->operand(0)->shape().element_type();
431 PrimitiveType to_type = op->shape().element_type();
432 CHECK(primitive_util::IsFloatingPointType(from_type));
433 if (from_type == to_type) {
434 return operand_value;
435 }
436 if (primitive_util::BitWidth(from_type) ==
437 primitive_util::BitWidth(to_type)) {
438 return BitCast(operand_value,
439 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
440 }
441 return InvalidArgument(
442 "bitcast conversion from primitive type %s to %s with unequal "
443 "bit-widths (%u versus %u) ",
444 PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
445 primitive_util::BitWidth(from_type),
446 primitive_util::BitWidth(to_type));
447 }
448 case HloOpcode::kExp:
449 return EmitExp(op->shape().element_type(), operand_value, "");
450 case HloOpcode::kExpm1:
451 return EmitExpm1(op->shape().element_type(), operand_value);
452 case HloOpcode::kLog:
453 return EmitLog(op->shape().element_type(), operand_value);
454 case HloOpcode::kLog1p:
455 return EmitLog1p(op->shape().element_type(), operand_value);
456 case HloOpcode::kCos:
457 return EmitCos(op->shape().element_type(), operand_value);
458 case HloOpcode::kSin:
459 return EmitSin(op->shape().element_type(), operand_value);
460 case HloOpcode::kTanh:
461 return EmitTanh(op->shape().element_type(), operand_value);
462 case HloOpcode::kSqrt:
463 return EmitSqrt(op->shape().element_type(), operand_value);
464 case HloOpcode::kRsqrt:
465 return EmitRsqrt(op->shape().element_type(), operand_value);
466 case HloOpcode::kCbrt:
467 return EmitCbrt(op->shape().element_type(), operand_value);
468 case HloOpcode::kFloor:
469 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
470 {operand_value},
471 {operand_value->getType()}, b_);
472 case HloOpcode::kCeil:
473 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ceil,
474 {operand_value},
475 {operand_value->getType()}, b_);
476 case HloOpcode::kAbs:
477 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
478 {operand_value},
479 {operand_value->getType()}, b_);
480 case HloOpcode::kRoundNearestAfz:
481 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round,
482 {operand_value},
483 {operand_value->getType()}, b_);
484 case HloOpcode::kSign: {
485 auto type = operand_value->getType();
486 auto zero = llvm::ConstantFP::get(type, 0.0);
487 auto ne0_i1 = FCmpONE(operand_value, zero);
488 auto ne0_float = UIToFP(ne0_i1, type);
489 llvm::Value* result = llvm_ir::EmitCallToIntrinsic(
490 llvm::Intrinsic::copysign, {ne0_float, operand_value},
491 {operand_value->getType()}, b_);
492 auto is_nan = FCmpUNO(operand_value, operand_value);
493 result = Select(is_nan, operand_value, result);
494 return result;
495 }
496 case HloOpcode::kIsFinite: {
497 // abs(x) o!= inf, this works because the comparison returns false if
498 // either operand is NaN.
499 auto type = operand_value->getType();
500 auto abs_value = llvm_ir::EmitCallToIntrinsic(
501 llvm::Intrinsic::fabs, {operand_value}, {type}, b_);
502 auto infinity = llvm::ConstantFP::getInfinity(type);
503 auto not_infinite = FCmpONE(abs_value, infinity);
504 return b_->CreateZExt(not_infinite,
505 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
506 }
507 case HloOpcode::kNegate:
508 return FNeg(operand_value);
509 case HloOpcode::kReal:
510 return operand_value;
511 case HloOpcode::kImag:
512 return llvm::ConstantFP::get(operand_value->getType(), 0.0);
513 default:
514 return Unimplemented("unary floating-point op '%s'",
515 HloOpcodeString(op->opcode()));
516 }
517 }
518
EmitComplexUnaryOp(const HloInstruction * op,llvm::Value * operand_value)519 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
520 const HloInstruction* op, llvm::Value* operand_value) {
521 PrimitiveType input_type = op->operand(0)->shape().element_type();
522 PrimitiveType component_type =
523 primitive_util::IsComplexType(input_type)
524 ? primitive_util::ComplexComponentType(input_type)
525 : input_type;
526 switch (op->opcode()) {
527 case HloOpcode::kLog: {
528 // log(a+bi) = log(abs(a+bi)) + i*atan2(b,a)
529 auto a = EmitExtractReal(operand_value);
530 auto b = EmitExtractImag(operand_value);
531 TF_ASSIGN_OR_RETURN(llvm::Value * angle,
532 EmitAtan2(component_type, b, a, ""));
533 TF_ASSIGN_OR_RETURN(llvm::Value * abs,
534 EmitComplexAbs(component_type, operand_value));
535 TF_ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs));
536 return EmitComposeComplex(op, log_abs, angle);
537 }
538 case HloOpcode::kLog1p: {
539 // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
540 auto a = EmitExtractReal(operand_value);
541 auto b = EmitExtractImag(operand_value);
542 llvm::Type* llvm_ty = a->getType();
543 auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
544 auto a_plus_one = FAdd(a, one);
545 auto sum_sq = FAdd(FMul(a_plus_one, a_plus_one), FMul(b, b));
546 TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
547 TF_ASSIGN_OR_RETURN(auto angle,
548 EmitAtan2(component_type, b, a_plus_one, ""));
549 auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
550 return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
551 }
552 case HloOpcode::kConvert: {
553 PrimitiveType from_type = op->operand(0)->shape().element_type();
554 TF_RET_CHECK(primitive_util::IsComplexType(from_type));
555 PrimitiveType to_type = op->shape().element_type();
556 TF_RET_CHECK(primitive_util::IsComplexType(to_type));
557 if (from_type == to_type) {
558 return operand_value;
559 }
560 PrimitiveType to_component_type =
561 primitive_util::ComplexComponentType(to_type);
562 auto to_ir_component_type =
563 llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
564 return EmitComposeComplex(
565 op, FPCast(EmitExtractReal(operand_value), to_ir_component_type),
566 FPCast(EmitExtractImag(operand_value), to_ir_component_type));
567 }
568 case HloOpcode::kExp: {
569 // e^(a+bi) = e^a*(cos(b)+sin(b)i)
570 TF_ASSIGN_OR_RETURN(
571 auto exp_a,
572 EmitExp(component_type, EmitExtractReal(operand_value), ""));
573 TF_ASSIGN_OR_RETURN(
574 auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
575 TF_ASSIGN_OR_RETURN(
576 auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
577 return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b));
578 }
579 case HloOpcode::kExpm1: {
580 // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
581 TF_ASSIGN_OR_RETURN(
582 auto exp_a,
583 EmitExp(component_type, EmitExtractReal(operand_value), ""));
584 TF_ASSIGN_OR_RETURN(
585 auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
586 TF_ASSIGN_OR_RETURN(
587 auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
588 auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0);
589 auto real_result = FSub(FMul(exp_a, cos_b), one);
590 auto imag_result = FMul(exp_a, sin_b);
591 return EmitComposeComplex(op, real_result, imag_result);
592 }
593 case HloOpcode::kCos: {
594 // cos(z) = .5(e^(iz) + e^(-iz))
595 // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
596 // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
597 // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i))
598 // cos(-x) = cos(x) and sin(-x) = -sin(x), so
599 // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
600 // = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
601 auto a = EmitExtractReal(operand_value);
602 auto b = EmitExtractImag(operand_value);
603 auto type = a->getType();
604 TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
605 auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
606 auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
607 TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
608 TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
609 return EmitComposeComplex(op,
610 FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)),
611 FMul(sin_a, FSub(half_exp_neg_b, half_exp_b)));
612 }
613 case HloOpcode::kSin: {
614 // sin(z) = .5i(e^(-iz) - e^(iz))
615 // sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi)))
616 // = .5i(e^(b-ai) - e^(-b+ai))
617 // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
618 // sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i))
619 // = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a)))
620 // cos(-x) = cos(x) and sin(-x) = -sin(x), so
621 // = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
622 // = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
623 auto a = EmitExtractReal(operand_value);
624 auto b = EmitExtractImag(operand_value);
625 auto type = a->getType();
626 TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
627 auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
628 auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
629 TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
630 TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
631 return EmitComposeComplex(op,
632 FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)),
633 FMul(cos_a, FSub(half_exp_b, half_exp_neg_b)));
634 }
635 case HloOpcode::kTanh: {
636 /*
637 tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x))
638 e^(a+bi) = e^a*(cos(b)+sin(b)i)
639 so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) /
640 (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a))
641 cos(b)=cos(-b), sin(-b)=-sin(b)
642 so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) /
643 (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a))
644 =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) /
645 (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a))
646 =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) /
647 (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a))
648 This is a complex division, so we can multiply by denom_conj/denom_conj
649 =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) *
650 (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) /
651 ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
652 =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
653 i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
654 ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
655 =(e^(2a)-e^(-2a) +
656 i*[cos(b)sin(b)(e^(2a)+2+e^(-2a))-cos(b)sin(b)(e^(2a)-2+e^(2a)))]
657 / (cos(b)^2*(e^(2a)+2+e^(-2a)) + sin(b)^2*(e^(2a)-2+e^(2a))
658 =(e^(2a)-e^(-2a) +
659 i*cos(b)sin(b)*[e^(2a)+2+e^(-2a)-e^(2a)+2-e^(-2a)]) /
660 ([cos(b)^2 + sin(b)^2][e^(2a)+e^(-2a)])+2*[cos(b)^2 - sin(b)^2])
661 =(e^(2a)-e^(-2a) + i*cos(b)sin(b)*4) /
662 (e^(2a)+e^(-2a)+2*[cos(b)^2 - sin(b)^2])
663 =(e^(2a)-e^(-2a) + i*[sin(2b)/2]*4) /
664 (e^(2a)+e^(-2a)+2*[cos(2b)])
665 =(e^(2a)-e^(-2a) + i*2*sin(2b)) / (e^(2a) + e^(-2a) + 2*cos(2b))
666 */
667 llvm::Value* a = EmitExtractReal(operand_value);
668 llvm::Value* b = EmitExtractImag(operand_value);
669
670 llvm::Type* type = a->getType();
671
672 llvm::Value* neg_one = llvm::ConstantFP::get(type, -1.F);
673 llvm::Value* two_a = FAdd(a, a);
674 llvm::Value* neg_2a = FMul(neg_one, two_a);
675
676 // When we are calculating the real numerator, e^(2a)-e^(-2a), for small
677 // values of `a`, we will get a ULP of 2^-23 using the exp function. Using
678 // expm1 to calculate e^(2a)-e^(-2a) = [e^(2a)-1] - [e^(-2a)-1] allows our
679 // ULP to be arbitrarily small. For larger values of `a`, calculating the
680 // numerator as Exp(2a)-Exp(-2a) vs Expm1(2a)-Expm1(-2a) return virtually
681 // identical results.
682 TF_ASSIGN_OR_RETURN(llvm::Value * exp_2a_m1,
683 EmitExpm1(component_type, two_a));
684 TF_ASSIGN_OR_RETURN(llvm::Value * exp_neg_2a_m1,
685 EmitExpm1(component_type, neg_2a));
686 llvm::Value* real_numerator = FSub(exp_2a_m1, exp_neg_2a_m1);
687
688 // We can use the identity cos(2b)+1 = cos(b)^2-sin(b)^2+cos(b)^2+sin(b)^2
689 // = 2cos(b)^2. This gives us the ability to be more precise when the
690 // denominator is close to zero.
691 TF_ASSIGN_OR_RETURN(llvm::Value * cos_b, EmitCos(component_type, b));
692 llvm::Value* four = llvm::ConstantFP::get(type, 4.F);
693 llvm::Value* cos_b_sq = FMul(cos_b, cos_b);
694 llvm::Value* two_cos_2b_p2 = FMul(cos_b_sq, four);
695
696 // Similarly we can compute sin(2b) with the formula sin(2b) =
697 // 2*sin(b)*cos(b).
698 TF_ASSIGN_OR_RETURN(llvm::Value * sin_b, EmitSin(component_type, b));
699 llvm::Value* imag_numerator = FMul(four, FMul(cos_b, sin_b));
700
701 // Expm1(x) is about x for small values of x, but exp_sum_m2 is about x^2
702 // for small value of x. As a result, due to floating point precision
703 // issues, x^2 is a better approximation than Expm1(x) + Expm1(x) for
704 // small values of x.
705 llvm::Value* a_sqr = FMul(a, a);
706 llvm::Value* use_approx_cutoff = llvm::ConstantFP::get(type, 1e-8);
707 llvm::Value* use_approx = FCmpOLT(a_sqr, use_approx_cutoff);
708
709 llvm::Value* exp_sum_m2 =
710 Select(use_approx, a_sqr, FAdd(exp_2a_m1, exp_neg_2a_m1));
711 llvm::Value* denom = FAdd(exp_sum_m2, two_cos_2b_p2);
712
713 // As `a` grows toward +inf and -inf, the real numerator will grow towards
714 // +inf and -inf respectively, while the denominator will always grow
715 // towards +inf. The result is real_numerator/denom = NaN, when it should
716 // equal +1 and -1 respectively. Therefore, if our denominator is +inf,
717 // we just hardcode the limits for the real numbers.
718 llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
719 llvm::Value* is_inf = FCmpOEQ(exp_sum_m2, inf);
720 llvm::Value* real_limit = llvm_ir::EmitCallToIntrinsic(
721 llvm::Intrinsic::copysign, {neg_one, a}, {type}, b_);
722
723 llvm::Value* real =
724 Select(is_inf, real_limit, FDiv(real_numerator, denom));
725 llvm::Value* imag = FDiv(imag_numerator, denom);
726
727 // The complex tanh functions have a few corner cases:
728 // 1. (+0, +0) => (+0, +0) - Handled normally
729 // 2. (x, +Inf) => (NaN, NaN) - See below
730 // 3. (x, NaN) => (NaN, NaN) - See below
731 // 4. (+inf, y) => (1, +0) - Handled normally
732 // 5. (+Inf, +Inf) => (1, +/-0) - See below
733 // 6. (+Inf, NaN) => (1, +/-0) - See below
734 // 7. (NaN, +0) => (NaN, +0) - See below
735 // 8. (NaN, y) => (NaN, NaN) - Handled normally
736 // 9. (NaN, NaN) => (NaN, NaN) - Handled normally
737 //
738 // For the cases that aren't handled normally:
739 // 2/3) Part of the calculation we do is that if exp(a) + exp(-a) = +inf,
740 // then we return (+/-1, +/-0). However, this is only true if we
741 // assume that a is infinity or b is finite. In the event that both a
742 // is finite and b is either +/-Inf or NaN, then our normal
743 // calculation would end up returing (+/-1, NaN), as opposed to (NaN,
744 // NaN).
745 // 5/6) We always calculate the imaginary value as sin(2b)/denominator.
746 // When the denominator is infinity, this assures us that the zero is
747 // the correct sign. However if our imaginary input results in
748 // sin(2b) = NaN, we calculate our imaginary result as NaN.
749 // 7) In the event that a is NaN, the denominator will be NaN.
750 // Therefore, the normal calculation gives (NaN, NaN) while we need
751 // (NaN, +0).
752 if (!(b_->getFastMathFlags().noNaNs() &&
753 b_->getFastMathFlags().noInfs())) {
754 llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
755 {a}, {type}, b_);
756 llvm::Value* zero = llvm::ConstantFP::get(type, 0.F);
757 llvm::Value* nan = llvm::ConstantFP::getNaN(type);
758
759 llvm::Value* a_is_inf = FCmpOEQ(abs_a, inf);
760 llvm::Value* b_is_zero = FCmpOEQ(b, zero);
761
762 // imag_numerator = 2sin(2b), so sin(2b) is NaN if and only if
763 // imag_numerator is NaN.
764 llvm::Value* sin_2b_is_nan =
765 b_->CreateFCmpUNO(imag_numerator, imag_numerator);
766
767 llvm::Value* real_is_nan =
768 b_->CreateAnd(sin_2b_is_nan, b_->CreateNot(a_is_inf));
769 llvm::Value* imag_is_zero =
770 b_->CreateOr(b_is_zero, b_->CreateAnd(a_is_inf, sin_2b_is_nan));
771
772 real = Select(real_is_nan, nan, real);
773 imag = Select(imag_is_zero, zero, imag);
774 }
775
776 return EmitComposeComplex(op, real, imag);
777 }
778 case HloOpcode::kAbs: {
779 return EmitComplexAbs(component_type, operand_value);
780 }
781 case HloOpcode::kSign: { // Sign(c) = c / |c|
782 TF_ASSIGN_OR_RETURN(auto cplx_abs,
783 EmitComplexAbs(component_type, operand_value));
784 auto type = cplx_abs->getType();
785 auto zero = llvm::ConstantFP::get(type, 0.0);
786 auto oeq = FCmpOEQ(cplx_abs, zero);
787 return Select(
788 oeq, EmitComposeComplex(op, zero, zero),
789 EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs),
790 FDiv(EmitExtractImag(operand_value), cplx_abs)));
791 }
792 case HloOpcode::kSqrt: {
793 return EmitComplexSqrt(op, component_type, operand_value);
794 }
795 case HloOpcode::kRsqrt: {
796 return EmitComplexRsqrt(op, component_type, operand_value);
797 }
798 case HloOpcode::kCbrt: {
799 return EmitComplexCbrt(op, component_type, operand_value);
800 }
801 case HloOpcode::kNegate:
802 return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
803 FNeg(EmitExtractImag(operand_value)));
804 case HloOpcode::kReal:
805 return EmitExtractReal(operand_value);
806 case HloOpcode::kImag:
807 return EmitExtractImag(operand_value);
808 default:
809 return Unimplemented("unary complex op '%s'",
810 HloOpcodeString(op->opcode()));
811 }
812 }
813
EmitBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)814 StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
815 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
816 PrimitiveType operand_type = op->operand(0)->shape().element_type();
817 if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
818 operand_type == PRED) {
819 return EmitIntegerBinaryOp(
820 op, lhs_value, rhs_value,
821 primitive_util::IsSignedIntegralType(operand_type));
822 } else if (primitive_util::IsComplexType(operand_type)) {
823 return EmitComplexBinaryOp(op, lhs_value, rhs_value);
824 } else {
825 return EmitFloatBinaryOp(op, lhs_value, rhs_value);
826 }
827 }
828
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)829 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
830 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
831 switch (op->opcode()) {
832 case HloOpcode::kComplex:
833 return EmitComposeComplex(op, lhs_value, rhs_value);
834 case HloOpcode::kAdd:
835 return FAdd(lhs_value, rhs_value, op->name());
836 case HloOpcode::kSubtract:
837 return FSub(lhs_value, rhs_value, op->name());
838 case HloOpcode::kMultiply:
839 return FMul(lhs_value, rhs_value, op->name());
840 case HloOpcode::kDivide:
841 return FDiv(lhs_value, rhs_value, op->name());
842 case HloOpcode::kRemainder:
843 return FRem(lhs_value, rhs_value, op->name());
844 // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
845 // comparisons always return false when one of the operands is NaN, whereas
846 // unordered comparisons return true.
847 //
848 // We use ordered comparisons for everything except kNe, where we use an
849 // unordered comparison. This makes x != y equivalent to !(x == y), and
850 // matches C++'s semantics.
851 case HloOpcode::kCompare: {
852 switch (op->comparison_direction()) {
853 case ComparisonDirection::kEq:
854 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
855 rhs_value, b_, op->name());
856 case ComparisonDirection::kNe:
857 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
858 rhs_value, b_, op->name());
859 case ComparisonDirection::kLt:
860 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
861 rhs_value, b_, op->name());
862 case ComparisonDirection::kGt:
863 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
864 rhs_value, b_, op->name());
865 case ComparisonDirection::kLe:
866 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
867 rhs_value, b_, op->name());
868 case ComparisonDirection::kGe:
869 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
870 rhs_value, b_, op->name());
871 }
872 }
873 case HloOpcode::kMaximum:
874 return EmitFloatMax(lhs_value, rhs_value, op->name());
875 case HloOpcode::kMinimum:
876 return EmitFloatMin(lhs_value, rhs_value, op->name());
877 case HloOpcode::kPower:
878 return EmitPow(op->shape().element_type(), lhs_value, rhs_value,
879 op->name());
880 case HloOpcode::kAtan2:
881 return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value,
882 op->name());
883 default:
884 return Unimplemented("binary floating point op '%s'",
885 HloOpcodeString(op->opcode()));
886 }
887 }
888
889 // Using sqrt(a^2 + b^2) can cause overflow errors. Therefore we can use
890 // sqrt(a^2 + b^2) = sqrt(a^2 * (1 + b^2/a^2))
891 // = |a| * sqrt(1 + (b/a)^2)
892 // With the assumption that |a| >= |b|.
893 //
894 // This method returns the min, max, and sqrt term for this calculation. This is
895 // done to prevent potential overflow errors that can occur from multiplying the
896 // max with the sqrt term. (i.e. when calculating the sqrt of the absolute
897 // value, we can take the sqrt of the max and the sqrt term before multiplying
898 // them together.) If return_sqrt is false, it returns 1 + (b/a)^2 instead of
899 // sqrt(1 + (b/a)^2).
900 StatusOr<std::tuple<llvm::Value*, llvm::Value*, llvm::Value*>>
EmitComplexAbsHelper(PrimitiveType prim_type,llvm::Value * operand_value,bool return_sqrt)901 ElementalIrEmitter::EmitComplexAbsHelper(PrimitiveType prim_type,
902 llvm::Value* operand_value,
903 bool return_sqrt) {
904 llvm::Value* real = EmitExtractReal(operand_value);
905 llvm::Value* imag = EmitExtractImag(operand_value);
906 llvm::Value* abs_real = llvm_ir::EmitCallToIntrinsic(
907 llvm::Intrinsic::fabs, {real}, {real->getType()}, b_);
908 llvm::Value* abs_imag = llvm_ir::EmitCallToIntrinsic(
909 llvm::Intrinsic::fabs, {imag}, {imag->getType()}, b_);
910 llvm::Value* max = EmitFloatMax(abs_real, abs_imag, "");
911 llvm::Value* min = EmitFloatMin(abs_real, abs_imag, "");
912
913 llvm::Value* div = FDiv(min, max);
914 llvm::Value* div_sq = FMul(div, div);
915 llvm::Value* one = llvm::ConstantFP::get(max->getType(), 1);
916 llvm::Value* one_p_div_sq = FAdd(one, div_sq);
917 TF_ASSIGN_OR_RETURN(llvm::Value * sqrt, EmitSqrt(prim_type, one_p_div_sq));
918 return std::make_tuple(min, max, return_sqrt ? sqrt : one_p_div_sq);
919 }
920
EmitComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)921 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexAbs(
922 PrimitiveType prim_type, llvm::Value* operand_value) {
923 llvm::Value* min;
924 llvm::Value* max;
925 llvm::Value* sqrt;
926 TF_ASSIGN_OR_RETURN(
927 std::tie(min, max, sqrt),
928 EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
929 llvm::Value* result = FMul(max, sqrt);
930 // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
931 // In such cases, we return `min` instead of `result`.
932 return Select(FCmpUNO(result, result), min, result);
933 }
934
935 // Calculates ComplexAbs in the same way, except using:
936 // sqrt(|a| * sqrt(1 + (b/a)^2)) = sqrt(|a|) * pow(1 + (b/a)^2, .25)
EmitSqrtComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)937 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrtComplexAbs(
938 PrimitiveType prim_type, llvm::Value* operand_value) {
939 llvm::Value* min;
940 llvm::Value* max;
941 llvm::Value* one_p_div_sq;
942 TF_ASSIGN_OR_RETURN(
943 std::tie(min, max, one_p_div_sq),
944 EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/false));
945 TF_ASSIGN_OR_RETURN(llvm::Value * sqrt_max, EmitSqrt(prim_type, max));
946 TF_ASSIGN_OR_RETURN(llvm::Value * pow,
947 EmitPow(prim_type, one_p_div_sq,
948 llvm::ConstantFP::get(max->getType(), .25), ""));
949 llvm::Value* result = FMul(sqrt_max, pow);
950 // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
951 // In such cases, we return `min` instead of `result`.
952 return Select(FCmpUNO(result, result), min, result);
953 }
954
955 // Calculates ComplexAbs in the same way, except using:
956 // rsqrt(|a| * sqrt(1 + (b/a)^2)) = rsqrt(|a|) * rsqrt(sqrt(1 + (b/a)^2))
EmitRsqrtComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)957 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrtComplexAbs(
958 PrimitiveType prim_type, llvm::Value* operand_value) {
959 llvm::Value* min;
960 llvm::Value* max;
961 llvm::Value* sqrt;
962 TF_ASSIGN_OR_RETURN(
963 std::tie(min, max, sqrt),
964 EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
965 TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_max, EmitRsqrt(prim_type, max));
966 TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_sqrt, EmitRsqrt(prim_type, sqrt));
967 llvm::Value* result = FMul(rsqrt_max, rsqrt_sqrt);
968 TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_min, EmitRsqrt(prim_type, min));
969 // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
970 // In such cases, we return rsqrt(min) instead of `result`.
971 return Select(FCmpUNO(result, result), rsqrt_min, result);
972 }
973
974 // Using our EmitComplexPower formula, but setting c=0.5 and d=0, we get:
975 // e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
976 // = e^[ln(r)*0.5] * [cos(t*0.5) + i*sin(t*0.5)]
977 // = r^0.5 * [cos(t/2) + i*sin(t/2)]
978 // = sqrt(r) * [cos(t/2) + i*sin(t/2)]
979 // where r = |a+bi| and t = atan2(b,a)
980 // TODO(bixia): See doc for implementation without atan2.
EmitComplexSqrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)981 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexSqrt(
982 const HloInstruction* op, PrimitiveType prim_type,
983 llvm::Value* operand_value) {
984 llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
985 ->getElementType(0);
986
987 TF_ASSIGN_OR_RETURN(llvm::Value * r,
988 EmitSqrtComplexAbs(prim_type, operand_value));
989
990 llvm::Value* a = EmitExtractReal(operand_value);
991 llvm::Value* b = EmitExtractImag(operand_value);
992 TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));
993
994 llvm::Value* c = llvm::ConstantFP::get(type, 0.5);
995 llvm::Value* angle = FMul(t, c);
996 TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
997 TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
998
999 llvm::Value* real_part;
1000 llvm::Value* imag_part;
1001
1002 llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1003
1004 if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
1005 llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1006 llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
1007 llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1008 llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1009 {b}, {b->getType()}, b_);
1010
1011 real_part = Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, inf)), inf,
1012 Select(And(FCmpOEQ(a, neg_inf), FCmpONE(abs_b, inf)),
1013 zero, FMul(r, cos)));
1014
1015 llvm::Value* b_signed_inf = llvm_ir::EmitCallToIntrinsic(
1016 llvm::Intrinsic::copysign, {inf, b}, {b->getType()}, b_);
1017 imag_part =
1018 Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, neg_inf)), b_signed_inf,
1019 Select(FCmpUNO(r, r), nan,
1020 Select(FCmpOEQ(sin, zero), sin, FMul(r, sin))));
1021 } else {
1022 real_part = FMul(r, cos);
1023 imag_part = Select(FCmpOEQ(sin, zero), sin, FMul(r, sin));
1024 }
1025
1026 return Select(FCmpOEQ(r, zero), EmitComposeComplex(op, zero, zero),
1027 EmitComposeComplex(op, real_part, imag_part));
1028 }
1029
1030 // Similar to Sqrt, we can use our EmitComplexPower formula, but set
1031 // c=-0.5 and d=0. We get:
1032 // e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
1033 // = e^[ln(r)*-0.5] * [cos(t*-0.5) + i*sin(t*-0.5)]
1034 // = r^(-0.5) * [cos(-t/2) + i*sin(-t/2)]
1035 // = rsqrt(r) * [cos(-t/2) + i*sin(-t/2)]
1036 // where r = |a+bi| and t = atan2(b,a).
EmitComplexRsqrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1037 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexRsqrt(
1038 const HloInstruction* op, PrimitiveType prim_type,
1039 llvm::Value* operand_value) {
1040 llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
1041 ->getElementType(0);
1042
1043 TF_ASSIGN_OR_RETURN(llvm::Value * r,
1044 EmitRsqrtComplexAbs(prim_type, operand_value));
1045
1046 llvm::Value* a = EmitExtractReal(operand_value);
1047 llvm::Value* b = EmitExtractImag(operand_value);
1048 TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));
1049
1050 llvm::Value* c = llvm::ConstantFP::get(type, -0.5);
1051 llvm::Value* angle = FMul(t, c);
1052 TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
1053 TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
1054
1055 llvm::Value* real_part = FMul(r, cos);
1056 llvm::Value* imag_part = FMul(r, sin);
1057
1058 if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
1059 llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1060 llvm::Value* neg_one = llvm::ConstantFP::get(type, -1);
1061 llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1062 llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1063 // llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
1064 llvm::Value* a_signed_zero = llvm_ir::EmitCallToIntrinsic(
1065 llvm::Intrinsic::copysign, {zero, a}, {a->getType()}, b_);
1066 llvm::Value* b_signed_zero = llvm_ir::EmitCallToIntrinsic(
1067 llvm::Intrinsic::copysign, {zero, b}, {b->getType()}, b_);
1068 llvm::Value* neg_b_signed_zero = FMul(b_signed_zero, neg_one);
1069
1070 llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1071 {a}, {a->getType()}, b_);
1072 llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1073 {b}, {b->getType()}, b_);
1074
1075 llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
1076 real_part = Select(
1077 is_zero_zero, inf,
1078 Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
1079 a_signed_zero, FMul(r, cos)));
1080 imag_part = Select(
1081 is_zero_zero, nan,
1082 Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
1083 neg_b_signed_zero, FMul(r, sin)));
1084 } else {
1085 llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1086 llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1087 llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1088
1089 llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
1090 real_part = Select(is_zero_zero, inf, FMul(r, cos));
1091 imag_part = Select(is_zero_zero, nan, FMul(r, sin));
1092 }
1093
1094 return EmitComposeComplex(op, real_part, imag_part);
1095 }
1096
1097 //
1098 // Using EmitComplexPower with c=1.0/3.0 and d=0
EmitComplexCbrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1099 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexCbrt(
1100 const HloInstruction* op, PrimitiveType prim_type,
1101 llvm::Value* operand_value) {
1102 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1103 auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
1104 auto zero = llvm::ConstantFP::get(type, 0);
1105 llvm::Value* a = EmitExtractReal(operand_value);
1106 llvm::Value* b = EmitExtractImag(operand_value);
1107 return EmitComplexPower(op, a, b, third, zero);
1108 }
1109
1110 // (a+bi)^(c+di) =
1111 // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
1112 // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
EmitComplexPower(const HloInstruction * op,llvm::Value * a,llvm::Value * b,llvm::Value * c,llvm::Value * d)1113 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexPower(
1114 const HloInstruction* op, llvm::Value* a, llvm::Value* b, llvm::Value* c,
1115 llvm::Value* d) {
1116 PrimitiveType component_type =
1117 primitive_util::ComplexComponentType(op->shape().element_type());
1118 auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b));
1119 auto zero = llvm::ConstantFP::get(a->getType(), 0);
1120 auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
1121 auto one = llvm::ConstantFP::get(a->getType(), 1);
1122 auto half_c = FMul(one_half, c);
1123
1124 TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
1125 EmitPow(component_type, aa_p_bb, half_c, ""));
1126
1127 auto neg_d = FNeg(d);
1128 TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a, ""));
1129 auto neg_d_arg_lhs = FMul(neg_d, arg_lhs);
1130 TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
1131 EmitExp(component_type, neg_d_arg_lhs, ""));
1132 auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
1133 TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
1134 auto half_d = FMul(one_half, d);
1135 auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb));
1136 TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
1137 TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
1138 // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
1139 // Branch Cuts for Complex Elementary Functions or Much Ado About
1140 // Nothing's Sign Bit, W. Kahan, Section 10.
1141 return Select(
1142 And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)),
1143 EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero),
1144 EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)));
1145 }
1146
EmitComplexBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1147 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
1148 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1149 switch (op->opcode()) {
1150 case HloOpcode::kAdd:
1151 return EmitComposeComplex(
1152 op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1153 FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
1154 case HloOpcode::kSubtract:
1155 return EmitComposeComplex(
1156 op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1157 FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
1158 case HloOpcode::kMultiply:
1159 return EmitComposeComplex(
1160 op,
1161 FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1162 FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))),
1163 FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
1164 FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))));
1165 case HloOpcode::kDivide: {
1166 // Division of complex numbers is implemented here, taking into account
1167 // over/underflow, NaN and Inf values.
1168 auto a_r = EmitExtractReal(lhs_value);
1169 auto a_i = EmitExtractImag(lhs_value);
1170 auto b_r = EmitExtractReal(rhs_value);
1171 auto b_i = EmitExtractImag(rhs_value);
1172 auto type = a_r->getType();
1173
1174 // Smith's algorithm to divide complex numbers. It is just a bit smarter
1175 // way to compute the following formula:
1176 // (a_r + a_i * i) / (b_r + b_i * i)
1177 // = (a_r + a_i * i) (b_r - b_i * i) / ((b_r + b_i * i)(b_r - b_i * i))
1178 // = ((a_r * b_r + a_i * b_i) + (a_i * b_r - a_r * b_i) * i) / ||b||^2
1179 //
1180 // Depending on whether |b_r| < |b_i| we compute either
1181 // b_r_b_i_ratio = b_r / b_i
1182 // b_r_b_i_denom = b_i + b_r * b_r_b_i_ratio
1183 // c_r = (a_r * b_r_b_i_ratio + a_i ) / b_r_b_i_denom
1184 // c_i = (a_i * b_r_b_i_ratio - a_r ) / b_r_b_i_denom
1185 //
1186 // or
1187 //
1188 // b_i_b_r_ratio = b_i / b_r
1189 // b_i_b_r_denom = b_r + b_i * b_i_b_r_denom
1190 // c_r = (a_r + a_i * b_i_b_r_ratio ) / b_i_b_r_denom
1191 // c_i = (a_i - a_r * b_i_b_r_ratio ) / b_i_b_r_denom
1192 //
1193 // See https://dl.acm.org/citation.cfm?id=368661 for more details.
1194 auto b_r_b_i_ratio = FDiv(b_r, b_i);
1195 auto b_r_b_i_denom = FAdd(b_i, FMul(b_r_b_i_ratio, b_r));
1196 auto b_i_b_r_ratio = FDiv(b_i, b_r);
1197 auto b_i_b_r_denom = FAdd(b_r, FMul(b_i_b_r_ratio, b_i));
1198
1199 auto b_r_lt_b_i = FCmpOLT(llvm_ir::EmitCallToIntrinsic(
1200 llvm::Intrinsic::fabs, {b_r}, {type}, b_),
1201 llvm_ir::EmitCallToIntrinsic(
1202 llvm::Intrinsic::fabs, {b_i}, {type}, b_));
1203 auto c_r = Select(
1204 b_r_lt_b_i, FDiv(FAdd(FMul(b_r_b_i_ratio, a_r), a_i), b_r_b_i_denom),
1205 FDiv(FAdd(FMul(b_i_b_r_ratio, a_i), a_r), b_i_b_r_denom));
1206 auto c_i = Select(
1207 b_r_lt_b_i, FDiv(FSub(FMul(b_r_b_i_ratio, a_i), a_r), b_r_b_i_denom),
1208 FDiv(FSub(a_i, FMul(b_i_b_r_ratio, a_r)), b_i_b_r_denom));
1209 auto result = EmitComposeComplex(op, c_r, c_i);
1210
1211 // Consider corner cases, if the result is (NaN, NaN).
1212 auto zero = llvm::ConstantFP::get(type, 0.0);
1213 auto one = llvm::ConstantFP::get(type, 1.0);
1214 auto inf = llvm::ConstantFP::getInfinity(type);
1215
1216 // Case 1. Zero denominator.
1217 auto zero_denominator =
1218 And(And(FCmpOEQ(b_r, zero), FCmpOEQ(b_i, zero)),
1219 Or(Neg(FCmpONE(a_r, zero)), Neg(FCmpONE(a_i, zero))));
1220 auto inf_with_sign_of_c = llvm_ir::EmitCallToIntrinsic(
1221 llvm::Intrinsic::copysign, {inf, a_r}, {type}, b_);
1222 auto zero_denominator_result = EmitComposeComplex(
1223 op, FMul(inf_with_sign_of_c, a_r), FMul(inf_with_sign_of_c, a_i));
1224
1225 // Case 2. Infinite numerator, finite denominator.
1226 auto b_r_finite = FCmpONE(llvm_ir::EmitCallToIntrinsic(
1227 llvm::Intrinsic::fabs, {b_r}, {type}, b_),
1228 inf);
1229 auto b_i_finite = FCmpONE(llvm_ir::EmitCallToIntrinsic(
1230 llvm::Intrinsic::fabs, {b_i}, {type}, b_),
1231 inf);
1232 auto inf_num_finite_denom = And(Or(FCmpOEQ(a_r, inf), FCmpOEQ(a_i, inf)),
1233 And(b_r_finite, b_i_finite));
1234
1235 auto a_r_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1236 llvm::Intrinsic::copysign,
1237 {Select(FCmpOEQ(a_r, inf), one, zero), a_r}, {type}, b_);
1238 auto a_i_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1239 llvm::Intrinsic::copysign,
1240 {Select(FCmpOEQ(a_i, inf), one, zero), a_i}, {type}, b_);
1241 auto inf_num_finite_denom_result =
1242 EmitComposeComplex(op,
1243 FMul(inf, FAdd(FMul(a_r_inf_with_sign, b_r),
1244 FMul(a_i_inf_with_sign, b_i))),
1245 FMul(inf, FSub(FMul(a_i_inf_with_sign, b_r),
1246 FMul(a_r_inf_with_sign, b_i))));
1247
1248 // Case 3. Finite numerator, infinite denominator.
1249 auto a_r_finite = FCmpONE(llvm_ir::EmitCallToIntrinsic(
1250 llvm::Intrinsic::fabs, {a_r}, {type}, b_),
1251 inf);
1252 auto a_i_finite = FCmpONE(llvm_ir::EmitCallToIntrinsic(
1253 llvm::Intrinsic::fabs, {a_i}, {type}, b_),
1254 inf);
1255 auto finite_num_inf_denom = And(Or(FCmpOEQ(b_r, inf), FCmpOEQ(b_i, inf)),
1256 And(a_r_finite, a_i_finite));
1257
1258 auto b_r_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1259 llvm::Intrinsic::copysign,
1260 {Select(FCmpOEQ(b_r, inf), one, zero), b_r}, {type}, b_);
1261 auto b_i_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1262 llvm::Intrinsic::copysign,
1263 {Select(FCmpOEQ(b_i, inf), one, zero), b_i}, {type}, b_);
1264 auto finite_num_inf_denom_result =
1265 EmitComposeComplex(op,
1266 FMul(zero, FAdd(FMul(a_r, b_r_inf_with_sign),
1267 FMul(a_i, b_i_inf_with_sign))),
1268 FMul(zero, FSub(FMul(a_i, b_r_inf_with_sign),
1269 FMul(a_r, b_i_inf_with_sign))));
1270
1271 auto c_nan = And(FCmpUNO(c_r, zero), FCmpUNO(c_i, zero));
1272 return Select(
1273 c_nan,
1274 Select(zero_denominator, zero_denominator_result,
1275 Select(inf_num_finite_denom, inf_num_finite_denom_result,
1276 Select(finite_num_inf_denom,
1277 finite_num_inf_denom_result, result))),
1278 result);
1279 }
1280 // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
1281 // comparisons always return false when one of the operands is NaN, whereas
1282 // unordered comparisons return true.
1283 //
1284 // We use ordered comparisons for everything except kNe, where we use an
1285 // unordered comparison. This makes x != y equivalent to !(x == y), and
1286 // matches C++'s semantics.
1287 case HloOpcode::kCompare: {
1288 switch (op->comparison_direction()) {
1289 case ComparisonDirection::kEq:
1290 return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
1291 EmitExtractReal(lhs_value),
1292 EmitExtractReal(rhs_value), b_),
1293 llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
1294 EmitExtractImag(lhs_value),
1295 EmitExtractImag(rhs_value), b_));
1296 case ComparisonDirection::kNe:
1297 return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
1298 EmitExtractReal(lhs_value),
1299 EmitExtractReal(rhs_value), b_),
1300 llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
1301 EmitExtractImag(lhs_value),
1302 EmitExtractImag(rhs_value), b_));
1303 default:
1304 return Unimplemented(
1305 "complex comparison '%s'",
1306 ComparisonDirectionToString(op->comparison_direction()));
1307 }
1308 }
1309 case HloOpcode::kPower: {
1310 auto a = EmitExtractReal(lhs_value);
1311 auto b = EmitExtractImag(lhs_value);
1312 auto c = EmitExtractReal(rhs_value);
1313 auto d = EmitExtractImag(rhs_value);
1314 return EmitComplexPower(op, a, b, c, d);
1315 }
1316 default:
1317 return Unimplemented("binary complex op '%s'",
1318 HloOpcodeString(op->opcode()));
1319 }
1320 }
1321
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value,absl::string_view name)1322 llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
1323 llvm::Value* rhs_value,
1324 absl::string_view name) {
1325 return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max(), name);
1326 }
1327
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value,absl::string_view name)1328 llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
1329 llvm::Value* rhs_value,
1330 absl::string_view name) {
1331 return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max(), name);
1332 }
1333
EmitLog(PrimitiveType prim_type,llvm::Value * value)1334 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
1335 llvm::Value* value) {
1336 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
1337 {value->getType()}, b_);
1338 }
1339
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)1340 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
1341 llvm::Value* value) {
1342 auto x = value;
1343 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1344 auto one = llvm::ConstantFP::get(type, 1.0);
1345 auto negative_half = llvm::ConstantFP::get(type, -0.5);
1346 // When x is large, the naive evaluation of ln(x + 1) is more
1347 // accurate than the Taylor series.
1348 TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
1349 // When x is small, (defined to be less than sqrt(2) / 2), use a rational
1350 // approximation. The approximation below is based on one from the Cephes
1351 // Mathematical Library.
1352 //
1353 // sqrt(2) - 1.
1354 const auto kAntilogarithmIsSmallThreshold = 0.41421356237309504880;
1355
1356 static const std::array<double, 7> kDenominatorCoeffs{
1357 1.,
1358 1.5062909083469192043167E1,
1359 8.3047565967967209469434E1,
1360 2.2176239823732856465394E2,
1361 3.0909872225312059774938E2,
1362 2.1642788614495947685003E2,
1363 6.0118660497603843919306E1,
1364 };
1365
1366 static const std::array<double, 7> kNumeratorCoeffs{
1367 4.5270000862445199635215E-5, 4.9854102823193375972212E-1,
1368 6.5787325942061044846969E0, 2.9911919328553073277375E1,
1369 6.0949667980987787057556E1, 5.7112963590585538103336E1,
1370 2.0039553499201281259648E1,
1371 };
1372
1373 auto x_squared = FMul(x, x);
1374 TF_ASSIGN_OR_RETURN(auto denominator,
1375 EvaluatePolynomial(type, x, kDenominatorCoeffs));
1376 TF_ASSIGN_OR_RETURN(auto numerator,
1377 EvaluatePolynomial(type, x, kNumeratorCoeffs));
1378 auto for_small_x = FDiv(numerator, denominator);
1379 for_small_x = FMul(FMul(x, x_squared), for_small_x);
1380 for_small_x = FAdd(FMul(negative_half, x_squared), for_small_x);
1381 for_small_x = FAdd(x, for_small_x);
1382
1383 auto abs_x =
1384 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1385 auto x_is_small = FCmpOLT(
1386 abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold));
1387 return Select(x_is_small, for_small_x, for_large_x);
1388 }
1389
EmitSqrt(PrimitiveType,llvm::Value * value)1390 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrt(PrimitiveType,
1391 llvm::Value* value) {
1392 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {value},
1393 {value->getType()}, b_);
1394 }
1395
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)1396 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
1397 llvm::Value* value) {
1398 TF_ASSIGN_OR_RETURN(auto sqrt, EmitSqrt(prim_type, value));
1399 return FDiv(llvm::ConstantFP::get(sqrt->getType(), 1.0), sqrt);
1400 }
1401
EmitSin(PrimitiveType prim_type,llvm::Value * value)1402 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
1403 llvm::Value* value) {
1404 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
1405 {value->getType()}, b_);
1406 }
1407
EmitCos(PrimitiveType prim_type,llvm::Value * value)1408 StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
1409 llvm::Value* value) {
1410 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
1411 {value->getType()}, b_);
1412 }
1413
EmitExp(PrimitiveType prim_type,llvm::Value * value,absl::string_view name)1414 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
1415 llvm::Value* value,
1416 absl::string_view name) {
1417 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
1418 {value->getType()}, b_, name);
1419 }
1420
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)1421 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
1422 llvm::Value* value) {
1423 auto x = value;
1424 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1425 auto one = llvm::ConstantFP::get(type, 1.0);
1426 auto half = llvm::ConstantFP::get(type, 0.5);
1427 // When the exponent is large, the naive evaluation of e^(x) - 1 is more
1428 // accurate than the Taylor series.
1429 TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value, ""));
1430 auto for_large_x = FSub(exp_x, one);
1431 // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + ….
1432 // We want exp(x)-1 which is x + x^2/2 + x^3/6 + ….
1433 // We use the second degree approximation of exp(x)-1 = x + x^2/2.
1434 auto x_squared = FMul(x, x);
1435 auto x_squared_over_two = FMul(x_squared, half);
1436 auto for_small_x = FAdd(x, x_squared_over_two);
1437 // At this point, the relative errors due to floating point precision loss of
1438 // calculating exp(x) - 1 and the polynomial exp(x)-1 = x + x^2/2 are about
1439 // equal, with a value of approximately 2^-16.
1440 const auto kExponentIsSmallThreshold = 0.009;
1441 auto abs_x =
1442 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1443 auto x_is_small =
1444 FCmpOLT(abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold));
1445 return Select(x_is_small, for_small_x, for_large_x);
1446 }
1447
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)1448 StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
1449 llvm::Value* lhs,
1450 llvm::Value* rhs,
1451 absl::string_view name) {
1452 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
1453 {lhs->getType()}, b_, name);
1454 }
1455
EmitCbrt(PrimitiveType prim_type,llvm::Value * value)1456 StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
1457 llvm::Value* value) {
1458 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1459 auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
1460 auto abs_value =
1461 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1462 TF_ASSIGN_OR_RETURN(llvm::Value * abs_res,
1463 EmitPow(prim_type, abs_value, third, ""));
1464 auto signed_res = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
1465 {abs_res, value}, {type}, b_);
1466 return signed_res;
1467 }
1468
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value *,absl::string_view)1469 StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(
1470 PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* /*rhs*/,
1471 absl::string_view /*name*/) {
1472 return Unimplemented("atan2");
1473 }
1474
EmitTanh(PrimitiveType prim_type,llvm::Value * value)1475 StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
1476 llvm::Value* value) {
1477 return Unimplemented("tanh");
1478 }
1479
EmitReducePrecision(const HloInstruction * hlo,llvm::Value * x)1480 StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
1481 const HloInstruction* hlo, llvm::Value* x) {
1482 return EmitReducePrecisionIR(
1483 /*src_ty=*/hlo->operand(0)->shape().element_type(), x,
1484 /*dest_exponent_bits=*/hlo->exponent_bits(),
1485 /*dest_mantissa_bits=*/hlo->mantissa_bits(), b_);
1486 }
1487
SaturateShiftIfNecessary(llvm::IRBuilder<> * b,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * shift_result,bool saturate_to_sign_bit)1488 static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
1489 llvm::Value* lhs, llvm::Value* rhs,
1490 llvm::Value* shift_result,
1491 bool saturate_to_sign_bit) {
1492 llvm::IntegerType* integer_type =
1493 llvm::cast<llvm::IntegerType>(lhs->getType());
1494 unsigned integer_bitsize = integer_type->getBitWidth();
1495 llvm::ConstantInt* integer_bitsize_constant =
1496 llvm::ConstantInt::get(integer_type, integer_bitsize);
1497 llvm::ConstantInt* zero = llvm::ConstantInt::get(integer_type, 0);
1498 llvm::ConstantInt* minus_one = llvm::ConstantInt::get(integer_type, -1);
1499 llvm::Value* saturated_value;
1500 if (saturate_to_sign_bit) {
1501 saturated_value =
1502 b->CreateSelect(b->CreateICmpSLT(lhs, zero), minus_one, zero);
1503 } else {
1504 saturated_value = zero;
1505 }
1506 llvm::Value* shift_amt_in_range =
1507 b->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk");
1508 return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value);
1509 }
1510
GetOne(llvm::Type * type)1511 llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) {
1512 return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 1);
1513 }
1514
GetZero(llvm::Type * type)1515 llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) {
1516 return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 0);
1517 }
1518
GetIntSMin(llvm::Type * type)1519 llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) {
1520 auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1521 return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue(
1522 integer_type->getBitWidth()));
1523 }
1524
GetMinusOne(llvm::Type * type)1525 llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) {
1526 auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1527 return llvm::ConstantInt::get(
1528 integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth()));
1529 }
1530
IsZero(llvm::Value * v)1531 llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) {
1532 return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0));
1533 }
1534
IsIntMinDivisionOverflow(llvm::Value * lhs,llvm::Value * rhs)1535 llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs,
1536 llvm::Value* rhs) {
1537 return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())),
1538 ICmpEQ(rhs, GetMinusOne(rhs->getType())));
1539 }
1540
EmitIntegerDivide(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1541 llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs,
1542 llvm::Value* rhs,
1543 bool is_signed) {
1544 // Integer division overflow behavior:
1545 //
1546 // X / 0 == -1
1547 // INT_SMIN /s -1 = INT_SMIN
1548
1549 if (!is_signed) {
1550 llvm::Value* udiv_is_unsafe = IsZero(rhs);
1551 llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs);
1552 llvm::Value* safe_div = UDiv(lhs, safe_rhs);
1553 return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div);
1554 }
1555
1556 llvm::Value* has_zero_divisor = IsZero(rhs);
1557 llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1558 llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1559 llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs);
1560 llvm::Value* safe_div = SDiv(lhs, safe_rhs);
1561
1562 return Select(
1563 has_zero_divisor, GetMinusOne(lhs->getType()),
1564 Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div));
1565 }
1566
EmitIntegerRemainder(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1567 llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
1568 llvm::Value* rhs,
1569 bool is_signed) {
1570 // Integer remainder overflow behavior:
1571 //
1572 // X % 0 == X
1573 // INT_SMIN %s -1 = 0
1574
1575 if (!is_signed) {
1576 llvm::Value* urem_is_unsafe = IsZero(rhs);
1577 llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs);
1578 llvm::Value* safe_rem = URem(lhs, safe_rhs);
1579 return Select(urem_is_unsafe, lhs, safe_rem);
1580 }
1581
1582 llvm::Value* has_zero_divisor = IsZero(rhs);
1583 llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1584 llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1585 llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs);
1586 llvm::Value* safe_rem = SRem(lhs, safe_rhs);
1587
1588 return Select(
1589 has_zero_divisor, lhs,
1590 Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem));
1591 }
1592
EmitIntegerPow(llvm::Value * base,llvm::Value * exponent,bool is_signed)1593 llvm::Value* ElementalIrEmitter::EmitIntegerPow(llvm::Value* base,
1594 llvm::Value* exponent,
1595 bool is_signed) {
1596 // Exponentiation by squaring:
1597 // https://en.wikipedia.org/wiki/Exponentiation_by_squaring;
1598 int bits = 6; // Everything else would overflow for any exponent > 1, as 2^64
1599 // is the larget possible exponent for a 64-bit integer, and
1600 // that's 1 << 6.
1601 llvm::Value* accumulator = llvm::ConstantInt::get(base->getType(), 1);
1602 llvm::Value* one = llvm::ConstantInt::get(exponent->getType(), 1);
1603 llvm::Value* zero = llvm::ConstantInt::get(exponent->getType(), 0);
1604 llvm::Value* original_base = base;
1605 llvm::Value* original_exponent = exponent;
1606
1607 // Unroll the loop at compile time.
1608 for (int i = 0; i < bits; i++) {
1609 accumulator =
1610 b_->CreateSelect(b_->CreateICmpEQ(b_->CreateAnd(exponent, one), one),
1611 b_->CreateMul(accumulator, base), accumulator);
1612 base = b_->CreateMul(base, base);
1613 exponent = b_->CreateLShr(exponent, 1);
1614 }
1615 return b_->CreateSelect(
1616 b_->CreateICmpSGE(original_exponent, zero), accumulator,
1617 b_->CreateSelect(b_->CreateICmpEQ(original_base, one), one, zero));
1618 }
1619
EmitIntegerBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1620 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
1621 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
1622 bool is_signed) {
1623 switch (op->opcode()) {
1624 // TODO(jingyue): add the "nsw" attribute for signed types.
1625 case HloOpcode::kAdd:
1626 return Add(lhs_value, rhs_value);
1627 case HloOpcode::kSubtract:
1628 return Sub(lhs_value, rhs_value);
1629 case HloOpcode::kMultiply:
1630 return Mul(lhs_value, rhs_value);
1631 case HloOpcode::kDivide:
1632 return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
1633 case HloOpcode::kRemainder:
1634 return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
1635 case HloOpcode::kCompare: {
1636 switch (op->comparison_direction()) {
1637 case ComparisonDirection::kEq:
1638 return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
1639 rhs_value, b_);
1640 case ComparisonDirection::kNe:
1641 return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
1642 rhs_value, b_);
1643 case ComparisonDirection::kLt:
1644 return llvm_ir::EmitComparison(
1645 is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
1646 lhs_value, rhs_value, b_);
1647 case ComparisonDirection::kGt:
1648 return llvm_ir::EmitComparison(
1649 is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
1650 lhs_value, rhs_value, b_);
1651 case ComparisonDirection::kLe:
1652 return llvm_ir::EmitComparison(
1653 is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
1654 lhs_value, rhs_value, b_);
1655 case ComparisonDirection::kGe:
1656 return llvm_ir::EmitComparison(
1657 is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
1658 lhs_value, rhs_value, b_);
1659 }
1660 }
1661 case HloOpcode::kMinimum:
1662 return EmitIntegralMin(lhs_value, rhs_value, is_signed);
1663 case HloOpcode::kMaximum:
1664 return EmitIntegralMax(lhs_value, rhs_value, is_signed);
1665 case HloOpcode::kAnd:
1666 return And(lhs_value, rhs_value);
1667 case HloOpcode::kOr:
1668 return Or(lhs_value, rhs_value);
1669 case HloOpcode::kPower:
1670 return EmitIntegerPow(lhs_value, rhs_value, is_signed);
1671 case HloOpcode::kXor:
1672 return Xor(lhs_value, rhs_value);
1673
1674 // Shifting out bits >= the number of bits in the type being shifted
1675 // produces a poison value in LLVM which is basically "deferred undefined
1676 // behavior" -- doing something observable with such a value precipitates
1677 // UB. We replace the poison value with a constant to avoid this deferred
1678 // UB.
1679 case HloOpcode::kShiftRightArithmetic:
1680 return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1681 AShr(lhs_value, rhs_value),
1682 /*saturate_to_sign_bit=*/true);
1683 case HloOpcode::kShiftLeft:
1684 return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1685 Shl(lhs_value, rhs_value),
1686 /*saturate_to_sign_bit=*/false);
1687 case HloOpcode::kShiftRightLogical:
1688 return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1689 LShr(lhs_value, rhs_value),
1690 /*saturate_to_sign_bit=*/false);
1691 default:
1692 return Unimplemented("binary integer op '%s'",
1693 HloOpcodeString(op->opcode()));
1694 }
1695 }
1696
EmitIntegralMax(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1697 llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
1698 llvm::Value* rhs_value,
1699 bool is_signed) {
1700 return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
1701 : llvm::ICmpInst::ICMP_UGE,
1702 lhs_value, rhs_value),
1703 lhs_value, rhs_value);
1704 }
1705
EmitIntegralMin(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1706 llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
1707 llvm::Value* rhs_value,
1708 bool is_signed) {
1709 return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
1710 : llvm::ICmpInst::ICMP_ULE,
1711 lhs_value, rhs_value),
1712 lhs_value, rhs_value);
1713 }
1714
EmitElementalSelect(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1715 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
1716 const HloInstruction* hlo,
1717 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1718 const llvm_ir::IrArray::Index& index) {
1719 TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
1720 operand_to_generator.at(hlo->operand(0))(index));
1721 TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
1722 operand_to_generator.at(hlo->operand(1))(index));
1723 TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
1724 operand_to_generator.at(hlo->operand(2))(index));
1725 return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value,
1726 on_false_value);
1727 }
1728
EmitElementalClamp(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1729 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
1730 const HloInstruction* hlo,
1731 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1732 const llvm_ir::IrArray::Index& index) {
1733 TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
1734 operand_to_generator.at(hlo->operand(0))(index));
1735 TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
1736 operand_to_generator.at(hlo->operand(1))(index));
1737 TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
1738 operand_to_generator.at(hlo->operand(2))(index));
1739 PrimitiveType prim_type = hlo->shape().element_type();
1740 if (primitive_util::IsFloatingPointType(prim_type)) {
1741 return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value, ""), "");
1742 } else if (primitive_util::IsIntegralType(prim_type)) {
1743 bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
1744 return EmitIntegralMin(
1745 max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed);
1746 } else {
1747 return Unimplemented("Clamp unimplemented for %s",
1748 PrimitiveType_Name(prim_type));
1749 }
1750 }
1751
EmitElementalConcatenate(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & target_index)1752 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
1753 const HloInstruction* hlo,
1754 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1755 const llvm_ir::IrArray::Index& target_index) {
1756 const int64 concat_dim = hlo->dimensions(0);
1757 auto source_index = target_index;
1758
1759 llvm::BasicBlock* init_block = b_->GetInsertBlock();
1760
1761 // A terminator should be present iff we're emitting code
1762 // into the middle (as opposed to the end) of a basic block.
1763 CHECK_EQ(b_->GetInsertPoint() == init_block->end(),
1764 init_block->getTerminator() == nullptr);
1765
1766 llvm::BasicBlock* exit_block;
1767 if (b_->GetInsertPoint() == init_block->end()) {
1768 exit_block = llvm_ir::CreateBasicBlock(
1769 /*insert_before=*/nullptr, IrName(hlo, "merge"), b_);
1770 } else {
1771 exit_block =
1772 init_block->splitBasicBlock(b_->GetInsertPoint(), IrName(hlo, "merge"));
1773 init_block->getTerminator()->eraseFromParent();
1774 }
1775
1776 llvm_ir::SetToFirstInsertPoint(exit_block, b_);
1777 llvm::PHINode* output =
1778 PHI(llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
1779 hlo->operands().size());
1780 auto prior_insert_point = b_->GetInsertPoint();
1781
1782 b_->SetInsertPoint(init_block);
1783
1784 // Assign a unique id for each *different* operand, and count how often each
1785 // operand is used. If all operands are different, the usage count will be 1
1786 // for each operand.
1787 absl::flat_hash_map<const HloInstruction*, int64> to_unique_operand_id;
1788 std::vector<int64> operand_usage_count;
1789 for (const auto* operand : hlo->operands()) {
1790 if (to_unique_operand_id.contains(operand)) {
1791 ++operand_usage_count[to_unique_operand_id[operand]];
1792 } else {
1793 int64 unique_operand_id = to_unique_operand_id.size();
1794 to_unique_operand_id[operand] = unique_operand_id;
1795 operand_usage_count.push_back(1);
1796 }
1797 }
1798
1799 // To avoid that we emit the same operand more than once, we create one basic
1800 // block for each *different* operand with a PHI node for the different source
1801 // index inputs.
1802 std::vector<llvm::BasicBlock*> emit_operand_blocks(
1803 to_unique_operand_id.size(), nullptr);
1804 std::vector<llvm::PHINode*> source_index_phis(to_unique_operand_id.size(),
1805 nullptr);
1806 for (const auto* operand : hlo->operands()) {
1807 int64 operand_id = to_unique_operand_id[operand];
1808 if (emit_operand_blocks[operand_id] != nullptr) {
1809 continue;
1810 }
1811
1812 emit_operand_blocks[operand_id] = llvm_ir::CreateBasicBlock(
1813 exit_block, StrCat("concat_index_from_operand_id", operand_id), b_);
1814 auto saved_insert_point = b_->GetInsertPoint();
1815 llvm_ir::SetToFirstInsertPoint(emit_operand_blocks[operand_id], b_);
1816 source_index_phis[operand_id] =
1817 PHI(source_index.GetType(), operand_usage_count[operand_id]);
1818 std::vector<llvm::Value*> operand_multi_index = source_index.multidim();
1819 operand_multi_index[concat_dim] =
1820 NSWSub(operand_multi_index[concat_dim], source_index_phis[operand_id]);
1821
1822 // Create the terminator of the block before calling operand generators,
1823 // because they require non-degenerate basic blocks.
1824 b_->SetInsertPoint(llvm::BranchInst::Create(
1825 exit_block, /*InsertAtEnd=*/emit_operand_blocks[operand_id]));
1826 llvm_ir::IrArray::Index operand_index(operand_multi_index, operand->shape(),
1827 source_index.GetType());
1828 TF_ASSIGN_OR_RETURN(llvm::Value * value,
1829 operand_to_generator.at(operand)(operand_index));
1830 output->addIncoming(value, b_->GetInsertBlock());
1831 b_->SetInsertPoint(init_block, saved_insert_point);
1832 }
1833
1834 int64 concat_dim_size = 0;
1835 for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
1836 ++operand_idx) {
1837 const HloInstruction* operand = hlo->operand(operand_idx);
1838 auto false_block = llvm_ir::CreateBasicBlock(
1839 exit_block, StrCat("concat_index_not_from_operand", operand_idx), b_);
1840 int64 operand_id = to_unique_operand_id[operand];
1841 source_index_phis[operand_id]->addIncoming(
1842 source_index.GetConstantWithIndexType(concat_dim_size),
1843 b_->GetInsertBlock());
1844 concat_dim_size += operand->shape().dimensions(concat_dim);
1845 CondBr(ICmpULT(source_index[concat_dim],
1846 source_index.GetConstantWithIndexType(concat_dim_size)),
1847 emit_operand_blocks[operand_id], false_block);
1848
1849 // Subtract the size of the concat dimension of the current operand
1850 // from the source index.
1851 b_->SetInsertPoint(false_block);
1852 }
1853
1854 Unreachable();
1855 b_->SetInsertPoint(exit_block, prior_insert_point);
1856 return output;
1857 }
1858
EmitElementalDynamicSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1859 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
1860 const HloInstruction* hlo,
1861 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1862 const llvm_ir::IrArray::Index& index) {
1863 // Emit IR to read dynamic start indices from hlo->operand(1).
1864 const HloInstruction* input_hlo = hlo->operand(0);
1865 const int64 rank = input_hlo->shape().rank();
1866 // Use the same index type for all tensor accesses in the same kernel.
1867 llvm::Type* index_type = index.GetType();
1868 std::vector<llvm::Value*> slice_start_multi_index(rank);
1869 for (int64 i = 0; i < rank; ++i) {
1870 auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
1871 return llvm::ConstantInt::get(index_type, c);
1872 };
1873 llvm_ir::IrArray::Index zero_index(index_type);
1874 TF_ASSIGN_OR_RETURN(
1875 llvm::Value * start_index_value,
1876 operand_to_generator.at(hlo->operand(1 + i))(zero_index));
1877
1878 // Clamp the start index so that the sliced portion fits in the operand:
1879 // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
1880 start_index_value = SExtOrTrunc(start_index_value, index_type);
1881 int64 largest_valid_start_index =
1882 input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i);
1883 CHECK_GE(largest_valid_start_index, 0);
1884
1885 bool is_signed = ShapeUtil::ElementIsSigned(hlo->operand(1)->shape());
1886 start_index_value = EmitIntegralMin(
1887 index_typed_const(largest_valid_start_index),
1888 EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
1889 is_signed);
1890
1891 start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
1892 slice_start_multi_index[i] = start_index_value;
1893 }
1894
1895 std::vector<llvm::Value*> input_multi_index(rank);
1896 for (int64 i = 0; i < rank; ++i) {
1897 // Emit IR which computes:
1898 // input_index = start_index + offset_index
1899 input_multi_index[i] = Add(slice_start_multi_index[i], index[i]);
1900 }
1901 llvm_ir::IrArray::Index input_index(input_multi_index, input_hlo->shape(),
1902 index_type);
1903 return operand_to_generator.at(input_hlo)(input_index);
1904 }
1905
EmitElementalGather(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1906 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
1907 const HloInstruction* hlo,
1908 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1909 const llvm_ir::IrArray::Index& index) {
1910 const Shape& operand_shape = hlo->operand(0)->shape();
1911 const Shape& indices_shape = hlo->operand(1)->shape();
1912 const Shape& output_shape = hlo->shape();
1913
1914 const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers();
1915
1916 const llvm_ir::ElementGenerator& operand_generator =
1917 operand_to_generator.at(hlo->operand(0));
1918 const llvm_ir::ElementGenerator& indices_generator =
1919 operand_to_generator.at(hlo->operand(1));
1920
1921 llvm::Type* index_type = index.GetType();
1922 // This is the index into `operand` that holds the element we want to
1923 // generate.
1924 std::vector<llvm::Value*> operand_multi_index;
1925
1926 // First copy in the window indices to operand_index. Also collect a mapping
1927 // from operand dimension to output window dimension. Elided window dimensions
1928 // map to -1.
1929 std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1);
1930 for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0;
1931 i < e; i++) {
1932 if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
1933 operand_multi_index.push_back(index.GetConstantWithIndexType(0));
1934 } else {
1935 int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++);
1936 operand_to_output_dim[i] = output_window_dim;
1937 operand_multi_index.push_back(index[output_window_dim]);
1938 }
1939 }
1940
1941 // This is the index of the index vector in the start_indices tensor.
1942 std::vector<llvm::Value*> gather_index_index_components;
1943 {
1944 for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
1945 if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
1946 gather_index_index_components.push_back(index[i]);
1947 }
1948 }
1949
1950 if (gather_index_index_components.size() !=
1951 indices_shape.dimensions_size()) {
1952 gather_index_index_components.insert(
1953 gather_index_index_components.begin() +
1954 dim_numbers.index_vector_dim(),
1955 nullptr);
1956 }
1957 }
1958
1959 auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
1960 auto index_component_type = index_component->getType();
1961 auto extended_type = index_component_type->getScalarSizeInBits() >=
1962 index_type->getScalarSizeInBits()
1963 ? index_component_type
1964 : index_type;
1965 // Possibly extend the value at the beginning to ensure clamping logic stays
1966 // in bounds.
1967 auto maybe_extended_index =
1968 index_component_type != extended_type
1969 ? b_->CreateSExt(index_component, extended_type)
1970 : index_component;
1971 int64 operand_dim = dim_numbers.start_index_map(dim);
1972 int64 output_dim = operand_to_output_dim[operand_dim];
1973 // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
1974 // This means we set the iteration index to 0, so for the purpose of the
1975 // following calculations we can consider the output dimension size to be 1.
1976 int64 output_dim_size =
1977 output_dim == -1 ? 1 : output_shape.dimensions(output_dim);
1978 int64 largest_valid_start_index =
1979 operand_shape.dimensions(operand_dim) - output_dim_size;
1980 CHECK_GE(largest_valid_start_index, 0);
1981
1982 // Clamp the gather index so that the gather region fits in the operand.
1983 // clamped_index =
1984 // clamp(gather_dim_component_extended, 0, largest_valid_start_index);
1985 bool is_signed = ShapeUtil::ElementIsSigned(indices_shape);
1986 auto clamped_index = EmitIntegralMin(
1987 llvm::ConstantInt::get(extended_type, largest_valid_start_index),
1988 EmitIntegralMax(llvm::ConstantInt::get(extended_type, 0),
1989 maybe_extended_index, is_signed),
1990 is_signed);
1991 // Truncate at the end to the optimized index size
1992 auto maybe_truncated_clamped_index = extended_type != index_type
1993 ? Trunc(clamped_index, index_type)
1994 : clamped_index;
1995
1996 operand_multi_index[operand_dim] =
1997 Add(operand_multi_index[operand_dim], maybe_truncated_clamped_index);
1998 };
1999
2000 if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
2001 IrArray::Index gather_index_index(gather_index_index_components,
2002 indices_shape, index_type);
2003 TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
2004 indices_generator(gather_index_index));
2005 add_to_operand_index(gather_dim_component, 0);
2006 } else {
2007 int64 index_vector_size =
2008 indices_shape.dimensions(dim_numbers.index_vector_dim());
2009 for (int64 i = 0; i < index_vector_size; i++) {
2010 gather_index_index_components[dim_numbers.index_vector_dim()] =
2011 index.GetConstantWithIndexType(i);
2012 IrArray::Index gather_index_index(gather_index_index_components,
2013 indices_shape, index_type);
2014 TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
2015 indices_generator(gather_index_index));
2016 add_to_operand_index(gather_dim_component, i);
2017 }
2018 }
2019 IrArray::Index operand_index(operand_multi_index, operand_shape, index_type);
2020 return operand_generator(operand_index);
2021 }
2022
EmitElementalDynamicUpdateSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2023 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
2024 const HloInstruction* hlo,
2025 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2026 const llvm_ir::IrArray::Index& index) {
2027 const HloInstruction* input_hlo = hlo->operand(0);
2028 const HloInstruction* update_hlo = hlo->operand(1);
2029 const HloInstruction* start_hlo = hlo->operand(2);
2030 // Calculate slice start/end indices.
2031 const int64 rank = input_hlo->shape().rank();
2032 std::vector<llvm::Value*> slice_start_multi_index(rank);
2033 std::vector<llvm::Value*> slice_limit_multi_index(rank);
2034 // Slice intersection gathers (ANDs) conditions on all ranks for which
2035 // 'input' is set to 'update'
2036 llvm::Value* slice_intersection = b_->getTrue();
2037
2038 for (int64 i = 0; i < rank; ++i) {
2039 llvm::Type* index_type = index[0]->getType();
2040 auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2041 return llvm::ConstantInt::get(index_type, c);
2042 };
2043
2044 llvm_ir::IrArray::Index zero_index(index_type);
2045 TF_ASSIGN_OR_RETURN(
2046 llvm::Value * start_index_value,
2047 operand_to_generator.at(hlo->operand(2 + i))(zero_index));
2048
2049 // Clamp the start index so that the update region fits in the operand.
2050 // start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
2051 start_index_value = SExtOrTrunc(start_index_value, index_type);
2052 llvm::Value* update_dim_size =
2053 index_typed_const(update_hlo->shape().dimensions(i));
2054 int64 largest_valid_start_index =
2055 input_hlo->shape().dimensions(i) - update_hlo->shape().dimensions(i);
2056 CHECK_GE(largest_valid_start_index, 0);
2057
2058 bool is_signed = ShapeUtil::ElementIsSigned(start_hlo->shape());
2059 start_index_value = EmitIntegralMin(
2060 index_typed_const(largest_valid_start_index),
2061 EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
2062 is_signed);
2063
2064 start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
2065 slice_start_multi_index[i] = start_index_value;
2066 slice_limit_multi_index[i] =
2067 Add(slice_start_multi_index[i], update_dim_size);
2068
2069 slice_intersection =
2070 And(slice_intersection, ICmpSGE(index[i], slice_start_multi_index[i]),
2071 "slice_intersection");
2072 slice_intersection =
2073 And(slice_intersection, ICmpSLT(index[i], slice_limit_multi_index[i]),
2074 "slice_intersection");
2075 }
2076
2077 // Emit:
2078 // if (slice_intersection) -> return data from 'update'.
2079 // else -> return data from 'input'.
2080 llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2081 llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2082 "ret_value_addr", b_);
2083 llvm_ir::LlvmIfData if_data =
2084 llvm_ir::EmitIfThenElse(slice_intersection, "slice_intersection", b_);
2085
2086 // Handle true BB (return data from 'update')
2087 SetToFirstInsertPoint(if_data.true_block, b_);
2088 // Compute update index for intersection case.
2089 std::vector<llvm::Value*> update_multi_index(rank);
2090 for (int64 i = 0; i < rank; ++i) {
2091 update_multi_index[i] = Sub(index[i], slice_start_multi_index[i]);
2092 }
2093 llvm_ir::IrArray::Index update_index(update_multi_index, update_hlo->shape(),
2094 index.GetType());
2095 TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
2096 operand_to_generator.at(update_hlo)(update_index));
2097 Store(true_value, ret_value_addr);
2098
2099 // Handle false BB (return data from 'input')
2100 SetToFirstInsertPoint(if_data.false_block, b_);
2101 TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
2102 operand_to_generator.at(input_hlo)(index));
2103 Store(false_value, ret_value_addr);
2104
2105 SetToFirstInsertPoint(if_data.after_block, b_);
2106 return Load(ret_value_addr);
2107 }
2108
EmitElementalPad(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & padded_index)2109 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
2110 const HloInstruction* hlo,
2111 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2112 const llvm_ir::IrArray::Index& padded_index) {
2113 std::vector<llvm::Value*> multi_index = padded_index.multidim();
2114 llvm::Value* in_bounds = b_->getTrue();
2115 for (size_t i = 0; i < multi_index.size(); ++i) {
2116 auto index_typed_const = [=](int64 n) {
2117 return padded_index.GetConstantWithIndexType(n);
2118 };
2119 const auto& pad_dim = hlo->padding_config().dimensions(i);
2120 multi_index[i] =
2121 Sub(multi_index[i], index_typed_const(pad_dim.edge_padding_low()));
2122 in_bounds = And(in_bounds, ICmpSGE(multi_index[i], index_typed_const(0)),
2123 "in_bounds");
2124 in_bounds =
2125 And(in_bounds,
2126 ICmpEQ(index_typed_const(0),
2127 URem(multi_index[i],
2128 index_typed_const(pad_dim.interior_padding() + 1))),
2129 "in_bounds");
2130 multi_index[i] =
2131 SDiv(multi_index[i], index_typed_const(pad_dim.interior_padding() + 1));
2132 in_bounds =
2133 And(in_bounds,
2134 ICmpSLT(multi_index[i],
2135 index_typed_const(hlo->operand(0)->shape().dimensions(i))),
2136 "in_bounds");
2137 }
2138
2139 // if (in_bounds) {
2140 // ret_value = operand0[index]; // source
2141 // } else {
2142 // ret_value = *operand1; // padding
2143 // }
2144 llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2145 llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2146 "pad_result_addr", b_);
2147 llvm_ir::LlvmIfData if_data =
2148 llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
2149 SetToFirstInsertPoint(if_data.true_block, b_);
2150 llvm_ir::IrArray::Index index(multi_index, hlo->operand(0)->shape(),
2151 padded_index.GetType());
2152 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2153 operand_to_generator.at(hlo->operand(0))(index));
2154 Store(operand_value, ret_value_addr);
2155
2156 SetToFirstInsertPoint(if_data.false_block, b_);
2157 TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
2158 operand_to_generator.at(hlo->operand(1))(
2159 IrArray::Index(index.GetType())));
2160 Store(padding_value, ret_value_addr);
2161
2162 SetToFirstInsertPoint(if_data.after_block, b_);
2163 // Don't create phi(operand_value, padding_value) here, because invoking
2164 // operand_to_generator may create new basic blocks, making the parent
2165 // of operand_value or padding_value no longer a predecessor of
2166 // if_data.after_block.
2167 return Load(ret_value_addr);
2168 }
2169
EmitElementalDot(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & dot_result_index)2170 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
2171 const HloInstruction* hlo,
2172 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2173 const llvm_ir::IrArray::Index& dot_result_index) {
2174 auto lhs_generator = operand_to_generator.at(hlo->operand(0));
2175 auto rhs_generator = operand_to_generator.at(hlo->operand(1));
2176
2177 const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers();
2178 int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0);
2179 int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0);
2180
2181 int64 contracted_dim_size =
2182 hlo->operand(0)->shape().dimensions(lhs_contracting_dim);
2183 int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
2184 int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
2185
2186 llvm::Type* index_type = dot_result_index.GetType();
2187 auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2188 return llvm::ConstantInt::get(index_type, c);
2189 };
2190
2191 std::unique_ptr<llvm_ir::ForLoop> inner_loop = llvm_ir::ForLoop::EmitForLoop(
2192 IrName(hlo, "inner"), index_typed_const(0),
2193 index_typed_const(contracted_dim_size), index_typed_const(1), b_);
2194
2195 SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), b_);
2196 PrimitiveType primitive_type = hlo->shape().element_type();
2197 llvm::Type* primitive_type_llvm =
2198 llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
2199 llvm::Value* accumulator_alloca =
2200 llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_);
2201 Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca);
2202
2203 SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_);
2204
2205 // This is the inner reduction loop for a dot operation that produces
2206 // one element in the output. If the operands to the dot operation have
2207 // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
2208 // Given an output index [a,b,c,d,e] in the result, we compute:
2209 // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
2210
2211 std::vector<llvm::Value*> lhs_multi_index, rhs_multi_index;
2212 for (int64 i = 0; i < lhs_dims - 1; i++) {
2213 lhs_multi_index.push_back(dot_result_index[i]);
2214 }
2215 lhs_multi_index.insert(lhs_multi_index.begin() + lhs_contracting_dim,
2216 inner_loop->GetIndVarValue());
2217 IrArray::Index lhs_index(lhs_multi_index, hlo->operand(0)->shape(),
2218 index_type);
2219
2220 int64 num_batch_dims = dim_numbers.rhs_batch_dimensions_size();
2221 for (int64 i = 0; i < num_batch_dims; i++) {
2222 rhs_multi_index.push_back(
2223 dot_result_index[dim_numbers.rhs_batch_dimensions(i)]);
2224 }
2225 for (int64 i = 0; i < rhs_dims - 1 - num_batch_dims; i++) {
2226 rhs_multi_index.push_back(dot_result_index[lhs_dims - 1 + i]);
2227 }
2228 rhs_multi_index.insert(rhs_multi_index.begin() + rhs_contracting_dim,
2229 inner_loop->GetIndVarValue());
2230 IrArray::Index rhs_index(rhs_multi_index, hlo->operand(1)->shape(),
2231 index_type);
2232
2233 llvm::Value* current_accumulator = Load(accumulator_alloca);
2234 TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
2235 TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
2236 llvm::Value* next_accumulator =
2237 EmitMulAdd(lhs_value, rhs_value, current_accumulator, primitive_type);
2238 Store(next_accumulator, accumulator_alloca);
2239
2240 SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
2241 return Load(accumulator_alloca);
2242 }
2243
MakeElementGenerator(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator)2244 llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
2245 const HloInstruction* hlo,
2246 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
2247 switch (hlo->opcode()) {
2248 case HloOpcode::kAbs:
2249 case HloOpcode::kRoundNearestAfz:
2250 case HloOpcode::kCeil:
2251 case HloOpcode::kClz:
2252 case HloOpcode::kConvert:
2253 case HloOpcode::kBitcastConvert:
2254 case HloOpcode::kCos:
2255 case HloOpcode::kExp:
2256 case HloOpcode::kExpm1:
2257 case HloOpcode::kFloor:
2258 case HloOpcode::kImag:
2259 case HloOpcode::kIsFinite:
2260 case HloOpcode::kLog:
2261 case HloOpcode::kLog1p:
2262 case HloOpcode::kNegate:
2263 case HloOpcode::kNot:
2264 case HloOpcode::kPopulationCount:
2265 case HloOpcode::kReal:
2266 case HloOpcode::kRsqrt:
2267 case HloOpcode::kSign:
2268 case HloOpcode::kSin:
2269 case HloOpcode::kSqrt:
2270 case HloOpcode::kCbrt:
2271 case HloOpcode::kTanh:
2272 return [this, hlo, &operand_to_generator](
2273 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2274 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2275 operand_to_generator.at(hlo->operand(0))(index));
2276 return EmitUnaryOp(hlo, operand_value);
2277 };
2278 case HloOpcode::kAdd:
2279 case HloOpcode::kAnd:
2280 case HloOpcode::kAtan2:
2281 case HloOpcode::kCompare:
2282 case HloOpcode::kComplex:
2283 case HloOpcode::kDivide:
2284 case HloOpcode::kMaximum:
2285 case HloOpcode::kMinimum:
2286 case HloOpcode::kMultiply:
2287 case HloOpcode::kOr:
2288 case HloOpcode::kXor:
2289 case HloOpcode::kPower:
2290 case HloOpcode::kRemainder:
2291 case HloOpcode::kShiftLeft:
2292 case HloOpcode::kShiftRightArithmetic:
2293 case HloOpcode::kShiftRightLogical:
2294 case HloOpcode::kSubtract:
2295 return [this, hlo, &operand_to_generator](
2296 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2297 const HloInstruction* lhs = hlo->operand(0);
2298 const HloInstruction* rhs = hlo->operand(1);
2299 TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value,
2300 operand_to_generator.at(lhs)(index));
2301 TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value,
2302 operand_to_generator.at(rhs)(index));
2303 return EmitBinaryOp(hlo, lhs_value, rhs_value);
2304 };
2305 case HloOpcode::kSelect:
2306 return [this, hlo, &operand_to_generator](
2307 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2308 return EmitElementalSelect(hlo, operand_to_generator, index);
2309 };
2310 case HloOpcode::kClamp:
2311 return [this, hlo, &operand_to_generator](
2312 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2313 return EmitElementalClamp(hlo, operand_to_generator, index);
2314 };
2315 case HloOpcode::kReducePrecision:
2316 return [this, hlo, &operand_to_generator](
2317 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2318 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2319 operand_to_generator.at(hlo->operand(0))(index));
2320 return EmitReducePrecision(hlo, operand_value);
2321 };
2322 case HloOpcode::kConcatenate:
2323 return [this, hlo, &operand_to_generator](
2324 const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
2325 return EmitElementalConcatenate(hlo, operand_to_generator,
2326 target_index);
2327 };
2328 case HloOpcode::kReverse:
2329 return [this, hlo, &operand_to_generator](
2330 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2331 const HloInstruction* operand = hlo->operand(0);
2332 std::vector<llvm::Value*> source_multi_index = target_index.multidim();
2333 for (int64 dim : hlo->dimensions()) {
2334 source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType(
2335 hlo->shape().dimensions(dim) - 1),
2336 target_index[dim]);
2337 }
2338 llvm_ir::IrArray::Index source_index(
2339 source_multi_index, operand->shape(), target_index.GetType());
2340 return operand_to_generator.at(operand)(source_index);
2341 };
2342 case HloOpcode::kBroadcast:
2343 return [this, hlo, &operand_to_generator](
2344 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2345 const HloInstruction* operand = hlo->operand(0);
2346 // The `dimensions` member of the broadcast instruction maps from
2347 // input dimensions to output dimensions.
2348 return operand_to_generator.at(operand)(
2349 target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(),
2350 hlo->dimensions(), b_));
2351 };
2352 case HloOpcode::kIota:
2353 return [this, hlo](
2354 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2355 auto* iota = Cast<HloIotaInstruction>(hlo);
2356 PrimitiveType element_type = iota->shape().element_type();
2357 IrArray::Index elem_index =
2358 iota->shape().rank() > 1
2359 ? target_index.SourceIndexOfBroadcast(
2360 iota->shape(),
2361 ShapeUtil::MakeShapeWithDescendingLayout(
2362 element_type,
2363 {iota->shape().dimensions(iota->iota_dimension())}),
2364 {iota->iota_dimension()}, b_)
2365 : target_index;
2366 llvm::Value* elem_index_linear = elem_index.linear();
2367 if (elem_index_linear == nullptr) {
2368 std::vector<int64> iota_bound = {
2369 iota->shape().dimensions(iota->iota_dimension())};
2370 elem_index_linear = elem_index.Linearize(iota_bound, b_);
2371 }
2372 Shape component_shape =
2373 ShapeUtil::ElementIsComplex(iota->shape())
2374 ? ShapeUtil::ComplexComponentShape(iota->shape())
2375 : iota->shape();
2376 PrimitiveType component_element_type = component_shape.element_type();
2377 llvm::Value* iota_result;
2378 if (primitive_util::IsIntegralType(component_element_type)) {
2379 iota_result = b_->CreateIntCast(
2380 elem_index_linear,
2381 llvm_ir::PrimitiveTypeToIrType(component_element_type, module_),
2382 /*isSigned=*/false);
2383 } else {
2384 TF_RET_CHECK(
2385 primitive_util::IsFloatingPointType(component_element_type))
2386 << component_element_type;
2387 llvm::Type* float_ir_type;
2388 if (component_element_type == BF16) {
2389 float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_);
2390 } else {
2391 float_ir_type =
2392 llvm_ir::PrimitiveTypeToIrType(component_element_type, module_);
2393 }
2394 llvm::Value* float_val =
2395 b_->CreateUIToFP(elem_index_linear, float_ir_type);
2396 if (component_element_type == BF16) {
2397 TF_ASSIGN_OR_RETURN(iota_result, EmitF32ToBF16(float_val, b_));
2398 } else {
2399 iota_result = float_val;
2400 }
2401 }
2402 if (ShapeUtil::ElementIsComplex(iota->shape())) {
2403 return EmitComposeComplex(iota, iota_result, nullptr);
2404 } else {
2405 return iota_result;
2406 }
2407 };
2408 case HloOpcode::kSlice:
2409 return [this, hlo, &operand_to_generator](
2410 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2411 IrArray::Index sliced_index = index.SourceIndexOfSlice(
2412 /*operand_shape=*/hlo->operand(0)->shape(),
2413 /*starts=*/hlo->slice_starts(),
2414 /*strides=*/hlo->slice_strides(), /*builder=*/b_);
2415 return operand_to_generator.at(hlo->operand(0))(sliced_index);
2416 };
2417 case HloOpcode::kDynamicSlice:
2418 return [this, hlo, &operand_to_generator](
2419 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2420 return EmitElementalDynamicSlice(hlo, operand_to_generator, index);
2421 };
2422
2423 case HloOpcode::kGather:
2424 return [this, hlo, &operand_to_generator](
2425 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2426 return EmitElementalGather(hlo, operand_to_generator, index);
2427 };
2428 case HloOpcode::kDynamicUpdateSlice:
2429 return [this, hlo, &operand_to_generator](
2430 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2431 return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator,
2432 index);
2433 };
2434 case HloOpcode::kBitcast:
2435 CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2436 ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2437 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2438 const HloInstruction* operand = hlo->operand(0);
2439 return operand_to_generator.at(operand)(
2440 index.SourceIndexOfBitcast(hlo->shape(), operand->shape(), b_));
2441 };
2442 case HloOpcode::kReshape:
2443 CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2444 ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2445 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2446 const HloInstruction* operand = hlo->operand(0);
2447 return operand_to_generator.at(operand)(
2448 index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_));
2449 };
2450 case HloOpcode::kCopy:
2451 return [hlo, &operand_to_generator](
2452 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2453 IrArray::Index source_index(target_index.multidim(),
2454 hlo->operand(0)->shape(),
2455 target_index.GetType());
2456 TF_ASSIGN_OR_RETURN(
2457 llvm::Value * operand_value,
2458 operand_to_generator.at(hlo->operand(0))(source_index));
2459 return operand_value;
2460 };
2461 case HloOpcode::kTranspose:
2462 return [this, hlo,
2463 &operand_to_generator](const IrArray::Index& target_index) {
2464 return operand_to_generator.at(hlo->operand(0))(
2465 target_index.SourceIndexOfTranspose(
2466 hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions()));
2467 };
2468 case HloOpcode::kPad:
2469 return [this, hlo, &operand_to_generator](
2470 const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
2471 return EmitElementalPad(hlo, operand_to_generator, padded_index);
2472 };
2473
2474 case HloOpcode::kDot:
2475 return [this, hlo,
2476 &operand_to_generator](const IrArray::Index& dot_result_index)
2477 -> StatusOr<llvm::Value*> {
2478 return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
2479 };
2480 case HloOpcode::kMap:
2481 return [this, hlo, &operand_to_generator](
2482 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2483 std::vector<llvm::Value*> operands;
2484 for (int i = 0; i < hlo->operand_count(); i++) {
2485 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2486 operand_to_generator.at(hlo->operand(i))(index));
2487 operands.push_back(operand_value);
2488 }
2489 return EmitElementalMap(Cast<HloMapInstruction>(hlo), operands);
2490 };
2491 case HloOpcode::kReduceWindow:
2492 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2493 auto reduce_window_instr = Cast<HloReduceWindowInstruction>(hlo);
2494 std::vector<llvm_ir::ElementGenerator> input_generators;
2495 for (const HloInstruction* instr :
2496 reduce_window_instr->input_arrays()) {
2497 input_generators.push_back(operand_to_generator.at(instr));
2498 }
2499
2500 std::vector<llvm_ir::ElementGenerator> initial_value_generators;
2501 for (const HloInstruction* instr : reduce_window_instr->init_values()) {
2502 initial_value_generators.push_back(operand_to_generator.at(instr));
2503 }
2504 return EmitElementalReduceWindow(
2505 Cast<HloReduceWindowInstruction>(hlo), std::move(input_generators),
2506 std::move(initial_value_generators), index);
2507 };
2508 case HloOpcode::kReduce:
2509 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2510 auto reduce_instr = Cast<HloReduceInstruction>(hlo);
2511 std::vector<llvm_ir::ElementGenerator> input_generators;
2512 for (const HloInstruction* instr : reduce_instr->inputs()) {
2513 input_generators.push_back(operand_to_generator.at(instr));
2514 }
2515
2516 std::vector<llvm_ir::ElementGenerator> initial_value_generators;
2517 for (const HloInstruction* instr : reduce_instr->init_values()) {
2518 initial_value_generators.push_back(operand_to_generator.at(instr));
2519 }
2520 return EmitElementalReduce(reduce_instr, std::move(input_generators),
2521 std::move(initial_value_generators), index);
2522 };
2523 case HloOpcode::kConvolution:
2524 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2525 return EmitConvolution(hlo, operand_to_generator, index);
2526 };
2527 default:
2528 return [hlo](const IrArray::Index& index) {
2529 return Unimplemented("Unhandled opcode for elemental IR emission: %s",
2530 HloOpcodeString(hlo->opcode()));
2531 };
2532 }
2533 }
2534
EmitExtractReal(llvm::Value * value)2535 llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) {
2536 return ExtractValue(value, {0});
2537 }
2538
EmitExtractImag(llvm::Value * value)2539 llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) {
2540 return ExtractValue(value, {1});
2541 }
2542
EmitComposeComplex(const HloInstruction * op,llvm::Value * real,llvm::Value * imag)2543 llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
2544 llvm::Value* real,
2545 llvm::Value* imag) {
2546 auto cplx_type =
2547 llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
2548 auto complex =
2549 InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0});
2550 if (imag != nullptr) {
2551 complex = InsertValue(complex, imag, {1});
2552 }
2553 return complex;
2554 }
2555
EmitMulAdd(llvm::Value * lhs,llvm::Value * rhs,llvm::Value * accumulator,xla::PrimitiveType primitive_type)2556 llvm::Value* ElementalIrEmitter::EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs,
2557 llvm::Value* accumulator,
2558 xla::PrimitiveType primitive_type) {
2559 if (primitive_util::IsComplexType(primitive_type)) {
2560 llvm::Value* product_real =
2561 FSub(FMul(EmitExtractReal(lhs), EmitExtractReal(rhs)),
2562 FMul(EmitExtractImag(lhs), EmitExtractImag(rhs)));
2563 llvm::Value* product_imag =
2564 FAdd(FMul(EmitExtractReal(lhs), EmitExtractImag(rhs)),
2565 FMul(EmitExtractImag(lhs), EmitExtractReal(rhs)));
2566 llvm::Value* next_accumulator = InsertValue(
2567 accumulator, FAdd(EmitExtractReal(accumulator), product_real), {0});
2568 return InsertValue(next_accumulator,
2569 FAdd(EmitExtractImag(accumulator), product_imag), {1});
2570 } else if (primitive_util::IsFloatingPointType(primitive_type)) {
2571 return FAdd(accumulator, FPCast(FMul(lhs, rhs), accumulator->getType()));
2572 } else if (primitive_type == PRED) {
2573 return Or(accumulator, And(lhs, rhs));
2574 }
2575 return Add(accumulator, Mul(lhs, rhs));
2576 }
2577
EmitElementalMap(const HloMapInstruction * map_instr,absl::Span<llvm::Value * const> elemental_operands)2578 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
2579 const HloMapInstruction* map_instr,
2580 absl::Span<llvm::Value* const> elemental_operands) {
2581 TF_ASSIGN_OR_RETURN(
2582 std::vector<llvm::Value*> values,
2583 EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands,
2584 llvm_ir::IrName(map_instr)));
2585 CHECK_EQ(values.size(), 1);
2586 return values[0];
2587 }
2588
EmitElementalReduceWindow(const HloReduceWindowInstruction * reduce_window,std::vector<llvm_ir::ElementGenerator> input_generators,std::vector<llvm_ir::ElementGenerator> initial_value_generators,const llvm_ir::IrArray::Index & index)2589 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduceWindow(
2590 const HloReduceWindowInstruction* reduce_window,
2591 std::vector<llvm_ir::ElementGenerator> input_generators,
2592 std::vector<llvm_ir::ElementGenerator> initial_value_generators,
2593 const llvm_ir::IrArray::Index& index) {
2594 // Pseudocode:
2595 // for each index I in output
2596 // value = init_value
2597 // for each index W in window
2598 // for each dimension i from 0 to rank - 1
2599 // (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i]
2600 // if I in bounds of input
2601 // value = function(value, input[I])
2602 // output[O] = value
2603 int64 input_count = reduce_window->input_count();
2604 std::vector<PrimitiveType> operand_element_types;
2605 std::vector<llvm::Type*> accum_types;
2606 std::vector<llvm::Value*> accum_ptrs;
2607 for (int64 operand_index = 0; operand_index < input_count; ++operand_index) {
2608 auto operand = reduce_window->input_arrays()[operand_index];
2609 PrimitiveType operand_element_type = operand->shape().element_type();
2610 operand_element_types.push_back(operand_element_type);
2611 llvm::Type* llvm_type =
2612 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_);
2613 accum_types.push_back(llvm_type);
2614 llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
2615 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
2616 "reduce_window_accum_ptr", b_);
2617 accum_ptrs.push_back(accum_ptr);
2618 {
2619 auto initial_value_generator = initial_value_generators[operand_index];
2620 TF_ASSIGN_OR_RETURN(
2621 llvm::Value* const init_value,
2622 initial_value_generator(llvm_ir::IrArray::Index(index.GetType())));
2623 Store(init_value, accum_ptr);
2624 }
2625 }
2626
2627 llvm::Type* index_type = index.GetType();
2628 auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2629 return index.GetConstantWithIndexType(c);
2630 };
2631
2632 const Window& window = reduce_window->window();
2633 llvm_ir::ForLoopNest loops(IrName(reduce_window), b_, index_type);
2634 std::vector<int64> window_size;
2635 for (const auto& dim : window.dimensions()) {
2636 window_size.push_back(dim.size());
2637 }
2638 const IrArray::Index window_index = loops.AddLoopsForShape(
2639 ShapeUtil::MakeShape(operand_element_types[0], window_size), "window");
2640 CHECK_EQ(window_index.size(), index.size());
2641
2642 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
2643
2644 std::vector<llvm::Value*> input_multi_index(index.size());
2645 llvm::Value* in_bounds = b_->getInt1(true);
2646 for (size_t i = 0; i < index.size(); ++i) {
2647 llvm::Value* stridden_index =
2648 NSWMul(index[i], index_typed_const(window.dimensions(i).stride()));
2649 input_multi_index[i] = NSWSub(
2650 NSWAdd(
2651 stridden_index,
2652 NSWMul(window_index[i],
2653 index_typed_const(window.dimensions(i).window_dilation()))),
2654 index_typed_const(window.dimensions(i).padding_low()));
2655
2656 // We need to verify that we are not in the dilated base area.
2657 llvm::Value* dilation_condition =
2658 ICmpEQ(SRem(input_multi_index[i],
2659 index_typed_const(window.dimensions(i).base_dilation())),
2660 index_typed_const(0));
2661 in_bounds = And(in_bounds, dilation_condition);
2662
2663 // Apply base dilation to the index.
2664 input_multi_index[i] =
2665 SDiv(input_multi_index[i],
2666 index_typed_const(window.dimensions(i).base_dilation()));
2667
2668 // We must check whether 0 <= input_multi_index[i] < bound, as
2669 // otherwise we are in the pad and so can skip the computation. This
2670 // comparison is equivalent to the unsigned comparison
2671 // input_multi_index[i] < bound, as a negative value wraps to a large
2672 // positive value.
2673 in_bounds = And(
2674 in_bounds,
2675 ICmpULT(input_multi_index[i],
2676 index_typed_const(
2677 reduce_window->input_arrays()[0]->shape().dimensions(i))));
2678 }
2679
2680 llvm_ir::LlvmIfData if_data =
2681 llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
2682 SetToFirstInsertPoint(if_data.true_block, b_);
2683
2684 // We are not in pad, so do the computation.
2685 std::vector<llvm::Value*> input_values(reduce_window->operand_count());
2686 IrArray::Index input_index(
2687 input_multi_index, reduce_window->input_arrays()[0]->shape(), index_type);
2688 for (int64 operand_idx = 0; operand_idx < input_count; ++operand_idx) {
2689 TF_ASSIGN_OR_RETURN(llvm::Value * input_value,
2690 input_generators[operand_idx](input_index));
2691 input_values[input_count + operand_idx] = input_value;
2692 input_values[operand_idx] = Load(accum_ptrs[operand_idx]);
2693 }
2694 TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accum_values,
2695 EmitThreadLocalCall(*reduce_window->to_apply(),
2696 input_values, "reducer_function"));
2697
2698 for (int64 operand_idx = 0; operand_idx < accum_values.size();
2699 ++operand_idx) {
2700 Store(accum_values[operand_idx], accum_ptrs[operand_idx]);
2701 }
2702
2703 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
2704 return EmitAccumResult(accum_ptrs, accum_types,
2705 reduce_window->shape().IsTuple());
2706 }
2707
EmitElementalReduce(const HloReduceInstruction * reduce,std::vector<llvm_ir::ElementGenerator> input_generators,std::vector<llvm_ir::ElementGenerator> initial_value_generators,const llvm_ir::IrArray::Index & index)2708 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
2709 const HloReduceInstruction* reduce,
2710 std::vector<llvm_ir::ElementGenerator> input_generators,
2711 std::vector<llvm_ir::ElementGenerator> initial_value_generators,
2712 const llvm_ir::IrArray::Index& index) {
2713 const Shape& out_shape = reduce->shape();
2714 bool is_variadic = !out_shape.IsArray();
2715 int accumulators_count = 1;
2716 if (is_variadic) {
2717 CHECK(out_shape.IsTuple());
2718 accumulators_count = out_shape.tuple_shapes_size();
2719 }
2720
2721 absl::Span<const int64> reduced_dimensions(reduce->dimensions());
2722
2723 std::vector<llvm::Value*> accumulator_addrs;
2724 std::vector<llvm::Type*> accumulator_types;
2725 llvm::Type* index_type = index.GetType();
2726 for (int i = 0; i < accumulators_count; i++) {
2727 const Shape& element_shape =
2728 is_variadic ? out_shape.tuple_shapes(i) : out_shape;
2729 PrimitiveType accumulator_type = element_shape.element_type();
2730 llvm::Type* accumulator_llvm_type =
2731 llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
2732 accumulator_types.push_back(accumulator_llvm_type);
2733
2734 // Initialize an accumulator with init_value.
2735 llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2736 accumulator_llvm_type, "accumulator_" + std::to_string(i), b());
2737 TF_ASSIGN_OR_RETURN(
2738 llvm::Value* const init_value,
2739 initial_value_generators[i](llvm_ir::IrArray::Index(index_type)));
2740 Store(init_value, accumulator_addr);
2741 accumulator_addrs.push_back(accumulator_addr);
2742 }
2743
2744 // The enclosing loops go over all the target elements. Now we have to compute
2745 // the actual target element. For this, we build a new loop nest to iterate
2746 // over all the reduction dimensions in the argument.
2747 // AddLoopsForShapeOnDimensions will return an Index where induction Value*s
2748 // are placed for each dimension in dimensions, and all the rest are nullptrs.
2749 llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), b(), index_type);
2750 const HloInstruction* arg = reduce->operand(0);
2751 std::vector<llvm::Value*> input_multi_index =
2752 loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions,
2753 "reduction_dim");
2754
2755 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b());
2756
2757 // Build a full index for the input argument, using input_multi_index as the
2758 // base. In input_multi_index only the reduction dimensions are filled in. We
2759 // fill in the rest of the dimensions with induction Value*s taken from
2760 // 'index' which iterates over the target array. See the high-level
2761 // description in the XLA documentation for details.
2762 auto it = index.begin();
2763
2764 for (auto& i : input_multi_index) {
2765 if (i == nullptr) {
2766 i = *it++;
2767 }
2768 }
2769 CHECK(index.end() == it);
2770 llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
2771 index_type);
2772
2773 std::vector<llvm::Value*> reduction_operands;
2774 for (llvm::Value* accum : accumulator_addrs) {
2775 llvm::Value* accum_value = Load(accum);
2776 reduction_operands.push_back(accum_value);
2777 }
2778
2779 for (int i = 0; i < accumulators_count; i++) {
2780 TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
2781 input_generators[i](input_index));
2782 reduction_operands.push_back(input_element);
2783 }
2784
2785 TF_ASSIGN_OR_RETURN(
2786 std::vector<llvm::Value*> results,
2787 EmitThreadLocalCall(*reduce->to_apply(), reduction_operands,
2788 "reduce_function"));
2789
2790 CHECK(results.size() == accumulators_count);
2791 for (int i = 0; i < accumulators_count; i++) {
2792 Store(results[i], accumulator_addrs[i]);
2793 }
2794 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b());
2795 return EmitAccumResult(accumulator_addrs, accumulator_types, is_variadic);
2796 }
2797
EmitAccumResult(absl::Span<llvm::Value * const> accumulator_addrs,llvm::ArrayRef<llvm::Type * > accumulator_types,bool is_variadic)2798 StatusOr<llvm::Value*> ElementalIrEmitter::EmitAccumResult(
2799 absl::Span<llvm::Value* const> accumulator_addrs,
2800 llvm::ArrayRef<llvm::Type*> accumulator_types, bool is_variadic) {
2801 TF_RET_CHECK(accumulator_addrs.size() == accumulator_types.size());
2802 if (is_variadic) {
2803 // Emit a structure, as that what the LoopEmitter expects.
2804 llvm::Value* returned_structure = llvm::UndefValue::get(
2805 llvm::StructType::get(b()->getContext(), accumulator_types));
2806 for (int64 i = 0; i < accumulator_addrs.size(); i++) {
2807 llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
2808 returned_structure =
2809 b()->CreateInsertValue(returned_structure, accumulator_value, i);
2810 }
2811 return returned_structure;
2812 } else {
2813 CHECK_EQ(accumulator_addrs.size(), 1);
2814 return Load(accumulator_addrs[0]);
2815 }
2816 }
2817
EmitConvolution(const HloInstruction * convolution,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2818 StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
2819 const HloInstruction* convolution,
2820 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2821 const llvm_ir::IrArray::Index& index) {
2822 const HloInstruction* lhs = convolution->operand(0);
2823 const auto& input_generator = operand_to_generator.at(lhs);
2824 const HloInstruction* rhs = convolution->operand(1);
2825 const auto& kernel_generator = operand_to_generator.at(rhs);
2826 const Window& window = convolution->window();
2827
2828 const ConvolutionDimensionNumbers& dnums =
2829 convolution->convolution_dimension_numbers();
2830 int num_spatial_dims = dnums.output_spatial_dimensions_size();
2831 std::vector<llvm::Value*> output_spatial(num_spatial_dims);
2832 for (int i = 0; i < num_spatial_dims; ++i) {
2833 output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
2834 }
2835 llvm::Value* output_feature = index[dnums.output_feature_dimension()];
2836 llvm::Value* batch = index[dnums.output_batch_dimension()];
2837
2838 // We will accumulate the products into this sum to calculate the output entry
2839 // at the given index.
2840 PrimitiveType lhs_element_type = lhs->shape().element_type();
2841 llvm::Type* lhs_llvm_type =
2842 llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
2843 // Upcast the accumulator to F32 from F16 for increased precision.
2844 llvm::Type* accumulator_type =
2845 lhs_element_type == F16 ? b_->getFloatTy() : lhs_llvm_type;
2846 llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
2847 accumulator_type, "convolution_sum_address", b_);
2848 llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type);
2849 Store(constant_zero, sum_address);
2850
2851 llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), b_);
2852 std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
2853 for (int i = 0; i < num_spatial_dims; ++i) {
2854 kernel_spatial[i] =
2855 loops
2856 .AddLoop(
2857 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
2858 absl::StrCat("k", i))
2859 ->GetIndVarValue();
2860 }
2861 llvm::Value* input_feature =
2862 loops
2863 .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()),
2864 "iz")
2865 ->GetIndVarValue();
2866
2867 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
2868
2869 // Calculate the spatial index in the input array, taking striding, dilation
2870 // and padding into account. An index in the padding will be out of the bounds
2871 // of the array.
2872 const auto calculate_input_index = [this](llvm::Value* output_index,
2873 llvm::Value* kernel_index,
2874 const WindowDimension& window_dim) {
2875 llvm::Value* strided_index =
2876 NSWMul(output_index, b_->getInt64(window_dim.stride()));
2877 llvm::Value* dilated_kernel_index =
2878 NSWMul(kernel_index, b_->getInt64(window_dim.window_dilation()));
2879 return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
2880 b_->getInt64(window_dim.padding_low()));
2881 };
2882 std::vector<llvm::Value*> input_spatial(num_spatial_dims);
2883 for (int i = 0; i < num_spatial_dims; ++i) {
2884 input_spatial[i] = calculate_input_index(
2885 output_spatial[i], kernel_spatial[i], window.dimensions(i));
2886 }
2887
2888 // We need to check if 0 <= input dim < bound, as otherwise we are in the
2889 // padding so that we can skip the computation. That is equivalent to input
2890 // dim < bound as an *unsigned* comparison, since a negative value will wrap
2891 // to a large positive value. The input dim is dilated, so we need to dilate
2892 // the bound as well to match.
2893
2894 // Also need to check that the input coordinates are not in one of the
2895 // holes created by base dilation.
2896 const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
2897 llvm::Value* remainder = SRem(input_index, b_->getInt64(base_dilation));
2898 return ICmpEQ(remainder, b_->getInt64(0));
2899 };
2900
2901 llvm::Value* in_bounds_condition = b_->getInt1(true);
2902 for (int i = 0; i < num_spatial_dims; ++i) {
2903 llvm::ConstantInt* input_bound = b_->getInt64(window_util::DilatedBound(
2904 lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
2905 window.dimensions(i).base_dilation()));
2906 llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
2907 llvm::Value* dim_not_in_hole =
2908 not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
2909 llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
2910 in_bounds_condition = And(in_bounds_condition, dim_ok);
2911 }
2912
2913 // Now we need to map the dilated base coordinates back to the actual
2914 // data indices on the lhs.
2915 const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
2916 return SDiv(input_index, b_->getInt64(base_dilation));
2917 };
2918 for (int i = 0; i < num_spatial_dims; ++i) {
2919 input_spatial[i] =
2920 undilate(input_spatial[i], window.dimensions(i).base_dilation());
2921 }
2922
2923 llvm_ir::LlvmIfData if_data =
2924 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", b_);
2925 SetToFirstInsertPoint(if_data.true_block, b_);
2926
2927 // We are not in the padding, so carry out the computation.
2928 int num_dims = num_spatial_dims + 2;
2929 std::vector<llvm::Value*> input_multi_index(num_dims);
2930 for (int i = 0; i < num_spatial_dims; ++i) {
2931 input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
2932 }
2933 input_multi_index[dnums.input_feature_dimension()] = input_feature;
2934 input_multi_index[dnums.input_batch_dimension()] = batch;
2935
2936 std::vector<llvm::Value*> kernel_multi_index(num_dims);
2937 for (int i = 0; i < num_spatial_dims; ++i) {
2938 kernel_multi_index[dnums.kernel_spatial_dimensions(i)] =
2939 window.dimensions(i).window_reversal()
2940 ? NSWSub(b_->getInt64(window.dimensions(i).size() - 1),
2941 kernel_spatial[i])
2942 : kernel_spatial[i];
2943 }
2944
2945 kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature;
2946 kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature;
2947
2948 llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(),
2949 b_->getInt64Ty());
2950 TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
2951 input_generator(input_index));
2952 llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(),
2953 b_->getInt64Ty());
2954 TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
2955 kernel_generator(kernel_index));
2956 llvm::Value* sum = EmitMulAdd(input_value, kernel_value, Load(sum_address),
2957 convolution->shape().element_type());
2958 Store(sum, sum_address);
2959
2960 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
2961 return FPCast(Load(sum_address), lhs_llvm_type);
2962 }
2963
2964 // Evaluate polynomial using Horner's method.
EvaluatePolynomial(llvm::Type * type,llvm::Value * x,absl::Span<const double> coefficients)2965 StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
2966 llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
2967 llvm::Value* poly = llvm::ConstantFP::get(type, 0.0);
2968 for (const double c : coefficients) {
2969 poly = FAdd(FMul(poly, x), llvm::ConstantFP::get(type, c));
2970 }
2971 return poly;
2972 }
2973
2974 } // namespace xla
2975