1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_ 16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_ 17 18 #include <stddef.h> 19 20 #include <cstdint> 21 22 #include "absl/types/optional.h" 23 #include "tensorflow/core/tpu/libtftpu.h" 24 #include "tensorflow/stream_executor/tpu/c_api_decl.h" 25 #include "tensorflow/stream_executor/tpu/proto_helper.h" 26 27 typedef struct TpuSerializedProto TpuSerializedProto; 28 29 namespace tensorflow { 30 class TpuMeshCommonState; 31 } // namespace tensorflow 32 33 extern "C" { 34 35 typedef struct XLA_TpuProgram XLA_TpuProgram; 36 37 // Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj. 38 enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding }; 39 40 struct TpuProgramFingerprint { 41 const char* bytes; 42 size_t size; 43 }; 44 45 struct TpuExecutableSerializedProto { 46 const char* bytes; 47 size_t size; 48 }; 49 50 struct CompilerMetadataSerializedProto { 51 const char* bytes; 52 size_t size; 53 }; 54 55 struct HostComputeMetadataSerializedProto { 56 const char* bytes; 57 size_t size; 58 }; 59 60 typedef struct XLA_TpuMeshState XLA_TpuMeshState; 61 62 typedef struct TpuProfiler TpuProfiler; 63 64 typedef struct XLA_DeviceAssignment { 65 const char* bytes; 66 size_t size; 67 } XLA_DeviceAssignment; 68 69 // Property for creating compilation cache key. 70 struct CompilationCacheKeyProperty { 71 const char* config_prefix; 72 const char* shapes_prefix; 73 const char* function_name; 74 uint64_t mlir_module_fingerprint; 75 const int32_t* device_ids; 76 size_t device_ids_size; 77 int32_t guaranteed_constants_size; 78 uint64_t function_library_fingerprint; 79 int32_t num_cores_per_replica; 80 int32_t num_replicas; 81 const XLA_TpuMeshState* mesh_state; 82 }; 83 84 // Compilation cache key result returning both the key and a more verbose debug 85 // version. 86 struct CompilationCacheKeyResult { 87 const char* key; 88 const char* debug_string; 89 }; 90 91 typedef struct XLA_TpuNodeContext XLA_TpuNodeContext; 92 93 typedef struct TfTpu_OrdinalSelector TfTpuOrdinalSelector; 94 95 struct TpuPartitionedCall_Params { 96 bool input_shape_opt; 97 bool group_tensors_for_packing; 98 int32_t minimum_input_tensors_packing; 99 int32_t minimum_output_tensors_packing; 100 101 // Whether to attempt to automatically shard inputs by adding an 102 // XlaSharding op after each input. 103 bool enable_auto_xla_input_sharding; 104 105 // The dimension of each input to shard if 106 // enable_auto_xla_input_sharding is set to true. Negative numbers are 107 // allowed and refers to dimensions starting from the end. 108 int32_t auto_xla_input_sharding_dim; 109 110 // If true, only create one variable on the TPU for each variable on the CPU. 111 bool enable_variable_deduplication; 112 }; 113 114 // Compiles Mlir or TF function computation by lowering into HLO IR and returns 115 // `count` number of TPU programs ready for execution. 116 // The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and creates 117 // `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller is 118 // responsible to deallocate both the `XLA_TpuProgram*[]` array and the 119 // `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free` 120 // API respectively. 121 TFTPU_CAPI_EXPORT void TpuCompile_CompileAndBuild( 122 TpuSerializedProto compilation_request, const XLA_TpuMeshState* mesh_state, 123 XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status); 124 125 // Compiles a HLO IR and returns `count` number of TPU programs ready for 126 // execution. The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and 127 // creates `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller 128 // is responsible to deallocate both the `XLA_TpuProgram*[]` array and the 129 // `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free` 130 // API respectively. 131 TFTPU_CAPI_EXPORT void TpuCompile_XrtCompileAndBuild( 132 TpuSerializedProto xrt_computation, const XLA_TpuMeshState* mesh_state, 133 XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status); 134 135 // Creates a TPU profiler that is ready to start profiling. 136 TFTPU_CAPI_EXPORT void TpuProfiler_Create(TpuProfiler** tpu_profiler, 137 TF_Status* status); 138 // Destroys the given TPU profiler. 139 TFTPU_CAPI_EXPORT void TpuProfiler_Destroy(TpuProfiler* tpu_profiler); 140 // Starts profiling if not already started, returns an error otherwise. 141 TFTPU_CAPI_EXPORT void TpuProfiler_Start(TpuProfiler* tpu_profiler, 142 TF_Status* status); 143 // Stops profiling if not already stopped, returns an error otherwise. 144 TFTPU_CAPI_EXPORT void TpuProfiler_Stop(TpuProfiler* tpu_profiler, 145 TF_Status* status); 146 // Serializes profiled data into `buffer` and returns the size of `buffer`. The 147 // profile data held by the TPU driver will be cleared after retrieval. 148 // 149 // Step 1. Query the size of buffer required into `size_in_bytes`. 150 // 151 // size_t size_in_bytes; 152 // TpuProfiler_CollectData(profiler, status, nullptr, &size_in_bytes); 153 // 154 // Step 2. Retrieve the data into a `buffer` of size `size_in_bytes`. 155 // Subsequently,The TPU driver clears its copy of the profile data. 156 // 157 // uint8_t buffer = new uint8_t[size_in_bytes]; 158 // TpuProfiler_CollectData(profiler, status, buffer, size_in_bytes); 159 // 160 // Step 3. Unpack the data into an XSpace. 161 // 162 // tensorflow::profiler::XSpace space; 163 // space.ParseFromArray(buffer, size_in_bytes); 164 // 165 TFTPU_CAPI_EXPORT void TpuProfiler_CollectData(TpuProfiler* tpu_profiler, 166 TF_Status* status, 167 uint8_t* buffer, 168 size_t* size_in_bytes); 169 170 // Creates a new TPU mesh state object. 171 TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create(); 172 173 // Deletes the given TPU `mesh_state` object. Once deleted the object is 174 // unusable. 175 TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state); 176 177 // Returns a pointer to an opaque mesh data structure used internally. 178 TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState( 179 XLA_TpuMeshState* mesh_state); 180 181 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_Create( 182 TfTpuOrdinalSelector** ordinal_selector, int num_cores_per_replica); 183 184 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_Destroy( 185 TfTpuOrdinalSelector* ordinal_selector); 186 187 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_GetOrdinal( 188 TfTpuOrdinalSelector* ordinal_selector, absl::optional<uint64_t> key, 189 int64_t* req_id, int64_t* ordinal); 190 191 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_DequeueFromCoreSelector( 192 TfTpuOrdinalSelector* ordinal_selector, int32_t device_ordinal, 193 int64_t req_id); 194 195 TFTPU_CAPI_EXPORT void TfTpu_GetTpuPartitionedCallParams( 196 TpuPartitionedCall_Params* params); 197 198 typedef struct TpuExecutable_LoadProgramAndEnqueueToStream_Params { 199 int32_t struct_size; 200 void* priv; 201 202 const XLA_TpuProgram* program; 203 SE_DeviceMemoryBase* arguments; 204 size_t arguments_len; 205 SE_DeviceMemoryBase* result; 206 bool has_cross_program_prefetch_addr; 207 SE_DeviceMemoryBase* cross_program_prefetch_addr; 208 int32_t rng_seed; 209 XLA_DeviceAssignment* device_assignment; 210 SE_Stream* stream; 211 212 TF_Status* status; // out 213 } TpuExecutable_LoadProgramAndEnqueueToStream_Params; 214 215 #define TpuExecutable_LoadProgramAndEnqueueToStream_Params_SIZE \ 216 (sizeof(struct TpuExecutable_LoadProgramAndEnqueueToStream_Params)) 217 218 TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream( 219 TpuExecutable_LoadProgramAndEnqueueToStream_Params* params); 220 221 TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape( 222 XLA_Shape* host_shape, XLA_Shape* device_shape); 223 TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape); 224 TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompact(XLA_Shape* shape); 225 TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape); 226 227 typedef struct TpuExecute_RuntimeInputToPaddedData_Params { 228 int32_t struct_size; 229 void* priv; 230 231 uint32_t* runtime_input_ptr; 232 size_t runtime_input_size; 233 int8_t* padded_data_ptr; 234 size_t padded_data_size; 235 XLA_Shape* runtime_shape; 236 XLA_Shape* compile_time_shape; 237 238 TF_Status* status; // out 239 } TpuExecute_RuntimeInputToPaddedData_Params; 240 241 #define TpuExecute_RuntimeInputToPaddedData_Params_SIZE \ 242 (sizeof(struct TpuExecute_RuntimeInputToPaddedData_Params)) 243 244 TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData( 245 TpuExecute_RuntimeInputToPaddedData_Params* params); 246 247 typedef struct ConfigureDistributedTpuOp_DoWork_Params { 248 int32_t struct_size; 249 void* priv; 250 251 size_t num_cores_per_host_size; 252 const int32_t* num_cores_per_host; 253 size_t server_address_size; 254 const char* server_address; 255 256 size_t* host_config_output_size; // out 257 char** host_config_output; // out 258 TF_Status* status; // out 259 } ConfigureDistributedTpuOp_DoWork_Params; 260 261 #define ConfigureDistributedTpuOp_DoWork_Params_SIZE \ 262 (sizeof(struct ConfigureDistributedTpuOp_DoWork_Params)) 263 264 TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork( 265 ConfigureDistributedTpuOp_DoWork_Params* params); 266 267 typedef struct WaitForDistributedTpuOp_DoWork_Params { 268 int32_t struct_size; 269 void* priv; 270 271 size_t num_hosts; 272 size_t num_cores_per_host; 273 const int32_t** host_ordinal_to_global_core_id_map; 274 tensorflow::TpuMeshCommonState* tpu_mesh_common_state; 275 276 size_t* tpu_topology_output_size; // out 277 char** tpu_topology_output; // out 278 TF_Status* status; // out 279 } WaitForDistributedTpuOp_DoWork_Params; 280 281 #define WaitForDistributedTpuOp_DoWork_Params_SIZE \ 282 (sizeof(struct WaitForDistributedTpuOp_DoWork_Params)) 283 284 TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork( 285 WaitForDistributedTpuOp_DoWork_Params* params); 286 287 typedef struct InitializeHostForDistributedTpuOp_DoWork_Params { 288 int32_t struct_size; 289 void* priv; 290 291 size_t tpu_host_config_size; 292 const char* tpu_host_config; 293 bool enable_whole_mesh_compilations; 294 bool is_master_worker; 295 296 size_t* core_id_output_size; // out 297 int32_t** core_id_output; // out 298 TF_Status* status; // out 299 } InitializeHostForDistributedTpuOp_DoWork_Params; 300 301 #define InitializeHostForDistributedTpuOp_DoWork_Params_SIZE \ 302 (sizeof(struct InitializeHostForDistributedTpuOp_DoWork_Params)) 303 304 TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork( 305 InitializeHostForDistributedTpuOp_DoWork_Params* params); 306 307 TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork( 308 const size_t tpu_topology_size, const char* tpu_topology, 309 TF_Status* status); 310 311 TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork( 312 int32_t* number_of_chips_output, TF_Status* status); 313 314 TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output); 315 TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output); 316 317 TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState(); 318 319 TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus, 320 TF_Status* status); 321 TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit, 322 TF_Status* status); 323 TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes( 324 int64_t* cache_size_in_bytes); 325 326 typedef struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params { 327 int32_t struct_size; 328 void* priv; 329 330 size_t tpu_host_config_size; 331 const char* tpu_host_config; 332 333 size_t* server_address_output_size; // out 334 char** server_address_output; // out 335 TF_Status* status; // out 336 } TpuConfigurationApi_CompilationCacheServerAddressFromConfig_Params; 337 338 #define TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE \ 339 (sizeof( \ 340 struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params)) 341 342 TFTPU_CAPI_EXPORT 343 void TpuConfigurationApi_CompilationCacheServerAddressFromConfig( 344 TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params* params); 345 346 typedef struct TpuConfigurationApi_GetServerAddressAndPort_Params { 347 int32_t struct_size; 348 void* priv; 349 350 size_t* server_address_output_size; // out 351 char** server_address_output; // out 352 int* port_output; // out 353 TF_Status* status; // out 354 } TpuConfigurationApi_GetServerAddressAndPort_Params; 355 356 #define TpuConfigurationApi_GetServerAddressAndPort_Params_SIZE \ 357 (sizeof(struct TpuConfigurationApi_GetServerAddressAndPort_Params)) 358 359 TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort( 360 TpuConfigurationApi_GetServerAddressAndPort_Params* params); 361 362 // Creates a new TPU program. 363 TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_New(); 364 365 // Destroys the `tpu_program`. 366 TFTPU_CAPI_EXPORT void TpuProgram_Free(XLA_TpuProgram* tpu_program); 367 368 // Creates an array of `XLA_TpuProgram*`. 369 TFTPU_CAPI_EXPORT XLA_TpuProgram** TpuProgram_NewArray(size_t count); 370 371 // Destroys an array of `XLA_TpuProgram*`. 372 TFTPU_CAPI_EXPORT void TpuProgram_FreeArray(XLA_TpuProgram* tpu_program[]); 373 374 // Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and 375 // destroyed, it is in an unusable state. 376 TFTPU_CAPI_EXPORT void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program, 377 TF_Status* status); 378 379 // Gets TPU program size in bytes from the `tpu_program`. 380 TFTPU_CAPI_EXPORT int64_t 381 TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program); 382 383 // Logs the summary of current memory state snapshot of the `tpu_program`. 384 TFTPU_CAPI_EXPORT bool TpuProgram_LogProgramMemorySummary( 385 const XLA_TpuProgram* tpu_program); 386 387 // Gets TPU program executable info from the `tpu_program`. 388 TFTPU_CAPI_EXPORT void TpuProgram_GetExecutableInfo( 389 const XLA_TpuProgram* tpu_program, TpuSerializedProto* executable_info, 390 TF_Status* status); 391 392 // Gets host transfer info proto. 393 TFTPU_CAPI_EXPORT void TpuProgram_GetHostTransferInfo( 394 const XLA_TpuProgram* tpu_program, TpuSerializedProto* host_transfer_info, 395 TF_Status* status); 396 397 // Gets HLO metadata proto. 398 TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata( 399 const XLA_TpuProgram* tpu_program, TpuSerializedProto* hlo_metadata, 400 TF_Status* status); 401 402 // Gets may modify variables boolean value. 403 TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables( 404 const XLA_TpuProgram* tpu_program, bool* may_modify_variables); 405 406 // Checks if TPU program has sharding. 407 TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding( 408 const XLA_TpuProgram* tpu_program); 409 410 // Gets TPU program by sharding type. Return value is valid only when the 411 // `status.status()` returns `OK`. 412 TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram( 413 XLA_TpuProgram* tpu_program, TpuProgramShardingType type); 414 415 // Gets TPU executable proto from a `tpu_program`. 416 TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable( 417 const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable, 418 TF_Status* status); 419 420 // Gets compilation metadata proto from a `tpu_program`. 421 TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata( 422 const XLA_TpuProgram* tpu_program, 423 CompilerMetadataSerializedProto* compiler_metadata, TF_Status* status); 424 425 // Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`. 426 TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto( 427 TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program, 428 TF_Status* status); 429 430 TFTPU_CAPI_EXPORT TpuProgramFingerprint 431 TpuProgram_GetFingerprint(const XLA_TpuProgram* tpu_program); 432 433 TFTPU_CAPI_EXPORT void TpuProgram_DestroyFingerprint( 434 TpuProgramFingerprint fingerprint); 435 436 // Checks if whether a TPU compilation is enabled. 437 TFTPU_CAPI_EXPORT bool TpuCompile_IsTpuCompilationEnabled(); 438 439 // XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit 440 // when cancellation is requested for an XLA compile op. Some tests require this 441 // behavior to be disabled, and we test for this condition with the following 442 // flag function. 443 TFTPU_CAPI_EXPORT bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation(); 444 445 // Returns the number of available TPU core count. 446 TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount( 447 const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type); 448 449 // Recycle unused service port. 450 TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port); 451 452 // Creates a unique compilation cache `key` used for `put` and `get` operations. 453 // Returned buffers are heap-allocated and must be owned. 454 TFTPU_CAPI_EXPORT CompilationCacheKeyResult 455 TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty property); 456 457 // Destroys the CompilationCacheKeyResult returned by calling the 458 // `TpuCompile_CreateCompilationCacheKey` API. 459 TFTPU_CAPI_EXPORT void TpuCompile_DestroyCompilationCacheKey( 460 CompilationCacheKeyResult result); 461 462 // Creates a guaranteed const fingerprint. Guarantee const is normally used in 463 // TPU inference to avoid re-copying unchanged variables onto the TPU device. 464 // It promises the value is identical for every execution in the same session 465 // even if the actual value changes in later executions. 466 TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint( 467 uint64_t fingerprint, const char* data, size_t size); 468 469 XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal, 470 TF_Status* status); 471 void TpuNodeContext_Free(XLA_TpuNodeContext* node_context); 472 473 void TpuNodeContext_StopChipHeartbeats(TF_Status* status); 474 475 void TpuNodeContext_CloseTpuHost(TF_Status* status); 476 477 void TpuNodeContext_Initialize(int device_ordinal, TF_Status* status); 478 479 bool TpuNodeContext_CompactionSupported(int device_ordinal); 480 481 // Globally initialize the TPU system for inference. 482 TFTPU_CAPI_EXPORT void TfTpu_InitializeTpuModelServer(); 483 484 struct TfTpu_OpsApiFn { 485 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild); 486 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_XrtCompileAndBuild); 487 488 TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create); 489 TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free); 490 TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState); 491 492 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Create); 493 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Destroy); 494 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Start); 495 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Stop); 496 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_CollectData); 497 498 TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream); 499 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape); 500 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize); 501 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompact); 502 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompactRaw); 503 504 TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData); 505 506 TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork); 507 TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork); 508 TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork); 509 TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork); 510 TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork); 511 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray); 512 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array); 513 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState); 514 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost); 515 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit); 516 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes); 517 TFTPU_ADD_FN_IN_STRUCT( 518 TpuConfigurationApi_CompilationCacheServerAddressFromConfig); 519 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort); 520 521 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New); 522 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free); 523 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_NewArray); 524 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_FreeArray); 525 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_UnloadAndDestroy); 526 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetProgramSize); 527 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_LogProgramMemorySummary); 528 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetExecutableInfo); 529 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo); 530 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata); 531 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables); 532 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding); 533 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram); 534 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable); 535 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata); 536 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto); 537 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetFingerprint); 538 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DestroyFingerprint); 539 540 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled); 541 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation); 542 TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount); 543 TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort); 544 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey); 545 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey); 546 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint); 547 548 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create); 549 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free); 550 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats); 551 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost); 552 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize); 553 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CompactionSupported); 554 555 TFTPU_ADD_FN_IN_STRUCT(TfTpu_InitializeTpuModelServer); 556 557 TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_Create); 558 TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_Destroy); 559 TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_GetOrdinal); 560 TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_DequeueFromCoreSelector); 561 TFTPU_ADD_FN_IN_STRUCT(TfTpu_GetTpuPartitionedCallParams); 562 }; 563 564 } // extern "C" 565 566 #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_ 567