1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
17
18 #include <utility>
19
20 #include "absl/time/clock.h"
21 #include "absl/time/time.h"
22 #include "tensorflow/core/distributed_runtime/call_options.h"
23 #include "tensorflow/core/distributed_runtime/master_interface.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/tracing.h"
31 #include "tensorflow/core/profiler/lib/traceme.h"
32 #include "tensorflow/core/protobuf/master.pb.h"
33
34 namespace tensorflow {
35
36 // GrpcRemoteMaster is an implementation of the MasterInterface
37 // that uses gRPC to talk to the Master service.
38 class GrpcRemoteMaster : public MasterInterface {
39 using MasterServiceStub = grpc::MasterService::Stub;
40
41 public:
GrpcRemoteMaster(const SharedGrpcChannelPtr & client_channel)42 explicit GrpcRemoteMaster(const SharedGrpcChannelPtr& client_channel)
43 : stub_(grpc::MasterService::NewStub(client_channel)) {}
44
~GrpcRemoteMaster()45 ~GrpcRemoteMaster() override {}
46
CreateSession(CallOptions * call_options,const CreateSessionRequest * request,CreateSessionResponse * response)47 Status CreateSession(CallOptions* call_options,
48 const CreateSessionRequest* request,
49 CreateSessionResponse* response) override {
50 return CallWithRetry(call_options, request, response,
51 &MasterServiceStub::CreateSession);
52 }
53
ExtendSession(CallOptions * call_options,const ExtendSessionRequest * request,ExtendSessionResponse * response)54 Status ExtendSession(CallOptions* call_options,
55 const ExtendSessionRequest* request,
56 ExtendSessionResponse* response) override {
57 return CallWithRetry(call_options, request, response,
58 &MasterServiceStub::ExtendSession);
59 }
60
PartialRunSetup(CallOptions * call_options,const PartialRunSetupRequest * request,PartialRunSetupResponse * response)61 Status PartialRunSetup(CallOptions* call_options,
62 const PartialRunSetupRequest* request,
63 PartialRunSetupResponse* response) override {
64 return CallWithRetry(call_options, request, response,
65 &MasterServiceStub::PartialRunSetup);
66 }
67
RunStep(CallOptions * call_options,RunStepRequestWrapper * request,MutableRunStepResponseWrapper * response)68 Status RunStep(CallOptions* call_options, RunStepRequestWrapper* request,
69 MutableRunStepResponseWrapper* response) override {
70 return CallWithRetry(call_options, &request->ToProto(),
71 get_proto_from_wrapper(response),
72 &MasterServiceStub::RunStep, "RunStep/Client");
73 }
74
CloseSession(CallOptions * call_options,const CloseSessionRequest * request,CloseSessionResponse * response)75 Status CloseSession(CallOptions* call_options,
76 const CloseSessionRequest* request,
77 CloseSessionResponse* response) override {
78 return CallWithRetry(call_options, request, response,
79 &MasterServiceStub::CloseSession);
80 }
81
ListDevices(CallOptions * call_options,const ListDevicesRequest * request,ListDevicesResponse * response)82 Status ListDevices(CallOptions* call_options,
83 const ListDevicesRequest* request,
84 ListDevicesResponse* response) override {
85 return CallWithRetry(call_options, request, response,
86 &MasterServiceStub::ListDevices);
87 }
88
Reset(CallOptions * call_options,const ResetRequest * request,ResetResponse * response)89 Status Reset(CallOptions* call_options, const ResetRequest* request,
90 ResetResponse* response) override {
91 return CallWithRetry(call_options, request, response,
92 &MasterServiceStub::Reset);
93 }
94
MakeCallable(CallOptions * call_options,const MakeCallableRequest * request,MakeCallableResponse * response)95 Status MakeCallable(CallOptions* call_options,
96 const MakeCallableRequest* request,
97 MakeCallableResponse* response) override {
98 return CallWithRetry(call_options, request, response,
99 &MasterServiceStub::MakeCallable);
100 }
RunCallable(CallOptions * call_options,const RunCallableRequest * request,RunCallableResponse * response)101 Status RunCallable(CallOptions* call_options,
102 const RunCallableRequest* request,
103 RunCallableResponse* response) override {
104 return CallWithRetry(call_options, request, response,
105 &MasterServiceStub::RunCallable);
106 }
ReleaseCallable(CallOptions * call_options,const ReleaseCallableRequest * request,ReleaseCallableResponse * response)107 Status ReleaseCallable(CallOptions* call_options,
108 const ReleaseCallableRequest* request,
109 ReleaseCallableResponse* response) override {
110 return CallWithRetry(call_options, request, response,
111 &MasterServiceStub::ReleaseCallable);
112 }
113
114 private:
115 // Start tracing, attaching a unique ID to both the trace and the RPC.
NewTraceRpc(StringPiece name,::grpc::ClientContext * ctx)116 profiler::TraceMe* NewTraceRpc(StringPiece name, ::grpc::ClientContext* ctx) {
117 string trace_id = strings::StrCat(tracing::GetUniqueArg());
118 ctx->AddMetadata(GrpcIdKey(), trace_id);
119 return new profiler::TraceMe(
120 [&] { return strings::StrCat(name, ":", trace_id); },
121 profiler::TraceMeLevel::kInfo);
122 }
123
124 template <typename Request, typename Response>
CallWithRetry(CallOptions * call_options,const Request * request,Response * response,::grpc::Status (MasterServiceStub::* pfunc)(::grpc::ClientContext *,const Request &,Response *),string trace_string={})125 Status CallWithRetry(CallOptions* call_options, const Request* request,
126 Response* response,
127 ::grpc::Status (MasterServiceStub::*pfunc)(
128 ::grpc::ClientContext*, const Request&, Response*),
129 string trace_string = {}) {
130 absl::Duration timeout = absl::Milliseconds(call_options->GetTimeout());
131 absl::Time expired_time = absl::FromUnixMicros(Env::Default()->NowMicros());
132 if (timeout > absl::ZeroDuration()) {
133 expired_time += timeout;
134 }
135 Status s;
136 for (int num_retries = 0;; ++num_retries) {
137 ::grpc::ClientContext ctx;
138 std::unique_ptr<profiler::TraceMe> trace;
139 if (!trace_string.empty()) {
140 trace.reset(NewTraceRpc(trace_string, &ctx));
141 }
142 ctx.set_fail_fast(false);
143 if (timeout > absl::ZeroDuration()) {
144 // We do not modify the timeout here to match legacy behavior. However,
145 // this could violate the contract of tensorflow::Session. If we retry
146 // an RPC just before the deadline is exceeded, we will still set the
147 // timeout to the original value. This leads to the overall timeout
148 // being double what was expected.
149 // TODO(b/117162170): investigate fixing this behavior for legacy and
150 // gRPC RPC layers.
151 ctx.set_deadline(absl::ToChronoTime(absl::Now() + timeout));
152 }
153 s = FromGrpcStatus((stub_.get()->*pfunc)(&ctx, *request, response));
154 if (!errors::IsUnavailable(s)) {
155 return s;
156 }
157 // TODO(b/117162170): we may want to make this configurable.
158 constexpr int kMaxRetries = 10;
159 LOG(WARNING) << "RPC failed with status = \"" << s
160 << "\" and grpc_error_string = \""
161 << ctx.debug_error_string() << "\", maybe retrying the RPC";
162 if (num_retries >= kMaxRetries) {
163 LOG(WARNING) << "Too many retries, returning last status: " << s;
164 return s;
165 }
166 absl::Time now = absl::FromUnixMicros(Env::Default()->NowMicros());
167 const absl::Time deadline_with_backoff =
168 now + absl::Microseconds(ComputeBackoffMicroseconds(num_retries));
169 // Wait for a short period of time before retrying the RPC. If our
170 // backoff would put us past the RPC deadline, we truncate it to ensure
171 // our RPC starts before the deadline.
172 const auto backoff_until = (timeout <= absl::ZeroDuration() ||
173 expired_time > deadline_with_backoff)
174 ? deadline_with_backoff
175 : expired_time;
176 Env::Default()->SleepForMicroseconds(
177 absl::ToInt64Microseconds(backoff_until - now));
178 now = absl::FromUnixMicros(Env::Default()->NowMicros());
179 if (now > expired_time && timeout > absl::ZeroDuration()) {
180 // If timeout_in_ms is set, exit the retry loop on timeout.
181 return errors::DeadlineExceeded(ctx.debug_error_string());
182 }
183 }
184 }
185
186 std::unique_ptr<MasterServiceStub> stub_;
187 };
188
NewGrpcMaster(const SharedGrpcChannelPtr & channel)189 MasterInterface* NewGrpcMaster(const SharedGrpcChannelPtr& channel) {
190 return new GrpcRemoteMaster(channel);
191 }
192
193 } // namespace tensorflow
194