1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_AGENT_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_AGENT_H_ 18 19 #include <functional> 20 #include <string> 21 #include <utility> 22 23 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h" 24 #include "tensorflow/core/platform/status.h" 25 #include "tensorflow/core/platform/statusor.h" 26 27 namespace tensorflow { 28 class DeviceAttributes; 29 class WorkerEnv; 30 class ServerDef; 31 32 // CoordinationServiceAgent defines the interface for tasks to communicate with 33 // the coordination service instance (which implements 34 // CoordinationServiceInterface). One instance of the agent should be deployed 35 // on each task for it to send various requests and stores / retrieves config 36 // key-value data to the service. 37 // 38 // See CoordinationServiceInterface for more details on coordination service. 39 // 40 // Experimental feature. Not yet implemented in open source. 41 class CoordinationServiceAgent { 42 public: 43 using StatusOrValueCallback = 44 std::function<void(const StatusOr<std::string>&)>; 45 using ChangedKeyValuesCallback = 46 std::function<void(const std::map<std::string, std::string>&)>; 47 ~CoordinationServiceAgent()48 virtual ~CoordinationServiceAgent() {} 49 50 // Initialize coordination service agent. 51 virtual Status Initialize( 52 const WorkerEnv* worker_env, const ServerDef& server_def, 53 std::unique_ptr<CoordinationClientCache> client_cache, 54 StatusCallback error_fn) = 0; 55 // Return true if the coordination service agent has been initialized. 56 virtual bool IsInitialized() = 0; 57 58 // Connect to coordination service with the following steps: 59 // - connect to service address specified in the config of `server_def` 60 // - register itself as a worker to the service 61 // - start a thread to periodically send heartbeat message with the service 62 virtual Status Connect() = 0; 63 64 // Wait for all tasks to be up and registered. The call blocks until all tasks 65 // in the cluster are up, or some error occurs. 66 virtual Status WaitForAllTasks() = 0; 67 68 // Get the device attributes of tasks from remote tasks in the cluster. 69 virtual const std::vector<DeviceAttributes>& GetClusterDeviceAttributes() = 0; 70 71 // State transition in coordination service agent: 72 // 73 // Init Connect SetError 74 // UNINITIALIZED ---> DISCONNECTED ------> RUNNING -------> ERROR 75 // ^ | 76 // |__________________________________| 77 // Reset 78 enum class TaskState { 79 UNINITIALIZED, 80 DISCONNECTED, 81 RUNNING, 82 ERROR, 83 }; 84 85 // Get status of a remote task. 86 virtual StatusOr<TaskState> GetTaskStatus(const std::string& job_name, 87 const int task_id) = 0; 88 89 // Report error to coordination service. This will invoke the error callback. 90 virtual Status ReportError(const Status& error) = 0; 91 92 // Disconnect from the service, and clean up the internal error status. 93 virtual Status Reset() = 0; 94 95 // Get config key-value from the service. 96 virtual StatusOr<std::string> GetKeyValue(const std::string& key) = 0; 97 virtual void GetKeyValueAsync(const std::string& key, 98 StatusOrValueCallback done) = 0; 99 100 // Insert config key-value to the service. Return error if key is already set. 101 virtual Status InsertKeyValue(const std::string& key, 102 const std::string& value) = 0; 103 104 // Delete config keys in the coordination service. 105 virtual Status DeleteKeyValue(const std::string& key) = 0; 106 107 // Update the value of a config key. 108 virtual Status UpdateKeyValue(const std::string& key, 109 const std::string& value) = 0; 110 111 // Register a callback that will be invoked when the key or keys under the key 112 // directory are changed (inserted, deleted, or updated). 113 virtual Status StartWatchKey(const std::string& key, 114 ChangedKeyValuesCallback on_change) = 0; 115 virtual Status StopWatchKey(const std::string& key) = 0; 116 117 protected: 118 // Set the service agent to error status and invoke the error callback. 119 // Note: different from ReportError, this does not report the error status to 120 // remote coordination service. 121 virtual void SetError(const Status& error) = 0; 122 123 // Activate the key-value callback watch. 124 virtual Status ActivateWatch(const std::string& key, 125 const std::map<std::string, std::string>&) = 0; 126 127 private: 128 friend class CoordinationServiceRpcHandler; 129 }; 130 131 std::unique_ptr<CoordinationServiceAgent> CreateCoordinationServiceAgent(); 132 133 } // namespace tensorflow 134 135 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_AGENT_H_ 136