• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_
16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_
17 
18 #include <stddef.h>
19 
20 #include <cstdint>
21 
22 #include "absl/types/optional.h"
23 #include "tensorflow/core/tpu/libtftpu.h"
24 #include "tensorflow/stream_executor/tpu/c_api_decl.h"
25 #include "tensorflow/stream_executor/tpu/proto_helper.h"
26 
27 typedef struct TpuSerializedProto TpuSerializedProto;
28 
29 namespace tensorflow {
30 class TpuMeshCommonState;
31 }  // namespace tensorflow
32 
33 extern "C" {
34 
35 typedef struct XLA_TpuProgram XLA_TpuProgram;
36 
37 // Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj.
38 enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding };
39 
40 struct TpuProgramFingerprint {
41   const char* bytes;
42   size_t size;
43 };
44 
45 struct TpuExecutableSerializedProto {
46   const char* bytes;
47   size_t size;
48 };
49 
50 struct CompilerMetadataSerializedProto {
51   const char* bytes;
52   size_t size;
53 };
54 
55 struct HostComputeMetadataSerializedProto {
56   const char* bytes;
57   size_t size;
58 };
59 
60 typedef struct XLA_TpuMeshState XLA_TpuMeshState;
61 
62 typedef struct TpuProfiler TpuProfiler;
63 
64 typedef struct XLA_DeviceAssignment {
65   const char* bytes;
66   size_t size;
67 } XLA_DeviceAssignment;
68 
69 // Property for creating compilation cache key.
70 struct CompilationCacheKeyProperty {
71   const char* config_prefix;
72   const char* shapes_prefix;
73   const char* function_name;
74   uint64_t mlir_module_fingerprint;
75   const int32_t* device_ids;
76   size_t device_ids_size;
77   int32_t guaranteed_constants_size;
78   uint64_t function_library_fingerprint;
79   int32_t num_cores_per_replica;
80   int32_t num_replicas;
81   const XLA_TpuMeshState* mesh_state;
82 };
83 
84 // Compilation cache key result returning both the key and a more verbose debug
85 // version.
86 struct CompilationCacheKeyResult {
87   const char* key;
88   const char* debug_string;
89 };
90 
91 typedef struct XLA_TpuNodeContext XLA_TpuNodeContext;
92 
93 typedef struct TfTpu_OrdinalSelector TfTpuOrdinalSelector;
94 
95 struct TpuPartitionedCall_Params {
96   bool input_shape_opt;
97   bool group_tensors_for_packing;
98   int32_t minimum_input_tensors_packing;
99   int32_t minimum_output_tensors_packing;
100 
101   // Whether to attempt to automatically shard inputs by adding an
102   // XlaSharding op after each input.
103   bool enable_auto_xla_input_sharding;
104 
105   // The dimension of each input to shard if
106   // enable_auto_xla_input_sharding is set to true. Negative numbers are
107   // allowed and refers to dimensions starting from the end.
108   int32_t auto_xla_input_sharding_dim;
109 
110   // If true, only create one variable on the TPU for each variable on the CPU.
111   bool enable_variable_deduplication;
112 };
113 
114 // Compiles Mlir or TF function computation by lowering into HLO IR and returns
115 // `count` number of TPU programs ready for execution.
116 // The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and creates
117 // `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller is
118 // responsible to deallocate both the `XLA_TpuProgram*[]` array and the
119 // `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free`
120 // API respectively.
121 TFTPU_CAPI_EXPORT void TpuCompile_CompileAndBuild(
122     TpuSerializedProto compilation_request, const XLA_TpuMeshState* mesh_state,
123     XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status);
124 
125 // Compiles a HLO IR and returns `count` number of TPU programs ready for
126 // execution. The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and
127 // creates `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller
128 // is responsible to deallocate both the `XLA_TpuProgram*[]` array and the
129 // `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free`
130 // API respectively.
131 TFTPU_CAPI_EXPORT void TpuCompile_XrtCompileAndBuild(
132     TpuSerializedProto xrt_computation, const XLA_TpuMeshState* mesh_state,
133     XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status);
134 
135 // Creates a TPU profiler that is ready to start profiling.
136 TFTPU_CAPI_EXPORT void TpuProfiler_Create(TpuProfiler** tpu_profiler,
137                                           TF_Status* status);
138 // Destroys the given TPU profiler.
139 TFTPU_CAPI_EXPORT void TpuProfiler_Destroy(TpuProfiler* tpu_profiler);
140 // Starts profiling if not already started, returns an error otherwise.
141 TFTPU_CAPI_EXPORT void TpuProfiler_Start(TpuProfiler* tpu_profiler,
142                                          TF_Status* status);
143 // Stops profiling if not already stopped, returns an error otherwise.
144 TFTPU_CAPI_EXPORT void TpuProfiler_Stop(TpuProfiler* tpu_profiler,
145                                         TF_Status* status);
146 // Serializes profiled data into `buffer` and returns the size of `buffer`. The
147 // profile data held by the TPU driver will be cleared after retrieval.
148 //
149 // Step 1. Query the size of buffer required into `size_in_bytes`.
150 //
151 //   size_t size_in_bytes;
152 //   TpuProfiler_CollectData(profiler, status, nullptr, &size_in_bytes);
153 //
154 // Step 2. Retrieve the data into a `buffer` of size `size_in_bytes`.
155 //         Subsequently,The TPU driver clears its copy of the profile data.
156 //
157 //   uint8_t buffer = new uint8_t[size_in_bytes];
158 //   TpuProfiler_CollectData(profiler, status, buffer, size_in_bytes);
159 //
160 // Step 3. Unpack the data into an XSpace.
161 //
162 //   tensorflow::profiler::XSpace space;
163 //   space.ParseFromArray(buffer, size_in_bytes);
164 //
165 TFTPU_CAPI_EXPORT void TpuProfiler_CollectData(TpuProfiler* tpu_profiler,
166                                                TF_Status* status,
167                                                uint8_t* buffer,
168                                                size_t* size_in_bytes);
169 
170 // Creates a new TPU mesh state object.
171 TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create();
172 
173 // Deletes the given TPU `mesh_state` object. Once deleted the object is
174 // unusable.
175 TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
176 
177 // Returns a pointer to an opaque mesh data structure used internally.
178 TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState(
179     XLA_TpuMeshState* mesh_state);
180 
181 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_Create(
182     TfTpuOrdinalSelector** ordinal_selector, int num_cores_per_replica);
183 
184 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_Destroy(
185     TfTpuOrdinalSelector* ordinal_selector);
186 
187 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_GetOrdinal(
188     TfTpuOrdinalSelector* ordinal_selector, absl::optional<uint64_t> key,
189     int64_t* req_id, int64_t* ordinal);
190 
191 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_DequeueFromCoreSelector(
192     TfTpuOrdinalSelector* ordinal_selector, int32_t device_ordinal,
193     int64_t req_id);
194 
195 TFTPU_CAPI_EXPORT void TfTpu_GetTpuPartitionedCallParams(
196     TpuPartitionedCall_Params* params);
197 
198 typedef struct TpuExecutable_LoadProgramAndEnqueueToStream_Params {
199   int32_t struct_size;
200   void* priv;
201 
202   const XLA_TpuProgram* program;
203   SE_DeviceMemoryBase* arguments;
204   size_t arguments_len;
205   SE_DeviceMemoryBase* result;
206   bool has_cross_program_prefetch_addr;
207   SE_DeviceMemoryBase* cross_program_prefetch_addr;
208   int32_t rng_seed;
209   XLA_DeviceAssignment* device_assignment;
210   SE_Stream* stream;
211 
212   TF_Status* status;  // out
213 } TpuExecutable_LoadProgramAndEnqueueToStream_Params;
214 
215 #define TpuExecutable_LoadProgramAndEnqueueToStream_Params_SIZE \
216   (sizeof(struct TpuExecutable_LoadProgramAndEnqueueToStream_Params))
217 
218 TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream(
219     TpuExecutable_LoadProgramAndEnqueueToStream_Params* params);
220 
221 TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape(
222     XLA_Shape* host_shape, XLA_Shape* device_shape);
223 TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape);
224 TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompact(XLA_Shape* shape);
225 TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape);
226 
227 typedef struct TpuExecute_RuntimeInputToPaddedData_Params {
228   int32_t struct_size;
229   void* priv;
230 
231   uint32_t* runtime_input_ptr;
232   size_t runtime_input_size;
233   int8_t* padded_data_ptr;
234   size_t padded_data_size;
235   XLA_Shape* runtime_shape;
236   XLA_Shape* compile_time_shape;
237 
238   TF_Status* status;  // out
239 } TpuExecute_RuntimeInputToPaddedData_Params;
240 
241 #define TpuExecute_RuntimeInputToPaddedData_Params_SIZE \
242   (sizeof(struct TpuExecute_RuntimeInputToPaddedData_Params))
243 
244 TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData(
245     TpuExecute_RuntimeInputToPaddedData_Params* params);
246 
247 typedef struct ConfigureDistributedTpuOp_DoWork_Params {
248   int32_t struct_size;
249   void* priv;
250 
251   size_t num_cores_per_host_size;
252   const int32_t* num_cores_per_host;
253   size_t server_address_size;
254   const char* server_address;
255 
256   size_t* host_config_output_size;  // out
257   char** host_config_output;        // out
258   TF_Status* status;                // out
259 } ConfigureDistributedTpuOp_DoWork_Params;
260 
261 #define ConfigureDistributedTpuOp_DoWork_Params_SIZE \
262   (sizeof(struct ConfigureDistributedTpuOp_DoWork_Params))
263 
264 TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
265     ConfigureDistributedTpuOp_DoWork_Params* params);
266 
267 typedef struct WaitForDistributedTpuOp_DoWork_Params {
268   int32_t struct_size;
269   void* priv;
270 
271   size_t num_hosts;
272   size_t num_cores_per_host;
273   const int32_t** host_ordinal_to_global_core_id_map;
274   tensorflow::TpuMeshCommonState* tpu_mesh_common_state;
275 
276   size_t* tpu_topology_output_size;  // out
277   char** tpu_topology_output;        // out
278   TF_Status* status;                 // out
279 } WaitForDistributedTpuOp_DoWork_Params;
280 
281 #define WaitForDistributedTpuOp_DoWork_Params_SIZE \
282   (sizeof(struct WaitForDistributedTpuOp_DoWork_Params))
283 
284 TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
285     WaitForDistributedTpuOp_DoWork_Params* params);
286 
287 typedef struct InitializeHostForDistributedTpuOp_DoWork_Params {
288   int32_t struct_size;
289   void* priv;
290 
291   size_t tpu_host_config_size;
292   const char* tpu_host_config;
293   bool enable_whole_mesh_compilations;
294   bool is_master_worker;
295 
296   size_t* core_id_output_size;  // out
297   int32_t** core_id_output;     // out
298   TF_Status* status;            // out
299 } InitializeHostForDistributedTpuOp_DoWork_Params;
300 
301 #define InitializeHostForDistributedTpuOp_DoWork_Params_SIZE \
302   (sizeof(struct InitializeHostForDistributedTpuOp_DoWork_Params))
303 
304 TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
305     InitializeHostForDistributedTpuOp_DoWork_Params* params);
306 
307 TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
308     const size_t tpu_topology_size, const char* tpu_topology,
309     TF_Status* status);
310 
311 TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork(
312     int32_t* number_of_chips_output, TF_Status* status);
313 
314 TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output);
315 TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output);
316 
317 TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState();
318 
319 TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus,
320                                                        TF_Status* status);
321 TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit,
322                                                           TF_Status* status);
323 TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes(
324     int64_t* cache_size_in_bytes);
325 
326 typedef struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params {
327   int32_t struct_size;
328   void* priv;
329 
330   size_t tpu_host_config_size;
331   const char* tpu_host_config;
332 
333   size_t* server_address_output_size;  // out
334   char** server_address_output;        // out
335   TF_Status* status;                   // out
336 } TpuConfigurationApi_CompilationCacheServerAddressFromConfig_Params;
337 
338 #define TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE \
339   (sizeof(                                                                   \
340       struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params))
341 
342 TFTPU_CAPI_EXPORT
343 void TpuConfigurationApi_CompilationCacheServerAddressFromConfig(
344     TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params* params);
345 
346 typedef struct TpuConfigurationApi_GetServerAddressAndPort_Params {
347   int32_t struct_size;
348   void* priv;
349 
350   size_t* server_address_output_size;  // out
351   char** server_address_output;        // out
352   int* port_output;                    // out
353   TF_Status* status;                   // out
354 } TpuConfigurationApi_GetServerAddressAndPort_Params;
355 
356 #define TpuConfigurationApi_GetServerAddressAndPort_Params_SIZE \
357   (sizeof(struct TpuConfigurationApi_GetServerAddressAndPort_Params))
358 
359 TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort(
360     TpuConfigurationApi_GetServerAddressAndPort_Params* params);
361 
362 // Creates a new TPU program.
363 TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_New();
364 
365 // Destroys the `tpu_program`.
366 TFTPU_CAPI_EXPORT void TpuProgram_Free(XLA_TpuProgram* tpu_program);
367 
368 // Creates an array of `XLA_TpuProgram*`.
369 TFTPU_CAPI_EXPORT XLA_TpuProgram** TpuProgram_NewArray(size_t count);
370 
371 // Destroys an array of `XLA_TpuProgram*`.
372 TFTPU_CAPI_EXPORT void TpuProgram_FreeArray(XLA_TpuProgram* tpu_program[]);
373 
374 // Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and
375 // destroyed, it is in an unusable state.
376 TFTPU_CAPI_EXPORT void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program,
377                                                    TF_Status* status);
378 
379 // Gets TPU program size in bytes from the `tpu_program`.
380 TFTPU_CAPI_EXPORT int64_t
381 TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program);
382 
383 // Logs the summary of current memory state snapshot of the `tpu_program`.
384 TFTPU_CAPI_EXPORT bool TpuProgram_LogProgramMemorySummary(
385     const XLA_TpuProgram* tpu_program);
386 
387 // Gets TPU program executable info from the `tpu_program`.
388 TFTPU_CAPI_EXPORT void TpuProgram_GetExecutableInfo(
389     const XLA_TpuProgram* tpu_program, TpuSerializedProto* executable_info,
390     TF_Status* status);
391 
392 // Gets host transfer info proto.
393 TFTPU_CAPI_EXPORT void TpuProgram_GetHostTransferInfo(
394     const XLA_TpuProgram* tpu_program, TpuSerializedProto* host_transfer_info,
395     TF_Status* status);
396 
397 // Gets HLO metadata proto.
398 TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata(
399     const XLA_TpuProgram* tpu_program, TpuSerializedProto* hlo_metadata,
400     TF_Status* status);
401 
402 // Gets may modify variables boolean value.
403 TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
404     const XLA_TpuProgram* tpu_program, bool* may_modify_variables);
405 
406 // Checks if TPU program has sharding.
407 TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
408     const XLA_TpuProgram* tpu_program);
409 
410 // Gets TPU program by sharding type. Return value is valid only when the
411 // `status.status()` returns `OK`.
412 TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram(
413     XLA_TpuProgram* tpu_program, TpuProgramShardingType type);
414 
415 // Gets TPU executable proto from a `tpu_program`.
416 TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable(
417     const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable,
418     TF_Status* status);
419 
420 // Gets compilation metadata proto from a `tpu_program`.
421 TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata(
422     const XLA_TpuProgram* tpu_program,
423     CompilerMetadataSerializedProto* compiler_metadata, TF_Status* status);
424 
425 // Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`.
426 TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto(
427     TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program,
428     TF_Status* status);
429 
430 TFTPU_CAPI_EXPORT TpuProgramFingerprint
431 TpuProgram_GetFingerprint(const XLA_TpuProgram* tpu_program);
432 
433 TFTPU_CAPI_EXPORT void TpuProgram_DestroyFingerprint(
434     TpuProgramFingerprint fingerprint);
435 
436 // Checks if whether a TPU compilation is enabled.
437 TFTPU_CAPI_EXPORT bool TpuCompile_IsTpuCompilationEnabled();
438 
439 // XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit
440 // when cancellation is requested for an XLA compile op. Some tests require this
441 // behavior to be disabled, and we test for this condition with the following
442 // flag function.
443 TFTPU_CAPI_EXPORT bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation();
444 
445 // Returns the number of available TPU core count.
446 TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount(
447     const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type);
448 
449 // Recycle unused service port.
450 TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port);
451 
452 // Creates a unique compilation cache `key` used for `put` and `get` operations.
453 // Returned buffers are heap-allocated and must be owned.
454 TFTPU_CAPI_EXPORT CompilationCacheKeyResult
455 TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty property);
456 
457 // Destroys the CompilationCacheKeyResult returned by calling the
458 // `TpuCompile_CreateCompilationCacheKey` API.
459 TFTPU_CAPI_EXPORT void TpuCompile_DestroyCompilationCacheKey(
460     CompilationCacheKeyResult result);
461 
462 // Creates a guaranteed const fingerprint. Guarantee const is normally used in
463 // TPU inference to avoid re-copying unchanged variables onto the TPU device.
464 // It promises the value is identical for every execution in the same session
465 // even if the actual value changes in later executions.
466 TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint(
467     uint64_t fingerprint, const char* data, size_t size);
468 
469 XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
470                                           TF_Status* status);
471 void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
472 
473 void TpuNodeContext_StopChipHeartbeats(TF_Status* status);
474 
475 void TpuNodeContext_CloseTpuHost(TF_Status* status);
476 
477 void TpuNodeContext_Initialize(int device_ordinal, TF_Status* status);
478 
479 bool TpuNodeContext_CompactionSupported(int device_ordinal);
480 
481 // Globally initialize the TPU system for inference.
482 TFTPU_CAPI_EXPORT void TfTpu_InitializeTpuModelServer();
483 
484 struct TfTpu_OpsApiFn {
485   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild);
486   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_XrtCompileAndBuild);
487 
488   TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create);
489   TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free);
490   TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState);
491 
492   TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Create);
493   TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Destroy);
494   TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Start);
495   TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Stop);
496   TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_CollectData);
497 
498   TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream);
499   TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape);
500   TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize);
501   TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompact);
502   TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompactRaw);
503 
504   TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData);
505 
506   TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork);
507   TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork);
508   TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork);
509   TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork);
510   TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
511   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray);
512   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array);
513   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState);
514   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost);
515   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit);
516   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
517   TFTPU_ADD_FN_IN_STRUCT(
518       TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
519   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort);
520 
521   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New);
522   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free);
523   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_NewArray);
524   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_FreeArray);
525   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_UnloadAndDestroy);
526   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetProgramSize);
527   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_LogProgramMemorySummary);
528   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetExecutableInfo);
529   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo);
530   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata);
531   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables);
532   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding);
533   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram);
534   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable);
535   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata);
536   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto);
537   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetFingerprint);
538   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DestroyFingerprint);
539 
540   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled);
541   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation);
542   TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount);
543   TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort);
544   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey);
545   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey);
546   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint);
547 
548   TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create);
549   TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free);
550   TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats);
551   TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost);
552   TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize);
553   TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CompactionSupported);
554 
555   TFTPU_ADD_FN_IN_STRUCT(TfTpu_InitializeTpuModelServer);
556 
557   TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_Create);
558   TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_Destroy);
559   TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_GetOrdinal);
560   TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_DequeueFromCoreSelector);
561   TFTPU_ADD_FN_IN_STRUCT(TfTpu_GetTpuPartitionedCallParams);
562 };
563 
564 }  // extern "C"
565 
566 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_
567