1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_SERVER_LIB_H_ 17 #define TENSORFLOW_CORE_DATA_SERVICE_SERVER_LIB_H_ 18 19 #include "grpcpp/server.h" 20 #include "grpcpp/server_builder.h" 21 #include "tensorflow/core/data/service/data_transfer.h" 22 #include "tensorflow/core/lib/core/status.h" 23 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h" 24 #include "tensorflow/core/protobuf/service_config.pb.h" 25 26 namespace tensorflow { 27 namespace data { 28 29 // Forward declared because transitively depending on .grpc.pb.h files causes 30 // issues in the pywrap build. 31 class GrpcDispatcherImpl; 32 class GrpcWorkerImpl; 33 34 // A grpc server for the tf.data service. 35 class GrpcDataServerBase { 36 public: 37 // Constructs a tf.data server with the specified port. If the port is 0, the 38 // server will find an available port in `Start()`. The chosen port can be 39 // found by calling `BoundPort()`. 40 GrpcDataServerBase(int requested_port, const std::string& protocol, 41 const std::string server_type); ~GrpcDataServerBase()42 virtual ~GrpcDataServerBase() {} 43 44 // Starts the server running asynchronously. 45 Status Start(); 46 47 // Stops the server. This will block until all outstanding requests complete. 48 void Stop(); 49 50 // Blocks until the server stops. 51 void Join(); 52 53 // Returns the port bound by the server. Only valid after calling Start(). 54 int BoundPort(); 55 56 protected: 57 virtual void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) = 0; 58 void AddProfilerServiceToBuilder(::grpc::ServerBuilder& builder); 59 // Starts the service. This will be called after building the service, so 60 // bound_port() will return the actual bound port. 61 virtual Status StartServiceInternal() = 0; StopServiceInternal()62 virtual void StopServiceInternal() {} 63 bound_port()64 int bound_port() { return bound_port_; } 65 66 const int requested_port_; 67 const std::string protocol_; 68 const std::string server_type_; 69 70 private: 71 int bound_port_; 72 bool started_ = false; 73 bool stopped_ = false; 74 75 std::unique_ptr<::grpc::Server> server_; 76 // TensorFlow profiler service implementation. 77 std::unique_ptr<grpc::ProfilerService::Service> profiler_service_ = nullptr; 78 }; 79 80 class DispatchGrpcDataServer : public GrpcDataServerBase { 81 public: 82 explicit DispatchGrpcDataServer(const experimental::DispatcherConfig& config); 83 ~DispatchGrpcDataServer() override; 84 85 // Returns the number of workers registerd with the dispatcher. 86 Status NumWorkers(int* num_workers); 87 88 protected: 89 void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override; 90 Status StartServiceInternal() override; 91 92 private: 93 const experimental::DispatcherConfig config_; 94 // Owned. We use a raw pointer because GrpcDispatcherImpl is forward-declared. 95 GrpcDispatcherImpl* service_; 96 }; 97 98 class WorkerGrpcDataServer : public GrpcDataServerBase { 99 public: 100 explicit WorkerGrpcDataServer(const experimental::WorkerConfig& config); 101 ~WorkerGrpcDataServer() override; 102 103 // Returns the number of tasks currently being executed by the worker. 104 Status NumTasks(int* num_tasks); 105 106 protected: 107 void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override; 108 Status StartServiceInternal() override; 109 void StopServiceInternal() override; 110 111 private: 112 const experimental::WorkerConfig config_; 113 // Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared. 114 GrpcWorkerImpl* service_; 115 std::shared_ptr<DataTransferServer> transfer_server_; 116 }; 117 118 // Creates a dispatch tf.data server and stores it in `out_server`. 119 Status NewDispatchServer(const experimental::DispatcherConfig& config, 120 std::unique_ptr<DispatchGrpcDataServer>& out_server); 121 122 // Creates a worker tf.data server and stores it in `out_server`. 123 Status NewWorkerServer(const experimental::WorkerConfig& config, 124 std::unique_ptr<WorkerGrpcDataServer>& out_server); 125 126 } // namespace data 127 } // namespace tensorflow 128 129 #endif // TENSORFLOW_CORE_DATA_SERVICE_SERVER_LIB_H_ 130