1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_H_ 18 19 #include <functional> 20 #include <string> 21 #include <utility> 22 23 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h" 24 #include "tensorflow/core/platform/status.h" 25 #include "tensorflow/core/platform/statusor.h" 26 27 namespace tensorflow { 28 class DeviceAttributes; 29 class ServerDef; 30 class WorkerEnv; 31 32 // Static registration for coordination service implementations. 33 #define REGISTER_COORDINATION_SERVICE(service_type_name, factory_fn) \ 34 REGISTER_COORDINATION_SERVICE_UNIQ_HELPER(__COUNTER__, service_type_name, \ 35 factory_fn) 36 #define REGISTER_COORDINATION_SERVICE_UNIQ_HELPER(counter, service_type_name, \ 37 factory_fn) \ 38 static bool static_coordination_service_##counter TF_ATTRIBUTE_UNUSED = \ 39 []() { \ 40 ::tensorflow::CoordinationServiceInterface:: \ 41 RegisterCoordinationService(service_type_name, \ 42 std::move(factory_fn)); \ 43 return true; \ 44 }() 45 46 // Coordination service is used for controlling and coordinating distributed 47 // execution in a cluster of multiple workers. 48 // 49 // When enabled, the service keeps track of cluster configurations and the state 50 // of cluster members. TF runtime and libraries can use it to orchastrate 51 // cluster initialization, check the healthiness of workers, and propagate error 52 // messages to the cluster. 53 // 54 // Normally, the service should first Start(), then perform the supported 55 // coordination operations, and finally Stop(). When service runs into error or 56 // SetError() is called, all subsequent operations will be in error state. 57 // 58 // CoordinationServiceInterface defines the service interface for distributed 59 // coordination. One instance of the service should be deployed in a cluster, 60 // handling various requests and stores configuration key-value data for the 61 // tasks. Each task interacts with the service through CoordinationServiceAgent. 62 // 63 // Experimental feature. Not yet implemented in open source. 64 class CoordinationServiceInterface { 65 public: 66 using CoordinationServiceFactory = 67 std::function<std::unique_ptr<CoordinationServiceInterface>( 68 const WorkerEnv* env, const ServerDef& server_def, 69 std::unique_ptr<CoordinationClientCache> cache)>; 70 71 using StatusOrValueCallback = 72 std::function<void(const StatusOr<std::string>&)>; 73 ~CoordinationServiceInterface()74 virtual ~CoordinationServiceInterface() {} 75 RegisterCoordinationService(const std::string & service_type_name,CoordinationServiceFactory factory_fn)76 static void RegisterCoordinationService( 77 const std::string& service_type_name, 78 CoordinationServiceFactory factory_fn) { 79 auto factories = GetCoordinationServiceFactories(); 80 factories->emplace(service_type_name, factory_fn); 81 } 82 83 static std::unique_ptr<CoordinationServiceInterface> EnableCoordinationService(const std::string & service_type,const WorkerEnv * env,const ServerDef & server_def,std::unique_ptr<CoordinationClientCache> cache)84 EnableCoordinationService(const std::string& service_type, 85 const WorkerEnv* env, const ServerDef& server_def, 86 std::unique_ptr<CoordinationClientCache> cache) { 87 const auto* factories = GetCoordinationServiceFactories(); 88 auto factories_iter = factories->find(service_type); 89 if (factories_iter == factories->end()) { 90 LOG(ERROR) << "No coordination service factory found for service type " 91 << service_type; 92 return nullptr; 93 } 94 auto service = factories_iter->second(env, server_def, std::move(cache)); 95 if (service != nullptr) { 96 *GetCoordinationServiceInstancePtr() = service.get(); 97 } 98 return service; 99 } 100 GetCoordinationServiceInstance()101 static CoordinationServiceInterface* GetCoordinationServiceInstance() { 102 return *GetCoordinationServiceInstancePtr(); 103 } 104 105 // Register a worker to the service. 106 virtual void RegisterWorker(const std::string& job_name, const int task_id, 107 const uint64 incarnation, 108 std::vector<DeviceAttributes> devices, 109 StatusCallback done) = 0; 110 111 // Wait for all tasks to be up and running. The callback is invoked when all 112 // tasks are up and registered, or some error occurs. 113 virtual void WaitForAllTasks(const std::string& job_name, const int task_id, 114 StatusCallback done) = 0; 115 116 // Update the heartbeat timestamp of a task. This should only be invoked on 117 // the leader of the cluster. 118 virtual Status RecordHeartbeat(const std::string& job_name, const int task_id, 119 const uint64 incarnation) = 0; 120 121 // Set a task in error state permanently. 122 virtual Status ReportTaskError(const std::string& job_name, const int task_id, 123 Status error) = 0; 124 125 // Insert a configuration key-value in the coordination service. 126 // For now, a key-value can only be inserted once and cannot be updated. 127 // The key-values are not persisted and will be lost if the leader fails. 128 virtual Status InsertKeyValue(const std::string& key, 129 const std::string& value) = 0; 130 131 // Get a configuration key-value from the coordination service. Block until 132 // the key-value is available. 133 virtual StatusOr<std::string> GetKeyValue(const std::string& key) = 0; 134 // Get a configuration key-value from the coordination service. The `done` 135 // callback is invoked when the key-value becomes available. 136 virtual void GetKeyValueAsync(const std::string& key, 137 StatusOrValueCallback done) = 0; 138 139 // Delete configuration key-value. If key is a directory, recursively clean 140 // up all key-values under the directory. 141 virtual Status DeleteKeyValue(const std::string& key) = 0; 142 143 private: 144 friend class CoordinationServiceRpcHandler; 145 virtual const std::vector<DeviceAttributes>& ListClusterDevices() = 0; 146 147 static std::unordered_map<std::string, CoordinationServiceFactory>* GetCoordinationServiceFactories()148 GetCoordinationServiceFactories() { 149 static auto* coordination_service_factories = 150 new std::unordered_map<std::string, CoordinationServiceFactory>(); 151 return coordination_service_factories; 152 } 153 154 // TODO(haoyuzhang): Remove singleton once we decide on how to access the 155 // coordination service from op kernel. GetCoordinationServiceInstancePtr()156 static CoordinationServiceInterface** GetCoordinationServiceInstancePtr() { 157 static CoordinationServiceInterface* instance = nullptr; 158 return &instance; 159 } 160 }; 161 162 } // namespace tensorflow 163 164 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_H_ 165