• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_AGENT_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_AGENT_H_
18 
19 #include <functional>
20 #include <string>
21 #include <utility>
22 
23 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
24 #include "tensorflow/core/platform/status.h"
25 #include "tensorflow/core/platform/statusor.h"
26 
27 namespace tensorflow {
28 class DeviceAttributes;
29 class WorkerEnv;
30 class ServerDef;
31 
32 // CoordinationServiceAgent defines the interface for tasks to communicate with
33 // the coordination service instance (which implements
34 // CoordinationServiceInterface). One instance of the agent should be deployed
35 // on each task for it to send various requests and stores / retrieves config
36 // key-value data to the service.
37 //
38 // See CoordinationServiceInterface for more details on coordination service.
39 //
40 // Experimental feature. Not yet implemented in open source.
41 class CoordinationServiceAgent {
42  public:
43   using StatusOrValueCallback =
44       std::function<void(const StatusOr<std::string>&)>;
45   using ChangedKeyValuesCallback =
46       std::function<void(const std::map<std::string, std::string>&)>;
47 
~CoordinationServiceAgent()48   virtual ~CoordinationServiceAgent() {}
49 
50   // Initialize coordination service agent.
51   virtual Status Initialize(
52       const WorkerEnv* worker_env, const ServerDef& server_def,
53       std::unique_ptr<CoordinationClientCache> client_cache,
54       StatusCallback error_fn) = 0;
55   // Return true if the coordination service agent has been initialized.
56   virtual bool IsInitialized() = 0;
57 
58   // Connect to coordination service with the following steps:
59   //   - connect to service address specified in the config of `server_def`
60   //   - register itself as a worker to the service
61   //   - start a thread to periodically send heartbeat message with the service
62   virtual Status Connect() = 0;
63 
64   // Wait for all tasks to be up and registered. The call blocks until all tasks
65   // in the cluster are up, or some error occurs.
66   virtual Status WaitForAllTasks() = 0;
67 
68   // Get the device attributes of tasks from remote tasks in the cluster.
69   virtual const std::vector<DeviceAttributes>& GetClusterDeviceAttributes() = 0;
70 
71   // State transition in coordination service agent:
72   //
73   //                 Init              Connect         SetError
74   //   UNINITIALIZED ---> DISCONNECTED ------> RUNNING -------> ERROR
75   //                           ^                                  |
76   //                           |__________________________________|
77   //                                         Reset
78   enum class TaskState {
79     UNINITIALIZED,
80     DISCONNECTED,
81     RUNNING,
82     ERROR,
83   };
84 
85   // Get status of a remote task.
86   virtual StatusOr<TaskState> GetTaskStatus(const std::string& job_name,
87                                             const int task_id) = 0;
88 
89   // Report error to coordination service. This will invoke the error callback.
90   virtual Status ReportError(const Status& error) = 0;
91 
92   // Disconnect from the service, and clean up the internal error status.
93   virtual Status Reset() = 0;
94 
95   // Get config key-value from the service.
96   virtual StatusOr<std::string> GetKeyValue(const std::string& key) = 0;
97   virtual void GetKeyValueAsync(const std::string& key,
98                                 StatusOrValueCallback done) = 0;
99 
100   // Insert config key-value to the service. Return error if key is already set.
101   virtual Status InsertKeyValue(const std::string& key,
102                                 const std::string& value) = 0;
103 
104   // Delete config keys in the coordination service.
105   virtual Status DeleteKeyValue(const std::string& key) = 0;
106 
107   // Update the value of a config key.
108   virtual Status UpdateKeyValue(const std::string& key,
109                                 const std::string& value) = 0;
110 
111   // Register a callback that will be invoked when the key or keys under the key
112   // directory are changed (inserted, deleted, or updated).
113   virtual Status StartWatchKey(const std::string& key,
114                                ChangedKeyValuesCallback on_change) = 0;
115   virtual Status StopWatchKey(const std::string& key) = 0;
116 
117  protected:
118   // Set the service agent to error status and invoke the error callback.
119   // Note: different from ReportError, this does not report the error status to
120   // remote coordination service.
121   virtual void SetError(const Status& error) = 0;
122 
123   // Activate the key-value callback watch.
124   virtual Status ActivateWatch(const std::string& key,
125                                const std::map<std::string, std::string>&) = 0;
126 
127  private:
128   friend class CoordinationServiceRpcHandler;
129 };
130 
131 std::unique_ptr<CoordinationServiceAgent> CreateCoordinationServiceAgent();
132 
133 }  // namespace tensorflow
134 
135 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_AGENT_H_
136