1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h"
17
18 #include "absl/synchronization/notification.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/eager/context.h"
21 #include "tensorflow/core/distributed_runtime/remote_device.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/platform/platform.h"
24 #include "tensorflow/core/platform/random.h"
25 #include "tensorflow/core/protobuf/cluster.pb.h"
26 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
27 #include "tfrt/distributed_runtime/distributed_context.h" // from @tf_runtime
28 #include "tfrt/distributed_runtime/distributed_init_helper.h" // from @tf_runtime
29 #include "tfrt/distributed_runtime/fabric_communicator.h" // from @tf_runtime
30 #include "tfrt/distributed_runtime/proto/cluster_config.pb.h" // from @tf_runtime
31 #include "tfrt/distributed_runtime/server_context.h" // from @tf_runtime
32 #include "tfrt/distributed_runtime/task_name_util.h" // from @tf_runtime
33
34 namespace tfrt {
35 namespace tf {
36 namespace {
37
38 constexpr char kRemotePrefix[] = "remote_";
39
GetTensorFlowDeviceType(string_view name)40 std::string GetTensorFlowDeviceType(string_view name) {
41 int pos = name.find(kRemotePrefix);
42 return absl::AsciiStrToUpper(
43 pos == 0 ? name.substr(strlen(kRemotePrefix)).str() : name.str());
44 }
45
ConvertServerDefToDistributedConfiguration(const tensorflow::ServerDef & server_def)46 DistributedContextConfiguration ConvertServerDefToDistributedConfiguration(
47 const tensorflow::ServerDef& server_def) {
48 DistributedContextConfiguration dist_config;
49 dist_config.set_job_name(server_def.job_name());
50 dist_config.set_task_id(server_def.task_index());
51 ClusterConfiguration* cluster_config = dist_config.mutable_cluster_config();
52 // Currently take the first task in the first job as collective group leader.
53 // TODO(haoyuzhang): Make this configurable from API by reading from
54 // `config.experimental.collective_group_leader`.
55 cluster_config->set_lead_task_name(TaskNameUtil::ConcatTaskName(
56 server_def.cluster().job(0).name(), /*task_id=*/0));
57 for (const auto& job_def : server_def.cluster().job()) {
58 JobConfiguration* job_config = cluster_config->add_jobs();
59 job_config->set_name(job_def.name());
60 *job_config->mutable_tasks() = job_def.tasks();
61 }
62 return dist_config;
63 }
64
CreateServer(const DistributedContextConfiguration & dist_config,HostContext * host_ctx)65 std::unique_ptr<ServerContext> CreateServer(
66 const DistributedContextConfiguration& dist_config, HostContext* host_ctx) {
67 const std::string& job_name = dist_config.job_name();
68 const int task_id = dist_config.task_id();
69 std::string server_address;
70 for (const auto& job_config : dist_config.cluster_config().jobs()) {
71 if (job_config.name() == job_name) {
72 server_address = job_config.tasks().at(task_id);
73 break;
74 }
75 }
76 FabricCommunicatorConfiguration fabric_config{"grpc_communicator",
77 server_address};
78 ServerContextConfiguration server_config{fabric_config};
79 return std::make_unique<ServerContext>(host_ctx, server_config);
80 }
81
82 } // namespace
83
84 class DistributedManagerContextImpl
85 : public DistributedManagerContextInterface {
86 public:
87 explicit DistributedManagerContextImpl(HostContext* host_context);
88
89 tensorflow::Status SetOrUpdateServerDef(
90 const tensorflow::ServerDef& server_def, bool reset_context,
91 int keep_alive_secs) override;
92
93 tensorflow::Status EnableCollectiveOps(
94 const tensorflow::ServerDef& server_def) override;
95
96 tensorflow::Status CheckRemoteAlive(const std::string& remote_task_name,
97 bool* is_alive) override;
98
99 tensorflow::CoordinationServiceAgent* GetCoordinationServiceAgent() override;
100
101 void UpdateRequestContextBuilder(RequestContextBuilder* builder) override;
102 void PopulateRemoteDevices(tensorflow::DeviceSet* dev_set) override;
103
104 private:
105 HostContext* host_context_;
106 std::unique_ptr<tfrt::ServerContext> server_context_;
107 AsyncValueRef<tfrt::DistributedContext> dist_context_;
108 std::unique_ptr<tensorflow::StaticDeviceMgr> tf_devices_;
109 };
110
DistributedManagerContextImpl(HostContext * host_context)111 DistributedManagerContextImpl::DistributedManagerContextImpl(
112 HostContext* host_context)
113 : host_context_(host_context) {
114 TaskNameUtil::SetUseReplicaInTaskName();
115 }
116
SetOrUpdateServerDef(const tensorflow::ServerDef & server_def,bool reset_context,int keep_alive_secs)117 tensorflow::Status DistributedManagerContextImpl::SetOrUpdateServerDef(
118 const tensorflow::ServerDef& server_def, bool reset_context,
119 int keep_alive_secs) {
120 #if defined(PLATFORM_GOOGLE)
121 DistributedContextConfiguration dist_config =
122 ConvertServerDefToDistributedConfiguration(server_def);
123 server_context_ = CreateServer(dist_config, host_context_);
124
125 // Create distributed contexts on current and remote tasks. Implemented as a
126 // blocking call to be consistent with the behavior of current TF.
127 const DistributedInitHelper* init_helper =
128 server_context_->GetDistributedInitHelper();
129 absl::Notification n;
130 init_helper->InitializeSingleClientDistributedContext(
131 std::move(dist_config),
132 [&n, this](Expected<DistributedContext*> expected) mutable {
133 if (!expected) tfrt::DieIfError(expected.takeError());
134 const uint64_t cid = expected.get()->GetContextId();
135 dist_context_ = server_context_->GetDistributedContextAsyncValue(cid);
136 n.Notify();
137 });
138 n.WaitForNotification();
139
140 auto device_refs =
141 dist_context_->GetRemoteDeviceManager()->ListDevices<Device>();
142 std::vector<std::unique_ptr<tensorflow::Device>> tf_devices;
143 for (auto& device_ref : device_refs) {
144 tensorflow::DeviceAttributes da;
145 da.set_name(device_ref->name().str());
146 da.set_device_type(GetTensorFlowDeviceType(device_ref->type().name()));
147 // TF Devices created here might not have all of the attributes needed.
148 // Currently, it is only used by Placer during TFRT Function creation.
149 tf_devices.emplace_back(NewRemoteDevice(tensorflow::Env::Default(), da));
150 }
151 tf_devices_ =
152 std::make_unique<tensorflow::StaticDeviceMgr>(std::move(tf_devices));
153 return ::tensorflow::OkStatus();
154 #endif // PLATFORM_GOOGLE
155 return tensorflow::errors::Unimplemented(
156 "SetOrUpdateServerDef in open source is not yet implemented.");
157 }
158
EnableCollectiveOps(const tensorflow::ServerDef & server_def)159 tensorflow::Status DistributedManagerContextImpl::EnableCollectiveOps(
160 const tensorflow::ServerDef& server_def) {
161 #if defined(PLATFORM_GOOGLE)
162 DistributedContextConfiguration dist_config =
163 ConvertServerDefToDistributedConfiguration(server_def);
164 server_context_ = CreateServer(dist_config, host_context_);
165
166 DistributedInitHelper* init_helper =
167 server_context_->GetDistributedInitHelper();
168 absl::Notification n;
169 init_helper->InitializeMultiClientDistributedContext(
170 std::move(dist_config),
171 [&n, this](Expected<DistributedContext*> expected) mutable {
172 if (!expected) tfrt::DieIfError(expected.takeError());
173 const uint64_t cid = expected.get()->GetContextId();
174 dist_context_ = server_context_->GetDistributedContextAsyncValue(cid);
175 n.Notify();
176 });
177 n.WaitForNotification();
178
179 return ::tensorflow::OkStatus();
180 #endif // PLATFORM_GOOGLE
181 return tensorflow::errors::Unimplemented(
182 "EnableCollectiveOps in open source is not yet implemented.");
183 }
184
CheckRemoteAlive(const std::string & remote_task_name,bool * is_alive)185 tensorflow::Status DistributedManagerContextImpl::CheckRemoteAlive(
186 const std::string& remote_task_name, bool* is_alive) {
187 return tensorflow::errors::Unimplemented(
188 "CheckRemoteAlive in TFRT is not yet implemented.");
189 }
190
191 tensorflow::CoordinationServiceAgent*
GetCoordinationServiceAgent()192 DistributedManagerContextImpl::GetCoordinationServiceAgent() {
193 TFRT_LOG(FATAL) << "Coordination service in TFRT is not yet enabled.";
194 return nullptr;
195 }
196
UpdateRequestContextBuilder(RequestContextBuilder * builder)197 void DistributedManagerContextImpl::UpdateRequestContextBuilder(
198 RequestContextBuilder* builder) {
199 builder->context_data().insert(dist_context_.CopyRef());
200 }
201
PopulateRemoteDevices(tensorflow::DeviceSet * dev_set)202 void DistributedManagerContextImpl::PopulateRemoteDevices(
203 tensorflow::DeviceSet* dev_set) {
204 if (tf_devices_ == nullptr) {
205 return;
206 }
207 for (auto& device : tf_devices_->ListDevices()) {
208 dev_set->AddDevice(device);
209 }
210 }
211
212 std::unique_ptr<DistributedManagerContextInterface>
CreateDistributedManagerContext(HostContext * host_context)213 CreateDistributedManagerContext(HostContext* host_context) {
214 return std::make_unique<DistributedManagerContextImpl>(host_context);
215 }
216
217 } // namespace tf
218 } // namespace tfrt
219