1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ 18 19 #include <memory> 20 #include <unordered_map> 21 #include "grpcpp/server_builder.h" 22 #include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h" 23 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h" 24 #include "tensorflow/core/distributed_runtime/worker.h" 25 #include "tensorflow/core/protobuf/worker.pb.h" 26 27 namespace grpc { 28 class ByteBuffer; 29 } // namespace grpc 30 31 namespace tensorflow { 32 33 class AsyncServiceInterface; 34 class ConfigProto; 35 struct WorkerEnv; 36 class WorkerSession; 37 class GrpcResponseCache; 38 39 class GrpcWorker : public Worker { 40 public: 41 GrpcWorker(WorkerEnv* env, const ConfigProto& config); 42 43 // Specialized version of RecvTensor for gRPC, which avoids a copy. 44 virtual void GrpcRecvTensorAsync(CallOptions* opts, 45 const RecvTensorRequest* request, 46 ::grpc::ByteBuffer* response, 47 StatusCallback done); 48 49 void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, 50 StatusCallback done) override; 51 52 void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, 53 RecvBufResponse* response, StatusCallback done) override; 54 55 void CleanupGraphAsync(const CleanupGraphRequest* request, 56 CleanupGraphResponse* response, 57 StatusCallback done) override; 58 59 WorkerEnv* env(); 60 61 void EnableResponseCache(); 62 63 void RemoveCacheEntryForId(int64 request_id); 64 65 private: 66 std::unique_ptr<GrpcResponseCache> response_cache_; 67 const int32 recv_buf_max_chunk_; 68 }; 69 70 std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* worker_env, 71 const ConfigProto& config); 72 73 struct GrpcWorkerServiceOptions { 74 // Map from GrpcWorkerMethod id to queue depth. If set this overrides the 75 // default queue depth for a method. 76 std::unordered_map<int, int> queue_depth; 77 int num_serving_threads = 8; 78 }; 79 80 // Returns an implementation of WorkerService rpc service. 81 std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService( 82 GrpcWorker* worker, ::grpc::ServerBuilder* builder, 83 GrpcWorkerServiceOptions opts = GrpcWorkerServiceOptions()); 84 85 } // namespace tensorflow 86 87 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ 88