• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <iterator>
17 
18 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
22 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
23 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 
26 namespace xla {
27 namespace cpu {
28 
GetComputeFunctionParams(llvm::Module * llvm_module,const int64 num_dynamic_loop_bounds)29 static std::vector<llvm::Type*> GetComputeFunctionParams(
30     llvm::Module* llvm_module, const int64 num_dynamic_loop_bounds) {
31   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(llvm_module->getContext());
32   llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo();
33   llvm::Type* i64_ptr_type =
34       llvm::Type::getInt64PtrTy(llvm_module->getContext());
35   std::vector<llvm::Type*> compute_function_params(
36       {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type});
37   if (num_dynamic_loop_bounds > 0) {
38     compute_function_params.push_back(i64_ptr_type);
39   }
40   compute_function_params.push_back(i64_ptr_type);
41   return compute_function_params;
42 }
43 
IrFunction(const string & function_name,llvm::Function::LinkageTypes linkage,const HloModuleConfig & module_config,llvm::Module * llvm_module,llvm::IRBuilder<> * b,int64 num_dynamic_loop_bounds)44 IrFunction::IrFunction(const string& function_name,
45                        llvm::Function::LinkageTypes linkage,
46                        const HloModuleConfig& module_config,
47                        llvm::Module* llvm_module, llvm::IRBuilder<>* b,
48                        int64 num_dynamic_loop_bounds)
49     : b_(b),
50       llvm_module_(llvm_module),
51       caller_insert_point_guard_(*b),
52       num_dynamic_loop_bounds_(num_dynamic_loop_bounds) {
53   Initialize(function_name, linkage, module_config);
54 }
55 
~IrFunction()56 IrFunction::~IrFunction() {
57   // Emit function return value.
58   b_->CreateRetVoid();
59 }
60 
GetDynamicLoopBounds()61 DynamicLoopBounds IrFunction::GetDynamicLoopBounds() {
62   DynamicLoopBounds dynamic_loop_bounds(num_dynamic_loop_bounds_);
63   for (int i = 0; i < num_dynamic_loop_bounds_; ++i) {
64     dynamic_loop_bounds[i].first = GetDynamicLoopBound(i * 2 + 0);
65     dynamic_loop_bounds[i].second = GetDynamicLoopBound(i * 2 + 1);
66   }
67   return dynamic_loop_bounds;
68 }
69 
Initialize(const string & function_name,llvm::Function::LinkageTypes linkage,const HloModuleConfig & module_config)70 void IrFunction::Initialize(const string& function_name,
71                             llvm::Function::LinkageTypes linkage,
72                             const HloModuleConfig& module_config) {
73   // The function signature is:
74   //   void function(i8* retval, i8* run_options, i8** params, i8**
75   //   buffer_table,
76   //                 i64* dynamic_loop_bounds, i64* prof_counters)
77   //
78   // For thread local functions:
79   //   retval: points to the returned value.
80   //   params: address of an array with pointers to parameters.
81   //   buffer_table: is null
82   //
83   // For global functions:
84   //   retval: is null
85   //   params: is null
86   //   buffer_table: address of an array with pointers to temporary buffers and
87   //     entry computation parameters (but not to constant buffers).
88   //
89   // Therefore, the generated function's signature (FunctionType) is statically
90   // determined - parameter unpacking is done in code generated into the
91   // function, rather than by a prologue dictated by the platform ABI.
92   //
93   //                      /--------------\
94   //   retval ----------> | return value |
95   //                      \--------------/
96   //
97   //                      /-------------------------------\
98   //   run_options -----> | xla::ExecutableRunOptions |
99   //                      \-------------------------------/
100   //
101   //                     /---------------------------------------------\
102   //   params -------->  |  param 0  |  param 1  | ..... |  param N-1  |
103   //                     |   addr    |   addr    |       |   addr      |
104   //                     \---------------------------------------------/
105   //                          |           |                   |
106   //                          |           |                   |
107   //                          V           V                   V
108   //                     /---------\  /---------\         /-----------\
109   //                     | param 0 |  | param 1 |         | param N-1 |
110   //                     \---------/  \---------/         \-----------/
111   //
112   //                     /---------------------------------------------\
113   //   buffer_table--->  |  buff  0  |  guff  1  | ..... |  buff  N-1  |
114   //                     |   addr    |   addr    |       |   addr      |
115   //                     \---------------------------------------------/
116   //                          |           |                   |
117   //                          |           |                   |
118   //                          V           V                   V
119   //                     /---------\  /---------\         /-----------\
120   //                     | temp  0 |  | temp  1 |         | temp  N-1 |
121   //                     \---------/  \---------/         \-----------/
122   //
123   //                        /--------------------------------------------\
124   // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....|
125   //  (elided for aot)      \--------------------------------------------/
126   //
127   //                     /---------------------------------------------\
128   //   prof counters ->  | counter 0 | counter 1 | ..... | counter N-1 |
129   //                     \---------------------------------------------/
130 
131   // Even though the type of params and buffer_table is void** in the host's
132   // view, in LLVM IR this is represented by i8*, similarly to void*. It's up to
133   // the code to use GEPs to unravel the indirection layers.
134   llvm::FunctionType* function_type = llvm::FunctionType::get(
135       /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()),
136       /*Params=*/
137       GetComputeFunctionParams(llvm_module_, num_dynamic_loop_bounds_),
138       /*isVarArg=*/false);
139 
140   // Functions with local linkage get an inlining bonus.  Because we know
141   // a-priori that embedded functions (non-entry functions) will not have its
142   // name resolved, give it local linkage.
143   function_ = llvm_ir::CreateCpuFunction(function_type, linkage, module_config,
144                                          function_name, llvm_module_);
145 
146   // Set meaningful names for the function's arguments: useful for debugging.
147   llvm::Function::arg_iterator arg_iter = function_->arg_begin();
148   arg_iter->setName("retval");
149   result_arg_ = &*arg_iter;
150   (++arg_iter)->setName("run_options");
151   exec_run_options_arg_ = &*arg_iter;
152   (++arg_iter)->setName("params");
153   parameters_arg_ = &*arg_iter;
154   (++arg_iter)->setName("buffer_table");
155   buffer_table_arg_ = &*arg_iter;
156   if (num_dynamic_loop_bounds_ > 0) {
157     (++arg_iter)->setName("dynamic_loop_bounds");
158     dynamic_loop_bounds_arg_ = &*arg_iter;
159   }
160   (++arg_iter)->setName("prof_counters");
161   profile_counters_arg_ = &*arg_iter;
162 
163   // We know a-priori that the function arguments are guaranteed to point to
164   // disjoint objects.
165   llvm::Argument* retval = result_arg();
166   for (llvm::Argument& argument : function_->args()) {
167     // However, the return buffer aliases the temporaries and thus cannot be
168     // marked noalias.
169     if (&argument == retval) {
170       continue;
171     }
172     function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias);
173   }
174 
175   b_->SetInsertPoint(llvm::BasicBlock::Create(
176       /*Context=*/llvm_module_->getContext(),
177       /*Name=*/"entry",
178       /*Parent=*/function_));
179 }
180 
GetDynamicLoopBound(const int64 offset)181 llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
182   CHECK_GT(num_dynamic_loop_bounds_, 0);
183   CHECK_LT(offset, num_dynamic_loop_bounds_ * 2);
184   string name = absl::StrCat("dynamic_loop_bound_", offset);
185   return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_),
186                                       b_->getInt64(offset), name));
187 }
188 
189 // Emits code to allocate an array of parameter address pointers, and store
190 // each address from 'parameter_addresses'.
191 // Returns an array of compute function call arguments (including parameter
192 // address buffer).
GetArrayFunctionCallArguments(absl::Span<llvm::Value * const> parameter_addresses,llvm::IRBuilder<> * b,absl::string_view name,llvm::Value * return_value_buffer,llvm::Value * exec_run_options_arg,llvm::Value * buffer_table_arg,llvm::Value * profile_counters_arg)193 std::vector<llvm::Value*> GetArrayFunctionCallArguments(
194     absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
195     absl::string_view name, llvm::Value* return_value_buffer,
196     llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
197     llvm::Value* profile_counters_arg) {
198   llvm::Value* parameter_addresses_buffer;
199 
200   if (parameter_addresses.empty()) {
201     parameter_addresses_buffer =
202         llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo());
203   } else {
204     parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
205         b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
206         absl::StrCat(name, "_parameter_addresses"), b);
207 
208     for (size_t i = 0; i < parameter_addresses.size(); ++i) {
209       llvm::Value* parameter_as_i8ptr = b->CreateBitCast(
210           parameter_addresses[i], b->getInt8PtrTy(),
211           absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr"));
212       llvm::Value* slot_in_param_addresses =
213           b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
214       b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
215     }
216   }
217 
218   const auto to_int8_ptr = [=](llvm::Value* ptr) {
219     return b->CreatePointerCast(ptr, b->getInt8PtrTy());
220   };
221   std::vector<llvm::Value*> arguments{
222       to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg),
223       parameter_addresses_buffer, buffer_table_arg};
224   if (profile_counters_arg != nullptr) {
225     arguments.push_back(profile_counters_arg);
226   }
227   return arguments;
228 }
229 
230 // Emits a call to a runtime fork/join function which dispatches parallel
231 // calls to 'parallel_function' (and joins threads before returning).
EmitCallToParallelForkJoin(const std::vector<llvm::Value * > & arguments,const Shape & shape,const std::vector<int64> & dimension_partition_counts,llvm::IRBuilder<> * b,llvm::Function * parallel_function,const string & name)232 Status EmitCallToParallelForkJoin(
233     const std::vector<llvm::Value*>& arguments, const Shape& shape,
234     const std::vector<int64>& dimension_partition_counts, llvm::IRBuilder<>* b,
235     llvm::Function* parallel_function, const string& name) {
236   llvm::Module* module = b->GetInsertBlock()->getModule();
237 
238   // Build ParallelForkJoin function type.
239   std::vector<llvm::Type*> compute_function_params =
240       GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0);
241   // Number of parallel compute functions.
242   compute_function_params.push_back(b->getInt32Ty());
243   // Array of partitions. There is an array element for each
244   // partition x partition_dim x 2 (for dimension start and limit).
245   compute_function_params.push_back(
246       llvm::Type::getInt64PtrTy(module->getContext()));
247   // Number of partitioned most-major dimensions in 'shape'.
248   compute_function_params.push_back(b->getInt32Ty());
249   // Function pointer for compute function to be dispatched in parallel.
250   compute_function_params.push_back(
251       llvm::Type::getInt8PtrTy(module->getContext()));
252 
253   llvm::FunctionType* fork_join_type = llvm::FunctionType::get(
254       /*Result=*/llvm::Type::getVoidTy(module->getContext()),
255       /*Params=*/compute_function_params,
256       /*isVarArg=*/false);
257 
258   llvm::Function* fork_join_func = llvm::dyn_cast<llvm::Function>(
259       module
260           ->getOrInsertFunction(runtime::kParallelForkJoinSymbolName,
261                                 fork_join_type)
262           .getCallee());
263   fork_join_func->setCallingConv(llvm::CallingConv::C);
264   fork_join_func->setDoesNotThrow();
265 
266   // Add common compute function arguments.
267   std::vector<llvm::Value*> fork_join_arguments(arguments);
268 
269   // Create ShapePartitionIterator to generate all partitions of 'shape'.
270   ShapePartitionIterator partition_iterator(shape, dimension_partition_counts);
271   const int64 num_partitions = partition_iterator.GetTotalPartitionCount();
272   // Add argument specifying the number of parallel partitions.
273   fork_join_arguments.push_back(b->getInt32(num_partitions));
274 
275   // The number of partitioned most-major dimensions in 'shape'.
276   const int32 num_partitioned_dims = dimension_partition_counts.size();
277   // A dimension partition consists of two elements: [start_index, limit_index).
278   const int32 dim_partition_size = 2;
279   // Calculate array partition stride.
280   const int32 array_partition_stride =
281       num_partitioned_dims * dim_partition_size;
282   // Calculate the total number of elements in the partition array.
283   const int32 partition_array_size =
284       dim_partition_size * num_partitioned_dims * num_partitions;
285 
286   // Store dimension partition values as llvm constants in 'partitions'.
287   // See comments in runtime_fork_join.cc for array layout description.
288   std::vector<llvm::Constant*> partitions(partition_array_size);
289   for (int32 i = 0; i < num_partitions; ++i) {
290     std::vector<std::pair<int64, int64>> dim_partitions =
291         partition_iterator.GetPartition(i);
292     CHECK_EQ(num_partitioned_dims, dim_partitions.size());
293     const int32 partition_index = i * array_partition_stride;
294     for (int32 j = 0; j < num_partitioned_dims; ++j) {
295       const std::pair<int64, int64>& dim_partition = dim_partitions[j];
296       const int32 index = partition_index + j * dim_partition_size;
297       // Store partition [dim_start, dim_limit) intervals for each dimension.
298       partitions[index] = b->getInt64(dim_partition.first);
299       partitions[index + 1] =
300           b->getInt64(dim_partition.first + dim_partition.second);
301     }
302   }
303 
304   // Create global variable out of dimension partitions in 'partitions'.
305   llvm::ArrayType* partitions_array_type =
306       llvm::ArrayType::get(b->getInt64Ty(), partition_array_size);
307   llvm::Constant* partitions_array =
308       llvm::ConstantArray::get(partitions_array_type, partitions);
309   llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable(
310       /*M=*/*module,
311       /*Ty=*/partitions_array_type,
312       /*isConstant=*/true,
313       /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
314       /*Initializer=*/partitions_array,
315       /*Name=*/
316       absl::StrCat(name, "_parallel_dimension_partitions"));
317 
318   // Add argument specifying parallel dimension partitions.
319   fork_join_arguments.push_back(
320       b->CreateBitCast(global_partitions_array,
321                        llvm::Type::getInt64PtrTy(module->getContext())));
322   // Add argument specifying the number of partitioned most-major dimensions.
323   fork_join_arguments.push_back(b->getInt32(num_partitioned_dims));
324   // Add argument for parallel compute function pointer.
325   fork_join_arguments.push_back(
326       b->CreateBitCast(parallel_function, b->getInt8PtrTy()));
327   // Emit call to parallel fork/join.
328   b->CreateCall(fork_join_func, fork_join_arguments);
329 
330   return Status::OK();
331 }
332 
333 }  // namespace cpu
334 }  // namespace xla
335