1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/common_shape_fns.h" 16 #include "tensorflow/core/framework/full_type.pb.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/op_def_builder.h" 19 #include "tensorflow/core/framework/shape_inference.h" 20 21 namespace tensorflow { 22 23 // -------------------------------------------------------------------------- 24 25 // The ops in this section can be composed to define an input 26 // pipeline. Each op produces a DT_VARIANT tensor that represents 27 // a DAG of "dataset" objects. An "dataset" object can be converted 28 // to a stateful "iterator" by passing the "dataset" to the 29 // "MakeIterator" op. 30 // 31 // TODO(b/123753214): DT_VARIANT tensors that represent "dataset" objects are 32 // not presently serializable. To avoid issues with graph optimizations, such 33 // as constant folding, CSE, or DCE, ensure that any "source dataset" ops 34 // (i.e. ops that output a dataset and do not take one as input) are 35 // marked as "do not optimize". 36 37 // TODO(mrry): Validate that `components` have shapes compatible with 38 // `output_shapes`. 39 REGISTER_OP("TensorDataset") 40 .Input("components: Toutput_types") 41 .Output("handle: variant") 42 .Attr("Toutput_types: list(type) >= 1") 43 .Attr("output_shapes: list(shape) >= 1") 44 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 45 .SetShapeFn(shape_inference::ScalarShape); 46 47 // TODO(mrry): Validate that the dim-0 slices of `components` have shapes 48 // compatible with `output_shapes`. 49 REGISTER_OP("TensorSliceDataset") 50 .Input("components: Toutput_types") 51 .Output("handle: variant") 52 .Attr("Toutput_types: list(type) >= 1") 53 .Attr("output_shapes: list(shape) >= 1") 54 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 55 .SetShapeFn(shape_inference::ScalarShape); 56 57 REGISTER_OP("SparseTensorSliceDataset") 58 .Input("indices: int64") 59 .Input("values: Tvalues") 60 .Input("dense_shape: int64") 61 .Output("handle: variant") 62 .Attr("Tvalues: type") 63 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 64 .SetShapeFn(shape_inference::ScalarShape); 65 66 REGISTER_OP("GeneratorDataset") 67 .Input("init_func_other_args: Tinit_func_args") 68 .Input("next_func_other_args: Tnext_func_args") 69 .Input("finalize_func_other_args: Tfinalize_func_args") 70 .Output("handle: variant") 71 .Attr("init_func: func") 72 .Attr("next_func: func") 73 .Attr("finalize_func: func") 74 .Attr("Tinit_func_args: list(type) >= 0") 75 .Attr("Tnext_func_args: list(type) >= 0") 76 .Attr("Tfinalize_func_args: list(type) >= 0") 77 .Attr("output_types: list(type) >= 1") 78 .Attr("output_shapes: list(shape) >= 1") 79 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 80 .SetShapeFn(shape_inference::ScalarShape); 81 82 REGISTER_OP("ZipDataset") 83 .Input("input_datasets: N * variant") 84 .Output("handle: variant") 85 .Attr("output_types: list(type) >= 1") 86 .Attr("output_shapes: list(shape) >= 1") 87 .Attr("N: int >= 1") 88 .SetShapeFn(shape_inference::ScalarShape); 89 90 REGISTER_OP("ConcatenateDataset") 91 .Input("input_dataset: variant") 92 .Input("another_dataset: variant") 93 .Output("handle: variant") 94 .Attr("output_types: list(type) >= 1") 95 .Attr("output_shapes: list(shape) >= 1") 96 .SetShapeFn(shape_inference::ScalarShape); 97 98 REGISTER_OP("RepeatDataset") 99 .Input("input_dataset: variant") 100 .Input("count: int64") 101 .Output("handle: variant") 102 .Attr("output_types: list(type) >= 1") 103 .Attr("output_shapes: list(shape) >= 1") __anon0b7d74c50102(shape_inference::InferenceContext* c) 104 .SetShapeFn([](shape_inference::InferenceContext* c) { 105 shape_inference::ShapeHandle count_shape; 106 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape)); 107 return shape_inference::ScalarShape(c); 108 }); 109 110 REGISTER_OP("TakeDataset") 111 .Input("input_dataset: variant") 112 .Input("count: int64") 113 .Output("handle: variant") 114 .Attr("output_types: list(type) >= 1") 115 .Attr("output_shapes: list(shape) >= 1") __anon0b7d74c50202(shape_inference::InferenceContext* c) 116 .SetShapeFn([](shape_inference::InferenceContext* c) { 117 shape_inference::ShapeHandle count_shape; 118 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape)); 119 return shape_inference::ScalarShape(c); 120 }); 121 122 REGISTER_OP("SkipDataset") 123 .Input("input_dataset: variant") 124 .Input("count: int64") 125 .Output("handle: variant") 126 .Attr("output_types: list(type) >= 1") 127 .Attr("output_shapes: list(shape) >= 1") __anon0b7d74c50302(shape_inference::InferenceContext* c) 128 .SetShapeFn([](shape_inference::InferenceContext* c) { 129 shape_inference::ShapeHandle count_shape; 130 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape)); 131 return shape_inference::ScalarShape(c); 132 }); 133 134 REGISTER_OP("MapDataset") 135 .Input("input_dataset: variant") 136 .Input("other_arguments: Targuments") 137 .Output("handle: variant") 138 .Attr("f: func") 139 .Attr("Targuments: list(type) >= 0") 140 .Attr("output_types: list(type) >= 1") 141 .Attr("output_shapes: list(shape) >= 1") 142 .Attr("use_inter_op_parallelism: bool = true") 143 .Attr("preserve_cardinality: bool = false") 144 .SetShapeFn(shape_inference::ScalarShape); 145 146 REGISTER_OP("ParallelMapDataset") 147 .Input("input_dataset: variant") 148 .Input("other_arguments: Targuments") 149 .Input("num_parallel_calls: int32") 150 .Output("handle: variant") 151 .Attr("f: func") 152 .Attr("Targuments: list(type) >= 0") 153 .Attr("output_types: list(type) >= 1") 154 .Attr("output_shapes: list(shape) >= 1") 155 .Attr("use_inter_op_parallelism: bool = true") 156 .Attr("sloppy: bool = false") 157 .Attr("preserve_cardinality: bool = false") 158 .SetShapeFn(shape_inference::ScalarShape); 159 160 REGISTER_OP("ParallelMapDatasetV2") 161 .Input("input_dataset: variant") 162 .Input("other_arguments: Targuments") 163 .Input("num_parallel_calls: int64") 164 .Output("handle: variant") 165 .Attr("f: func") 166 .Attr("Targuments: list(type) >= 0") 167 .Attr("output_types: list(type) >= 1") 168 .Attr("output_shapes: list(shape) >= 1") 169 .Attr("use_inter_op_parallelism: bool = true") 170 // "true", "false", or "default". 171 .Attr("deterministic: string = 'default'") 172 .Attr("preserve_cardinality: bool = false") 173 .SetShapeFn(shape_inference::ScalarShape); 174 175 REGISTER_OP("PrefetchDataset") 176 .Input("input_dataset: variant") 177 .Input("buffer_size: int64") 178 .Output("handle: variant") 179 .Attr("output_types: list(type) >= 1") 180 .Attr("output_shapes: list(shape) >= 1") 181 .Attr("slack_period: int = 0") 182 .Attr("legacy_autotune: bool = true") 183 .Attr("buffer_size_min: int = 0") __anon0b7d74c50402(shape_inference::InferenceContext* c) 184 .SetShapeFn([](shape_inference::InferenceContext* c) { 185 shape_inference::ShapeHandle unused; 186 // buffer_size should be a scalar. 187 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 188 return shape_inference::ScalarShape(c); 189 }); 190 191 REGISTER_OP("FlatMapDataset") 192 .Input("input_dataset: variant") 193 .Input("other_arguments: Targuments") 194 .Output("handle: variant") 195 .Attr("f: func") 196 .Attr("Targuments: list(type) >= 0") 197 .Attr("output_types: list(type) >= 1") 198 .Attr("output_shapes: list(shape) >= 1") 199 .SetShapeFn(shape_inference::ScalarShape); 200 201 REGISTER_OP("InterleaveDataset") 202 .Input("input_dataset: variant") 203 .Input("other_arguments: Targuments") 204 .Input("cycle_length: int64") 205 .Input("block_length: int64") 206 .Output("handle: variant") 207 .Attr("f: func") 208 .Attr("Targuments: list(type) >= 0") 209 .Attr("output_types: list(type) >= 1") 210 .Attr("output_shapes: list(shape) >= 1") 211 .SetShapeFn(shape_inference::ScalarShape); 212 213 REGISTER_OP("ParallelInterleaveDatasetV2") 214 .Input("input_dataset: variant") 215 .Input("other_arguments: Targuments") 216 .Input("cycle_length: int64") 217 .Input("block_length: int64") 218 .Input("num_parallel_calls: int64") 219 .Output("handle: variant") 220 .Attr("f: func") 221 .Attr("Targuments: list(type) >= 0") 222 .Attr("output_types: list(type) >= 1") 223 .Attr("output_shapes: list(shape) >= 1") 224 .Attr("sloppy: bool = false") 225 .SetShapeFn(shape_inference::ScalarShape); 226 227 REGISTER_OP("ParallelInterleaveDatasetV3") 228 .Input("input_dataset: variant") 229 .Input("other_arguments: Targuments") 230 .Input("cycle_length: int64") 231 .Input("block_length: int64") 232 .Input("num_parallel_calls: int64") 233 .Output("handle: variant") 234 .Attr("f: func") 235 // "true", "false", or "default". 236 .Attr("deterministic: string = 'default'") 237 .Attr("Targuments: list(type) >= 0") 238 .Attr("output_types: list(type) >= 1") 239 .Attr("output_shapes: list(shape) >= 1") 240 .SetShapeFn(shape_inference::ScalarShape); 241 242 // Like V3, but adds buffer_output_elements and prefetch_input_elements. 243 REGISTER_OP("ParallelInterleaveDatasetV4") 244 .Input("input_dataset: variant") 245 .Input("other_arguments: Targuments") 246 .Input("cycle_length: int64") 247 .Input("block_length: int64") 248 .Input("buffer_output_elements: int64") 249 .Input("prefetch_input_elements: int64") 250 .Input("num_parallel_calls: int64") 251 .Output("handle: variant") 252 .Attr("f: func") 253 // "true", "false", or "default". 254 .Attr("deterministic: string = 'default'") 255 .Attr("Targuments: list(type) >= 0") 256 .Attr("output_types: list(type) >= 1") 257 .Attr("output_shapes: list(shape) >= 1") 258 .SetShapeFn(shape_inference::ScalarShape); 259 260 REGISTER_OP("FilterDataset") 261 .Input("input_dataset: variant") 262 .Input("other_arguments: Targuments") 263 .Output("handle: variant") 264 .Attr("predicate: func") 265 .Attr("Targuments: list(type) >= 0") 266 .Attr("output_types: list(type) >= 1") 267 .Attr("output_shapes: list(shape) >= 1") 268 .SetShapeFn(shape_inference::ScalarShape); 269 270 // This op is no longer supported. 271 REGISTER_OP("FilterByLastComponentDataset") 272 .Input("input_dataset: variant") 273 .Output("output: variant") 274 .Attr("output_types: list(type) >= 1") 275 .Attr("output_shapes: list(shape) >= 1") 276 .SetShapeFn(shape_inference::ScalarShape); 277 278 REGISTER_OP("WindowDataset") 279 .Input("input_dataset: variant") 280 .Input("size: int64") 281 .Input("shift: int64") 282 .Input("stride: int64") 283 .Input("drop_remainder: bool") 284 .Output("handle: variant") 285 .Attr("output_types: list(type) >= 1") 286 .Attr("output_shapes: list(shape) >= 1") __anon0b7d74c50502(shape_inference::InferenceContext* c) 287 .SetShapeFn([](shape_inference::InferenceContext* c) { 288 shape_inference::ShapeHandle unused; 289 // size, shift, stride, and drop_remainder should be scalars. 290 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 291 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 292 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 293 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 294 return shape_inference::ScalarShape(c); 295 }); 296 297 REGISTER_OP("WindowOp") 298 .Input("inputs: Tinputs") 299 .Output("handle: variant") 300 .Attr("output_types: list(type) >= 1") 301 .Attr("output_shapes: list(shape) >= 1") 302 .Attr("Tinputs: list(type) >= 1") 303 .SetShapeFn(shape_inference::ScalarShape); 304 305 REGISTER_OP("BatchDataset") 306 .Input("input_dataset: variant") 307 .Input("batch_size: int64") 308 .Output("handle: variant") 309 .Attr("output_types: list(type) >= 1") 310 .Attr("output_shapes: list(shape) >= 1") __anon0b7d74c50602(shape_inference::InferenceContext* c) 311 .SetShapeFn([](shape_inference::InferenceContext* c) { 312 shape_inference::ShapeHandle unused; 313 // batch_size should be a scalar. 314 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 315 return shape_inference::ScalarShape(c); 316 }); 317 318 REGISTER_OP("BatchDatasetV2") 319 .Input("input_dataset: variant") 320 .Input("batch_size: int64") 321 .Input("drop_remainder: bool") 322 .Output("handle: variant") 323 .Attr("parallel_copy: bool = false") 324 .Attr("output_types: list(type) >= 1") 325 .Attr("output_shapes: list(shape) >= 1") __anon0b7d74c50702(shape_inference::InferenceContext* c) 326 .SetShapeFn([](shape_inference::InferenceContext* c) { 327 shape_inference::ShapeHandle unused; 328 // batch_size should be a scalar. 329 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 330 // drop_remainder should be a scalar. 331 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 332 return shape_inference::ScalarShape(c); 333 }); 334 335 REGISTER_OP("ParallelBatchDataset") 336 .Input("input_dataset: variant") 337 .Input("batch_size: int64") 338 .Input("num_parallel_calls: int64") 339 .Input("drop_remainder: bool") 340 .Output("handle: variant") 341 .Attr("parallel_copy: bool = false") 342 .Attr("output_types: list(type) >= 1") 343 .Attr("output_shapes: list(shape) >= 1") 344 // "true", "false", or "default". 345 .Attr("deterministic: string = 'default'") __anon0b7d74c50802(shape_inference::InferenceContext* c) 346 .SetShapeFn([](shape_inference::InferenceContext* c) { 347 shape_inference::ShapeHandle unused; 348 // batch_size should be a scalar. 349 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 350 // num_parallel_calls should be a scalar. 351 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 352 // drop_remainder should be a scalar. 353 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 354 return shape_inference::ScalarShape(c); 355 }); 356 357 REGISTER_OP("ShardDataset") 358 .Input("input_dataset: variant") 359 .Input("num_shards: int64") 360 .Input("index: int64") 361 .Output("handle: variant") 362 .Attr("require_non_empty: bool = false") 363 .Attr("output_types: list(type) >= 1") 364 .Attr("output_shapes: list(shape) >= 1") __anon0b7d74c50902(shape_inference::InferenceContext* c) 365 .SetShapeFn([](shape_inference::InferenceContext* c) { 366 shape_inference::ShapeHandle unused; 367 // num_shards should be a scalar. 368 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 369 // index should be a scalar. 370 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 371 return shape_inference::ScalarShape(c); 372 }); 373 374 // TODO(mrry): Validate that `padded_shapes` are all vectors, the lengths of 375 // `output_types` and `output_shapes` are `N` the `output_shapes` are (as far as 376 // possible to tell statically) compatible with `padded_shapes`, and that 377 // `padding_values` are all scalars. 378 REGISTER_OP("PaddedBatchDataset") 379 .Input("input_dataset: variant") 380 .Input("batch_size: int64") 381 .Input("padded_shapes: N * int64") 382 .Input("padding_values: Toutput_types") 383 .Output("handle: variant") 384 .Attr("Toutput_types: list(type) >= 1") 385 .Attr("output_shapes: list(shape) >= 1") 386 .Attr("N: int >= 1") __anon0b7d74c50a02(shape_inference::InferenceContext* c) 387 .SetShapeFn([](shape_inference::InferenceContext* c) { 388 shape_inference::ShapeHandle unused; 389 // batch_size should be a scalar. 390 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 391 return shape_inference::ScalarShape(c); 392 }); 393 394 REGISTER_OP("PaddedBatchDatasetV2") 395 .Input("input_dataset: variant") 396 .Input("batch_size: int64") 397 .Input("padded_shapes: N * int64") 398 .Input("padding_values: Toutput_types") 399 .Input("drop_remainder: bool") 400 .Output("handle: variant") 401 .Attr("parallel_copy: bool = false") 402 .Attr("Toutput_types: list(type) >= 1") 403 .Attr("output_shapes: list(shape) >= 1") 404 .Attr("N: int >= 1") __anon0b7d74c50b02(shape_inference::InferenceContext* c) 405 .SetShapeFn([](shape_inference::InferenceContext* c) { 406 shape_inference::ShapeHandle unused; 407 // batch_size should be a scalar. 408 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 409 // drop_remainder should be a scalar. 410 TF_RETURN_IF_ERROR( 411 c->WithRank(c->input(c->num_inputs() - 1), 0, &unused)); 412 return shape_inference::ScalarShape(c); 413 }); 414 415 REGISTER_OP("RangeDataset") 416 .Input("start: int64") 417 .Input("stop: int64") 418 .Input("step: int64") 419 .Output("handle: variant") 420 .Attr("output_types: list(type) >= 1") 421 .Attr("output_shapes: list(shape) >= 1") 422 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. __anon0b7d74c50c02(shape_inference::InferenceContext* c) 423 .SetShapeFn([](shape_inference::InferenceContext* c) { 424 shape_inference::ShapeHandle unused; 425 // start, stop, and step should be scalars. 426 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); 427 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 428 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 429 return shape_inference::ScalarShape(c); 430 }); 431 432 REGISTER_OP("AnonymousSeedGenerator") 433 .Input("seed: int64") 434 .Input("seed2: int64") 435 .Input("reshuffle: bool") 436 .Output("handle: resource") 437 .Output("deleter: variant") __anon0b7d74c50d02(shape_inference::InferenceContext* c) 438 .SetShapeFn([](shape_inference::InferenceContext* c) { 439 c->set_output(0, c->Scalar()); 440 c->set_output(1, c->Scalar()); 441 return Status::OK(); 442 }); 443 444 REGISTER_OP("DatasetCardinality") 445 .Input("input_dataset: variant") 446 .Output("cardinality: int64") 447 .SetShapeFn(shape_inference::ScalarShape); 448 449 REGISTER_OP("DeleteSeedGenerator") 450 .Input("handle: resource") 451 .Input("deleter: variant") 452 .SetShapeFn(shape_inference::NoOutputs); 453 454 // Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator. 455 REGISTER_OP("AnonymousRandomSeedGenerator") 456 .Input("seed: int64") 457 .Input("seed2: int64") 458 .Output("handle: resource") 459 .Output("deleter: variant") __anon0b7d74c50e02(shape_inference::InferenceContext* c) 460 .SetShapeFn([](shape_inference::InferenceContext* c) { 461 c->set_output(0, c->Scalar()); 462 c->set_output(1, c->Scalar()); 463 return Status::OK(); 464 }); 465 466 // Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator. 467 REGISTER_OP("DeleteRandomSeedGenerator") 468 .Input("handle: resource") 469 .Input("deleter: variant") 470 .SetShapeFn(shape_inference::NoOutputs); 471 472 REGISTER_OP("DummySeedGenerator") 473 .Output("handle: resource") __anon0b7d74c50f02(shape_inference::InferenceContext* c) 474 .SetShapeFn([](shape_inference::InferenceContext* c) { 475 c->set_output(0, c->Scalar()); 476 return Status::OK(); 477 }); 478 479 REGISTER_OP("ShuffleDataset") 480 .Input("input_dataset: variant") 481 .Input("buffer_size: int64") 482 .Input("seed: int64") 483 .Input("seed2: int64") 484 .Output("handle: variant") 485 .Attr("reshuffle_each_iteration: bool = true") 486 .Attr("output_types: list(type) >= 1") 487 .Attr("output_shapes: list(shape) >= 1") __anon0b7d74c51002(shape_inference::InferenceContext* c) 488 .SetShapeFn([](shape_inference::InferenceContext* c) { 489 shape_inference::ShapeHandle unused; 490 // buffer_size, seed, and seed2 should be scalars. 491 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 492 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 493 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 494 return shape_inference::ScalarShape(c); 495 }); 496 497 REGISTER_OP("ShuffleDatasetV2") 498 .Input("input_dataset: variant") 499 .Input("buffer_size: int64") 500 .Input("seed_generator: resource") 501 .Output("handle: variant") 502 .Attr("output_types: list(type) >= 1") 503 .Attr("output_shapes: list(shape) >= 1") __anon0b7d74c51102(shape_inference::InferenceContext* c) 504 .SetShapeFn([](shape_inference::InferenceContext* c) { 505 shape_inference::ShapeHandle unused; 506 // buffer_size and seed_generator should be scalars. 507 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 508 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 509 return shape_inference::ScalarShape(c); 510 }); 511 512 REGISTER_OP("ShuffleDatasetV3") 513 .Input("input_dataset: variant") 514 .Input("buffer_size: int64") 515 .Input("seed: int64") 516 .Input("seed2: int64") 517 .Input("seed_generator: resource") 518 .Output("handle: variant") 519 .Attr("reshuffle_each_iteration: bool = true") 520 .Attr("output_types: list(type) >= 1") 521 .Attr("output_shapes: list(shape) >= 1") __anon0b7d74c51202(shape_inference::InferenceContext* c) 522 .SetShapeFn([](shape_inference::InferenceContext* c) { 523 shape_inference::ShapeHandle unused; 524 // buffer_size, seed, seed2, and seed_generator should be scalars. 525 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 526 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 527 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 528 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 529 return shape_inference::ScalarShape(c); 530 }); 531 532 REGISTER_OP("ShuffleAndRepeatDataset") 533 .Input("input_dataset: variant") 534 .Input("buffer_size: int64") 535 .Input("seed: int64") 536 .Input("seed2: int64") 537 .Input("count: int64") 538 .Output("handle: variant") 539 .Attr("output_types: list(type) >= 1") 540 .Attr("output_shapes: list(shape) >= 1") 541 .Attr("reshuffle_each_iteration: bool = true") __anon0b7d74c51302(shape_inference::InferenceContext* c) 542 .SetShapeFn([](shape_inference::InferenceContext* c) { 543 shape_inference::ShapeHandle unused; 544 // buffer_size, seed, seed2, and count should be scalars. 545 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 546 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 547 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 548 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 549 return shape_inference::ScalarShape(c); 550 }); 551 552 REGISTER_OP("ShuffleAndRepeatDatasetV2") 553 .Input("input_dataset: variant") 554 .Input("buffer_size: int64") 555 .Input("seed: int64") 556 .Input("seed2: int64") 557 .Input("count: int64") 558 .Input("seed_generator: resource") 559 .Output("handle: variant") 560 .Attr("reshuffle_each_iteration: bool = true") 561 .Attr("output_types: list(type) >= 1") 562 .Attr("output_shapes: list(shape) >= 1") __anon0b7d74c51402(shape_inference::InferenceContext* c) 563 .SetShapeFn([](shape_inference::InferenceContext* c) { 564 shape_inference::ShapeHandle unused; 565 // buffer_size, seed, seed2, count, and seed_generator should be scalars. 566 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 567 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 568 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 569 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 570 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); 571 return shape_inference::ScalarShape(c); 572 }); 573 574 REGISTER_OP("AnonymousMemoryCache") 575 .Output("handle: resource") 576 .Output("deleter: variant") __anon0b7d74c51502(shape_inference::InferenceContext* c) 577 .SetShapeFn([](shape_inference::InferenceContext* c) { 578 c->set_output(0, c->Scalar()); 579 c->set_output(1, c->Scalar()); 580 return Status::OK(); 581 }); 582 583 REGISTER_OP("DeleteMemoryCache") 584 .Input("handle: resource") 585 .Input("deleter: variant") 586 .SetShapeFn(shape_inference::NoOutputs); 587 588 REGISTER_OP("DummyMemoryCache") 589 .Output("handle: resource") __anon0b7d74c51602(shape_inference::InferenceContext* c) 590 .SetShapeFn([](shape_inference::InferenceContext* c) { 591 c->set_output(0, c->Scalar()); 592 return Status::OK(); 593 }); 594 595 REGISTER_OP("CacheDataset") 596 .Input("input_dataset: variant") 597 .Input("filename: string") 598 .Output("handle: variant") 599 .Attr("output_types: list(type) >= 1") 600 .Attr("output_shapes: list(shape) >= 1") __anon0b7d74c51702(shape_inference::InferenceContext* c) 601 .SetShapeFn([](shape_inference::InferenceContext* c) { 602 shape_inference::ShapeHandle unused; 603 // filename should be a scalar. 604 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 605 return shape_inference::ScalarShape(c); 606 }); 607 608 REGISTER_OP("CacheDatasetV2") 609 .Input("input_dataset: variant") 610 .Input("filename: string") 611 .Input("cache: resource") 612 .Output("handle: variant") 613 .Attr("output_types: list(type) >= 1") 614 .Attr("output_shapes: list(shape) >= 1") __anon0b7d74c51802(shape_inference::InferenceContext* c) 615 .SetShapeFn([](shape_inference::InferenceContext* c) { 616 shape_inference::ShapeHandle unused; 617 // filename should be a scalar. 618 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 619 // cache should be a scalar. 620 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 621 return shape_inference::ScalarShape(c); 622 }); 623 624 REGISTER_OP("TextLineDataset") 625 .Input("filenames: string") 626 .Input("compression_type: string") 627 .Input("buffer_size: int64") 628 .Output("handle: variant") 629 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. __anon0b7d74c51902(shape_inference::InferenceContext* c) 630 .SetShapeFn([](shape_inference::InferenceContext* c) { 631 shape_inference::ShapeHandle unused; 632 // `filenames` must be a scalar or a vector. 633 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 634 // `compression_type` could only be a scalar. 635 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 636 // `buffer_size` could only be a scalar. 637 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 638 return shape_inference::ScalarShape(c); 639 }); 640 641 REGISTER_OP("FixedLengthRecordDataset") 642 .Input("filenames: string") 643 .Input("header_bytes: int64") 644 .Input("record_bytes: int64") 645 .Input("footer_bytes: int64") 646 .Input("buffer_size: int64") 647 .Output("handle: variant") 648 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. __anon0b7d74c51a02(shape_inference::InferenceContext* c) 649 .SetShapeFn([](shape_inference::InferenceContext* c) { 650 shape_inference::ShapeHandle unused; 651 // `filenames` must be a scalar or a vector. 652 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 653 // header_bytes, record_bytes, footer_bytes, buffer_size should be 654 // scalars. 655 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 656 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 657 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 658 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 659 return shape_inference::ScalarShape(c); 660 }); 661 662 REGISTER_OP("FixedLengthRecordDatasetV2") 663 .Input("filenames: string") 664 .Input("header_bytes: int64") 665 .Input("record_bytes: int64") 666 .Input("footer_bytes: int64") 667 .Input("buffer_size: int64") 668 .Input("compression_type: string") 669 .Output("handle: variant") 670 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. __anon0b7d74c51b02(shape_inference::InferenceContext* c) 671 .SetShapeFn([](shape_inference::InferenceContext* c) { 672 shape_inference::ShapeHandle unused; 673 // `filenames` must be a scalar or a vector. 674 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 675 // header_bytes, record_bytes, footer_bytes, buffer_size should be 676 // scalars. 677 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 678 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 679 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 680 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 681 return shape_inference::ScalarShape(c); 682 }); 683 684 REGISTER_OP("TFRecordDataset") 685 .Input("filenames: string") 686 .Input("compression_type: string") 687 .Input("buffer_size: int64") 688 .Output("handle: variant") 689 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. __anon0b7d74c51c02(shape_inference::InferenceContext* c) 690 .SetShapeFn([](shape_inference::InferenceContext* c) { 691 shape_inference::ShapeHandle unused; 692 // `filenames` must be a scalar or a vector. 693 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 694 // `compression_type` could only be a scalar. 695 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 696 // `buffer_size` could only be a scalar. 697 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 698 return shape_inference::ScalarShape(c); 699 }); 700 701 REGISTER_OP("Iterator") 702 .Output("handle: resource") 703 .Attr("shared_name: string") 704 .Attr("container: string") 705 .Attr("output_types: list(type) >= 1") 706 .Attr("output_shapes: list(shape) >= 1") 707 .SetShapeFn(shape_inference::ScalarShape); 708 709 REGISTER_OP("IteratorV2") 710 .Output("handle: resource") 711 .Attr("shared_name: string") 712 .Attr("container: string") 713 .Attr("output_types: list(type) >= 1") 714 .Attr("output_shapes: list(shape) >= 1") 715 .SetShapeFn(shape_inference::ScalarShape); 716 717 REGISTER_OP("AnonymousIterator") 718 .Output("handle: resource") 719 .Attr("output_types: list(type) >= 1") 720 .Attr("output_shapes: list(shape) >= 1") 721 .SetShapeFn(shape_inference::ScalarShape); 722 723 REGISTER_OP("AnonymousIteratorV2") 724 .Output("handle: resource") 725 .Output("deleter: variant") 726 .Attr("output_types: list(type) >= 1") 727 .Attr("output_shapes: list(shape) >= 1") __anon0b7d74c51d02(shape_inference::InferenceContext* c) 728 .SetShapeFn([](shape_inference::InferenceContext* c) { 729 c->set_output(0, c->Scalar()); 730 c->set_output(1, c->Scalar()); 731 return Status::OK(); 732 }); 733 734 REGISTER_OP("DeleteIterator") 735 .Input("handle: resource") 736 .Input("deleter: variant") 737 .SetShapeFn(shape_inference::NoOutputs); 738 739 REGISTER_OP("DeleteMultiDeviceIterator") 740 .Input("multi_device_iterator: resource") 741 .Input("iterators: N * resource") 742 .Input("deleter: variant") 743 .Attr("N: int >= 0") 744 .SetShapeFn(shape_inference::NoOutputs); 745 746 REGISTER_OP("MakeIterator") 747 .Input("dataset: variant") 748 .Input("iterator: resource") 749 .SetShapeFn(shape_inference::NoOutputs); 750 751 REGISTER_OP("OneShotIterator") 752 .Output("handle: resource") 753 .Attr("dataset_factory: func") 754 .Attr("output_types: list(type) >= 1") 755 .Attr("output_shapes: list(shape) >= 1") 756 .Attr("container: string = ''") 757 .Attr("shared_name: string = ''") 758 .SetIsStateful() 759 .SetShapeFn(shape_inference::ScalarShape); 760 761 REGISTER_OP("IteratorGetNext") 762 .Input("iterator: resource") 763 .Output("components: output_types") 764 .Attr("output_types: list(type) >= 1") 765 .Attr("output_shapes: list(shape) >= 1") 766 .SetShapeFn(shape_inference::DatasetIteratorShape); 767 768 REGISTER_OP("IteratorGetNextSync") 769 .Input("iterator: resource") 770 .Output("components: output_types") 771 .Attr("output_types: list(type) >= 1") 772 .Attr("output_shapes: list(shape) >= 1") 773 .SetShapeFn(shape_inference::DatasetIteratorShape); 774 775 // TODO(b/124308596): Instead of conservatively marking this op as stateful, 776 // implement a mechanism to determine whether `dataset` has a side-effect 777 // and use it to decide whether to use a stateless or stateful version of this 778 // op. 779 REGISTER_OP("DatasetToSingleElement") 780 .Input("dataset: variant") 781 .Output("components: output_types") 782 .Attr("output_types: list(type) >= 1") 783 .Attr("output_shapes: list(shape) >= 1") 784 .SetIsStateful() 785 .SetShapeFn(shape_inference::DatasetIteratorShape); 786 787 // TODO(b/124308596): Instead of conservatively marking this op as stateful, 788 // implement a mechanism to determine whether `dataset` has a side-effect 789 // and use it to decide whether to use a stateless or stateful version of this 790 // op. 791 REGISTER_OP("ReduceDataset") 792 .Input("input_dataset: variant") 793 .Input("initial_state: Tstate") 794 .Input("other_arguments: Targuments") 795 .Output("components: output_types") 796 .Attr("f: func") 797 .Attr("Tstate: list(type) >= 1") 798 .Attr("Targuments: list(type) >= 0") 799 .Attr("output_types: list(type) >= 1") 800 .Attr("output_shapes: list(shape) >= 1") 801 .Attr("use_inter_op_parallelism: bool = true") 802 .SetIsStateful() 803 .SetShapeFn(shape_inference::DatasetIteratorShape); 804 805 REGISTER_OP("IteratorToStringHandle") 806 .Input("resource_handle: resource") 807 .Output("string_handle: string") 808 .SetShapeFn(shape_inference::ScalarShape); 809 810 REGISTER_OP("IteratorFromStringHandle") 811 .Input("string_handle: string") 812 .Output("resource_handle: resource") 813 .Attr("output_types: list(type) >= 0 = []") 814 .Attr("output_shapes: list(shape) >= 0 = []") 815 .SetShapeFn(shape_inference::ScalarShape); 816 817 REGISTER_OP("IteratorFromStringHandleV2") 818 .Input("string_handle: string") 819 .Output("resource_handle: resource") 820 .Attr("output_types: list(type) >= 0 = []") 821 .Attr("output_shapes: list(shape) >= 0 = []") 822 .SetShapeFn(shape_inference::ScalarShape); 823 824 REGISTER_OP("SerializeIterator") 825 .Input("resource_handle: resource") 826 .Attr("external_state_policy: int = 0") 827 .Output("serialized: variant") __anon0b7d74c51e02(shape_inference::InferenceContext* c) 828 .SetShapeFn([](shape_inference::InferenceContext* c) { 829 c->set_output(0, c->Vector(c->UnknownDim())); 830 return Status::OK(); 831 }); 832 833 REGISTER_OP("DeserializeIterator") 834 .Input("resource_handle: resource") 835 .Input("serialized: variant") 836 .SetShapeFn(shape_inference::NoOutputs); 837 838 REGISTER_OP("DatasetToGraph") 839 .Input("input_dataset: variant") 840 .Attr("stateful_whitelist: list(string) >= 0 = []") 841 .Attr("allow_stateful: bool = false") 842 .Attr("strip_device_assignment: bool = false") 843 .Output("graph: string") 844 .SetShapeFn(shape_inference::ScalarShape); 845 846 REGISTER_OP("DatasetToGraphV2") 847 .Input("input_dataset: variant") 848 .Attr("external_state_policy: int = 0") 849 .Attr("strip_device_assignment: bool = false") 850 .Output("graph: string") 851 .SetShapeFn(shape_inference::ScalarShape); 852 853 REGISTER_OP("OptimizeDataset") 854 .Input("input_dataset: variant") 855 .Input("optimizations: string") 856 .Output("handle: variant") 857 .Attr("output_types: list(type) >= 1") 858 .Attr("output_shapes: list(shape) >= 1") 859 .Attr("optimization_configs: list(string) = []") 860 .SetShapeFn(shape_inference::ScalarShape); 861 862 REGISTER_OP("OptimizeDatasetV2") 863 .Input("input_dataset: variant") 864 .Input("optimizations_enabled: string") 865 .Input("optimizations_disabled: string") 866 .Input("optimizations_default: string") 867 .Output("handle: variant") 868 .Attr("output_types: list(type) >= 1") 869 .Attr("output_shapes: list(shape) >= 1") 870 .Attr("optimization_configs: list(string) = []") 871 .SetShapeFn(shape_inference::ScalarShape); 872 873 REGISTER_OP("OptionalFromValue") 874 .Input("components: Toutput_types") 875 .Output("optional: variant") 876 .Attr("Toutput_types: list(type) >= 1") 877 .SetTypeConstructor(full_type::Unary(TFT_OPTIONAL, "Toutput_types")) __anon0b7d74c51f02(shape_inference::InferenceContext* c) 878 .SetShapeFn([](shape_inference::InferenceContext* c) { 879 std::vector<DataType> dtypes; 880 TF_RETURN_IF_ERROR(c->GetAttr("Toutput_types", &dtypes)); 881 c->set_output(0, c->Scalar()); 882 std::vector<shape_inference::ShapeAndType> shapes_and_types; 883 shapes_and_types.reserve(c->num_inputs()); 884 const FullTypeDef& ret_types = c->ret_types(); 885 for (int i = 0; i < c->num_inputs(); ++i) { 886 // TODO(mdan): output_type(i) == optional is incorrect. 887 // "Optional" is the type of the the whole container, not of individual 888 // elements. 889 // 890 // Why ret_types.args(0) and not args(i) -- 891 // For example if Toutput_types is (int32, float32), then 892 // ret_types.args[0] (i.e. the 0th output) is 893 // Optional[Record[Tensor[int32, s1], Tensor[float32, s2]]] 894 // set_output_handle_shapes_and_types tracks the same thing, but in 895 // a transposed way: 896 // {ShapeAndType(in32, s1, Optional), ShapeAndType(in32, s2, Optional)} 897 // That should be corrected in the future (see todo above). 898 shapes_and_types.emplace_back(c->input(i), dtypes[i], 899 ret_types.args(0)); 900 } 901 c->set_output_handle_shapes_and_types(0, shapes_and_types); 902 return Status::OK(); 903 }); 904 905 REGISTER_OP("OptionalNone") 906 .Output("optional: variant") 907 .SetShapeFn(shape_inference::ScalarShape); 908 909 REGISTER_OP("OptionalHasValue") 910 .Input("optional: variant") 911 .Output("has_value: bool") 912 .SetShapeFn(shape_inference::ScalarShape); 913 914 REGISTER_OP("OptionalGetValue") 915 .Input("optional: variant") 916 .Output("components: output_types") 917 .Attr("output_types: list(type) >= 1") 918 .Attr("output_shapes: list(shape) >= 1") 919 .SetShapeFn(shape_inference::DatasetIteratorShape); 920 921 REGISTER_OP("IteratorGetNextAsOptional") 922 .Input("iterator: resource") 923 .Output("optional: variant") 924 .Attr("output_types: list(type) >= 1") 925 .Attr("output_shapes: list(shape) >= 1") 926 .SetShapeFn(shape_inference::ScalarShape); 927 928 REGISTER_OP("ModelDataset") 929 .Input("input_dataset: variant") 930 .Output("handle: variant") 931 .Attr("algorithm: int = 0") 932 .Attr("cpu_budget: int = 0") 933 .Attr("ram_budget: int = 0") 934 .Attr("output_types: list(type) >= 1") 935 .Attr("output_shapes: list(shape) >= 1") 936 .SetShapeFn(shape_inference::ScalarShape); 937 938 // TODO(b/124308749): Add a stateful version of MapDefun and use it when `f` 939 // is stateful. 940 REGISTER_OP("MapDefun") 941 .Input("arguments: Targuments") 942 .Input("captured_inputs: Tcaptured") 943 .Output("output: output_types") 944 .Attr("Targuments: list(type) >= 1") 945 .Attr("Tcaptured: list(type) >= 0 = []") 946 .Attr("output_types: list(type) >= 1") 947 .Attr("output_shapes: list(shape) >= 1") 948 .Attr("f: func") 949 .Attr("max_intra_op_parallelism: int = 1") __anon0b7d74c52002(shape_inference::InferenceContext* c) 950 .SetShapeFn([](shape_inference::InferenceContext* c) { 951 std::vector<PartialTensorShape> output_shapes; 952 TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); 953 DataTypeVector t_args; 954 TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args)); 955 if (output_shapes.size() != c->num_outputs()) { 956 return errors::InvalidArgument( 957 "`output_shapes` must be the same length as `output_types` (", 958 output_shapes.size(), " vs. ", c->num_outputs(), ")"); 959 } 960 961 int64_t dim_zero = -1; 962 for (size_t i = 0; i < t_args.size(); ++i) { 963 if (c->Rank(c->input(i)) == 0) { 964 return errors::InvalidArgument( 965 "Arguments must have rank at least 1. Input ", i, 966 " has rank of 0."); 967 } 968 auto dim_handle = c->Dim(c->input(i), 0); 969 if (c->ValueKnown(dim_handle)) { 970 if (dim_zero == -1) { 971 dim_zero = c->Value(dim_handle); 972 } else if (c->Value(dim_handle) != dim_zero) { 973 return errors::InvalidArgument( 974 "Arguments must have the same dimension 0."); 975 } 976 } 977 } 978 979 for (size_t i = 0; i < output_shapes.size(); ++i) { 980 PartialTensorShape s({}); 981 s = s.Concatenate(dim_zero); 982 s = s.Concatenate(output_shapes[i]); 983 shape_inference::ShapeHandle output_shape_handle; 984 985 TF_RETURN_IF_ERROR( 986 c->MakeShapeFromPartialTensorShape(s, &output_shape_handle)); 987 c->set_output(static_cast<int>(i), output_shape_handle); 988 } 989 return Status::OK(); 990 }); 991 992 REGISTER_OP("WrapDatasetVariant") 993 .Input("input_handle: variant") 994 .Output("output_handle: variant") 995 .SetShapeFn(shape_inference::ScalarShape); 996 997 REGISTER_OP("UnwrapDatasetVariant") 998 .Input("input_handle: variant") 999 .Output("output_handle: variant") 1000 .SetShapeFn(shape_inference::ScalarShape); 1001 1002 REGISTER_OP("AnonymousMultiDeviceIterator") 1003 .Output("handle: resource") 1004 .Output("deleter: variant") 1005 .Attr("devices: list(string) >= 1") 1006 .Attr("output_types: list(type) >= 1") 1007 .Attr("output_shapes: list(shape) >= 1") __anon0b7d74c52102(shape_inference::InferenceContext* c) 1008 .SetShapeFn([](shape_inference::InferenceContext* c) { 1009 c->set_output(0, c->Scalar()); 1010 c->set_output(1, c->Scalar()); 1011 return Status::OK(); 1012 }); 1013 1014 REGISTER_OP("MultiDeviceIterator") 1015 .Output("handle: resource") 1016 .Attr("devices: list(string) >= 1") 1017 .Attr("shared_name: string") 1018 .Attr("container: string") 1019 .Attr("output_types: list(type) >= 1") 1020 .Attr("output_shapes: list(shape) >= 1") 1021 .SetShapeFn(shape_inference::ScalarShape); 1022 1023 REGISTER_OP("MultiDeviceIteratorInit") 1024 .Input("dataset: variant") 1025 .Input("multi_device_iterator: resource") 1026 .Input("max_buffer_size: int64") 1027 .Output("incarnation_id: int64") 1028 .SetShapeFn(shape_inference::ScalarShape); 1029 1030 REGISTER_OP("MultiDeviceIteratorGetNextFromShard") 1031 .Input("multi_device_iterator: resource") 1032 .Input("shard_num: int32") 1033 .Input("incarnation_id: int64") 1034 .Output("components: output_types") 1035 .Attr("output_types: list(type) >= 1") 1036 .Attr("output_shapes: list(shape) >= 1") 1037 .SetShapeFn(shape_inference::DatasetIteratorShape); 1038 1039 REGISTER_OP("MultiDeviceIteratorToStringHandle") 1040 .Input("multi_device_iterator: resource") 1041 .Output("string_handle: string") 1042 .SetShapeFn(shape_inference::ScalarShape); 1043 1044 REGISTER_OP("MultiDeviceIteratorFromStringHandle") 1045 .Input("string_handle: string") 1046 .Output("multi_device_iterator: resource") 1047 .Attr("output_types: list(type) >= 0 = []") 1048 .Attr("output_shapes: list(shape) >= 0 = []") 1049 .SetShapeFn(shape_inference::ScalarShape); 1050 1051 REGISTER_OP("OptionsDataset") 1052 .Input("input_dataset: variant") 1053 .Output("handle: variant") 1054 .Attr("serialized_options: string") 1055 .Attr("output_types: list(type) >= 1") 1056 .Attr("output_shapes: list(shape) >= 1") 1057 .SetShapeFn(shape_inference::ScalarShape); 1058 1059 REGISTER_OP("GetOptions") 1060 .Input("input_dataset: variant") 1061 .Output("serialized_options: string") 1062 .SetShapeFn(shape_inference::ScalarShape); 1063 1064 REGISTER_OP("FinalizeDataset") 1065 .Input("input_dataset: variant") 1066 .Output("handle: variant") 1067 .Attr("has_captured_ref: bool = false") 1068 .Attr("output_types: list(type) >= 1") 1069 .Attr("output_shapes: list(shape) >= 1") 1070 .SetShapeFn(shape_inference::ScalarShape); 1071 1072 } // namespace tensorflow 1073