1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
18
19 #include <memory>
20
21 #include "grpcpp/grpcpp.h"
22 #include "grpcpp/impl/codegen/proto_utils.h"
23 #include "grpcpp/support/byte_buffer.h"
24 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/strings/stringprintf.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/protobuf.h"
29
30 namespace tensorflow {
31
32 // Given the total number of RPC retries attempted, return a randomized
33 // amount of time to delay before retrying the request.
34 //
35 // The average computed backoff increases with the number of RPCs attempted.
36 // See implementation for details on the calculations.
37 int64 ComputeBackoffMicroseconds(int current_retry_attempt,
38 int64 min_delay = 1000,
39 int64 max_delay = 10000000);
40
41 // Thin wrapper around ::grpc::ProtoBufferReader to give TensorResponse an
42 // efficient byte reader from which to decode a RecvTensorResponse.
43 class GrpcByteSource : public TensorResponse::Source {
44 public:
GrpcByteSource(::grpc::ByteBuffer * buffer)45 explicit GrpcByteSource(::grpc::ByteBuffer* buffer) : buffer_(buffer) {}
~GrpcByteSource()46 ~GrpcByteSource() override { DeleteStream(); }
47
48 typedef ::grpc::ProtoBufferReader Reader;
49
contents()50 protobuf::io::ZeroCopyInputStream* contents() override {
51 DeleteStream();
52 stream_ = new (&space_) Reader(buffer_);
53 return stream_;
54 }
55
56 private:
DeleteStream()57 void DeleteStream() {
58 if (stream_) {
59 stream_->~Reader();
60 }
61 }
62
63 ::grpc::ByteBuffer* buffer_; // Not owned
64 Reader* stream_ = nullptr; // Points into space_ if non-nullptr
65 char space_[sizeof(Reader)];
66 };
67
68 constexpr char kStreamRemovedMessage[] = "Stream removed";
69
70 // Identify if the given grpc::Status corresponds to an HTTP stream removed
71 // error (see chttp2_transport.cc).
72 //
73 // When auto-reconnecting to a remote TensorFlow worker after it restarts, gRPC
74 // can return an UNKNOWN error code with a "Stream removed" error message.
75 // This should not be treated as an unrecoverable error.
76 //
77 // N.B. This is dependent on the error message from grpc remaining consistent.
IsStreamRemovedError(const::grpc::Status & s)78 inline bool IsStreamRemovedError(const ::grpc::Status& s) {
79 return !s.ok() && s.error_code() == ::grpc::StatusCode::UNKNOWN &&
80 s.error_message() == kStreamRemovedMessage;
81 }
82
FromGrpcStatus(const::grpc::Status & s)83 inline Status FromGrpcStatus(const ::grpc::Status& s) {
84 if (s.ok()) {
85 return Status::OK();
86 } else {
87 // Convert "UNKNOWN" stream removed errors into unavailable, to allow
88 // for retry upstream.
89 if (IsStreamRemovedError(s)) {
90 return Status(tensorflow::error::UNAVAILABLE, s.error_message());
91 }
92 return Status(static_cast<tensorflow::error::Code>(s.error_code()),
93 s.error_message());
94 }
95 }
96
ToGrpcStatus(const::tensorflow::Status & s)97 inline ::grpc::Status ToGrpcStatus(const ::tensorflow::Status& s) {
98 if (s.ok()) {
99 return ::grpc::Status::OK;
100 } else {
101 if (s.error_message().size() > 3072 /* 3k bytes */) {
102 // TODO(b/62947679): Remove truncation once the gRPC issue is resolved.
103 string scratch =
104 strings::Printf("%.3072s ... [truncated]", s.error_message().c_str());
105 LOG(ERROR) << "Truncated error message: " << s;
106 return ::grpc::Status(static_cast<::grpc::StatusCode>(s.code()), scratch);
107 }
108 return ::grpc::Status(static_cast<::grpc::StatusCode>(s.code()),
109 s.error_message());
110 }
111 }
112
113 typedef std::shared_ptr<::grpc::Channel> SharedGrpcChannelPtr;
114
GrpcIdKey()115 inline string GrpcIdKey() { return "tf-rpc"; }
116
117 // Serialize src and store in *dst.
118 ::grpc::Status GrpcMaybeUnparseProto(const protobuf::Message& src,
119 ::grpc::ByteBuffer* dst);
120
121 // Parse contents of src and initialize *dst with them.
122 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, protobuf::Message* dst);
123
124 // Specialization for TensorResponse
125 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, TensorResponse* dst);
126
127 // Copy string src to grpc buffer *dst.
128 ::grpc::Status GrpcMaybeUnparseProto(const string& src,
129 ::grpc::ByteBuffer* dst);
130
131 // Copy grpc buffer src to string *dst.
132 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, string* dst);
133
134 // Copy grpc buffer src to tstring *dst.
135 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, tstring* dst);
136
137 } // namespace tensorflow
138
139 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
140