1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21
22 #include "absl/base/casts.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/GlobalValue.h"
27 #include "llvm/IR/GlobalVariable.h"
28 #include "llvm/IR/MDBuilder.h"
29 #include "llvm/IR/Operator.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Target/TargetOptions.h"
32 #include "llvm/Transforms/Utils/Cloning.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/literal.h"
35 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
36 #include "tensorflow/compiler/xla/service/dump.h"
37 #include "tensorflow/compiler/xla/service/name_uniquer.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/io/path.h"
43 #include "tensorflow/core/platform/byte_order.h"
44 #include "tensorflow/core/platform/env.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/types.h"
47
48 namespace xla {
49 namespace llvm_ir {
50
51 namespace {
52
53 // Note, this function is only useful in an insertion context; in a global
54 // (e.g. constants) context it will CHECK fail.
ModuleFromIRBuilder(llvm::IRBuilder<> * b)55 llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* b) {
56 auto block = CHECK_NOTNULL(b->GetInsertBlock());
57 auto fn = CHECK_NOTNULL(block->getParent());
58 auto module = CHECK_NOTNULL(fn->getParent());
59 return module;
60 }
61
62 } // namespace
63
DropConstantInitializers(const llvm::Module & module)64 std::unique_ptr<llvm::Module> DropConstantInitializers(
65 const llvm::Module& module) {
66 std::unique_ptr<llvm::Module> cloned_module = CloneModule(module);
67 for (llvm::GlobalVariable& global_var : cloned_module->globals()) {
68 global_var.setInitializer(nullptr);
69 global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage);
70 }
71 return cloned_module;
72 }
73
DumpModuleToString(const llvm::Module & module)74 string DumpModuleToString(const llvm::Module& module) {
75 std::string buffer_string;
76 llvm::raw_string_ostream ostream(buffer_string);
77 module.print(ostream, nullptr);
78 ostream.flush();
79 return buffer_string;
80 }
81
EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id,absl::Span<llvm::Value * const> operands,absl::Span<llvm::Type * const> overloaded_types,llvm::IRBuilder<> * b)82 llvm::CallInst* EmitCallToIntrinsic(
83 llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
84 absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b) {
85 llvm::Module* module = ModuleFromIRBuilder(b);
86 llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
87 module, intrinsic_id, AsArrayRef(overloaded_types));
88 return b->CreateCall(intrinsic, AsArrayRef(operands));
89 }
90
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)91 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
92 llvm::IRBuilder<>* b) {
93 if (b->getFastMathFlags().noNaNs()) {
94 auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
95 return b->CreateSelect(cmp, lhs_value, rhs_value);
96 } else {
97 auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value);
98 auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
99 auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan);
100 return b->CreateSelect(sel_lhs, lhs_value, rhs_value);
101 }
102 }
103
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)104 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
105 llvm::IRBuilder<>* b) {
106 if (b->getFastMathFlags().noNaNs()) {
107 auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
108 return b->CreateSelect(cmp, lhs_value, rhs_value);
109 } else {
110 auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value);
111 auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
112 auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan);
113 return b->CreateSelect(sel_lhs, lhs_value, rhs_value);
114 }
115 }
116
EmitBufferIndexingGEP(llvm::Value * array,llvm::Value * index,llvm::IRBuilder<> * b)117 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index,
118 llvm::IRBuilder<>* b) {
119 llvm::Type* array_type = array->getType();
120 CHECK(array_type->isPointerTy());
121 llvm::PointerType* array_type_as_pointer =
122 llvm::cast<llvm::PointerType>(array_type);
123 VLOG(2) << "EmitBufferIndexingGEP with type="
124 << llvm_ir::DumpToString(*array_type)
125 << " array=" << llvm_ir::DumpToString(*array)
126 << " index=" << llvm_ir::DumpToString(*index);
127
128 return b->CreateInBoundsGEP(
129 array_type_as_pointer->getElementType(), array,
130 llvm::isa<llvm::GlobalVariable>(array)
131 ? llvm::ArrayRef<llvm::Value*>({b->getInt64(0), index})
132 : index);
133 }
134
EmitBufferIndexingGEP(llvm::Value * array,int64 index,llvm::IRBuilder<> * b)135 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
136 llvm::IRBuilder<>* b) {
137 return EmitBufferIndexingGEP(array, b->getInt64(index), b);
138 }
139
PrimitiveTypeToIrType(PrimitiveType element_type,llvm::Module * module)140 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
141 llvm::Module* module) {
142 switch (element_type) {
143 case PRED:
144 case S8:
145 case U8:
146 return llvm::Type::getInt8Ty(module->getContext());
147 case S16:
148 case U16:
149 case BF16:
150 // For BF16 we just need some type that is 16 bits wide so that it will
151 // take up the right amount of space in memory. LLVM does not have a BF16
152 // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so
153 // we can't map it directly to an LLVM type. We will not map a BF16
154 // addition to an addition on this type (int16) - this is just the type
155 // used for storage.
156 return llvm::Type::getInt16Ty(module->getContext());
157 case F16:
158 return llvm::Type::getHalfTy(module->getContext());
159 case S32:
160 case U32:
161 return llvm::Type::getInt32Ty(module->getContext());
162 case S64:
163 case U64:
164 return llvm::Type::getInt64Ty(module->getContext());
165 case F32:
166 return llvm::Type::getFloatTy(module->getContext());
167 case F64:
168 return llvm::Type::getDoubleTy(module->getContext());
169 case C64: {
170 auto cplx_t = module->getTypeByName("complex64");
171 if (cplx_t == nullptr) {
172 // C++ standard dictates the memory layout of std::complex is contiguous
173 // real followed by imaginary. C++11 section 26.4 [complex.numbers]:
174 // If z is an lvalue expression of type cv std::complex<T> then the
175 // expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed,
176 // reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of
177 // z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the
178 // imaginary part of z.
179 return llvm::StructType::create(
180 {llvm::Type::getFloatTy(module->getContext()),
181 llvm::Type::getFloatTy(module->getContext())},
182 "complex64", /*isPacked=*/true);
183 }
184 return cplx_t;
185 }
186 case C128: {
187 auto cplx_t = module->getTypeByName("complex128");
188 if (cplx_t == nullptr) {
189 return llvm::StructType::create(
190 {llvm::Type::getDoubleTy(module->getContext()),
191 llvm::Type::getDoubleTy(module->getContext())},
192 "complex128", /*isPacked=*/true);
193 }
194 return cplx_t;
195 } // A Tuple contains an array of pointers. Use i8*.
196 case TUPLE:
197 // An Opaque is like a void*, use i8*.
198 case OPAQUE_TYPE:
199 return llvm::Type::getInt8PtrTy(module->getContext());
200 case TOKEN:
201 // Tokens do not have a physical representation, but the compiler needs
202 // some placeholder type, so use int8*.
203 return llvm::Type::getInt8PtrTy(module->getContext());
204 default:
205 LOG(FATAL) << "unsupported type " << element_type;
206 }
207 }
208
GetSizeInBits(llvm::Type * type)209 int GetSizeInBits(llvm::Type* type) {
210 const llvm::StructType* struct_ty = llvm::dyn_cast<llvm::StructType>(type);
211 if (struct_ty) {
212 CHECK(struct_ty->isPacked());
213 int bits = 0;
214 for (auto element_type : struct_ty->elements()) {
215 bits += GetSizeInBits(element_type);
216 }
217 return bits;
218 }
219 int bits = type->getPrimitiveSizeInBits();
220 CHECK_GT(bits, 0) << "type is not sized";
221 return bits;
222 }
223
ShapeToIrType(const Shape & shape,llvm::Module * module)224 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) {
225 llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module);
226 if (shape.IsTuple()) {
227 // A tuple buffer is an array of pointers.
228 result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
229 } else if (shape.IsArray()) {
230 for (int64 dimension : LayoutUtil::MinorToMajor(shape)) {
231 result_type =
232 llvm::ArrayType::get(result_type, shape.dimensions(dimension));
233 }
234 }
235 return result_type;
236 }
237
EncodeSelfDescribingShapeConstant(const Shape & shape,int32 * shape_size,llvm::IRBuilder<> * b)238 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
239 int32* shape_size,
240 llvm::IRBuilder<>* b) {
241 string encoded_shape = shape.SerializeAsString();
242 if (encoded_shape.size() > std::numeric_limits<int32>::max()) {
243 return InternalError("Encoded shape size exceeded int32 size limit.");
244 }
245 *shape_size = static_cast<int32>(encoded_shape.size());
246 return b->CreateGlobalStringPtr(encoded_shape);
247 }
248
ConvertLiteralToIrConstant(const Literal & literal,llvm::Module * module)249 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
250 llvm::Module* module) {
251 const char* data = static_cast<const char*>(literal.untyped_data());
252 CHECK_EQ(module->getDataLayout().isLittleEndian(),
253 tensorflow::port::kLittleEndian);
254 return llvm::ConstantDataArray::getString(
255 module->getContext(), llvm::StringRef(data, literal.size_bytes()),
256 /*AddNull=*/false);
257 }
258
AllocateSharedMemoryTile(llvm::Module * module,llvm::Type * tile_type,absl::string_view name)259 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
260 llvm::Type* tile_type,
261 absl::string_view name) {
262 // Both AMDGPU and NVPTX use the same address space for shared memory.
263 const int kGPUSharedMemoryAddrSpace = 3;
264 return new llvm::GlobalVariable(
265 *module, tile_type,
266 /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
267 llvm::UndefValue::get(tile_type), AsStringRef(name), nullptr,
268 llvm::GlobalValue::NotThreadLocal, kGPUSharedMemoryAddrSpace);
269 }
270
EmitAllocaAtFunctionEntry(llvm::Type * type,absl::string_view name,llvm::IRBuilder<> * b,int alignment)271 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
272 absl::string_view name,
273 llvm::IRBuilder<>* b,
274 int alignment) {
275 return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment);
276 }
277
EmitAllocaAtFunctionEntryWithCount(llvm::Type * type,llvm::Value * element_count,absl::string_view name,llvm::IRBuilder<> * b,int alignment)278 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
279 llvm::Value* element_count,
280 absl::string_view name,
281 llvm::IRBuilder<>* b,
282 int alignment) {
283 llvm::IRBuilder<>::InsertPointGuard guard(*b);
284 llvm::Function* function = b->GetInsertBlock()->getParent();
285 b->SetInsertPoint(&function->getEntryBlock(),
286 function->getEntryBlock().getFirstInsertionPt());
287 llvm::AllocaInst* alloca =
288 b->CreateAlloca(type, element_count, AsStringRef(name));
289 if (alignment != 0) {
290 alloca->setAlignment(llvm::MaybeAlign(alignment));
291 }
292 return alloca;
293 }
294
CreateBasicBlock(llvm::BasicBlock * insert_before,absl::string_view name,llvm::IRBuilder<> * b)295 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
296 absl::string_view name,
297 llvm::IRBuilder<>* b) {
298 return llvm::BasicBlock::Create(
299 /*Context=*/b->getContext(),
300 /*Name=*/AsStringRef(name),
301 /*Parent=*/b->GetInsertBlock()->getParent(),
302 /*InsertBefore*/ insert_before);
303 }
304
EmitIfThenElse(llvm::Value * condition,absl::string_view name,llvm::IRBuilder<> * b,bool emit_else)305 LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
306 llvm::IRBuilder<>* b, bool emit_else) {
307 llvm_ir::LlvmIfData if_data;
308 if_data.if_block = b->GetInsertBlock();
309 if_data.true_block =
310 CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b);
311 if_data.false_block =
312 emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b)
313 : nullptr;
314
315 // Add a terminator to the if block, if necessary.
316 if (if_data.if_block->getTerminator() == nullptr) {
317 b->SetInsertPoint(if_data.if_block);
318 if_data.after_block =
319 CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b);
320 b->CreateBr(if_data.after_block);
321 } else {
322 if_data.after_block = if_data.if_block->splitBasicBlock(
323 b->GetInsertPoint(), absl::StrCat(name, "-after"));
324 }
325
326 // Our basic block should now end with an unconditional branch. Remove it;
327 // we're going to replace it with a conditional branch.
328 if_data.if_block->getTerminator()->eraseFromParent();
329
330 b->SetInsertPoint(if_data.if_block);
331 b->CreateCondBr(condition, if_data.true_block,
332 emit_else ? if_data.false_block : if_data.after_block);
333
334 b->SetInsertPoint(if_data.true_block);
335 b->CreateBr(if_data.after_block);
336
337 if (emit_else) {
338 b->SetInsertPoint(if_data.false_block);
339 b->CreateBr(if_data.after_block);
340 }
341
342 b->SetInsertPoint(if_data.after_block,
343 if_data.after_block->getFirstInsertionPt());
344
345 return if_data;
346 }
347
EmitComparison(llvm::CmpInst::Predicate predicate,llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)348 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
349 llvm::Value* lhs_value, llvm::Value* rhs_value,
350 llvm::IRBuilder<>* b) {
351 llvm::Value* comparison_result;
352 if (lhs_value->getType()->isIntegerTy()) {
353 comparison_result = b->CreateICmp(predicate, lhs_value, rhs_value);
354 } else {
355 comparison_result = b->CreateFCmp(predicate, lhs_value, rhs_value);
356 }
357 // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
358 // arrays. So we extend it to i8 so that it's addressable.
359 return b->CreateZExt(comparison_result, llvm_ir::PrimitiveTypeToIrType(
360 PRED, ModuleFromIRBuilder(b)));
361 }
362
363 // Internal helper that is called from emitted code to log an int64 value with a
364 // tag.
LogS64(const char * tag,int64 value)365 static void LogS64(const char* tag, int64 value) {
366 LOG(INFO) << tag << " (int64): " << value;
367 }
368
EmitLogging(const char * tag,llvm::Value * value,llvm::IRBuilder<> * b)369 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b) {
370 llvm::FunctionType* log_function_type = llvm::FunctionType::get(
371 b->getVoidTy(), {b->getInt64Ty(), b->getInt64Ty()}, /*isVarArg=*/false);
372 b->CreateCall(log_function_type,
373 b->CreateIntToPtr(b->getInt64(absl::bit_cast<int64>(&LogS64)),
374 log_function_type->getPointerTo()),
375 {b->getInt64(absl::bit_cast<int64>(tag)), value});
376 }
377
SetAlignmentMetadataForLoad(llvm::LoadInst * load,uint64_t alignment)378 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
379 llvm::LLVMContext& context = load->getContext();
380 llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
381 llvm::Constant* alignment_constant =
382 llvm::ConstantInt::get(int64_ty, alignment);
383 llvm::MDBuilder metadata_builder(context);
384 auto* alignment_metadata =
385 metadata_builder.createConstant(alignment_constant);
386 load->setMetadata(llvm::LLVMContext::MD_align,
387 llvm::MDNode::get(context, alignment_metadata));
388 }
389
SetDereferenceableMetadataForLoad(llvm::LoadInst * load,uint64_t dereferenceable_bytes)390 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
391 uint64_t dereferenceable_bytes) {
392 llvm::LLVMContext& context = load->getContext();
393 llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
394 llvm::Constant* dereferenceable_bytes_constant =
395 llvm::ConstantInt::get(int64_ty, dereferenceable_bytes);
396 llvm::MDBuilder metadata_builder(context);
397 auto* dereferenceable_bytes_metadata =
398 metadata_builder.createConstant(dereferenceable_bytes_constant);
399 load->setMetadata(llvm::LLVMContext::MD_dereferenceable,
400 llvm::MDNode::get(context, dereferenceable_bytes_metadata));
401 }
402
AddRangeMetadata(int64 lower,int64 upper,llvm::Instruction * inst)403 llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
404 llvm::Instruction* inst) {
405 llvm::LLVMContext& context = inst->getParent()->getContext();
406 llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context);
407 inst->setMetadata(
408 llvm::LLVMContext::MD_range,
409 llvm::MDNode::get(
410 context,
411 {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)),
412 llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))}));
413 return inst;
414 }
415
IrName(string a)416 string IrName(string a) {
417 a.erase(std::remove(a.begin(), a.end(), '%'), a.end());
418 return a;
419 }
420
IrName(absl::string_view a,absl::string_view b)421 string IrName(absl::string_view a, absl::string_view b) {
422 if (!a.empty() && !b.empty()) {
423 return IrName(absl::StrCat(a, ".", b));
424 }
425 return IrName(absl::StrCat(a, b));
426 }
427
IrName(const HloInstruction * a,absl::string_view b)428 string IrName(const HloInstruction* a, absl::string_view b) {
429 return IrName(a->name(), b);
430 }
431
SanitizeFunctionName(string function_name)432 string SanitizeFunctionName(string function_name) {
433 // The backend with the strictest requirements on function names is NVPTX, so
434 // we sanitize to its requirements.
435 //
436 // A slightly stricter version of the NVPTX requirements is that names match
437 // /[a-zA-Z_$][a-zA-Z0-9_$]*/, with the exception that the names "_" and "$"
438 // are illegal.
439
440 // Sanitize chars in function_name.
441 std::transform(function_name.begin(), function_name.end(),
442 function_name.begin(), [](char c) {
443 if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') ||
444 ('0' <= c && c <= '9') || c == '_' || c == '$') {
445 return c;
446 }
447 return '_';
448 });
449
450 // Ensure the name isn't empty.
451 if (function_name.empty()) {
452 function_name = "__unnamed";
453 }
454
455 // Ensure the name doesn't start with a number.
456 if (!function_name.empty() && function_name[0] >= '0' &&
457 function_name[0] <= '9') {
458 function_name.insert(function_name.begin(), '_');
459 }
460
461 // Ensure the name isn't "_" or "$".
462 if (function_name == "_" || function_name == "$") {
463 function_name += '_';
464 }
465
466 return function_name;
467 }
468
SetToFirstInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)469 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
470 builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
471 }
472
SetToLastInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)473 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
474 if (llvm::Instruction* terminator = blk->getTerminator()) {
475 builder->SetInsertPoint(terminator);
476 } else {
477 builder->SetInsertPoint(blk);
478 }
479 }
480
CreateRor(llvm::Value * rotand,llvm::Value * rotor,llvm::IRBuilder<> * builder)481 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
482 llvm::IRBuilder<>* builder) {
483 auto size = rotand->getType()->getPrimitiveSizeInBits();
484 auto size_value = builder->getIntN(size, size);
485 auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); };
486 return builder->CreateOr(
487 builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))),
488 builder->CreateLShr(rotand, mod(rotor)));
489 }
490
ByteSizeOf(const Shape & shape,const llvm::DataLayout & data_layout)491 int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
492 unsigned pointer_size = data_layout.getPointerSize();
493 return ShapeUtil::ByteSizeOf(shape, pointer_size);
494 }
495
GetCpuFastMathFlags(const HloModuleConfig & module_config)496 llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) {
497 llvm::FastMathFlags flags;
498 const auto& options = module_config.debug_options();
499 if (!options.xla_cpu_enable_fast_math()) {
500 return flags;
501 }
502 // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, AllowReciprocal,
503 // AllowContract, and ApproxFunc.
504 flags.setFast();
505 flags.setNoNaNs(!options.xla_cpu_fast_math_honor_nans());
506 flags.setNoInfs(!options.xla_cpu_fast_math_honor_infs());
507 flags.setAllowReciprocal(!options.xla_cpu_fast_math_honor_division());
508 flags.setApproxFunc(!options.xla_cpu_fast_math_honor_functions());
509 return flags;
510 }
511
MergeMetadata(llvm::LLVMContext * context,const std::map<int,llvm::MDNode * > & a,const std::map<int,llvm::MDNode * > & b)512 std::map<int, llvm::MDNode*> MergeMetadata(
513 llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
514 const std::map<int, llvm::MDNode*>& b) {
515 // We should extend this as needed to deal with other kinds of metadata like
516 // !dereferenceable and !range.
517
518 std::map<int, llvm::MDNode*> result;
519 for (auto kind_md_pair : a) {
520 if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) {
521 llvm::SmallVector<llvm::Metadata*, 8> union_of_scopes;
522 llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
523 for (const auto& scope_a : kind_md_pair.second->operands()) {
524 scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
525 union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_a.get()));
526 }
527 auto it = b.find(kind_md_pair.first);
528 if (it != b.end()) {
529 for (const auto& scope_b : it->second->operands()) {
530 if (!scope_set.count(llvm::cast<llvm::MDNode>(scope_b.get()))) {
531 union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b.get()));
532 }
533 }
534 }
535 result[llvm::LLVMContext::MD_alias_scope] =
536 llvm::MDNode::get(*context, union_of_scopes);
537 } else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) {
538 llvm::SmallVector<llvm::Metadata*, 8> intersection_of_scopes;
539 llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
540 for (const auto& scope_a : kind_md_pair.second->operands()) {
541 scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
542 }
543 auto it = b.find(kind_md_pair.first);
544 if (it != b.end()) {
545 for (const auto& scope_b : it->second->operands()) {
546 if (scope_set.count(llvm::cast<llvm::MDNode>(scope_b))) {
547 intersection_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b));
548 }
549 }
550 }
551 if (!intersection_of_scopes.empty()) {
552 result[llvm::LLVMContext::MD_noalias] =
553 llvm::MDNode::get(*context, intersection_of_scopes);
554 }
555 }
556 }
557 return result;
558 }
559
CreateAndWriteStringToFile(const string & directory_name,const string & file_name,const string & text)560 static Status CreateAndWriteStringToFile(const string& directory_name,
561 const string& file_name,
562 const string& text) {
563 std::unique_ptr<tensorflow::WritableFile> f;
564 TF_RETURN_IF_ERROR(
565 tensorflow::Env::Default()->RecursivelyCreateDir(directory_name));
566 TF_RETURN_IF_ERROR(
567 tensorflow::Env::Default()->NewWritableFile(file_name, &f));
568 TF_RETURN_IF_ERROR(f->Append(text));
569 TF_RETURN_IF_ERROR(f->Close());
570 return Status::OK();
571 }
572
DumpIrIfEnabled(const HloModule & hlo_module,const llvm::Module & llvm_module,bool optimized)573 void DumpIrIfEnabled(const HloModule& hlo_module,
574 const llvm::Module& llvm_module, bool optimized) {
575 const auto& debug_opts = hlo_module.config().debug_options();
576 if (!DumpingEnabledForHloModule(hlo_module)) {
577 return;
578 }
579 // We can end up compiling different modules with the same name when using
580 // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously
581 // dumped from the same process in such cases.
582 string suffix = absl::StrCat("ir-", optimized ? "with" : "no", "-opt");
583 DumpToFileInDirOrStdout(hlo_module, "", absl::StrCat(suffix, ".ll"),
584 DumpModuleToString(llvm_module));
585
586 // For some models the embedded constants can be huge, so also dump the module
587 // with the constants stripped to get IR that is easier to manipulate. Skip
588 // this if we're dumping to stdout; there's no point in duplicating everything
589 // when writing to the terminal.
590 if (!DumpingToStdout(debug_opts)) {
591 DumpToFileInDir(hlo_module, "", absl::StrCat(suffix, "-noconst.ll"),
592 DumpModuleToString(*DropConstantInitializers(llvm_module)));
593 }
594 }
595
CreateCpuFunction(llvm::FunctionType * function_type,llvm::GlobalValue::LinkageTypes linkage,const HloModuleConfig & module_config,absl::string_view name,llvm::Module * module)596 llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
597 llvm::GlobalValue::LinkageTypes linkage,
598 const HloModuleConfig& module_config,
599 absl::string_view name,
600 llvm::Module* module) {
601 llvm::Function* function =
602 llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
603 function->setCallingConv(llvm::CallingConv::C);
604 function->addFnAttr("no-frame-pointer-elim", "false");
605
606 // Generate unwind information so that GDB can crawl through the stack frames
607 // created by the JIT compiled code.
608 function->setHasUWTable();
609
610 // Tensorflow always flushes denormals to zero, let LLVM know that flushing
611 // denormals is safe. This allows vectorization using ARM's neon instruction
612 // set.
613 function->addFnAttr("denormal-fp-math", "preserve-sign");
614
615 // Add the optimize attribute to the function if optimizing for size. This
616 // controls internal behavior of some optimization passes (e.g. loop
617 // unrolling).
618 if (cpu::options::OptimizeForSizeRequested(module_config)) {
619 function->addFnAttr(llvm::Attribute::OptimizeForSize);
620 }
621
622 return function;
623 }
624
InitializeLLVMCommandLineOptions(const HloModuleConfig & config)625 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
626 auto options = config.debug_options().xla_backend_extra_options();
627 if (!options.empty()) {
628 std::vector<string> fake_argv_storage;
629 fake_argv_storage.push_back("");
630 for (const auto& it : options) {
631 // Skip options the XLA backend itself consumes.
632 if (!absl::StartsWith(it.first, "xla_")) {
633 if (it.second.empty()) {
634 fake_argv_storage.push_back(it.first);
635 } else {
636 fake_argv_storage.push_back(it.first + "=" + it.second);
637 }
638 }
639 }
640
641 VLOG(2) << "Passing argv to LLVM:";
642 std::vector<const char*> fake_argv;
643 for (const auto& s : fake_argv_storage) {
644 fake_argv.push_back(s.c_str());
645 VLOG(2) << s;
646 }
647 llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]);
648 }
649 }
650
UMulLowHigh32(llvm::IRBuilder<> * b,llvm::Value * src0,llvm::Value * src1)651 std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b,
652 llvm::Value* src0,
653 llvm::Value* src1) {
654 CHECK_EQ(src0->getType()->getPrimitiveSizeInBits(), 32);
655 CHECK_EQ(src1->getType()->getPrimitiveSizeInBits(), 32);
656 llvm::Type* int64_ty = b->getInt64Ty();
657 src0 = b->CreateZExt(src0, int64_ty);
658 src1 = b->CreateZExt(src1, int64_ty);
659 return SplitInt64ToInt32s(b, b->CreateMul(src0, src1));
660 }
661
SplitInt64ToInt32s(llvm::IRBuilder<> * b,llvm::Value * value_64bits)662 std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s(
663 llvm::IRBuilder<>* b, llvm::Value* value_64bits) {
664 CHECK_EQ(value_64bits->getType()->getPrimitiveSizeInBits(), 64);
665 llvm::Type* int32_ty = b->getInt32Ty();
666 llvm::Value* low_32bits = b->CreateTrunc(value_64bits, int32_ty);
667 llvm::Value* high_32bits =
668 b->CreateTrunc(b->CreateLShr(value_64bits, 32), int32_ty);
669 return std::make_pair(low_32bits, high_32bits);
670 }
671
GetGlobalMemoryAddressSpace(const llvm::Module & module)672 unsigned GetGlobalMemoryAddressSpace(const llvm::Module& module) {
673 const unsigned kAMDGPUGlobalMemoryAddrSpace = 1;
674 llvm::Triple target_triple = llvm::Triple(module.getTargetTriple());
675 if (target_triple.getArch() == llvm::Triple::amdgcn) {
676 // AMDGPU uses 1 for global memory address space.
677 return kAMDGPUGlobalMemoryAddrSpace;
678 }
679 return 0;
680 }
681
GetOrCreateVariableForRngState(llvm::Module * module,llvm::IRBuilder<> * b)682 llvm::GlobalVariable* GetOrCreateVariableForRngState(llvm::Module* module,
683 llvm::IRBuilder<>* b) {
684 static const char* kRngStateVariableName = "rng_state";
685 llvm::GlobalVariable* state_ptr =
686 module->getNamedGlobal(kRngStateVariableName);
687 if (!state_ptr) {
688 unsigned global_address_space = GetGlobalMemoryAddressSpace(*module);
689 llvm::Type* state_type = b->getInt128Ty();
690 // Use a non-zero initial value as zero state can cause the result of the
691 // first random number generation not passing the chi-square test. The
692 // values used here are arbitrarily chosen, any non-zero values should be
693 // fine.
694 state_ptr = new llvm::GlobalVariable(
695 /*M=*/*module,
696 /*Ty=*/state_type,
697 /*isConstant=*/false,
698 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
699 /*Initializer=*/llvm::ConstantInt::get(b->getInt128Ty(), 0x7012395ull),
700 /*Name=*/kRngStateVariableName,
701 /*InsertBefore=*/nullptr,
702 /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
703 /*AddressSpace=*/global_address_space,
704 /*isExternallyInitialized=*/false);
705 }
706 return state_ptr;
707 }
708
RngGetAndUpdateState(uint64 delta,llvm::Module * module,llvm::IRBuilder<> * builder)709 llvm::Value* RngGetAndUpdateState(uint64 delta, llvm::Module* module,
710 llvm::IRBuilder<>* builder) {
711 llvm::GlobalVariable* state_ptr =
712 GetOrCreateVariableForRngState(module, builder);
713 llvm::LoadInst* state_value_old =
714 builder->CreateLoad(state_ptr, "load_state");
715 llvm::Value* state_value_new = builder->CreateAdd(
716 state_value_old,
717 llvm::ConstantInt::get(state_value_old->getType(), delta));
718 builder->CreateStore(state_value_new, state_ptr);
719 return state_value_old;
720 }
721
722 } // namespace llvm_ir
723 } // namespace xla
724