1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains the base ProcessingLayer and a subclass that uses Combiners.""" 16 17import abc 18import collections 19 20import numpy as np 21 22from tensorflow.python.eager import context 23from tensorflow.python.eager import def_function 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import sparse_tensor 27from tensorflow.python.keras import backend 28from tensorflow.python.keras.engine import data_adapter 29from tensorflow.python.keras.engine.base_layer import Layer 30from tensorflow.python.keras.utils import tf_utils 31from tensorflow.python.keras.utils import version_utils 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import sparse_ops 34from tensorflow.python.ops import variables 35from tensorflow.python.ops.ragged import ragged_tensor 36from tensorflow.python.training.tracking import base as trackable 37from tensorflow.python.util.tf_export import keras_export 38 39 40@keras_export('keras.layers.experimental.preprocessing.PreprocessingLayer') 41class PreprocessingLayer(Layer, metaclass=abc.ABCMeta): 42 """Base class for Preprocessing Layers. 43 44 **Don't use this class directly: it's an abstract base class!** You may 45 be looking for one of the many built-in 46 [preprocessing layers](https://keras.io/guides/preprocessing_layers/) 47 instead. 48 49 Preprocessing layers are layers whose state gets computed before model 50 training starts. They do not get updated during training. 51 Most preprocessing layers implement an `adapt()` method for state computation. 52 53 The `PreprocessingLayer` class is the base class you would subclass to 54 implement your own preprocessing layers. 55 56 Attributes: 57 streaming: Whether a layer can be adapted multiple times without resetting 58 the state of the layer. 59 """ 60 _must_restore_from_config = True 61 62 def __init__(self, streaming=True, **kwargs): 63 super(PreprocessingLayer, self).__init__(**kwargs) 64 self._streaming = streaming 65 self._is_compiled = False 66 self._is_adapted = False 67 68 # Sets `is_adapted=False` when `reset_state` is called. 69 self._reset_state_impl = self.reset_state 70 self.reset_state = self._reset_state_wrapper 71 72 self._adapt_function = None 73 74 @property 75 def streaming(self): 76 """Whether `adapt` can be called twice without resetting the state.""" 77 return self._streaming 78 79 @property 80 def is_adapted(self): 81 """Whether the layer has been fit to data already.""" 82 return self._is_adapted 83 84 def update_state(self, data): 85 """Accumulates statistics for the preprocessing layer. 86 87 Arguments: 88 data: A mini-batch of inputs to the layer. 89 """ 90 raise NotImplementedError 91 92 def reset_state(self): # pylint: disable=method-hidden 93 """Resets the statistics of the preprocessing layer.""" 94 raise NotImplementedError 95 96 def merge_state(self, layers): 97 """Merge the statistics of multiple preprocessing layers. 98 99 This layer will contain the merged state. 100 101 Arguments: 102 layers: Layers whose statistics should be merge with the statistics of 103 this layer. 104 """ 105 raise NotImplementedError 106 107 def finalize_state(self): 108 """Finalize the statistics for the preprocessing layer. 109 110 This method is called at the end of `adapt` or after restoring a serialized 111 preprocessing layer's state. This method handles any one-time operations 112 that should occur on the layer's state before `Layer.__call__`. 113 """ 114 pass 115 116 def make_adapt_function(self): 117 """Creates a function to execute one step of `adapt`. 118 119 This method can be overridden to support custom adapt logic. 120 This method is called by `PreprocessingLayer.adapt`. 121 122 Typically, this method directly controls `tf.function` settings, 123 and delegates the actual state update logic to 124 `PreprocessingLayer.update_state`. 125 126 This function is cached the first time `PreprocessingLayer.adapt` 127 is called. The cache is cleared whenever `PreprocessingLayer.compile` 128 is called. 129 130 Returns: 131 Function. The function created by this method should accept a 132 `tf.data.Iterator`, retrieve a batch, and update the state of the 133 layer. 134 """ 135 if self._adapt_function is not None: 136 return self._adapt_function 137 138 def adapt_step(iterator): 139 data = next(iterator) 140 self._adapt_maybe_build(data) 141 self.update_state(data) 142 143 if self._steps_per_execution.numpy().item() == 1: 144 adapt_fn = adapt_step 145 else: 146 147 def adapt_fn(iterator): 148 for _ in math_ops.range(self._steps_per_execution): 149 adapt_step(iterator) 150 151 if not self._run_eagerly: 152 adapt_fn = def_function.function(adapt_fn) 153 154 self._adapt_function = adapt_fn 155 return self._adapt_function 156 157 def compile(self, run_eagerly=None, steps_per_execution=None): 158 """Configures the layer for `adapt`. 159 160 Arguments: 161 run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s logic 162 will not be wrapped in a `tf.function`. Recommended to leave this as 163 `None` unless your `Model` cannot be run inside a `tf.function`. 164 steps_per_execution: Int. Defaults to 1. The number of batches to run 165 during each `tf.function` call. Running multiple batches inside a 166 single `tf.function` call can greatly improve performance on TPUs or 167 small models with a large Python overhead. 168 """ 169 if steps_per_execution is None: 170 steps_per_execution = 1 171 self._configure_steps_per_execution(steps_per_execution) 172 173 if run_eagerly is None: 174 run_eagerly = self.dynamic 175 self._run_eagerly = run_eagerly 176 177 self._is_compiled = True 178 179 def adapt(self, data, batch_size=None, steps=None, reset_state=True): 180 """Fits the state of the preprocessing layer to the data being passed. 181 182 After calling `adapt` on a layer, a preprocessing layer's state will not 183 update during training. In order to make preprocessing layers efficient in 184 any distribution context, they are kept constant with respect to any 185 compiled `tf.Graph`s that call the layer. This does not affect the layer use 186 when adapting each layer only once, but if you adapt a layer multiple times 187 you will need to take care to re-compile any compiled functions as follows: 188 189 * If you are adding a preprocessing layer to a `keras.Model`, you need to 190 call `model.compile` after each subsequent call to `adapt`. 191 * If you are calling a preprocessing layer inside `tf.data.Dataset.map`, 192 you should call `map` again on the input `tf.data.Dataset` after each 193 `adapt`. 194 * If you are using a `tf.function` directly which calls a preprocessing 195 layer, you need to call `tf.function` again on your callable after 196 each subsequent call to `adapt`. 197 198 `tf.keras.Model` example with multiple adapts: 199 200 >>> layer = tf.keras.layers.experimental.preprocessing.Normalization( 201 ... axis=None) 202 >>> layer.adapt([0, 2]) 203 >>> model = tf.keras.Sequential(layer) 204 >>> model.predict([0, 1, 2]) 205 array([-1., 0., 1.], dtype=float32) 206 >>> layer.adapt([-1, 1]) 207 >>> model.compile() # This is needed to re-compile model.predict! 208 >>> model.predict([0, 1, 2]) 209 array([0., 1., 2.], dtype=float32) 210 211 `tf.data.Dataset` example with multiple adapts: 212 213 >>> layer = tf.keras.layers.experimental.preprocessing.Normalization( 214 ... axis=None) 215 >>> layer.adapt([0, 2]) 216 >>> input_ds = tf.data.Dataset.range(3) 217 >>> normalized_ds = input_ds.map(layer) 218 >>> list(normalized_ds.as_numpy_iterator()) 219 [array([-1.], dtype=float32), 220 array([0.], dtype=float32), 221 array([1.], dtype=float32)] 222 >>> layer.adapt([-1, 1]) 223 >>> normalized_ds = input_ds.map(layer) # Re-map over the input dataset. 224 >>> list(normalized_ds.as_numpy_iterator()) 225 [array([0.], dtype=float32), 226 array([1.], dtype=float32), 227 array([2.], dtype=float32)] 228 229 Arguments: 230 data: The data to train on. It can be passed either as a tf.data 231 Dataset, or as a numpy array. 232 batch_size: Integer or `None`. 233 Number of samples per state update. 234 If unspecified, `batch_size` will default to 32. 235 Do not specify the `batch_size` if your data is in the 236 form of datasets, generators, or `keras.utils.Sequence` instances 237 (since they generate batches). 238 steps: Integer or `None`. 239 Total number of steps (batches of samples) 240 When training with input tensors such as 241 TensorFlow data tensors, the default `None` is equal to 242 the number of samples in your dataset divided by 243 the batch size, or 1 if that cannot be determined. If x is a 244 `tf.data` dataset, and 'steps' is None, the epoch will run until 245 the input dataset is exhausted. When passing an infinitely 246 repeating dataset, you must specify the `steps` argument. This 247 argument is not supported with array inputs. 248 reset_state: Optional argument specifying whether to clear the state of 249 the layer at the start of the call to `adapt`, or whether to start 250 from the existing state. This argument may not be relevant to all 251 preprocessing layers: a subclass of PreprocessingLayer may choose to 252 throw if 'reset_state' is set to False. 253 """ 254 _disallow_inside_tf_function('adapt') 255 if not version_utils.should_use_v2(): 256 raise RuntimeError('`adapt` is only supported in tensorflow v2.') # pylint: disable=g-doc-exception 257 if not self.streaming and self._is_adapted and not reset_state: 258 raise ValueError('{} does not supporting calling `adapt` twice without ' 259 'resetting the state.'.format(self.__class__.__name__)) 260 if not self._is_compiled: 261 self.compile() # Compile with defaults. 262 if self.built and reset_state: 263 self.reset_state() 264 data_handler = data_adapter.DataHandler( 265 data, 266 batch_size=batch_size, 267 steps_per_epoch=steps, 268 epochs=1, 269 steps_per_execution=self._steps_per_execution, 270 distribute=False) 271 self._adapt_function = self.make_adapt_function() 272 for _, iterator in data_handler.enumerate_epochs(): 273 with data_handler.catch_stop_iteration(): 274 for _ in data_handler.steps(): 275 self._adapt_function(iterator) 276 if data_handler.should_sync: 277 context.async_wait() 278 self.finalize_state() 279 self._is_adapted = True 280 281 def _reset_state_wrapper(self): 282 """Calls `reset_state` and sets `adapted` to `False`.""" 283 self._reset_state_impl() 284 self._is_adapted = False 285 286 @trackable.no_automatic_dependency_tracking 287 def _configure_steps_per_execution(self, steps_per_execution): 288 self._steps_per_execution = variables.Variable( 289 steps_per_execution, 290 dtype='int64', 291 aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA) 292 293 # TODO(omalleyt): Unify this logic with `Layer._maybe_build`. 294 def _adapt_maybe_build(self, data): 295 if not self.built: 296 try: 297 # If this is a Numpy array or tensor, we can get shape from .shape. 298 # If not, an attribute error will be thrown. 299 data_shape = data.shape 300 data_shape_nones = tuple([None] * len(data.shape)) 301 except AttributeError: 302 # The input has an unknown number of dimensions. 303 data_shape = None 304 data_shape_nones = None 305 306 # TODO (b/159261555): move this to base layer build. 307 batch_input_shape = getattr(self, '_batch_input_shape', None) 308 if batch_input_shape is None: 309 # Set the number of dimensions. 310 self._batch_input_shape = data_shape_nones 311 self.build(data_shape) 312 self.built = True 313 314 315# TODO(omalleyt): This class will be gradually replaced. 316class CombinerPreprocessingLayer(PreprocessingLayer): 317 """Base class for PreprocessingLayers that do computation using a Combiner. 318 319 This class provides several helper methods to make creating a 320 PreprocessingLayer easier. It assumes that the core of your computation will 321 be done via a Combiner object. Subclassing this class to create a 322 PreprocessingLayer allows your layer to be compatible with distributed 323 computation. 324 325 This class is compatible with Tensorflow 2.0+. 326 """ 327 328 def __init__(self, combiner, **kwargs): 329 super(CombinerPreprocessingLayer, self).__init__(**kwargs) 330 self.state_variables = collections.OrderedDict() 331 self._combiner = combiner 332 self._adapt_accumulator = None 333 334 def reset_state(self): # pylint: disable=method-hidden 335 self._adapt_accumulator = None 336 337 @trackable.no_automatic_dependency_tracking 338 def update_state(self, data): 339 if self._adapt_accumulator is None: 340 self._adapt_accumulator = self._get_accumulator() 341 self._adapt_accumulator = self._combiner.compute(data, 342 self._adapt_accumulator) 343 344 def merge_state(self, layers): 345 accumulators = ([self._get_accumulator()] + 346 [l._get_accumulator() for l in layers]) # pylint: disable=protected-access 347 merged_accumulator = self._combiner.merge(accumulators) 348 self._set_accumulator(merged_accumulator) 349 350 def finalize_state(self): 351 if self._adapt_accumulator is not None: 352 self._set_accumulator(self._adapt_accumulator) 353 354 def compile(self, run_eagerly=None, steps_per_execution=None): 355 # TODO(omalleyt): Remove this once sublayers are switched to new APIs. 356 if run_eagerly is None: 357 run_eagerly = True 358 super(CombinerPreprocessingLayer, self).compile( 359 run_eagerly=run_eagerly, steps_per_execution=steps_per_execution) 360 361 def adapt(self, data, batch_size=None, steps=None, reset_state=True): 362 if not reset_state: 363 self._adapt_accumulator = self._combiner.restore(self._restore_updates()) 364 super(CombinerPreprocessingLayer, self).adapt( 365 data, batch_size=batch_size, steps=steps, reset_state=reset_state) 366 367 def _add_state_variable(self, 368 name, 369 shape, 370 dtype, 371 initializer=None, 372 partitioner=None, 373 use_resource=None, 374 **kwargs): 375 """Add a variable that can hold state which is updated during adapt(). 376 377 Args: 378 name: Variable name. 379 shape: Variable shape. Defaults to scalar if unspecified. 380 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 381 initializer: initializer instance (callable). 382 partitioner: Partitioner to be passed to the `Trackable` API. 383 use_resource: Whether to use `ResourceVariable` 384 **kwargs: Additional keyword arguments. Accepted values are `getter` and 385 `collections`. 386 387 Returns: 388 The created variable. 389 """ 390 weight = self.add_weight( 391 name=name, 392 shape=shape, 393 dtype=dtype, 394 initializer=initializer, 395 regularizer=None, 396 trainable=False, 397 constraint=None, 398 partitioner=partitioner, 399 use_resource=use_resource, 400 **kwargs) 401 # TODO(momernick): Do not allow collisions here. 402 self.state_variables[name] = weight 403 return weight 404 405 def _restore_updates(self): 406 """Recreates a dict of updates from the layer's weights.""" 407 data_dict = {} 408 for name, var in self.state_variables.items(): 409 data_dict[name] = var.numpy() 410 return data_dict 411 412 def _get_accumulator(self): 413 if self._is_adapted: 414 return self._combiner.restore(self._restore_updates()) 415 else: 416 return None 417 418 def _set_accumulator(self, accumulator): 419 updates = self._combiner.extract(accumulator) 420 self._set_state_variables(updates) 421 self._adapt_accumulator = None # Reset accumulator from adapt. 422 423 def _set_state_variables(self, updates): 424 """Directly update the internal state of this Layer. 425 426 This method expects a string-keyed dict of {state_variable_name: state}. The 427 precise nature of the state, and the names associated, are describe by 428 the subclasses of CombinerPreprocessingLayer. 429 430 Args: 431 updates: A string keyed dict of weights to update. 432 433 Raises: 434 RuntimeError: if 'build()' was not called before 'set_processing_state'. 435 """ 436 # TODO(momernick): Do we need to do any more input sanitization? 437 if not self.built: 438 raise RuntimeError('_set_state_variables() must be called after build().') 439 440 with ops.init_scope(): 441 for var_name, value in updates.items(): 442 self.state_variables[var_name].assign(value) 443 444 445def convert_to_list(values, sparse_default_value=None): 446 """Convert a TensorLike, CompositeTensor, or ndarray into a Python list.""" 447 if tf_utils.is_ragged(values): 448 # There is a corner case when dealing with ragged tensors: if you get an 449 # actual RaggedTensor (not a RaggedTensorValue) passed in non-eager mode, 450 # you can't call to_list() on it without evaluating it first. However, 451 # because we don't yet fully support composite tensors across Keras, 452 # backend.get_value() won't evaluate the tensor. 453 # TODO(momernick): Get Keras to recognize composite tensors as Tensors 454 # and then replace this with a call to backend.get_value. 455 if (isinstance(values, ragged_tensor.RaggedTensor) and 456 not context.executing_eagerly()): 457 values = backend.get_session(values).run(values) 458 values = values.to_list() 459 460 if isinstance(values, 461 (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): 462 if sparse_default_value is None: 463 if dtypes.as_dtype(values.values.dtype) == dtypes.string: 464 sparse_default_value = '' 465 else: 466 sparse_default_value = -1 467 dense_tensor = sparse_ops.sparse_tensor_to_dense( 468 values, default_value=sparse_default_value) 469 values = backend.get_value(dense_tensor) 470 471 if isinstance(values, ops.Tensor): 472 values = backend.get_value(values) 473 474 # We may get passed a ndarray or the code above may give us a ndarray. 475 # In either case, we want to force it into a standard python list. 476 if isinstance(values, np.ndarray): 477 values = values.tolist() 478 479 return values 480 481 482# TODO(omalleyt): This class will be gradually replaced. 483class Combiner(object): 484 """Functional object that defines a shardable computation. 485 486 This object defines functions required to create and manipulate data objects. 487 These data objects, referred to below as 'accumulators', are computation- 488 specific and may be implemented alongside concrete subclasses of Combiner 489 (if necessary - some computations may be simple enough that standard Python 490 types can be used as accumulators). 491 492 The intent for this class is that by describing computations in this way, we 493 can arbitrarily shard a dataset, perform computations on a subset, and then 494 merge the computation into a final result. This enables distributed 495 computation. 496 497 The combiner itself does not own any state - all computational state is owned 498 by the accumulator objects. This is so that we can have an arbitrary number of 499 Combiners (thus sharding the computation N ways) without risking any change 500 to the underlying computation. These accumulator objects are uniquely 501 associated with each Combiner; a Combiner defines what the accumulator object 502 should be and will only work with accumulators of that type. 503 """ 504 __metaclass__ = abc.ABCMeta 505 506 def __repr__(self): 507 return '<{}>'.format(self.__class__.__name__) 508 509 @abc.abstractmethod 510 def compute(self, batch_values, accumulator=None): 511 """Compute a step in this computation, returning a new accumulator. 512 513 This method computes a step of the computation described by this Combiner. 514 If an accumulator is passed, the data in that accumulator is also used; so 515 compute(batch_values) results in f(batch_values), while 516 compute(batch_values, accumulator) results in 517 merge(f(batch_values), accumulator). 518 519 Args: 520 batch_values: A list of ndarrays representing the values of the inputs for 521 this step of the computation. 522 accumulator: the current accumulator. Can be None. 523 524 Returns: 525 An accumulator that includes the passed batch of inputs. 526 """ 527 pass 528 529 @abc.abstractmethod 530 def merge(self, accumulators): 531 """Merge several accumulators to a single accumulator. 532 533 This method takes the partial values in several accumulators and combines 534 them into a single accumulator. This computation must not be order-specific 535 (that is, merge([a, b]) must return the same result as merge([b, a]). 536 537 Args: 538 accumulators: the accumulators to merge, as a list. 539 540 Returns: 541 A merged accumulator. 542 """ 543 pass 544 545 @abc.abstractmethod 546 def extract(self, accumulator): 547 """Convert an accumulator into a dict of output values. 548 549 Args: 550 accumulator: The accumulator to convert. 551 552 Returns: 553 A dict of ndarrays representing the data in this accumulator. 554 """ 555 pass 556 557 @abc.abstractmethod 558 def restore(self, output): 559 """Create an accumulator based on 'output'. 560 561 This method creates a new accumulator with identical internal state to the 562 one used to create the data in 'output'. This means that if you do 563 564 output_data = combiner.extract(accumulator_1) 565 accumulator_2 = combiner.restore(output_data) 566 567 then accumulator_1 and accumulator_2 will have identical internal state, and 568 computations using either of them will be equivalent. 569 570 Args: 571 output: The data output from a previous computation. Should be in the same 572 form as provided by 'extract_output'. 573 574 Returns: 575 A new accumulator. 576 """ 577 pass 578 579 @abc.abstractmethod 580 def serialize(self, accumulator): 581 """Serialize an accumulator for a remote call. 582 583 This function serializes an accumulator to be sent to a remote process. 584 585 Args: 586 accumulator: The accumulator to serialize. 587 588 Returns: 589 A byte string representing the passed accumulator. 590 """ 591 pass 592 593 @abc.abstractmethod 594 def deserialize(self, encoded_accumulator): 595 """Deserialize an accumulator received from 'serialize()'. 596 597 This function deserializes an accumulator serialized by 'serialize()'. 598 599 Args: 600 encoded_accumulator: A byte string representing an accumulator. 601 602 Returns: 603 The accumulator represented by the passed byte_string. 604 """ 605 pass 606 607 608def _disallow_inside_tf_function(method_name): 609 """Disallow calling a method inside a `tf.function`.""" 610 if ops.inside_function(): 611 error_msg = ( 612 'Detected a call to `PreprocessingLayer.{method_name}` inside a ' 613 '`tf.function`. `PreprocessingLayer.{method_name} is a high-level ' 614 'endpoint that manages its own `tf.function`. Please move the call ' 615 'to `PreprocessingLayer.{method_name}` outside of all enclosing ' 616 '`tf.function`s. Note that you can call a `PreprocessingLayer` ' 617 'directly on `Tensor`s inside a `tf.function` like: `layer(x)`, ' 618 'or update its state like: `layer.update_state(x)`.').format( 619 method_name=method_name) 620 raise RuntimeError(error_msg) 621