• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/service/worker_impl.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "grpcpp/create_channel.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/string_view.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/c/c_api_internal.h"
27 #include "tensorflow/c/tf_status_helper.h"
28 #include "tensorflow/core/data/dataset.pb.h"
29 #include "tensorflow/core/data/service/auto_shard_rewriter.h"
30 #include "tensorflow/core/data/service/common.h"
31 #include "tensorflow/core/data/service/common.pb.h"
32 #include "tensorflow/core/data/service/credentials_factory.h"
33 #include "tensorflow/core/data/service/data_transfer.h"
34 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
35 #include "tensorflow/core/data/service/dispatcher.pb.h"
36 #include "tensorflow/core/data/service/dispatcher_client.h"
37 #include "tensorflow/core/data/service/grpc_util.h"
38 #include "tensorflow/core/data/service/split_provider.h"
39 #include "tensorflow/core/data/service/task_runner.h"
40 #include "tensorflow/core/data/service/utils.h"
41 #include "tensorflow/core/data/service/worker.pb.h"
42 #include "tensorflow/core/data/standalone.h"
43 #include "tensorflow/core/framework/dataset_options.pb.h"
44 #include "tensorflow/core/framework/metrics.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/framework/tensor.pb.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
49 #include "tensorflow/core/lib/monitoring/gauge.h"
50 #include "tensorflow/core/platform/env.h"
51 #include "tensorflow/core/platform/errors.h"
52 #include "tensorflow/core/platform/refcount.h"
53 #include "tensorflow/core/platform/snappy.h"
54 #include "tensorflow/core/platform/status.h"
55 #include "tensorflow/core/platform/statusor.h"
56 #include "tensorflow/core/platform/thread_annotations.h"
57 #include "tensorflow/core/public/session_options.h"
58 
59 namespace tensorflow {
60 namespace data {
61 namespace {
62 
63 const constexpr uint64 kRetryIntervalMicros = 5ull * 1000 * 1000;
64 
65 // Moves the element into the response. If the tensor contains a single
66 // CompressedElement variant, the move will be zero-copy. Otherwise, the tensor
67 // data will be serialized as TensorProtos.
MoveElementToResponse(std::vector<Tensor> && element,GetElementResponse & resp)68 Status MoveElementToResponse(std::vector<Tensor>&& element,
69                              GetElementResponse& resp) {
70   if (element.size() != 1 || element[0].dtype() != DT_VARIANT ||
71       !TensorShapeUtils::IsScalar(element[0].shape())) {
72     for (const auto& component : element) {
73       UncompressedElement* uncompressed = resp.mutable_uncompressed();
74       component.AsProtoTensorContent(uncompressed->add_components());
75     }
76     return Status::OK();
77   }
78   Variant& variant = element[0].scalar<Variant>()();
79   CompressedElement* compressed = variant.get<CompressedElement>();
80   if (compressed == nullptr) {
81     return errors::FailedPrecondition(
82         "Expected dataset to produce a CompressedElement variant tensor, but "
83         "it produced ",
84         variant.TypeName());
85   }
86   *resp.mutable_compressed() = *compressed;
87   return Status::OK();
88 }
89 }  // namespace
90 
91 mutex LocalWorkers::mu_(LINKER_INITIALIZED);
92 LocalWorkers::AddressToWorkerMap* LocalWorkers::local_workers_ =
93     new AddressToWorkerMap();
94 
DataServiceWorkerImpl(const experimental::WorkerConfig & config)95 DataServiceWorkerImpl::DataServiceWorkerImpl(
96     const experimental::WorkerConfig& config)
97     : config_(config) {
98   metrics::RecordTFDataServiceWorkerCreated();
99 }
100 
~DataServiceWorkerImpl()101 DataServiceWorkerImpl::~DataServiceWorkerImpl() {
102   mutex_lock l(mu_);
103   cancelled_ = true;
104   task_completion_cv_.notify_one();
105   heartbeat_cv_.notify_one();
106 }
107 
Start(const std::string & worker_address,const std::string & transfer_address)108 Status DataServiceWorkerImpl::Start(const std::string& worker_address,
109                                     const std::string& transfer_address) {
110   VLOG(3) << "Starting tf.data service worker at address " << worker_address;
111   worker_address_ = worker_address;
112   transfer_address_ = transfer_address;
113 
114   dispatcher_ = absl::make_unique<DataServiceDispatcherClient>(
115       config_.dispatcher_address(), config_.protocol());
116   TF_RETURN_IF_ERROR(dispatcher_->Initialize());
117 
118   Status s = Heartbeat();
119   while (!s.ok()) {
120     if (!errors::IsUnavailable(s) && !errors::IsAborted(s) &&
121         !errors::IsCancelled(s)) {
122       return s;
123     }
124     LOG(WARNING) << "Failed to register with dispatcher at "
125                  << config_.dispatcher_address() << ": " << s;
126     Env::Default()->SleepForMicroseconds(kRetryIntervalMicros);
127     s = Heartbeat();
128   }
129   LOG(INFO) << "Worker registered with dispatcher running at "
130             << config_.dispatcher_address();
131   task_completion_thread_ = absl::WrapUnique(
132       Env::Default()->StartThread({}, "data-service-worker-task-completion",
133                                   [this]() { TaskCompletionThread(); }));
134   heartbeat_thread_ = absl::WrapUnique(Env::Default()->StartThread(
135       {}, "data-service-worker-heartbeat", [this]() { HeartbeatThread(); }));
136   mutex_lock l(mu_);
137   registered_ = true;
138   return Status::OK();
139 }
140 
Stop()141 void DataServiceWorkerImpl::Stop() {
142   std::vector<std::shared_ptr<Task>> tasks;
143   {
144     mutex_lock l(mu_);
145     cancelled_ = true;
146     for (const auto& entry : tasks_) {
147       tasks.push_back(entry.second);
148     }
149   }
150   for (auto& task : tasks) {
151     StopTask(*task);
152   }
153   // At this point there are no outstanding requests in this RPC handler.
154   // However, requests successfully returned from this RPC handler may still be
155   // in progress within the gRPC server. If we shut down the gRPC server
156   // immediately, it could cause these requests to fail, e.g. with broken pipe.
157   // To mitigate this, we sleep for some time to give the gRPC server time to
158   // complete requests.
159   Env::Default()->SleepForMicroseconds(config_.shutdown_quiet_period_ms() *
160                                        1000);
161 }
162 
GetElementResult(const GetElementRequest * request,struct GetElementResult * result)163 Status DataServiceWorkerImpl::GetElementResult(
164     const GetElementRequest* request, struct GetElementResult* result) {
165   Task* task;
166   {
167     mutex_lock l(mu_);
168     if (cancelled_) {
169       return errors::Cancelled("Worker is shutting down");
170     }
171     if (!registered_) {
172       // We need to reject requests until the worker has registered with the
173       // dispatcher, so that we don't return NOT_FOUND for tasks that the worker
174       // had before preemption.
175       return errors::Unavailable(
176           "Worker has not yet registered with dispatcher.");
177     }
178     auto it = tasks_.find(request->task_id());
179     if (it == tasks_.end()) {
180       if (finished_tasks_.contains(request->task_id())) {
181         VLOG(3) << "Task is already finished";
182         result->end_of_sequence = true;
183         result->skip = false;
184         return Status::OK();
185       } else {
186         // Perhaps the workers hasn't gotten the task from the dispatcher yet.
187         // Return Unavailable so that the client knows to continue retrying.
188         return errors::Unavailable("Task ", request->task_id(), " not found");
189       }
190     }
191     task = it->second.get();
192     TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
193     task->outstanding_requests++;
194   }
195   auto cleanup = gtl::MakeCleanup([&] {
196     mutex_lock l(mu_);
197     task->outstanding_requests--;
198     cv_.notify_all();
199   });
200   TF_RETURN_IF_ERROR(task->task_runner->GetNext(*request, *result));
201 
202   if (result->end_of_sequence) {
203     mutex_lock l(mu_);
204     VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
205     pending_completed_tasks_.insert(request->task_id());
206     task_completion_cv_.notify_one();
207   }
208   return Status::OK();
209 }
210 
ProcessTask(const ProcessTaskRequest * request,ProcessTaskResponse * response)211 Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
212                                           ProcessTaskResponse* response) {
213   mutex_lock l(mu_);
214   const TaskDef& task = request->task();
215   VLOG(3) << "Received request to process task " << task.task_id();
216   return ProcessTaskInternal(task);
217 }
218 
ProcessTaskInternal(const TaskDef & task_def)219 Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def)
220     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
221   std::shared_ptr<Task>& task = tasks_[task_def.task_id()];
222   if (task) {
223     VLOG(1) << "Received request to process already-processed task "
224             << task->task_def.task_id();
225     return Status::OK();
226   }
227   task = absl::make_unique<Task>(task_def);
228   VLOG(3) << "Began processing for task " << task_def.task_id()
229           << " with processing mode "
230           << task_def.processing_mode_def().DebugString();
231   return Status::OK();
232 }
233 
EnsureTaskInitialized(DataServiceWorkerImpl::Task & task)234 Status DataServiceWorkerImpl::EnsureTaskInitialized(
235     DataServiceWorkerImpl::Task& task) {
236   if (task.task_def.worker_address() != worker_address_) {
237     return errors::Internal(absl::Substitute(
238         "Dispatcher's worker address $0 does not match worker's address $1.",
239         task.task_def.worker_address(), worker_address_));
240   }
241 
242   mutex_lock l(task.mu);
243   if (task.initialized) {
244     return Status::OK();
245   }
246   TF_ASSIGN_OR_RETURN(DatasetDef dataset_def, GetDatasetDef(task.task_def));
247   TF_ASSIGN_OR_RETURN(std::unique_ptr<standalone::Dataset> dataset,
248                       MakeDataset(dataset_def, task.task_def));
249   TF_ASSIGN_OR_RETURN(std::unique_ptr<standalone::Iterator> iterator,
250                       MakeDatasetIterator(*dataset, task.task_def));
251   auto task_iterator = absl::make_unique<StandaloneTaskIterator>(
252       std::move(dataset), std::move(iterator));
253   TF_RETURN_IF_ERROR(TaskRunner::Create(
254       config_, task.task_def, std::move(task_iterator), task.task_runner));
255 
256   task.initialized = true;
257   VLOG(3) << "Created iterator for task " << task.task_def.task_id();
258   return Status::OK();
259 }
260 
GetDatasetDef(const TaskDef & task_def) const261 StatusOr<DatasetDef> DataServiceWorkerImpl::GetDatasetDef(
262     const TaskDef& task_def) const {
263   switch (task_def.dataset_case()) {
264     case TaskDef::kDatasetDef:
265       return task_def.dataset_def();
266     case TaskDef::kPath: {
267       DatasetDef def;
268       Status s = ReadDatasetDef(task_def.path(), def);
269       if (!s.ok()) {
270         LOG(INFO) << "Failed to read dataset from " << task_def.path() << ": "
271                   << s << ". Falling back to reading from dispatcher.";
272         TF_RETURN_IF_ERROR(
273             dispatcher_->GetDatasetDef(task_def.dataset_id(), def));
274       }
275       return def;
276     }
277     case TaskDef::DATASET_NOT_SET:
278       return errors::Internal("Unrecognized dataset case: ",
279                               task_def.dataset_case());
280   }
281 }
282 
283 StatusOr<std::unique_ptr<standalone::Dataset>>
MakeDataset(const DatasetDef & dataset_def,const TaskDef & task_def) const284 DataServiceWorkerImpl::MakeDataset(const DatasetDef& dataset_def,
285                                    const TaskDef& task_def) const {
286   TF_ASSIGN_OR_RETURN(AutoShardRewriter auto_shard_rewriter,
287                       AutoShardRewriter::Create(task_def));
288   // `ApplyAutoShardRewrite` does nothing if auto-sharding is disabled.
289   TF_ASSIGN_OR_RETURN(
290       GraphDef rewritten_graph,
291       auto_shard_rewriter.ApplyAutoShardRewrite(dataset_def.graph()));
292   std::unique_ptr<standalone::Dataset> dataset;
293   TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
294       standalone::Dataset::Params(), rewritten_graph, &dataset));
295   return dataset;
296 }
297 
298 StatusOr<std::unique_ptr<standalone::Iterator>>
MakeDatasetIterator(standalone::Dataset & dataset,const TaskDef & task_def) const299 DataServiceWorkerImpl::MakeDatasetIterator(standalone::Dataset& dataset,
300                                            const TaskDef& task_def) const {
301   std::unique_ptr<standalone::Iterator> iterator;
302   if (IsNoShard(task_def.processing_mode_def()) ||
303       IsStaticShard(task_def.processing_mode_def())) {
304     TF_RETURN_IF_ERROR(dataset.MakeIterator(&iterator));
305     return iterator;
306   }
307 
308   if (IsDynamicShard(task_def.processing_mode_def())) {
309     std::vector<std::unique_ptr<SplitProvider>> split_providers;
310     split_providers.reserve(task_def.num_split_providers());
311     for (int i = 0; i < task_def.num_split_providers(); ++i) {
312       split_providers.push_back(absl::make_unique<DataServiceSplitProvider>(
313           config_.dispatcher_address(), config_.protocol(), task_def.job_id(),
314           i, config_.dispatcher_timeout_ms()));
315     }
316     TF_RETURN_IF_ERROR(
317         dataset.MakeIterator(std::move(split_providers), &iterator));
318     return iterator;
319   }
320 
321   return errors::InvalidArgument("Unrecognized processing mode: ",
322                                  task_def.processing_mode_def().DebugString());
323 }
324 
StopTask(Task & task)325 void DataServiceWorkerImpl::StopTask(Task& task) TF_LOCKS_EXCLUDED(mu_) {
326   {
327     mutex_lock l(task.mu);
328     task.initialized = true;
329   }
330   if (task.task_runner) {
331     task.task_runner->Cancel();
332   }
333   mutex_lock l(mu_);
334   while (task.outstanding_requests > 0) {
335     cv_.wait(l);
336   }
337 }
338 
GetElement(const GetElementRequest * request,GetElementResponse * response)339 Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
340                                          GetElementResponse* response) {
341   VLOG(3) << "Received GetElement request for task " << request->task_id();
342   struct GetElementResult result;
343   TF_RETURN_IF_ERROR(GetElementResult(request, &result));
344   response->set_end_of_sequence(result.end_of_sequence);
345   response->set_skip_task(result.skip);
346   if (!response->end_of_sequence() && !response->skip_task()) {
347     TF_RETURN_IF_ERROR(
348         MoveElementToResponse(std::move(result.components), *response));
349     VLOG(3) << "Producing an element for task " << request->task_id();
350   }
351   return Status::OK();
352 }
353 
GetWorkerTasks(const GetWorkerTasksRequest * request,GetWorkerTasksResponse * response)354 Status DataServiceWorkerImpl::GetWorkerTasks(
355     const GetWorkerTasksRequest* request, GetWorkerTasksResponse* response) {
356   mutex_lock l(mu_);
357   for (const auto& it : tasks_) {
358     Task* task = it.second.get();
359     TaskInfo* task_info = response->add_tasks();
360     task_info->set_worker_address(worker_address_);
361     task_info->set_task_id(task->task_def.task_id());
362     task_info->set_job_id(task->task_def.job_id());
363   }
364   return Status::OK();
365 }
366 
TaskCompletionThread()367 void DataServiceWorkerImpl::TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_) {
368   while (true) {
369     {
370       mutex_lock l(mu_);
371       while (!cancelled_ && pending_completed_tasks_.empty()) {
372         task_completion_cv_.wait(l);
373       }
374       if (cancelled_) {
375         VLOG(3) << "Task completion thread shutting down";
376         return;
377       }
378     }
379     Status s = SendTaskUpdates();
380     if (!s.ok()) {
381       LOG(WARNING) << "Failed to send task updates to dispatcher: " << s;
382       mutex_lock l(mu_);
383       if (!cancelled_) {
384         task_completion_cv_.wait_for(
385             l, std::chrono::microseconds(kRetryIntervalMicros));
386       }
387     }
388   }
389 }
390 
SendTaskUpdates()391 Status DataServiceWorkerImpl::SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_) {
392   std::vector<TaskProgress> task_progress;
393   {
394     mutex_lock l(mu_);
395     VLOG(3) << "Sending " << pending_completed_tasks_.size()
396             << " task updates to dispatcher";
397     task_progress.reserve(pending_completed_tasks_.size());
398     for (int task_id : pending_completed_tasks_) {
399       task_progress.emplace_back();
400       task_progress.back().set_task_id(task_id);
401       task_progress.back().set_completed(true);
402     }
403   }
404 
405   TF_RETURN_IF_ERROR(dispatcher_->WorkerUpdate(worker_address_, task_progress));
406   mutex_lock l(mu_);
407   for (const auto& update : task_progress) {
408     pending_completed_tasks_.erase(update.task_id());
409   }
410   VLOG(3) << "Sent " << task_progress.size() << " task updates ";
411   return Status::OK();
412 }
413 
HeartbeatThread()414 void DataServiceWorkerImpl::HeartbeatThread() TF_LOCKS_EXCLUDED(mu_) {
415   while (true) {
416     int64_t next_heartbeat_micros =
417         Env::Default()->NowMicros() + (config_.heartbeat_interval_ms() * 1000);
418     {
419       mutex_lock l(mu_);
420       while (!cancelled_ &&
421              Env::Default()->NowMicros() < next_heartbeat_micros) {
422         int64_t time_to_wait_micros =
423             next_heartbeat_micros - Env::Default()->NowMicros();
424         heartbeat_cv_.wait_for(l,
425                                std::chrono::microseconds(time_to_wait_micros));
426       }
427       if (cancelled_) {
428         VLOG(3) << "Heartbeat thread shutting down";
429         return;
430       }
431       if (!registered_) {
432         VLOG(1) << "Not performing heartbeat; worker is not yet registered";
433         continue;
434       }
435     }
436     Status s = Heartbeat();
437     if (!s.ok()) {
438       LOG(WARNING) << "Failed to send heartbeat to dispatcher: " << s;
439     }
440   }
441 }
442 
Heartbeat()443 Status DataServiceWorkerImpl::Heartbeat() TF_LOCKS_EXCLUDED(mu_) {
444   std::vector<int64> current_tasks;
445   {
446     mutex_lock l(mu_);
447     for (const auto& task : tasks_) {
448       current_tasks.push_back(task.first);
449     }
450   }
451   std::vector<TaskDef> new_tasks;
452   std::vector<int64> task_ids_to_delete;
453   TF_RETURN_IF_ERROR(dispatcher_->WorkerHeartbeat(
454       worker_address_, transfer_address_, current_tasks, new_tasks,
455       task_ids_to_delete));
456   std::vector<std::shared_ptr<Task>> tasks_to_delete;
457   {
458     mutex_lock l(mu_);
459     for (const auto& task : new_tasks) {
460       VLOG(1) << "Received new task from dispatcher with id " << task.task_id();
461       Status s = ProcessTaskInternal(task);
462       if (!s.ok() && !errors::IsAlreadyExists(s)) {
463         LOG(WARNING) << "Failed to start processing task " << task.task_id()
464                      << ": " << s;
465       }
466     }
467     tasks_to_delete.reserve(task_ids_to_delete.size());
468     for (int64_t task_id : task_ids_to_delete) {
469       VLOG(3) << "Deleting task " << task_id
470               << " at the request of the dispatcher";
471       tasks_to_delete.push_back(std::move(tasks_[task_id]));
472       tasks_.erase(task_id);
473       finished_tasks_.insert(task_id);
474     }
475   }
476   for (const auto& task : tasks_to_delete) {
477     StopTask(*task);
478   }
479   return Status::OK();
480 }
481 
Add(absl::string_view worker_address,std::shared_ptr<DataServiceWorkerImpl> worker)482 void LocalWorkers::Add(absl::string_view worker_address,
483                        std::shared_ptr<DataServiceWorkerImpl> worker) {
484   DCHECK(worker != nullptr) << "Adding a nullptr local worker is disallowed.";
485   VLOG(1) << "Register local worker at address " << worker_address;
486   mutex_lock l(mu_);
487   (*local_workers_)[worker_address] = worker;
488 }
489 
Get(absl::string_view worker_address)490 std::shared_ptr<DataServiceWorkerImpl> LocalWorkers::Get(
491     absl::string_view worker_address) {
492   tf_shared_lock l(mu_);
493   AddressToWorkerMap::const_iterator it = local_workers_->find(worker_address);
494   if (it == local_workers_->end()) {
495     return nullptr;
496   }
497   return it->second;
498 }
499 
Empty()500 bool LocalWorkers::Empty() {
501   tf_shared_lock l(mu_);
502   return local_workers_->empty();
503 }
504 
Remove(absl::string_view worker_address)505 void LocalWorkers::Remove(absl::string_view worker_address) {
506   VLOG(1) << "Remove local worker at address " << worker_address;
507   mutex_lock l(mu_);
508   local_workers_->erase(worker_address);
509 }
510 
511 }  // namespace data
512 }  // namespace tensorflow
513