• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
17 
18 #include "grpcpp/support/byte_buffer.h"
19 #include "grpcpp/support/slice.h"
20 #include "absl/flags/flag.h"
21 #include "tensorflow/core/common_runtime/dma_helper.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_reference.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/lib/gtl/inlined_vector.h"
27 #include "tensorflow/core/lib/io/proto_encode_helper.h"
28 #include "tensorflow/core/platform/env.h"
29 #include "tensorflow/core/protobuf/worker.pb.h"
30 
31 // (Omitted internal-only flag)
32 
33 namespace tensorflow {
34 namespace grpc {
35 
EncodeRecvTensorResponseToByteBuffer(const RecvTensorResponse & proto,::grpc::ByteBuffer * result)36 void EncodeRecvTensorResponseToByteBuffer(const RecvTensorResponse& proto,
37                                           ::grpc::ByteBuffer* result) {
38   ::grpc::Slice slice(proto.ByteSizeLong());
39   proto.SerializeWithCachedSizesToArray(
40       const_cast<uint8*>(reinterpret_cast<const uint8*>(slice.begin())));
41   ::grpc::ByteBuffer tmp(&slice, 1);
42   result->Swap(&tmp);
43 }
44 
45 // We generate a RecvTensorResponse protocol buffer encoding into "*result",
46 // but where possible, we share the underlying Tensor buffer for "val", to
47 // avoid an extra copy.
48 //
49 // We hand-encode the protocol buffer data in the following order, as follows:
50 //
51 // Let R be a RecvTensorResponse object we want to encode, logically
52 // constructed by filling in data from "is_dead" and "val" and filling
53 // in a few other fields as well.
54 //
55 // (Letters here are used in the code to refer back to which part of the
56 //  encoding the code is generating).
57 //
58 // A:   <protocol buffer encoding of fields except R.tensor()>
59 // B1:  <tag encoding for RecvTensorResponse::tensor>
60 // B2:  <varint32 length of R.tensor() sub message>
61 // C:   <protocol buffer encoding of R.tensor() except for
62 //          R.tensor().tensor_content()>
63 // D1:  <tag encoding for TensorProto::tensor_content>
64 // D2:  <varint32 length of R.tensor().tensor_content() data>
65 // E:   <actual data for val's representation>
66 //
67 // If the tensor data is up to "kLargeTensorBytes", then A
68 // through E will all be encoded into "*result" in a single grpc::Slice.
69 //
70 // If the tensor data is larger than "kLargeTensorBytes", then A through
71 // D2 will be encoded in one grpc::Slice, and E will be encoded in a second
72 // grpc::Slice that points to the backing store for the tensor data, to avoid
73 // copying the tensor data (and the grpc::Slice setup will be arrange so as
74 // to dereference the underlying tensor data buffer when it is no longer
75 // needed in the "*result" ByteBuffer).
VarLengthEncodingSize(uint32 tag,size_t bytes)76 static int VarLengthEncodingSize(uint32 tag, size_t bytes) {
77   return core::VarintLength(tag << 3) + core::VarintLength(bytes) + bytes;
78 }
79 
80 // Returns an upper bound in bytes of the protocol buffer encoding of
81 // the "skeleton" of "val" (all the data needed for dtype and the shape,
82 // but not the actual contents of "val").
SkeletonEncodingSizeUpperBound(const Tensor & val)83 static int SkeletonEncodingSizeUpperBound(const Tensor& val) {
84   static const int kVarintMax64 = 10;  // Max length of varint64 encoding
85   const int ndims = val.shape().dims();
86   return (2 * kVarintMax64) +           // dtype
87          (ndims * (4 * kVarintMax64));  // Shape: 4 varints per dim
88 }
89 
90 // Encode the skeleton for "val" (the encoded TensorProto contents
91 // (dtype and shape, but not the actual data) into "*e".  The backing
92 // store for "*e" must be of appropriate size to hold this encoding.
EncodeSkeleton(const Tensor & val,io::ProtoEncodeHelper * e)93 static void EncodeSkeleton(const Tensor& val, io::ProtoEncodeHelper* e) {
94   // Encode val.dtype()
95   e->WriteUint64(TensorProto::kDtypeFieldNumber, val.dtype());
96 
97   // Compute length of val.shape() proto encoding
98   const int ndims = val.shape().dims();
99   int tensor_shape_bytes = 0;
100   for (int d = 0; d < ndims; d++) {
101     int64 dim_size = val.shape().dim_size(d);
102     tensor_shape_bytes +=
103         2 +  // TensorShapeProto dim tag + varintlength of submessage
104         1 +  // TensorShapeProto_Dim::kSizeFieldNumber
105         core::VarintLength(dim_size);
106   }
107 
108   if (tensor_shape_bytes > 0) {
109     e->WriteVarlengthBeginning(TensorProto::kTensorShapeFieldNumber,
110                                tensor_shape_bytes);
111     // Encode val.shape()
112     for (int d = 0; d < ndims; d++) {
113       int64 dim_size = val.shape().dim_size(d);
114       int64 dim_varlen = 1 +  // TensorShapeProto_Dim::kSizeFieldNumber
115                          core::VarintLength(dim_size);
116       e->WriteVarlengthBeginning(TensorShapeProto::kDimFieldNumber, dim_varlen);
117       e->WriteUint64(TensorShapeProto_Dim::kSizeFieldNumber, dim_size);
118     }
119   }
120 
121 #ifndef NDEBUG
122   {
123     // Debug-mode only check to make sure the encoding above is
124     // identical to the auto-generated protocol buffer encoding.
125     TensorProto skeleton;
126     skeleton.set_dtype(val.dtype());
127     val.shape().AsProto(skeleton.mutable_tensor_shape());
128     string tensor_except_contents;  // tensor() field except contents
129     skeleton.AppendToString(&tensor_except_contents);
130     TensorProto skeleton2;
131     skeleton2.ParseFromString(string(e->data(), e->size()));
132     string out;
133     skeleton.AppendToString(&out);
134     DCHECK_EQ(tensor_except_contents, out) << skeleton.DebugString() << " vs\n"
135                                            << skeleton2.DebugString();
136   }
137 #endif
138 }
139 
EncodeTensorToByteBuffer(bool is_dead,const Tensor & val,bool require_ack,::grpc::ByteBuffer * result)140 void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack,
141                               ::grpc::ByteBuffer* result) {
142   const int kLargeTensorBytes = 1024;
143   const int64 kProtoBufLimitBytes = 1LL << 31;
144 
145   if (val.TotalBytes() > kProtoBufLimitBytes) {
146     size_t exceeded_bytes = val.TotalBytes() - kProtoBufLimitBytes;
147     LOG(FATAL) << "Cannot encode a Tensor that exceeds the 2GB protobuf limit. "
148                   "Exceeded bytes: "
149                << exceeded_bytes;
150   }
151 
152   RecvTensorResponse response;
153   if (is_dead) {
154     response.set_is_dead(is_dead);
155   }
156   response.set_require_ack(require_ack);
157   response.set_send_start_micros(Env::Default()->NowMicros());
158   if (!DataTypeCanUseMemcpy(val.dtype())) {
159     // Straightforward but slow path for complicated kinds of tensor data
160     // TODO(jeff,sanjay): If this becomes an issue, we could
161     // go directly from val -> ByteBuffer, with some effort.
162     val.AsProtoTensorContent(response.mutable_tensor());
163 
164     // Encode full protocol buffer to a ByteBuffer
165     EncodeRecvTensorResponseToByteBuffer(response, result);
166   } else {
167     // skeleton is the encoded TensorProto contents (dtype and shape), but
168     // not the actual data
169     gtl::InlinedVector<char, 128> skeleton(SkeletonEncodingSizeUpperBound(val));
170     io::ProtoEncodeHelper e_skeleton(skeleton.data(), skeleton.size());
171     EncodeSkeleton(val, &e_skeleton);
172 
173     StringPiece tdata = val.tensor_data();
174     uint32 overall_tensor_proto_bytesize =
175         (e_skeleton.size() +
176          VarLengthEncodingSize(TensorProto::kTensorContentFieldNumber,
177                                tdata.size()));
178     string header;  // All of RecvTensorResponse except the tensor() field
179     response.AppendToString(&header);
180 
181     size_t expected_size =
182         (header.size() +
183          VarLengthEncodingSize(RecvTensorResponse::kTensorFieldNumber,
184                                overall_tensor_proto_bytesize));
185     // If "share_tensor_slice_memory == false", we copy the tensor data to
186     // the end of the buffer we are preparing that holds the rest of the
187     // RecvTensorResponse protocol buffer.
188     //
189     // If "share_tensor_slice_memory == true", we arrange to share the
190     // backing store of the data by creating a slice that also points to the
191     // backing store, with appropriate reference counts to keep the
192     // backing store alive as needed.
193     //
194     // We enable this behavior if the tensor is large.
195     bool share_tensor_slice_memory = (tdata.size() > kLargeTensorBytes);
196 
197     // (Omitted internal-only conditional)
198 
199     size_t encoder_size = expected_size - tdata.size();
200 
201     // Encode all but the actual "tdata", but including the tag and
202     // varlength header for the "tdata"
203     gtl::InlinedVector<char, 1024> space(encoder_size);
204     io::ProtoEncodeHelper e(space.data(), space.size());
205     // (A)
206     e.WriteRawBytes(header);
207 
208     // (B1) & (B2)
209     e.WriteVarlengthBeginning(RecvTensorResponse::kTensorFieldNumber,
210                               overall_tensor_proto_bytesize);
211     // (C)
212     e.WriteRawBytes(StringPiece(e_skeleton.data(), e_skeleton.size()));
213     // (D1) & (D2)
214     e.WriteVarlengthBeginning(TensorProto::kTensorContentFieldNumber,
215                               tdata.size());
216 
217     // All but the tensor backing store are serialized now
218 
219     // Now allocate memory and put into the ByteBuffer
220     ::grpc::Slice slices[2];
221     int num_slices = 0;
222     {
223       size_t slice_len =
224           e.size() + (share_tensor_slice_memory ? 0 : tdata.size());
225       slices[0] = ::grpc::Slice(slice_len);
226       memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
227       if (!share_tensor_slice_memory) {
228         // (E)
229         memcpy(const_cast<uint8_t*>(slices[0].begin()) + e.size(), tdata.data(),
230                tdata.size());
231       }
232       num_slices += 1;
233     }
234 
235     if (share_tensor_slice_memory) {
236       // (E) Encode tensor data, but by sharing backing store
237       const TensorBuffer* buf = DMAHelper::buffer(&val);
238       buf->Ref();
239       slices[1] = ::grpc::Slice(
240           const_cast<void*>(static_cast<const void*>(tdata.data())),
241           tdata.size(),
242           [](void* backing) { static_cast<TensorBuffer*>(backing)->Unref(); },
243           const_cast<TensorBuffer*>(buf));
244       num_slices += 1;
245     }
246     size_t total_bytes = 0;
247     for (int i = 0; i < num_slices; i++) {
248       total_bytes += slices[i].size();
249     }
250     CHECK_EQ(total_bytes, expected_size);
251 
252     ::grpc::ByteBuffer tmp(&slices[0], num_slices);
253     result->Swap(&tmp);
254   }
255 }
256 
257 }  // namespace grpc
258 }  // namespace tensorflow
259