1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Experimental API for testing of tf.data.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import gen_experimental_dataset_ops 24 25 26def assert_next(transformations): 27 """A transformation that asserts which transformations happen next. 28 29 Transformations should be referred to by their base name, not including 30 version suffix. For example, use "Batch" instead of "BatchV2". "Batch" will 31 match any of "Batch", "BatchV1", "BatchV2", etc. 32 33 Args: 34 transformations: A `tf.string` vector `tf.Tensor` identifying the 35 transformations that are expected to happen next. 36 37 Returns: 38 A `Dataset` transformation function, which can be passed to 39 `tf.data.Dataset.apply`. 40 """ 41 42 def _apply_fn(dataset): 43 """Function from `Dataset` to `Dataset` that applies the transformation.""" 44 return _AssertNextDataset(dataset, transformations) 45 46 return _apply_fn 47 48 49def non_serializable(): 50 """A non-serializable identity transformation. 51 52 Returns: 53 A `Dataset` transformation function, which can be passed to 54 `tf.data.Dataset.apply`. 55 """ 56 57 def _apply_fn(dataset): 58 """Function from `Dataset` to `Dataset` that applies the transformation.""" 59 return _NonSerializableDataset(dataset) 60 61 return _apply_fn 62 63 64def sleep(sleep_microseconds): 65 """Sleeps for `sleep_microseconds` before producing each input element. 66 67 Args: 68 sleep_microseconds: The number of microseconds to sleep before producing an 69 input element. 70 71 Returns: 72 A `Dataset` transformation function, which can be passed to 73 `tf.data.Dataset.apply`. 74 """ 75 76 def _apply_fn(dataset): 77 return _SleepDataset(dataset, sleep_microseconds) 78 79 return _apply_fn 80 81 82class _AssertNextDataset(dataset_ops.UnaryUnchangedStructureDataset): 83 """A `Dataset` that asserts which transformations happen next.""" 84 85 def __init__(self, input_dataset, transformations): 86 """See `assert_next()` for details.""" 87 self._input_dataset = input_dataset 88 if transformations is None: 89 raise ValueError("At least one transformation should be specified") 90 self._transformations = ops.convert_to_tensor( 91 transformations, dtype=dtypes.string, name="transformations") 92 variant_tensor = ( 93 gen_experimental_dataset_ops.experimental_assert_next_dataset( 94 self._input_dataset._variant_tensor, # pylint: disable=protected-access 95 self._transformations, 96 **self._flat_structure)) 97 super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor) 98 99 100class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset): 101 """A `Dataset` that performs non-serializable identity transformation.""" 102 103 def __init__(self, input_dataset): 104 """See `non_serializable()` for details.""" 105 self._input_dataset = input_dataset 106 variant_tensor = ( 107 gen_experimental_dataset_ops.experimental_non_serializable_dataset( 108 self._input_dataset._variant_tensor, # pylint: disable=protected-access 109 **self._flat_structure)) 110 super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor) 111 112 113class _SleepDataset(dataset_ops.UnaryUnchangedStructureDataset): 114 """A `Dataset` that sleeps before producing each upstream element.""" 115 116 def __init__(self, input_dataset, sleep_microseconds): 117 self._input_dataset = input_dataset 118 self._sleep_microseconds = sleep_microseconds 119 variant_tensor = gen_experimental_dataset_ops.sleep_dataset( 120 self._input_dataset._variant_tensor, # pylint: disable=protected-access 121 self._sleep_microseconds, 122 **self._flat_structure) 123 super(_SleepDataset, self).__init__(input_dataset, variant_tensor) 124