• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_
18 
19 #include "absl/synchronization/mutex.h"
20 #include "absl/synchronization/notification.h"
21 #include "absl/time/time.h"
22 #include "tensorflow/compiler/xla/pjrt/distributed/key_value_store.h"
23 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/core/platform/env.h"
27 
28 namespace xla {
29 
30 typedef int NodeId;
31 
32 class DistributedRuntimeServiceImpl final
33     : public grpc::DistributedRuntimeService::Service {
34  public:
35   struct Options {
36     // Number of nodes in the job. Mandatory. Must be non-negative.
37     int num_nodes = -1;
38 
39     tensorflow::Env* env = tensorflow::Env::Default();
40 
41     // Interval at which the service should check for missed heartbeat RPCs
42     // from the clients.
43     absl::Duration heartbeat_interval = absl::Seconds(10);
44 
45     // Number of heartbeats that a client may miss in a row before the
46     // coordinator concludes that a client has vanished.
47     int max_missing_heartbeats = 10;
48 
49     // How long should we wait for all clients to call EnumerateDevices() before
50     // giving up?
51     absl::Duration enumerate_devices_timeout = absl::Seconds(60);
52 
53     // How long should we wait for all clients to call Shutdown() before giving
54     // up and returning a failure?
55     absl::Duration shutdown_timeout = absl::Seconds(60);
56   };
57   explicit DistributedRuntimeServiceImpl(const Options& options);
58   ~DistributedRuntimeServiceImpl() override;
59 
60   DistributedRuntimeServiceImpl(const DistributedRuntimeServiceImpl&) = delete;
61   DistributedRuntimeServiceImpl(DistributedRuntimeServiceImpl&&) = delete;
62   DistributedRuntimeServiceImpl& operator=(
63       const DistributedRuntimeServiceImpl&) = delete;
64   DistributedRuntimeServiceImpl&& operator=(DistributedRuntimeServiceImpl&&) =
65       delete;
66 
67   ::grpc::Status Connect(::grpc::ServerContext* context,
68                          const ConnectRequest* request,
69                          ConnectResponse* response) override;
70 
71   ::grpc::Status Shutdown(::grpc::ServerContext* context,
72                           const ShutdownRequest* request,
73                           ShutdownResponse* response) override;
74 
75   ::grpc::Status Heartbeat(::grpc::ServerContext* context,
76                            const HeartbeatRequest* request,
77                            HeartbeatResponse* response) override;
78 
79   ::grpc::Status EnumerateDevices(::grpc::ServerContext* context,
80                                   const EnumerateDevicesRequest* request,
81                                   EnumerateDevicesResponse* response) override;
82 
83   ::grpc::Status KeyValueGet(::grpc::ServerContext* context,
84                              const KeyValueGetRequest* request,
85                              KeyValueGetResponse* response) override;
86 
87   ::grpc::Status KeyValueSet(::grpc::ServerContext* context,
88                              const KeyValueSetRequest* request,
89                              KeyValueSetResponse* response) override;
90 
91  private:
92   // Entry point for the heartbeat checking thread.
93   void HeartbeatLoop();
94 
95   // Validates a session id number matches the current session id.
96   xla::Status ValidateSessionId(uint64 session_id);
97 
98   // Validates a node id number.
99   xla::Status ValidateNodeId(int node_id);
100 
101   const Options options_;
102   const uint64 session_id_;
103 
104   absl::Mutex mu_;
105   enum class State { kInitializing, kRunning, kClosed };
106   State state_ ABSL_GUARDED_BY(mu_) = State::kInitializing;
107   Status service_status_ ABSL_GUARDED_BY(mu_);
108 
109   // State for Connect() and heartbeats.
110   struct Node {
111     // Have we heard from a task with this ID?
112     bool present = false;
113 
114     // A unique ID belonging to the client. Used to identify the client that
115     // most recently called Connect() with a particular task id.
116     uint64 client_id = 0;
117 
118     // When did we last receive a heartbeat from this task?
119     absl::Time last_heartbeat = absl::InfinitePast();
120   };
121   int num_nodes_present_ ABSL_GUARDED_BY(mu_) = 0;
122   std::vector<Node> nodes_ ABSL_GUARDED_BY(mu_);
123 
124   // State for EnumerateDevices.
125   int num_topologies_present_ ABSL_GUARDED_BY(mu_) = 0;
126   std::vector<LocalTopologyProto> local_topologies_ ABSL_GUARDED_BY(mu_);
127   absl::optional<GlobalTopologyProto> topology_ ABSL_GUARDED_BY(mu_);
128 
129   // State for Shutdown(). Counter of how many nodes are blocked at the
130   // Shutdown() barrier.
131   int num_nodes_shutting_down_ ABSL_GUARDED_BY(mu_) = 0;
132 
133   // Key-value store, used by distributed GPU code to share NCCL state.
134   KeyValueStore key_value_store_;
135 
136   // Notification that tells the heartbeat thread to stop.
137   absl::Notification stop_heartbeat_thread_;
138 
139   // Thread that checks for missing hearbeats from the clients periodically.
140   std::unique_ptr<tensorflow::Thread> heartbeat_thread_;
141 };
142 
143 class DistributedRuntimeService {
144  public:
145   static xla::StatusOr<std::unique_ptr<DistributedRuntimeService>> Get(
146       const std::string& address,
147       std::shared_ptr<::grpc::ServerCredentials> credentials,
148       const DistributedRuntimeServiceImpl::Options& options);
149 
150   explicit DistributedRuntimeService(
151       const DistributedRuntimeServiceImpl::Options& options);
152   ~DistributedRuntimeService();
153 
154   DistributedRuntimeService(const DistributedRuntimeService&) = delete;
155   DistributedRuntimeService(DistributedRuntimeService&&) = delete;
156   DistributedRuntimeService& operator=(const DistributedRuntimeService&) =
157       delete;
158   DistributedRuntimeService& operator=(DistributedRuntimeService&&) = delete;
159 
server()160   ::grpc::Server* server() const { return server_.get(); }
161 
162  private:
163   DistributedRuntimeServiceImpl impl_;
164   std::unique_ptr<::grpc::Server> server_;
165 };
166 
167 // Everything below this point is exposed only for tests.
168 
169 // Given a LocalTopologyProto object from each node, builds a
170 // GlobalTopologyProto that describes all nodes.
171 void BuildGlobalTopology(absl::Span<LocalTopologyProto> local_topologies,
172                          GlobalTopologyProto* global_topology);
173 
174 }  // namespace xla
175 
176 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_
177