1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""take-while dataset transformation.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.util import structure as structure_lib 22from tensorflow.python.framework import dtypes 23from tensorflow.python.ops import gen_experimental_dataset_ops 24from tensorflow.python.util.tf_export import tf_export 25 26 27class _TakeWhileDataset(dataset_ops.UnaryUnchangedStructureDataset): 28 """A dataset that stops iteration when `predicate` returns false.""" 29 30 def __init__(self, input_dataset, predicate): 31 """See `take_while()` for details.""" 32 33 self._input_dataset = input_dataset 34 wrapped_func = dataset_ops.StructuredFunctionWrapper( 35 predicate, 36 "tf.data.experimental.take_while()", 37 dataset=self._input_dataset) 38 39 if not wrapped_func.output_structure.is_compatible_with( 40 structure_lib.TensorStructure(dtypes.bool, [])): 41 raise ValueError("`predicate` must return a scalar boolean tensor.") 42 43 self._predicate = wrapped_func 44 var_tensor = gen_experimental_dataset_ops.experimental_take_while_dataset( 45 self._input_dataset._variant_tensor, # pylint: disable=protected-access 46 other_arguments=self._predicate.function.captured_inputs, 47 predicate=self._predicate.function, 48 **dataset_ops.flat_structure(self)) 49 super(_TakeWhileDataset, self).__init__(input_dataset, var_tensor) 50 51 def _functions(self): 52 return [self._predicate] 53 54 55@tf_export("data.experimental.take_while") 56def take_while(predicate): 57 """A transformation that stops dataset iteration based on a `predicate`. 58 59 Args: 60 predicate: A function that maps a nested structure of tensors (having shapes 61 and types defined by `self.output_shapes` and `self.output_types`) to a 62 scalar `tf.bool` tensor. 63 64 Returns: 65 A `Dataset` transformation function, which can be passed to 66 `tf.data.Dataset.apply`. 67 """ 68 69 def _apply_fn(dataset): 70 return _TakeWhileDataset(dataset, predicate) 71 72 return _apply_fn 73