• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/util/example_proto_helper.h"
16 
17 #include <vector>
18 
19 #include "tensorflow/core/example/example.pb.h"
20 #include "tensorflow/core/example/feature.pb.h"
21 #include "tensorflow/core/framework/numeric_op.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/protobuf.h"
26 #include "tensorflow/core/util/sparse/sparse_tensor.h"
27 
28 namespace tensorflow {
29 
CheckValidType(const DataType & dtype)30 Status CheckValidType(const DataType& dtype) {
31   switch (dtype) {
32     case DT_INT64:
33     case DT_FLOAT:
34     case DT_STRING:
35       return Status::OK();
36     default:
37       return errors::InvalidArgument("Received input dtype: ",
38                                      DataTypeString(dtype));
39   }
40 }
41 
CheckTypesMatch(const Feature & feature,const DataType & dtype,bool * match)42 Status CheckTypesMatch(const Feature& feature, const DataType& dtype,
43                        bool* match) {
44   switch (dtype) {
45     case DT_INT64:
46       *match = (feature.kind_case() == Feature::kInt64List);
47       break;
48     case DT_FLOAT:
49       *match = (feature.kind_case() == Feature::kFloatList);
50       break;
51     case DT_STRING:
52       *match = (feature.kind_case() == Feature::kBytesList);
53       break;
54     default:
55       return errors::InvalidArgument("Invalid input dtype: ",
56                                      DataTypeString(dtype));
57   }
58   return Status::OK();
59 }
60 
FeatureDenseCopy(const std::size_t out_index,const string & name,const string & key,const DataType & dtype,const TensorShape & shape,const Feature & feature,Tensor * out)61 Status FeatureDenseCopy(const std::size_t out_index, const string& name,
62                         const string& key, const DataType& dtype,
63                         const TensorShape& shape, const Feature& feature,
64                         Tensor* out) {
65   const std::size_t num_elements = shape.num_elements();
66   const std::size_t offset = out_index * num_elements;
67 
68   switch (dtype) {
69     case DT_INT64: {
70       const Int64List& values = feature.int64_list();
71       if (static_cast<size_t>(values.value_size()) != num_elements) {
72         return errors::InvalidArgument(
73             "Name: ", name, ", Key: ", key, ", Index: ", out_index,
74             ".  Number of int64 values != expected.  "
75             "values size: ",
76             values.value_size(), " but output shape: ", shape.DebugString());
77       }
78       auto out_p = out->flat<int64>().data() + offset;
79       std::copy_n(values.value().data(), num_elements, out_p);
80       return Status::OK();
81     }
82     case DT_FLOAT: {
83       const FloatList& values = feature.float_list();
84       if (static_cast<size_t>(values.value_size()) != num_elements) {
85         return errors::InvalidArgument(
86             "Name: ", name, ", Key: ", key, ", Index: ", out_index,
87             ".  Number of float values != expected.  "
88             "values size: ",
89             values.value_size(), " but output shape: ", shape.DebugString());
90       }
91       auto out_p = out->flat<float>().data() + offset;
92       std::copy_n(values.value().data(), num_elements, out_p);
93       return Status::OK();
94     }
95     case DT_STRING: {
96       const BytesList& values = feature.bytes_list();
97       if (static_cast<size_t>(values.value_size()) != num_elements) {
98         return errors::InvalidArgument(
99             "Name: ", name, ", Key ", key, ", Index: ", out_index,
100             ".  Number of bytes values != expected.  "
101             "Values size: ",
102             values.value_size(), " but output shape: ", shape.DebugString());
103       }
104       auto out_p = out->flat<tstring>().data() + offset;
105       std::transform(values.value().data(),
106                      values.value().data() + num_elements, out_p,
107                      [](const string* s) { return *s; });
108       return Status::OK();
109     }
110     default:
111       return errors::InvalidArgument("Invalid input dtype: ",
112                                      DataTypeString(dtype));
113   }
114 }
115 
FeatureSparseCopy(const std::size_t batch,const string & key,const DataType & dtype,const Feature & feature)116 Tensor FeatureSparseCopy(const std::size_t batch, const string& key,
117                          const DataType& dtype, const Feature& feature) {
118   switch (dtype) {
119     case DT_INT64: {
120       const Int64List& values = feature.int64_list();
121       const int64 num_elements = values.value_size();
122       Tensor out(dtype, TensorShape({num_elements}));
123       auto out_p = out.flat<int64>().data();
124       std::copy_n(values.value().data(), num_elements, out_p);
125       return out;
126     }
127     case DT_FLOAT: {
128       const FloatList& values = feature.float_list();
129       const int64 num_elements = values.value_size();
130       Tensor out(dtype, TensorShape({num_elements}));
131       auto out_p = out.flat<float>().data();
132       std::copy_n(values.value().data(), num_elements, out_p);
133       return out;
134     }
135     case DT_STRING: {
136       const BytesList& values = feature.bytes_list();
137       const int64 num_elements = values.value_size();
138       Tensor out(dtype, TensorShape({num_elements}));
139       auto out_p = out.flat<tstring>().data();
140       std::transform(values.value().data(),
141                      values.value().data() + num_elements, out_p,
142                      [](const string* s) { return *s; });
143       return out;
144     }
145     default:
146       LOG(FATAL) << "not supposed to be here.  dtype requested: " << dtype;
147   }
148 }
149 
CopyIntoSparseTensor(const Tensor & in,const int batch,const int64 offset,Tensor * indices,Tensor * values)150 int64 CopyIntoSparseTensor(const Tensor& in, const int batch,
151                            const int64 offset, Tensor* indices,
152                            Tensor* values) {
153   const int64 num_elements = in.shape().num_elements();
154   const DataType& dtype = in.dtype();
155   CHECK_EQ(dtype, values->dtype());
156 
157   // Update indices.
158   if (num_elements > 0) {
159     auto ix_t = indices->matrix<int64>();
160     int64* ix_p = &ix_t(offset, 0);
161     for (int64 i = 0; i < num_elements; ++i, ix_p += 2) {
162       *ix_p = batch;    // Column 0 stores the batch entry
163       *(ix_p + 1) = i;  // Column 1 stores the index in the batch
164     }
165   }
166 
167   // Copy values over.
168   switch (dtype) {
169     case DT_INT64: {
170       std::copy_n(in.flat<int64>().data(), num_elements,
171                   values->flat<int64>().data() + offset);
172       break;
173     }
174     case DT_FLOAT: {
175       std::copy_n(in.flat<float>().data(), num_elements,
176                   values->flat<float>().data() + offset);
177       break;
178     }
179     case DT_STRING: {
180       std::copy_n(in.flat<tstring>().data(), num_elements,
181                   values->flat<tstring>().data() + offset);
182       break;
183     }
184     default:
185       LOG(FATAL) << "Not supposed to be here.  Saw dtype: " << dtype;
186   }
187 
188   return num_elements;
189 }
190 
RowDenseCopy(const std::size_t & out_index,const DataType & dtype,const Tensor & in,Tensor * out)191 void RowDenseCopy(const std::size_t& out_index, const DataType& dtype,
192                   const Tensor& in, Tensor* out) {
193   const std::size_t num_elements = in.shape().num_elements();
194   const std::size_t offset = out_index * num_elements;
195 
196   switch (dtype) {
197     case DT_INT64: {
198       std::copy_n(in.flat<int64>().data(), num_elements,
199                   out->flat<int64>().data() + offset);
200       break;
201     }
202     case DT_FLOAT: {
203       std::copy_n(in.flat<float>().data(), num_elements,
204                   out->flat<float>().data() + offset);
205       break;
206     }
207     case DT_STRING: {
208       // TODO(dero): verify.
209       std::copy_n(in.flat<tstring>().data(), num_elements,
210                   out->flat<tstring>().data() + offset);
211       break;
212     }
213     default:
214       LOG(FATAL) << "Not supposed to be here.  Saw dtype: " << dtype;
215   }
216 }
217 
SingleExampleProtoToTensors(const Example & example,const string & example_name,const int batch_index,const std::vector<FixedLenFeature> & fixed_len_features,const std::vector<VarLenFeature> & var_len_features,std::vector<Tensor * > * output_dense_values_tensor,std::vector<std::vector<Tensor>> * output_sparse_values_tmp)218 Status SingleExampleProtoToTensors(
219     const Example& example, const string& example_name, const int batch_index,
220     const std::vector<FixedLenFeature>& fixed_len_features,
221     const std::vector<VarLenFeature>& var_len_features,
222     std::vector<Tensor*>* output_dense_values_tensor,
223     std::vector<std::vector<Tensor>>* output_sparse_values_tmp) {
224   const Features& features = example.features();
225   const auto& feature_dict = features.feature();
226 
227   // Handle dense features.
228   for (size_t d = 0; d < fixed_len_features.size(); ++d) {
229     const FixedLenFeature& feature_config = fixed_len_features[d];
230     const string& key = feature_config.key;
231     const DataType& dtype = feature_config.dtype;
232     const TensorShape& shape = feature_config.shape;
233     const Tensor& default_value = feature_config.default_value;
234     bool required = (default_value.NumElements() == 0);
235     const auto& feature_found = feature_dict.find(key);
236     const bool feature_has_data =  // Found key & data type is set
237         (feature_found != feature_dict.end() &&
238          (feature_found->second.kind_case() != Feature::KIND_NOT_SET));
239 
240     const bool required_ok = feature_has_data || !required;
241     if (!required_ok) {
242       return errors::InvalidArgument("Name: ", example_name, ", Feature: ", key,
243                                      " is required but could not be found.");
244     }
245 
246     // Perform the FeatureDenseCopy into the output dense_values tensor (if
247     // the value is present).
248     if (feature_has_data) {
249       const Feature& f = feature_found->second;
250       bool types_match;
251       TF_RETURN_IF_ERROR(CheckTypesMatch(f, dtype, &types_match));
252       if (!types_match) {
253         return errors::InvalidArgument("Name: ", example_name,
254                                        ", Feature: ", key,
255                                        ".  Data types don't match. ",
256                                        "Expected type: ", DataTypeString(dtype),
257                                        "  Feature is: ", f.DebugString());
258       }
259       TF_RETURN_IF_ERROR(FeatureDenseCopy(batch_index, example_name, key, dtype,
260                                           shape, f,
261                                           (*output_dense_values_tensor)[d]));
262     } else {
263       // If the value is missing, RowDenseCopy the default value.
264       RowDenseCopy(batch_index, dtype, default_value,
265                    (*output_dense_values_tensor)[d]);
266     }
267   }
268 
269   // Handle sparse features.
270   for (size_t d = 0; d < var_len_features.size(); ++d) {
271     const VarLenFeature& feature_config = var_len_features[d];
272     const string& key = feature_config.key;
273     const DataType& dtype = feature_config.dtype;
274     const auto& feature_found = feature_dict.find(key);
275 
276     const bool feature_has_data =  // Found key & data type is set
277         (feature_found != feature_dict.end() &&
278          (feature_found->second.kind_case() != Feature::KIND_NOT_SET));
279 
280     if (feature_has_data) {
281       const Feature& f = feature_found->second;
282       bool types_match;
283       TF_RETURN_IF_ERROR(CheckTypesMatch(f, dtype, &types_match));
284       if (!types_match) {
285         return errors::InvalidArgument("Name: ", example_name,
286                                        ", Feature: ", key,
287                                        ".  Data types don't match. ",
288                                        "Expected type: ", DataTypeString(dtype),
289                                        "  Feature is: ", f.DebugString());
290       }
291       (*output_sparse_values_tmp)[d][batch_index] =
292           FeatureSparseCopy(batch_index, key, dtype, f);
293     } else {
294       (*output_sparse_values_tmp)[d][batch_index] =
295           Tensor(dtype, TensorShape({0}));
296     }
297   }
298   return Status::OK();
299 }
300 
GetSparseTensorShapes(const VarLenFeature & var_len_feature,const std::vector<Tensor> & sparse_values_tmp,const int batch_size,VarLenFeatureBatchShapes * output_shapes)301 Status GetSparseTensorShapes(const VarLenFeature& var_len_feature,
302                              const std::vector<Tensor>& sparse_values_tmp,
303                              const int batch_size,
304                              VarLenFeatureBatchShapes* output_shapes) {
305   int64 total_num_features = 0;
306   int64 max_num_features = 0;
307   for (int b = 0; b < batch_size; ++b) {
308     const Tensor& t = sparse_values_tmp[b];
309     const int64 num_elements = t.shape().num_elements();
310     total_num_features += num_elements;
311     max_num_features = std::max(max_num_features, num_elements);
312   }
313   output_shapes->indices_shape.AddDim(total_num_features);
314   output_shapes->indices_shape.AddDim(2);
315   output_shapes->values_shape.AddDim(total_num_features);
316   output_shapes->max_num_features = max_num_features;
317   return Status::OK();
318 }
319 
BatchExampleProtoToTensors(const std::vector<const Example * > & examples,const std::vector<string> & names,const std::vector<FixedLenFeature> & fixed_len_features,const std::vector<VarLenFeature> & var_len_features,Allocator * allocator,std::vector<Tensor> * output_dense_values_tensor,std::vector<Tensor> * output_sparse_indices_tensor,std::vector<Tensor> * output_sparse_values_tensor,std::vector<Tensor> * output_sparse_shapes_tensor)320 Status BatchExampleProtoToTensors(
321     const std::vector<const Example*>& examples,
322     const std::vector<string>& names,
323     const std::vector<FixedLenFeature>& fixed_len_features,
324     const std::vector<VarLenFeature>& var_len_features, Allocator* allocator,
325     std::vector<Tensor>* output_dense_values_tensor,
326     std::vector<Tensor>* output_sparse_indices_tensor,
327     std::vector<Tensor>* output_sparse_values_tensor,
328     std::vector<Tensor>* output_sparse_shapes_tensor) {
329   const int batch_size = examples.size();
330 
331   const bool has_names = (!names.empty());
332   if (has_names) {
333     if (names.size() != examples.size()) {
334       return errors::InvalidArgument(
335           "Expected len(names) == len(examples), but got: ", names.size(),
336           " vs. ", examples.size());
337     }
338   }
339 
340   // We also need a map of Tensor pointers for the SingleExampleProtoToTensors
341   // call. (Is there a better solution here?)
342   std::vector<Tensor*> output_dense_values_tensor_ptrs(
343       fixed_len_features.size());
344 
345   // Preallocate dense_values, since we know their sizes.
346   for (size_t d = 0; d < fixed_len_features.size(); ++d) {
347     const FixedLenFeature& config = fixed_len_features[d];
348     TensorShape out_shape;
349     out_shape.AddDim(batch_size);
350     const TensorShape& shape = config.shape;
351     const DataType& dtype = config.dtype;
352     for (const int dim : shape.dim_sizes()) out_shape.AddDim(dim);
353     (*output_dense_values_tensor)[d] = Tensor(allocator, dtype, out_shape);
354     output_dense_values_tensor_ptrs[d] = &(*output_dense_values_tensor)[d];
355   }
356 
357   // Temporary vector to hold sparse values.
358   std::vector<std::vector<Tensor>> sparse_values_tmp(var_len_features.size());
359 
360   for (size_t d = 0; d < var_len_features.size(); ++d) {
361     sparse_values_tmp[d] = std::vector<Tensor>(batch_size);
362   }
363 
364   for (size_t b = 0; b < examples.size(); ++b) {
365     const Example& ex = *(examples[b]);
366     const string& example_name = (has_names) ? names[b] : "<unknown>";
367     TF_RETURN_IF_ERROR(SingleExampleProtoToTensors(
368         ex, example_name, b, fixed_len_features, var_len_features,
369         &output_dense_values_tensor_ptrs, &sparse_values_tmp));
370   }
371 
372   for (size_t d = 0; d < var_len_features.size(); ++d) {
373     const VarLenFeature& feature_config = var_len_features[d];
374     const DataType& dtype = feature_config.dtype;
375     const std::vector<Tensor>& sparse_values_tensor = sparse_values_tmp[d];
376 
377     VarLenFeatureBatchShapes sparse_tensor_batch_shapes;
378     TF_RETURN_IF_ERROR(GetSparseTensorShapes(feature_config,
379                                              sparse_values_tensor, batch_size,
380                                              &sparse_tensor_batch_shapes));
381     const TensorShape& indices_shape = sparse_tensor_batch_shapes.indices_shape;
382     const TensorShape& values_shape = sparse_tensor_batch_shapes.values_shape;
383 
384     // Allocate the sparse indices here.
385     (*output_sparse_indices_tensor)[d] =
386         Tensor(allocator, DT_INT64, indices_shape);
387     (*output_sparse_values_tensor)[d] = Tensor(allocator, dtype, values_shape);
388     (*output_sparse_shapes_tensor)[d] =
389         Tensor(allocator, DT_INT64, TensorShape({2}));
390 
391     auto shape_t = (*output_sparse_shapes_tensor)[d].vec<int64>();
392     shape_t(0) = batch_size;
393     shape_t(1) = sparse_tensor_batch_shapes.max_num_features;
394 
395     Tensor* sp_indices_d = &(*output_sparse_indices_tensor)[d];
396     Tensor* sp_values_d = &(*output_sparse_values_tensor)[d];
397 
398     int64 offset = 0;
399     for (int b = 0; b < batch_size; ++b) {
400       const int64 num_elements = CopyIntoSparseTensor(
401           sparse_values_tensor[b], b, offset, sp_indices_d, sp_values_d);
402       offset += num_elements;
403     }
404   }
405   return Status::OK();
406 }
407 
FinishInit(int op_version)408 Status ParseExampleAttrs::FinishInit(int op_version) {
409   switch (op_version) {
410     case 1:
411       num_ragged = 0;
412       break;
413     case 2:
414       num_dense = dense_types.size();
415       num_ragged = ragged_value_types.size();
416       break;
417     default:
418       return errors::InvalidArgument("Unexpected op_version", op_version);
419   }
420   if (static_cast<size_t>(num_sparse) != sparse_types.size()) {
421     return errors::InvalidArgument("len(sparse_keys) != len(sparse_types)");
422   }
423   if (static_cast<size_t>(num_dense) != dense_types.size()) {
424     return errors::InvalidArgument("len(dense_keys) != len(dense_types)");
425   }
426   if (static_cast<size_t>(num_dense) != dense_shapes.size()) {
427     return errors::InvalidArgument("len(dense_keys) != len(dense_shapes)");
428   }
429   if (static_cast<size_t>(num_ragged) != ragged_value_types.size()) {
430     return errors::InvalidArgument(
431         "len(ragged_keys) != len(ragged_value_types)");
432   }
433   if (static_cast<size_t>(num_ragged) != ragged_split_types.size()) {
434     return errors::InvalidArgument(
435         "len(ragged_keys) != len(ragged_split_types)");
436   }
437   if (num_dense > std::numeric_limits<int32>::max()) {
438     return errors::InvalidArgument("num_dense_ too large");
439   }
440   for (const DataType& type : dense_types) {
441     TF_RETURN_IF_ERROR(CheckValidType(type));
442   }
443   for (const DataType& type : sparse_types) {
444     TF_RETURN_IF_ERROR(CheckValidType(type));
445   }
446   for (const DataType& type : ragged_value_types) {
447     TF_RETURN_IF_ERROR(CheckValidType(type));
448   }
449   for (const DataType& type : ragged_split_types) {
450     if (!(type == DT_INT64 || type == DT_INT32)) {
451       return errors::InvalidArgument("Invalid ragged_split_type: ",
452                                      DataTypeString(type));
453     }
454   }
455   return Status::OK();
456 }
457 
FinishInit()458 Status ParseSingleExampleAttrs::FinishInit() {
459   if (sparse_keys.size() != sparse_types.size()) {
460     return errors::InvalidArgument("len(sparse_keys) != len(sparse_types)");
461   }
462   if (dense_keys.size() != dense_types.size()) {
463     return errors::InvalidArgument("len(dense_keys) != len(dense_types)");
464   }
465   if (dense_keys.size() != dense_shapes.size()) {
466     return errors::InvalidArgument("len(dense_keys) != len(dense_shapes)");
467   }
468   for (const DataType& type : dense_types) {
469     TF_RETURN_IF_ERROR(CheckValidType(type));
470   }
471   for (const DataType& type : sparse_types) {
472     TF_RETURN_IF_ERROR(CheckValidType(type));
473   }
474   return Status::OK();
475 }
476 
FinishInit(int op_version)477 Status ParseSequenceExampleAttrs::FinishInit(int op_version) {
478   switch (op_version) {
479     case 1:
480       num_context_ragged = 0;
481       num_feature_list_ragged = 0;
482       if (num_context_sparse != context_sparse_keys.size()) {
483         return errors::InvalidArgument(
484             "num_context_sparse (", num_context_sparse,
485             ") must match the size of context_sparse_keys (",
486             context_sparse_keys.size(), ")");
487       }
488       if (num_context_dense != context_dense_keys.size()) {
489         return errors::InvalidArgument(
490             "num_context_dense (", num_context_dense,
491             ") must match the size of context_dense_keys (",
492             context_dense_keys.size(), ")");
493       }
494       if (num_feature_list_sparse != feature_list_sparse_keys.size()) {
495         return errors::InvalidArgument(
496             "num_feature_list_sparse (", num_feature_list_sparse,
497             ") must match the size of feature_list_sparse_keys (",
498             feature_list_sparse_keys.size(), ")");
499       }
500       if (num_feature_list_dense != feature_list_dense_keys.size()) {
501         return errors::InvalidArgument(
502             "num_feature_list_dense (", num_feature_list_dense,
503             ") must match the size of feature_list_dense_keys (",
504             feature_list_dense_keys.size(), ")");
505       }
506       break;
507     case 2:
508       num_context_dense = context_dense_types.size();
509       num_context_ragged = context_ragged_value_types.size();
510       num_feature_list_ragged = feature_list_ragged_value_types.size();
511       break;
512     default:
513       return errors::InvalidArgument("Unexpected op_version", op_version);
514   }
515   if (num_context_sparse != context_sparse_types.size()) {
516     return errors::InvalidArgument(
517         "num_context_sparse (", num_context_sparse,
518         ") must match the size of context_sparse_types (",
519         context_sparse_types.size(), ")");
520   }
521   if (num_context_dense != context_dense_types.size() ||
522       num_context_dense != context_dense_shapes.size()) {
523     return errors::InvalidArgument(
524         "num_context_dense (", num_context_dense,
525         ") must match the size of context_dense_types (",
526         context_dense_types.size(), ") and context_dense_shapes (",
527         context_dense_shapes.size(), ")");
528   }
529   if ((num_context_ragged != context_ragged_value_types.size()) ||
530       (num_context_ragged != context_ragged_split_types.size())) {
531     return errors::InvalidArgument(
532         "num_context_ragged (", num_context_ragged,
533         ") must match the size of context_ragged_value_types (",
534         context_ragged_value_types.size(), ") and context_ragged_split_types (",
535         context_ragged_split_types.size(), ")");
536   }
537   if (num_feature_list_sparse != feature_list_sparse_types.size()) {
538     return errors::InvalidArgument(
539         "num_feature_list_sparse (", num_feature_list_sparse,
540         ") must match the size of feature_list_sparse_types (",
541         feature_list_sparse_types.size(), ")");
542   }
543   if (num_feature_list_dense != feature_list_dense_types.size() ||
544       num_feature_list_dense != feature_list_dense_shapes.size()) {
545     return errors::InvalidArgument(
546         "num_feature_list_dense (", num_feature_list_dense,
547         ") must match the size of feature_list_dense_types (",
548         feature_list_dense_types.size(), ") and feature_list_dense_shapes (",
549         feature_list_dense_shapes.size(), ")");
550   }
551   if ((num_feature_list_ragged != feature_list_ragged_value_types.size()) ||
552       (num_feature_list_ragged != feature_list_ragged_split_types.size())) {
553     return errors::InvalidArgument(
554         "num_feature_list_ragged (", num_feature_list_ragged,
555         ") must match the size of feature_list_ragged_value_types (",
556         feature_list_ragged_value_types.size(),
557         ") and feature_list_ragged_split_types (",
558         feature_list_ragged_split_types.size(), ")");
559   }
560   for (const DataType& type : context_dense_types) {
561     TF_RETURN_IF_ERROR(CheckValidType(type));
562   }
563   for (const DataType& type : context_sparse_types) {
564     TF_RETURN_IF_ERROR(CheckValidType(type));
565   }
566   for (const DataType& type : feature_list_dense_types) {
567     TF_RETURN_IF_ERROR(CheckValidType(type));
568   }
569   for (const DataType& type : feature_list_sparse_types) {
570     TF_RETURN_IF_ERROR(CheckValidType(type));
571   }
572   for (const DataType& type : context_ragged_value_types) {
573     TF_RETURN_IF_ERROR(CheckValidType(type));
574   }
575   for (const DataType& type : context_ragged_split_types) {
576     if (!(type == DT_INT64 || type == DT_INT32)) {
577       return errors::InvalidArgument("Invalid context_ragged_split_type: ",
578                                      DataTypeString(type));
579     }
580   }
581   for (const DataType& type : feature_list_ragged_value_types) {
582     TF_RETURN_IF_ERROR(CheckValidType(type));
583   }
584   for (const DataType& type : feature_list_ragged_split_types) {
585     if (!(type == DT_INT64 || type == DT_INT32)) {
586       return errors::InvalidArgument("Invalid feature_list_ragged_split_type: ",
587                                      DataTypeString(type));
588     }
589   }
590 
591   return Status::OK();
592 }
593 
FinishInit()594 Status ParseSingleSequenceExampleAttrs::FinishInit() {
595   if (static_cast<size_t>(num_context_sparse) != context_sparse_types.size()) {
596     return errors::InvalidArgument(
597         "len(context_sparse_keys) != len(context_sparse_types)");
598   }
599   if (static_cast<size_t>(num_context_dense) != context_dense_types.size()) {
600     return errors::InvalidArgument(
601         "len(context_dense_keys) != len(context_dense_types)");
602   }
603   if (static_cast<size_t>(num_context_dense) != context_dense_shapes.size()) {
604     return errors::InvalidArgument(
605         "len(context_dense_keys) != len(context_dense_shapes)");
606   }
607   if (static_cast<size_t>(num_feature_list_sparse) !=
608       feature_list_sparse_types.size()) {
609     return errors::InvalidArgument(
610         "len(feature_list_sparse_keys) != len(feature_list_sparse_types)");
611   }
612   if (static_cast<size_t>(num_feature_list_dense) !=
613       feature_list_dense_types.size()) {
614     return errors::InvalidArgument(
615         "len(feature_list_dense_keys) != "
616         "len(feature_list_dense_types)");
617   }
618   for (const DataType& type : context_dense_types) {
619     TF_RETURN_IF_ERROR(CheckValidType(type));
620   }
621   for (const DataType& type : context_sparse_types) {
622     TF_RETURN_IF_ERROR(CheckValidType(type));
623   }
624   for (const DataType& type : feature_list_dense_types) {
625     TF_RETURN_IF_ERROR(CheckValidType(type));
626   }
627   for (const DataType& type : feature_list_sparse_types) {
628     TF_RETURN_IF_ERROR(CheckValidType(type));
629   }
630   return Status::OK();
631 }
632 
GetDenseShapes(const std::vector<PartialTensorShape> & dense_shapes,std::vector<bool> * variable_length,std::vector<std::size_t> * elements_per_stride)633 Status GetDenseShapes(const std::vector<PartialTensorShape>& dense_shapes,
634                       std::vector<bool>* variable_length,
635                       std::vector<std::size_t>* elements_per_stride) {
636   // Temporary check until we start allowing a variable length outer
637   // dimension.
638   for (int i = 0; i < dense_shapes.size(); ++i) {
639     bool shape_ok = true;
640     if (dense_shapes[i].dims() == -1) {
641       shape_ok = false;
642     } else {
643       for (int d = 1; d < dense_shapes[i].dims(); ++d) {
644         if (dense_shapes[i].dim_size(d) == -1) {
645           shape_ok = false;
646         }
647       }
648     }
649     if (!shape_ok) {
650       return errors::InvalidArgument(
651           "dense_shapes[", i,
652           "] has unknown rank or unknown inner dimensions: ",
653           dense_shapes[i].DebugString());
654     }
655     TensorShape dense_shape;
656     if (dense_shapes[i].dims() > 0 && dense_shapes[i].dim_size(0) == -1) {
657       variable_length->push_back(true);
658       for (int d = 1; d < dense_shapes[i].dims(); ++d) {
659         dense_shape.AddDim(dense_shapes[i].dim_size(d));
660       }
661     } else {
662       variable_length->push_back(false);
663       dense_shapes[i].AsTensorShape(&dense_shape);
664     }
665     elements_per_stride->push_back(dense_shape.num_elements());
666   }
667   return Status::OK();
668 }
669 
670 }  // namespace tensorflow
671