• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/full_type.pb.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/op_def_builder.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 
21 namespace tensorflow {
22 
23 // --------------------------------------------------------------------------
24 
25 // The ops in this section can be composed to define an input
26 // pipeline. Each op produces a DT_VARIANT tensor that represents
27 // a DAG of "dataset" objects. An "dataset" object can be converted
28 // to a stateful "iterator" by passing the "dataset" to the
29 // "MakeIterator" op.
30 //
31 // TODO(b/123753214): DT_VARIANT tensors that represent "dataset" objects are
32 // not presently serializable. To avoid issues with graph optimizations, such
33 // as constant folding, CSE, or DCE, ensure that any "source dataset" ops
34 // (i.e. ops that output a dataset and do not take one as input) are
35 // marked as "do not optimize".
36 
37 // TODO(mrry): Validate that `components` have shapes compatible with
38 // `output_shapes`.
39 REGISTER_OP("TensorDataset")
40     .Input("components: Toutput_types")
41     .Output("handle: variant")
42     .Attr("Toutput_types: list(type) >= 1")
43     .Attr("output_shapes: list(shape) >= 1")
44     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
45     .SetShapeFn(shape_inference::ScalarShape);
46 
47 // TODO(mrry): Validate that the dim-0 slices of `components` have shapes
48 // compatible with `output_shapes`.
49 REGISTER_OP("TensorSliceDataset")
50     .Input("components: Toutput_types")
51     .Output("handle: variant")
52     .Attr("Toutput_types: list(type) >= 1")
53     .Attr("output_shapes: list(shape) >= 1")
54     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
55     .SetShapeFn(shape_inference::ScalarShape);
56 
57 REGISTER_OP("SparseTensorSliceDataset")
58     .Input("indices: int64")
59     .Input("values: Tvalues")
60     .Input("dense_shape: int64")
61     .Output("handle: variant")
62     .Attr("Tvalues: type")
63     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
64     .SetShapeFn(shape_inference::ScalarShape);
65 
66 REGISTER_OP("GeneratorDataset")
67     .Input("init_func_other_args: Tinit_func_args")
68     .Input("next_func_other_args: Tnext_func_args")
69     .Input("finalize_func_other_args: Tfinalize_func_args")
70     .Output("handle: variant")
71     .Attr("init_func: func")
72     .Attr("next_func: func")
73     .Attr("finalize_func: func")
74     .Attr("Tinit_func_args: list(type) >= 0")
75     .Attr("Tnext_func_args: list(type) >= 0")
76     .Attr("Tfinalize_func_args: list(type) >= 0")
77     .Attr("output_types: list(type) >= 1")
78     .Attr("output_shapes: list(shape) >= 1")
79     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
80     .SetShapeFn(shape_inference::ScalarShape);
81 
82 REGISTER_OP("ZipDataset")
83     .Input("input_datasets: N * variant")
84     .Output("handle: variant")
85     .Attr("output_types: list(type) >= 1")
86     .Attr("output_shapes: list(shape) >= 1")
87     .Attr("N: int >= 1")
88     .SetShapeFn(shape_inference::ScalarShape);
89 
90 REGISTER_OP("ConcatenateDataset")
91     .Input("input_dataset: variant")
92     .Input("another_dataset: variant")
93     .Output("handle: variant")
94     .Attr("output_types: list(type) >= 1")
95     .Attr("output_shapes: list(shape) >= 1")
96     .SetShapeFn(shape_inference::ScalarShape);
97 
98 REGISTER_OP("RepeatDataset")
99     .Input("input_dataset: variant")
100     .Input("count: int64")
101     .Output("handle: variant")
102     .Attr("output_types: list(type) >= 1")
103     .Attr("output_shapes: list(shape) >= 1")
__anon0b7d74c50102(shape_inference::InferenceContext* c) 104     .SetShapeFn([](shape_inference::InferenceContext* c) {
105       shape_inference::ShapeHandle count_shape;
106       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape));
107       return shape_inference::ScalarShape(c);
108     });
109 
110 REGISTER_OP("TakeDataset")
111     .Input("input_dataset: variant")
112     .Input("count: int64")
113     .Output("handle: variant")
114     .Attr("output_types: list(type) >= 1")
115     .Attr("output_shapes: list(shape) >= 1")
__anon0b7d74c50202(shape_inference::InferenceContext* c) 116     .SetShapeFn([](shape_inference::InferenceContext* c) {
117       shape_inference::ShapeHandle count_shape;
118       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape));
119       return shape_inference::ScalarShape(c);
120     });
121 
122 REGISTER_OP("SkipDataset")
123     .Input("input_dataset: variant")
124     .Input("count: int64")
125     .Output("handle: variant")
126     .Attr("output_types: list(type) >= 1")
127     .Attr("output_shapes: list(shape) >= 1")
__anon0b7d74c50302(shape_inference::InferenceContext* c) 128     .SetShapeFn([](shape_inference::InferenceContext* c) {
129       shape_inference::ShapeHandle count_shape;
130       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape));
131       return shape_inference::ScalarShape(c);
132     });
133 
134 REGISTER_OP("MapDataset")
135     .Input("input_dataset: variant")
136     .Input("other_arguments: Targuments")
137     .Output("handle: variant")
138     .Attr("f: func")
139     .Attr("Targuments: list(type) >= 0")
140     .Attr("output_types: list(type) >= 1")
141     .Attr("output_shapes: list(shape) >= 1")
142     .Attr("use_inter_op_parallelism: bool = true")
143     .Attr("preserve_cardinality: bool = false")
144     .SetShapeFn(shape_inference::ScalarShape);
145 
146 REGISTER_OP("ParallelMapDataset")
147     .Input("input_dataset: variant")
148     .Input("other_arguments: Targuments")
149     .Input("num_parallel_calls: int32")
150     .Output("handle: variant")
151     .Attr("f: func")
152     .Attr("Targuments: list(type) >= 0")
153     .Attr("output_types: list(type) >= 1")
154     .Attr("output_shapes: list(shape) >= 1")
155     .Attr("use_inter_op_parallelism: bool = true")
156     .Attr("sloppy: bool = false")
157     .Attr("preserve_cardinality: bool = false")
158     .SetShapeFn(shape_inference::ScalarShape);
159 
160 REGISTER_OP("ParallelMapDatasetV2")
161     .Input("input_dataset: variant")
162     .Input("other_arguments: Targuments")
163     .Input("num_parallel_calls: int64")
164     .Output("handle: variant")
165     .Attr("f: func")
166     .Attr("Targuments: list(type) >= 0")
167     .Attr("output_types: list(type) >= 1")
168     .Attr("output_shapes: list(shape) >= 1")
169     .Attr("use_inter_op_parallelism: bool = true")
170     // "true", "false", or "default".
171     .Attr("deterministic: string = 'default'")
172     .Attr("preserve_cardinality: bool = false")
173     .SetShapeFn(shape_inference::ScalarShape);
174 
175 REGISTER_OP("PrefetchDataset")
176     .Input("input_dataset: variant")
177     .Input("buffer_size: int64")
178     .Output("handle: variant")
179     .Attr("output_types: list(type) >= 1")
180     .Attr("output_shapes: list(shape) >= 1")
181     .Attr("slack_period: int = 0")
182     .Attr("legacy_autotune: bool = true")
183     .Attr("buffer_size_min: int = 0")
__anon0b7d74c50402(shape_inference::InferenceContext* c) 184     .SetShapeFn([](shape_inference::InferenceContext* c) {
185       shape_inference::ShapeHandle unused;
186       // buffer_size should be a scalar.
187       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
188       return shape_inference::ScalarShape(c);
189     });
190 
191 REGISTER_OP("FlatMapDataset")
192     .Input("input_dataset: variant")
193     .Input("other_arguments: Targuments")
194     .Output("handle: variant")
195     .Attr("f: func")
196     .Attr("Targuments: list(type) >= 0")
197     .Attr("output_types: list(type) >= 1")
198     .Attr("output_shapes: list(shape) >= 1")
199     .SetShapeFn(shape_inference::ScalarShape);
200 
201 REGISTER_OP("InterleaveDataset")
202     .Input("input_dataset: variant")
203     .Input("other_arguments: Targuments")
204     .Input("cycle_length: int64")
205     .Input("block_length: int64")
206     .Output("handle: variant")
207     .Attr("f: func")
208     .Attr("Targuments: list(type) >= 0")
209     .Attr("output_types: list(type) >= 1")
210     .Attr("output_shapes: list(shape) >= 1")
211     .SetShapeFn(shape_inference::ScalarShape);
212 
213 REGISTER_OP("ParallelInterleaveDatasetV2")
214     .Input("input_dataset: variant")
215     .Input("other_arguments: Targuments")
216     .Input("cycle_length: int64")
217     .Input("block_length: int64")
218     .Input("num_parallel_calls: int64")
219     .Output("handle: variant")
220     .Attr("f: func")
221     .Attr("Targuments: list(type) >= 0")
222     .Attr("output_types: list(type) >= 1")
223     .Attr("output_shapes: list(shape) >= 1")
224     .Attr("sloppy: bool = false")
225     .SetShapeFn(shape_inference::ScalarShape);
226 
227 REGISTER_OP("ParallelInterleaveDatasetV3")
228     .Input("input_dataset: variant")
229     .Input("other_arguments: Targuments")
230     .Input("cycle_length: int64")
231     .Input("block_length: int64")
232     .Input("num_parallel_calls: int64")
233     .Output("handle: variant")
234     .Attr("f: func")
235     // "true", "false", or "default".
236     .Attr("deterministic: string = 'default'")
237     .Attr("Targuments: list(type) >= 0")
238     .Attr("output_types: list(type) >= 1")
239     .Attr("output_shapes: list(shape) >= 1")
240     .SetShapeFn(shape_inference::ScalarShape);
241 
242 // Like V3, but adds buffer_output_elements and prefetch_input_elements.
243 REGISTER_OP("ParallelInterleaveDatasetV4")
244     .Input("input_dataset: variant")
245     .Input("other_arguments: Targuments")
246     .Input("cycle_length: int64")
247     .Input("block_length: int64")
248     .Input("buffer_output_elements: int64")
249     .Input("prefetch_input_elements: int64")
250     .Input("num_parallel_calls: int64")
251     .Output("handle: variant")
252     .Attr("f: func")
253     // "true", "false", or "default".
254     .Attr("deterministic: string = 'default'")
255     .Attr("Targuments: list(type) >= 0")
256     .Attr("output_types: list(type) >= 1")
257     .Attr("output_shapes: list(shape) >= 1")
258     .SetShapeFn(shape_inference::ScalarShape);
259 
260 REGISTER_OP("FilterDataset")
261     .Input("input_dataset: variant")
262     .Input("other_arguments: Targuments")
263     .Output("handle: variant")
264     .Attr("predicate: func")
265     .Attr("Targuments: list(type) >= 0")
266     .Attr("output_types: list(type) >= 1")
267     .Attr("output_shapes: list(shape) >= 1")
268     .SetShapeFn(shape_inference::ScalarShape);
269 
270 // This op is no longer supported.
271 REGISTER_OP("FilterByLastComponentDataset")
272     .Input("input_dataset: variant")
273     .Output("output: variant")
274     .Attr("output_types: list(type) >= 1")
275     .Attr("output_shapes: list(shape) >= 1")
276     .SetShapeFn(shape_inference::ScalarShape);
277 
278 REGISTER_OP("WindowDataset")
279     .Input("input_dataset: variant")
280     .Input("size: int64")
281     .Input("shift: int64")
282     .Input("stride: int64")
283     .Input("drop_remainder: bool")
284     .Output("handle: variant")
285     .Attr("output_types: list(type) >= 1")
286     .Attr("output_shapes: list(shape) >= 1")
__anon0b7d74c50502(shape_inference::InferenceContext* c) 287     .SetShapeFn([](shape_inference::InferenceContext* c) {
288       shape_inference::ShapeHandle unused;
289       // size, shift, stride, and drop_remainder should be scalars.
290       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
291       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
292       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
293       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
294       return shape_inference::ScalarShape(c);
295     });
296 
297 REGISTER_OP("WindowOp")
298     .Input("inputs: Tinputs")
299     .Output("handle: variant")
300     .Attr("output_types: list(type) >= 1")
301     .Attr("output_shapes: list(shape) >= 1")
302     .Attr("Tinputs: list(type) >= 1")
303     .SetShapeFn(shape_inference::ScalarShape);
304 
305 REGISTER_OP("BatchDataset")
306     .Input("input_dataset: variant")
307     .Input("batch_size: int64")
308     .Output("handle: variant")
309     .Attr("output_types: list(type) >= 1")
310     .Attr("output_shapes: list(shape) >= 1")
__anon0b7d74c50602(shape_inference::InferenceContext* c) 311     .SetShapeFn([](shape_inference::InferenceContext* c) {
312       shape_inference::ShapeHandle unused;
313       // batch_size should be a scalar.
314       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
315       return shape_inference::ScalarShape(c);
316     });
317 
318 REGISTER_OP("BatchDatasetV2")
319     .Input("input_dataset: variant")
320     .Input("batch_size: int64")
321     .Input("drop_remainder: bool")
322     .Output("handle: variant")
323     .Attr("parallel_copy: bool = false")
324     .Attr("output_types: list(type) >= 1")
325     .Attr("output_shapes: list(shape) >= 1")
__anon0b7d74c50702(shape_inference::InferenceContext* c) 326     .SetShapeFn([](shape_inference::InferenceContext* c) {
327       shape_inference::ShapeHandle unused;
328       // batch_size should be a scalar.
329       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
330       // drop_remainder should be a scalar.
331       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
332       return shape_inference::ScalarShape(c);
333     });
334 
335 REGISTER_OP("ParallelBatchDataset")
336     .Input("input_dataset: variant")
337     .Input("batch_size: int64")
338     .Input("num_parallel_calls: int64")
339     .Input("drop_remainder: bool")
340     .Output("handle: variant")
341     .Attr("parallel_copy: bool = false")
342     .Attr("output_types: list(type) >= 1")
343     .Attr("output_shapes: list(shape) >= 1")
344     // "true", "false", or "default".
345     .Attr("deterministic: string = 'default'")
__anon0b7d74c50802(shape_inference::InferenceContext* c) 346     .SetShapeFn([](shape_inference::InferenceContext* c) {
347       shape_inference::ShapeHandle unused;
348       // batch_size should be a scalar.
349       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
350       // num_parallel_calls should be a scalar.
351       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
352       // drop_remainder should be a scalar.
353       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
354       return shape_inference::ScalarShape(c);
355     });
356 
357 REGISTER_OP("ShardDataset")
358     .Input("input_dataset: variant")
359     .Input("num_shards: int64")
360     .Input("index: int64")
361     .Output("handle: variant")
362     .Attr("require_non_empty: bool = false")
363     .Attr("output_types: list(type) >= 1")
364     .Attr("output_shapes: list(shape) >= 1")
__anon0b7d74c50902(shape_inference::InferenceContext* c) 365     .SetShapeFn([](shape_inference::InferenceContext* c) {
366       shape_inference::ShapeHandle unused;
367       // num_shards should be a scalar.
368       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
369       // index should be a scalar.
370       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
371       return shape_inference::ScalarShape(c);
372     });
373 
374 // TODO(mrry): Validate that `padded_shapes` are all vectors, the lengths of
375 // `output_types` and `output_shapes` are `N` the `output_shapes` are (as far as
376 // possible to tell statically) compatible with `padded_shapes`, and that
377 // `padding_values` are all scalars.
378 REGISTER_OP("PaddedBatchDataset")
379     .Input("input_dataset: variant")
380     .Input("batch_size: int64")
381     .Input("padded_shapes: N * int64")
382     .Input("padding_values: Toutput_types")
383     .Output("handle: variant")
384     .Attr("Toutput_types: list(type) >= 1")
385     .Attr("output_shapes: list(shape) >= 1")
386     .Attr("N: int >= 1")
__anon0b7d74c50a02(shape_inference::InferenceContext* c) 387     .SetShapeFn([](shape_inference::InferenceContext* c) {
388       shape_inference::ShapeHandle unused;
389       // batch_size should be a scalar.
390       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
391       return shape_inference::ScalarShape(c);
392     });
393 
394 REGISTER_OP("PaddedBatchDatasetV2")
395     .Input("input_dataset: variant")
396     .Input("batch_size: int64")
397     .Input("padded_shapes: N * int64")
398     .Input("padding_values: Toutput_types")
399     .Input("drop_remainder: bool")
400     .Output("handle: variant")
401     .Attr("parallel_copy: bool = false")
402     .Attr("Toutput_types: list(type) >= 1")
403     .Attr("output_shapes: list(shape) >= 1")
404     .Attr("N: int >= 1")
__anon0b7d74c50b02(shape_inference::InferenceContext* c) 405     .SetShapeFn([](shape_inference::InferenceContext* c) {
406       shape_inference::ShapeHandle unused;
407       // batch_size should be a scalar.
408       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
409       // drop_remainder should be a scalar.
410       TF_RETURN_IF_ERROR(
411           c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
412       return shape_inference::ScalarShape(c);
413     });
414 
415 REGISTER_OP("RangeDataset")
416     .Input("start: int64")
417     .Input("stop: int64")
418     .Input("step: int64")
419     .Output("handle: variant")
420     .Attr("output_types: list(type) >= 1")
421     .Attr("output_shapes: list(shape) >= 1")
422     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
__anon0b7d74c50c02(shape_inference::InferenceContext* c) 423     .SetShapeFn([](shape_inference::InferenceContext* c) {
424       shape_inference::ShapeHandle unused;
425       // start, stop, and step should be scalars.
426       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
427       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
428       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
429       return shape_inference::ScalarShape(c);
430     });
431 
432 REGISTER_OP("AnonymousSeedGenerator")
433     .Input("seed: int64")
434     .Input("seed2: int64")
435     .Input("reshuffle: bool")
436     .Output("handle: resource")
437     .Output("deleter: variant")
__anon0b7d74c50d02(shape_inference::InferenceContext* c) 438     .SetShapeFn([](shape_inference::InferenceContext* c) {
439       c->set_output(0, c->Scalar());
440       c->set_output(1, c->Scalar());
441       return Status::OK();
442     });
443 
444 REGISTER_OP("DatasetCardinality")
445     .Input("input_dataset: variant")
446     .Output("cardinality: int64")
447     .SetShapeFn(shape_inference::ScalarShape);
448 
449 REGISTER_OP("DeleteSeedGenerator")
450     .Input("handle: resource")
451     .Input("deleter: variant")
452     .SetShapeFn(shape_inference::NoOutputs);
453 
454 // Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator.
455 REGISTER_OP("AnonymousRandomSeedGenerator")
456     .Input("seed: int64")
457     .Input("seed2: int64")
458     .Output("handle: resource")
459     .Output("deleter: variant")
__anon0b7d74c50e02(shape_inference::InferenceContext* c) 460     .SetShapeFn([](shape_inference::InferenceContext* c) {
461       c->set_output(0, c->Scalar());
462       c->set_output(1, c->Scalar());
463       return Status::OK();
464     });
465 
466 // Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator.
467 REGISTER_OP("DeleteRandomSeedGenerator")
468     .Input("handle: resource")
469     .Input("deleter: variant")
470     .SetShapeFn(shape_inference::NoOutputs);
471 
472 REGISTER_OP("DummySeedGenerator")
473     .Output("handle: resource")
__anon0b7d74c50f02(shape_inference::InferenceContext* c) 474     .SetShapeFn([](shape_inference::InferenceContext* c) {
475       c->set_output(0, c->Scalar());
476       return Status::OK();
477     });
478 
479 REGISTER_OP("ShuffleDataset")
480     .Input("input_dataset: variant")
481     .Input("buffer_size: int64")
482     .Input("seed: int64")
483     .Input("seed2: int64")
484     .Output("handle: variant")
485     .Attr("reshuffle_each_iteration: bool = true")
486     .Attr("output_types: list(type) >= 1")
487     .Attr("output_shapes: list(shape) >= 1")
__anon0b7d74c51002(shape_inference::InferenceContext* c) 488     .SetShapeFn([](shape_inference::InferenceContext* c) {
489       shape_inference::ShapeHandle unused;
490       // buffer_size, seed, and seed2 should be scalars.
491       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
492       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
493       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
494       return shape_inference::ScalarShape(c);
495     });
496 
497 REGISTER_OP("ShuffleDatasetV2")
498     .Input("input_dataset: variant")
499     .Input("buffer_size: int64")
500     .Input("seed_generator: resource")
501     .Output("handle: variant")
502     .Attr("output_types: list(type) >= 1")
503     .Attr("output_shapes: list(shape) >= 1")
__anon0b7d74c51102(shape_inference::InferenceContext* c) 504     .SetShapeFn([](shape_inference::InferenceContext* c) {
505       shape_inference::ShapeHandle unused;
506       // buffer_size and seed_generator should be scalars.
507       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
508       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
509       return shape_inference::ScalarShape(c);
510     });
511 
512 REGISTER_OP("ShuffleDatasetV3")
513     .Input("input_dataset: variant")
514     .Input("buffer_size: int64")
515     .Input("seed: int64")
516     .Input("seed2: int64")
517     .Input("seed_generator: resource")
518     .Output("handle: variant")
519     .Attr("reshuffle_each_iteration: bool = true")
520     .Attr("output_types: list(type) >= 1")
521     .Attr("output_shapes: list(shape) >= 1")
__anon0b7d74c51202(shape_inference::InferenceContext* c) 522     .SetShapeFn([](shape_inference::InferenceContext* c) {
523       shape_inference::ShapeHandle unused;
524       // buffer_size, seed, seed2, and seed_generator should be scalars.
525       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
526       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
527       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
528       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
529       return shape_inference::ScalarShape(c);
530     });
531 
532 REGISTER_OP("ShuffleAndRepeatDataset")
533     .Input("input_dataset: variant")
534     .Input("buffer_size: int64")
535     .Input("seed: int64")
536     .Input("seed2: int64")
537     .Input("count: int64")
538     .Output("handle: variant")
539     .Attr("output_types: list(type) >= 1")
540     .Attr("output_shapes: list(shape) >= 1")
541     .Attr("reshuffle_each_iteration: bool = true")
__anon0b7d74c51302(shape_inference::InferenceContext* c) 542     .SetShapeFn([](shape_inference::InferenceContext* c) {
543       shape_inference::ShapeHandle unused;
544       // buffer_size, seed, seed2, and count should be scalars.
545       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
546       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
547       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
548       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
549       return shape_inference::ScalarShape(c);
550     });
551 
552 REGISTER_OP("ShuffleAndRepeatDatasetV2")
553     .Input("input_dataset: variant")
554     .Input("buffer_size: int64")
555     .Input("seed: int64")
556     .Input("seed2: int64")
557     .Input("count: int64")
558     .Input("seed_generator: resource")
559     .Output("handle: variant")
560     .Attr("reshuffle_each_iteration: bool = true")
561     .Attr("output_types: list(type) >= 1")
562     .Attr("output_shapes: list(shape) >= 1")
__anon0b7d74c51402(shape_inference::InferenceContext* c) 563     .SetShapeFn([](shape_inference::InferenceContext* c) {
564       shape_inference::ShapeHandle unused;
565       // buffer_size, seed, seed2, count, and seed_generator should be scalars.
566       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
567       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
568       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
569       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
570       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
571       return shape_inference::ScalarShape(c);
572     });
573 
574 REGISTER_OP("AnonymousMemoryCache")
575     .Output("handle: resource")
576     .Output("deleter: variant")
__anon0b7d74c51502(shape_inference::InferenceContext* c) 577     .SetShapeFn([](shape_inference::InferenceContext* c) {
578       c->set_output(0, c->Scalar());
579       c->set_output(1, c->Scalar());
580       return Status::OK();
581     });
582 
583 REGISTER_OP("DeleteMemoryCache")
584     .Input("handle: resource")
585     .Input("deleter: variant")
586     .SetShapeFn(shape_inference::NoOutputs);
587 
588 REGISTER_OP("DummyMemoryCache")
589     .Output("handle: resource")
__anon0b7d74c51602(shape_inference::InferenceContext* c) 590     .SetShapeFn([](shape_inference::InferenceContext* c) {
591       c->set_output(0, c->Scalar());
592       return Status::OK();
593     });
594 
595 REGISTER_OP("CacheDataset")
596     .Input("input_dataset: variant")
597     .Input("filename: string")
598     .Output("handle: variant")
599     .Attr("output_types: list(type) >= 1")
600     .Attr("output_shapes: list(shape) >= 1")
__anon0b7d74c51702(shape_inference::InferenceContext* c) 601     .SetShapeFn([](shape_inference::InferenceContext* c) {
602       shape_inference::ShapeHandle unused;
603       // filename should be a scalar.
604       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
605       return shape_inference::ScalarShape(c);
606     });
607 
608 REGISTER_OP("CacheDatasetV2")
609     .Input("input_dataset: variant")
610     .Input("filename: string")
611     .Input("cache: resource")
612     .Output("handle: variant")
613     .Attr("output_types: list(type) >= 1")
614     .Attr("output_shapes: list(shape) >= 1")
__anon0b7d74c51802(shape_inference::InferenceContext* c) 615     .SetShapeFn([](shape_inference::InferenceContext* c) {
616       shape_inference::ShapeHandle unused;
617       // filename should be a scalar.
618       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
619       // cache should be a scalar.
620       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
621       return shape_inference::ScalarShape(c);
622     });
623 
624 REGISTER_OP("TextLineDataset")
625     .Input("filenames: string")
626     .Input("compression_type: string")
627     .Input("buffer_size: int64")
628     .Output("handle: variant")
629     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
__anon0b7d74c51902(shape_inference::InferenceContext* c) 630     .SetShapeFn([](shape_inference::InferenceContext* c) {
631       shape_inference::ShapeHandle unused;
632       // `filenames` must be a scalar or a vector.
633       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
634       // `compression_type` could only be a scalar.
635       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
636       // `buffer_size` could only be a scalar.
637       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
638       return shape_inference::ScalarShape(c);
639     });
640 
641 REGISTER_OP("FixedLengthRecordDataset")
642     .Input("filenames: string")
643     .Input("header_bytes: int64")
644     .Input("record_bytes: int64")
645     .Input("footer_bytes: int64")
646     .Input("buffer_size: int64")
647     .Output("handle: variant")
648     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
__anon0b7d74c51a02(shape_inference::InferenceContext* c) 649     .SetShapeFn([](shape_inference::InferenceContext* c) {
650       shape_inference::ShapeHandle unused;
651       // `filenames` must be a scalar or a vector.
652       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
653       // header_bytes, record_bytes, footer_bytes, buffer_size should be
654       // scalars.
655       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
656       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
657       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
658       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
659       return shape_inference::ScalarShape(c);
660     });
661 
662 REGISTER_OP("FixedLengthRecordDatasetV2")
663     .Input("filenames: string")
664     .Input("header_bytes: int64")
665     .Input("record_bytes: int64")
666     .Input("footer_bytes: int64")
667     .Input("buffer_size: int64")
668     .Input("compression_type: string")
669     .Output("handle: variant")
670     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
__anon0b7d74c51b02(shape_inference::InferenceContext* c) 671     .SetShapeFn([](shape_inference::InferenceContext* c) {
672       shape_inference::ShapeHandle unused;
673       // `filenames` must be a scalar or a vector.
674       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
675       // header_bytes, record_bytes, footer_bytes, buffer_size should be
676       // scalars.
677       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
678       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
679       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
680       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
681       return shape_inference::ScalarShape(c);
682     });
683 
684 REGISTER_OP("TFRecordDataset")
685     .Input("filenames: string")
686     .Input("compression_type: string")
687     .Input("buffer_size: int64")
688     .Output("handle: variant")
689     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
__anon0b7d74c51c02(shape_inference::InferenceContext* c) 690     .SetShapeFn([](shape_inference::InferenceContext* c) {
691       shape_inference::ShapeHandle unused;
692       // `filenames` must be a scalar or a vector.
693       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
694       // `compression_type` could only be a scalar.
695       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
696       // `buffer_size` could only be a scalar.
697       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
698       return shape_inference::ScalarShape(c);
699     });
700 
701 REGISTER_OP("Iterator")
702     .Output("handle: resource")
703     .Attr("shared_name: string")
704     .Attr("container: string")
705     .Attr("output_types: list(type) >= 1")
706     .Attr("output_shapes: list(shape) >= 1")
707     .SetShapeFn(shape_inference::ScalarShape);
708 
709 REGISTER_OP("IteratorV2")
710     .Output("handle: resource")
711     .Attr("shared_name: string")
712     .Attr("container: string")
713     .Attr("output_types: list(type) >= 1")
714     .Attr("output_shapes: list(shape) >= 1")
715     .SetShapeFn(shape_inference::ScalarShape);
716 
717 REGISTER_OP("AnonymousIterator")
718     .Output("handle: resource")
719     .Attr("output_types: list(type) >= 1")
720     .Attr("output_shapes: list(shape) >= 1")
721     .SetShapeFn(shape_inference::ScalarShape);
722 
723 REGISTER_OP("AnonymousIteratorV2")
724     .Output("handle: resource")
725     .Output("deleter: variant")
726     .Attr("output_types: list(type) >= 1")
727     .Attr("output_shapes: list(shape) >= 1")
__anon0b7d74c51d02(shape_inference::InferenceContext* c) 728     .SetShapeFn([](shape_inference::InferenceContext* c) {
729       c->set_output(0, c->Scalar());
730       c->set_output(1, c->Scalar());
731       return Status::OK();
732     });
733 
734 REGISTER_OP("DeleteIterator")
735     .Input("handle: resource")
736     .Input("deleter: variant")
737     .SetShapeFn(shape_inference::NoOutputs);
738 
739 REGISTER_OP("DeleteMultiDeviceIterator")
740     .Input("multi_device_iterator: resource")
741     .Input("iterators: N * resource")
742     .Input("deleter: variant")
743     .Attr("N: int >= 0")
744     .SetShapeFn(shape_inference::NoOutputs);
745 
746 REGISTER_OP("MakeIterator")
747     .Input("dataset: variant")
748     .Input("iterator: resource")
749     .SetShapeFn(shape_inference::NoOutputs);
750 
751 REGISTER_OP("OneShotIterator")
752     .Output("handle: resource")
753     .Attr("dataset_factory: func")
754     .Attr("output_types: list(type) >= 1")
755     .Attr("output_shapes: list(shape) >= 1")
756     .Attr("container: string = ''")
757     .Attr("shared_name: string = ''")
758     .SetIsStateful()
759     .SetShapeFn(shape_inference::ScalarShape);
760 
761 REGISTER_OP("IteratorGetNext")
762     .Input("iterator: resource")
763     .Output("components: output_types")
764     .Attr("output_types: list(type) >= 1")
765     .Attr("output_shapes: list(shape) >= 1")
766     .SetShapeFn(shape_inference::DatasetIteratorShape);
767 
768 REGISTER_OP("IteratorGetNextSync")
769     .Input("iterator: resource")
770     .Output("components: output_types")
771     .Attr("output_types: list(type) >= 1")
772     .Attr("output_shapes: list(shape) >= 1")
773     .SetShapeFn(shape_inference::DatasetIteratorShape);
774 
775 // TODO(b/124308596): Instead of conservatively marking this op as stateful,
776 // implement a mechanism to determine whether `dataset` has a side-effect
777 // and use it to decide whether to use a stateless or stateful version of this
778 // op.
779 REGISTER_OP("DatasetToSingleElement")
780     .Input("dataset: variant")
781     .Output("components: output_types")
782     .Attr("output_types: list(type) >= 1")
783     .Attr("output_shapes: list(shape) >= 1")
784     .SetIsStateful()
785     .SetShapeFn(shape_inference::DatasetIteratorShape);
786 
787 // TODO(b/124308596): Instead of conservatively marking this op as stateful,
788 // implement a mechanism to determine whether `dataset` has a side-effect
789 // and use it to decide whether to use a stateless or stateful version of this
790 // op.
791 REGISTER_OP("ReduceDataset")
792     .Input("input_dataset: variant")
793     .Input("initial_state: Tstate")
794     .Input("other_arguments: Targuments")
795     .Output("components: output_types")
796     .Attr("f: func")
797     .Attr("Tstate: list(type) >= 1")
798     .Attr("Targuments: list(type) >= 0")
799     .Attr("output_types: list(type) >= 1")
800     .Attr("output_shapes: list(shape) >= 1")
801     .Attr("use_inter_op_parallelism: bool = true")
802     .SetIsStateful()
803     .SetShapeFn(shape_inference::DatasetIteratorShape);
804 
805 REGISTER_OP("IteratorToStringHandle")
806     .Input("resource_handle: resource")
807     .Output("string_handle: string")
808     .SetShapeFn(shape_inference::ScalarShape);
809 
810 REGISTER_OP("IteratorFromStringHandle")
811     .Input("string_handle: string")
812     .Output("resource_handle: resource")
813     .Attr("output_types: list(type) >= 0 = []")
814     .Attr("output_shapes: list(shape) >= 0 = []")
815     .SetShapeFn(shape_inference::ScalarShape);
816 
817 REGISTER_OP("IteratorFromStringHandleV2")
818     .Input("string_handle: string")
819     .Output("resource_handle: resource")
820     .Attr("output_types: list(type) >= 0 = []")
821     .Attr("output_shapes: list(shape) >= 0 = []")
822     .SetShapeFn(shape_inference::ScalarShape);
823 
824 REGISTER_OP("SerializeIterator")
825     .Input("resource_handle: resource")
826     .Attr("external_state_policy: int = 0")
827     .Output("serialized: variant")
__anon0b7d74c51e02(shape_inference::InferenceContext* c) 828     .SetShapeFn([](shape_inference::InferenceContext* c) {
829       c->set_output(0, c->Vector(c->UnknownDim()));
830       return Status::OK();
831     });
832 
833 REGISTER_OP("DeserializeIterator")
834     .Input("resource_handle: resource")
835     .Input("serialized: variant")
836     .SetShapeFn(shape_inference::NoOutputs);
837 
838 REGISTER_OP("DatasetToGraph")
839     .Input("input_dataset: variant")
840     .Attr("stateful_whitelist: list(string) >= 0 = []")
841     .Attr("allow_stateful: bool = false")
842     .Attr("strip_device_assignment: bool = false")
843     .Output("graph: string")
844     .SetShapeFn(shape_inference::ScalarShape);
845 
846 REGISTER_OP("DatasetToGraphV2")
847     .Input("input_dataset: variant")
848     .Attr("external_state_policy: int = 0")
849     .Attr("strip_device_assignment: bool = false")
850     .Output("graph: string")
851     .SetShapeFn(shape_inference::ScalarShape);
852 
853 REGISTER_OP("OptimizeDataset")
854     .Input("input_dataset: variant")
855     .Input("optimizations: string")
856     .Output("handle: variant")
857     .Attr("output_types: list(type) >= 1")
858     .Attr("output_shapes: list(shape) >= 1")
859     .Attr("optimization_configs: list(string) = []")
860     .SetShapeFn(shape_inference::ScalarShape);
861 
862 REGISTER_OP("OptimizeDatasetV2")
863     .Input("input_dataset: variant")
864     .Input("optimizations_enabled: string")
865     .Input("optimizations_disabled: string")
866     .Input("optimizations_default: string")
867     .Output("handle: variant")
868     .Attr("output_types: list(type) >= 1")
869     .Attr("output_shapes: list(shape) >= 1")
870     .Attr("optimization_configs: list(string) = []")
871     .SetShapeFn(shape_inference::ScalarShape);
872 
873 REGISTER_OP("OptionalFromValue")
874     .Input("components: Toutput_types")
875     .Output("optional: variant")
876     .Attr("Toutput_types: list(type) >= 1")
877     .SetTypeConstructor(full_type::Unary(TFT_OPTIONAL, "Toutput_types"))
__anon0b7d74c51f02(shape_inference::InferenceContext* c) 878     .SetShapeFn([](shape_inference::InferenceContext* c) {
879       std::vector<DataType> dtypes;
880       TF_RETURN_IF_ERROR(c->GetAttr("Toutput_types", &dtypes));
881       c->set_output(0, c->Scalar());
882       std::vector<shape_inference::ShapeAndType> shapes_and_types;
883       shapes_and_types.reserve(c->num_inputs());
884       const FullTypeDef& ret_types = c->ret_types();
885       for (int i = 0; i < c->num_inputs(); ++i) {
886         // TODO(mdan): output_type(i) == optional is incorrect.
887         // "Optional" is the type of the the whole container, not of individual
888         // elements.
889         //
890         // Why ret_types.args(0) and not args(i) --
891         // For example if Toutput_types is (int32, float32), then
892         // ret_types.args[0] (i.e. the 0th output) is
893         // Optional[Record[Tensor[int32, s1], Tensor[float32, s2]]]
894         // set_output_handle_shapes_and_types tracks the same thing, but in
895         // a transposed way:
896         // {ShapeAndType(in32, s1, Optional), ShapeAndType(in32, s2, Optional)}
897         // That should be corrected in the future (see todo above).
898         shapes_and_types.emplace_back(c->input(i), dtypes[i],
899                                       ret_types.args(0));
900       }
901       c->set_output_handle_shapes_and_types(0, shapes_and_types);
902       return Status::OK();
903     });
904 
905 REGISTER_OP("OptionalNone")
906     .Output("optional: variant")
907     .SetShapeFn(shape_inference::ScalarShape);
908 
909 REGISTER_OP("OptionalHasValue")
910     .Input("optional: variant")
911     .Output("has_value: bool")
912     .SetShapeFn(shape_inference::ScalarShape);
913 
914 REGISTER_OP("OptionalGetValue")
915     .Input("optional: variant")
916     .Output("components: output_types")
917     .Attr("output_types: list(type) >= 1")
918     .Attr("output_shapes: list(shape) >= 1")
919     .SetShapeFn(shape_inference::DatasetIteratorShape);
920 
921 REGISTER_OP("IteratorGetNextAsOptional")
922     .Input("iterator: resource")
923     .Output("optional: variant")
924     .Attr("output_types: list(type) >= 1")
925     .Attr("output_shapes: list(shape) >= 1")
926     .SetShapeFn(shape_inference::ScalarShape);
927 
928 REGISTER_OP("ModelDataset")
929     .Input("input_dataset: variant")
930     .Output("handle: variant")
931     .Attr("algorithm: int = 0")
932     .Attr("cpu_budget: int = 0")
933     .Attr("ram_budget: int = 0")
934     .Attr("output_types: list(type) >= 1")
935     .Attr("output_shapes: list(shape) >= 1")
936     .SetShapeFn(shape_inference::ScalarShape);
937 
938 // TODO(b/124308749): Add a stateful version of MapDefun and use it when `f`
939 // is stateful.
940 REGISTER_OP("MapDefun")
941     .Input("arguments: Targuments")
942     .Input("captured_inputs: Tcaptured")
943     .Output("output: output_types")
944     .Attr("Targuments: list(type) >= 1")
945     .Attr("Tcaptured: list(type) >= 0 = []")
946     .Attr("output_types: list(type) >= 1")
947     .Attr("output_shapes: list(shape) >= 1")
948     .Attr("f: func")
949     .Attr("max_intra_op_parallelism: int = 1")
__anon0b7d74c52002(shape_inference::InferenceContext* c) 950     .SetShapeFn([](shape_inference::InferenceContext* c) {
951       std::vector<PartialTensorShape> output_shapes;
952       TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
953       DataTypeVector t_args;
954       TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args));
955       if (output_shapes.size() != c->num_outputs()) {
956         return errors::InvalidArgument(
957             "`output_shapes` must be the same length as `output_types` (",
958             output_shapes.size(), " vs. ", c->num_outputs(), ")");
959       }
960 
961       int64_t dim_zero = -1;
962       for (size_t i = 0; i < t_args.size(); ++i) {
963         if (c->Rank(c->input(i)) == 0) {
964           return errors::InvalidArgument(
965               "Arguments must have rank at least 1. Input ", i,
966               " has rank of 0.");
967         }
968         auto dim_handle = c->Dim(c->input(i), 0);
969         if (c->ValueKnown(dim_handle)) {
970           if (dim_zero == -1) {
971             dim_zero = c->Value(dim_handle);
972           } else if (c->Value(dim_handle) != dim_zero) {
973             return errors::InvalidArgument(
974                 "Arguments must have the same dimension 0.");
975           }
976         }
977       }
978 
979       for (size_t i = 0; i < output_shapes.size(); ++i) {
980         PartialTensorShape s({});
981         s = s.Concatenate(dim_zero);
982         s = s.Concatenate(output_shapes[i]);
983         shape_inference::ShapeHandle output_shape_handle;
984 
985         TF_RETURN_IF_ERROR(
986             c->MakeShapeFromPartialTensorShape(s, &output_shape_handle));
987         c->set_output(static_cast<int>(i), output_shape_handle);
988       }
989       return Status::OK();
990     });
991 
992 REGISTER_OP("WrapDatasetVariant")
993     .Input("input_handle: variant")
994     .Output("output_handle: variant")
995     .SetShapeFn(shape_inference::ScalarShape);
996 
997 REGISTER_OP("UnwrapDatasetVariant")
998     .Input("input_handle: variant")
999     .Output("output_handle: variant")
1000     .SetShapeFn(shape_inference::ScalarShape);
1001 
1002 REGISTER_OP("AnonymousMultiDeviceIterator")
1003     .Output("handle: resource")
1004     .Output("deleter: variant")
1005     .Attr("devices: list(string) >= 1")
1006     .Attr("output_types: list(type) >= 1")
1007     .Attr("output_shapes: list(shape) >= 1")
__anon0b7d74c52102(shape_inference::InferenceContext* c) 1008     .SetShapeFn([](shape_inference::InferenceContext* c) {
1009       c->set_output(0, c->Scalar());
1010       c->set_output(1, c->Scalar());
1011       return Status::OK();
1012     });
1013 
1014 REGISTER_OP("MultiDeviceIterator")
1015     .Output("handle: resource")
1016     .Attr("devices: list(string) >= 1")
1017     .Attr("shared_name: string")
1018     .Attr("container: string")
1019     .Attr("output_types: list(type) >= 1")
1020     .Attr("output_shapes: list(shape) >= 1")
1021     .SetShapeFn(shape_inference::ScalarShape);
1022 
1023 REGISTER_OP("MultiDeviceIteratorInit")
1024     .Input("dataset: variant")
1025     .Input("multi_device_iterator: resource")
1026     .Input("max_buffer_size: int64")
1027     .Output("incarnation_id: int64")
1028     .SetShapeFn(shape_inference::ScalarShape);
1029 
1030 REGISTER_OP("MultiDeviceIteratorGetNextFromShard")
1031     .Input("multi_device_iterator: resource")
1032     .Input("shard_num: int32")
1033     .Input("incarnation_id: int64")
1034     .Output("components: output_types")
1035     .Attr("output_types: list(type) >= 1")
1036     .Attr("output_shapes: list(shape) >= 1")
1037     .SetShapeFn(shape_inference::DatasetIteratorShape);
1038 
1039 REGISTER_OP("MultiDeviceIteratorToStringHandle")
1040     .Input("multi_device_iterator: resource")
1041     .Output("string_handle: string")
1042     .SetShapeFn(shape_inference::ScalarShape);
1043 
1044 REGISTER_OP("MultiDeviceIteratorFromStringHandle")
1045     .Input("string_handle: string")
1046     .Output("multi_device_iterator: resource")
1047     .Attr("output_types: list(type) >= 0 = []")
1048     .Attr("output_shapes: list(shape) >= 0 = []")
1049     .SetShapeFn(shape_inference::ScalarShape);
1050 
1051 REGISTER_OP("OptionsDataset")
1052     .Input("input_dataset: variant")
1053     .Output("handle: variant")
1054     .Attr("serialized_options: string")
1055     .Attr("output_types: list(type) >= 1")
1056     .Attr("output_shapes: list(shape) >= 1")
1057     .SetShapeFn(shape_inference::ScalarShape);
1058 
1059 REGISTER_OP("GetOptions")
1060     .Input("input_dataset: variant")
1061     .Output("serialized_options: string")
1062     .SetShapeFn(shape_inference::ScalarShape);
1063 
1064 REGISTER_OP("FinalizeDataset")
1065     .Input("input_dataset: variant")
1066     .Output("handle: variant")
1067     .Attr("has_captured_ref: bool = false")
1068     .Attr("output_types: list(type) >= 1")
1069     .Attr("output_shapes: list(shape) >= 1")
1070     .SetShapeFn(shape_inference::ScalarShape);
1071 
1072 }  // namespace tensorflow
1073