1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <iterator>
17
18 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
22 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
23 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25
26 namespace xla {
27 namespace cpu {
28
GetComputeFunctionParams(llvm::Module * llvm_module,const int64 num_dynamic_loop_bounds)29 static std::vector<llvm::Type*> GetComputeFunctionParams(
30 llvm::Module* llvm_module, const int64 num_dynamic_loop_bounds) {
31 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(llvm_module->getContext());
32 llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo();
33 llvm::Type* i64_ptr_type =
34 llvm::Type::getInt64PtrTy(llvm_module->getContext());
35 std::vector<llvm::Type*> compute_function_params(
36 {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type});
37 if (num_dynamic_loop_bounds > 0) {
38 compute_function_params.push_back(i64_ptr_type);
39 }
40 compute_function_params.push_back(i64_ptr_type);
41 return compute_function_params;
42 }
43
IrFunction(const string & function_name,llvm::Function::LinkageTypes linkage,const HloModuleConfig & module_config,llvm::Module * llvm_module,llvm::IRBuilder<> * b,int64 num_dynamic_loop_bounds)44 IrFunction::IrFunction(const string& function_name,
45 llvm::Function::LinkageTypes linkage,
46 const HloModuleConfig& module_config,
47 llvm::Module* llvm_module, llvm::IRBuilder<>* b,
48 int64 num_dynamic_loop_bounds)
49 : b_(b),
50 llvm_module_(llvm_module),
51 caller_insert_point_guard_(*b),
52 num_dynamic_loop_bounds_(num_dynamic_loop_bounds) {
53 Initialize(function_name, linkage, module_config);
54 }
55
~IrFunction()56 IrFunction::~IrFunction() {
57 // Emit function return value.
58 b_->CreateRetVoid();
59 }
60
GetDynamicLoopBounds()61 DynamicLoopBounds IrFunction::GetDynamicLoopBounds() {
62 DynamicLoopBounds dynamic_loop_bounds(num_dynamic_loop_bounds_);
63 for (int i = 0; i < num_dynamic_loop_bounds_; ++i) {
64 dynamic_loop_bounds[i].first = GetDynamicLoopBound(i * 2 + 0);
65 dynamic_loop_bounds[i].second = GetDynamicLoopBound(i * 2 + 1);
66 }
67 return dynamic_loop_bounds;
68 }
69
Initialize(const string & function_name,llvm::Function::LinkageTypes linkage,const HloModuleConfig & module_config)70 void IrFunction::Initialize(const string& function_name,
71 llvm::Function::LinkageTypes linkage,
72 const HloModuleConfig& module_config) {
73 // The function signature is:
74 // void function(i8* retval, i8* run_options, i8** params, i8**
75 // buffer_table,
76 // i64* dynamic_loop_bounds, i64* prof_counters)
77 //
78 // For thread local functions:
79 // retval: points to the returned value.
80 // params: address of an array with pointers to parameters.
81 // buffer_table: is null
82 //
83 // For global functions:
84 // retval: is null
85 // params: is null
86 // buffer_table: address of an array with pointers to temporary buffers and
87 // entry computation parameters (but not to constant buffers).
88 //
89 // Therefore, the generated function's signature (FunctionType) is statically
90 // determined - parameter unpacking is done in code generated into the
91 // function, rather than by a prologue dictated by the platform ABI.
92 //
93 // /--------------\
94 // retval ----------> | return value |
95 // \--------------/
96 //
97 // /-------------------------------\
98 // run_options -----> | xla::ExecutableRunOptions |
99 // \-------------------------------/
100 //
101 // /---------------------------------------------\
102 // params --------> | param 0 | param 1 | ..... | param N-1 |
103 // | addr | addr | | addr |
104 // \---------------------------------------------/
105 // | | |
106 // | | |
107 // V V V
108 // /---------\ /---------\ /-----------\
109 // | param 0 | | param 1 | | param N-1 |
110 // \---------/ \---------/ \-----------/
111 //
112 // /---------------------------------------------\
113 // buffer_table---> | buff 0 | guff 1 | ..... | buff N-1 |
114 // | addr | addr | | addr |
115 // \---------------------------------------------/
116 // | | |
117 // | | |
118 // V V V
119 // /---------\ /---------\ /-----------\
120 // | temp 0 | | temp 1 | | temp N-1 |
121 // \---------/ \---------/ \-----------/
122 //
123 // /--------------------------------------------\
124 // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....|
125 // (elided for aot) \--------------------------------------------/
126 //
127 // /---------------------------------------------\
128 // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 |
129 // \---------------------------------------------/
130
131 // Even though the type of params and buffer_table is void** in the host's
132 // view, in LLVM IR this is represented by i8*, similarly to void*. It's up to
133 // the code to use GEPs to unravel the indirection layers.
134 llvm::FunctionType* function_type = llvm::FunctionType::get(
135 /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()),
136 /*Params=*/
137 GetComputeFunctionParams(llvm_module_, num_dynamic_loop_bounds_),
138 /*isVarArg=*/false);
139
140 // Functions with local linkage get an inlining bonus. Because we know
141 // a-priori that embedded functions (non-entry functions) will not have its
142 // name resolved, give it local linkage.
143 function_ = llvm_ir::CreateCpuFunction(function_type, linkage, module_config,
144 function_name, llvm_module_);
145
146 // Set meaningful names for the function's arguments: useful for debugging.
147 llvm::Function::arg_iterator arg_iter = function_->arg_begin();
148 arg_iter->setName("retval");
149 result_arg_ = &*arg_iter;
150 (++arg_iter)->setName("run_options");
151 exec_run_options_arg_ = &*arg_iter;
152 (++arg_iter)->setName("params");
153 parameters_arg_ = &*arg_iter;
154 (++arg_iter)->setName("buffer_table");
155 buffer_table_arg_ = &*arg_iter;
156 if (num_dynamic_loop_bounds_ > 0) {
157 (++arg_iter)->setName("dynamic_loop_bounds");
158 dynamic_loop_bounds_arg_ = &*arg_iter;
159 }
160 (++arg_iter)->setName("prof_counters");
161 profile_counters_arg_ = &*arg_iter;
162
163 // We know a-priori that the function arguments are guaranteed to point to
164 // disjoint objects.
165 llvm::Argument* retval = result_arg();
166 for (llvm::Argument& argument : function_->args()) {
167 // However, the return buffer aliases the temporaries and thus cannot be
168 // marked noalias.
169 if (&argument == retval) {
170 continue;
171 }
172 function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias);
173 }
174
175 b_->SetInsertPoint(llvm::BasicBlock::Create(
176 /*Context=*/llvm_module_->getContext(),
177 /*Name=*/"entry",
178 /*Parent=*/function_));
179 }
180
GetDynamicLoopBound(const int64 offset)181 llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
182 CHECK_GT(num_dynamic_loop_bounds_, 0);
183 CHECK_LT(offset, num_dynamic_loop_bounds_ * 2);
184 string name = absl::StrCat("dynamic_loop_bound_", offset);
185 return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_),
186 b_->getInt64(offset), name));
187 }
188
189 // Emits code to allocate an array of parameter address pointers, and store
190 // each address from 'parameter_addresses'.
191 // Returns an array of compute function call arguments (including parameter
192 // address buffer).
GetArrayFunctionCallArguments(absl::Span<llvm::Value * const> parameter_addresses,llvm::IRBuilder<> * b,absl::string_view name,llvm::Value * return_value_buffer,llvm::Value * exec_run_options_arg,llvm::Value * buffer_table_arg,llvm::Value * profile_counters_arg)193 std::vector<llvm::Value*> GetArrayFunctionCallArguments(
194 absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
195 absl::string_view name, llvm::Value* return_value_buffer,
196 llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
197 llvm::Value* profile_counters_arg) {
198 llvm::Value* parameter_addresses_buffer;
199
200 if (parameter_addresses.empty()) {
201 parameter_addresses_buffer =
202 llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo());
203 } else {
204 parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
205 b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
206 absl::StrCat(name, "_parameter_addresses"), b);
207
208 for (size_t i = 0; i < parameter_addresses.size(); ++i) {
209 llvm::Value* parameter_as_i8ptr = b->CreateBitCast(
210 parameter_addresses[i], b->getInt8PtrTy(),
211 absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr"));
212 llvm::Value* slot_in_param_addresses =
213 b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
214 b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
215 }
216 }
217
218 const auto to_int8_ptr = [=](llvm::Value* ptr) {
219 return b->CreatePointerCast(ptr, b->getInt8PtrTy());
220 };
221 std::vector<llvm::Value*> arguments{
222 to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg),
223 parameter_addresses_buffer, buffer_table_arg};
224 if (profile_counters_arg != nullptr) {
225 arguments.push_back(profile_counters_arg);
226 }
227 return arguments;
228 }
229
230 // Emits a call to a runtime fork/join function which dispatches parallel
231 // calls to 'parallel_function' (and joins threads before returning).
EmitCallToParallelForkJoin(const std::vector<llvm::Value * > & arguments,const Shape & shape,const std::vector<int64> & dimension_partition_counts,llvm::IRBuilder<> * b,llvm::Function * parallel_function,const string & name)232 Status EmitCallToParallelForkJoin(
233 const std::vector<llvm::Value*>& arguments, const Shape& shape,
234 const std::vector<int64>& dimension_partition_counts, llvm::IRBuilder<>* b,
235 llvm::Function* parallel_function, const string& name) {
236 llvm::Module* module = b->GetInsertBlock()->getModule();
237
238 // Build ParallelForkJoin function type.
239 std::vector<llvm::Type*> compute_function_params =
240 GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0);
241 // Number of parallel compute functions.
242 compute_function_params.push_back(b->getInt32Ty());
243 // Array of partitions. There is an array element for each
244 // partition x partition_dim x 2 (for dimension start and limit).
245 compute_function_params.push_back(
246 llvm::Type::getInt64PtrTy(module->getContext()));
247 // Number of partitioned most-major dimensions in 'shape'.
248 compute_function_params.push_back(b->getInt32Ty());
249 // Function pointer for compute function to be dispatched in parallel.
250 compute_function_params.push_back(
251 llvm::Type::getInt8PtrTy(module->getContext()));
252
253 llvm::FunctionType* fork_join_type = llvm::FunctionType::get(
254 /*Result=*/llvm::Type::getVoidTy(module->getContext()),
255 /*Params=*/compute_function_params,
256 /*isVarArg=*/false);
257
258 llvm::Function* fork_join_func = llvm::dyn_cast<llvm::Function>(
259 module
260 ->getOrInsertFunction(runtime::kParallelForkJoinSymbolName,
261 fork_join_type)
262 .getCallee());
263 fork_join_func->setCallingConv(llvm::CallingConv::C);
264 fork_join_func->setDoesNotThrow();
265
266 // Add common compute function arguments.
267 std::vector<llvm::Value*> fork_join_arguments(arguments);
268
269 // Create ShapePartitionIterator to generate all partitions of 'shape'.
270 ShapePartitionIterator partition_iterator(shape, dimension_partition_counts);
271 const int64 num_partitions = partition_iterator.GetTotalPartitionCount();
272 // Add argument specifying the number of parallel partitions.
273 fork_join_arguments.push_back(b->getInt32(num_partitions));
274
275 // The number of partitioned most-major dimensions in 'shape'.
276 const int32 num_partitioned_dims = dimension_partition_counts.size();
277 // A dimension partition consists of two elements: [start_index, limit_index).
278 const int32 dim_partition_size = 2;
279 // Calculate array partition stride.
280 const int32 array_partition_stride =
281 num_partitioned_dims * dim_partition_size;
282 // Calculate the total number of elements in the partition array.
283 const int32 partition_array_size =
284 dim_partition_size * num_partitioned_dims * num_partitions;
285
286 // Store dimension partition values as llvm constants in 'partitions'.
287 // See comments in runtime_fork_join.cc for array layout description.
288 std::vector<llvm::Constant*> partitions(partition_array_size);
289 for (int32 i = 0; i < num_partitions; ++i) {
290 std::vector<std::pair<int64, int64>> dim_partitions =
291 partition_iterator.GetPartition(i);
292 CHECK_EQ(num_partitioned_dims, dim_partitions.size());
293 const int32 partition_index = i * array_partition_stride;
294 for (int32 j = 0; j < num_partitioned_dims; ++j) {
295 const std::pair<int64, int64>& dim_partition = dim_partitions[j];
296 const int32 index = partition_index + j * dim_partition_size;
297 // Store partition [dim_start, dim_limit) intervals for each dimension.
298 partitions[index] = b->getInt64(dim_partition.first);
299 partitions[index + 1] =
300 b->getInt64(dim_partition.first + dim_partition.second);
301 }
302 }
303
304 // Create global variable out of dimension partitions in 'partitions'.
305 llvm::ArrayType* partitions_array_type =
306 llvm::ArrayType::get(b->getInt64Ty(), partition_array_size);
307 llvm::Constant* partitions_array =
308 llvm::ConstantArray::get(partitions_array_type, partitions);
309 llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable(
310 /*M=*/*module,
311 /*Ty=*/partitions_array_type,
312 /*isConstant=*/true,
313 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
314 /*Initializer=*/partitions_array,
315 /*Name=*/
316 absl::StrCat(name, "_parallel_dimension_partitions"));
317
318 // Add argument specifying parallel dimension partitions.
319 fork_join_arguments.push_back(
320 b->CreateBitCast(global_partitions_array,
321 llvm::Type::getInt64PtrTy(module->getContext())));
322 // Add argument specifying the number of partitioned most-major dimensions.
323 fork_join_arguments.push_back(b->getInt32(num_partitioned_dims));
324 // Add argument for parallel compute function pointer.
325 fork_join_arguments.push_back(
326 b->CreateBitCast(parallel_function, b->getInt8PtrTy()));
327 // Emit call to parallel fork/join.
328 b->CreateCall(fork_join_func, fork_join_arguments);
329
330 return Status::OK();
331 }
332
333 } // namespace cpu
334 } // namespace xla
335