1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_CLIENT_H_ 16 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_CLIENT_H_ 17 18 #include <memory> 19 #include <string> 20 #include <vector> 21 22 #include "absl/types/optional.h" 23 #include "tensorflow/core/data/service/common.h" 24 #include "tensorflow/core/data/service/common.pb.h" 25 #include "tensorflow/core/data/service/data_transfer.h" 26 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h" 27 #include "tensorflow/core/data/service/dispatcher.pb.h" 28 #include "tensorflow/core/framework/graph.pb.h" 29 #include "tensorflow/core/framework/tensor.h" 30 #include "tensorflow/core/platform/mutex.h" 31 #include "tensorflow/core/platform/status.h" 32 #include "tensorflow/core/platform/types.h" 33 #include "tensorflow/core/protobuf/data_service.pb.h" 34 35 namespace tensorflow { 36 namespace data { 37 38 // Client for communicating with the tf.data service dispatcher. 39 class DataServiceDispatcherClient : public DataServiceClientBase { 40 public: DataServiceDispatcherClient(const std::string & address,const std::string & protocol)41 DataServiceDispatcherClient(const std::string& address, 42 const std::string& protocol) 43 : DataServiceClientBase(address, protocol) {} 44 45 // Sends a heartbeat to the dispatcher. If the worker wasn't already 46 // registered with the dispatcher, this will register the worker. The 47 // dispatcher will report which new tasks the worker should run, and which 48 // tasks it should delete. This is stored into `new_tasks` and 49 // `tasks_to_delete`. 50 Status WorkerHeartbeat(const std::string& worker_address, 51 const std::string& transfer_address, 52 const std::vector<int64>& current_tasks, 53 std::vector<TaskDef>& new_tasks, 54 std::vector<int64>& tasks_to_delete); 55 56 // Updates the dispatcher with information about the worker's state. 57 Status WorkerUpdate(const std::string& worker_address, 58 std::vector<TaskProgress>& task_progress); 59 60 // Gets a dataset definition for the given dataset id, and stores the 61 // definition in `dataset_def`. 62 Status GetDatasetDef(int64_t dataset_id, DatasetDef& dataset_def); 63 64 // Gets the next split for the specified job id, repetition, and split 65 // provider index. 66 Status GetSplit(int64_t job_id, int64_t repetition, 67 int64_t split_provider_index, Tensor& split, 68 bool& end_of_splits); 69 70 // Registers a dataset with the tf.data service, and stores the generated 71 // dataset id in `dataset_id`. 72 Status RegisterDataset(const DatasetDef& dataset, 73 const absl::optional<std::string>& element_spec, 74 int64& dataset_id); 75 76 // If `job_key` is set, looks up a job matching `job_key`. If `job_key` is 77 // absent or no matching job is found, creates a new job. The resulting job 78 // id is stored in `job_client_id`. 79 Status GetOrCreateJob(int64_t dataset_id, 80 const ProcessingModeDef& processing_mode, 81 const absl::optional<JobKey>& job_key, 82 absl::optional<int64> num_consumers, 83 int64& job_client_id, TargetWorkers target_workers); 84 85 // Releases a job client id, indicating that the id will no longer be used to 86 // read from the job. 87 Status ReleaseJobClient(int64_t job_client_id); 88 89 // Attempts to remove a task. The task is removed if all consumers try to 90 // remove the task in the same round. 91 Status MaybeRemoveTask(int64_t task_id, int64_t consumer_index, int64_t round, 92 bool& removed); 93 94 // Heartbeats to the dispatcher, getting back the tasks that should be 95 // running, and whether the job is finished. 96 Status ClientHeartbeat(ClientHeartbeatRequest& req, 97 ClientHeartbeatResponse& resp); 98 99 // Queries the dispatcher for its registered workers. The worker info will be 100 // stored in `workers`. 101 Status GetWorkers(std::vector<WorkerInfo>& workers); 102 103 // Returns element spec for the registered dataset. 104 Status GetElementSpec(int64_t dataset_id, std::string& element_spec); 105 106 protected: 107 Status EnsureInitialized() override; 108 109 private: 110 mutex mu_; 111 // Initialization is guarded by `mu_`, but using the stub does not require 112 // holding `mu_` 113 std::unique_ptr<DispatcherService::Stub> stub_; 114 }; 115 116 } // namespace data 117 } // namespace tensorflow 118 119 #endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_CLIENT_H_ 120