1#if defined(PLATFORM_GOOGLE) 2#include "third_party/tensorflow/core/tpu/tpu_executor_init_fns.inc" 3#else 4#include "tensorflow/core/tpu/tpu_executor_init_fns.inc" 5#endif 6 7namespace { 8 9tensorflow::Status SetTpuOpsStructFns(void* library_handle) { 10 // Constant cast so that we can initialize the functions. The functions are 11 // mutable here because this is the only place where they are initialized. 12 auto* ops_api_fn = const_cast<TfTpu_OpsApiFn*>(tensorflow::tpu::OpsApiFn()); 13 14 TFTPU_SET_FN(ops_api_fn, ConfigureDistributedTpuOp_DoWork); 15 TFTPU_SET_FN(ops_api_fn, WaitForDistributedTpuOp_DoWork); 16 TFTPU_SET_FN(ops_api_fn, InitializeHostForDistributedTpuOp_DoWork); 17 TFTPU_SET_FN(ops_api_fn, SetGlobalTPUArrayOp_DoWork); 18 TFTPU_SET_FN(ops_api_fn, DisconnectDistributedTpuChipsOp_DoWork); 19 TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_FreeCharArray); 20 TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_FreeInt32Array); 21 TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_HasTPUPodState); 22 TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_TpusPerHost); 23 TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_TpuMemoryLimit); 24 TFTPU_SET_FN(ops_api_fn, 25 TpuConfigurationApi_RemoteCompilationCacheSizeInBytes); 26 TFTPU_SET_FN(ops_api_fn, 27 TpuConfigurationApi_CompilationCacheServerAddressFromConfig); 28 TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_GetServerAddressAndPort); 29 30 TFTPU_SET_FN(ops_api_fn, TpuMeshState_Create); 31 TFTPU_SET_FN(ops_api_fn, TpuMeshState_Free); 32 TFTPU_SET_FN(ops_api_fn, TpuMeshState_MeshCommonState); 33 34 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngineState_Create); 35 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngineState_Free); 36 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngineState_GetState); 37 38 TFTPU_SET_FN(ops_api_fn, TpuCompile_CompileAndBuild); 39 TFTPU_SET_FN(ops_api_fn, TpuCompile_XrtCompileAndBuild); 40 41 TFTPU_SET_FN(ops_api_fn, TpuExecutable_LoadProgramAndEnqueueToStream); 42 TFTPU_SET_FN(ops_api_fn, HardwareLayout_HostShapeToDeviceShape); 43 TFTPU_SET_FN(ops_api_fn, HardwareLayout_ShapeSize); 44 TFTPU_SET_FN(ops_api_fn, HardwareLayout_ShapeSizeCompact); 45 TFTPU_SET_FN(ops_api_fn, HardwareLayout_ShapeSizeCompactRaw); 46 47 TFTPU_SET_FN(ops_api_fn, TpuExecute_RuntimeInputToPaddedData); 48 49 TFTPU_SET_FN(ops_api_fn, TpuProgram_New); 50 TFTPU_SET_FN(ops_api_fn, TpuProgram_Free); 51 TFTPU_SET_FN(ops_api_fn, TpuProgram_NewArray); 52 TFTPU_SET_FN(ops_api_fn, TpuProgram_FreeArray); 53 TFTPU_SET_FN(ops_api_fn, TpuProgram_UnloadAndDestroy); 54 TFTPU_SET_FN(ops_api_fn, TpuProgram_GetProgramSize); 55 TFTPU_SET_FN(ops_api_fn, TpuProgram_LogProgramMemorySummary); 56 TFTPU_SET_FN(ops_api_fn, TpuProgram_GetExecutableInfo); 57 TFTPU_SET_FN(ops_api_fn, TpuProgram_GetHostTransferInfo); 58 TFTPU_SET_FN(ops_api_fn, TpuProgram_GetHloMetadata); 59 TFTPU_SET_FN(ops_api_fn, TpuProgram_GetMayModifyVariables); 60 TFTPU_SET_FN(ops_api_fn, TpuProgram_HasSharding); 61 TFTPU_SET_FN(ops_api_fn, TpuProgram_GetTpuProgram); 62 TFTPU_SET_FN(ops_api_fn, TpuProgram_SerializeTpuExecutable); 63 TFTPU_SET_FN(ops_api_fn, TpuProgram_SerializeCompilerMetadata); 64 TFTPU_SET_FN(ops_api_fn, 65 TpuProgram_DeserializeFromGetTpuProgramResponseProto); 66 TFTPU_SET_FN(ops_api_fn, TpuProgram_GetFingerprint); 67 TFTPU_SET_FN(ops_api_fn, TpuProgram_DestroyFingerprint); 68 69 TFTPU_SET_FN(ops_api_fn, TpuNodeContext_Create); 70 TFTPU_SET_FN(ops_api_fn, TpuNodeContext_Free); 71 TFTPU_SET_FN(ops_api_fn, TpuNodeContext_Initialize); 72 TFTPU_SET_FN(ops_api_fn, TpuNodeContext_StopChipHeartbeats); 73 TFTPU_SET_FN(ops_api_fn, TpuNodeContext_CloseTpuHost); 74 TFTPU_SET_FN(ops_api_fn, TpuNodeContext_CompactionSupported); 75 76 TFTPU_SET_FN(ops_api_fn, TpuTopology_AvailableCoreCount); 77 TFTPU_SET_FN(ops_api_fn, TpuNetUtil_RecycleUnusedPort); 78 TFTPU_SET_FN(ops_api_fn, TpuCompile_IsTpuCompilationEnabled); 79 TFTPU_SET_FN(ops_api_fn, TpuCompile_ShouldTpuCompileOpIgnoreCancellation); 80 TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateCompilationCacheKey); 81 TFTPU_SET_FN(ops_api_fn, TpuCompile_DestroyCompilationCacheKey); 82 TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateGuaranteedConstFingerprint); 83 84 TFTPU_SET_FN(ops_api_fn, TpuProfiler_Create); 85 TFTPU_SET_FN(ops_api_fn, TpuProfiler_Destroy); 86 TFTPU_SET_FN(ops_api_fn, TpuProfiler_Start); 87 TFTPU_SET_FN(ops_api_fn, TpuProfiler_Stop); 88 TFTPU_SET_FN(ops_api_fn, TpuProfiler_CollectData); 89 90 TFTPU_SET_FN(ops_api_fn, TfTpu_InitializeTpuModelServer); 91 92 TFTPU_SET_FN(ops_api_fn, TfTpuOrdinalSelector_Create); 93 TFTPU_SET_FN(ops_api_fn, TfTpuOrdinalSelector_Destroy); 94 TFTPU_SET_FN(ops_api_fn, TfTpuOrdinalSelector_GetOrdinal); 95 TFTPU_SET_FN(ops_api_fn, TfTpuOrdinalSelector_DequeueFromCoreSelector); 96 TFTPU_SET_FN(ops_api_fn, TfTpu_GetTpuPartitionedCallParams); 97 98 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_ExecutePartitioner); 99 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_ConfigureMemory); 100 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_CollateMemory); 101 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_ConfigureHost); 102 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_ConnectHosts); 103 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_Finalize); 104 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_IsInitialized); 105 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_WriteParameters); 106 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_ReadParameters); 107 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingTensorBatchFixedState_Create); 108 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingTensorBatchFixedState_Destroy); 109 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_EnqueueTensorBatch); 110 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_RecvActivationsComputation); 111 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation); 112 TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation); 113 114 return tensorflow::Status::OK(); 115} 116 117tensorflow::Status InitializeTpuStructFns(void* library_handle) { 118 TF_RETURN_IF_ERROR(SetTpuOpsStructFns(library_handle)); 119 TF_RETURN_IF_ERROR(SetExecutorStructFn(library_handle)); 120 121 return tensorflow::Status::OK(); 122} 123 124} // namespace 125