1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ 18 19 #include <memory> 20 #include <unordered_map> 21 #include "tensorflow/core/distributed_runtime/recent_request_ids.h" 22 #include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h" 23 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h" 24 #include "tensorflow/core/distributed_runtime/worker.h" 25 26 namespace grpc { 27 class ByteBuffer; 28 class ServerBuilder; 29 } // namespace grpc 30 31 namespace tensorflow { 32 33 class AsyncServiceInterface; 34 class ConfigProto; 35 struct WorkerEnv; 36 struct WorkerSession; 37 38 class GrpcWorker : public Worker { 39 public: 40 GrpcWorker(WorkerEnv* env, const ConfigProto& config); 41 42 // Specialized version of RecvTensor for gRPC, which avoids a copy. 43 virtual void GrpcRecvTensorAsync(CallOptions* opts, 44 const RecvTensorRequest* request, 45 ::grpc::ByteBuffer* response, 46 StatusCallback done); 47 48 virtual void LoggingAsync(const LoggingRequest* request, 49 LoggingResponse* response, StatusCallback done); 50 51 virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, 52 RecvBufResponse* response, StatusCallback done); 53 54 WorkerEnv* env(); 55 56 private: 57 RecentRequestIds recent_request_ids_; 58 const int32 recv_buf_max_chunk_; 59 }; 60 61 std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* worker_env, 62 const ConfigProto& config); 63 64 struct GrpcWorkerServiceOptions { 65 // Map from GrpcWorkerMethod id to queue depth. If set this overrides the 66 // default queue depth for a method. 67 std::unordered_map<int, int> queue_depth; 68 int num_serving_threads = 8; 69 int64 response_cache_bytes = 0; 70 int64 response_cache_expires_seconds = 0; 71 }; 72 73 // Returns an implementation of WorkerService rpc service. 74 std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService( 75 GrpcWorker* worker, ::grpc::ServerBuilder* builder, 76 GrpcWorkerServiceOptions opts = GrpcWorkerServiceOptions()); 77 78 } // namespace tensorflow 79 80 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ 81