1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/service/worker_impl.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "grpcpp/create_channel.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/string_view.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/c/c_api_internal.h"
27 #include "tensorflow/c/tf_status_helper.h"
28 #include "tensorflow/core/data/dataset.pb.h"
29 #include "tensorflow/core/data/service/auto_shard_rewriter.h"
30 #include "tensorflow/core/data/service/common.h"
31 #include "tensorflow/core/data/service/common.pb.h"
32 #include "tensorflow/core/data/service/credentials_factory.h"
33 #include "tensorflow/core/data/service/data_transfer.h"
34 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
35 #include "tensorflow/core/data/service/dispatcher.pb.h"
36 #include "tensorflow/core/data/service/dispatcher_client.h"
37 #include "tensorflow/core/data/service/grpc_util.h"
38 #include "tensorflow/core/data/service/split_provider.h"
39 #include "tensorflow/core/data/service/task_runner.h"
40 #include "tensorflow/core/data/service/utils.h"
41 #include "tensorflow/core/data/service/worker.pb.h"
42 #include "tensorflow/core/data/standalone.h"
43 #include "tensorflow/core/framework/dataset_options.pb.h"
44 #include "tensorflow/core/framework/metrics.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/framework/tensor.pb.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
49 #include "tensorflow/core/lib/monitoring/gauge.h"
50 #include "tensorflow/core/platform/env.h"
51 #include "tensorflow/core/platform/errors.h"
52 #include "tensorflow/core/platform/refcount.h"
53 #include "tensorflow/core/platform/snappy.h"
54 #include "tensorflow/core/platform/status.h"
55 #include "tensorflow/core/platform/statusor.h"
56 #include "tensorflow/core/platform/thread_annotations.h"
57 #include "tensorflow/core/public/session_options.h"
58
59 namespace tensorflow {
60 namespace data {
61 namespace {
62
63 const constexpr uint64 kRetryIntervalMicros = 5ull * 1000 * 1000;
64
65 // Moves the element into the response. If the tensor contains a single
66 // CompressedElement variant, the move will be zero-copy. Otherwise, the tensor
67 // data will be serialized as TensorProtos.
MoveElementToResponse(std::vector<Tensor> && element,GetElementResponse & resp)68 Status MoveElementToResponse(std::vector<Tensor>&& element,
69 GetElementResponse& resp) {
70 if (element.size() != 1 || element[0].dtype() != DT_VARIANT ||
71 !TensorShapeUtils::IsScalar(element[0].shape())) {
72 for (const auto& component : element) {
73 UncompressedElement* uncompressed = resp.mutable_uncompressed();
74 component.AsProtoTensorContent(uncompressed->add_components());
75 }
76 return Status::OK();
77 }
78 Variant& variant = element[0].scalar<Variant>()();
79 CompressedElement* compressed = variant.get<CompressedElement>();
80 if (compressed == nullptr) {
81 return errors::FailedPrecondition(
82 "Expected dataset to produce a CompressedElement variant tensor, but "
83 "it produced ",
84 variant.TypeName());
85 }
86 *resp.mutable_compressed() = *compressed;
87 return Status::OK();
88 }
89 } // namespace
90
91 mutex LocalWorkers::mu_(LINKER_INITIALIZED);
92 LocalWorkers::AddressToWorkerMap* LocalWorkers::local_workers_ =
93 new AddressToWorkerMap();
94
DataServiceWorkerImpl(const experimental::WorkerConfig & config)95 DataServiceWorkerImpl::DataServiceWorkerImpl(
96 const experimental::WorkerConfig& config)
97 : config_(config) {
98 metrics::RecordTFDataServiceWorkerCreated();
99 }
100
~DataServiceWorkerImpl()101 DataServiceWorkerImpl::~DataServiceWorkerImpl() {
102 mutex_lock l(mu_);
103 cancelled_ = true;
104 task_completion_cv_.notify_one();
105 heartbeat_cv_.notify_one();
106 }
107
Start(const std::string & worker_address,const std::string & transfer_address)108 Status DataServiceWorkerImpl::Start(const std::string& worker_address,
109 const std::string& transfer_address) {
110 VLOG(3) << "Starting tf.data service worker at address " << worker_address;
111 worker_address_ = worker_address;
112 transfer_address_ = transfer_address;
113
114 dispatcher_ = absl::make_unique<DataServiceDispatcherClient>(
115 config_.dispatcher_address(), config_.protocol());
116 TF_RETURN_IF_ERROR(dispatcher_->Initialize());
117
118 Status s = Heartbeat();
119 while (!s.ok()) {
120 if (!errors::IsUnavailable(s) && !errors::IsAborted(s) &&
121 !errors::IsCancelled(s)) {
122 return s;
123 }
124 LOG(WARNING) << "Failed to register with dispatcher at "
125 << config_.dispatcher_address() << ": " << s;
126 Env::Default()->SleepForMicroseconds(kRetryIntervalMicros);
127 s = Heartbeat();
128 }
129 LOG(INFO) << "Worker registered with dispatcher running at "
130 << config_.dispatcher_address();
131 task_completion_thread_ = absl::WrapUnique(
132 Env::Default()->StartThread({}, "data-service-worker-task-completion",
133 [this]() { TaskCompletionThread(); }));
134 heartbeat_thread_ = absl::WrapUnique(Env::Default()->StartThread(
135 {}, "data-service-worker-heartbeat", [this]() { HeartbeatThread(); }));
136 mutex_lock l(mu_);
137 registered_ = true;
138 return Status::OK();
139 }
140
Stop()141 void DataServiceWorkerImpl::Stop() {
142 std::vector<std::shared_ptr<Task>> tasks;
143 {
144 mutex_lock l(mu_);
145 cancelled_ = true;
146 for (const auto& entry : tasks_) {
147 tasks.push_back(entry.second);
148 }
149 }
150 for (auto& task : tasks) {
151 StopTask(*task);
152 }
153 // At this point there are no outstanding requests in this RPC handler.
154 // However, requests successfully returned from this RPC handler may still be
155 // in progress within the gRPC server. If we shut down the gRPC server
156 // immediately, it could cause these requests to fail, e.g. with broken pipe.
157 // To mitigate this, we sleep for some time to give the gRPC server time to
158 // complete requests.
159 Env::Default()->SleepForMicroseconds(config_.shutdown_quiet_period_ms() *
160 1000);
161 }
162
GetElementResult(const GetElementRequest * request,struct GetElementResult * result)163 Status DataServiceWorkerImpl::GetElementResult(
164 const GetElementRequest* request, struct GetElementResult* result) {
165 Task* task;
166 {
167 mutex_lock l(mu_);
168 if (cancelled_) {
169 return errors::Cancelled("Worker is shutting down");
170 }
171 if (!registered_) {
172 // We need to reject requests until the worker has registered with the
173 // dispatcher, so that we don't return NOT_FOUND for tasks that the worker
174 // had before preemption.
175 return errors::Unavailable(
176 "Worker has not yet registered with dispatcher.");
177 }
178 auto it = tasks_.find(request->task_id());
179 if (it == tasks_.end()) {
180 if (finished_tasks_.contains(request->task_id())) {
181 VLOG(3) << "Task is already finished";
182 result->end_of_sequence = true;
183 result->skip = false;
184 return Status::OK();
185 } else {
186 // Perhaps the workers hasn't gotten the task from the dispatcher yet.
187 // Return Unavailable so that the client knows to continue retrying.
188 return errors::Unavailable("Task ", request->task_id(), " not found");
189 }
190 }
191 task = it->second.get();
192 TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
193 task->outstanding_requests++;
194 }
195 auto cleanup = gtl::MakeCleanup([&] {
196 mutex_lock l(mu_);
197 task->outstanding_requests--;
198 cv_.notify_all();
199 });
200 TF_RETURN_IF_ERROR(task->task_runner->GetNext(*request, *result));
201
202 if (result->end_of_sequence) {
203 mutex_lock l(mu_);
204 VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
205 pending_completed_tasks_.insert(request->task_id());
206 task_completion_cv_.notify_one();
207 }
208 return Status::OK();
209 }
210
ProcessTask(const ProcessTaskRequest * request,ProcessTaskResponse * response)211 Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
212 ProcessTaskResponse* response) {
213 mutex_lock l(mu_);
214 const TaskDef& task = request->task();
215 VLOG(3) << "Received request to process task " << task.task_id();
216 return ProcessTaskInternal(task);
217 }
218
ProcessTaskInternal(const TaskDef & task_def)219 Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def)
220 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
221 std::shared_ptr<Task>& task = tasks_[task_def.task_id()];
222 if (task) {
223 VLOG(1) << "Received request to process already-processed task "
224 << task->task_def.task_id();
225 return Status::OK();
226 }
227 task = absl::make_unique<Task>(task_def);
228 VLOG(3) << "Began processing for task " << task_def.task_id()
229 << " with processing mode "
230 << task_def.processing_mode_def().DebugString();
231 return Status::OK();
232 }
233
EnsureTaskInitialized(DataServiceWorkerImpl::Task & task)234 Status DataServiceWorkerImpl::EnsureTaskInitialized(
235 DataServiceWorkerImpl::Task& task) {
236 if (task.task_def.worker_address() != worker_address_) {
237 return errors::Internal(absl::Substitute(
238 "Dispatcher's worker address $0 does not match worker's address $1.",
239 task.task_def.worker_address(), worker_address_));
240 }
241
242 mutex_lock l(task.mu);
243 if (task.initialized) {
244 return Status::OK();
245 }
246 TF_ASSIGN_OR_RETURN(DatasetDef dataset_def, GetDatasetDef(task.task_def));
247 TF_ASSIGN_OR_RETURN(std::unique_ptr<standalone::Dataset> dataset,
248 MakeDataset(dataset_def, task.task_def));
249 TF_ASSIGN_OR_RETURN(std::unique_ptr<standalone::Iterator> iterator,
250 MakeDatasetIterator(*dataset, task.task_def));
251 auto task_iterator = absl::make_unique<StandaloneTaskIterator>(
252 std::move(dataset), std::move(iterator));
253 TF_RETURN_IF_ERROR(TaskRunner::Create(
254 config_, task.task_def, std::move(task_iterator), task.task_runner));
255
256 task.initialized = true;
257 VLOG(3) << "Created iterator for task " << task.task_def.task_id();
258 return Status::OK();
259 }
260
GetDatasetDef(const TaskDef & task_def) const261 StatusOr<DatasetDef> DataServiceWorkerImpl::GetDatasetDef(
262 const TaskDef& task_def) const {
263 switch (task_def.dataset_case()) {
264 case TaskDef::kDatasetDef:
265 return task_def.dataset_def();
266 case TaskDef::kPath: {
267 DatasetDef def;
268 Status s = ReadDatasetDef(task_def.path(), def);
269 if (!s.ok()) {
270 LOG(INFO) << "Failed to read dataset from " << task_def.path() << ": "
271 << s << ". Falling back to reading from dispatcher.";
272 TF_RETURN_IF_ERROR(
273 dispatcher_->GetDatasetDef(task_def.dataset_id(), def));
274 }
275 return def;
276 }
277 case TaskDef::DATASET_NOT_SET:
278 return errors::Internal("Unrecognized dataset case: ",
279 task_def.dataset_case());
280 }
281 }
282
283 StatusOr<std::unique_ptr<standalone::Dataset>>
MakeDataset(const DatasetDef & dataset_def,const TaskDef & task_def) const284 DataServiceWorkerImpl::MakeDataset(const DatasetDef& dataset_def,
285 const TaskDef& task_def) const {
286 TF_ASSIGN_OR_RETURN(AutoShardRewriter auto_shard_rewriter,
287 AutoShardRewriter::Create(task_def));
288 // `ApplyAutoShardRewrite` does nothing if auto-sharding is disabled.
289 TF_ASSIGN_OR_RETURN(
290 GraphDef rewritten_graph,
291 auto_shard_rewriter.ApplyAutoShardRewrite(dataset_def.graph()));
292 std::unique_ptr<standalone::Dataset> dataset;
293 TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
294 standalone::Dataset::Params(), rewritten_graph, &dataset));
295 return dataset;
296 }
297
298 StatusOr<std::unique_ptr<standalone::Iterator>>
MakeDatasetIterator(standalone::Dataset & dataset,const TaskDef & task_def) const299 DataServiceWorkerImpl::MakeDatasetIterator(standalone::Dataset& dataset,
300 const TaskDef& task_def) const {
301 std::unique_ptr<standalone::Iterator> iterator;
302 if (IsNoShard(task_def.processing_mode_def()) ||
303 IsStaticShard(task_def.processing_mode_def())) {
304 TF_RETURN_IF_ERROR(dataset.MakeIterator(&iterator));
305 return iterator;
306 }
307
308 if (IsDynamicShard(task_def.processing_mode_def())) {
309 std::vector<std::unique_ptr<SplitProvider>> split_providers;
310 split_providers.reserve(task_def.num_split_providers());
311 for (int i = 0; i < task_def.num_split_providers(); ++i) {
312 split_providers.push_back(absl::make_unique<DataServiceSplitProvider>(
313 config_.dispatcher_address(), config_.protocol(), task_def.job_id(),
314 i, config_.dispatcher_timeout_ms()));
315 }
316 TF_RETURN_IF_ERROR(
317 dataset.MakeIterator(std::move(split_providers), &iterator));
318 return iterator;
319 }
320
321 return errors::InvalidArgument("Unrecognized processing mode: ",
322 task_def.processing_mode_def().DebugString());
323 }
324
StopTask(Task & task)325 void DataServiceWorkerImpl::StopTask(Task& task) TF_LOCKS_EXCLUDED(mu_) {
326 {
327 mutex_lock l(task.mu);
328 task.initialized = true;
329 }
330 if (task.task_runner) {
331 task.task_runner->Cancel();
332 }
333 mutex_lock l(mu_);
334 while (task.outstanding_requests > 0) {
335 cv_.wait(l);
336 }
337 }
338
GetElement(const GetElementRequest * request,GetElementResponse * response)339 Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
340 GetElementResponse* response) {
341 VLOG(3) << "Received GetElement request for task " << request->task_id();
342 struct GetElementResult result;
343 TF_RETURN_IF_ERROR(GetElementResult(request, &result));
344 response->set_end_of_sequence(result.end_of_sequence);
345 response->set_skip_task(result.skip);
346 if (!response->end_of_sequence() && !response->skip_task()) {
347 TF_RETURN_IF_ERROR(
348 MoveElementToResponse(std::move(result.components), *response));
349 VLOG(3) << "Producing an element for task " << request->task_id();
350 }
351 return Status::OK();
352 }
353
GetWorkerTasks(const GetWorkerTasksRequest * request,GetWorkerTasksResponse * response)354 Status DataServiceWorkerImpl::GetWorkerTasks(
355 const GetWorkerTasksRequest* request, GetWorkerTasksResponse* response) {
356 mutex_lock l(mu_);
357 for (const auto& it : tasks_) {
358 Task* task = it.second.get();
359 TaskInfo* task_info = response->add_tasks();
360 task_info->set_worker_address(worker_address_);
361 task_info->set_task_id(task->task_def.task_id());
362 task_info->set_job_id(task->task_def.job_id());
363 }
364 return Status::OK();
365 }
366
TaskCompletionThread()367 void DataServiceWorkerImpl::TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_) {
368 while (true) {
369 {
370 mutex_lock l(mu_);
371 while (!cancelled_ && pending_completed_tasks_.empty()) {
372 task_completion_cv_.wait(l);
373 }
374 if (cancelled_) {
375 VLOG(3) << "Task completion thread shutting down";
376 return;
377 }
378 }
379 Status s = SendTaskUpdates();
380 if (!s.ok()) {
381 LOG(WARNING) << "Failed to send task updates to dispatcher: " << s;
382 mutex_lock l(mu_);
383 if (!cancelled_) {
384 task_completion_cv_.wait_for(
385 l, std::chrono::microseconds(kRetryIntervalMicros));
386 }
387 }
388 }
389 }
390
SendTaskUpdates()391 Status DataServiceWorkerImpl::SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_) {
392 std::vector<TaskProgress> task_progress;
393 {
394 mutex_lock l(mu_);
395 VLOG(3) << "Sending " << pending_completed_tasks_.size()
396 << " task updates to dispatcher";
397 task_progress.reserve(pending_completed_tasks_.size());
398 for (int task_id : pending_completed_tasks_) {
399 task_progress.emplace_back();
400 task_progress.back().set_task_id(task_id);
401 task_progress.back().set_completed(true);
402 }
403 }
404
405 TF_RETURN_IF_ERROR(dispatcher_->WorkerUpdate(worker_address_, task_progress));
406 mutex_lock l(mu_);
407 for (const auto& update : task_progress) {
408 pending_completed_tasks_.erase(update.task_id());
409 }
410 VLOG(3) << "Sent " << task_progress.size() << " task updates ";
411 return Status::OK();
412 }
413
HeartbeatThread()414 void DataServiceWorkerImpl::HeartbeatThread() TF_LOCKS_EXCLUDED(mu_) {
415 while (true) {
416 int64_t next_heartbeat_micros =
417 Env::Default()->NowMicros() + (config_.heartbeat_interval_ms() * 1000);
418 {
419 mutex_lock l(mu_);
420 while (!cancelled_ &&
421 Env::Default()->NowMicros() < next_heartbeat_micros) {
422 int64_t time_to_wait_micros =
423 next_heartbeat_micros - Env::Default()->NowMicros();
424 heartbeat_cv_.wait_for(l,
425 std::chrono::microseconds(time_to_wait_micros));
426 }
427 if (cancelled_) {
428 VLOG(3) << "Heartbeat thread shutting down";
429 return;
430 }
431 if (!registered_) {
432 VLOG(1) << "Not performing heartbeat; worker is not yet registered";
433 continue;
434 }
435 }
436 Status s = Heartbeat();
437 if (!s.ok()) {
438 LOG(WARNING) << "Failed to send heartbeat to dispatcher: " << s;
439 }
440 }
441 }
442
Heartbeat()443 Status DataServiceWorkerImpl::Heartbeat() TF_LOCKS_EXCLUDED(mu_) {
444 std::vector<int64> current_tasks;
445 {
446 mutex_lock l(mu_);
447 for (const auto& task : tasks_) {
448 current_tasks.push_back(task.first);
449 }
450 }
451 std::vector<TaskDef> new_tasks;
452 std::vector<int64> task_ids_to_delete;
453 TF_RETURN_IF_ERROR(dispatcher_->WorkerHeartbeat(
454 worker_address_, transfer_address_, current_tasks, new_tasks,
455 task_ids_to_delete));
456 std::vector<std::shared_ptr<Task>> tasks_to_delete;
457 {
458 mutex_lock l(mu_);
459 for (const auto& task : new_tasks) {
460 VLOG(1) << "Received new task from dispatcher with id " << task.task_id();
461 Status s = ProcessTaskInternal(task);
462 if (!s.ok() && !errors::IsAlreadyExists(s)) {
463 LOG(WARNING) << "Failed to start processing task " << task.task_id()
464 << ": " << s;
465 }
466 }
467 tasks_to_delete.reserve(task_ids_to_delete.size());
468 for (int64_t task_id : task_ids_to_delete) {
469 VLOG(3) << "Deleting task " << task_id
470 << " at the request of the dispatcher";
471 tasks_to_delete.push_back(std::move(tasks_[task_id]));
472 tasks_.erase(task_id);
473 finished_tasks_.insert(task_id);
474 }
475 }
476 for (const auto& task : tasks_to_delete) {
477 StopTask(*task);
478 }
479 return Status::OK();
480 }
481
Add(absl::string_view worker_address,std::shared_ptr<DataServiceWorkerImpl> worker)482 void LocalWorkers::Add(absl::string_view worker_address,
483 std::shared_ptr<DataServiceWorkerImpl> worker) {
484 DCHECK(worker != nullptr) << "Adding a nullptr local worker is disallowed.";
485 VLOG(1) << "Register local worker at address " << worker_address;
486 mutex_lock l(mu_);
487 (*local_workers_)[worker_address] = worker;
488 }
489
Get(absl::string_view worker_address)490 std::shared_ptr<DataServiceWorkerImpl> LocalWorkers::Get(
491 absl::string_view worker_address) {
492 tf_shared_lock l(mu_);
493 AddressToWorkerMap::const_iterator it = local_workers_->find(worker_address);
494 if (it == local_workers_->end()) {
495 return nullptr;
496 }
497 return it->second;
498 }
499
Empty()500 bool LocalWorkers::Empty() {
501 tf_shared_lock l(mu_);
502 return local_workers_->empty();
503 }
504
Remove(absl::string_view worker_address)505 void LocalWorkers::Remove(absl::string_view worker_address) {
506 VLOG(1) << "Remove local worker at address " << worker_address;
507 mutex_lock l(mu_);
508 local_workers_->erase(worker_address);
509 }
510
511 } // namespace data
512 } // namespace tensorflow
513