• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/op_def_builder.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 
20 namespace tensorflow {
21 
22 // --------------------------------------------------------------------------
23 
24 // The ops in this section can be composed to define an input
25 // pipeline. Each op produces a DT_VARIANT tensor that represents
26 // a DAG of "dataset" objects. An "dataset" object can be converted
27 // to a stateful "iterator" by passing the "dataset" to the
28 // "MakeIterator" op.
29 //
30 // TODO(b/123753214): DT_VARIANT tensors that represent "dataset" objects are
31 // not presently serializable. To avoid issues with constant folding, ensure
32 // that any "source dataset" ops (i.e. ops that output a dataset and do not
33 // take one as input) are marked "stateful".
34 
35 REGISTER_OP("TensorDataset")
36     .Input("components: Toutput_types")
37     .Output("handle: variant")
38     .Attr("Toutput_types: list(type) >= 1")
39     .Attr("output_shapes: list(shape) >= 1")
40     .SetDoNotOptimize()  // TODO(b/123753214): Source dataset ops must
41                          // disable constant folding.
42     .SetShapeFn(shape_inference::ScalarShape);  // TODO(mrry): Validate that
43                                                 // `components` have shapes
44                                                 // compatible with
45                                                 // `output_shapes`.
46 
47 REGISTER_OP("TensorSliceDataset")
48     .Input("components: Toutput_types")
49     .Output("handle: variant")
50     .Attr("Toutput_types: list(type) >= 1")
51     .Attr("output_shapes: list(shape) >= 1")
52     .SetDoNotOptimize()  // TODO(b/123753214): Source dataset ops must
53                          // disable constant folding.
54     .SetShapeFn(shape_inference::ScalarShape);  // TODO(mrry): Validate that the
55                                                 // dim-0 slices of `components`
56                                                 // have shapes compatible with
57                                                 // `output_shapes`.
58 
59 REGISTER_OP("SparseTensorSliceDataset")
60     .Input("indices: int64")
61     .Input("values: Tvalues")
62     .Input("dense_shape: int64")
63     .Output("handle: variant")
64     .Attr("Tvalues: type")
65     .SetDoNotOptimize()  // TODO(b/123753214): Source dataset ops must
66                          // disable constant folding.
67     .SetShapeFn(shape_inference::ScalarShape);
68 
69 REGISTER_OP("GeneratorDataset")
70     .Input("init_func_other_args: Tinit_func_args")
71     .Input("next_func_other_args: Tnext_func_args")
72     .Input("finalize_func_other_args: Tfinalize_func_args")
73     .Output("handle: variant")
74     .Attr("init_func: func")
75     .Attr("next_func: func")
76     .Attr("finalize_func: func")
77     .Attr("Tinit_func_args: list(type) >= 0")
78     .Attr("Tnext_func_args: list(type) >= 0")
79     .Attr("Tfinalize_func_args: list(type) >= 0")
80     .Attr("output_types: list(type) >= 1")
81     .Attr("output_shapes: list(shape) >= 1")
82     .SetDoNotOptimize()  // TODO(b/123753214): Source dataset ops must
83                          // disable constant folding.
84     .SetShapeFn(shape_inference::ScalarShape);
85 
86 REGISTER_OP("ZipDataset")
87     .Input("input_datasets: N * variant")
88     .Output("handle: variant")
89     .Attr("output_types: list(type) >= 1")
90     .Attr("output_shapes: list(shape) >= 1")
91     .Attr("N: int >= 1")
92     .SetShapeFn(shape_inference::ScalarShape);
93 
94 REGISTER_OP("ConcatenateDataset")
95     .Input("input_dataset: variant")
96     .Input("another_dataset: variant")
97     .Output("handle: variant")
98     .Attr("output_types: list(type) >= 1")
99     .Attr("output_shapes: list(shape) >= 1")
100     .SetShapeFn(shape_inference::ScalarShape);
101 
102 REGISTER_OP("RepeatDataset")
103     .Input("input_dataset: variant")
104     .Input("count: int64")
105     .Output("handle: variant")
106     .Attr("output_types: list(type) >= 1")
107     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c00102(shape_inference::InferenceContext* c) 108     .SetShapeFn([](shape_inference::InferenceContext* c) {
109       shape_inference::ShapeHandle count_shape;
110       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape));
111       return shape_inference::ScalarShape(c);
112     });
113 
114 REGISTER_OP("TakeDataset")
115     .Input("input_dataset: variant")
116     .Input("count: int64")
117     .Output("handle: variant")
118     .Attr("output_types: list(type) >= 1")
119     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c00202(shape_inference::InferenceContext* c) 120     .SetShapeFn([](shape_inference::InferenceContext* c) {
121       shape_inference::ShapeHandle count_shape;
122       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape));
123       return shape_inference::ScalarShape(c);
124     });
125 
126 REGISTER_OP("SkipDataset")
127     .Input("input_dataset: variant")
128     .Input("count: int64")
129     .Output("handle: variant")
130     .Attr("output_types: list(type) >= 1")
131     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c00302(shape_inference::InferenceContext* c) 132     .SetShapeFn([](shape_inference::InferenceContext* c) {
133       shape_inference::ShapeHandle count_shape;
134       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape));
135       return shape_inference::ScalarShape(c);
136     });
137 
138 REGISTER_OP("MapDataset")
139     .Input("input_dataset: variant")
140     .Input("other_arguments: Targuments")
141     .Output("handle: variant")
142     .Attr("f: func")
143     .Attr("Targuments: list(type) >= 0")
144     .Attr("output_types: list(type) >= 1")
145     .Attr("output_shapes: list(shape) >= 1")
146     .Attr("use_inter_op_parallelism: bool = true")
147     .Attr("preserve_cardinality: bool = false")
148     .SetShapeFn(shape_inference::ScalarShape);
149 
150 REGISTER_OP("ParallelMapDataset")
151     .Input("input_dataset: variant")
152     .Input("other_arguments: Targuments")
153     .Input("num_parallel_calls: int32")
154     .Output("handle: variant")
155     .Attr("f: func")
156     .Attr("Targuments: list(type) >= 0")
157     .Attr("output_types: list(type) >= 1")
158     .Attr("output_shapes: list(shape) >= 1")
159     .Attr("use_inter_op_parallelism: bool = true")
160     .Attr("sloppy: bool = false")
161     .Attr("preserve_cardinality: bool = false")
162     .SetShapeFn(shape_inference::ScalarShape);
163 
164 REGISTER_OP("ParallelMapDatasetV2")
165     .Input("input_dataset: variant")
166     .Input("other_arguments: Targuments")
167     .Input("num_parallel_calls: int64")
168     .Output("handle: variant")
169     .Attr("f: func")
170     .Attr("Targuments: list(type) >= 0")
171     .Attr("output_types: list(type) >= 1")
172     .Attr("output_shapes: list(shape) >= 1")
173     .Attr("use_inter_op_parallelism: bool = true")
174     // "true", "false", or "default".
175     .Attr("deterministic: string = 'default'")
176     .Attr("preserve_cardinality: bool = false")
177     .SetShapeFn(shape_inference::ScalarShape);
178 
179 REGISTER_OP("PrefetchDataset")
180     .Input("input_dataset: variant")
181     .Input("buffer_size: int64")
182     .Output("handle: variant")
183     .Attr("output_types: list(type) >= 1")
184     .Attr("output_shapes: list(shape) >= 1")
185     .Attr("slack_period: int = 0")
186     .Attr("legacy_autotune: bool = true")
187     .Attr("buffer_size_min: int = 0")
__anon42b7b0c00402(shape_inference::InferenceContext* c) 188     .SetShapeFn([](shape_inference::InferenceContext* c) {
189       shape_inference::ShapeHandle unused;
190       // buffer_size should be a scalar.
191       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
192       return shape_inference::ScalarShape(c);
193     });
194 
195 REGISTER_OP("FlatMapDataset")
196     .Input("input_dataset: variant")
197     .Input("other_arguments: Targuments")
198     .Output("handle: variant")
199     .Attr("f: func")
200     .Attr("Targuments: list(type) >= 0")
201     .Attr("output_types: list(type) >= 1")
202     .Attr("output_shapes: list(shape) >= 1")
203     .SetShapeFn(shape_inference::ScalarShape);
204 
205 REGISTER_OP("InterleaveDataset")
206     .Input("input_dataset: variant")
207     .Input("other_arguments: Targuments")
208     .Input("cycle_length: int64")
209     .Input("block_length: int64")
210     .Output("handle: variant")
211     .Attr("f: func")
212     .Attr("Targuments: list(type) >= 0")
213     .Attr("output_types: list(type) >= 1")
214     .Attr("output_shapes: list(shape) >= 1")
215     .SetShapeFn(shape_inference::ScalarShape);
216 
217 REGISTER_OP("ParallelInterleaveDatasetV2")
218     .Input("input_dataset: variant")
219     .Input("other_arguments: Targuments")
220     .Input("cycle_length: int64")
221     .Input("block_length: int64")
222     .Input("num_parallel_calls: int64")
223     .Output("handle: variant")
224     .Attr("f: func")
225     .Attr("Targuments: list(type) >= 0")
226     .Attr("output_types: list(type) >= 1")
227     .Attr("output_shapes: list(shape) >= 1")
228     .Attr("sloppy: bool = false")
229     .SetShapeFn(shape_inference::ScalarShape);
230 
231 REGISTER_OP("ParallelInterleaveDatasetV3")
232     .Input("input_dataset: variant")
233     .Input("other_arguments: Targuments")
234     .Input("cycle_length: int64")
235     .Input("block_length: int64")
236     .Input("num_parallel_calls: int64")
237     .Output("handle: variant")
238     .Attr("f: func")
239     // "true", "false", or "default".
240     .Attr("deterministic: string = 'default'")
241     .Attr("Targuments: list(type) >= 0")
242     .Attr("output_types: list(type) >= 1")
243     .Attr("output_shapes: list(shape) >= 1")
244     .SetShapeFn(shape_inference::ScalarShape);
245 
246 // Like V3, but adds buffer_output_elements and prefetch_input_elements.
247 REGISTER_OP("ParallelInterleaveDatasetV4")
248     .Input("input_dataset: variant")
249     .Input("other_arguments: Targuments")
250     .Input("cycle_length: int64")
251     .Input("block_length: int64")
252     .Input("buffer_output_elements: int64")
253     .Input("prefetch_input_elements: int64")
254     .Input("num_parallel_calls: int64")
255     .Output("handle: variant")
256     .Attr("f: func")
257     // "true", "false", or "default".
258     .Attr("deterministic: string = 'default'")
259     .Attr("Targuments: list(type) >= 0")
260     .Attr("output_types: list(type) >= 1")
261     .Attr("output_shapes: list(shape) >= 1")
262     .SetShapeFn(shape_inference::ScalarShape);
263 
264 REGISTER_OP("FilterDataset")
265     .Input("input_dataset: variant")
266     .Input("other_arguments: Targuments")
267     .Output("handle: variant")
268     .Attr("predicate: func")
269     .Attr("Targuments: list(type) >= 0")
270     .Attr("output_types: list(type) >= 1")
271     .Attr("output_shapes: list(shape) >= 1")
272     .SetShapeFn(shape_inference::ScalarShape);
273 
274 // This op is no longer supported.
275 REGISTER_OP("FilterByLastComponentDataset")
276     .Input("input_dataset: variant")
277     .Output("output: variant")
278     .Attr("output_types: list(type) >= 1")
279     .Attr("output_shapes: list(shape) >= 1")
280     .SetShapeFn(shape_inference::ScalarShape);
281 
282 REGISTER_OP("WindowDataset")
283     .Input("input_dataset: variant")
284     .Input("size: int64")
285     .Input("shift: int64")
286     .Input("stride: int64")
287     .Input("drop_remainder: bool")
288     .Output("handle: variant")
289     .Attr("output_types: list(type) >= 1")
290     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c00502(shape_inference::InferenceContext* c) 291     .SetShapeFn([](shape_inference::InferenceContext* c) {
292       shape_inference::ShapeHandle unused;
293       // size, shift, stride, and drop_remainder should be scalars.
294       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
295       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
296       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
297       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
298       return shape_inference::ScalarShape(c);
299     });
300 
301 REGISTER_OP("BatchDataset")
302     .Input("input_dataset: variant")
303     .Input("batch_size: int64")
304     .Output("handle: variant")
305     .Attr("output_types: list(type) >= 1")
306     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c00602(shape_inference::InferenceContext* c) 307     .SetShapeFn([](shape_inference::InferenceContext* c) {
308       shape_inference::ShapeHandle unused;
309       // batch_size should be a scalar.
310       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
311       return shape_inference::ScalarShape(c);
312     });
313 
314 REGISTER_OP("BatchDatasetV2")
315     .Input("input_dataset: variant")
316     .Input("batch_size: int64")
317     .Input("drop_remainder: bool")
318     .Output("handle: variant")
319     .Attr("parallel_copy: bool = false")
320     .Attr("output_types: list(type) >= 1")
321     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c00702(shape_inference::InferenceContext* c) 322     .SetShapeFn([](shape_inference::InferenceContext* c) {
323       shape_inference::ShapeHandle unused;
324       // batch_size should be a scalar.
325       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
326       // drop_remainder should be a scalar.
327       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
328       return shape_inference::ScalarShape(c);
329     });
330 
331 REGISTER_OP("ParallelBatchDataset")
332     .Input("input_dataset: variant")
333     .Input("batch_size: int64")
334     .Input("num_parallel_calls: int64")
335     .Input("drop_remainder: bool")
336     .Output("handle: variant")
337     .Attr("output_types: list(type) >= 1")
338     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c00802(shape_inference::InferenceContext* c) 339     .SetShapeFn([](shape_inference::InferenceContext* c) {
340       shape_inference::ShapeHandle unused;
341       // batch_size should be a scalar.
342       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
343       // num_parallel_calls should be a scalar.
344       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
345       // drop_remainder should be a scalar.
346       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
347       return shape_inference::ScalarShape(c);
348     });
349 
350 REGISTER_OP("ShardDataset")
351     .Input("input_dataset: variant")
352     .Input("num_shards: int64")
353     .Input("index: int64")
354     .Output("handle: variant")
355     .Attr("require_non_empty: bool = false")
356     .Attr("output_types: list(type) >= 1")
357     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c00902(shape_inference::InferenceContext* c) 358     .SetShapeFn([](shape_inference::InferenceContext* c) {
359       shape_inference::ShapeHandle unused;
360       // num_shards should be a scalar.
361       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
362       // index should be a scalar.
363       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
364       return shape_inference::ScalarShape(c);
365     });
366 
367 // TODO(mrry): Validate that `padded_shapes` are all vectors, the lengths of
368 // `output_types` and `output_shapes` are `N` the `output_shapes` are (as far as
369 // possible to tell statically) compatible with `padded_shapes`, and that
370 // `padding_values` are all scalars.
371 REGISTER_OP("PaddedBatchDataset")
372     .Input("input_dataset: variant")
373     .Input("batch_size: int64")
374     .Input("padded_shapes: N * int64")
375     .Input("padding_values: Toutput_types")
376     .Output("handle: variant")
377     .Attr("Toutput_types: list(type) >= 1")
378     .Attr("output_shapes: list(shape) >= 1")
379     .Attr("N: int >= 1")
__anon42b7b0c00a02(shape_inference::InferenceContext* c) 380     .SetShapeFn([](shape_inference::InferenceContext* c) {
381       shape_inference::ShapeHandle unused;
382       // batch_size should be a scalar.
383       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
384       return shape_inference::ScalarShape(c);
385     });
386 
387 REGISTER_OP("PaddedBatchDatasetV2")
388     .Input("input_dataset: variant")
389     .Input("batch_size: int64")
390     .Input("padded_shapes: N * int64")
391     .Input("padding_values: Toutput_types")
392     .Input("drop_remainder: bool")
393     .Output("handle: variant")
394     .Attr("parallel_copy: bool = false")
395     .Attr("Toutput_types: list(type) >= 1")
396     .Attr("output_shapes: list(shape) >= 1")
397     .Attr("N: int >= 1")
__anon42b7b0c00b02(shape_inference::InferenceContext* c) 398     .SetShapeFn([](shape_inference::InferenceContext* c) {
399       shape_inference::ShapeHandle unused;
400       // batch_size should be a scalar.
401       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
402       // drop_remainder should be a scalar.
403       TF_RETURN_IF_ERROR(
404           c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
405       return shape_inference::ScalarShape(c);
406     });
407 
408 REGISTER_OP("RangeDataset")
409     .Input("start: int64")
410     .Input("stop: int64")
411     .Input("step: int64")
412     .Output("handle: variant")
413     .Attr("output_types: list(type) >= 1")
414     .Attr("output_shapes: list(shape) >= 1")
415     .SetDoNotOptimize()  // TODO(b/123753214): Source dataset ops must
416                          // disable constant folding.
__anon42b7b0c00c02(shape_inference::InferenceContext* c) 417     .SetShapeFn([](shape_inference::InferenceContext* c) {
418       shape_inference::ShapeHandle unused;
419       // start, stop, and step should be scalars.
420       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
421       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
422       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
423       return shape_inference::ScalarShape(c);
424     });
425 
426 REGISTER_OP("AnonymousSeedGenerator")
427     .Input("seed: int64")
428     .Input("seed2: int64")
429     .Input("reshuffle: bool")
430     .Output("handle: resource")
431     .Output("deleter: variant")
__anon42b7b0c00d02(shape_inference::InferenceContext* c) 432     .SetShapeFn([](shape_inference::InferenceContext* c) {
433       c->set_output(0, c->Scalar());
434       c->set_output(1, c->Scalar());
435       return Status::OK();
436     });
437 
438 REGISTER_OP("DatasetCardinality")
439     .Input("input_dataset: variant")
440     .Output("cardinality: int64")
441     .SetShapeFn(shape_inference::ScalarShape);
442 
443 REGISTER_OP("DeleteSeedGenerator")
444     .Input("handle: resource")
445     .Input("deleter: variant")
446     .SetShapeFn(shape_inference::NoOutputs);
447 
448 // Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator.
449 REGISTER_OP("AnonymousRandomSeedGenerator")
450     .Input("seed: int64")
451     .Input("seed2: int64")
452     .Output("handle: resource")
453     .Output("deleter: variant")
__anon42b7b0c00e02(shape_inference::InferenceContext* c) 454     .SetShapeFn([](shape_inference::InferenceContext* c) {
455       c->set_output(0, c->Scalar());
456       c->set_output(1, c->Scalar());
457       return Status::OK();
458     });
459 
460 // Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator.
461 REGISTER_OP("DeleteRandomSeedGenerator")
462     .Input("handle: resource")
463     .Input("deleter: variant")
464     .SetShapeFn(shape_inference::NoOutputs);
465 
466 REGISTER_OP("DummySeedGenerator")
467     .Output("handle: resource")
__anon42b7b0c00f02(shape_inference::InferenceContext* c) 468     .SetShapeFn([](shape_inference::InferenceContext* c) {
469       c->set_output(0, c->Scalar());
470       return Status::OK();
471     });
472 
473 REGISTER_OP("ShuffleDataset")
474     .Input("input_dataset: variant")
475     .Input("buffer_size: int64")
476     .Input("seed: int64")
477     .Input("seed2: int64")
478     .Output("handle: variant")
479     .Attr("reshuffle_each_iteration: bool = true")
480     .Attr("output_types: list(type) >= 1")
481     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c01002(shape_inference::InferenceContext* c) 482     .SetShapeFn([](shape_inference::InferenceContext* c) {
483       shape_inference::ShapeHandle unused;
484       // buffer_size, seed, and seed2 should be scalars.
485       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
486       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
487       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
488       return shape_inference::ScalarShape(c);
489     });
490 
491 REGISTER_OP("ShuffleDatasetV2")
492     .Input("input_dataset: variant")
493     .Input("buffer_size: int64")
494     .Input("seed_generator: resource")
495     .Output("handle: variant")
496     .Attr("output_types: list(type) >= 1")
497     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c01102(shape_inference::InferenceContext* c) 498     .SetShapeFn([](shape_inference::InferenceContext* c) {
499       shape_inference::ShapeHandle unused;
500       // buffer_size and seed_generator should be scalars.
501       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
502       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
503       return shape_inference::ScalarShape(c);
504     });
505 
506 REGISTER_OP("ShuffleDatasetV3")
507     .Input("input_dataset: variant")
508     .Input("buffer_size: int64")
509     .Input("seed: int64")
510     .Input("seed2: int64")
511     .Input("seed_generator: resource")
512     .Output("handle: variant")
513     .Attr("reshuffle_each_iteration: bool = true")
514     .Attr("output_types: list(type) >= 1")
515     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c01202(shape_inference::InferenceContext* c) 516     .SetShapeFn([](shape_inference::InferenceContext* c) {
517       shape_inference::ShapeHandle unused;
518       // buffer_size, seed, seed2, and seed_generator should be scalars.
519       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
520       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
521       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
522       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
523       return shape_inference::ScalarShape(c);
524     });
525 
526 REGISTER_OP("ShuffleAndRepeatDataset")
527     .Input("input_dataset: variant")
528     .Input("buffer_size: int64")
529     .Input("seed: int64")
530     .Input("seed2: int64")
531     .Input("count: int64")
532     .Output("handle: variant")
533     .Attr("output_types: list(type) >= 1")
534     .Attr("output_shapes: list(shape) >= 1")
535     .Attr("reshuffle_each_iteration: bool = true")
__anon42b7b0c01302(shape_inference::InferenceContext* c) 536     .SetShapeFn([](shape_inference::InferenceContext* c) {
537       shape_inference::ShapeHandle unused;
538       // buffer_size, seed, seed2, and count should be scalars.
539       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
540       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
541       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
542       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
543       return shape_inference::ScalarShape(c);
544     });
545 
546 REGISTER_OP("ShuffleAndRepeatDatasetV2")
547     .Input("input_dataset: variant")
548     .Input("buffer_size: int64")
549     .Input("seed: int64")
550     .Input("seed2: int64")
551     .Input("count: int64")
552     .Input("seed_generator: resource")
553     .Output("handle: variant")
554     .Attr("reshuffle_each_iteration: bool = true")
555     .Attr("output_types: list(type) >= 1")
556     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c01402(shape_inference::InferenceContext* c) 557     .SetShapeFn([](shape_inference::InferenceContext* c) {
558       shape_inference::ShapeHandle unused;
559       // buffer_size, seed, seed2, count, and seed_generator should be scalars.
560       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
561       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
562       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
563       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
564       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
565       return shape_inference::ScalarShape(c);
566     });
567 
568 REGISTER_OP("AnonymousMemoryCache")
569     .Output("handle: resource")
570     .Output("deleter: variant")
__anon42b7b0c01502(shape_inference::InferenceContext* c) 571     .SetShapeFn([](shape_inference::InferenceContext* c) {
572       c->set_output(0, c->Scalar());
573       c->set_output(1, c->Scalar());
574       return Status::OK();
575     });
576 
577 REGISTER_OP("DeleteMemoryCache")
578     .Input("handle: resource")
579     .Input("deleter: variant")
580     .SetShapeFn(shape_inference::NoOutputs);
581 
582 REGISTER_OP("DummyMemoryCache")
583     .Output("handle: resource")
__anon42b7b0c01602(shape_inference::InferenceContext* c) 584     .SetShapeFn([](shape_inference::InferenceContext* c) {
585       c->set_output(0, c->Scalar());
586       return Status::OK();
587     });
588 
589 REGISTER_OP("CacheDataset")
590     .Input("input_dataset: variant")
591     .Input("filename: string")
592     .Output("handle: variant")
593     .Attr("output_types: list(type) >= 1")
594     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c01702(shape_inference::InferenceContext* c) 595     .SetShapeFn([](shape_inference::InferenceContext* c) {
596       shape_inference::ShapeHandle unused;
597       // filename should be a scalar.
598       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
599       return shape_inference::ScalarShape(c);
600     });
601 
602 REGISTER_OP("CacheDatasetV2")
603     .Input("input_dataset: variant")
604     .Input("filename: string")
605     .Input("cache: resource")
606     .Output("handle: variant")
607     .Attr("output_types: list(type) >= 1")
608     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c01802(shape_inference::InferenceContext* c) 609     .SetShapeFn([](shape_inference::InferenceContext* c) {
610       shape_inference::ShapeHandle unused;
611       // filename should be a scalar.
612       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
613       // cache should be a scalar.
614       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
615       return shape_inference::ScalarShape(c);
616     });
617 
618 REGISTER_OP("TextLineDataset")
619     .Input("filenames: string")
620     .Input("compression_type: string")
621     .Input("buffer_size: int64")
622     .Output("handle: variant")
623     .SetDoNotOptimize()  // TODO(b/123753214): Source dataset ops must
624                          // disable constant folding.
__anon42b7b0c01902(shape_inference::InferenceContext* c) 625     .SetShapeFn([](shape_inference::InferenceContext* c) {
626       shape_inference::ShapeHandle unused;
627       // `filenames` must be a scalar or a vector.
628       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
629       // `compression_type` could only be a scalar.
630       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
631       // `buffer_size` could only be a scalar.
632       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
633       return shape_inference::ScalarShape(c);
634     });
635 
636 REGISTER_OP("FixedLengthRecordDataset")
637     .Input("filenames: string")
638     .Input("header_bytes: int64")
639     .Input("record_bytes: int64")
640     .Input("footer_bytes: int64")
641     .Input("buffer_size: int64")
642     .Output("handle: variant")
643     .SetDoNotOptimize()  // TODO(b/123753214): Source dataset ops must
644                          // disable constant folding.
__anon42b7b0c01a02(shape_inference::InferenceContext* c) 645     .SetShapeFn([](shape_inference::InferenceContext* c) {
646       shape_inference::ShapeHandle unused;
647       // `filenames` must be a scalar or a vector.
648       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
649       // header_bytes, record_bytes, footer_bytes, buffer_size should be
650       // scalars.
651       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
652       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
653       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
654       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
655       return shape_inference::ScalarShape(c);
656     });
657 
658 REGISTER_OP("FixedLengthRecordDatasetV2")
659     .Input("filenames: string")
660     .Input("header_bytes: int64")
661     .Input("record_bytes: int64")
662     .Input("footer_bytes: int64")
663     .Input("buffer_size: int64")
664     .Input("compression_type: string")
665     .Output("handle: variant")
666     .SetDoNotOptimize()  // TODO(b/123753214): Source dataset ops must
667                          // disable constant folding.
__anon42b7b0c01b02(shape_inference::InferenceContext* c) 668     .SetShapeFn([](shape_inference::InferenceContext* c) {
669       shape_inference::ShapeHandle unused;
670       // `filenames` must be a scalar or a vector.
671       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
672       // header_bytes, record_bytes, footer_bytes, buffer_size should be
673       // scalars.
674       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
675       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
676       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
677       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
678       return shape_inference::ScalarShape(c);
679     });
680 
681 REGISTER_OP("TFRecordDataset")
682     .Input("filenames: string")
683     .Input("compression_type: string")
684     .Input("buffer_size: int64")
685     .Output("handle: variant")
686     .SetDoNotOptimize()  // TODO(b/123753214): Source dataset ops must
687                          // disable constant folding.
__anon42b7b0c01c02(shape_inference::InferenceContext* c) 688     .SetShapeFn([](shape_inference::InferenceContext* c) {
689       shape_inference::ShapeHandle unused;
690       // `filenames` must be a scalar or a vector.
691       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
692       // `compression_type` could only be a scalar.
693       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
694       // `buffer_size` could only be a scalar.
695       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
696       return shape_inference::ScalarShape(c);
697     });
698 
699 REGISTER_OP("Iterator")
700     .Output("handle: resource")
701     .Attr("shared_name: string")
702     .Attr("container: string")
703     .Attr("output_types: list(type) >= 1")
704     .Attr("output_shapes: list(shape) >= 1")
705     .SetShapeFn(shape_inference::ScalarShape);
706 
707 REGISTER_OP("IteratorV2")
708     .Output("handle: resource")
709     .Attr("shared_name: string")
710     .Attr("container: string")
711     .Attr("output_types: list(type) >= 1")
712     .Attr("output_shapes: list(shape) >= 1")
713     .SetShapeFn(shape_inference::ScalarShape);
714 
715 REGISTER_OP("AnonymousIterator")
716     .Output("handle: resource")
717     .Attr("output_types: list(type) >= 1")
718     .Attr("output_shapes: list(shape) >= 1")
719     .SetShapeFn(shape_inference::ScalarShape);
720 
721 REGISTER_OP("AnonymousIteratorV2")
722     .Output("handle: resource")
723     .Output("deleter: variant")
724     .Attr("output_types: list(type) >= 1")
725     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c01d02(shape_inference::InferenceContext* c) 726     .SetShapeFn([](shape_inference::InferenceContext* c) {
727       c->set_output(0, c->Scalar());
728       c->set_output(1, c->Scalar());
729       return Status::OK();
730     });
731 
732 REGISTER_OP("DeleteIterator")
733     .Input("handle: resource")
734     .Input("deleter: variant")
735     .SetShapeFn(shape_inference::NoOutputs);
736 
737 REGISTER_OP("DeleteMultiDeviceIterator")
738     .Input("multi_device_iterator: resource")
739     .Input("iterators: N * resource")
740     .Input("deleter: variant")
741     .Attr("N: int >= 0")
742     .SetShapeFn(shape_inference::NoOutputs);
743 
744 REGISTER_OP("MakeIterator")
745     .Input("dataset: variant")
746     .Input("iterator: resource")
747     .SetShapeFn(shape_inference::NoOutputs);
748 
749 REGISTER_OP("OneShotIterator")
750     .Output("handle: resource")
751     .Attr("dataset_factory: func")
752     .Attr("output_types: list(type) >= 1")
753     .Attr("output_shapes: list(shape) >= 1")
754     .Attr("container: string = ''")
755     .Attr("shared_name: string = ''")
756     .SetIsStateful()
757     .SetShapeFn(shape_inference::ScalarShape);
758 
759 REGISTER_OP("IteratorGetNext")
760     .Input("iterator: resource")
761     .Output("components: output_types")
762     .Attr("output_types: list(type) >= 1")
763     .Attr("output_shapes: list(shape) >= 1")
764     .SetShapeFn(shape_inference::DatasetIteratorShape);
765 
766 REGISTER_OP("IteratorGetNextSync")
767     .Input("iterator: resource")
768     .Output("components: output_types")
769     .Attr("output_types: list(type) >= 1")
770     .Attr("output_shapes: list(shape) >= 1")
771     .SetShapeFn(shape_inference::DatasetIteratorShape);
772 
773 // TODO(b/124308596): Instead of conservatively marking this op as stateful,
774 // implement a mechanism to determine whether `dataset` has a side-effect
775 // and use it to decide whether to use a stateless or stateful version of this
776 // op.
777 REGISTER_OP("DatasetToSingleElement")
778     .Input("dataset: variant")
779     .Output("components: output_types")
780     .Attr("output_types: list(type) >= 1")
781     .Attr("output_shapes: list(shape) >= 1")
782     .SetIsStateful()
783     .SetShapeFn(shape_inference::DatasetIteratorShape);
784 
785 // TODO(b/124308596): Instead of conservatively marking this op as stateful,
786 // implement a mechanism to determine whether `dataset` has a side-effect
787 // and use it to decide whether to use a stateless or stateful version of this
788 // op.
789 REGISTER_OP("ReduceDataset")
790     .Input("input_dataset: variant")
791     .Input("initial_state: Tstate")
792     .Input("other_arguments: Targuments")
793     .Output("components: output_types")
794     .Attr("f: func")
795     .Attr("Tstate: list(type) >= 1")
796     .Attr("Targuments: list(type) >= 0")
797     .Attr("output_types: list(type) >= 1")
798     .Attr("output_shapes: list(shape) >= 1")
799     .Attr("use_inter_op_parallelism: bool = true")
800     .SetIsStateful()
801     .SetShapeFn(shape_inference::DatasetIteratorShape);
802 
803 REGISTER_OP("IteratorToStringHandle")
804     .Input("resource_handle: resource")
805     .Output("string_handle: string")
806     .SetShapeFn(shape_inference::ScalarShape);
807 
808 REGISTER_OP("IteratorFromStringHandle")
809     .Input("string_handle: string")
810     .Output("resource_handle: resource")
811     .Attr("output_types: list(type) >= 0 = []")
812     .Attr("output_shapes: list(shape) >= 0 = []")
813     .SetShapeFn(shape_inference::ScalarShape);
814 
815 REGISTER_OP("IteratorFromStringHandleV2")
816     .Input("string_handle: string")
817     .Output("resource_handle: resource")
818     .Attr("output_types: list(type) >= 0 = []")
819     .Attr("output_shapes: list(shape) >= 0 = []")
820     .SetShapeFn(shape_inference::ScalarShape);
821 
822 REGISTER_OP("SerializeIterator")
823     .Input("resource_handle: resource")
824     .Attr("external_state_policy: int = 0")
825     .Output("serialized: variant")
__anon42b7b0c01e02(shape_inference::InferenceContext* c) 826     .SetShapeFn([](shape_inference::InferenceContext* c) {
827       c->set_output(0, c->Vector(c->UnknownDim()));
828       return Status::OK();
829     });
830 
831 REGISTER_OP("DeserializeIterator")
832     .Input("resource_handle: resource")
833     .Input("serialized: variant")
834     .SetShapeFn(shape_inference::NoOutputs);
835 
836 REGISTER_OP("DatasetToGraph")
837     .Input("input_dataset: variant")
838     .Attr("stateful_whitelist: list(string) >= 0 = []")
839     .Attr("allow_stateful: bool = false")
840     .Attr("strip_device_assignment: bool = false")
841     .Output("graph: string")
842     .SetShapeFn(shape_inference::ScalarShape);
843 
844 REGISTER_OP("DatasetToGraphV2")
845     .Input("input_dataset: variant")
846     .Attr("external_state_policy: int = 0")
847     .Attr("strip_device_assignment: bool = false")
848     .Output("graph: string")
849     .SetShapeFn(shape_inference::ScalarShape);
850 
851 REGISTER_OP("OptimizeDataset")
852     .Input("input_dataset: variant")
853     .Input("optimizations: string")
854     .Output("handle: variant")
855     .Attr("output_types: list(type) >= 1")
856     .Attr("output_shapes: list(shape) >= 1")
857     .Attr("optimization_configs: list(string) = []")
858     .SetShapeFn(shape_inference::ScalarShape);
859 
860 REGISTER_OP("OptimizeDatasetV2")
861     .Input("input_dataset: variant")
862     .Input("optimizations_enabled: string")
863     .Input("optimizations_disabled: string")
864     .Input("optimizations_default: string")
865     .Output("handle: variant")
866     .Attr("output_types: list(type) >= 1")
867     .Attr("output_shapes: list(shape) >= 1")
868     .Attr("optimization_configs: list(string) = []")
869     .SetShapeFn(shape_inference::ScalarShape);
870 
871 REGISTER_OP("OptionalFromValue")
872     .Input("components: Toutput_types")
873     .Output("optional: variant")
874     .Attr("Toutput_types: list(type) >= 1")
__anon42b7b0c01f02(shape_inference::InferenceContext* c) 875     .SetShapeFn([](shape_inference::InferenceContext* c) {
876       std::vector<DataType> dtypes;
877       TF_RETURN_IF_ERROR(c->GetAttr("Toutput_types", &dtypes));
878       c->set_output(0, c->Scalar());
879       std::vector<shape_inference::ShapeAndType> shapes_and_types;
880       shapes_and_types.reserve(c->num_inputs());
881       for (int i = 0; i < c->num_inputs(); ++i) {
882         shapes_and_types.emplace_back(c->input(i), dtypes[i], ST_OPTIONAL);
883       }
884       c->set_output_handle_shapes_and_types(0, shapes_and_types);
885       return Status::OK();
886     });
887 
888 REGISTER_OP("OptionalNone")
889     .Output("optional: variant")
890     .SetShapeFn(shape_inference::ScalarShape);
891 
892 REGISTER_OP("OptionalHasValue")
893     .Input("optional: variant")
894     .Output("has_value: bool")
895     .SetShapeFn(shape_inference::ScalarShape);
896 
897 REGISTER_OP("OptionalGetValue")
898     .Input("optional: variant")
899     .Output("components: output_types")
900     .Attr("output_types: list(type) >= 1")
901     .Attr("output_shapes: list(shape) >= 1")
902     .SetShapeFn(shape_inference::DatasetIteratorShape);
903 
904 REGISTER_OP("IteratorGetNextAsOptional")
905     .Input("iterator: resource")
906     .Output("optional: variant")
907     .Attr("output_types: list(type) >= 1")
908     .Attr("output_shapes: list(shape) >= 1")
909     .SetShapeFn(shape_inference::ScalarShape);
910 
911 REGISTER_OP("ModelDataset")
912     .Input("input_dataset: variant")
913     .Output("handle: variant")
914     .Attr("algorithm: int = 0")
915     .Attr("cpu_budget: int = 0")
916     .Attr("ram_budget: int = 0")
917     .Attr("output_types: list(type) >= 1")
918     .Attr("output_shapes: list(shape) >= 1")
919     .SetShapeFn(shape_inference::ScalarShape);
920 
921 // TODO(b/124308749): Add a stateful version of MapDefun and use it when `f`
922 // is stateful.
923 REGISTER_OP("MapDefun")
924     .Input("arguments: Targuments")
925     .Input("captured_inputs: Tcaptured")
926     .Output("output: output_types")
927     .Attr("Targuments: list(type) >= 1")
928     .Attr("Tcaptured: list(type) >= 0 = []")
929     .Attr("output_types: list(type) >= 1")
930     .Attr("output_shapes: list(shape) >= 1")
931     .Attr("f: func")
932     .Attr("max_intra_op_parallelism: int = 1")
__anon42b7b0c02002(shape_inference::InferenceContext* c) 933     .SetShapeFn([](shape_inference::InferenceContext* c) {
934       std::vector<PartialTensorShape> output_shapes;
935       TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
936       DataTypeVector t_args;
937       TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args));
938       if (output_shapes.size() != c->num_outputs()) {
939         return errors::InvalidArgument(
940             "`output_shapes` must be the same length as `output_types` (",
941             output_shapes.size(), " vs. ", c->num_outputs(), ")");
942       }
943 
944       int64 dim_zero = -1;
945       for (size_t i = 0; i < t_args.size(); ++i) {
946         if (c->Rank(c->input(i)) == 0) {
947           return errors::InvalidArgument(
948               "Arguments must have rank at least 1. Input ", i,
949               " has rank of 0.");
950         }
951         auto dim_handle = c->Dim(c->input(i), 0);
952         if (c->ValueKnown(dim_handle)) {
953           if (dim_zero == -1) {
954             dim_zero = c->Value(dim_handle);
955           } else if (c->Value(dim_handle) != dim_zero) {
956             return errors::InvalidArgument(
957                 "Arguments must have the same dimension 0.");
958           }
959         }
960       }
961 
962       for (size_t i = 0; i < output_shapes.size(); ++i) {
963         PartialTensorShape s({});
964         s = s.Concatenate(dim_zero);
965         s = s.Concatenate(output_shapes[i]);
966         shape_inference::ShapeHandle output_shape_handle;
967 
968         TF_RETURN_IF_ERROR(
969             c->MakeShapeFromPartialTensorShape(s, &output_shape_handle));
970         c->set_output(static_cast<int>(i), output_shape_handle);
971       }
972       return Status::OK();
973     });
974 
975 REGISTER_OP("WrapDatasetVariant")
976     .Input("input_handle: variant")
977     .Output("output_handle: variant")
978     .SetShapeFn(shape_inference::ScalarShape);
979 
980 REGISTER_OP("UnwrapDatasetVariant")
981     .Input("input_handle: variant")
982     .Output("output_handle: variant")
983     .SetShapeFn(shape_inference::ScalarShape);
984 
985 REGISTER_OP("AnonymousMultiDeviceIterator")
986     .Output("handle: resource")
987     .Output("deleter: variant")
988     .Attr("devices: list(string) >= 1")
989     .Attr("output_types: list(type) >= 1")
990     .Attr("output_shapes: list(shape) >= 1")
__anon42b7b0c02102(shape_inference::InferenceContext* c) 991     .SetShapeFn([](shape_inference::InferenceContext* c) {
992       c->set_output(0, c->Scalar());
993       c->set_output(1, c->Scalar());
994       return Status::OK();
995     });
996 
997 REGISTER_OP("MultiDeviceIterator")
998     .Output("handle: resource")
999     .Attr("devices: list(string) >= 1")
1000     .Attr("shared_name: string")
1001     .Attr("container: string")
1002     .Attr("output_types: list(type) >= 1")
1003     .Attr("output_shapes: list(shape) >= 1")
1004     .SetShapeFn(shape_inference::ScalarShape);
1005 
1006 REGISTER_OP("MultiDeviceIteratorInit")
1007     .Input("dataset: variant")
1008     .Input("multi_device_iterator: resource")
1009     .Input("max_buffer_size: int64")
1010     .Output("incarnation_id: int64")
1011     .SetShapeFn(shape_inference::ScalarShape);
1012 
1013 REGISTER_OP("MultiDeviceIteratorGetNextFromShard")
1014     .Input("multi_device_iterator: resource")
1015     .Input("shard_num: int32")
1016     .Input("incarnation_id: int64")
1017     .Output("components: output_types")
1018     .Attr("output_types: list(type) >= 1")
1019     .Attr("output_shapes: list(shape) >= 1")
1020     .SetShapeFn(shape_inference::DatasetIteratorShape);
1021 
1022 REGISTER_OP("MultiDeviceIteratorToStringHandle")
1023     .Input("multi_device_iterator: resource")
1024     .Output("string_handle: string")
1025     .SetShapeFn(shape_inference::ScalarShape);
1026 
1027 REGISTER_OP("MultiDeviceIteratorFromStringHandle")
1028     .Input("string_handle: string")
1029     .Output("multi_device_iterator: resource")
1030     .Attr("output_types: list(type) >= 0 = []")
1031     .Attr("output_shapes: list(shape) >= 0 = []")
1032     .SetShapeFn(shape_inference::ScalarShape);
1033 
1034 }  // namespace tensorflow
1035