1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/service/server_lib.h"
17
18 #include "tensorflow/core/data/service/credentials_factory.h"
19 #include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
20 #include "tensorflow/core/data/service/grpc_util.h"
21 #include "tensorflow/core/data/service/grpc_worker_impl.h"
22 #include "tensorflow/core/platform/errors.h"
23
24 namespace tensorflow {
25 namespace data {
26
27 namespace {
28 constexpr char kPortPlaceholder[] = "%port%";
29 }
30
GrpcDataServerBase(int port,const std::string & protocol,const std::string server_type)31 GrpcDataServerBase::GrpcDataServerBase(int port, const std::string& protocol,
32 const std::string server_type)
33 : requested_port_(port),
34 protocol_(protocol),
35 server_type_(server_type),
36 bound_port_(port) {}
37
Start()38 Status GrpcDataServerBase::Start() {
39 if (stopped_) {
40 return errors::FailedPrecondition(
41 "Server cannot be started after it has been stopped.");
42 }
43 if (started_) {
44 return Status::OK();
45 }
46 ::grpc::ServerBuilder builder;
47 std::shared_ptr<::grpc::ServerCredentials> credentials;
48 TF_RETURN_IF_ERROR(
49 CredentialsFactory::CreateServerCredentials(protocol_, &credentials));
50 builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port_),
51 credentials, &bound_port_);
52 builder.SetMaxReceiveMessageSize(-1);
53
54 AddDataServiceToBuilder(builder);
55 AddProfilerServiceToBuilder(builder);
56 server_ = builder.BuildAndStart();
57 if (!server_) {
58 return errors::Internal("Could not start gRPC server");
59 }
60
61 TF_RETURN_IF_ERROR(StartServiceInternal());
62
63 started_ = true;
64 LOG(INFO) << "Started tf.data " << server_type_
65 << " running at 0.0.0.0:" << BoundPort();
66 return Status::OK();
67 }
68
Stop()69 void GrpcDataServerBase::Stop() {
70 if (stopped_) {
71 return;
72 }
73 if (server_) {
74 StopServiceInternal();
75 server_->Shutdown();
76 LOG(INFO) << "Shut down " << server_type_ << " server running at port "
77 << BoundPort();
78 }
79 stopped_ = true;
80 }
81
Join()82 void GrpcDataServerBase::Join() { server_->Wait(); }
83
BoundPort()84 int GrpcDataServerBase::BoundPort() { return bound_port(); }
85
AddProfilerServiceToBuilder(::grpc::ServerBuilder & builder)86 void GrpcDataServerBase::AddProfilerServiceToBuilder(
87 ::grpc::ServerBuilder& builder) {
88 profiler_service_ = profiler::CreateProfilerService();
89 builder.RegisterService(profiler_service_.get());
90 }
91
DispatchGrpcDataServer(const experimental::DispatcherConfig & config)92 DispatchGrpcDataServer::DispatchGrpcDataServer(
93 const experimental::DispatcherConfig& config)
94 : GrpcDataServerBase(config.port(), config.protocol(), "DispatchServer"),
95 config_(config) {}
96
~DispatchGrpcDataServer()97 DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
98
AddDataServiceToBuilder(::grpc::ServerBuilder & builder)99 void DispatchGrpcDataServer::AddDataServiceToBuilder(
100 ::grpc::ServerBuilder& builder) {
101 service_ = absl::make_unique<GrpcDispatcherImpl>(config_, builder).release();
102 }
103
StartServiceInternal()104 Status DispatchGrpcDataServer::StartServiceInternal() {
105 return service_->Start();
106 }
107
NumWorkers(int * num_workers)108 Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
109 GetWorkersRequest req;
110 GetWorkersResponse resp;
111 ::grpc::ServerContext ctx;
112 ::grpc::Status s = service_->GetWorkers(&ctx, &req, &resp);
113 if (!s.ok()) {
114 return grpc_util::WrapError("Failed to get workers", s);
115 }
116 *num_workers = resp.workers_size();
117 return Status::OK();
118 }
119
WorkerGrpcDataServer(const experimental::WorkerConfig & config)120 WorkerGrpcDataServer::WorkerGrpcDataServer(
121 const experimental::WorkerConfig& config)
122 : GrpcDataServerBase(config.port(), config.protocol(), "WorkerServer"),
123 config_(config) {}
124
~WorkerGrpcDataServer()125 WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
126
AddDataServiceToBuilder(::grpc::ServerBuilder & builder)127 void WorkerGrpcDataServer::AddDataServiceToBuilder(
128 ::grpc::ServerBuilder& builder) {
129 service_ = absl::make_unique<GrpcWorkerImpl>(config_, builder).release();
130 }
131
StartServiceInternal()132 Status WorkerGrpcDataServer::StartServiceInternal() {
133 std::string base_address = config_.worker_address();
134 if (base_address.empty()) {
135 base_address = absl::StrCat("localhost:", kPortPlaceholder);
136 }
137 std::string worker_address = str_util::StringReplace(
138 base_address, kPortPlaceholder, absl::StrCat(bound_port()),
139 /*replace_all=*/false);
140 std::string transfer_address = worker_address;
141 std::string transfer_protocol = config_.data_transfer_protocol();
142 if (!transfer_protocol.empty() && transfer_protocol != "grpc") {
143 TF_RETURN_IF_ERROR(DataTransferServer::Build(
144 transfer_protocol, service_->get_element_getter(), &transfer_server_));
145 TF_RETURN_IF_ERROR(transfer_server_->Start());
146 LOG(INFO) << "Data transfer server started at 0.0.0.0:"
147 << transfer_server_->get_port();
148 transfer_address = str_util::StringReplace(
149 config_.data_transfer_address(), kPortPlaceholder,
150 absl::StrCat(transfer_server_->get_port()),
151 /*replace_all=*/false);
152 }
153 TF_RETURN_IF_ERROR(service_->Start(worker_address, transfer_address));
154 return Status::OK();
155 }
156
StopServiceInternal()157 void WorkerGrpcDataServer::StopServiceInternal() { service_->Stop(); }
158
NumTasks(int * num_tasks)159 Status WorkerGrpcDataServer::NumTasks(int* num_tasks) {
160 GetWorkerTasksRequest req;
161 GetWorkerTasksResponse resp;
162 ::grpc::ServerContext ctx;
163 ::grpc::Status s = service_->GetWorkerTasks(&ctx, &req, &resp);
164 if (!s.ok()) {
165 return grpc_util::WrapError("Failed to get tasks", s);
166 }
167 *num_tasks = resp.tasks_size();
168 return Status::OK();
169 }
170
NewDispatchServer(const experimental::DispatcherConfig & config,std::unique_ptr<DispatchGrpcDataServer> & out_server)171 Status NewDispatchServer(const experimental::DispatcherConfig& config,
172 std::unique_ptr<DispatchGrpcDataServer>& out_server) {
173 out_server = absl::make_unique<DispatchGrpcDataServer>(config);
174 return Status::OK();
175 }
176
NewWorkerServer(const experimental::WorkerConfig & config,std::unique_ptr<WorkerGrpcDataServer> & out_server)177 Status NewWorkerServer(const experimental::WorkerConfig& config,
178 std::unique_ptr<WorkerGrpcDataServer>& out_server) {
179 out_server = absl::make_unique<WorkerGrpcDataServer>(config);
180 return Status::OK();
181 }
182
183 } // namespace data
184 } // namespace tensorflow
185