• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h"
17 
18 #include "absl/synchronization/notification.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/eager/context.h"
21 #include "tensorflow/core/distributed_runtime/remote_device.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/platform/platform.h"
24 #include "tensorflow/core/platform/random.h"
25 #include "tensorflow/core/protobuf/cluster.pb.h"
26 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
27 #include "tfrt/distributed_runtime/distributed_context.h"  // from @tf_runtime
28 #include "tfrt/distributed_runtime/distributed_init_helper.h"  // from @tf_runtime
29 #include "tfrt/distributed_runtime/fabric_communicator.h"  // from @tf_runtime
30 #include "tfrt/distributed_runtime/proto/cluster_config.pb.h"  // from @tf_runtime
31 #include "tfrt/distributed_runtime/server_context.h"  // from @tf_runtime
32 #include "tfrt/distributed_runtime/task_name_util.h"  // from @tf_runtime
33 
34 namespace tfrt {
35 namespace tf {
36 namespace {
37 
38 constexpr char kRemotePrefix[] = "remote_";
39 
GetTensorFlowDeviceType(string_view name)40 std::string GetTensorFlowDeviceType(string_view name) {
41   int pos = name.find(kRemotePrefix);
42   return absl::AsciiStrToUpper(
43       pos == 0 ? name.substr(strlen(kRemotePrefix)).str() : name.str());
44 }
45 
ConvertServerDefToDistributedConfiguration(const tensorflow::ServerDef & server_def)46 DistributedContextConfiguration ConvertServerDefToDistributedConfiguration(
47     const tensorflow::ServerDef& server_def) {
48   DistributedContextConfiguration dist_config;
49   dist_config.set_job_name(server_def.job_name());
50   dist_config.set_task_id(server_def.task_index());
51   ClusterConfiguration* cluster_config = dist_config.mutable_cluster_config();
52   // Currently take the first task in the first job as collective group leader.
53   // TODO(haoyuzhang): Make this configurable from API by reading from
54   // `config.experimental.collective_group_leader`.
55   cluster_config->set_lead_task_name(TaskNameUtil::ConcatTaskName(
56       server_def.cluster().job(0).name(), /*task_id=*/0));
57   for (const auto& job_def : server_def.cluster().job()) {
58     JobConfiguration* job_config = cluster_config->add_jobs();
59     job_config->set_name(job_def.name());
60     *job_config->mutable_tasks() = job_def.tasks();
61   }
62   return dist_config;
63 }
64 
CreateServer(const DistributedContextConfiguration & dist_config,HostContext * host_ctx)65 std::unique_ptr<ServerContext> CreateServer(
66     const DistributedContextConfiguration& dist_config, HostContext* host_ctx) {
67   const std::string& job_name = dist_config.job_name();
68   const int task_id = dist_config.task_id();
69   std::string server_address;
70   for (const auto& job_config : dist_config.cluster_config().jobs()) {
71     if (job_config.name() == job_name) {
72       server_address = job_config.tasks().at(task_id);
73       break;
74     }
75   }
76   FabricCommunicatorConfiguration fabric_config{"grpc_communicator",
77                                                 server_address};
78   ServerContextConfiguration server_config{fabric_config};
79   return std::make_unique<ServerContext>(host_ctx, server_config);
80 }
81 
82 }  // namespace
83 
84 class DistributedManagerContextImpl
85     : public DistributedManagerContextInterface {
86  public:
87   explicit DistributedManagerContextImpl(HostContext* host_context);
88 
89   tensorflow::Status SetOrUpdateServerDef(
90       const tensorflow::ServerDef& server_def, bool reset_context,
91       int keep_alive_secs) override;
92 
93   tensorflow::Status EnableCollectiveOps(
94       const tensorflow::ServerDef& server_def) override;
95 
96   tensorflow::Status EnableCoordinationService(
97       const std::string& service_type, const tensorflow::WorkerEnv* worker_env,
98       const tensorflow::ServerDef& server_def,
99       tensorflow::WorkerCacheInterface* worker_cache) override;
100 
101   tensorflow::Status CheckRemoteAlive(const std::string& remote_task_name,
102                                       bool* is_alive) override;
103 
104   tensorflow::CoordinationServiceAgent* GetCoordinationServiceAgent() override;
105 
106   void UpdateRequestContextBuilder(RequestContextBuilder* builder) override;
107   void PopulateRemoteDevices(tensorflow::DeviceSet* dev_set) override;
108 
109  private:
110   HostContext* host_context_;
111   std::unique_ptr<tfrt::ServerContext> server_context_;
112   AsyncValueRef<tfrt::DistributedContext> dist_context_;
113   std::unique_ptr<tensorflow::StaticDeviceMgr> tf_devices_;
114 };
115 
DistributedManagerContextImpl(HostContext * host_context)116 DistributedManagerContextImpl::DistributedManagerContextImpl(
117     HostContext* host_context)
118     : host_context_(host_context) {
119   TaskNameUtil::SetUseReplicaInTaskName();
120 }
121 
SetOrUpdateServerDef(const tensorflow::ServerDef & server_def,bool reset_context,int keep_alive_secs)122 tensorflow::Status DistributedManagerContextImpl::SetOrUpdateServerDef(
123     const tensorflow::ServerDef& server_def, bool reset_context,
124     int keep_alive_secs) {
125 #if defined(PLATFORM_GOOGLE)
126   DistributedContextConfiguration dist_config =
127       ConvertServerDefToDistributedConfiguration(server_def);
128   server_context_ = CreateServer(dist_config, host_context_);
129 
130   // Create distributed contexts on current and remote tasks. Implemented as a
131   // blocking call to be consistent with the behavior of current TF.
132   const DistributedInitHelper* init_helper =
133       server_context_->GetDistributedInitHelper();
134   absl::Notification n;
135   init_helper->InitializeSingleClientDistributedContext(
136       std::move(dist_config),
137       [&n, this](Expected<DistributedContext*> expected) mutable {
138         if (!expected) tfrt::DieIfError(expected.takeError());
139         const uint64_t cid = expected.get()->GetContextId();
140         dist_context_ = server_context_->GetDistributedContextAsyncValue(cid);
141         n.Notify();
142       });
143   n.WaitForNotification();
144 
145   auto device_refs =
146       dist_context_->GetRemoteDeviceManager()->ListDevices<Device>();
147   std::vector<std::unique_ptr<tensorflow::Device>> tf_devices;
148   for (auto& device_ref : device_refs) {
149     tensorflow::DeviceAttributes da;
150     da.set_name(device_ref->name().str());
151     da.set_device_type(GetTensorFlowDeviceType(device_ref->type().name()));
152     // TF Devices created here might not have all of the attributes needed.
153     // Currently, it is only used by Placer during TFRT Function creation.
154     tf_devices.emplace_back(NewRemoteDevice(tensorflow::Env::Default(), da));
155   }
156   tf_devices_ =
157       std::make_unique<tensorflow::StaticDeviceMgr>(std::move(tf_devices));
158   return tensorflow::Status::OK();
159 #endif  // PLATFORM_GOOGLE
160   return tensorflow::errors::Unimplemented(
161       "SetOrUpdateServerDef in open source is not yet implemented.");
162 }
163 
EnableCollectiveOps(const tensorflow::ServerDef & server_def)164 tensorflow::Status DistributedManagerContextImpl::EnableCollectiveOps(
165     const tensorflow::ServerDef& server_def) {
166 #if defined(PLATFORM_GOOGLE)
167   DistributedContextConfiguration dist_config =
168       ConvertServerDefToDistributedConfiguration(server_def);
169   server_context_ = CreateServer(dist_config, host_context_);
170 
171   DistributedInitHelper* init_helper =
172       server_context_->GetDistributedInitHelper();
173   absl::Notification n;
174   init_helper->InitializeMultiClientDistributedContext(
175       std::move(dist_config),
176       [&n, this](Expected<DistributedContext*> expected) mutable {
177         if (!expected) tfrt::DieIfError(expected.takeError());
178         const uint64_t cid = expected.get()->GetContextId();
179         dist_context_ = server_context_->GetDistributedContextAsyncValue(cid);
180         n.Notify();
181       });
182   n.WaitForNotification();
183 
184   return tensorflow::Status::OK();
185 #endif  // PLATFORM_GOOGLE
186   return tensorflow::errors::Unimplemented(
187       "EnableCollectiveOps in open source is not yet implemented.");
188 }
189 
EnableCoordinationService(const std::string & service_type,const tensorflow::WorkerEnv * worker_env,const tensorflow::ServerDef & server_def,tensorflow::WorkerCacheInterface * worker_cache)190 tensorflow::Status DistributedManagerContextImpl::EnableCoordinationService(
191     const std::string& service_type, const tensorflow::WorkerEnv* worker_env,
192     const tensorflow::ServerDef& server_def,
193     tensorflow::WorkerCacheInterface* worker_cache) {
194   return tensorflow::errors::Unimplemented(
195       "EnableCoordinationService in TFRT is not yet implemented.");
196 }
197 
CheckRemoteAlive(const std::string & remote_task_name,bool * is_alive)198 tensorflow::Status DistributedManagerContextImpl::CheckRemoteAlive(
199     const std::string& remote_task_name, bool* is_alive) {
200   return tensorflow::errors::Unimplemented(
201       "CheckRemoteAlive in TFRT is not yet implemented.");
202 }
203 
204 tensorflow::CoordinationServiceAgent*
GetCoordinationServiceAgent()205 DistributedManagerContextImpl::GetCoordinationServiceAgent() {
206   TFRT_LOG(FATAL) << "Coordination service in TFRT is not yet enabled.";
207   return nullptr;
208 }
209 
UpdateRequestContextBuilder(RequestContextBuilder * builder)210 void DistributedManagerContextImpl::UpdateRequestContextBuilder(
211     RequestContextBuilder* builder) {
212   builder->context_data().insert(dist_context_.CopyRef());
213 }
214 
PopulateRemoteDevices(tensorflow::DeviceSet * dev_set)215 void DistributedManagerContextImpl::PopulateRemoteDevices(
216     tensorflow::DeviceSet* dev_set) {
217   if (tf_devices_ == nullptr) {
218     return;
219   }
220   for (auto& device : tf_devices_->ListDevices()) {
221     dev_set->AddDevice(device);
222   }
223 }
224 
225 std::unique_ptr<DistributedManagerContextInterface>
CreateDistributedManagerContext(HostContext * host_context)226 CreateDistributedManagerContext(HostContext* host_context) {
227   return std::make_unique<DistributedManagerContextImpl>(host_context);
228 }
229 
230 }  // namespace tf
231 }  // namespace tfrt
232