1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21
22 #include "absl/base/casts.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "llvm/ADT/Triple.h"
26 #include "llvm/IR/DerivedTypes.h"
27 #include "llvm/IR/GlobalValue.h"
28 #include "llvm/IR/GlobalVariable.h"
29 #include "llvm/IR/MDBuilder.h"
30 #include "llvm/IR/Operator.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Target/TargetOptions.h"
33 #include "llvm/Transforms/Utils/Cloning.h"
34 #include "tensorflow/compiler/xla/debug_options_flags.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
38 #include "tensorflow/compiler/xla/service/dump.h"
39 #include "tensorflow/compiler/xla/service/name_uniquer.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/io/path.h"
45 #include "tensorflow/core/platform/byte_order.h"
46 #include "tensorflow/core/platform/env.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/types.h"
49
50 namespace xla {
51 namespace llvm_ir {
52
53 namespace {
54
55 // Note, this function is only useful in an insertion context; in a global
56 // (e.g. constants) context it will CHECK fail.
ModuleFromIRBuilder(llvm::IRBuilder<> * b)57 llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* b) {
58 auto block = CHECK_NOTNULL(b->GetInsertBlock());
59 auto fn = CHECK_NOTNULL(block->getParent());
60 auto module = CHECK_NOTNULL(fn->getParent());
61 return module;
62 }
63
64 } // namespace
65
DropConstantInitializers(const llvm::Module & module)66 std::unique_ptr<llvm::Module> DropConstantInitializers(
67 const llvm::Module& module) {
68 std::unique_ptr<llvm::Module> cloned_module = CloneModule(module);
69 for (llvm::GlobalVariable& global_var : cloned_module->globals()) {
70 global_var.setInitializer(nullptr);
71 global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage);
72 }
73 return cloned_module;
74 }
75
DumpModuleToString(const llvm::Module & module)76 string DumpModuleToString(const llvm::Module& module) {
77 std::string buffer_string;
78 llvm::raw_string_ostream ostream(buffer_string);
79 module.print(ostream, nullptr);
80 ostream.flush();
81 return buffer_string;
82 }
83
EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id,absl::Span<llvm::Value * const> operands,absl::Span<llvm::Type * const> overloaded_types,llvm::IRBuilder<> * b,absl::string_view name)84 llvm::CallInst* EmitCallToIntrinsic(
85 llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
86 absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b,
87 absl::string_view name) {
88 llvm::Module* module = ModuleFromIRBuilder(b);
89 llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
90 module, intrinsic_id, AsArrayRef(overloaded_types));
91 return b->CreateCall(intrinsic, AsArrayRef(operands), name.data());
92 }
93
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b,bool enable_fast_min_max,absl::string_view name)94 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
95 llvm::IRBuilder<>* b, bool enable_fast_min_max,
96 absl::string_view name) {
97 if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
98 auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
99 return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
100 } else {
101 auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value);
102 auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
103 auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan);
104 return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
105 }
106 }
107
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b,bool enable_fast_min_max,absl::string_view name)108 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
109 llvm::IRBuilder<>* b, bool enable_fast_min_max,
110 absl::string_view name) {
111 if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
112 auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
113 return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
114 } else {
115 auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value);
116 auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
117 auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan);
118 return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
119 }
120 }
121
EmitBufferIndexingGEP(llvm::Value * array,llvm::Value * index,llvm::IRBuilder<> * b)122 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index,
123 llvm::IRBuilder<>* b) {
124 llvm::Type* array_type = array->getType();
125 CHECK(array_type->isPointerTy());
126 llvm::PointerType* array_type_as_pointer =
127 llvm::cast<llvm::PointerType>(array_type);
128 VLOG(2) << "EmitBufferIndexingGEP with type="
129 << llvm_ir::DumpToString(*array_type)
130 << " array=" << llvm_ir::DumpToString(*array)
131 << " index=" << llvm_ir::DumpToString(*index);
132
133 return b->CreateInBoundsGEP(
134 array_type_as_pointer->getElementType(), array,
135 llvm::isa<llvm::GlobalVariable>(array)
136 ? llvm::ArrayRef<llvm::Value*>({b->getInt64(0), index})
137 : index);
138 }
139
EmitBufferIndexingGEP(llvm::Value * array,int64 index,llvm::IRBuilder<> * b)140 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
141 llvm::IRBuilder<>* b) {
142 return EmitBufferIndexingGEP(array, b->getInt64(index), b);
143 }
144
PrimitiveTypeToIrType(PrimitiveType element_type,llvm::Module * module)145 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
146 llvm::Module* module) {
147 switch (element_type) {
148 case PRED:
149 case S8:
150 case U8:
151 return llvm::Type::getInt8Ty(module->getContext());
152 case S16:
153 case U16:
154 case BF16:
155 // For BF16 we just need some type that is 16 bits wide so that it will
156 // take up the right amount of space in memory. LLVM does not have a BF16
157 // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so
158 // we can't map it directly to an LLVM type. We will not map a BF16
159 // addition to an addition on this type (int16) - this is just the type
160 // used for storage.
161 return llvm::Type::getInt16Ty(module->getContext());
162 case F16:
163 return llvm::Type::getHalfTy(module->getContext());
164 case S32:
165 case U32:
166 return llvm::Type::getInt32Ty(module->getContext());
167 case S64:
168 case U64:
169 return llvm::Type::getInt64Ty(module->getContext());
170 case F32:
171 return llvm::Type::getFloatTy(module->getContext());
172 case F64:
173 return llvm::Type::getDoubleTy(module->getContext());
174 case C64: {
175 auto cplx_t =
176 llvm::StructType::getTypeByName(module->getContext(), "complex64");
177 if (cplx_t == nullptr) {
178 // C++ standard dictates the memory layout of std::complex is contiguous
179 // real followed by imaginary. C++11 section 26.4 [complex.numbers]:
180 // If z is an lvalue expression of type cv std::complex<T> then the
181 // expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed,
182 // reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of
183 // z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the
184 // imaginary part of z.
185 return llvm::StructType::create(
186 {llvm::Type::getFloatTy(module->getContext()),
187 llvm::Type::getFloatTy(module->getContext())},
188 "complex64", /*isPacked=*/true);
189 }
190 return cplx_t;
191 }
192 case C128: {
193 auto cplx_t =
194 llvm::StructType::getTypeByName(module->getContext(), "complex128");
195 if (cplx_t == nullptr) {
196 return llvm::StructType::create(
197 {llvm::Type::getDoubleTy(module->getContext()),
198 llvm::Type::getDoubleTy(module->getContext())},
199 "complex128", /*isPacked=*/true);
200 }
201 return cplx_t;
202 } // A Tuple contains an array of pointers. Use i8*.
203 case TUPLE:
204 // An Opaque is like a void*, use i8*.
205 case OPAQUE_TYPE:
206 return llvm::Type::getInt8PtrTy(module->getContext());
207 case TOKEN:
208 // Tokens do not have a physical representation, but the compiler needs
209 // some placeholder type, so use int8*.
210 return llvm::Type::getInt8PtrTy(module->getContext());
211 default:
212 LOG(FATAL) << "unsupported type " << element_type;
213 }
214 }
215
GetSizeInBits(llvm::Type * type)216 int GetSizeInBits(llvm::Type* type) {
217 const llvm::StructType* struct_ty = llvm::dyn_cast<llvm::StructType>(type);
218 if (struct_ty) {
219 CHECK(struct_ty->isPacked());
220 int bits = 0;
221 for (auto element_type : struct_ty->elements()) {
222 bits += GetSizeInBits(element_type);
223 }
224 return bits;
225 }
226 int bits = type->getPrimitiveSizeInBits();
227 CHECK_GT(bits, 0) << "type is not sized";
228 return bits;
229 }
230
ShapeToIrType(const Shape & shape,llvm::Module * module)231 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) {
232 llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module);
233 if (shape.IsTuple()) {
234 // A tuple buffer is an array of pointers.
235 result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
236 } else if (shape.IsArray()) {
237 for (int64 dimension : LayoutUtil::MinorToMajor(shape)) {
238 result_type =
239 llvm::ArrayType::get(result_type, shape.dimensions(dimension));
240 }
241 }
242 return result_type;
243 }
244
EncodeSelfDescribingShapeConstant(const Shape & shape,int32 * shape_size,llvm::IRBuilder<> * b)245 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
246 int32* shape_size,
247 llvm::IRBuilder<>* b) {
248 string encoded_shape = shape.SerializeAsString();
249 if (encoded_shape.size() > std::numeric_limits<int32>::max()) {
250 return InternalError("Encoded shape size exceeded int32 size limit.");
251 }
252 *shape_size = static_cast<int32>(encoded_shape.size());
253 return b->CreateGlobalStringPtr(encoded_shape);
254 }
255
ConvertLiteralToIrConstant(const Literal & literal,llvm::Module * module)256 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
257 llvm::Module* module) {
258 const char* data = static_cast<const char*>(literal.untyped_data());
259 CHECK_EQ(module->getDataLayout().isLittleEndian(),
260 tensorflow::port::kLittleEndian);
261 return llvm::ConstantDataArray::getString(
262 module->getContext(), llvm::StringRef(data, literal.size_bytes()),
263 /*AddNull=*/false);
264 }
265
AllocateSharedMemoryTile(llvm::Module * module,llvm::Type * tile_type,absl::string_view name)266 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
267 llvm::Type* tile_type,
268 absl::string_view name) {
269 // Both AMDGPU and NVPTX use the same address space for shared memory.
270 const int kGPUSharedMemoryAddrSpace = 3;
271 return new llvm::GlobalVariable(
272 *module, tile_type,
273 /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
274 llvm::UndefValue::get(tile_type), AsStringRef(name), nullptr,
275 llvm::GlobalValue::NotThreadLocal, kGPUSharedMemoryAddrSpace);
276 }
277
EmitAllocaAtFunctionEntry(llvm::Type * type,absl::string_view name,llvm::IRBuilder<> * b,int alignment)278 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
279 absl::string_view name,
280 llvm::IRBuilder<>* b,
281 int alignment) {
282 return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment);
283 }
284
EmitAllocaAtFunctionEntryWithCount(llvm::Type * type,llvm::Value * element_count,absl::string_view name,llvm::IRBuilder<> * b,int alignment)285 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
286 llvm::Value* element_count,
287 absl::string_view name,
288 llvm::IRBuilder<>* b,
289 int alignment) {
290 llvm::IRBuilder<>::InsertPointGuard guard(*b);
291 llvm::Function* function = b->GetInsertBlock()->getParent();
292 b->SetInsertPoint(&function->getEntryBlock(),
293 function->getEntryBlock().getFirstInsertionPt());
294 llvm::AllocaInst* alloca =
295 b->CreateAlloca(type, element_count, AsStringRef(name));
296 if (alignment != 0) {
297 alloca->setAlignment(llvm::Align(alignment));
298 }
299 return alloca;
300 }
301
CreateBasicBlock(llvm::BasicBlock * insert_before,absl::string_view name,llvm::IRBuilder<> * b)302 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
303 absl::string_view name,
304 llvm::IRBuilder<>* b) {
305 return llvm::BasicBlock::Create(
306 /*Context=*/b->getContext(),
307 /*Name=*/AsStringRef(name),
308 /*Parent=*/b->GetInsertBlock()->getParent(),
309 /*InsertBefore*/ insert_before);
310 }
311
EmitIfThenElse(llvm::Value * condition,absl::string_view name,llvm::IRBuilder<> * b,bool emit_else)312 LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
313 llvm::IRBuilder<>* b, bool emit_else) {
314 llvm_ir::LlvmIfData if_data;
315 if_data.if_block = b->GetInsertBlock();
316 if_data.true_block =
317 CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b);
318 if_data.false_block =
319 emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b)
320 : nullptr;
321
322 // Add a terminator to the if block, if necessary.
323 if (if_data.if_block->getTerminator() == nullptr) {
324 b->SetInsertPoint(if_data.if_block);
325 if_data.after_block =
326 CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b);
327 b->CreateBr(if_data.after_block);
328 } else {
329 if_data.after_block = if_data.if_block->splitBasicBlock(
330 b->GetInsertPoint(), absl::StrCat(name, "-after"));
331 }
332
333 // Our basic block should now end with an unconditional branch. Remove it;
334 // we're going to replace it with a conditional branch.
335 if_data.if_block->getTerminator()->eraseFromParent();
336
337 b->SetInsertPoint(if_data.if_block);
338 b->CreateCondBr(condition, if_data.true_block,
339 emit_else ? if_data.false_block : if_data.after_block);
340
341 b->SetInsertPoint(if_data.true_block);
342 b->CreateBr(if_data.after_block);
343
344 if (emit_else) {
345 b->SetInsertPoint(if_data.false_block);
346 b->CreateBr(if_data.after_block);
347 }
348
349 b->SetInsertPoint(if_data.after_block,
350 if_data.after_block->getFirstInsertionPt());
351
352 return if_data;
353 }
354
EmitComparison(llvm::CmpInst::Predicate predicate,llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b,absl::string_view name)355 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
356 llvm::Value* lhs_value, llvm::Value* rhs_value,
357 llvm::IRBuilder<>* b, absl::string_view name) {
358 llvm::Value* comparison_result;
359 if (lhs_value->getType()->isIntegerTy()) {
360 comparison_result =
361 b->CreateICmp(predicate, lhs_value, rhs_value, name.data());
362 } else {
363 comparison_result =
364 b->CreateFCmp(predicate, lhs_value, rhs_value, name.data());
365 }
366 // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
367 // arrays. So we extend it to i8 so that it's addressable.
368 return b->CreateZExt(comparison_result, llvm_ir::PrimitiveTypeToIrType(
369 PRED, ModuleFromIRBuilder(b)));
370 }
371
372 // Internal helper that is called from emitted code to log an int64 value with a
373 // tag.
LogS64(const char * tag,int64 value)374 static void LogS64(const char* tag, int64 value) {
375 LOG(INFO) << tag << " (int64): " << value;
376 }
377
EmitLogging(const char * tag,llvm::Value * value,llvm::IRBuilder<> * b)378 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b) {
379 llvm::FunctionType* log_function_type = llvm::FunctionType::get(
380 b->getVoidTy(), {b->getInt64Ty(), b->getInt64Ty()}, /*isVarArg=*/false);
381 b->CreateCall(log_function_type,
382 b->CreateIntToPtr(b->getInt64(absl::bit_cast<int64>(&LogS64)),
383 log_function_type->getPointerTo()),
384 {b->getInt64(absl::bit_cast<int64>(tag)), value});
385 }
386
SetAlignmentMetadataForLoad(llvm::LoadInst * load,uint64_t alignment)387 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
388 llvm::LLVMContext& context = load->getContext();
389 llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
390 llvm::Constant* alignment_constant =
391 llvm::ConstantInt::get(int64_ty, alignment);
392 llvm::MDBuilder metadata_builder(context);
393 auto* alignment_metadata =
394 metadata_builder.createConstant(alignment_constant);
395 load->setMetadata(llvm::LLVMContext::MD_align,
396 llvm::MDNode::get(context, alignment_metadata));
397 }
398
SetDereferenceableMetadataForLoad(llvm::LoadInst * load,uint64_t dereferenceable_bytes)399 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
400 uint64_t dereferenceable_bytes) {
401 llvm::LLVMContext& context = load->getContext();
402 llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
403 llvm::Constant* dereferenceable_bytes_constant =
404 llvm::ConstantInt::get(int64_ty, dereferenceable_bytes);
405 llvm::MDBuilder metadata_builder(context);
406 auto* dereferenceable_bytes_metadata =
407 metadata_builder.createConstant(dereferenceable_bytes_constant);
408 load->setMetadata(llvm::LLVMContext::MD_dereferenceable,
409 llvm::MDNode::get(context, dereferenceable_bytes_metadata));
410 }
411
AddRangeMetadata(int64 lower,int64 upper,llvm::Instruction * inst)412 llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
413 llvm::Instruction* inst) {
414 llvm::LLVMContext& context = inst->getParent()->getContext();
415 llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context);
416 inst->setMetadata(
417 llvm::LLVMContext::MD_range,
418 llvm::MDNode::get(
419 context,
420 {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)),
421 llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))}));
422 return inst;
423 }
424
IrName(absl::string_view a)425 string IrName(absl::string_view a) {
426 std::string s(a);
427 s.erase(std::remove(s.begin(), s.end(), '%'), s.end());
428 return s;
429 }
430
IrName(absl::string_view a,absl::string_view b)431 string IrName(absl::string_view a, absl::string_view b) {
432 if (!a.empty() && !b.empty()) {
433 return IrName(absl::StrCat(a, ".", b));
434 }
435 return IrName(absl::StrCat(a, b));
436 }
437
IrName(const HloInstruction * a,absl::string_view b)438 string IrName(const HloInstruction* a, absl::string_view b) {
439 return IrName(a->name(), b);
440 }
441
SanitizeFunctionName(string function_name)442 string SanitizeFunctionName(string function_name) {
443 // The backend with the strictest requirements on function names is NVPTX, so
444 // we sanitize to its requirements.
445 //
446 // A slightly stricter version of the NVPTX requirements is that names match
447 // /[a-zA-Z_$][a-zA-Z0-9_$]*/, with the exception that the names "_" and "$"
448 // are illegal.
449
450 // Sanitize chars in function_name.
451 std::transform(function_name.begin(), function_name.end(),
452 function_name.begin(), [](char c) {
453 if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') ||
454 ('0' <= c && c <= '9') || c == '_' || c == '$') {
455 return c;
456 }
457 return '_';
458 });
459
460 // Ensure the name isn't empty.
461 if (function_name.empty()) {
462 function_name = "__unnamed";
463 }
464
465 // Ensure the name doesn't start with a number.
466 if (!function_name.empty() && function_name[0] >= '0' &&
467 function_name[0] <= '9') {
468 function_name.insert(function_name.begin(), '_');
469 }
470
471 // Ensure the name isn't "_" or "$".
472 if (function_name == "_" || function_name == "$") {
473 function_name += '_';
474 }
475
476 return function_name;
477 }
478
SetToFirstInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)479 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
480 builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
481 }
482
SetToLastInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)483 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
484 if (llvm::Instruction* terminator = blk->getTerminator()) {
485 builder->SetInsertPoint(terminator);
486 } else {
487 builder->SetInsertPoint(blk);
488 }
489 }
490
CreateRor(llvm::Value * rotand,llvm::Value * rotor,llvm::IRBuilder<> * builder)491 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
492 llvm::IRBuilder<>* builder) {
493 auto size = rotand->getType()->getPrimitiveSizeInBits();
494 auto size_value = builder->getIntN(size, size);
495 auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); };
496 return builder->CreateOr(
497 builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))),
498 builder->CreateLShr(rotand, mod(rotor)));
499 }
500
ByteSizeOf(const Shape & shape,const llvm::DataLayout & data_layout)501 int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
502 unsigned pointer_size = data_layout.getPointerSize();
503 return ShapeUtil::ByteSizeOf(shape, pointer_size);
504 }
505
GetCpuFastMathFlags(const HloModuleConfig & module_config)506 llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) {
507 llvm::FastMathFlags flags;
508 const auto& options = module_config.debug_options();
509 if (!options.xla_cpu_enable_fast_math()) {
510 return flags;
511 }
512 // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, AllowReciprocal,
513 // AllowContract, and ApproxFunc.
514 flags.setFast();
515 flags.setNoNaNs(!options.xla_cpu_fast_math_honor_nans());
516 flags.setNoInfs(!options.xla_cpu_fast_math_honor_infs());
517 flags.setAllowReciprocal(!options.xla_cpu_fast_math_honor_division());
518 flags.setApproxFunc(!options.xla_cpu_fast_math_honor_functions());
519 return flags;
520 }
521
MergeMetadata(llvm::LLVMContext * context,const std::map<int,llvm::MDNode * > & a,const std::map<int,llvm::MDNode * > & b)522 std::map<int, llvm::MDNode*> MergeMetadata(
523 llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
524 const std::map<int, llvm::MDNode*>& b) {
525 // We should extend this as needed to deal with other kinds of metadata like
526 // !dereferenceable and !range.
527
528 std::map<int, llvm::MDNode*> result;
529 for (auto kind_md_pair : a) {
530 if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) {
531 llvm::SmallVector<llvm::Metadata*, 8> union_of_scopes;
532 llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
533 for (const auto& scope_a : kind_md_pair.second->operands()) {
534 scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
535 union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_a.get()));
536 }
537 auto it = b.find(kind_md_pair.first);
538 if (it != b.end()) {
539 for (const auto& scope_b : it->second->operands()) {
540 if (!scope_set.count(llvm::cast<llvm::MDNode>(scope_b.get()))) {
541 union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b.get()));
542 }
543 }
544 }
545 result[llvm::LLVMContext::MD_alias_scope] =
546 llvm::MDNode::get(*context, union_of_scopes);
547 } else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) {
548 llvm::SmallVector<llvm::Metadata*, 8> intersection_of_scopes;
549 llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
550 for (const auto& scope_a : kind_md_pair.second->operands()) {
551 scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
552 }
553 auto it = b.find(kind_md_pair.first);
554 if (it != b.end()) {
555 for (const auto& scope_b : it->second->operands()) {
556 if (scope_set.count(llvm::cast<llvm::MDNode>(scope_b))) {
557 intersection_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b));
558 }
559 }
560 }
561 if (!intersection_of_scopes.empty()) {
562 result[llvm::LLVMContext::MD_noalias] =
563 llvm::MDNode::get(*context, intersection_of_scopes);
564 }
565 }
566 }
567 return result;
568 }
569
CreateAndWriteStringToFile(const string & directory_name,const string & file_name,const string & text)570 static Status CreateAndWriteStringToFile(const string& directory_name,
571 const string& file_name,
572 const string& text) {
573 std::unique_ptr<tensorflow::WritableFile> f;
574 TF_RETURN_IF_ERROR(
575 tensorflow::Env::Default()->RecursivelyCreateDir(directory_name));
576 TF_RETURN_IF_ERROR(
577 tensorflow::Env::Default()->NewWritableFile(file_name, &f));
578 TF_RETURN_IF_ERROR(f->Append(text));
579 TF_RETURN_IF_ERROR(f->Close());
580 return Status::OK();
581 }
582
DumpIrIfEnabled(const HloModule & hlo_module,const llvm::Module & llvm_module,bool optimized,absl::string_view filename_suffix)583 void DumpIrIfEnabled(const HloModule& hlo_module,
584 const llvm::Module& llvm_module, bool optimized,
585 absl::string_view filename_suffix) {
586 const auto& debug_opts = hlo_module.config().debug_options();
587 if (!DumpingEnabledForHloModule(hlo_module)) {
588 return;
589 }
590 // We can end up compiling different modules with the same name when using
591 // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously
592 // dumped from the same process in such cases.
593 string suffix = absl::StrCat("ir-", optimized ? "with" : "no", "-opt");
594 DumpToFileInDirOrStdout(
595 hlo_module, "",
596 absl::StrCat(suffix, filename_suffix.empty() ? "" : ".", filename_suffix,
597 ".ll"),
598 DumpModuleToString(llvm_module));
599
600 // For some models the embedded constants can be huge, so also dump the module
601 // with the constants stripped to get IR that is easier to manipulate. Skip
602 // this if we're dumping to stdout; there's no point in duplicating everything
603 // when writing to the terminal.
604 if (!DumpingToStdout(debug_opts)) {
605 DumpToFileInDir(hlo_module, "", absl::StrCat(suffix, "-noconst.ll"),
606 DumpModuleToString(*DropConstantInitializers(llvm_module)));
607 }
608 }
609
CreateCpuFunction(llvm::FunctionType * function_type,llvm::GlobalValue::LinkageTypes linkage,const HloModuleConfig & module_config,absl::string_view name,llvm::Module * module)610 llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
611 llvm::GlobalValue::LinkageTypes linkage,
612 const HloModuleConfig& module_config,
613 absl::string_view name,
614 llvm::Module* module) {
615 llvm::Function* function =
616 llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
617 function->setCallingConv(llvm::CallingConv::C);
618 function->addFnAttr("no-frame-pointer-elim", "false");
619
620 // Generate unwind information so that GDB can crawl through the stack frames
621 // created by the JIT compiled code.
622 function->setHasUWTable();
623
624 // Tensorflow always flushes denormals to zero, let LLVM know that flushing
625 // denormals is safe. This allows vectorization using ARM's neon instruction
626 // set.
627 function->addFnAttr("denormal-fp-math", "preserve-sign");
628
629 // Add the optimize attribute to the function if optimizing for size. This
630 // controls internal behavior of some optimization passes (e.g. loop
631 // unrolling).
632 if (cpu::options::OptimizeForSizeRequested(module_config)) {
633 function->addFnAttr(llvm::Attribute::OptimizeForSize);
634 }
635
636 return function;
637 }
638
InitializeLLVMCommandLineOptions(const HloModuleConfig & config)639 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
640 auto options = config.debug_options().xla_backend_extra_options();
641 if (!options.empty()) {
642 std::vector<string> fake_argv_storage;
643 fake_argv_storage.push_back("");
644 for (const auto& it : options) {
645 // Skip options the XLA backend itself consumes.
646 if (!absl::StartsWith(it.first, "xla_")) {
647 if (it.second.empty()) {
648 fake_argv_storage.push_back(it.first);
649 } else {
650 fake_argv_storage.push_back(it.first + "=" + it.second);
651 }
652 }
653 }
654
655 VLOG(2) << "Passing argv to LLVM:";
656 std::vector<const char*> fake_argv;
657 for (const auto& s : fake_argv_storage) {
658 fake_argv.push_back(s.c_str());
659 VLOG(2) << s;
660 }
661 llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]);
662 }
663 }
664
UMulLowHigh32(llvm::IRBuilder<> * b,llvm::Value * src0,llvm::Value * src1)665 std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b,
666 llvm::Value* src0,
667 llvm::Value* src1) {
668 CHECK_EQ(src0->getType()->getPrimitiveSizeInBits(), 32);
669 CHECK_EQ(src1->getType()->getPrimitiveSizeInBits(), 32);
670 llvm::Type* int64_ty = b->getInt64Ty();
671 src0 = b->CreateZExt(src0, int64_ty);
672 src1 = b->CreateZExt(src1, int64_ty);
673 return SplitInt64ToInt32s(b, b->CreateMul(src0, src1));
674 }
675
SplitInt64ToInt32s(llvm::IRBuilder<> * b,llvm::Value * value_64bits)676 std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s(
677 llvm::IRBuilder<>* b, llvm::Value* value_64bits) {
678 CHECK_EQ(value_64bits->getType()->getPrimitiveSizeInBits(), 64);
679 llvm::Type* int32_ty = b->getInt32Ty();
680 llvm::Value* low_32bits = b->CreateTrunc(value_64bits, int32_ty);
681 llvm::Value* high_32bits =
682 b->CreateTrunc(b->CreateLShr(value_64bits, 32), int32_ty);
683 return std::make_pair(low_32bits, high_32bits);
684 }
685
GetGlobalMemoryAddressSpace(const llvm::Module & module)686 unsigned GetGlobalMemoryAddressSpace(const llvm::Module& module) {
687 const unsigned kAMDGPUGlobalMemoryAddrSpace = 1;
688 llvm::Triple target_triple = llvm::Triple(module.getTargetTriple());
689 if (target_triple.getArch() == llvm::Triple::amdgcn) {
690 // AMDGPU uses 1 for global memory address space.
691 return kAMDGPUGlobalMemoryAddrSpace;
692 }
693 return 0;
694 }
695
GetOrCreateVariableForRngState(llvm::Module * module,llvm::IRBuilder<> * b)696 llvm::GlobalVariable* GetOrCreateVariableForRngState(llvm::Module* module,
697 llvm::IRBuilder<>* b) {
698 static const char* kRngStateVariableName = "rng_state";
699 llvm::GlobalVariable* state_ptr =
700 module->getNamedGlobal(kRngStateVariableName);
701 if (!state_ptr) {
702 unsigned global_address_space = GetGlobalMemoryAddressSpace(*module);
703 llvm::Type* state_type = b->getInt128Ty();
704 // Use a non-zero initial value as zero state can cause the result of the
705 // first random number generation not passing the chi-square test. The
706 // values used here are arbitrarily chosen, any non-zero values should be
707 // fine.
708 state_ptr = new llvm::GlobalVariable(
709 /*M=*/*module,
710 /*Ty=*/state_type,
711 /*isConstant=*/false,
712 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
713 /*Initializer=*/llvm::ConstantInt::get(b->getInt128Ty(), 0x7012395ull),
714 /*Name=*/kRngStateVariableName,
715 /*InsertBefore=*/nullptr,
716 /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
717 /*AddressSpace=*/global_address_space,
718 /*isExternallyInitialized=*/false);
719 }
720 return state_ptr;
721 }
722
RngGetAndUpdateState(uint64 delta,llvm::Module * module,llvm::IRBuilder<> * builder)723 llvm::Value* RngGetAndUpdateState(uint64 delta, llvm::Module* module,
724 llvm::IRBuilder<>* builder) {
725 llvm::GlobalVariable* state_ptr =
726 GetOrCreateVariableForRngState(module, builder);
727 llvm::LoadInst* state_value_old =
728 builder->CreateLoad(state_ptr, "load_state");
729 llvm::Value* state_value_new = builder->CreateAdd(
730 state_value_old,
731 llvm::ConstantInt::get(state_value_old->getType(), delta));
732 builder->CreateStore(state_value_new, state_ptr);
733 return state_value_old;
734 }
735
736 } // namespace llvm_ir
737 } // namespace xla
738