1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_ 16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_ 17 18 #include <stdint.h> 19 20 #include <vector> 21 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" 24 #include "tensorflow/stream_executor/lib/statusor.h" 25 26 namespace tensorflow { 27 28 Status CreateTpuCompilationCache( 29 ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache); 30 31 xla::StatusOr<std::vector<int32_t>> ConstructDevicesPerHost( 32 OpKernelContext* ctx); 33 34 // The ConfigureDistributedTpu op is used to start an TPUDriver from 35 // TensorFlow. It should be run on a TPU_SYSTEM device and returns the 36 // connection host:port for the CompilationCacheServer. The 37 // CompilationCacheServer will remain live until the device's Resource Manager 38 // is cleared or a ShutdownDistributedTpuOp is run on the same device. 39 class ConfigureDistributedTpuOp : public OpKernel { 40 public: ConfigureDistributedTpuOp(OpKernelConstruction * ctx)41 explicit ConfigureDistributedTpuOp(OpKernelConstruction* ctx) 42 : OpKernel(ctx) { 43 OP_REQUIRES( 44 ctx, ctx->num_inputs() > 0, 45 errors::Internal("_ConfigureDistributedTPU needs at least one input")); 46 } 47 void Compute(OpKernelContext* ctx) override; ~ConfigureDistributedTpuOp()48 ~ConfigureDistributedTpuOp() override {} 49 50 private: 51 // ConfigureDistributedTpuOp is neither copyable nor movable. 52 ConfigureDistributedTpuOp(const ConfigureDistributedTpuOp&) = delete; 53 ConfigureDistributedTpuOp& operator=(const ConfigureDistributedTpuOp&) = 54 delete; 55 }; 56 57 // The WaitForDistributedTpuOp op is used to block execution until 58 // the distributed Tpu system has started up. It must be run on 59 // the same TPU_SYSTEM device that ConfigureDistributedTpuOp was run 60 // on, after all of the InitializeHostForDistributedTpuOp Ops have 61 // completed. 62 class WaitForDistributedTpuOp : public OpKernel { 63 public: WaitForDistributedTpuOp(OpKernelConstruction * ctx)64 explicit WaitForDistributedTpuOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 65 OP_REQUIRES_OK(ctx, 66 ctx->GetAttr("startup_timeout_sec", &startup_timeout_sec_)); 67 OP_REQUIRES(ctx, startup_timeout_sec_ > 0, 68 errors::InvalidArgument("startup_timeout_sec ", 69 startup_timeout_sec_, " must be >0")); 70 } 71 void Compute(OpKernelContext* ctx) override; ~WaitForDistributedTpuOp()72 ~WaitForDistributedTpuOp() override {} 73 74 private: 75 // The time to wait for all hosts to start up. 76 int startup_timeout_sec_; 77 78 // WaitForDistributedTpuOp is neither copyable nor movable. 79 WaitForDistributedTpuOp(const WaitForDistributedTpuOp&) = delete; 80 WaitForDistributedTpuOp& operator=(const WaitForDistributedTpuOp&) = delete; 81 }; 82 83 // The ShutdownDistributedTpu op is used to stop a running TPUDriver from 84 // TensorFlow. It should be run on the TPU_SYSTEM device where 85 // ConfigureDistributedTpuOp was run. 86 class ShutdownDistributedTpuOp : public OpKernel { 87 public: ShutdownDistributedTpuOp(OpKernelConstruction * ctx)88 explicit ShutdownDistributedTpuOp(OpKernelConstruction* ctx) 89 : OpKernel(ctx) {} 90 91 void Compute(OpKernelContext* ctx) override; 92 ~ShutdownDistributedTpuOp()93 ~ShutdownDistributedTpuOp() override {} 94 95 private: 96 // ShutdownDistributedTpuOp is neither copyable nor movable. 97 ShutdownDistributedTpuOp(const ShutdownDistributedTpuOp&) = delete; 98 ShutdownDistributedTpuOp& operator=(const ShutdownDistributedTpuOp&) = delete; 99 }; 100 101 // The InitializeHostForDistributedTpu op is used to initialize the 102 // TPUPlatform on a host in a distributed TPU system. It should be 103 // run on every host containing TPU devices before any other Ops that use 104 // TPU are run. 105 class InitializeHostForDistributedTpuOp : public OpKernel { 106 public: InitializeHostForDistributedTpuOp(OpKernelConstruction * ctx)107 explicit InitializeHostForDistributedTpuOp(OpKernelConstruction* ctx) 108 : OpKernel(ctx) { 109 ctx->GetAttr("enable_whole_mesh_compilations", 110 &enable_whole_mesh_compilations_) 111 .IgnoreError(); 112 } 113 114 void Compute(OpKernelContext* ctx) override; 115 ~InitializeHostForDistributedTpuOp()116 ~InitializeHostForDistributedTpuOp() override {} 117 118 private: 119 // InitializeHostForDistributedTpuOp is neither copyable nor movable. 120 InitializeHostForDistributedTpuOp(const InitializeHostForDistributedTpuOp&) = 121 delete; 122 InitializeHostForDistributedTpuOp& operator=( 123 const InitializeHostForDistributedTpuOp&) = delete; 124 125 bool enable_whole_mesh_compilations_ = false; 126 }; 127 128 // The SetGlobalTPUArray op is used to initialize the TPUPlatform on a 129 // host in a distributed TPU system. It should be run on every host 130 // containing TPU devices before any other Ops that use TPU are run. 131 class SetGlobalTPUArrayOp : public OpKernel { 132 public: SetGlobalTPUArrayOp(OpKernelConstruction * ctx)133 explicit SetGlobalTPUArrayOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 134 135 void Compute(OpKernelContext* ctx) override; 136 ~SetGlobalTPUArrayOp()137 ~SetGlobalTPUArrayOp() override {} 138 139 private: 140 // SetGlobalTPUArrayOp is neither copyable nor movable. 141 SetGlobalTPUArrayOp(const SetGlobalTPUArrayOp&) = delete; 142 SetGlobalTPUArrayOp& operator=(const SetGlobalTPUArrayOp&) = delete; 143 }; 144 145 // The DisconnectDistributedTpuChips op is used to disconnect all the chips on a 146 // host from a running TPUDriver instance. It should be run on every host 147 // containing TPU devices before the ShutdownDistributedTpuOp is run on 148 // the TPU_SYSTEM. 149 class DisconnectDistributedTpuChipsOp : public OpKernel { 150 public: DisconnectDistributedTpuChipsOp(OpKernelConstruction * ctx)151 explicit DisconnectDistributedTpuChipsOp(OpKernelConstruction* ctx) 152 : OpKernel(ctx) {} 153 154 void Compute(OpKernelContext* ctx) override; 155 ~DisconnectDistributedTpuChipsOp()156 ~DisconnectDistributedTpuChipsOp() override {} 157 158 private: 159 // DisconnectDistributedTpuChipsOp is neither copyable nor movable. 160 DisconnectDistributedTpuChipsOp(const DisconnectDistributedTpuChipsOp&) = 161 delete; 162 DisconnectDistributedTpuChipsOp& operator=( 163 const DisconnectDistributedTpuChipsOp&) = delete; 164 }; 165 166 } // namespace tensorflow 167 168 #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_ 169