1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ 16 #define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ 17 #include <stddef.h> 18 #include <stdint.h> 19 20 #include "tensorflow/c/c_api_macros.h" 21 #include "tensorflow/c/tf_status.h" 22 23 // -------------------------------------------------------------------------- 24 // C API for StreamExecutor. The API is under active development and eventually 25 // should allow registering a pluggable device with TensorFlow. 26 // 27 // Conventions: 28 // * Struct prefix indicates whether struct fields should be filled by the 29 // plugin or core implementation: 30 // * SE_ : set/filled by core unless explicitly marked otherwise. 31 // * SP_ : set/filled by plugin unless explicitly marked otherwise. 32 // * We use `struct_size` for version checking. It is exempt from the `SE/SP` 33 // rule above and should be set both by core and the plugin. 34 // * For example, `create_device` function receives `SP_Device*` as input 35 // with `struct_size` populated by core. The plugin is responsible for 36 // setting `struct_size` as well, along with all other fields. 37 // * Refer to "TensorFlow Versioning Strategy" section at 38 // https://github.com/tensorflow/community/pull/257/files. 39 // * Note that the API is still under active development and doesn't have 40 // versioning guarantees yet. 41 // * `void* ext` is a free-form field that can be populated by 42 // a plugin in `SP_*` structs or potential future extension points in `SE_` 43 // structs. 44 // 45 // Example usage: 46 // 47 // /* Sample TensorFlow code below, exact implementation might differ. */ 48 // // Version checking uses `struct_size`. It is exempt from the `SE/SP` rule 49 // // above and should be set both by core and the plugin." 50 // SP_Device device { SP_DEVICE_STRUCT_SIZE }; 51 // SE_CreateDeviceParams params { SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE } ; 52 // params.device = &device; 53 // 54 // /* Plugin code below */ 55 // constexpr char DEVICE_NAME[] = "MY_DEVICE"; 56 // constexpr char DEVICE_TYPE[] = "GPU"; 57 // 58 // void create_device(const SP_Platform* platform, 59 // SE_CreateDeviceParams* params, TF_Status* status) { 60 // // Custom actions based on TensorFlow's view of SP_Device. 61 // OnTFDeviceView(params->device->struct_size); 62 // params->device = { SP_DEVICE_STRUCT_SIZE }; 63 // params->device->device_handle = get_my_device_handle(device->ordinal); 64 // params->device->ordinal = params->ordinal; 65 // ... 66 // } 67 // 68 // void destroy_device(const SP_Platform* platform, SP_Device* device) { 69 // delete_my_device_handle(device->device_handle); 70 // } 71 // 72 // void SE_InitPlugin( 73 // SE_PlatformRegistrationParams* params, 74 // TF_Status* status) { 75 // params->platform = { SP_PLATFORM_STRUCT_SIZE }; 76 // // Values such as `name` and `type` must outlive SE_InitPlugin call. 77 // params->platform->name = DEVICE_NAME; 78 // params->platform->type = DEVICE_TYPE; 79 // params->platform->visible_device_count = 2; 80 // params->platform_fns->create_device = create_device; 81 // params->platform_fns->destroy_device = destroy_device; 82 // ... 83 // } 84 85 #define SE_MAJOR 0 86 #define SE_MINOR 0 87 #define SE_PATCH 1 88 89 #ifdef __cplusplus 90 extern "C" { 91 #endif 92 93 typedef struct SP_Stream_st* SP_Stream; 94 typedef struct SP_Event_st* SP_Event; 95 typedef struct SP_Timer_st* SP_Timer; 96 // Takes `callback_arg` passed to `host_callback` as the first argument. 97 typedef void (*SE_StatusCallbackFn)(void* const, TF_Status* const); 98 99 typedef struct SP_TimerFns { 100 size_t struct_size; 101 void* ext; // reserved for future use 102 uint64_t (*nanoseconds)(SP_Timer timer); 103 } SP_TimerFns; 104 105 #define SP_TIMER_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_TimerFns, nanoseconds) 106 107 typedef struct SP_AllocatorStats { 108 size_t struct_size; 109 int64_t num_allocs; 110 int64_t bytes_in_use; 111 int64_t peak_bytes_in_use; 112 int64_t largest_alloc_size; 113 114 int8_t has_bytes_limit; 115 int64_t bytes_limit; 116 117 int64_t bytes_reserved; 118 int64_t peak_bytes_reserved; 119 120 int8_t has_bytes_reservable_limit; 121 int64_t bytes_reservable_limit; 122 123 int64_t largest_free_block_bytes; 124 } SP_AllocatorStats; 125 126 #define SP_ALLOCATORSTATS_STRUCT_SIZE \ 127 TF_OFFSET_OF_END(SP_AllocatorStats, largest_free_block_bytes) 128 129 // Potential states for an SP_Event. If `poll_for_status` returns anything aside 130 // from kPending or kComplete, an error has occurred; kUnknown is a bad state. 131 typedef enum SE_EventStatus { 132 SE_EVENT_UNKNOWN, 133 SE_EVENT_ERROR, 134 SE_EVENT_PENDING, 135 SE_EVENT_COMPLETE, 136 } SE_EventStatus; 137 138 // Memory allocation information. 139 // This matches DeviceMemoryBase defined here: 140 // https://cs.opensource.google/tensorflow/tensorflow/+/refs/tags/v2.3.0:tensorflow/stream_executor/device_memory.h;l=57 141 typedef struct SP_DeviceMemoryBase { 142 size_t struct_size; 143 void* ext; // Reserved for future use 144 // Platform-dependent value representing allocated memory. 145 // Note that the pointer does not have to be to the virtual address itself. 146 void* opaque; 147 uint64_t size; // Size in bytes of this allocation. 148 uint64_t payload; // Value for plugin's use 149 } SP_DeviceMemoryBase; 150 151 #define SP_DEVICE_MEMORY_BASE_STRUCT_SIZE \ 152 TF_OFFSET_OF_END(SP_DeviceMemoryBase, payload) 153 154 typedef struct SP_Device { 155 size_t struct_size; 156 void* ext; // free-form data set by plugin 157 int32_t ordinal; // device index 158 159 // Device vendor can store handle to their device representation 160 // here. 161 void* device_handle; 162 163 // [Optional] 164 // Device hardware name. Used for printing. 165 // Must be null-terminated. 166 const char* hardware_name; 167 168 // [Optional] 169 // Device vendor name. Used for printing. 170 // Must be null-terminated. 171 const char* device_vendor; 172 173 // [Optional] 174 // Returns the PCI bus identifier for this device, of the form 175 // [domain]:[bus]:[device].[function] 176 // where domain number is usually 0000. 177 // Example: 0000:00:02.1 178 // For more information see: 179 // https://en.wikipedia.org/wiki/PCI_configuration_space 180 // https://www.oreilly.com/library/view/linux-device-drivers/0596005903/ch12.html 181 // Used for printing. Must be null-terminated. 182 const char* pci_bus_id; 183 } SP_Device; 184 185 #define SP_DEVICE_STRUCT_SIZE TF_OFFSET_OF_END(SP_Device, pci_bus_id) 186 187 typedef struct SE_CreateDeviceParams { 188 size_t struct_size; 189 void* ext; // reserved for future use 190 int32_t ordinal; // device index 191 192 SP_Device* device; // Input/output, struct_size set by TF for plugin to read. 193 // Subsequently plugin fills the entire struct. 194 } SE_CreateDeviceParams; 195 196 #define SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE \ 197 TF_OFFSET_OF_END(SE_CreateDeviceParams, device) 198 199 typedef struct SP_DeviceFns { 200 size_t struct_size; 201 void* ext; // reserved for future use 202 203 // [Optional] 204 // Returns the NUMA node associated with this device, for use in 205 // determining socket locality. If the NUMA node could not be determined, -1 206 // is returned. 207 // Negative values are treated as "unset". 208 int32_t (*get_numa_node)(const SP_Device* device); 209 210 // [Optional] 211 // Device's memory bandwidth in bytes/sec. (This is for reads/writes to/from 212 // the device's own memory, not for transfers between the host and device.) 213 // Negative values are treated as "unset". 214 int64_t (*get_memory_bandwidth)(const SP_Device* device); 215 216 // [Optional] 217 // Estimate of average number of floating point operations per second for 218 // this device * 10e-9. 219 // Negative values are treated as "unset". 220 double (*get_gflops)(const SP_Device* device); 221 } SP_DeviceFns; 222 223 #define SP_DEVICE_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_DeviceFns, get_gflops) 224 225 typedef struct SE_CreateDeviceFnsParams { 226 size_t struct_size; 227 void* ext; // reserved for future use 228 229 SP_DeviceFns* device_fns; // output, to be filled by plugin 230 } SE_CreateDeviceFnsParams; 231 232 #define SE_CREATE_DEVICE_FNS_PARAMS_STRUCT_SIZE \ 233 TF_OFFSET_OF_END(SE_CreateDeviceFnsParams, device_fns) 234 235 typedef struct SP_StreamExecutor { 236 size_t struct_size; 237 void* ext; // reserved for future use 238 239 /*** ALLOCATION CALLBACKS ***/ 240 // Synchronously allocates `size` bytes on the underlying platform and returns 241 // `SP_DeviceMemoryBase` representing that allocation. In the case of failure, 242 // nullptr is returned. 243 // `memory_space` is reserved for a potential future usage and should be set 244 // to 0. 245 void (*allocate)(const SP_Device* device, uint64_t size, int64_t memory_space, 246 SP_DeviceMemoryBase* mem); 247 248 // Deallocate the device memory previously allocated via this interface. 249 // Deallocation of a nullptr-representative value is permitted. 250 void (*deallocate)(const SP_Device* device, SP_DeviceMemoryBase* memory); 251 252 // Allocates a region of host memory and registers it with the platform API. 253 // Memory allocated in this manner is required for use in asynchronous memcpy 254 // operations, such as `memcpy_dtoh`. 255 void* (*host_memory_allocate)(const SP_Device* device, uint64_t size); 256 257 // Deallocates a region of host memory allocated by `host_memory_allocate`. 258 void (*host_memory_deallocate)(const SP_Device* device, void* mem); 259 260 // Allocates unified memory space of the given size, if supported. Unified 261 // memory support should be added by setting `supports_unified_memory` field 262 // in `SP_Platform`. 263 void* (*unified_memory_allocate)(const SP_Device* device, uint64_t bytes); 264 265 // Deallocates unified memory space previously allocated with 266 // `unified_memory_allocate`. Unified 267 // memory support should be added by setting `supports_unified_memory` field 268 // in `SP_Platform`. 269 void (*unified_memory_deallocate)(const SP_Device* device, void* location); 270 271 // Fills SP_AllocatorStats with allocator statistics, if it is available. 272 // If it is not available, return false. 273 TF_Bool (*get_allocator_stats)(const SP_Device* device, 274 SP_AllocatorStats* stats); 275 // Fills the underlying device memory usage information, if it is 276 // available. If it is not available (false is returned), free/total need not 277 // be initialized. 278 TF_Bool (*device_memory_usage)(const SP_Device* device, int64_t* free, 279 int64_t* total); 280 281 /*** STREAM CALLBACKS ***/ 282 // Creates SP_Stream. This call should also allocate stream 283 // resources on the underlying platform and initializes its 284 // internals. 285 void (*create_stream)(const SP_Device* device, SP_Stream* stream, 286 TF_Status* status); 287 288 // Destroys SP_Stream and deallocates any underlying resources. 289 void (*destroy_stream)(const SP_Device* device, SP_Stream stream); 290 291 // Causes `dependent` to not begin execution until `other` has finished its 292 // last-enqueued work. 293 void (*create_stream_dependency)(const SP_Device* device, SP_Stream dependent, 294 SP_Stream other, TF_Status* status); 295 296 // Without blocking the device, retrieve the current stream status. 297 void (*get_stream_status)(const SP_Device* device, SP_Stream stream, 298 TF_Status* status); 299 300 /*** EVENT CALLBACKS ***/ 301 // Create SP_Event. Performs platform-specific allocation and initialization 302 // of an event. 303 void (*create_event)(const SP_Device* device, SP_Event* event, 304 TF_Status* status); 305 306 // Destroy SE_Event and perform any platform-specific deallocation and 307 // cleanup of an event. 308 void (*destroy_event)(const SP_Device* device, SP_Event event); 309 310 // Requests the current status of the event from the underlying platform. 311 SE_EventStatus (*get_event_status)(const SP_Device* device, SP_Event event); 312 // Inserts the specified event at the end of the specified stream. 313 void (*record_event)(const SP_Device* device, SP_Stream stream, 314 SP_Event event, TF_Status* status); 315 316 // Wait for the specified event at the end of the specified stream. 317 void (*wait_for_event)(const SP_Device* const device, SP_Stream stream, 318 SP_Event event, TF_Status* const status); 319 320 /*** TIMER CALLBACKS ***/ 321 // Creates SP_Timer. Allocates timer resources on the underlying platform 322 // and initializes its internals, setting `timer` output variable. Sets 323 // values in `timer_fns` struct. 324 void (*create_timer)(const SP_Device* device, SP_Timer* timer, 325 TF_Status* status); 326 327 // Destroy timer and deallocates timer resources on the underlying platform. 328 void (*destroy_timer)(const SP_Device* device, SP_Timer timer); 329 330 // Records a start event for an interval timer. 331 void (*start_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer, 332 TF_Status* status); 333 334 // Records a stop event for an interval timer. 335 void (*stop_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer, 336 TF_Status* status); 337 338 /*** MEMCPY CALLBACKS ***/ 339 // Enqueues a memcpy operation onto stream, with a host destination location 340 // `host_dst` and a device memory source, with target size `size`. 341 void (*memcpy_dtoh)(const SP_Device* device, SP_Stream stream, void* host_dst, 342 const SP_DeviceMemoryBase* device_src, uint64_t size, 343 TF_Status* status); 344 345 // Enqueues a memcpy operation onto stream, with a device destination 346 // location and a host memory source, with target size `size`. 347 void (*memcpy_htod)(const SP_Device* device, SP_Stream stream, 348 SP_DeviceMemoryBase* device_dst, const void* host_src, 349 uint64_t size, TF_Status* status); 350 351 // Enqueues a memcpy operation onto stream, with a device destination 352 // location and a device memory source, with target size `size`. 353 void (*memcpy_dtod)(const SP_Device* device, SP_Stream stream, 354 SP_DeviceMemoryBase* device_dst, 355 const SP_DeviceMemoryBase* device_src, uint64_t size, 356 TF_Status* status); 357 358 // Blocks the caller while a data segment of the given size is 359 // copied from the device source to the host destination. 360 void (*sync_memcpy_dtoh)(const SP_Device* device, void* host_dst, 361 const SP_DeviceMemoryBase* device_src, uint64_t size, 362 TF_Status* status); 363 364 // Blocks the caller while a data segment of the given size is 365 // copied from the host source to the device destination. 366 void (*sync_memcpy_htod)(const SP_Device* device, 367 SP_DeviceMemoryBase* device_dst, 368 const void* host_src, uint64_t size, 369 TF_Status* status); 370 371 // Blocks the caller while a data segment of the given size is copied from the 372 // device source to the device destination. 373 void (*sync_memcpy_dtod)(const SP_Device* device, 374 SP_DeviceMemoryBase* device_dst, 375 const SP_DeviceMemoryBase* device_src, uint64_t size, 376 TF_Status* status); 377 378 // Causes the host code to synchronously wait for the event to complete. 379 void (*block_host_for_event)(const SP_Device* device, SP_Event event, 380 TF_Status* status); 381 382 // [Optional] 383 // Causes the host code to synchronously wait for operations entrained onto 384 // stream to complete. Effectively a join on the asynchronous device 385 // operations enqueued on the stream before this program point. 386 // If not set, then corresponding functionality will be implemented 387 // by registering an event on the `stream` and waiting for it using 388 // `block_host_for_event`. 389 void (*block_host_until_done)(const SP_Device* device, SP_Stream stream, 390 TF_Status* status); 391 392 // Synchronizes all activity occurring in the StreamExecutor's context (most 393 // likely a whole device). 394 void (*synchronize_all_activity)(const SP_Device* device, TF_Status* status); 395 396 // Enqueues on a stream a user-specified function to be run on the host. 397 // `callback_arg` should be passed as the first argument to `callback_fn`. 398 TF_Bool (*host_callback)(const SP_Device* device, SP_Stream stream, 399 SE_StatusCallbackFn callback_fn, void* callback_arg); 400 } SP_StreamExecutor; 401 402 #define SP_STREAMEXECUTOR_STRUCT_SIZE \ 403 TF_OFFSET_OF_END(SP_StreamExecutor, host_callback) 404 405 typedef struct SE_CreateStreamExecutorParams { 406 size_t struct_size; 407 void* ext; // reserved for future use 408 409 SP_StreamExecutor* stream_executor; // output, to be filled by plugin 410 } SE_CreateStreamExecutorParams; 411 412 #define SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE \ 413 TF_OFFSET_OF_END(SE_CreateStreamExecutorParams, stream_executor) 414 415 typedef struct SP_Platform { 416 size_t struct_size; 417 418 void* ext; // free-form data set by plugin 419 420 // Platform name (also referred to as subtype), for example MY_DEVICE. 421 // The name must start with a capital letter and consist of 422 // capital letters and underscores. 423 // Must be null-terminated. 424 const char* name; 425 426 // Device type name, for example GPU. Must be null-terminated. 427 // The name must start with a capital letter and consist of 428 // capital letters and underscores. 429 const char* type; 430 431 // Number of visible devices 432 size_t visible_device_count; 433 434 // Whether this platform supports unified memory. 435 // Unified memory is a single memory address space accessible from any device. 436 TF_Bool supports_unified_memory; 437 } SP_Platform; 438 439 #define SP_PLATFORM_STRUCT_SIZE \ 440 TF_OFFSET_OF_END(SP_Platform, supports_unified_memory) 441 442 typedef struct SP_PlatformFns { 443 size_t struct_size; 444 445 void* ext; // reserved for future use 446 447 // Callbacks for creating/destroying SP_Device. 448 void (*create_device)(const SP_Platform* platform, 449 SE_CreateDeviceParams* params, TF_Status* status); 450 451 // Clean up fields inside SP_Device that were allocated 452 // by the plugin. `device` itself should not be deleted here. 453 void (*destroy_device)(const SP_Platform* platform, SP_Device* device); 454 455 // Callbacks for creating/destroying SP_DeviceFns. 456 void (*create_device_fns)(const SP_Platform* platform, 457 SE_CreateDeviceFnsParams* params, 458 TF_Status* status); 459 460 // Clean up fields inside SP_DeviceFns that were allocated 461 // by the plugin. `device_fns` itself should not be deleted here. 462 void (*destroy_device_fns)(const SP_Platform* platform, 463 SP_DeviceFns* device_fns); 464 465 // Callbacks for creating/destroying SP_StreamExecutor. 466 void (*create_stream_executor)(const SP_Platform* platform, 467 SE_CreateStreamExecutorParams* params, 468 TF_Status* status); 469 // Clean up fields inside SP_StreamExecutor that were allocated 470 // by the plugin. `stream_executor` itself should not be deleted here. 471 void (*destroy_stream_executor)(const SP_Platform* platform, 472 SP_StreamExecutor* stream_executor); 473 474 // Callbacks for creating/destroying SP_TimerFns. 475 void (*create_timer_fns)(const SP_Platform* platform, SP_TimerFns* timer, 476 TF_Status* status); 477 478 void (*destroy_timer_fns)(const SP_Platform* platform, 479 SP_TimerFns* timer_fns); 480 } SP_PlatformFns; 481 482 #define SP_PLATFORM_FNS_STRUCT_SIZE \ 483 TF_OFFSET_OF_END(SP_PlatformFns, destroy_timer_fns) 484 485 typedef struct SE_PlatformRegistrationParams { 486 size_t struct_size; 487 void* ext; // reserved for future use 488 489 // StreamExecutor C API version. 490 int32_t major_version; 491 int32_t minor_version; 492 int32_t patch_version; 493 494 SP_Platform* platform; // output, set by plugin 495 SP_PlatformFns* platform_fns; // output, set by plugin 496 // Clean up fields inside SP_Platform that were allocated 497 // by the plugin. `platform` itself should not be deleted here. 498 void (*destroy_platform)(SP_Platform* platform); // out, set by plugin 499 void (*destroy_platform_fns)( 500 SP_PlatformFns* platform_fns); // out, set by plugin 501 } SE_PlatformRegistrationParams; 502 503 #define SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE \ 504 TF_OFFSET_OF_END(SE_PlatformRegistrationParams, destroy_platform_fns) 505 506 void SE_InitPlugin(SE_PlatformRegistrationParams* params, TF_Status* status); 507 508 #ifdef __cplusplus 509 } // extern "C" 510 #endif 511 512 #endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ 513