• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the base ProcessingLayer and a subclass that uses Combiners."""
16
17import abc
18import collections
19
20import numpy as np
21
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.keras import backend
28from tensorflow.python.keras.engine import data_adapter
29from tensorflow.python.keras.engine.base_layer import Layer
30from tensorflow.python.keras.utils import tf_utils
31from tensorflow.python.keras.utils import version_utils
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import sparse_ops
34from tensorflow.python.ops import variables
35from tensorflow.python.ops.ragged import ragged_tensor
36from tensorflow.python.training.tracking import base as trackable
37from tensorflow.python.util.tf_export import keras_export
38
39
40@keras_export('keras.layers.experimental.preprocessing.PreprocessingLayer')
41class PreprocessingLayer(Layer, metaclass=abc.ABCMeta):
42  """Base class for Preprocessing Layers.
43
44  **Don't use this class directly: it's an abstract base class!** You may
45  be looking for one of the many built-in
46  [preprocessing layers](https://keras.io/guides/preprocessing_layers/)
47  instead.
48
49  Preprocessing layers are layers whose state gets computed before model
50  training starts. They do not get updated during training.
51  Most preprocessing layers implement an `adapt()` method for state computation.
52
53  The `PreprocessingLayer` class is the base class you would subclass to
54  implement your own preprocessing layers.
55
56  Attributes:
57    streaming: Whether a layer can be adapted multiple times without resetting
58      the state of the layer.
59  """
60  _must_restore_from_config = True
61
62  def __init__(self, streaming=True, **kwargs):
63    super(PreprocessingLayer, self).__init__(**kwargs)
64    self._streaming = streaming
65    self._is_compiled = False
66    self._is_adapted = False
67
68    # Sets `is_adapted=False` when `reset_state` is called.
69    self._reset_state_impl = self.reset_state
70    self.reset_state = self._reset_state_wrapper
71
72    self._adapt_function = None
73
74  @property
75  def streaming(self):
76    """Whether `adapt` can be called twice without resetting the state."""
77    return self._streaming
78
79  @property
80  def is_adapted(self):
81    """Whether the layer has been fit to data already."""
82    return self._is_adapted
83
84  def update_state(self, data):
85    """Accumulates statistics for the preprocessing layer.
86
87    Arguments:
88      data: A mini-batch of inputs to the layer.
89    """
90    raise NotImplementedError
91
92  def reset_state(self):  # pylint: disable=method-hidden
93    """Resets the statistics of the preprocessing layer."""
94    raise NotImplementedError
95
96  def merge_state(self, layers):
97    """Merge the statistics of multiple preprocessing layers.
98
99    This layer will contain the merged state.
100
101    Arguments:
102      layers: Layers whose statistics should be merge with the statistics of
103        this layer.
104    """
105    raise NotImplementedError
106
107  def finalize_state(self):
108    """Finalize the statistics for the preprocessing layer.
109
110    This method is called at the end of `adapt` or after restoring a serialized
111    preprocessing layer's state. This method handles any one-time operations
112    that should occur on the layer's state before `Layer.__call__`.
113    """
114    pass
115
116  def make_adapt_function(self):
117    """Creates a function to execute one step of `adapt`.
118
119    This method can be overridden to support custom adapt logic.
120    This method is called by `PreprocessingLayer.adapt`.
121
122    Typically, this method directly controls `tf.function` settings,
123    and delegates the actual state update logic to
124    `PreprocessingLayer.update_state`.
125
126    This function is cached the first time `PreprocessingLayer.adapt`
127    is called. The cache is cleared whenever `PreprocessingLayer.compile`
128    is called.
129
130    Returns:
131      Function. The function created by this method should accept a
132      `tf.data.Iterator`, retrieve a batch, and update the state of the
133      layer.
134    """
135    if self._adapt_function is not None:
136      return self._adapt_function
137
138    def adapt_step(iterator):
139      data = next(iterator)
140      self._adapt_maybe_build(data)
141      self.update_state(data)
142
143    if self._steps_per_execution.numpy().item() == 1:
144      adapt_fn = adapt_step
145    else:
146
147      def adapt_fn(iterator):
148        for _ in math_ops.range(self._steps_per_execution):
149          adapt_step(iterator)
150
151    if not self._run_eagerly:
152      adapt_fn = def_function.function(adapt_fn)
153
154    self._adapt_function = adapt_fn
155    return self._adapt_function
156
157  def compile(self, run_eagerly=None, steps_per_execution=None):
158    """Configures the layer for `adapt`.
159
160    Arguments:
161      run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s logic
162        will not be wrapped in a `tf.function`. Recommended to leave this as
163        `None` unless your `Model` cannot be run inside a `tf.function`.
164        steps_per_execution: Int. Defaults to 1. The number of batches to run
165          during each `tf.function` call. Running multiple batches inside a
166          single `tf.function` call can greatly improve performance on TPUs or
167          small models with a large Python overhead.
168    """
169    if steps_per_execution is None:
170      steps_per_execution = 1
171    self._configure_steps_per_execution(steps_per_execution)
172
173    if run_eagerly is None:
174      run_eagerly = self.dynamic
175    self._run_eagerly = run_eagerly
176
177    self._is_compiled = True
178
179  def adapt(self, data, batch_size=None, steps=None, reset_state=True):
180    """Fits the state of the preprocessing layer to the data being passed.
181
182    After calling `adapt` on a layer, a preprocessing layer's state will not
183    update during training. In order to make preprocessing layers efficient in
184    any distribution context, they are kept constant with respect to any
185    compiled `tf.Graph`s that call the layer. This does not affect the layer use
186    when adapting each layer only once, but if you adapt a layer multiple times
187    you will need to take care to re-compile any compiled functions as follows:
188
189     * If you are adding a preprocessing layer to a `keras.Model`, you need to
190       call `model.compile` after each subsequent call to `adapt`.
191     * If you are calling a preprocessing layer inside `tf.data.Dataset.map`,
192       you should call `map` again on the input `tf.data.Dataset` after each
193       `adapt`.
194     * If you are using a `tf.function` directly which calls a preprocessing
195       layer, you need to call `tf.function` again on your callable after
196       each subsequent call to `adapt`.
197
198    `tf.keras.Model` example with multiple adapts:
199
200    >>> layer = tf.keras.layers.experimental.preprocessing.Normalization(
201    ...     axis=None)
202    >>> layer.adapt([0, 2])
203    >>> model = tf.keras.Sequential(layer)
204    >>> model.predict([0, 1, 2])
205    array([-1.,  0.,  1.], dtype=float32)
206    >>> layer.adapt([-1, 1])
207    >>> model.compile() # This is needed to re-compile model.predict!
208    >>> model.predict([0, 1, 2])
209    array([0., 1., 2.], dtype=float32)
210
211    `tf.data.Dataset` example with multiple adapts:
212
213    >>> layer = tf.keras.layers.experimental.preprocessing.Normalization(
214    ...     axis=None)
215    >>> layer.adapt([0, 2])
216    >>> input_ds = tf.data.Dataset.range(3)
217    >>> normalized_ds = input_ds.map(layer)
218    >>> list(normalized_ds.as_numpy_iterator())
219    [array([-1.], dtype=float32),
220     array([0.], dtype=float32),
221     array([1.], dtype=float32)]
222    >>> layer.adapt([-1, 1])
223    >>> normalized_ds = input_ds.map(layer) # Re-map over the input dataset.
224    >>> list(normalized_ds.as_numpy_iterator())
225    [array([0.], dtype=float32),
226     array([1.], dtype=float32),
227     array([2.], dtype=float32)]
228
229    Arguments:
230        data: The data to train on. It can be passed either as a tf.data
231          Dataset, or as a numpy array.
232        batch_size: Integer or `None`.
233            Number of samples per state update.
234            If unspecified, `batch_size` will default to 32.
235            Do not specify the `batch_size` if your data is in the
236            form of datasets, generators, or `keras.utils.Sequence` instances
237            (since they generate batches).
238        steps: Integer or `None`.
239            Total number of steps (batches of samples)
240            When training with input tensors such as
241            TensorFlow data tensors, the default `None` is equal to
242            the number of samples in your dataset divided by
243            the batch size, or 1 if that cannot be determined. If x is a
244            `tf.data` dataset, and 'steps' is None, the epoch will run until
245            the input dataset is exhausted. When passing an infinitely
246            repeating dataset, you must specify the `steps` argument. This
247            argument is not supported with array inputs.
248        reset_state: Optional argument specifying whether to clear the state of
249          the layer at the start of the call to `adapt`, or whether to start
250          from the existing state. This argument may not be relevant to all
251          preprocessing layers: a subclass of PreprocessingLayer may choose to
252          throw if 'reset_state' is set to False.
253    """
254    _disallow_inside_tf_function('adapt')
255    if not version_utils.should_use_v2():
256      raise RuntimeError('`adapt` is only supported in tensorflow v2.')  # pylint: disable=g-doc-exception
257    if not self.streaming and self._is_adapted and not reset_state:
258      raise ValueError('{} does not supporting calling `adapt` twice without '
259                       'resetting the state.'.format(self.__class__.__name__))
260    if not self._is_compiled:
261      self.compile()  # Compile with defaults.
262    if self.built and reset_state:
263      self.reset_state()
264    data_handler = data_adapter.DataHandler(
265        data,
266        batch_size=batch_size,
267        steps_per_epoch=steps,
268        epochs=1,
269        steps_per_execution=self._steps_per_execution,
270        distribute=False)
271    self._adapt_function = self.make_adapt_function()
272    for _, iterator in data_handler.enumerate_epochs():
273      with data_handler.catch_stop_iteration():
274        for _ in data_handler.steps():
275          self._adapt_function(iterator)
276          if data_handler.should_sync:
277            context.async_wait()
278    self.finalize_state()
279    self._is_adapted = True
280
281  def _reset_state_wrapper(self):
282    """Calls `reset_state` and sets `adapted` to `False`."""
283    self._reset_state_impl()
284    self._is_adapted = False
285
286  @trackable.no_automatic_dependency_tracking
287  def _configure_steps_per_execution(self, steps_per_execution):
288    self._steps_per_execution = variables.Variable(
289        steps_per_execution,
290        dtype='int64',
291        aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA)
292
293  # TODO(omalleyt): Unify this logic with `Layer._maybe_build`.
294  def _adapt_maybe_build(self, data):
295    if not self.built:
296      try:
297        # If this is a Numpy array or tensor, we can get shape from .shape.
298        # If not, an attribute error will be thrown.
299        data_shape = data.shape
300        data_shape_nones = tuple([None] * len(data.shape))
301      except AttributeError:
302        # The input has an unknown number of dimensions.
303        data_shape = None
304        data_shape_nones = None
305
306      # TODO (b/159261555): move this to base layer build.
307      batch_input_shape = getattr(self, '_batch_input_shape', None)
308      if batch_input_shape is None:
309        # Set the number of dimensions.
310        self._batch_input_shape = data_shape_nones
311      self.build(data_shape)
312      self.built = True
313
314
315# TODO(omalleyt): This class will be gradually replaced.
316class CombinerPreprocessingLayer(PreprocessingLayer):
317  """Base class for PreprocessingLayers that do computation using a Combiner.
318
319  This class provides several helper methods to make creating a
320  PreprocessingLayer easier. It assumes that the core of your computation will
321  be done via a Combiner object. Subclassing this class to create a
322  PreprocessingLayer allows your layer to be compatible with distributed
323  computation.
324
325  This class is compatible with Tensorflow 2.0+.
326  """
327
328  def __init__(self, combiner, **kwargs):
329    super(CombinerPreprocessingLayer, self).__init__(**kwargs)
330    self.state_variables = collections.OrderedDict()
331    self._combiner = combiner
332    self._adapt_accumulator = None
333
334  def reset_state(self):  # pylint: disable=method-hidden
335    self._adapt_accumulator = None
336
337  @trackable.no_automatic_dependency_tracking
338  def update_state(self, data):
339    if self._adapt_accumulator is None:
340      self._adapt_accumulator = self._get_accumulator()
341    self._adapt_accumulator = self._combiner.compute(data,
342                                                     self._adapt_accumulator)
343
344  def merge_state(self, layers):
345    accumulators = ([self._get_accumulator()] +
346                    [l._get_accumulator() for l in layers])  # pylint: disable=protected-access
347    merged_accumulator = self._combiner.merge(accumulators)
348    self._set_accumulator(merged_accumulator)
349
350  def finalize_state(self):
351    if self._adapt_accumulator is not None:
352      self._set_accumulator(self._adapt_accumulator)
353
354  def compile(self, run_eagerly=None, steps_per_execution=None):
355    # TODO(omalleyt): Remove this once sublayers are switched to new APIs.
356    if run_eagerly is None:
357      run_eagerly = True
358    super(CombinerPreprocessingLayer, self).compile(
359        run_eagerly=run_eagerly, steps_per_execution=steps_per_execution)
360
361  def adapt(self, data, batch_size=None, steps=None, reset_state=True):
362    if not reset_state:
363      self._adapt_accumulator = self._combiner.restore(self._restore_updates())
364    super(CombinerPreprocessingLayer, self).adapt(
365        data, batch_size=batch_size, steps=steps, reset_state=reset_state)
366
367  def _add_state_variable(self,
368                          name,
369                          shape,
370                          dtype,
371                          initializer=None,
372                          partitioner=None,
373                          use_resource=None,
374                          **kwargs):
375    """Add a variable that can hold state which is updated during adapt().
376
377    Args:
378      name: Variable name.
379      shape: Variable shape. Defaults to scalar if unspecified.
380      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
381      initializer: initializer instance (callable).
382      partitioner: Partitioner to be passed to the `Trackable` API.
383      use_resource: Whether to use `ResourceVariable`
384      **kwargs: Additional keyword arguments. Accepted values are `getter` and
385        `collections`.
386
387    Returns:
388      The created variable.
389    """
390    weight = self.add_weight(
391        name=name,
392        shape=shape,
393        dtype=dtype,
394        initializer=initializer,
395        regularizer=None,
396        trainable=False,
397        constraint=None,
398        partitioner=partitioner,
399        use_resource=use_resource,
400        **kwargs)
401    # TODO(momernick): Do not allow collisions here.
402    self.state_variables[name] = weight
403    return weight
404
405  def _restore_updates(self):
406    """Recreates a dict of updates from the layer's weights."""
407    data_dict = {}
408    for name, var in self.state_variables.items():
409      data_dict[name] = var.numpy()
410    return data_dict
411
412  def _get_accumulator(self):
413    if self._is_adapted:
414      return self._combiner.restore(self._restore_updates())
415    else:
416      return None
417
418  def _set_accumulator(self, accumulator):
419    updates = self._combiner.extract(accumulator)
420    self._set_state_variables(updates)
421    self._adapt_accumulator = None  # Reset accumulator from adapt.
422
423  def _set_state_variables(self, updates):
424    """Directly update the internal state of this Layer.
425
426    This method expects a string-keyed dict of {state_variable_name: state}. The
427    precise nature of the state, and the names associated, are describe by
428    the subclasses of CombinerPreprocessingLayer.
429
430    Args:
431      updates: A string keyed dict of weights to update.
432
433    Raises:
434      RuntimeError: if 'build()' was not called before 'set_processing_state'.
435    """
436    # TODO(momernick): Do we need to do any more input sanitization?
437    if not self.built:
438      raise RuntimeError('_set_state_variables() must be called after build().')
439
440    with ops.init_scope():
441      for var_name, value in updates.items():
442        self.state_variables[var_name].assign(value)
443
444
445def convert_to_list(values, sparse_default_value=None):
446  """Convert a TensorLike, CompositeTensor, or ndarray into a Python list."""
447  if tf_utils.is_ragged(values):
448    # There is a corner case when dealing with ragged tensors: if you get an
449    # actual RaggedTensor (not a RaggedTensorValue) passed in non-eager mode,
450    # you can't call to_list() on it without evaluating it first. However,
451    # because we don't yet fully support composite tensors across Keras,
452    # backend.get_value() won't evaluate the tensor.
453    # TODO(momernick): Get Keras to recognize composite tensors as Tensors
454    # and then replace this with a call to backend.get_value.
455    if (isinstance(values, ragged_tensor.RaggedTensor) and
456        not context.executing_eagerly()):
457      values = backend.get_session(values).run(values)
458    values = values.to_list()
459
460  if isinstance(values,
461                (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
462    if sparse_default_value is None:
463      if dtypes.as_dtype(values.values.dtype) == dtypes.string:
464        sparse_default_value = ''
465      else:
466        sparse_default_value = -1
467    dense_tensor = sparse_ops.sparse_tensor_to_dense(
468        values, default_value=sparse_default_value)
469    values = backend.get_value(dense_tensor)
470
471  if isinstance(values, ops.Tensor):
472    values = backend.get_value(values)
473
474  # We may get passed a ndarray or the code above may give us a ndarray.
475  # In either case, we want to force it into a standard python list.
476  if isinstance(values, np.ndarray):
477    values = values.tolist()
478
479  return values
480
481
482# TODO(omalleyt): This class will be gradually replaced.
483class Combiner(object):
484  """Functional object that defines a shardable computation.
485
486  This object defines functions required to create and manipulate data objects.
487  These data objects, referred to below as 'accumulators', are computation-
488  specific and may be implemented alongside concrete subclasses of Combiner
489  (if necessary - some computations may be simple enough that standard Python
490  types can be used as accumulators).
491
492  The intent for this class is that by describing computations in this way, we
493  can arbitrarily shard a dataset, perform computations on a subset, and then
494  merge the computation into a final result. This enables distributed
495  computation.
496
497  The combiner itself does not own any state - all computational state is owned
498  by the accumulator objects. This is so that we can have an arbitrary number of
499  Combiners (thus sharding the computation N ways) without risking any change
500  to the underlying computation. These accumulator objects are uniquely
501  associated with each Combiner; a Combiner defines what the accumulator object
502  should be and will only work with accumulators of that type.
503  """
504  __metaclass__ = abc.ABCMeta
505
506  def __repr__(self):
507    return '<{}>'.format(self.__class__.__name__)
508
509  @abc.abstractmethod
510  def compute(self, batch_values, accumulator=None):
511    """Compute a step in this computation, returning a new accumulator.
512
513    This method computes a step of the computation described by this Combiner.
514    If an accumulator is passed, the data in that accumulator is also used; so
515    compute(batch_values) results in f(batch_values), while
516    compute(batch_values, accumulator) results in
517    merge(f(batch_values), accumulator).
518
519    Args:
520      batch_values: A list of ndarrays representing the values of the inputs for
521        this step of the computation.
522      accumulator: the current accumulator. Can be None.
523
524    Returns:
525      An accumulator that includes the passed batch of inputs.
526    """
527    pass
528
529  @abc.abstractmethod
530  def merge(self, accumulators):
531    """Merge several accumulators to a single accumulator.
532
533    This method takes the partial values in several accumulators and combines
534    them into a single accumulator. This computation must not be order-specific
535    (that is, merge([a, b]) must return the same result as merge([b, a]).
536
537    Args:
538      accumulators: the accumulators to merge, as a list.
539
540    Returns:
541      A merged accumulator.
542    """
543    pass
544
545  @abc.abstractmethod
546  def extract(self, accumulator):
547    """Convert an accumulator into a dict of output values.
548
549    Args:
550      accumulator: The accumulator to convert.
551
552    Returns:
553      A dict of ndarrays representing the data in this accumulator.
554    """
555    pass
556
557  @abc.abstractmethod
558  def restore(self, output):
559    """Create an accumulator based on 'output'.
560
561    This method creates a new accumulator with identical internal state to the
562    one used to create the data in 'output'. This means that if you do
563
564    output_data = combiner.extract(accumulator_1)
565    accumulator_2 = combiner.restore(output_data)
566
567    then accumulator_1 and accumulator_2 will have identical internal state, and
568    computations using either of them will be equivalent.
569
570    Args:
571      output: The data output from a previous computation. Should be in the same
572        form as provided by 'extract_output'.
573
574    Returns:
575      A new accumulator.
576    """
577    pass
578
579  @abc.abstractmethod
580  def serialize(self, accumulator):
581    """Serialize an accumulator for a remote call.
582
583    This function serializes an accumulator to be sent to a remote process.
584
585    Args:
586      accumulator: The accumulator to serialize.
587
588    Returns:
589      A byte string representing the passed accumulator.
590    """
591    pass
592
593  @abc.abstractmethod
594  def deserialize(self, encoded_accumulator):
595    """Deserialize an accumulator received from 'serialize()'.
596
597    This function deserializes an accumulator serialized by 'serialize()'.
598
599    Args:
600      encoded_accumulator: A byte string representing an accumulator.
601
602    Returns:
603      The accumulator represented by the passed byte_string.
604    """
605    pass
606
607
608def _disallow_inside_tf_function(method_name):
609  """Disallow calling a method inside a `tf.function`."""
610  if ops.inside_function():
611    error_msg = (
612        'Detected a call to `PreprocessingLayer.{method_name}` inside a '
613        '`tf.function`. `PreprocessingLayer.{method_name} is a high-level '
614        'endpoint that manages its own `tf.function`. Please move the call '
615        'to `PreprocessingLayer.{method_name}` outside of all enclosing '
616        '`tf.function`s. Note that you can call a `PreprocessingLayer` '
617        'directly on `Tensor`s inside a `tf.function` like: `layer(x)`, '
618        'or update its state like: `layer.update_state(x)`.').format(
619            method_name=method_name)
620    raise RuntimeError(error_msg)
621