1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_ 18 19 #include <functional> 20 21 #include "tensorflow/core/distributed_runtime/call_options.h" 22 #include "tensorflow/core/distributed_runtime/message_wrappers.h" 23 #include "tensorflow/core/lib/core/notification.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/platform/types.h" 26 #include "tensorflow/core/protobuf/worker.pb.h" 27 28 namespace tensorflow { 29 30 // Status callback. 31 typedef std::function<void(const Status&)> StatusCallback; 32 33 // Custom decoder for a response to RecvTensorAsync. 34 class TensorResponse; 35 36 // Interface for talking with the TensorFlow Worker service. 37 class WorkerInterface { 38 public: 39 virtual void GetStatusAsync(const GetStatusRequest* request, 40 GetStatusResponse* response, 41 StatusCallback done) = 0; 42 43 virtual void CreateWorkerSessionAsync( 44 const CreateWorkerSessionRequest* request, 45 CreateWorkerSessionResponse* response, StatusCallback done) = 0; 46 47 virtual void DeleteWorkerSessionAsync( 48 CallOptions* opts, const DeleteWorkerSessionRequest* request, 49 DeleteWorkerSessionResponse* response, StatusCallback done) = 0; 50 51 virtual void RegisterGraphAsync(const RegisterGraphRequest* request, 52 RegisterGraphResponse* response, 53 StatusCallback done) = 0; 54 55 virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request, 56 DeregisterGraphResponse* response, 57 StatusCallback done) = 0; 58 59 virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, 60 MutableRunGraphResponseWrapper* repsonse, 61 StatusCallback done) = 0; 62 RunGraphAsync(CallOptions * opts,const RunGraphRequest * request,RunGraphResponse * response,StatusCallback done)63 virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request, 64 RunGraphResponse* response, StatusCallback done) { 65 // TODO(mrry): Convert this to std::bind/std::move if the overhead 66 // of std::function copying becomes too much. 67 RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request); 68 MutableRunGraphResponseWrapper* wrapped_response = 69 new NonOwnedProtoRunGraphResponse(response); 70 RunGraphAsync(opts, wrapped_request, wrapped_response, 71 [wrapped_request, wrapped_response, done](const Status& s) { 72 done(s); 73 delete wrapped_request; 74 delete wrapped_response; 75 }); 76 } 77 78 // Returns a request object for use in calls to 79 // `RunGraphAsync()`. Ownership is transferred to the caller. 80 // 81 // The message returned from this method must only be used in a 82 // `RunGraph()` call on the same `WorkerInterface` instance. CreateRunGraphRequest()83 virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() { 84 return new MutableProtoRunGraphRequest; 85 } 86 87 // Returns a response object for use in calls to 88 // `RunGraphAsync()`. Ownership is transferred to the caller. 89 // 90 // The message returned from this method must only be used in a 91 // `RunGraph()` call on the same `WorkerInterface` instance. CreateRunGraphResponse()92 virtual MutableRunGraphResponseWrapper* CreateRunGraphResponse() { 93 return new OwnedProtoRunGraphResponse; 94 } 95 96 virtual void CleanupGraphAsync(const CleanupGraphRequest* request, 97 CleanupGraphResponse* response, 98 StatusCallback done) = 0; 99 100 virtual void CleanupAllAsync(const CleanupAllRequest* request, 101 CleanupAllResponse* response, 102 StatusCallback done) = 0; 103 104 virtual void RecvTensorAsync(CallOptions* opts, 105 const RecvTensorRequest* request, 106 TensorResponse* response, 107 StatusCallback done) = 0; 108 109 virtual void LoggingAsync(const LoggingRequest* request, 110 LoggingResponse* response, StatusCallback done) = 0; 111 112 virtual void TracingAsync(const TracingRequest* request, 113 TracingResponse* response, StatusCallback done) = 0; 114 115 virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, 116 RecvBufResponse* response, StatusCallback done) = 0; 117 118 virtual void CompleteGroupAsync(CallOptions* opts, 119 const CompleteGroupRequest* request, 120 CompleteGroupResponse* response, 121 StatusCallback done) = 0; 122 123 virtual void CompleteInstanceAsync(CallOptions* ops, 124 const CompleteInstanceRequest* request, 125 CompleteInstanceResponse* response, 126 StatusCallback done) = 0; 127 128 virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request, 129 GetStepSequenceResponse* response, 130 StatusCallback done) = 0; 131 GetStatus(const GetStatusRequest * request,GetStatusResponse * response)132 Status GetStatus(const GetStatusRequest* request, 133 GetStatusResponse* response) { 134 return CallAndWait(&ME::GetStatusAsync, request, response); 135 } 136 CreateWorkerSession(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response)137 Status CreateWorkerSession(const CreateWorkerSessionRequest* request, 138 CreateWorkerSessionResponse* response) { 139 return CallAndWait(&ME::CreateWorkerSessionAsync, request, response); 140 } 141 DeleteWorkerSession(const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response)142 Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request, 143 DeleteWorkerSessionResponse* response) { 144 return CallAndWaitWithOptions(&ME::DeleteWorkerSessionAsync, request, 145 response); 146 } 147 RegisterGraph(const RegisterGraphRequest * request,RegisterGraphResponse * response)148 Status RegisterGraph(const RegisterGraphRequest* request, 149 RegisterGraphResponse* response) { 150 return CallAndWait(&ME::RegisterGraphAsync, request, response); 151 } 152 DeregisterGraph(const DeregisterGraphRequest * request,DeregisterGraphResponse * response)153 Status DeregisterGraph(const DeregisterGraphRequest* request, 154 DeregisterGraphResponse* response) { 155 return CallAndWait(&ME::DeregisterGraphAsync, request, response); 156 } 157 CleanupGraph(const CleanupGraphRequest * request,CleanupGraphResponse * response)158 Status CleanupGraph(const CleanupGraphRequest* request, 159 CleanupGraphResponse* response) { 160 return CallAndWait(&ME::CleanupGraphAsync, request, response); 161 } 162 CleanupAll(const CleanupAllRequest * request,CleanupAllResponse * response)163 Status CleanupAll(const CleanupAllRequest* request, 164 CleanupAllResponse* response) { 165 return CallAndWait(&ME::CleanupAllAsync, request, response); 166 } 167 Logging(const LoggingRequest * request,LoggingResponse * response)168 Status Logging(const LoggingRequest* request, LoggingResponse* response) { 169 return CallAndWait(&ME::LoggingAsync, request, response); 170 } 171 Tracing(const TracingRequest * request,TracingResponse * response)172 Status Tracing(const TracingRequest* request, TracingResponse* response) { 173 return CallAndWait(&ME::TracingAsync, request, response); 174 } 175 GetStepSequence(const GetStepSequenceRequest * request,GetStepSequenceResponse * response)176 Status GetStepSequence(const GetStepSequenceRequest* request, 177 GetStepSequenceResponse* response) { 178 return CallAndWait(&ME::GetStepSequenceAsync, request, response); 179 } 180 181 protected: 182 // Instances of WorkerInterface must be deleted by a call to 183 // WorkerCacheInterface::ReleaseWorker(). ~WorkerInterface()184 virtual ~WorkerInterface() {} 185 friend class WorkerCacheInterface; 186 187 // NOTE: This should only be called by implementations of this 188 // interface whose CreateRunGraphResponse() method returns a 189 // proto-based wrappers for the RunGraphResponse message. get_proto_from_wrapper(MutableRunGraphResponseWrapper * wrapper)190 RunGraphResponse* get_proto_from_wrapper( 191 MutableRunGraphResponseWrapper* wrapper) { 192 return wrapper->get_proto(); 193 } 194 195 private: 196 typedef WorkerInterface ME; 197 198 template <typename Method, typename Req, typename Resp> CallAndWait(Method func,const Req * req,Resp * resp)199 Status CallAndWait(Method func, const Req* req, Resp* resp) { 200 Status ret; 201 Notification n; 202 (this->*func)(req, resp, [&ret, &n](const Status& s) { 203 ret = s; 204 n.Notify(); 205 }); 206 n.WaitForNotification(); 207 return ret; 208 } 209 210 template <typename Method, typename Req, typename Resp> CallAndWaitWithOptions(Method func,const Req * req,Resp * resp)211 Status CallAndWaitWithOptions(Method func, const Req* req, Resp* resp) { 212 CallOptions call_opts; 213 Status ret; 214 Notification n; 215 (this->*func)(&call_opts, req, resp, [&ret, &n](const Status& s) { 216 ret = s; 217 n.Notify(); 218 }); 219 n.WaitForNotification(); 220 return ret; 221 } 222 }; 223 224 } // namespace tensorflow 225 226 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_ 227