1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
18
19 #include <stdint.h>
20 #include <string>
21 #include <vector>
22
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/IR/BasicBlock.h"
27 #include "llvm/IR/GlobalVariable.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "tensorflow/compiler/xla/literal.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
36 #include "tensorflow/compiler/xla/types.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/platform/types.h"
39
40 namespace llvm {
41 class FastMathFlags;
42 class TargetOptions;
43 };
44
45 namespace xla {
46 namespace llvm_ir {
47
48 // Convert a absl::string_view to a llvm::StringRef. Note: both
49 // absl::string_view and llvm::StringRef are non-owning pointers into a
50 // string in memory. This method is used to feed strings to LLVM
51 // & Clang APIs that expect llvm::StringRef.
AsStringRef(absl::string_view str)52 inline llvm::StringRef AsStringRef(absl::string_view str) {
53 return llvm::StringRef(str.data(), str.size());
54 }
55
56 template <typename T>
AsArrayRef(const std::vector<T> & vec)57 llvm::ArrayRef<T> AsArrayRef(const std::vector<T>& vec) {
58 return llvm::ArrayRef<T>(vec.data(), vec.size());
59 }
60
61 template <typename T>
AsArrayRef(const absl::Span<const T> & slice)62 llvm::ArrayRef<T> AsArrayRef(const absl::Span<const T>& slice) {
63 return llvm::ArrayRef<T>(slice.data(), slice.size());
64 }
65
66 // Dump the given LLVM entity to a string. This works for Types and Values.
67 template <typename T>
DumpToString(const T & entity)68 string DumpToString(const T& entity) {
69 std::string buffer_string;
70 llvm::raw_string_ostream ostream(buffer_string);
71 entity.print(ostream);
72 ostream.flush();
73 return buffer_string;
74 }
75
76 // Dump the given LLVM module to a string. This requires a function distinct
77 // from DumpToString because the signatures of the print() methods for Values
78 // and Modules are slightly different.
79 string DumpModuleToString(const llvm::Module& module);
80
81 // Constructs a human-friendly name from the given inputs. The result is
82 // suitable for use as an llvm::Value's name.
83 //
84 // This is equivalent to
85 //
86 // - changing the HloInstruction* to its name() (if we called that overload),
87 // - joining all of the nonempty inputs by '.', and then
88 // - removing all '%'s.
89 //
90 string IrName(string a);
91 string IrName(absl::string_view a, absl::string_view b);
92 string IrName(const HloInstruction* a, absl::string_view b = "");
93
94 // Removes special characters from a function name.
95 //
96 // Note that this can cause different inputs to map to the same output, so after
97 // sanitizing a function name, you must run it through a uniquer.
98 string SanitizeFunctionName(string function_name);
99
100 // Emits a call to the specified intrinsic with the given operands. Overloaded
101 // intrinsics (for example, "minnum") must include a type in overloaded_types
102 // for each overloaded type. Typically, overloaded intrinsics have only a single
103 // overloaded type.
104 llvm::CallInst* EmitCallToIntrinsic(
105 llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
106 absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b);
107
108 // Emit float max. Emit maxnum intrinsic is fast math is disabled, or
109 // fcmp+select otherwise
110 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
111 llvm::IRBuilder<>* b);
112
113 // Emit float min. Emit minnum intrinsic is fast math is disabled, or
114 // fcmp+select otherwise
115 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
116 llvm::IRBuilder<>* b);
117
118 // Convenience methods for emitting a GEP instruction that indexes into a buffer
119 // (1-dimensional array), equivalent to array[index]. The type is automatically
120 // determined from the element type of the array. The int64 index overload
121 // wraps the index in a i64 llvm::Value.
122 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index,
123 llvm::IRBuilder<>* b);
124 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
125 llvm::IRBuilder<>* b);
126
127 // Returns the LLVM type which represents the given XLA primitive type.
128 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
129 llvm::Module* module);
130
131 // Returns the type size in bits. If "type" is a struct, it must be packed.
132 int GetSizeInBits(llvm::Type* type);
133
134 // Returns the LLVM type which represents the given XLA shape. For example,
135 // if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]].
136 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module);
137
138 // Returns a value that represents a pointer to a global string constant that
139 // encodes the shape as a serialized protobuf.
140 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
141 int32* shape_size,
142 llvm::IRBuilder<>* b);
143
144 // Converts a given literal to an IR Constant. Literals have known constant
145 // values at IR emission time.
146 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
147 llvm::Module* module);
148
149 // Allocates a tile of shared memory.
150 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
151 llvm::Type* tile_type,
152 absl::string_view name);
153
154 // Inserts an allocate of the requested type at the entry point of the
155 // function that the builder is currently building. The insert point
156 // of the builder is set to the same place after calling this function
157 // as before.
158 //
159 // This can be useful to avoid e.g. executing an alloca every time
160 // through a loop.
161 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
162 absl::string_view name,
163 llvm::IRBuilder<>* b,
164 int alignment = 0);
165
166 // As EmitAllocaAtFunctionEntry, but allocates element_count entries
167 // instead of a single element.
168 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
169 llvm::Value* element_count,
170 absl::string_view name,
171 llvm::IRBuilder<>* b,
172 int alignment = 0);
173
174 // Creates a basic block with the same context and function as for the
175 // builder. Inserts at the end of the function if insert_before is
176 // null.
177 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
178 absl::string_view name,
179 llvm::IRBuilder<>* b);
180
181 // Struct with data on a conditional branch in a diamond shape created
182 // via EmitIfThenElse.
183 struct LlvmIfData {
184 // The block that has the conditional branch.
185 llvm::BasicBlock* if_block;
186
187 // The block that is executed if the condition is true.
188 llvm::BasicBlock* true_block;
189
190 // The block that is executed if the condition is false.
191 llvm::BasicBlock* false_block;
192
193 // The block that follows after both the true_block and the
194 // false_block.
195 llvm::BasicBlock* after_block;
196 };
197
198 // Inserts a diamond-shaped if-then-else construct at the current
199 // insertion point of the builder. This involves splitting the current
200 // block into two blocks, at the insertion point, and introducing a
201 // true-block and a false-block that connect the two split pieces. The
202 // true-block is executed if the condition parameter evaluates to true
203 // and otherwise the false-block is executed. If `emit_else` is false,
204 // it jumps to the after-block rather than the false-block if the
205 // condition is false, and the returned `false_block` is null.
206 //
207 // Currently the insertion point of the builder must be a well-formed
208 // block with a terminator. If you need to use this for a
209 // non-terminated block, just make the function able to do that too.
210 LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
211 llvm::IRBuilder<>* b, bool emit_else = true);
212
213 // Emits a compare operation between "lhs" and "rhs" with the given predicate,
214 // and then converts the result to i8 so that it is addressable.
215 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
216 llvm::Value* lhs, llvm::Value* rhs,
217 llvm::IRBuilder<>* b);
218
219 // Emits a call that logs the given value with the given tag as a prefix.
220 // The provided tag and value are passed to a runtime logging call that is
221 // embedded in this translation unit when the emitted code is executed.
222 //
223 // This can be very useful for debugging generated programs in short order when
224 // developing new generated routines.
225 //
226 // Precondition: value must be an int64.
227 // Precondition: tag must be a stable pointer for the lifetime of the generated
228 // program (the constant pointer is burned in to the program).
229 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b);
230
231 // Adds alignment metadata to a load instruction using the given alignment.
232 // The alignment refers to the result of the load, not the load itself.
233 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment);
234
235 // Adds dereferenceable metadata to a load instruction using the given
236 // the number of dereferenceable bytes.
237 // Dereferenceable refers to the result of the load, not the load itself.
238 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
239 uint64_t dereferenceable_bytes);
240
241 // Tells LLVM `inst >= lower && inst < upper`. Returns `inst` for convenience.
242 llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
243 llvm::Instruction* inst);
244
245 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
246
247 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
248
249 // Create a bitwise rotation of `rotand` by `rotor`.
250 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
251 llvm::IRBuilder<>* builder);
252
253 // Returns the number of bytes within the shape.
254 int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout);
255
256 // Gets an llvm::FastMathFlags that reflects the settings in the given
257 // module config.
258 llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config);
259
260 // Computes a conservative union of the metadata in "a" and "b". For
261 // aliasing-related metadata, this means the result can be applied to
262 // instructions whose aliasing relationship can be described either by "a" *or*
263 // by "b".
264 std::map<int, llvm::MDNode*> MergeMetadata(
265 llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
266 const std::map<int, llvm::MDNode*>& b);
267
268 // Dumps out `llvm_module` to the path specified in DebugOptions, if dumping is
269 // enabled for the given HLO module.
270 //
271 // A sanitized version of `hlo_module_name` is incorporated into the file name.
272 // If `optimized` is true then a suffix of "-with-opt.ll" is used, else a suffix
273 // of "-no-opt.ll" is used.
274 void DumpIrIfEnabled(const HloModule& hlo_module,
275 const llvm::Module& llvm_module, bool optimized);
276
277 llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
278 llvm::GlobalValue::LinkageTypes linkage,
279 const HloModuleConfig& module_config,
280 absl::string_view name, llvm::Module* module);
281
282 // Extracts the xla_backend_extra_options from `config` and passes those that
283 // don't start with xla_ to LLVM.
284 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config);
285
286 // Zero-extends two 32-bit values to 64 bits, multiplies them, and returns the
287 // result as a pair of (low 32 bits, high 32 bits).
288 std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b,
289 llvm::Value* src0,
290 llvm::Value* src1);
291 // Splits the 64-bit integer value into its high and low 32 bits.
292 std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s(
293 llvm::IRBuilder<>* b, llvm::Value* value_64bits);
294
295 // Checks whether a global variable is already created to represent the state
296 // of a random number generator. If not, creates such a variable. Returns the
297 // global variable.
298 llvm::GlobalVariable* GetOrCreateVariableRngState(llvm::Module* module,
299 llvm::IRBuilder<>* b);
300
301 // Adds a delta value to the global state variable and return the old value of
302 // the variable.
303 llvm::Value* RngGetAndUpdateState(uint64 delta, llvm::Module* module,
304 llvm::IRBuilder<>* b);
305
306 // Gets the LLVM address space that should be used for global variables (e.g.
307 // XLA's rng state).
308 unsigned GetGlobalMemoryAddressSpace(const llvm::Module& module);
309 } // namespace llvm_ir
310 } // namespace xla
311
312 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
313