• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/core/data/service/common.pb.h"
26 #include "tensorflow/core/data/service/data_transfer.h"
27 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
28 #include "tensorflow/core/data/service/dispatcher_client.h"
29 #include "tensorflow/core/data/service/task_runner.h"
30 #include "tensorflow/core/data/service/worker.pb.h"
31 #include "tensorflow/core/data/standalone.h"
32 #include "tensorflow/core/framework/cancellation.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/platform/env.h"
35 #include "tensorflow/core/platform/mutex.h"
36 #include "tensorflow/core/platform/status.h"
37 #include "tensorflow/core/platform/statusor.h"
38 #include "tensorflow/core/platform/thread_annotations.h"
39 #include "tensorflow/core/protobuf/service_config.pb.h"
40 #include "tensorflow/core/public/session.h"
41 
42 namespace tensorflow {
43 namespace data {
44 
45 // A TensorFlow DataService serves dataset elements over RPC.
46 class DataServiceWorkerImpl {
47  public:
48   explicit DataServiceWorkerImpl(const experimental::WorkerConfig& config);
49   ~DataServiceWorkerImpl();
50 
51   // Starts the worker. The worker needs to know its own address so that it can
52   // register with the dispatcher. This is set in `Start` instead of in the
53   // constructor because the worker may be binding to port `0`, in which case
54   // the address isn't known until the worker has started and decided which port
55   // to bind to.
56   Status Start(const std::string& worker_address,
57                const std::string& transfer_address);
58   // Stops the worker, attempting a clean shutdown by rejecting new requests
59   // and waiting for outstanding requests to complete.
60   void Stop();
61 
62   // Serves a GetElement request, storing the result in `*result`. See
63   // worker.proto for GetElement API documentation.
64   Status GetElementResult(const GetElementRequest* request,
65                           GetElementResult* result);
66 
67   // See worker.proto for API documentation.
68 
69   /// Dispatcher-facing API.
70   Status ProcessTask(const ProcessTaskRequest* request,
71                      ProcessTaskResponse* response);
72 
73   /// Client-facing API.
74   Status GetElement(const GetElementRequest* request,
75                     GetElementResponse* response);
76   Status GetWorkerTasks(const GetWorkerTasksRequest* request,
77                         GetWorkerTasksResponse* response);
78 
79  private:
80   struct Task {
TaskTask81     explicit Task(TaskDef task_def) : task_def(std::move(task_def)) {}
82 
83     TaskDef task_def;
84     mutex mu;
85     bool initialized TF_GUARDED_BY(mu) = false;
86     int64 outstanding_requests TF_GUARDED_BY(&DataServiceWorkerImpl::mu_) = 0;
87     std::unique_ptr<TaskRunner> task_runner;
88   };
89 
90   // Sends task status to the dispatcher and checks for dispatcher commands.
91   Status SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_);
92   // Creates an iterator to process a task.
93   Status ProcessTaskInternal(const TaskDef& task)
94       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
95   Status EnsureTaskInitialized(Task& task);
96   // Stops a task, cancelling the task's outstanding requests and waiting for
97   // them to finish.
98   void StopTask(Task& task) TF_LOCKS_EXCLUDED(mu_);
99   // A thread for notifying the dispatcher when tasks complete.
100   void TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_);
101   // A thread for doing periodic heartbeats to the dispatcher.
102   void HeartbeatThread() TF_LOCKS_EXCLUDED(mu_);
103   // Performs a heartbeat to the dispatcher.
104   Status Heartbeat() TF_LOCKS_EXCLUDED(mu_);
105   // Gets the DatasetDef for `task_def`.
106   StatusOr<DatasetDef> GetDatasetDef(const TaskDef& task_def) const;
107   // Creates a dataset from `dataset_def`.
108   StatusOr<std::unique_ptr<standalone::Dataset>> MakeDataset(
109       const DatasetDef& dataset_def, const TaskDef& task_def) const;
110   // Creates an iterator for `dataset`.
111   StatusOr<std::unique_ptr<standalone::Iterator>> MakeDatasetIterator(
112       standalone::Dataset& dataset, const TaskDef& task_def) const;
113 
114   const experimental::WorkerConfig config_;
115   // The worker's own address.
116   std::string worker_address_;
117   std::string transfer_address_;
118   std::unique_ptr<DataServiceDispatcherClient> dispatcher_;
119 
120   mutex mu_;
121   condition_variable cv_;
122   // Information about tasks, keyed by task ids. The tasks are updated based on
123   // the heartbeat responses from the dispatcher.
124   absl::flat_hash_map<int64, std::shared_ptr<Task>> tasks_ TF_GUARDED_BY(mu_);
125   // Ids of tasks that have finished.
126   absl::flat_hash_set<int64> finished_tasks_ TF_GUARDED_BY(mu_);
127   // Completed tasks which haven't yet been communicated to the dispatcher.
128   absl::flat_hash_set<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
129   bool cancelled_ TF_GUARDED_BY(mu_) = false;
130   // Whether the worker has registered with the dispatcher yet.
131   bool registered_ TF_GUARDED_BY(mu_) = false;
132   // A thread for notifying the dispatcher when tasks complete.
133   std::unique_ptr<Thread> task_completion_thread_;
134   condition_variable task_completion_cv_ TF_GUARDED_BY(mu_);
135   // A thread for performing regular heartbeats to the dispatcher.
136   std::unique_ptr<Thread> heartbeat_thread_;
137   condition_variable heartbeat_cv_ TF_GUARDED_BY(mu_);
138   int64 outstanding_requests_ TF_GUARDED_BY(mu_) = 0;
139   CancellationManager cancellation_manager_;
140 
141   TF_DISALLOW_COPY_AND_ASSIGN(DataServiceWorkerImpl);
142 };
143 
144 // Local in-process workers shared among clients and servers. If clients and
145 // workers colocate in the same process, clients can read from local workers to
146 // reduce RPC calls and data copy.
147 class LocalWorkers {
148  public:
149   // Adds a `worker` at `worker_address`. If a worker already exists at the
150   // address, it will be updated to the new `worker`.
151   // REQUIRES: worker != nullptr.
152   static void Add(absl::string_view worker_address,
153                   std::shared_ptr<DataServiceWorkerImpl> worker);
154 
155   // Gets a local worker at `worker_address`. Returns nullptr if a worker is not
156   // found.
157   static std::shared_ptr<DataServiceWorkerImpl> Get(
158       absl::string_view worker_address);
159 
160   // Returns if there are any local workers in the process.
161   static bool Empty();
162 
163   // Removes a worker at `worker_address`. It is no-op if a worker is not found
164   // at the address.
165   static void Remove(absl::string_view worker_address);
166 
167  private:
168   using AddressToWorkerMap =
169       absl::flat_hash_map<std::string, std::shared_ptr<DataServiceWorkerImpl>>;
170   static mutex mu_;
171   static AddressToWorkerMap* local_workers_ TF_GUARDED_BY(mu_);
172 };
173 
174 }  // namespace data
175 }  // namespace tensorflow
176 
177 #endif  // TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_
178