• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h"
17 
18 #include "absl/synchronization/notification.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/eager/context.h"
21 #include "tensorflow/core/distributed_runtime/remote_device.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/platform/platform.h"
24 #include "tensorflow/core/platform/random.h"
25 #include "tensorflow/core/protobuf/cluster.pb.h"
26 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
27 #include "tfrt/distributed_runtime/distributed_context.h"  // from @tf_runtime
28 #include "tfrt/distributed_runtime/distributed_init_helper.h"  // from @tf_runtime
29 #include "tfrt/distributed_runtime/fabric_communicator.h"  // from @tf_runtime
30 #include "tfrt/distributed_runtime/proto/cluster_config.pb.h"  // from @tf_runtime
31 #include "tfrt/distributed_runtime/server_context.h"  // from @tf_runtime
32 #include "tfrt/distributed_runtime/task_name_util.h"  // from @tf_runtime
33 
34 namespace tfrt {
35 namespace tf {
36 namespace {
37 
38 constexpr char kRemotePrefix[] = "remote_";
39 
GetTensorFlowDeviceType(string_view name)40 std::string GetTensorFlowDeviceType(string_view name) {
41   int pos = name.find(kRemotePrefix);
42   return absl::AsciiStrToUpper(
43       pos == 0 ? name.substr(strlen(kRemotePrefix)).str() : name.str());
44 }
45 
ConvertServerDefToDistributedConfiguration(const tensorflow::ServerDef & server_def)46 DistributedContextConfiguration ConvertServerDefToDistributedConfiguration(
47     const tensorflow::ServerDef& server_def) {
48   DistributedContextConfiguration dist_config;
49   dist_config.set_job_name(server_def.job_name());
50   dist_config.set_task_id(server_def.task_index());
51   ClusterConfiguration* cluster_config = dist_config.mutable_cluster_config();
52   // Currently take the first task in the first job as collective group leader.
53   // TODO(haoyuzhang): Make this configurable from API by reading from
54   // `config.experimental.collective_group_leader`.
55   cluster_config->set_lead_task_name(TaskNameUtil::ConcatTaskName(
56       server_def.cluster().job(0).name(), /*task_id=*/0));
57   for (const auto& job_def : server_def.cluster().job()) {
58     JobConfiguration* job_config = cluster_config->add_jobs();
59     job_config->set_name(job_def.name());
60     *job_config->mutable_tasks() = job_def.tasks();
61   }
62   return dist_config;
63 }
64 
CreateServer(const DistributedContextConfiguration & dist_config,HostContext * host_ctx)65 std::unique_ptr<ServerContext> CreateServer(
66     const DistributedContextConfiguration& dist_config, HostContext* host_ctx) {
67   const std::string& job_name = dist_config.job_name();
68   const int task_id = dist_config.task_id();
69   std::string server_address;
70   for (const auto& job_config : dist_config.cluster_config().jobs()) {
71     if (job_config.name() == job_name) {
72       server_address = job_config.tasks().at(task_id);
73       break;
74     }
75   }
76   FabricCommunicatorConfiguration fabric_config{"grpc_communicator",
77                                                 server_address};
78   ServerContextConfiguration server_config{fabric_config};
79   return std::make_unique<ServerContext>(host_ctx, server_config);
80 }
81 
82 }  // namespace
83 
84 class DistributedManagerContextImpl
85     : public DistributedManagerContextInterface {
86  public:
87   explicit DistributedManagerContextImpl(HostContext* host_context);
88 
89   tensorflow::Status SetOrUpdateServerDef(
90       const tensorflow::ServerDef& server_def, bool reset_context,
91       int keep_alive_secs) override;
92 
93   tensorflow::Status EnableCollectiveOps(
94       const tensorflow::ServerDef& server_def) override;
95 
96   tensorflow::Status CheckRemoteAlive(const std::string& remote_task_name,
97                                       bool* is_alive) override;
98 
99   tensorflow::CoordinationServiceAgent* GetCoordinationServiceAgent() override;
100 
101   void UpdateRequestContextBuilder(RequestContextBuilder* builder) override;
102   void PopulateRemoteDevices(tensorflow::DeviceSet* dev_set) override;
103 
104  private:
105   HostContext* host_context_;
106   std::unique_ptr<tfrt::ServerContext> server_context_;
107   AsyncValueRef<tfrt::DistributedContext> dist_context_;
108   std::unique_ptr<tensorflow::StaticDeviceMgr> tf_devices_;
109 };
110 
DistributedManagerContextImpl(HostContext * host_context)111 DistributedManagerContextImpl::DistributedManagerContextImpl(
112     HostContext* host_context)
113     : host_context_(host_context) {
114   TaskNameUtil::SetUseReplicaInTaskName();
115 }
116 
SetOrUpdateServerDef(const tensorflow::ServerDef & server_def,bool reset_context,int keep_alive_secs)117 tensorflow::Status DistributedManagerContextImpl::SetOrUpdateServerDef(
118     const tensorflow::ServerDef& server_def, bool reset_context,
119     int keep_alive_secs) {
120 #if defined(PLATFORM_GOOGLE)
121   DistributedContextConfiguration dist_config =
122       ConvertServerDefToDistributedConfiguration(server_def);
123   server_context_ = CreateServer(dist_config, host_context_);
124 
125   // Create distributed contexts on current and remote tasks. Implemented as a
126   // blocking call to be consistent with the behavior of current TF.
127   const DistributedInitHelper* init_helper =
128       server_context_->GetDistributedInitHelper();
129   absl::Notification n;
130   init_helper->InitializeSingleClientDistributedContext(
131       std::move(dist_config),
132       [&n, this](Expected<DistributedContext*> expected) mutable {
133         if (!expected) tfrt::DieIfError(expected.takeError());
134         const uint64_t cid = expected.get()->GetContextId();
135         dist_context_ = server_context_->GetDistributedContextAsyncValue(cid);
136         n.Notify();
137       });
138   n.WaitForNotification();
139 
140   auto device_refs =
141       dist_context_->GetRemoteDeviceManager()->ListDevices<Device>();
142   std::vector<std::unique_ptr<tensorflow::Device>> tf_devices;
143   for (auto& device_ref : device_refs) {
144     tensorflow::DeviceAttributes da;
145     da.set_name(device_ref->name().str());
146     da.set_device_type(GetTensorFlowDeviceType(device_ref->type().name()));
147     // TF Devices created here might not have all of the attributes needed.
148     // Currently, it is only used by Placer during TFRT Function creation.
149     tf_devices.emplace_back(NewRemoteDevice(tensorflow::Env::Default(), da));
150   }
151   tf_devices_ =
152       std::make_unique<tensorflow::StaticDeviceMgr>(std::move(tf_devices));
153   return ::tensorflow::OkStatus();
154 #endif  // PLATFORM_GOOGLE
155   return tensorflow::errors::Unimplemented(
156       "SetOrUpdateServerDef in open source is not yet implemented.");
157 }
158 
EnableCollectiveOps(const tensorflow::ServerDef & server_def)159 tensorflow::Status DistributedManagerContextImpl::EnableCollectiveOps(
160     const tensorflow::ServerDef& server_def) {
161 #if defined(PLATFORM_GOOGLE)
162   DistributedContextConfiguration dist_config =
163       ConvertServerDefToDistributedConfiguration(server_def);
164   server_context_ = CreateServer(dist_config, host_context_);
165 
166   DistributedInitHelper* init_helper =
167       server_context_->GetDistributedInitHelper();
168   absl::Notification n;
169   init_helper->InitializeMultiClientDistributedContext(
170       std::move(dist_config),
171       [&n, this](Expected<DistributedContext*> expected) mutable {
172         if (!expected) tfrt::DieIfError(expected.takeError());
173         const uint64_t cid = expected.get()->GetContextId();
174         dist_context_ = server_context_->GetDistributedContextAsyncValue(cid);
175         n.Notify();
176       });
177   n.WaitForNotification();
178 
179   return ::tensorflow::OkStatus();
180 #endif  // PLATFORM_GOOGLE
181   return tensorflow::errors::Unimplemented(
182       "EnableCollectiveOps in open source is not yet implemented.");
183 }
184 
CheckRemoteAlive(const std::string & remote_task_name,bool * is_alive)185 tensorflow::Status DistributedManagerContextImpl::CheckRemoteAlive(
186     const std::string& remote_task_name, bool* is_alive) {
187   return tensorflow::errors::Unimplemented(
188       "CheckRemoteAlive in TFRT is not yet implemented.");
189 }
190 
191 tensorflow::CoordinationServiceAgent*
GetCoordinationServiceAgent()192 DistributedManagerContextImpl::GetCoordinationServiceAgent() {
193   TFRT_LOG(FATAL) << "Coordination service in TFRT is not yet enabled.";
194   return nullptr;
195 }
196 
UpdateRequestContextBuilder(RequestContextBuilder * builder)197 void DistributedManagerContextImpl::UpdateRequestContextBuilder(
198     RequestContextBuilder* builder) {
199   builder->context_data().insert(dist_context_.CopyRef());
200 }
201 
PopulateRemoteDevices(tensorflow::DeviceSet * dev_set)202 void DistributedManagerContextImpl::PopulateRemoteDevices(
203     tensorflow::DeviceSet* dev_set) {
204   if (tf_devices_ == nullptr) {
205     return;
206   }
207   for (auto& device : tf_devices_->ListDevices()) {
208     dev_set->AddDevice(device);
209   }
210 }
211 
212 std::unique_ptr<DistributedManagerContextInterface>
CreateDistributedManagerContext(HostContext * host_context)213 CreateDistributedManagerContext(HostContext* host_context) {
214   return std::make_unique<DistributedManagerContextImpl>(host_context);
215 }
216 
217 }  // namespace tf
218 }  // namespace tfrt
219