1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h"
17
18 #include "absl/synchronization/notification.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/eager/context.h"
21 #include "tensorflow/core/distributed_runtime/remote_device.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/platform/platform.h"
24 #include "tensorflow/core/platform/random.h"
25 #include "tensorflow/core/protobuf/cluster.pb.h"
26 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
27 #include "tfrt/distributed_runtime/distributed_context.h" // from @tf_runtime
28 #include "tfrt/distributed_runtime/distributed_init_helper.h" // from @tf_runtime
29 #include "tfrt/distributed_runtime/fabric_communicator.h" // from @tf_runtime
30 #include "tfrt/distributed_runtime/proto/cluster_config.pb.h" // from @tf_runtime
31 #include "tfrt/distributed_runtime/server_context.h" // from @tf_runtime
32 #include "tfrt/distributed_runtime/task_name_util.h" // from @tf_runtime
33
34 namespace tfrt {
35 namespace tf {
36 namespace {
37
38 constexpr char kRemotePrefix[] = "remote_";
39
GetTensorFlowDeviceType(string_view name)40 std::string GetTensorFlowDeviceType(string_view name) {
41 int pos = name.find(kRemotePrefix);
42 return absl::AsciiStrToUpper(
43 pos == 0 ? name.substr(strlen(kRemotePrefix)).str() : name.str());
44 }
45
ConvertServerDefToDistributedConfiguration(const tensorflow::ServerDef & server_def)46 DistributedContextConfiguration ConvertServerDefToDistributedConfiguration(
47 const tensorflow::ServerDef& server_def) {
48 DistributedContextConfiguration dist_config;
49 dist_config.set_job_name(server_def.job_name());
50 dist_config.set_task_id(server_def.task_index());
51 ClusterConfiguration* cluster_config = dist_config.mutable_cluster_config();
52 // Currently take the first task in the first job as collective group leader.
53 // TODO(haoyuzhang): Make this configurable from API by reading from
54 // `config.experimental.collective_group_leader`.
55 cluster_config->set_lead_task_name(TaskNameUtil::ConcatTaskName(
56 server_def.cluster().job(0).name(), /*task_id=*/0));
57 for (const auto& job_def : server_def.cluster().job()) {
58 JobConfiguration* job_config = cluster_config->add_jobs();
59 job_config->set_name(job_def.name());
60 *job_config->mutable_tasks() = job_def.tasks();
61 }
62 return dist_config;
63 }
64
CreateServer(const DistributedContextConfiguration & dist_config,HostContext * host_ctx)65 std::unique_ptr<ServerContext> CreateServer(
66 const DistributedContextConfiguration& dist_config, HostContext* host_ctx) {
67 const std::string& job_name = dist_config.job_name();
68 const int task_id = dist_config.task_id();
69 std::string server_address;
70 for (const auto& job_config : dist_config.cluster_config().jobs()) {
71 if (job_config.name() == job_name) {
72 server_address = job_config.tasks().at(task_id);
73 break;
74 }
75 }
76 FabricCommunicatorConfiguration fabric_config{"grpc_communicator",
77 server_address};
78 ServerContextConfiguration server_config{fabric_config};
79 return std::make_unique<ServerContext>(host_ctx, server_config);
80 }
81
82 } // namespace
83
84 class DistributedManagerContextImpl
85 : public DistributedManagerContextInterface {
86 public:
87 explicit DistributedManagerContextImpl(HostContext* host_context);
88
89 tensorflow::Status SetOrUpdateServerDef(
90 const tensorflow::ServerDef& server_def, bool reset_context,
91 int keep_alive_secs) override;
92
93 tensorflow::Status EnableCollectiveOps(
94 const tensorflow::ServerDef& server_def) override;
95
96 tensorflow::Status EnableCoordinationService(
97 const std::string& service_type, const tensorflow::WorkerEnv* worker_env,
98 const tensorflow::ServerDef& server_def,
99 tensorflow::WorkerCacheInterface* worker_cache) override;
100
101 tensorflow::Status CheckRemoteAlive(const std::string& remote_task_name,
102 bool* is_alive) override;
103
104 tensorflow::CoordinationServiceAgent* GetCoordinationServiceAgent() override;
105
106 void UpdateRequestContextBuilder(RequestContextBuilder* builder) override;
107 void PopulateRemoteDevices(tensorflow::DeviceSet* dev_set) override;
108
109 private:
110 HostContext* host_context_;
111 std::unique_ptr<tfrt::ServerContext> server_context_;
112 AsyncValueRef<tfrt::DistributedContext> dist_context_;
113 std::unique_ptr<tensorflow::StaticDeviceMgr> tf_devices_;
114 };
115
DistributedManagerContextImpl(HostContext * host_context)116 DistributedManagerContextImpl::DistributedManagerContextImpl(
117 HostContext* host_context)
118 : host_context_(host_context) {
119 TaskNameUtil::SetUseReplicaInTaskName();
120 }
121
SetOrUpdateServerDef(const tensorflow::ServerDef & server_def,bool reset_context,int keep_alive_secs)122 tensorflow::Status DistributedManagerContextImpl::SetOrUpdateServerDef(
123 const tensorflow::ServerDef& server_def, bool reset_context,
124 int keep_alive_secs) {
125 #if defined(PLATFORM_GOOGLE)
126 DistributedContextConfiguration dist_config =
127 ConvertServerDefToDistributedConfiguration(server_def);
128 server_context_ = CreateServer(dist_config, host_context_);
129
130 // Create distributed contexts on current and remote tasks. Implemented as a
131 // blocking call to be consistent with the behavior of current TF.
132 const DistributedInitHelper* init_helper =
133 server_context_->GetDistributedInitHelper();
134 absl::Notification n;
135 init_helper->InitializeSingleClientDistributedContext(
136 std::move(dist_config),
137 [&n, this](Expected<DistributedContext*> expected) mutable {
138 if (!expected) tfrt::DieIfError(expected.takeError());
139 const uint64_t cid = expected.get()->GetContextId();
140 dist_context_ = server_context_->GetDistributedContextAsyncValue(cid);
141 n.Notify();
142 });
143 n.WaitForNotification();
144
145 auto device_refs =
146 dist_context_->GetRemoteDeviceManager()->ListDevices<Device>();
147 std::vector<std::unique_ptr<tensorflow::Device>> tf_devices;
148 for (auto& device_ref : device_refs) {
149 tensorflow::DeviceAttributes da;
150 da.set_name(device_ref->name().str());
151 da.set_device_type(GetTensorFlowDeviceType(device_ref->type().name()));
152 // TF Devices created here might not have all of the attributes needed.
153 // Currently, it is only used by Placer during TFRT Function creation.
154 tf_devices.emplace_back(NewRemoteDevice(tensorflow::Env::Default(), da));
155 }
156 tf_devices_ =
157 std::make_unique<tensorflow::StaticDeviceMgr>(std::move(tf_devices));
158 return tensorflow::Status::OK();
159 #endif // PLATFORM_GOOGLE
160 return tensorflow::errors::Unimplemented(
161 "SetOrUpdateServerDef in open source is not yet implemented.");
162 }
163
EnableCollectiveOps(const tensorflow::ServerDef & server_def)164 tensorflow::Status DistributedManagerContextImpl::EnableCollectiveOps(
165 const tensorflow::ServerDef& server_def) {
166 #if defined(PLATFORM_GOOGLE)
167 DistributedContextConfiguration dist_config =
168 ConvertServerDefToDistributedConfiguration(server_def);
169 server_context_ = CreateServer(dist_config, host_context_);
170
171 DistributedInitHelper* init_helper =
172 server_context_->GetDistributedInitHelper();
173 absl::Notification n;
174 init_helper->InitializeMultiClientDistributedContext(
175 std::move(dist_config),
176 [&n, this](Expected<DistributedContext*> expected) mutable {
177 if (!expected) tfrt::DieIfError(expected.takeError());
178 const uint64_t cid = expected.get()->GetContextId();
179 dist_context_ = server_context_->GetDistributedContextAsyncValue(cid);
180 n.Notify();
181 });
182 n.WaitForNotification();
183
184 return tensorflow::Status::OK();
185 #endif // PLATFORM_GOOGLE
186 return tensorflow::errors::Unimplemented(
187 "EnableCollectiveOps in open source is not yet implemented.");
188 }
189
EnableCoordinationService(const std::string & service_type,const tensorflow::WorkerEnv * worker_env,const tensorflow::ServerDef & server_def,tensorflow::WorkerCacheInterface * worker_cache)190 tensorflow::Status DistributedManagerContextImpl::EnableCoordinationService(
191 const std::string& service_type, const tensorflow::WorkerEnv* worker_env,
192 const tensorflow::ServerDef& server_def,
193 tensorflow::WorkerCacheInterface* worker_cache) {
194 return tensorflow::errors::Unimplemented(
195 "EnableCoordinationService in TFRT is not yet implemented.");
196 }
197
CheckRemoteAlive(const std::string & remote_task_name,bool * is_alive)198 tensorflow::Status DistributedManagerContextImpl::CheckRemoteAlive(
199 const std::string& remote_task_name, bool* is_alive) {
200 return tensorflow::errors::Unimplemented(
201 "CheckRemoteAlive in TFRT is not yet implemented.");
202 }
203
204 tensorflow::CoordinationServiceAgent*
GetCoordinationServiceAgent()205 DistributedManagerContextImpl::GetCoordinationServiceAgent() {
206 TFRT_LOG(FATAL) << "Coordination service in TFRT is not yet enabled.";
207 return nullptr;
208 }
209
UpdateRequestContextBuilder(RequestContextBuilder * builder)210 void DistributedManagerContextImpl::UpdateRequestContextBuilder(
211 RequestContextBuilder* builder) {
212 builder->context_data().insert(dist_context_.CopyRef());
213 }
214
PopulateRemoteDevices(tensorflow::DeviceSet * dev_set)215 void DistributedManagerContextImpl::PopulateRemoteDevices(
216 tensorflow::DeviceSet* dev_set) {
217 if (tf_devices_ == nullptr) {
218 return;
219 }
220 for (auto& device : tf_devices_->ListDevices()) {
221 dev_set->AddDevice(device);
222 }
223 }
224
225 std::unique_ptr<DistributedManagerContextInterface>
CreateDistributedManagerContext(HostContext * host_context)226 CreateDistributedManagerContext(HostContext* host_context) {
227 return std::make_unique<DistributedManagerContextImpl>(host_context);
228 }
229
230 } // namespace tf
231 } // namespace tfrt
232