1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ 18 19 #include <cassert> 20 #include <string> 21 22 #include "tensorflow/compiler/xla/cpu_function_runtime.h" 23 #include "tensorflow/compiler/xla/executable_run_options.h" 24 #include "tensorflow/core/platform/types.h" 25 26 // Forward-declare, rather than include, to reduce code size for users that 27 // never use this functionality. 28 namespace xla { 29 class ProgramShapeProto; 30 class HloProfilePrinterData; 31 } 32 33 namespace tensorflow { 34 35 // Represents a function compiled by XLA, produced via either JIT or AOT. 36 // 37 // The Run method invokes the actual computation, with inputs read from arg 38 // buffers, and outputs written to result buffers. Each Run call may also use a 39 // set of temporary buffers for the computation. 40 // 41 // By default each instance of this class manages its own arg, result and temp 42 // buffers. The AllocMode constructor parameter may be used to modify the buffer 43 // allocation strategy. 44 // 45 // Under the default allocation strategy, this class is thread-compatible: 46 // o Calls to non-const methods require exclusive access to the object. 47 // o Concurrent calls to const methods are OK, if those calls are made while it 48 // is guaranteed that no thread may call a non-const method. 49 class XlaCompiledCpuFunction { 50 public: 51 // Type of the raw function, produced by either JIT or AOT. 52 using RawFunction = void (*)(void* result, 53 const xla::ExecutableRunOptions* run_options, 54 const void** args, void** temps, 55 int64* profile_counters); 56 57 // StaticData represents the state necessary to run an XLA-compiled 58 // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for 59 // AOT this is backed by data compiled into the object file. 60 // 61 // The contents of StaticData are XLA-internal implementation details and 62 // should not be relied on by clients (and therefore are private). 63 class StaticData { 64 private: 65 // The raw function to call. 66 RawFunction raw_function_; 67 68 // Contains information about the buffers used by the XLA computation. 69 const xla::cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr; 70 size_t num_buffers_ = 0; 71 72 // Entry parameter i is described by 73 // buffer_infos[arg_index_table[i]]. 74 const int32* arg_index_table_ = nullptr; 75 76 // There are num_args entry parameters. 77 int64 num_args_ = 0; 78 79 // There are num_variables variables. 80 int64 num_variables_ = 0; 81 82 // The 0-based index of the result tuple, in the temp buffers. 83 size_t result_index_ = 0; 84 85 // [Optional] Arrays of arg and result names. These are arrays of C-style 86 // strings, where the array is terminated by nullptr. 87 const char** arg_names_ = nullptr; 88 const char** variable_names_ = nullptr; 89 const char** result_names_ = nullptr; 90 91 // [Optional] Arg and result shapes. 92 const xla::ProgramShapeProto* program_shape_ = nullptr; 93 94 // [Optional] Profile printer data. Null if profiling is disabled. 95 const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; 96 97 // [Optional] The number of profile counters expected in the profile counter 98 // buffer by the generated code and hlo_profile_printer. 0 if profiling is 99 // disabled. This information is already present in 100 // hlo_profile_printer_data but xla::HloProfilePrinterData is forward 101 // declared so we don't have access to that information here. 102 int64 profile_counters_size_ = 0; 103 104 // Only XlaCompiledCpuFunction is allowed to read and write the above 105 // fields. 106 friend class XlaCompiledCpuFunction; 107 }; 108 109 // AllocMode controls the buffer allocation mode. 110 enum class AllocMode { 111 // Allocate all buffers - args, results, profile and temps. 112 ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS, 113 114 // Only allocate result, profile and temp buffers. 115 // Use set_arg_data to set argument buffers before Run is called. 116 RESULTS_PROFILES_AND_TEMPS_ONLY, 117 }; 118 119 explicit XlaCompiledCpuFunction( 120 const StaticData& static_data, 121 AllocMode alloc_mode = 122 AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS); 123 virtual ~XlaCompiledCpuFunction(); 124 125 XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; 126 XlaCompiledCpuFunction& operator=(const XlaCompiledCpuFunction&) = delete; 127 128 // Sets the intra-op thread pool used to run individual ops concurrently. set_thread_pool(const Eigen::ThreadPoolDevice * pool)129 void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { 130 run_options_.set_intra_op_thread_pool(pool); 131 } 132 133 // Runs the computation, with inputs read from arg buffers, and outputs 134 // written to result buffers. Returns true on success and false on failure. 135 bool Run(); 136 137 // Returns the error message from the previous failed Run call. 138 // 139 // TODO(fschneider): For now this always returns an empty string because there 140 // is no support for error reporting in XLA. Remove this once all callers are 141 // updated. error_msg()142 string error_msg() const { return {}; } 143 144 // ------------------------------ 145 // Arg methods for managing input buffers. Buffers are in row-major order. 146 147 // Returns the buffer for the positional argument at the given `index`. arg_data(size_t index)148 void* arg_data(size_t index) { 149 return buffer_table_[arg_index_table_[index]]; 150 } arg_data(size_t index)151 const void* arg_data(size_t index) const { 152 return buffer_table_[arg_index_table_[index]]; 153 } 154 num_args()155 int num_args() const { return num_args_; } 156 num_variables()157 int num_variables() const { return num_variables_; } 158 159 // Returns the size of entry parameter `idx`. 160 // 161 // There is a static version of this method on tfcompile generated subclasses 162 // of XlaCompiledCpuFunction, but try to prefer this when possible since it 163 // works both for XlaJitCompiledCpuFunction and AOT compiled subclasses. arg_size(int idx)164 int arg_size(int idx) const { 165 assert(idx < num_args()); 166 return buffer_infos_[arg_index_table_[idx]].size(); 167 } 168 169 // Sets the buffer for the positional argument at the given `index` to `data`. 170 // Must be called before Run to have an effect. May be called under any 171 // AllocMode; if the AllocMode is RESULTS_AND_TEMPS_ONLY, this method must be 172 // called for each positional argument, in order to set the argument buffers. 173 // 174 // Allocated memory must be aligned to the size specified by 175 // xla::cpu_function_runtime::kMinAlign. If possible, use the functions in 176 // tensorflow/compiler/tf2xla/cpu_function_runtime.h to ensure correct 177 // alignment. 178 // 179 // Aliasing of argument and result buffers is not allowed, and results in 180 // undefined behavior. set_arg_data(size_t index,const void * data)181 void set_arg_data(size_t index, const void* data) { 182 assert((arg_size(index) < xla::cpu_function_runtime::kMinAlign || 183 (uintptr_t)data % xla::cpu_function_runtime::kMinAlign == 0) && 184 "Underaligned pointer!"); 185 // The const_cast is safe because the generated code does not write to arg 186 // buffers. 187 // 188 // buffer_table_ contains pointers to buffers that _will_ be written to by 189 // generated code so it would be misleading to make buffer_table_ a `const 190 // void**`. 191 buffer_table_[arg_index_table_[index]] = const_cast<void*>(data); 192 } 193 194 // ------------------------------ 195 // Result methods for managing output buffers. Buffers are in row-major order. 196 // Must only be called after a successful Run call. Unlike the arg methods, 197 // there is no set_resultN_data method. The result buffers are managed 198 // internally, and may change after each call to Run. 199 200 // Returns the underlying array of result buffers, where results()[I] is the 201 // buffer for the positional result at index I. results()202 void** results() { return static_cast<void**>(buffer_table_[result_index_]); } results()203 const void* const* results() const { 204 return static_cast<const void* const*>(buffer_table_[result_index_]); 205 } 206 207 // Profile counters for this XLA computation. 208 // 209 // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in 210 // this case) these counters are non-null and are automatically populated by 211 // `Run`. The counters can then be pretty-printed using 212 // `hlo_profile_printer()`. 213 // 214 // When Hlo profiling is disabled, this accessor returns null. profile_counters()215 const int64* profile_counters() const { return profile_counters_; } 216 217 // Returns the buffer for the positional result at the given `index`. result_data(size_t index)218 void* result_data(size_t index) { return results()[index]; } result_data(size_t index)219 const void* result_data(size_t index) const { return results()[index]; } 220 221 // ------------------------------ 222 // Methods for extracting optional metadata. 223 224 // Returns true iff data is available for the Lookup{Arg,Variable,Result}Index 225 // methods. E.g. the data might not be compiled into the binary for AOT. HasNameIndices()226 bool HasNameIndices() const { 227 return arg_names_ != nullptr && variable_names_ != nullptr && 228 result_names_ != nullptr; 229 } 230 231 // Returns the 0-based index for the argument with the given `name`. 232 // Returns -1 if the name wasn't found, or data isn't available. 233 // 234 // The index remains constant for every instance of XlaCompiledCpuFunction 235 // generated from the same static data, and might not be cheap to determine. 236 // Recommended usage is to capture this in a variable for re-use. 237 int LookupArgIndex(const string& name) const; 238 239 // Returns the 0-based index for the variable with the given `name`. 240 // Returns -1 if the name wasn't found, or data isn't available. 241 // 242 // The index remains constant for every instance of XlaCompiledCpuFunction 243 // generated from the same static data, and might not be cheap to determine. 244 // Recommended usage is to capture this in a variable for re-use. 245 int LookupVariableIndex(const string& name) const; 246 247 // Returns the 0-based index for the result with the given `name`. 248 // Returns -1 if the name wasn't found, or data isn't available. 249 // 250 // The index remains constant for every instance of XlaCompiledCpuFunction 251 // generated from the same static data, and might not be cheap to determine. 252 // Recommended usage is to capture this in a variable for re-use. 253 int LookupResultIndex(const string& name) const; 254 255 // Returns the shape of the args and results. May return nullptr if the 256 // program shape isn't available. ProgramShape()257 const xla::ProgramShapeProto* ProgramShape() const { return program_shape_; } 258 hlo_profiling_enabled()259 bool hlo_profiling_enabled() const { 260 return hlo_profile_printer_data_ != nullptr; 261 } hlo_profile_printer_data()262 const xla::HloProfilePrinterData& hlo_profile_printer_data() const { 263 assert(hlo_profiling_enabled()); 264 return *hlo_profile_printer_data_; 265 } 266 267 protected: 268 // --------------------------------------------------------------------------- 269 // Accessors for reading from and writing to instances of `StaticData`. 270 // 271 // Classes generated by tfcompile can call these because the generated classes 272 // inherit from `XlaCompiledCpuFunction`. `XlaJitCompiledCpuFunction` can 273 // call these because it is explicitly added as a friend. 274 set_static_data_raw_function(StaticData * static_data,RawFunction raw_function)275 static void set_static_data_raw_function(StaticData* static_data, 276 RawFunction raw_function) { 277 static_data->raw_function_ = raw_function; 278 } 279 set_static_data_buffer_infos(StaticData * static_data,const xla::cpu_function_runtime::BufferInfo * buffer_infos)280 static void set_static_data_buffer_infos( 281 StaticData* static_data, 282 const xla::cpu_function_runtime::BufferInfo* buffer_infos) { 283 static_data->buffer_infos_ = buffer_infos; 284 } 285 set_static_data_num_buffers(StaticData * static_data,size_t num_buffers)286 static void set_static_data_num_buffers(StaticData* static_data, 287 size_t num_buffers) { 288 static_data->num_buffers_ = num_buffers; 289 } 290 set_static_data_arg_index_table(StaticData * static_data,const int32 * arg_index_table)291 static void set_static_data_arg_index_table(StaticData* static_data, 292 const int32* arg_index_table) { 293 static_data->arg_index_table_ = arg_index_table; 294 } 295 set_static_data_num_args(StaticData * static_data,int64 num_args)296 static void set_static_data_num_args(StaticData* static_data, 297 int64 num_args) { 298 static_data->num_args_ = num_args; 299 } 300 set_static_data_num_variables(StaticData * static_data,int64 num_variables)301 static void set_static_data_num_variables(StaticData* static_data, 302 int64 num_variables) { 303 static_data->num_variables_ = num_variables; 304 } 305 set_static_data_result_index(StaticData * static_data,size_t result_index)306 static void set_static_data_result_index(StaticData* static_data, 307 size_t result_index) { 308 static_data->result_index_ = result_index; 309 } 310 set_static_data_arg_names(StaticData * static_data,const char ** arg_names)311 static void set_static_data_arg_names(StaticData* static_data, 312 const char** arg_names) { 313 static_data->arg_names_ = arg_names; 314 } 315 set_static_data_variable_names(StaticData * static_data,const char ** variable_names)316 static void set_static_data_variable_names(StaticData* static_data, 317 const char** variable_names) { 318 static_data->variable_names_ = variable_names; 319 } 320 set_static_data_result_names(StaticData * static_data,const char ** result_names)321 static void set_static_data_result_names(StaticData* static_data, 322 const char** result_names) { 323 static_data->result_names_ = result_names; 324 } 325 set_static_data_program_shape(StaticData * static_data,const xla::ProgramShapeProto * program_shape)326 static void set_static_data_program_shape( 327 StaticData* static_data, const xla::ProgramShapeProto* program_shape) { 328 static_data->program_shape_ = program_shape; 329 } 330 set_static_data_hlo_profile_printer_data(StaticData * static_data,const xla::HloProfilePrinterData * hlo_profile_printer_data)331 static void set_static_data_hlo_profile_printer_data( 332 StaticData* static_data, 333 const xla::HloProfilePrinterData* hlo_profile_printer_data) { 334 static_data->hlo_profile_printer_data_ = hlo_profile_printer_data; 335 } 336 337 static const xla::HloProfilePrinterData* get_static_data_hlo_profile_printer_data(StaticData * static_data)338 get_static_data_hlo_profile_printer_data(StaticData* static_data) { 339 return static_data->hlo_profile_printer_data_; 340 } 341 set_static_data_profile_counters_size(StaticData * static_data,int64 profile_counters_size)342 static void set_static_data_profile_counters_size( 343 StaticData* static_data, int64 profile_counters_size) { 344 static_data->profile_counters_size_ = profile_counters_size; 345 } 346 347 private: 348 const RawFunction raw_function_; 349 const size_t result_index_; 350 351 // Array containing pointers to argument and temp buffers (slots corresponding 352 // to constant and on-stack buffers are null). 353 void** const buffer_table_; 354 355 // Describes the buffers used by the XLA computation. 356 const xla::cpu_function_runtime::BufferInfo* const buffer_infos_; 357 358 // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]] 359 // for XLA generated code to be able to find it. 360 const int32* const arg_index_table_; 361 362 // The number of incoming arguments. 363 const int32 num_args_; 364 365 // The number of incoming variables. 366 const int32 num_variables_; 367 368 // Backing memory for buffer_table_ and args_, the latter depending on 369 // AllocMode. 370 void* alloc_buffer_table_ = nullptr; 371 372 // Backing memory for profiling counters. 373 int64* profile_counters_ = nullptr; 374 375 // Options and context passed to the compiled function. 376 xla::ExecutableRunOptions run_options_; 377 378 // Optional metadata. 379 const char** arg_names_ = nullptr; 380 const char** variable_names_ = nullptr; 381 const char** result_names_ = nullptr; 382 const xla::ProgramShapeProto* program_shape_ = nullptr; 383 const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; 384 385 // Add `XlaJitCompiledCpuFunction` as a friend so that it can access the 386 // `set_static_data_*` static methods above. 387 friend class XlaJitCompiledCpuFunction; 388 }; 389 390 } // namespace tensorflow 391 392 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ 393