1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21
22 #include "absl/base/casts.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "llvm/ADT/Triple.h"
26 #include "llvm/IR/DerivedTypes.h"
27 #include "llvm/IR/GlobalValue.h"
28 #include "llvm/IR/GlobalVariable.h"
29 #include "llvm/IR/MDBuilder.h"
30 #include "llvm/IR/Operator.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Target/TargetOptions.h"
33 #include "llvm/Transforms/Utils/Cloning.h"
34 #include "tensorflow/compiler/xla/debug_options_flags.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
38 #include "tensorflow/compiler/xla/service/dump.h"
39 #include "tensorflow/compiler/xla/service/name_uniquer.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/io/path.h"
45 #include "tensorflow/core/platform/byte_order.h"
46 #include "tensorflow/core/platform/env.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/types.h"
49
50 namespace xla {
51 namespace llvm_ir {
52
53 namespace {
54
55 // Note, this function is only useful in an insertion context; in a global
56 // (e.g. constants) context it will CHECK fail.
ModuleFromIRBuilder(llvm::IRBuilder<> * b)57 llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* b) {
58 auto block = CHECK_NOTNULL(b->GetInsertBlock());
59 auto fn = CHECK_NOTNULL(block->getParent());
60 auto module = CHECK_NOTNULL(fn->getParent());
61 return module;
62 }
63
64 } // namespace
65
DropConstantInitializers(const llvm::Module & module)66 std::unique_ptr<llvm::Module> DropConstantInitializers(
67 const llvm::Module& module) {
68 std::unique_ptr<llvm::Module> cloned_module = CloneModule(module);
69 for (llvm::GlobalVariable& global_var : cloned_module->globals()) {
70 global_var.setInitializer(nullptr);
71 global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage);
72 }
73 return cloned_module;
74 }
75
DumpModuleToString(const llvm::Module & module)76 string DumpModuleToString(const llvm::Module& module) {
77 std::string buffer_string;
78 llvm::raw_string_ostream ostream(buffer_string);
79 module.print(ostream, nullptr);
80 ostream.flush();
81 return buffer_string;
82 }
83
EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id,absl::Span<llvm::Value * const> operands,absl::Span<llvm::Type * const> overloaded_types,llvm::IRBuilder<> * b,absl::string_view name)84 llvm::CallInst* EmitCallToIntrinsic(
85 llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
86 absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b,
87 absl::string_view name) {
88 llvm::Module* module = ModuleFromIRBuilder(b);
89 llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
90 module, intrinsic_id, AsArrayRef(overloaded_types));
91 return b->CreateCall(intrinsic, AsArrayRef(operands), name.data());
92 }
93
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b,bool enable_fast_min_max,absl::string_view name)94 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
95 llvm::IRBuilder<>* b, bool enable_fast_min_max,
96 absl::string_view name) {
97 if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
98 auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
99 return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
100 } else {
101 auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value);
102 auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
103 auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan);
104 return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
105 }
106 }
107
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b,bool enable_fast_min_max,absl::string_view name)108 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
109 llvm::IRBuilder<>* b, bool enable_fast_min_max,
110 absl::string_view name) {
111 if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
112 auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
113 return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
114 } else {
115 auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value);
116 auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
117 auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan);
118 return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
119 }
120 }
121
EmitBufferIndexingGEP(llvm::Value * array,llvm::Value * index,llvm::IRBuilder<> * b)122 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index,
123 llvm::IRBuilder<>* b) {
124 llvm::Type* array_type = array->getType();
125 CHECK(array_type->isPointerTy());
126 llvm::PointerType* array_type_as_pointer =
127 llvm::cast<llvm::PointerType>(array_type);
128 VLOG(2) << "EmitBufferIndexingGEP with type="
129 << llvm_ir::DumpToString(*array_type)
130 << " array=" << llvm_ir::DumpToString(*array)
131 << " index=" << llvm_ir::DumpToString(*index);
132
133 return b->CreateInBoundsGEP(
134 array_type_as_pointer->getElementType(), array,
135 llvm::isa<llvm::GlobalVariable>(array)
136 ? llvm::ArrayRef<llvm::Value*>({b->getInt64(0), index})
137 : index);
138 }
139
EmitBufferIndexingGEP(llvm::Value * array,int64_t index,llvm::IRBuilder<> * b)140 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64_t index,
141 llvm::IRBuilder<>* b) {
142 return EmitBufferIndexingGEP(array, b->getInt64(index), b);
143 }
144
PrimitiveTypeToIrType(PrimitiveType element_type,llvm::Module * module)145 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
146 llvm::Module* module) {
147 switch (element_type) {
148 case PRED:
149 case S8:
150 case U8:
151 return llvm::Type::getInt8Ty(module->getContext());
152 case S16:
153 case U16:
154 case BF16:
155 // For BF16 we just need some type that is 16 bits wide so that it will
156 // take up the right amount of space in memory. LLVM does not have a BF16
157 // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so
158 // we can't map it directly to an LLVM type. We will not map a BF16
159 // addition to an addition on this type (int16) - this is just the type
160 // used for storage.
161 return llvm::Type::getInt16Ty(module->getContext());
162 case F16:
163 return llvm::Type::getHalfTy(module->getContext());
164 case S32:
165 case U32:
166 return llvm::Type::getInt32Ty(module->getContext());
167 case S64:
168 case U64:
169 return llvm::Type::getInt64Ty(module->getContext());
170 case F32:
171 return llvm::Type::getFloatTy(module->getContext());
172 case F64:
173 return llvm::Type::getDoubleTy(module->getContext());
174 case C64: {
175 auto cplx_t =
176 llvm::StructType::getTypeByName(module->getContext(), "complex64");
177 if (cplx_t == nullptr) {
178 // C++ standard dictates the memory layout of std::complex is contiguous
179 // real followed by imaginary. C++11 section 26.4 [complex.numbers]:
180 // If z is an lvalue expression of type cv std::complex<T> then the
181 // expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed,
182 // reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of
183 // z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the
184 // imaginary part of z.
185 return llvm::StructType::create(
186 {llvm::Type::getFloatTy(module->getContext()),
187 llvm::Type::getFloatTy(module->getContext())},
188 "complex64", /*isPacked=*/true);
189 }
190 return cplx_t;
191 }
192 case C128: {
193 auto cplx_t =
194 llvm::StructType::getTypeByName(module->getContext(), "complex128");
195 if (cplx_t == nullptr) {
196 return llvm::StructType::create(
197 {llvm::Type::getDoubleTy(module->getContext()),
198 llvm::Type::getDoubleTy(module->getContext())},
199 "complex128", /*isPacked=*/true);
200 }
201 return cplx_t;
202 } // A Tuple contains an array of pointers. Use i8*.
203 case TUPLE:
204 // An Opaque is like a void*, use i8*.
205 case OPAQUE_TYPE:
206 return llvm::Type::getInt8PtrTy(module->getContext());
207 case TOKEN:
208 // Tokens do not have a physical representation, but the compiler needs
209 // some placeholder type, so use int8*.
210 return llvm::Type::getInt8PtrTy(module->getContext());
211 default:
212 LOG(FATAL) << "unsupported type " << element_type;
213 }
214 }
215
GetSizeInBits(llvm::Type * type)216 int GetSizeInBits(llvm::Type* type) {
217 const llvm::StructType* struct_ty = llvm::dyn_cast<llvm::StructType>(type);
218 if (struct_ty) {
219 CHECK(struct_ty->isPacked());
220 int bits = 0;
221 for (auto element_type : struct_ty->elements()) {
222 bits += GetSizeInBits(element_type);
223 }
224 return bits;
225 }
226 int bits = type->getPrimitiveSizeInBits();
227 CHECK_GT(bits, 0) << "type is not sized";
228 return bits;
229 }
230
ShapeToIrType(const Shape & shape,llvm::Module * module)231 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) {
232 llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module);
233 if (shape.IsTuple()) {
234 // A tuple buffer is an array of pointers.
235 result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
236 } else if (shape.IsArray()) {
237 for (int64_t dimension : LayoutUtil::MinorToMajor(shape)) {
238 result_type =
239 llvm::ArrayType::get(result_type, shape.dimensions(dimension));
240 }
241 }
242 return result_type;
243 }
244
EncodeSelfDescribingShapeConstant(const Shape & shape,int32 * shape_size,llvm::IRBuilder<> * b)245 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
246 int32* shape_size,
247 llvm::IRBuilder<>* b) {
248 string encoded_shape = shape.SerializeAsString();
249 if (encoded_shape.size() > std::numeric_limits<int32>::max()) {
250 return InternalError("Encoded shape size exceeded int32 size limit.");
251 }
252 *shape_size = static_cast<int32>(encoded_shape.size());
253 return b->CreateGlobalStringPtr(encoded_shape);
254 }
255
ConvertLiteralToIrConstant(const Literal & literal,llvm::Module * module)256 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
257 llvm::Module* module) {
258 const char* data = static_cast<const char*>(literal.untyped_data());
259 CHECK_EQ(module->getDataLayout().isLittleEndian(),
260 tensorflow::port::kLittleEndian);
261 return llvm::ConstantDataArray::getString(
262 module->getContext(), llvm::StringRef(data, literal.size_bytes()),
263 /*AddNull=*/false);
264 }
265
AllocateSharedMemoryTile(llvm::Module * module,llvm::Type * tile_type,absl::string_view name)266 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
267 llvm::Type* tile_type,
268 absl::string_view name) {
269 // Both AMDGPU and NVPTX use the same address space for shared memory.
270 const int kGPUSharedMemoryAddrSpace = 3;
271 return new llvm::GlobalVariable(
272 *module, tile_type,
273 /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
274 llvm::UndefValue::get(tile_type), AsStringRef(name), nullptr,
275 llvm::GlobalValue::NotThreadLocal, kGPUSharedMemoryAddrSpace);
276 }
277
EmitAllocaAtFunctionEntry(llvm::Type * type,absl::string_view name,llvm::IRBuilder<> * b,int alignment)278 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
279 absl::string_view name,
280 llvm::IRBuilder<>* b,
281 int alignment) {
282 return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment);
283 }
284
EmitAllocaAtFunctionEntryWithCount(llvm::Type * type,llvm::Value * element_count,absl::string_view name,llvm::IRBuilder<> * b,int alignment)285 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
286 llvm::Value* element_count,
287 absl::string_view name,
288 llvm::IRBuilder<>* b,
289 int alignment) {
290 llvm::IRBuilder<>::InsertPointGuard guard(*b);
291 llvm::Function* function = b->GetInsertBlock()->getParent();
292 b->SetInsertPoint(&function->getEntryBlock(),
293 function->getEntryBlock().getFirstInsertionPt());
294 llvm::AllocaInst* alloca =
295 b->CreateAlloca(type, element_count, AsStringRef(name));
296 if (alignment != 0) {
297 alloca->setAlignment(llvm::Align(alignment));
298 }
299 return alloca;
300 }
301
CreateBasicBlock(llvm::BasicBlock * insert_before,absl::string_view name,llvm::IRBuilder<> * b)302 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
303 absl::string_view name,
304 llvm::IRBuilder<>* b) {
305 return llvm::BasicBlock::Create(
306 /*Context=*/b->getContext(),
307 /*Name=*/AsStringRef(name),
308 /*Parent=*/b->GetInsertBlock()->getParent(),
309 /*InsertBefore*/ insert_before);
310 }
311
EmitIfThenElse(llvm::Value * condition,absl::string_view name,llvm::IRBuilder<> * b,bool emit_else)312 LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
313 llvm::IRBuilder<>* b, bool emit_else) {
314 llvm_ir::LlvmIfData if_data;
315 if_data.if_block = b->GetInsertBlock();
316 if_data.true_block =
317 CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b);
318 if_data.false_block =
319 emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b)
320 : nullptr;
321
322 // Add a terminator to the if block, if necessary.
323 if (if_data.if_block->getTerminator() == nullptr) {
324 b->SetInsertPoint(if_data.if_block);
325 if_data.after_block =
326 CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b);
327 b->CreateBr(if_data.after_block);
328 } else {
329 if_data.after_block = if_data.if_block->splitBasicBlock(
330 b->GetInsertPoint(), absl::StrCat(name, "-after"));
331 }
332
333 // Our basic block should now end with an unconditional branch. Remove it;
334 // we're going to replace it with a conditional branch.
335 if_data.if_block->getTerminator()->eraseFromParent();
336
337 b->SetInsertPoint(if_data.if_block);
338 b->CreateCondBr(condition, if_data.true_block,
339 emit_else ? if_data.false_block : if_data.after_block);
340
341 b->SetInsertPoint(if_data.true_block);
342 b->CreateBr(if_data.after_block);
343
344 if (emit_else) {
345 b->SetInsertPoint(if_data.false_block);
346 b->CreateBr(if_data.after_block);
347 }
348
349 b->SetInsertPoint(if_data.after_block,
350 if_data.after_block->getFirstInsertionPt());
351
352 return if_data;
353 }
354
EmitComparison(llvm::CmpInst::Predicate predicate,llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b,absl::string_view name)355 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
356 llvm::Value* lhs_value, llvm::Value* rhs_value,
357 llvm::IRBuilder<>* b, absl::string_view name) {
358 llvm::Value* comparison_result;
359 if (lhs_value->getType()->isIntegerTy()) {
360 comparison_result =
361 b->CreateICmp(predicate, lhs_value, rhs_value, name.data());
362 } else {
363 comparison_result =
364 b->CreateFCmp(predicate, lhs_value, rhs_value, name.data());
365 }
366 // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
367 // arrays. So we extend it to i8 so that it's addressable.
368 return b->CreateZExt(comparison_result, llvm_ir::PrimitiveTypeToIrType(
369 PRED, ModuleFromIRBuilder(b)));
370 }
371
372 // Internal helper that is called from emitted code to log an int64 value with a
373 // tag.
LogS64(const char * tag,int64_t value)374 static void LogS64(const char* tag, int64_t value) {
375 LOG(INFO) << tag << " (int64): " << value;
376 }
377
EmitLogging(const char * tag,llvm::Value * value,llvm::IRBuilder<> * b)378 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b) {
379 llvm::FunctionType* log_function_type = llvm::FunctionType::get(
380 b->getVoidTy(), {b->getInt64Ty(), b->getInt64Ty()}, /*isVarArg=*/false);
381 b->CreateCall(log_function_type,
382 b->CreateIntToPtr(b->getInt64(absl::bit_cast<int64>(&LogS64)),
383 log_function_type->getPointerTo()),
384 {b->getInt64(absl::bit_cast<int64>(tag)), value});
385 }
386
SetAlignmentMetadataForLoad(llvm::LoadInst * load,uint64_t alignment)387 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
388 llvm::LLVMContext& context = load->getContext();
389 llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
390 llvm::Constant* alignment_constant =
391 llvm::ConstantInt::get(int64_ty, alignment);
392 llvm::MDBuilder metadata_builder(context);
393 auto* alignment_metadata =
394 metadata_builder.createConstant(alignment_constant);
395 load->setMetadata(llvm::LLVMContext::MD_align,
396 llvm::MDNode::get(context, alignment_metadata));
397 }
398
SetDereferenceableMetadataForLoad(llvm::LoadInst * load,uint64_t dereferenceable_bytes)399 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
400 uint64_t dereferenceable_bytes) {
401 llvm::LLVMContext& context = load->getContext();
402 llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
403 llvm::Constant* dereferenceable_bytes_constant =
404 llvm::ConstantInt::get(int64_ty, dereferenceable_bytes);
405 llvm::MDBuilder metadata_builder(context);
406 auto* dereferenceable_bytes_metadata =
407 metadata_builder.createConstant(dereferenceable_bytes_constant);
408 load->setMetadata(llvm::LLVMContext::MD_dereferenceable,
409 llvm::MDNode::get(context, dereferenceable_bytes_metadata));
410 }
411
AddRangeMetadata(int64_t lower,int64_t upper,llvm::Instruction * inst)412 llvm::Instruction* AddRangeMetadata(int64_t lower, int64_t upper,
413 llvm::Instruction* inst) {
414 llvm::LLVMContext& context = inst->getParent()->getContext();
415 llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context);
416 inst->setMetadata(
417 llvm::LLVMContext::MD_range,
418 llvm::MDNode::get(
419 context,
420 {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)),
421 llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))}));
422 return inst;
423 }
424
IrName(absl::string_view a)425 string IrName(absl::string_view a) {
426 std::string s(a);
427 s.erase(std::remove(s.begin(), s.end(), '%'), s.end());
428 return s;
429 }
430
IrName(absl::string_view a,absl::string_view b)431 string IrName(absl::string_view a, absl::string_view b) {
432 if (!a.empty() && !b.empty()) {
433 return IrName(absl::StrCat(a, ".", b));
434 }
435 return IrName(absl::StrCat(a, b));
436 }
437
IrName(const HloInstruction * a,absl::string_view b)438 string IrName(const HloInstruction* a, absl::string_view b) {
439 return IrName(a->name(), b);
440 }
441
SanitizeFunctionName(string function_name)442 string SanitizeFunctionName(string function_name) {
443 // The backend with the strictest requirements on function names is NVPTX, so
444 // we sanitize to its requirements.
445 //
446 // A slightly stricter version of the NVPTX requirements is that names match
447 // /[a-zA-Z_$][a-zA-Z0-9_$]*/, with the exception that the names "_" and "$"
448 // are illegal.
449
450 // Sanitize chars in function_name.
451 std::transform(function_name.begin(), function_name.end(),
452 function_name.begin(), [](char c) {
453 if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') ||
454 ('0' <= c && c <= '9') || c == '_' || c == '$') {
455 return c;
456 }
457 return '_';
458 });
459
460 // Ensure the name isn't empty.
461 if (function_name.empty()) {
462 function_name = "__unnamed";
463 }
464
465 // Ensure the name doesn't start with a number.
466 if (!function_name.empty() && function_name[0] >= '0' &&
467 function_name[0] <= '9') {
468 function_name.insert(function_name.begin(), '_');
469 }
470
471 // Ensure the name isn't "_" or "$".
472 if (function_name == "_" || function_name == "$") {
473 function_name += '_';
474 }
475
476 return function_name;
477 }
478
SetToFirstInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)479 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
480 builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
481 }
482
SetToLastInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)483 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
484 if (llvm::Instruction* terminator = blk->getTerminator()) {
485 builder->SetInsertPoint(terminator);
486 } else {
487 builder->SetInsertPoint(blk);
488 }
489 }
490
CreateRor(llvm::Value * rotand,llvm::Value * rotor,llvm::IRBuilder<> * builder)491 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
492 llvm::IRBuilder<>* builder) {
493 auto size = rotand->getType()->getPrimitiveSizeInBits();
494 auto size_value = builder->getIntN(size, size);
495 auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); };
496 return builder->CreateOr(
497 builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))),
498 builder->CreateLShr(rotand, mod(rotor)));
499 }
500
ByteSizeOf(const Shape & shape,const llvm::DataLayout & data_layout)501 int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
502 unsigned pointer_size = data_layout.getPointerSize();
503 return ShapeUtil::ByteSizeOf(shape, pointer_size);
504 }
505
GetCpuFastMathFlags(const HloModuleConfig & module_config)506 llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) {
507 llvm::FastMathFlags flags;
508 const auto& options = module_config.debug_options();
509 if (!options.xla_cpu_enable_fast_math()) {
510 return flags;
511 }
512 // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, AllowReciprocal,
513 // AllowContract, and ApproxFunc.
514 flags.setFast();
515 flags.setNoNaNs(!options.xla_cpu_fast_math_honor_nans());
516 flags.setNoInfs(!options.xla_cpu_fast_math_honor_infs());
517 flags.setAllowReciprocal(!options.xla_cpu_fast_math_honor_division());
518 flags.setApproxFunc(!options.xla_cpu_fast_math_honor_functions());
519 return flags;
520 }
521
MergeMetadata(llvm::LLVMContext * context,const std::map<int,llvm::MDNode * > & a,const std::map<int,llvm::MDNode * > & b)522 std::map<int, llvm::MDNode*> MergeMetadata(
523 llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
524 const std::map<int, llvm::MDNode*>& b) {
525 // We should extend this as needed to deal with other kinds of metadata like
526 // !dereferenceable and !range.
527
528 std::map<int, llvm::MDNode*> result;
529 for (auto kind_md_pair : a) {
530 if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) {
531 llvm::SmallVector<llvm::Metadata*, 8> union_of_scopes;
532 llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
533 for (const auto& scope_a : kind_md_pair.second->operands()) {
534 scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
535 union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_a.get()));
536 }
537 auto it = b.find(kind_md_pair.first);
538 if (it != b.end()) {
539 for (const auto& scope_b : it->second->operands()) {
540 if (!scope_set.count(llvm::cast<llvm::MDNode>(scope_b.get()))) {
541 union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b.get()));
542 }
543 }
544 }
545 result[llvm::LLVMContext::MD_alias_scope] =
546 llvm::MDNode::get(*context, union_of_scopes);
547 } else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) {
548 llvm::SmallVector<llvm::Metadata*, 8> intersection_of_scopes;
549 llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
550 for (const auto& scope_a : kind_md_pair.second->operands()) {
551 scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
552 }
553 auto it = b.find(kind_md_pair.first);
554 if (it != b.end()) {
555 for (const auto& scope_b : it->second->operands()) {
556 if (scope_set.count(llvm::cast<llvm::MDNode>(scope_b))) {
557 intersection_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b));
558 }
559 }
560 }
561 if (!intersection_of_scopes.empty()) {
562 result[llvm::LLVMContext::MD_noalias] =
563 llvm::MDNode::get(*context, intersection_of_scopes);
564 }
565 }
566 }
567 return result;
568 }
569
CreateAndWriteStringToFile(const string & directory_name,const string & file_name,const string & text)570 static Status CreateAndWriteStringToFile(const string& directory_name,
571 const string& file_name,
572 const string& text) {
573 std::unique_ptr<tensorflow::WritableFile> f;
574 TF_RETURN_IF_ERROR(
575 tensorflow::Env::Default()->RecursivelyCreateDir(directory_name));
576 TF_RETURN_IF_ERROR(
577 tensorflow::Env::Default()->NewWritableFile(file_name, &f));
578 TF_RETURN_IF_ERROR(f->Append(text));
579 TF_RETURN_IF_ERROR(f->Close());
580 return Status::OK();
581 }
582
DumpIrIfEnabled(const HloModule & hlo_module,const llvm::Module & llvm_module,bool optimized,absl::string_view filename_suffix)583 void DumpIrIfEnabled(const HloModule& hlo_module,
584 const llvm::Module& llvm_module, bool optimized,
585 absl::string_view filename_suffix) {
586 const auto& debug_opts = hlo_module.config().debug_options();
587 if (!DumpingEnabledForHloModule(hlo_module)) {
588 return;
589 }
590 // We can end up compiling different modules with the same name when using
591 // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously
592 // dumped from the same process in such cases.
593 string suffix =
594 absl::StrCat("ir-", optimized ? "with" : "no", "-opt",
595 filename_suffix.empty() ? "" : ".", filename_suffix);
596 DumpToFileInDirOrStdout(hlo_module, "", absl::StrCat(suffix, ".ll"),
597 DumpModuleToString(llvm_module));
598
599 // For some models the embedded constants can be huge, so also dump the module
600 // with the constants stripped to get IR that is easier to manipulate. Skip
601 // this if we're dumping to stdout; there's no point in duplicating everything
602 // when writing to the terminal.
603 if (!DumpingToStdout(debug_opts)) {
604 DumpToFileInDir(hlo_module, "", absl::StrCat(suffix, "-noconst.ll"),
605 DumpModuleToString(*DropConstantInitializers(llvm_module)));
606 }
607 }
608
DumpIrIfEnabled(mlir::ModuleOp mlir_module,int unique_id,const DebugOptions & debug_options)609 void DumpIrIfEnabled(mlir::ModuleOp mlir_module, int unique_id,
610 const DebugOptions& debug_options) {
611 absl::optional<absl::string_view> module_name;
612 if (llvm::Optional<llvm::StringRef> mlir_module_name =
613 mlir_module.getName()) {
614 module_name = AsStringView(*mlir_module_name);
615 }
616 if (!DumpingEnabledForHloModule(module_name.value_or("<unnamed>"),
617 debug_options)) {
618 return;
619 }
620
621 DumpToFileInDirOrStdout(debug_options, unique_id, module_name.value_or(""),
622 /*file_prefix=*/"",
623 /*file_suffix=*/"lmhlo", DumpToString(mlir_module));
624 }
625
CreateCpuFunction(llvm::FunctionType * function_type,llvm::GlobalValue::LinkageTypes linkage,const HloModuleConfig & module_config,absl::string_view name,llvm::Module * module)626 llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
627 llvm::GlobalValue::LinkageTypes linkage,
628 const HloModuleConfig& module_config,
629 absl::string_view name,
630 llvm::Module* module) {
631 llvm::Function* function =
632 llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
633 function->setCallingConv(llvm::CallingConv::C);
634 function->addFnAttr("no-frame-pointer-elim", "false");
635
636 // Generate unwind information so that GDB can crawl through the stack frames
637 // created by the JIT compiled code.
638 function->setHasUWTable();
639
640 // Tensorflow always flushes denormals to zero, let LLVM know that flushing
641 // denormals is safe. This allows vectorization using ARM's neon instruction
642 // set.
643 function->addFnAttr("denormal-fp-math", "preserve-sign");
644
645 // Add the optimize attribute to the function if optimizing for size. This
646 // controls internal behavior of some optimization passes (e.g. loop
647 // unrolling).
648 if (cpu::options::OptimizeForSizeRequested(module_config)) {
649 function->addFnAttr(llvm::Attribute::OptimizeForSize);
650 }
651
652 return function;
653 }
654
InitializeLLVMCommandLineOptions(const HloModuleConfig & config)655 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
656 auto options = config.debug_options().xla_backend_extra_options();
657 if (!options.empty()) {
658 std::vector<string> fake_argv_storage;
659 fake_argv_storage.push_back("");
660 for (const auto& it : options) {
661 // Skip options the XLA backend itself consumes.
662 if (!absl::StartsWith(it.first, "xla_")) {
663 if (it.second.empty()) {
664 fake_argv_storage.push_back(it.first);
665 } else {
666 fake_argv_storage.push_back(it.first + "=" + it.second);
667 }
668 }
669 }
670
671 VLOG(2) << "Passing argv to LLVM:";
672 std::vector<const char*> fake_argv;
673 for (const auto& s : fake_argv_storage) {
674 fake_argv.push_back(s.c_str());
675 VLOG(2) << s;
676 }
677 llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]);
678 }
679 }
680
UMulLowHigh32(llvm::IRBuilder<> * b,llvm::Value * src0,llvm::Value * src1)681 std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b,
682 llvm::Value* src0,
683 llvm::Value* src1) {
684 CHECK_EQ(src0->getType()->getPrimitiveSizeInBits(), 32);
685 CHECK_EQ(src1->getType()->getPrimitiveSizeInBits(), 32);
686 llvm::Type* int64_ty = b->getInt64Ty();
687 src0 = b->CreateZExt(src0, int64_ty);
688 src1 = b->CreateZExt(src1, int64_ty);
689 return SplitInt64ToInt32s(b, b->CreateMul(src0, src1));
690 }
691
SplitInt64ToInt32s(llvm::IRBuilder<> * b,llvm::Value * value_64bits)692 std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s(
693 llvm::IRBuilder<>* b, llvm::Value* value_64bits) {
694 CHECK_EQ(value_64bits->getType()->getPrimitiveSizeInBits(), 64);
695 llvm::Type* int32_ty = b->getInt32Ty();
696 llvm::Value* low_32bits = b->CreateTrunc(value_64bits, int32_ty);
697 llvm::Value* high_32bits =
698 b->CreateTrunc(b->CreateLShr(value_64bits, 32), int32_ty);
699 return std::make_pair(low_32bits, high_32bits);
700 }
701
GetGlobalMemoryAddressSpace()702 unsigned GetGlobalMemoryAddressSpace() { return 1; }
703
GetOrCreateVariableForRngState(llvm::Module * module,llvm::IRBuilder<> * b)704 llvm::GlobalVariable* GetOrCreateVariableForRngState(llvm::Module* module,
705 llvm::IRBuilder<>* b) {
706 static const char* kRngStateVariableName = "rng_state";
707 llvm::GlobalVariable* state_ptr =
708 module->getNamedGlobal(kRngStateVariableName);
709 if (!state_ptr) {
710 llvm::Type* state_type = b->getInt128Ty();
711 // Use a non-zero initial value as zero state can cause the result of the
712 // first random number generation not passing the chi-square test. The
713 // values used here are arbitrarily chosen, any non-zero values should be
714 // fine.
715 state_ptr = new llvm::GlobalVariable(
716 /*M=*/*module,
717 /*Ty=*/state_type,
718 /*isConstant=*/false,
719 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
720 /*Initializer=*/llvm::ConstantInt::get(b->getInt128Ty(), 0x7012395ull),
721 /*Name=*/kRngStateVariableName,
722 /*InsertBefore=*/nullptr,
723 /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
724 /*AddressSpace=*/GetGlobalMemoryAddressSpace(),
725 /*isExternallyInitialized=*/false);
726 }
727 return state_ptr;
728 }
729
RngGetAndUpdateState(uint64 delta,llvm::Module * module,llvm::IRBuilder<> * builder)730 llvm::Value* RngGetAndUpdateState(uint64 delta, llvm::Module* module,
731 llvm::IRBuilder<>* builder) {
732 llvm::GlobalVariable* state_ptr =
733 GetOrCreateVariableForRngState(module, builder);
734 llvm::LoadInst* state_value_old =
735 builder->CreateLoad(state_ptr, "load_state");
736 llvm::Value* state_value_new = builder->CreateAdd(
737 state_value_old,
738 llvm::ConstantInt::get(state_value_old->getType(), delta));
739 builder->CreateStore(state_value_new, state_ptr);
740 return state_value_old;
741 }
742
743 } // namespace llvm_ir
744 } // namespace xla
745