• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // EncodeProto is a TensorFlow Op which serializes tensors into
17 // arbitrary protobufs.
18 //
19 // See the docstring in ../ops/encode_proto_op.cc for usage of the op.
20 //
21 // This implementation writes the serialized format using a handful of
22 // calls from the WireFormatLite API.
23 
24 #include <memory>
25 #include <vector>
26 
27 #include "third_party/eigen3/Eigen/Core"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor_types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/protobuf.h"
33 #include "tensorflow/core/util/proto/descriptors.h"
34 #include "tensorflow/core/util/proto/proto_utils.h"
35 
36 namespace tensorflow {
37 namespace {
38 
39 using ::tensorflow::protobuf::Descriptor;
40 using ::tensorflow::protobuf::DescriptorPool;
41 using ::tensorflow::protobuf::FieldDescriptor;
42 using ::tensorflow::protobuf::internal::WireFormatLite;
43 using ::tensorflow::protobuf::io::CodedOutputStream;
44 using ::tensorflow::protobuf::io::StringOutputStream;
45 
46 // Computes the total serialized size for a packed repeated field. For
47 // fixed-size types this can just multiply, but for variable-sized types it has
48 // to iterate through the values in the tensor.
49 template <WireFormatLite::FieldType FieldType, typename TensorT>
50 size_t TotalPackedSize(const Tensor& input, int message_index, int size);
51 
52 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)53 size_t TotalPackedSize<WireFormatLite::TYPE_DOUBLE, double>(const Tensor& input,
54                                                             int message_index,
55                                                             int size) {
56   return size * WireFormatLite::kDoubleSize;
57 }
58 
59 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)60 size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, double>(const Tensor& input,
61                                                            int message_index,
62                                                            int size) {
63   return size * WireFormatLite::kFloatSize;
64 }
65 
66 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)67 size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, float>(const Tensor& input,
68                                                           int message_index,
69                                                           int size) {
70   return size * WireFormatLite::kFloatSize;
71 }
72 
73 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)74 size_t TotalPackedSize<WireFormatLite::TYPE_INT64, int64>(const Tensor& input,
75                                                           int message_index,
76                                                           int size) {
77   size_t data_size = 0;
78   auto input_t = input.flat_inner_dims<int64>();
79   for (int64 i = 0; i < size; i++) {
80     data_size += WireFormatLite::Int64Size(
81         input_t(static_cast<int64>(message_index), i));
82   }
83   return data_size;
84 }
85 
86 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)87 size_t TotalPackedSize<WireFormatLite::TYPE_UINT64, uint64>(const Tensor& input,
88                                                             int message_index,
89                                                             int size) {
90   size_t data_size = 0;
91   auto input_t = input.flat_inner_dims<uint64>();
92   for (int64 i = 0; i < size; i++) {
93     data_size += WireFormatLite::UInt64Size(
94         input_t(static_cast<int64>(message_index), i));
95   }
96   return data_size;
97 }
98 
99 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)100 size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int64>(const Tensor& input,
101                                                           int message_index,
102                                                           int size) {
103   size_t data_size = 0;
104   auto input_t = input.flat_inner_dims<int64>();
105   for (int64 i = 0; i < size; i++) {
106     data_size += WireFormatLite::Int32Size(
107         input_t(static_cast<int64>(message_index), i));
108   }
109   return data_size;
110 }
111 
112 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)113 size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int32>(const Tensor& input,
114                                                           int message_index,
115                                                           int size) {
116   size_t data_size = 0;
117   auto input_t = input.flat_inner_dims<int32>();
118   for (int64 i = 0; i < size; i++) {
119     data_size += WireFormatLite::Int32Size(
120         input_t(static_cast<int64>(message_index), i));
121   }
122   return data_size;
123 }
124 
125 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)126 size_t TotalPackedSize<WireFormatLite::TYPE_FIXED64, uint64>(
127     const Tensor& input, int message_index, int size) {
128   return size * WireFormatLite::kFixed64Size;
129 }
130 
131 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)132 size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, uint64>(
133     const Tensor& input, int message_index, int size) {
134   return size * WireFormatLite::kFixed32Size;
135 }
136 
137 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)138 size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, uint32>(
139     const Tensor& input, int message_index, int size) {
140   return size * WireFormatLite::kFixed32Size;
141 }
142 
143 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)144 size_t TotalPackedSize<WireFormatLite::TYPE_BOOL, bool>(const Tensor& input,
145                                                         int message_index,
146                                                         int size) {
147   return size * WireFormatLite::kBoolSize;
148 }
149 
150 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)151 size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, uint64>(const Tensor& input,
152                                                             int message_index,
153                                                             int size) {
154   size_t data_size = 0;
155   auto input_t = input.flat_inner_dims<uint64>();
156   for (int64 i = 0; i < size; i++) {
157     data_size += WireFormatLite::UInt32Size(
158         input_t(static_cast<int64>(message_index), i));
159   }
160   return data_size;
161 }
162 
163 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)164 size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, uint32>(const Tensor& input,
165                                                             int message_index,
166                                                             int size) {
167   size_t data_size = 0;
168   auto input_t = input.flat_inner_dims<uint32>();
169   for (int64 i = 0; i < size; i++) {
170     data_size += WireFormatLite::UInt32Size(
171         input_t(static_cast<int64>(message_index), i));
172   }
173   return data_size;
174 }
175 
176 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)177 size_t TotalPackedSize<WireFormatLite::TYPE_ENUM, int32>(const Tensor& input,
178                                                          int message_index,
179                                                          int size) {
180   size_t data_size = 0;
181   auto input_t = input.flat_inner_dims<int32>();
182   for (int64 i = 0; i < size; i++) {
183     data_size +=
184         WireFormatLite::EnumSize(input_t(static_cast<int64>(message_index), i));
185   }
186   return data_size;
187 }
188 
189 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)190 size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int32>(
191     const Tensor& input, int message_index, int size) {
192   return size * WireFormatLite::kSFixed32Size;
193 }
194 
195 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)196 size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int64>(
197     const Tensor& input, int message_index, int size) {
198   return size * WireFormatLite::kSFixed32Size;
199 }
200 
201 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)202 size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED64, int64>(
203     const Tensor& input, int message_index, int size) {
204   return size * WireFormatLite::kSFixed64Size;
205 }
206 
207 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)208 size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int32>(const Tensor& input,
209                                                            int message_index,
210                                                            int size) {
211   size_t data_size = 0;
212   auto input_t = input.flat_inner_dims<int32>();
213   for (int64 i = 0; i < size; i++) {
214     data_size += WireFormatLite::SInt32Size(
215         input_t(static_cast<int64>(message_index), i));
216   }
217   return data_size;
218 }
219 
220 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)221 size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int64>(const Tensor& input,
222                                                            int message_index,
223                                                            int size) {
224   size_t data_size = 0;
225   auto input_t = input.flat_inner_dims<int64>();
226   for (int64 i = 0; i < size; i++) {
227     data_size += WireFormatLite::SInt32Size(
228         input_t(static_cast<int64>(message_index), i));
229   }
230   return data_size;
231 }
232 
233 template <>
TotalPackedSize(const Tensor & input,int message_index,int size)234 size_t TotalPackedSize<WireFormatLite::TYPE_SINT64, int64>(const Tensor& input,
235                                                            int message_index,
236                                                            int size) {
237   size_t data_size = 0;
238   auto input_t = input.flat_inner_dims<int64>();
239   for (int64 i = 0; i < size; i++) {
240     data_size += WireFormatLite::SInt64Size(
241         input_t(static_cast<int64>(message_index), i));
242   }
243   return data_size;
244 }
245 
246 // Writes a possibly repeated primitive field. TensorFlow does not have unsigned
247 // types, so we decode them to signed and encode them back to unsigned.
248 template <typename TensorT, typename ProtoT,
249           WireFormatLite::FieldType FieldType,
250           void Writer(ProtoT, CodedOutputStream*)>
WriteField(const FieldDescriptor & field_desc,const Tensor & input,int message_index,int size,CodedOutputStream * output)251 Status WriteField(const FieldDescriptor& field_desc, const Tensor& input,
252                   int message_index, int size, CodedOutputStream* output) {
253   auto wire_type = WireFormatLite::WireTypeForFieldType(
254       WireFormatLite::FieldType(field_desc.type()));
255 
256   auto input_t = input.flat_inner_dims<TensorT>();
257   if (field_desc.options().packed()) {
258     // Write the tag for the packed field.
259     WireFormatLite::WriteTag(field_desc.number(),
260                              WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output);
261 
262     // Write the total packed length.
263     size_t data_size =
264         TotalPackedSize<FieldType, TensorT>(input, message_index, size);
265     output->WriteVarint32(data_size);
266 
267     // Write individual values.
268     for (int64 i = 0; i < size; i++) {
269       // Note implicit cast from signed to unsigned.
270       const ProtoT& value = input_t(static_cast<int64>(message_index), i);
271       Writer(value, output);
272     }
273   } else {
274     for (int64 i = 0; i < size; i++) {
275       WireFormatLite::WriteTag(field_desc.number(), wire_type, output);
276 
277       // Note implicit cast from signed to unsigned.
278       const ProtoT& value = input_t(static_cast<int64>(message_index), i);
279       Writer(value, output);
280     }
281   }
282   return Status::OK();
283 }
284 
285 // Writes a possibly repeated string, bytes, or message field.
286 template <typename T, void Writer(int, const T&, CodedOutputStream*)>
WriteVarLenField(const FieldDescriptor & field_desc,const Tensor & input,int message_index,int size,CodedOutputStream * output)287 Status WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input,
288                         int message_index, int size,
289                         CodedOutputStream* output) {
290   auto input_t = input.flat_inner_dims<T>();
291   for (int64 i = 0; i < size; i++) {
292     const T& value = input_t(static_cast<int64>(message_index), i);
293     // TODO(nix): there doesn't seem to be an inlined version of
294     // WireFormatLite::WriteString or its relatives, which might allow a
295     // small speedup.
296     Writer(field_desc.number(), value, output);
297   }
298   return Status::OK();
299 }
300 
301 // Writes a group field. Groups are treated like submessages, but tag-delimited
302 // instead of length-delimited. WireFormatLite handles this differently so we
303 // code it ourselves.
WriteGroup(const FieldDescriptor & field_desc,const Tensor & input,int message_index,int size,CodedOutputStream * output)304 Status WriteGroup(const FieldDescriptor& field_desc, const Tensor& input,
305                   int message_index, int size, CodedOutputStream* output) {
306   auto input_t = input.flat_inner_dims<string>();
307   for (int64 i = 0; i < size; i++) {
308     const string& value = input_t(static_cast<int64>(message_index), i);
309     WireFormatLite::WriteTag(field_desc.number(),
310                              WireFormatLite::WIRETYPE_START_GROUP, output);
311     // Note the use of WriteRaw instead of WriteString to skip the length.
312     output->WriteRaw(value.data(), value.size());
313     WireFormatLite::WriteTag(field_desc.number(),
314                              WireFormatLite::WIRETYPE_END_GROUP, output);
315   }
316   return Status::OK();
317 }
318 
319 // Writes a (possibly repeated) field into an output stream. It is the caller's
320 // responsibility to ensure that the type of the input tensor is compatible with
321 // the type of the proto field descriptor, and that (message_index, size-1) is
322 // within bounds.
WriteField(const FieldDescriptor & field_desc,const Tensor & input,int message_index,int size,CodedOutputStream * output)323 Status WriteField(const FieldDescriptor& field_desc, const Tensor& input,
324                   int message_index, int size, CodedOutputStream* output) {
325   DataType dtype = input.dtype();
326 
327   switch (field_desc.type()) {
328     case WireFormatLite::TYPE_DOUBLE:
329       return WriteField<double, double, WireFormatLite::TYPE_DOUBLE,
330                         WireFormatLite::WriteDoubleNoTag>(
331           field_desc, input, message_index, size, output);
332     case WireFormatLite::TYPE_FLOAT:
333       switch (dtype) {
334         case DataType::DT_FLOAT:
335           return WriteField<float, float, WireFormatLite::TYPE_FLOAT,
336                             WireFormatLite::WriteFloatNoTag>(
337               field_desc, input, message_index, size, output);
338         case DataType::DT_DOUBLE:
339           return WriteField<double, float, WireFormatLite::TYPE_FLOAT,
340                             WireFormatLite::WriteFloatNoTag>(
341               field_desc, input, message_index, size, output);
342         default:
343           return errors::DataLoss("Failed writing TYPE_FLOAT for ",
344                                   DataTypeString(dtype));
345       }
346     case WireFormatLite::TYPE_INT64:
347       return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_INT64,
348                         WireFormatLite::WriteInt64NoTag>(
349           field_desc, input, message_index, size, output);
350     case WireFormatLite::TYPE_UINT64:
351       return WriteField<uint64, protobuf_uint64, WireFormatLite::TYPE_UINT64,
352                         WireFormatLite::WriteUInt64NoTag>(
353           field_desc, input, message_index, size, output);
354     case WireFormatLite::TYPE_INT32:
355       switch (dtype) {
356         case DataType::DT_INT64:
357           return WriteField<int64, int32, WireFormatLite::TYPE_INT32,
358                             WireFormatLite::WriteInt32NoTag>(
359               field_desc, input, message_index, size, output);
360         case DataType::DT_INT32:
361           return WriteField<int32, int32, WireFormatLite::TYPE_INT32,
362                             WireFormatLite::WriteInt32NoTag>(
363               field_desc, input, message_index, size, output);
364         default:
365           return errors::DataLoss("Failed writing TYPE_INT32 for ",
366                                   DataTypeString(dtype));
367       }
368     case WireFormatLite::TYPE_FIXED64:
369       return WriteField<uint64, protobuf_uint64, WireFormatLite::TYPE_FIXED64,
370                         WireFormatLite::WriteFixed64NoTag>(
371           field_desc, input, message_index, size, output);
372     case WireFormatLite::TYPE_FIXED32:
373       switch (dtype) {
374         case DataType::DT_UINT64:
375           return WriteField<uint64, uint32, WireFormatLite::TYPE_FIXED32,
376                             WireFormatLite::WriteFixed32NoTag>(
377               field_desc, input, message_index, size, output);
378         case DataType::DT_UINT32:
379           return WriteField<uint32, uint32, WireFormatLite::TYPE_FIXED32,
380                             WireFormatLite::WriteFixed32NoTag>(
381               field_desc, input, message_index, size, output);
382         default:
383           return errors::DataLoss("Failed writing TYPE_FIXED32 for ",
384                                   DataTypeString(dtype));
385       }
386     case WireFormatLite::TYPE_BOOL:
387       return WriteField<bool, bool, WireFormatLite::TYPE_BOOL,
388                         WireFormatLite::WriteBoolNoTag>(
389           field_desc, input, message_index, size, output);
390     case WireFormatLite::TYPE_STRING:
391       return WriteVarLenField<string, WireFormatLite::WriteString>(
392           field_desc, input, message_index, size, output);
393     case WireFormatLite::TYPE_GROUP:
394       return WriteGroup(field_desc, input, message_index, size, output);
395     case WireFormatLite::TYPE_MESSAGE:
396       return WriteVarLenField<string, WireFormatLite::WriteBytes>(
397           field_desc, input, message_index, size, output);
398     case WireFormatLite::TYPE_BYTES:
399       return WriteVarLenField<string, WireFormatLite::WriteBytes>(
400           field_desc, input, message_index, size, output);
401     case WireFormatLite::TYPE_UINT32:
402       switch (dtype) {
403         case DataType::DT_UINT64:
404           return WriteField<uint64, uint32, WireFormatLite::TYPE_UINT32,
405                             WireFormatLite::WriteUInt32NoTag>(
406               field_desc, input, message_index, size, output);
407         case DataType::DT_UINT32:
408           return WriteField<uint32, uint32, WireFormatLite::TYPE_UINT32,
409                             WireFormatLite::WriteUInt32NoTag>(
410               field_desc, input, message_index, size, output);
411         default:
412           return errors::DataLoss("Failed writing TYPE_UINT32 for ",
413                                   DataTypeString(dtype));
414       }
415     case WireFormatLite::TYPE_ENUM:
416       return WriteField<int32, int32, WireFormatLite::TYPE_ENUM,
417                         WireFormatLite::WriteEnumNoTag>(
418           field_desc, input, message_index, size, output);
419     case WireFormatLite::TYPE_SFIXED32:
420       switch (dtype) {
421         case DataType::DT_INT64:
422           return WriteField<int64, int32, WireFormatLite::TYPE_SFIXED32,
423                             WireFormatLite::WriteSFixed32NoTag>(
424               field_desc, input, message_index, size, output);
425         case DataType::DT_INT32:
426           return WriteField<int32, int32, WireFormatLite::TYPE_SFIXED32,
427                             WireFormatLite::WriteSFixed32NoTag>(
428               field_desc, input, message_index, size, output);
429         default:
430           return errors::DataLoss("Failed writing TYPE_SFIXED32 for ",
431                                   DataTypeString(dtype));
432       }
433     case WireFormatLite::TYPE_SFIXED64:
434       return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SFIXED64,
435                         WireFormatLite::WriteSFixed64NoTag>(
436           field_desc, input, message_index, size, output);
437     case WireFormatLite::TYPE_SINT32:
438       switch (dtype) {
439         case DataType::DT_INT64:
440           return WriteField<int64, int32, WireFormatLite::TYPE_SINT32,
441                             WireFormatLite::WriteSInt32NoTag>(
442               field_desc, input, message_index, size, output);
443         case DataType::DT_INT32:
444           return WriteField<int32, int32, WireFormatLite::TYPE_SINT32,
445                             WireFormatLite::WriteSInt32NoTag>(
446               field_desc, input, message_index, size, output);
447         default:
448           return errors::DataLoss("Failed writing TYPE_SINT32 for ",
449                                   DataTypeString(dtype));
450       }
451     case WireFormatLite::TYPE_SINT64:
452       return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SINT64,
453                         WireFormatLite::WriteSInt64NoTag>(
454           field_desc, input, message_index, size, output);
455       // default: intentionally omitted in order to enable static checking.
456   }
457 }
458 
459 class EncodeProtoOp : public OpKernel {
460  public:
EncodeProtoOp(OpKernelConstruction * context)461   explicit EncodeProtoOp(OpKernelConstruction* context) : OpKernel(context) {
462     string descriptor_source;
463     OP_REQUIRES_OK(context,
464                    context->GetAttr("descriptor_source", &descriptor_source));
465     // We always get back a desc_pool, but we may not own it. If we own it,
466     // owned_desc_pool_ will be filled in.
467     DescriptorPool const* desc_pool;
468     OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source,
469                                               &desc_pool, &owned_desc_pool_));
470 
471     string message_type;
472     OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type));
473     const Descriptor* message_desc =
474         desc_pool->FindMessageTypeByName(message_type);
475     OP_REQUIRES(context, message_desc != nullptr,
476                 errors::InvalidArgument("No descriptor found for message type ",
477                                         message_type));
478 
479     OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names_));
480 
481     // Gather the field descriptors for the given field_names.
482     field_descs_.resize(field_names_.size());
483     for (int i = 0; i < field_names_.size(); i++) {
484       const string& name = field_names_[i];
485       auto field_desc = message_desc->FindFieldByName(name);
486       OP_REQUIRES(context, field_desc != nullptr,
487                   errors::InvalidArgument("Unknown field: ", name,
488                                           " in message type ", message_type));
489 
490       field_descs_[i] = field_desc;
491     }
492 
493     // Build a list of indices into field_descs sorted by increasing
494     // field_number. This will be used to output fields in sorted order,
495     // which is strongly encouraged when serializing protobufs.
496     sorted_field_index_.resize(field_names_.size());
497     // Start with the fields sorted by current index.
498     for (int i = 0; i < field_names_.size(); i++) sorted_field_index_[i] = i;
499     // Then sort the field indices by their proto field number.
500     std::sort(sorted_field_index_.begin(), sorted_field_index_.end(),
501               [this](int a, int b) -> bool {
502                 return field_descs_[a]->number() < field_descs_[b]->number();
503               });
504   }
505 
Compute(OpKernelContext * ctx)506   void Compute(OpKernelContext* ctx) override {
507     const Tensor* sizes_tensor;
508     OP_REQUIRES_OK(ctx, ctx->input("sizes", &sizes_tensor));
509 
510     OpInputList values;
511     OP_REQUIRES_OK(ctx, ctx->input_list("values", &values));
512 
513     OP_REQUIRES(ctx, field_descs_.size() == values.size(),
514                 errors::InvalidArgument(
515                     "Length of inputs list must match field_names"));
516 
517     // Check the arguments for consistency.
518     TensorShape common_prefix;
519     int message_count;
520     for (int i = 0; i < field_descs_.size(); i++) {
521       const Tensor& v = values[i];
522 
523       // The type of each value tensor must match the corresponding field.
524       OP_REQUIRES(
525           ctx,
526           proto_utils::IsCompatibleType(field_descs_[i]->type(), v.dtype()),
527           errors::InvalidArgument(
528               "Incompatible type for field ", field_names_[i],
529               ".  Saw dtype: ", DataTypeString(v.dtype()),
530               " but field type is: ", field_descs_[i]->type_name()));
531 
532       OP_REQUIRES(
533           ctx, TensorShapeUtils::IsMatrixOrHigher(v.shape()),
534           errors::InvalidArgument("Invalid shape for field ", field_names_[i],
535                                   ".  Saw shape ", v.shape().DebugString(),
536                                   " but it should be at least a matrix."));
537 
538       // All value tensors must have the same shape prefix (i.e. batch size).
539       TensorShape shape_prefix = v.shape();
540       shape_prefix.RemoveDim(shape_prefix.dims() - 1);
541 
542       // Do some initialization on the first input value. The rest will
543       // have to match this one.
544       if (i == 0) {
545         OP_REQUIRES(ctx, v.dims() >= 1,
546                     errors::InvalidArgument(
547                         "Expected value to be at least a vector, saw shape: ",
548                         v.shape().DebugString()));
549         common_prefix = shape_prefix;
550         message_count = common_prefix.num_elements();
551       } else {
552         OP_REQUIRES(ctx, shape_prefix == common_prefix,
553                     errors::InvalidArgument(
554                         "Values must match up to the last dimension"));
555       }
556     }
557 
558     TensorShape expected_sizes_shape = common_prefix;
559     expected_sizes_shape.AddDim(field_descs_.size());
560 
561     OP_REQUIRES(ctx, sizes_tensor->shape() == expected_sizes_shape,
562                 errors::InvalidArgument(
563                     "sizes should be batch_size + [len(field_names)].  Saw: ",
564                     sizes_tensor->shape().DebugString(),
565                     " but expected: ", expected_sizes_shape.DebugString()));
566 
567     auto sizes = sizes_tensor->flat_inner_dims<int32>();
568 
569     for (int i = 0; i < field_descs_.size(); ++i) {
570       const Tensor& v = values[i];
571       int max_size = v.dim_size(v.dims() - 1);
572 
573       // The last dimension of a value tensor must be greater than the
574       // corresponding size in the sizes tensor.
575       for (int message_index = 0; message_index < message_count;
576            message_index++) {
577         OP_REQUIRES(
578             ctx, sizes(message_index, i) <= max_size,
579             errors::InvalidArgument(
580                 "Size to write must not be larger than value tensor; but saw: ",
581                 sizes(message_index, i), " > ", max_size, " at message ",
582                 message_index, " field ", i));
583       }
584     }
585 
586     // This pointer is owned by the context.
587     Tensor* output_tensor;
588     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, common_prefix, &output_tensor));
589 
590     auto bufs = output_tensor->flat<string>();
591     for (int message_index = 0; message_index < message_count;
592          message_index++) {
593       // TODO(nix): possibly optimize allocation here by calling
594       // `bufs(message_index).reserve(DEFAULT_BUF_SIZE)`.
595       StringOutputStream output_string(&bufs(message_index));
596       CodedOutputStream out(&output_string);
597       // Write fields in ascending field_number order.
598       for (int i : sorted_field_index_) {
599         auto& field_desc = *field_descs_[i];
600         const Tensor& v = values[i];
601         int size = sizes(message_index, i);
602         if (!size) continue;
603         OP_REQUIRES_OK(ctx,
604                        WriteField(field_desc, v, message_index, size, &out));
605       }
606     }
607   }
608 
609  private:
610   std::vector<string> field_names_;
611   std::vector<const FieldDescriptor*> field_descs_;
612 
613   // Owned_desc_pool_ is null when using descriptor_source=local.
614   std::unique_ptr<DescriptorPool> owned_desc_pool_;
615 
616   // Contains indices into field_names_, sorted by field number since that's the
617   // order of writing.
618   std::vector<int> sorted_field_index_;
619 
620   TF_DISALLOW_COPY_AND_ASSIGN(EncodeProtoOp);
621 };
622 
623 REGISTER_KERNEL_BUILDER(Name("EncodeProto").Device(DEVICE_CPU), EncodeProtoOp);
624 
625 }  // namespace
626 }  // namespace tensorflow
627