1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <iterator>
17
18 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
22 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
23 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25
26 namespace xla {
27 namespace cpu {
28
GetComputeFunctionParams(llvm::Module * llvm_module,const int64 num_dynamic_loop_bounds)29 static std::vector<llvm::Type*> GetComputeFunctionParams(
30 llvm::Module* llvm_module, const int64 num_dynamic_loop_bounds) {
31 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(llvm_module->getContext());
32 llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo();
33 llvm::Type* i64_ptr_type =
34 llvm::Type::getInt64PtrTy(llvm_module->getContext());
35 std::vector<llvm::Type*> compute_function_params(
36 {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type});
37 if (num_dynamic_loop_bounds > 0) {
38 compute_function_params.push_back(i64_ptr_type);
39 }
40 compute_function_params.push_back(i64_ptr_type);
41 return compute_function_params;
42 }
43
IrFunction(const string & function_name,llvm::Function::LinkageTypes linkage,const HloModuleConfig & module_config,llvm::Module * llvm_module,llvm::IRBuilder<> * b,int64 num_dynamic_loop_bounds)44 IrFunction::IrFunction(const string& function_name,
45 llvm::Function::LinkageTypes linkage,
46 const HloModuleConfig& module_config,
47 llvm::Module* llvm_module, llvm::IRBuilder<>* b,
48 int64 num_dynamic_loop_bounds)
49 : b_(b),
50 llvm_module_(llvm_module),
51 caller_insert_point_guard_(*b),
52 num_dynamic_loop_bounds_(num_dynamic_loop_bounds) {
53 Initialize(function_name, linkage, module_config);
54 }
55
~IrFunction()56 IrFunction::~IrFunction() {
57 // Emit function return value.
58 b_->CreateRetVoid();
59 }
60
GetDynamicLoopBounds()61 DynamicLoopBounds IrFunction::GetDynamicLoopBounds() {
62 DynamicLoopBounds dynamic_loop_bounds(num_dynamic_loop_bounds_);
63 for (int i = 0; i < num_dynamic_loop_bounds_; ++i) {
64 dynamic_loop_bounds[i].first = GetDynamicLoopBound(i * 2 + 0);
65 dynamic_loop_bounds[i].second = GetDynamicLoopBound(i * 2 + 1);
66 }
67 return dynamic_loop_bounds;
68 }
69
Initialize(const string & function_name,llvm::Function::LinkageTypes linkage,const HloModuleConfig & module_config)70 void IrFunction::Initialize(const string& function_name,
71 llvm::Function::LinkageTypes linkage,
72 const HloModuleConfig& module_config) {
73 // The function signature is:
74 // void function(i8* retval, i8* run_options, i8** params, i8**
75 // buffer_table,
76 // i64* dynamic_loop_bounds, i64* prof_counters)
77 //
78 // For thread local functions:
79 // retval: points to the returned value.
80 // params: address of an array with pointers to parameters.
81 // buffer_table: is null
82 //
83 // For global functions:
84 // retval: is null
85 // params: is null
86 // buffer_table: address of an array with pointers to temporary buffers and
87 // entry computation parameters (but not to constant buffers).
88 //
89 // Therefore, the generated function's signature (FunctionType) is statically
90 // determined - parameter unpacking is done in code generated into the
91 // function, rather than by a prologue dictated by the platform ABI.
92 //
93 // /--------------\
94 // retval ----------> | return value |
95 // \--------------/
96 //
97 // /-------------------------------\
98 // run_options -----> | xla::ExecutableRunOptions |
99 // \-------------------------------/
100 //
101 // /---------------------------------------------\
102 // params --------> | param 0 | param 1 | ..... | param N-1 |
103 // | addr | addr | | addr |
104 // \---------------------------------------------/
105 // | | |
106 // | | |
107 // V V V
108 // /---------\ /---------\ /-----------\
109 // | param 0 | | param 1 | | param N-1 |
110 // \---------/ \---------/ \-----------/
111 //
112 // /---------------------------------------------\
113 // buffer_table---> | buff 0 | guff 1 | ..... | buff N-1 |
114 // | addr | addr | | addr |
115 // \---------------------------------------------/
116 // | | |
117 // | | |
118 // V V V
119 // /---------\ /---------\ /-----------\
120 // | temp 0 | | temp 1 | | temp N-1 |
121 // \---------/ \---------/ \-----------/
122 //
123 // /--------------------------------------------\
124 // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....|
125 // (elided for aot) \--------------------------------------------/
126 //
127 // /---------------------------------------------\
128 // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 |
129 // \---------------------------------------------/
130
131 // Even though the type of params and buffer_table is void** in the host's
132 // view, in LLVM IR this is represented by i8*, similarly to void*. It's up to
133 // the code to use GEPs to unravel the indirection layers.
134 llvm::FunctionType* function_type = llvm::FunctionType::get(
135 /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()),
136 /*Params=*/
137 GetComputeFunctionParams(llvm_module_, num_dynamic_loop_bounds_),
138 /*isVarArg=*/false);
139
140 // Functions with local linkage get an inlining bonus. Because we know
141 // a-priori that embedded functions (non-entry functions) will not have its
142 // name resolved, give it local linkage.
143 function_ = llvm_ir::CreateCpuFunction(function_type, linkage, module_config,
144 function_name, llvm_module_);
145
146 // Set meaningful names for the function's arguments: useful for debugging.
147 llvm::Function::arg_iterator arg_iter = function_->arg_begin();
148 arg_iter->setName("retval");
149 result_arg_ = &*arg_iter;
150 (++arg_iter)->setName("run_options");
151 exec_run_options_arg_ = &*arg_iter;
152 (++arg_iter)->setName("params");
153 parameters_arg_ = &*arg_iter;
154 (++arg_iter)->setName("buffer_table");
155 buffer_table_arg_ = &*arg_iter;
156 if (num_dynamic_loop_bounds_ > 0) {
157 (++arg_iter)->setName("dynamic_loop_bounds");
158 dynamic_loop_bounds_arg_ = &*arg_iter;
159 }
160 (++arg_iter)->setName("prof_counters");
161 profile_counters_arg_ = &*arg_iter;
162
163 // We know a-priori that the function arguments are guaranteed to point to
164 // disjoint objects.
165 llvm::Argument* retval = result_arg();
166 for (llvm::Argument& argument : function_->args()) {
167 // However, the return buffer aliases the temporaries and thus cannot be
168 // marked noalias.
169 if (&argument == retval) {
170 continue;
171 }
172 function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias);
173 }
174
175 b_->SetInsertPoint(llvm::BasicBlock::Create(
176 /*Context=*/llvm_module_->getContext(),
177 /*Name=*/"entry",
178 /*Parent=*/function_));
179 }
180
GetDynamicLoopBound(const int64 offset)181 llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
182 CHECK_GT(num_dynamic_loop_bounds_, 0);
183 CHECK_LT(offset, num_dynamic_loop_bounds_ * 2);
184 string name = absl::StrCat("dynamic_loop_bound_", offset);
185 return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_),
186 b_->getInt64(offset), name));
187 }
188
EncodeArrayFunctionArguments(absl::Span<llvm::Value * const> arguments,absl::string_view name,llvm::IRBuilder<> * b)189 llvm::Value* EncodeArrayFunctionArguments(
190 absl::Span<llvm::Value* const> arguments, absl::string_view name,
191 llvm::IRBuilder<>* b) {
192 llvm::Value* arguments_buffer;
193 llvm::Type* int8ptr_ty = b->getInt8PtrTy();
194 if (arguments.empty()) {
195 arguments_buffer = llvm::Constant::getNullValue(int8ptr_ty->getPointerTo());
196 } else {
197 arguments_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
198 int8ptr_ty, b->getInt32(arguments.size()),
199 absl::StrCat(name, "_parameter_addresses"), b);
200
201 for (size_t i = 0; i < arguments.size(); i++) {
202 llvm::Value* parameter_as_i8ptr = b->CreateBitCast(
203 arguments[i], b->getInt8PtrTy(),
204 absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr"));
205 llvm::Value* slot_in_param_addresses =
206 b->CreateInBoundsGEP(arguments_buffer, {b->getInt64(i)});
207 b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
208 }
209 }
210 return arguments_buffer;
211 }
212
213 // Emits code to allocate an array of parameter address pointers, and store
214 // each address from 'parameter_addresses'.
215 // Returns an array of compute function call arguments (including parameter
216 // address buffer).
GetArrayFunctionCallArguments(absl::Span<llvm::Value * const> parameter_addresses,llvm::IRBuilder<> * b,absl::string_view name,llvm::Value * return_value_buffer,llvm::Value * exec_run_options_arg,llvm::Value * buffer_table_arg,llvm::Value * profile_counters_arg)217 std::vector<llvm::Value*> GetArrayFunctionCallArguments(
218 absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
219 absl::string_view name, llvm::Value* return_value_buffer,
220 llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
221 llvm::Value* profile_counters_arg) {
222 llvm::Value* parameter_addresses_buffer =
223 EncodeArrayFunctionArguments(parameter_addresses, name, b);
224
225 const auto to_int8_ptr = [=](llvm::Value* ptr) {
226 return b->CreatePointerCast(ptr, b->getInt8PtrTy());
227 };
228 std::vector<llvm::Value*> arguments{
229 to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg),
230 parameter_addresses_buffer, buffer_table_arg};
231 if (profile_counters_arg != nullptr) {
232 arguments.push_back(profile_counters_arg);
233 }
234 return arguments;
235 }
236
237 // Emits a call to a runtime fork/join function which dispatches parallel
238 // calls to 'parallel_function' (and joins threads before returning).
EmitCallToParallelForkJoin(const std::vector<llvm::Value * > & arguments,const Shape & shape,const std::vector<int64> & dimension_partition_counts,llvm::IRBuilder<> * b,llvm::Function * parallel_function,const string & name)239 Status EmitCallToParallelForkJoin(
240 const std::vector<llvm::Value*>& arguments, const Shape& shape,
241 const std::vector<int64>& dimension_partition_counts, llvm::IRBuilder<>* b,
242 llvm::Function* parallel_function, const string& name) {
243 llvm::Module* module = b->GetInsertBlock()->getModule();
244
245 // Build ParallelForkJoin function type.
246 std::vector<llvm::Type*> compute_function_params =
247 GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0);
248 // Number of parallel compute functions.
249 compute_function_params.push_back(b->getInt32Ty());
250 // Array of partitions. There is an array element for each
251 // partition x partition_dim x 2 (for dimension start and limit).
252 compute_function_params.push_back(
253 llvm::Type::getInt64PtrTy(module->getContext()));
254 // Number of partitioned most-major dimensions in 'shape'.
255 compute_function_params.push_back(b->getInt32Ty());
256 // Function pointer for compute function to be dispatched in parallel.
257 compute_function_params.push_back(
258 llvm::Type::getInt8PtrTy(module->getContext()));
259
260 llvm::FunctionType* fork_join_type = llvm::FunctionType::get(
261 /*Result=*/llvm::Type::getVoidTy(module->getContext()),
262 /*Params=*/compute_function_params,
263 /*isVarArg=*/false);
264
265 llvm::Function* fork_join_func = llvm::dyn_cast<llvm::Function>(
266 module
267 ->getOrInsertFunction(runtime::kParallelForkJoinSymbolName,
268 fork_join_type)
269 .getCallee());
270 fork_join_func->setCallingConv(llvm::CallingConv::C);
271 fork_join_func->setDoesNotThrow();
272
273 // Add common compute function arguments.
274 std::vector<llvm::Value*> fork_join_arguments(arguments);
275
276 // Create ShapePartitionIterator to generate all partitions of 'shape'.
277 ShapePartitionIterator partition_iterator(shape, dimension_partition_counts);
278 const int64 num_partitions = partition_iterator.GetTotalPartitionCount();
279 // Add argument specifying the number of parallel partitions.
280 fork_join_arguments.push_back(b->getInt32(num_partitions));
281
282 // The number of partitioned most-major dimensions in 'shape'.
283 const int32 num_partitioned_dims = dimension_partition_counts.size();
284 // A dimension partition consists of two elements: [start_index, limit_index).
285 const int32 dim_partition_size = 2;
286 // Calculate array partition stride.
287 const int32 array_partition_stride =
288 num_partitioned_dims * dim_partition_size;
289 // Calculate the total number of elements in the partition array.
290 const int32 partition_array_size =
291 dim_partition_size * num_partitioned_dims * num_partitions;
292
293 // Store dimension partition values as llvm constants in 'partitions'.
294 // See comments in runtime_fork_join.cc for array layout description.
295 std::vector<llvm::Constant*> partitions(partition_array_size);
296 for (int32 i = 0; i < num_partitions; ++i) {
297 std::vector<std::pair<int64, int64>> dim_partitions =
298 partition_iterator.GetPartition(i);
299 CHECK_EQ(num_partitioned_dims, dim_partitions.size());
300 const int32 partition_index = i * array_partition_stride;
301 for (int32 j = 0; j < num_partitioned_dims; ++j) {
302 const std::pair<int64, int64>& dim_partition = dim_partitions[j];
303 const int32 index = partition_index + j * dim_partition_size;
304 // Store partition [dim_start, dim_limit) intervals for each dimension.
305 partitions[index] = b->getInt64(dim_partition.first);
306 partitions[index + 1] =
307 b->getInt64(dim_partition.first + dim_partition.second);
308 }
309 }
310
311 // Create global variable out of dimension partitions in 'partitions'.
312 llvm::ArrayType* partitions_array_type =
313 llvm::ArrayType::get(b->getInt64Ty(), partition_array_size);
314 llvm::Constant* partitions_array =
315 llvm::ConstantArray::get(partitions_array_type, partitions);
316 llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable(
317 /*M=*/*module,
318 /*Ty=*/partitions_array_type,
319 /*isConstant=*/true,
320 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
321 /*Initializer=*/partitions_array,
322 /*Name=*/
323 absl::StrCat(name, "_parallel_dimension_partitions"));
324
325 // Add argument specifying parallel dimension partitions.
326 fork_join_arguments.push_back(
327 b->CreateBitCast(global_partitions_array,
328 llvm::Type::getInt64PtrTy(module->getContext())));
329 // Add argument specifying the number of partitioned most-major dimensions.
330 fork_join_arguments.push_back(b->getInt32(num_partitioned_dims));
331 // Add argument for parallel compute function pointer.
332 fork_join_arguments.push_back(
333 b->CreateBitCast(parallel_function, b->getInt8PtrTy()));
334 // Emit call to parallel fork/join.
335 b->CreateCall(fork_join_func, fork_join_arguments);
336
337 return Status::OK();
338 }
339
340 } // namespace cpu
341 } // namespace xla
342