• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/service/dispatcher_impl.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #ifdef PLATFORM_GOOGLE
24 #include "file/logging/log_lines.h"
25 #endif
26 #include "grpcpp/create_channel.h"
27 #include "grpcpp/impl/codegen/server_context.h"
28 #include "grpcpp/security/credentials.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/memory/memory.h"
32 #include "absl/types/optional.h"
33 #include "tensorflow/core/data/dataset_utils.h"
34 #include "tensorflow/core/data/hash_utils.h"
35 #include "tensorflow/core/data/service/common.h"
36 #include "tensorflow/core/data/service/common.pb.h"
37 #include "tensorflow/core/data/service/credentials_factory.h"
38 #include "tensorflow/core/data/service/dataset_store.h"
39 #include "tensorflow/core/data/service/dispatcher.pb.h"
40 #include "tensorflow/core/data/service/dispatcher_state.h"
41 #include "tensorflow/core/data/service/grpc_util.h"
42 #include "tensorflow/core/data/service/journal.h"
43 #include "tensorflow/core/data/service/worker.grpc.pb.h"
44 #include "tensorflow/core/data/standalone.h"
45 #include "tensorflow/core/framework/dataset.h"
46 #include "tensorflow/core/framework/graph.pb.h"
47 #include "tensorflow/core/framework/node_def.pb.h"
48 #include "tensorflow/core/framework/tensor.h"
49 #include "tensorflow/core/platform/env.h"
50 #include "tensorflow/core/platform/errors.h"
51 #include "tensorflow/core/platform/mutex.h"
52 #include "tensorflow/core/platform/path.h"
53 #include "tensorflow/core/platform/protobuf.h"
54 #include "tensorflow/core/platform/status.h"
55 #include "tensorflow/core/platform/thread_annotations.h"
56 #include "tensorflow/core/protobuf/data_service.pb.h"
57 #include "tensorflow/core/protobuf/service_config.pb.h"
58 #include "tensorflow/core/public/session_options.h"
59 
60 namespace tensorflow {
61 namespace data {
62 namespace {
63 
64 using ::tensorflow::protobuf::util::MessageDifferencer;
65 
66 // The name of the journal directory inside the dispatcher's working directory.
67 // This name is load-bearing; do not change.
68 constexpr char kJournalDir[] = "tf_data_dispatcher_journal";
69 // The name of the datasets directory inside the dispatcher's working directory.
70 constexpr char kDatasetsDir[] = "datasets";
71 
72 constexpr std::array<const char*, 8> kNodeNameSharingOps = {
73     "HashTable",
74     "HashTableV2",
75     "MutableHashTable",
76     "MutableHashTableV2",
77     "MutableDenseHashTable",
78     "MutableDenseHashTableV2",
79     "MutableHashTableOfTensors",
80     "MutableHashTableOfTensorsV2",
81 };
82 
83 using Dataset = DispatcherState::Dataset;
84 using Worker = DispatcherState::Worker;
85 using NamedJobKey = DispatcherState::NamedJobKey;
86 using Job = DispatcherState::Job;
87 using Task = DispatcherState::Task;
88 
JournalDir(const std::string & work_dir)89 std::string JournalDir(const std::string& work_dir) {
90   return io::JoinPath(work_dir, kJournalDir);
91 }
92 
DatasetsDir(const std::string & work_dir)93 std::string DatasetsDir(const std::string& work_dir) {
94   return io::JoinPath(work_dir, kDatasetsDir);
95 }
96 
DatasetKey(int64_t id,uint64 fingerprint)97 std::string DatasetKey(int64_t id, uint64 fingerprint) {
98   return absl::StrCat("id_", id, "_fp_", fingerprint);
99 }
100 
CreateWorkerStub(const std::string & address,const std::string & protocol,std::unique_ptr<WorkerService::Stub> & stub)101 Status CreateWorkerStub(const std::string& address, const std::string& protocol,
102                         std::unique_ptr<WorkerService::Stub>& stub) {
103   ::grpc::ChannelArguments args;
104   args.SetMaxReceiveMessageSize(-1);
105   std::shared_ptr<::grpc::ChannelCredentials> credentials;
106   TF_RETURN_IF_ERROR(
107       CredentialsFactory::CreateClientCredentials(protocol, &credentials));
108   auto channel = ::grpc::CreateCustomChannel(address, credentials, args);
109   stub = WorkerService::NewStub(channel);
110   return Status::OK();
111 }
112 
PrepareGraph(GraphDef * graph)113 void PrepareGraph(GraphDef* graph) {
114   for (NodeDef& node : *graph->mutable_node()) {
115     for (const auto& op : kNodeNameSharingOps) {
116       // Set `use_node_name_sharing` to `true` so that resources aren't deleted
117       // prematurely. Otherwise, resources may be deleted when their ops are
118       // deleted at the end of the GraphRunner::Run used by standalone::Dataset.
119       if (node.op() == op) {
120         (*node.mutable_attr())["use_node_name_sharing"].set_b(true);
121       }
122       if (!node.device().empty()) {
123         *node.mutable_device() = "";
124       }
125     }
126   }
127   StripDevicePlacement(graph->mutable_library());
128 }
129 }  // namespace
130 
DataServiceDispatcherImpl(const experimental::DispatcherConfig & config)131 DataServiceDispatcherImpl::DataServiceDispatcherImpl(
132     const experimental::DispatcherConfig& config)
133     : config_(config), env_(Env::Default()), state_(config_) {
134   if (config_.work_dir().empty()) {
135     dataset_store_ = absl::make_unique<MemoryDatasetStore>();
136   } else {
137     dataset_store_ = absl::make_unique<FileSystemDatasetStore>(
138         DatasetsDir(config_.work_dir()));
139   }
140 }
141 
~DataServiceDispatcherImpl()142 DataServiceDispatcherImpl::~DataServiceDispatcherImpl() {
143   {
144     mutex_lock l(mu_);
145     cancelled_ = true;
146     job_gc_thread_cv_.notify_all();
147   }
148   job_gc_thread_.reset();
149 }
150 
Start()151 Status DataServiceDispatcherImpl::Start() {
152   mutex_lock l(mu_);
153   if (config_.job_gc_timeout_ms() >= 0) {
154     job_gc_thread_ = absl::WrapUnique(
155         env_->StartThread({}, "job-gc-thread", [&] { JobGcThread(); }));
156   }
157   if (config_.work_dir().empty()) {
158     if (config_.fault_tolerant_mode()) {
159       return errors::InvalidArgument(
160           "fault_tolerant_mode is True, but no work_dir is configured.");
161     }
162   } else {
163     TF_RETURN_IF_ERROR(
164         env_->RecursivelyCreateDir(DatasetsDir(config_.work_dir())));
165   }
166   if (!config_.fault_tolerant_mode()) {
167     LOG(INFO) << "Running with fault_tolerant_mode=False. The dispatcher will "
168                  "not be able to recover its state on restart.";
169     started_ = true;
170     return Status::OK();
171   }
172   journal_writer_ = absl::make_unique<FileJournalWriter>(
173       env_, JournalDir(config_.work_dir()));
174   LOG(INFO) << "Attempting to restore dispatcher state from journal in "
175             << JournalDir(config_.work_dir());
176   Update update;
177   bool end_of_journal = false;
178   FileJournalReader reader(env_, JournalDir(config_.work_dir()));
179   Status s = reader.Read(update, end_of_journal);
180   if (errors::IsNotFound(s)) {
181     LOG(INFO) << "No journal found. Starting dispatcher from new state.";
182   } else if (!s.ok()) {
183     return s;
184   } else {
185     while (!end_of_journal) {
186       TF_RETURN_IF_ERROR(ApplyWithoutJournaling(update));
187       TF_RETURN_IF_ERROR(reader.Read(update, end_of_journal));
188     }
189   }
190   for (const auto& job : state_.ListJobs()) {
191     if (IsDynamicShard(job->processing_mode)) {
192       TF_RETURN_IF_ERROR(
193           RestoreSplitProviders(*job, split_providers_[job->job_id]));
194     }
195   }
196   // Initialize the journal writer in `Start` so that we fail fast in case it
197   // can't be initialized.
198   TF_RETURN_IF_ERROR(journal_writer_.value()->EnsureInitialized());
199   started_ = true;
200   return Status::OK();
201 }
202 
RestoreSplitProviders(const Job & job,std::vector<std::unique_ptr<SplitProvider>> & restored)203 Status DataServiceDispatcherImpl::RestoreSplitProviders(
204     const Job& job, std::vector<std::unique_ptr<SplitProvider>>& restored)
205     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
206   const std::vector<int64>& indices =
207       job.distributed_epoch_state.value().indices;
208   std::vector<std::unique_ptr<SplitProvider>> split_providers;
209   TF_RETURN_IF_ERROR(MakeSplitProviders(job.dataset_id, split_providers));
210   for (int provider_index = 0; provider_index < indices.size();
211        ++provider_index) {
212     int index = indices[provider_index];
213     VLOG(1) << "Restoring split provider " << provider_index << " for job "
214             << job.job_id << " to index " << index;
215     Tensor unused_tensor;
216     bool unused_end_of_splits;
217     for (int i = 0; i < index; ++i) {
218       TF_RETURN_IF_ERROR(split_providers[provider_index]->GetNext(
219           &unused_tensor, &unused_end_of_splits));
220     }
221   }
222   restored = std::move(split_providers);
223   return Status::OK();
224 }
225 
FindTasksToDelete(const absl::flat_hash_set<int64> & current_tasks,const std::vector<std::shared_ptr<const Task>> assigned_tasks,WorkerHeartbeatResponse * response)226 Status DataServiceDispatcherImpl::FindTasksToDelete(
227     const absl::flat_hash_set<int64>& current_tasks,
228     const std::vector<std::shared_ptr<const Task>> assigned_tasks,
229     WorkerHeartbeatResponse* response) {
230   absl::flat_hash_set<int64> assigned_ids;
231   for (const auto& assigned : assigned_tasks) {
232     assigned_ids.insert(assigned->task_id);
233   }
234   for (int64_t current_task : current_tasks) {
235     if (!assigned_ids.contains(current_task)) {
236       response->add_tasks_to_delete(current_task);
237     }
238   }
239   return Status::OK();
240 }
241 
FindNewTasks(const std::string & worker_address,const absl::flat_hash_set<int64> & current_tasks,std::vector<std::shared_ptr<const Task>> & assigned_tasks,WorkerHeartbeatResponse * response)242 Status DataServiceDispatcherImpl::FindNewTasks(
243     const std::string& worker_address,
244     const absl::flat_hash_set<int64>& current_tasks,
245     std::vector<std::shared_ptr<const Task>>& assigned_tasks,
246     WorkerHeartbeatResponse* response) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
247   // Check for round-robin jobs that had tasks on the worker removed. Now that
248   // the worker is back, we create a new pending task for the worker.
249   absl::flat_hash_set<int64> assigned_job_ids;
250   for (const auto& task : assigned_tasks) {
251     assigned_job_ids.insert(task->job->job_id);
252   }
253   for (const auto& job : state_.ListJobs()) {
254     if (!assigned_job_ids.contains(job->job_id) && job->IsRoundRobin() &&
255         !job->finished) {
256       VLOG(1) << "Creating pending task for reconnected worker "
257               << worker_address;
258       TF_RETURN_IF_ERROR(CreatePendingTask(job, worker_address));
259     }
260   }
261   // Refresh assigned_tasks to include newly added pending tasks.
262   TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, assigned_tasks));
263   for (const auto& task : assigned_tasks) {
264     if (current_tasks.contains(task->task_id)) {
265       continue;
266     }
267     TaskDef* task_def = response->add_new_tasks();
268     TF_RETURN_IF_ERROR(PopulateTaskDef(task, task_def));
269   }
270   return Status::OK();
271 }
272 
WorkerHeartbeat(const WorkerHeartbeatRequest * request,WorkerHeartbeatResponse * response)273 Status DataServiceDispatcherImpl::WorkerHeartbeat(
274     const WorkerHeartbeatRequest* request, WorkerHeartbeatResponse* response) {
275   TF_RETURN_IF_ERROR(CheckStarted());
276   VLOG(4) << "Received worker heartbeat request from worker "
277           << request->worker_address();
278   mutex_lock l(mu_);
279   const std::string& worker_address = request->worker_address();
280   // Assigned tasks from the perspective of the dispatcher.
281   std::vector<std::shared_ptr<const Task>> assigned_tasks;
282   Status s = state_.TasksForWorker(worker_address, assigned_tasks);
283   if (!s.ok()) {
284     if (!errors::IsNotFound(s)) {
285       return s;
286     }
287     VLOG(1) << "Registering new worker at address " << worker_address;
288     TF_RETURN_IF_ERROR(state_.ValidateWorker(worker_address));
289     Update update;
290     update.mutable_register_worker()->set_worker_address(worker_address);
291     update.mutable_register_worker()->set_transfer_address(
292         request->transfer_address());
293     TF_RETURN_IF_ERROR(Apply(update));
294     TF_RETURN_IF_ERROR(CreateTasksForWorker(worker_address));
295     TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, assigned_tasks));
296   }
297   absl::flat_hash_set<int64> current_tasks;
298   current_tasks.insert(request->current_tasks().cbegin(),
299                        request->current_tasks().cend());
300   TF_RETURN_IF_ERROR(
301       FindTasksToDelete(current_tasks, assigned_tasks, response));
302   TF_RETURN_IF_ERROR(
303       FindNewTasks(worker_address, current_tasks, assigned_tasks, response));
304 
305   VLOG(4) << "Finished worker heartbeat for worker at address "
306           << request->worker_address();
307   return Status::OK();
308 }
309 
WorkerUpdate(const WorkerUpdateRequest * request,WorkerUpdateResponse * response)310 Status DataServiceDispatcherImpl::WorkerUpdate(
311     const WorkerUpdateRequest* request, WorkerUpdateResponse* response) {
312   TF_RETURN_IF_ERROR(CheckStarted());
313   mutex_lock l(mu_);
314   for (auto& update : request->updates()) {
315     int64_t task_id = update.task_id();
316     std::shared_ptr<const Task> task;
317     TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
318     if (update.completed()) {
319       if (task->finished) {
320         VLOG(1) << "Received completion update for already-finished task "
321                 << task->task_id << " on worker " << task->worker_address;
322         continue;
323       }
324       Update update;
325       update.mutable_finish_task()->set_task_id(task_id);
326       TF_RETURN_IF_ERROR(Apply(update));
327       VLOG(3) << "Task " << task_id << " from job " << task->job->job_id
328               << " completed";
329     }
330   }
331   return Status::OK();
332 }
333 
GetDatasetDef(const GetDatasetDefRequest * request,GetDatasetDefResponse * response)334 Status DataServiceDispatcherImpl::GetDatasetDef(
335     const GetDatasetDefRequest* request, GetDatasetDefResponse* response) {
336   TF_RETURN_IF_ERROR(CheckStarted());
337   mutex_lock l(mu_);
338   std::shared_ptr<const Dataset> dataset;
339   TF_RETURN_IF_ERROR(state_.DatasetFromId(request->dataset_id(), dataset));
340   std::shared_ptr<const DatasetDef> dataset_def;
341   TF_RETURN_IF_ERROR(GetDatasetDef(*dataset, dataset_def));
342   *response->mutable_dataset_def() = *dataset_def;
343   return Status::OK();
344 }
345 
GetSplit(const GetSplitRequest * request,GetSplitResponse * response)346 Status DataServiceDispatcherImpl::GetSplit(const GetSplitRequest* request,
347                                            GetSplitResponse* response) {
348   TF_RETURN_IF_ERROR(CheckStarted());
349   mutex_lock l(mu_);
350   int64_t job_id = request->job_id();
351   int64_t repetition = request->repetition();
352   int64_t provider_index = request->split_provider_index();
353   VLOG(3) << "Received GetSplit request for job " << job_id << ", repetition "
354           << repetition << ", split provider index " << provider_index;
355   std::shared_ptr<const Job> job;
356   TF_RETURN_IF_ERROR(state_.JobFromId(job_id, job));
357   if (!job->distributed_epoch_state.has_value()) {
358     return errors::FailedPrecondition(
359         "Cannot get split for job ", job_id,
360         ", since it is not a distributed_epoch job.");
361   }
362   int64_t current_repetition =
363       job->distributed_epoch_state.value().repetitions[provider_index];
364   if (repetition < current_repetition) {
365     response->set_end_of_splits(true);
366     VLOG(3) << "Returning end_of_splits since current reptition "
367             << current_repetition << " is greater than the requested reptition "
368             << repetition;
369     return Status::OK();
370   }
371   SplitProvider* split_provider =
372       split_providers_[job_id][provider_index].get();
373   DCHECK(split_provider != nullptr);
374   Tensor split;
375   bool end_of_splits = false;
376   TF_RETURN_IF_ERROR(split_provider->GetNext(&split, &end_of_splits));
377   TF_RETURN_IF_ERROR(RecordSplitProduced(
378       job_id, repetition, request->split_provider_index(), end_of_splits));
379   response->set_end_of_splits(end_of_splits);
380   if (end_of_splits) {
381     // Reset the split provider to prepare for the next repetition.
382     TF_RETURN_IF_ERROR(split_providers_[job_id][provider_index]->Reset());
383   } else {
384     split.AsProtoTensorContent(response->mutable_split());
385   }
386   VLOG(3) << "Returning from GetSplit, end_of_splits=" << end_of_splits;
387   return Status::OK();
388 }
389 
MakeSplitProviders(int64_t dataset_id,std::vector<std::unique_ptr<SplitProvider>> & split_providers)390 Status DataServiceDispatcherImpl::MakeSplitProviders(
391     int64_t dataset_id,
392     std::vector<std::unique_ptr<SplitProvider>>& split_providers)
393     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
394   std::shared_ptr<const Dataset> dataset;
395   TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
396   std::shared_ptr<const DatasetDef> dataset_def;
397   TF_RETURN_IF_ERROR(GetDatasetDef(*dataset, dataset_def));
398   standalone::Dataset::Params params;
399   std::unique_ptr<standalone::Dataset> standalone_dataset;
400   TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
401       params, dataset_def->graph(), &standalone_dataset));
402   TF_RETURN_IF_ERROR(standalone_dataset->MakeSplitProviders(&split_providers));
403   return Status::OK();
404 }
405 
GetVersion(const GetVersionRequest * request,GetVersionResponse * response)406 Status DataServiceDispatcherImpl::GetVersion(const GetVersionRequest* request,
407                                              GetVersionResponse* response) {
408   response->set_version(kDataServiceVersion);
409   return Status::OK();
410 }
411 
GetOrRegisterDataset(const GetOrRegisterDatasetRequest * request,GetOrRegisterDatasetResponse * response)412 Status DataServiceDispatcherImpl::GetOrRegisterDataset(
413     const GetOrRegisterDatasetRequest* request,
414     GetOrRegisterDatasetResponse* response) {
415   TF_RETURN_IF_ERROR(CheckStarted());
416   uint64 fingerprint;
417   DatasetDef dataset_def = request->dataset();
418   GraphDef* graph = dataset_def.mutable_graph();
419   PrepareGraph(graph);
420   TF_RETURN_IF_ERROR(HashGraph(*graph, &fingerprint));
421 
422   mutex_lock l(mu_);
423 #if defined(PLATFORM_GOOGLE)
424   VLOG_LINES(4,
425              absl::StrCat("Registering dataset graph: ", graph->DebugString()));
426 #else
427   VLOG(4) << "Registering dataset graph: " << graph->DebugString();
428 #endif
429   std::shared_ptr<const Dataset> dataset;
430   Status s = state_.DatasetFromFingerprint(fingerprint, dataset);
431   if (s.ok()) {
432     int64_t id = dataset->dataset_id;
433     VLOG(3) << "Received duplicate RegisterDataset request with fingerprint "
434             << fingerprint << ". Returning id " << id;
435     response->set_dataset_id(id);
436     return Status::OK();
437   } else if (!errors::IsNotFound(s)) {
438     return s;
439   }
440 
441   int64_t id;
442   TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, dataset_def, id));
443   if (!request->element_spec().empty()) {
444     TF_RETURN_IF_ERROR(SetElementSpec(id, request->element_spec()));
445   }
446 
447   response->set_dataset_id(id);
448   VLOG(3) << "Registered new dataset with id " << id;
449   return Status::OK();
450 }
451 
RegisterDataset(uint64 fingerprint,const DatasetDef & dataset,int64 & dataset_id)452 Status DataServiceDispatcherImpl::RegisterDataset(uint64 fingerprint,
453                                                   const DatasetDef& dataset,
454                                                   int64& dataset_id)
455     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
456   dataset_id = state_.NextAvailableDatasetId();
457   Update update;
458   RegisterDatasetUpdate* register_dataset = update.mutable_register_dataset();
459   register_dataset->set_dataset_id(dataset_id);
460   register_dataset->set_fingerprint(fingerprint);
461   TF_RETURN_IF_ERROR(
462       dataset_store_->Put(DatasetKey(dataset_id, fingerprint), dataset));
463   return Apply(update);
464 }
465 
SetElementSpec(int64_t dataset_id,const std::string & element_spec)466 Status DataServiceDispatcherImpl::SetElementSpec(
467     int64_t dataset_id, const std::string& element_spec)
468     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
469   Update update;
470   SetElementSpecUpdate* set_element_spec = update.mutable_set_element_spec();
471   set_element_spec->set_dataset_id(dataset_id);
472   set_element_spec->set_element_spec(element_spec);
473   TF_RETURN_IF_ERROR(Apply(update));
474   return Status::OK();
475 }
476 
GetElementSpec(const GetElementSpecRequest * request,GetElementSpecResponse * response)477 Status DataServiceDispatcherImpl::GetElementSpec(
478     const GetElementSpecRequest* request, GetElementSpecResponse* response) {
479   TF_RETURN_IF_ERROR(CheckStarted());
480   mutex_lock l(mu_);
481   VLOG(4) << "Read the element spec.";
482   int64_t dataset_id = request->dataset_id();
483 
484   std::string element_spec;
485   TF_RETURN_IF_ERROR(state_.GetElementSpec(dataset_id, element_spec));
486   VLOG(3) << "Get the `element_spec` for registered dataset with dataset id: "
487           << dataset_id << ".";
488   *response->mutable_element_spec() = element_spec;
489   return Status::OK();
490 }
491 
GetOrCreateJob(const GetOrCreateJobRequest * request,GetOrCreateJobResponse * response)492 Status DataServiceDispatcherImpl::GetOrCreateJob(
493     const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
494   TF_RETURN_IF_ERROR(CheckStarted());
495   VLOG(3) << "GetOrCreateJob(" << request->DebugString() << ")";
496   absl::optional<NamedJobKey> key;
497   if (request->has_job_key()) {
498     key.emplace(request->job_key().job_name(),
499                 request->job_key().job_name_index());
500   }
501   std::shared_ptr<const Job> job;
502   std::vector<std::shared_ptr<const Task>> tasks;
503   {
504     mutex_lock l(mu_);
505     if (key.has_value()) {
506       Status s = state_.NamedJobByKey(key.value(), job);
507       if (s.ok()) {
508         TF_RETURN_IF_ERROR(ValidateMatchingJob(job, *request));
509         // If the matching job was already garbage-collected, we fall through to
510         // re-create the job.
511         if (!job->garbage_collected) {
512           int64_t job_client_id;
513           TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id));
514           response->set_job_client_id(job_client_id);
515           VLOG(3) << "Found existing job for name=" << key.value().name
516                   << ", index=" << key.value().index
517                   << ". job_id: " << job->job_id;
518           return Status::OK();
519         }
520       } else if (!errors::IsNotFound(s)) {
521         return s;
522       }
523     }
524     TF_RETURN_IF_ERROR(CreateJob(*request, job));
525     int64_t job_client_id;
526     TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id));
527     response->set_job_client_id(job_client_id);
528     TF_RETURN_IF_ERROR(CreateTasksForJob(job, tasks));
529   }
530   TF_RETURN_IF_ERROR(AssignTasks(tasks));
531   VLOG(3) << "Created job " << job->job_id << " for CreateJob("
532           << request->DebugString() << ")";
533   return Status::OK();
534 }
535 
MaybeRemoveTask(const MaybeRemoveTaskRequest * request,MaybeRemoveTaskResponse * response)536 Status DataServiceDispatcherImpl::MaybeRemoveTask(
537     const MaybeRemoveTaskRequest* request, MaybeRemoveTaskResponse* response) {
538   VLOG(1) << "Attempting to remove task. Request: " << request->DebugString();
539   std::shared_ptr<TaskRemover> remover;
540   std::shared_ptr<const Task> task;
541   {
542     mutex_lock l(mu_);
543     Status s = state_.TaskFromId(request->task_id(), task);
544     if (errors::IsNotFound(s)) {
545       // Task is already removed.
546       response->set_removed(true);
547       return Status::OK();
548     }
549     TF_RETURN_IF_ERROR(s);
550     auto& remover_ref = remove_task_requests_[task->task_id];
551     if (remover_ref == nullptr) {
552       if (!task->job->IsRoundRobin()) {
553         return errors::FailedPrecondition(
554             "MaybeRemoveTask called on a non-round-robin task.");
555       }
556       remover_ref =
557           std::make_shared<TaskRemover>(task->job->num_consumers.value());
558     }
559     remover = remover_ref;
560   }
561   bool removed =
562       remover->RequestRemoval(request->consumer_index(), request->round());
563   response->set_removed(removed);
564   if (!removed) {
565     VLOG(1) << "Failed to remove task " << task->task_id;
566     return Status::OK();
567   }
568   mutex_lock l(mu_);
569   if (!task->removed) {
570     Update update;
571     RemoveTaskUpdate* remove_task = update.mutable_remove_task();
572     remove_task->set_task_id(request->task_id());
573     TF_RETURN_IF_ERROR(Apply(update));
574   }
575   VLOG(1) << "Task " << task->task_id << " successfully removed";
576   return Status::OK();
577 }
578 
ReleaseJobClient(const ReleaseJobClientRequest * request,ReleaseJobClientResponse * response)579 Status DataServiceDispatcherImpl::ReleaseJobClient(
580     const ReleaseJobClientRequest* request,
581     ReleaseJobClientResponse* response) {
582   TF_RETURN_IF_ERROR(CheckStarted());
583   mutex_lock l(mu_);
584   int64_t job_client_id = request->job_client_id();
585   std::shared_ptr<const Job> job;
586   TF_RETURN_IF_ERROR(state_.JobForJobClientId(job_client_id, job));
587   Update update;
588   ReleaseJobClientUpdate* release_job_client =
589       update.mutable_release_job_client();
590   release_job_client->set_job_client_id(job_client_id);
591   release_job_client->set_time_micros(env_->NowMicros());
592   TF_RETURN_IF_ERROR(Apply(update));
593   return Status::OK();
594 }
595 
596 // Validates that the job matches the requested processing mode.
ValidateMatchingJob(std::shared_ptr<const Job> job,const GetOrCreateJobRequest & request)597 Status DataServiceDispatcherImpl::ValidateMatchingJob(
598     std::shared_ptr<const Job> job, const GetOrCreateJobRequest& request)
599     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
600   DCHECK(job->named_job_key.has_value());
601   std::string job_name = job->named_job_key->name;
602 
603   if (!MessageDifferencer::Equals(job->processing_mode,
604                                   request.processing_mode_def())) {
605     return errors::FailedPrecondition(
606         "Tried to create a job with name ", job_name, " and processing_mode <",
607         request.processing_mode_def().ShortDebugString(),
608         "> but there is already an existing job with that name using "
609         "processing mode <",
610         job->processing_mode.ShortDebugString(), ">");
611   }
612 
613   if (job->target_workers != request.target_workers()) {
614     return errors::InvalidArgument(
615         "Tried to create job with name ", job_name, " and target_workers <",
616         TargetWorkersToString(request.target_workers()),
617         ">, but there is already an existing job "
618         "with that name using target_workers <",
619         TargetWorkersToString(job->target_workers), ">.");
620   }
621   return Status::OK();
622 }
623 
CreateJob(const GetOrCreateJobRequest & request,std::shared_ptr<const Job> & job)624 Status DataServiceDispatcherImpl::CreateJob(
625     const GetOrCreateJobRequest& request, std::shared_ptr<const Job>& job)
626     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
627   TF_RETURN_IF_ERROR(ValidateProcessingMode(request.processing_mode_def()));
628   int64_t job_id = state_.NextAvailableJobId();
629   int64_t num_split_providers = 0;
630   if (IsDynamicShard(request.processing_mode_def())) {
631     TF_RETURN_IF_ERROR(
632         MakeSplitProviders(request.dataset_id(), split_providers_[job_id]));
633     num_split_providers = split_providers_[job_id].size();
634   }
635   Update update;
636   CreateJobUpdate* create_job = update.mutable_create_job();
637   create_job->set_job_id(job_id);
638   create_job->set_dataset_id(request.dataset_id());
639   *create_job->mutable_processing_mode_def() = request.processing_mode_def();
640   create_job->set_num_split_providers(num_split_providers);
641   if (request.has_job_key()) {
642     NamedJobKeyDef* key = create_job->mutable_named_job_key();
643     key->set_name(request.job_key().job_name());
644     key->set_index(request.job_key().job_name_index());
645   }
646   if (request.optional_num_consumers_case() ==
647       GetOrCreateJobRequest::kNumConsumers) {
648     create_job->set_num_consumers(request.num_consumers());
649   }
650   create_job->set_target_workers(request.target_workers());
651   TF_RETURN_IF_ERROR(Apply(update));
652   TF_RETURN_IF_ERROR(state_.JobFromId(job_id, job));
653   return Status::OK();
654 }
655 
CreateTasksForWorker(const std::string & worker_address)656 Status DataServiceDispatcherImpl::CreateTasksForWorker(
657     const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
658   std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
659   for (const auto& job : jobs) {
660     if (job->finished) {
661       continue;
662     }
663     if (job->num_consumers.has_value()) {
664       TF_RETURN_IF_ERROR(CreatePendingTask(job, worker_address));
665       continue;
666     }
667     std::shared_ptr<const Task> task;
668     TF_RETURN_IF_ERROR(CreateTask(job, worker_address, task));
669   }
670   return Status::OK();
671 }
672 
AcquireJobClientId(const std::shared_ptr<const Job> & job,int64 & job_client_id)673 Status DataServiceDispatcherImpl::AcquireJobClientId(
674     const std::shared_ptr<const Job>& job, int64& job_client_id)
675     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
676   job_client_id = state_.NextAvailableJobClientId();
677   Update update;
678   AcquireJobClientUpdate* acquire_job_client =
679       update.mutable_acquire_job_client();
680   acquire_job_client->set_job_client_id(job_client_id);
681   acquire_job_client->set_job_id(job->job_id);
682   TF_RETURN_IF_ERROR(Apply(update));
683   return Status::OK();
684 }
685 
CreateTasksForJob(std::shared_ptr<const Job> job,std::vector<std::shared_ptr<const Task>> & tasks)686 Status DataServiceDispatcherImpl::CreateTasksForJob(
687     std::shared_ptr<const Job> job,
688     std::vector<std::shared_ptr<const Task>>& tasks)
689     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
690   std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
691   tasks.clear();
692   tasks.reserve(workers.size());
693   for (const auto& worker : workers) {
694     std::shared_ptr<const Task> task;
695     TF_RETURN_IF_ERROR(CreateTask(job, worker->address, task));
696     tasks.push_back(task);
697   }
698   return Status::OK();
699 }
700 
CreatePendingTask(std::shared_ptr<const Job> job,const std::string & worker_address)701 Status DataServiceDispatcherImpl::CreatePendingTask(
702     std::shared_ptr<const Job> job, const std::string& worker_address)
703     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
704   int64_t task_id = state_.NextAvailableTaskId();
705   Update update;
706   CreatePendingTaskUpdate* create_task = update.mutable_create_pending_task();
707   create_task->set_task_id(task_id);
708   create_task->set_job_id(job->job_id);
709   create_task->set_worker_address(worker_address);
710   create_task->set_starting_round(round_robin_rounds_[job->job_id] + 1);
711   std::shared_ptr<const Worker> worker;
712   TF_RETURN_IF_ERROR(state_.WorkerFromAddress(worker_address, worker));
713   create_task->set_transfer_address(worker->transfer_address);
714   TF_RETURN_IF_ERROR(Apply(update));
715   return Status::OK();
716 }
717 
CreateTask(std::shared_ptr<const Job> job,const std::string & worker_address,std::shared_ptr<const Task> & task)718 Status DataServiceDispatcherImpl::CreateTask(std::shared_ptr<const Job> job,
719                                              const std::string& worker_address,
720                                              std::shared_ptr<const Task>& task)
721     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
722   int64_t task_id = state_.NextAvailableTaskId();
723   Update update;
724   CreateTaskUpdate* create_task = update.mutable_create_task();
725   create_task->set_task_id(task_id);
726   create_task->set_job_id(job->job_id);
727   create_task->set_worker_address(worker_address);
728   std::shared_ptr<const Worker> worker;
729   TF_RETURN_IF_ERROR(state_.WorkerFromAddress(worker_address, worker));
730   create_task->set_transfer_address(worker->transfer_address);
731   TF_RETURN_IF_ERROR(Apply(update));
732   TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
733   return Status::OK();
734 }
735 
AssignTasks(std::vector<std::shared_ptr<const Task>> tasks)736 Status DataServiceDispatcherImpl::AssignTasks(
737     std::vector<std::shared_ptr<const Task>> tasks) TF_LOCKS_EXCLUDED(mu_) {
738   for (const auto& task : tasks) {
739     TF_RETURN_IF_ERROR(AssignTask(task));
740   }
741   return Status::OK();
742 }
743 
GetOrCreateWorkerStub(const std::string & worker_address,WorkerService::Stub * & out_stub)744 Status DataServiceDispatcherImpl::GetOrCreateWorkerStub(
745     const std::string& worker_address, WorkerService::Stub*& out_stub)
746     TF_LOCKS_EXCLUDED(mu_) {
747   {
748     mutex_lock l(mu_);
749     auto it = worker_stubs_.find(worker_address);
750     if (it != worker_stubs_.end()) {
751       out_stub = it->second.get();
752       return Status::OK();
753     }
754   }
755   std::unique_ptr<WorkerService::Stub> stub;
756   TF_RETURN_IF_ERROR(
757       CreateWorkerStub(worker_address, config_.protocol(), stub));
758   {
759     mutex_lock l(mu_);
760     // A concurrent call could have already created the stub.
761     auto& worker = worker_stubs_[worker_address];
762     if (worker == nullptr) {
763       worker = std::move(stub);
764     }
765     out_stub = worker.get();
766   }
767   return Status::OK();
768 }
769 
AssignTask(std::shared_ptr<const Task> task)770 Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
771     TF_LOCKS_EXCLUDED(mu_) {
772   VLOG(2) << "Started assigning task " << task->task_id << " to worker "
773           << task->worker_address;
774   grpc::ClientContext client_ctx;
775   ProcessTaskRequest req;
776   TaskDef* task_def = req.mutable_task();
777   {
778     mutex_lock l(mu_);
779     TF_RETURN_IF_ERROR(PopulateTaskDef(task, task_def));
780   }
781   ProcessTaskResponse resp;
782   WorkerService::Stub* stub;
783   TF_RETURN_IF_ERROR(GetOrCreateWorkerStub(task->worker_address, stub));
784   grpc::Status s = stub->ProcessTask(&client_ctx, req, &resp);
785   if (!s.ok()) {
786     if (s.error_code() == grpc::StatusCode::UNAVAILABLE ||
787         s.error_code() == grpc::StatusCode::ABORTED ||
788         s.error_code() == grpc::StatusCode::CANCELLED) {
789       // Worker is presumably preempted. We will assign the task to the worker
790       // when it reconnects.
791       return Status::OK();
792     }
793     return grpc_util::WrapError(
794         absl::StrCat("Failed to submit task to worker ", task->worker_address),
795         s);
796   }
797   VLOG(2) << "Finished assigning task " << task->task_id << " to worker "
798           << task->worker_address;
799   return Status::OK();
800 }
801 
ClientHeartbeat(const ClientHeartbeatRequest * request,ClientHeartbeatResponse * response)802 Status DataServiceDispatcherImpl::ClientHeartbeat(
803     const ClientHeartbeatRequest* request, ClientHeartbeatResponse* response) {
804   TF_RETURN_IF_ERROR(CheckStarted());
805   mutex_lock l(mu_);
806   VLOG(4) << "Received heartbeat from client id " << request->job_client_id();
807   std::shared_ptr<const Job> job;
808   Status s = state_.JobForJobClientId(request->job_client_id(), job);
809   if (errors::IsNotFound(s) && !config_.fault_tolerant_mode()) {
810     return errors::NotFound(
811         "Unknown job client id ", request->job_client_id(),
812         ". The dispatcher is not configured to be fault tolerant, so this "
813         "could be caused by a dispatcher restart.");
814   }
815   TF_RETURN_IF_ERROR(s);
816   if (job->garbage_collected) {
817     return errors::FailedPrecondition(
818         "The requested job has been garbage collected due to inactivity. "
819         "Consider configuring the dispatcher with a higher "
820         "`job_gc_timeout_ms`.");
821   }
822   if (request->optional_current_round_case() ==
823       ClientHeartbeatRequest::kCurrentRound) {
824     round_robin_rounds_[request->job_client_id()] =
825         std::max(round_robin_rounds_[request->job_client_id()],
826                  request->current_round());
827   }
828   if (!job->pending_tasks.empty()) {
829     const auto& task = job->pending_tasks.front();
830     Update update;
831     ClientHeartbeatUpdate* client_heartbeat = update.mutable_client_heartbeat();
832     bool apply_update = false;
833     client_heartbeat->set_job_client_id(request->job_client_id());
834     absl::optional<int64> blocked_round;
835     if (request->optional_blocked_round_case() ==
836         ClientHeartbeatRequest::kBlockedRound) {
837       blocked_round = request->blocked_round();
838     }
839     VLOG(1) << "Handling pending task in job client heartbeat. job_client_id: "
840             << request->job_client_id()
841             << ". current_round: " << request->current_round()
842             << ". blocked_round: " << blocked_round.value_or(-1)
843             << ". target_round: " << task.target_round;
844     if (request->current_round() >= task.target_round) {
845       TaskRejected* rejected = client_heartbeat->mutable_task_rejected();
846       // Exponentially try later and later rounds until consumers all agree.
847       int64_t round_offset = 2;
848       for (int i = 0; i < task.failures; ++i) {
849         round_offset *= 2;
850       }
851       rejected->set_new_target_round(
852           round_robin_rounds_[request->job_client_id()] + round_offset);
853       apply_update = true;
854     }
855     if (blocked_round.has_value() &&
856         blocked_round.value() <= task.target_round &&
857         !task.ready_consumers.contains(request->job_client_id())) {
858       client_heartbeat->set_task_accepted(true);
859       apply_update = true;
860     }
861     if (apply_update) {
862       TF_RETURN_IF_ERROR(Apply(update));
863     }
864   }
865   if (!job->pending_tasks.empty()) {
866     response->set_block_round(job->pending_tasks.front().target_round);
867   }
868 
869   std::vector<std::shared_ptr<const Task>> tasks;
870   TF_RETURN_IF_ERROR(state_.TasksForJob(job->job_id, tasks));
871   for (const auto& task : tasks) {
872     TaskInfo* task_info = response->mutable_task_info()->Add();
873     task_info->set_worker_address(task->worker_address);
874     task_info->set_transfer_address(task->transfer_address);
875     task_info->set_task_id(task->task_id);
876     task_info->set_job_id(job->job_id);
877     task_info->set_starting_round(task->starting_round);
878   }
879   response->set_job_finished(job->finished);
880   VLOG(4) << "Found " << response->task_info_size()
881           << " tasks for job client id " << request->job_client_id();
882   return Status::OK();
883 }
884 
GetWorkers(const GetWorkersRequest * request,GetWorkersResponse * response)885 Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request,
886                                              GetWorkersResponse* response) {
887   TF_RETURN_IF_ERROR(CheckStarted());
888   mutex_lock l(mu_);
889   VLOG(3) << "Enter GetWorkers";
890   std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
891   for (const auto& worker : workers) {
892     WorkerInfo* info = response->add_workers();
893     info->set_address(worker->address);
894   }
895   VLOG(3) << "Returning list of " << response->workers_size()
896           << " workers from GetWorkers";
897   return Status::OK();
898 }
899 
PopulateTaskDef(std::shared_ptr<const Task> task,TaskDef * task_def) const900 Status DataServiceDispatcherImpl::PopulateTaskDef(
901     std::shared_ptr<const Task> task, TaskDef* task_def) const
902     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
903   task_def->set_dataset_id(task->job->dataset_id);
904   task_def->set_job_id(task->job->job_id);
905   task_def->set_worker_address(task->worker_address);
906   task_def->set_task_id(task->task_id);
907   *task_def->mutable_processing_mode_def() = task->job->processing_mode;
908   if (IsStaticShard(task->job->processing_mode)) {
909     task_def->set_num_workers(config_.worker_addresses_size());
910     TF_ASSIGN_OR_RETURN(int64_t worker_index,
911                         state_.GetWorkerIndex(task->worker_address));
912     task_def->set_worker_index(worker_index);
913   }
914   if (task->job->distributed_epoch_state.has_value()) {
915     task_def->set_num_split_providers(
916         task->job->distributed_epoch_state.value().indices.size());
917   }
918   if (task->job->num_consumers.has_value()) {
919     task_def->set_num_consumers(task->job->num_consumers.value());
920   }
921   std::shared_ptr<const Dataset> dataset;
922   TF_RETURN_IF_ERROR(state_.DatasetFromId(task->job->dataset_id, dataset));
923   std::string dataset_key =
924       DatasetKey(dataset->dataset_id, dataset->fingerprint);
925   if (config_.work_dir().empty()) {
926     std::shared_ptr<const DatasetDef> dataset_def;
927     TF_RETURN_IF_ERROR(dataset_store_->Get(dataset_key, dataset_def));
928     *task_def->mutable_dataset_def() = *dataset_def;
929   } else {
930     std::string path =
931         io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key);
932     task_def->set_path(path);
933   }
934   return Status::OK();
935 }
936 
CheckStarted()937 Status DataServiceDispatcherImpl::CheckStarted() TF_LOCKS_EXCLUDED(mu_) {
938   mutex_lock l(mu_);
939   if (!started_) {
940     return errors::Unavailable("Dispatcher has not started yet.");
941   }
942   return Status::OK();
943 }
944 
RecordSplitProduced(int64_t job_id,int64_t repetition,int64_t split_provider_index,bool finished)945 Status DataServiceDispatcherImpl::RecordSplitProduced(
946     int64_t job_id, int64_t repetition, int64_t split_provider_index,
947     bool finished) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
948   Update update;
949   ProduceSplitUpdate* produce_split = update.mutable_produce_split();
950   produce_split->set_job_id(job_id);
951   produce_split->set_repetition(repetition);
952   produce_split->set_split_provider_index(split_provider_index);
953   produce_split->set_finished(finished);
954   return Apply(update);
955 }
956 
ApplyWithoutJournaling(const Update & update)957 Status DataServiceDispatcherImpl::ApplyWithoutJournaling(const Update& update)
958     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
959   return state_.Apply(update);
960 }
961 
Apply(const Update & update)962 Status DataServiceDispatcherImpl::Apply(const Update& update)
963     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
964   if (journal_writer_.has_value()) {
965     TF_RETURN_IF_ERROR(journal_writer_.value()->Write(update));
966   }
967   return state_.Apply(update);
968 }
969 
JobGcThread()970 void DataServiceDispatcherImpl::JobGcThread() {
971   int64_t next_check_micros = 0;
972   while (true) {
973     mutex_lock l(mu_);
974     while (!cancelled_ && env_->NowMicros() < next_check_micros) {
975       int64_t remaining_micros = next_check_micros - env_->NowMicros();
976       job_gc_thread_cv_.wait_for(l,
977                                  std::chrono::microseconds(remaining_micros));
978     }
979     if (cancelled_) {
980       return;
981     }
982     Status s = GcOldJobs();
983     if (!s.ok()) {
984       LOG(WARNING) << "Error garbage collecting old jobs: " << s;
985     }
986     next_check_micros =
987         env_->NowMicros() + (config_.job_gc_check_interval_ms() * 1000);
988   }
989 }
990 
GcOldJobs()991 Status DataServiceDispatcherImpl::GcOldJobs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
992   std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
993   int64_t now = env_->NowMicros();
994   for (const auto& job : jobs) {
995     if (job->finished || job->num_clients > 0 ||
996         job->last_client_released_micros < 0 ||
997         now < job->last_client_released_micros +
998                   (config_.job_gc_timeout_ms() * 1000)) {
999       continue;
1000     }
1001     Update update;
1002     update.mutable_garbage_collect_job()->set_job_id(job->job_id);
1003     TF_RETURN_IF_ERROR(state_.Apply(update));
1004     LOG(INFO) << "Garbage collected job " << job->DebugString();
1005   }
1006   return Status::OK();
1007 }
1008 
GetDatasetDef(int64_t dataset_id,std::shared_ptr<const DatasetDef> & dataset_def)1009 Status DataServiceDispatcherImpl::GetDatasetDef(
1010     int64_t dataset_id, std::shared_ptr<const DatasetDef>& dataset_def)
1011     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1012   std::shared_ptr<const Dataset> dataset;
1013   TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
1014   return GetDatasetDef(*dataset, dataset_def);
1015 }
1016 
GetDatasetDef(const Dataset & dataset,std::shared_ptr<const DatasetDef> & dataset_def)1017 Status DataServiceDispatcherImpl::GetDatasetDef(
1018     const Dataset& dataset, std::shared_ptr<const DatasetDef>& dataset_def)
1019     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1020   std::string key = DatasetKey(dataset.dataset_id, dataset.fingerprint);
1021   return dataset_store_->Get(key, dataset_def);
1022 }
1023 
1024 }  // namespace data
1025 }  // namespace tensorflow
1026