1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21
22 #include "absl/base/casts.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/GlobalValue.h"
27 #include "llvm/IR/GlobalVariable.h"
28 #include "llvm/IR/MDBuilder.h"
29 #include "llvm/IR/Operator.h"
30 #include "llvm/Target/TargetOptions.h"
31 #include "llvm/Transforms/Utils/Cloning.h"
32 #include "tensorflow/compiler/xla/layout_util.h"
33 #include "tensorflow/compiler/xla/literal.h"
34 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
35 #include "tensorflow/compiler/xla/service/dump.h"
36 #include "tensorflow/compiler/xla/service/name_uniquer.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/types.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/io/path.h"
42 #include "tensorflow/core/platform/byte_order.h"
43 #include "tensorflow/core/platform/env.h"
44 #include "tensorflow/core/platform/logging.h"
45 #include "tensorflow/core/platform/types.h"
46
47 namespace xla {
48 namespace llvm_ir {
49
50 namespace {
51
52 // Note, this function is only useful in an insertion context; in a global
53 // (e.g. constants) context it will CHECK fail.
ModuleFromIRBuilder(llvm::IRBuilder<> * b)54 llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* b) {
55 auto block = CHECK_NOTNULL(b->GetInsertBlock());
56 auto fn = CHECK_NOTNULL(block->getParent());
57 auto module = CHECK_NOTNULL(fn->getParent());
58 return module;
59 }
60
61 } // namespace
62
DropConstantInitializers(const llvm::Module & module)63 std::unique_ptr<llvm::Module> DropConstantInitializers(
64 const llvm::Module& module) {
65 std::unique_ptr<llvm::Module> cloned_module = CloneModule(module);
66 for (llvm::GlobalVariable& global_var : cloned_module->globals()) {
67 global_var.setInitializer(nullptr);
68 global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage);
69 }
70 return cloned_module;
71 }
72
DumpModuleToString(const llvm::Module & module)73 string DumpModuleToString(const llvm::Module& module) {
74 std::string buffer_string;
75 llvm::raw_string_ostream ostream(buffer_string);
76 module.print(ostream, nullptr);
77 ostream.flush();
78 return buffer_string;
79 }
80
EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id,absl::Span<llvm::Value * const> operands,absl::Span<llvm::Type * const> overloaded_types,llvm::IRBuilder<> * b)81 llvm::CallInst* EmitCallToIntrinsic(
82 llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
83 absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b) {
84 llvm::Module* module = ModuleFromIRBuilder(b);
85 llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
86 module, intrinsic_id, AsArrayRef(overloaded_types));
87 return b->CreateCall(intrinsic, AsArrayRef(operands));
88 }
89
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)90 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
91 llvm::IRBuilder<>* b) {
92 if (b->getFastMathFlags().noNaNs()) {
93 auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
94 return b->CreateSelect(cmp, lhs_value, rhs_value);
95 } else {
96 auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value);
97 auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
98 auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan);
99 return b->CreateSelect(sel_lhs, lhs_value, rhs_value);
100 }
101 }
102
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)103 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
104 llvm::IRBuilder<>* b) {
105 if (b->getFastMathFlags().noNaNs()) {
106 auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
107 return b->CreateSelect(cmp, lhs_value, rhs_value);
108 } else {
109 auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value);
110 auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
111 auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan);
112 return b->CreateSelect(sel_lhs, lhs_value, rhs_value);
113 }
114 }
115
EmitBufferIndexingGEP(llvm::Value * array,llvm::Value * index,llvm::IRBuilder<> * b)116 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index,
117 llvm::IRBuilder<>* b) {
118 llvm::Type* array_type = array->getType();
119 CHECK(array_type->isPointerTy());
120 llvm::PointerType* array_type_as_pointer =
121 llvm::cast<llvm::PointerType>(array_type);
122 VLOG(2) << "EmitBufferIndexingGEP with type="
123 << llvm_ir::DumpToString(*array_type)
124 << " array=" << llvm_ir::DumpToString(*array)
125 << " index=" << llvm_ir::DumpToString(*index);
126
127 return b->CreateInBoundsGEP(
128 array_type_as_pointer->getElementType(), array,
129 llvm::isa<llvm::GlobalVariable>(array)
130 ? llvm::ArrayRef<llvm::Value*>({b->getInt64(0), index})
131 : index);
132 }
133
EmitBufferIndexingGEP(llvm::Value * array,int64 index,llvm::IRBuilder<> * b)134 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
135 llvm::IRBuilder<>* b) {
136 return EmitBufferIndexingGEP(array, b->getInt64(index), b);
137 }
138
PrimitiveTypeToIrType(PrimitiveType element_type,llvm::Module * module)139 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
140 llvm::Module* module) {
141 switch (element_type) {
142 case PRED:
143 case S8:
144 case U8:
145 return llvm::Type::getInt8Ty(module->getContext());
146 case S16:
147 case U16:
148 case BF16:
149 // For BF16 we just need some type that is 16 bits wide so that it will
150 // take up the right amount of space in memory. LLVM does not have a BF16
151 // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so
152 // we can't map it directly to an LLVM type. We will not map a BF16
153 // addition to an addition on this type (int16) - this is just the type
154 // used for storage.
155 return llvm::Type::getInt16Ty(module->getContext());
156 case F16:
157 return llvm::Type::getHalfTy(module->getContext());
158 case S32:
159 case U32:
160 return llvm::Type::getInt32Ty(module->getContext());
161 case S64:
162 case U64:
163 return llvm::Type::getInt64Ty(module->getContext());
164 case F32:
165 return llvm::Type::getFloatTy(module->getContext());
166 case F64:
167 return llvm::Type::getDoubleTy(module->getContext());
168 case C64: {
169 auto cplx_t = module->getTypeByName("complex64");
170 if (cplx_t == nullptr) {
171 // C++ standard dictates the memory layout of std::complex is contiguous
172 // real followed by imaginary. C++11 section 26.4 [complex.numbers]:
173 // If z is an lvalue expression of type cv std::complex<T> then the
174 // expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed,
175 // reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of
176 // z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the
177 // imaginary part of z.
178 return llvm::StructType::create(
179 {llvm::Type::getFloatTy(module->getContext()),
180 llvm::Type::getFloatTy(module->getContext())},
181 "complex64", /*isPacked=*/true);
182 }
183 return cplx_t;
184 }
185 case C128: {
186 auto cplx_t = module->getTypeByName("complex128");
187 if (cplx_t == nullptr) {
188 return llvm::StructType::create(
189 {llvm::Type::getDoubleTy(module->getContext()),
190 llvm::Type::getDoubleTy(module->getContext())},
191 "complex128", /*isPacked=*/true);
192 }
193 return cplx_t;
194 } // A Tuple contains an array of pointers. Use i8*.
195 case TUPLE:
196 // An Opaque is like a void*, use i8*.
197 case OPAQUE:
198 return llvm::Type::getInt8PtrTy(module->getContext());
199 case TOKEN:
200 // Tokens do not have a physical representation, but the compiler needs
201 // some placeholder type, so use int8*.
202 return llvm::Type::getInt8PtrTy(module->getContext());
203 default:
204 LOG(FATAL) << "unsupported type " << element_type;
205 }
206 }
207
GetSizeInBits(llvm::Type * type)208 int GetSizeInBits(llvm::Type* type) {
209 const llvm::StructType* struct_ty = llvm::dyn_cast<llvm::StructType>(type);
210 if (struct_ty) {
211 CHECK(struct_ty->isPacked());
212 int bits = 0;
213 for (auto element_type : struct_ty->elements()) {
214 bits += GetSizeInBits(element_type);
215 }
216 return bits;
217 }
218 int bits = type->getPrimitiveSizeInBits();
219 CHECK_GT(bits, 0) << "type is not sized";
220 return bits;
221 }
222
ShapeToIrType(const Shape & shape,llvm::Module * module)223 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) {
224 llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module);
225 if (shape.IsTuple()) {
226 // A tuple buffer is an array of pointers.
227 result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
228 } else if (shape.IsArray()) {
229 for (int64 dimension : LayoutUtil::MinorToMajor(shape)) {
230 result_type =
231 llvm::ArrayType::get(result_type, shape.dimensions(dimension));
232 }
233 }
234 return result_type;
235 }
236
EncodeSelfDescribingShapeConstant(const Shape & shape,int32 * shape_size,llvm::IRBuilder<> * b)237 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
238 int32* shape_size,
239 llvm::IRBuilder<>* b) {
240 string encoded_shape = shape.SerializeAsString();
241 if (encoded_shape.size() > std::numeric_limits<int32>::max()) {
242 return InternalError("Encoded shape size exceeded int32 size limit.");
243 }
244 *shape_size = static_cast<int32>(encoded_shape.size());
245 return b->CreateGlobalStringPtr(encoded_shape);
246 }
247
DecodeSelfDescribingShapeConstant(const void * shape_ptr,int32 size_bytes)248 StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr,
249 int32 size_bytes) {
250 ShapeProto shape_proto;
251 TF_RET_CHECK(shape_proto.ParseFromArray(shape_ptr, size_bytes));
252 Shape shape(shape_proto);
253 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
254 return std::move(shape);
255 }
256
ConvertLiteralToIrConstant(const Literal & literal,llvm::Module * module)257 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
258 llvm::Module* module) {
259 const char* data = static_cast<const char*>(literal.untyped_data());
260 CHECK_EQ(module->getDataLayout().isLittleEndian(),
261 tensorflow::port::kLittleEndian);
262 return llvm::ConstantDataArray::getString(
263 module->getContext(), llvm::StringRef(data, literal.size_bytes()),
264 /*AddNull=*/false);
265 }
266
AllocateSharedMemoryTile(llvm::Module * module,llvm::Type * tile_type,absl::string_view name)267 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
268 llvm::Type* tile_type,
269 absl::string_view name) {
270 const int kNVPTXSharedMemoryAddrSpace = 3;
271 return new llvm::GlobalVariable(
272 *module, tile_type,
273 /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
274 llvm::UndefValue::get(tile_type), AsStringRef(name), nullptr,
275 llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace);
276 }
277
EmitAllocaAtFunctionEntry(llvm::Type * type,absl::string_view name,llvm::IRBuilder<> * b,int alignment)278 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
279 absl::string_view name,
280 llvm::IRBuilder<>* b,
281 int alignment) {
282 return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment);
283 }
284
EmitAllocaAtFunctionEntryWithCount(llvm::Type * type,llvm::Value * element_count,absl::string_view name,llvm::IRBuilder<> * b,int alignment)285 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
286 llvm::Value* element_count,
287 absl::string_view name,
288 llvm::IRBuilder<>* b,
289 int alignment) {
290 llvm::IRBuilder<>::InsertPointGuard guard(*b);
291 llvm::Function* function = b->GetInsertBlock()->getParent();
292 b->SetInsertPoint(&function->getEntryBlock(),
293 function->getEntryBlock().getFirstInsertionPt());
294 llvm::AllocaInst* alloca =
295 b->CreateAlloca(type, element_count, AsStringRef(name));
296 if (alignment != 0) {
297 alloca->setAlignment(alignment);
298 }
299 return alloca;
300 }
301
CreateBasicBlock(llvm::BasicBlock * insert_before,absl::string_view name,llvm::IRBuilder<> * b)302 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
303 absl::string_view name,
304 llvm::IRBuilder<>* b) {
305 return llvm::BasicBlock::Create(
306 /*Context=*/b->getContext(),
307 /*Name=*/AsStringRef(name),
308 /*Parent=*/b->GetInsertBlock()->getParent(),
309 /*InsertBefore*/ insert_before);
310 }
311
EmitIfThenElse(llvm::Value * condition,absl::string_view name,llvm::IRBuilder<> * b,bool emit_else)312 LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
313 llvm::IRBuilder<>* b, bool emit_else) {
314 llvm_ir::LlvmIfData if_data;
315 if_data.if_block = b->GetInsertBlock();
316 if_data.true_block =
317 CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b);
318 if_data.false_block =
319 emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b)
320 : nullptr;
321
322 // Add a terminator to the if block, if necessary.
323 if (if_data.if_block->getTerminator() == nullptr) {
324 b->SetInsertPoint(if_data.if_block);
325 if_data.after_block =
326 CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b);
327 b->CreateBr(if_data.after_block);
328 } else {
329 if_data.after_block = if_data.if_block->splitBasicBlock(
330 b->GetInsertPoint(), absl::StrCat(name, "-after"));
331 }
332
333 // Our basic block should now end with an unconditional branch. Remove it;
334 // we're going to replace it with a conditional branch.
335 if_data.if_block->getTerminator()->eraseFromParent();
336
337 b->SetInsertPoint(if_data.if_block);
338 b->CreateCondBr(condition, if_data.true_block,
339 emit_else ? if_data.false_block : if_data.after_block);
340
341 b->SetInsertPoint(if_data.true_block);
342 b->CreateBr(if_data.after_block);
343
344 if (emit_else) {
345 b->SetInsertPoint(if_data.false_block);
346 b->CreateBr(if_data.after_block);
347 }
348
349 b->SetInsertPoint(if_data.after_block,
350 if_data.after_block->getFirstInsertionPt());
351
352 return if_data;
353 }
354
EmitComparison(llvm::CmpInst::Predicate predicate,llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)355 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
356 llvm::Value* lhs_value, llvm::Value* rhs_value,
357 llvm::IRBuilder<>* b) {
358 llvm::Value* comparison_result;
359 if (lhs_value->getType()->isIntegerTy()) {
360 comparison_result = b->CreateICmp(predicate, lhs_value, rhs_value);
361 } else {
362 comparison_result = b->CreateFCmp(predicate, lhs_value, rhs_value);
363 }
364 // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
365 // arrays. So we extend it to i8 so that it's addressable.
366 return b->CreateZExt(comparison_result, llvm_ir::PrimitiveTypeToIrType(
367 PRED, ModuleFromIRBuilder(b)));
368 }
369
370 // Internal helper that is called from emitted code to log an int64 value with a
371 // tag.
LogS64(const char * tag,int64 value)372 static void LogS64(const char* tag, int64 value) {
373 LOG(INFO) << tag << " (int64): " << value;
374 }
375
EmitLogging(const char * tag,llvm::Value * value,llvm::IRBuilder<> * b)376 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b) {
377 llvm::FunctionType* log_function_type = llvm::FunctionType::get(
378 b->getVoidTy(), {b->getInt64Ty(), b->getInt64Ty()}, /*isVarArg=*/false);
379 b->CreateCall(log_function_type,
380 b->CreateIntToPtr(b->getInt64(absl::bit_cast<int64>(&LogS64)),
381 log_function_type->getPointerTo()),
382 {b->getInt64(absl::bit_cast<int64>(tag)), value});
383 }
384
SetAlignmentMetadataForLoad(llvm::LoadInst * load,uint64_t alignment)385 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
386 llvm::LLVMContext& context = load->getContext();
387 llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
388 llvm::Constant* alignment_constant =
389 llvm::ConstantInt::get(int64_ty, alignment);
390 llvm::MDBuilder metadata_builder(context);
391 auto* alignment_metadata =
392 metadata_builder.createConstant(alignment_constant);
393 load->setMetadata(llvm::LLVMContext::MD_align,
394 llvm::MDNode::get(context, alignment_metadata));
395 }
396
SetDereferenceableMetadataForLoad(llvm::LoadInst * load,uint64_t dereferenceable_bytes)397 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
398 uint64_t dereferenceable_bytes) {
399 llvm::LLVMContext& context = load->getContext();
400 llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
401 llvm::Constant* dereferenceable_bytes_constant =
402 llvm::ConstantInt::get(int64_ty, dereferenceable_bytes);
403 llvm::MDBuilder metadata_builder(context);
404 auto* dereferenceable_bytes_metadata =
405 metadata_builder.createConstant(dereferenceable_bytes_constant);
406 load->setMetadata(llvm::LLVMContext::MD_dereferenceable,
407 llvm::MDNode::get(context, dereferenceable_bytes_metadata));
408 }
409
AddRangeMetadata(int64 lower,int64 upper,llvm::Instruction * inst)410 llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
411 llvm::Instruction* inst) {
412 llvm::LLVMContext& context = inst->getParent()->getContext();
413 llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context);
414 inst->setMetadata(
415 llvm::LLVMContext::MD_range,
416 llvm::MDNode::get(
417 context,
418 {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)),
419 llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))}));
420 return inst;
421 }
422
IrName(string a)423 string IrName(string a) {
424 a.erase(std::remove(a.begin(), a.end(), '%'), a.end());
425 return a;
426 }
427
IrName(absl::string_view a,absl::string_view b)428 string IrName(absl::string_view a, absl::string_view b) {
429 if (!a.empty() && !b.empty()) {
430 return IrName(absl::StrCat(a, ".", b));
431 }
432 return IrName(absl::StrCat(a, b));
433 }
434
IrName(const HloInstruction * a,absl::string_view b)435 string IrName(const HloInstruction* a, absl::string_view b) {
436 return IrName(a->name(), b);
437 }
438
SanitizeFunctionName(string function_name)439 string SanitizeFunctionName(string function_name) {
440 // The backend with the strictest requirements on function names is NVPTX, so
441 // we sanitize to its requirements.
442 //
443 // A slightly stricter version of the NVPTX requirements is that names match
444 // /[a-zA-Z_$][a-zA-Z0-9_$]*/, with the exception that the names "_" and "$"
445 // are illegal.
446
447 // Sanitize chars in function_name.
448 std::transform(function_name.begin(), function_name.end(),
449 function_name.begin(), [](char c) {
450 if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') ||
451 ('0' <= c && c <= '9') || c == '_' || c == '$') {
452 return c;
453 }
454 return '_';
455 });
456
457 // Ensure the name isn't empty.
458 if (function_name.empty()) {
459 function_name = "__unnamed";
460 }
461
462 // Ensure the name doesn't start with a number.
463 if (!function_name.empty() && function_name[0] >= '0' &&
464 function_name[0] <= '9') {
465 function_name.insert(function_name.begin(), '_');
466 }
467
468 // Ensure the name isn't "_" or "$".
469 if (function_name == "_" || function_name == "$") {
470 function_name += '_';
471 }
472
473 return function_name;
474 }
475
SetToFirstInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)476 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
477 builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
478 }
479
SetToLastInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)480 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
481 if (llvm::Instruction* terminator = blk->getTerminator()) {
482 builder->SetInsertPoint(terminator);
483 } else {
484 builder->SetInsertPoint(blk);
485 }
486 }
487
CreateRor(llvm::Value * rotand,llvm::Value * rotor,llvm::IRBuilder<> * builder)488 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
489 llvm::IRBuilder<>* builder) {
490 auto size = rotand->getType()->getPrimitiveSizeInBits();
491 auto size_value = builder->getIntN(size, size);
492 auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); };
493 return builder->CreateOr(
494 builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))),
495 builder->CreateLShr(rotand, mod(rotor)));
496 }
497
ByteSizeOf(const Shape & shape,const llvm::DataLayout & data_layout)498 int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
499 unsigned pointer_size = data_layout.getPointerSize();
500 return ShapeUtil::ByteSizeOf(shape, pointer_size);
501 }
502
GetCpuFastMathFlags(const HloModuleConfig & module_config)503 llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) {
504 llvm::FastMathFlags flags;
505 if (!module_config.debug_options().xla_cpu_enable_fast_math()) {
506 return flags;
507 }
508
509 // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, AllowReciprocal,
510 // AllowContract, and ApproxFunc.
511 flags.setFast();
512
513 if (module_config.debug_options().xla_cpu_fast_math_honor_nans()) {
514 flags.setNoNaNs(false);
515 }
516
517 if (module_config.debug_options().xla_cpu_fast_math_honor_infs()) {
518 flags.setNoInfs(false);
519 }
520
521 return flags;
522 }
523
MergeMetadata(llvm::LLVMContext * context,const std::map<int,llvm::MDNode * > & a,const std::map<int,llvm::MDNode * > & b)524 std::map<int, llvm::MDNode*> MergeMetadata(
525 llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
526 const std::map<int, llvm::MDNode*>& b) {
527 // We should extend this as needed to deal with other kinds of metadata like
528 // !dereferenceable and !range.
529
530 std::map<int, llvm::MDNode*> result;
531 for (auto kind_md_pair : a) {
532 if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) {
533 llvm::SmallVector<llvm::Metadata*, 8> union_of_scopes;
534 llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
535 for (const auto& scope_a : kind_md_pair.second->operands()) {
536 scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
537 union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_a.get()));
538 }
539 auto it = b.find(kind_md_pair.first);
540 if (it != b.end()) {
541 for (const auto& scope_b : it->second->operands()) {
542 if (!scope_set.count(llvm::cast<llvm::MDNode>(scope_b.get()))) {
543 union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b.get()));
544 }
545 }
546 }
547 result[llvm::LLVMContext::MD_alias_scope] =
548 llvm::MDNode::get(*context, union_of_scopes);
549 } else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) {
550 llvm::SmallVector<llvm::Metadata*, 8> intersection_of_scopes;
551 llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
552 for (const auto& scope_a : kind_md_pair.second->operands()) {
553 scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
554 }
555 auto it = b.find(kind_md_pair.first);
556 if (it != b.end()) {
557 for (const auto& scope_b : it->second->operands()) {
558 if (scope_set.count(llvm::cast<llvm::MDNode>(scope_b))) {
559 intersection_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b));
560 }
561 }
562 }
563 if (!intersection_of_scopes.empty()) {
564 result[llvm::LLVMContext::MD_noalias] =
565 llvm::MDNode::get(*context, intersection_of_scopes);
566 }
567 }
568 }
569 return result;
570 }
571
CreateAndWriteStringToFile(const string & directory_name,const string & file_name,const string & text)572 static Status CreateAndWriteStringToFile(const string& directory_name,
573 const string& file_name,
574 const string& text) {
575 std::unique_ptr<tensorflow::WritableFile> f;
576 TF_RETURN_IF_ERROR(
577 tensorflow::Env::Default()->RecursivelyCreateDir(directory_name));
578 TF_RETURN_IF_ERROR(
579 tensorflow::Env::Default()->NewWritableFile(file_name, &f));
580 TF_RETURN_IF_ERROR(f->Append(text));
581 TF_RETURN_IF_ERROR(f->Close());
582 return Status::OK();
583 }
584
DumpIrIfEnabled(const HloModule & hlo_module,const llvm::Module & llvm_module,bool optimized)585 void DumpIrIfEnabled(const HloModule& hlo_module,
586 const llvm::Module& llvm_module, bool optimized) {
587 const auto& debug_opts = hlo_module.config().debug_options();
588 if (!DumpingEnabledForHloModule(hlo_module)) {
589 return;
590 }
591 // We can end up compiling different modules with the same name when using
592 // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously
593 // dumped from the same process in such cases.
594 string suffix = absl::StrCat("ir-", optimized ? "with" : "no", "-opt");
595 DumpToFileInDirOrStdout(hlo_module, absl::StrCat(suffix, ".ll"),
596 DumpModuleToString(llvm_module));
597
598 // For some models the embedded constants can be huge, so also dump the module
599 // with the constants stripped to get IR that is easier to manipulate. Skip
600 // this if we're dumping to stdout; there's no point in duplicating everything
601 // when writing to the terminal.
602 if (!DumpingToStdout(debug_opts)) {
603 DumpToFileInDir(hlo_module, absl::StrCat(suffix, "-noconst.ll"),
604 DumpModuleToString(*DropConstantInitializers(llvm_module)));
605 }
606 }
607
CreateCpuFunction(llvm::FunctionType * function_type,llvm::GlobalValue::LinkageTypes linkage,const HloModuleConfig & module_config,absl::string_view name,llvm::Module * module)608 llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
609 llvm::GlobalValue::LinkageTypes linkage,
610 const HloModuleConfig& module_config,
611 absl::string_view name,
612 llvm::Module* module) {
613 llvm::Function* function =
614 llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
615 function->setCallingConv(llvm::CallingConv::C);
616 function->addFnAttr("no-frame-pointer-elim", "false");
617
618 // Generate unwind information so that GDB can crawl through the stack frames
619 // created by the JIT compiled code.
620 function->setHasUWTable();
621
622 if (module_config.debug_options().xla_cpu_enable_fast_math()) {
623 function->addFnAttr("unsafe-fp-math", "true");
624 function->addFnAttr("no-signed-zeros-fp-math", "true");
625
626 if (!module_config.debug_options().xla_cpu_fast_math_honor_nans()) {
627 function->addFnAttr("no-nans-fp-math", "true");
628 }
629
630 if (!module_config.debug_options().xla_cpu_fast_math_honor_infs()) {
631 function->addFnAttr("no-infs-fp-math", "true");
632 }
633 }
634
635 // Add the optize attribute to the function if optimizing for size. This
636 // controls internal behavior of some optimization passes (e.g. loop
637 // unrolling).
638 if (cpu::options::OptimizeForSizeRequested(module_config)) {
639 function->addFnAttr(llvm::Attribute::OptimizeForSize);
640 }
641
642 return function;
643 }
644
InitializeLLVMCommandLineOptions(const HloModuleConfig & config)645 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
646 auto options = config.debug_options().xla_backend_extra_options();
647 if (!options.empty()) {
648 std::vector<string> fake_argv_storage;
649 fake_argv_storage.push_back("");
650 for (const auto& it : options) {
651 // Skip options the XLA backend itself consumes.
652 if (!absl::StartsWith(it.first, "xla_")) {
653 if (it.second.empty()) {
654 fake_argv_storage.push_back(it.first);
655 } else {
656 fake_argv_storage.push_back(it.first + "=" + it.second);
657 }
658 }
659 }
660
661 VLOG(2) << "Passing argv to LLVM:";
662 std::vector<const char*> fake_argv;
663 for (const auto& s : fake_argv_storage) {
664 fake_argv.push_back(s.c_str());
665 VLOG(2) << s;
666 }
667 llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]);
668 }
669 }
670
UMulLowHigh32(llvm::IRBuilder<> * b,llvm::Value * src0,llvm::Value * src1)671 std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b,
672 llvm::Value* src0,
673 llvm::Value* src1) {
674 CHECK_EQ(src0->getType()->getPrimitiveSizeInBits(), 32);
675 CHECK_EQ(src1->getType()->getPrimitiveSizeInBits(), 32);
676 llvm::Type* int64_ty = b->getInt64Ty();
677 src0 = b->CreateZExt(src0, int64_ty);
678 src1 = b->CreateZExt(src1, int64_ty);
679 return SplitInt64ToInt32s(b, b->CreateMul(src0, src1));
680 }
681
SplitInt64ToInt32s(llvm::IRBuilder<> * b,llvm::Value * value_64bits)682 std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s(
683 llvm::IRBuilder<>* b, llvm::Value* value_64bits) {
684 CHECK_EQ(value_64bits->getType()->getPrimitiveSizeInBits(), 64);
685 llvm::Type* int32_ty = b->getInt32Ty();
686 llvm::Value* low_32bits = b->CreateTrunc(value_64bits, int32_ty);
687 llvm::Value* high_32bits =
688 b->CreateTrunc(b->CreateLShr(value_64bits, 32), int32_ty);
689 return std::make_pair(low_32bits, high_32bits);
690 }
691
GetOrCreateVariableForPhiloxRngState(llvm::Module * module,llvm::IRBuilder<> * b)692 llvm::GlobalVariable* GetOrCreateVariableForPhiloxRngState(
693 llvm::Module* module, llvm::IRBuilder<>* b) {
694 static const char* kPhiloxRngStateVariableName = "philox_rng_state";
695 llvm::GlobalVariable* state_ptr =
696 module->getNamedGlobal(kPhiloxRngStateVariableName);
697 if (!state_ptr) {
698 state_ptr = new llvm::GlobalVariable(
699 /*M=*/*module,
700 /*Ty=*/b->getInt64Ty(),
701 /*isConstant=*/false,
702 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
703 /*Initializer=*/b->getInt64(0),
704 /*Name=*/kPhiloxRngStateVariableName);
705 }
706 return state_ptr;
707 }
708
IncrementVariableForPhiloxRngState(int64 value,llvm::Module * module,llvm::IRBuilder<> * builder)709 void IncrementVariableForPhiloxRngState(int64 value, llvm::Module* module,
710 llvm::IRBuilder<>* builder) {
711 llvm::GlobalVariable* state_ptr =
712 GetOrCreateVariableForPhiloxRngState(module, builder);
713 llvm::Value* state_value_old = builder->CreateLoad(state_ptr, "load_state");
714 // If the 64-bit value overflows, we use the wraparound value. This should
715 // be fine in practice as we only add one to the value each time when a RNG is
716 // executed.
717 llvm::Value* state_value_new = builder->CreateAdd(
718 state_value_old, builder->getInt64(value), "inc_state");
719 builder->CreateStore(state_value_new, state_ptr);
720 }
721
722 } // namespace llvm_ir
723 } // namespace xla
724