1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
17 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
18 #include "tensorflow/core/lib/random/random.h"
19
20 namespace tensorflow {
21
22 namespace {
23
GenerateUniformRandomNumber()24 double GenerateUniformRandomNumber() {
25 return random::New64() * (1.0 / std::numeric_limits<uint64>::max());
26 }
27
GenerateUniformRandomNumberBetween(double a,double b)28 double GenerateUniformRandomNumberBetween(double a, double b) {
29 if (a == b) return a;
30 DCHECK_LT(a, b);
31 return a + GenerateUniformRandomNumber() * (b - a);
32 }
33
34 } // namespace
35
ComputeBackoffMicroseconds(int current_retry_attempt,int64 min_delay,int64 max_delay)36 int64 ComputeBackoffMicroseconds(int current_retry_attempt, int64 min_delay,
37 int64 max_delay) {
38 DCHECK_GE(current_retry_attempt, 0);
39
40 // This function with the constants below is calculating:
41 //
42 // (0.4 * min_delay) + (random[0.6,1.0] * min_delay * 1.3^retries)
43 //
44 // Note that there is an extra truncation that occurs and is documented in
45 // comments below.
46 constexpr double kBackoffBase = 1.3;
47 constexpr double kBackoffRandMult = 0.4;
48
49 // This first term does not vary with current_retry_attempt or a random
50 // number. It exists to ensure the final term is >= min_delay
51 const double first_term = kBackoffRandMult * min_delay;
52
53 // This is calculating min_delay * 1.3^retries
54 double uncapped_second_term = min_delay;
55 while (current_retry_attempt > 0 &&
56 uncapped_second_term < max_delay - first_term) {
57 current_retry_attempt--;
58 uncapped_second_term *= kBackoffBase;
59 }
60 // Note that first_term + uncapped_second_term can exceed max_delay here
61 // because of the final multiply by kBackoffBase. We fix that problem with
62 // the min() below.
63 double second_term = std::min(uncapped_second_term, max_delay - first_term);
64
65 // This supplies the random jitter to ensure that retried don't cause a
66 // thundering herd problem.
67 second_term *=
68 GenerateUniformRandomNumberBetween(1.0 - kBackoffRandMult, 1.0);
69
70 return std::max(static_cast<int64>(first_term + second_term), min_delay);
71 }
72
GrpcMaybeUnparseProto(const protobuf::Message & src,grpc::ByteBuffer * dst)73 ::grpc::Status GrpcMaybeUnparseProto(const protobuf::Message& src,
74 grpc::ByteBuffer* dst) {
75 bool own_buffer;
76 return ::grpc::GenericSerialize<::grpc::ProtoBufferWriter,
77 protobuf::Message>(src, dst, &own_buffer);
78 }
79
80 // GrpcMaybeUnparseProto from a string simply copies the string to the
81 // ByteBuffer.
GrpcMaybeUnparseProto(const string & src,grpc::ByteBuffer * dst)82 ::grpc::Status GrpcMaybeUnparseProto(const string& src, grpc::ByteBuffer* dst) {
83 ::grpc::Slice s(src.data(), src.size());
84 ::grpc::ByteBuffer buffer(&s, 1);
85 dst->Swap(&buffer);
86 return ::grpc::Status::OK;
87 }
88
GrpcMaybeParseProto(::grpc::ByteBuffer * src,protobuf::Message * dst)89 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, protobuf::Message* dst) {
90 ::grpc::ProtoBufferReader reader(src);
91 return dst->ParseFromZeroCopyStream(&reader);
92 }
93
94 // Overload of GrpcParseProto so we can decode a TensorResponse without
95 // extra copying. This overload is used by the RPCState class in
96 // grpc_state.h.
GrpcMaybeParseProto(::grpc::ByteBuffer * src,TensorResponse * dst)97 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, TensorResponse* dst) {
98 ::tensorflow::GrpcByteSource byte_source(src);
99 auto s = dst->ParseFrom(&byte_source);
100 return s.ok();
101 }
102
103 // GrpcMaybeParseProto simply copies bytes into the string.
GrpcMaybeParseProto(grpc::ByteBuffer * src,string * dst)104 bool GrpcMaybeParseProto(grpc::ByteBuffer* src, string* dst) {
105 dst->clear();
106 dst->reserve(src->Length());
107 std::vector<::grpc::Slice> slices;
108 if (!src->Dump(&slices).ok()) {
109 return false;
110 }
111 for (const ::grpc::Slice& s : slices) {
112 dst->append(reinterpret_cast<const char*>(s.begin()), s.size());
113 }
114 return true;
115 }
116
117 // GrpcMaybeParseProto simply copies bytes into the tstring.
GrpcMaybeParseProto(grpc::ByteBuffer * src,tstring * dst)118 bool GrpcMaybeParseProto(grpc::ByteBuffer* src, tstring* dst) {
119 dst->clear();
120 dst->reserve(src->Length());
121 std::vector<::grpc::Slice> slices;
122 if (!src->Dump(&slices).ok()) {
123 return false;
124 }
125 for (const ::grpc::Slice& s : slices) {
126 dst->append(reinterpret_cast<const char*>(s.begin()), s.size());
127 }
128 return true;
129 }
130
131 } // namespace tensorflow
132