• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/service/server_lib.h"
17 
18 #include "tensorflow/core/data/service/credentials_factory.h"
19 #include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
20 #include "tensorflow/core/data/service/grpc_util.h"
21 #include "tensorflow/core/data/service/grpc_worker_impl.h"
22 #include "tensorflow/core/platform/errors.h"
23 
24 namespace tensorflow {
25 namespace data {
26 
27 namespace {
28 constexpr char kPortPlaceholder[] = "%port%";
29 }
30 
GrpcDataServerBase(int port,const std::string & protocol,const std::string server_type)31 GrpcDataServerBase::GrpcDataServerBase(int port, const std::string& protocol,
32                                        const std::string server_type)
33     : requested_port_(port),
34       protocol_(protocol),
35       server_type_(server_type),
36       bound_port_(port) {}
37 
Start()38 Status GrpcDataServerBase::Start() {
39   if (stopped_) {
40     return errors::FailedPrecondition(
41         "Server cannot be started after it has been stopped.");
42   }
43   if (started_) {
44     return Status::OK();
45   }
46   ::grpc::ServerBuilder builder;
47   std::shared_ptr<::grpc::ServerCredentials> credentials;
48   TF_RETURN_IF_ERROR(
49       CredentialsFactory::CreateServerCredentials(protocol_, &credentials));
50   builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port_),
51                            credentials, &bound_port_);
52   builder.SetMaxReceiveMessageSize(-1);
53 
54   AddDataServiceToBuilder(builder);
55   AddProfilerServiceToBuilder(builder);
56   server_ = builder.BuildAndStart();
57   if (!server_) {
58     return errors::Internal("Could not start gRPC server");
59   }
60 
61   TF_RETURN_IF_ERROR(StartServiceInternal());
62 
63   started_ = true;
64   LOG(INFO) << "Started tf.data " << server_type_
65             << " running at 0.0.0.0:" << BoundPort();
66   return Status::OK();
67 }
68 
Stop()69 void GrpcDataServerBase::Stop() {
70   if (stopped_) {
71     return;
72   }
73   if (server_) {
74     StopServiceInternal();
75     server_->Shutdown();
76     LOG(INFO) << "Shut down " << server_type_ << " server running at port "
77               << BoundPort();
78   }
79   stopped_ = true;
80 }
81 
Join()82 void GrpcDataServerBase::Join() { server_->Wait(); }
83 
BoundPort()84 int GrpcDataServerBase::BoundPort() { return bound_port(); }
85 
AddProfilerServiceToBuilder(::grpc::ServerBuilder & builder)86 void GrpcDataServerBase::AddProfilerServiceToBuilder(
87     ::grpc::ServerBuilder& builder) {
88   profiler_service_ = profiler::CreateProfilerService();
89   builder.RegisterService(profiler_service_.get());
90 }
91 
DispatchGrpcDataServer(const experimental::DispatcherConfig & config)92 DispatchGrpcDataServer::DispatchGrpcDataServer(
93     const experimental::DispatcherConfig& config)
94     : GrpcDataServerBase(config.port(), config.protocol(), "DispatchServer"),
95       config_(config) {}
96 
~DispatchGrpcDataServer()97 DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
98 
AddDataServiceToBuilder(::grpc::ServerBuilder & builder)99 void DispatchGrpcDataServer::AddDataServiceToBuilder(
100     ::grpc::ServerBuilder& builder) {
101   service_ = absl::make_unique<GrpcDispatcherImpl>(config_, builder).release();
102 }
103 
StartServiceInternal()104 Status DispatchGrpcDataServer::StartServiceInternal() {
105   return service_->Start();
106 }
107 
NumWorkers(int * num_workers)108 Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
109   GetWorkersRequest req;
110   GetWorkersResponse resp;
111   ::grpc::ServerContext ctx;
112   ::grpc::Status s = service_->GetWorkers(&ctx, &req, &resp);
113   if (!s.ok()) {
114     return grpc_util::WrapError("Failed to get workers", s);
115   }
116   *num_workers = resp.workers_size();
117   return Status::OK();
118 }
119 
WorkerGrpcDataServer(const experimental::WorkerConfig & config)120 WorkerGrpcDataServer::WorkerGrpcDataServer(
121     const experimental::WorkerConfig& config)
122     : GrpcDataServerBase(config.port(), config.protocol(), "WorkerServer"),
123       config_(config) {}
124 
~WorkerGrpcDataServer()125 WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
126 
AddDataServiceToBuilder(::grpc::ServerBuilder & builder)127 void WorkerGrpcDataServer::AddDataServiceToBuilder(
128     ::grpc::ServerBuilder& builder) {
129   service_ = absl::make_unique<GrpcWorkerImpl>(config_, builder).release();
130 }
131 
StartServiceInternal()132 Status WorkerGrpcDataServer::StartServiceInternal() {
133   std::string base_address = config_.worker_address();
134   if (base_address.empty()) {
135     base_address = absl::StrCat("localhost:", kPortPlaceholder);
136   }
137   std::string worker_address = str_util::StringReplace(
138       base_address, kPortPlaceholder, absl::StrCat(bound_port()),
139       /*replace_all=*/false);
140   std::string transfer_address = worker_address;
141   std::string transfer_protocol = config_.data_transfer_protocol();
142   if (!transfer_protocol.empty() && transfer_protocol != "grpc") {
143     TF_RETURN_IF_ERROR(DataTransferServer::Build(
144         transfer_protocol, service_->get_element_getter(), &transfer_server_));
145     TF_RETURN_IF_ERROR(transfer_server_->Start());
146     LOG(INFO) << "Data transfer server started at 0.0.0.0:"
147               << transfer_server_->get_port();
148     transfer_address = str_util::StringReplace(
149         config_.data_transfer_address(), kPortPlaceholder,
150         absl::StrCat(transfer_server_->get_port()),
151         /*replace_all=*/false);
152   }
153   TF_RETURN_IF_ERROR(service_->Start(worker_address, transfer_address));
154   return Status::OK();
155 }
156 
StopServiceInternal()157 void WorkerGrpcDataServer::StopServiceInternal() { service_->Stop(); }
158 
NumTasks(int * num_tasks)159 Status WorkerGrpcDataServer::NumTasks(int* num_tasks) {
160   GetWorkerTasksRequest req;
161   GetWorkerTasksResponse resp;
162   ::grpc::ServerContext ctx;
163   ::grpc::Status s = service_->GetWorkerTasks(&ctx, &req, &resp);
164   if (!s.ok()) {
165     return grpc_util::WrapError("Failed to get tasks", s);
166   }
167   *num_tasks = resp.tasks_size();
168   return Status::OK();
169 }
170 
NewDispatchServer(const experimental::DispatcherConfig & config,std::unique_ptr<DispatchGrpcDataServer> & out_server)171 Status NewDispatchServer(const experimental::DispatcherConfig& config,
172                          std::unique_ptr<DispatchGrpcDataServer>& out_server) {
173   out_server = absl::make_unique<DispatchGrpcDataServer>(config);
174   return Status::OK();
175 }
176 
NewWorkerServer(const experimental::WorkerConfig & config,std::unique_ptr<WorkerGrpcDataServer> & out_server)177 Status NewWorkerServer(const experimental::WorkerConfig& config,
178                        std::unique_ptr<WorkerGrpcDataServer>& out_server) {
179   out_server = absl::make_unique<WorkerGrpcDataServer>(config);
180   return Status::OK();
181 }
182 
183 }  // namespace data
184 }  // namespace tensorflow
185