• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
17 
18 #include <stdint.h>
19 
20 #include <vector>
21 
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
24 #include "tensorflow/stream_executor/lib/statusor.h"
25 
26 namespace tensorflow {
27 
28 Status CreateTpuCompilationCache(
29     ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache);
30 
31 xla::StatusOr<std::vector<int32_t>> ConstructDevicesPerHost(
32     OpKernelContext* ctx);
33 
34 // The ConfigureDistributedTpu op is used to start an TPUDriver from
35 // TensorFlow. It should be run on a TPU_SYSTEM device and returns the
36 // connection host:port for the CompilationCacheServer. The
37 // CompilationCacheServer will remain live until the device's Resource Manager
38 // is cleared or a ShutdownDistributedTpuOp is run on the same device.
39 class ConfigureDistributedTpuOp : public OpKernel {
40  public:
ConfigureDistributedTpuOp(OpKernelConstruction * ctx)41   explicit ConfigureDistributedTpuOp(OpKernelConstruction* ctx)
42       : OpKernel(ctx) {
43     OP_REQUIRES(
44         ctx, ctx->num_inputs() > 0,
45         errors::Internal("_ConfigureDistributedTPU needs at least one input"));
46   }
47   void Compute(OpKernelContext* ctx) override;
~ConfigureDistributedTpuOp()48   ~ConfigureDistributedTpuOp() override {}
49 
50  private:
51   // ConfigureDistributedTpuOp is neither copyable nor movable.
52   ConfigureDistributedTpuOp(const ConfigureDistributedTpuOp&) = delete;
53   ConfigureDistributedTpuOp& operator=(const ConfigureDistributedTpuOp&) =
54       delete;
55 };
56 
57 // The WaitForDistributedTpuOp op is used to block execution until
58 // the distributed Tpu system has started up. It must be run on
59 // the same TPU_SYSTEM device that ConfigureDistributedTpuOp was run
60 // on, after all of the InitializeHostForDistributedTpuOp Ops have
61 // completed.
62 class WaitForDistributedTpuOp : public OpKernel {
63  public:
WaitForDistributedTpuOp(OpKernelConstruction * ctx)64   explicit WaitForDistributedTpuOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
65     OP_REQUIRES_OK(ctx,
66                    ctx->GetAttr("startup_timeout_sec", &startup_timeout_sec_));
67     OP_REQUIRES(ctx, startup_timeout_sec_ > 0,
68                 errors::InvalidArgument("startup_timeout_sec ",
69                                         startup_timeout_sec_, " must be >0"));
70   }
71   void Compute(OpKernelContext* ctx) override;
~WaitForDistributedTpuOp()72   ~WaitForDistributedTpuOp() override {}
73 
74  private:
75   // The time to wait for all hosts to start up.
76   int startup_timeout_sec_;
77 
78   // WaitForDistributedTpuOp is neither copyable nor movable.
79   WaitForDistributedTpuOp(const WaitForDistributedTpuOp&) = delete;
80   WaitForDistributedTpuOp& operator=(const WaitForDistributedTpuOp&) = delete;
81 };
82 
83 // The ShutdownDistributedTpu op is used to stop a running TPUDriver from
84 // TensorFlow. It should be run on the TPU_SYSTEM device where
85 // ConfigureDistributedTpuOp was run.
86 class ShutdownDistributedTpuOp : public OpKernel {
87  public:
ShutdownDistributedTpuOp(OpKernelConstruction * ctx)88   explicit ShutdownDistributedTpuOp(OpKernelConstruction* ctx)
89       : OpKernel(ctx) {}
90 
91   void Compute(OpKernelContext* ctx) override;
92 
~ShutdownDistributedTpuOp()93   ~ShutdownDistributedTpuOp() override {}
94 
95  private:
96   // ShutdownDistributedTpuOp is neither copyable nor movable.
97   ShutdownDistributedTpuOp(const ShutdownDistributedTpuOp&) = delete;
98   ShutdownDistributedTpuOp& operator=(const ShutdownDistributedTpuOp&) = delete;
99 };
100 
101 // The InitializeHostForDistributedTpu op is used to initialize the
102 // TPUPlatform on a host in a distributed TPU system. It should be
103 // run on every host containing TPU devices before any other Ops that use
104 // TPU are run.
105 class InitializeHostForDistributedTpuOp : public OpKernel {
106  public:
InitializeHostForDistributedTpuOp(OpKernelConstruction * ctx)107   explicit InitializeHostForDistributedTpuOp(OpKernelConstruction* ctx)
108       : OpKernel(ctx) {
109     ctx->GetAttr("enable_whole_mesh_compilations",
110                  &enable_whole_mesh_compilations_)
111         .IgnoreError();
112   }
113 
114   void Compute(OpKernelContext* ctx) override;
115 
~InitializeHostForDistributedTpuOp()116   ~InitializeHostForDistributedTpuOp() override {}
117 
118  private:
119   // InitializeHostForDistributedTpuOp is neither copyable nor movable.
120   InitializeHostForDistributedTpuOp(const InitializeHostForDistributedTpuOp&) =
121       delete;
122   InitializeHostForDistributedTpuOp& operator=(
123       const InitializeHostForDistributedTpuOp&) = delete;
124 
125   bool enable_whole_mesh_compilations_ = false;
126 };
127 
128 // The SetGlobalTPUArray op is used to initialize the TPUPlatform on a
129 // host in a distributed TPU system. It should be run on every host
130 // containing TPU devices before any other Ops that use TPU are run.
131 class SetGlobalTPUArrayOp : public OpKernel {
132  public:
SetGlobalTPUArrayOp(OpKernelConstruction * ctx)133   explicit SetGlobalTPUArrayOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
134 
135   void Compute(OpKernelContext* ctx) override;
136 
~SetGlobalTPUArrayOp()137   ~SetGlobalTPUArrayOp() override {}
138 
139  private:
140   // SetGlobalTPUArrayOp is neither copyable nor movable.
141   SetGlobalTPUArrayOp(const SetGlobalTPUArrayOp&) = delete;
142   SetGlobalTPUArrayOp& operator=(const SetGlobalTPUArrayOp&) = delete;
143 };
144 
145 // The DisconnectDistributedTpuChips op is used to disconnect all the chips on a
146 // host from a running TPUDriver instance. It should be run on every host
147 // containing TPU devices before the ShutdownDistributedTpuOp is run on
148 // the TPU_SYSTEM.
149 class DisconnectDistributedTpuChipsOp : public OpKernel {
150  public:
DisconnectDistributedTpuChipsOp(OpKernelConstruction * ctx)151   explicit DisconnectDistributedTpuChipsOp(OpKernelConstruction* ctx)
152       : OpKernel(ctx) {}
153 
154   void Compute(OpKernelContext* ctx) override;
155 
~DisconnectDistributedTpuChipsOp()156   ~DisconnectDistributedTpuChipsOp() override {}
157 
158  private:
159   // DisconnectDistributedTpuChipsOp is neither copyable nor movable.
160   DisconnectDistributedTpuChipsOp(const DisconnectDistributedTpuChipsOp&) =
161       delete;
162   DisconnectDistributedTpuChipsOp& operator=(
163       const DisconnectDistributedTpuChipsOp&) = delete;
164 };
165 
166 }  // namespace tensorflow
167 
168 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
169