1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ 16 #define TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ 17 18 #include <memory> 19 #include <string> 20 #include <utility> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/container/flat_hash_set.h" 24 #include "absl/strings/string_view.h" 25 #include "tensorflow/core/data/service/common.pb.h" 26 #include "tensorflow/core/data/service/data_transfer.h" 27 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h" 28 #include "tensorflow/core/data/service/dispatcher_client.h" 29 #include "tensorflow/core/data/service/task_runner.h" 30 #include "tensorflow/core/data/service/worker.pb.h" 31 #include "tensorflow/core/data/standalone.h" 32 #include "tensorflow/core/framework/cancellation.h" 33 #include "tensorflow/core/lib/core/errors.h" 34 #include "tensorflow/core/platform/env.h" 35 #include "tensorflow/core/platform/mutex.h" 36 #include "tensorflow/core/platform/status.h" 37 #include "tensorflow/core/platform/statusor.h" 38 #include "tensorflow/core/platform/thread_annotations.h" 39 #include "tensorflow/core/protobuf/service_config.pb.h" 40 #include "tensorflow/core/public/session.h" 41 42 namespace tensorflow { 43 namespace data { 44 45 // A TensorFlow DataService serves dataset elements over RPC. 46 class DataServiceWorkerImpl { 47 public: 48 explicit DataServiceWorkerImpl(const experimental::WorkerConfig& config); 49 ~DataServiceWorkerImpl(); 50 51 // Starts the worker. The worker needs to know its own address so that it can 52 // register with the dispatcher. This is set in `Start` instead of in the 53 // constructor because the worker may be binding to port `0`, in which case 54 // the address isn't known until the worker has started and decided which port 55 // to bind to. 56 Status Start(const std::string& worker_address, 57 const std::string& transfer_address); 58 // Stops the worker, attempting a clean shutdown by rejecting new requests 59 // and waiting for outstanding requests to complete. 60 void Stop(); 61 62 // Serves a GetElement request, storing the result in `*result`. See 63 // worker.proto for GetElement API documentation. 64 Status GetElementResult(const GetElementRequest* request, 65 GetElementResult* result); 66 67 // See worker.proto for API documentation. 68 69 /// Dispatcher-facing API. 70 Status ProcessTask(const ProcessTaskRequest* request, 71 ProcessTaskResponse* response); 72 73 /// Client-facing API. 74 Status GetElement(const GetElementRequest* request, 75 GetElementResponse* response); 76 Status GetWorkerTasks(const GetWorkerTasksRequest* request, 77 GetWorkerTasksResponse* response); 78 79 private: 80 struct Task { TaskTask81 explicit Task(TaskDef task_def) : task_def(std::move(task_def)) {} 82 83 TaskDef task_def; 84 mutex mu; 85 bool initialized TF_GUARDED_BY(mu) = false; 86 int64 outstanding_requests TF_GUARDED_BY(&DataServiceWorkerImpl::mu_) = 0; 87 std::unique_ptr<TaskRunner> task_runner; 88 }; 89 90 // Sends task status to the dispatcher and checks for dispatcher commands. 91 Status SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_); 92 // Creates an iterator to process a task. 93 Status ProcessTaskInternal(const TaskDef& task) 94 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 95 Status EnsureTaskInitialized(Task& task); 96 // Stops a task, cancelling the task's outstanding requests and waiting for 97 // them to finish. 98 void StopTask(Task& task) TF_LOCKS_EXCLUDED(mu_); 99 // A thread for notifying the dispatcher when tasks complete. 100 void TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_); 101 // A thread for doing periodic heartbeats to the dispatcher. 102 void HeartbeatThread() TF_LOCKS_EXCLUDED(mu_); 103 // Performs a heartbeat to the dispatcher. 104 Status Heartbeat() TF_LOCKS_EXCLUDED(mu_); 105 // Gets the DatasetDef for `task_def`. 106 StatusOr<DatasetDef> GetDatasetDef(const TaskDef& task_def) const; 107 // Creates a dataset from `dataset_def`. 108 StatusOr<std::unique_ptr<standalone::Dataset>> MakeDataset( 109 const DatasetDef& dataset_def, const TaskDef& task_def) const; 110 // Creates an iterator for `dataset`. 111 StatusOr<std::unique_ptr<standalone::Iterator>> MakeDatasetIterator( 112 standalone::Dataset& dataset, const TaskDef& task_def) const; 113 114 const experimental::WorkerConfig config_; 115 // The worker's own address. 116 std::string worker_address_; 117 std::string transfer_address_; 118 std::unique_ptr<DataServiceDispatcherClient> dispatcher_; 119 120 mutex mu_; 121 condition_variable cv_; 122 // Information about tasks, keyed by task ids. The tasks are updated based on 123 // the heartbeat responses from the dispatcher. 124 absl::flat_hash_map<int64, std::shared_ptr<Task>> tasks_ TF_GUARDED_BY(mu_); 125 // Ids of tasks that have finished. 126 absl::flat_hash_set<int64> finished_tasks_ TF_GUARDED_BY(mu_); 127 // Completed tasks which haven't yet been communicated to the dispatcher. 128 absl::flat_hash_set<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_); 129 bool cancelled_ TF_GUARDED_BY(mu_) = false; 130 // Whether the worker has registered with the dispatcher yet. 131 bool registered_ TF_GUARDED_BY(mu_) = false; 132 // A thread for notifying the dispatcher when tasks complete. 133 std::unique_ptr<Thread> task_completion_thread_; 134 condition_variable task_completion_cv_ TF_GUARDED_BY(mu_); 135 // A thread for performing regular heartbeats to the dispatcher. 136 std::unique_ptr<Thread> heartbeat_thread_; 137 condition_variable heartbeat_cv_ TF_GUARDED_BY(mu_); 138 int64 outstanding_requests_ TF_GUARDED_BY(mu_) = 0; 139 CancellationManager cancellation_manager_; 140 141 TF_DISALLOW_COPY_AND_ASSIGN(DataServiceWorkerImpl); 142 }; 143 144 // Local in-process workers shared among clients and servers. If clients and 145 // workers colocate in the same process, clients can read from local workers to 146 // reduce RPC calls and data copy. 147 class LocalWorkers { 148 public: 149 // Adds a `worker` at `worker_address`. If a worker already exists at the 150 // address, it will be updated to the new `worker`. 151 // REQUIRES: worker != nullptr. 152 static void Add(absl::string_view worker_address, 153 std::shared_ptr<DataServiceWorkerImpl> worker); 154 155 // Gets a local worker at `worker_address`. Returns nullptr if a worker is not 156 // found. 157 static std::shared_ptr<DataServiceWorkerImpl> Get( 158 absl::string_view worker_address); 159 160 // Returns if there are any local workers in the process. 161 static bool Empty(); 162 163 // Removes a worker at `worker_address`. It is no-op if a worker is not found 164 // at the address. 165 static void Remove(absl::string_view worker_address); 166 167 private: 168 using AddressToWorkerMap = 169 absl::flat_hash_map<std::string, std::shared_ptr<DataServiceWorkerImpl>>; 170 static mutex mu_; 171 static AddressToWorkerMap* local_workers_ TF_GUARDED_BY(mu_); 172 }; 173 174 } // namespace data 175 } // namespace tensorflow 176 177 #endif // TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ 178