1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
18
19 #include <stdint.h>
20
21 #include <string>
22 #include <vector>
23
24 #include "absl/strings/string_view.h"
25 #include "absl/types/span.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/IR/BasicBlock.h"
28 #include "llvm/IR/GlobalVariable.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/Module.h"
32 #include "llvm/IR/Value.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
35 #include "tensorflow/compiler/xla/literal.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
38 #include "tensorflow/compiler/xla/types.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/platform/types.h"
41
42 namespace llvm {
43 class FastMathFlags;
44 class TargetOptions;
45 };
46
47 namespace xla {
48 namespace llvm_ir {
49
50 // Convert a absl::string_view to a llvm::StringRef. Note: both
51 // absl::string_view and llvm::StringRef are non-owning pointers into a
52 // string in memory. This method is used to feed strings to LLVM
53 // & Clang APIs that expect llvm::StringRef.
AsStringRef(absl::string_view str)54 inline llvm::StringRef AsStringRef(absl::string_view str) {
55 return llvm::StringRef(str.data(), str.size());
56 }
57
AsStringView(llvm::StringRef str)58 inline absl::string_view AsStringView(llvm::StringRef str) {
59 return absl::string_view(str.data(), str.size());
60 }
61
62 template <typename T>
AsArrayRef(const std::vector<T> & vec)63 llvm::ArrayRef<T> AsArrayRef(const std::vector<T>& vec) {
64 return llvm::ArrayRef<T>(vec.data(), vec.size());
65 }
66
67 template <typename T>
AsArrayRef(const absl::Span<const T> slice)68 llvm::ArrayRef<T> AsArrayRef(const absl::Span<const T> slice) {
69 return llvm::ArrayRef<T>(slice.data(), slice.size());
70 }
71
72 // Dump the given LLVM entity to a string. This works for Types and Values.
73 template <typename T>
DumpToString(const T & entity)74 string DumpToString(const T& entity) {
75 std::string buffer_string;
76 llvm::raw_string_ostream ostream(buffer_string);
77 entity.print(ostream);
78 ostream.flush();
79 return buffer_string;
80 }
81
82 // Same as above, except that const T& does not work well with MILR because the
83 // print methods are not const.
84 template <typename T>
DumpToString(T & entity)85 string DumpToString(T& entity) {
86 std::string buffer_string;
87 llvm::raw_string_ostream ostream(buffer_string);
88 entity.print(ostream);
89 ostream.flush();
90 return buffer_string;
91 }
92
93 // Dump the given LLVM module to a string. This requires a function distinct
94 // from DumpToString because the signatures of the print() methods for Values
95 // and Modules are slightly different.
96 string DumpModuleToString(const llvm::Module& module);
97
98 // Constructs a human-friendly name from the given inputs. The result is
99 // suitable for use as an llvm::Value's name.
100 //
101 // This is equivalent to
102 //
103 // - changing the HloInstruction* to its name() (if we called that overload),
104 // - joining all of the nonempty inputs by '.', and then
105 // - removing all '%'s.
106 //
107 string IrName(absl::string_view a);
108 string IrName(absl::string_view a, absl::string_view b);
109 string IrName(const HloInstruction* a, absl::string_view b = "");
110
111 // Removes special characters from a function name.
112 //
113 // Note that this can cause different inputs to map to the same output, so after
114 // sanitizing a function name, you must run it through a uniquer.
115 string SanitizeFunctionName(string function_name);
116
117 // Emits a call to the specified intrinsic with the given operands. Overloaded
118 // intrinsics (for example, "minnum") must include a type in overloaded_types
119 // for each overloaded type. Typically, overloaded intrinsics have only a single
120 // overloaded type.
121 llvm::CallInst* EmitCallToIntrinsic(
122 llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
123 absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b,
124 absl::string_view name = "");
125
126 // Emit float max. Emit maxnum intrinsic is fast math is disabled, or
127 // fcmp+select otherwise
128 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
129 llvm::IRBuilder<>* b, bool enable_fast_min_max,
130 absl::string_view name = "");
131
132 // Emit float min. Emit minnum intrinsic is fast math is disabled, or
133 // fcmp+select otherwise
134 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
135 llvm::IRBuilder<>* b, bool enable_fast_min_max,
136 absl::string_view name = "");
137
138 // Convenience methods for emitting a GEP instruction that indexes into a buffer
139 // (1-dimensional array), equivalent to array[index]. The type is automatically
140 // determined from the element type of the array. The int64 index overload
141 // wraps the index in a i64 llvm::Value.
142 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index,
143 llvm::IRBuilder<>* b);
144 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64_t index,
145 llvm::IRBuilder<>* b);
146
147 // Returns the LLVM type which represents the given XLA primitive type.
148 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
149 llvm::Module* module);
150
151 // Returns the type size in bits. If "type" is a struct, it must be packed.
152 int GetSizeInBits(llvm::Type* type);
153
154 // Returns the LLVM type which represents the given XLA shape. For example,
155 // if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]].
156 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module);
157
158 // Returns a value that represents a pointer to a global string constant that
159 // encodes the shape as a serialized protobuf.
160 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
161 int32* shape_size,
162 llvm::IRBuilder<>* b);
163
164 // Converts a given literal to an IR Constant. Literals have known constant
165 // values at IR emission time.
166 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
167 llvm::Module* module);
168
169 // Allocates a tile of shared memory.
170 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
171 llvm::Type* tile_type,
172 absl::string_view name);
173
174 // Inserts an allocate of the requested type at the entry point of the
175 // function that the builder is currently building. The insert point
176 // of the builder is set to the same place after calling this function
177 // as before.
178 //
179 // This can be useful to avoid e.g. executing an alloca every time
180 // through a loop.
181 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
182 absl::string_view name,
183 llvm::IRBuilder<>* b,
184 int alignment = 0);
185
186 // As EmitAllocaAtFunctionEntry, but allocates element_count entries
187 // instead of a single element.
188 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
189 llvm::Value* element_count,
190 absl::string_view name,
191 llvm::IRBuilder<>* b,
192 int alignment = 0);
193
194 // Creates a basic block with the same context and function as for the
195 // builder. Inserts at the end of the function if insert_before is
196 // null.
197 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
198 absl::string_view name,
199 llvm::IRBuilder<>* b);
200
201 // Struct with data on a conditional branch in a diamond shape created
202 // via EmitIfThenElse.
203 struct LlvmIfData {
204 // The block that has the conditional branch.
205 llvm::BasicBlock* if_block;
206
207 // The block that is executed if the condition is true.
208 llvm::BasicBlock* true_block;
209
210 // The block that is executed if the condition is false.
211 llvm::BasicBlock* false_block;
212
213 // The block that follows after both the true_block and the
214 // false_block.
215 llvm::BasicBlock* after_block;
216 };
217
218 // Inserts a diamond-shaped if-then-else construct at the current
219 // insertion point of the builder. This involves splitting the current
220 // block into two blocks, at the insertion point, and introducing a
221 // true-block and a false-block that connect the two split pieces. The
222 // true-block is executed if the condition parameter evaluates to true
223 // and otherwise the false-block is executed. If `emit_else` is false,
224 // it jumps to the after-block rather than the false-block if the
225 // condition is false, and the returned `false_block` is null.
226 //
227 // Currently the insertion point of the builder must be a well-formed
228 // block with a terminator. If you need to use this for a
229 // non-terminated block, just make the function able to do that too.
230 LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
231 llvm::IRBuilder<>* b, bool emit_else = true);
232
233 // Emits a compare operation between "lhs" and "rhs" with the given predicate,
234 // and then converts the result to i8 so that it is addressable.
235 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
236 llvm::Value* lhs, llvm::Value* rhs,
237 llvm::IRBuilder<>* b, absl::string_view name = "");
238
239 // Emits a call that logs the given value with the given tag as a prefix.
240 // The provided tag and value are passed to a runtime logging call that is
241 // embedded in this translation unit when the emitted code is executed.
242 //
243 // This can be very useful for debugging generated programs in short order when
244 // developing new generated routines.
245 //
246 // Precondition: value must be an int64.
247 // Precondition: tag must be a stable pointer for the lifetime of the generated
248 // program (the constant pointer is burned in to the program).
249 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b);
250
251 // Adds alignment metadata to a load instruction using the given alignment.
252 // The alignment refers to the result of the load, not the load itself.
253 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment);
254
255 // Adds dereferenceable metadata to a load instruction using the given
256 // the number of dereferenceable bytes.
257 // Dereferenceable refers to the result of the load, not the load itself.
258 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
259 uint64_t dereferenceable_bytes);
260
261 // Tells LLVM `inst >= lower && inst < upper`. Returns `inst` for convenience.
262 llvm::Instruction* AddRangeMetadata(int64_t lower, int64_t upper,
263 llvm::Instruction* inst);
264
265 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
266
267 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
268
269 // Create a bitwise rotation of `rotand` by `rotor`.
270 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
271 llvm::IRBuilder<>* builder);
272
273 // Returns the number of bytes within the shape.
274 int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout);
275
276 // Gets an llvm::FastMathFlags that reflects the settings in the given
277 // module config.
278 llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config);
279
280 // Computes a conservative union of the metadata in "a" and "b". For
281 // aliasing-related metadata, this means the result can be applied to
282 // instructions whose aliasing relationship can be described either by "a" *or*
283 // by "b".
284 std::map<int, llvm::MDNode*> MergeMetadata(
285 llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
286 const std::map<int, llvm::MDNode*>& b);
287
288 // Dumps out `llvm_module` to the path specified in DebugOptions, if dumping is
289 // enabled for the given HLO module.
290 //
291 // A sanitized version of `hlo_module_name` is incorporated into the file name.
292 // If `optimized` is true then a suffix of "-with-opt.ll" is used, else a suffix
293 // of "-no-opt.ll" is used.
294 void DumpIrIfEnabled(const HloModule& hlo_module,
295 const llvm::Module& llvm_module, bool optimized,
296 absl::string_view filename_suffix = "");
297
298 void DumpIrIfEnabled(mlir::ModuleOp mlir_module, int unique_id,
299 const DebugOptions& debug_options);
300
301 llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
302 llvm::GlobalValue::LinkageTypes linkage,
303 const HloModuleConfig& module_config,
304 absl::string_view name, llvm::Module* module);
305
306 // Extracts the xla_backend_extra_options from `config` and passes those that
307 // don't start with xla_ to LLVM.
308 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config);
309
310 // Zero-extends two 32-bit values to 64 bits, multiplies them, and returns the
311 // result as a pair of (low 32 bits, high 32 bits).
312 std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b,
313 llvm::Value* src0,
314 llvm::Value* src1);
315 // Splits the 64-bit integer value into its high and low 32 bits.
316 std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s(
317 llvm::IRBuilder<>* b, llvm::Value* value_64bits);
318
319 // Checks whether a global variable is already created to represent the state
320 // of a random number generator. If not, creates such a variable. Returns the
321 // global variable.
322 llvm::GlobalVariable* GetOrCreateVariableRngState(llvm::Module* module,
323 llvm::IRBuilder<>* b);
324
325 // Adds a delta value to the global state variable and return the old value of
326 // the variable.
327 llvm::Value* RngGetAndUpdateState(uint64 delta, llvm::Module* module,
328 llvm::IRBuilder<>* b);
329
330 // Gets the LLVM address space that should be used for global variables (e.g.
331 // XLA's rng state).
332 unsigned GetGlobalMemoryAddressSpace();
333 } // namespace llvm_ir
334 } // namespace xla
335
336 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
337