• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/base/casts.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/GlobalValue.h"
27 #include "llvm/IR/GlobalVariable.h"
28 #include "llvm/IR/MDBuilder.h"
29 #include "llvm/IR/Operator.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Target/TargetOptions.h"
32 #include "llvm/Transforms/Utils/Cloning.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/literal.h"
35 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
36 #include "tensorflow/compiler/xla/service/dump.h"
37 #include "tensorflow/compiler/xla/service/name_uniquer.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/io/path.h"
43 #include "tensorflow/core/platform/byte_order.h"
44 #include "tensorflow/core/platform/env.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/types.h"
47 
48 namespace xla {
49 namespace llvm_ir {
50 
51 namespace {
52 
53 // Note, this function is only useful in an insertion context; in a global
54 // (e.g. constants) context it will CHECK fail.
ModuleFromIRBuilder(llvm::IRBuilder<> * b)55 llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* b) {
56   auto block = CHECK_NOTNULL(b->GetInsertBlock());
57   auto fn = CHECK_NOTNULL(block->getParent());
58   auto module = CHECK_NOTNULL(fn->getParent());
59   return module;
60 }
61 
62 }  // namespace
63 
DropConstantInitializers(const llvm::Module & module)64 std::unique_ptr<llvm::Module> DropConstantInitializers(
65     const llvm::Module& module) {
66   std::unique_ptr<llvm::Module> cloned_module = CloneModule(module);
67   for (llvm::GlobalVariable& global_var : cloned_module->globals()) {
68     global_var.setInitializer(nullptr);
69     global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage);
70   }
71   return cloned_module;
72 }
73 
DumpModuleToString(const llvm::Module & module)74 string DumpModuleToString(const llvm::Module& module) {
75   std::string buffer_string;
76   llvm::raw_string_ostream ostream(buffer_string);
77   module.print(ostream, nullptr);
78   ostream.flush();
79   return buffer_string;
80 }
81 
EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id,absl::Span<llvm::Value * const> operands,absl::Span<llvm::Type * const> overloaded_types,llvm::IRBuilder<> * b)82 llvm::CallInst* EmitCallToIntrinsic(
83     llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
84     absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b) {
85   llvm::Module* module = ModuleFromIRBuilder(b);
86   llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
87       module, intrinsic_id, AsArrayRef(overloaded_types));
88   return b->CreateCall(intrinsic, AsArrayRef(operands));
89 }
90 
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)91 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
92                           llvm::IRBuilder<>* b) {
93   if (b->getFastMathFlags().noNaNs()) {
94     auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
95     return b->CreateSelect(cmp, lhs_value, rhs_value);
96   } else {
97     auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value);
98     auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
99     auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan);
100     return b->CreateSelect(sel_lhs, lhs_value, rhs_value);
101   }
102 }
103 
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)104 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
105                           llvm::IRBuilder<>* b) {
106   if (b->getFastMathFlags().noNaNs()) {
107     auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
108     return b->CreateSelect(cmp, lhs_value, rhs_value);
109   } else {
110     auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value);
111     auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
112     auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan);
113     return b->CreateSelect(sel_lhs, lhs_value, rhs_value);
114   }
115 }
116 
EmitBufferIndexingGEP(llvm::Value * array,llvm::Value * index,llvm::IRBuilder<> * b)117 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index,
118                                    llvm::IRBuilder<>* b) {
119   llvm::Type* array_type = array->getType();
120   CHECK(array_type->isPointerTy());
121   llvm::PointerType* array_type_as_pointer =
122       llvm::cast<llvm::PointerType>(array_type);
123   VLOG(2) << "EmitBufferIndexingGEP with type="
124           << llvm_ir::DumpToString(*array_type)
125           << " array=" << llvm_ir::DumpToString(*array)
126           << " index=" << llvm_ir::DumpToString(*index);
127 
128   return b->CreateInBoundsGEP(
129       array_type_as_pointer->getElementType(), array,
130       llvm::isa<llvm::GlobalVariable>(array)
131           ? llvm::ArrayRef<llvm::Value*>({b->getInt64(0), index})
132           : index);
133 }
134 
EmitBufferIndexingGEP(llvm::Value * array,int64 index,llvm::IRBuilder<> * b)135 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
136                                    llvm::IRBuilder<>* b) {
137   return EmitBufferIndexingGEP(array, b->getInt64(index), b);
138 }
139 
PrimitiveTypeToIrType(PrimitiveType element_type,llvm::Module * module)140 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
141                                   llvm::Module* module) {
142   switch (element_type) {
143     case PRED:
144     case S8:
145     case U8:
146       return llvm::Type::getInt8Ty(module->getContext());
147     case S16:
148     case U16:
149     case BF16:
150       // For BF16 we just need some type that is 16 bits wide so that it will
151       // take up the right amount of space in memory. LLVM does not have a BF16
152       // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so
153       // we can't map it directly to an LLVM type. We will not map a BF16
154       // addition to an addition on this type (int16) - this is just the type
155       // used for storage.
156       return llvm::Type::getInt16Ty(module->getContext());
157     case F16:
158       return llvm::Type::getHalfTy(module->getContext());
159     case S32:
160     case U32:
161       return llvm::Type::getInt32Ty(module->getContext());
162     case S64:
163     case U64:
164       return llvm::Type::getInt64Ty(module->getContext());
165     case F32:
166       return llvm::Type::getFloatTy(module->getContext());
167     case F64:
168       return llvm::Type::getDoubleTy(module->getContext());
169     case C64: {
170       auto cplx_t = module->getTypeByName("complex64");
171       if (cplx_t == nullptr) {
172         // C++ standard dictates the memory layout of std::complex is contiguous
173         // real followed by imaginary. C++11 section 26.4 [complex.numbers]:
174         // If z is an lvalue expression of type cv std::complex<T> then the
175         // expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed,
176         // reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of
177         // z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the
178         // imaginary part of z.
179         return llvm::StructType::create(
180             {llvm::Type::getFloatTy(module->getContext()),
181              llvm::Type::getFloatTy(module->getContext())},
182             "complex64", /*isPacked=*/true);
183       }
184       return cplx_t;
185     }
186     case C128: {
187       auto cplx_t = module->getTypeByName("complex128");
188       if (cplx_t == nullptr) {
189         return llvm::StructType::create(
190             {llvm::Type::getDoubleTy(module->getContext()),
191              llvm::Type::getDoubleTy(module->getContext())},
192             "complex128", /*isPacked=*/true);
193       }
194       return cplx_t;
195     }  // A Tuple contains an array of pointers. Use i8*.
196     case TUPLE:
197     // An Opaque is like a void*, use i8*.
198     case OPAQUE_TYPE:
199       return llvm::Type::getInt8PtrTy(module->getContext());
200     case TOKEN:
201       // Tokens do not have a physical representation, but the compiler needs
202       // some placeholder type, so use int8*.
203       return llvm::Type::getInt8PtrTy(module->getContext());
204     default:
205       LOG(FATAL) << "unsupported type " << element_type;
206   }
207 }
208 
GetSizeInBits(llvm::Type * type)209 int GetSizeInBits(llvm::Type* type) {
210   const llvm::StructType* struct_ty = llvm::dyn_cast<llvm::StructType>(type);
211   if (struct_ty) {
212     CHECK(struct_ty->isPacked());
213     int bits = 0;
214     for (auto element_type : struct_ty->elements()) {
215       bits += GetSizeInBits(element_type);
216     }
217     return bits;
218   }
219   int bits = type->getPrimitiveSizeInBits();
220   CHECK_GT(bits, 0) << "type is not sized";
221   return bits;
222 }
223 
ShapeToIrType(const Shape & shape,llvm::Module * module)224 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) {
225   llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module);
226   if (shape.IsTuple()) {
227     // A tuple buffer is an array of pointers.
228     result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
229   } else if (shape.IsArray()) {
230     for (int64 dimension : LayoutUtil::MinorToMajor(shape)) {
231       result_type =
232           llvm::ArrayType::get(result_type, shape.dimensions(dimension));
233     }
234   }
235   return result_type;
236 }
237 
EncodeSelfDescribingShapeConstant(const Shape & shape,int32 * shape_size,llvm::IRBuilder<> * b)238 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
239                                                          int32* shape_size,
240                                                          llvm::IRBuilder<>* b) {
241   string encoded_shape = shape.SerializeAsString();
242   if (encoded_shape.size() > std::numeric_limits<int32>::max()) {
243     return InternalError("Encoded shape size exceeded int32 size limit.");
244   }
245   *shape_size = static_cast<int32>(encoded_shape.size());
246   return b->CreateGlobalStringPtr(encoded_shape);
247 }
248 
ConvertLiteralToIrConstant(const Literal & literal,llvm::Module * module)249 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
250                                            llvm::Module* module) {
251   const char* data = static_cast<const char*>(literal.untyped_data());
252   CHECK_EQ(module->getDataLayout().isLittleEndian(),
253            tensorflow::port::kLittleEndian);
254   return llvm::ConstantDataArray::getString(
255       module->getContext(), llvm::StringRef(data, literal.size_bytes()),
256       /*AddNull=*/false);
257 }
258 
AllocateSharedMemoryTile(llvm::Module * module,llvm::Type * tile_type,absl::string_view name)259 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
260                                                llvm::Type* tile_type,
261                                                absl::string_view name) {
262   // Both AMDGPU and NVPTX use the same address space for shared memory.
263   const int kGPUSharedMemoryAddrSpace = 3;
264   return new llvm::GlobalVariable(
265       *module, tile_type,
266       /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
267       llvm::UndefValue::get(tile_type), AsStringRef(name), nullptr,
268       llvm::GlobalValue::NotThreadLocal, kGPUSharedMemoryAddrSpace);
269 }
270 
EmitAllocaAtFunctionEntry(llvm::Type * type,absl::string_view name,llvm::IRBuilder<> * b,int alignment)271 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
272                                             absl::string_view name,
273                                             llvm::IRBuilder<>* b,
274                                             int alignment) {
275   return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment);
276 }
277 
EmitAllocaAtFunctionEntryWithCount(llvm::Type * type,llvm::Value * element_count,absl::string_view name,llvm::IRBuilder<> * b,int alignment)278 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
279                                                      llvm::Value* element_count,
280                                                      absl::string_view name,
281                                                      llvm::IRBuilder<>* b,
282                                                      int alignment) {
283   llvm::IRBuilder<>::InsertPointGuard guard(*b);
284   llvm::Function* function = b->GetInsertBlock()->getParent();
285   b->SetInsertPoint(&function->getEntryBlock(),
286                     function->getEntryBlock().getFirstInsertionPt());
287   llvm::AllocaInst* alloca =
288       b->CreateAlloca(type, element_count, AsStringRef(name));
289   if (alignment != 0) {
290     alloca->setAlignment(llvm::MaybeAlign(alignment));
291   }
292   return alloca;
293 }
294 
CreateBasicBlock(llvm::BasicBlock * insert_before,absl::string_view name,llvm::IRBuilder<> * b)295 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
296                                    absl::string_view name,
297                                    llvm::IRBuilder<>* b) {
298   return llvm::BasicBlock::Create(
299       /*Context=*/b->getContext(),
300       /*Name=*/AsStringRef(name),
301       /*Parent=*/b->GetInsertBlock()->getParent(),
302       /*InsertBefore*/ insert_before);
303 }
304 
EmitIfThenElse(llvm::Value * condition,absl::string_view name,llvm::IRBuilder<> * b,bool emit_else)305 LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
306                           llvm::IRBuilder<>* b, bool emit_else) {
307   llvm_ir::LlvmIfData if_data;
308   if_data.if_block = b->GetInsertBlock();
309   if_data.true_block =
310       CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b);
311   if_data.false_block =
312       emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b)
313                 : nullptr;
314 
315   // Add a terminator to the if block, if necessary.
316   if (if_data.if_block->getTerminator() == nullptr) {
317     b->SetInsertPoint(if_data.if_block);
318     if_data.after_block =
319         CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b);
320     b->CreateBr(if_data.after_block);
321   } else {
322     if_data.after_block = if_data.if_block->splitBasicBlock(
323         b->GetInsertPoint(), absl::StrCat(name, "-after"));
324   }
325 
326   // Our basic block should now end with an unconditional branch.  Remove it;
327   // we're going to replace it with a conditional branch.
328   if_data.if_block->getTerminator()->eraseFromParent();
329 
330   b->SetInsertPoint(if_data.if_block);
331   b->CreateCondBr(condition, if_data.true_block,
332                   emit_else ? if_data.false_block : if_data.after_block);
333 
334   b->SetInsertPoint(if_data.true_block);
335   b->CreateBr(if_data.after_block);
336 
337   if (emit_else) {
338     b->SetInsertPoint(if_data.false_block);
339     b->CreateBr(if_data.after_block);
340   }
341 
342   b->SetInsertPoint(if_data.after_block,
343                     if_data.after_block->getFirstInsertionPt());
344 
345   return if_data;
346 }
347 
EmitComparison(llvm::CmpInst::Predicate predicate,llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)348 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
349                             llvm::Value* lhs_value, llvm::Value* rhs_value,
350                             llvm::IRBuilder<>* b) {
351   llvm::Value* comparison_result;
352   if (lhs_value->getType()->isIntegerTy()) {
353     comparison_result = b->CreateICmp(predicate, lhs_value, rhs_value);
354   } else {
355     comparison_result = b->CreateFCmp(predicate, lhs_value, rhs_value);
356   }
357   // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
358   // arrays. So we extend it to i8 so that it's addressable.
359   return b->CreateZExt(comparison_result, llvm_ir::PrimitiveTypeToIrType(
360                                               PRED, ModuleFromIRBuilder(b)));
361 }
362 
363 // Internal helper that is called from emitted code to log an int64 value with a
364 // tag.
LogS64(const char * tag,int64 value)365 static void LogS64(const char* tag, int64 value) {
366   LOG(INFO) << tag << " (int64): " << value;
367 }
368 
EmitLogging(const char * tag,llvm::Value * value,llvm::IRBuilder<> * b)369 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b) {
370   llvm::FunctionType* log_function_type = llvm::FunctionType::get(
371       b->getVoidTy(), {b->getInt64Ty(), b->getInt64Ty()}, /*isVarArg=*/false);
372   b->CreateCall(log_function_type,
373                 b->CreateIntToPtr(b->getInt64(absl::bit_cast<int64>(&LogS64)),
374                                   log_function_type->getPointerTo()),
375                 {b->getInt64(absl::bit_cast<int64>(tag)), value});
376 }
377 
SetAlignmentMetadataForLoad(llvm::LoadInst * load,uint64_t alignment)378 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
379   llvm::LLVMContext& context = load->getContext();
380   llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
381   llvm::Constant* alignment_constant =
382       llvm::ConstantInt::get(int64_ty, alignment);
383   llvm::MDBuilder metadata_builder(context);
384   auto* alignment_metadata =
385       metadata_builder.createConstant(alignment_constant);
386   load->setMetadata(llvm::LLVMContext::MD_align,
387                     llvm::MDNode::get(context, alignment_metadata));
388 }
389 
SetDereferenceableMetadataForLoad(llvm::LoadInst * load,uint64_t dereferenceable_bytes)390 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
391                                        uint64_t dereferenceable_bytes) {
392   llvm::LLVMContext& context = load->getContext();
393   llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
394   llvm::Constant* dereferenceable_bytes_constant =
395       llvm::ConstantInt::get(int64_ty, dereferenceable_bytes);
396   llvm::MDBuilder metadata_builder(context);
397   auto* dereferenceable_bytes_metadata =
398       metadata_builder.createConstant(dereferenceable_bytes_constant);
399   load->setMetadata(llvm::LLVMContext::MD_dereferenceable,
400                     llvm::MDNode::get(context, dereferenceable_bytes_metadata));
401 }
402 
AddRangeMetadata(int64 lower,int64 upper,llvm::Instruction * inst)403 llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
404                                     llvm::Instruction* inst) {
405   llvm::LLVMContext& context = inst->getParent()->getContext();
406   llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context);
407   inst->setMetadata(
408       llvm::LLVMContext::MD_range,
409       llvm::MDNode::get(
410           context,
411           {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)),
412            llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))}));
413   return inst;
414 }
415 
IrName(string a)416 string IrName(string a) {
417   a.erase(std::remove(a.begin(), a.end(), '%'), a.end());
418   return a;
419 }
420 
IrName(absl::string_view a,absl::string_view b)421 string IrName(absl::string_view a, absl::string_view b) {
422   if (!a.empty() && !b.empty()) {
423     return IrName(absl::StrCat(a, ".", b));
424   }
425   return IrName(absl::StrCat(a, b));
426 }
427 
IrName(const HloInstruction * a,absl::string_view b)428 string IrName(const HloInstruction* a, absl::string_view b) {
429   return IrName(a->name(), b);
430 }
431 
SanitizeFunctionName(string function_name)432 string SanitizeFunctionName(string function_name) {
433   // The backend with the strictest requirements on function names is NVPTX, so
434   // we sanitize to its requirements.
435   //
436   // A slightly stricter version of the NVPTX requirements is that names match
437   // /[a-zA-Z_$][a-zA-Z0-9_$]*/, with the exception that the names "_" and "$"
438   // are illegal.
439 
440   // Sanitize chars in function_name.
441   std::transform(function_name.begin(), function_name.end(),
442                  function_name.begin(), [](char c) {
443                    if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') ||
444                        ('0' <= c && c <= '9') || c == '_' || c == '$') {
445                      return c;
446                    }
447                    return '_';
448                  });
449 
450   // Ensure the name isn't empty.
451   if (function_name.empty()) {
452     function_name = "__unnamed";
453   }
454 
455   // Ensure the name doesn't start with a number.
456   if (!function_name.empty() && function_name[0] >= '0' &&
457       function_name[0] <= '9') {
458     function_name.insert(function_name.begin(), '_');
459   }
460 
461   // Ensure the name isn't "_" or "$".
462   if (function_name == "_" || function_name == "$") {
463     function_name += '_';
464   }
465 
466   return function_name;
467 }
468 
SetToFirstInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)469 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
470   builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
471 }
472 
SetToLastInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)473 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
474   if (llvm::Instruction* terminator = blk->getTerminator()) {
475     builder->SetInsertPoint(terminator);
476   } else {
477     builder->SetInsertPoint(blk);
478   }
479 }
480 
CreateRor(llvm::Value * rotand,llvm::Value * rotor,llvm::IRBuilder<> * builder)481 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
482                        llvm::IRBuilder<>* builder) {
483   auto size = rotand->getType()->getPrimitiveSizeInBits();
484   auto size_value = builder->getIntN(size, size);
485   auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); };
486   return builder->CreateOr(
487       builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))),
488       builder->CreateLShr(rotand, mod(rotor)));
489 }
490 
ByteSizeOf(const Shape & shape,const llvm::DataLayout & data_layout)491 int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
492   unsigned pointer_size = data_layout.getPointerSize();
493   return ShapeUtil::ByteSizeOf(shape, pointer_size);
494 }
495 
GetCpuFastMathFlags(const HloModuleConfig & module_config)496 llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) {
497   llvm::FastMathFlags flags;
498   const auto& options = module_config.debug_options();
499   if (!options.xla_cpu_enable_fast_math()) {
500     return flags;
501   }
502   // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, AllowReciprocal,
503   // AllowContract, and ApproxFunc.
504   flags.setFast();
505   flags.setNoNaNs(!options.xla_cpu_fast_math_honor_nans());
506   flags.setNoInfs(!options.xla_cpu_fast_math_honor_infs());
507   flags.setAllowReciprocal(!options.xla_cpu_fast_math_honor_division());
508   flags.setApproxFunc(!options.xla_cpu_fast_math_honor_functions());
509   return flags;
510 }
511 
MergeMetadata(llvm::LLVMContext * context,const std::map<int,llvm::MDNode * > & a,const std::map<int,llvm::MDNode * > & b)512 std::map<int, llvm::MDNode*> MergeMetadata(
513     llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
514     const std::map<int, llvm::MDNode*>& b) {
515   // We should extend this as needed to deal with other kinds of metadata like
516   // !dereferenceable and !range.
517 
518   std::map<int, llvm::MDNode*> result;
519   for (auto kind_md_pair : a) {
520     if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) {
521       llvm::SmallVector<llvm::Metadata*, 8> union_of_scopes;
522       llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
523       for (const auto& scope_a : kind_md_pair.second->operands()) {
524         scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
525         union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_a.get()));
526       }
527       auto it = b.find(kind_md_pair.first);
528       if (it != b.end()) {
529         for (const auto& scope_b : it->second->operands()) {
530           if (!scope_set.count(llvm::cast<llvm::MDNode>(scope_b.get()))) {
531             union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b.get()));
532           }
533         }
534       }
535       result[llvm::LLVMContext::MD_alias_scope] =
536           llvm::MDNode::get(*context, union_of_scopes);
537     } else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) {
538       llvm::SmallVector<llvm::Metadata*, 8> intersection_of_scopes;
539       llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
540       for (const auto& scope_a : kind_md_pair.second->operands()) {
541         scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
542       }
543       auto it = b.find(kind_md_pair.first);
544       if (it != b.end()) {
545         for (const auto& scope_b : it->second->operands()) {
546           if (scope_set.count(llvm::cast<llvm::MDNode>(scope_b))) {
547             intersection_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b));
548           }
549         }
550       }
551       if (!intersection_of_scopes.empty()) {
552         result[llvm::LLVMContext::MD_noalias] =
553             llvm::MDNode::get(*context, intersection_of_scopes);
554       }
555     }
556   }
557   return result;
558 }
559 
CreateAndWriteStringToFile(const string & directory_name,const string & file_name,const string & text)560 static Status CreateAndWriteStringToFile(const string& directory_name,
561                                          const string& file_name,
562                                          const string& text) {
563   std::unique_ptr<tensorflow::WritableFile> f;
564   TF_RETURN_IF_ERROR(
565       tensorflow::Env::Default()->RecursivelyCreateDir(directory_name));
566   TF_RETURN_IF_ERROR(
567       tensorflow::Env::Default()->NewWritableFile(file_name, &f));
568   TF_RETURN_IF_ERROR(f->Append(text));
569   TF_RETURN_IF_ERROR(f->Close());
570   return Status::OK();
571 }
572 
DumpIrIfEnabled(const HloModule & hlo_module,const llvm::Module & llvm_module,bool optimized)573 void DumpIrIfEnabled(const HloModule& hlo_module,
574                      const llvm::Module& llvm_module, bool optimized) {
575   const auto& debug_opts = hlo_module.config().debug_options();
576   if (!DumpingEnabledForHloModule(hlo_module)) {
577     return;
578   }
579   // We can end up compiling different modules with the same name when using
580   // XlaJitCompiledCpuFunction::Compile.  Avoid overwriting IR files previously
581   // dumped from the same process in such cases.
582   string suffix = absl::StrCat("ir-", optimized ? "with" : "no", "-opt");
583   DumpToFileInDirOrStdout(hlo_module, "", absl::StrCat(suffix, ".ll"),
584                           DumpModuleToString(llvm_module));
585 
586   // For some models the embedded constants can be huge, so also dump the module
587   // with the constants stripped to get IR that is easier to manipulate.  Skip
588   // this if we're dumping to stdout; there's no point in duplicating everything
589   // when writing to the terminal.
590   if (!DumpingToStdout(debug_opts)) {
591     DumpToFileInDir(hlo_module, "", absl::StrCat(suffix, "-noconst.ll"),
592                     DumpModuleToString(*DropConstantInitializers(llvm_module)));
593   }
594 }
595 
CreateCpuFunction(llvm::FunctionType * function_type,llvm::GlobalValue::LinkageTypes linkage,const HloModuleConfig & module_config,absl::string_view name,llvm::Module * module)596 llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
597                                   llvm::GlobalValue::LinkageTypes linkage,
598                                   const HloModuleConfig& module_config,
599                                   absl::string_view name,
600                                   llvm::Module* module) {
601   llvm::Function* function =
602       llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
603   function->setCallingConv(llvm::CallingConv::C);
604   function->addFnAttr("no-frame-pointer-elim", "false");
605 
606   // Generate unwind information so that GDB can crawl through the stack frames
607   // created by the JIT compiled code.
608   function->setHasUWTable();
609 
610   // Tensorflow always flushes denormals to zero, let LLVM know that flushing
611   // denormals is safe. This allows vectorization using ARM's neon instruction
612   // set.
613   function->addFnAttr("denormal-fp-math", "preserve-sign");
614 
615   // Add the optimize attribute to the function if optimizing for size. This
616   // controls internal behavior of some optimization passes (e.g. loop
617   // unrolling).
618   if (cpu::options::OptimizeForSizeRequested(module_config)) {
619     function->addFnAttr(llvm::Attribute::OptimizeForSize);
620   }
621 
622   return function;
623 }
624 
InitializeLLVMCommandLineOptions(const HloModuleConfig & config)625 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
626   auto options = config.debug_options().xla_backend_extra_options();
627   if (!options.empty()) {
628     std::vector<string> fake_argv_storage;
629     fake_argv_storage.push_back("");
630     for (const auto& it : options) {
631       // Skip options the XLA backend itself consumes.
632       if (!absl::StartsWith(it.first, "xla_")) {
633         if (it.second.empty()) {
634           fake_argv_storage.push_back(it.first);
635         } else {
636           fake_argv_storage.push_back(it.first + "=" + it.second);
637         }
638       }
639     }
640 
641     VLOG(2) << "Passing argv to LLVM:";
642     std::vector<const char*> fake_argv;
643     for (const auto& s : fake_argv_storage) {
644       fake_argv.push_back(s.c_str());
645       VLOG(2) << s;
646     }
647     llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]);
648   }
649 }
650 
UMulLowHigh32(llvm::IRBuilder<> * b,llvm::Value * src0,llvm::Value * src1)651 std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b,
652                                                     llvm::Value* src0,
653                                                     llvm::Value* src1) {
654   CHECK_EQ(src0->getType()->getPrimitiveSizeInBits(), 32);
655   CHECK_EQ(src1->getType()->getPrimitiveSizeInBits(), 32);
656   llvm::Type* int64_ty = b->getInt64Ty();
657   src0 = b->CreateZExt(src0, int64_ty);
658   src1 = b->CreateZExt(src1, int64_ty);
659   return SplitInt64ToInt32s(b, b->CreateMul(src0, src1));
660 }
661 
SplitInt64ToInt32s(llvm::IRBuilder<> * b,llvm::Value * value_64bits)662 std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s(
663     llvm::IRBuilder<>* b, llvm::Value* value_64bits) {
664   CHECK_EQ(value_64bits->getType()->getPrimitiveSizeInBits(), 64);
665   llvm::Type* int32_ty = b->getInt32Ty();
666   llvm::Value* low_32bits = b->CreateTrunc(value_64bits, int32_ty);
667   llvm::Value* high_32bits =
668       b->CreateTrunc(b->CreateLShr(value_64bits, 32), int32_ty);
669   return std::make_pair(low_32bits, high_32bits);
670 }
671 
GetGlobalMemoryAddressSpace(const llvm::Module & module)672 unsigned GetGlobalMemoryAddressSpace(const llvm::Module& module) {
673   const unsigned kAMDGPUGlobalMemoryAddrSpace = 1;
674   llvm::Triple target_triple = llvm::Triple(module.getTargetTriple());
675   if (target_triple.getArch() == llvm::Triple::amdgcn) {
676     // AMDGPU uses 1 for global memory address space.
677     return kAMDGPUGlobalMemoryAddrSpace;
678   }
679   return 0;
680 }
681 
GetOrCreateVariableForRngState(llvm::Module * module,llvm::IRBuilder<> * b)682 llvm::GlobalVariable* GetOrCreateVariableForRngState(llvm::Module* module,
683                                                      llvm::IRBuilder<>* b) {
684   static const char* kRngStateVariableName = "rng_state";
685   llvm::GlobalVariable* state_ptr =
686       module->getNamedGlobal(kRngStateVariableName);
687   if (!state_ptr) {
688     unsigned global_address_space = GetGlobalMemoryAddressSpace(*module);
689     llvm::Type* state_type = b->getInt128Ty();
690     // Use a non-zero initial value as zero state can cause the result of the
691     // first random number generation not passing the chi-square test. The
692     // values used here are arbitrarily chosen, any non-zero values should be
693     // fine.
694     state_ptr = new llvm::GlobalVariable(
695         /*M=*/*module,
696         /*Ty=*/state_type,
697         /*isConstant=*/false,
698         /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
699         /*Initializer=*/llvm::ConstantInt::get(b->getInt128Ty(), 0x7012395ull),
700         /*Name=*/kRngStateVariableName,
701         /*InsertBefore=*/nullptr,
702         /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
703         /*AddressSpace=*/global_address_space,
704         /*isExternallyInitialized=*/false);
705   }
706   return state_ptr;
707 }
708 
RngGetAndUpdateState(uint64 delta,llvm::Module * module,llvm::IRBuilder<> * builder)709 llvm::Value* RngGetAndUpdateState(uint64 delta, llvm::Module* module,
710                                   llvm::IRBuilder<>* builder) {
711   llvm::GlobalVariable* state_ptr =
712       GetOrCreateVariableForRngState(module, builder);
713   llvm::LoadInst* state_value_old =
714       builder->CreateLoad(state_ptr, "load_state");
715   llvm::Value* state_value_new = builder->CreateAdd(
716       state_value_old,
717       llvm::ConstantInt::get(state_value_old->getType(), delta));
718   builder->CreateStore(state_value_new, state_ptr);
719   return state_value_old;
720 }
721 
722 }  // namespace llvm_ir
723 }  // namespace xla
724