1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrappers for Datasets and Iterators.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.util import deprecation 22from tensorflow.python.util.tf_export import tf_export 23 24 25@deprecation.deprecated(None, "Use `tf.data.Dataset.get_single_element()`.") 26@tf_export("data.experimental.get_single_element") 27def get_single_element(dataset): 28 """Returns the single element of the `dataset` as a nested structure of tensors. 29 30 The function enables you to use a `tf.data.Dataset` in a stateless 31 "tensor-in tensor-out" expression, without creating an iterator. 32 This facilitates the ease of data transformation on tensors using the 33 optimized `tf.data.Dataset` abstraction on top of them. 34 35 For example, lets consider a `preprocessing_fn` which would take as an 36 input the raw features and returns the processed feature along with 37 it's label. 38 39 ```python 40 def preprocessing_fn(raw_feature): 41 # ... the raw_feature is preprocessed as per the use-case 42 return feature 43 44 raw_features = ... # input batch of BATCH_SIZE elements. 45 dataset = (tf.data.Dataset.from_tensor_slices(raw_features) 46 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 47 .batch(BATCH_SIZE)) 48 49 processed_features = tf.data.experimental.get_single_element(dataset) 50 ``` 51 52 In the above example, the `raw_features` tensor of length=BATCH_SIZE 53 was converted to a `tf.data.Dataset`. Next, each of the `raw_feature` was 54 mapped using the `preprocessing_fn` and the processed features were 55 grouped into a single batch. The final `dataset` contains only one element 56 which is a batch of all the processed features. 57 58 NOTE: The `dataset` should contain only one element. 59 60 Now, instead of creating an iterator for the `dataset` and retrieving the 61 batch of features, the `tf.data.experimental.get_single_element()` function 62 is used to skip the iterator creation process and directly output the batch 63 of features. 64 65 This can be particularly useful when your tensor transformations are 66 expressed as `tf.data.Dataset` operations, and you want to use those 67 transformations while serving your model. 68 69 # Keras 70 71 ```python 72 73 model = ... # A pre-built or custom model 74 75 class PreprocessingModel(tf.keras.Model): 76 def __init__(self, model): 77 super().__init__(self) 78 self.model = model 79 80 @tf.function(input_signature=[...]) 81 def serving_fn(self, data): 82 ds = tf.data.Dataset.from_tensor_slices(data) 83 ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 84 ds = ds.batch(batch_size=BATCH_SIZE) 85 return tf.argmax( 86 self.model(tf.data.experimental.get_single_element(ds)), 87 axis=-1 88 ) 89 90 preprocessing_model = PreprocessingModel(model) 91 your_exported_model_dir = ... # save the model to this path. 92 tf.saved_model.save(preprocessing_model, your_exported_model_dir, 93 signatures={'serving_default': preprocessing_model.serving_fn}) 94 ``` 95 96 # Estimator 97 98 In the case of estimators, you need to generally define a `serving_input_fn` 99 which would require the features to be processed by the model while 100 inferencing. 101 102 ```python 103 def serving_input_fn(): 104 105 raw_feature_spec = ... # Spec for the raw_features 106 input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( 107 raw_feature_spec, default_batch_size=None) 108 ) 109 serving_input_receiver = input_fn() 110 raw_features = serving_input_receiver.features 111 112 def preprocessing_fn(raw_feature): 113 # ... the raw_feature is preprocessed as per the use-case 114 return feature 115 116 dataset = (tf.data.Dataset.from_tensor_slices(raw_features) 117 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 118 .batch(BATCH_SIZE)) 119 120 processed_features = tf.data.experimental.get_single_element(dataset) 121 122 # Please note that the value of `BATCH_SIZE` should be equal to 123 # the size of the leading dimension of `raw_features`. This ensures 124 # that `dataset` has only element, which is a pre-requisite for 125 # using `tf.data.experimental.get_single_element(dataset)`. 126 127 return tf.estimator.export.ServingInputReceiver( 128 processed_features, serving_input_receiver.receiver_tensors) 129 130 estimator = ... # A pre-built or custom estimator 131 estimator.export_saved_model(your_exported_model_dir, serving_input_fn) 132 ``` 133 134 Args: 135 dataset: A `tf.data.Dataset` object containing a single element. 136 137 Returns: 138 A nested structure of `tf.Tensor` objects, corresponding to the single 139 element of `dataset`. 140 141 Raises: 142 TypeError: if `dataset` is not a `tf.data.Dataset` object. 143 InvalidArgumentError: (at runtime) if `dataset` does not contain exactly 144 one element. 145 """ 146 if not isinstance(dataset, dataset_ops.DatasetV2): 147 raise TypeError("`dataset` must be a `tf.data.Dataset` object.") 148 149 return dataset.get_single_element() 150