• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_H_
18 
19 #include <functional>
20 #include <string>
21 #include <utility>
22 
23 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
24 #include "tensorflow/core/platform/status.h"
25 #include "tensorflow/core/platform/statusor.h"
26 
27 namespace tensorflow {
28 class DeviceAttributes;
29 class ServerDef;
30 class WorkerEnv;
31 
32 // Static registration for coordination service implementations.
33 #define REGISTER_COORDINATION_SERVICE(service_type_name, factory_fn)        \
34   REGISTER_COORDINATION_SERVICE_UNIQ_HELPER(__COUNTER__, service_type_name, \
35                                             factory_fn)
36 #define REGISTER_COORDINATION_SERVICE_UNIQ_HELPER(counter, service_type_name, \
37                                                   factory_fn)                 \
38   static bool static_coordination_service_##counter TF_ATTRIBUTE_UNUSED =     \
39       []() {                                                                  \
40         ::tensorflow::CoordinationServiceInterface::                          \
41             RegisterCoordinationService(service_type_name,                    \
42                                         std::move(factory_fn));               \
43         return true;                                                          \
44       }()
45 
46 // Coordination service is used for controlling and coordinating distributed
47 // execution in a cluster of multiple workers.
48 //
49 // When enabled, the service keeps track of cluster configurations and the state
50 // of cluster members. TF runtime and libraries can use it to orchastrate
51 // cluster initialization, check the healthiness of workers, and propagate error
52 // messages to the cluster.
53 //
54 // Normally, the service should first Start(), then perform the supported
55 // coordination operations, and finally Stop(). When service runs into error or
56 // SetError() is called, all subsequent operations will be in error state.
57 //
58 // CoordinationServiceInterface defines the service interface for distributed
59 // coordination. One instance of the service should be deployed in a cluster,
60 // handling various requests and stores configuration key-value data for the
61 // tasks. Each task interacts with the service through CoordinationServiceAgent.
62 //
63 // Experimental feature. Not yet implemented in open source.
64 class CoordinationServiceInterface {
65  public:
66   using CoordinationServiceFactory =
67       std::function<std::unique_ptr<CoordinationServiceInterface>(
68           const WorkerEnv* env, const ServerDef& server_def,
69           std::unique_ptr<CoordinationClientCache> cache)>;
70 
71   using StatusOrValueCallback =
72       std::function<void(const StatusOr<std::string>&)>;
73 
~CoordinationServiceInterface()74   virtual ~CoordinationServiceInterface() {}
75 
RegisterCoordinationService(const std::string & service_type_name,CoordinationServiceFactory factory_fn)76   static void RegisterCoordinationService(
77       const std::string& service_type_name,
78       CoordinationServiceFactory factory_fn) {
79     auto factories = GetCoordinationServiceFactories();
80     factories->emplace(service_type_name, factory_fn);
81   }
82 
83   static std::unique_ptr<CoordinationServiceInterface>
EnableCoordinationService(const std::string & service_type,const WorkerEnv * env,const ServerDef & server_def,std::unique_ptr<CoordinationClientCache> cache)84   EnableCoordinationService(const std::string& service_type,
85                             const WorkerEnv* env, const ServerDef& server_def,
86                             std::unique_ptr<CoordinationClientCache> cache) {
87     const auto* factories = GetCoordinationServiceFactories();
88     auto factories_iter = factories->find(service_type);
89     if (factories_iter == factories->end()) {
90       LOG(ERROR) << "No coordination service factory found for service type "
91                  << service_type;
92       return nullptr;
93     }
94     auto service = factories_iter->second(env, server_def, std::move(cache));
95     if (service != nullptr) {
96       *GetCoordinationServiceInstancePtr() = service.get();
97     }
98     return service;
99   }
100 
GetCoordinationServiceInstance()101   static CoordinationServiceInterface* GetCoordinationServiceInstance() {
102     return *GetCoordinationServiceInstancePtr();
103   }
104 
105   // Register a worker to the service.
106   virtual void RegisterWorker(const std::string& job_name, const int task_id,
107                               const uint64 incarnation,
108                               std::vector<DeviceAttributes> devices,
109                               StatusCallback done) = 0;
110 
111   // Wait for all tasks to be up and running. The callback is invoked when all
112   // tasks are up and registered, or some error occurs.
113   virtual void WaitForAllTasks(const std::string& job_name, const int task_id,
114                                StatusCallback done) = 0;
115 
116   // Update the heartbeat timestamp of a task. This should only be invoked on
117   // the leader of the cluster.
118   virtual Status RecordHeartbeat(const std::string& job_name, const int task_id,
119                                  const uint64 incarnation) = 0;
120 
121   // Set a task in error state permanently.
122   virtual Status ReportTaskError(const std::string& job_name, const int task_id,
123                                  Status error) = 0;
124 
125   // Insert a configuration key-value in the coordination service.
126   // For now, a key-value can only be inserted once and cannot be updated.
127   // The key-values are not persisted and will be lost if the leader fails.
128   virtual Status InsertKeyValue(const std::string& key,
129                                 const std::string& value) = 0;
130 
131   // Get a configuration key-value from the coordination service. Block until
132   // the key-value is available.
133   virtual StatusOr<std::string> GetKeyValue(const std::string& key) = 0;
134   // Get a configuration key-value from the coordination service. The `done`
135   // callback is invoked when the key-value becomes available.
136   virtual void GetKeyValueAsync(const std::string& key,
137                                 StatusOrValueCallback done) = 0;
138 
139   // Delete configuration key-value. If key is a directory, recursively clean
140   // up all key-values under the directory.
141   virtual Status DeleteKeyValue(const std::string& key) = 0;
142 
143  private:
144   friend class CoordinationServiceRpcHandler;
145   virtual const std::vector<DeviceAttributes>& ListClusterDevices() = 0;
146 
147   static std::unordered_map<std::string, CoordinationServiceFactory>*
GetCoordinationServiceFactories()148   GetCoordinationServiceFactories() {
149     static auto* coordination_service_factories =
150         new std::unordered_map<std::string, CoordinationServiceFactory>();
151     return coordination_service_factories;
152   }
153 
154   // TODO(haoyuzhang): Remove singleton once we decide on how to access the
155   // coordination service from op kernel.
GetCoordinationServiceInstancePtr()156   static CoordinationServiceInterface** GetCoordinationServiceInstancePtr() {
157     static CoordinationServiceInterface* instance = nullptr;
158     return &instance;
159   }
160 };
161 
162 }  // namespace tensorflow
163 
164 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_H_
165