1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/util/example_proto_helper.h"
16
17 #include <vector>
18
19 #include "tensorflow/core/example/example.pb.h"
20 #include "tensorflow/core/example/feature.pb_text.h"
21 #include "tensorflow/core/framework/numeric_op.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/protobuf.h"
26 #include "tensorflow/core/util/sparse/sparse_tensor.h"
27
28 namespace tensorflow {
29
CheckValidType(const DataType & dtype)30 Status CheckValidType(const DataType& dtype) {
31 switch (dtype) {
32 case DT_INT64:
33 case DT_FLOAT:
34 case DT_STRING:
35 return Status::OK();
36 default:
37 return errors::InvalidArgument("Received input dtype: ",
38 DataTypeString(dtype));
39 }
40 }
41
CheckTypesMatch(const Feature & feature,const DataType & dtype,bool * match)42 Status CheckTypesMatch(const Feature& feature, const DataType& dtype,
43 bool* match) {
44 switch (dtype) {
45 case DT_INT64:
46 *match = (feature.kind_case() == Feature::kInt64List);
47 break;
48 case DT_FLOAT:
49 *match = (feature.kind_case() == Feature::kFloatList);
50 break;
51 case DT_STRING:
52 *match = (feature.kind_case() == Feature::kBytesList);
53 break;
54 default:
55 return errors::InvalidArgument("Invalid input dtype: ",
56 DataTypeString(dtype));
57 }
58 return Status::OK();
59 }
60
FeatureDenseCopy(const std::size_t out_index,const string & name,const string & key,const DataType & dtype,const TensorShape & shape,const Feature & feature,Tensor * out)61 Status FeatureDenseCopy(const std::size_t out_index, const string& name,
62 const string& key, const DataType& dtype,
63 const TensorShape& shape, const Feature& feature,
64 Tensor* out) {
65 const std::size_t num_elements = shape.num_elements();
66 const std::size_t offset = out_index * num_elements;
67
68 switch (dtype) {
69 case DT_INT64: {
70 const Int64List& values = feature.int64_list();
71 if (static_cast<size_t>(values.value_size()) != num_elements) {
72 return errors::InvalidArgument(
73 "Name: ", name, ", Key: ", key, ", Index: ", out_index,
74 ". Number of int64 values != expected. "
75 "values size: ",
76 values.value_size(), " but output shape: ", shape.DebugString());
77 }
78 auto out_p = out->flat<int64>().data() + offset;
79 std::copy_n(values.value().data(), num_elements, out_p);
80 return Status::OK();
81 }
82 case DT_FLOAT: {
83 const FloatList& values = feature.float_list();
84 if (static_cast<size_t>(values.value_size()) != num_elements) {
85 return errors::InvalidArgument(
86 "Name: ", name, ", Key: ", key, ", Index: ", out_index,
87 ". Number of float values != expected. "
88 "values size: ",
89 values.value_size(), " but output shape: ", shape.DebugString());
90 }
91 auto out_p = out->flat<float>().data() + offset;
92 std::copy_n(values.value().data(), num_elements, out_p);
93 return Status::OK();
94 }
95 case DT_STRING: {
96 const BytesList& values = feature.bytes_list();
97 if (static_cast<size_t>(values.value_size()) != num_elements) {
98 return errors::InvalidArgument(
99 "Name: ", name, ", Key ", key, ", Index: ", out_index,
100 ". Number of bytes values != expected. "
101 "Values size: ",
102 values.value_size(), " but output shape: ", shape.DebugString());
103 }
104 auto out_p = out->flat<string>().data() + offset;
105 std::transform(values.value().data(),
106 values.value().data() + num_elements, out_p,
107 [](const string* s) { return *s; });
108 return Status::OK();
109 }
110 default:
111 return errors::InvalidArgument("Invalid input dtype: ",
112 DataTypeString(dtype));
113 }
114 }
115
FeatureSparseCopy(const std::size_t batch,const string & key,const DataType & dtype,const Feature & feature)116 Tensor FeatureSparseCopy(const std::size_t batch, const string& key,
117 const DataType& dtype, const Feature& feature) {
118 switch (dtype) {
119 case DT_INT64: {
120 const Int64List& values = feature.int64_list();
121 const int64 num_elements = values.value_size();
122 Tensor out(dtype, TensorShape({num_elements}));
123 auto out_p = out.flat<int64>().data();
124 std::copy_n(values.value().data(), num_elements, out_p);
125 return out;
126 }
127 case DT_FLOAT: {
128 const FloatList& values = feature.float_list();
129 const int64 num_elements = values.value_size();
130 Tensor out(dtype, TensorShape({num_elements}));
131 auto out_p = out.flat<float>().data();
132 std::copy_n(values.value().data(), num_elements, out_p);
133 return out;
134 }
135 case DT_STRING: {
136 const BytesList& values = feature.bytes_list();
137 const int64 num_elements = values.value_size();
138 Tensor out(dtype, TensorShape({num_elements}));
139 auto out_p = out.flat<string>().data();
140 std::transform(values.value().data(),
141 values.value().data() + num_elements, out_p,
142 [](const string* s) { return *s; });
143 return out;
144 }
145 default:
146 LOG(FATAL) << "not supposed to be here. dtype requested: " << dtype;
147 }
148 }
149
CopyIntoSparseTensor(const Tensor & in,const int batch,const int64 offset,Tensor * indices,Tensor * values)150 int64 CopyIntoSparseTensor(const Tensor& in, const int batch,
151 const int64 offset, Tensor* indices,
152 Tensor* values) {
153 const int64 num_elements = in.shape().num_elements();
154 const DataType& dtype = in.dtype();
155 CHECK_EQ(dtype, values->dtype());
156
157 // Update indices.
158 auto ix_t = indices->matrix<int64>();
159 int64* ix_p = &ix_t(offset, 0);
160 for (int64 i = 0; i < num_elements; ++i, ix_p += 2) {
161 *ix_p = batch; // Column 0 stores the batch entry
162 *(ix_p + 1) = i; // Column 1 stores the index in the batch
163 }
164
165 // Copy values over.
166 switch (dtype) {
167 case DT_INT64: {
168 std::copy_n(in.flat<int64>().data(), num_elements,
169 values->flat<int64>().data() + offset);
170 break;
171 }
172 case DT_FLOAT: {
173 std::copy_n(in.flat<float>().data(), num_elements,
174 values->flat<float>().data() + offset);
175 break;
176 }
177 case DT_STRING: {
178 std::copy_n(in.flat<string>().data(), num_elements,
179 values->flat<string>().data() + offset);
180 break;
181 }
182 default:
183 LOG(FATAL) << "Not supposed to be here. Saw dtype: " << dtype;
184 }
185
186 return num_elements;
187 }
188
RowDenseCopy(const std::size_t & out_index,const DataType & dtype,const Tensor & in,Tensor * out)189 void RowDenseCopy(const std::size_t& out_index, const DataType& dtype,
190 const Tensor& in, Tensor* out) {
191 const std::size_t num_elements = in.shape().num_elements();
192 const std::size_t offset = out_index * num_elements;
193
194 switch (dtype) {
195 case DT_INT64: {
196 std::copy_n(in.flat<int64>().data(), num_elements,
197 out->flat<int64>().data() + offset);
198 break;
199 }
200 case DT_FLOAT: {
201 std::copy_n(in.flat<float>().data(), num_elements,
202 out->flat<float>().data() + offset);
203 break;
204 }
205 case DT_STRING: {
206 std::copy_n(in.flat<string>().data(), num_elements,
207 out->flat<string>().data() + offset);
208 break;
209 }
210 default:
211 LOG(FATAL) << "Not supposed to be here. Saw dtype: " << dtype;
212 }
213 }
214
SingleExampleProtoToTensors(const Example & example,const string & example_name,const int batch_index,const std::vector<FixedLenFeature> & fixed_len_features,const std::vector<VarLenFeature> & var_len_features,std::vector<Tensor * > * output_dense_values_tensor,std::vector<std::vector<Tensor>> * output_sparse_values_tmp)215 Status SingleExampleProtoToTensors(
216 const Example& example, const string& example_name, const int batch_index,
217 const std::vector<FixedLenFeature>& fixed_len_features,
218 const std::vector<VarLenFeature>& var_len_features,
219 std::vector<Tensor*>* output_dense_values_tensor,
220 std::vector<std::vector<Tensor>>* output_sparse_values_tmp) {
221 const Features& features = example.features();
222 const auto& feature_dict = features.feature();
223
224 // Handle dense features.
225 for (size_t d = 0; d < fixed_len_features.size(); ++d) {
226 const FixedLenFeature& feature_config = fixed_len_features[d];
227 const string& key = feature_config.key;
228 const DataType& dtype = feature_config.dtype;
229 const TensorShape& shape = feature_config.shape;
230 const Tensor& default_value = feature_config.default_value;
231 bool required = (default_value.NumElements() == 0);
232 const auto& feature_found = feature_dict.find(key);
233 const bool feature_has_data = // Found key & data type is set
234 (feature_found != feature_dict.end() &&
235 (feature_found->second.kind_case() != Feature::KIND_NOT_SET));
236
237 const bool required_ok = feature_has_data || !required;
238 if (!required_ok) {
239 return errors::InvalidArgument("Name: ", example_name, ", Feature: ", key,
240 " is required but could not be found.");
241 }
242
243 // Perform the FeatureDenseCopy into the output dense_values tensor (if
244 // the value is present).
245 if (feature_has_data) {
246 const Feature& f = feature_found->second;
247 bool types_match;
248 TF_RETURN_IF_ERROR(CheckTypesMatch(f, dtype, &types_match));
249 if (!types_match) {
250 return errors::InvalidArgument("Name: ", example_name,
251 ", Feature: ", key,
252 ". Data types don't match. ",
253 "Expected type: ", DataTypeString(dtype),
254 " Feature is: ", ProtoDebugString(f));
255 }
256 TF_RETURN_IF_ERROR(FeatureDenseCopy(batch_index, example_name, key, dtype,
257 shape, f,
258 (*output_dense_values_tensor)[d]));
259 } else {
260 // If the value is missing, RowDenseCopy the default value.
261 RowDenseCopy(batch_index, dtype, default_value,
262 (*output_dense_values_tensor)[d]);
263 }
264 }
265
266 // Handle sparse features.
267 for (size_t d = 0; d < var_len_features.size(); ++d) {
268 const VarLenFeature& feature_config = var_len_features[d];
269 const string& key = feature_config.key;
270 const DataType& dtype = feature_config.dtype;
271 const auto& feature_found = feature_dict.find(key);
272
273 const bool feature_has_data = // Found key & data type is set
274 (feature_found != feature_dict.end() &&
275 (feature_found->second.kind_case() != Feature::KIND_NOT_SET));
276
277 if (feature_has_data) {
278 const Feature& f = feature_found->second;
279 bool types_match;
280 TF_RETURN_IF_ERROR(CheckTypesMatch(f, dtype, &types_match));
281 if (!types_match) {
282 return errors::InvalidArgument("Name: ", example_name,
283 ", Feature: ", key,
284 ". Data types don't match. ",
285 "Expected type: ", DataTypeString(dtype),
286 " Feature is: ", ProtoDebugString(f));
287 }
288 (*output_sparse_values_tmp)[d][batch_index] =
289 FeatureSparseCopy(batch_index, key, dtype, f);
290 } else {
291 (*output_sparse_values_tmp)[d][batch_index] =
292 Tensor(dtype, TensorShape({0}));
293 }
294 }
295 return Status::OK();
296 }
297
GetSparseTensorShapes(const VarLenFeature & var_len_feature,const std::vector<Tensor> & sparse_values_tmp,const int batch_size,VarLenFeatureBatchShapes * output_shapes)298 Status GetSparseTensorShapes(const VarLenFeature& var_len_feature,
299 const std::vector<Tensor>& sparse_values_tmp,
300 const int batch_size,
301 VarLenFeatureBatchShapes* output_shapes) {
302 int64 total_num_features = 0;
303 int64 max_num_features = 0;
304 for (int b = 0; b < batch_size; ++b) {
305 const Tensor& t = sparse_values_tmp[b];
306 const int64 num_elements = t.shape().num_elements();
307 total_num_features += num_elements;
308 max_num_features = std::max(max_num_features, num_elements);
309 }
310 output_shapes->indices_shape.AddDim(total_num_features);
311 output_shapes->indices_shape.AddDim(2);
312 output_shapes->values_shape.AddDim(total_num_features);
313 output_shapes->max_num_features = max_num_features;
314 return Status::OK();
315 }
316
BatchExampleProtoToTensors(const std::vector<const Example * > & examples,const std::vector<string> & names,const std::vector<FixedLenFeature> & fixed_len_features,const std::vector<VarLenFeature> & var_len_features,Allocator * allocator,std::vector<Tensor> * output_dense_values_tensor,std::vector<Tensor> * output_sparse_indices_tensor,std::vector<Tensor> * output_sparse_values_tensor,std::vector<Tensor> * output_sparse_shapes_tensor)317 Status BatchExampleProtoToTensors(
318 const std::vector<const Example*>& examples,
319 const std::vector<string>& names,
320 const std::vector<FixedLenFeature>& fixed_len_features,
321 const std::vector<VarLenFeature>& var_len_features, Allocator* allocator,
322 std::vector<Tensor>* output_dense_values_tensor,
323 std::vector<Tensor>* output_sparse_indices_tensor,
324 std::vector<Tensor>* output_sparse_values_tensor,
325 std::vector<Tensor>* output_sparse_shapes_tensor) {
326 const int batch_size = examples.size();
327
328 const bool has_names = (!names.empty());
329 if (has_names) {
330 if (names.size() != examples.size()) {
331 return errors::InvalidArgument(
332 "Expected len(names) == len(examples), but got: ", names.size(),
333 " vs. ", examples.size());
334 }
335 }
336
337 // We also need a map of Tensor pointers for the SingleExampleProtoToTensors
338 // call. (Is there a better solution here?)
339 std::vector<Tensor*> output_dense_values_tensor_ptrs(
340 fixed_len_features.size());
341
342 // Preallocate dense_values, since we know their sizes.
343 for (size_t d = 0; d < fixed_len_features.size(); ++d) {
344 const FixedLenFeature& config = fixed_len_features[d];
345 TensorShape out_shape;
346 out_shape.AddDim(batch_size);
347 const TensorShape& shape = config.shape;
348 const DataType& dtype = config.dtype;
349 for (const int dim : shape.dim_sizes()) out_shape.AddDim(dim);
350 (*output_dense_values_tensor)[d] = Tensor(allocator, dtype, out_shape);
351 output_dense_values_tensor_ptrs[d] = &(*output_dense_values_tensor)[d];
352 }
353
354 // Temporary vector to hold sparse values.
355 std::vector<std::vector<Tensor>> sparse_values_tmp(var_len_features.size());
356
357 for (size_t d = 0; d < var_len_features.size(); ++d) {
358 sparse_values_tmp[d] = std::vector<Tensor>(batch_size);
359 }
360
361 for (size_t b = 0; b < examples.size(); ++b) {
362 const Example& ex = *(examples[b]);
363 const string& example_name = (has_names) ? names[b] : "<unknown>";
364 TF_RETURN_IF_ERROR(SingleExampleProtoToTensors(
365 ex, example_name, b, fixed_len_features, var_len_features,
366 &output_dense_values_tensor_ptrs, &sparse_values_tmp));
367 }
368
369 for (size_t d = 0; d < var_len_features.size(); ++d) {
370 const VarLenFeature& feature_config = var_len_features[d];
371 const DataType& dtype = feature_config.dtype;
372 const std::vector<Tensor>& sparse_values_tensor = sparse_values_tmp[d];
373
374 VarLenFeatureBatchShapes sparse_tensor_batch_shapes;
375 TF_RETURN_IF_ERROR(GetSparseTensorShapes(feature_config,
376 sparse_values_tensor, batch_size,
377 &sparse_tensor_batch_shapes));
378 const TensorShape& indices_shape = sparse_tensor_batch_shapes.indices_shape;
379 const TensorShape& values_shape = sparse_tensor_batch_shapes.values_shape;
380
381 // Allocate the sparse indices here.
382 (*output_sparse_indices_tensor)[d] =
383 Tensor(allocator, DT_INT64, indices_shape);
384 (*output_sparse_values_tensor)[d] = Tensor(allocator, dtype, values_shape);
385 (*output_sparse_shapes_tensor)[d] =
386 Tensor(allocator, DT_INT64, TensorShape({2}));
387
388 auto shape_t = (*output_sparse_shapes_tensor)[d].vec<int64>();
389 shape_t(0) = batch_size;
390 shape_t(1) = sparse_tensor_batch_shapes.max_num_features;
391
392 Tensor* sp_indices_d = &(*output_sparse_indices_tensor)[d];
393 Tensor* sp_values_d = &(*output_sparse_values_tensor)[d];
394
395 int64 offset = 0;
396 for (int b = 0; b < batch_size; ++b) {
397 const int64 num_elements = CopyIntoSparseTensor(
398 sparse_values_tensor[b], b, offset, sp_indices_d, sp_values_d);
399 offset += num_elements;
400 }
401 }
402 return Status::OK();
403 }
404
FinishInit()405 Status ParseExampleAttrs::FinishInit() {
406 if (static_cast<size_t>(num_sparse) != sparse_types.size()) {
407 return errors::InvalidArgument("len(sparse_keys) != len(sparse_types)");
408 }
409 if (static_cast<size_t>(num_dense) != dense_types.size()) {
410 return errors::InvalidArgument("len(dense_keys) != len(dense_types)");
411 }
412 if (static_cast<size_t>(num_dense) != dense_shapes.size()) {
413 return errors::InvalidArgument("len(dense_keys) != len(dense_shapes)");
414 }
415 if (num_dense > std::numeric_limits<int32>::max()) {
416 return errors::InvalidArgument("num_dense_ too large");
417 }
418 for (const DataType& type : dense_types) {
419 TF_RETURN_IF_ERROR(CheckValidType(type));
420 }
421 for (const DataType& type : sparse_types) {
422 TF_RETURN_IF_ERROR(CheckValidType(type));
423 }
424 return Status::OK();
425 }
426
FinishInit()427 Status ParseSingleExampleAttrs::FinishInit() {
428 if (sparse_keys.size() != sparse_types.size()) {
429 return errors::InvalidArgument("len(sparse_keys) != len(sparse_types)");
430 }
431 if (dense_keys.size() != dense_types.size()) {
432 return errors::InvalidArgument("len(dense_keys) != len(dense_types)");
433 }
434 if (dense_keys.size() != dense_shapes.size()) {
435 return errors::InvalidArgument("len(dense_keys) != len(dense_shapes)");
436 }
437 for (const DataType& type : dense_types) {
438 TF_RETURN_IF_ERROR(CheckValidType(type));
439 }
440 for (const DataType& type : sparse_types) {
441 TF_RETURN_IF_ERROR(CheckValidType(type));
442 }
443 return Status::OK();
444 }
445
FinishInit()446 Status ParseSequenceExampleAttrs::FinishInit() {
447 if (num_context_sparse != context_sparse_keys.size() ||
448 num_context_sparse != context_sparse_types.size()) {
449 return errors::InvalidArgument(
450 "num_context_sparse (", num_context_sparse,
451 ") must match the size of context_sparse_keys (",
452 context_sparse_keys.size(), ") and context_sparse_types (",
453 context_sparse_types.size(), ")");
454 }
455 if (num_context_dense != context_dense_keys.size() ||
456 num_context_dense != context_dense_types.size() ||
457 num_context_dense != context_dense_shapes.size()) {
458 return errors::InvalidArgument(
459 "num_context_dense (", num_context_dense,
460 ") must match the size of context_dense_keys (",
461 context_dense_keys.size(), "), context_dense_types (",
462 context_dense_types.size(), ") and context_dense_shapes (",
463 context_dense_shapes.size(), ")");
464 }
465 if (num_feature_list_sparse != feature_list_sparse_keys.size() ||
466 num_feature_list_sparse != feature_list_sparse_types.size()) {
467 return errors::InvalidArgument(
468 "num_feature_list_sparse (", num_feature_list_sparse,
469 ") must match the size of feature_list_sparse_keys (",
470 feature_list_sparse_keys.size(), ") and feature_list_sparse_types (",
471 feature_list_sparse_types.size(), ")");
472 }
473 if (num_feature_list_dense != feature_list_dense_keys.size() ||
474 num_feature_list_dense != feature_list_dense_types.size() ||
475 num_feature_list_dense != feature_list_dense_shapes.size()) {
476 return errors::InvalidArgument(
477 "num_feature_list_dense (", num_feature_list_dense,
478 ") must match the size of feature_list_dense_keys (",
479 feature_list_dense_keys.size(), "), feature_list_dense_types (",
480 feature_list_dense_types.size(), ") and feature_list_dense_shapes (",
481 feature_list_dense_shapes.size(), ")");
482 }
483 for (const DataType& type : context_dense_types) {
484 TF_RETURN_IF_ERROR(CheckValidType(type));
485 }
486 for (const DataType& type : context_sparse_types) {
487 TF_RETURN_IF_ERROR(CheckValidType(type));
488 }
489 for (const DataType& type : feature_list_dense_types) {
490 TF_RETURN_IF_ERROR(CheckValidType(type));
491 }
492 for (const DataType& type : feature_list_sparse_types) {
493 TF_RETURN_IF_ERROR(CheckValidType(type));
494 }
495
496 return Status::OK();
497 }
498
FinishInit()499 Status ParseSingleSequenceExampleAttrs::FinishInit() {
500 if (static_cast<size_t>(num_context_sparse) != context_sparse_types.size()) {
501 return errors::InvalidArgument(
502 "len(context_sparse_keys) != len(context_sparse_types)");
503 }
504 if (static_cast<size_t>(num_context_dense) != context_dense_types.size()) {
505 return errors::InvalidArgument(
506 "len(context_dense_keys) != len(context_dense_types)");
507 }
508 if (static_cast<size_t>(num_context_dense) != context_dense_shapes.size()) {
509 return errors::InvalidArgument(
510 "len(context_dense_keys) != len(context_dense_shapes)");
511 }
512 if (static_cast<size_t>(num_feature_list_sparse) !=
513 feature_list_sparse_types.size()) {
514 return errors::InvalidArgument(
515 "len(feature_list_sparse_keys) != len(feature_list_sparse_types)");
516 }
517 if (static_cast<size_t>(num_feature_list_dense) !=
518 feature_list_dense_types.size()) {
519 return errors::InvalidArgument(
520 "len(feature_list_dense_keys) != "
521 "len(feature_list_dense_types)");
522 }
523 for (const DataType& type : context_dense_types) {
524 TF_RETURN_IF_ERROR(CheckValidType(type));
525 }
526 for (const DataType& type : context_sparse_types) {
527 TF_RETURN_IF_ERROR(CheckValidType(type));
528 }
529 for (const DataType& type : feature_list_dense_types) {
530 TF_RETURN_IF_ERROR(CheckValidType(type));
531 }
532 for (const DataType& type : feature_list_sparse_types) {
533 TF_RETURN_IF_ERROR(CheckValidType(type));
534 }
535 return Status::OK();
536 }
537
538 } // namespace tensorflow
539