• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#if defined(PLATFORM_GOOGLE)
2#include "third_party/tensorflow/core/tpu/tpu_executor_init_fns.inc"
3#else
4#include "tensorflow/core/tpu/tpu_executor_init_fns.inc"
5#endif
6
7namespace {
8
9tensorflow::Status SetTpuOpsStructFns(void* library_handle) {
10  // Constant cast so that we can initialize the functions. The functions are
11  // mutable here because this is the only place where they are initialized.
12  auto* ops_api_fn = const_cast<TfTpu_OpsApiFn*>(tensorflow::tpu::OpsApiFn());
13
14  TFTPU_SET_FN(ops_api_fn, ConfigureDistributedTpuOp_DoWork);
15  TFTPU_SET_FN(ops_api_fn, WaitForDistributedTpuOp_DoWork);
16  TFTPU_SET_FN(ops_api_fn, InitializeHostForDistributedTpuOp_DoWork);
17  TFTPU_SET_FN(ops_api_fn, SetGlobalTPUArrayOp_DoWork);
18  TFTPU_SET_FN(ops_api_fn, DisconnectDistributedTpuChipsOp_DoWork);
19  TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_FreeCharArray);
20  TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_FreeInt32Array);
21  TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_HasTPUPodState);
22  TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_TpusPerHost);
23  TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_TpuMemoryLimit);
24  TFTPU_SET_FN(ops_api_fn,
25               TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
26  TFTPU_SET_FN(ops_api_fn,
27               TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
28  TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_GetServerAddressAndPort);
29
30  TFTPU_SET_FN(ops_api_fn, TpuMeshState_Create);
31  TFTPU_SET_FN(ops_api_fn, TpuMeshState_Free);
32  TFTPU_SET_FN(ops_api_fn, TpuMeshState_MeshCommonState);
33
34  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngineState_Create);
35  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngineState_Free);
36  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngineState_GetState);
37
38  TFTPU_SET_FN(ops_api_fn, TpuCompile_CompileAndBuild);
39  TFTPU_SET_FN(ops_api_fn, TpuCompile_XrtCompileAndBuild);
40
41  TFTPU_SET_FN(ops_api_fn, TpuExecutable_LoadProgramAndEnqueueToStream);
42  TFTPU_SET_FN(ops_api_fn, HardwareLayout_HostShapeToDeviceShape);
43  TFTPU_SET_FN(ops_api_fn, HardwareLayout_ShapeSize);
44  TFTPU_SET_FN(ops_api_fn, HardwareLayout_ShapeSizeCompact);
45  TFTPU_SET_FN(ops_api_fn, HardwareLayout_ShapeSizeCompactRaw);
46
47  TFTPU_SET_FN(ops_api_fn, TpuExecute_RuntimeInputToPaddedData);
48
49  TFTPU_SET_FN(ops_api_fn, TpuProgram_New);
50  TFTPU_SET_FN(ops_api_fn, TpuProgram_Free);
51  TFTPU_SET_FN(ops_api_fn, TpuProgram_NewArray);
52  TFTPU_SET_FN(ops_api_fn, TpuProgram_FreeArray);
53  TFTPU_SET_FN(ops_api_fn, TpuProgram_UnloadAndDestroy);
54  TFTPU_SET_FN(ops_api_fn, TpuProgram_GetProgramSize);
55  TFTPU_SET_FN(ops_api_fn, TpuProgram_LogProgramMemorySummary);
56  TFTPU_SET_FN(ops_api_fn, TpuProgram_GetExecutableInfo);
57  TFTPU_SET_FN(ops_api_fn, TpuProgram_GetHostTransferInfo);
58  TFTPU_SET_FN(ops_api_fn, TpuProgram_GetHloMetadata);
59  TFTPU_SET_FN(ops_api_fn, TpuProgram_GetMayModifyVariables);
60  TFTPU_SET_FN(ops_api_fn, TpuProgram_HasSharding);
61  TFTPU_SET_FN(ops_api_fn, TpuProgram_GetTpuProgram);
62  TFTPU_SET_FN(ops_api_fn, TpuProgram_SerializeTpuExecutable);
63  TFTPU_SET_FN(ops_api_fn, TpuProgram_SerializeCompilerMetadata);
64  TFTPU_SET_FN(ops_api_fn,
65               TpuProgram_DeserializeFromGetTpuProgramResponseProto);
66  TFTPU_SET_FN(ops_api_fn, TpuProgram_GetFingerprint);
67  TFTPU_SET_FN(ops_api_fn, TpuProgram_DestroyFingerprint);
68
69  TFTPU_SET_FN(ops_api_fn, TpuNodeContext_Create);
70  TFTPU_SET_FN(ops_api_fn, TpuNodeContext_Free);
71  TFTPU_SET_FN(ops_api_fn, TpuNodeContext_Initialize);
72  TFTPU_SET_FN(ops_api_fn, TpuNodeContext_StopChipHeartbeats);
73  TFTPU_SET_FN(ops_api_fn, TpuNodeContext_CloseTpuHost);
74  TFTPU_SET_FN(ops_api_fn, TpuNodeContext_CompactionSupported);
75
76  TFTPU_SET_FN(ops_api_fn, TpuTopology_AvailableCoreCount);
77  TFTPU_SET_FN(ops_api_fn, TpuNetUtil_RecycleUnusedPort);
78  TFTPU_SET_FN(ops_api_fn, TpuCompile_IsTpuCompilationEnabled);
79  TFTPU_SET_FN(ops_api_fn, TpuCompile_ShouldTpuCompileOpIgnoreCancellation);
80  TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateCompilationCacheKey);
81  TFTPU_SET_FN(ops_api_fn, TpuCompile_DestroyCompilationCacheKey);
82  TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateGuaranteedConstFingerprint);
83
84  TFTPU_SET_FN(ops_api_fn, TpuProfiler_Create);
85  TFTPU_SET_FN(ops_api_fn, TpuProfiler_Destroy);
86  TFTPU_SET_FN(ops_api_fn, TpuProfiler_Start);
87  TFTPU_SET_FN(ops_api_fn, TpuProfiler_Stop);
88  TFTPU_SET_FN(ops_api_fn, TpuProfiler_CollectData);
89
90  TFTPU_SET_FN(ops_api_fn, TfTpu_InitializeTpuModelServer);
91
92  TFTPU_SET_FN(ops_api_fn, TfTpuOrdinalSelector_Create);
93  TFTPU_SET_FN(ops_api_fn, TfTpuOrdinalSelector_Destroy);
94  TFTPU_SET_FN(ops_api_fn, TfTpuOrdinalSelector_GetOrdinal);
95  TFTPU_SET_FN(ops_api_fn, TfTpuOrdinalSelector_DequeueFromCoreSelector);
96  TFTPU_SET_FN(ops_api_fn, TfTpu_GetTpuPartitionedCallParams);
97
98  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_ExecutePartitioner);
99  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_ConfigureMemory);
100  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_CollateMemory);
101  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_ConfigureHost);
102  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_ConnectHosts);
103  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_Finalize);
104  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_IsInitialized);
105  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_WriteParameters);
106  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_ReadParameters);
107  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingTensorBatchFixedState_Create);
108  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingTensorBatchFixedState_Destroy);
109  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_EnqueueTensorBatch);
110  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_RecvActivationsComputation);
111  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation);
112  TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation);
113
114  return tensorflow::Status::OK();
115}
116
117tensorflow::Status InitializeTpuStructFns(void* library_handle) {
118  TF_RETURN_IF_ERROR(SetTpuOpsStructFns(library_handle));
119  TF_RETURN_IF_ERROR(SetExecutorStructFn(library_handle));
120
121  return tensorflow::Status::OK();
122}
123
124}  // namespace
125