1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // EncodeProto is a TensorFlow Op which serializes tensors into
17 // arbitrary protobufs.
18 //
19 // See the docstring in ../ops/encode_proto_op.cc for usage of the op.
20 //
21 // This implementation writes the serialized format using a handful of
22 // calls from the WireFormatLite API.
23
24 #include <memory>
25 #include <vector>
26
27 #include "third_party/eigen3/Eigen/Core"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor_types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/protobuf.h"
33 #include "tensorflow/core/util/proto/descriptors.h"
34 #include "tensorflow/core/util/proto/proto_utils.h"
35
36 namespace tensorflow {
37 namespace {
38
39 using ::tensorflow::protobuf::Descriptor;
40 using ::tensorflow::protobuf::DescriptorPool;
41 using ::tensorflow::protobuf::FieldDescriptor;
42 using ::tensorflow::protobuf::internal::WireFormatLite;
43 using ::tensorflow::protobuf::io::CodedOutputStream;
44 using ::tensorflow::protobuf::io::StringOutputStream;
45
46 // Computes the total serialized size for a packed repeated field. For
47 // fixed-size types this can just multiply, but for variable-sized types it has
48 // to iterate through the values in the tensor.
49 template <WireFormatLite::FieldType FieldType, typename TensorT>
50 size_t TotalPackedSize(const Tensor& input, int message_index, int size);
51
52 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)53 size_t TotalPackedSize<WireFormatLite::TYPE_DOUBLE, double>(const Tensor& input,
54 int message_index,
55 int size) {
56 return size * WireFormatLite::kDoubleSize;
57 }
58
59 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)60 size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, double>(const Tensor& input,
61 int message_index,
62 int size) {
63 return size * WireFormatLite::kFloatSize;
64 }
65
66 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)67 size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, float>(const Tensor& input,
68 int message_index,
69 int size) {
70 return size * WireFormatLite::kFloatSize;
71 }
72
73 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)74 size_t TotalPackedSize<WireFormatLite::TYPE_INT64, int64>(const Tensor& input,
75 int message_index,
76 int size) {
77 size_t data_size = 0;
78 auto input_t = input.flat_inner_dims<int64>();
79 for (int64 i = 0; i < size; i++) {
80 data_size += WireFormatLite::Int64Size(
81 input_t(static_cast<int64>(message_index), i));
82 }
83 return data_size;
84 }
85
86 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)87 size_t TotalPackedSize<WireFormatLite::TYPE_UINT64, uint64>(const Tensor& input,
88 int message_index,
89 int size) {
90 size_t data_size = 0;
91 auto input_t = input.flat_inner_dims<uint64>();
92 for (int64 i = 0; i < size; i++) {
93 data_size += WireFormatLite::UInt64Size(
94 input_t(static_cast<int64>(message_index), i));
95 }
96 return data_size;
97 }
98
99 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)100 size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int64>(const Tensor& input,
101 int message_index,
102 int size) {
103 size_t data_size = 0;
104 auto input_t = input.flat_inner_dims<int64>();
105 for (int64 i = 0; i < size; i++) {
106 data_size += WireFormatLite::Int32Size(
107 input_t(static_cast<int64>(message_index), i));
108 }
109 return data_size;
110 }
111
112 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)113 size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int32>(const Tensor& input,
114 int message_index,
115 int size) {
116 size_t data_size = 0;
117 auto input_t = input.flat_inner_dims<int32>();
118 for (int64 i = 0; i < size; i++) {
119 data_size += WireFormatLite::Int32Size(
120 input_t(static_cast<int64>(message_index), i));
121 }
122 return data_size;
123 }
124
125 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)126 size_t TotalPackedSize<WireFormatLite::TYPE_FIXED64, uint64>(
127 const Tensor& input, int message_index, int size) {
128 return size * WireFormatLite::kFixed64Size;
129 }
130
131 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)132 size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, uint64>(
133 const Tensor& input, int message_index, int size) {
134 return size * WireFormatLite::kFixed32Size;
135 }
136
137 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)138 size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, uint32>(
139 const Tensor& input, int message_index, int size) {
140 return size * WireFormatLite::kFixed32Size;
141 }
142
143 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)144 size_t TotalPackedSize<WireFormatLite::TYPE_BOOL, bool>(const Tensor& input,
145 int message_index,
146 int size) {
147 return size * WireFormatLite::kBoolSize;
148 }
149
150 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)151 size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, uint64>(const Tensor& input,
152 int message_index,
153 int size) {
154 size_t data_size = 0;
155 auto input_t = input.flat_inner_dims<uint64>();
156 for (int64 i = 0; i < size; i++) {
157 data_size += WireFormatLite::UInt32Size(
158 input_t(static_cast<int64>(message_index), i));
159 }
160 return data_size;
161 }
162
163 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)164 size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, uint32>(const Tensor& input,
165 int message_index,
166 int size) {
167 size_t data_size = 0;
168 auto input_t = input.flat_inner_dims<uint32>();
169 for (int64 i = 0; i < size; i++) {
170 data_size += WireFormatLite::UInt32Size(
171 input_t(static_cast<int64>(message_index), i));
172 }
173 return data_size;
174 }
175
176 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)177 size_t TotalPackedSize<WireFormatLite::TYPE_ENUM, int32>(const Tensor& input,
178 int message_index,
179 int size) {
180 size_t data_size = 0;
181 auto input_t = input.flat_inner_dims<int32>();
182 for (int64 i = 0; i < size; i++) {
183 data_size +=
184 WireFormatLite::EnumSize(input_t(static_cast<int64>(message_index), i));
185 }
186 return data_size;
187 }
188
189 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)190 size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int32>(
191 const Tensor& input, int message_index, int size) {
192 return size * WireFormatLite::kSFixed32Size;
193 }
194
195 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)196 size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int64>(
197 const Tensor& input, int message_index, int size) {
198 return size * WireFormatLite::kSFixed32Size;
199 }
200
201 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)202 size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED64, int64>(
203 const Tensor& input, int message_index, int size) {
204 return size * WireFormatLite::kSFixed64Size;
205 }
206
207 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)208 size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int32>(const Tensor& input,
209 int message_index,
210 int size) {
211 size_t data_size = 0;
212 auto input_t = input.flat_inner_dims<int32>();
213 for (int64 i = 0; i < size; i++) {
214 data_size += WireFormatLite::SInt32Size(
215 input_t(static_cast<int64>(message_index), i));
216 }
217 return data_size;
218 }
219
220 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)221 size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int64>(const Tensor& input,
222 int message_index,
223 int size) {
224 size_t data_size = 0;
225 auto input_t = input.flat_inner_dims<int64>();
226 for (int64 i = 0; i < size; i++) {
227 data_size += WireFormatLite::SInt32Size(
228 input_t(static_cast<int64>(message_index), i));
229 }
230 return data_size;
231 }
232
233 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)234 size_t TotalPackedSize<WireFormatLite::TYPE_SINT64, int64>(const Tensor& input,
235 int message_index,
236 int size) {
237 size_t data_size = 0;
238 auto input_t = input.flat_inner_dims<int64>();
239 for (int64 i = 0; i < size; i++) {
240 data_size += WireFormatLite::SInt64Size(
241 input_t(static_cast<int64>(message_index), i));
242 }
243 return data_size;
244 }
245
246 // Writes a possibly repeated primitive field. TensorFlow does not have unsigned
247 // types, so we decode them to signed and encode them back to unsigned.
248 template <typename TensorT, typename ProtoT,
249 WireFormatLite::FieldType FieldType,
250 void Writer(ProtoT, CodedOutputStream*)>
WriteField(const FieldDescriptor & field_desc,const Tensor & input,int message_index,int size,CodedOutputStream * output)251 Status WriteField(const FieldDescriptor& field_desc, const Tensor& input,
252 int message_index, int size, CodedOutputStream* output) {
253 auto wire_type = WireFormatLite::WireTypeForFieldType(
254 WireFormatLite::FieldType(field_desc.type()));
255
256 auto input_t = input.flat_inner_dims<TensorT>();
257 if (field_desc.options().packed()) {
258 // Write the tag for the packed field.
259 WireFormatLite::WriteTag(field_desc.number(),
260 WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output);
261
262 // Write the total packed length.
263 size_t data_size =
264 TotalPackedSize<FieldType, TensorT>(input, message_index, size);
265 output->WriteVarint32(data_size);
266
267 // Write individual values.
268 for (int64 i = 0; i < size; i++) {
269 // Note implicit cast from signed to unsigned.
270 const ProtoT& value = input_t(static_cast<int64>(message_index), i);
271 Writer(value, output);
272 }
273 } else {
274 for (int64 i = 0; i < size; i++) {
275 WireFormatLite::WriteTag(field_desc.number(), wire_type, output);
276
277 // Note implicit cast from signed to unsigned.
278 const ProtoT& value = input_t(static_cast<int64>(message_index), i);
279 Writer(value, output);
280 }
281 }
282 return Status::OK();
283 }
284
285 // Writes a possibly repeated string, bytes, or message field.
286 template <typename T, void Writer(int, const T&, CodedOutputStream*)>
WriteVarLenField(const FieldDescriptor & field_desc,const Tensor & input,int message_index,int size,CodedOutputStream * output)287 Status WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input,
288 int message_index, int size,
289 CodedOutputStream* output) {
290 auto input_t = input.flat_inner_dims<T>();
291 for (int64 i = 0; i < size; i++) {
292 const T& value = input_t(static_cast<int64>(message_index), i);
293 // TODO(nix): there doesn't seem to be an inlined version of
294 // WireFormatLite::WriteString or its relatives, which might allow a
295 // small speedup.
296 Writer(field_desc.number(), value, output);
297 }
298 return Status::OK();
299 }
300
301 // Writes a group field. Groups are treated like submessages, but tag-delimited
302 // instead of length-delimited. WireFormatLite handles this differently so we
303 // code it ourselves.
WriteGroup(const FieldDescriptor & field_desc,const Tensor & input,int message_index,int size,CodedOutputStream * output)304 Status WriteGroup(const FieldDescriptor& field_desc, const Tensor& input,
305 int message_index, int size, CodedOutputStream* output) {
306 auto input_t = input.flat_inner_dims<string>();
307 for (int64 i = 0; i < size; i++) {
308 const string& value = input_t(static_cast<int64>(message_index), i);
309 WireFormatLite::WriteTag(field_desc.number(),
310 WireFormatLite::WIRETYPE_START_GROUP, output);
311 // Note the use of WriteRaw instead of WriteString to skip the length.
312 output->WriteRaw(value.data(), value.size());
313 WireFormatLite::WriteTag(field_desc.number(),
314 WireFormatLite::WIRETYPE_END_GROUP, output);
315 }
316 return Status::OK();
317 }
318
319 // Writes a (possibly repeated) field into an output stream. It is the caller's
320 // responsibility to ensure that the type of the input tensor is compatible with
321 // the type of the proto field descriptor, and that (message_index, size-1) is
322 // within bounds.
WriteField(const FieldDescriptor & field_desc,const Tensor & input,int message_index,int size,CodedOutputStream * output)323 Status WriteField(const FieldDescriptor& field_desc, const Tensor& input,
324 int message_index, int size, CodedOutputStream* output) {
325 DataType dtype = input.dtype();
326
327 switch (field_desc.type()) {
328 case WireFormatLite::TYPE_DOUBLE:
329 return WriteField<double, double, WireFormatLite::TYPE_DOUBLE,
330 WireFormatLite::WriteDoubleNoTag>(
331 field_desc, input, message_index, size, output);
332 case WireFormatLite::TYPE_FLOAT:
333 switch (dtype) {
334 case DataType::DT_FLOAT:
335 return WriteField<float, float, WireFormatLite::TYPE_FLOAT,
336 WireFormatLite::WriteFloatNoTag>(
337 field_desc, input, message_index, size, output);
338 case DataType::DT_DOUBLE:
339 return WriteField<double, float, WireFormatLite::TYPE_FLOAT,
340 WireFormatLite::WriteFloatNoTag>(
341 field_desc, input, message_index, size, output);
342 default:
343 return errors::DataLoss("Failed writing TYPE_FLOAT for ",
344 DataTypeString(dtype));
345 }
346 case WireFormatLite::TYPE_INT64:
347 return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_INT64,
348 WireFormatLite::WriteInt64NoTag>(
349 field_desc, input, message_index, size, output);
350 case WireFormatLite::TYPE_UINT64:
351 return WriteField<uint64, protobuf_uint64, WireFormatLite::TYPE_UINT64,
352 WireFormatLite::WriteUInt64NoTag>(
353 field_desc, input, message_index, size, output);
354 case WireFormatLite::TYPE_INT32:
355 switch (dtype) {
356 case DataType::DT_INT64:
357 return WriteField<int64, int32, WireFormatLite::TYPE_INT32,
358 WireFormatLite::WriteInt32NoTag>(
359 field_desc, input, message_index, size, output);
360 case DataType::DT_INT32:
361 return WriteField<int32, int32, WireFormatLite::TYPE_INT32,
362 WireFormatLite::WriteInt32NoTag>(
363 field_desc, input, message_index, size, output);
364 default:
365 return errors::DataLoss("Failed writing TYPE_INT32 for ",
366 DataTypeString(dtype));
367 }
368 case WireFormatLite::TYPE_FIXED64:
369 return WriteField<uint64, protobuf_uint64, WireFormatLite::TYPE_FIXED64,
370 WireFormatLite::WriteFixed64NoTag>(
371 field_desc, input, message_index, size, output);
372 case WireFormatLite::TYPE_FIXED32:
373 switch (dtype) {
374 case DataType::DT_UINT64:
375 return WriteField<uint64, uint32, WireFormatLite::TYPE_FIXED32,
376 WireFormatLite::WriteFixed32NoTag>(
377 field_desc, input, message_index, size, output);
378 case DataType::DT_UINT32:
379 return WriteField<uint32, uint32, WireFormatLite::TYPE_FIXED32,
380 WireFormatLite::WriteFixed32NoTag>(
381 field_desc, input, message_index, size, output);
382 default:
383 return errors::DataLoss("Failed writing TYPE_FIXED32 for ",
384 DataTypeString(dtype));
385 }
386 case WireFormatLite::TYPE_BOOL:
387 return WriteField<bool, bool, WireFormatLite::TYPE_BOOL,
388 WireFormatLite::WriteBoolNoTag>(
389 field_desc, input, message_index, size, output);
390 case WireFormatLite::TYPE_STRING:
391 return WriteVarLenField<string, WireFormatLite::WriteString>(
392 field_desc, input, message_index, size, output);
393 case WireFormatLite::TYPE_GROUP:
394 return WriteGroup(field_desc, input, message_index, size, output);
395 case WireFormatLite::TYPE_MESSAGE:
396 return WriteVarLenField<string, WireFormatLite::WriteBytes>(
397 field_desc, input, message_index, size, output);
398 case WireFormatLite::TYPE_BYTES:
399 return WriteVarLenField<string, WireFormatLite::WriteBytes>(
400 field_desc, input, message_index, size, output);
401 case WireFormatLite::TYPE_UINT32:
402 switch (dtype) {
403 case DataType::DT_UINT64:
404 return WriteField<uint64, uint32, WireFormatLite::TYPE_UINT32,
405 WireFormatLite::WriteUInt32NoTag>(
406 field_desc, input, message_index, size, output);
407 case DataType::DT_UINT32:
408 return WriteField<uint32, uint32, WireFormatLite::TYPE_UINT32,
409 WireFormatLite::WriteUInt32NoTag>(
410 field_desc, input, message_index, size, output);
411 default:
412 return errors::DataLoss("Failed writing TYPE_UINT32 for ",
413 DataTypeString(dtype));
414 }
415 case WireFormatLite::TYPE_ENUM:
416 return WriteField<int32, int32, WireFormatLite::TYPE_ENUM,
417 WireFormatLite::WriteEnumNoTag>(
418 field_desc, input, message_index, size, output);
419 case WireFormatLite::TYPE_SFIXED32:
420 switch (dtype) {
421 case DataType::DT_INT64:
422 return WriteField<int64, int32, WireFormatLite::TYPE_SFIXED32,
423 WireFormatLite::WriteSFixed32NoTag>(
424 field_desc, input, message_index, size, output);
425 case DataType::DT_INT32:
426 return WriteField<int32, int32, WireFormatLite::TYPE_SFIXED32,
427 WireFormatLite::WriteSFixed32NoTag>(
428 field_desc, input, message_index, size, output);
429 default:
430 return errors::DataLoss("Failed writing TYPE_SFIXED32 for ",
431 DataTypeString(dtype));
432 }
433 case WireFormatLite::TYPE_SFIXED64:
434 return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SFIXED64,
435 WireFormatLite::WriteSFixed64NoTag>(
436 field_desc, input, message_index, size, output);
437 case WireFormatLite::TYPE_SINT32:
438 switch (dtype) {
439 case DataType::DT_INT64:
440 return WriteField<int64, int32, WireFormatLite::TYPE_SINT32,
441 WireFormatLite::WriteSInt32NoTag>(
442 field_desc, input, message_index, size, output);
443 case DataType::DT_INT32:
444 return WriteField<int32, int32, WireFormatLite::TYPE_SINT32,
445 WireFormatLite::WriteSInt32NoTag>(
446 field_desc, input, message_index, size, output);
447 default:
448 return errors::DataLoss("Failed writing TYPE_SINT32 for ",
449 DataTypeString(dtype));
450 }
451 case WireFormatLite::TYPE_SINT64:
452 return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SINT64,
453 WireFormatLite::WriteSInt64NoTag>(
454 field_desc, input, message_index, size, output);
455 // default: intentionally omitted in order to enable static checking.
456 }
457 }
458
459 class EncodeProtoOp : public OpKernel {
460 public:
EncodeProtoOp(OpKernelConstruction * context)461 explicit EncodeProtoOp(OpKernelConstruction* context) : OpKernel(context) {
462 string descriptor_source;
463 OP_REQUIRES_OK(context,
464 context->GetAttr("descriptor_source", &descriptor_source));
465 // We always get back a desc_pool, but we may not own it. If we own it,
466 // owned_desc_pool_ will be filled in.
467 DescriptorPool const* desc_pool;
468 OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source,
469 &desc_pool, &owned_desc_pool_));
470
471 string message_type;
472 OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type));
473 const Descriptor* message_desc =
474 desc_pool->FindMessageTypeByName(message_type);
475 OP_REQUIRES(context, message_desc != nullptr,
476 errors::InvalidArgument("No descriptor found for message type ",
477 message_type));
478
479 OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names_));
480
481 // Gather the field descriptors for the given field_names.
482 field_descs_.resize(field_names_.size());
483 for (int i = 0; i < field_names_.size(); i++) {
484 const string& name = field_names_[i];
485 auto field_desc = message_desc->FindFieldByName(name);
486 OP_REQUIRES(context, field_desc != nullptr,
487 errors::InvalidArgument("Unknown field: ", name,
488 " in message type ", message_type));
489
490 field_descs_[i] = field_desc;
491 }
492
493 // Build a list of indices into field_descs sorted by increasing
494 // field_number. This will be used to output fields in sorted order,
495 // which is strongly encouraged when serializing protobufs.
496 sorted_field_index_.resize(field_names_.size());
497 // Start with the fields sorted by current index.
498 for (int i = 0; i < field_names_.size(); i++) sorted_field_index_[i] = i;
499 // Then sort the field indices by their proto field number.
500 std::sort(sorted_field_index_.begin(), sorted_field_index_.end(),
501 [this](int a, int b) -> bool {
502 return field_descs_[a]->number() < field_descs_[b]->number();
503 });
504 }
505
Compute(OpKernelContext * ctx)506 void Compute(OpKernelContext* ctx) override {
507 const Tensor* sizes_tensor;
508 OP_REQUIRES_OK(ctx, ctx->input("sizes", &sizes_tensor));
509
510 OpInputList values;
511 OP_REQUIRES_OK(ctx, ctx->input_list("values", &values));
512
513 OP_REQUIRES(ctx, field_descs_.size() == values.size(),
514 errors::InvalidArgument(
515 "Length of inputs list must match field_names"));
516
517 // Check the arguments for consistency.
518 TensorShape common_prefix;
519 int message_count;
520 for (int i = 0; i < field_descs_.size(); i++) {
521 const Tensor& v = values[i];
522
523 // The type of each value tensor must match the corresponding field.
524 OP_REQUIRES(
525 ctx,
526 proto_utils::IsCompatibleType(field_descs_[i]->type(), v.dtype()),
527 errors::InvalidArgument(
528 "Incompatible type for field ", field_names_[i],
529 ". Saw dtype: ", DataTypeString(v.dtype()),
530 " but field type is: ", field_descs_[i]->type_name()));
531
532 OP_REQUIRES(
533 ctx, TensorShapeUtils::IsMatrixOrHigher(v.shape()),
534 errors::InvalidArgument("Invalid shape for field ", field_names_[i],
535 ". Saw shape ", v.shape().DebugString(),
536 " but it should be at least a matrix."));
537
538 // All value tensors must have the same shape prefix (i.e. batch size).
539 TensorShape shape_prefix = v.shape();
540 shape_prefix.RemoveDim(shape_prefix.dims() - 1);
541
542 // Do some initialization on the first input value. The rest will
543 // have to match this one.
544 if (i == 0) {
545 OP_REQUIRES(ctx, v.dims() >= 1,
546 errors::InvalidArgument(
547 "Expected value to be at least a vector, saw shape: ",
548 v.shape().DebugString()));
549 common_prefix = shape_prefix;
550 message_count = common_prefix.num_elements();
551 } else {
552 OP_REQUIRES(ctx, shape_prefix == common_prefix,
553 errors::InvalidArgument(
554 "Values must match up to the last dimension"));
555 }
556 }
557
558 TensorShape expected_sizes_shape = common_prefix;
559 expected_sizes_shape.AddDim(field_descs_.size());
560
561 OP_REQUIRES(ctx, sizes_tensor->shape() == expected_sizes_shape,
562 errors::InvalidArgument(
563 "sizes should be batch_size + [len(field_names)]. Saw: ",
564 sizes_tensor->shape().DebugString(),
565 " but expected: ", expected_sizes_shape.DebugString()));
566
567 auto sizes = sizes_tensor->flat_inner_dims<int32>();
568
569 for (int i = 0; i < field_descs_.size(); ++i) {
570 const Tensor& v = values[i];
571 int max_size = v.dim_size(v.dims() - 1);
572
573 // The last dimension of a value tensor must be greater than the
574 // corresponding size in the sizes tensor.
575 for (int message_index = 0; message_index < message_count;
576 message_index++) {
577 OP_REQUIRES(
578 ctx, sizes(message_index, i) <= max_size,
579 errors::InvalidArgument(
580 "Size to write must not be larger than value tensor; but saw: ",
581 sizes(message_index, i), " > ", max_size, " at message ",
582 message_index, " field ", i));
583 }
584 }
585
586 // This pointer is owned by the context.
587 Tensor* output_tensor;
588 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, common_prefix, &output_tensor));
589
590 auto bufs = output_tensor->flat<string>();
591 for (int message_index = 0; message_index < message_count;
592 message_index++) {
593 // TODO(nix): possibly optimize allocation here by calling
594 // `bufs(message_index).reserve(DEFAULT_BUF_SIZE)`.
595 StringOutputStream output_string(&bufs(message_index));
596 CodedOutputStream out(&output_string);
597 // Write fields in ascending field_number order.
598 for (int i : sorted_field_index_) {
599 auto& field_desc = *field_descs_[i];
600 const Tensor& v = values[i];
601 int size = sizes(message_index, i);
602 if (!size) continue;
603 OP_REQUIRES_OK(ctx,
604 WriteField(field_desc, v, message_index, size, &out));
605 }
606 }
607 }
608
609 private:
610 std::vector<string> field_names_;
611 std::vector<const FieldDescriptor*> field_descs_;
612
613 // Owned_desc_pool_ is null when using descriptor_source=local.
614 std::unique_ptr<DescriptorPool> owned_desc_pool_;
615
616 // Contains indices into field_names_, sorted by field number since that's the
617 // order of writing.
618 std::vector<int> sorted_field_index_;
619
620 TF_DISALLOW_COPY_AND_ASSIGN(EncodeProtoOp);
621 };
622
623 REGISTER_KERNEL_BUILDER(Name("EncodeProto").Device(DEVICE_CPU), EncodeProtoOp);
624
625 } // namespace
626 } // namespace tensorflow
627