1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <vector>
22
23 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/strings/str_cat.h"
27 #include "llvm/IR/BasicBlock.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
31 #include "tensorflow/compiler/xla/primitive_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
33 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
34 #include "tensorflow/compiler/xla/service/hlo_module.h"
35 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
37 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
38 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/statusor.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/lib/random/random.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/types.h"
48
49 namespace xla {
50
51 using absl::StrCat;
52 using llvm_ir::IrArray;
53 using llvm_ir::IrName;
54 using llvm_ir::SetToFirstInsertPoint;
55
56 namespace {
57
GlobalRandomValue()58 int64 GlobalRandomValue() {
59 static auto* mu = new tensorflow::mutex();
60 static std::mt19937_64 rng{42};
61 tensorflow::mutex_lock l(*mu);
62 return rng();
63 }
64
EmitReducePrecisionFloat(llvm::Value * x,int64 exponent_bits,int64 mantissa_bits,llvm::IRBuilder<> * b)65 llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
66 int64 mantissa_bits,
67 llvm::IRBuilder<>* b) {
68 // Integer and float types for casting and constant generation.
69 llvm::Type* float_type = x->getType();
70 llvm::IntegerType* int_type = b->getInt32Ty();
71
72 // Cast the input value to an integer for bitwise manipulation.
73 llvm::Value* x_as_int = b->CreateBitCast(x, int_type);
74
75 if (mantissa_bits < 23) {
76 // Last remaining mantissa bit.
77 const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
78
79 // Compute rounding bias for round-to-nearest with ties to even. This is
80 // equal to a base value of 0111... plus one bit if the last remaining
81 // mantissa bit is 1.
82 const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
83 llvm::Value* x_last_mantissa_bit = b->CreateLShr(
84 b->CreateAnd(x_as_int,
85 llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
86 (23 - mantissa_bits));
87 llvm::Value* x_rounding_bias =
88 b->CreateAdd(x_last_mantissa_bit,
89 llvm::ConstantInt::get(int_type, base_rounding_bias));
90
91 // Add rounding bias, and mask out truncated bits. Note that the case
92 // where adding the rounding bias overflows into the exponent bits is
93 // correct; the non-masked mantissa bits will all be zero, and the
94 // exponent will be incremented by one.
95 const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
96 x_as_int = b->CreateAdd(x_as_int, x_rounding_bias);
97 x_as_int = b->CreateAnd(x_as_int,
98 llvm::ConstantInt::get(int_type, truncation_mask));
99 }
100
101 if (exponent_bits < 8) {
102 // Masks for f32 values.
103 const uint32_t f32_sign_bit_mask = 1u << 31;
104 const uint32_t f32_exp_bits_mask = 0xffu << 23;
105
106 // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
107 // significant bit -- is equal to 1.0f for all exponent sizes. Adding
108 // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
109 // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
110 // exponent (corresponding to 0.0f).
111 //
112 // Thus, the f32 exponent corresponding to the highest non-infinite
113 // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
114 // exponent corresponding to the lowest exponent for a bit size of n is
115 // (2^7-1) - 2^(n-1)-1.
116 //
117 // Note that we have already checked that exponents_bits >= 1.
118 const uint32_t f32_exponent_bias = (1 << 7) - 1;
119 const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1;
120 const uint32_t reduced_max_exponent =
121 f32_exponent_bias + reduced_exponent_bias;
122 const uint32_t reduced_min_exponent =
123 f32_exponent_bias - reduced_exponent_bias;
124
125 // Do we overflow or underflow?
126 llvm::Value* x_exponent = b->CreateAnd(
127 x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
128 llvm::Value* x_overflows = b->CreateICmpUGT(
129 x_exponent,
130 llvm::ConstantInt::get(int_type, reduced_max_exponent << 23));
131 llvm::Value* x_underflows = b->CreateICmpULE(
132 x_exponent,
133 llvm::ConstantInt::get(int_type, reduced_min_exponent << 23));
134
135 // Compute appropriately-signed values of zero and infinity.
136 llvm::Value* x_signed_zero = b->CreateAnd(
137 x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask));
138 llvm::Value* x_signed_inf = b->CreateOr(
139 x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
140
141 // Force to zero or infinity if overflow or underflow. (Note that this
142 // truncates all denormal values to zero, rather than rounding them.)
143 x_as_int = b->CreateSelect(x_overflows, x_signed_inf, x_as_int);
144 x_as_int = b->CreateSelect(x_underflows, x_signed_zero, x_as_int);
145 }
146
147 // Cast the result back to a floating-point type.
148 llvm::Value* result = b->CreateBitCast(x_as_int, float_type);
149
150 // Correct result for NaN inputs.
151 //
152 // The exponent handling will "normalize" NaN values to infinities, which is
153 // undesirable (except in the case with no mantissa bits, in which case it
154 // is mandatory). This logic also handles cases where mantissa-rounding
155 // causes a NaN's mantissa to overflow into the exponent bits, which would
156 // otherwise create an erroneous zero value.
157 //
158 // If the fast-math flags are set to assume no NaNs, the comparison is likely
159 // to be optimized away, so there's no point in even emitting it.
160 if (!b->getFastMathFlags().noNaNs()) {
161 llvm::Value* x_is_nan = b->CreateFCmpUNO(x, x);
162
163 if (mantissa_bits > 0) {
164 result = b->CreateSelect(x_is_nan, x, result);
165 } else {
166 result = b->CreateSelect(
167 x_is_nan, llvm::ConstantFP::getInfinity(float_type), result);
168 }
169 }
170 return result;
171 }
172
EmitF32ToBF16(llvm::Value * f32_value,llvm::IRBuilder<> * b)173 llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, llvm::IRBuilder<>* b) {
174 auto reduced_precision = EmitReducePrecisionFloat(
175 f32_value,
176 /*exponent_bits=*/primitive_util::kBFloat16ExponentBits,
177 /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b);
178 auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty());
179 auto shifted = b->CreateLShr(as_int32, 16);
180 auto truncated = b->CreateTrunc(shifted, b->getInt16Ty());
181 return b->CreateBitCast(truncated, b->getInt16Ty());
182 }
183
EmitBF16ToF32(llvm::Value * bf16_value,llvm::IRBuilder<> * b)184 llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, llvm::IRBuilder<>* b) {
185 auto as_int16 = b->CreateBitCast(bf16_value, b->getInt16Ty());
186 auto as_int32 = b->CreateZExt(as_int16, b->getInt32Ty());
187 auto shifted = b->CreateShl(as_int32, 16);
188 return b->CreateBitCast(shifted, b->getFloatTy());
189 }
190
EmitIntegralToFloating(llvm::Value * integer_value,PrimitiveType from_type,PrimitiveType to_type,llvm::Module * module,llvm::IRBuilder<> * b)191 llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
192 PrimitiveType from_type,
193 PrimitiveType to_type, llvm::Module* module,
194 llvm::IRBuilder<>* b) {
195 if (primitive_util::IsSignedIntegralType(from_type)) {
196 return b->CreateSIToFP(integer_value,
197 llvm_ir::PrimitiveTypeToIrType(to_type, module));
198 } else {
199 CHECK(primitive_util::IsUnsignedIntegralType(from_type) ||
200 from_type == PRED);
201 return b->CreateUIToFP(integer_value,
202 llvm_ir::PrimitiveTypeToIrType(to_type, module));
203 }
204 }
205
206 } // namespace
207
EmitUnaryOp(const HloInstruction * op,llvm::Value * operand_value)208 StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
209 const HloInstruction* op, llvm::Value* operand_value) {
210 if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
211 op->operand(0)->shape().element_type() == PRED) {
212 return EmitIntegerUnaryOp(op, operand_value);
213 } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
214 return EmitComplexUnaryOp(op, operand_value);
215 } else {
216 return EmitFloatUnaryOp(op, operand_value);
217 }
218 }
219
EmitIntegerUnaryOp(const HloInstruction * op,llvm::Value * operand_value)220 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
221 const HloInstruction* op, llvm::Value* operand_value) {
222 switch (op->opcode()) {
223 case HloOpcode::kConvert: {
224 PrimitiveType from_type = op->operand(0)->shape().element_type();
225 PrimitiveType to_type = op->shape().element_type();
226 CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED)
227 << from_type;
228 if (from_type == to_type) {
229 return operand_value;
230 }
231 if (to_type == PRED) {
232 return b_->CreateZExt(
233 ICmpNE(operand_value,
234 llvm::ConstantInt::get(operand_value->getType(), 0)),
235 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
236 }
237 if (primitive_util::IsIntegralType(to_type)) {
238 return IntCast(operand_value,
239 llvm_ir::PrimitiveTypeToIrType(to_type, module_),
240 primitive_util::IsSignedIntegralType(from_type));
241 }
242 if (primitive_util::IsFloatingPointType(to_type)) {
243 if (to_type == BF16) {
244 return EmitF32ToBF16(EmitIntegralToFloating(operand_value, from_type,
245 F32, module_, b_),
246 b_);
247 }
248 return EmitIntegralToFloating(operand_value, from_type, to_type,
249 module_, b_);
250 }
251 if (primitive_util::IsComplexType(to_type)) {
252 auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
253 primitive_util::ComplexComponentType(to_type), module_);
254 if (primitive_util::IsSignedIntegralType(from_type)) {
255 return EmitComposeComplex(
256 op, SIToFP(operand_value, to_ir_component_type), nullptr);
257 }
258 if (primitive_util::IsUnsignedIntegralType(from_type) ||
259 from_type == PRED) {
260 return EmitComposeComplex(
261 op, UIToFP(operand_value, to_ir_component_type), nullptr);
262 }
263 }
264 return Unimplemented("conversion from primitive type %s to %s",
265 PrimitiveType_Name(from_type),
266 PrimitiveType_Name(to_type));
267 }
268 case HloOpcode::kBitcastConvert: {
269 PrimitiveType from_type = op->operand(0)->shape().element_type();
270 PrimitiveType to_type = op->shape().element_type();
271 CHECK(primitive_util::IsIntegralType(from_type));
272 if (from_type == to_type) {
273 return operand_value;
274 }
275 if (primitive_util::BitWidth(from_type) ==
276 primitive_util::BitWidth(to_type)) {
277 return BitCast(operand_value,
278 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
279 }
280 return InvalidArgument(
281 "bitcast conversion from primitive type %s to %s with unequal "
282 "bit-widths (%u versus %u) ",
283 PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
284 primitive_util::BitWidth(from_type),
285 primitive_util::BitWidth(to_type));
286 }
287 case HloOpcode::kAbs: {
288 bool is_signed =
289 primitive_util::IsSignedIntegralType(op->shape().element_type());
290 if (is_signed) {
291 auto type =
292 llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
293 auto cmp = ICmpSGE(operand_value, GetZero(type));
294 return Select(cmp, operand_value, Neg(operand_value));
295 } else {
296 return operand_value;
297 }
298 }
299 case HloOpcode::kClz: {
300 auto is_zero_undef = b_->getFalse();
301 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctlz,
302 {operand_value, is_zero_undef},
303 {operand_value->getType()}, b_);
304 }
305 case HloOpcode::kSign: {
306 CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type()))
307 << op->shape().element_type();
308 auto type =
309 llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
310 auto cmp = ICmpEQ(operand_value, GetZero(type));
311 auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1);
312 return Select(cmp, GetZero(type), Or(ashr, 1));
313 }
314 case HloOpcode::kNegate:
315 return Neg(operand_value);
316 case HloOpcode::kNot: {
317 auto type = op->shape().element_type();
318 if (type == PRED) {
319 // It is not sufficient to just call CreateNot() here because a PRED
320 // is represented as an i8 and the truth value is stored only in the
321 // bottom bit.
322 return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())),
323 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
324 } else if (primitive_util::IsIntegralType(type)) {
325 return Not(operand_value);
326 }
327 return Unimplemented("unary op Not is not defined for type '%d'", type);
328 }
329 default:
330 return Unimplemented("unary integer op '%s'",
331 HloOpcodeString(op->opcode()));
332 }
333 }
334
EmitFloatUnaryOp(const HloInstruction * op,llvm::Value * operand_value)335 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
336 const HloInstruction* op, llvm::Value* operand_value) {
337 switch (op->opcode()) {
338 case HloOpcode::kConvert: {
339 PrimitiveType from_type = op->operand(0)->shape().element_type();
340 PrimitiveType to_type = op->shape().element_type();
341 CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type;
342 if (from_type == to_type) {
343 return operand_value;
344 }
345 if (primitive_util::IsComplexType(to_type)) {
346 PrimitiveType to_component_type =
347 primitive_util::ComplexComponentType(to_type);
348 if (from_type == to_component_type) {
349 return EmitComposeComplex(op, operand_value, nullptr);
350 }
351 return EmitComposeComplex(
352 op,
353 FPCast(operand_value,
354 llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
355 nullptr);
356 }
357 if (from_type == BF16) {
358 TF_RET_CHECK(to_type != BF16);
359 operand_value = EmitBF16ToF32(operand_value, b_);
360 from_type = F32;
361 if (from_type == to_type) {
362 return operand_value;
363 }
364 }
365 if (from_type == F32 && to_type == BF16) {
366 return EmitF32ToBF16(operand_value, b_);
367 }
368 if (to_type == PRED) {
369 return b_->CreateZExt(
370 FCmpUNE(operand_value,
371 llvm::ConstantFP::get(operand_value->getType(), 0.0)),
372 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
373 }
374 if (primitive_util::IsFloatingPointType(to_type)) {
375 return FPCast(operand_value,
376 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
377 }
378 if (primitive_util::IsSignedIntegralType(to_type)) {
379 return FPToSI(operand_value,
380 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
381 }
382 if (primitive_util::IsUnsignedIntegralType(to_type)) {
383 return FPToUI(operand_value,
384 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
385 }
386 return Unimplemented("unhandled conversion operation: %s => %s",
387 PrimitiveType_Name(from_type),
388 PrimitiveType_Name(to_type));
389 }
390 case HloOpcode::kBitcastConvert: {
391 PrimitiveType from_type = op->operand(0)->shape().element_type();
392 PrimitiveType to_type = op->shape().element_type();
393 CHECK(primitive_util::IsFloatingPointType(from_type));
394 if (from_type == to_type) {
395 return operand_value;
396 }
397 if (primitive_util::BitWidth(from_type) ==
398 primitive_util::BitWidth(to_type)) {
399 return BitCast(operand_value,
400 llvm_ir::PrimitiveTypeToIrType(to_type, module_));
401 }
402 return InvalidArgument(
403 "bitcast conversion from primitive type %s to %s with unequal "
404 "bit-widths (%u versus %u) ",
405 PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
406 primitive_util::BitWidth(from_type),
407 primitive_util::BitWidth(to_type));
408 }
409 case HloOpcode::kExp:
410 return EmitExp(op->shape().element_type(), operand_value);
411 case HloOpcode::kExpm1:
412 return EmitExpm1(op->shape().element_type(), operand_value);
413 case HloOpcode::kLog:
414 return EmitLog(op->shape().element_type(), operand_value);
415 case HloOpcode::kLog1p:
416 return EmitLog1p(op->shape().element_type(), operand_value);
417 case HloOpcode::kCos:
418 return EmitCos(op->shape().element_type(), operand_value);
419 case HloOpcode::kSin:
420 return EmitSin(op->shape().element_type(), operand_value);
421 case HloOpcode::kTanh:
422 return EmitTanh(op->shape().element_type(), operand_value);
423 case HloOpcode::kSqrt:
424 return EmitSqrt(op->shape().element_type(), operand_value);
425 case HloOpcode::kRsqrt:
426 return EmitRsqrt(op->shape().element_type(), operand_value);
427 case HloOpcode::kFloor:
428 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
429 {operand_value},
430 {operand_value->getType()}, b_);
431 case HloOpcode::kCeil:
432 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ceil,
433 {operand_value},
434 {operand_value->getType()}, b_);
435 case HloOpcode::kAbs:
436 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
437 {operand_value},
438 {operand_value->getType()}, b_);
439 case HloOpcode::kRoundNearestAfz:
440 return EmitRoundNearestAfz(op->shape().element_type(), operand_value);
441 case HloOpcode::kSign: {
442 auto type = operand_value->getType();
443 auto zero = llvm::ConstantFP::get(type, 0.0);
444 auto ne0_i1 = FCmpONE(operand_value, zero);
445 auto ne0_float = UIToFP(ne0_i1, type);
446 llvm::Value* result = llvm_ir::EmitCallToIntrinsic(
447 llvm::Intrinsic::copysign, {ne0_float, operand_value},
448 {operand_value->getType()}, b_);
449 auto is_nan = FCmpUNO(operand_value, operand_value);
450 result = Select(is_nan, operand_value, result);
451 return result;
452 }
453 case HloOpcode::kIsFinite: {
454 // abs(x) o!= inf, this works because the comparison returns false if
455 // either operand is NaN.
456 auto type = operand_value->getType();
457 auto abs_value = llvm_ir::EmitCallToIntrinsic(
458 llvm::Intrinsic::fabs, {operand_value}, {type}, b_);
459 auto infinity = llvm::ConstantFP::getInfinity(type);
460 auto not_infinite = FCmpONE(abs_value, infinity);
461 return b_->CreateZExt(not_infinite,
462 llvm_ir::PrimitiveTypeToIrType(PRED, module_));
463 }
464 case HloOpcode::kNegate:
465 return FNeg(operand_value);
466 case HloOpcode::kReal:
467 return operand_value;
468 case HloOpcode::kImag:
469 return llvm::ConstantFP::get(operand_value->getType(), 0.0);
470 default:
471 return Unimplemented("unary floating-point op '%s'",
472 HloOpcodeString(op->opcode()));
473 }
474 }
475
EmitComplexUnaryOp(const HloInstruction * op,llvm::Value * operand_value)476 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
477 const HloInstruction* op, llvm::Value* operand_value) {
478 PrimitiveType input_type = op->operand(0)->shape().element_type();
479 PrimitiveType component_type =
480 primitive_util::IsComplexType(input_type)
481 ? primitive_util::ComplexComponentType(input_type)
482 : input_type;
483 switch (op->opcode()) {
484 case HloOpcode::kLog: {
485 // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
486 auto a = EmitExtractReal(operand_value);
487 auto b = EmitExtractImag(operand_value);
488 llvm::Type* llvm_ty = a->getType();
489 auto sum_sq = FAdd(FMul(a, a), FMul(b, b));
490 TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
491 TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a));
492 auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
493 return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
494 }
495 case HloOpcode::kLog1p: {
496 // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
497 auto a = EmitExtractReal(operand_value);
498 auto b = EmitExtractImag(operand_value);
499 llvm::Type* llvm_ty = a->getType();
500 auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
501 auto a_plus_one = FAdd(a, one);
502 auto sum_sq = FAdd(FMul(a_plus_one, a_plus_one), FMul(b, b));
503 TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
504 TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one));
505 auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
506 return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
507 }
508 case HloOpcode::kConvert: {
509 PrimitiveType from_type = op->operand(0)->shape().element_type();
510 TF_RET_CHECK(primitive_util::IsComplexType(from_type));
511 PrimitiveType to_type = op->shape().element_type();
512 TF_RET_CHECK(primitive_util::IsComplexType(to_type));
513 if (from_type == to_type) {
514 return operand_value;
515 }
516 PrimitiveType to_component_type =
517 primitive_util::ComplexComponentType(to_type);
518 auto to_ir_component_type =
519 llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
520 return EmitComposeComplex(
521 op, FPCast(EmitExtractReal(operand_value), to_ir_component_type),
522 FPCast(EmitExtractImag(operand_value), to_ir_component_type));
523 }
524 case HloOpcode::kExp: {
525 // e^(a+bi) = e^a*(cos(b)+sin(b)i)
526 TF_ASSIGN_OR_RETURN(
527 auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value)));
528 TF_ASSIGN_OR_RETURN(
529 auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
530 TF_ASSIGN_OR_RETURN(
531 auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
532 return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b));
533 }
534 case HloOpcode::kExpm1: {
535 // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
536 TF_ASSIGN_OR_RETURN(
537 auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value)));
538 TF_ASSIGN_OR_RETURN(
539 auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
540 TF_ASSIGN_OR_RETURN(
541 auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
542 auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0);
543 auto real_result = FSub(FMul(exp_a, cos_b), one);
544 auto imag_result = FMul(exp_a, sin_b);
545 return EmitComposeComplex(op, real_result, imag_result);
546 }
547 case HloOpcode::kCos: {
548 // cos(z) = .5(e^(iz) + e^(-iz))
549 // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
550 // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
551 // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i))
552 // cos(-x) = cos(x) and sin(-x) = -sin(x), so
553 // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
554 // = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
555 auto a = EmitExtractReal(operand_value);
556 auto b = EmitExtractImag(operand_value);
557 auto type = a->getType();
558 TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
559 auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
560 auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
561 TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
562 TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
563 return EmitComposeComplex(op,
564 FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)),
565 FMul(sin_a, FSub(half_exp_neg_b, half_exp_b)));
566 }
567 case HloOpcode::kSin: {
568 // sin(z) = .5i(e^(-iz) - e^(iz))
569 // sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi)))
570 // = .5i(e^(b-ai) - e^(-b+ai))
571 // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
572 // sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i))
573 // = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a)))
574 // cos(-x) = cos(x) and sin(-x) = -sin(x), so
575 // = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
576 // = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
577 auto a = EmitExtractReal(operand_value);
578 auto b = EmitExtractImag(operand_value);
579 auto type = a->getType();
580 TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
581 auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
582 auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
583 TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
584 TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
585 return EmitComposeComplex(op,
586 FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)),
587 FMul(cos_a, FSub(half_exp_b, half_exp_neg_b)));
588 }
589 case HloOpcode::kTanh: {
590 /*
591 tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x))
592 e^(a+bi) = e^a*(cos(b)+sin(b)i)
593 so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) /
594 (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a))
595 cos(b)=cos(-b), sin(-b)=-sin(b)
596 so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) /
597 (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a))
598 =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) /
599 (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a))
600 =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) /
601 (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a))
602 This is a complex division, so we can multiply by denom_conj/denom_conj
603 =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) *
604 (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) /
605 ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
606 =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
607 i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
608 ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
609 */
610 auto a = EmitExtractReal(operand_value);
611 auto b = EmitExtractImag(operand_value);
612 TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a));
613 TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b));
614 TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b));
615 auto exp_neg_a = FDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
616 auto exp_2a_minus_exp_neg_2a =
617 FSub(FMul(exp_a, exp_a), FMul(exp_neg_a, exp_neg_a));
618 auto cos_b_sq = FMul(cos_b, cos_b);
619 auto sin_b_sq = FMul(sin_b, sin_b);
620 auto real_num = FAdd(FMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
621 FMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
622 auto cos_b_sin_b = FMul(cos_b, sin_b);
623 auto exp_a_plus_exp_neg_a = FAdd(exp_a, exp_neg_a);
624 auto exp_a_plus_exp_neg_a_sq =
625 FMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
626 auto exp_a_minus_exp_neg_a = FSub(exp_a, exp_neg_a);
627 auto exp_a_minus_exp_neg_a_sq =
628 FMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
629 auto imag_num = FMul(
630 cos_b_sin_b, FSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq));
631 auto denom = FAdd(FMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
632 FMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
633 return EmitComposeComplex(op, FDiv(real_num, denom),
634 FDiv(imag_num, denom));
635 }
636 case HloOpcode::kAbs: {
637 auto sum_sq = FAdd(
638 FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)),
639 FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value)));
640 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq},
641 {sum_sq->getType()}, b_);
642 }
643 case HloOpcode::kSign: { // Sign(c) = c / |c|
644 auto sum_sq = FAdd(
645 FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)),
646 FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value)));
647 auto cplx_abs = llvm_ir::EmitCallToIntrinsic(
648 llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_);
649 auto type = cplx_abs->getType();
650 auto zero = llvm::ConstantFP::get(type, 0.0);
651 auto oeq = FCmpOEQ(cplx_abs, zero);
652 return Select(
653 oeq, EmitComposeComplex(op, zero, zero),
654 EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs),
655 FDiv(EmitExtractImag(operand_value), cplx_abs)));
656 }
657 case HloOpcode::kSqrt: {
658 auto a = EmitExtractReal(operand_value);
659 auto b = EmitExtractImag(operand_value);
660 auto c = llvm::ConstantFP::get(a->getType(), 0.5);
661 auto d = llvm::ConstantFP::get(b->getType(), 0.0);
662 return EmitComplexPower(op, a, b, c, d);
663 }
664 case HloOpcode::kRsqrt: {
665 auto a = EmitExtractReal(operand_value);
666 auto b = EmitExtractImag(operand_value);
667 auto c = llvm::ConstantFP::get(a->getType(), -0.5);
668 auto d = llvm::ConstantFP::get(b->getType(), 0.0);
669 return EmitComplexPower(op, a, b, c, d);
670 }
671 case HloOpcode::kNegate:
672 return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
673 FNeg(EmitExtractImag(operand_value)));
674 case HloOpcode::kReal:
675 return EmitExtractReal(operand_value);
676 case HloOpcode::kImag:
677 return EmitExtractImag(operand_value);
678 default:
679 return Unimplemented("unary complex op '%s'",
680 HloOpcodeString(op->opcode()));
681 }
682 }
683
EmitBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)684 StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
685 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
686 PrimitiveType operand_type = op->operand(0)->shape().element_type();
687 if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
688 operand_type == PRED) {
689 return EmitIntegerBinaryOp(
690 op, lhs_value, rhs_value,
691 primitive_util::IsSignedIntegralType(operand_type));
692 } else if (primitive_util::IsComplexType(operand_type)) {
693 return EmitComplexBinaryOp(op, lhs_value, rhs_value);
694 } else {
695 return EmitFloatBinaryOp(op, lhs_value, rhs_value);
696 }
697 }
698
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)699 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
700 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
701 switch (op->opcode()) {
702 case HloOpcode::kComplex:
703 return EmitComposeComplex(op, lhs_value, rhs_value);
704 case HloOpcode::kAdd:
705 return FAdd(lhs_value, rhs_value);
706 case HloOpcode::kSubtract:
707 return FSub(lhs_value, rhs_value);
708 case HloOpcode::kMultiply:
709 return FMul(lhs_value, rhs_value);
710 case HloOpcode::kDivide:
711 return FDiv(lhs_value, rhs_value);
712 case HloOpcode::kRemainder:
713 return FRem(lhs_value, rhs_value);
714 // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
715 // comparisons always return false when one of the operands is NaN, whereas
716 // unordered comparisons return true.
717 //
718 // We use ordered comparisons for everything except kNe, where we use an
719 // unordered comparison. This makes x != y equivalent to !(x == y), and
720 // matches C++'s semantics.
721 case HloOpcode::kCompare: {
722 switch (op->comparison_direction()) {
723 case ComparisonDirection::kEq:
724 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
725 rhs_value, b_);
726 case ComparisonDirection::kNe:
727 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
728 rhs_value, b_);
729 case ComparisonDirection::kLt:
730 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
731 rhs_value, b_);
732 case ComparisonDirection::kGt:
733 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
734 rhs_value, b_);
735 case ComparisonDirection::kLe:
736 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
737 rhs_value, b_);
738 case ComparisonDirection::kGe:
739 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
740 rhs_value, b_);
741 }
742 }
743 case HloOpcode::kMaximum:
744 return EmitFloatMax(lhs_value, rhs_value);
745 case HloOpcode::kMinimum:
746 return EmitFloatMin(lhs_value, rhs_value);
747 case HloOpcode::kPower:
748 return EmitPow(op->shape().element_type(), lhs_value, rhs_value);
749 case HloOpcode::kAtan2:
750 return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value);
751 default:
752 return Unimplemented("binary floating point op '%s'",
753 HloOpcodeString(op->opcode()));
754 }
755 }
756
757 // (a+bi)^(c+di) =
758 // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
759 // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
EmitComplexPower(const HloInstruction * op,llvm::Value * a,llvm::Value * b,llvm::Value * c,llvm::Value * d)760 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexPower(
761 const HloInstruction* op, llvm::Value* a, llvm::Value* b, llvm::Value* c,
762 llvm::Value* d) {
763 PrimitiveType component_type =
764 primitive_util::ComplexComponentType(op->shape().element_type());
765 auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b));
766 auto zero = llvm::ConstantFP::get(a->getType(), 0);
767 auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
768 auto one = llvm::ConstantFP::get(a->getType(), 1);
769 auto half_c = FMul(one_half, c);
770
771 TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
772 EmitPow(component_type, aa_p_bb, half_c));
773
774 auto neg_d = FNeg(d);
775 TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a));
776 auto neg_d_arg_lhs = FMul(neg_d, arg_lhs);
777 TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
778 EmitExp(component_type, neg_d_arg_lhs));
779 auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
780 TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
781 auto half_d = FMul(one_half, d);
782 auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb));
783 TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
784 TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
785 // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
786 // Branch Cuts for Complex Elementary Functions or Much Ado About
787 // Nothing's Sign Bit, W. Kahan, Section 10.
788 return Select(
789 And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)),
790 EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero),
791 EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)));
792 }
793
EmitComplexBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)794 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
795 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
796 switch (op->opcode()) {
797 case HloOpcode::kAdd:
798 return EmitComposeComplex(
799 op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
800 FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
801 case HloOpcode::kSubtract:
802 return EmitComposeComplex(
803 op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
804 FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
805 case HloOpcode::kMultiply:
806 return EmitComposeComplex(
807 op,
808 FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
809 FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))),
810 FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
811 FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))));
812 case HloOpcode::kDivide: {
813 // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di))
814 // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2)
815 auto rhs_sum_sq =
816 FAdd(FMul(EmitExtractReal(rhs_value), EmitExtractReal(rhs_value)),
817 FMul(EmitExtractImag(rhs_value), EmitExtractImag(rhs_value)));
818 auto type = rhs_sum_sq->getType();
819 auto zero = llvm::ConstantFP::get(type, 0.0);
820 auto oeq = FCmpOEQ(rhs_sum_sq, zero);
821 auto real_inf_or_nan = FDiv(EmitExtractReal(lhs_value), zero);
822 auto imag_inf_or_nan = FDiv(EmitExtractImag(lhs_value), zero);
823 return Select(
824 oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan),
825 EmitComposeComplex(op,
826 FDiv(FAdd(FMul(EmitExtractReal(lhs_value),
827 EmitExtractReal(rhs_value)),
828 FMul(EmitExtractImag(lhs_value),
829 EmitExtractImag(rhs_value))),
830 rhs_sum_sq),
831 FDiv(FSub(FMul(EmitExtractImag(lhs_value),
832 EmitExtractReal(rhs_value)),
833 FMul(EmitExtractReal(lhs_value),
834 EmitExtractImag(rhs_value))),
835 rhs_sum_sq)));
836 }
837 // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
838 // comparisons always return false when one of the operands is NaN, whereas
839 // unordered comparisons return true.
840 //
841 // We use ordered comparisons for everything except kNe, where we use an
842 // unordered comparison. This makes x != y equivalent to !(x == y), and
843 // matches C++'s semantics.
844 case HloOpcode::kCompare: {
845 switch (op->comparison_direction()) {
846 case ComparisonDirection::kEq:
847 return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
848 EmitExtractReal(lhs_value),
849 EmitExtractReal(rhs_value), b_),
850 llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
851 EmitExtractImag(lhs_value),
852 EmitExtractImag(rhs_value), b_));
853 case ComparisonDirection::kNe:
854 return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
855 EmitExtractReal(lhs_value),
856 EmitExtractReal(rhs_value), b_),
857 llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
858 EmitExtractImag(lhs_value),
859 EmitExtractImag(rhs_value), b_));
860 default:
861 return Unimplemented(
862 "complex comparison '%s'",
863 ComparisonDirectionToString(op->comparison_direction()));
864 }
865 }
866 case HloOpcode::kPower: {
867 auto a = EmitExtractReal(lhs_value);
868 auto b = EmitExtractImag(lhs_value);
869 auto c = EmitExtractReal(rhs_value);
870 auto d = EmitExtractImag(rhs_value);
871 return EmitComplexPower(op, a, b, c, d);
872 }
873 default:
874 return Unimplemented("binary complex op '%s'",
875 HloOpcodeString(op->opcode()));
876 }
877 }
878
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value)879 llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
880 llvm::Value* rhs_value) {
881 return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_);
882 }
883
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value)884 llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
885 llvm::Value* rhs_value) {
886 return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_);
887 }
888
889 // TODO(b/123355973): We have an implementation of erfinv in math.cc. We
890 // shouldn't have two implementations, especially since this one isn't testable
891 // (it's only observable via a normally-distributed RNG).
EmitErfInv(PrimitiveType prim_type,llvm::Value * x)892 StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
893 llvm::Value* x) {
894 if (prim_type != F16 && prim_type != F32 && prim_type != F64) {
895 return Unimplemented(
896 "Inverse erf is only implemented for element "
897 "types F16, F32 and F64.");
898 }
899
900 // Upcast half to float.
901 if (prim_type == F16) {
902 x = b_->CreateFPExt(x, b_->getFloatTy());
903 }
904
905 auto get_float = [&](const double f) {
906 return llvm::ConstantFP::get(x->getType(), f);
907 };
908 auto multiply_add = [&](absl::Span<const double> coefficients,
909 llvm::Value* w) {
910 llvm::Value* p = get_float(coefficients.front());
911 coefficients.remove_prefix(1);
912 for (float coefficient : coefficients) {
913 p = FAdd(FMul(p, w), get_float(coefficient));
914 }
915 return p;
916 };
917
918 // Approximation for inverse error function from
919 // Giles, M., "Approximating the erfinv function".
920 // The approximation has the form (float version):
921 // w = -log((1-x)*(1+x))
922 // if ( w < 5 ) {
923 // w = w - 2.5
924 // p = sum_{i=1}^n lq[i]*w^i
925 // } else {
926 // w = sqrt(w) - 3
927 // p = sum_{i=1}^n gq[i]*w^i
928 // }
929 // return p*x
930 llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
931 module_, llvm::Intrinsic::log, {x->getType()});
932
933 llvm::Value* w = FNeg(Call(
934 logf_fn, {FMul(FSub(get_float(1.0f), x), FAdd(get_float(1.0f), x))}));
935
936 llvm::Value* p_addr =
937 llvm_ir::EmitAllocaAtFunctionEntry(x->getType(), "p.addr", b_);
938
939 if (prim_type == F16 || prim_type == F32) {
940 llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
941 FCmpOLT(w, get_float(5.0f)), "w_less_than_five", b_);
942 // Handle true BB.
943 SetToFirstInsertPoint(if_data.true_block, b_);
944 {
945 llvm::Value* lw = FSub(w, get_float(2.5f));
946 absl::Span<const double> lq{
947 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
948 -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
949 -0.00417768164f, 0.246640727f, 1.50140941f};
950 llvm::Value* p = multiply_add(lq, lw);
951 Store(p, p_addr);
952 }
953
954 // Handle false BB.
955 SetToFirstInsertPoint(if_data.false_block, b_);
956 {
957 llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
958 module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
959
960 llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.0f));
961 absl::Span<const double> gq{
962 -0.000200214257f, 0.000100950558f, 0.00134934322f,
963 -0.00367342844f, 0.00573950773f, -0.0076224613f,
964 0.00943887047f, 1.00167406f, 2.83297682f};
965 llvm::Value* p = multiply_add(gq, gw);
966 Store(p, p_addr);
967 }
968
969 SetToFirstInsertPoint(if_data.after_block, b_);
970 } else {
971 DCHECK(prim_type == F64);
972
973 llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
974 FCmpOLT(w, get_float(6.25)), "w_less_than_6.25", b_);
975
976 SetToFirstInsertPoint(if_data.true_block, b_);
977 {
978 llvm::Value* lw = FSub(w, get_float(3.125));
979 absl::Span<const double> c{
980 -3.6444120640178196996e-21, -1.685059138182016589e-19,
981 1.2858480715256400167e-18, 1.115787767802518096e-17,
982 -1.333171662854620906e-16, 2.0972767875968561637e-17,
983 6.6376381343583238325e-15, -4.0545662729752068639e-14,
984 -8.1519341976054721522e-14, 2.6335093153082322977e-12,
985 -1.2975133253453532498e-11, -5.4154120542946279317e-11,
986 1.051212273321532285e-09, -4.1126339803469836976e-09,
987 -2.9070369957882005086e-08, 4.2347877827932403518e-07,
988 -1.3654692000834678645e-06, -1.3882523362786468719e-05,
989 0.0001867342080340571352, -0.00074070253416626697512,
990 -0.0060336708714301490533, 0.24015818242558961693,
991 1.6536545626831027356};
992 llvm::Value* p = multiply_add(c, lw);
993 Store(p, p_addr);
994 }
995
996 SetToFirstInsertPoint(if_data.false_block, b_);
997 llvm_ir::LlvmIfData if_data_second = llvm_ir::EmitIfThenElse(
998 FCmpOLT(w, get_float(16.0)), "w_less_than_16", b_);
999 SetToFirstInsertPoint(if_data_second.true_block, b_);
1000 {
1001 llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
1002 module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
1003
1004 llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.25));
1005 absl::Span<const double> t1{
1006 2.2137376921775787049e-09, 9.0756561938885390979e-08,
1007 -2.7517406297064545428e-07, 1.8239629214389227755e-08,
1008 1.5027403968909827627e-06, -4.013867526981545969e-06,
1009 2.9234449089955446044e-06, 1.2475304481671778723e-05,
1010 -4.7318229009055733981e-05, 6.8284851459573175448e-05,
1011 2.4031110387097893999e-05, -0.0003550375203628474796,
1012 0.00095328937973738049703, -0.0016882755560235047313,
1013 0.0024914420961078508066, -0.0037512085075692412107,
1014 0.005370914553590063617, 1.0052589676941592334,
1015 3.0838856104922207635};
1016 llvm::Value* p = multiply_add(t1, gw);
1017 Store(p, p_addr);
1018 }
1019
1020 SetToFirstInsertPoint(if_data_second.false_block, b_);
1021 {
1022 llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
1023 module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
1024
1025 llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(5.0));
1026 absl::Span<const double> t2{
1027 -2.7109920616438573243e-11, -2.5556418169965252055e-10,
1028 1.5076572693500548083e-09, -3.7894654401267369937e-09,
1029 7.6157012080783393804e-09, -1.4960026627149240478e-08,
1030 2.9147953450901080826e-08, -6.7711997758452339498e-08,
1031 2.2900482228026654717e-07, -9.9298272942317002539e-07,
1032 4.5260625972231537039e-06, -1.9681778105531670567e-05,
1033 7.5995277030017761139e-05, -0.00021503011930044477347,
1034 -0.00013871931833623122026, 1.0103004648645343977,
1035 4.8499064014085844221};
1036 llvm::Value* p = multiply_add(t2, gw);
1037 Store(p, p_addr);
1038 }
1039
1040 SetToFirstInsertPoint(if_data.after_block, b_);
1041 }
1042 llvm::Value* p = Load(p_addr);
1043 x = FMul(p, x);
1044 // Trunc back to half if needed.
1045 if (prim_type == F16) {
1046 x = b_->CreateFPTrunc(x, b_->getHalfTy());
1047 }
1048 return x;
1049 }
1050
EmitErfcInv(PrimitiveType prim_type,llvm::Value * value)1051 StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type,
1052 llvm::Value* value) {
1053 // Compute erfcinv(value) by calculating erfinv(1.0 - value).
1054 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1055 auto one = llvm::ConstantFP::get(type, 1.0);
1056 return EmitErfInv(prim_type, FSub(one, value));
1057 }
1058
EmitLog(PrimitiveType prim_type,llvm::Value * value)1059 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
1060 llvm::Value* value) {
1061 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
1062 {value->getType()}, b_);
1063 }
1064
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)1065 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
1066 llvm::Value* value) {
1067 auto x = value;
1068 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1069 auto one = llvm::ConstantFP::get(type, 1.0);
1070 auto negative_half = llvm::ConstantFP::get(type, -0.5);
1071 // When x is large, the naive evaluation of ln(x + 1) is more
1072 // accurate than the Taylor series.
1073 TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
1074 // The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + ….
1075 auto for_small_x = FMul(FAdd(FMul(negative_half, x), one), x);
1076 const auto kAntilogarithmIsSmallThreshold = 1e-4;
1077 auto abs_x =
1078 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1079 auto x_is_small = FCmpOLT(
1080 abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold));
1081 return Select(x_is_small, for_small_x, for_large_x);
1082 }
1083
EmitSqrt(PrimitiveType prim_type,llvm::Value * value)1084 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
1085 llvm::Value* value) {
1086 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {value},
1087 {value->getType()}, b_);
1088 }
1089
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)1090 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
1091 llvm::Value* value) {
1092 TF_ASSIGN_OR_RETURN(auto sqrt, EmitSqrt(prim_type, value));
1093 return FDiv(llvm::ConstantFP::get(sqrt->getType(), 1.0), sqrt);
1094 }
1095
EmitSin(PrimitiveType prim_type,llvm::Value * value)1096 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
1097 llvm::Value* value) {
1098 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
1099 {value->getType()}, b_);
1100 }
1101
EmitCos(PrimitiveType prim_type,llvm::Value * value)1102 StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
1103 llvm::Value* value) {
1104 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
1105 {value->getType()}, b_);
1106 }
1107
EmitExp(PrimitiveType prim_type,llvm::Value * value)1108 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
1109 llvm::Value* value) {
1110 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
1111 {value->getType()}, b_);
1112 }
1113
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)1114 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
1115 llvm::Value* value) {
1116 auto x = value;
1117 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1118 auto one = llvm::ConstantFP::get(type, 1.0);
1119 auto half = llvm::ConstantFP::get(type, 0.5);
1120 // When the exponent is large, the naive evaluation of e^(x) - 1 is more
1121 // accurate than the Taylor series.
1122 TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value));
1123 auto for_large_x = FSub(exp_x, one);
1124 // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + ….
1125 // We want exp(x)-1 which is x + x^2/2 + x^3/6 + ….
1126 auto x_squared = FAdd(x, x);
1127 auto x_squared_over_two = FMul(x_squared, half);
1128 auto for_small_x = FAdd(x, x_squared_over_two);
1129 const auto kExponentIsSmallThreshold = 1e-5;
1130 auto abs_x =
1131 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1132 auto x_is_small =
1133 FCmpOLT(abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold));
1134 return Select(x_is_small, for_small_x, for_large_x);
1135 }
1136
EmitRoundNearestAfz(PrimitiveType,llvm::Value * value)1137 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRoundNearestAfz(
1138 PrimitiveType /*prim_type*/, llvm::Value* value) {
1139 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round, {value},
1140 {value->getType()}, b_);
1141 }
1142
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs)1143 StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
1144 llvm::Value* lhs,
1145 llvm::Value* rhs) {
1146 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
1147 {lhs->getType()}, b_);
1148 }
1149
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs)1150 StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
1151 llvm::Value* lhs,
1152 llvm::Value* rhs) {
1153 return Unimplemented("atan2");
1154 }
1155
EmitTanh(PrimitiveType prim_type,llvm::Value * value)1156 StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
1157 llvm::Value* value) {
1158 return Unimplemented("tanh");
1159 }
1160
EmitReducePrecision(const HloInstruction * hlo,llvm::Value * x)1161 StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
1162 const HloInstruction* hlo, llvm::Value* x) {
1163 if (hlo->operand(0)->shape().element_type() != F32) {
1164 return Unimplemented("reduce-precision only implemented for F32");
1165 }
1166 return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(),
1167 /*mantissa_bits=*/hlo->mantissa_bits(), b_);
1168 }
1169
SaturateShiftIfNecessary(llvm::IRBuilder<> * b,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * shift_result,bool saturate_to_sign_bit)1170 static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
1171 llvm::Value* lhs, llvm::Value* rhs,
1172 llvm::Value* shift_result,
1173 bool saturate_to_sign_bit) {
1174 llvm::IntegerType* integer_type =
1175 llvm::cast<llvm::IntegerType>(lhs->getType());
1176 unsigned integer_bitsize = integer_type->getBitWidth();
1177 llvm::ConstantInt* integer_bitsize_constant =
1178 llvm::ConstantInt::get(integer_type, integer_bitsize);
1179 llvm::ConstantInt* zero = llvm::ConstantInt::get(integer_type, 0);
1180 llvm::ConstantInt* minus_one = llvm::ConstantInt::get(integer_type, -1);
1181 llvm::Value* saturated_value;
1182 if (saturate_to_sign_bit) {
1183 saturated_value =
1184 b->CreateSelect(b->CreateICmpSLT(lhs, zero), minus_one, zero);
1185 } else {
1186 saturated_value = zero;
1187 }
1188 llvm::Value* shift_amt_in_range =
1189 b->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk");
1190 return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value);
1191 }
1192
GetOne(llvm::Type * type)1193 llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) {
1194 return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 1);
1195 }
1196
GetZero(llvm::Type * type)1197 llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) {
1198 return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 0);
1199 }
1200
GetIntSMin(llvm::Type * type)1201 llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) {
1202 auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1203 return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue(
1204 integer_type->getBitWidth()));
1205 }
1206
GetMinusOne(llvm::Type * type)1207 llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) {
1208 auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1209 return llvm::ConstantInt::get(
1210 integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth()));
1211 }
1212
IsZero(llvm::Value * v)1213 llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) {
1214 return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0));
1215 }
1216
IsIntMinDivisionOverflow(llvm::Value * lhs,llvm::Value * rhs)1217 llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs,
1218 llvm::Value* rhs) {
1219 return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())),
1220 ICmpEQ(rhs, GetMinusOne(rhs->getType())));
1221 }
1222
EmitIntegerDivide(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1223 llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs,
1224 llvm::Value* rhs,
1225 bool is_signed) {
1226 // Integer division overflow behavior:
1227 //
1228 // X / 0 == -1
1229 // INT_SMIN /s -1 = INT_SMIN
1230
1231 if (!is_signed) {
1232 llvm::Value* udiv_is_unsafe = IsZero(rhs);
1233 llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs);
1234 llvm::Value* safe_div = UDiv(lhs, safe_rhs);
1235 return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div);
1236 }
1237
1238 llvm::Value* has_zero_divisor = IsZero(rhs);
1239 llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1240 llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1241 llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs);
1242 llvm::Value* safe_div = SDiv(lhs, safe_rhs);
1243
1244 return Select(
1245 has_zero_divisor, GetMinusOne(lhs->getType()),
1246 Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div));
1247 }
1248
EmitIntegerRemainder(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1249 llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
1250 llvm::Value* rhs,
1251 bool is_signed) {
1252 // Integer remainder overflow behavior:
1253 //
1254 // X % 0 == X
1255 // INT_SMIN %s -1 = 0
1256
1257 if (!is_signed) {
1258 llvm::Value* urem_is_unsafe = IsZero(rhs);
1259 llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs);
1260 llvm::Value* safe_rem = URem(lhs, safe_rhs);
1261 return Select(urem_is_unsafe, lhs, safe_rem);
1262 }
1263
1264 llvm::Value* has_zero_divisor = IsZero(rhs);
1265 llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1266 llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1267 llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs);
1268 llvm::Value* safe_rem = SRem(lhs, safe_rhs);
1269
1270 return Select(
1271 has_zero_divisor, lhs,
1272 Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem));
1273 }
1274
EmitIntegerBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1275 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
1276 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
1277 bool is_signed) {
1278 switch (op->opcode()) {
1279 // TODO(jingyue): add the "nsw" attribute for signed types.
1280 case HloOpcode::kAdd:
1281 return Add(lhs_value, rhs_value);
1282 case HloOpcode::kSubtract:
1283 return Sub(lhs_value, rhs_value);
1284 case HloOpcode::kMultiply:
1285 return Mul(lhs_value, rhs_value);
1286 case HloOpcode::kDivide:
1287 return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
1288 case HloOpcode::kRemainder:
1289 return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
1290 case HloOpcode::kCompare: {
1291 switch (op->comparison_direction()) {
1292 case ComparisonDirection::kEq:
1293 return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
1294 rhs_value, b_);
1295 case ComparisonDirection::kNe:
1296 return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
1297 rhs_value, b_);
1298 case ComparisonDirection::kLt:
1299 return llvm_ir::EmitComparison(
1300 is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
1301 lhs_value, rhs_value, b_);
1302 case ComparisonDirection::kGt:
1303 return llvm_ir::EmitComparison(
1304 is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
1305 lhs_value, rhs_value, b_);
1306 case ComparisonDirection::kLe:
1307 return llvm_ir::EmitComparison(
1308 is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
1309 lhs_value, rhs_value, b_);
1310 case ComparisonDirection::kGe:
1311 return llvm_ir::EmitComparison(
1312 is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
1313 lhs_value, rhs_value, b_);
1314 }
1315 }
1316 case HloOpcode::kMinimum:
1317 return EmitIntegralMin(lhs_value, rhs_value, is_signed);
1318 case HloOpcode::kMaximum:
1319 return EmitIntegralMax(lhs_value, rhs_value, is_signed);
1320 case HloOpcode::kAnd:
1321 return And(lhs_value, rhs_value);
1322 case HloOpcode::kOr:
1323 return Or(lhs_value, rhs_value);
1324 case HloOpcode::kXor:
1325 return Xor(lhs_value, rhs_value);
1326
1327 // Shifting out bits >= the number of bits in the type being shifted
1328 // produces a poison value in LLVM which is basically "deferred undefined
1329 // behavior" -- doing something observable with such a value precipitates
1330 // UB. We replace the poison value with a constant to avoid this deferred
1331 // UB.
1332 case HloOpcode::kShiftRightArithmetic:
1333 return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1334 AShr(lhs_value, rhs_value),
1335 /*saturate_to_sign_bit=*/true);
1336 case HloOpcode::kShiftLeft:
1337 return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1338 Shl(lhs_value, rhs_value),
1339 /*saturate_to_sign_bit=*/false);
1340 case HloOpcode::kShiftRightLogical:
1341 return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1342 LShr(lhs_value, rhs_value),
1343 /*saturate_to_sign_bit=*/false);
1344 default:
1345 return Unimplemented("binary integer op '%s'",
1346 HloOpcodeString(op->opcode()));
1347 }
1348 }
1349
EmitIntegralMax(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1350 llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
1351 llvm::Value* rhs_value,
1352 bool is_signed) {
1353 return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
1354 : llvm::ICmpInst::ICMP_UGE,
1355 lhs_value, rhs_value),
1356 lhs_value, rhs_value);
1357 }
1358
EmitIntegralMin(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1359 llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
1360 llvm::Value* rhs_value,
1361 bool is_signed) {
1362 return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
1363 : llvm::ICmpInst::ICMP_ULE,
1364 lhs_value, rhs_value),
1365 lhs_value, rhs_value);
1366 }
1367
ConvertValueForDistribution(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index,llvm::Value * raw_value)1368 StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
1369 const HloInstruction* hlo,
1370 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1371 const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) {
1372 TF_ASSIGN_OR_RETURN(llvm::Value * a_or_mean,
1373 operand_to_generator.at(hlo->operand(0))(index));
1374 TF_ASSIGN_OR_RETURN(llvm::Value * b_or_sigma,
1375 operand_to_generator.at(hlo->operand(1))(index));
1376 PrimitiveType elem_prim_ty = hlo->shape().element_type();
1377 llvm::Type* elem_ir_ty =
1378 llvm_ir::PrimitiveTypeToIrType(elem_prim_ty, module_);
1379 llvm::Type* raw_value_ty = raw_value->getType();
1380
1381 // If we're generating a floating-point value, convert the raw integer R (i.e.
1382 // `raw_value`) to a float in the range [0, 1).
1383 //
1384 // The basic approach is to choose a significand and exponent such that the
1385 // significand is uniformly distributed and the exponent is distributed, well,
1386 // exponentially (it's more likely to be close to 0 than far from 0).
1387 //
1388 // An easy way to do this is to say that the significand is the first S bits
1389 // of R, and the exponent is determined by the number of trailing zeroes in R,
1390 // exp = 2^-(cttz(R) + 1). (+1 because the largest exponent should be -1;
1391 // this way the largest value we can return is 1.999... * 2^-1 = 1-ε.)
1392 //
1393 // This results in a small bias. Namely, if R has enough trailing zeroes, the
1394 // significand and exponent will "overlap". As a concrete example, consider
1395 //
1396 // 20 X's 12 zeroes
1397 // R = 0bXXXXXXXXXXXXXXXXXXXX000000000000
1398 //
1399 // Here the exponent is 2^-13 because R has 12 trailing zeroes. The
1400 // significand is made up of the first 23 most-significant bits of R, which we
1401 // observe contain 3 zeroes. This is biased because any random value with
1402 // exponent 2^-12 will have a significand which ends in `000`.
1403 //
1404 // For f32s, this problem occurs only when there are more than 32-23 = 9
1405 // trailing zeros, which happens with probability 0.5^10 = ~0.1%. Moreover the
1406 // probability of a large bias (i.e. many trailing 0s in the significand) is
1407 // exponentially low. So we deem this acceptable.
1408 llvm::Value* elem_value = raw_value;
1409 if (elem_ir_ty->isFloatingPointTy()) {
1410 const auto& dest_flt_semantics = elem_ir_ty->getFltSemantics();
1411 const int bits = raw_value_ty->getPrimitiveSizeInBits();
1412 CHECK_GE(bits, llvm::APFloat::semanticsSizeInBits(dest_flt_semantics));
1413
1414 // Subtract 1 because semanticsPrecision includes the "hidden bit", i.e. the
1415 // implicit "1." at the beginning of the significand.
1416 const int significand_bits =
1417 llvm::APFloat::semanticsPrecision(dest_flt_semantics) - 1;
1418
1419 llvm::Value* cttz = llvm_ir::EmitCallToIntrinsic(
1420 llvm::Intrinsic::cttz, {raw_value, /*is_zero_undef=*/b_->getFalse()},
1421 {raw_value->getType()}, b_);
1422 llvm::Value* significand = LShr(raw_value, bits - significand_bits);
1423
1424 // Exponent bias is -127 for f32, meaning that if the exponent is E and the
1425 // significand is S, then the value of the number is 2^(E - 127) * (1.S).
1426 //
1427 // We want cttz == 0 to correspond to 2^-1, so our exponent is computed as
1428 // E = 126 - cttz.
1429 //
1430 // For f64, this is all the same, except the bias is -1023.
1431 //
1432 // In IEEE floating point, the absolute value of the exponent bias equals
1433 // the value of the largest possible exponent.
1434 const int bias = -llvm::APFloat::semanticsMaxExponent(dest_flt_semantics);
1435 llvm::Value* exponent =
1436 Sub(llvm::ConstantInt::get(cttz->getType(), -bias - 1), cttz);
1437
1438 // Now just slot everything into place! The `Trunc` is here because
1439 // raw_value may be larger than our float destination.
1440 elem_value =
1441 BitCast(Trunc(Or(Shl(exponent, significand_bits), significand),
1442 b_->getIntNTy(elem_ir_ty->getPrimitiveSizeInBits())),
1443 elem_ir_ty);
1444 }
1445
1446 // Convert the value for the requested distribution.
1447 switch (hlo->random_distribution()) {
1448 case RNG_UNIFORM: {
1449 if (elem_ir_ty->isFloatingPointTy()) {
1450 return FAdd(FMul(FSub(b_or_sigma, a_or_mean), elem_value), a_or_mean);
1451 } else {
1452 // To generate a uniform random value in [a, b) from a raw random sample
1453 // in range [0, 2^N), we let range = b - a and return
1454 // (a + raw_value % range). If range is not a power of 2, raw values
1455 // larger than (2^N - 2^N % range) are biased toward results in
1456 // [a, a + (limit % range)). An unbiased algorithm would need to drop
1457 // raw values and re-sample, but we don't do this because re-sampling in
1458 // an efficient way is complex, and it's not clear that users need it.
1459 // In particular, if one thread in a GPU warp needs to re-sample, we pay
1460 // the same cost as if the whole warp were to re-sample. So an
1461 // efficient re-sampling implementation on GPU would need to do
1462 // nontrivial work to share entropy between threads in the warp.
1463 auto range = Sub(b_or_sigma, a_or_mean);
1464 return Add(a_or_mean, URem(elem_value, range));
1465 }
1466 }
1467 case RNG_NORMAL: {
1468 TF_ASSIGN_OR_RETURN(
1469 llvm::Value * r,
1470 EmitErfcInv(elem_prim_ty, FMul(llvm::ConstantFP::get(elem_ir_ty, 2.0),
1471 elem_value)));
1472 return FAdd(FMul(r, b_or_sigma), a_or_mean);
1473 }
1474 default:
1475 return InvalidArgument(
1476 "unhandled distribution %s",
1477 RandomDistribution_Name(hlo->random_distribution()));
1478 }
1479 }
1480
1481 namespace {
1482
1483 // Checks that the primitive type is supported by the elemental IR emitter for
1484 // Philox RNG and returns the number of elements in each 128 bit sample of the
1485 // Philox RNG algorithm.
GetNumberOfElementsPerPhiloxRngSample(PrimitiveType elem_prim_ty)1486 int32 GetNumberOfElementsPerPhiloxRngSample(PrimitiveType elem_prim_ty) {
1487 // Calculate the number of elements, that is the number of random numbers, in
1488 // a 128 bit sample.
1489 switch (elem_prim_ty) {
1490 case U32:
1491 case S32:
1492 case F32:
1493 // The algorithm uses 32 bits to generate values for F16.
1494 case F16:
1495 return 4;
1496 case U64:
1497 case S64:
1498 case F64:
1499 return 2;
1500 default:
1501 // BF16 is converted to F16 by the hlo pass HloElementTypeConverter.
1502 // Other data types are not supported by XLA random operation.
1503 LOG(FATAL) << "Unrecognized primitive type for RNG " << elem_prim_ty;
1504 }
1505 return 0;
1506 }
1507
1508 // Calculates the four uint32 values for the 128-bit Philox sample.
CalculateSampleValues(llvm::Value * sample_idx,llvm::Value * hlo_random_value,llvm::Value * global_random_number,llvm::Value * rng_state,llvm::IRBuilder<> * b)1509 std::array<llvm::Value*, 4> CalculateSampleValues(
1510 llvm::Value* sample_idx, llvm::Value* hlo_random_value,
1511 llvm::Value* global_random_number, llvm::Value* rng_state,
1512 llvm::IRBuilder<>* b) {
1513 llvm::Type* index_ty = sample_idx->getType();
1514
1515 std::array<llvm::Value*, 4> counter_values;
1516
1517 // Use the sample index to initialize counter[0] and counter[1].
1518 unsigned index_ty_size_in_bits = index_ty->getPrimitiveSizeInBits();
1519 CHECK(index_ty_size_in_bits == 32 || index_ty_size_in_bits == 64);
1520 if (index_ty_size_in_bits == 32) {
1521 counter_values[0] = sample_idx;
1522 counter_values[1] = b->getInt32(0);
1523 } else {
1524 std::tie(counter_values[0], counter_values[1]) =
1525 llvm_ir::SplitInt64ToInt32s(b, sample_idx);
1526 }
1527
1528 // Xor the global state variable with the global random number seed and use
1529 // the result to initialize counter[2] and counter[3].
1530 std::tie(counter_values[2], counter_values[3]) = llvm_ir::SplitInt64ToInt32s(
1531 b, b->CreateXor(rng_state, global_random_number));
1532
1533 // The algorithm uses a 64 bit key, which is also interpreted as two uint32
1534 // values.
1535 llvm::Value* key_values[2];
1536
1537 // Use a module random number to initialize the key.
1538 std::tie(key_values[0], key_values[1]) =
1539 llvm_ir::SplitInt64ToInt32s(b, hlo_random_value);
1540
1541 // Prepare the constants used in the Philox RNG Algorithm.
1542 llvm::Value* philoxW32A = b->getInt32(0x9E3779B9);
1543 llvm::Value* philoxW32B = b->getInt32(0xBB67AE85);
1544 llvm::Value* philoxM4xW32A = b->getInt32(0xD2511F53);
1545 llvm::Value* philoxM4xW32B = b->getInt32(0xCD9E8D57);
1546
1547 // Compute the 128 bit value for the current sample by repeating the
1548 // single round computation and key raising computation for ten times.
1549 for (int round = 0; round < 10; ++round) {
1550 // A single round of computation of the counter values is as follows:
1551 // MultiplyHighLow(kPhiloxM4x32A, counter[0], &lo0, &hi0);
1552 // MultiplyHighLow(kPhiloxM4x32B, counter[2], &lo1, &hi1);
1553 // counter[0] = hi1 ^ counter[1] ^ key[0];
1554 // counter[1] = lo1;
1555 // counter[2] = hi0 ^ counter[3] ^ key[1];
1556 // counter[3] = lo0;
1557 llvm::Value* lo0;
1558 llvm::Value* hi0;
1559 std::tie(lo0, hi0) =
1560 llvm_ir::UMulLowHigh32(b, philoxM4xW32A, counter_values[0]);
1561 llvm::Value* lo1;
1562 llvm::Value* hi1;
1563 std::tie(lo1, hi1) =
1564 llvm_ir::UMulLowHigh32(b, philoxM4xW32B, counter_values[2]);
1565 counter_values[0] =
1566 b->CreateXor(hi1, b->CreateXor(counter_values[1], key_values[0]));
1567 counter_values[1] = lo1;
1568 counter_values[2] =
1569 b->CreateXor(hi0, b->CreateXor(counter_values[3], key_values[1]));
1570 counter_values[3] = lo0;
1571 key_values[0] = b->CreateAdd(key_values[0], philoxW32A);
1572 key_values[1] = b->CreateAdd(key_values[1], philoxW32B);
1573 }
1574
1575 return counter_values;
1576 }
1577
1578 } // namespace
1579
1580 // Implements the Philox algorithm to generate random numbers in parallel.
1581 // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
1582 // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
1583 //
1584 // The paper presents a few variants of the Philox algorithm, we picked the
1585 // 4x32_10 version of the algorithm for the following reasons:
1586 // . 4x32 uses 32-bit multiplication which is fast on GPUs.
1587 // . The authors recommend the 10-round variant, and TensorFlow also uses it.
1588 //
1589 // Precondition: the RNG instruction is not fused.
MakePhiloxRngElementGenerator(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator)1590 llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
1591 const HloInstruction* hlo,
1592 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
1593 VLOG(3) << "Using philox RNG algorithm";
1594 CHECK(!hlo->IsFused());
1595 // A random number generated by the per module random number generator.
1596 // This ensures that each RNG HLO generates a different random sequence.
1597 llvm::Value* hlo_random_value = b_->getInt64(hlo->GetModule()->RandomNew64());
1598 // A value specified by the configuration or generated by a global random
1599 // number generator.
1600 llvm::Value* global_random_number =
1601 b_->getInt64(hlo_module_config_.seed() != 0 ? hlo_module_config_.seed()
1602 : GlobalRandomValue());
1603
1604 int elems_per_sample =
1605 GetNumberOfElementsPerPhiloxRngSample(hlo->shape().element_type());
1606
1607 // Allocate stack storage for the 128 bit sample as four int32.
1608 llvm::Type* int32_ty = b_->getInt32Ty();
1609 llvm::Value* sample_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
1610 int32_ty, /*element_count=*/b_->getInt32(4), "sample", b_);
1611
1612 // Load the global state variable for the Philox RNG algorithm.
1613 llvm::GlobalVariable* rng_state_ptr =
1614 llvm_ir::GetOrCreateVariableForPhiloxRngState(module_, b_);
1615 llvm::Value* rng_state = Load(rng_state_ptr, "rng_state_value");
1616
1617 // Build and return the elemental IR generator to generate a random value for
1618 // the element corresponding to the current thread.
1619 //
1620 // This elemental IR generator computes one sample with multiple random
1621 // numbers but only returns one random number. As a result, neighboring
1622 // threads may calculate the same sample unnecessarily. However, if the
1623 // kernel containing the RNG hlo is unrolled, LLVM is able to optimize away
1624 // the duplicated computation of the same sample. In particular, if the unroll
1625 // factor is a multiplier of elems_per_sample, LLVM is able to completely
1626 // remove such duplicated computation. If the unroll factor is a non-trivial
1627 // factor of elems_per_sample, LLVM can only partially remove such duplicated
1628 // computation.
1629 return [=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
1630 llvm::Type* index_ty = index.GetType();
1631 // Calculate the linear element index.
1632 llvm::Value* elem_idx = index.linear();
1633 if (elem_idx == nullptr) {
1634 elem_idx = index.Linearize(AsInt64Slice(hlo->shape().dimensions()), b_);
1635 }
1636
1637 // Calculate the index for the 128 bit sample and the offset of the current
1638 // element within the sample.
1639 llvm::Value* elems_per_sample_value =
1640 llvm::ConstantInt::get(index_ty, elems_per_sample);
1641 llvm::Value* sample_idx = UDiv(elem_idx, elems_per_sample_value);
1642 llvm::Value* elem_offset = URem(elem_idx, elems_per_sample_value);
1643
1644 std::array<llvm::Value*, 4> counter_values = CalculateSampleValues(
1645 sample_idx, hlo_random_value, global_random_number, rng_state, b_);
1646
1647 // Store the four counter_values into the sample_address alloca so we can
1648 // load the elem_offset'th one below.
1649 for (int idx = 0; idx < 4; ++idx) {
1650 Store(counter_values[idx],
1651 InBoundsGEP(sample_address, b_->getInt32(idx)));
1652 }
1653
1654 llvm::Type* int64_ty = b_->getInt64Ty();
1655 CHECK(elems_per_sample == 2 || elems_per_sample == 4);
1656 llvm::Type* raw_value_ty = elems_per_sample == 2 ? int64_ty : int32_ty;
1657 // Retrieve the raw value for the current element from the current sample.
1658 llvm::Value* raw_elem_value = Load(
1659 InBoundsGEP(PointerCast(sample_address, raw_value_ty->getPointerTo()),
1660 elem_offset),
1661 "raw_elem_value");
1662
1663 return ConvertValueForDistribution(hlo, operand_to_generator, index,
1664 raw_elem_value);
1665 };
1666 }
1667
EmitElementalSelect(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1668 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
1669 const HloInstruction* hlo,
1670 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1671 const llvm_ir::IrArray::Index& index) {
1672 TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
1673 operand_to_generator.at(hlo->operand(0))(index));
1674 TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
1675 operand_to_generator.at(hlo->operand(1))(index));
1676 TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
1677 operand_to_generator.at(hlo->operand(2))(index));
1678 return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value,
1679 on_false_value);
1680 }
1681
EmitElementalClamp(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1682 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
1683 const HloInstruction* hlo,
1684 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1685 const llvm_ir::IrArray::Index& index) {
1686 TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
1687 operand_to_generator.at(hlo->operand(0))(index));
1688 TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
1689 operand_to_generator.at(hlo->operand(1))(index));
1690 TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
1691 operand_to_generator.at(hlo->operand(2))(index));
1692 PrimitiveType prim_type = hlo->shape().element_type();
1693 if (primitive_util::IsFloatingPointType(prim_type)) {
1694 return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
1695 } else if (primitive_util::IsIntegralType(prim_type)) {
1696 bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
1697 return EmitIntegralMin(
1698 max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed);
1699 } else {
1700 return Unimplemented("Clamp unimplemented for %s",
1701 PrimitiveType_Name(prim_type));
1702 }
1703 }
1704
EmitElementalConcatenate(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & target_index)1705 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
1706 const HloInstruction* hlo,
1707 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1708 const llvm_ir::IrArray::Index& target_index) {
1709 const int64 concat_dim = hlo->dimensions(0);
1710 auto source_index = target_index;
1711
1712 llvm::BasicBlock* init_block = b_->GetInsertBlock();
1713
1714 // A terminator should be present iff we're emitting code
1715 // into the middle (as opposed to the end) of a basic block.
1716 CHECK_EQ(b_->GetInsertPoint() == init_block->end(),
1717 init_block->getTerminator() == nullptr);
1718
1719 llvm::BasicBlock* exit_block;
1720 if (b_->GetInsertPoint() == init_block->end()) {
1721 exit_block = llvm_ir::CreateBasicBlock(
1722 /*insert_before=*/nullptr, IrName(hlo, "merge"), b_);
1723 } else {
1724 exit_block =
1725 init_block->splitBasicBlock(b_->GetInsertPoint(), IrName(hlo, "merge"));
1726 init_block->getTerminator()->eraseFromParent();
1727 }
1728
1729 llvm_ir::SetToFirstInsertPoint(exit_block, b_);
1730 llvm::PHINode* output =
1731 PHI(llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
1732 hlo->operands().size());
1733 auto prior_insert_point = b_->GetInsertPoint();
1734
1735 b_->SetInsertPoint(init_block);
1736
1737 // Assign a unique id for each *different* operand, and count how often each
1738 // operand is used. If all operands are different, the usage count will be 1
1739 // for each operand.
1740 absl::flat_hash_map<const HloInstruction*, int64> to_unique_operand_id;
1741 std::vector<int64> operand_usage_count;
1742 for (const auto* operand : hlo->operands()) {
1743 if (to_unique_operand_id.contains(operand)) {
1744 ++operand_usage_count[to_unique_operand_id[operand]];
1745 } else {
1746 int64 unique_operand_id = to_unique_operand_id.size();
1747 to_unique_operand_id[operand] = unique_operand_id;
1748 operand_usage_count.push_back(1);
1749 }
1750 }
1751
1752 // To avoid that we emit the same operand more than once, we create one basic
1753 // block for each *different* operand with a PHI node for the different source
1754 // index inputs.
1755 std::vector<llvm::BasicBlock*> emit_operand_blocks(
1756 to_unique_operand_id.size(), nullptr);
1757 std::vector<llvm::PHINode*> source_index_phis(to_unique_operand_id.size(),
1758 nullptr);
1759 for (const auto* operand : hlo->operands()) {
1760 int64 operand_id = to_unique_operand_id[operand];
1761 if (emit_operand_blocks[operand_id] != nullptr) {
1762 continue;
1763 }
1764
1765 emit_operand_blocks[operand_id] = llvm_ir::CreateBasicBlock(
1766 exit_block, StrCat("concat_index_from_operand_id", operand_id), b_);
1767 auto saved_insert_point = b_->GetInsertPoint();
1768 llvm_ir::SetToFirstInsertPoint(emit_operand_blocks[operand_id], b_);
1769 source_index_phis[operand_id] =
1770 PHI(source_index.GetType(), operand_usage_count[operand_id]);
1771 std::vector<llvm::Value*> operand_multi_index = source_index.multidim();
1772 operand_multi_index[concat_dim] = source_index_phis[operand_id];
1773
1774 // Create the terminator of the block before calling operand generators,
1775 // because they require non-degenerate basic blocks.
1776 b_->SetInsertPoint(llvm::BranchInst::Create(
1777 exit_block, /*InsertAtEnd=*/emit_operand_blocks[operand_id]));
1778 llvm_ir::IrArray::Index operand_index(operand_multi_index, operand->shape(),
1779 source_index.GetType());
1780 TF_ASSIGN_OR_RETURN(llvm::Value * value,
1781 operand_to_generator.at(operand)(operand_index));
1782 output->addIncoming(value, b_->GetInsertBlock());
1783 b_->SetInsertPoint(init_block, saved_insert_point);
1784 }
1785
1786 std::vector<llvm::Value*> source_multi_index = source_index.multidim();
1787 for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
1788 ++operand_idx) {
1789 const HloInstruction* operand = hlo->operand(operand_idx);
1790 auto false_block = llvm_ir::CreateBasicBlock(
1791 exit_block, StrCat("concat_index_not_from_operand", operand_idx), b_);
1792 auto concat_dim_size = source_index.GetConstantWithIndexType(
1793 operand->shape().dimensions(concat_dim));
1794 int64 operand_id = to_unique_operand_id[operand];
1795 source_index_phis[operand_id]->addIncoming(source_multi_index[concat_dim],
1796 b_->GetInsertBlock());
1797 CondBr(ICmpULT(source_multi_index[concat_dim], concat_dim_size),
1798 emit_operand_blocks[operand_id], false_block);
1799
1800 // Subtract the size of the concat dimension of the current operand
1801 // from the source index.
1802 b_->SetInsertPoint(false_block);
1803 source_multi_index[concat_dim] =
1804 Sub(source_multi_index[concat_dim], concat_dim_size);
1805 }
1806
1807 Unreachable();
1808 b_->SetInsertPoint(exit_block, prior_insert_point);
1809 return output;
1810 }
1811
EmitElementalDynamicSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1812 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
1813 const HloInstruction* hlo,
1814 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1815 const llvm_ir::IrArray::Index& index) {
1816 // Emit IR to read dynamic start indices from hlo->operand(1).
1817 const HloInstruction* input_hlo = hlo->operand(0);
1818 const int64 rank = input_hlo->shape().rank();
1819 // Use the same index type for all tensor accesses in the same kernel.
1820 llvm::Type* index_type = index.GetType();
1821 std::vector<llvm::Value*> slice_start_multi_index(rank);
1822 for (int64 i = 0; i < rank; ++i) {
1823 auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
1824 return llvm::ConstantInt::get(index_type, c);
1825 };
1826 llvm_ir::IrArray::Index zero_index(index_type);
1827 TF_ASSIGN_OR_RETURN(
1828 llvm::Value * start_index_value,
1829 operand_to_generator.at(hlo->operand(1 + i))(zero_index));
1830
1831 // Clamp the start index so that the sliced portion fits in the operand:
1832 // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
1833 start_index_value = SExtOrTrunc(start_index_value, index_type);
1834 int64 largest_valid_start_index =
1835 input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i);
1836 CHECK_GE(largest_valid_start_index, 0);
1837
1838 bool is_signed = ShapeUtil::ElementIsSigned(hlo->operand(1)->shape());
1839 start_index_value = EmitIntegralMin(
1840 index_typed_const(largest_valid_start_index),
1841 EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
1842 is_signed);
1843
1844 start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
1845 slice_start_multi_index[i] = start_index_value;
1846 }
1847
1848 std::vector<llvm::Value*> input_multi_index(rank);
1849 for (int64 i = 0; i < rank; ++i) {
1850 // Emit IR which computes:
1851 // input_index = start_index + offset_index
1852 input_multi_index[i] = Add(slice_start_multi_index[i], index[i]);
1853 }
1854 llvm_ir::IrArray::Index input_index(input_multi_index, input_hlo->shape(),
1855 index_type);
1856 return operand_to_generator.at(input_hlo)(input_index);
1857 }
1858
EmitElementalGather(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1859 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
1860 const HloInstruction* hlo,
1861 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1862 const llvm_ir::IrArray::Index& index) {
1863 const Shape& operand_shape = hlo->operand(0)->shape();
1864 const Shape& indices_shape = hlo->operand(1)->shape();
1865 const Shape& output_shape = hlo->shape();
1866
1867 const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers();
1868
1869 const llvm_ir::ElementGenerator& operand_generator =
1870 operand_to_generator.at(hlo->operand(0));
1871 const llvm_ir::ElementGenerator& indices_generator =
1872 operand_to_generator.at(hlo->operand(1));
1873
1874 llvm::Type* index_type = index.GetType();
1875 // This is the index into `operand` that holds the element we want to
1876 // generate.
1877 std::vector<llvm::Value*> operand_multi_index;
1878
1879 // First copy in the window indices to operand_index. Also collect a mapping
1880 // from operand dimension to output window dimension. Elided window dimensions
1881 // map to -1.
1882 std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1);
1883 for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0;
1884 i < e; i++) {
1885 if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
1886 operand_multi_index.push_back(index.GetConstantWithIndexType(0));
1887 } else {
1888 int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++);
1889 operand_to_output_dim[i] = output_window_dim;
1890 operand_multi_index.push_back(index[output_window_dim]);
1891 }
1892 }
1893
1894 // This is the index of the index vector in the start_indices tensor.
1895 std::vector<llvm::Value*> gather_index_index_components;
1896 {
1897 for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
1898 if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
1899 gather_index_index_components.push_back(index[i]);
1900 }
1901 }
1902
1903 if (gather_index_index_components.size() !=
1904 indices_shape.dimensions_size()) {
1905 gather_index_index_components.insert(
1906 gather_index_index_components.begin() +
1907 dim_numbers.index_vector_dim(),
1908 nullptr);
1909 }
1910 }
1911
1912 auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
1913 llvm::Value* gather_dim_component_extended =
1914 SExtOrTrunc(index_component, index_type);
1915 int64 operand_dim = dim_numbers.start_index_map(dim);
1916 int64 output_dim = operand_to_output_dim[operand_dim];
1917 // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
1918 // This means we set the iteration index to 0, so for the purpose of the
1919 // following calculations we can consider the output dimension size to be 1.
1920 int64 output_dim_size =
1921 output_dim == -1 ? 1 : output_shape.dimensions(output_dim);
1922 int64 largest_valid_start_index =
1923 operand_shape.dimensions(operand_dim) - output_dim_size;
1924 CHECK_GE(largest_valid_start_index, 0);
1925
1926 // Clamp the gather index so that the gather region fits in the operand.
1927 // gather_dim_component_extended_inbound =
1928 // clamp(gather_dim_component_extended, 0, largest_valid_start_index);
1929 bool is_signed = ShapeUtil::ElementIsSigned(indices_shape);
1930 auto gather_dim_component_extended_inbound = EmitIntegralMin(
1931 index.GetConstantWithIndexType(largest_valid_start_index),
1932 EmitIntegralMax(index.GetConstantWithIndexType(0),
1933 gather_dim_component_extended, is_signed),
1934 is_signed);
1935
1936 operand_multi_index[operand_dim] =
1937 Add(operand_multi_index[operand_dim],
1938 gather_dim_component_extended_inbound);
1939 };
1940
1941 if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
1942 IrArray::Index gather_index_index(gather_index_index_components,
1943 indices_shape, index_type);
1944 TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
1945 indices_generator(gather_index_index));
1946 add_to_operand_index(gather_dim_component, 0);
1947 } else {
1948 int64 index_vector_size =
1949 indices_shape.dimensions(dim_numbers.index_vector_dim());
1950 for (int64 i = 0; i < index_vector_size; i++) {
1951 gather_index_index_components[dim_numbers.index_vector_dim()] =
1952 index.GetConstantWithIndexType(i);
1953 IrArray::Index gather_index_index(gather_index_index_components,
1954 indices_shape, index_type);
1955 TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
1956 indices_generator(gather_index_index));
1957 add_to_operand_index(gather_dim_component, i);
1958 }
1959 }
1960 IrArray::Index operand_index(operand_multi_index, operand_shape, index_type);
1961 return operand_generator(operand_index);
1962 }
1963
EmitElementalDynamicUpdateSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1964 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
1965 const HloInstruction* hlo,
1966 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1967 const llvm_ir::IrArray::Index& index) {
1968 const HloInstruction* input_hlo = hlo->operand(0);
1969 const HloInstruction* update_hlo = hlo->operand(1);
1970 const HloInstruction* start_hlo = hlo->operand(2);
1971 // Calculate slice start/end indices.
1972 const int64 rank = input_hlo->shape().rank();
1973 std::vector<llvm::Value*> slice_start_multi_index(rank);
1974 std::vector<llvm::Value*> slice_limit_multi_index(rank);
1975 // Slice intersection gathers (ANDs) conditions on all ranks for which
1976 // 'input' is set to 'update'
1977 llvm::Value* slice_intersection = b_->getTrue();
1978
1979 for (int64 i = 0; i < rank; ++i) {
1980 llvm::Type* index_type = index[0]->getType();
1981 auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
1982 return llvm::ConstantInt::get(index_type, c);
1983 };
1984
1985 llvm_ir::IrArray::Index zero_index(index_type);
1986 TF_ASSIGN_OR_RETURN(
1987 llvm::Value * start_index_value,
1988 operand_to_generator.at(hlo->operand(2 + i))(zero_index));
1989
1990 // Clamp the start index so that the update region fits in the operand.
1991 // start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
1992 start_index_value = SExtOrTrunc(start_index_value, index_type);
1993 llvm::Value* update_dim_size =
1994 index_typed_const(update_hlo->shape().dimensions(i));
1995 int64 largest_valid_start_index =
1996 input_hlo->shape().dimensions(i) - update_hlo->shape().dimensions(i);
1997 CHECK_GE(largest_valid_start_index, 0);
1998
1999 bool is_signed = ShapeUtil::ElementIsSigned(start_hlo->shape());
2000 start_index_value = EmitIntegralMin(
2001 index_typed_const(largest_valid_start_index),
2002 EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
2003 is_signed);
2004
2005 start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
2006 slice_start_multi_index[i] = start_index_value;
2007 slice_limit_multi_index[i] =
2008 Add(slice_start_multi_index[i], update_dim_size);
2009
2010 slice_intersection =
2011 And(slice_intersection, ICmpSGE(index[i], slice_start_multi_index[i]),
2012 "slice_intersection");
2013 slice_intersection =
2014 And(slice_intersection, ICmpSLT(index[i], slice_limit_multi_index[i]),
2015 "slice_intersection");
2016 }
2017
2018 // Emit:
2019 // if (slice_intersection) -> return data from 'update'.
2020 // else -> return data from 'input'.
2021 llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2022 llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2023 "ret_value_addr", b_);
2024 llvm_ir::LlvmIfData if_data =
2025 llvm_ir::EmitIfThenElse(slice_intersection, "slice_intersection", b_);
2026
2027 // Handle true BB (return data from 'update')
2028 SetToFirstInsertPoint(if_data.true_block, b_);
2029 // Compute update index for intersection case.
2030 std::vector<llvm::Value*> update_multi_index(rank);
2031 for (int64 i = 0; i < rank; ++i) {
2032 update_multi_index[i] = Sub(index[i], slice_start_multi_index[i]);
2033 }
2034 llvm_ir::IrArray::Index update_index(update_multi_index, update_hlo->shape(),
2035 index.GetType());
2036 TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
2037 operand_to_generator.at(update_hlo)(update_index));
2038 Store(true_value, ret_value_addr);
2039
2040 // Handle false BB (return data from 'input')
2041 SetToFirstInsertPoint(if_data.false_block, b_);
2042 TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
2043 operand_to_generator.at(input_hlo)(index));
2044 Store(false_value, ret_value_addr);
2045
2046 SetToFirstInsertPoint(if_data.after_block, b_);
2047 return Load(ret_value_addr);
2048 }
2049
EmitElementalPad(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & padded_index)2050 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
2051 const HloInstruction* hlo,
2052 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2053 const llvm_ir::IrArray::Index& padded_index) {
2054 std::vector<llvm::Value*> multi_index = padded_index.multidim();
2055 llvm::Value* in_bounds = b_->getTrue();
2056 for (size_t i = 0; i < multi_index.size(); ++i) {
2057 auto index_typed_const = [=](int64 n) {
2058 return padded_index.GetConstantWithIndexType(n);
2059 };
2060 const auto& pad_dim = hlo->padding_config().dimensions(i);
2061 multi_index[i] =
2062 Sub(multi_index[i], index_typed_const(pad_dim.edge_padding_low()));
2063 in_bounds = And(in_bounds, ICmpSGE(multi_index[i], index_typed_const(0)),
2064 "in_bounds");
2065 in_bounds =
2066 And(in_bounds,
2067 ICmpEQ(index_typed_const(0),
2068 URem(multi_index[i],
2069 index_typed_const(pad_dim.interior_padding() + 1))),
2070 "in_bounds");
2071 multi_index[i] =
2072 SDiv(multi_index[i], index_typed_const(pad_dim.interior_padding() + 1));
2073 in_bounds =
2074 And(in_bounds,
2075 ICmpSLT(multi_index[i],
2076 index_typed_const(hlo->operand(0)->shape().dimensions(i))),
2077 "in_bounds");
2078 }
2079
2080 // if (in_bounds) {
2081 // ret_value = operand0[index]; // source
2082 // } else {
2083 // ret_value = *operand1; // padding
2084 // }
2085 llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2086 llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2087 "pad_result_addr", b_);
2088 llvm_ir::LlvmIfData if_data =
2089 llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
2090 SetToFirstInsertPoint(if_data.true_block, b_);
2091 llvm_ir::IrArray::Index index(multi_index, hlo->operand(0)->shape(),
2092 padded_index.GetType());
2093 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2094 operand_to_generator.at(hlo->operand(0))(index));
2095 Store(operand_value, ret_value_addr);
2096
2097 SetToFirstInsertPoint(if_data.false_block, b_);
2098 TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
2099 operand_to_generator.at(hlo->operand(1))(
2100 IrArray::Index(index.GetType())));
2101 Store(padding_value, ret_value_addr);
2102
2103 SetToFirstInsertPoint(if_data.after_block, b_);
2104 // Don't create phi(operand_value, padding_value) here, because invoking
2105 // operand_to_generator may create new basic blocks, making the parent
2106 // of operand_value or padding_value no longer a predecessor of
2107 // if_data.after_block.
2108 return Load(ret_value_addr);
2109 }
2110
EmitElementalDot(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & dot_result_index)2111 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
2112 const HloInstruction* hlo,
2113 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2114 const llvm_ir::IrArray::Index& dot_result_index) {
2115 auto lhs_generator = operand_to_generator.at(hlo->operand(0));
2116 auto rhs_generator = operand_to_generator.at(hlo->operand(1));
2117
2118 const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers();
2119 int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0);
2120 int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0);
2121
2122 int64 contracted_dim_size =
2123 hlo->operand(0)->shape().dimensions(lhs_contracting_dim);
2124 int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
2125 int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
2126
2127 llvm::Type* index_type = dot_result_index[0]->getType();
2128 auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2129 return llvm::ConstantInt::get(index_type, c);
2130 };
2131
2132 std::unique_ptr<llvm_ir::ForLoop> inner_loop = llvm_ir::ForLoop::EmitForLoop(
2133 IrName(hlo, "inner"), index_typed_const(0),
2134 index_typed_const(contracted_dim_size), index_typed_const(1), b_);
2135
2136 SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), b_);
2137 PrimitiveType primitive_type = hlo->shape().element_type();
2138 llvm::Type* primitive_type_llvm =
2139 llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
2140 llvm::Value* accumulator_alloca =
2141 llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_);
2142 Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca);
2143
2144 SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_);
2145
2146 // This is the inner reduction loop for a dot operation that produces
2147 // one element in the output. If the operands to the dot operation have
2148 // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
2149 // Given an output index [a,b,c,d,e] in the result, we compute:
2150 // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
2151
2152 std::vector<llvm::Value*> lhs_multi_index, rhs_multi_index;
2153 for (int64 i = 0; i < lhs_dims - 1; i++) {
2154 lhs_multi_index.push_back(dot_result_index[i]);
2155 }
2156 lhs_multi_index.insert(lhs_multi_index.begin() + lhs_contracting_dim,
2157 inner_loop->GetIndVarValue());
2158 IrArray::Index lhs_index(lhs_multi_index, hlo->operand(0)->shape(),
2159 index_type);
2160
2161 int64 num_batch_dims = dim_numbers.rhs_batch_dimensions_size();
2162 for (int64 i = 0; i < num_batch_dims; i++) {
2163 rhs_multi_index.push_back(
2164 dot_result_index[dim_numbers.rhs_batch_dimensions(i)]);
2165 }
2166 for (int64 i = 0; i < rhs_dims - 1 - num_batch_dims; i++) {
2167 rhs_multi_index.push_back(dot_result_index[lhs_dims - 1 + i]);
2168 }
2169 rhs_multi_index.insert(rhs_multi_index.begin() + rhs_contracting_dim,
2170 inner_loop->GetIndVarValue());
2171 IrArray::Index rhs_index(rhs_multi_index, hlo->operand(1)->shape(),
2172 index_type);
2173
2174 llvm::Value* current_accumulator = Load(accumulator_alloca);
2175 TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
2176 TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
2177 llvm::Value* next_accumulator;
2178 if (primitive_util::IsComplexType(primitive_type)) {
2179 llvm::Value* product_real =
2180 FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
2181 FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
2182 llvm::Value* product_imag =
2183 FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
2184 FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)));
2185 next_accumulator = InsertValue(
2186 current_accumulator,
2187 FAdd(EmitExtractReal(current_accumulator), product_real), {0});
2188 next_accumulator = InsertValue(
2189 next_accumulator,
2190 FAdd(EmitExtractImag(current_accumulator), product_imag), {1});
2191 } else if (primitive_util::IsFloatingPointType(primitive_type)) {
2192 next_accumulator = FAdd(current_accumulator, FMul(lhs_value, rhs_value));
2193 } else {
2194 next_accumulator = Add(current_accumulator, Mul(lhs_value, rhs_value));
2195 }
2196 Store(next_accumulator, accumulator_alloca);
2197
2198 SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
2199 return Load(accumulator_alloca);
2200 }
2201
MakeElementGenerator(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator)2202 llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
2203 const HloInstruction* hlo,
2204 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
2205 switch (hlo->opcode()) {
2206 case HloOpcode::kAbs:
2207 case HloOpcode::kRoundNearestAfz:
2208 case HloOpcode::kCeil:
2209 case HloOpcode::kClz:
2210 case HloOpcode::kConvert:
2211 case HloOpcode::kBitcastConvert:
2212 case HloOpcode::kCos:
2213 case HloOpcode::kExp:
2214 case HloOpcode::kExpm1:
2215 case HloOpcode::kFloor:
2216 case HloOpcode::kImag:
2217 case HloOpcode::kIsFinite:
2218 case HloOpcode::kLog:
2219 case HloOpcode::kLog1p:
2220 case HloOpcode::kNegate:
2221 case HloOpcode::kNot:
2222 case HloOpcode::kReal:
2223 case HloOpcode::kRsqrt:
2224 case HloOpcode::kSign:
2225 case HloOpcode::kSin:
2226 case HloOpcode::kSqrt:
2227 case HloOpcode::kTanh:
2228 return [this, hlo, &operand_to_generator](
2229 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2230 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2231 operand_to_generator.at(hlo->operand(0))(index));
2232 return EmitUnaryOp(hlo, operand_value);
2233 };
2234 case HloOpcode::kAdd:
2235 case HloOpcode::kAnd:
2236 case HloOpcode::kAtan2:
2237 case HloOpcode::kCompare:
2238 case HloOpcode::kComplex:
2239 case HloOpcode::kDivide:
2240 case HloOpcode::kMaximum:
2241 case HloOpcode::kMinimum:
2242 case HloOpcode::kMultiply:
2243 case HloOpcode::kOr:
2244 case HloOpcode::kXor:
2245 case HloOpcode::kPower:
2246 case HloOpcode::kRemainder:
2247 case HloOpcode::kShiftLeft:
2248 case HloOpcode::kShiftRightArithmetic:
2249 case HloOpcode::kShiftRightLogical:
2250 case HloOpcode::kSubtract:
2251 return [this, hlo, &operand_to_generator](
2252 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2253 const HloInstruction* lhs = hlo->operand(0);
2254 const HloInstruction* rhs = hlo->operand(1);
2255 TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value,
2256 operand_to_generator.at(lhs)(index));
2257 TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value,
2258 operand_to_generator.at(rhs)(index));
2259 return EmitBinaryOp(hlo, lhs_value, rhs_value);
2260 };
2261 case HloOpcode::kSelect:
2262 return [this, hlo, &operand_to_generator](
2263 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2264 return EmitElementalSelect(hlo, operand_to_generator, index);
2265 };
2266 case HloOpcode::kClamp:
2267 return [this, hlo, &operand_to_generator](
2268 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2269 return EmitElementalClamp(hlo, operand_to_generator, index);
2270 };
2271 case HloOpcode::kReducePrecision:
2272 return [this, hlo, &operand_to_generator](
2273 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2274 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2275 operand_to_generator.at(hlo->operand(0))(index));
2276 return EmitReducePrecision(hlo, operand_value);
2277 };
2278 case HloOpcode::kConcatenate:
2279 return [this, hlo, &operand_to_generator](
2280 const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
2281 return EmitElementalConcatenate(hlo, operand_to_generator,
2282 target_index);
2283 };
2284 case HloOpcode::kReverse:
2285 return [this, hlo, &operand_to_generator](
2286 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2287 const HloInstruction* operand = hlo->operand(0);
2288 std::vector<llvm::Value*> source_multi_index = target_index.multidim();
2289 for (int64 dim : hlo->dimensions()) {
2290 source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType(
2291 hlo->shape().dimensions(dim) - 1),
2292 target_index[dim]);
2293 }
2294 llvm_ir::IrArray::Index source_index(
2295 source_multi_index, operand->shape(), target_index.GetType());
2296 return operand_to_generator.at(operand)(source_index);
2297 };
2298 case HloOpcode::kBroadcast:
2299 return [this, hlo, &operand_to_generator](
2300 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2301 const HloInstruction* operand = hlo->operand(0);
2302 // The `dimensions` member of the broadcast instruction maps from
2303 // input dimensions to output dimensions.
2304 return operand_to_generator.at(operand)(
2305 target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(),
2306 hlo->dimensions(), b_));
2307 };
2308 case HloOpcode::kIota:
2309 return [this, hlo](
2310 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2311 auto* iota = Cast<HloIotaInstruction>(hlo);
2312 PrimitiveType element_type = iota->shape().element_type();
2313 IrArray::Index elem_index =
2314 iota->shape().rank() > 1
2315 ? target_index.SourceIndexOfBroadcast(
2316 iota->shape(),
2317 ShapeUtil::MakeShapeWithDescendingLayout(
2318 element_type,
2319 {iota->shape().dimensions(iota->iota_dimension())}),
2320 {iota->iota_dimension()}, b_)
2321 : target_index;
2322 llvm::Value* elem_index_linear = elem_index.linear();
2323 if (elem_index_linear == nullptr) {
2324 std::vector<int64> iota_bound = {
2325 iota->shape().dimensions(iota->iota_dimension())};
2326 elem_index_linear = elem_index.Linearize(iota_bound, b_);
2327 }
2328 Shape component_shape =
2329 ShapeUtil::ElementIsComplex(iota->shape())
2330 ? ShapeUtil::ComplexComponentShape(iota->shape())
2331 : iota->shape();
2332 PrimitiveType component_element_type = component_shape.element_type();
2333 llvm::Value* iota_result;
2334 if (primitive_util::IsIntegralType(component_element_type) ||
2335 component_element_type == PRED) {
2336 iota_result = b_->CreateIntCast(
2337 elem_index_linear,
2338 llvm_ir::PrimitiveTypeToIrType(component_element_type, module_),
2339 /*isSigned=*/false);
2340 } else {
2341 TF_RET_CHECK(
2342 primitive_util::IsFloatingPointType(component_element_type))
2343 << component_element_type;
2344 llvm::Type* float_ir_type;
2345 if (component_element_type == BF16) {
2346 float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_);
2347 } else {
2348 float_ir_type =
2349 llvm_ir::PrimitiveTypeToIrType(component_element_type, module_);
2350 }
2351 llvm::Value* float_val =
2352 b_->CreateUIToFP(elem_index_linear, float_ir_type);
2353 if (component_element_type == BF16) {
2354 iota_result = EmitF32ToBF16(float_val, b_);
2355 } else {
2356 iota_result = float_val;
2357 }
2358 }
2359 if (ShapeUtil::ElementIsComplex(iota->shape())) {
2360 return EmitComposeComplex(iota, iota_result, nullptr);
2361 } else {
2362 return iota_result;
2363 }
2364 };
2365 case HloOpcode::kSlice:
2366 return [this, hlo, &operand_to_generator](
2367 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2368 IrArray::Index sliced_index = index.SourceIndexOfSlice(
2369 /*operand_shape=*/hlo->operand(0)->shape(),
2370 /*starts=*/hlo->slice_starts(),
2371 /*strides=*/hlo->slice_strides(), /*builder=*/b_);
2372 return operand_to_generator.at(hlo->operand(0))(sliced_index);
2373 };
2374 case HloOpcode::kDynamicSlice:
2375 return [this, hlo, &operand_to_generator](
2376 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2377 return EmitElementalDynamicSlice(hlo, operand_to_generator, index);
2378 };
2379
2380 case HloOpcode::kGather:
2381 return [this, hlo, &operand_to_generator](
2382 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2383 return EmitElementalGather(hlo, operand_to_generator, index);
2384 };
2385 case HloOpcode::kDynamicUpdateSlice:
2386 return [this, hlo, &operand_to_generator](
2387 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2388 return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator,
2389 index);
2390 };
2391 case HloOpcode::kBitcast:
2392 CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2393 ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2394 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2395 const HloInstruction* operand = hlo->operand(0);
2396 return operand_to_generator.at(operand)(
2397 index.SourceIndexOfBitcast(hlo->shape(), operand->shape(), b_));
2398 };
2399 case HloOpcode::kReshape:
2400 CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2401 ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2402 return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2403 const HloInstruction* operand = hlo->operand(0);
2404 return operand_to_generator.at(operand)(
2405 index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_));
2406 };
2407 case HloOpcode::kCopy:
2408 return [hlo, &operand_to_generator](
2409 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2410 IrArray::Index source_index(target_index.multidim(),
2411 hlo->operand(0)->shape(),
2412 target_index.GetType());
2413 TF_ASSIGN_OR_RETURN(
2414 llvm::Value * operand_value,
2415 operand_to_generator.at(hlo->operand(0))(source_index));
2416 return operand_value;
2417 };
2418 case HloOpcode::kTranspose:
2419 return [this, hlo,
2420 &operand_to_generator](const IrArray::Index& target_index) {
2421 return operand_to_generator.at(hlo->operand(0))(
2422 target_index.SourceIndexOfTranspose(
2423 hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(), b_));
2424 };
2425 case HloOpcode::kRng:
2426 return MakePhiloxRngElementGenerator(hlo, operand_to_generator);
2427 case HloOpcode::kPad:
2428 return [this, hlo, &operand_to_generator](
2429 const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
2430 return EmitElementalPad(hlo, operand_to_generator, padded_index);
2431 };
2432
2433 case HloOpcode::kDot:
2434 return [this, hlo,
2435 &operand_to_generator](const IrArray::Index& dot_result_index)
2436 -> StatusOr<llvm::Value*> {
2437 return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
2438 };
2439 case HloOpcode::kReplicaId:
2440 return [this, hlo](const IrArray::Index&) -> StatusOr<llvm::Value*> {
2441 if (hlo_module_config_.replica_count() != 1) {
2442 return Unimplemented("Replication is not implemented on CPU/GPU.");
2443 }
2444 llvm::Type* type = llvm_ir::PrimitiveTypeToIrType(
2445 hlo->shape().element_type(), module_);
2446 return llvm::ConstantInt::getNullValue(type);
2447 };
2448 default:
2449 return [hlo](const IrArray::Index& index) {
2450 return Unimplemented("Unhandled opcode for elemental IR emission: %s",
2451 HloOpcodeString(hlo->opcode()));
2452 };
2453 }
2454 }
2455
EmitExtractReal(llvm::Value * value)2456 llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) {
2457 return ExtractValue(value, {0});
2458 }
2459
EmitExtractImag(llvm::Value * value)2460 llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) {
2461 return ExtractValue(value, {1});
2462 }
2463
EmitComposeComplex(const HloInstruction * op,llvm::Value * real,llvm::Value * imag)2464 llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
2465 llvm::Value* real,
2466 llvm::Value* imag) {
2467 auto cplx_type =
2468 llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
2469 auto complex =
2470 InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0});
2471 if (imag != nullptr) {
2472 complex = InsertValue(complex, imag, {1});
2473 }
2474 return complex;
2475 }
2476
2477 } // namespace xla
2478