1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/common_shape_fns.h" 16 #include "tensorflow/core/framework/op.h" 17 #include "tensorflow/core/framework/op_def_builder.h" 18 #include "tensorflow/core/framework/shape_inference.h" 19 20 namespace tensorflow { 21 22 // -------------------------------------------------------------------------- 23 24 // The ops in this section can be composed to define an input 25 // pipeline. Each op produces a DT_VARIANT tensor that represents 26 // a DAG of "dataset" objects. An "dataset" object can be converted 27 // to a stateful "iterator" by passing the "dataset" to the 28 // "MakeIterator" op. 29 // 30 // TODO(b/123753214): DT_VARIANT tensors that represent "dataset" objects are 31 // not presently serializable. To avoid issues with constant folding, ensure 32 // that any "source dataset" ops (i.e. ops that output a dataset and do not 33 // take one as input) are marked "stateful". 34 35 REGISTER_OP("TensorDataset") 36 .Input("components: Toutput_types") 37 .Output("handle: variant") 38 .Attr("Toutput_types: list(type) >= 1") 39 .Attr("output_shapes: list(shape) >= 1") 40 .SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must 41 // disable constant folding. 42 .SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): Validate that 43 // `components` have shapes 44 // compatible with 45 // `output_shapes`. 46 47 REGISTER_OP("TensorSliceDataset") 48 .Input("components: Toutput_types") 49 .Output("handle: variant") 50 .Attr("Toutput_types: list(type) >= 1") 51 .Attr("output_shapes: list(shape) >= 1") 52 .SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must 53 // disable constant folding. 54 .SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): Validate that the 55 // dim-0 slices of `components` 56 // have shapes compatible with 57 // `output_shapes`. 58 59 REGISTER_OP("SparseTensorSliceDataset") 60 .Input("indices: int64") 61 .Input("values: Tvalues") 62 .Input("dense_shape: int64") 63 .Output("handle: variant") 64 .Attr("Tvalues: type") 65 .SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must 66 // disable constant folding. 67 .SetShapeFn(shape_inference::ScalarShape); 68 69 REGISTER_OP("GeneratorDataset") 70 .Input("init_func_other_args: Tinit_func_args") 71 .Input("next_func_other_args: Tnext_func_args") 72 .Input("finalize_func_other_args: Tfinalize_func_args") 73 .Output("handle: variant") 74 .Attr("init_func: func") 75 .Attr("next_func: func") 76 .Attr("finalize_func: func") 77 .Attr("Tinit_func_args: list(type) >= 0") 78 .Attr("Tnext_func_args: list(type) >= 0") 79 .Attr("Tfinalize_func_args: list(type) >= 0") 80 .Attr("output_types: list(type) >= 1") 81 .Attr("output_shapes: list(shape) >= 1") 82 .SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must 83 // disable constant folding. 84 .SetShapeFn(shape_inference::ScalarShape); 85 86 REGISTER_OP("ZipDataset") 87 .Input("input_datasets: N * variant") 88 .Output("handle: variant") 89 .Attr("output_types: list(type) >= 1") 90 .Attr("output_shapes: list(shape) >= 1") 91 .Attr("N: int >= 1") 92 .SetShapeFn(shape_inference::ScalarShape); 93 94 REGISTER_OP("ConcatenateDataset") 95 .Input("input_dataset: variant") 96 .Input("another_dataset: variant") 97 .Output("handle: variant") 98 .Attr("output_types: list(type) >= 1") 99 .Attr("output_shapes: list(shape) >= 1") 100 .SetShapeFn(shape_inference::ScalarShape); 101 102 REGISTER_OP("RepeatDataset") 103 .Input("input_dataset: variant") 104 .Input("count: int64") 105 .Output("handle: variant") 106 .Attr("output_types: list(type) >= 1") 107 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c00102(shape_inference::InferenceContext* c) 108 .SetShapeFn([](shape_inference::InferenceContext* c) { 109 shape_inference::ShapeHandle count_shape; 110 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape)); 111 return shape_inference::ScalarShape(c); 112 }); 113 114 REGISTER_OP("TakeDataset") 115 .Input("input_dataset: variant") 116 .Input("count: int64") 117 .Output("handle: variant") 118 .Attr("output_types: list(type) >= 1") 119 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c00202(shape_inference::InferenceContext* c) 120 .SetShapeFn([](shape_inference::InferenceContext* c) { 121 shape_inference::ShapeHandle count_shape; 122 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape)); 123 return shape_inference::ScalarShape(c); 124 }); 125 126 REGISTER_OP("SkipDataset") 127 .Input("input_dataset: variant") 128 .Input("count: int64") 129 .Output("handle: variant") 130 .Attr("output_types: list(type) >= 1") 131 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c00302(shape_inference::InferenceContext* c) 132 .SetShapeFn([](shape_inference::InferenceContext* c) { 133 shape_inference::ShapeHandle count_shape; 134 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape)); 135 return shape_inference::ScalarShape(c); 136 }); 137 138 REGISTER_OP("MapDataset") 139 .Input("input_dataset: variant") 140 .Input("other_arguments: Targuments") 141 .Output("handle: variant") 142 .Attr("f: func") 143 .Attr("Targuments: list(type) >= 0") 144 .Attr("output_types: list(type) >= 1") 145 .Attr("output_shapes: list(shape) >= 1") 146 .Attr("use_inter_op_parallelism: bool = true") 147 .Attr("preserve_cardinality: bool = false") 148 .SetShapeFn(shape_inference::ScalarShape); 149 150 REGISTER_OP("ParallelMapDataset") 151 .Input("input_dataset: variant") 152 .Input("other_arguments: Targuments") 153 .Input("num_parallel_calls: int32") 154 .Output("handle: variant") 155 .Attr("f: func") 156 .Attr("Targuments: list(type) >= 0") 157 .Attr("output_types: list(type) >= 1") 158 .Attr("output_shapes: list(shape) >= 1") 159 .Attr("use_inter_op_parallelism: bool = true") 160 .Attr("sloppy: bool = false") 161 .Attr("preserve_cardinality: bool = false") 162 .SetShapeFn(shape_inference::ScalarShape); 163 164 REGISTER_OP("ParallelMapDatasetV2") 165 .Input("input_dataset: variant") 166 .Input("other_arguments: Targuments") 167 .Input("num_parallel_calls: int64") 168 .Output("handle: variant") 169 .Attr("f: func") 170 .Attr("Targuments: list(type) >= 0") 171 .Attr("output_types: list(type) >= 1") 172 .Attr("output_shapes: list(shape) >= 1") 173 .Attr("use_inter_op_parallelism: bool = true") 174 // "true", "false", or "default". 175 .Attr("deterministic: string = 'default'") 176 .Attr("preserve_cardinality: bool = false") 177 .SetShapeFn(shape_inference::ScalarShape); 178 179 REGISTER_OP("PrefetchDataset") 180 .Input("input_dataset: variant") 181 .Input("buffer_size: int64") 182 .Output("handle: variant") 183 .Attr("output_types: list(type) >= 1") 184 .Attr("output_shapes: list(shape) >= 1") 185 .Attr("slack_period: int = 0") 186 .Attr("legacy_autotune: bool = true") 187 .Attr("buffer_size_min: int = 0") __anon42b7b0c00402(shape_inference::InferenceContext* c) 188 .SetShapeFn([](shape_inference::InferenceContext* c) { 189 shape_inference::ShapeHandle unused; 190 // buffer_size should be a scalar. 191 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 192 return shape_inference::ScalarShape(c); 193 }); 194 195 REGISTER_OP("FlatMapDataset") 196 .Input("input_dataset: variant") 197 .Input("other_arguments: Targuments") 198 .Output("handle: variant") 199 .Attr("f: func") 200 .Attr("Targuments: list(type) >= 0") 201 .Attr("output_types: list(type) >= 1") 202 .Attr("output_shapes: list(shape) >= 1") 203 .SetShapeFn(shape_inference::ScalarShape); 204 205 REGISTER_OP("InterleaveDataset") 206 .Input("input_dataset: variant") 207 .Input("other_arguments: Targuments") 208 .Input("cycle_length: int64") 209 .Input("block_length: int64") 210 .Output("handle: variant") 211 .Attr("f: func") 212 .Attr("Targuments: list(type) >= 0") 213 .Attr("output_types: list(type) >= 1") 214 .Attr("output_shapes: list(shape) >= 1") 215 .SetShapeFn(shape_inference::ScalarShape); 216 217 REGISTER_OP("ParallelInterleaveDatasetV2") 218 .Input("input_dataset: variant") 219 .Input("other_arguments: Targuments") 220 .Input("cycle_length: int64") 221 .Input("block_length: int64") 222 .Input("num_parallel_calls: int64") 223 .Output("handle: variant") 224 .Attr("f: func") 225 .Attr("Targuments: list(type) >= 0") 226 .Attr("output_types: list(type) >= 1") 227 .Attr("output_shapes: list(shape) >= 1") 228 .Attr("sloppy: bool = false") 229 .SetShapeFn(shape_inference::ScalarShape); 230 231 REGISTER_OP("ParallelInterleaveDatasetV3") 232 .Input("input_dataset: variant") 233 .Input("other_arguments: Targuments") 234 .Input("cycle_length: int64") 235 .Input("block_length: int64") 236 .Input("num_parallel_calls: int64") 237 .Output("handle: variant") 238 .Attr("f: func") 239 // "true", "false", or "default". 240 .Attr("deterministic: string = 'default'") 241 .Attr("Targuments: list(type) >= 0") 242 .Attr("output_types: list(type) >= 1") 243 .Attr("output_shapes: list(shape) >= 1") 244 .SetShapeFn(shape_inference::ScalarShape); 245 246 // Like V3, but adds buffer_output_elements and prefetch_input_elements. 247 REGISTER_OP("ParallelInterleaveDatasetV4") 248 .Input("input_dataset: variant") 249 .Input("other_arguments: Targuments") 250 .Input("cycle_length: int64") 251 .Input("block_length: int64") 252 .Input("buffer_output_elements: int64") 253 .Input("prefetch_input_elements: int64") 254 .Input("num_parallel_calls: int64") 255 .Output("handle: variant") 256 .Attr("f: func") 257 // "true", "false", or "default". 258 .Attr("deterministic: string = 'default'") 259 .Attr("Targuments: list(type) >= 0") 260 .Attr("output_types: list(type) >= 1") 261 .Attr("output_shapes: list(shape) >= 1") 262 .SetShapeFn(shape_inference::ScalarShape); 263 264 REGISTER_OP("FilterDataset") 265 .Input("input_dataset: variant") 266 .Input("other_arguments: Targuments") 267 .Output("handle: variant") 268 .Attr("predicate: func") 269 .Attr("Targuments: list(type) >= 0") 270 .Attr("output_types: list(type) >= 1") 271 .Attr("output_shapes: list(shape) >= 1") 272 .SetShapeFn(shape_inference::ScalarShape); 273 274 // This op is no longer supported. 275 REGISTER_OP("FilterByLastComponentDataset") 276 .Input("input_dataset: variant") 277 .Output("output: variant") 278 .Attr("output_types: list(type) >= 1") 279 .Attr("output_shapes: list(shape) >= 1") 280 .SetShapeFn(shape_inference::ScalarShape); 281 282 REGISTER_OP("WindowDataset") 283 .Input("input_dataset: variant") 284 .Input("size: int64") 285 .Input("shift: int64") 286 .Input("stride: int64") 287 .Input("drop_remainder: bool") 288 .Output("handle: variant") 289 .Attr("output_types: list(type) >= 1") 290 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c00502(shape_inference::InferenceContext* c) 291 .SetShapeFn([](shape_inference::InferenceContext* c) { 292 shape_inference::ShapeHandle unused; 293 // size, shift, stride, and drop_remainder should be scalars. 294 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 295 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 296 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 297 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 298 return shape_inference::ScalarShape(c); 299 }); 300 301 REGISTER_OP("BatchDataset") 302 .Input("input_dataset: variant") 303 .Input("batch_size: int64") 304 .Output("handle: variant") 305 .Attr("output_types: list(type) >= 1") 306 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c00602(shape_inference::InferenceContext* c) 307 .SetShapeFn([](shape_inference::InferenceContext* c) { 308 shape_inference::ShapeHandle unused; 309 // batch_size should be a scalar. 310 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 311 return shape_inference::ScalarShape(c); 312 }); 313 314 REGISTER_OP("BatchDatasetV2") 315 .Input("input_dataset: variant") 316 .Input("batch_size: int64") 317 .Input("drop_remainder: bool") 318 .Output("handle: variant") 319 .Attr("parallel_copy: bool = false") 320 .Attr("output_types: list(type) >= 1") 321 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c00702(shape_inference::InferenceContext* c) 322 .SetShapeFn([](shape_inference::InferenceContext* c) { 323 shape_inference::ShapeHandle unused; 324 // batch_size should be a scalar. 325 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 326 // drop_remainder should be a scalar. 327 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 328 return shape_inference::ScalarShape(c); 329 }); 330 331 REGISTER_OP("ParallelBatchDataset") 332 .Input("input_dataset: variant") 333 .Input("batch_size: int64") 334 .Input("num_parallel_calls: int64") 335 .Input("drop_remainder: bool") 336 .Output("handle: variant") 337 .Attr("output_types: list(type) >= 1") 338 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c00802(shape_inference::InferenceContext* c) 339 .SetShapeFn([](shape_inference::InferenceContext* c) { 340 shape_inference::ShapeHandle unused; 341 // batch_size should be a scalar. 342 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 343 // num_parallel_calls should be a scalar. 344 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 345 // drop_remainder should be a scalar. 346 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 347 return shape_inference::ScalarShape(c); 348 }); 349 350 REGISTER_OP("ShardDataset") 351 .Input("input_dataset: variant") 352 .Input("num_shards: int64") 353 .Input("index: int64") 354 .Output("handle: variant") 355 .Attr("require_non_empty: bool = false") 356 .Attr("output_types: list(type) >= 1") 357 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c00902(shape_inference::InferenceContext* c) 358 .SetShapeFn([](shape_inference::InferenceContext* c) { 359 shape_inference::ShapeHandle unused; 360 // num_shards should be a scalar. 361 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 362 // index should be a scalar. 363 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 364 return shape_inference::ScalarShape(c); 365 }); 366 367 // TODO(mrry): Validate that `padded_shapes` are all vectors, the lengths of 368 // `output_types` and `output_shapes` are `N` the `output_shapes` are (as far as 369 // possible to tell statically) compatible with `padded_shapes`, and that 370 // `padding_values` are all scalars. 371 REGISTER_OP("PaddedBatchDataset") 372 .Input("input_dataset: variant") 373 .Input("batch_size: int64") 374 .Input("padded_shapes: N * int64") 375 .Input("padding_values: Toutput_types") 376 .Output("handle: variant") 377 .Attr("Toutput_types: list(type) >= 1") 378 .Attr("output_shapes: list(shape) >= 1") 379 .Attr("N: int >= 1") __anon42b7b0c00a02(shape_inference::InferenceContext* c) 380 .SetShapeFn([](shape_inference::InferenceContext* c) { 381 shape_inference::ShapeHandle unused; 382 // batch_size should be a scalar. 383 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 384 return shape_inference::ScalarShape(c); 385 }); 386 387 REGISTER_OP("PaddedBatchDatasetV2") 388 .Input("input_dataset: variant") 389 .Input("batch_size: int64") 390 .Input("padded_shapes: N * int64") 391 .Input("padding_values: Toutput_types") 392 .Input("drop_remainder: bool") 393 .Output("handle: variant") 394 .Attr("parallel_copy: bool = false") 395 .Attr("Toutput_types: list(type) >= 1") 396 .Attr("output_shapes: list(shape) >= 1") 397 .Attr("N: int >= 1") __anon42b7b0c00b02(shape_inference::InferenceContext* c) 398 .SetShapeFn([](shape_inference::InferenceContext* c) { 399 shape_inference::ShapeHandle unused; 400 // batch_size should be a scalar. 401 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 402 // drop_remainder should be a scalar. 403 TF_RETURN_IF_ERROR( 404 c->WithRank(c->input(c->num_inputs() - 1), 0, &unused)); 405 return shape_inference::ScalarShape(c); 406 }); 407 408 REGISTER_OP("RangeDataset") 409 .Input("start: int64") 410 .Input("stop: int64") 411 .Input("step: int64") 412 .Output("handle: variant") 413 .Attr("output_types: list(type) >= 1") 414 .Attr("output_shapes: list(shape) >= 1") 415 .SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must 416 // disable constant folding. __anon42b7b0c00c02(shape_inference::InferenceContext* c) 417 .SetShapeFn([](shape_inference::InferenceContext* c) { 418 shape_inference::ShapeHandle unused; 419 // start, stop, and step should be scalars. 420 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); 421 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 422 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 423 return shape_inference::ScalarShape(c); 424 }); 425 426 REGISTER_OP("AnonymousSeedGenerator") 427 .Input("seed: int64") 428 .Input("seed2: int64") 429 .Input("reshuffle: bool") 430 .Output("handle: resource") 431 .Output("deleter: variant") __anon42b7b0c00d02(shape_inference::InferenceContext* c) 432 .SetShapeFn([](shape_inference::InferenceContext* c) { 433 c->set_output(0, c->Scalar()); 434 c->set_output(1, c->Scalar()); 435 return Status::OK(); 436 }); 437 438 REGISTER_OP("DatasetCardinality") 439 .Input("input_dataset: variant") 440 .Output("cardinality: int64") 441 .SetShapeFn(shape_inference::ScalarShape); 442 443 REGISTER_OP("DeleteSeedGenerator") 444 .Input("handle: resource") 445 .Input("deleter: variant") 446 .SetShapeFn(shape_inference::NoOutputs); 447 448 // Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator. 449 REGISTER_OP("AnonymousRandomSeedGenerator") 450 .Input("seed: int64") 451 .Input("seed2: int64") 452 .Output("handle: resource") 453 .Output("deleter: variant") __anon42b7b0c00e02(shape_inference::InferenceContext* c) 454 .SetShapeFn([](shape_inference::InferenceContext* c) { 455 c->set_output(0, c->Scalar()); 456 c->set_output(1, c->Scalar()); 457 return Status::OK(); 458 }); 459 460 // Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator. 461 REGISTER_OP("DeleteRandomSeedGenerator") 462 .Input("handle: resource") 463 .Input("deleter: variant") 464 .SetShapeFn(shape_inference::NoOutputs); 465 466 REGISTER_OP("DummySeedGenerator") 467 .Output("handle: resource") __anon42b7b0c00f02(shape_inference::InferenceContext* c) 468 .SetShapeFn([](shape_inference::InferenceContext* c) { 469 c->set_output(0, c->Scalar()); 470 return Status::OK(); 471 }); 472 473 REGISTER_OP("ShuffleDataset") 474 .Input("input_dataset: variant") 475 .Input("buffer_size: int64") 476 .Input("seed: int64") 477 .Input("seed2: int64") 478 .Output("handle: variant") 479 .Attr("reshuffle_each_iteration: bool = true") 480 .Attr("output_types: list(type) >= 1") 481 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c01002(shape_inference::InferenceContext* c) 482 .SetShapeFn([](shape_inference::InferenceContext* c) { 483 shape_inference::ShapeHandle unused; 484 // buffer_size, seed, and seed2 should be scalars. 485 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 486 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 487 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 488 return shape_inference::ScalarShape(c); 489 }); 490 491 REGISTER_OP("ShuffleDatasetV2") 492 .Input("input_dataset: variant") 493 .Input("buffer_size: int64") 494 .Input("seed_generator: resource") 495 .Output("handle: variant") 496 .Attr("output_types: list(type) >= 1") 497 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c01102(shape_inference::InferenceContext* c) 498 .SetShapeFn([](shape_inference::InferenceContext* c) { 499 shape_inference::ShapeHandle unused; 500 // buffer_size and seed_generator should be scalars. 501 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 502 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 503 return shape_inference::ScalarShape(c); 504 }); 505 506 REGISTER_OP("ShuffleDatasetV3") 507 .Input("input_dataset: variant") 508 .Input("buffer_size: int64") 509 .Input("seed: int64") 510 .Input("seed2: int64") 511 .Input("seed_generator: resource") 512 .Output("handle: variant") 513 .Attr("reshuffle_each_iteration: bool = true") 514 .Attr("output_types: list(type) >= 1") 515 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c01202(shape_inference::InferenceContext* c) 516 .SetShapeFn([](shape_inference::InferenceContext* c) { 517 shape_inference::ShapeHandle unused; 518 // buffer_size, seed, seed2, and seed_generator should be scalars. 519 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 520 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 521 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 522 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 523 return shape_inference::ScalarShape(c); 524 }); 525 526 REGISTER_OP("ShuffleAndRepeatDataset") 527 .Input("input_dataset: variant") 528 .Input("buffer_size: int64") 529 .Input("seed: int64") 530 .Input("seed2: int64") 531 .Input("count: int64") 532 .Output("handle: variant") 533 .Attr("output_types: list(type) >= 1") 534 .Attr("output_shapes: list(shape) >= 1") 535 .Attr("reshuffle_each_iteration: bool = true") __anon42b7b0c01302(shape_inference::InferenceContext* c) 536 .SetShapeFn([](shape_inference::InferenceContext* c) { 537 shape_inference::ShapeHandle unused; 538 // buffer_size, seed, seed2, and count should be scalars. 539 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 540 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 541 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 542 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 543 return shape_inference::ScalarShape(c); 544 }); 545 546 REGISTER_OP("ShuffleAndRepeatDatasetV2") 547 .Input("input_dataset: variant") 548 .Input("buffer_size: int64") 549 .Input("seed: int64") 550 .Input("seed2: int64") 551 .Input("count: int64") 552 .Input("seed_generator: resource") 553 .Output("handle: variant") 554 .Attr("reshuffle_each_iteration: bool = true") 555 .Attr("output_types: list(type) >= 1") 556 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c01402(shape_inference::InferenceContext* c) 557 .SetShapeFn([](shape_inference::InferenceContext* c) { 558 shape_inference::ShapeHandle unused; 559 // buffer_size, seed, seed2, count, and seed_generator should be scalars. 560 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 561 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 562 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 563 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 564 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); 565 return shape_inference::ScalarShape(c); 566 }); 567 568 REGISTER_OP("AnonymousMemoryCache") 569 .Output("handle: resource") 570 .Output("deleter: variant") __anon42b7b0c01502(shape_inference::InferenceContext* c) 571 .SetShapeFn([](shape_inference::InferenceContext* c) { 572 c->set_output(0, c->Scalar()); 573 c->set_output(1, c->Scalar()); 574 return Status::OK(); 575 }); 576 577 REGISTER_OP("DeleteMemoryCache") 578 .Input("handle: resource") 579 .Input("deleter: variant") 580 .SetShapeFn(shape_inference::NoOutputs); 581 582 REGISTER_OP("DummyMemoryCache") 583 .Output("handle: resource") __anon42b7b0c01602(shape_inference::InferenceContext* c) 584 .SetShapeFn([](shape_inference::InferenceContext* c) { 585 c->set_output(0, c->Scalar()); 586 return Status::OK(); 587 }); 588 589 REGISTER_OP("CacheDataset") 590 .Input("input_dataset: variant") 591 .Input("filename: string") 592 .Output("handle: variant") 593 .Attr("output_types: list(type) >= 1") 594 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c01702(shape_inference::InferenceContext* c) 595 .SetShapeFn([](shape_inference::InferenceContext* c) { 596 shape_inference::ShapeHandle unused; 597 // filename should be a scalar. 598 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 599 return shape_inference::ScalarShape(c); 600 }); 601 602 REGISTER_OP("CacheDatasetV2") 603 .Input("input_dataset: variant") 604 .Input("filename: string") 605 .Input("cache: resource") 606 .Output("handle: variant") 607 .Attr("output_types: list(type) >= 1") 608 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c01802(shape_inference::InferenceContext* c) 609 .SetShapeFn([](shape_inference::InferenceContext* c) { 610 shape_inference::ShapeHandle unused; 611 // filename should be a scalar. 612 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 613 // cache should be a scalar. 614 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 615 return shape_inference::ScalarShape(c); 616 }); 617 618 REGISTER_OP("TextLineDataset") 619 .Input("filenames: string") 620 .Input("compression_type: string") 621 .Input("buffer_size: int64") 622 .Output("handle: variant") 623 .SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must 624 // disable constant folding. __anon42b7b0c01902(shape_inference::InferenceContext* c) 625 .SetShapeFn([](shape_inference::InferenceContext* c) { 626 shape_inference::ShapeHandle unused; 627 // `filenames` must be a scalar or a vector. 628 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 629 // `compression_type` could only be a scalar. 630 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 631 // `buffer_size` could only be a scalar. 632 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 633 return shape_inference::ScalarShape(c); 634 }); 635 636 REGISTER_OP("FixedLengthRecordDataset") 637 .Input("filenames: string") 638 .Input("header_bytes: int64") 639 .Input("record_bytes: int64") 640 .Input("footer_bytes: int64") 641 .Input("buffer_size: int64") 642 .Output("handle: variant") 643 .SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must 644 // disable constant folding. __anon42b7b0c01a02(shape_inference::InferenceContext* c) 645 .SetShapeFn([](shape_inference::InferenceContext* c) { 646 shape_inference::ShapeHandle unused; 647 // `filenames` must be a scalar or a vector. 648 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 649 // header_bytes, record_bytes, footer_bytes, buffer_size should be 650 // scalars. 651 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 652 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 653 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 654 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 655 return shape_inference::ScalarShape(c); 656 }); 657 658 REGISTER_OP("FixedLengthRecordDatasetV2") 659 .Input("filenames: string") 660 .Input("header_bytes: int64") 661 .Input("record_bytes: int64") 662 .Input("footer_bytes: int64") 663 .Input("buffer_size: int64") 664 .Input("compression_type: string") 665 .Output("handle: variant") 666 .SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must 667 // disable constant folding. __anon42b7b0c01b02(shape_inference::InferenceContext* c) 668 .SetShapeFn([](shape_inference::InferenceContext* c) { 669 shape_inference::ShapeHandle unused; 670 // `filenames` must be a scalar or a vector. 671 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 672 // header_bytes, record_bytes, footer_bytes, buffer_size should be 673 // scalars. 674 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 675 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 676 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 677 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 678 return shape_inference::ScalarShape(c); 679 }); 680 681 REGISTER_OP("TFRecordDataset") 682 .Input("filenames: string") 683 .Input("compression_type: string") 684 .Input("buffer_size: int64") 685 .Output("handle: variant") 686 .SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must 687 // disable constant folding. __anon42b7b0c01c02(shape_inference::InferenceContext* c) 688 .SetShapeFn([](shape_inference::InferenceContext* c) { 689 shape_inference::ShapeHandle unused; 690 // `filenames` must be a scalar or a vector. 691 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 692 // `compression_type` could only be a scalar. 693 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 694 // `buffer_size` could only be a scalar. 695 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 696 return shape_inference::ScalarShape(c); 697 }); 698 699 REGISTER_OP("Iterator") 700 .Output("handle: resource") 701 .Attr("shared_name: string") 702 .Attr("container: string") 703 .Attr("output_types: list(type) >= 1") 704 .Attr("output_shapes: list(shape) >= 1") 705 .SetShapeFn(shape_inference::ScalarShape); 706 707 REGISTER_OP("IteratorV2") 708 .Output("handle: resource") 709 .Attr("shared_name: string") 710 .Attr("container: string") 711 .Attr("output_types: list(type) >= 1") 712 .Attr("output_shapes: list(shape) >= 1") 713 .SetShapeFn(shape_inference::ScalarShape); 714 715 REGISTER_OP("AnonymousIterator") 716 .Output("handle: resource") 717 .Attr("output_types: list(type) >= 1") 718 .Attr("output_shapes: list(shape) >= 1") 719 .SetShapeFn(shape_inference::ScalarShape); 720 721 REGISTER_OP("AnonymousIteratorV2") 722 .Output("handle: resource") 723 .Output("deleter: variant") 724 .Attr("output_types: list(type) >= 1") 725 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c01d02(shape_inference::InferenceContext* c) 726 .SetShapeFn([](shape_inference::InferenceContext* c) { 727 c->set_output(0, c->Scalar()); 728 c->set_output(1, c->Scalar()); 729 return Status::OK(); 730 }); 731 732 REGISTER_OP("DeleteIterator") 733 .Input("handle: resource") 734 .Input("deleter: variant") 735 .SetShapeFn(shape_inference::NoOutputs); 736 737 REGISTER_OP("DeleteMultiDeviceIterator") 738 .Input("multi_device_iterator: resource") 739 .Input("iterators: N * resource") 740 .Input("deleter: variant") 741 .Attr("N: int >= 0") 742 .SetShapeFn(shape_inference::NoOutputs); 743 744 REGISTER_OP("MakeIterator") 745 .Input("dataset: variant") 746 .Input("iterator: resource") 747 .SetShapeFn(shape_inference::NoOutputs); 748 749 REGISTER_OP("OneShotIterator") 750 .Output("handle: resource") 751 .Attr("dataset_factory: func") 752 .Attr("output_types: list(type) >= 1") 753 .Attr("output_shapes: list(shape) >= 1") 754 .Attr("container: string = ''") 755 .Attr("shared_name: string = ''") 756 .SetIsStateful() 757 .SetShapeFn(shape_inference::ScalarShape); 758 759 REGISTER_OP("IteratorGetNext") 760 .Input("iterator: resource") 761 .Output("components: output_types") 762 .Attr("output_types: list(type) >= 1") 763 .Attr("output_shapes: list(shape) >= 1") 764 .SetShapeFn(shape_inference::DatasetIteratorShape); 765 766 REGISTER_OP("IteratorGetNextSync") 767 .Input("iterator: resource") 768 .Output("components: output_types") 769 .Attr("output_types: list(type) >= 1") 770 .Attr("output_shapes: list(shape) >= 1") 771 .SetShapeFn(shape_inference::DatasetIteratorShape); 772 773 // TODO(b/124308596): Instead of conservatively marking this op as stateful, 774 // implement a mechanism to determine whether `dataset` has a side-effect 775 // and use it to decide whether to use a stateless or stateful version of this 776 // op. 777 REGISTER_OP("DatasetToSingleElement") 778 .Input("dataset: variant") 779 .Output("components: output_types") 780 .Attr("output_types: list(type) >= 1") 781 .Attr("output_shapes: list(shape) >= 1") 782 .SetIsStateful() 783 .SetShapeFn(shape_inference::DatasetIteratorShape); 784 785 // TODO(b/124308596): Instead of conservatively marking this op as stateful, 786 // implement a mechanism to determine whether `dataset` has a side-effect 787 // and use it to decide whether to use a stateless or stateful version of this 788 // op. 789 REGISTER_OP("ReduceDataset") 790 .Input("input_dataset: variant") 791 .Input("initial_state: Tstate") 792 .Input("other_arguments: Targuments") 793 .Output("components: output_types") 794 .Attr("f: func") 795 .Attr("Tstate: list(type) >= 1") 796 .Attr("Targuments: list(type) >= 0") 797 .Attr("output_types: list(type) >= 1") 798 .Attr("output_shapes: list(shape) >= 1") 799 .Attr("use_inter_op_parallelism: bool = true") 800 .SetIsStateful() 801 .SetShapeFn(shape_inference::DatasetIteratorShape); 802 803 REGISTER_OP("IteratorToStringHandle") 804 .Input("resource_handle: resource") 805 .Output("string_handle: string") 806 .SetShapeFn(shape_inference::ScalarShape); 807 808 REGISTER_OP("IteratorFromStringHandle") 809 .Input("string_handle: string") 810 .Output("resource_handle: resource") 811 .Attr("output_types: list(type) >= 0 = []") 812 .Attr("output_shapes: list(shape) >= 0 = []") 813 .SetShapeFn(shape_inference::ScalarShape); 814 815 REGISTER_OP("IteratorFromStringHandleV2") 816 .Input("string_handle: string") 817 .Output("resource_handle: resource") 818 .Attr("output_types: list(type) >= 0 = []") 819 .Attr("output_shapes: list(shape) >= 0 = []") 820 .SetShapeFn(shape_inference::ScalarShape); 821 822 REGISTER_OP("SerializeIterator") 823 .Input("resource_handle: resource") 824 .Attr("external_state_policy: int = 0") 825 .Output("serialized: variant") __anon42b7b0c01e02(shape_inference::InferenceContext* c) 826 .SetShapeFn([](shape_inference::InferenceContext* c) { 827 c->set_output(0, c->Vector(c->UnknownDim())); 828 return Status::OK(); 829 }); 830 831 REGISTER_OP("DeserializeIterator") 832 .Input("resource_handle: resource") 833 .Input("serialized: variant") 834 .SetShapeFn(shape_inference::NoOutputs); 835 836 REGISTER_OP("DatasetToGraph") 837 .Input("input_dataset: variant") 838 .Attr("stateful_whitelist: list(string) >= 0 = []") 839 .Attr("allow_stateful: bool = false") 840 .Attr("strip_device_assignment: bool = false") 841 .Output("graph: string") 842 .SetShapeFn(shape_inference::ScalarShape); 843 844 REGISTER_OP("DatasetToGraphV2") 845 .Input("input_dataset: variant") 846 .Attr("external_state_policy: int = 0") 847 .Attr("strip_device_assignment: bool = false") 848 .Output("graph: string") 849 .SetShapeFn(shape_inference::ScalarShape); 850 851 REGISTER_OP("OptimizeDataset") 852 .Input("input_dataset: variant") 853 .Input("optimizations: string") 854 .Output("handle: variant") 855 .Attr("output_types: list(type) >= 1") 856 .Attr("output_shapes: list(shape) >= 1") 857 .Attr("optimization_configs: list(string) = []") 858 .SetShapeFn(shape_inference::ScalarShape); 859 860 REGISTER_OP("OptimizeDatasetV2") 861 .Input("input_dataset: variant") 862 .Input("optimizations_enabled: string") 863 .Input("optimizations_disabled: string") 864 .Input("optimizations_default: string") 865 .Output("handle: variant") 866 .Attr("output_types: list(type) >= 1") 867 .Attr("output_shapes: list(shape) >= 1") 868 .Attr("optimization_configs: list(string) = []") 869 .SetShapeFn(shape_inference::ScalarShape); 870 871 REGISTER_OP("OptionalFromValue") 872 .Input("components: Toutput_types") 873 .Output("optional: variant") 874 .Attr("Toutput_types: list(type) >= 1") __anon42b7b0c01f02(shape_inference::InferenceContext* c) 875 .SetShapeFn([](shape_inference::InferenceContext* c) { 876 std::vector<DataType> dtypes; 877 TF_RETURN_IF_ERROR(c->GetAttr("Toutput_types", &dtypes)); 878 c->set_output(0, c->Scalar()); 879 std::vector<shape_inference::ShapeAndType> shapes_and_types; 880 shapes_and_types.reserve(c->num_inputs()); 881 for (int i = 0; i < c->num_inputs(); ++i) { 882 shapes_and_types.emplace_back(c->input(i), dtypes[i], ST_OPTIONAL); 883 } 884 c->set_output_handle_shapes_and_types(0, shapes_and_types); 885 return Status::OK(); 886 }); 887 888 REGISTER_OP("OptionalNone") 889 .Output("optional: variant") 890 .SetShapeFn(shape_inference::ScalarShape); 891 892 REGISTER_OP("OptionalHasValue") 893 .Input("optional: variant") 894 .Output("has_value: bool") 895 .SetShapeFn(shape_inference::ScalarShape); 896 897 REGISTER_OP("OptionalGetValue") 898 .Input("optional: variant") 899 .Output("components: output_types") 900 .Attr("output_types: list(type) >= 1") 901 .Attr("output_shapes: list(shape) >= 1") 902 .SetShapeFn(shape_inference::DatasetIteratorShape); 903 904 REGISTER_OP("IteratorGetNextAsOptional") 905 .Input("iterator: resource") 906 .Output("optional: variant") 907 .Attr("output_types: list(type) >= 1") 908 .Attr("output_shapes: list(shape) >= 1") 909 .SetShapeFn(shape_inference::ScalarShape); 910 911 REGISTER_OP("ModelDataset") 912 .Input("input_dataset: variant") 913 .Output("handle: variant") 914 .Attr("algorithm: int = 0") 915 .Attr("cpu_budget: int = 0") 916 .Attr("ram_budget: int = 0") 917 .Attr("output_types: list(type) >= 1") 918 .Attr("output_shapes: list(shape) >= 1") 919 .SetShapeFn(shape_inference::ScalarShape); 920 921 // TODO(b/124308749): Add a stateful version of MapDefun and use it when `f` 922 // is stateful. 923 REGISTER_OP("MapDefun") 924 .Input("arguments: Targuments") 925 .Input("captured_inputs: Tcaptured") 926 .Output("output: output_types") 927 .Attr("Targuments: list(type) >= 1") 928 .Attr("Tcaptured: list(type) >= 0 = []") 929 .Attr("output_types: list(type) >= 1") 930 .Attr("output_shapes: list(shape) >= 1") 931 .Attr("f: func") 932 .Attr("max_intra_op_parallelism: int = 1") __anon42b7b0c02002(shape_inference::InferenceContext* c) 933 .SetShapeFn([](shape_inference::InferenceContext* c) { 934 std::vector<PartialTensorShape> output_shapes; 935 TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); 936 DataTypeVector t_args; 937 TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args)); 938 if (output_shapes.size() != c->num_outputs()) { 939 return errors::InvalidArgument( 940 "`output_shapes` must be the same length as `output_types` (", 941 output_shapes.size(), " vs. ", c->num_outputs(), ")"); 942 } 943 944 int64 dim_zero = -1; 945 for (size_t i = 0; i < t_args.size(); ++i) { 946 if (c->Rank(c->input(i)) == 0) { 947 return errors::InvalidArgument( 948 "Arguments must have rank at least 1. Input ", i, 949 " has rank of 0."); 950 } 951 auto dim_handle = c->Dim(c->input(i), 0); 952 if (c->ValueKnown(dim_handle)) { 953 if (dim_zero == -1) { 954 dim_zero = c->Value(dim_handle); 955 } else if (c->Value(dim_handle) != dim_zero) { 956 return errors::InvalidArgument( 957 "Arguments must have the same dimension 0."); 958 } 959 } 960 } 961 962 for (size_t i = 0; i < output_shapes.size(); ++i) { 963 PartialTensorShape s({}); 964 s = s.Concatenate(dim_zero); 965 s = s.Concatenate(output_shapes[i]); 966 shape_inference::ShapeHandle output_shape_handle; 967 968 TF_RETURN_IF_ERROR( 969 c->MakeShapeFromPartialTensorShape(s, &output_shape_handle)); 970 c->set_output(static_cast<int>(i), output_shape_handle); 971 } 972 return Status::OK(); 973 }); 974 975 REGISTER_OP("WrapDatasetVariant") 976 .Input("input_handle: variant") 977 .Output("output_handle: variant") 978 .SetShapeFn(shape_inference::ScalarShape); 979 980 REGISTER_OP("UnwrapDatasetVariant") 981 .Input("input_handle: variant") 982 .Output("output_handle: variant") 983 .SetShapeFn(shape_inference::ScalarShape); 984 985 REGISTER_OP("AnonymousMultiDeviceIterator") 986 .Output("handle: resource") 987 .Output("deleter: variant") 988 .Attr("devices: list(string) >= 1") 989 .Attr("output_types: list(type) >= 1") 990 .Attr("output_shapes: list(shape) >= 1") __anon42b7b0c02102(shape_inference::InferenceContext* c) 991 .SetShapeFn([](shape_inference::InferenceContext* c) { 992 c->set_output(0, c->Scalar()); 993 c->set_output(1, c->Scalar()); 994 return Status::OK(); 995 }); 996 997 REGISTER_OP("MultiDeviceIterator") 998 .Output("handle: resource") 999 .Attr("devices: list(string) >= 1") 1000 .Attr("shared_name: string") 1001 .Attr("container: string") 1002 .Attr("output_types: list(type) >= 1") 1003 .Attr("output_shapes: list(shape) >= 1") 1004 .SetShapeFn(shape_inference::ScalarShape); 1005 1006 REGISTER_OP("MultiDeviceIteratorInit") 1007 .Input("dataset: variant") 1008 .Input("multi_device_iterator: resource") 1009 .Input("max_buffer_size: int64") 1010 .Output("incarnation_id: int64") 1011 .SetShapeFn(shape_inference::ScalarShape); 1012 1013 REGISTER_OP("MultiDeviceIteratorGetNextFromShard") 1014 .Input("multi_device_iterator: resource") 1015 .Input("shard_num: int32") 1016 .Input("incarnation_id: int64") 1017 .Output("components: output_types") 1018 .Attr("output_types: list(type) >= 1") 1019 .Attr("output_shapes: list(shape) >= 1") 1020 .SetShapeFn(shape_inference::DatasetIteratorShape); 1021 1022 REGISTER_OP("MultiDeviceIteratorToStringHandle") 1023 .Input("multi_device_iterator: resource") 1024 .Output("string_handle: string") 1025 .SetShapeFn(shape_inference::ScalarShape); 1026 1027 REGISTER_OP("MultiDeviceIteratorFromStringHandle") 1028 .Input("string_handle: string") 1029 .Output("multi_device_iterator: resource") 1030 .Attr("output_types: list(type) >= 0 = []") 1031 .Attr("output_shapes: list(shape) >= 0 = []") 1032 .SetShapeFn(shape_inference::ScalarShape); 1033 1034 } // namespace tensorflow 1035