1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_ 18 19 #include <unordered_map> 20 21 #include "tensorflow/core/distributed_runtime/graph_mgr.h" 22 #include "tensorflow/core/distributed_runtime/partial_run_mgr.h" 23 #include "tensorflow/core/distributed_runtime/session_mgr.h" 24 #include "tensorflow/core/distributed_runtime/worker_interface.h" 25 26 namespace tensorflow { 27 28 class CancellationManager; 29 class Device; 30 struct WorkerEnv; 31 struct WorkerSession; 32 33 // A TensorFlow Worker runs registered graphs and supports worker-to-worker 34 // Tensor transfer. 35 // 36 // See `../protobuf/worker_service.proto` for more details about each method. 37 // 38 // This class may be subclassed to provide specialized implementations of 39 // particular methods for different transport mechanism. For example, 40 // `GrpcWorker` specializes the `RecvTensorAsync()` method to support a more 41 // efficient gRPC data structure for handling large binary data. 42 class Worker : public WorkerInterface { 43 public: 44 Worker(WorkerEnv* env); ~Worker()45 virtual ~Worker() {} 46 47 void GetStatusAsync(const GetStatusRequest* request, 48 GetStatusResponse* response, 49 StatusCallback done) override; 50 51 void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request, 52 CreateWorkerSessionResponse* response, 53 StatusCallback done) override; 54 55 void DeleteWorkerSessionAsync(CallOptions* opts, 56 const DeleteWorkerSessionRequest* request, 57 DeleteWorkerSessionResponse* response, 58 StatusCallback done) override; 59 60 void RegisterGraphAsync(const RegisterGraphRequest* request, 61 RegisterGraphResponse* response, 62 StatusCallback done) override; 63 64 void DeregisterGraphAsync(const DeregisterGraphRequest* request, 65 DeregisterGraphResponse* response, 66 StatusCallback done) override; 67 68 void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, 69 MutableRunGraphResponseWrapper* response, 70 StatusCallback done) override; 71 72 MutableRunGraphRequestWrapper* CreateRunGraphRequest() override; 73 74 MutableRunGraphResponseWrapper* CreateRunGraphResponse() override; 75 76 void CleanupGraphAsync(const CleanupGraphRequest* request, 77 CleanupGraphResponse* response, 78 StatusCallback done) override; 79 80 void CleanupAllAsync(const CleanupAllRequest* request, 81 CleanupAllResponse* response, 82 StatusCallback done) override; 83 84 void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, 85 TensorResponse* response, StatusCallback done) override; 86 87 void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, 88 StatusCallback done) override; 89 90 void TracingAsync(const TracingRequest* request, TracingResponse* response, 91 StatusCallback done) override; 92 93 void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, 94 RecvBufResponse* response, StatusCallback done) override; 95 96 void CompleteGroupAsync(CallOptions* opts, 97 const CompleteGroupRequest* request, 98 CompleteGroupResponse* response, 99 StatusCallback done) override; 100 101 void CompleteInstanceAsync(CallOptions* opts, 102 const CompleteInstanceRequest* request, 103 CompleteInstanceResponse* response, 104 StatusCallback done) override; 105 106 void GetStepSequenceAsync(const GetStepSequenceRequest* request, 107 GetStepSequenceResponse* response, 108 StatusCallback done) override; 109 110 protected: 111 WorkerEnv* const env_; // Not owned. 112 113 Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, 114 Device** src_dev); 115 116 void AbortStep(int64); 117 118 private: 119 PartialRunMgr partial_run_mgr_; 120 121 CancellationManager cancellation_manager_; 122 123 Status PrepareRunGraph(RunGraphRequestWrapper* req, 124 GraphMgr::NamedTensors* in, 125 GraphMgr::NamedTensors* out); 126 127 void DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, 128 MutableRunGraphResponseWrapper* response, 129 StatusCallback done); 130 131 void DoPartialRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, 132 MutableRunGraphResponseWrapper* response, 133 StatusCallback done); 134 135 TF_DISALLOW_COPY_AND_ASSIGN(Worker); 136 }; 137 138 } // namespace tensorflow 139 140 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_ 141