1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_ 16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_ 17 18 #include <stddef.h> 19 20 #include <cstdint> 21 22 #include "tensorflow/core/tpu/libtftpu.h" 23 #include "tensorflow/stream_executor/tpu/c_api_decl.h" 24 #include "tensorflow/stream_executor/tpu/proto_helper.h" 25 26 typedef struct TpuSerializedProto TpuSerializedProto; 27 28 namespace tensorflow { 29 class TpuMeshCommonState; 30 } // namespace tensorflow 31 32 extern "C" { 33 34 typedef struct XLA_TpuProgram XLA_TpuProgram; 35 36 // Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj. 37 enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding }; 38 39 struct TpuExecutableSerializedProto { 40 const char* bytes; 41 size_t size; 42 }; 43 44 struct CompilerMetadataSerializedProto { 45 const char* bytes; 46 size_t size; 47 }; 48 49 struct HostComputeMetadataSerializedProto { 50 const char* bytes; 51 size_t size; 52 }; 53 54 typedef struct XLA_TpuMeshState XLA_TpuMeshState; 55 56 typedef struct TpuProfiler TpuProfiler; 57 58 typedef struct XLA_DeviceAssignment { 59 const char* bytes; 60 size_t size; 61 } XLA_DeviceAssignment; 62 63 // Property for creating compilation cache key. 64 struct CompilationCacheKeyProperty { 65 const char* config_prefix; 66 const char* shapes_prefix; 67 const char* function_name; 68 uint64_t mlir_module_fingerprint; 69 const int32_t* device_ids; 70 size_t device_ids_size; 71 int32_t guaranteed_constants_size; 72 uint64_t function_library_fingerprint; 73 int32_t num_cores_per_replica; 74 int32_t num_replicas; 75 const XLA_TpuMeshState* mesh_state; 76 }; 77 78 // Compilation cache key result returning both the key and a more verbose debug 79 // version. 80 struct CompilationCacheKeyResult { 81 const char* key; 82 const char* debug_string; 83 }; 84 85 typedef struct XLA_TpuNodeContext XLA_TpuNodeContext; 86 87 // Compiles Mlir or TF function computation by lowering into HLO IR and returns 88 // `count` number of TPU programs ready for execution. 89 // The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and creates 90 // `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller is 91 // responsible to deallocate both the `XLA_TpuProgram*[]` array and the 92 // `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free` 93 // API respectively. 94 TFTPU_CAPI_EXPORT void TpuCompile_CompileAndBuild( 95 TpuSerializedProto compilation_request, const XLA_TpuMeshState* mesh_state, 96 XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status); 97 98 // Compiles a HLO IR and returns `count` number of TPU programs ready for 99 // execution. The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and 100 // creates `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller 101 // is responsible to deallocate both the `XLA_TpuProgram*[]` array and the 102 // `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free` 103 // API respectively. 104 TFTPU_CAPI_EXPORT void TpuCompile_XrtCompileAndBuild( 105 TpuSerializedProto xrt_computation, const XLA_TpuMeshState* mesh_state, 106 XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status); 107 108 // Creates a TPU profiler that is ready to start profiling. 109 TFTPU_CAPI_EXPORT void TpuProfiler_Create(TpuProfiler** tpu_profiler, 110 TF_Status* status); 111 // Destroys the given TPU profiler. 112 TFTPU_CAPI_EXPORT void TpuProfiler_Destroy(TpuProfiler* tpu_profiler); 113 // Starts profiling if not already started, returns an error otherwise. 114 TFTPU_CAPI_EXPORT void TpuProfiler_Start(TpuProfiler* tpu_profiler, 115 TF_Status* status); 116 // Stops profiling if not already stopped, returns an error otherwise. 117 TFTPU_CAPI_EXPORT void TpuProfiler_Stop(TpuProfiler* tpu_profiler, 118 TF_Status* status); 119 // Serializes profiled data into `buffer` and returns the size of `buffer`. The 120 // profile data held by the TPU driver will be cleared after retrieval. 121 // 122 // Step 1. Query the size of buffer required into `size_in_bytes`. 123 // 124 // size_t size_in_bytes; 125 // TpuProfiler_CollectData(profiler, status, nullptr, &size_in_bytes); 126 // 127 // Step 2. Retrieve the data into a `buffer` of size `size_in_bytes`. 128 // Subsequently,The TPU driver clears its copy of the profile data. 129 // 130 // uint8_t buffer = new uint8_t[size_in_bytes]; 131 // TpuProfiler_CollectData(profiler, status, buffer, size_in_bytes); 132 // 133 // Step 3. Unpack the data into an XSpace. 134 // 135 // tensorflow::profiler::XSpace space; 136 // space.ParseFromArray(buffer, size_in_bytes); 137 // 138 TFTPU_CAPI_EXPORT void TpuProfiler_CollectData(TpuProfiler* tpu_profiler, 139 TF_Status* status, 140 uint8_t* buffer, 141 size_t* size_in_bytes); 142 143 // Creates a new TPU mesh state object. 144 TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create(); 145 146 // Deletes the given TPU `mesh_state` object. Once deleted the object is 147 // unusable. 148 TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state); 149 150 // Returns a pointer to an opaque mesh data structure used internally. 151 TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState( 152 XLA_TpuMeshState* mesh_state); 153 154 typedef struct TpuExecutable_LoadProgramAndEnqueueToStream_Params { 155 int32_t struct_size; 156 void* priv; 157 158 const XLA_TpuProgram* program; 159 SE_DeviceMemoryBase* arguments; 160 size_t arguments_len; 161 SE_DeviceMemoryBase* result; 162 bool has_cross_program_prefetch_addr; 163 SE_DeviceMemoryBase* cross_program_prefetch_addr; 164 int32_t rng_seed; 165 XLA_DeviceAssignment* device_assignment; 166 SE_Stream* stream; 167 168 TF_Status* status; // out 169 } TpuExecutable_LoadProgramAndEnqueueToStream_Params; 170 171 #define TpuExecutable_LoadProgramAndEnqueueToStream_Params_SIZE \ 172 (sizeof(struct TpuExecutable_LoadProgramAndEnqueueToStream_Params)) 173 174 TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream( 175 TpuExecutable_LoadProgramAndEnqueueToStream_Params* params); 176 177 TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape( 178 XLA_Shape* host_shape, XLA_Shape* device_shape); 179 TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape); 180 TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompact(XLA_Shape* shape); 181 TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape); 182 183 typedef struct TpuExecute_RuntimeInputToPaddedData_Params { 184 int32_t struct_size; 185 void* priv; 186 187 uint32_t* runtime_input_ptr; 188 size_t runtime_input_size; 189 int8_t* padded_data_ptr; 190 size_t padded_data_size; 191 XLA_Shape* runtime_shape; 192 XLA_Shape* compile_time_shape; 193 194 TF_Status* status; // out 195 } TpuExecute_RuntimeInputToPaddedData_Params; 196 197 #define TpuExecute_RuntimeInputToPaddedData_Params_SIZE \ 198 (sizeof(struct TpuExecute_RuntimeInputToPaddedData_Params)) 199 200 TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData( 201 TpuExecute_RuntimeInputToPaddedData_Params* params); 202 203 typedef struct ConfigureDistributedTpuOp_DoWork_Params { 204 int32_t struct_size; 205 void* priv; 206 207 size_t num_cores_per_host_size; 208 const int32_t* num_cores_per_host; 209 size_t server_address_size; 210 const char* server_address; 211 212 size_t* host_config_output_size; // out 213 char** host_config_output; // out 214 TF_Status* status; // out 215 } ConfigureDistributedTpuOp_DoWork_Params; 216 217 #define ConfigureDistributedTpuOp_DoWork_Params_SIZE \ 218 (sizeof(struct ConfigureDistributedTpuOp_DoWork_Params)) 219 220 TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork( 221 ConfigureDistributedTpuOp_DoWork_Params* params); 222 223 typedef struct WaitForDistributedTpuOp_DoWork_Params { 224 int32_t struct_size; 225 void* priv; 226 227 size_t num_hosts; 228 size_t num_cores_per_host; 229 const int32_t** host_ordinal_to_global_core_id_map; 230 tensorflow::TpuMeshCommonState* tpu_mesh_common_state; 231 232 size_t* tpu_topology_output_size; // out 233 char** tpu_topology_output; // out 234 TF_Status* status; // out 235 } WaitForDistributedTpuOp_DoWork_Params; 236 237 #define WaitForDistributedTpuOp_DoWork_Params_SIZE \ 238 (sizeof(struct WaitForDistributedTpuOp_DoWork_Params)) 239 240 TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork( 241 WaitForDistributedTpuOp_DoWork_Params* params); 242 243 typedef struct InitializeHostForDistributedTpuOp_DoWork_Params { 244 int32_t struct_size; 245 void* priv; 246 247 size_t tpu_host_config_size; 248 const char* tpu_host_config; 249 bool enable_whole_mesh_compilations; 250 bool is_master_worker; 251 252 size_t* core_id_output_size; // out 253 int32_t** core_id_output; // out 254 TF_Status* status; // out 255 } InitializeHostForDistributedTpuOp_DoWork_Params; 256 257 #define InitializeHostForDistributedTpuOp_DoWork_Params_SIZE \ 258 (sizeof(struct InitializeHostForDistributedTpuOp_DoWork_Params)) 259 260 TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork( 261 InitializeHostForDistributedTpuOp_DoWork_Params* params); 262 263 TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork( 264 const size_t tpu_topology_size, const char* tpu_topology, 265 TF_Status* status); 266 267 TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork( 268 int32_t* number_of_chips_output, TF_Status* status); 269 270 TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output); 271 TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output); 272 273 TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState(); 274 275 TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus, 276 TF_Status* status); 277 TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit, 278 TF_Status* status); 279 TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes( 280 int64_t* cache_size_in_bytes); 281 282 typedef struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params { 283 int32_t struct_size; 284 void* priv; 285 286 size_t tpu_host_config_size; 287 const char* tpu_host_config; 288 289 size_t* server_address_output_size; // out 290 char** server_address_output; // out 291 TF_Status* status; // out 292 } TpuConfigurationApi_CompilationCacheServerAddressFromConfig_Params; 293 294 #define TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE \ 295 (sizeof( \ 296 struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params)) 297 298 TFTPU_CAPI_EXPORT 299 void TpuConfigurationApi_CompilationCacheServerAddressFromConfig( 300 TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params* params); 301 302 typedef struct TpuConfigurationApi_GetServerAddressAndPort_Params { 303 int32_t struct_size; 304 void* priv; 305 306 size_t* server_address_output_size; // out 307 char** server_address_output; // out 308 int* port_output; // out 309 TF_Status* status; // out 310 } TpuConfigurationApi_GetServerAddressAndPort_Params; 311 312 #define TpuConfigurationApi_GetServerAddressAndPort_Params_SIZE \ 313 (sizeof(struct TpuConfigurationApi_GetServerAddressAndPort_Params)) 314 315 TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort( 316 TpuConfigurationApi_GetServerAddressAndPort_Params* params); 317 318 // Creates a new TPU program. 319 TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_New(); 320 321 // Destroys the `tpu_program`. 322 TFTPU_CAPI_EXPORT void TpuProgram_Free(XLA_TpuProgram* tpu_program); 323 324 // Creates an array of `XLA_TpuProgram*`. 325 TFTPU_CAPI_EXPORT XLA_TpuProgram** TpuProgram_NewArray(size_t count); 326 327 // Destroys an array of `XLA_TpuProgram*`. 328 TFTPU_CAPI_EXPORT void TpuProgram_FreeArray(XLA_TpuProgram* tpu_program[]); 329 330 // Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and 331 // destroyed, it is in an unusable state. 332 TFTPU_CAPI_EXPORT void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program, 333 TF_Status* status); 334 335 // Gets TPU program size in bytes from the `tpu_program`. 336 TFTPU_CAPI_EXPORT int64_t 337 TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program); 338 339 // Logs the summary of current memory state snapshot of the `tpu_program`. 340 TFTPU_CAPI_EXPORT bool TpuProgram_LogProgramMemorySummary( 341 const XLA_TpuProgram* tpu_program); 342 343 // Gets TPU program executable info from the `tpu_program`. 344 TFTPU_CAPI_EXPORT void TpuProgram_GetExecutableInfo( 345 const XLA_TpuProgram* tpu_program, TpuSerializedProto* executable_info, 346 TF_Status* status); 347 348 // Gets host transfer info proto. 349 TFTPU_CAPI_EXPORT void TpuProgram_GetHostTransferInfo( 350 const XLA_TpuProgram* tpu_program, TpuSerializedProto* host_transfer_info, 351 TF_Status* status); 352 353 // Gets HLO metadata proto. 354 TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata( 355 const XLA_TpuProgram* tpu_program, TpuSerializedProto* hlo_metadata, 356 TF_Status* status); 357 358 // Gets may modify variables boolean value. 359 TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables( 360 const XLA_TpuProgram* tpu_program, bool* may_modify_variables); 361 362 // Checks if TPU program has sharding. 363 TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding( 364 const XLA_TpuProgram* tpu_program); 365 366 // Gets TPU program by sharding type. Return value is valid only when the 367 // `status.status()` returns `OK`. 368 TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram( 369 XLA_TpuProgram* tpu_program, TpuProgramShardingType type); 370 371 // Gets TPU executable proto from a `tpu_program`. 372 TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable( 373 const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable, 374 TF_Status* status); 375 376 // Gets compilation metadata proto from a `tpu_program`. 377 TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata( 378 const XLA_TpuProgram* tpu_program, 379 CompilerMetadataSerializedProto* compiler_metadata, TF_Status* status); 380 381 // Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`. 382 TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto( 383 TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program, 384 TF_Status* status); 385 386 // Checks if whether a TPU compilation is enabled. 387 TFTPU_CAPI_EXPORT bool TpuCompile_IsTpuCompilationEnabled(); 388 389 // XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit 390 // when cancellation is requested for an XLA compile op. Some tests require this 391 // behavior to be disabled, and we test for this condition with the following 392 // flag function. 393 TFTPU_CAPI_EXPORT bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation(); 394 395 // Returns the number of available TPU core count. 396 TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount( 397 const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type); 398 399 // Recycle unused service port. 400 TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port); 401 402 // Creates a unique compilation cache `key` used for `put` and `get` operations. 403 // Returned buffers are heap-allocated and must be owned. 404 TFTPU_CAPI_EXPORT CompilationCacheKeyResult 405 TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty property); 406 407 // Destroys the CompilationCacheKeyResult returned by calling the 408 // `TpuCompile_CreateCompilationCacheKey` API. 409 TFTPU_CAPI_EXPORT void TpuCompile_DestroyCompilationCacheKey( 410 CompilationCacheKeyResult result); 411 412 // Creates a guaranteed const fingerprint. Guarantee const is normally used in 413 // TPU inference to avoid re-copying unchanged variables onto the TPU device. 414 // It promises the value is identical for every execution in the same session 415 // even if the actual value changes in later executions. 416 TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint( 417 uint64_t fingerprint, const char* data, size_t size); 418 419 XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal, 420 TF_Status* status); 421 void TpuNodeContext_Free(XLA_TpuNodeContext* node_context); 422 423 void TpuNodeContext_StopChipHeartbeats(TF_Status* status); 424 425 void TpuNodeContext_CloseTpuHost(TF_Status* status); 426 427 void TpuNodeContext_Initialize(int device_ordinal, TF_Status* status); 428 429 // Globally initialize the TPU system for inference. 430 TFTPU_CAPI_EXPORT void TfTpu_InitializeTpuModelServer(); 431 432 struct TfTpu_OpsApiFn { 433 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild); 434 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_XrtCompileAndBuild); 435 436 TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create); 437 TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free); 438 TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState); 439 440 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Create); 441 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Destroy); 442 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Start); 443 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Stop); 444 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_CollectData); 445 446 TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream); 447 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape); 448 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize); 449 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompact); 450 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompactRaw); 451 TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData); 452 453 TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork); 454 TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork); 455 TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork); 456 TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork); 457 TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork); 458 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray); 459 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array); 460 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState); 461 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost); 462 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit); 463 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes); 464 TFTPU_ADD_FN_IN_STRUCT( 465 TpuConfigurationApi_CompilationCacheServerAddressFromConfig); 466 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort); 467 468 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New); 469 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free); 470 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_NewArray); 471 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_FreeArray); 472 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_UnloadAndDestroy); 473 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetProgramSize); 474 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_LogProgramMemorySummary); 475 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetExecutableInfo); 476 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo); 477 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata); 478 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables); 479 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding); 480 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram); 481 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable); 482 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata); 483 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto); 484 485 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled); 486 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation); 487 TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount); 488 TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort); 489 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey); 490 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey); 491 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint); 492 493 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create); 494 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free); 495 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats); 496 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost); 497 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize); 498 499 TFTPU_ADD_FN_IN_STRUCT(TfTpu_InitializeTpuModelServer); 500 }; 501 502 } // extern "C" 503 504 #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_ 505