• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/parsing_ops.cc.
17 
18 #include <numeric>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "tensorflow/core/example/example.pb.h"
23 #include "tensorflow/core/example/feature.pb_text.h"
24 #include "tensorflow/core/framework/common_shape_fns.h"
25 #include "tensorflow/core/framework/numeric_op.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/protobuf.h"
30 #include "tensorflow/core/util/example_proto_fast_parsing.h"
31 #include "tensorflow/core/util/example_proto_helper.h"
32 #include "tensorflow/core/util/sparse/sparse_tensor.h"
33 #include "tensorflow/core/util/work_sharder.h"
34 
35 namespace tensorflow {
36 
37 class ParseExampleOp : public OpKernel {
38  public:
ParseExampleOp(OpKernelConstruction * ctx)39   explicit ParseExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
40     OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
41   }
42 
Compute(OpKernelContext * ctx)43   void Compute(OpKernelContext* ctx) override {
44     const Tensor* names;
45     const Tensor* serialized;
46     OpInputList dense_keys;
47     OpInputList sparse_keys;
48     OpInputList dense_defaults;
49 
50     // Grab the input list arguments.
51     OP_REQUIRES_OK(ctx, ctx->input("names", &names));
52     OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
53     OP_REQUIRES_OK(ctx, ctx->input_list("dense_keys", &dense_keys));
54     OP_REQUIRES_OK(ctx, ctx->input_list("sparse_keys", &sparse_keys));
55     OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults", &dense_defaults));
56 
57     std::vector<string> dense_keys_t(attrs_.num_dense);
58     std::vector<string> sparse_keys_t(attrs_.num_sparse);
59 
60     // Check that the input list sizes match the attribute declared sizes.
61     CHECK_EQ(dense_keys.size(), attrs_.num_dense);
62     CHECK_EQ(sparse_keys.size(), attrs_.num_sparse);
63 
64     // Copy from OpInputList to std::vector<string>.
65     for (int di = 0; di < attrs_.num_dense; ++di) {
66       dense_keys_t[di] = dense_keys[di].scalar<string>()();
67     }
68     for (int di = 0; di < attrs_.num_sparse; ++di) {
69       sparse_keys_t[di] = sparse_keys[di].scalar<string>()();
70     }
71 
72     if (names->NumElements() > 0) {
73       OP_REQUIRES(
74           ctx, TensorShapeUtils::IsVector(names->shape()),
75           errors::InvalidArgument("Expected names to be a vector, got shape: ",
76                                   names->shape().DebugString()));
77       OP_REQUIRES(
78           ctx, names->NumElements() == serialized->NumElements(),
79           errors::InvalidArgument(
80               "Expected len(names) == len(serialized), but got: ",
81               names->NumElements(), " vs. ", serialized->NumElements()));
82     }
83 
84     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(serialized->shape()),
85                 errors::InvalidArgument(
86                     "Expected serialized to be a vector, got shape: ",
87                     serialized->shape().DebugString()));
88     OP_REQUIRES(ctx, dense_defaults.size() == attrs_.num_dense,
89                 errors::InvalidArgument(
90                     "Expected len(dense_defaults) == len(dense_keys) but got: ",
91                     dense_defaults.size(), " vs. ", attrs_.num_dense));
92 
93     for (int d = 0; d < static_cast<int>(attrs_.num_dense); ++d) {
94       const Tensor& def_value = dense_defaults[d];
95       if (attrs_.variable_length[d]) {
96         OP_REQUIRES(ctx, def_value.NumElements() == 1,
97                     errors::InvalidArgument(
98                         "dense_shape[", d, "] is a variable length shape: ",
99                         attrs_.dense_shapes[d].DebugString(),
100                         ", therefore "
101                         "def_value[",
102                         d,
103                         "] must contain a single element ("
104                         "the padding element).  But its shape is: ",
105                         def_value.shape().DebugString()));
106       } else if (def_value.NumElements() > 0) {
107         OP_REQUIRES(ctx,
108                     attrs_.dense_shapes[d].IsCompatibleWith(def_value.shape()),
109                     errors::InvalidArgument(
110                         "def_value[", d,
111                         "].shape() == ", def_value.shape().DebugString(),
112                         " is not compatible with dense_shapes_[", d,
113                         "] == ", attrs_.dense_shapes[d].DebugString()));
114       }
115       OP_REQUIRES(ctx, def_value.dtype() == attrs_.dense_types[d],
116                   errors::InvalidArgument(
117                       "dense_defaults[", d, "].dtype() == ",
118                       DataTypeString(def_value.dtype()), " != dense_types_[", d,
119                       "] == ", DataTypeString(attrs_.dense_types[d])));
120     }
121 
122     example::Result result;
123 
124     example::FastParseExampleConfig config;
125     for (int d = 0; d < attrs_.num_dense; ++d) {
126       config.dense.push_back({dense_keys_t[d], attrs_.dense_types[d],
127                               attrs_.dense_shapes[d], dense_defaults[d],
128                               attrs_.variable_length[d],
129                               attrs_.elements_per_stride[d]});
130     }
131     for (int d = 0; d < attrs_.num_sparse; ++d) {
132       config.sparse.push_back({sparse_keys_t[d], attrs_.sparse_types[d]});
133     }
134 
135     auto serialized_t = serialized->flat<string>();
136     auto names_t = names->flat<string>();
137     gtl::ArraySlice<string> slice(serialized_t.data(), serialized_t.size());
138     gtl::ArraySlice<string> names_slice(names_t.data(), names_t.size());
139 
140     OP_REQUIRES_OK(
141         ctx,
142         FastParseExample(
143             config, slice, names_slice,
144             ctx->device()->tensorflow_cpu_worker_threads()->workers, &result));
145 
146     OpOutputList dense_values;
147     OpOutputList sparse_indices;
148     OpOutputList sparse_values;
149     OpOutputList sparse_shapes;
150     OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values));
151     OP_REQUIRES_OK(ctx, ctx->output_list("sparse_indices", &sparse_indices));
152     OP_REQUIRES_OK(ctx, ctx->output_list("sparse_values", &sparse_values));
153     OP_REQUIRES_OK(ctx, ctx->output_list("sparse_shapes", &sparse_shapes));
154     for (int d = 0; d < attrs_.num_dense; ++d) {
155       dense_values.set(d, result.dense_values[d]);
156     }
157     for (int d = 0; d < attrs_.num_sparse; ++d) {
158       sparse_indices.set(d, result.sparse_indices[d]);
159       sparse_values.set(d, result.sparse_values[d]);
160       sparse_shapes.set(d, result.sparse_shapes[d]);
161     }
162   }
163 
164  protected:
165   ParseExampleAttrs attrs_;
166 };
167 
168 REGISTER_KERNEL_BUILDER(Name("ParseExample").Device(DEVICE_CPU),
169                         ParseExampleOp);
170 
171 class ParseSingleExampleOp : public OpKernel {
172  public:
ParseSingleExampleOp(OpKernelConstruction * ctx)173   explicit ParseSingleExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
174     OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
175   }
176 
Compute(OpKernelContext * ctx)177   void Compute(OpKernelContext* ctx) override {
178     const Tensor* serialized;
179     OpInputList dense_defaults;
180 
181     // Grab the input list arguments.
182     OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
183     OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults", &dense_defaults));
184 
185     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()),
186                 errors::InvalidArgument(
187                     "Expected serialized to be a scalar, got shape: ",
188                     serialized->shape().DebugString()));
189     OP_REQUIRES(ctx, dense_defaults.size() == attrs_.dense_keys.size(),
190                 errors::InvalidArgument(
191                     "Expected len(dense_defaults) == len(dense_keys) but got: ",
192                     dense_defaults.size(), " vs. ", attrs_.dense_keys.size()));
193 
194     for (size_t d = 0; d < attrs_.dense_keys.size(); ++d) {
195       const Tensor& def_value = dense_defaults[d];
196       if (attrs_.variable_length[d]) {
197         OP_REQUIRES(ctx, def_value.NumElements() == 1,
198                     errors::InvalidArgument(
199                         "dense_shape[", d, "] is a variable length shape: ",
200                         attrs_.dense_shapes[d].DebugString(),
201                         ", therefore "
202                         "def_value[",
203                         d,
204                         "] must contain a single element ("
205                         "the padding element).  But its shape is: ",
206                         def_value.shape().DebugString()));
207       } else if (def_value.NumElements() > 0) {
208         OP_REQUIRES(ctx,
209                     attrs_.dense_shapes[d].IsCompatibleWith(def_value.shape()),
210                     errors::InvalidArgument(
211                         "def_value[", d,
212                         "].shape() == ", def_value.shape().DebugString(),
213                         " is not compatible with dense_shapes_[", d,
214                         "] == ", attrs_.dense_shapes[d].DebugString()));
215       }
216       OP_REQUIRES(ctx, def_value.dtype() == attrs_.dense_types[d],
217                   errors::InvalidArgument(
218                       "dense_defaults[", d, "].dtype() == ",
219                       DataTypeString(def_value.dtype()), " != dense_types_[", d,
220                       "] == ", DataTypeString(attrs_.dense_types[d])));
221     }
222 
223     example::Result result;
224 
225     // TODO(mrry): Build the configuration once and cache it.
226     example::FastParseExampleConfig config;
227     for (int d = 0; d < attrs_.dense_keys.size(); ++d) {
228       config.dense.push_back({attrs_.dense_keys[d], attrs_.dense_types[d],
229                               attrs_.dense_shapes[d], dense_defaults[d],
230                               attrs_.variable_length[d],
231                               attrs_.elements_per_stride[d]});
232     }
233     for (int d = 0; d < attrs_.sparse_keys.size(); ++d) {
234       config.sparse.push_back({attrs_.sparse_keys[d], attrs_.sparse_types[d]});
235     }
236 
237     const string& serialized_proto = serialized->scalar<string>()();
238 
239     OP_REQUIRES_OK(ctx,
240                    FastParseSingleExample(config, serialized_proto, &result));
241 
242     OpOutputList dense_values;
243     OpOutputList sparse_indices;
244     OpOutputList sparse_values;
245     OpOutputList sparse_shapes;
246     OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values));
247     OP_REQUIRES_OK(ctx, ctx->output_list("sparse_indices", &sparse_indices));
248     OP_REQUIRES_OK(ctx, ctx->output_list("sparse_values", &sparse_values));
249     OP_REQUIRES_OK(ctx, ctx->output_list("sparse_shapes", &sparse_shapes));
250     for (int d = 0; d < attrs_.dense_keys.size(); ++d) {
251       dense_values.set(d, result.dense_values[d]);
252     }
253     for (int d = 0; d < attrs_.sparse_keys.size(); ++d) {
254       sparse_indices.set(d, result.sparse_indices[d]);
255       sparse_values.set(d, result.sparse_values[d]);
256       sparse_shapes.set(d, result.sparse_shapes[d]);
257     }
258   }
259 
260  protected:
261   ParseSingleExampleAttrs attrs_;
262 };
263 
264 REGISTER_KERNEL_BUILDER(Name("ParseSingleExample").Device(DEVICE_CPU),
265                         ParseSingleExampleOp);
266 
267 class ParseSequenceExampleOp : public OpKernel {
268  public:
ParseSequenceExampleOp(OpKernelConstruction * ctx)269   explicit ParseSequenceExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
270     OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
271   }
272 
Compute(OpKernelContext * ctx)273   void Compute(OpKernelContext* ctx) override {
274     const Tensor* debug_name;
275     const Tensor* serialized;
276     OpInputList context_dense_defaults;
277 
278     OP_REQUIRES_OK(ctx, ctx->input("debug_name", &debug_name));
279     OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
280     OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults",
281                                         &context_dense_defaults));
282 
283     bool has_debug_name = (debug_name->NumElements() > 0);
284     if (has_debug_name) {
285       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(debug_name->shape()),
286                   errors::InvalidArgument(
287                       "Expected debug_name to be a vector, got shape: ",
288                       debug_name->shape().DebugString()));
289     }
290 
291     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(serialized->shape()),
292                 errors::InvalidArgument(
293                     "Expected serialized to be a vector, got shape: ",
294                     serialized->shape().DebugString()));
295 
296     OP_REQUIRES(ctx, context_dense_defaults.size() == attrs_.num_context_dense,
297                 errors::InvalidArgument("Expected len(context_dense_defaults) "
298                                         "== len(context_dense_keys) but got: ",
299                                         context_dense_defaults.size(), " vs. ",
300                                         attrs_.num_context_dense));
301 
302     std::vector<bool> required(attrs_.num_context_dense);
303     for (int d = 0; d < attrs_.num_context_dense; ++d) {
304       const Tensor& def_value = context_dense_defaults[d];
305       required[d] = (def_value.NumElements() == 0);  // No default provided.
306 
307       if (def_value.NumElements() > 0) {
308         OP_REQUIRES(ctx, def_value.shape() == attrs_.context_dense_shapes[d],
309                     errors::InvalidArgument(
310                         "default_value[", d,
311                         "].shape() == ", def_value.shape().DebugString(),
312                         " != context_dense_shapes[", d,
313                         "] == ", attrs_.context_dense_shapes[d].DebugString()));
314         OP_REQUIRES(
315             ctx, def_value.dtype() == attrs_.context_dense_types[d],
316             errors::InvalidArgument(
317                 "context_dense_defaults[", d, "].dtype() == ",
318                 DataTypeString(def_value.dtype()), " != context_dense_types[",
319                 d, "] == ", DataTypeString(attrs_.context_dense_types[d])));
320       }
321     }
322 
323     example::Result context_result, feature_list_result;
324     std::vector<Tensor> dense_feature_lengths;
325 
326     example::FastParseExampleConfig context_config;
327     for (int d = 0; d < attrs_.num_context_dense; ++d) {
328       context_config.dense.push_back(
329           {attrs_.context_dense_keys[d], attrs_.context_dense_types[d],
330            attrs_.context_dense_shapes[d], context_dense_defaults[d],
331            false /* attrs_.context_variable_length[d] */,
332            0 /*attrs_.context_elements_per_stride[d] */});
333     }
334     for (int d = 0; d < attrs_.num_context_sparse; ++d) {
335       context_config.sparse.push_back(
336           {attrs_.context_sparse_keys[d], attrs_.context_sparse_types[d]});
337     }
338     example::FastParseExampleConfig feature_list_config;
339     for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
340       DataType dtype = attrs_.feature_list_dense_types[d];
341       Tensor default_value = Tensor(dtype, TensorShape({}));
342       feature_list_config.dense.push_back(
343           {attrs_.feature_list_dense_keys[d], dtype,
344            attrs_.feature_list_dense_shapes[d], default_value,
345            (attrs_.feature_list_dense_missing_assumed_empty.count(
346                 attrs_.feature_list_dense_keys[d]) > 0),
347            0 /*attrs_.context_elements_per_stride[d] */});
348     }
349     for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
350       feature_list_config.sparse.push_back(
351           {attrs_.feature_list_sparse_keys[d],
352            attrs_.feature_list_sparse_types[d]});
353     }
354 
355     auto serialized_t = serialized->flat<string>();
356     auto debug_name_t = debug_name->flat<string>();
357     gtl::ArraySlice<string> slice(serialized_t.data(), serialized_t.size());
358     gtl::ArraySlice<string> names_slice(debug_name_t.data(),
359                                         debug_name_t.size());
360 
361     OP_REQUIRES_OK(
362         ctx,
363         FastParseSequenceExample(
364             context_config, feature_list_config, slice, names_slice,
365             ctx->device()->tensorflow_cpu_worker_threads()->workers,
366             &context_result, &feature_list_result, &dense_feature_lengths));
367 
368     OpOutputList context_sparse_indices;
369     OpOutputList context_sparse_values;
370     OpOutputList context_sparse_shapes;
371     OpOutputList context_dense_values;
372     OpOutputList feature_list_sparse_indices;
373     OpOutputList feature_list_sparse_values;
374     OpOutputList feature_list_sparse_shapes;
375     OpOutputList feature_list_dense_values;
376     OpOutputList feature_list_dense_lengths;
377 
378     OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
379                                          &context_sparse_indices));
380     OP_REQUIRES_OK(
381         ctx, ctx->output_list("context_sparse_values", &context_sparse_values));
382     OP_REQUIRES_OK(
383         ctx, ctx->output_list("context_sparse_shapes", &context_sparse_shapes));
384     OP_REQUIRES_OK(
385         ctx, ctx->output_list("context_dense_values", &context_dense_values));
386     OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
387                                          &context_sparse_indices));
388     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_indices",
389                                          &feature_list_sparse_indices));
390     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_values",
391                                          &feature_list_sparse_values));
392     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_shapes",
393                                          &feature_list_sparse_shapes));
394     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_values",
395                                          &feature_list_dense_values));
396     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_lengths",
397                                          &feature_list_dense_lengths));
398     for (int d = 0; d < attrs_.num_context_dense; ++d) {
399       context_dense_values.set(d, context_result.dense_values[d]);
400     }
401     TensorShape lengths_shape;
402     lengths_shape.AddDim(serialized_t.size());
403     for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
404       feature_list_dense_values.set(d, feature_list_result.dense_values[d]);
405       feature_list_dense_lengths.set(d, dense_feature_lengths[d]);
406     }
407     for (int d = 0; d < attrs_.num_context_sparse; ++d) {
408       context_sparse_indices.set(d, context_result.sparse_indices[d]);
409       context_sparse_values.set(d, context_result.sparse_values[d]);
410       context_sparse_shapes.set(d, context_result.sparse_shapes[d]);
411     }
412     for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
413       feature_list_sparse_indices.set(d, feature_list_result.sparse_indices[d]);
414       feature_list_sparse_values.set(d, feature_list_result.sparse_values[d]);
415       feature_list_sparse_shapes.set(d, feature_list_result.sparse_shapes[d]);
416     }
417   }
418 
419  protected:
420   ParseSequenceExampleAttrs attrs_;
421 };
422 
423 REGISTER_KERNEL_BUILDER(Name("ParseSequenceExample").Device(DEVICE_CPU),
424                         ParseSequenceExampleOp);
425 
426 class ParseSingleSequenceExampleOp : public OpKernel {
427  public:
ParseSingleSequenceExampleOp(OpKernelConstruction * ctx)428   explicit ParseSingleSequenceExampleOp(OpKernelConstruction* ctx)
429       : OpKernel(ctx) {
430     OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
431   }
432 
Compute(OpKernelContext * ctx)433   void Compute(OpKernelContext* ctx) override {
434     const Tensor* debug_name;
435     const Tensor* serialized;
436     OpInputList context_dense_keys;
437     OpInputList context_sparse_keys;
438     OpInputList context_dense_defaults;
439     OpInputList feature_list_dense_keys;
440     OpInputList feature_list_sparse_keys;
441     const Tensor* feature_list_dense_missing_assumed_empty;
442 
443     OP_REQUIRES_OK(ctx, ctx->input("debug_name", &debug_name));
444     OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
445     OP_REQUIRES_OK(ctx, ctx->input("feature_list_dense_missing_assumed_empty",
446                                    &feature_list_dense_missing_assumed_empty));
447     OP_REQUIRES_OK(ctx,
448                    ctx->input_list("context_dense_keys", &context_dense_keys));
449     OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_dense_keys",
450                                         &feature_list_dense_keys));
451     OP_REQUIRES_OK(
452         ctx, ctx->input_list("context_sparse_keys", &context_sparse_keys));
453     OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_sparse_keys",
454                                         &feature_list_sparse_keys));
455     OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults",
456                                         &context_dense_defaults));
457 
458     std::vector<string> context_dense_keys_t(attrs_.num_context_dense);
459     std::vector<string> context_sparse_keys_t(attrs_.num_context_sparse);
460     std::vector<string> feature_list_dense_keys_t(
461         attrs_.num_feature_list_dense);
462     std::vector<string> feature_list_sparse_keys_t(
463         attrs_.num_feature_list_sparse);
464     std::unordered_set<string> feature_list_dense_missing_assumed_empty_set;
465     CHECK_EQ(context_dense_keys.size(), attrs_.num_context_dense);
466     CHECK_EQ(context_sparse_keys.size(), attrs_.num_context_sparse);
467     CHECK_EQ(feature_list_dense_keys.size(), attrs_.num_feature_list_dense);
468     CHECK_EQ(feature_list_sparse_keys.size(), attrs_.num_feature_list_sparse);
469     for (int di = 0; di < attrs_.num_context_dense; ++di) {
470       OP_REQUIRES(ctx,
471                   TensorShapeUtils::IsScalar(context_dense_keys[di].shape()),
472                   errors::InvalidArgument(
473                       "Expected context_dense_keys[", di,
474                       "] to be a scalar, got shape: ",
475                       context_dense_keys[di].shape().DebugString()));
476       context_dense_keys_t[di] = context_dense_keys[di].scalar<string>()();
477     }
478     for (int di = 0; di < attrs_.num_context_sparse; ++di) {
479       OP_REQUIRES(ctx,
480                   TensorShapeUtils::IsScalar(context_sparse_keys[di].shape()),
481                   errors::InvalidArgument(
482                       "Expected context_sparse_keys[", di,
483                       "] to be a scalar, got shape: ",
484                       context_sparse_keys[di].shape().DebugString()));
485       context_sparse_keys_t[di] = context_sparse_keys[di].scalar<string>()();
486     }
487     for (int di = 0; di < attrs_.num_feature_list_dense; ++di) {
488       OP_REQUIRES(
489           ctx, TensorShapeUtils::IsScalar(feature_list_dense_keys[di].shape()),
490           errors::InvalidArgument(
491               "Expected feature_list_dense_keys[", di,
492               "] to be a scalar, got shape: ",
493               feature_list_dense_keys[di].shape().DebugString()));
494       feature_list_dense_keys_t[di] =
495           feature_list_dense_keys[di].scalar<string>()();
496     }
497     for (int di = 0; di < attrs_.num_feature_list_sparse; ++di) {
498       OP_REQUIRES(
499           ctx, TensorShapeUtils::IsScalar(feature_list_sparse_keys[di].shape()),
500           errors::InvalidArgument(
501               "Expected feature_list_sparse_keys[", di,
502               "] to be a scalar, got shape: ",
503               feature_list_sparse_keys[di].shape().DebugString()));
504       feature_list_sparse_keys_t[di] =
505           feature_list_sparse_keys[di].scalar<string>()();
506     }
507     OP_REQUIRES(
508         ctx,
509         TensorShapeUtils::IsVector(
510             feature_list_dense_missing_assumed_empty->shape()),
511         errors::InvalidArgument(
512             "Expected feature_list_dense_missing_assumed_empty ",
513             "to be a vector, got shape: ",
514             feature_list_dense_missing_assumed_empty->shape().DebugString()));
515     auto feature_list_dense_missing_assumped_empty_t =
516         feature_list_dense_missing_assumed_empty->vec<string>();
517     for (int de = 0;
518          de < feature_list_dense_missing_assumed_empty->NumElements(); ++de) {
519       feature_list_dense_missing_assumed_empty_set.insert(
520           feature_list_dense_missing_assumped_empty_t(de));
521     }
522 
523     bool has_debug_name = (debug_name->NumElements() > 0);
524     if (has_debug_name) {
525       OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(debug_name->shape()),
526                   errors::InvalidArgument(
527                       "Expected debug_name to be a scalar, got shape: ",
528                       debug_name->shape().DebugString()));
529     }
530     auto debug_name_t = debug_name->scalar<string>();
531 
532     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()),
533                 errors::InvalidArgument(
534                     "Expected serialized to be a scalar, got shape: ",
535                     serialized->shape().DebugString()));
536 
537     OP_REQUIRES(ctx, context_dense_defaults.size() == attrs_.num_context_dense,
538                 errors::InvalidArgument("Expected len(context_dense_defaults) "
539                                         "== len(context_dense_keys) but got: ",
540                                         context_dense_defaults.size(), " vs. ",
541                                         attrs_.num_context_dense));
542 
543     std::vector<bool> required(attrs_.num_context_dense);
544     for (int d = 0; d < attrs_.num_context_dense; ++d) {
545       const Tensor& def_value = context_dense_defaults[d];
546       required[d] = (def_value.NumElements() == 0);  // No default provided.
547 
548       if (def_value.NumElements() > 0) {
549         OP_REQUIRES(ctx, def_value.shape() == attrs_.context_dense_shapes[d],
550                     errors::InvalidArgument(
551                         "def_value[", d,
552                         "].shape() == ", def_value.shape().DebugString(),
553                         " != context_dense_shapes_[", d,
554                         "] == ", attrs_.context_dense_shapes[d].DebugString()));
555         OP_REQUIRES(
556             ctx, def_value.dtype() == attrs_.context_dense_types[d],
557             errors::InvalidArgument(
558                 "context_dense_defaults[", d, "].dtype() == ",
559                 DataTypeString(def_value.dtype()), " != context_dense_types_[",
560                 d, "] == ", DataTypeString(attrs_.context_dense_types[d])));
561       }
562     }
563 
564     auto serialized_t = serialized->scalar<string>();
565 
566     OpOutputList context_sparse_indices;
567     OpOutputList context_sparse_values;
568     OpOutputList context_sparse_shapes;
569     OpOutputList context_dense_values;
570     OpOutputList feature_list_sparse_indices;
571     OpOutputList feature_list_sparse_values;
572     OpOutputList feature_list_sparse_shapes;
573     OpOutputList feature_list_dense_values;
574 
575     OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
576                                          &context_sparse_indices));
577     OP_REQUIRES_OK(
578         ctx, ctx->output_list("context_sparse_values", &context_sparse_values));
579     OP_REQUIRES_OK(
580         ctx, ctx->output_list("context_sparse_shapes", &context_sparse_shapes));
581     OP_REQUIRES_OK(
582         ctx, ctx->output_list("context_dense_values", &context_dense_values));
583     OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
584                                          &context_sparse_indices));
585     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_indices",
586                                          &feature_list_sparse_indices));
587     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_values",
588                                          &feature_list_sparse_values));
589     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_shapes",
590                                          &feature_list_sparse_shapes));
591     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_values",
592                                          &feature_list_dense_values));
593 
594 #ifdef TENSORFLOW_LITE_PROTOS
595     SequenceExample ex;
596 #else
597     // Allocate the SequenceExample on an arena. Provides better memory locality
598     // and greatly speeds up destruction.
599     protobuf::ArenaOptions options;
600     // We have some hint of what the final proto size will be based on the size
601     // of the serialized bytes- use this to set a custom allocation strategy.
602     // Note that the default allocation strategy is quite conservative (min
603     // block size of 256 bytes, and a max of 8 kilobytes).
604     const size_t block_size = serialized_t().size() * 1.1;
605     options.start_block_size = std::max(options.start_block_size, block_size);
606     options.max_block_size = std::max(options.max_block_size, block_size);
607     protobuf::Arena arena(options);
608     auto& ex = *protobuf::Arena::CreateMessage<SequenceExample>(&arena);
609 #endif
610     OP_REQUIRES(
611         ctx, ParseProtoUnlimited(&ex, serialized_t()),
612         errors::InvalidArgument("Could not parse example input, value: '",
613                                 serialized_t(), "'"));
614 
615     const string& name = (has_debug_name) ? debug_name_t() : "<unknown>";
616     const Features& context = ex.context();
617     const auto& context_dict = context.feature();
618 
619     // Context Dense -----------------------------------------------------------
620 
621     // Preallocate context_dense_values, since we know their sizes
622     for (int d = 0; d < attrs_.num_context_dense; ++d) {
623       TensorShape out_shape;
624       for (const int dim : attrs_.context_dense_shapes[d].dim_sizes())
625         out_shape.AddDim(dim);
626       Tensor* out = nullptr;
627       OP_REQUIRES_OK(ctx, context_dense_values.allocate(d, out_shape, &out));
628     }
629 
630     for (int d = 0; d < attrs_.num_context_dense; ++d) {
631       const string& key = context_dense_keys_t[d];
632       const DataType& dtype = attrs_.context_dense_types[d];
633       const TensorShape& shape = attrs_.context_dense_shapes[d];
634 
635       const auto& feature_found = context_dict.find(key);
636       OP_REQUIRES(
637           ctx, (feature_found != context_dict.end()) || !required[d],
638           errors::InvalidArgument("Name: ", name, ", Context feature '", key,
639                                   "' is required but could not be found."));
640       if (feature_found != context_dict.end()) {
641         const Feature& f = feature_found->second;
642         bool types_match;
643         OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
644         OP_REQUIRES(
645             ctx, types_match,
646             errors::InvalidArgument("Name: ", name, ", Context feature: ", key,
647                                     ".  Data types don't match. ",
648                                     "Expected type: ", DataTypeString(dtype),
649                                     "  Feature is: ", ProtoDebugString(f)));
650 
651         OP_REQUIRES_OK(ctx, FeatureDenseCopy(0, name, key, dtype, shape, f,
652                                              context_dense_values[d]));
653       } else {
654         RowDenseCopy(0, dtype, context_dense_defaults[d],
655                      context_dense_values[d]);
656       }
657     }
658 
659     // Context Sparse ----------------------------------------------------------
660     for (int d = 0; d < attrs_.num_context_sparse; ++d) {
661       const string& key = context_sparse_keys_t[d];
662       const DataType& dtype = attrs_.context_sparse_types[d];
663 
664       const auto& feature_found = context_dict.find(key);
665       bool feature_has_data =  // Found key & data type is set
666           (feature_found != context_dict.end() &&
667            (feature_found->second.kind_case() != Feature::KIND_NOT_SET));
668 
669       if (feature_has_data) {
670         const Feature& f = feature_found->second;
671         bool types_match;
672         OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
673         OP_REQUIRES(
674             ctx, types_match,
675             errors::InvalidArgument("Name: ", name, ", Context feature: ", key,
676                                     ".  Data types don't match. ",
677                                     "Expected type: ", DataTypeString(dtype),
678                                     "  Feature is: ", ProtoDebugString(f)));
679 
680         Tensor feature_values = FeatureSparseCopy(0, key, dtype, f);
681         const int64 num_elements = feature_values.NumElements();
682         TensorShape indices_shape({num_elements, 1});
683         Tensor* sp_indices_d = nullptr;
684         Tensor* sp_shape_d = nullptr;
685         OP_REQUIRES_OK(ctx, context_sparse_indices.allocate(d, indices_shape,
686                                                             &sp_indices_d));
687         context_sparse_values.set(d, feature_values);
688         OP_REQUIRES_OK(ctx, context_sparse_shapes.allocate(d, TensorShape({1}),
689                                                            &sp_shape_d));
690         auto shape_t = sp_shape_d->vec<int64>();
691         shape_t(0) = num_elements;
692         auto indices_t = sp_indices_d->matrix<int64>();
693         std::iota(indices_t.data(), indices_t.data() + num_elements, 0);
694       } else {
695         TensorShape indices_shape({0, 1});
696         TensorShape values_shape({0});
697         Tensor* sp_indices_d = nullptr;
698         Tensor* sp_values_d = nullptr;
699         Tensor* sp_shape_d = nullptr;
700         OP_REQUIRES_OK(ctx, context_sparse_indices.allocate(d, indices_shape,
701                                                             &sp_indices_d));
702         OP_REQUIRES_OK(
703             ctx, context_sparse_values.allocate(d, values_shape, &sp_values_d));
704         OP_REQUIRES_OK(ctx, context_sparse_shapes.allocate(d, TensorShape({1}),
705                                                            &sp_shape_d));
706         auto shape_t = sp_shape_d->vec<int64>();
707         shape_t(0) = 0;
708       }
709     }
710 
711     // Feature List Dense ------------------------------------------------------
712 
713     // Preallocate context_dense_values, since we can infer their
714     // sizes
715     const FeatureLists& feature_lists = ex.feature_lists();
716     const auto& feature_list_dict = feature_lists.feature_list();
717     FeatureList empty_feature_list;  // Placeholder for missing FLs
718 
719     for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
720       const string& key = feature_list_dense_keys_t[d];
721       const DataType& dtype = attrs_.feature_list_dense_types[d];
722       const TensorShape& shape = attrs_.feature_list_dense_shapes[d];
723 
724       const auto& feature_list_found = feature_list_dict.find(key);
725       bool feature_list_missing =
726           (feature_list_found == feature_list_dict.end());
727       bool feature_list_allowed_missing =
728           (feature_list_dense_missing_assumed_empty_set.count(key) > 0);
729 
730       OP_REQUIRES(
731           ctx, !feature_list_missing || feature_list_allowed_missing,
732           errors::InvalidArgument("Name: ", name, ", Feature list '", key,
733                                   "' is required but could not be found.  "
734                                   "Did you mean to include it in "
735                                   "feature_list_dense_missing_assumed_empty or "
736                                   "feature_list_dense_defaults?"));
737 
738       TensorShape out_shape;
739       const FeatureList& fl = (feature_list_missing)
740                                   ? empty_feature_list
741                                   : feature_list_found->second;
742       out_shape.AddDim(fl.feature_size());
743       for (const int dim : attrs_.feature_list_dense_shapes[d].dim_sizes()) {
744         out_shape.AddDim(dim);
745       }
746       Tensor* out = nullptr;
747       OP_REQUIRES_OK(ctx,
748                      feature_list_dense_values.allocate(d, out_shape, &out));
749 
750       for (int64 t = 0; t < fl.feature_size(); ++t) {
751         const Feature& f = fl.feature(t);
752         bool types_match;
753         OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
754         OP_REQUIRES(ctx, types_match,
755                     errors::InvalidArgument(
756                         "Name: ", name, ", Feature list: ", key, ", Index: ", t,
757                         ".  Data types don't match. ",
758                         "Expected type: ", DataTypeString(dtype),
759                         "  Feature is: ", ProtoDebugString(f)));
760         OP_REQUIRES_OK(ctx, FeatureDenseCopy(t, name, key, dtype, shape, f,
761                                              feature_list_dense_values[d]));
762       }
763     }
764 
765     // Feature List Sparse -----------------------------------------------------
766     for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
767       const string& key = feature_list_sparse_keys_t[d];
768       const DataType& dtype = attrs_.feature_list_sparse_types[d];
769 
770       const auto& feature_list_found = feature_list_dict.find(key);
771       bool feature_list_has_data =  // Found key
772           (feature_list_found != feature_list_dict.end());
773 
774       std::vector<Tensor> sparse_values_tmp;
775       int64 feature_list_size = 0;
776       if (feature_list_has_data) {
777         const FeatureList& fl = feature_list_found->second;
778         feature_list_size = fl.feature_size();
779         for (int64 t = 0; t < feature_list_size; ++t) {
780           const Feature& f = fl.feature(t);
781           bool types_match;
782           OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
783           OP_REQUIRES(
784               ctx, f.kind_case() == Feature::KIND_NOT_SET || types_match,
785               errors::InvalidArgument("Name: ", name, ", Feature List: ", key,
786                                       ", Index: ", t,
787                                       ".  Data types don't match. ",
788                                       "Expected type: ", DataTypeString(dtype),
789                                       "  Feature is: ", ProtoDebugString(f)));
790           sparse_values_tmp.push_back(FeatureSparseCopy(t, key, dtype, f));
791         }
792       } else {
793         sparse_values_tmp.push_back(Tensor(dtype, TensorShape({0})));
794       }
795 
796       int64 total_num_features = 0;
797       int64 max_num_features = 0;
798       for (int t = 0; t < feature_list_size; ++t) {
799         const Tensor& v = sparse_values_tmp[t];
800         const int64 num_elements = v.shape().num_elements();
801         total_num_features += num_elements;
802         max_num_features = std::max(max_num_features, num_elements);
803       }
804 
805       TensorShape indices_shape({total_num_features, 2});
806       TensorShape values_shape({total_num_features});
807       Tensor* sp_indices_d = nullptr;
808       Tensor* sp_values_d = nullptr;
809       Tensor* sp_shape_d = nullptr;
810       OP_REQUIRES_OK(ctx, feature_list_sparse_indices.allocate(d, indices_shape,
811                                                                &sp_indices_d));
812       OP_REQUIRES_OK(ctx, feature_list_sparse_values.allocate(d, values_shape,
813                                                               &sp_values_d));
814       OP_REQUIRES_OK(ctx, feature_list_sparse_shapes.allocate(
815                               d, TensorShape({2}), &sp_shape_d));
816       auto shape_t = sp_shape_d->vec<int64>();
817       shape_t(0) = feature_list_size;
818       shape_t(1) = max_num_features;
819 
820       int64 offset = 0;
821 
822       for (int t = 0; t < feature_list_size; ++t) {
823         const int64 num_elements = CopyIntoSparseTensor(
824             sparse_values_tmp[t], t, offset, sp_indices_d, sp_values_d);
825         offset += num_elements;
826       }
827     }
828   }
829 
830  protected:
831   ParseSingleSequenceExampleAttrs attrs_;
832 };
833 
834 REGISTER_KERNEL_BUILDER(Name("ParseSingleSequenceExample").Device(DEVICE_CPU),
835                         ParseSingleSequenceExampleOp);
836 
837 #ifndef IS_MOBILE_PLATFORM
838 // when using lite protos on mobile, decoding JSON is not available.
839 
840 class DecodeJSONExampleOp : public OpKernel {
841  public:
DecodeJSONExampleOp(OpKernelConstruction * ctx)842   explicit DecodeJSONExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
843     resolver_.reset(protobuf::util::NewTypeResolverForDescriptorPool(
844         "type.googleapis.com", protobuf::DescriptorPool::generated_pool()));
845   }
846 
Compute(OpKernelContext * ctx)847   void Compute(OpKernelContext* ctx) {
848     const Tensor* json_examples;
849     OP_REQUIRES_OK(ctx, ctx->input("json_examples", &json_examples));
850     Tensor* binary_examples;
851     OP_REQUIRES_OK(
852         ctx, ctx->allocate_output("binary_examples", json_examples->shape(),
853                                   &binary_examples));
854 
855     for (int i = 0; i < json_examples->NumElements(); ++i) {
856       const string& json_example = json_examples->flat<string>()(i);
857       auto status = protobuf::util::JsonToBinaryString(
858           resolver_.get(), "type.googleapis.com/tensorflow.Example",
859           json_example, &binary_examples->flat<string>()(i));
860       OP_REQUIRES(ctx, status.ok(),
861                   errors::InvalidArgument("Error while parsing JSON: ",
862                                           string(status.error_message())));
863     }
864   }
865 
866  private:
867   std::unique_ptr<protobuf::util::TypeResolver> resolver_;
868 };
869 
870 REGISTER_KERNEL_BUILDER(Name("DecodeJSONExample").Device(DEVICE_CPU),
871                         DecodeJSONExampleOp);
872 #endif
873 
874 }  // namespace tensorflow
875