• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
17 #include "grpcpp/support/byte_buffer.h"
18 #include "grpcpp/support/slice.h"
19 #include "tensorflow/core/common_runtime/dma_helper.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/framework/tensor_reference.h"
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/lib/gtl/inlined_vector.h"
25 #include "tensorflow/core/lib/io/proto_encode_helper.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/protobuf/worker.pb.h"
28 
29 // (Omitted internal-only flag)
30 
31 namespace tensorflow {
32 namespace grpc {
33 
EncodeRecvTensorResponseToByteBuffer(const RecvTensorResponse & proto,::grpc::ByteBuffer * result)34 void EncodeRecvTensorResponseToByteBuffer(const RecvTensorResponse& proto,
35                                           ::grpc::ByteBuffer* result) {
36   ::grpc::Slice slice(proto.ByteSizeLong());
37   proto.SerializeWithCachedSizesToArray(
38       const_cast<uint8*>(reinterpret_cast<const uint8*>(slice.begin())));
39   ::grpc::ByteBuffer tmp(&slice, 1);
40   result->Swap(&tmp);
41 }
42 
43 // We generate a RecvTensorResponse protocol buffer encoding into "*result",
44 // but where possible, we share the underlying Tensor buffer for "val", to
45 // avoid an extra copy.
46 //
47 // We hand-encode the protocol buffer data in the following order, as follows:
48 //
49 // Let R be a RecvTensorResponse object we want to encode, logically
50 // constructed by filling in data from "is_dead" and "val" and filling
51 // in a few other fields as well.
52 //
53 // (Letters here are used in the code to refer back to which part of the
54 //  encoding the code is generating).
55 //
56 // A:   <protocol buffer encoding of fields except R.tensor()>
57 // B1:  <tag encoding for RecvTensorResponse::tensor>
58 // B2:  <varint32 length of R.tensor() sub message>
59 // C:   <protocol buffer encoding of R.tensor() except for
60 //          R.tensor().tensor_content()>
61 // D1:  <tag encoding for TensorProto::tensor_content>
62 // D2:  <varint32 length of R.tensor().tensor_content() data>
63 // E:   <actual data for val's representation>
64 //
65 // If the tensor data is up to "kLargeTensorBytes", then A
66 // through E will all be encoded into "*result" in a single grpc::Slice.
67 //
68 // If the tensor data is larger than "kLargeTensorBytes", then A through
69 // D2 will be encoded in one grpc::Slice, and E will be encoded in a second
70 // grpc::Slice that points to the backing store for the tensor data, to avoid
71 // copying the tensor data (and the grpc::Slice setup will be arrange so as
72 // to dereference the underlying tensor data buffer when it is no longer
73 // needed in the "*result" ByteBuffer).
VarLengthEncodingSize(uint32 tag,size_t bytes)74 static int VarLengthEncodingSize(uint32 tag, size_t bytes) {
75   return core::VarintLength(tag << 3) + core::VarintLength(bytes) + bytes;
76 }
77 
78 // Returns an upper bound in bytes of the protocol buffer encoding of
79 // the "skeleton" of "val" (all the data needed for dtype and the shape,
80 // but not the actual contents of "val").
SkeletonEncodingSizeUpperBound(const Tensor & val)81 static int SkeletonEncodingSizeUpperBound(const Tensor& val) {
82   static const int kVarintMax64 = 10;  // Max length of varint64 encoding
83   const int ndims = val.shape().dims();
84   return (2 * kVarintMax64) +           // dtype
85          (ndims * (4 * kVarintMax64));  // Shape: 4 varints per dim
86 }
87 
88 // Encode the skeleton for "val" (the encoded TensorProto contents
89 // (dtype and shape, but not the actual data) into "*e".  The backing
90 // store for "*e" must be of appropriate size to hold this encoding.
EncodeSkeleton(const Tensor & val,io::ProtoEncodeHelper * e)91 static void EncodeSkeleton(const Tensor& val, io::ProtoEncodeHelper* e) {
92   // Encode val.dtype()
93   e->WriteUint64(TensorProto::kDtypeFieldNumber, val.dtype());
94 
95   // Compute length of val.shape() proto encoding
96   const int ndims = val.shape().dims();
97   int tensor_shape_bytes = 0;
98   for (int d = 0; d < ndims; d++) {
99     int64 dim_size = val.shape().dim_size(d);
100     tensor_shape_bytes +=
101         2 +  // TensorShapeProto dim tag + varintlength of submessage
102         1 +  // TensorShapeProto_Dim::kSizeFieldNumber
103         core::VarintLength(dim_size);
104   }
105 
106   if (tensor_shape_bytes > 0) {
107     e->WriteVarlengthBeginning(TensorProto::kTensorShapeFieldNumber,
108                                tensor_shape_bytes);
109     // Encode val.shape()
110     for (int d = 0; d < ndims; d++) {
111       int64 dim_size = val.shape().dim_size(d);
112       int64 dim_varlen = 1 +  // TensorShapeProto_Dim::kSizeFieldNumber
113                          core::VarintLength(dim_size);
114       e->WriteVarlengthBeginning(TensorShapeProto::kDimFieldNumber, dim_varlen);
115       e->WriteUint64(TensorShapeProto_Dim::kSizeFieldNumber, dim_size);
116     }
117   }
118 
119 #ifndef NDEBUG
120   {
121     // Debug-mode only check to make sure the encoding above is
122     // identical to the auto-generated protocol buffer encoding.
123     TensorProto skeleton;
124     skeleton.set_dtype(val.dtype());
125     val.shape().AsProto(skeleton.mutable_tensor_shape());
126     string tensor_except_contents;  // tensor() field except contents
127     skeleton.AppendToString(&tensor_except_contents);
128     TensorProto skeleton2;
129     skeleton2.ParseFromString(string(e->data(), e->size()));
130     string out;
131     skeleton.AppendToString(&out);
132     DCHECK_EQ(tensor_except_contents, out) << skeleton.DebugString() << " vs\n"
133                                            << skeleton2.DebugString();
134   }
135 #endif
136 }
137 
EncodeTensorToByteBuffer(bool is_dead,const Tensor & val,::grpc::ByteBuffer * result)138 void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
139                               ::grpc::ByteBuffer* result) {
140   const int kLargeTensorBytes = 1024;
141   RecvTensorResponse response;
142   if (is_dead) {
143     response.set_is_dead(is_dead);
144   }
145   response.set_send_start_micros(Env::Default()->NowMicros());
146   if (!DataTypeCanUseMemcpy(val.dtype())) {
147     // Straightforward but slow path for complicated kinds of tensor data
148     // TODO(jeff,sanjay): If this becomes an issue, we could
149     // go directly from val -> ByteBuffer, with some effort.
150     val.AsProtoTensorContent(response.mutable_tensor());
151 
152     // Encode full protocol buffer to a ByteBuffer
153     EncodeRecvTensorResponseToByteBuffer(response, result);
154   } else {
155     // skeleton is the encoded TensorProto contents (dtype and shape), but
156     // not the actual data
157     gtl::InlinedVector<char, 128> skeleton(SkeletonEncodingSizeUpperBound(val));
158     io::ProtoEncodeHelper e_skeleton(skeleton.data(), skeleton.size());
159     EncodeSkeleton(val, &e_skeleton);
160 
161     StringPiece tdata = val.tensor_data();
162     uint32 overall_tensor_proto_bytesize =
163         (e_skeleton.size() +
164          VarLengthEncodingSize(TensorProto::kTensorContentFieldNumber,
165                                tdata.size()));
166     string header;  // All of RecvTensorResponse except the tensor() field
167     response.AppendToString(&header);
168 
169     size_t expected_size =
170         (header.size() +
171          VarLengthEncodingSize(RecvTensorResponse::kTensorFieldNumber,
172                                overall_tensor_proto_bytesize));
173     // If "share_tensor_slice_memory == false", we copy the tensor data to
174     // the end of the buffer we are preparing that holds the rest of the
175     // RecvTensorResponse protocol buffer.
176     //
177     // If "share_tensor_slice_memory == true", we arrange to share the
178     // backing store of the data by creating a slice that also points to the
179     // backing store, with appropriate reference counts to keep the
180     // backing store alive as needed.
181     //
182     // We enable this behavior if the tensor is large.
183     bool share_tensor_slice_memory = (tdata.size() > kLargeTensorBytes);
184 
185     // (Omitted internal-only conditional)
186 
187     size_t encoder_size = expected_size - tdata.size();
188 
189     // Encode all but the actual "tdata", but including the tag and
190     // varlength header for the "tdata"
191     gtl::InlinedVector<char, 1024> space(encoder_size);
192     io::ProtoEncodeHelper e(space.data(), space.size());
193     // (A)
194     e.WriteRawBytes(header);
195 
196     // (B1) & (B2)
197     e.WriteVarlengthBeginning(RecvTensorResponse::kTensorFieldNumber,
198                               overall_tensor_proto_bytesize);
199     // (C)
200     e.WriteRawBytes(StringPiece(e_skeleton.data(), e_skeleton.size()));
201     // (D1) & (D2)
202     e.WriteVarlengthBeginning(TensorProto::kTensorContentFieldNumber,
203                               tdata.size());
204 
205     // All but the tensor backing store are serialized now
206 
207     // Now allocate memory and put into the ByteBuffer
208     ::grpc::Slice slices[2];
209     int num_slices = 0;
210     {
211       size_t slice_len =
212           e.size() + (share_tensor_slice_memory ? 0 : tdata.size());
213       slices[0] = ::grpc::Slice(slice_len);
214       memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
215       if (!share_tensor_slice_memory) {
216         // (E)
217         memcpy(const_cast<uint8_t*>(slices[0].begin()) + e.size(), tdata.data(),
218                tdata.size());
219       }
220       num_slices += 1;
221     }
222 
223     if (share_tensor_slice_memory) {
224       // (E) Encode tensor data, but by sharing backing store
225       const TensorBuffer* buf = DMAHelper::buffer(&val);
226       buf->Ref();
227       slices[1] = ::grpc::Slice(
228           const_cast<void*>(static_cast<const void*>(tdata.data())),
229           tdata.size(),
230           [](void* backing) { static_cast<TensorBuffer*>(backing)->Unref(); },
231           const_cast<TensorBuffer*>(buf));
232       num_slices += 1;
233     }
234     size_t total_bytes = 0;
235     for (int i = 0; i < num_slices; i++) {
236       total_bytes += slices[i].size();
237     }
238     CHECK_EQ(total_bytes, expected_size);
239 
240     ::grpc::ByteBuffer tmp(&slices[0], num_slices);
241     result->Swap(&tmp);
242   }
243 }
244 
245 }  // namespace grpc
246 }  // namespace tensorflow
247