1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_ 18 19 #include <unordered_map> 20 21 #include "tensorflow/core/distributed_runtime/graph_mgr.h" 22 #include "tensorflow/core/distributed_runtime/partial_run_mgr.h" 23 #include "tensorflow/core/distributed_runtime/recent_request_ids.h" 24 #include "tensorflow/core/distributed_runtime/session_mgr.h" 25 #include "tensorflow/core/distributed_runtime/worker_interface.h" 26 27 namespace tensorflow { 28 29 class CancellationManager; 30 class Device; 31 struct WorkerEnv; 32 class WorkerSession; 33 34 // A TensorFlow Worker runs registered graphs and supports worker-to-worker 35 // Tensor transfer. 36 // 37 // See `../protobuf/worker_service.proto` for more details about each method. 38 // 39 // This class may be subclassed to provide specialized implementations of 40 // particular methods for different transport mechanism. For example, 41 // `GrpcWorker` specializes the `RecvTensorAsync()` method to support a more 42 // efficient gRPC data structure for handling large binary data. 43 class Worker : public WorkerInterface { 44 public: 45 Worker(WorkerEnv* env); ~Worker()46 virtual ~Worker() {} 47 48 void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, 49 GetStatusResponse* response, bool fail_fast, 50 StatusCallback done) override; 51 52 void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request, 53 CreateWorkerSessionResponse* response, 54 StatusCallback done) override; 55 56 void DeleteWorkerSessionAsync(CallOptions* opts, 57 const DeleteWorkerSessionRequest* request, 58 DeleteWorkerSessionResponse* response, 59 StatusCallback done) override; 60 61 void RegisterGraphAsync(const RegisterGraphRequest* request, 62 RegisterGraphResponse* response, 63 StatusCallback done) override; 64 65 void DeregisterGraphAsync(const DeregisterGraphRequest* request, 66 DeregisterGraphResponse* response, 67 StatusCallback done) override; 68 69 void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, 70 MutableRunGraphResponseWrapper* response, 71 StatusCallback done) override; 72 73 MutableRunGraphRequestWrapper* CreateRunGraphRequest() override; 74 75 MutableRunGraphResponseWrapper* CreateRunGraphResponse() override; 76 77 void CleanupGraphAsync(const CleanupGraphRequest* request, 78 CleanupGraphResponse* response, 79 StatusCallback done) override; 80 81 void CleanupAllAsync(const CleanupAllRequest* request, 82 CleanupAllResponse* response, 83 StatusCallback done) override; 84 85 void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, 86 TensorResponse* response, StatusCallback done) override; 87 88 void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, 89 StatusCallback done) override; 90 91 void TracingAsync(const TracingRequest* request, TracingResponse* response, 92 StatusCallback done) override; 93 94 void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, 95 RecvBufResponse* response, StatusCallback done) override; 96 97 void CompleteGroupAsync(CallOptions* opts, 98 const CompleteGroupRequest* request, 99 CompleteGroupResponse* response, 100 StatusCallback done) override; 101 102 void CompleteInstanceAsync(CallOptions* opts, 103 const CompleteInstanceRequest* request, 104 CompleteInstanceResponse* response, 105 StatusCallback done) override; 106 107 void GetStepSequenceAsync(const GetStepSequenceRequest* request, 108 GetStepSequenceResponse* response, 109 StatusCallback done) override; 110 111 protected: 112 WorkerEnv* const env_; // Not owned. 113 RecentRequestIds recent_request_ids_; 114 115 Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, 116 Device** src_dev); 117 118 void AbortStep(int64_t); 119 120 private: 121 PartialRunMgr partial_run_mgr_; 122 123 CancellationManager cancellation_manager_; 124 125 Status PrepareRunGraph(RunGraphRequestWrapper* req, 126 GraphMgr::NamedTensors* in, 127 GraphMgr::NamedTensors* out); 128 129 void DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, 130 MutableRunGraphResponseWrapper* response, 131 StatusCallback done); 132 133 void DoPartialRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, 134 MutableRunGraphResponseWrapper* response, 135 StatusCallback done); 136 137 TF_DISALLOW_COPY_AND_ASSIGN(Worker); 138 }; 139 140 } // namespace tensorflow 141 142 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_ 143