1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/parsing_ops.cc. 17 18 #include <numeric> 19 #include <unordered_set> 20 #include <vector> 21 22 #include "tensorflow/core/example/example.pb.h" 23 #include "tensorflow/core/example/feature.pb_text.h" 24 #include "tensorflow/core/framework/common_shape_fns.h" 25 #include "tensorflow/core/framework/numeric_op.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/lib/gtl/array_slice.h" 28 #include "tensorflow/core/platform/logging.h" 29 #include "tensorflow/core/platform/protobuf.h" 30 #include "tensorflow/core/util/example_proto_fast_parsing.h" 31 #include "tensorflow/core/util/example_proto_helper.h" 32 #include "tensorflow/core/util/sparse/sparse_tensor.h" 33 #include "tensorflow/core/util/work_sharder.h" 34 35 namespace tensorflow { 36 37 class ParseExampleOp : public OpKernel { 38 public: ParseExampleOp(OpKernelConstruction * ctx)39 explicit ParseExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 40 OP_REQUIRES_OK(ctx, attrs_.Init(ctx)); 41 } 42 Compute(OpKernelContext * ctx)43 void Compute(OpKernelContext* ctx) override { 44 const Tensor* names; 45 const Tensor* serialized; 46 OpInputList dense_keys; 47 OpInputList sparse_keys; 48 OpInputList dense_defaults; 49 50 // Grab the input list arguments. 51 OP_REQUIRES_OK(ctx, ctx->input("names", &names)); 52 OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized)); 53 OP_REQUIRES_OK(ctx, ctx->input_list("dense_keys", &dense_keys)); 54 OP_REQUIRES_OK(ctx, ctx->input_list("sparse_keys", &sparse_keys)); 55 OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults", &dense_defaults)); 56 57 std::vector<string> dense_keys_t(attrs_.num_dense); 58 std::vector<string> sparse_keys_t(attrs_.num_sparse); 59 60 // Check that the input list sizes match the attribute declared sizes. 61 CHECK_EQ(dense_keys.size(), attrs_.num_dense); 62 CHECK_EQ(sparse_keys.size(), attrs_.num_sparse); 63 64 // Copy from OpInputList to std::vector<string>. 65 for (int di = 0; di < attrs_.num_dense; ++di) { 66 dense_keys_t[di] = dense_keys[di].scalar<string>()(); 67 } 68 for (int di = 0; di < attrs_.num_sparse; ++di) { 69 sparse_keys_t[di] = sparse_keys[di].scalar<string>()(); 70 } 71 72 if (names->NumElements() > 0) { 73 OP_REQUIRES( 74 ctx, TensorShapeUtils::IsVector(names->shape()), 75 errors::InvalidArgument("Expected names to be a vector, got shape: ", 76 names->shape().DebugString())); 77 OP_REQUIRES( 78 ctx, names->NumElements() == serialized->NumElements(), 79 errors::InvalidArgument( 80 "Expected len(names) == len(serialized), but got: ", 81 names->NumElements(), " vs. ", serialized->NumElements())); 82 } 83 84 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(serialized->shape()), 85 errors::InvalidArgument( 86 "Expected serialized to be a vector, got shape: ", 87 serialized->shape().DebugString())); 88 OP_REQUIRES(ctx, dense_defaults.size() == attrs_.num_dense, 89 errors::InvalidArgument( 90 "Expected len(dense_defaults) == len(dense_keys) but got: ", 91 dense_defaults.size(), " vs. ", attrs_.num_dense)); 92 93 for (int d = 0; d < static_cast<int>(attrs_.num_dense); ++d) { 94 const Tensor& def_value = dense_defaults[d]; 95 if (attrs_.variable_length[d]) { 96 OP_REQUIRES(ctx, def_value.NumElements() == 1, 97 errors::InvalidArgument( 98 "dense_shape[", d, "] is a variable length shape: ", 99 attrs_.dense_shapes[d].DebugString(), 100 ", therefore " 101 "def_value[", 102 d, 103 "] must contain a single element (" 104 "the padding element). But its shape is: ", 105 def_value.shape().DebugString())); 106 } else if (def_value.NumElements() > 0) { 107 OP_REQUIRES(ctx, 108 attrs_.dense_shapes[d].IsCompatibleWith(def_value.shape()), 109 errors::InvalidArgument( 110 "def_value[", d, 111 "].shape() == ", def_value.shape().DebugString(), 112 " is not compatible with dense_shapes_[", d, 113 "] == ", attrs_.dense_shapes[d].DebugString())); 114 } 115 OP_REQUIRES(ctx, def_value.dtype() == attrs_.dense_types[d], 116 errors::InvalidArgument( 117 "dense_defaults[", d, "].dtype() == ", 118 DataTypeString(def_value.dtype()), " != dense_types_[", d, 119 "] == ", DataTypeString(attrs_.dense_types[d]))); 120 } 121 122 example::Result result; 123 124 example::FastParseExampleConfig config; 125 for (int d = 0; d < attrs_.num_dense; ++d) { 126 config.dense.push_back({dense_keys_t[d], attrs_.dense_types[d], 127 attrs_.dense_shapes[d], dense_defaults[d], 128 attrs_.variable_length[d], 129 attrs_.elements_per_stride[d]}); 130 } 131 for (int d = 0; d < attrs_.num_sparse; ++d) { 132 config.sparse.push_back({sparse_keys_t[d], attrs_.sparse_types[d]}); 133 } 134 135 auto serialized_t = serialized->flat<string>(); 136 auto names_t = names->flat<string>(); 137 gtl::ArraySlice<string> slice(serialized_t.data(), serialized_t.size()); 138 gtl::ArraySlice<string> names_slice(names_t.data(), names_t.size()); 139 140 OP_REQUIRES_OK( 141 ctx, 142 FastParseExample( 143 config, slice, names_slice, 144 ctx->device()->tensorflow_cpu_worker_threads()->workers, &result)); 145 146 OpOutputList dense_values; 147 OpOutputList sparse_indices; 148 OpOutputList sparse_values; 149 OpOutputList sparse_shapes; 150 OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values)); 151 OP_REQUIRES_OK(ctx, ctx->output_list("sparse_indices", &sparse_indices)); 152 OP_REQUIRES_OK(ctx, ctx->output_list("sparse_values", &sparse_values)); 153 OP_REQUIRES_OK(ctx, ctx->output_list("sparse_shapes", &sparse_shapes)); 154 for (int d = 0; d < attrs_.num_dense; ++d) { 155 dense_values.set(d, result.dense_values[d]); 156 } 157 for (int d = 0; d < attrs_.num_sparse; ++d) { 158 sparse_indices.set(d, result.sparse_indices[d]); 159 sparse_values.set(d, result.sparse_values[d]); 160 sparse_shapes.set(d, result.sparse_shapes[d]); 161 } 162 } 163 164 protected: 165 ParseExampleAttrs attrs_; 166 }; 167 168 REGISTER_KERNEL_BUILDER(Name("ParseExample").Device(DEVICE_CPU), 169 ParseExampleOp); 170 171 class ParseSingleExampleOp : public OpKernel { 172 public: ParseSingleExampleOp(OpKernelConstruction * ctx)173 explicit ParseSingleExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 174 OP_REQUIRES_OK(ctx, attrs_.Init(ctx)); 175 } 176 Compute(OpKernelContext * ctx)177 void Compute(OpKernelContext* ctx) override { 178 const Tensor* serialized; 179 OpInputList dense_defaults; 180 181 // Grab the input list arguments. 182 OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized)); 183 OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults", &dense_defaults)); 184 185 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()), 186 errors::InvalidArgument( 187 "Expected serialized to be a scalar, got shape: ", 188 serialized->shape().DebugString())); 189 OP_REQUIRES(ctx, dense_defaults.size() == attrs_.dense_keys.size(), 190 errors::InvalidArgument( 191 "Expected len(dense_defaults) == len(dense_keys) but got: ", 192 dense_defaults.size(), " vs. ", attrs_.dense_keys.size())); 193 194 for (size_t d = 0; d < attrs_.dense_keys.size(); ++d) { 195 const Tensor& def_value = dense_defaults[d]; 196 if (attrs_.variable_length[d]) { 197 OP_REQUIRES(ctx, def_value.NumElements() == 1, 198 errors::InvalidArgument( 199 "dense_shape[", d, "] is a variable length shape: ", 200 attrs_.dense_shapes[d].DebugString(), 201 ", therefore " 202 "def_value[", 203 d, 204 "] must contain a single element (" 205 "the padding element). But its shape is: ", 206 def_value.shape().DebugString())); 207 } else if (def_value.NumElements() > 0) { 208 OP_REQUIRES(ctx, 209 attrs_.dense_shapes[d].IsCompatibleWith(def_value.shape()), 210 errors::InvalidArgument( 211 "def_value[", d, 212 "].shape() == ", def_value.shape().DebugString(), 213 " is not compatible with dense_shapes_[", d, 214 "] == ", attrs_.dense_shapes[d].DebugString())); 215 } 216 OP_REQUIRES(ctx, def_value.dtype() == attrs_.dense_types[d], 217 errors::InvalidArgument( 218 "dense_defaults[", d, "].dtype() == ", 219 DataTypeString(def_value.dtype()), " != dense_types_[", d, 220 "] == ", DataTypeString(attrs_.dense_types[d]))); 221 } 222 223 example::Result result; 224 225 // TODO(mrry): Build the configuration once and cache it. 226 example::FastParseExampleConfig config; 227 for (int d = 0; d < attrs_.dense_keys.size(); ++d) { 228 config.dense.push_back({attrs_.dense_keys[d], attrs_.dense_types[d], 229 attrs_.dense_shapes[d], dense_defaults[d], 230 attrs_.variable_length[d], 231 attrs_.elements_per_stride[d]}); 232 } 233 for (int d = 0; d < attrs_.sparse_keys.size(); ++d) { 234 config.sparse.push_back({attrs_.sparse_keys[d], attrs_.sparse_types[d]}); 235 } 236 237 const string& serialized_proto = serialized->scalar<string>()(); 238 239 OP_REQUIRES_OK(ctx, 240 FastParseSingleExample(config, serialized_proto, &result)); 241 242 OpOutputList dense_values; 243 OpOutputList sparse_indices; 244 OpOutputList sparse_values; 245 OpOutputList sparse_shapes; 246 OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values)); 247 OP_REQUIRES_OK(ctx, ctx->output_list("sparse_indices", &sparse_indices)); 248 OP_REQUIRES_OK(ctx, ctx->output_list("sparse_values", &sparse_values)); 249 OP_REQUIRES_OK(ctx, ctx->output_list("sparse_shapes", &sparse_shapes)); 250 for (int d = 0; d < attrs_.dense_keys.size(); ++d) { 251 dense_values.set(d, result.dense_values[d]); 252 } 253 for (int d = 0; d < attrs_.sparse_keys.size(); ++d) { 254 sparse_indices.set(d, result.sparse_indices[d]); 255 sparse_values.set(d, result.sparse_values[d]); 256 sparse_shapes.set(d, result.sparse_shapes[d]); 257 } 258 } 259 260 protected: 261 ParseSingleExampleAttrs attrs_; 262 }; 263 264 REGISTER_KERNEL_BUILDER(Name("ParseSingleExample").Device(DEVICE_CPU), 265 ParseSingleExampleOp); 266 267 class ParseSequenceExampleOp : public OpKernel { 268 public: ParseSequenceExampleOp(OpKernelConstruction * ctx)269 explicit ParseSequenceExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 270 OP_REQUIRES_OK(ctx, attrs_.Init(ctx)); 271 } 272 Compute(OpKernelContext * ctx)273 void Compute(OpKernelContext* ctx) override { 274 const Tensor* debug_name; 275 const Tensor* serialized; 276 OpInputList context_dense_defaults; 277 278 OP_REQUIRES_OK(ctx, ctx->input("debug_name", &debug_name)); 279 OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized)); 280 OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults", 281 &context_dense_defaults)); 282 283 bool has_debug_name = (debug_name->NumElements() > 0); 284 if (has_debug_name) { 285 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(debug_name->shape()), 286 errors::InvalidArgument( 287 "Expected debug_name to be a vector, got shape: ", 288 debug_name->shape().DebugString())); 289 } 290 291 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(serialized->shape()), 292 errors::InvalidArgument( 293 "Expected serialized to be a vector, got shape: ", 294 serialized->shape().DebugString())); 295 296 OP_REQUIRES(ctx, context_dense_defaults.size() == attrs_.num_context_dense, 297 errors::InvalidArgument("Expected len(context_dense_defaults) " 298 "== len(context_dense_keys) but got: ", 299 context_dense_defaults.size(), " vs. ", 300 attrs_.num_context_dense)); 301 302 std::vector<bool> required(attrs_.num_context_dense); 303 for (int d = 0; d < attrs_.num_context_dense; ++d) { 304 const Tensor& def_value = context_dense_defaults[d]; 305 required[d] = (def_value.NumElements() == 0); // No default provided. 306 307 if (def_value.NumElements() > 0) { 308 OP_REQUIRES(ctx, def_value.shape() == attrs_.context_dense_shapes[d], 309 errors::InvalidArgument( 310 "default_value[", d, 311 "].shape() == ", def_value.shape().DebugString(), 312 " != context_dense_shapes[", d, 313 "] == ", attrs_.context_dense_shapes[d].DebugString())); 314 OP_REQUIRES( 315 ctx, def_value.dtype() == attrs_.context_dense_types[d], 316 errors::InvalidArgument( 317 "context_dense_defaults[", d, "].dtype() == ", 318 DataTypeString(def_value.dtype()), " != context_dense_types[", 319 d, "] == ", DataTypeString(attrs_.context_dense_types[d]))); 320 } 321 } 322 323 example::Result context_result, feature_list_result; 324 std::vector<Tensor> dense_feature_lengths; 325 326 example::FastParseExampleConfig context_config; 327 for (int d = 0; d < attrs_.num_context_dense; ++d) { 328 context_config.dense.push_back( 329 {attrs_.context_dense_keys[d], attrs_.context_dense_types[d], 330 attrs_.context_dense_shapes[d], context_dense_defaults[d], 331 false /* attrs_.context_variable_length[d] */, 332 0 /*attrs_.context_elements_per_stride[d] */}); 333 } 334 for (int d = 0; d < attrs_.num_context_sparse; ++d) { 335 context_config.sparse.push_back( 336 {attrs_.context_sparse_keys[d], attrs_.context_sparse_types[d]}); 337 } 338 example::FastParseExampleConfig feature_list_config; 339 for (int d = 0; d < attrs_.num_feature_list_dense; ++d) { 340 DataType dtype = attrs_.feature_list_dense_types[d]; 341 Tensor default_value = Tensor(dtype, TensorShape({})); 342 feature_list_config.dense.push_back( 343 {attrs_.feature_list_dense_keys[d], dtype, 344 attrs_.feature_list_dense_shapes[d], default_value, 345 (attrs_.feature_list_dense_missing_assumed_empty.count( 346 attrs_.feature_list_dense_keys[d]) > 0), 347 0 /*attrs_.context_elements_per_stride[d] */}); 348 } 349 for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) { 350 feature_list_config.sparse.push_back( 351 {attrs_.feature_list_sparse_keys[d], 352 attrs_.feature_list_sparse_types[d]}); 353 } 354 355 auto serialized_t = serialized->flat<string>(); 356 auto debug_name_t = debug_name->flat<string>(); 357 gtl::ArraySlice<string> slice(serialized_t.data(), serialized_t.size()); 358 gtl::ArraySlice<string> names_slice(debug_name_t.data(), 359 debug_name_t.size()); 360 361 OP_REQUIRES_OK( 362 ctx, 363 FastParseSequenceExample( 364 context_config, feature_list_config, slice, names_slice, 365 ctx->device()->tensorflow_cpu_worker_threads()->workers, 366 &context_result, &feature_list_result, &dense_feature_lengths)); 367 368 OpOutputList context_sparse_indices; 369 OpOutputList context_sparse_values; 370 OpOutputList context_sparse_shapes; 371 OpOutputList context_dense_values; 372 OpOutputList feature_list_sparse_indices; 373 OpOutputList feature_list_sparse_values; 374 OpOutputList feature_list_sparse_shapes; 375 OpOutputList feature_list_dense_values; 376 OpOutputList feature_list_dense_lengths; 377 378 OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices", 379 &context_sparse_indices)); 380 OP_REQUIRES_OK( 381 ctx, ctx->output_list("context_sparse_values", &context_sparse_values)); 382 OP_REQUIRES_OK( 383 ctx, ctx->output_list("context_sparse_shapes", &context_sparse_shapes)); 384 OP_REQUIRES_OK( 385 ctx, ctx->output_list("context_dense_values", &context_dense_values)); 386 OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices", 387 &context_sparse_indices)); 388 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_indices", 389 &feature_list_sparse_indices)); 390 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_values", 391 &feature_list_sparse_values)); 392 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_shapes", 393 &feature_list_sparse_shapes)); 394 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_values", 395 &feature_list_dense_values)); 396 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_lengths", 397 &feature_list_dense_lengths)); 398 for (int d = 0; d < attrs_.num_context_dense; ++d) { 399 context_dense_values.set(d, context_result.dense_values[d]); 400 } 401 TensorShape lengths_shape; 402 lengths_shape.AddDim(serialized_t.size()); 403 for (int d = 0; d < attrs_.num_feature_list_dense; ++d) { 404 feature_list_dense_values.set(d, feature_list_result.dense_values[d]); 405 feature_list_dense_lengths.set(d, dense_feature_lengths[d]); 406 } 407 for (int d = 0; d < attrs_.num_context_sparse; ++d) { 408 context_sparse_indices.set(d, context_result.sparse_indices[d]); 409 context_sparse_values.set(d, context_result.sparse_values[d]); 410 context_sparse_shapes.set(d, context_result.sparse_shapes[d]); 411 } 412 for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) { 413 feature_list_sparse_indices.set(d, feature_list_result.sparse_indices[d]); 414 feature_list_sparse_values.set(d, feature_list_result.sparse_values[d]); 415 feature_list_sparse_shapes.set(d, feature_list_result.sparse_shapes[d]); 416 } 417 } 418 419 protected: 420 ParseSequenceExampleAttrs attrs_; 421 }; 422 423 REGISTER_KERNEL_BUILDER(Name("ParseSequenceExample").Device(DEVICE_CPU), 424 ParseSequenceExampleOp); 425 426 class ParseSingleSequenceExampleOp : public OpKernel { 427 public: ParseSingleSequenceExampleOp(OpKernelConstruction * ctx)428 explicit ParseSingleSequenceExampleOp(OpKernelConstruction* ctx) 429 : OpKernel(ctx) { 430 OP_REQUIRES_OK(ctx, attrs_.Init(ctx)); 431 } 432 Compute(OpKernelContext * ctx)433 void Compute(OpKernelContext* ctx) override { 434 const Tensor* debug_name; 435 const Tensor* serialized; 436 OpInputList context_dense_keys; 437 OpInputList context_sparse_keys; 438 OpInputList context_dense_defaults; 439 OpInputList feature_list_dense_keys; 440 OpInputList feature_list_sparse_keys; 441 const Tensor* feature_list_dense_missing_assumed_empty; 442 443 OP_REQUIRES_OK(ctx, ctx->input("debug_name", &debug_name)); 444 OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized)); 445 OP_REQUIRES_OK(ctx, ctx->input("feature_list_dense_missing_assumed_empty", 446 &feature_list_dense_missing_assumed_empty)); 447 OP_REQUIRES_OK(ctx, 448 ctx->input_list("context_dense_keys", &context_dense_keys)); 449 OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_dense_keys", 450 &feature_list_dense_keys)); 451 OP_REQUIRES_OK( 452 ctx, ctx->input_list("context_sparse_keys", &context_sparse_keys)); 453 OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_sparse_keys", 454 &feature_list_sparse_keys)); 455 OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults", 456 &context_dense_defaults)); 457 458 std::vector<string> context_dense_keys_t(attrs_.num_context_dense); 459 std::vector<string> context_sparse_keys_t(attrs_.num_context_sparse); 460 std::vector<string> feature_list_dense_keys_t( 461 attrs_.num_feature_list_dense); 462 std::vector<string> feature_list_sparse_keys_t( 463 attrs_.num_feature_list_sparse); 464 std::unordered_set<string> feature_list_dense_missing_assumed_empty_set; 465 CHECK_EQ(context_dense_keys.size(), attrs_.num_context_dense); 466 CHECK_EQ(context_sparse_keys.size(), attrs_.num_context_sparse); 467 CHECK_EQ(feature_list_dense_keys.size(), attrs_.num_feature_list_dense); 468 CHECK_EQ(feature_list_sparse_keys.size(), attrs_.num_feature_list_sparse); 469 for (int di = 0; di < attrs_.num_context_dense; ++di) { 470 OP_REQUIRES(ctx, 471 TensorShapeUtils::IsScalar(context_dense_keys[di].shape()), 472 errors::InvalidArgument( 473 "Expected context_dense_keys[", di, 474 "] to be a scalar, got shape: ", 475 context_dense_keys[di].shape().DebugString())); 476 context_dense_keys_t[di] = context_dense_keys[di].scalar<string>()(); 477 } 478 for (int di = 0; di < attrs_.num_context_sparse; ++di) { 479 OP_REQUIRES(ctx, 480 TensorShapeUtils::IsScalar(context_sparse_keys[di].shape()), 481 errors::InvalidArgument( 482 "Expected context_sparse_keys[", di, 483 "] to be a scalar, got shape: ", 484 context_sparse_keys[di].shape().DebugString())); 485 context_sparse_keys_t[di] = context_sparse_keys[di].scalar<string>()(); 486 } 487 for (int di = 0; di < attrs_.num_feature_list_dense; ++di) { 488 OP_REQUIRES( 489 ctx, TensorShapeUtils::IsScalar(feature_list_dense_keys[di].shape()), 490 errors::InvalidArgument( 491 "Expected feature_list_dense_keys[", di, 492 "] to be a scalar, got shape: ", 493 feature_list_dense_keys[di].shape().DebugString())); 494 feature_list_dense_keys_t[di] = 495 feature_list_dense_keys[di].scalar<string>()(); 496 } 497 for (int di = 0; di < attrs_.num_feature_list_sparse; ++di) { 498 OP_REQUIRES( 499 ctx, TensorShapeUtils::IsScalar(feature_list_sparse_keys[di].shape()), 500 errors::InvalidArgument( 501 "Expected feature_list_sparse_keys[", di, 502 "] to be a scalar, got shape: ", 503 feature_list_sparse_keys[di].shape().DebugString())); 504 feature_list_sparse_keys_t[di] = 505 feature_list_sparse_keys[di].scalar<string>()(); 506 } 507 OP_REQUIRES( 508 ctx, 509 TensorShapeUtils::IsVector( 510 feature_list_dense_missing_assumed_empty->shape()), 511 errors::InvalidArgument( 512 "Expected feature_list_dense_missing_assumed_empty ", 513 "to be a vector, got shape: ", 514 feature_list_dense_missing_assumed_empty->shape().DebugString())); 515 auto feature_list_dense_missing_assumped_empty_t = 516 feature_list_dense_missing_assumed_empty->vec<string>(); 517 for (int de = 0; 518 de < feature_list_dense_missing_assumed_empty->NumElements(); ++de) { 519 feature_list_dense_missing_assumed_empty_set.insert( 520 feature_list_dense_missing_assumped_empty_t(de)); 521 } 522 523 bool has_debug_name = (debug_name->NumElements() > 0); 524 if (has_debug_name) { 525 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(debug_name->shape()), 526 errors::InvalidArgument( 527 "Expected debug_name to be a scalar, got shape: ", 528 debug_name->shape().DebugString())); 529 } 530 auto debug_name_t = debug_name->scalar<string>(); 531 532 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()), 533 errors::InvalidArgument( 534 "Expected serialized to be a scalar, got shape: ", 535 serialized->shape().DebugString())); 536 537 OP_REQUIRES(ctx, context_dense_defaults.size() == attrs_.num_context_dense, 538 errors::InvalidArgument("Expected len(context_dense_defaults) " 539 "== len(context_dense_keys) but got: ", 540 context_dense_defaults.size(), " vs. ", 541 attrs_.num_context_dense)); 542 543 std::vector<bool> required(attrs_.num_context_dense); 544 for (int d = 0; d < attrs_.num_context_dense; ++d) { 545 const Tensor& def_value = context_dense_defaults[d]; 546 required[d] = (def_value.NumElements() == 0); // No default provided. 547 548 if (def_value.NumElements() > 0) { 549 OP_REQUIRES(ctx, def_value.shape() == attrs_.context_dense_shapes[d], 550 errors::InvalidArgument( 551 "def_value[", d, 552 "].shape() == ", def_value.shape().DebugString(), 553 " != context_dense_shapes_[", d, 554 "] == ", attrs_.context_dense_shapes[d].DebugString())); 555 OP_REQUIRES( 556 ctx, def_value.dtype() == attrs_.context_dense_types[d], 557 errors::InvalidArgument( 558 "context_dense_defaults[", d, "].dtype() == ", 559 DataTypeString(def_value.dtype()), " != context_dense_types_[", 560 d, "] == ", DataTypeString(attrs_.context_dense_types[d]))); 561 } 562 } 563 564 auto serialized_t = serialized->scalar<string>(); 565 566 OpOutputList context_sparse_indices; 567 OpOutputList context_sparse_values; 568 OpOutputList context_sparse_shapes; 569 OpOutputList context_dense_values; 570 OpOutputList feature_list_sparse_indices; 571 OpOutputList feature_list_sparse_values; 572 OpOutputList feature_list_sparse_shapes; 573 OpOutputList feature_list_dense_values; 574 575 OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices", 576 &context_sparse_indices)); 577 OP_REQUIRES_OK( 578 ctx, ctx->output_list("context_sparse_values", &context_sparse_values)); 579 OP_REQUIRES_OK( 580 ctx, ctx->output_list("context_sparse_shapes", &context_sparse_shapes)); 581 OP_REQUIRES_OK( 582 ctx, ctx->output_list("context_dense_values", &context_dense_values)); 583 OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices", 584 &context_sparse_indices)); 585 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_indices", 586 &feature_list_sparse_indices)); 587 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_values", 588 &feature_list_sparse_values)); 589 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_shapes", 590 &feature_list_sparse_shapes)); 591 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_values", 592 &feature_list_dense_values)); 593 594 #ifdef TENSORFLOW_LITE_PROTOS 595 SequenceExample ex; 596 #else 597 // Allocate the SequenceExample on an arena. Provides better memory locality 598 // and greatly speeds up destruction. 599 protobuf::ArenaOptions options; 600 // We have some hint of what the final proto size will be based on the size 601 // of the serialized bytes- use this to set a custom allocation strategy. 602 // Note that the default allocation strategy is quite conservative (min 603 // block size of 256 bytes, and a max of 8 kilobytes). 604 const size_t block_size = serialized_t().size() * 1.1; 605 options.start_block_size = std::max(options.start_block_size, block_size); 606 options.max_block_size = std::max(options.max_block_size, block_size); 607 protobuf::Arena arena(options); 608 auto& ex = *protobuf::Arena::CreateMessage<SequenceExample>(&arena); 609 #endif 610 OP_REQUIRES( 611 ctx, ParseProtoUnlimited(&ex, serialized_t()), 612 errors::InvalidArgument("Could not parse example input, value: '", 613 serialized_t(), "'")); 614 615 const string& name = (has_debug_name) ? debug_name_t() : "<unknown>"; 616 const Features& context = ex.context(); 617 const auto& context_dict = context.feature(); 618 619 // Context Dense ----------------------------------------------------------- 620 621 // Preallocate context_dense_values, since we know their sizes 622 for (int d = 0; d < attrs_.num_context_dense; ++d) { 623 TensorShape out_shape; 624 for (const int dim : attrs_.context_dense_shapes[d].dim_sizes()) 625 out_shape.AddDim(dim); 626 Tensor* out = nullptr; 627 OP_REQUIRES_OK(ctx, context_dense_values.allocate(d, out_shape, &out)); 628 } 629 630 for (int d = 0; d < attrs_.num_context_dense; ++d) { 631 const string& key = context_dense_keys_t[d]; 632 const DataType& dtype = attrs_.context_dense_types[d]; 633 const TensorShape& shape = attrs_.context_dense_shapes[d]; 634 635 const auto& feature_found = context_dict.find(key); 636 OP_REQUIRES( 637 ctx, (feature_found != context_dict.end()) || !required[d], 638 errors::InvalidArgument("Name: ", name, ", Context feature '", key, 639 "' is required but could not be found.")); 640 if (feature_found != context_dict.end()) { 641 const Feature& f = feature_found->second; 642 bool types_match; 643 OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); 644 OP_REQUIRES( 645 ctx, types_match, 646 errors::InvalidArgument("Name: ", name, ", Context feature: ", key, 647 ". Data types don't match. ", 648 "Expected type: ", DataTypeString(dtype), 649 " Feature is: ", ProtoDebugString(f))); 650 651 OP_REQUIRES_OK(ctx, FeatureDenseCopy(0, name, key, dtype, shape, f, 652 context_dense_values[d])); 653 } else { 654 RowDenseCopy(0, dtype, context_dense_defaults[d], 655 context_dense_values[d]); 656 } 657 } 658 659 // Context Sparse ---------------------------------------------------------- 660 for (int d = 0; d < attrs_.num_context_sparse; ++d) { 661 const string& key = context_sparse_keys_t[d]; 662 const DataType& dtype = attrs_.context_sparse_types[d]; 663 664 const auto& feature_found = context_dict.find(key); 665 bool feature_has_data = // Found key & data type is set 666 (feature_found != context_dict.end() && 667 (feature_found->second.kind_case() != Feature::KIND_NOT_SET)); 668 669 if (feature_has_data) { 670 const Feature& f = feature_found->second; 671 bool types_match; 672 OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); 673 OP_REQUIRES( 674 ctx, types_match, 675 errors::InvalidArgument("Name: ", name, ", Context feature: ", key, 676 ". Data types don't match. ", 677 "Expected type: ", DataTypeString(dtype), 678 " Feature is: ", ProtoDebugString(f))); 679 680 Tensor feature_values = FeatureSparseCopy(0, key, dtype, f); 681 const int64 num_elements = feature_values.NumElements(); 682 TensorShape indices_shape({num_elements, 1}); 683 Tensor* sp_indices_d = nullptr; 684 Tensor* sp_shape_d = nullptr; 685 OP_REQUIRES_OK(ctx, context_sparse_indices.allocate(d, indices_shape, 686 &sp_indices_d)); 687 context_sparse_values.set(d, feature_values); 688 OP_REQUIRES_OK(ctx, context_sparse_shapes.allocate(d, TensorShape({1}), 689 &sp_shape_d)); 690 auto shape_t = sp_shape_d->vec<int64>(); 691 shape_t(0) = num_elements; 692 auto indices_t = sp_indices_d->matrix<int64>(); 693 std::iota(indices_t.data(), indices_t.data() + num_elements, 0); 694 } else { 695 TensorShape indices_shape({0, 1}); 696 TensorShape values_shape({0}); 697 Tensor* sp_indices_d = nullptr; 698 Tensor* sp_values_d = nullptr; 699 Tensor* sp_shape_d = nullptr; 700 OP_REQUIRES_OK(ctx, context_sparse_indices.allocate(d, indices_shape, 701 &sp_indices_d)); 702 OP_REQUIRES_OK( 703 ctx, context_sparse_values.allocate(d, values_shape, &sp_values_d)); 704 OP_REQUIRES_OK(ctx, context_sparse_shapes.allocate(d, TensorShape({1}), 705 &sp_shape_d)); 706 auto shape_t = sp_shape_d->vec<int64>(); 707 shape_t(0) = 0; 708 } 709 } 710 711 // Feature List Dense ------------------------------------------------------ 712 713 // Preallocate context_dense_values, since we can infer their 714 // sizes 715 const FeatureLists& feature_lists = ex.feature_lists(); 716 const auto& feature_list_dict = feature_lists.feature_list(); 717 FeatureList empty_feature_list; // Placeholder for missing FLs 718 719 for (int d = 0; d < attrs_.num_feature_list_dense; ++d) { 720 const string& key = feature_list_dense_keys_t[d]; 721 const DataType& dtype = attrs_.feature_list_dense_types[d]; 722 const TensorShape& shape = attrs_.feature_list_dense_shapes[d]; 723 724 const auto& feature_list_found = feature_list_dict.find(key); 725 bool feature_list_missing = 726 (feature_list_found == feature_list_dict.end()); 727 bool feature_list_allowed_missing = 728 (feature_list_dense_missing_assumed_empty_set.count(key) > 0); 729 730 OP_REQUIRES( 731 ctx, !feature_list_missing || feature_list_allowed_missing, 732 errors::InvalidArgument("Name: ", name, ", Feature list '", key, 733 "' is required but could not be found. " 734 "Did you mean to include it in " 735 "feature_list_dense_missing_assumed_empty or " 736 "feature_list_dense_defaults?")); 737 738 TensorShape out_shape; 739 const FeatureList& fl = (feature_list_missing) 740 ? empty_feature_list 741 : feature_list_found->second; 742 out_shape.AddDim(fl.feature_size()); 743 for (const int dim : attrs_.feature_list_dense_shapes[d].dim_sizes()) { 744 out_shape.AddDim(dim); 745 } 746 Tensor* out = nullptr; 747 OP_REQUIRES_OK(ctx, 748 feature_list_dense_values.allocate(d, out_shape, &out)); 749 750 for (int64 t = 0; t < fl.feature_size(); ++t) { 751 const Feature& f = fl.feature(t); 752 bool types_match; 753 OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); 754 OP_REQUIRES(ctx, types_match, 755 errors::InvalidArgument( 756 "Name: ", name, ", Feature list: ", key, ", Index: ", t, 757 ". Data types don't match. ", 758 "Expected type: ", DataTypeString(dtype), 759 " Feature is: ", ProtoDebugString(f))); 760 OP_REQUIRES_OK(ctx, FeatureDenseCopy(t, name, key, dtype, shape, f, 761 feature_list_dense_values[d])); 762 } 763 } 764 765 // Feature List Sparse ----------------------------------------------------- 766 for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) { 767 const string& key = feature_list_sparse_keys_t[d]; 768 const DataType& dtype = attrs_.feature_list_sparse_types[d]; 769 770 const auto& feature_list_found = feature_list_dict.find(key); 771 bool feature_list_has_data = // Found key 772 (feature_list_found != feature_list_dict.end()); 773 774 std::vector<Tensor> sparse_values_tmp; 775 int64 feature_list_size = 0; 776 if (feature_list_has_data) { 777 const FeatureList& fl = feature_list_found->second; 778 feature_list_size = fl.feature_size(); 779 for (int64 t = 0; t < feature_list_size; ++t) { 780 const Feature& f = fl.feature(t); 781 bool types_match; 782 OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); 783 OP_REQUIRES( 784 ctx, f.kind_case() == Feature::KIND_NOT_SET || types_match, 785 errors::InvalidArgument("Name: ", name, ", Feature List: ", key, 786 ", Index: ", t, 787 ". Data types don't match. ", 788 "Expected type: ", DataTypeString(dtype), 789 " Feature is: ", ProtoDebugString(f))); 790 sparse_values_tmp.push_back(FeatureSparseCopy(t, key, dtype, f)); 791 } 792 } else { 793 sparse_values_tmp.push_back(Tensor(dtype, TensorShape({0}))); 794 } 795 796 int64 total_num_features = 0; 797 int64 max_num_features = 0; 798 for (int t = 0; t < feature_list_size; ++t) { 799 const Tensor& v = sparse_values_tmp[t]; 800 const int64 num_elements = v.shape().num_elements(); 801 total_num_features += num_elements; 802 max_num_features = std::max(max_num_features, num_elements); 803 } 804 805 TensorShape indices_shape({total_num_features, 2}); 806 TensorShape values_shape({total_num_features}); 807 Tensor* sp_indices_d = nullptr; 808 Tensor* sp_values_d = nullptr; 809 Tensor* sp_shape_d = nullptr; 810 OP_REQUIRES_OK(ctx, feature_list_sparse_indices.allocate(d, indices_shape, 811 &sp_indices_d)); 812 OP_REQUIRES_OK(ctx, feature_list_sparse_values.allocate(d, values_shape, 813 &sp_values_d)); 814 OP_REQUIRES_OK(ctx, feature_list_sparse_shapes.allocate( 815 d, TensorShape({2}), &sp_shape_d)); 816 auto shape_t = sp_shape_d->vec<int64>(); 817 shape_t(0) = feature_list_size; 818 shape_t(1) = max_num_features; 819 820 int64 offset = 0; 821 822 for (int t = 0; t < feature_list_size; ++t) { 823 const int64 num_elements = CopyIntoSparseTensor( 824 sparse_values_tmp[t], t, offset, sp_indices_d, sp_values_d); 825 offset += num_elements; 826 } 827 } 828 } 829 830 protected: 831 ParseSingleSequenceExampleAttrs attrs_; 832 }; 833 834 REGISTER_KERNEL_BUILDER(Name("ParseSingleSequenceExample").Device(DEVICE_CPU), 835 ParseSingleSequenceExampleOp); 836 837 #ifndef IS_MOBILE_PLATFORM 838 // when using lite protos on mobile, decoding JSON is not available. 839 840 class DecodeJSONExampleOp : public OpKernel { 841 public: DecodeJSONExampleOp(OpKernelConstruction * ctx)842 explicit DecodeJSONExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 843 resolver_.reset(protobuf::util::NewTypeResolverForDescriptorPool( 844 "type.googleapis.com", protobuf::DescriptorPool::generated_pool())); 845 } 846 Compute(OpKernelContext * ctx)847 void Compute(OpKernelContext* ctx) { 848 const Tensor* json_examples; 849 OP_REQUIRES_OK(ctx, ctx->input("json_examples", &json_examples)); 850 Tensor* binary_examples; 851 OP_REQUIRES_OK( 852 ctx, ctx->allocate_output("binary_examples", json_examples->shape(), 853 &binary_examples)); 854 855 for (int i = 0; i < json_examples->NumElements(); ++i) { 856 const string& json_example = json_examples->flat<string>()(i); 857 auto status = protobuf::util::JsonToBinaryString( 858 resolver_.get(), "type.googleapis.com/tensorflow.Example", 859 json_example, &binary_examples->flat<string>()(i)); 860 OP_REQUIRES(ctx, status.ok(), 861 errors::InvalidArgument("Error while parsing JSON: ", 862 string(status.error_message()))); 863 } 864 } 865 866 private: 867 std::unique_ptr<protobuf::util::TypeResolver> resolver_; 868 }; 869 870 REGISTER_KERNEL_BUILDER(Name("DecodeJSONExample").Device(DEVICE_CPU), 871 DecodeJSONExampleOp); 872 #endif 873 874 } // namespace tensorflow 875