• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Datasets and Iterators."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.util import deprecation
22from tensorflow.python.util.tf_export import tf_export
23
24
25@deprecation.deprecated(None, "Use `tf.data.Dataset.get_single_element()`.")
26@tf_export("data.experimental.get_single_element")
27def get_single_element(dataset):
28  """Returns the single element of the `dataset` as a nested structure of tensors.
29
30  The function enables you to use a `tf.data.Dataset` in a stateless
31  "tensor-in tensor-out" expression, without creating an iterator.
32  This facilitates the ease of data transformation on tensors using the
33  optimized `tf.data.Dataset` abstraction on top of them.
34
35  For example, lets consider a `preprocessing_fn` which would take as an
36  input the raw features and returns the processed feature along with
37  it's label.
38
39  ```python
40  def preprocessing_fn(raw_feature):
41    # ... the raw_feature is preprocessed as per the use-case
42    return feature
43
44  raw_features = ...  # input batch of BATCH_SIZE elements.
45  dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
46             .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
47             .batch(BATCH_SIZE))
48
49  processed_features = tf.data.experimental.get_single_element(dataset)
50  ```
51
52  In the above example, the `raw_features` tensor of length=BATCH_SIZE
53  was converted to a `tf.data.Dataset`. Next, each of the `raw_feature` was
54  mapped using the `preprocessing_fn` and the processed features were
55  grouped into a single batch. The final `dataset` contains only one element
56  which is a batch of all the processed features.
57
58  NOTE: The `dataset` should contain only one element.
59
60  Now, instead of creating an iterator for the `dataset` and retrieving the
61  batch of features, the `tf.data.experimental.get_single_element()` function
62  is used to skip the iterator creation process and directly output the batch
63  of features.
64
65  This can be particularly useful when your tensor transformations are
66  expressed as `tf.data.Dataset` operations, and you want to use those
67  transformations while serving your model.
68
69  # Keras
70
71  ```python
72
73  model = ... # A pre-built or custom model
74
75  class PreprocessingModel(tf.keras.Model):
76    def __init__(self, model):
77      super().__init__(self)
78      self.model = model
79
80    @tf.function(input_signature=[...])
81    def serving_fn(self, data):
82      ds = tf.data.Dataset.from_tensor_slices(data)
83      ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
84      ds = ds.batch(batch_size=BATCH_SIZE)
85      return tf.argmax(
86        self.model(tf.data.experimental.get_single_element(ds)),
87        axis=-1
88      )
89
90  preprocessing_model = PreprocessingModel(model)
91  your_exported_model_dir = ... # save the model to this path.
92  tf.saved_model.save(preprocessing_model, your_exported_model_dir,
93                signatures={'serving_default': preprocessing_model.serving_fn})
94  ```
95
96  # Estimator
97
98  In the case of estimators, you need to generally define a `serving_input_fn`
99  which would require the features to be processed by the model while
100  inferencing.
101
102  ```python
103  def serving_input_fn():
104
105    raw_feature_spec = ... # Spec for the raw_features
106    input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
107        raw_feature_spec, default_batch_size=None)
108    )
109    serving_input_receiver = input_fn()
110    raw_features = serving_input_receiver.features
111
112    def preprocessing_fn(raw_feature):
113      # ... the raw_feature is preprocessed as per the use-case
114      return feature
115
116    dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
117              .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
118              .batch(BATCH_SIZE))
119
120    processed_features = tf.data.experimental.get_single_element(dataset)
121
122    # Please note that the value of `BATCH_SIZE` should be equal to
123    # the size of the leading dimension of `raw_features`. This ensures
124    # that `dataset` has only element, which is a pre-requisite for
125    # using `tf.data.experimental.get_single_element(dataset)`.
126
127    return tf.estimator.export.ServingInputReceiver(
128        processed_features, serving_input_receiver.receiver_tensors)
129
130  estimator = ... # A pre-built or custom estimator
131  estimator.export_saved_model(your_exported_model_dir, serving_input_fn)
132  ```
133
134  Args:
135    dataset: A `tf.data.Dataset` object containing a single element.
136
137  Returns:
138    A nested structure of `tf.Tensor` objects, corresponding to the single
139    element of `dataset`.
140
141  Raises:
142    TypeError: if `dataset` is not a `tf.data.Dataset` object.
143    InvalidArgumentError: (at runtime) if `dataset` does not contain exactly
144      one element.
145  """
146  if not isinstance(dataset, dataset_ops.DatasetV2):
147    raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
148
149  return dataset.get_single_element()
150