• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_CLIENT_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_CLIENT_H_
17 
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/types/optional.h"
23 #include "tensorflow/core/data/service/common.h"
24 #include "tensorflow/core/data/service/common.pb.h"
25 #include "tensorflow/core/data/service/data_transfer.h"
26 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
27 #include "tensorflow/core/data/service/dispatcher.pb.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/status.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/protobuf/data_service.pb.h"
34 
35 namespace tensorflow {
36 namespace data {
37 
38 // Client for communicating with the tf.data service dispatcher.
39 class DataServiceDispatcherClient : public DataServiceClientBase {
40  public:
DataServiceDispatcherClient(const std::string & address,const std::string & protocol)41   DataServiceDispatcherClient(const std::string& address,
42                               const std::string& protocol)
43       : DataServiceClientBase(address, protocol) {}
44 
45   // Sends a heartbeat to the dispatcher. If the worker wasn't already
46   // registered with the dispatcher, this will register the worker. The
47   // dispatcher will report which new tasks the worker should run, and which
48   // tasks it should delete. This is stored into `new_tasks` and
49   // `tasks_to_delete`.
50   Status WorkerHeartbeat(const std::string& worker_address,
51                          const std::string& transfer_address,
52                          const std::vector<int64>& current_tasks,
53                          std::vector<TaskDef>& new_tasks,
54                          std::vector<int64>& tasks_to_delete);
55 
56   // Updates the dispatcher with information about the worker's state.
57   Status WorkerUpdate(const std::string& worker_address,
58                       std::vector<TaskProgress>& task_progress);
59 
60   // Gets a dataset definition for the given dataset id, and stores the
61   // definition in `dataset_def`.
62   Status GetDatasetDef(int64_t dataset_id, DatasetDef& dataset_def);
63 
64   // Gets the next split for the specified job id, repetition, and split
65   // provider index.
66   Status GetSplit(int64_t job_id, int64_t repetition,
67                   int64_t split_provider_index, Tensor& split,
68                   bool& end_of_splits);
69 
70   // Registers a dataset with the tf.data service, and stores the generated
71   // dataset id in `dataset_id`.
72   Status RegisterDataset(const DatasetDef& dataset,
73                          const absl::optional<std::string>& element_spec,
74                          int64& dataset_id);
75 
76   // If `job_key` is set, looks up a job matching `job_key`. If `job_key` is
77   // absent or no matching job is found, creates a new job. The resulting job
78   // id is stored in `job_client_id`.
79   Status GetOrCreateJob(int64_t dataset_id,
80                         const ProcessingModeDef& processing_mode,
81                         const absl::optional<JobKey>& job_key,
82                         absl::optional<int64> num_consumers,
83                         int64& job_client_id, TargetWorkers target_workers);
84 
85   // Releases a job client id, indicating that the id will no longer be used to
86   // read from the job.
87   Status ReleaseJobClient(int64_t job_client_id);
88 
89   // Attempts to remove a task. The task is removed if all consumers try to
90   // remove the task in the same round.
91   Status MaybeRemoveTask(int64_t task_id, int64_t consumer_index, int64_t round,
92                          bool& removed);
93 
94   // Heartbeats to the dispatcher, getting back the tasks that should be
95   // running, and whether the job is finished.
96   Status ClientHeartbeat(ClientHeartbeatRequest& req,
97                          ClientHeartbeatResponse& resp);
98 
99   // Queries the dispatcher for its registered workers. The worker info will be
100   // stored in `workers`.
101   Status GetWorkers(std::vector<WorkerInfo>& workers);
102 
103   // Returns element spec for the registered dataset.
104   Status GetElementSpec(int64_t dataset_id, std::string& element_spec);
105 
106  protected:
107   Status EnsureInitialized() override;
108 
109  private:
110   mutex mu_;
111   // Initialization is guarded by `mu_`, but using the stub does not require
112   // holding `mu_`
113   std::unique_ptr<DispatcherService::Stub> stub_;
114 };
115 
116 }  // namespace data
117 }  // namespace tensorflow
118 
119 #endif  // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_CLIENT_H_
120