• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=redefined-outer-name
17# pylint: disable=redefined-builtin
18# pylint: disable=g-classes-have-attributes
19"""Keras backend API."""
20
21import collections
22import itertools
23import json
24import os
25import sys
26import threading
27import warnings
28import weakref
29
30import numpy as np
31
32from tensorflow.core.protobuf import config_pb2
33from tensorflow.python import tf2
34from tensorflow.python.client import session as session_module
35from tensorflow.python.distribute import distribution_strategy_context
36from tensorflow.python.eager import context
37from tensorflow.python.eager.context import get_config
38from tensorflow.python.framework import composite_tensor
39from tensorflow.python.framework import config
40from tensorflow.python.framework import constant_op
41from tensorflow.python.framework import device_spec
42from tensorflow.python.framework import dtypes as dtypes_module
43from tensorflow.python.framework import func_graph
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import sparse_tensor
46from tensorflow.python.framework import tensor_shape
47from tensorflow.python.framework import tensor_spec
48from tensorflow.python.framework import tensor_util
49from tensorflow.python.keras import backend_config
50from tensorflow.python.keras.distribute import distribute_coordinator_utils as dc
51from tensorflow.python.keras.engine import keras_tensor
52from tensorflow.python.keras.utils import control_flow_util
53from tensorflow.python.keras.utils import object_identity
54from tensorflow.python.keras.utils import tf_contextlib
55from tensorflow.python.keras.utils import tf_inspect
56from tensorflow.python.ops import array_ops
57from tensorflow.python.ops import clip_ops
58from tensorflow.python.ops import control_flow_ops
59from tensorflow.python.ops import ctc_ops as ctc
60from tensorflow.python.ops import functional_ops
61from tensorflow.python.ops import gradients as gradients_module
62from tensorflow.python.ops import image_ops
63from tensorflow.python.ops import init_ops
64from tensorflow.python.ops import linalg_ops
65from tensorflow.python.ops import logging_ops
66from tensorflow.python.ops import map_fn as map_fn_lib
67from tensorflow.python.ops import math_ops
68from tensorflow.python.ops import nn
69from tensorflow.python.ops import random_ops
70from tensorflow.python.ops import sparse_ops
71from tensorflow.python.ops import state_ops
72from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
73from tensorflow.python.ops import tensor_array_ops
74from tensorflow.python.ops import variables as variables_module
75from tensorflow.python.ops.ragged import ragged_tensor
76from tensorflow.python.platform import tf_logging as logging
77from tensorflow.python.training import moving_averages
78from tensorflow.python.training.tracking import util as tracking_util
79from tensorflow.python.util import dispatch
80from tensorflow.python.util import keras_deps
81from tensorflow.python.util import nest
82from tensorflow.python.util.tf_export import keras_export
83from tensorflow.tools.docs import doc_controls
84
85py_all = all
86py_sum = sum
87py_any = any
88
89# INTERNAL UTILS
90
91# The internal graph maintained by Keras and used by the symbolic Keras APIs
92# while executing eagerly (such as the functional API for model-building).
93# This is thread-local to allow building separate models in different threads
94# concurrently, but comes at the cost of not being able to build one model
95# across threads.
96_GRAPH = threading.local()
97
98# A graph which is used for constructing functions in eager mode.
99_CURRENT_SCRATCH_GRAPH = threading.local()
100
101# This is a thread local object that will hold the default internal TF session
102# used by Keras. It can be set manually via `set_session(sess)`.
103_SESSION = threading.local()
104
105
106# A global dictionary mapping graph objects to an index of counters used
107# for various layer/optimizer names in each graph.
108# Allows to give unique autogenerated names to layers, in a graph-specific way.
109PER_GRAPH_OBJECT_NAME_UIDS = weakref.WeakKeyDictionary()
110
111
112# A global set tracking what object names have been seen so far.
113# Optionally used as an avoid-list when generating names
114OBSERVED_NAMES = set()
115
116
117# _DUMMY_EAGER_GRAPH.key is used as a key in _GRAPH_LEARNING_PHASES.
118# We keep a separate reference to it to make sure it does not get removed from
119# _GRAPH_LEARNING_PHASES.
120# _DummyEagerGraph inherits from threading.local to make its `key` attribute
121# thread local. This is needed to make set_learning_phase affect only the
122# current thread during eager execution (see b/123096885 for more details).
123class _DummyEagerGraph(threading.local):
124  """_DummyEagerGraph provides a thread local `key` attribute.
125
126  We can't use threading.local directly, i.e. without subclassing, because
127  gevent monkey patches threading.local and its version does not support
128  weak references.
129  """
130
131  class _WeakReferencableClass:
132    """This dummy class is needed for two reasons.
133
134    - We need something that supports weak references. Basic types like string
135    and ints don't.
136    - We need something whose hash and equality are based on object identity
137    to make sure they are treated as different keys to _GRAPH_LEARNING_PHASES.
138
139    An empty Python class satisfies both of these requirements.
140    """
141    pass
142
143  def __init__(self):
144    # Constructors for classes subclassing threading.local run once
145    # per thread accessing something in the class. Thus, each thread will
146    # get a different key.
147    super(_DummyEagerGraph, self).__init__()
148    self.key = _DummyEagerGraph._WeakReferencableClass()
149    self.learning_phase_is_set = False
150
151
152_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
153
154# This boolean flag can be set to True to leave variable initialization
155# up to the user.
156# Change its value via `manual_variable_initialization(value)`.
157_MANUAL_VAR_INIT = False
158
159# This list holds the available devices.
160# It is populated when `_get_available_gpus()` is called for the first time.
161# We assume our devices don't change henceforth.
162_LOCAL_DEVICES = None
163
164# The below functions are kept accessible from backend for compatibility.
165epsilon = backend_config.epsilon
166floatx = backend_config.floatx
167image_data_format = backend_config.image_data_format
168set_epsilon = backend_config.set_epsilon
169set_floatx = backend_config.set_floatx
170set_image_data_format = backend_config.set_image_data_format
171
172
173@keras_export('keras.backend.backend')
174@doc_controls.do_not_generate_docs
175def backend():
176  """Publicly accessible method for determining the current backend.
177
178  Only exists for API compatibility with multi-backend Keras.
179
180  Returns:
181      The string "tensorflow".
182  """
183  return 'tensorflow'
184
185
186@keras_export('keras.backend.cast_to_floatx')
187@dispatch.add_dispatch_support
188@doc_controls.do_not_generate_docs
189def cast_to_floatx(x):
190  """Cast a Numpy array to the default Keras float type.
191
192  Args:
193      x: Numpy array or TensorFlow tensor.
194
195  Returns:
196      The same array (Numpy array if `x` was a Numpy array, or TensorFlow tensor
197      if `x` was a tensor), cast to its new type.
198
199  Example:
200
201  >>> tf.keras.backend.floatx()
202  'float32'
203  >>> arr = np.array([1.0, 2.0], dtype='float64')
204  >>> arr.dtype
205  dtype('float64')
206  >>> new_arr = cast_to_floatx(arr)
207  >>> new_arr
208  array([1.,  2.], dtype=float32)
209  >>> new_arr.dtype
210  dtype('float32')
211
212  """
213  if isinstance(x, (ops.Tensor,
214                    variables_module.Variable,
215                    sparse_tensor.SparseTensor)):
216    return math_ops.cast(x, dtype=floatx())
217  return np.asarray(x, dtype=floatx())
218
219
220@keras_export('keras.backend.get_uid')
221def get_uid(prefix=''):
222  """Associates a string prefix with an integer counter in a TensorFlow graph.
223
224  Args:
225    prefix: String prefix to index.
226
227  Returns:
228    Unique integer ID.
229
230  Example:
231
232  >>> get_uid('dense')
233  1
234  >>> get_uid('dense')
235  2
236
237  """
238  graph = get_graph()
239  if graph not in PER_GRAPH_OBJECT_NAME_UIDS:
240    PER_GRAPH_OBJECT_NAME_UIDS[graph] = collections.defaultdict(int)
241  layer_name_uids = PER_GRAPH_OBJECT_NAME_UIDS[graph]
242  layer_name_uids[prefix] += 1
243  return layer_name_uids[prefix]
244
245
246@keras_export('keras.backend.reset_uids')
247def reset_uids():
248  """Resets graph identifiers.
249  """
250
251  PER_GRAPH_OBJECT_NAME_UIDS.clear()
252  OBSERVED_NAMES.clear()
253
254
255@keras_export('keras.backend.clear_session')
256def clear_session():
257  """Resets all state generated by Keras.
258
259  Keras manages a global state, which it uses to implement the Functional
260  model-building API and to uniquify autogenerated layer names.
261
262  If you are creating many models in a loop, this global state will consume
263  an increasing amount of memory over time, and you may want to clear it.
264  Calling `clear_session()` releases the global state: this helps avoid clutter
265  from old models and layers, especially when memory is limited.
266
267  Example 1: calling `clear_session()` when creating models in a loop
268
269  ```python
270  for _ in range(100):
271    # Without `clear_session()`, each iteration of this loop will
272    # slightly increase the size of the global state managed by Keras
273    model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in range(10)])
274
275  for _ in range(100):
276    # With `clear_session()` called at the beginning,
277    # Keras starts with a blank state at each iteration
278    # and memory consumption is constant over time.
279    tf.keras.backend.clear_session()
280    model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in range(10)])
281  ```
282
283  Example 2: resetting the layer name generation counter
284
285  >>> import tensorflow as tf
286  >>> layers = [tf.keras.layers.Dense(10) for _ in range(10)]
287  >>> new_layer = tf.keras.layers.Dense(10)
288  >>> print(new_layer.name)
289  dense_10
290  >>> tf.keras.backend.set_learning_phase(1)
291  >>> print(tf.keras.backend.learning_phase())
292  1
293  >>> tf.keras.backend.clear_session()
294  >>> new_layer = tf.keras.layers.Dense(10)
295  >>> print(new_layer.name)
296  dense
297  """
298  global _SESSION
299  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
300  global _GRAPH_VARIABLES  # pylint: disable=global-variable-not-assigned
301  global _GRAPH_TF_OPTIMIZERS  # pylint: disable=global-variable-not-assigned
302  global _GRAPH
303  _GRAPH.graph = None
304  ops.reset_default_graph()
305  reset_uids()
306  _SESSION.session = None
307  graph = get_graph()
308  with graph.as_default():
309    _DUMMY_EAGER_GRAPH.learning_phase_is_set = False
310    _GRAPH_LEARNING_PHASES.clear()
311    # Create the learning phase placeholder in graph using the default factory.
312    _GRAPH_LEARNING_PHASES.setdefault(graph)
313    _GRAPH_VARIABLES.pop(graph, None)
314    _GRAPH_TF_OPTIMIZERS.pop(graph, None)
315  if context.executing_eagerly():
316    # Clear pending nodes in eager executors, kernel caches and step_containers.
317    context.context().clear_kernel_cache()
318
319# Inject the clear_session function to keras_deps to remove the dependency
320# from TFLite to Keras.
321keras_deps.register_clear_session_function(clear_session)
322
323
324@keras_export('keras.backend.manual_variable_initialization')
325@doc_controls.do_not_generate_docs
326def manual_variable_initialization(value):
327  """Sets the manual variable initialization flag.
328
329  This boolean flag determines whether
330  variables should be initialized
331  as they are instantiated (default), or if
332  the user should handle the initialization
333  (e.g. via `tf.compat.v1.initialize_all_variables()`).
334
335  Args:
336      value: Python boolean.
337  """
338  global _MANUAL_VAR_INIT
339  _MANUAL_VAR_INIT = value
340
341
342@keras_export('keras.backend.learning_phase')
343@doc_controls.do_not_generate_docs
344def learning_phase():
345  """Returns the learning phase flag.
346
347  The learning phase flag is a bool tensor (0 = test, 1 = train)
348  to be passed as input to any Keras function
349  that uses a different behavior at train time and test time.
350
351  Returns:
352      Learning phase (scalar integer tensor or Python integer).
353  """
354  graph = ops.get_default_graph()
355  if graph is getattr(_GRAPH, 'graph', None):
356    # Don't enter an init_scope for the learning phase if eager execution
357    # is enabled but we're inside the Keras workspace graph.
358    learning_phase = symbolic_learning_phase()
359  else:
360    with ops.init_scope():
361      # We always check & set the learning phase inside the init_scope,
362      # otherwise the wrong default_graph will be used to look up the learning
363      # phase inside of functions & defuns.
364      #
365      # This is because functions & defuns (both in graph & in eager mode)
366      # will always execute non-eagerly using a function-specific default
367      # subgraph.
368      learning_phase = _GRAPH_LEARNING_PHASES[None]
369  _mark_func_graph_as_unsaveable(graph, learning_phase)
370  return learning_phase
371
372
373def global_learning_phase_is_set():
374  return _DUMMY_EAGER_GRAPH.learning_phase_is_set
375
376
377def _mark_func_graph_as_unsaveable(graph, learning_phase):
378  """Mark func graph as unsaveable due to use of symbolic keras learning phase.
379
380  Functions that capture the symbolic learning phase cannot be exported to
381  SavedModel. Mark the funcgraph as unsaveable, so that an error will be raised
382  if it is exported.
383
384  Args:
385    graph: Graph or FuncGraph object.
386    learning_phase: Learning phase placeholder or int defined in the graph.
387  """
388  if graph.building_function and is_placeholder(learning_phase):
389    graph.mark_as_unsaveable(
390        'The keras learning phase placeholder was used inside a function. '
391        'Exporting placeholders is not supported when saving out a SavedModel. '
392        'Please call `tf.keras.backend.set_learning_phase(0)` in the function '
393        'to set the learning phase to a constant value.')
394
395
396def symbolic_learning_phase():
397  graph = get_graph()
398  with graph.as_default():
399    return _GRAPH_LEARNING_PHASES[graph]
400
401
402def _default_learning_phase():
403  if context.executing_eagerly():
404    return 0
405  else:
406    with name_scope(''):
407      return array_ops.placeholder_with_default(
408          False, shape=(), name='keras_learning_phase')
409
410
411@keras_export('keras.backend.set_learning_phase')
412@doc_controls.do_not_generate_docs
413def set_learning_phase(value):
414  """Sets the learning phase to a fixed value.
415
416  The backend learning phase affects any code that calls
417  `backend.learning_phase()`
418  In particular, all Keras built-in layers use the learning phase as the default
419  for the `training` arg to `Layer.__call__`.
420
421  User-written layers and models can achieve the same behavior with code that
422  looks like:
423
424  ```python
425    def call(self, inputs, training=None):
426      if training is None:
427        training = backend.learning_phase()
428  ```
429
430  Args:
431      value: Learning phase value, either 0 or 1 (integers).
432             0 = test, 1 = train
433
434  Raises:
435      ValueError: if `value` is neither `0` nor `1`.
436  """
437  warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
438                'will be removed after 2020-10-11. To update it, simply '
439                'pass a True/False value to the `training` argument of the '
440                '`__call__` method of your layer or model.')
441  deprecated_internal_set_learning_phase(value)
442
443
444def deprecated_internal_set_learning_phase(value):
445  """A deprecated internal implementation of set_learning_phase.
446
447  This method is an internal-only version of `set_learning_phase` that
448  does not raise a deprecation error. It is required because
449  saved_model needs to keep working with user code that uses the deprecated
450  learning phase methods until those APIs are fully removed from the public API.
451
452  Specifically SavedModel saving needs to make sure the learning phase is 0
453  during tracing even if users overwrote it to a different value.
454
455  But, we don't want to raise deprecation warnings for users when savedmodel
456  sets learning phase just for compatibility with code that relied on
457  explicitly setting the learning phase for other values.
458
459  Args:
460      value: Learning phase value, either 0 or 1 (integers). 0 = test, 1 = train
461
462  Raises:
463      ValueError: if `value` is neither `0` nor `1`.
464  """
465  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
466  if value not in {0, 1}:
467    raise ValueError('Expected learning phase to be 0 or 1.')
468  with ops.init_scope():
469    if context.executing_eagerly():
470      # In an eager context, the learning phase values applies to both the eager
471      # context and the internal Keras graph.
472      _DUMMY_EAGER_GRAPH.learning_phase_is_set = True
473      _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value
474    _GRAPH_LEARNING_PHASES[get_graph()] = value
475
476
477@keras_export('keras.backend.learning_phase_scope')
478@tf_contextlib.contextmanager
479@doc_controls.do_not_generate_docs
480def learning_phase_scope(value):
481  """Provides a scope within which the learning phase is equal to `value`.
482
483  The learning phase gets restored to its original value upon exiting the scope.
484
485  Args:
486     value: Learning phase value, either 0 or 1 (integers).
487            0 = test, 1 = train
488
489  Yields:
490    None.
491
492  Raises:
493     ValueError: if `value` is neither `0` nor `1`.
494  """
495  warnings.warn('`tf.keras.backend.learning_phase_scope` is deprecated and '
496                'will be removed after 2020-10-11. To update it, simply '
497                'pass a True/False value to the `training` argument of the '
498                '`__call__` method of your layer or model.')
499  with deprecated_internal_learning_phase_scope(value):
500    try:
501      yield
502    finally:
503      pass
504
505
506@tf_contextlib.contextmanager
507def deprecated_internal_learning_phase_scope(value):
508  """An internal-only version of `learning_phase_scope`.
509
510  Unlike the public method, this method does not raise a deprecation warning.
511  This is needed because saved model saving needs to set learning phase
512  to maintain compatibility
513  with code that sets/gets the learning phase, but saved model
514  saving itself shouldn't raise a deprecation warning.
515
516  We can get rid of this method and its usages when the public API is
517  removed.
518
519  Args:
520     value: Learning phase value, either 0 or 1 (integers). 0 = test, 1 = train
521
522  Yields:
523    None.
524
525  Raises:
526     ValueError: if `value` is neither `0` nor `1`.
527  """
528  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
529  if value not in {0, 1}:
530    raise ValueError('Expected learning phase to be 0 or 1.')
531
532  with ops.init_scope():
533    if context.executing_eagerly():
534      previous_eager_value = _GRAPH_LEARNING_PHASES.get(
535          _DUMMY_EAGER_GRAPH.key, None)
536    previous_graph_value = _GRAPH_LEARNING_PHASES.get(get_graph(), None)
537
538  learning_phase_previously_set = _DUMMY_EAGER_GRAPH.learning_phase_is_set
539  try:
540    deprecated_internal_set_learning_phase(value)
541    yield
542  finally:
543    # Restore learning phase to initial value.
544    if not learning_phase_previously_set:
545      _DUMMY_EAGER_GRAPH.learning_phase_is_set = False
546    with ops.init_scope():
547      if context.executing_eagerly():
548        if previous_eager_value is not None:
549          _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = previous_eager_value
550        elif _DUMMY_EAGER_GRAPH.key in _GRAPH_LEARNING_PHASES:
551          del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
552
553      graph = get_graph()
554      if previous_graph_value is not None:
555        _GRAPH_LEARNING_PHASES[graph] = previous_graph_value
556      elif graph in _GRAPH_LEARNING_PHASES:
557        del _GRAPH_LEARNING_PHASES[graph]
558
559
560@tf_contextlib.contextmanager
561def eager_learning_phase_scope(value):
562  """Internal scope that sets the learning phase in eager / tf.function only.
563
564  Args:
565      value: Learning phase value, either 0 or 1 (integers).
566             0 = test, 1 = train
567
568  Yields:
569    None.
570
571  Raises:
572     ValueError: if `value` is neither `0` nor `1`.
573  """
574  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
575  assert value in {0, 1}
576  assert ops.executing_eagerly_outside_functions()
577  global_learning_phase_was_set = global_learning_phase_is_set()
578  if global_learning_phase_was_set:
579    previous_value = learning_phase()
580  try:
581    _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value
582    yield
583  finally:
584    # Restore learning phase to initial value or unset.
585    if global_learning_phase_was_set:
586      _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = previous_value
587    else:
588      del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
589
590
591def _as_graph_element(obj):
592  """Convert `obj` to a graph element if possible, otherwise return `None`.
593
594  Args:
595    obj: Object to convert.
596
597  Returns:
598    The result of `obj._as_graph_element()` if that method is available;
599        otherwise `None`.
600  """
601  conv_fn = getattr(obj, '_as_graph_element', None)
602  if conv_fn and callable(conv_fn):
603    return conv_fn()
604  return None
605
606
607def _assert_same_graph(original_item, item):
608  """Fail if the 2 items are from different graphs.
609
610  Args:
611    original_item: Original item to check against.
612    item: Item to check.
613
614  Raises:
615    ValueError: if graphs do not match.
616  """
617  original_graph = getattr(original_item, 'graph', None)
618  graph = getattr(item, 'graph', None)
619  if original_graph and graph and original_graph is not graph:
620    raise ValueError(
621        '%s must be from the same graph as %s (graphs are %s and %s).' %
622        (item, original_item, graph, original_graph))
623
624
625def _current_graph(op_input_list, graph=None):
626  """Returns the appropriate graph to use for the given inputs.
627
628  This library method provides a consistent algorithm for choosing the graph
629  in which an Operation should be constructed:
630
631  1. If the default graph is being used to construct a function, we
632     use the default graph.
633  2. If the "graph" is specified explicitly, we validate that all of the inputs
634     in "op_input_list" are compatible with that graph.
635  3. Otherwise, we attempt to select a graph from the first Operation-
636     or Tensor-valued input in "op_input_list", and validate that all other
637     such inputs are in the same graph.
638  4. If the graph was not specified and it could not be inferred from
639     "op_input_list", we attempt to use the default graph.
640
641  Args:
642    op_input_list: A list of inputs to an operation, which may include `Tensor`,
643      `Operation`, and other objects that may be converted to a graph element.
644    graph: (Optional) The explicit graph to use.
645
646  Raises:
647    TypeError: If op_input_list is not a list or tuple, or if graph is not a
648      Graph.
649    ValueError: If a graph is explicitly passed and not all inputs are from it,
650      or if the inputs are from multiple graphs, or we could not find a graph
651      and there was no default graph.
652
653  Returns:
654    The appropriate graph to use for the given inputs.
655
656  """
657  current_default_graph = ops.get_default_graph()
658  if current_default_graph.building_function:
659    return current_default_graph
660
661  op_input_list = tuple(op_input_list)  # Handle generators correctly
662  if graph and not isinstance(graph, ops.Graph):
663    raise TypeError('Input graph needs to be a Graph: %s' % (graph,))
664
665  # 1. We validate that all of the inputs are from the same graph. This is
666  #    either the supplied graph parameter, or the first one selected from one
667  #    the graph-element-valued inputs. In the latter case, we hold onto
668  #    that input in original_graph_element so we can provide a more
669  #    informative error if a mismatch is found.
670  original_graph_element = None
671  for op_input in op_input_list:
672    # Determine if this is a valid graph_element.
673    # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
674    # up.
675    if (isinstance(op_input, (
676        ops.Operation, ops.Tensor, composite_tensor.CompositeTensor)) and
677        ((not isinstance(op_input, ops.Tensor))
678         or type(op_input) == ops.Tensor)):  # pylint: disable=unidiomatic-typecheck
679      graph_element = op_input
680    else:
681      graph_element = _as_graph_element(op_input)
682
683    if graph_element is not None:
684      if not graph:
685        original_graph_element = graph_element
686        graph = getattr(graph_element, 'graph', None)
687      elif original_graph_element is not None:
688        _assert_same_graph(original_graph_element, graph_element)
689      elif graph_element.graph is not graph:
690        raise ValueError('%s is not from the passed-in graph.' % graph_element)
691
692  # 2. If all else fails, we use the default graph, which is always there.
693  return graph or current_default_graph
694
695
696def _get_session(op_input_list=()):
697  """Returns the session object for the current thread."""
698  global _SESSION
699  default_session = ops.get_default_session()
700  if default_session is not None:
701    session = default_session
702  else:
703    if ops.inside_function():
704      raise RuntimeError('Cannot get session inside Tensorflow graph function.')
705    # If we don't have a session, or that session does not match the current
706    # graph, create and cache a new session.
707    if (getattr(_SESSION, 'session', None) is None or
708        _SESSION.session.graph is not _current_graph(op_input_list)):
709      # If we are creating the Session inside a tf.distribute.Strategy scope,
710      # we ask the strategy for the right session options to use.
711      if distribution_strategy_context.has_strategy():
712        configure_and_create_distributed_session(
713            distribution_strategy_context.get_strategy())
714      else:
715        _SESSION.session = session_module.Session(
716            config=get_default_session_config())
717    session = _SESSION.session
718  return session
719
720
721@keras_export(v1=['keras.backend.get_session'])
722def get_session(op_input_list=()):
723  """Returns the TF session to be used by the backend.
724
725  If a default TensorFlow session is available, we will return it.
726
727  Else, we will return the global Keras session assuming it matches
728  the current graph.
729
730  If no global Keras session exists at this point:
731  we will create a new global session.
732
733  Note that you can manually set the global session
734  via `K.set_session(sess)`.
735
736  Args:
737      op_input_list: An option sequence of tensors or ops, which will be used
738        to determine the current graph. Otherwise the default graph will be
739        used.
740
741  Returns:
742      A TensorFlow session.
743  """
744  session = _get_session(op_input_list)
745  if not _MANUAL_VAR_INIT:
746    with session.graph.as_default():
747      _initialize_variables(session)
748  return session
749
750# Inject the get_session function to keras_deps to remove the dependency
751# from TFLite to Keras.
752keras_deps.register_get_session_function(get_session)
753
754# Inject the get_session function to tracking_util to avoid the backward
755# dependency from TF to Keras.
756tracking_util.register_session_provider(get_session)
757
758
759def get_graph():
760  if context.executing_eagerly():
761    global _GRAPH
762    if not getattr(_GRAPH, 'graph', None):
763      _GRAPH.graph = func_graph.FuncGraph('keras_graph')
764    return _GRAPH.graph
765  else:
766    return ops.get_default_graph()
767
768
769@tf_contextlib.contextmanager
770def _scratch_graph(graph=None):
771  """Retrieve a shared and temporary func graph.
772
773  The eager execution path lifts a subgraph from the keras global graph into
774  a scratch graph in order to create a function. DistributionStrategies, in
775  turn, constructs multiple functions as well as a final combined function. In
776  order for that logic to work correctly, all of the functions need to be
777  created on the same scratch FuncGraph.
778
779  Args:
780    graph: A graph to be used as the current scratch graph. If not set then
781      a scratch graph will either be retrieved or created:
782
783  Yields:
784    The current scratch graph.
785  """
786  global _CURRENT_SCRATCH_GRAPH
787  scratch_graph = getattr(_CURRENT_SCRATCH_GRAPH, 'graph', None)
788  # If scratch graph and `graph` are both configured, they must match.
789  if (scratch_graph is not None and graph is not None and
790      scratch_graph is not graph):
791    raise ValueError('Multiple scratch graphs specified.')
792
793  if scratch_graph:
794    yield scratch_graph
795    return
796
797  graph = graph or func_graph.FuncGraph('keras_scratch_graph')
798  try:
799    _CURRENT_SCRATCH_GRAPH.graph = graph
800    yield graph
801  finally:
802    _CURRENT_SCRATCH_GRAPH.graph = None
803
804
805@keras_export(v1=['keras.backend.set_session'])
806def set_session(session):
807  """Sets the global TensorFlow session.
808
809  Args:
810      session: A TF Session.
811  """
812  global _SESSION
813  _SESSION.session = session
814
815
816def get_default_session_config():
817  if os.environ.get('OMP_NUM_THREADS'):
818    logging.warning(
819        'OMP_NUM_THREADS is no longer used by the default Keras config. '
820        'To configure the number of threads, use tf.config.threading APIs.')
821
822  config = get_config()
823  config.allow_soft_placement = True
824
825  return config
826
827
828def get_default_graph_uid_map():
829  graph = ops.get_default_graph()
830  name_uid_map = PER_GRAPH_OBJECT_NAME_UIDS.get(graph, None)
831  if name_uid_map is None:
832    name_uid_map = collections.defaultdict(int)
833    PER_GRAPH_OBJECT_NAME_UIDS[graph] = name_uid_map
834  return name_uid_map
835
836
837# DEVICE MANIPULATION
838
839
840class _TfDeviceCaptureOp:
841  """Class for capturing the TF device scope."""
842
843  def __init__(self):
844    self.device = None
845
846  def _set_device(self, device):
847    """This method captures TF's explicit device scope setting."""
848    if isinstance(device, device_spec.DeviceSpecV2):
849      device = device.to_string()
850    self.device = device
851
852  def _set_device_from_string(self, device_str):
853    self.device = device_str
854
855
856def _get_current_tf_device():
857  """Return explicit device of current context, otherwise returns `None`.
858
859  Returns:
860      If the current device scope is explicitly set, it returns a string with
861      the device (`CPU` or `GPU`). If the scope is not explicitly set, it will
862      return `None`.
863  """
864  graph = get_graph()
865  op = _TfDeviceCaptureOp()
866  graph._apply_device_functions(op)
867  if tf2.enabled():
868    return device_spec.DeviceSpecV2.from_string(op.device)
869  else:
870    return device_spec.DeviceSpecV1.from_string(op.device)
871
872
873def _is_current_explicit_device(device_type):
874  """Check if the current device is explicitly set on the device type specified.
875
876  Args:
877      device_type: A string containing `GPU` or `CPU` (case-insensitive).
878
879  Returns:
880      A boolean indicating if the current device scope is explicitly set on the
881      device type.
882
883  Raises:
884      ValueError: If the `device_type` string indicates an unsupported device.
885  """
886  device_type = device_type.upper()
887  if device_type not in ['CPU', 'GPU']:
888    raise ValueError('`device_type` should be either "CPU" or "GPU".')
889  device = _get_current_tf_device()
890  return device is not None and device.device_type == device_type.upper()
891
892
893def _get_available_gpus():
894  """Get a list of available GPU devices (formatted as strings).
895
896  Returns:
897      A list of available GPU devices.
898  """
899  if ops.executing_eagerly_outside_functions():
900    # Returns names of devices directly.
901    return [d.name for d in config.list_logical_devices('GPU')]
902
903  global _LOCAL_DEVICES
904  if _LOCAL_DEVICES is None:
905    _LOCAL_DEVICES = get_session().list_devices()
906  return [x.name for x in _LOCAL_DEVICES if x.device_type == 'GPU']
907
908
909def _has_nchw_support():
910  """Check whether the current scope supports NCHW ops.
911
912  TensorFlow does not support NCHW on CPU. Therefore we check if we are not
913  explicitly put on
914  CPU, and have GPUs available. In this case there will be soft-placing on the
915  GPU device.
916
917  Returns:
918      bool: if the current scope device placement would support nchw
919  """
920  explicitly_on_cpu = _is_current_explicit_device('CPU')
921  gpus_available = bool(_get_available_gpus())
922  return not explicitly_on_cpu and gpus_available
923
924
925# VARIABLE MANIPULATION
926
927
928def _constant_to_tensor(x, dtype):
929  """Convert the input `x` to a tensor of type `dtype`.
930
931  This is slightly faster than the _to_tensor function, at the cost of
932  handling fewer cases.
933
934  Args:
935      x: An object to be converted (numpy arrays, floats, ints and lists of
936        them).
937      dtype: The destination type.
938
939  Returns:
940      A tensor.
941  """
942  return constant_op.constant(x, dtype=dtype)
943
944
945def _to_tensor(x, dtype):
946  """Convert the input `x` to a tensor of type `dtype`.
947
948  Args:
949      x: An object to be converted (numpy array, list, tensors).
950      dtype: The destination type.
951
952  Returns:
953      A tensor.
954  """
955  return ops.convert_to_tensor_v2_with_dispatch(x, dtype=dtype)
956
957
958@keras_export('keras.backend.is_sparse')
959@doc_controls.do_not_generate_docs
960def is_sparse(tensor):
961  """Returns whether a tensor is a sparse tensor.
962
963  Args:
964      tensor: A tensor instance.
965
966  Returns:
967      A boolean.
968
969  Example:
970
971
972  >>> a = tf.keras.backend.placeholder((2, 2), sparse=False)
973  >>> print(tf.keras.backend.is_sparse(a))
974  False
975  >>> b = tf.keras.backend.placeholder((2, 2), sparse=True)
976  >>> print(tf.keras.backend.is_sparse(b))
977  True
978
979  """
980  spec = getattr(tensor, '_type_spec', None)
981  if spec is not None:
982    return isinstance(spec, sparse_tensor.SparseTensorSpec)
983  return isinstance(tensor, sparse_tensor.SparseTensor)
984
985
986@keras_export('keras.backend.to_dense')
987@dispatch.add_dispatch_support
988@doc_controls.do_not_generate_docs
989def to_dense(tensor):
990  """Converts a sparse tensor into a dense tensor and returns it.
991
992  Args:
993      tensor: A tensor instance (potentially sparse).
994
995  Returns:
996      A dense tensor.
997
998  Examples:
999
1000
1001  >>> b = tf.keras.backend.placeholder((2, 2), sparse=True)
1002  >>> print(tf.keras.backend.is_sparse(b))
1003  True
1004  >>> c = tf.keras.backend.to_dense(b)
1005  >>> print(tf.keras.backend.is_sparse(c))
1006  False
1007
1008  """
1009  if is_sparse(tensor):
1010    return sparse_ops.sparse_tensor_to_dense(tensor)
1011  else:
1012    return tensor
1013
1014
1015@keras_export('keras.backend.name_scope', v1=[])
1016@doc_controls.do_not_generate_docs
1017def name_scope(name):
1018  """A context manager for use when defining a Python op.
1019
1020  This context manager pushes a name scope, which will make the name of all
1021  operations added within it have a prefix.
1022
1023  For example, to define a new Python op called `my_op`:
1024
1025
1026  def my_op(a):
1027    with tf.name_scope("MyOp") as scope:
1028      a = tf.convert_to_tensor(a, name="a")
1029      # Define some computation that uses `a`.
1030      return foo_op(..., name=scope)
1031
1032
1033  When executed, the Tensor `a` will have the name `MyOp/a`.
1034
1035  Args:
1036    name: The prefix to use on all names created within the name scope.
1037
1038  Returns:
1039    Name scope context manager.
1040  """
1041  return ops.name_scope_v2(name)
1042
1043# Export V1 version.
1044_v1_name_scope = ops.name_scope_v1
1045keras_export(v1=['keras.backend.name_scope'])(_v1_name_scope)
1046
1047
1048@keras_export('keras.backend.variable')
1049@doc_controls.do_not_generate_docs
1050def variable(value, dtype=None, name=None, constraint=None):
1051  """Instantiates a variable and returns it.
1052
1053  Args:
1054      value: Numpy array, initial value of the tensor.
1055      dtype: Tensor type.
1056      name: Optional name string for the tensor.
1057      constraint: Optional projection function to be
1058          applied to the variable after an optimizer update.
1059
1060  Returns:
1061      A variable instance (with Keras metadata included).
1062
1063  Examples:
1064
1065  >>> val = np.array([[1, 2], [3, 4]])
1066  >>> kvar = tf.keras.backend.variable(value=val, dtype='float64',
1067  ...                                  name='example_var')
1068  >>> tf.keras.backend.dtype(kvar)
1069  'float64'
1070  >>> print(kvar)
1071  <tf.Variable 'example_var:...' shape=(2, 2) dtype=float64, numpy=
1072    array([[1., 2.],
1073           [3., 4.]])>
1074
1075  """
1076  if dtype is None:
1077    dtype = floatx()
1078  if hasattr(value, 'tocoo'):
1079    sparse_coo = value.tocoo()
1080    indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), np.expand_dims(
1081        sparse_coo.col, 1)), 1)
1082    v = sparse_tensor.SparseTensor(
1083        indices=indices, values=sparse_coo.data, dense_shape=sparse_coo.shape)
1084    v._keras_shape = sparse_coo.shape
1085    return v
1086  v = variables_module.Variable(
1087      value,
1088      dtype=dtypes_module.as_dtype(dtype),
1089      name=name,
1090      constraint=constraint)
1091  if isinstance(value, np.ndarray):
1092    v._keras_shape = value.shape
1093  elif hasattr(value, 'shape'):
1094    v._keras_shape = int_shape(value)
1095  track_variable(v)
1096  return v
1097
1098
1099def track_tf_optimizer(tf_optimizer):
1100  """Tracks the given TF optimizer for initialization of its variables."""
1101  if context.executing_eagerly():
1102    return
1103  optimizers = _GRAPH_TF_OPTIMIZERS[None]
1104  optimizers.add(tf_optimizer)
1105
1106
1107@keras_export('keras.__internal__.backend.track_variable', v1=[])
1108def track_variable(v):
1109  """Tracks the given variable for initialization."""
1110  if context.executing_eagerly():
1111    return
1112  graph = v.graph if hasattr(v, 'graph') else get_graph()
1113  _GRAPH_VARIABLES[graph].add(v)
1114
1115
1116def observe_object_name(name):
1117  """Observe a name and make sure it won't be used by `unique_object_name`."""
1118  OBSERVED_NAMES.add(name)
1119
1120
1121def unique_object_name(name,
1122                       name_uid_map=None,
1123                       avoid_names=None,
1124                       namespace='',
1125                       zero_based=False,
1126                       avoid_observed_names=False):
1127  """Makes a object name (or arbitrary string) unique within a TensorFlow graph.
1128
1129  Args:
1130    name: String name to make unique.
1131    name_uid_map: An optional defaultdict(int) to use when creating unique
1132      names. If None (default), uses a per-Graph dictionary.
1133    avoid_names: An optional set or dict with names which should not be used. If
1134      None (default), don't avoid any names unless `avoid_observed_names` is
1135      True.
1136    namespace: Gets a name which is unique within the (graph, namespace). Layers
1137      which are not Networks use a blank namespace and so get graph-global
1138      names.
1139    zero_based: If True, name sequences start with no suffix (e.g. "dense",
1140      "dense_1"). If False, naming is one-based ("dense_1", "dense_2").
1141    avoid_observed_names: If True, avoid any names that have been observed by
1142      `backend.observe_object_name`.
1143
1144  Returns:
1145    Unique string name.
1146
1147  Example:
1148
1149
1150  unique_object_name('dense')  # dense_1
1151  unique_object_name('dense')  # dense_2
1152
1153  """
1154  if name_uid_map is None:
1155    name_uid_map = get_default_graph_uid_map()
1156  if avoid_names is None:
1157    if avoid_observed_names:
1158      avoid_names = OBSERVED_NAMES
1159    else:
1160      avoid_names = set()
1161  proposed_name = None
1162  while proposed_name is None or proposed_name in avoid_names:
1163    name_key = (namespace, name)
1164    if zero_based:
1165      number = name_uid_map[name_key]
1166      if number:
1167        proposed_name = name + '_' + str(number)
1168      else:
1169        proposed_name = name
1170      name_uid_map[name_key] += 1
1171    else:
1172      name_uid_map[name_key] += 1
1173      proposed_name = name + '_' + str(name_uid_map[name_key])
1174  return proposed_name
1175
1176
1177def _get_variables(graph=None):
1178  """Returns variables corresponding to the given graph for initialization."""
1179  assert not context.executing_eagerly()
1180  variables = _GRAPH_VARIABLES[graph]
1181  for opt in _GRAPH_TF_OPTIMIZERS[graph]:
1182    variables.update(opt.optimizer.variables())
1183  return variables
1184
1185
1186@keras_export('keras.__internal__.backend.initialize_variables', v1=[])
1187def _initialize_variables(session):
1188  """Utility to initialize uninitialized variables on the fly."""
1189  variables = _get_variables(get_graph())
1190  candidate_vars = []
1191  for v in variables:
1192    if not getattr(v, '_keras_initialized', False):
1193      candidate_vars.append(v)
1194  if candidate_vars:
1195    # This step is expensive, so we only run it on variables not already
1196    # marked as initialized.
1197    is_initialized = session.run(
1198        [variables_module.is_variable_initialized(v) for v in candidate_vars])
1199    # TODO(kathywu): Some metric variables loaded from SavedModel are never
1200    # actually used, and do not have an initializer.
1201    should_be_initialized = [
1202        (not is_initialized[n]) and v.initializer is not None
1203        for n, v in enumerate(candidate_vars)]
1204    uninitialized_vars = []
1205    for flag, v in zip(should_be_initialized, candidate_vars):
1206      if flag:
1207        uninitialized_vars.append(v)
1208      v._keras_initialized = True
1209    if uninitialized_vars:
1210      session.run(variables_module.variables_initializer(uninitialized_vars))
1211
1212
1213@keras_export('keras.backend.constant')
1214@dispatch.add_dispatch_support
1215@doc_controls.do_not_generate_docs
1216def constant(value, dtype=None, shape=None, name=None):
1217  """Creates a constant tensor.
1218
1219  Args:
1220      value: A constant value (or list)
1221      dtype: The type of the elements of the resulting tensor.
1222      shape: Optional dimensions of resulting tensor.
1223      name: Optional name for the tensor.
1224
1225  Returns:
1226      A Constant Tensor.
1227  """
1228  if dtype is None:
1229    dtype = floatx()
1230
1231  return constant_op.constant(value, dtype=dtype, shape=shape, name=name)
1232
1233
1234@keras_export('keras.backend.is_keras_tensor')
1235def is_keras_tensor(x):
1236  """Returns whether `x` is a Keras tensor.
1237
1238  A "Keras tensor" is a tensor that was returned by a Keras layer,
1239  (`Layer` class) or by `Input`.
1240
1241  Args:
1242      x: A candidate tensor.
1243
1244  Returns:
1245      A boolean: Whether the argument is a Keras tensor.
1246
1247  Raises:
1248      ValueError: In case `x` is not a symbolic tensor.
1249
1250  Examples:
1251
1252  >>> np_var = np.array([1, 2])
1253  >>> # A numpy array is not a symbolic tensor.
1254  >>> tf.keras.backend.is_keras_tensor(np_var)
1255  Traceback (most recent call last):
1256  ...
1257  ValueError: Unexpectedly found an instance of type `<class 'numpy.ndarray'>`.
1258  Expected a symbolic tensor instance.
1259  >>> keras_var = tf.keras.backend.variable(np_var)
1260  >>> # A variable created with the keras backend is not a Keras tensor.
1261  >>> tf.keras.backend.is_keras_tensor(keras_var)
1262  False
1263  >>> keras_placeholder = tf.keras.backend.placeholder(shape=(2, 4, 5))
1264  >>> # A placeholder is a Keras tensor.
1265  >>> tf.keras.backend.is_keras_tensor(keras_placeholder)
1266  True
1267  >>> keras_input = tf.keras.layers.Input([10])
1268  >>> # An Input is a Keras tensor.
1269  >>> tf.keras.backend.is_keras_tensor(keras_input)
1270  True
1271  >>> keras_layer_output = tf.keras.layers.Dense(10)(keras_input)
1272  >>> # Any Keras layer output is a Keras tensor.
1273  >>> tf.keras.backend.is_keras_tensor(keras_layer_output)
1274  True
1275
1276  """
1277  if not isinstance(x,
1278                    (ops.Tensor, variables_module.Variable,
1279                     sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor,
1280                     keras_tensor.KerasTensor)):
1281    raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) +
1282                     '`. Expected a symbolic tensor instance.')
1283  if ops.executing_eagerly_outside_functions():
1284    return isinstance(x, keras_tensor.KerasTensor)
1285  return hasattr(x, '_keras_history')
1286
1287
1288@keras_export('keras.backend.placeholder')
1289@doc_controls.do_not_generate_docs
1290def placeholder(shape=None,
1291                ndim=None,
1292                dtype=None,
1293                sparse=False,
1294                name=None,
1295                ragged=False):
1296  """Instantiates a placeholder tensor and returns it.
1297
1298  Args:
1299      shape: Shape of the placeholder
1300          (integer tuple, may include `None` entries).
1301      ndim: Number of axes of the tensor.
1302          At least one of {`shape`, `ndim`} must be specified.
1303          If both are specified, `shape` is used.
1304      dtype: Placeholder type.
1305      sparse: Boolean, whether the placeholder should have a sparse type.
1306      name: Optional name string for the placeholder.
1307      ragged: Boolean, whether the placeholder should have a ragged type.
1308          In this case, values of 'None' in the 'shape' argument represent
1309          ragged dimensions. For more information about RaggedTensors, see this
1310          [guide](https://www.tensorflow.org/guide/ragged_tensors).
1311
1312  Raises:
1313      ValueError: If called with sparse = True and ragged = True.
1314
1315  Returns:
1316      Tensor instance (with Keras metadata included).
1317
1318  Examples:
1319
1320
1321  >>> input_ph = tf.keras.backend.placeholder(shape=(2, 4, 5))
1322  >>> input_ph
1323  <KerasTensor: shape=(2, 4, 5) dtype=float32 (created by layer ...)>
1324
1325  """
1326  if sparse and ragged:
1327    raise ValueError(
1328        'Cannot set both sparse and ragged to True when creating a placeholder.'
1329    )
1330  if dtype is None:
1331    dtype = floatx()
1332  if not shape:
1333    if ndim:
1334      shape = (None,) * ndim
1335  if ops.executing_eagerly_outside_functions():
1336    if sparse:
1337      spec = sparse_tensor.SparseTensorSpec(
1338          shape=shape, dtype=dtype)
1339    elif ragged:
1340      ragged_rank = 0
1341      for i in range(1, len(shape)):
1342        # Hacky because could be tensorshape or tuple maybe?
1343        # Or just tensorshape?
1344        if shape[i] is None or (
1345            hasattr(shape[i], 'value') and
1346            shape[i].value is None):
1347          ragged_rank = i
1348      spec = ragged_tensor.RaggedTensorSpec(
1349          shape=shape, dtype=dtype, ragged_rank=ragged_rank)
1350    else:
1351      spec = tensor_spec.TensorSpec(
1352          shape=shape, dtype=dtype, name=name)
1353    x = keras_tensor.keras_tensor_from_type_spec(spec, name=name)
1354  else:
1355    with get_graph().as_default():
1356      if sparse:
1357        x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
1358      elif ragged:
1359        ragged_rank = 0
1360        for i in range(1, len(shape)):
1361          if shape[i] is None:
1362            ragged_rank = i
1363        type_spec = ragged_tensor.RaggedTensorSpec(
1364            shape=shape, dtype=dtype, ragged_rank=ragged_rank)
1365        def tensor_spec_to_placeholder(tensorspec):
1366          return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
1367        x = nest.map_structure(tensor_spec_to_placeholder, type_spec,
1368                               expand_composites=True)
1369      else:
1370        x = array_ops.placeholder(dtype, shape=shape, name=name)
1371
1372  if context.executing_eagerly():
1373    # Add keras_history connectivity information to the placeholder
1374    # when the placeholder is built in a top-level eager context
1375    # (intended to be used with keras.backend.function)
1376    from tensorflow.python.keras.engine import input_layer  # pylint: disable=g-import-not-at-top
1377    x = input_layer.Input(tensor=x)
1378    x._is_backend_placeholder = True
1379
1380  return x
1381
1382
1383def is_placeholder(x):
1384  """Returns whether `x` is a placeholder.
1385
1386  Args:
1387      x: A candidate placeholder.
1388
1389  Returns:
1390      Boolean.
1391  """
1392  try:
1393    if ops.executing_eagerly_outside_functions():
1394      return hasattr(x, '_is_backend_placeholder')
1395    from tensorflow.python.keras.utils import tf_utils  # pylint: disable=g-import-not-at-top
1396    if tf_utils.is_extension_type(x):
1397      flat_components = nest.flatten(x, expand_composites=True)
1398      return py_any(is_placeholder(c) for c in flat_components)
1399    else:
1400      return x.op.type == 'Placeholder'
1401  except AttributeError:
1402    return False
1403
1404
1405@keras_export('keras.backend.shape')
1406@dispatch.add_dispatch_support
1407@doc_controls.do_not_generate_docs
1408def shape(x):
1409  """Returns the symbolic shape of a tensor or variable.
1410
1411  Args:
1412      x: A tensor or variable.
1413
1414  Returns:
1415      A symbolic shape (which is itself a tensor).
1416
1417  Examples:
1418
1419  >>> val = np.array([[1, 2], [3, 4]])
1420  >>> kvar = tf.keras.backend.variable(value=val)
1421  >>> tf.keras.backend.shape(kvar)
1422  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 2], dtype=int32)>
1423  >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1424  >>> tf.keras.backend.shape(input)
1425  <KerasTensor: shape=(3,) dtype=int32 inferred_value=[2, 4, 5] ...>
1426
1427  """
1428  return array_ops.shape(x)
1429
1430
1431@keras_export('keras.backend.int_shape')
1432@doc_controls.do_not_generate_docs
1433def int_shape(x):
1434  """Returns the shape of tensor or variable as a tuple of int or None entries.
1435
1436  Args:
1437      x: Tensor or variable.
1438
1439  Returns:
1440      A tuple of integers (or None entries).
1441
1442  Examples:
1443
1444  >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1445  >>> tf.keras.backend.int_shape(input)
1446  (2, 4, 5)
1447  >>> val = np.array([[1, 2], [3, 4]])
1448  >>> kvar = tf.keras.backend.variable(value=val)
1449  >>> tf.keras.backend.int_shape(kvar)
1450  (2, 2)
1451
1452  """
1453  try:
1454    shape = x.shape
1455    if not isinstance(shape, tuple):
1456      shape = tuple(shape.as_list())
1457    return shape
1458  except ValueError:
1459    return None
1460
1461
1462@keras_export('keras.backend.ndim')
1463@doc_controls.do_not_generate_docs
1464def ndim(x):
1465  """Returns the number of axes in a tensor, as an integer.
1466
1467  Args:
1468      x: Tensor or variable.
1469
1470  Returns:
1471      Integer (scalar), number of axes.
1472
1473  Examples:
1474
1475
1476  >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1477  >>> val = np.array([[1, 2], [3, 4]])
1478  >>> kvar = tf.keras.backend.variable(value=val)
1479  >>> tf.keras.backend.ndim(input)
1480  3
1481  >>> tf.keras.backend.ndim(kvar)
1482  2
1483
1484  """
1485  return x.shape.rank
1486
1487
1488@keras_export('keras.backend.dtype')
1489@dispatch.add_dispatch_support
1490@doc_controls.do_not_generate_docs
1491def dtype(x):
1492  """Returns the dtype of a Keras tensor or variable, as a string.
1493
1494  Args:
1495      x: Tensor or variable.
1496
1497  Returns:
1498      String, dtype of `x`.
1499
1500  Examples:
1501
1502  >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5)))
1503  'float32'
1504  >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5),
1505  ...                                                     dtype='float32'))
1506  'float32'
1507  >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5),
1508  ...                                                     dtype='float64'))
1509  'float64'
1510  >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]))
1511  >>> tf.keras.backend.dtype(kvar)
1512  'float32'
1513  >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]),
1514  ...                                  dtype='float32')
1515  >>> tf.keras.backend.dtype(kvar)
1516  'float32'
1517
1518  """
1519  return x.dtype.base_dtype.name
1520
1521
1522@doc_controls.do_not_generate_docs
1523def dtype_numpy(x):
1524  """Returns the numpy dtype of a Keras tensor or variable.
1525
1526  Args:
1527      x: Tensor or variable.
1528
1529  Returns:
1530      numpy.dtype, dtype of `x`.
1531  """
1532  return dtypes_module.as_dtype(x.dtype).as_numpy_dtype
1533
1534
1535@keras_export('keras.backend.eval')
1536@doc_controls.do_not_generate_docs
1537def eval(x):
1538  """Evaluates the value of a variable.
1539
1540  Args:
1541      x: A variable.
1542
1543  Returns:
1544      A Numpy array.
1545
1546  Examples:
1547
1548  >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]),
1549  ...                                  dtype='float32')
1550  >>> tf.keras.backend.eval(kvar)
1551  array([[1.,  2.],
1552         [3.,  4.]], dtype=float32)
1553
1554  """
1555  return get_value(to_dense(x))
1556
1557
1558@keras_export('keras.backend.zeros')
1559@doc_controls.do_not_generate_docs
1560def zeros(shape, dtype=None, name=None):
1561  """Instantiates an all-zeros variable and returns it.
1562
1563  Args:
1564      shape: Tuple or list of integers, shape of returned Keras variable
1565      dtype: data type of returned Keras variable
1566      name: name of returned Keras variable
1567
1568  Returns:
1569      A variable (including Keras metadata), filled with `0.0`.
1570      Note that if `shape` was symbolic, we cannot return a variable,
1571      and will return a dynamically-shaped tensor instead.
1572
1573  Example:
1574
1575  >>> kvar = tf.keras.backend.zeros((3,4))
1576  >>> tf.keras.backend.eval(kvar)
1577  array([[0.,  0.,  0.,  0.],
1578         [0.,  0.,  0.,  0.],
1579         [0.,  0.,  0.,  0.]], dtype=float32)
1580  >>> A = tf.constant([1,2,3])
1581  >>> kvar2 = tf.keras.backend.zeros(A.shape) # [0., 0., 0.]
1582  >>> tf.keras.backend.eval(kvar2)
1583  array([0., 0., 0.], dtype=float32)
1584  >>> kvar3 = tf.keras.backend.zeros(A.shape,dtype=tf.int32)
1585  >>> tf.keras.backend.eval(kvar3)
1586  array([0, 0, 0], dtype=int32)
1587  >>> kvar4 = tf.keras.backend.zeros([2,3])
1588  >>> tf.keras.backend.eval(kvar4)
1589  array([[0., 0., 0.],
1590         [0., 0., 0.]], dtype=float32)
1591
1592  """
1593  with ops.init_scope():
1594    if dtype is None:
1595      dtype = floatx()
1596    tf_dtype = dtypes_module.as_dtype(dtype)
1597    v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
1598    if py_all(v.shape.as_list()):
1599      return variable(v, dtype=dtype, name=name)
1600    return v
1601
1602
1603@keras_export('keras.backend.ones')
1604@dispatch.add_dispatch_support
1605@doc_controls.do_not_generate_docs
1606def ones(shape, dtype=None, name=None):
1607  """Instantiates an all-ones variable and returns it.
1608
1609  Args:
1610      shape: Tuple of integers, shape of returned Keras variable.
1611      dtype: String, data type of returned Keras variable.
1612      name: String, name of returned Keras variable.
1613
1614  Returns:
1615      A Keras variable, filled with `1.0`.
1616      Note that if `shape` was symbolic, we cannot return a variable,
1617      and will return a dynamically-shaped tensor instead.
1618
1619  Example:
1620
1621
1622  >>> kvar = tf.keras.backend.ones((3,4))
1623  >>> tf.keras.backend.eval(kvar)
1624  array([[1.,  1.,  1.,  1.],
1625         [1.,  1.,  1.,  1.],
1626         [1.,  1.,  1.,  1.]], dtype=float32)
1627
1628  """
1629  with ops.init_scope():
1630    if dtype is None:
1631      dtype = floatx()
1632    tf_dtype = dtypes_module.as_dtype(dtype)
1633    v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
1634    if py_all(v.shape.as_list()):
1635      return variable(v, dtype=dtype, name=name)
1636    return v
1637
1638
1639@keras_export('keras.backend.eye')
1640@dispatch.add_dispatch_support
1641@doc_controls.do_not_generate_docs
1642def eye(size, dtype=None, name=None):
1643  """Instantiate an identity matrix and returns it.
1644
1645  Args:
1646      size: Integer, number of rows/columns.
1647      dtype: String, data type of returned Keras variable.
1648      name: String, name of returned Keras variable.
1649
1650  Returns:
1651      A Keras variable, an identity matrix.
1652
1653  Example:
1654
1655
1656  >>> kvar = tf.keras.backend.eye(3)
1657  >>> tf.keras.backend.eval(kvar)
1658  array([[1.,  0.,  0.],
1659         [0.,  1.,  0.],
1660         [0.,  0.,  1.]], dtype=float32)
1661
1662
1663  """
1664  if dtype is None:
1665    dtype = floatx()
1666  tf_dtype = dtypes_module.as_dtype(dtype)
1667  return variable(linalg_ops.eye(size, dtype=tf_dtype), dtype, name)
1668
1669
1670@keras_export('keras.backend.zeros_like')
1671@doc_controls.do_not_generate_docs
1672def zeros_like(x, dtype=None, name=None):
1673  """Instantiates an all-zeros variable of the same shape as another tensor.
1674
1675  Args:
1676      x: Keras variable or Keras tensor.
1677      dtype: dtype of returned Keras variable.
1678             `None` uses the dtype of `x`.
1679      name: name for the variable to create.
1680
1681  Returns:
1682      A Keras variable with the shape of `x` filled with zeros.
1683
1684  Example:
1685
1686  ```python
1687  from tensorflow.keras import backend as K
1688  kvar = K.variable(np.random.random((2,3)))
1689  kvar_zeros = K.zeros_like(kvar)
1690  K.eval(kvar_zeros)
1691  # array([[ 0.,  0.,  0.], [ 0.,  0.,  0.]], dtype=float32)
1692  ```
1693  """
1694  return array_ops.zeros_like(x, dtype=dtype, name=name)
1695
1696
1697@keras_export('keras.backend.ones_like')
1698@dispatch.add_dispatch_support
1699@doc_controls.do_not_generate_docs
1700def ones_like(x, dtype=None, name=None):
1701  """Instantiates an all-ones variable of the same shape as another tensor.
1702
1703  Args:
1704      x: Keras variable or tensor.
1705      dtype: String, dtype of returned Keras variable.
1706           None uses the dtype of x.
1707      name: String, name for the variable to create.
1708
1709  Returns:
1710      A Keras variable with the shape of x filled with ones.
1711
1712  Example:
1713
1714  >>> kvar = tf.keras.backend.variable(np.random.random((2,3)))
1715  >>> kvar_ones = tf.keras.backend.ones_like(kvar)
1716  >>> tf.keras.backend.eval(kvar_ones)
1717  array([[1.,  1.,  1.],
1718         [1.,  1.,  1.]], dtype=float32)
1719
1720  """
1721  return array_ops.ones_like(x, dtype=dtype, name=name)
1722
1723
1724def identity(x, name=None):
1725  """Returns a tensor with the same content as the input tensor.
1726
1727  Args:
1728      x: The input tensor.
1729      name: String, name for the variable to create.
1730
1731  Returns:
1732      A tensor of the same shape, type and content.
1733  """
1734  return array_ops.identity(x, name=name)
1735
1736
1737@keras_export('keras.backend.random_uniform_variable')
1738@doc_controls.do_not_generate_docs
1739def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):
1740  """Instantiates a variable with values drawn from a uniform distribution.
1741
1742  Args:
1743      shape: Tuple of integers, shape of returned Keras variable.
1744      low: Float, lower boundary of the output interval.
1745      high: Float, upper boundary of the output interval.
1746      dtype: String, dtype of returned Keras variable.
1747      name: String, name of returned Keras variable.
1748      seed: Integer, random seed.
1749
1750  Returns:
1751      A Keras variable, filled with drawn samples.
1752
1753  Example:
1754
1755  >>> kvar = tf.keras.backend.random_uniform_variable(shape=(2,3),
1756  ... low=0.0, high=1.0)
1757  >>> kvar
1758  <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
1759  dtype=float32)>
1760  """
1761  if dtype is None:
1762    dtype = floatx()
1763  tf_dtype = dtypes_module.as_dtype(dtype)
1764  if seed is None:
1765    # ensure that randomness is conditioned by the Numpy RNG
1766    seed = np.random.randint(10e8)
1767  value = init_ops.random_uniform_initializer(
1768      low, high, dtype=tf_dtype, seed=seed)(shape)
1769  return variable(value, dtype=dtype, name=name)
1770
1771
1772@keras_export('keras.backend.random_normal_variable')
1773@doc_controls.do_not_generate_docs
1774def random_normal_variable(shape, mean, scale, dtype=None, name=None,
1775                           seed=None):
1776  """Instantiates a variable with values drawn from a normal distribution.
1777
1778  Args:
1779      shape: Tuple of integers, shape of returned Keras variable.
1780      mean: Float, mean of the normal distribution.
1781      scale: Float, standard deviation of the normal distribution.
1782      dtype: String, dtype of returned Keras variable.
1783      name: String, name of returned Keras variable.
1784      seed: Integer, random seed.
1785
1786  Returns:
1787      A Keras variable, filled with drawn samples.
1788
1789  Example:
1790
1791  >>> kvar = tf.keras.backend.random_normal_variable(shape=(2,3),
1792  ... mean=0.0, scale=1.0)
1793  >>> kvar
1794  <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
1795  dtype=float32)>
1796  """
1797  if dtype is None:
1798    dtype = floatx()
1799  tf_dtype = dtypes_module.as_dtype(dtype)
1800  if seed is None:
1801    # ensure that randomness is conditioned by the Numpy RNG
1802    seed = np.random.randint(10e8)
1803  value = init_ops.random_normal_initializer(
1804      mean, scale, dtype=tf_dtype, seed=seed)(shape)
1805  return variable(value, dtype=dtype, name=name)
1806
1807
1808@keras_export('keras.backend.count_params')
1809@doc_controls.do_not_generate_docs
1810def count_params(x):
1811  """Returns the static number of elements in a variable or tensor.
1812
1813  Args:
1814      x: Variable or tensor.
1815
1816  Returns:
1817      Integer, the number of scalars in `x`.
1818
1819  Example:
1820
1821  >>> kvar = tf.keras.backend.zeros((2,3))
1822  >>> tf.keras.backend.count_params(kvar)
1823  6
1824  >>> tf.keras.backend.eval(kvar)
1825  array([[0.,  0.,  0.],
1826         [0.,  0.,  0.]], dtype=float32)
1827
1828  """
1829  return np.prod(x.shape.as_list())
1830
1831
1832@keras_export('keras.backend.cast')
1833@dispatch.add_dispatch_support
1834@doc_controls.do_not_generate_docs
1835def cast(x, dtype):
1836  """Casts a tensor to a different dtype and returns it.
1837
1838  You can cast a Keras variable but it still returns a Keras tensor.
1839
1840  Args:
1841      x: Keras tensor (or variable).
1842      dtype: String, either (`'float16'`, `'float32'`, or `'float64'`).
1843
1844  Returns:
1845      Keras tensor with dtype `dtype`.
1846
1847  Examples:
1848      Cast a float32 variable to a float64 tensor
1849
1850  >>> input = tf.keras.backend.ones(shape=(1,3))
1851  >>> print(input)
1852  <tf.Variable 'Variable:0' shape=(1, 3) dtype=float32,
1853  numpy=array([[1., 1., 1.]], dtype=float32)>
1854  >>> cast_input = tf.keras.backend.cast(input, dtype='float64')
1855  >>> print(cast_input)
1856  tf.Tensor([[1. 1. 1.]], shape=(1, 3), dtype=float64)
1857
1858  """
1859  return math_ops.cast(x, dtype)
1860
1861
1862# UPDATES OPS
1863
1864
1865@keras_export('keras.backend.update')
1866@doc_controls.do_not_generate_docs
1867def update(x, new_x):
1868  return state_ops.assign(x, new_x)
1869
1870
1871@keras_export('keras.backend.update_add')
1872@doc_controls.do_not_generate_docs
1873def update_add(x, increment):
1874  """Update the value of `x` by adding `increment`.
1875
1876  Args:
1877      x: A Variable.
1878      increment: A tensor of same shape as `x`.
1879
1880  Returns:
1881      The variable `x` updated.
1882  """
1883  return state_ops.assign_add(x, increment)
1884
1885
1886@keras_export('keras.backend.update_sub')
1887@doc_controls.do_not_generate_docs
1888def update_sub(x, decrement):
1889  """Update the value of `x` by subtracting `decrement`.
1890
1891  Args:
1892      x: A Variable.
1893      decrement: A tensor of same shape as `x`.
1894
1895  Returns:
1896      The variable `x` updated.
1897  """
1898  return state_ops.assign_sub(x, decrement)
1899
1900
1901@keras_export('keras.backend.moving_average_update')
1902@doc_controls.do_not_generate_docs
1903def moving_average_update(x, value, momentum):
1904  """Compute the exponential moving average of a value.
1905
1906  The moving average 'x' is updated with 'value' following:
1907
1908  ```
1909  x = x * momentum + value * (1 - momentum)
1910  ```
1911
1912  For example:
1913
1914  >>> x = tf.Variable(0.0)
1915  >>> momentum=0.9
1916  >>> moving_average_update(x, value = 2.0, momentum=momentum).numpy()
1917  >>> x.numpy()
1918  0.2
1919
1920  The result will be biased towards the initial value of the variable.
1921
1922  If the variable was initialized to zero, you can divide by
1923  `1 - momentum ** num_updates` to debias it (Section 3 of
1924  [Kingma et al., 2015](https://arxiv.org/abs/1412.6980)):
1925
1926  >>> num_updates = 1.0
1927  >>> x_zdb = x/(1 - momentum**num_updates)
1928  >>> x_zdb.numpy()
1929  2.0
1930
1931  Args:
1932      x: A Variable, the moving average.
1933      value: A tensor with the same shape as `x`, the new value to be
1934        averaged in.
1935      momentum: The moving average momentum.
1936
1937  Returns:
1938      The updated variable.
1939  """
1940  if tf2.enabled():
1941    momentum = math_ops.cast(momentum, x.dtype)
1942    value = math_ops.cast(value, x.dtype)
1943    return x.assign(x * momentum + value * (1 - momentum))
1944  else:
1945    return moving_averages.assign_moving_average(
1946        x, value, momentum, zero_debias=True)
1947
1948
1949# LINEAR ALGEBRA
1950
1951
1952@keras_export('keras.backend.dot')
1953@dispatch.add_dispatch_support
1954@doc_controls.do_not_generate_docs
1955def dot(x, y):
1956  """Multiplies 2 tensors (and/or variables) and returns a tensor.
1957
1958  This operation corresponds to `numpy.dot(a, b, out=None)`.
1959
1960  Args:
1961      x: Tensor or variable.
1962      y: Tensor or variable.
1963
1964  Returns:
1965      A tensor, dot product of `x` and `y`.
1966
1967  Examples:
1968
1969  If inputs `x` and `y` are 2-D arrays, then it is equivalent to `tf.matmul`.
1970  >>> x = tf.keras.backend.placeholder(shape=(2, 3))
1971  >>> y = tf.keras.backend.placeholder(shape=(3, 4))
1972  >>> xy = tf.keras.backend.dot(x, y)
1973  >>> xy
1974  <KerasTensor: shape=(2, 4) dtype=float32 ...>
1975
1976  >>> x = tf.keras.backend.placeholder(shape=(32, 28, 3))
1977  >>> y = tf.keras.backend.placeholder(shape=(3, 4))
1978  >>> xy = tf.keras.backend.dot(x, y)
1979  >>> xy
1980  <KerasTensor: shape=(32, 28, 4) dtype=float32 ...>
1981
1982  If `x` is an N-D array and `y` is an M-D array (where M>=2), it is a sum
1983  product over the last axis of `x` and the second-to-last axis of `y`.
1984  >>> x = tf.keras.backend.random_uniform_variable(shape=(2, 3), low=0, high=1)
1985  >>> y = tf.keras.backend.ones((4, 3, 5))
1986  >>> xy = tf.keras.backend.dot(x, y)
1987  >>> tf.keras.backend.int_shape(xy)
1988  (2, 4, 5)
1989  """
1990  if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
1991    x_shape = []
1992    for i, s in zip(int_shape(x), array_ops.unstack(array_ops.shape(x))):
1993      if i is not None:
1994        x_shape.append(i)
1995      else:
1996        x_shape.append(s)
1997    x_shape = tuple(x_shape)
1998    y_shape = []
1999    for i, s in zip(int_shape(y), array_ops.unstack(array_ops.shape(y))):
2000      if i is not None:
2001        y_shape.append(i)
2002      else:
2003        y_shape.append(s)
2004    y_shape = tuple(y_shape)
2005    y_permute_dim = list(range(ndim(y)))
2006    y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim
2007    xt = array_ops.reshape(x, [-1, x_shape[-1]])
2008    yt = array_ops.reshape(
2009        array_ops.transpose(y, perm=y_permute_dim), [y_shape[-2], -1])
2010    return array_ops.reshape(
2011        math_ops.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:])
2012  if is_sparse(x):
2013    out = sparse_ops.sparse_tensor_dense_matmul(x, y)
2014  else:
2015    out = math_ops.matmul(x, y)
2016  return out
2017
2018
2019@keras_export('keras.backend.batch_dot')
2020@dispatch.add_dispatch_support
2021@doc_controls.do_not_generate_docs
2022def batch_dot(x, y, axes=None):
2023  """Batchwise dot product.
2024
2025  `batch_dot` is used to compute dot product of `x` and `y` when
2026  `x` and `y` are data in batch, i.e. in a shape of
2027  `(batch_size, :)`.
2028  `batch_dot` results in a tensor or variable with less dimensions
2029  than the input. If the number of dimensions is reduced to 1,
2030  we use `expand_dims` to make sure that ndim is at least 2.
2031
2032  Args:
2033    x: Keras tensor or variable with `ndim >= 2`.
2034    y: Keras tensor or variable with `ndim >= 2`.
2035    axes: Tuple or list of integers with target dimensions, or single integer.
2036      The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]` should be equal.
2037
2038  Returns:
2039    A tensor with shape equal to the concatenation of `x`'s shape
2040    (less the dimension that was summed over) and `y`'s shape
2041    (less the batch dimension and the dimension that was summed over).
2042    If the final rank is 1, we reshape it to `(batch_size, 1)`.
2043
2044  Examples:
2045
2046  >>> x_batch = tf.keras.backend.ones(shape=(32, 20, 1))
2047  >>> y_batch = tf.keras.backend.ones(shape=(32, 30, 20))
2048  >>> xy_batch_dot = tf.keras.backend.batch_dot(x_batch, y_batch, axes=(1, 2))
2049  >>> tf.keras.backend.int_shape(xy_batch_dot)
2050  (32, 1, 30)
2051
2052  Shape inference:
2053    Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.
2054    If `axes` is (1, 2), to find the output shape of resultant tensor,
2055        loop through each dimension in `x`'s shape and `y`'s shape:
2056    * `x.shape[0]` : 100 : append to output shape
2057    * `x.shape[1]` : 20 : do not append to output shape,
2058        dimension 1 of `x` has been summed over. (`dot_axes[0]` = 1)
2059    * `y.shape[0]` : 100 : do not append to output shape,
2060        always ignore first dimension of `y`
2061    * `y.shape[1]` : 30 : append to output shape
2062    * `y.shape[2]` : 20 : do not append to output shape,
2063        dimension 2 of `y` has been summed over. (`dot_axes[1]` = 2)
2064    `output_shape` = `(100, 30)`
2065  """
2066  x_shape = int_shape(x)
2067  y_shape = int_shape(y)
2068
2069  x_ndim = len(x_shape)
2070  y_ndim = len(y_shape)
2071
2072  if x_ndim < 2 or y_ndim < 2:
2073    raise ValueError('Cannot do batch_dot on inputs '
2074                     'with rank < 2. '
2075                     'Received inputs with shapes ' +
2076                     str(x_shape) + ' and ' +
2077                     str(y_shape) + '.')
2078
2079  x_batch_size = x_shape[0]
2080  y_batch_size = y_shape[0]
2081
2082  if x_batch_size is not None and y_batch_size is not None:
2083    if x_batch_size != y_batch_size:
2084      raise ValueError('Cannot do batch_dot on inputs '
2085                       'with different batch sizes. '
2086                       'Received inputs with shapes ' +
2087                       str(x_shape) + ' and ' +
2088                       str(y_shape) + '.')
2089  if isinstance(axes, int):
2090    axes = [axes, axes]
2091
2092  if axes is None:
2093    if y_ndim == 2:
2094      axes = [x_ndim - 1, y_ndim - 1]
2095    else:
2096      axes = [x_ndim - 1, y_ndim - 2]
2097
2098  if py_any(isinstance(a, (list, tuple)) for a in axes):
2099    raise ValueError('Multiple target dimensions are not supported. ' +
2100                     'Expected: None, int, (int, int), ' +
2101                     'Provided: ' + str(axes))
2102
2103  # if tuple, convert to list.
2104  axes = list(axes)
2105
2106  # convert negative indices.
2107  if axes[0] < 0:
2108    axes[0] += x_ndim
2109  if axes[1] < 0:
2110    axes[1] += y_ndim
2111
2112  # sanity checks
2113  if 0 in axes:
2114    raise ValueError('Cannot perform batch_dot over axis 0. '
2115                     'If your inputs are not batched, '
2116                     'add a dummy batch dimension to your '
2117                     'inputs using K.expand_dims(x, 0)')
2118  a0, a1 = axes
2119  d1 = x_shape[a0]
2120  d2 = y_shape[a1]
2121
2122  if d1 is not None and d2 is not None and d1 != d2:
2123    raise ValueError('Cannot do batch_dot on inputs with shapes ' +
2124                     str(x_shape) + ' and ' + str(y_shape) +
2125                     ' with axes=' + str(axes) + '. x.shape[%d] != '
2126                     'y.shape[%d] (%d != %d).' % (axes[0], axes[1], d1, d2))
2127
2128  # backup ndims. Need them later.
2129  orig_x_ndim = x_ndim
2130  orig_y_ndim = y_ndim
2131
2132  # if rank is 2, expand to 3.
2133  if x_ndim == 2:
2134    x = array_ops.expand_dims(x, 1)
2135    a0 += 1
2136    x_ndim += 1
2137  if y_ndim == 2:
2138    y = array_ops.expand_dims(y, 2)
2139    y_ndim += 1
2140
2141  # bring x's dimension to be reduced to last axis.
2142  if a0 != x_ndim - 1:
2143    pattern = list(range(x_ndim))
2144    for i in range(a0, x_ndim - 1):
2145      pattern[i] = pattern[i + 1]
2146    pattern[-1] = a0
2147    x = array_ops.transpose(x, pattern)
2148
2149  # bring y's dimension to be reduced to axis 1.
2150  if a1 != 1:
2151    pattern = list(range(y_ndim))
2152    for i in range(a1, 1, -1):
2153      pattern[i] = pattern[i - 1]
2154    pattern[1] = a1
2155    y = array_ops.transpose(y, pattern)
2156
2157  # normalize both inputs to rank 3.
2158  if x_ndim > 3:
2159    # squash middle dimensions of x.
2160    x_shape = shape(x)
2161    x_mid_dims = x_shape[1:-1]
2162    x_squashed_shape = array_ops.stack(
2163        [x_shape[0], -1, x_shape[-1]])
2164    x = array_ops.reshape(x, x_squashed_shape)
2165    x_squashed = True
2166  else:
2167    x_squashed = False
2168
2169  if y_ndim > 3:
2170    # squash trailing dimensions of y.
2171    y_shape = shape(y)
2172    y_trail_dims = y_shape[2:]
2173    y_squashed_shape = array_ops.stack(
2174        [y_shape[0], y_shape[1], -1])
2175    y = array_ops.reshape(y, y_squashed_shape)
2176    y_squashed = True
2177  else:
2178    y_squashed = False
2179
2180  result = math_ops.matmul(x, y)
2181
2182  # if inputs were squashed, we have to reshape the matmul output.
2183  output_shape = array_ops.shape(result)
2184  do_reshape = False
2185
2186  if x_squashed:
2187    output_shape = array_ops.concat(
2188        [output_shape[:1],
2189         x_mid_dims,
2190         output_shape[-1:]], 0)
2191    do_reshape = True
2192
2193  if y_squashed:
2194    output_shape = array_ops.concat([output_shape[:-1], y_trail_dims], 0)
2195    do_reshape = True
2196
2197  if do_reshape:
2198    result = array_ops.reshape(result, output_shape)
2199
2200  # if the inputs were originally rank 2, we remove the added 1 dim.
2201  if orig_x_ndim == 2:
2202    result = array_ops.squeeze(result, 1)
2203  elif orig_y_ndim == 2:
2204    result = array_ops.squeeze(result, -1)
2205
2206  return result
2207
2208
2209@keras_export('keras.backend.transpose')
2210@dispatch.add_dispatch_support
2211@doc_controls.do_not_generate_docs
2212def transpose(x):
2213  """Transposes a tensor and returns it.
2214
2215  Args:
2216      x: Tensor or variable.
2217
2218  Returns:
2219      A tensor.
2220
2221  Examples:
2222
2223  >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]])
2224  >>> tf.keras.backend.eval(var)
2225  array([[1.,  2.,  3.],
2226         [4.,  5.,  6.]], dtype=float32)
2227  >>> var_transposed = tf.keras.backend.transpose(var)
2228  >>> tf.keras.backend.eval(var_transposed)
2229  array([[1.,  4.],
2230         [2.,  5.],
2231         [3.,  6.]], dtype=float32)
2232  >>> input = tf.keras.backend.placeholder((2, 3))
2233  >>> input
2234  <KerasTensor: shape=(2, 3) dtype=float32 ...>
2235  >>> input_transposed = tf.keras.backend.transpose(input)
2236  >>> input_transposed
2237  <KerasTensor: shape=(3, 2) dtype=float32 ...>
2238  """
2239  return array_ops.transpose(x)
2240
2241
2242@keras_export('keras.backend.gather')
2243@dispatch.add_dispatch_support
2244@doc_controls.do_not_generate_docs
2245def gather(reference, indices):
2246  """Retrieves the elements of indices `indices` in the tensor `reference`.
2247
2248  Args:
2249      reference: A tensor.
2250      indices: An integer tensor of indices.
2251
2252  Returns:
2253      A tensor of same type as `reference`.
2254
2255  Examples:
2256
2257  >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]])
2258  >>> tf.keras.backend.eval(var)
2259  array([[1., 2., 3.],
2260         [4., 5., 6.]], dtype=float32)
2261  >>> var_gathered = tf.keras.backend.gather(var, [0])
2262  >>> tf.keras.backend.eval(var_gathered)
2263  array([[1., 2., 3.]], dtype=float32)
2264  >>> var_gathered = tf.keras.backend.gather(var, [1])
2265  >>> tf.keras.backend.eval(var_gathered)
2266  array([[4., 5., 6.]], dtype=float32)
2267  >>> var_gathered = tf.keras.backend.gather(var, [0,1,0])
2268  >>> tf.keras.backend.eval(var_gathered)
2269  array([[1., 2., 3.],
2270         [4., 5., 6.],
2271         [1., 2., 3.]], dtype=float32)
2272  """
2273  return array_ops.gather(reference, indices)
2274
2275
2276# ELEMENT-WISE OPERATIONS
2277
2278
2279@keras_export('keras.backend.max')
2280@dispatch.add_dispatch_support
2281@doc_controls.do_not_generate_docs
2282def max(x, axis=None, keepdims=False):
2283  """Maximum value in a tensor.
2284
2285  Args:
2286      x: A tensor or variable.
2287      axis: An integer, the axis to find maximum values.
2288      keepdims: A boolean, whether to keep the dimensions or not.
2289          If `keepdims` is `False`, the rank of the tensor is reduced
2290          by 1. If `keepdims` is `True`,
2291          the reduced dimension is retained with length 1.
2292
2293  Returns:
2294      A tensor with maximum values of `x`.
2295  """
2296  return math_ops.reduce_max(x, axis, keepdims)
2297
2298
2299@keras_export('keras.backend.min')
2300@dispatch.add_dispatch_support
2301@doc_controls.do_not_generate_docs
2302def min(x, axis=None, keepdims=False):
2303  """Minimum value in a tensor.
2304
2305  Args:
2306      x: A tensor or variable.
2307      axis: An integer, the axis to find minimum values.
2308      keepdims: A boolean, whether to keep the dimensions or not.
2309          If `keepdims` is `False`, the rank of the tensor is reduced
2310          by 1. If `keepdims` is `True`,
2311          the reduced dimension is retained with length 1.
2312
2313  Returns:
2314      A tensor with minimum values of `x`.
2315  """
2316  return math_ops.reduce_min(x, axis, keepdims)
2317
2318
2319@keras_export('keras.backend.sum')
2320@dispatch.add_dispatch_support
2321@doc_controls.do_not_generate_docs
2322def sum(x, axis=None, keepdims=False):
2323  """Sum of the values in a tensor, alongside the specified axis.
2324
2325  Args:
2326      x: A tensor or variable.
2327      axis: An integer, the axis to sum over.
2328      keepdims: A boolean, whether to keep the dimensions or not.
2329          If `keepdims` is `False`, the rank of the tensor is reduced
2330          by 1. If `keepdims` is `True`,
2331          the reduced dimension is retained with length 1.
2332
2333  Returns:
2334      A tensor with sum of `x`.
2335  """
2336  return math_ops.reduce_sum(x, axis, keepdims)
2337
2338
2339@keras_export('keras.backend.prod')
2340@dispatch.add_dispatch_support
2341@doc_controls.do_not_generate_docs
2342def prod(x, axis=None, keepdims=False):
2343  """Multiplies the values in a tensor, alongside the specified axis.
2344
2345  Args:
2346      x: A tensor or variable.
2347      axis: An integer, the axis to compute the product.
2348      keepdims: A boolean, whether to keep the dimensions or not.
2349          If `keepdims` is `False`, the rank of the tensor is reduced
2350          by 1. If `keepdims` is `True`,
2351          the reduced dimension is retained with length 1.
2352
2353  Returns:
2354      A tensor with the product of elements of `x`.
2355  """
2356  return math_ops.reduce_prod(x, axis, keepdims)
2357
2358
2359@keras_export('keras.backend.cumsum')
2360@dispatch.add_dispatch_support
2361@doc_controls.do_not_generate_docs
2362def cumsum(x, axis=0):
2363  """Cumulative sum of the values in a tensor, alongside the specified axis.
2364
2365  Args:
2366      x: A tensor or variable.
2367      axis: An integer, the axis to compute the sum.
2368
2369  Returns:
2370      A tensor of the cumulative sum of values of `x` along `axis`.
2371  """
2372  return math_ops.cumsum(x, axis=axis)
2373
2374
2375@keras_export('keras.backend.cumprod')
2376@dispatch.add_dispatch_support
2377@doc_controls.do_not_generate_docs
2378def cumprod(x, axis=0):
2379  """Cumulative product of the values in a tensor, alongside the specified axis.
2380
2381  Args:
2382      x: A tensor or variable.
2383      axis: An integer, the axis to compute the product.
2384
2385  Returns:
2386      A tensor of the cumulative product of values of `x` along `axis`.
2387  """
2388  return math_ops.cumprod(x, axis=axis)
2389
2390
2391@keras_export('keras.backend.var')
2392@doc_controls.do_not_generate_docs
2393def var(x, axis=None, keepdims=False):
2394  """Variance of a tensor, alongside the specified axis.
2395
2396  Args:
2397      x: A tensor or variable.
2398      axis: An integer, the axis to compute the variance.
2399      keepdims: A boolean, whether to keep the dimensions or not.
2400          If `keepdims` is `False`, the rank of the tensor is reduced
2401          by 1. If `keepdims` is `True`,
2402          the reduced dimension is retained with length 1.
2403
2404  Returns:
2405      A tensor with the variance of elements of `x`.
2406  """
2407  if x.dtype.base_dtype == dtypes_module.bool:
2408    x = math_ops.cast(x, floatx())
2409  return math_ops.reduce_variance(x, axis=axis, keepdims=keepdims)
2410
2411
2412@keras_export('keras.backend.std')
2413@dispatch.add_dispatch_support
2414@doc_controls.do_not_generate_docs
2415def std(x, axis=None, keepdims=False):
2416  """Standard deviation of a tensor, alongside the specified axis.
2417
2418  It is an alias to `tf.math.reduce_std`.
2419
2420  Args:
2421      x: A tensor or variable. It should have numerical dtypes. Boolean type
2422        inputs will be converted to float.
2423      axis: An integer, the axis to compute the standard deviation. If `None`
2424        (the default), reduces all dimensions. Must be in the range
2425        `[-rank(x), rank(x))`.
2426      keepdims: A boolean, whether to keep the dimensions or not.
2427          If `keepdims` is `False`, the rank of the tensor is reduced
2428          by 1. If `keepdims` is `True`, the reduced dimension is retained with
2429          length 1.
2430
2431  Returns:
2432      A tensor with the standard deviation of elements of `x` with same dtype.
2433      Boolean type input will be converted to float.
2434  """
2435  if x.dtype.base_dtype == dtypes_module.bool:
2436    x = math_ops.cast(x, floatx())
2437  return math_ops.reduce_std(x, axis=axis, keepdims=keepdims)
2438
2439
2440@keras_export('keras.backend.mean')
2441@dispatch.add_dispatch_support
2442@doc_controls.do_not_generate_docs
2443def mean(x, axis=None, keepdims=False):
2444  """Mean of a tensor, alongside the specified axis.
2445
2446  Args:
2447      x: A tensor or variable.
2448      axis: A list of integer. Axes to compute the mean.
2449      keepdims: A boolean, whether to keep the dimensions or not.
2450          If `keepdims` is `False`, the rank of the tensor is reduced
2451          by 1 for each entry in `axis`. If `keepdims` is `True`,
2452          the reduced dimensions are retained with length 1.
2453
2454  Returns:
2455      A tensor with the mean of elements of `x`.
2456  """
2457  if x.dtype.base_dtype == dtypes_module.bool:
2458    x = math_ops.cast(x, floatx())
2459  return math_ops.reduce_mean(x, axis, keepdims)
2460
2461
2462@keras_export('keras.backend.any')
2463@dispatch.add_dispatch_support
2464@doc_controls.do_not_generate_docs
2465def any(x, axis=None, keepdims=False):
2466  """Bitwise reduction (logical OR).
2467
2468  Args:
2469      x: Tensor or variable.
2470      axis: axis along which to perform the reduction.
2471      keepdims: whether the drop or broadcast the reduction axes.
2472
2473  Returns:
2474      A uint8 tensor (0s and 1s).
2475  """
2476  x = math_ops.cast(x, dtypes_module.bool)
2477  return math_ops.reduce_any(x, axis, keepdims)
2478
2479
2480@keras_export('keras.backend.all')
2481@dispatch.add_dispatch_support
2482@doc_controls.do_not_generate_docs
2483def all(x, axis=None, keepdims=False):
2484  """Bitwise reduction (logical AND).
2485
2486  Args:
2487      x: Tensor or variable.
2488      axis: axis along which to perform the reduction.
2489      keepdims: whether the drop or broadcast the reduction axes.
2490
2491  Returns:
2492      A uint8 tensor (0s and 1s).
2493  """
2494  x = math_ops.cast(x, dtypes_module.bool)
2495  return math_ops.reduce_all(x, axis, keepdims)
2496
2497
2498@keras_export('keras.backend.argmax')
2499@dispatch.add_dispatch_support
2500@doc_controls.do_not_generate_docs
2501def argmax(x, axis=-1):
2502  """Returns the index of the maximum value along an axis.
2503
2504  Args:
2505      x: Tensor or variable.
2506      axis: axis along which to perform the reduction.
2507
2508  Returns:
2509      A tensor.
2510  """
2511  return math_ops.argmax(x, axis)
2512
2513
2514@keras_export('keras.backend.argmin')
2515@dispatch.add_dispatch_support
2516@doc_controls.do_not_generate_docs
2517def argmin(x, axis=-1):
2518  """Returns the index of the minimum value along an axis.
2519
2520  Args:
2521      x: Tensor or variable.
2522      axis: axis along which to perform the reduction.
2523
2524  Returns:
2525      A tensor.
2526  """
2527  return math_ops.argmin(x, axis)
2528
2529
2530@keras_export('keras.backend.square')
2531@dispatch.add_dispatch_support
2532@doc_controls.do_not_generate_docs
2533def square(x):
2534  """Element-wise square.
2535
2536  Args:
2537      x: Tensor or variable.
2538
2539  Returns:
2540      A tensor.
2541  """
2542  return math_ops.square(x)
2543
2544
2545@keras_export('keras.backend.abs')
2546@dispatch.add_dispatch_support
2547@doc_controls.do_not_generate_docs
2548def abs(x):
2549  """Element-wise absolute value.
2550
2551  Args:
2552      x: Tensor or variable.
2553
2554  Returns:
2555      A tensor.
2556  """
2557  return math_ops.abs(x)
2558
2559
2560@keras_export('keras.backend.sqrt')
2561@dispatch.add_dispatch_support
2562@doc_controls.do_not_generate_docs
2563def sqrt(x):
2564  """Element-wise square root.
2565
2566     This function clips negative tensor values to 0 before computing the
2567     square root.
2568
2569  Args:
2570      x: Tensor or variable.
2571
2572  Returns:
2573      A tensor.
2574  """
2575  zero = _constant_to_tensor(0., x.dtype.base_dtype)
2576  x = math_ops.maximum(x, zero)
2577  return math_ops.sqrt(x)
2578
2579
2580@keras_export('keras.backend.exp')
2581@dispatch.add_dispatch_support
2582@doc_controls.do_not_generate_docs
2583def exp(x):
2584  """Element-wise exponential.
2585
2586  Args:
2587      x: Tensor or variable.
2588
2589  Returns:
2590      A tensor.
2591  """
2592  return math_ops.exp(x)
2593
2594
2595@keras_export('keras.backend.log')
2596@dispatch.add_dispatch_support
2597@doc_controls.do_not_generate_docs
2598def log(x):
2599  """Element-wise log.
2600
2601  Args:
2602      x: Tensor or variable.
2603
2604  Returns:
2605      A tensor.
2606  """
2607  return math_ops.log(x)
2608
2609
2610def logsumexp(x, axis=None, keepdims=False):
2611  """Computes log(sum(exp(elements across dimensions of a tensor))).
2612
2613  This function is more numerically stable than log(sum(exp(x))).
2614  It avoids overflows caused by taking the exp of large inputs and
2615  underflows caused by taking the log of small inputs.
2616
2617  Args:
2618      x: A tensor or variable.
2619      axis: An integer, the axis to reduce over.
2620      keepdims: A boolean, whether to keep the dimensions or not.
2621          If `keepdims` is `False`, the rank of the tensor is reduced
2622          by 1. If `keepdims` is `True`, the reduced dimension is
2623          retained with length 1.
2624
2625  Returns:
2626      The reduced tensor.
2627  """
2628  return math_ops.reduce_logsumexp(x, axis, keepdims)
2629
2630
2631@keras_export('keras.backend.round')
2632@dispatch.add_dispatch_support
2633@doc_controls.do_not_generate_docs
2634def round(x):
2635  """Element-wise rounding to the closest integer.
2636
2637  In case of tie, the rounding mode used is "half to even".
2638
2639  Args:
2640      x: Tensor or variable.
2641
2642  Returns:
2643      A tensor.
2644  """
2645  return math_ops.round(x)
2646
2647
2648@keras_export('keras.backend.sign')
2649@dispatch.add_dispatch_support
2650@doc_controls.do_not_generate_docs
2651def sign(x):
2652  """Element-wise sign.
2653
2654  Args:
2655      x: Tensor or variable.
2656
2657  Returns:
2658      A tensor.
2659  """
2660  return math_ops.sign(x)
2661
2662
2663@keras_export('keras.backend.pow')
2664@dispatch.add_dispatch_support
2665@doc_controls.do_not_generate_docs
2666def pow(x, a):
2667  """Element-wise exponentiation.
2668
2669  Args:
2670      x: Tensor or variable.
2671      a: Python integer.
2672
2673  Returns:
2674      A tensor.
2675  """
2676  return math_ops.pow(x, a)
2677
2678
2679@keras_export('keras.backend.clip')
2680@dispatch.add_dispatch_support
2681@doc_controls.do_not_generate_docs
2682def clip(x, min_value, max_value):
2683  """Element-wise value clipping.
2684
2685  Args:
2686      x: Tensor or variable.
2687      min_value: Python float, integer, or tensor.
2688      max_value: Python float, integer, or tensor.
2689
2690  Returns:
2691      A tensor.
2692  """
2693  if (isinstance(min_value, (int, float)) and
2694      isinstance(max_value, (int, float))):
2695    if max_value < min_value:
2696      max_value = min_value
2697  if min_value is None:
2698    min_value = -np.inf
2699  if max_value is None:
2700    max_value = np.inf
2701  return clip_ops.clip_by_value(x, min_value, max_value)
2702
2703
2704@keras_export('keras.backend.equal')
2705@dispatch.add_dispatch_support
2706@doc_controls.do_not_generate_docs
2707def equal(x, y):
2708  """Element-wise equality between two tensors.
2709
2710  Args:
2711      x: Tensor or variable.
2712      y: Tensor or variable.
2713
2714  Returns:
2715      A bool tensor.
2716  """
2717  return math_ops.equal(x, y)
2718
2719
2720@keras_export('keras.backend.not_equal')
2721@dispatch.add_dispatch_support
2722@doc_controls.do_not_generate_docs
2723def not_equal(x, y):
2724  """Element-wise inequality between two tensors.
2725
2726  Args:
2727      x: Tensor or variable.
2728      y: Tensor or variable.
2729
2730  Returns:
2731      A bool tensor.
2732  """
2733  return math_ops.not_equal(x, y)
2734
2735
2736@keras_export('keras.backend.greater')
2737@dispatch.add_dispatch_support
2738@doc_controls.do_not_generate_docs
2739def greater(x, y):
2740  """Element-wise truth value of (x > y).
2741
2742  Args:
2743      x: Tensor or variable.
2744      y: Tensor or variable.
2745
2746  Returns:
2747      A bool tensor.
2748  """
2749  return math_ops.greater(x, y)
2750
2751
2752@keras_export('keras.backend.greater_equal')
2753@dispatch.add_dispatch_support
2754@doc_controls.do_not_generate_docs
2755def greater_equal(x, y):
2756  """Element-wise truth value of (x >= y).
2757
2758  Args:
2759      x: Tensor or variable.
2760      y: Tensor or variable.
2761
2762  Returns:
2763      A bool tensor.
2764  """
2765  return math_ops.greater_equal(x, y)
2766
2767
2768@keras_export('keras.backend.less')
2769@dispatch.add_dispatch_support
2770@doc_controls.do_not_generate_docs
2771def less(x, y):
2772  """Element-wise truth value of (x < y).
2773
2774  Args:
2775      x: Tensor or variable.
2776      y: Tensor or variable.
2777
2778  Returns:
2779      A bool tensor.
2780  """
2781  return math_ops.less(x, y)
2782
2783
2784@keras_export('keras.backend.less_equal')
2785@dispatch.add_dispatch_support
2786@doc_controls.do_not_generate_docs
2787def less_equal(x, y):
2788  """Element-wise truth value of (x <= y).
2789
2790  Args:
2791      x: Tensor or variable.
2792      y: Tensor or variable.
2793
2794  Returns:
2795      A bool tensor.
2796  """
2797  return math_ops.less_equal(x, y)
2798
2799
2800@keras_export('keras.backend.maximum')
2801@dispatch.add_dispatch_support
2802@doc_controls.do_not_generate_docs
2803def maximum(x, y):
2804  """Element-wise maximum of two tensors.
2805
2806  Args:
2807      x: Tensor or variable.
2808      y: Tensor or variable.
2809
2810  Returns:
2811      A tensor with the element wise maximum value(s) of `x` and `y`.
2812
2813  Examples:
2814
2815  >>> x = tf.Variable([[1, 2], [3, 4]])
2816  >>> y = tf.Variable([[2, 1], [0, -1]])
2817  >>> m = tf.keras.backend.maximum(x, y)
2818  >>> m
2819  <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
2820  array([[2, 2],
2821         [3, 4]], dtype=int32)>
2822  """
2823  return math_ops.maximum(x, y)
2824
2825
2826@keras_export('keras.backend.minimum')
2827@dispatch.add_dispatch_support
2828@doc_controls.do_not_generate_docs
2829def minimum(x, y):
2830  """Element-wise minimum of two tensors.
2831
2832  Args:
2833      x: Tensor or variable.
2834      y: Tensor or variable.
2835
2836  Returns:
2837      A tensor.
2838  """
2839  return math_ops.minimum(x, y)
2840
2841
2842@keras_export('keras.backend.sin')
2843@dispatch.add_dispatch_support
2844@doc_controls.do_not_generate_docs
2845def sin(x):
2846  """Computes sin of x element-wise.
2847
2848  Args:
2849      x: Tensor or variable.
2850
2851  Returns:
2852      A tensor.
2853  """
2854  return math_ops.sin(x)
2855
2856
2857@keras_export('keras.backend.cos')
2858@dispatch.add_dispatch_support
2859@doc_controls.do_not_generate_docs
2860def cos(x):
2861  """Computes cos of x element-wise.
2862
2863  Args:
2864      x: Tensor or variable.
2865
2866  Returns:
2867      A tensor.
2868  """
2869  return math_ops.cos(x)
2870
2871
2872def _regular_normalize_batch_in_training(x,
2873                                         gamma,
2874                                         beta,
2875                                         reduction_axes,
2876                                         epsilon=1e-3):
2877  """Non-fused version of `normalize_batch_in_training`.
2878
2879  Args:
2880      x: Input tensor or variable.
2881      gamma: Tensor by which to scale the input.
2882      beta: Tensor with which to center the input.
2883      reduction_axes: iterable of integers,
2884          axes over which to normalize.
2885      epsilon: Fuzz factor.
2886
2887  Returns:
2888      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2889  """
2890  mean, var = nn.moments(x, reduction_axes, None, None, False)
2891  normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
2892  return normed, mean, var
2893
2894
2895def _broadcast_normalize_batch_in_training(x,
2896                                           gamma,
2897                                           beta,
2898                                           reduction_axes,
2899                                           epsilon=1e-3):
2900  """Non-fused, broadcast version of `normalize_batch_in_training`.
2901
2902  Args:
2903      x: Input tensor or variable.
2904      gamma: Tensor by which to scale the input.
2905      beta: Tensor with which to center the input.
2906      reduction_axes: iterable of integers,
2907          axes over which to normalize.
2908      epsilon: Fuzz factor.
2909
2910  Returns:
2911      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2912  """
2913  mean, var = nn.moments(x, reduction_axes, None, None, False)
2914  target_shape = []
2915  for axis in range(ndim(x)):
2916    if axis in reduction_axes:
2917      target_shape.append(1)
2918    else:
2919      target_shape.append(array_ops.shape(x)[axis])
2920  target_shape = array_ops.stack(target_shape)
2921
2922  broadcast_mean = array_ops.reshape(mean, target_shape)
2923  broadcast_var = array_ops.reshape(var, target_shape)
2924  if gamma is None:
2925    broadcast_gamma = None
2926  else:
2927    broadcast_gamma = array_ops.reshape(gamma, target_shape)
2928  if beta is None:
2929    broadcast_beta = None
2930  else:
2931    broadcast_beta = array_ops.reshape(beta, target_shape)
2932
2933  normed = nn.batch_normalization(x, broadcast_mean, broadcast_var,
2934                                  broadcast_beta, broadcast_gamma, epsilon)
2935  return normed, mean, var
2936
2937
2938def _fused_normalize_batch_in_training(x,
2939                                       gamma,
2940                                       beta,
2941                                       reduction_axes,
2942                                       epsilon=1e-3):
2943  """Fused version of `normalize_batch_in_training`.
2944
2945  Args:
2946      x: Input tensor or variable.
2947      gamma: Tensor by which to scale the input.
2948      beta: Tensor with which to center the input.
2949      reduction_axes: iterable of integers,
2950          axes over which to normalize.
2951      epsilon: Fuzz factor.
2952
2953  Returns:
2954      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2955  """
2956  if list(reduction_axes) == [0, 1, 2]:
2957    normalization_axis = 3
2958    tf_data_format = 'NHWC'
2959  else:
2960    normalization_axis = 1
2961    tf_data_format = 'NCHW'
2962
2963  if gamma is None:
2964    gamma = constant_op.constant(
2965        1.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
2966  if beta is None:
2967    beta = constant_op.constant(
2968        0.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
2969
2970  return nn.fused_batch_norm(
2971      x, gamma, beta, epsilon=epsilon, data_format=tf_data_format)
2972
2973
2974@keras_export('keras.backend.normalize_batch_in_training')
2975@doc_controls.do_not_generate_docs
2976def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
2977  """Computes mean and std for batch then apply batch_normalization on batch.
2978
2979  Args:
2980      x: Input tensor or variable.
2981      gamma: Tensor by which to scale the input.
2982      beta: Tensor with which to center the input.
2983      reduction_axes: iterable of integers,
2984          axes over which to normalize.
2985      epsilon: Fuzz factor.
2986
2987  Returns:
2988      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2989  """
2990  if ndim(x) == 4 and list(reduction_axes) in [[0, 1, 2], [0, 2, 3]]:
2991    if not _has_nchw_support() and list(reduction_axes) == [0, 2, 3]:
2992      return _broadcast_normalize_batch_in_training(
2993          x, gamma, beta, reduction_axes, epsilon=epsilon)
2994    return _fused_normalize_batch_in_training(
2995        x, gamma, beta, reduction_axes, epsilon=epsilon)
2996  else:
2997    if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
2998      return _regular_normalize_batch_in_training(
2999          x, gamma, beta, reduction_axes, epsilon=epsilon)
3000    else:
3001      return _broadcast_normalize_batch_in_training(
3002          x, gamma, beta, reduction_axes, epsilon=epsilon)
3003
3004
3005@keras_export('keras.backend.batch_normalization')
3006@dispatch.add_dispatch_support
3007@doc_controls.do_not_generate_docs
3008def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
3009  """Applies batch normalization on x given mean, var, beta and gamma.
3010
3011  I.e. returns:
3012  `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
3013
3014  Args:
3015      x: Input tensor or variable.
3016      mean: Mean of batch.
3017      var: Variance of batch.
3018      beta: Tensor with which to center the input.
3019      gamma: Tensor by which to scale the input.
3020      axis: Integer, the axis that should be normalized.
3021          (typically the features axis).
3022      epsilon: Fuzz factor.
3023
3024  Returns:
3025      A tensor.
3026  """
3027  if ndim(x) == 4:
3028    # The CPU implementation of `fused_batch_norm` only supports NHWC
3029    if axis == 1 or axis == -3:
3030      tf_data_format = 'NCHW'
3031    elif axis == 3 or axis == -1:
3032      tf_data_format = 'NHWC'
3033    else:
3034      tf_data_format = None
3035
3036    if (tf_data_format == 'NHWC' or
3037        tf_data_format == 'NCHW' and _has_nchw_support()):
3038      # The mean / var / beta / gamma tensors may be broadcasted
3039      # so they may have extra axes of size 1, which should be squeezed.
3040      if ndim(mean) > 1:
3041        mean = array_ops.reshape(mean, [-1])
3042      if ndim(var) > 1:
3043        var = array_ops.reshape(var, [-1])
3044      if beta is None:
3045        beta = zeros_like(mean)
3046      elif ndim(beta) > 1:
3047        beta = array_ops.reshape(beta, [-1])
3048      if gamma is None:
3049        gamma = ones_like(mean)
3050      elif ndim(gamma) > 1:
3051        gamma = array_ops.reshape(gamma, [-1])
3052    y, _, _ = nn.fused_batch_norm(
3053        x,
3054        gamma,
3055        beta,
3056        epsilon=epsilon,
3057        mean=mean,
3058        variance=var,
3059        data_format=tf_data_format,
3060        is_training=False
3061    )
3062    return y
3063  return nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
3064
3065
3066# SHAPE OPERATIONS
3067
3068
3069@keras_export('keras.backend.concatenate')
3070@dispatch.add_dispatch_support
3071@doc_controls.do_not_generate_docs
3072def concatenate(tensors, axis=-1):
3073  """Concatenates a list of tensors alongside the specified axis.
3074
3075  Args:
3076      tensors: list of tensors to concatenate.
3077      axis: concatenation axis.
3078
3079  Returns:
3080      A tensor.
3081
3082  Example:
3083
3084      >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
3085      >>> b = tf.constant([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
3086      >>> tf.keras.backend.concatenate((a, b), axis=-1)
3087      <tf.Tensor: shape=(3, 6), dtype=int32, numpy=
3088      array([[ 1,  2,  3, 10, 20, 30],
3089             [ 4,  5,  6, 40, 50, 60],
3090             [ 7,  8,  9, 70, 80, 90]], dtype=int32)>
3091
3092  """
3093  if axis < 0:
3094    rank = ndim(tensors[0])
3095    if rank:
3096      axis %= rank
3097    else:
3098      axis = 0
3099
3100  if py_all(is_sparse(x) for x in tensors):
3101    return sparse_ops.sparse_concat(axis, tensors)
3102  elif py_all(isinstance(x, ragged_tensor.RaggedTensor) for x in tensors):
3103    return array_ops.concat(tensors, axis)
3104  else:
3105    return array_ops.concat([to_dense(x) for x in tensors], axis)
3106
3107
3108@keras_export('keras.backend.reshape')
3109@dispatch.add_dispatch_support
3110@doc_controls.do_not_generate_docs
3111def reshape(x, shape):
3112  """Reshapes a tensor to the specified shape.
3113
3114  Args:
3115      x: Tensor or variable.
3116      shape: Target shape tuple.
3117
3118  Returns:
3119      A tensor.
3120
3121  Example:
3122
3123    >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
3124    >>> a
3125    <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
3126    array([[ 1,  2,  3],
3127           [ 4,  5,  6],
3128           [ 7,  8,  9],
3129           [10, 11, 12]], dtype=int32)>
3130    >>> tf.keras.backend.reshape(a, shape=(2, 6))
3131    <tf.Tensor: shape=(2, 6), dtype=int32, numpy=
3132    array([[ 1,  2,  3,  4,  5,  6],
3133           [ 7,  8,  9, 10, 11, 12]], dtype=int32)>
3134
3135  """
3136  return array_ops.reshape(x, shape)
3137
3138
3139@keras_export('keras.backend.permute_dimensions')
3140@dispatch.add_dispatch_support
3141@doc_controls.do_not_generate_docs
3142def permute_dimensions(x, pattern):
3143  """Permutes axes in a tensor.
3144
3145  Args:
3146      x: Tensor or variable.
3147      pattern: A tuple of
3148          dimension indices, e.g. `(0, 2, 1)`.
3149
3150  Returns:
3151      A tensor.
3152
3153  Example:
3154
3155    >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
3156    >>> a
3157    <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
3158    array([[ 1,  2,  3],
3159           [ 4,  5,  6],
3160           [ 7,  8,  9],
3161           [10, 11, 12]], dtype=int32)>
3162    >>> tf.keras.backend.permute_dimensions(a, pattern=(1, 0))
3163    <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
3164    array([[ 1,  4,  7, 10],
3165           [ 2,  5,  8, 11],
3166           [ 3,  6,  9, 12]], dtype=int32)>
3167
3168  """
3169  return array_ops.transpose(x, perm=pattern)
3170
3171
3172@keras_export('keras.backend.resize_images')
3173@dispatch.add_dispatch_support
3174@doc_controls.do_not_generate_docs
3175def resize_images(x, height_factor, width_factor, data_format,
3176                  interpolation='nearest'):
3177  """Resizes the images contained in a 4D tensor.
3178
3179  Args:
3180      x: Tensor or variable to resize.
3181      height_factor: Positive integer.
3182      width_factor: Positive integer.
3183      data_format: One of `"channels_first"`, `"channels_last"`.
3184      interpolation: A string, one of `nearest` or `bilinear`.
3185
3186  Returns:
3187      A tensor.
3188
3189  Raises:
3190      ValueError: in case of incorrect value for
3191        `data_format` or `interpolation`.
3192  """
3193  if data_format == 'channels_first':
3194    rows, cols = 2, 3
3195  elif data_format == 'channels_last':
3196    rows, cols = 1, 2
3197  else:
3198    raise ValueError('Invalid `data_format` argument: %s' % (data_format,))
3199
3200  new_shape = x.shape[rows:cols + 1]
3201  if new_shape.is_fully_defined():
3202    new_shape = constant_op.constant(new_shape.as_list(), dtype='int32')
3203  else:
3204    new_shape = array_ops.shape_v2(x)[rows:cols + 1]
3205  new_shape *= constant_op.constant(
3206      np.array([height_factor, width_factor], dtype='int32'))
3207
3208  if data_format == 'channels_first':
3209    x = permute_dimensions(x, [0, 2, 3, 1])
3210  if interpolation == 'nearest':
3211    x = image_ops.resize_images_v2(
3212        x, new_shape, method=image_ops.ResizeMethod.NEAREST_NEIGHBOR)
3213  elif interpolation == 'bilinear':
3214    x = image_ops.resize_images_v2(x, new_shape,
3215                                   method=image_ops.ResizeMethod.BILINEAR)
3216  else:
3217    raise ValueError('interpolation should be one '
3218                     'of "nearest" or "bilinear".')
3219  if data_format == 'channels_first':
3220    x = permute_dimensions(x, [0, 3, 1, 2])
3221
3222  return x
3223
3224
3225@keras_export('keras.backend.resize_volumes')
3226@dispatch.add_dispatch_support
3227@doc_controls.do_not_generate_docs
3228def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
3229  """Resizes the volume contained in a 5D tensor.
3230
3231  Args:
3232      x: Tensor or variable to resize.
3233      depth_factor: Positive integer.
3234      height_factor: Positive integer.
3235      width_factor: Positive integer.
3236      data_format: One of `"channels_first"`, `"channels_last"`.
3237
3238  Returns:
3239      A tensor.
3240
3241  Raises:
3242      ValueError: if `data_format` is neither
3243          `channels_last` or `channels_first`.
3244  """
3245  if data_format == 'channels_first':
3246    output = repeat_elements(x, depth_factor, axis=2)
3247    output = repeat_elements(output, height_factor, axis=3)
3248    output = repeat_elements(output, width_factor, axis=4)
3249    return output
3250  elif data_format == 'channels_last':
3251    output = repeat_elements(x, depth_factor, axis=1)
3252    output = repeat_elements(output, height_factor, axis=2)
3253    output = repeat_elements(output, width_factor, axis=3)
3254    return output
3255  else:
3256    raise ValueError('Invalid data_format: ' + str(data_format))
3257
3258
3259@keras_export('keras.backend.repeat_elements')
3260@dispatch.add_dispatch_support
3261@doc_controls.do_not_generate_docs
3262def repeat_elements(x, rep, axis):
3263  """Repeats the elements of a tensor along an axis, like `np.repeat`.
3264
3265  If `x` has shape `(s1, s2, s3)` and `axis` is `1`, the output
3266  will have shape `(s1, s2 * rep, s3)`.
3267
3268  Args:
3269      x: Tensor or variable.
3270      rep: Python integer, number of times to repeat.
3271      axis: Axis along which to repeat.
3272
3273  Returns:
3274      A tensor.
3275
3276  Example:
3277
3278      >>> b = tf.constant([1, 2, 3])
3279      >>> tf.keras.backend.repeat_elements(b, rep=2, axis=0)
3280      <tf.Tensor: shape=(6,), dtype=int32,
3281          numpy=array([1, 1, 2, 2, 3, 3], dtype=int32)>
3282
3283  """
3284  x_shape = x.shape.as_list()
3285  # For static axis
3286  if x_shape[axis] is not None:
3287    # slices along the repeat axis
3288    splits = array_ops.split(value=x,
3289                             num_or_size_splits=x_shape[axis],
3290                             axis=axis)
3291    # repeat each slice the given number of reps
3292    x_rep = [s for s in splits for _ in range(rep)]
3293    return concatenate(x_rep, axis)
3294
3295  # Here we use tf.tile to mimic behavior of np.repeat so that
3296  # we can handle dynamic shapes (that include None).
3297  # To do that, we need an auxiliary axis to repeat elements along
3298  # it and then merge them along the desired axis.
3299
3300  # Repeating
3301  auxiliary_axis = axis + 1
3302  x_shape = array_ops.shape(x)
3303  x_rep = array_ops.expand_dims(x, axis=auxiliary_axis)
3304  reps = np.ones(len(x.shape) + 1)
3305  reps[auxiliary_axis] = rep
3306  x_rep = array_ops.tile(x_rep, reps)
3307
3308  # Merging
3309  reps = np.delete(reps, auxiliary_axis)
3310  reps[axis] = rep
3311  reps = array_ops.constant(reps, dtype='int32')
3312  x_shape *= reps
3313  x_rep = array_ops.reshape(x_rep, x_shape)
3314
3315  # Fix shape representation
3316  x_shape = x.shape.as_list()
3317  x_rep.set_shape(x_shape)
3318  x_rep._keras_shape = tuple(x_shape)
3319  return x_rep
3320
3321
3322@keras_export('keras.backend.repeat')
3323@dispatch.add_dispatch_support
3324@doc_controls.do_not_generate_docs
3325def repeat(x, n):
3326  """Repeats a 2D tensor.
3327
3328  if `x` has shape (samples, dim) and `n` is `2`,
3329  the output will have shape `(samples, 2, dim)`.
3330
3331  Args:
3332      x: Tensor or variable.
3333      n: Python integer, number of times to repeat.
3334
3335  Returns:
3336      A tensor.
3337
3338  Example:
3339
3340      >>> b = tf.constant([[1, 2], [3, 4]])
3341      >>> b
3342      <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
3343      array([[1, 2],
3344             [3, 4]], dtype=int32)>
3345      >>> tf.keras.backend.repeat(b, n=2)
3346      <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
3347      array([[[1, 2],
3348              [1, 2]],
3349             [[3, 4],
3350              [3, 4]]], dtype=int32)>
3351
3352  """
3353  assert ndim(x) == 2
3354  x = array_ops.expand_dims(x, 1)
3355  pattern = array_ops.stack([1, n, 1])
3356  return array_ops.tile(x, pattern)
3357
3358
3359@keras_export('keras.backend.arange')
3360@dispatch.add_dispatch_support
3361@doc_controls.do_not_generate_docs
3362def arange(start, stop=None, step=1, dtype='int32'):
3363  """Creates a 1D tensor containing a sequence of integers.
3364
3365  The function arguments use the same convention as
3366  Theano's arange: if only one argument is provided,
3367  it is in fact the "stop" argument and "start" is 0.
3368
3369  The default type of the returned tensor is `'int32'` to
3370  match TensorFlow's default.
3371
3372  Args:
3373      start: Start value.
3374      stop: Stop value.
3375      step: Difference between two successive values.
3376      dtype: Integer dtype to use.
3377
3378  Returns:
3379      An integer tensor.
3380
3381  Example:
3382
3383      >>> tf.keras.backend.arange(start=0, stop=10, step=1.5)
3384      <tf.Tensor: shape=(7,), dtype=float32,
3385          numpy=array([0. , 1.5, 3. , 4.5, 6. , 7.5, 9. ], dtype=float32)>
3386
3387
3388
3389  """
3390  # Match the behavior of numpy and Theano by returning an empty sequence.
3391  if stop is None and start < 0:
3392    start = 0
3393  result = math_ops.range(start, limit=stop, delta=step, name='arange')
3394  if dtype != 'int32':
3395    result = cast(result, dtype)
3396  return result
3397
3398
3399@keras_export('keras.backend.tile')
3400@dispatch.add_dispatch_support
3401@doc_controls.do_not_generate_docs
3402def tile(x, n):
3403  """Creates a tensor by tiling `x` by `n`.
3404
3405  Args:
3406      x: A tensor or variable
3407      n: A list of integer. The length must be the same as the number of
3408          dimensions in `x`.
3409
3410  Returns:
3411      A tiled tensor.
3412  """
3413  if isinstance(n, int):
3414    n = [n]
3415  return array_ops.tile(x, n)
3416
3417
3418@keras_export('keras.backend.flatten')
3419@dispatch.add_dispatch_support
3420@doc_controls.do_not_generate_docs
3421def flatten(x):
3422  """Flatten a tensor.
3423
3424  Args:
3425      x: A tensor or variable.
3426
3427  Returns:
3428      A tensor, reshaped into 1-D
3429
3430  Example:
3431
3432      >>> b = tf.constant([[1, 2], [3, 4]])
3433      >>> b
3434      <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
3435      array([[1, 2],
3436             [3, 4]], dtype=int32)>
3437      >>> tf.keras.backend.flatten(b)
3438      <tf.Tensor: shape=(4,), dtype=int32,
3439          numpy=array([1, 2, 3, 4], dtype=int32)>
3440
3441  """
3442  return array_ops.reshape(x, [-1])
3443
3444
3445@keras_export('keras.backend.batch_flatten')
3446@dispatch.add_dispatch_support
3447@doc_controls.do_not_generate_docs
3448def batch_flatten(x):
3449  """Turn a nD tensor into a 2D tensor with same 0th dimension.
3450
3451  In other words, it flattens each data samples of a batch.
3452
3453  Args:
3454      x: A tensor or variable.
3455
3456  Returns:
3457      A tensor.
3458
3459  Examples:
3460    Flattening a 3D tensor to 2D by collapsing the last dimension.
3461
3462  >>> x_batch = tf.keras.backend.ones(shape=(2, 3, 4, 5))
3463  >>> x_batch_flatten = batch_flatten(x_batch)
3464  >>> tf.keras.backend.int_shape(x_batch_flatten)
3465  (2, 60)
3466
3467  """
3468  x = array_ops.reshape(x, array_ops.stack([-1, prod(shape(x)[1:])]))
3469  return x
3470
3471
3472@keras_export('keras.backend.expand_dims')
3473@dispatch.add_dispatch_support
3474@doc_controls.do_not_generate_docs
3475def expand_dims(x, axis=-1):
3476  """Adds a 1-sized dimension at index "axis".
3477
3478  Args:
3479      x: A tensor or variable.
3480      axis: Position where to add a new axis.
3481
3482  Returns:
3483      A tensor with expanded dimensions.
3484  """
3485  return array_ops.expand_dims(x, axis)
3486
3487
3488@keras_export('keras.backend.squeeze')
3489@dispatch.add_dispatch_support
3490@doc_controls.do_not_generate_docs
3491def squeeze(x, axis):
3492  """Removes a 1-dimension from the tensor at index "axis".
3493
3494  Args:
3495      x: A tensor or variable.
3496      axis: Axis to drop.
3497
3498  Returns:
3499      A tensor with the same data as `x` but reduced dimensions.
3500  """
3501  return array_ops.squeeze(x, [axis])
3502
3503
3504@keras_export('keras.backend.temporal_padding')
3505@dispatch.add_dispatch_support
3506@doc_controls.do_not_generate_docs
3507def temporal_padding(x, padding=(1, 1)):
3508  """Pads the middle dimension of a 3D tensor.
3509
3510  Args:
3511      x: Tensor or variable.
3512      padding: Tuple of 2 integers, how many zeros to
3513          add at the start and end of dim 1.
3514
3515  Returns:
3516      A padded 3D tensor.
3517  """
3518  assert len(padding) == 2
3519  pattern = [[0, 0], [padding[0], padding[1]], [0, 0]]
3520  return array_ops.pad(x, pattern)
3521
3522
3523@keras_export('keras.backend.spatial_2d_padding')
3524@dispatch.add_dispatch_support
3525@doc_controls.do_not_generate_docs
3526def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
3527  """Pads the 2nd and 3rd dimensions of a 4D tensor.
3528
3529  Args:
3530      x: Tensor or variable.
3531      padding: Tuple of 2 tuples, padding pattern.
3532      data_format: One of `channels_last` or `channels_first`.
3533
3534  Returns:
3535      A padded 4D tensor.
3536
3537  Raises:
3538      ValueError: if `data_format` is neither
3539          `channels_last` or `channels_first`.
3540  """
3541  assert len(padding) == 2
3542  assert len(padding[0]) == 2
3543  assert len(padding[1]) == 2
3544  if data_format is None:
3545    data_format = image_data_format()
3546  if data_format not in {'channels_first', 'channels_last'}:
3547    raise ValueError('Unknown data_format: ' + str(data_format))
3548
3549  if data_format == 'channels_first':
3550    pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])]
3551  else:
3552    pattern = [[0, 0], list(padding[0]), list(padding[1]), [0, 0]]
3553  return array_ops.pad(x, pattern)
3554
3555
3556@keras_export('keras.backend.spatial_3d_padding')
3557@dispatch.add_dispatch_support
3558@doc_controls.do_not_generate_docs
3559def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
3560  """Pads 5D tensor with zeros along the depth, height, width dimensions.
3561
3562  Pads these dimensions with respectively
3563  "padding[0]", "padding[1]" and "padding[2]" zeros left and right.
3564
3565  For 'channels_last' data_format,
3566  the 2nd, 3rd and 4th dimension will be padded.
3567  For 'channels_first' data_format,
3568  the 3rd, 4th and 5th dimension will be padded.
3569
3570  Args:
3571      x: Tensor or variable.
3572      padding: Tuple of 3 tuples, padding pattern.
3573      data_format: One of `channels_last` or `channels_first`.
3574
3575  Returns:
3576      A padded 5D tensor.
3577
3578  Raises:
3579      ValueError: if `data_format` is neither
3580          `channels_last` or `channels_first`.
3581
3582  """
3583  assert len(padding) == 3
3584  assert len(padding[0]) == 2
3585  assert len(padding[1]) == 2
3586  assert len(padding[2]) == 2
3587  if data_format is None:
3588    data_format = image_data_format()
3589  if data_format not in {'channels_first', 'channels_last'}:
3590    raise ValueError('Unknown data_format: ' + str(data_format))
3591
3592  if data_format == 'channels_first':
3593    pattern = [[0, 0], [0, 0], [padding[0][0], padding[0][1]],
3594               [padding[1][0], padding[1][1]], [padding[2][0], padding[2][1]]]
3595  else:
3596    pattern = [[0, 0], [padding[0][0], padding[0][1]],
3597               [padding[1][0], padding[1][1]], [padding[2][0],
3598                                                padding[2][1]], [0, 0]]
3599  return array_ops.pad(x, pattern)
3600
3601
3602@keras_export('keras.backend.stack')
3603@dispatch.add_dispatch_support
3604@doc_controls.do_not_generate_docs
3605def stack(x, axis=0):
3606  """Stacks a list of rank `R` tensors into a rank `R+1` tensor.
3607
3608  Args:
3609      x: List of tensors.
3610      axis: Axis along which to perform stacking.
3611
3612  Returns:
3613      A tensor.
3614
3615  Example:
3616
3617      >>> a = tf.constant([[1, 2],[3, 4]])
3618      >>> b = tf.constant([[10, 20],[30, 40]])
3619      >>> tf.keras.backend.stack((a, b))
3620      <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
3621      array([[[ 1,  2],
3622              [ 3,  4]],
3623             [[10, 20],
3624              [30, 40]]], dtype=int32)>
3625
3626  """
3627  return array_ops.stack(x, axis=axis)
3628
3629
3630@keras_export('keras.backend.one_hot')
3631@dispatch.add_dispatch_support
3632@doc_controls.do_not_generate_docs
3633def one_hot(indices, num_classes):
3634  """Computes the one-hot representation of an integer tensor.
3635
3636  Args:
3637      indices: nD integer tensor of shape
3638          `(batch_size, dim1, dim2, ... dim(n-1))`
3639      num_classes: Integer, number of classes to consider.
3640
3641  Returns:
3642      (n + 1)D one hot representation of the input
3643      with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
3644
3645  Returns:
3646      The one-hot tensor.
3647  """
3648  return array_ops.one_hot(indices, depth=num_classes, axis=-1)
3649
3650
3651@keras_export('keras.backend.reverse')
3652@dispatch.add_dispatch_support
3653@doc_controls.do_not_generate_docs
3654def reverse(x, axes):
3655  """Reverse a tensor along the specified axes.
3656
3657  Args:
3658      x: Tensor to reverse.
3659      axes: Integer or iterable of integers.
3660          Axes to reverse.
3661
3662  Returns:
3663      A tensor.
3664  """
3665  if isinstance(axes, int):
3666    axes = [axes]
3667  return array_ops.reverse(x, axes)
3668
3669
3670# VALUE MANIPULATION
3671_VALUE_SET_CODE_STRING = """
3672  >>> K = tf.keras.backend  # Common keras convention
3673  >>> v = K.variable(1.)
3674
3675  >>> # reassign
3676  >>> K.set_value(v, 2.)
3677  >>> print(K.get_value(v))
3678  2.0
3679
3680  >>> # increment
3681  >>> K.set_value(v, K.get_value(v) + 1)
3682  >>> print(K.get_value(v))
3683  3.0
3684
3685  Variable semantics in TensorFlow 2 are eager execution friendly. The above
3686  code is roughly equivalent to:
3687
3688  >>> v = tf.Variable(1.)
3689
3690  >>> v.assign(2.)
3691  >>> print(v.numpy())
3692  2.0
3693
3694  >>> v.assign_add(1.)
3695  >>> print(v.numpy())
3696  3.0"""[3:]  # Prune first newline and indent to match the docstring template.
3697
3698
3699@keras_export('keras.backend.get_value')
3700@doc_controls.do_not_generate_docs
3701def get_value(x):
3702  """Returns the value of a variable.
3703
3704  `backend.get_value` is the complement of `backend.set_value`, and provides
3705  a generic interface for reading from variables while abstracting away the
3706  differences between TensorFlow 1.x and 2.x semantics.
3707
3708  {snippet}
3709
3710  Args:
3711      x: input variable.
3712
3713  Returns:
3714      A Numpy array.
3715  """
3716  if not tensor_util.is_tf_type(x):
3717    return x
3718  if context.executing_eagerly() or isinstance(x, ops.EagerTensor):
3719    return x.numpy()
3720  if not getattr(x, '_in_graph_mode', True):
3721    # This is a variable which was created in an eager context, but is being
3722    # evaluated from a Graph.
3723    with context.eager_mode():
3724      return x.numpy()
3725
3726  if ops.executing_eagerly_outside_functions():
3727    # This method of evaluating works inside the Keras FuncGraph.
3728    with ops.init_scope():
3729      return x.numpy()
3730
3731  with x.graph.as_default():
3732    return x.eval(session=get_session((x,)))
3733
3734
3735@keras_export('keras.backend.batch_get_value')
3736@dispatch.add_dispatch_support
3737@doc_controls.do_not_generate_docs
3738def batch_get_value(tensors):
3739  """Returns the value of more than one tensor variable.
3740
3741  Args:
3742      tensors: list of ops to run.
3743
3744  Returns:
3745      A list of Numpy arrays.
3746
3747  Raises:
3748      RuntimeError: If this method is called inside defun.
3749  """
3750  if context.executing_eagerly():
3751    return [x.numpy() for x in tensors]
3752  elif ops.inside_function():  # pylint: disable=protected-access
3753    raise RuntimeError('Cannot get value inside Tensorflow graph function.')
3754  if tensors:
3755    return get_session(tensors).run(tensors)
3756  else:
3757    return []
3758
3759
3760@keras_export('keras.backend.set_value')
3761@doc_controls.do_not_generate_docs
3762def set_value(x, value):
3763  """Sets the value of a variable, from a Numpy array.
3764
3765  `backend.set_value` is the complement of `backend.get_value`, and provides
3766  a generic interface for assigning to variables while abstracting away the
3767  differences between TensorFlow 1.x and 2.x semantics.
3768
3769  {snippet}
3770
3771  Args:
3772      x: Variable to set to a new value.
3773      value: Value to set the tensor to, as a Numpy array
3774          (of the same shape).
3775  """
3776  value = np.asarray(value, dtype=dtype_numpy(x))
3777  if ops.executing_eagerly_outside_functions():
3778    x.assign(value)
3779  else:
3780    with get_graph().as_default():
3781      tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
3782      if hasattr(x, '_assign_placeholder'):
3783        assign_placeholder = x._assign_placeholder
3784        assign_op = x._assign_op
3785      else:
3786        # In order to support assigning weights to resizable variables in
3787        # Keras, we make a placeholder with the correct number of dimensions
3788        # but with None in each dimension. This way, we can assign weights
3789        # of any size (as long as they have the correct dimensionality).
3790        placeholder_shape = tensor_shape.TensorShape([None] * value.ndim)
3791        assign_placeholder = array_ops.placeholder(
3792            tf_dtype, shape=placeholder_shape)
3793        assign_op = x.assign(assign_placeholder)
3794        x._assign_placeholder = assign_placeholder
3795        x._assign_op = assign_op
3796      get_session().run(assign_op, feed_dict={assign_placeholder: value})
3797
3798
3799@keras_export('keras.backend.batch_set_value')
3800@dispatch.add_dispatch_support
3801@doc_controls.do_not_generate_docs
3802def batch_set_value(tuples):
3803  """Sets the values of many tensor variables at once.
3804
3805  Args:
3806      tuples: a list of tuples `(tensor, value)`.
3807          `value` should be a Numpy array.
3808  """
3809  if ops.executing_eagerly_outside_functions():
3810    for x, value in tuples:
3811      x.assign(np.asarray(value, dtype=dtype_numpy(x)))
3812  else:
3813    with get_graph().as_default():
3814      if tuples:
3815        assign_ops = []
3816        feed_dict = {}
3817        for x, value in tuples:
3818          value = np.asarray(value, dtype=dtype_numpy(x))
3819          tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
3820          if hasattr(x, '_assign_placeholder'):
3821            assign_placeholder = x._assign_placeholder
3822            assign_op = x._assign_op
3823          else:
3824            # In order to support assigning weights to resizable variables in
3825            # Keras, we make a placeholder with the correct number of dimensions
3826            # but with None in each dimension. This way, we can assign weights
3827            # of any size (as long as they have the correct dimensionality).
3828            placeholder_shape = tensor_shape.TensorShape([None] * value.ndim)
3829            assign_placeholder = array_ops.placeholder(
3830                tf_dtype, shape=placeholder_shape)
3831            assign_op = x.assign(assign_placeholder)
3832            x._assign_placeholder = assign_placeholder
3833            x._assign_op = assign_op
3834          assign_ops.append(assign_op)
3835          feed_dict[assign_placeholder] = value
3836        get_session().run(assign_ops, feed_dict=feed_dict)
3837
3838
3839get_value.__doc__ = get_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
3840set_value.__doc__ = set_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
3841
3842
3843@keras_export('keras.backend.print_tensor')
3844@dispatch.add_dispatch_support
3845@doc_controls.do_not_generate_docs
3846def print_tensor(x, message='', summarize=3):
3847  """Prints `message` and the tensor value when evaluated.
3848
3849  Note that `print_tensor` returns a new tensor identical to `x`
3850  which should be used in the following code. Otherwise the
3851  print operation is not taken into account during evaluation.
3852
3853  Example:
3854
3855  >>> x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
3856  >>> tf.keras.backend.print_tensor(x)
3857  <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
3858    array([[1., 2.],
3859           [3., 4.]], dtype=float32)>
3860
3861  Args:
3862      x: Tensor to print.
3863      message: Message to print jointly with the tensor.
3864      summarize: The first and last `summarize` elements within each dimension
3865          are recursively printed per Tensor. If None, then the first 3 and last
3866          3 elements of each dimension are printed for each tensor. If set to
3867          -1, it will print all elements of every tensor.
3868
3869  Returns:
3870      The same tensor `x`, unchanged.
3871  """
3872  if isinstance(x, ops.Tensor) and hasattr(x, 'graph'):
3873    with get_graph().as_default():
3874      op = logging_ops.print_v2(
3875          message, x, output_stream=sys.stdout, summarize=summarize)
3876      with ops.control_dependencies([op]):
3877        return array_ops.identity(x)
3878  else:
3879    logging_ops.print_v2(
3880        message, x, output_stream=sys.stdout, summarize=summarize)
3881    return x
3882
3883# GRAPH MANIPULATION
3884
3885
3886class GraphExecutionFunction:
3887  """Runs a computation graph.
3888
3889  It's possible to pass arguments to `tf.Session.run()` via `session_kwargs`.
3890  In particular additional operations via `fetches` argument and additional
3891  tensor substitutions via `feed_dict` arguments. Note that given
3892  substitutions are merged with substitutions from `inputs`. Even though
3893  `feed_dict` is passed once in the constructor (called in `model.compile()`)
3894  we can modify the values in the dictionary. Through this feed_dict we can
3895  provide additional substitutions besides Keras inputs.
3896
3897  Args:
3898      inputs: Feed placeholders to the computation graph.
3899      outputs: Output tensors to fetch.
3900      updates: Additional update ops to be run at function call.
3901      name: A name to help users identify what this function does.
3902      session_kwargs: Arguments to `tf.Session.run()`:
3903                      `fetches`, `feed_dict`, `options`, `run_metadata`.
3904  """
3905
3906  def __init__(self, inputs, outputs, updates=None, name=None,
3907               **session_kwargs):
3908    updates = updates or []
3909    if not isinstance(updates, (list, tuple)):
3910      raise TypeError('`updates` in a Keras backend function '
3911                      'should be a list or tuple.')
3912
3913    self._inputs_structure = inputs
3914    self.inputs = nest.flatten(inputs, expand_composites=True)
3915    self._outputs_structure = outputs
3916    self.outputs = cast_variables_to_tensor(
3917        nest.flatten(outputs, expand_composites=True))
3918    # TODO(b/127668432): Consider using autograph to generate these
3919    # dependencies in call.
3920    # Index 0 = total loss or model output for `predict`.
3921    with ops.control_dependencies([self.outputs[0]]):
3922      updates_ops = []
3923      for update in updates:
3924        if isinstance(update, tuple):
3925          p, new_p = update
3926          updates_ops.append(state_ops.assign(p, new_p))
3927        else:
3928          # assumed already an op
3929          updates_ops.append(update)
3930      self.updates_op = control_flow_ops.group(*updates_ops)
3931    self.name = name
3932    # additional tensor substitutions
3933    self.feed_dict = session_kwargs.pop('feed_dict', None)
3934    # additional operations
3935    self.fetches = session_kwargs.pop('fetches', [])
3936    if not isinstance(self.fetches, list):
3937      self.fetches = [self.fetches]
3938    self.run_options = session_kwargs.pop('options', None)
3939    self.run_metadata = session_kwargs.pop('run_metadata', None)
3940    # The main use case of `fetches` being passed to a model is the ability
3941    # to run custom updates
3942    # This requires us to wrap fetches in `identity` ops.
3943    self.fetches = [array_ops.identity(x) for x in self.fetches]
3944    self.session_kwargs = session_kwargs
3945    # This mapping keeps track of the function that should receive the
3946    # output from a fetch in `fetches`: { fetch: function(fetch_output) }
3947    # A Callback can use this to register a function with access to the
3948    # output values for a fetch it added.
3949    self.fetch_callbacks = {}
3950
3951    if session_kwargs:
3952      raise ValueError('Some keys in session_kwargs are not supported at this '
3953                       'time: %s' % (session_kwargs.keys(),))
3954
3955    self._callable_fn = None
3956    self._feed_arrays = None
3957    self._feed_symbols = None
3958    self._symbol_vals = None
3959    self._fetches = None
3960    self._session = None
3961
3962  def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session):
3963    """Generates a callable that runs the graph.
3964
3965    Args:
3966      feed_arrays: List of input tensors to be fed Numpy arrays at runtime.
3967      feed_symbols: List of input tensors to be fed symbolic tensors at runtime.
3968      symbol_vals: List of symbolic tensors to be fed to `feed_symbols`.
3969      session: Session to use to generate the callable.
3970
3971    Returns:
3972      Function that runs the graph according to the above options.
3973    """
3974    # Prepare callable options.
3975    callable_opts = config_pb2.CallableOptions()
3976    # Handle external-data feed.
3977    for x in feed_arrays:
3978      callable_opts.feed.append(x.name)
3979    if self.feed_dict:
3980      for key in sorted(self.feed_dict.keys()):
3981        callable_opts.feed.append(key.name)
3982    # Handle symbolic feed.
3983    for x, y in zip(feed_symbols, symbol_vals):
3984      connection = callable_opts.tensor_connection.add()
3985      if x.dtype != y.dtype:
3986        y = math_ops.cast(y, dtype=x.dtype)
3987      from_tensor = _as_graph_element(y)
3988      if from_tensor is None:
3989        from_tensor = y
3990      connection.from_tensor = from_tensor.name  # Data tensor
3991      connection.to_tensor = x.name  # Placeholder
3992    # Handle fetches.
3993    for x in self.outputs + self.fetches:
3994      callable_opts.fetch.append(x.name)
3995    # Handle updates.
3996    callable_opts.target.append(self.updates_op.name)
3997    # Handle run_options.
3998    if self.run_options:
3999      callable_opts.run_options.CopyFrom(self.run_options)
4000    # Create callable.
4001    callable_fn = session._make_callable_from_options(callable_opts)
4002    # Cache parameters corresponding to the generated callable, so that
4003    # we can detect future mismatches and refresh the callable.
4004    self._callable_fn = callable_fn
4005    self._feed_arrays = feed_arrays
4006    self._feed_symbols = feed_symbols
4007    self._symbol_vals = symbol_vals
4008    self._fetches = list(self.fetches)
4009    self._session = session
4010
4011  def _call_fetch_callbacks(self, fetches_output):
4012    for fetch, output in zip(self._fetches, fetches_output):
4013      if fetch in self.fetch_callbacks:
4014        self.fetch_callbacks[fetch](output)
4015
4016  def _eval_if_composite(self, tensor):
4017    """Helper method which evaluates any CompositeTensors passed to it."""
4018    # We need to evaluate any composite tensor objects that have been
4019    # reconstructed in 'pack_sequence_as', since otherwise they'll be output as
4020    # actual CompositeTensor objects instead of the value(s) contained in the
4021    # CompositeTensors. E.g., if output_structure contains a SparseTensor, then
4022    # this ensures that we return its value as a SparseTensorValue rather than
4023    # a SparseTensor.
4024    from tensorflow.python.keras.utils import tf_utils  # pylint: disable=g-import-not-at-top
4025    if tf_utils.is_extension_type(tensor):
4026      return self._session.run(tensor)
4027    else:
4028      return tensor
4029
4030  def __call__(self, inputs):
4031    inputs = nest.flatten(inputs, expand_composites=True)
4032
4033    session = get_session(inputs)
4034    feed_arrays = []
4035    array_vals = []
4036    feed_symbols = []
4037    symbol_vals = []
4038    for tensor, value in zip(self.inputs, inputs):
4039      if value is None:
4040        continue
4041
4042      if tensor_util.is_tf_type(value):
4043        # Case: feeding symbolic tensor.
4044        feed_symbols.append(tensor)
4045        symbol_vals.append(value)
4046      else:
4047        # Case: feeding Numpy array.
4048        feed_arrays.append(tensor)
4049        # We need to do array conversion and type casting at this level, since
4050        # `callable_fn` only supports exact matches.
4051        tensor_type = dtypes_module.as_dtype(tensor.dtype)
4052        array_vals.append(np.asarray(value,
4053                                     dtype=tensor_type.as_numpy_dtype))
4054
4055    if self.feed_dict:
4056      for key in sorted(self.feed_dict.keys()):
4057        array_vals.append(
4058            np.asarray(self.feed_dict[key], dtype=key.dtype.as_numpy_dtype))
4059
4060    # Refresh callable if anything has changed.
4061    if (self._callable_fn is None or feed_arrays != self._feed_arrays or
4062        symbol_vals != self._symbol_vals or
4063        feed_symbols != self._feed_symbols or self.fetches != self._fetches or
4064        session != self._session):
4065      self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
4066
4067    fetched = self._callable_fn(*array_vals,
4068                                run_metadata=self.run_metadata)
4069    self._call_fetch_callbacks(fetched[-len(self._fetches):])
4070    output_structure = nest.pack_sequence_as(
4071        self._outputs_structure,
4072        fetched[:len(self.outputs)],
4073        expand_composites=True)
4074    # We need to evaluate any composite tensor objects that have been
4075    # reconstructed in 'pack_sequence_as', since otherwise they'll be output as
4076    # actual CompositeTensor objects instead of the value(s) contained in the
4077    # CompositeTensors. E.g., if output_structure contains a SparseTensor, then
4078    # this ensures that we return its value as a SparseTensorValue rather than
4079    # a SparseTensor.
4080    return nest.map_structure(self._eval_if_composite, output_structure)
4081
4082
4083@keras_export('keras.backend.function')
4084@doc_controls.do_not_generate_docs
4085def function(inputs, outputs, updates=None, name=None, **kwargs):
4086  """Instantiates a Keras function.
4087
4088  Args:
4089      inputs: List of placeholder tensors.
4090      outputs: List of output tensors.
4091      updates: List of update ops.
4092      name: String, name of function.
4093      **kwargs: Passed to `tf.Session.run`.
4094
4095  Returns:
4096      Output values as Numpy arrays.
4097
4098  Raises:
4099      ValueError: if invalid kwargs are passed in or if in eager execution.
4100  """
4101  if ops.executing_eagerly_outside_functions():
4102    if kwargs:
4103      raise ValueError('Session keyword arguments are not supported during '
4104                       'eager execution. You passed: %s' % (kwargs,))
4105    if updates:
4106      raise ValueError('`updates` argument is not supported during '
4107                       'eager execution. You passed: %s' % (updates,))
4108    from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
4109    from tensorflow.python.keras.utils import tf_utils  # pylint: disable=g-import-not-at-top
4110    model = models.Model(inputs=inputs, outputs=outputs)
4111
4112    wrap_outputs = isinstance(outputs, list) and len(outputs) == 1
4113    def func(model_inputs):
4114      outs = model(model_inputs)
4115      if wrap_outputs:
4116        outs = [outs]
4117      return tf_utils.sync_to_numpy_or_python_type(outs)
4118
4119    return func
4120
4121  if kwargs:
4122    for key in kwargs:
4123      if (key not in tf_inspect.getfullargspec(session_module.Session.run)[0]
4124          and key not in ['inputs', 'outputs', 'updates', 'name']):
4125        msg = ('Invalid argument "%s" passed to K.function with TensorFlow '
4126               'backend') % key
4127        raise ValueError(msg)
4128  return GraphExecutionFunction(
4129      inputs, outputs, updates=updates, name=name, **kwargs)
4130
4131
4132@keras_export('keras.backend.gradients')
4133@doc_controls.do_not_generate_docs
4134def gradients(loss, variables):
4135  """Returns the gradients of `loss` w.r.t. `variables`.
4136
4137  Args:
4138      loss: Scalar tensor to minimize.
4139      variables: List of variables.
4140
4141  Returns:
4142      A gradients tensor.
4143  """
4144  return gradients_module.gradients(
4145      loss, variables, colocate_gradients_with_ops=True)
4146
4147
4148@keras_export('keras.backend.stop_gradient')
4149@dispatch.add_dispatch_support
4150@doc_controls.do_not_generate_docs
4151def stop_gradient(variables):
4152  """Returns `variables` but with zero gradient w.r.t. every other variable.
4153
4154  Args:
4155      variables: Tensor or list of tensors to consider constant with respect
4156        to any other variable.
4157
4158
4159  Returns:
4160      A single tensor or a list of tensors (depending on the passed argument)
4161      that has no gradient with respect to any other variable.
4162  """
4163  if isinstance(variables, (list, tuple)):
4164    return map(array_ops.stop_gradient, variables)
4165  return array_ops.stop_gradient(variables)
4166
4167
4168# CONTROL FLOW
4169
4170
4171@keras_export('keras.backend.rnn')
4172@dispatch.add_dispatch_support
4173def rnn(step_function,
4174        inputs,
4175        initial_states,
4176        go_backwards=False,
4177        mask=None,
4178        constants=None,
4179        unroll=False,
4180        input_length=None,
4181        time_major=False,
4182        zero_output_for_mask=False):
4183  """Iterates over the time dimension of a tensor.
4184
4185  Args:
4186      step_function: RNN step function.
4187          Args;
4188              input; Tensor with shape `(samples, ...)` (no time dimension),
4189                  representing input for the batch of samples at a certain
4190                  time step.
4191              states; List of tensors.
4192          Returns;
4193              output; Tensor with shape `(samples, output_dim)`
4194                  (no time dimension).
4195              new_states; List of tensors, same length and shapes
4196                  as 'states'. The first state in the list must be the
4197                  output tensor at the previous timestep.
4198      inputs: Tensor of temporal data of shape `(samples, time, ...)`
4199          (at least 3D), or nested tensors, and each of which has shape
4200          `(samples, time, ...)`.
4201      initial_states: Tensor with shape `(samples, state_size)`
4202          (no time dimension), containing the initial values for the states used
4203          in the step function. In the case that state_size is in a nested
4204          shape, the shape of initial_states will also follow the nested
4205          structure.
4206      go_backwards: Boolean. If True, do the iteration over the time
4207          dimension in reverse order and return the reversed sequence.
4208      mask: Binary tensor with shape `(samples, time, 1)`,
4209          with a zero for every element that is masked.
4210      constants: List of constant values passed at each step.
4211      unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
4212      input_length: An integer or a 1-D Tensor, depending on whether
4213          the time dimension is fixed-length or not. In case of variable length
4214          input, it is used for masking in case there's no mask specified.
4215      time_major: Boolean. If true, the inputs and outputs will be in shape
4216          `(timesteps, batch, ...)`, whereas in the False case, it will be
4217          `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
4218          efficient because it avoids transposes at the beginning and end of the
4219          RNN calculation. However, most TensorFlow data is batch-major, so by
4220          default this function accepts input and emits output in batch-major
4221          form.
4222      zero_output_for_mask: Boolean. If True, the output for masked timestep
4223          will be zeros, whereas in the False case, output from previous
4224          timestep is returned.
4225
4226  Returns:
4227      A tuple, `(last_output, outputs, new_states)`.
4228          last_output: the latest output of the rnn, of shape `(samples, ...)`
4229          outputs: tensor with shape `(samples, time, ...)` where each
4230              entry `outputs[s, t]` is the output of the step function
4231              at time `t` for sample `s`.
4232          new_states: list of tensors, latest states returned by
4233              the step function, of shape `(samples, ...)`.
4234
4235  Raises:
4236      ValueError: if input dimension is less than 3.
4237      ValueError: if `unroll` is `True` but input timestep is not a fixed
4238      number.
4239      ValueError: if `mask` is provided (not `None`) but states is not provided
4240          (`len(states)` == 0).
4241  """
4242
4243  def swap_batch_timestep(input_t):
4244    # Swap the batch and timestep dim for the incoming tensor.
4245    axes = list(range(len(input_t.shape)))
4246    axes[0], axes[1] = 1, 0
4247    return array_ops.transpose(input_t, axes)
4248
4249  if not time_major:
4250    inputs = nest.map_structure(swap_batch_timestep, inputs)
4251
4252  flatted_inputs = nest.flatten(inputs)
4253  time_steps = flatted_inputs[0].shape[0]
4254  batch = flatted_inputs[0].shape[1]
4255  time_steps_t = array_ops.shape(flatted_inputs[0])[0]
4256
4257  for input_ in flatted_inputs:
4258    input_.shape.with_rank_at_least(3)
4259
4260  if mask is not None:
4261    if mask.dtype != dtypes_module.bool:
4262      mask = math_ops.cast(mask, dtypes_module.bool)
4263    if len(mask.shape) == 2:
4264      mask = expand_dims(mask)
4265    if not time_major:
4266      mask = swap_batch_timestep(mask)
4267
4268  if constants is None:
4269    constants = []
4270
4271  # tf.where needs its condition tensor to be the same shape as its two
4272  # result tensors, but in our case the condition (mask) tensor is
4273  # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
4274  # So we need to broadcast the mask to match the shape of inputs.
4275  # That's what the tile call does, it just repeats the mask along its
4276  # second dimension n times.
4277  def _expand_mask(mask_t, input_t, fixed_dim=1):
4278    if nest.is_nested(mask_t):
4279      raise ValueError('mask_t is expected to be tensor, but got %s' % mask_t)
4280    if nest.is_nested(input_t):
4281      raise ValueError('input_t is expected to be tensor, but got %s' % input_t)
4282    rank_diff = len(input_t.shape) - len(mask_t.shape)
4283    for _ in range(rank_diff):
4284      mask_t = array_ops.expand_dims(mask_t, -1)
4285    multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
4286    return array_ops.tile(mask_t, multiples)
4287
4288  if unroll:
4289    if not time_steps:
4290      raise ValueError('Unrolling requires a fixed number of timesteps.')
4291    states = tuple(initial_states)
4292    successive_states = []
4293    successive_outputs = []
4294
4295    # Process the input tensors. The input tensor need to be split on the
4296    # time_step dim, and reverse if go_backwards is True. In the case of nested
4297    # input, the input is flattened and then transformed individually.
4298    # The result of this will be a tuple of lists, each of the item in tuple is
4299    # list of the tensor with shape (batch, feature)
4300    def _process_single_input_t(input_t):
4301      input_t = array_ops.unstack(input_t)  # unstack for time_step dim
4302      if go_backwards:
4303        input_t.reverse()
4304      return input_t
4305
4306    if nest.is_nested(inputs):
4307      processed_input = nest.map_structure(_process_single_input_t, inputs)
4308    else:
4309      processed_input = (_process_single_input_t(inputs),)
4310
4311    def _get_input_tensor(time):
4312      inp = [t_[time] for t_ in processed_input]
4313      return nest.pack_sequence_as(inputs, inp)
4314
4315    if mask is not None:
4316      mask_list = array_ops.unstack(mask)
4317      if go_backwards:
4318        mask_list.reverse()
4319
4320      for i in range(time_steps):
4321        inp = _get_input_tensor(i)
4322        mask_t = mask_list[i]
4323        output, new_states = step_function(inp,
4324                                           tuple(states) + tuple(constants))
4325        tiled_mask_t = _expand_mask(mask_t, output)
4326
4327        if not successive_outputs:
4328          prev_output = zeros_like(output)
4329        else:
4330          prev_output = successive_outputs[-1]
4331
4332        output = array_ops.where_v2(tiled_mask_t, output, prev_output)
4333
4334        flat_states = nest.flatten(states)
4335        flat_new_states = nest.flatten(new_states)
4336        tiled_mask_t = tuple(_expand_mask(mask_t, s) for s in flat_states)
4337        flat_final_states = tuple(
4338            array_ops.where_v2(m, s, ps)
4339            for m, s, ps in zip(tiled_mask_t, flat_new_states, flat_states))
4340        states = nest.pack_sequence_as(states, flat_final_states)
4341
4342        successive_outputs.append(output)
4343        successive_states.append(states)
4344      last_output = successive_outputs[-1]
4345      new_states = successive_states[-1]
4346      outputs = array_ops.stack(successive_outputs)
4347
4348      if zero_output_for_mask:
4349        last_output = array_ops.where_v2(
4350            _expand_mask(mask_list[-1], last_output), last_output,
4351            zeros_like(last_output))
4352        outputs = array_ops.where_v2(
4353            _expand_mask(mask, outputs, fixed_dim=2), outputs,
4354            zeros_like(outputs))
4355
4356    else:  # mask is None
4357      for i in range(time_steps):
4358        inp = _get_input_tensor(i)
4359        output, states = step_function(inp, tuple(states) + tuple(constants))
4360        successive_outputs.append(output)
4361        successive_states.append(states)
4362      last_output = successive_outputs[-1]
4363      new_states = successive_states[-1]
4364      outputs = array_ops.stack(successive_outputs)
4365
4366  else:  # Unroll == False
4367    states = tuple(initial_states)
4368
4369    # Create input tensor array, if the inputs is nested tensors, then it will
4370    # be flattened first, and tensor array will be created one per flattened
4371    # tensor.
4372    input_ta = tuple(
4373        tensor_array_ops.TensorArray(
4374            dtype=inp.dtype,
4375            size=time_steps_t,
4376            tensor_array_name='input_ta_%s' % i)
4377        for i, inp in enumerate(flatted_inputs))
4378    input_ta = tuple(
4379        ta.unstack(input_) if not go_backwards else ta
4380        .unstack(reverse(input_, 0))
4381        for ta, input_ in zip(input_ta, flatted_inputs))
4382
4383    # Get the time(0) input and compute the output for that, the output will be
4384    # used to determine the dtype of output tensor array. Don't read from
4385    # input_ta due to TensorArray clear_after_read default to True.
4386    input_time_zero = nest.pack_sequence_as(inputs,
4387                                            [inp[0] for inp in flatted_inputs])
4388    # output_time_zero is used to determine the cell output shape and its dtype.
4389    # the value is discarded.
4390    output_time_zero, _ = step_function(
4391        input_time_zero, tuple(initial_states) + tuple(constants))
4392    output_ta = tuple(
4393        tensor_array_ops.TensorArray(
4394            dtype=out.dtype,
4395            size=time_steps_t,
4396            element_shape=out.shape,
4397            tensor_array_name='output_ta_%s' % i)
4398        for i, out in enumerate(nest.flatten(output_time_zero)))
4399
4400    time = constant_op.constant(0, dtype='int32', name='time')
4401
4402    # We only specify the 'maximum_iterations' when building for XLA since that
4403    # causes slowdowns on GPU in TF.
4404    if (not context.executing_eagerly() and
4405        control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())):
4406      max_iterations = math_ops.reduce_max(input_length)
4407    else:
4408      max_iterations = None
4409
4410    while_loop_kwargs = {
4411        'cond': lambda time, *_: time < time_steps_t,
4412        'maximum_iterations': max_iterations,
4413        'parallel_iterations': 32,
4414        'swap_memory': True,
4415    }
4416    if mask is not None:
4417      if go_backwards:
4418        mask = reverse(mask, 0)
4419
4420      mask_ta = tensor_array_ops.TensorArray(
4421          dtype=dtypes_module.bool,
4422          size=time_steps_t,
4423          tensor_array_name='mask_ta')
4424      mask_ta = mask_ta.unstack(mask)
4425
4426      def masking_fn(time):
4427        return mask_ta.read(time)
4428
4429      def compute_masked_output(mask_t, flat_out, flat_mask):
4430        tiled_mask_t = tuple(
4431            _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
4432            for o in flat_out)
4433        return tuple(
4434            array_ops.where_v2(m, o, fm)
4435            for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask))
4436    elif isinstance(input_length, ops.Tensor):
4437      if go_backwards:
4438        max_len = math_ops.reduce_max(input_length, axis=0)
4439        rev_input_length = math_ops.subtract(max_len - 1, input_length)
4440
4441        def masking_fn(time):
4442          return math_ops.less(rev_input_length, time)
4443      else:
4444
4445        def masking_fn(time):
4446          return math_ops.greater(input_length, time)
4447
4448      def compute_masked_output(mask_t, flat_out, flat_mask):
4449        return tuple(
4450            array_ops.where(mask_t, o, zo)
4451            for (o, zo) in zip(flat_out, flat_mask))
4452    else:
4453      masking_fn = None
4454
4455    if masking_fn is not None:
4456      # Mask for the T output will be base on the output of T - 1. In the case
4457      # T = 0, a zero filled tensor will be used.
4458      flat_zero_output = tuple(array_ops.zeros_like(o)
4459                               for o in nest.flatten(output_time_zero))
4460      def _step(time, output_ta_t, prev_output, *states):
4461        """RNN step function.
4462
4463        Args:
4464            time: Current timestep value.
4465            output_ta_t: TensorArray.
4466            prev_output: tuple of outputs from time - 1.
4467            *states: List of states.
4468
4469        Returns:
4470            Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
4471        """
4472        current_input = tuple(ta.read(time) for ta in input_ta)
4473        # maybe set shape.
4474        current_input = nest.pack_sequence_as(inputs, current_input)
4475        mask_t = masking_fn(time)
4476        output, new_states = step_function(current_input,
4477                                           tuple(states) + tuple(constants))
4478        # mask output
4479        flat_output = nest.flatten(output)
4480        flat_mask_output = (flat_zero_output if zero_output_for_mask
4481                            else nest.flatten(prev_output))
4482        flat_new_output = compute_masked_output(mask_t, flat_output,
4483                                                flat_mask_output)
4484
4485        # mask states
4486        flat_state = nest.flatten(states)
4487        flat_new_state = nest.flatten(new_states)
4488        for state, new_state in zip(flat_state, flat_new_state):
4489          if isinstance(new_state, ops.Tensor):
4490            new_state.set_shape(state.shape)
4491        flat_final_state = compute_masked_output(mask_t, flat_new_state,
4492                                                 flat_state)
4493        new_states = nest.pack_sequence_as(new_states, flat_final_state)
4494
4495        output_ta_t = tuple(
4496            ta.write(time, out)
4497            for ta, out in zip(output_ta_t, flat_new_output))
4498        return (time + 1, output_ta_t,
4499                tuple(flat_new_output)) + tuple(new_states)
4500
4501      final_outputs = control_flow_ops.while_loop(
4502          body=_step,
4503          loop_vars=(time, output_ta, flat_zero_output) + states,
4504          **while_loop_kwargs)
4505      # Skip final_outputs[2] which is the output for final timestep.
4506      new_states = final_outputs[3:]
4507    else:
4508      def _step(time, output_ta_t, *states):
4509        """RNN step function.
4510
4511        Args:
4512            time: Current timestep value.
4513            output_ta_t: TensorArray.
4514            *states: List of states.
4515
4516        Returns:
4517            Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
4518        """
4519        current_input = tuple(ta.read(time) for ta in input_ta)
4520        current_input = nest.pack_sequence_as(inputs, current_input)
4521        output, new_states = step_function(current_input,
4522                                           tuple(states) + tuple(constants))
4523        flat_state = nest.flatten(states)
4524        flat_new_state = nest.flatten(new_states)
4525        for state, new_state in zip(flat_state, flat_new_state):
4526          if isinstance(new_state, ops.Tensor):
4527            new_state.set_shape(state.shape)
4528
4529        flat_output = nest.flatten(output)
4530        output_ta_t = tuple(
4531            ta.write(time, out) for ta, out in zip(output_ta_t, flat_output))
4532        new_states = nest.pack_sequence_as(initial_states, flat_new_state)
4533        return (time + 1, output_ta_t) + tuple(new_states)
4534
4535      final_outputs = control_flow_ops.while_loop(
4536          body=_step,
4537          loop_vars=(time, output_ta) + states,
4538          **while_loop_kwargs)
4539      new_states = final_outputs[2:]
4540
4541    output_ta = final_outputs[1]
4542
4543    outputs = tuple(o.stack() for o in output_ta)
4544    last_output = tuple(o[-1] for o in outputs)
4545
4546    outputs = nest.pack_sequence_as(output_time_zero, outputs)
4547    last_output = nest.pack_sequence_as(output_time_zero, last_output)
4548
4549  # static shape inference
4550  def set_shape(output_):
4551    if isinstance(output_, ops.Tensor):
4552      shape = output_.shape.as_list()
4553      shape[0] = time_steps
4554      shape[1] = batch
4555      output_.set_shape(shape)
4556    return output_
4557
4558  outputs = nest.map_structure(set_shape, outputs)
4559
4560  if not time_major:
4561    outputs = nest.map_structure(swap_batch_timestep, outputs)
4562
4563  return last_output, outputs, new_states
4564
4565
4566@keras_export('keras.backend.switch')
4567@dispatch.add_dispatch_support
4568@doc_controls.do_not_generate_docs
4569def switch(condition, then_expression, else_expression):
4570  """Switches between two operations depending on a scalar value.
4571
4572  Note that both `then_expression` and `else_expression`
4573  should be symbolic tensors of the *same shape*.
4574
4575  Args:
4576      condition: tensor (`int` or `bool`).
4577      then_expression: either a tensor, or a callable that returns a tensor.
4578      else_expression: either a tensor, or a callable that returns a tensor.
4579
4580  Returns:
4581      The selected tensor.
4582
4583  Raises:
4584      ValueError: If rank of `condition` is greater than rank of expressions.
4585  """
4586  if condition.dtype != dtypes_module.bool:
4587    condition = math_ops.cast(condition, 'bool')
4588  cond_ndim = ndim(condition)
4589  if not cond_ndim:
4590    if not callable(then_expression):
4591
4592      def then_expression_fn():
4593        return then_expression
4594    else:
4595      then_expression_fn = then_expression
4596    if not callable(else_expression):
4597
4598      def else_expression_fn():
4599        return else_expression
4600    else:
4601      else_expression_fn = else_expression
4602    x = control_flow_ops.cond(condition, then_expression_fn, else_expression_fn)
4603  else:
4604    # tf.where needs its condition tensor
4605    # to be the same shape as its two
4606    # result tensors
4607    if callable(then_expression):
4608      then_expression = then_expression()
4609    if callable(else_expression):
4610      else_expression = else_expression()
4611    expr_ndim = ndim(then_expression)
4612    if cond_ndim > expr_ndim:
4613      raise ValueError('Rank of `condition` should be less than or'
4614                       ' equal to rank of `then_expression` and '
4615                       '`else_expression`. ndim(condition)=' + str(cond_ndim) +
4616                       ', ndim(then_expression)'
4617                       '=' + str(expr_ndim))
4618    if cond_ndim > 1:
4619      ndim_diff = expr_ndim - cond_ndim
4620      cond_shape = array_ops.concat(
4621          [array_ops.shape(condition), [1] * ndim_diff], axis=0)
4622      condition = array_ops.reshape(condition, cond_shape)
4623      expr_shape = array_ops.shape(then_expression)
4624      shape_diff = expr_shape - cond_shape
4625      tile_shape = array_ops.where_v2(shape_diff > 0, expr_shape,
4626                                      array_ops.ones_like(expr_shape))
4627      condition = array_ops.tile(condition, tile_shape)
4628    x = array_ops.where_v2(condition, then_expression, else_expression)
4629  return x
4630
4631
4632@keras_export('keras.backend.in_train_phase')
4633@doc_controls.do_not_generate_docs
4634def in_train_phase(x, alt, training=None):
4635  """Selects `x` in train phase, and `alt` otherwise.
4636
4637  Note that `alt` should have the *same shape* as `x`.
4638
4639  Args:
4640      x: What to return in train phase
4641          (tensor or callable that returns a tensor).
4642      alt: What to return otherwise
4643          (tensor or callable that returns a tensor).
4644      training: Optional scalar tensor
4645          (or Python boolean, or Python integer)
4646          specifying the learning phase.
4647
4648  Returns:
4649      Either `x` or `alt` based on the `training` flag.
4650      the `training` flag defaults to `K.learning_phase()`.
4651  """
4652  from tensorflow.python.keras.engine import base_layer_utils  # pylint: disable=g-import-not-at-top
4653  if training is None:
4654    training = base_layer_utils.call_context().training
4655
4656  if training is None:
4657    training = learning_phase()
4658
4659  # TODO(b/138862903): Handle the case when training is tensor.
4660  if not tensor_util.is_tf_type(training):
4661    if training == 1 or training is True:
4662      if callable(x):
4663        return x()
4664      else:
4665        return x
4666
4667    elif training == 0 or training is False:
4668      if callable(alt):
4669        return alt()
4670      else:
4671        return alt
4672
4673  # else: assume learning phase is a placeholder tensor.
4674  x = switch(training, x, alt)
4675  return x
4676
4677
4678@keras_export('keras.backend.in_test_phase')
4679@doc_controls.do_not_generate_docs
4680def in_test_phase(x, alt, training=None):
4681  """Selects `x` in test phase, and `alt` otherwise.
4682
4683  Note that `alt` should have the *same shape* as `x`.
4684
4685  Args:
4686      x: What to return in test phase
4687          (tensor or callable that returns a tensor).
4688      alt: What to return otherwise
4689          (tensor or callable that returns a tensor).
4690      training: Optional scalar tensor
4691          (or Python boolean, or Python integer)
4692          specifying the learning phase.
4693
4694  Returns:
4695      Either `x` or `alt` based on `K.learning_phase`.
4696  """
4697  return in_train_phase(alt, x, training=training)
4698
4699
4700# NN OPERATIONS
4701
4702
4703@keras_export('keras.backend.relu')
4704@dispatch.add_dispatch_support
4705@doc_controls.do_not_generate_docs
4706def relu(x, alpha=0., max_value=None, threshold=0):
4707  """Rectified linear unit.
4708
4709  With default values, it returns element-wise `max(x, 0)`.
4710
4711  Otherwise, it follows:
4712  `f(x) = max_value` for `x >= max_value`,
4713  `f(x) = x` for `threshold <= x < max_value`,
4714  `f(x) = alpha * (x - threshold)` otherwise.
4715
4716  Args:
4717      x: A tensor or variable.
4718      alpha: A scalar, slope of negative section (default=`0.`).
4719      max_value: float. Saturation threshold.
4720      threshold: float. Threshold value for thresholded activation.
4721
4722  Returns:
4723      A tensor.
4724  """
4725  # While x can be a tensor or variable, we also see cases where
4726  # numpy arrays, lists, tuples are passed as well.
4727  # lists, tuples do not have 'dtype' attribute.
4728  dtype = getattr(x, 'dtype', floatx())
4729  if alpha != 0.:
4730    if max_value is None and threshold == 0:
4731      return nn.leaky_relu(x, alpha=alpha)
4732
4733    if threshold != 0:
4734      negative_part = nn.relu(-x + threshold)
4735    else:
4736      negative_part = nn.relu(-x)
4737
4738  clip_max = max_value is not None
4739
4740  if threshold != 0:
4741    # computes x for x > threshold else 0
4742    x = x * math_ops.cast(math_ops.greater(x, threshold), dtype=dtype)
4743  elif max_value == 6:
4744    # if no threshold, then can use nn.relu6 native TF op for performance
4745    x = nn.relu6(x)
4746    clip_max = False
4747  else:
4748    x = nn.relu(x)
4749
4750  if clip_max:
4751    max_value = _constant_to_tensor(max_value, x.dtype.base_dtype)
4752    zero = _constant_to_tensor(0, x.dtype.base_dtype)
4753    x = clip_ops.clip_by_value(x, zero, max_value)
4754
4755  if alpha != 0.:
4756    alpha = _to_tensor(alpha, x.dtype.base_dtype)
4757    x -= alpha * negative_part
4758  return x
4759
4760
4761@keras_export('keras.backend.elu')
4762@dispatch.add_dispatch_support
4763@doc_controls.do_not_generate_docs
4764def elu(x, alpha=1.):
4765  """Exponential linear unit.
4766
4767  Args:
4768      x: A tensor or variable to compute the activation function for.
4769      alpha: A scalar, slope of negative section.
4770
4771  Returns:
4772      A tensor.
4773  """
4774  res = nn.elu(x)
4775  if alpha == 1:
4776    return res
4777  else:
4778    return array_ops.where_v2(x > 0, res, alpha * res)
4779
4780
4781@keras_export('keras.backend.softmax')
4782@dispatch.add_dispatch_support
4783@doc_controls.do_not_generate_docs
4784def softmax(x, axis=-1):
4785  """Softmax of a tensor.
4786
4787  Args:
4788      x: A tensor or variable.
4789      axis: The dimension softmax would be performed on.
4790          The default is -1 which indicates the last dimension.
4791
4792  Returns:
4793      A tensor.
4794  """
4795  return nn.softmax(x, axis=axis)
4796
4797
4798@keras_export('keras.backend.softplus')
4799@dispatch.add_dispatch_support
4800@doc_controls.do_not_generate_docs
4801def softplus(x):
4802  """Softplus of a tensor.
4803
4804  Args:
4805      x: A tensor or variable.
4806
4807  Returns:
4808      A tensor.
4809  """
4810  return math_ops.softplus(x)
4811
4812
4813@keras_export('keras.backend.softsign')
4814@dispatch.add_dispatch_support
4815@doc_controls.do_not_generate_docs
4816def softsign(x):
4817  """Softsign of a tensor.
4818
4819  Args:
4820      x: A tensor or variable.
4821
4822  Returns:
4823      A tensor.
4824  """
4825  return nn.softsign(x)
4826
4827
4828@keras_export('keras.backend.categorical_crossentropy')
4829@dispatch.add_dispatch_support
4830@doc_controls.do_not_generate_docs
4831def categorical_crossentropy(target, output, from_logits=False, axis=-1):
4832  """Categorical crossentropy between an output tensor and a target tensor.
4833
4834  Args:
4835      target: A tensor of the same shape as `output`.
4836      output: A tensor resulting from a softmax
4837          (unless `from_logits` is True, in which
4838          case `output` is expected to be the logits).
4839      from_logits: Boolean, whether `output` is the
4840          result of a softmax, or is a tensor of logits.
4841      axis: Int specifying the channels axis. `axis=-1` corresponds to data
4842          format `channels_last`, and `axis=1` corresponds to data format
4843          `channels_first`.
4844
4845  Returns:
4846      Output tensor.
4847
4848  Raises:
4849      ValueError: if `axis` is neither -1 nor one of the axes of `output`.
4850
4851  Example:
4852
4853  >>> a = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 1.], shape=[3,3])
4854  >>> print(a)
4855  tf.Tensor(
4856    [[1. 0. 0.]
4857     [0. 1. 0.]
4858     [0. 0. 1.]], shape=(3, 3), dtype=float32)
4859  >>> b = tf.constant([.9, .05, .05, .05, .89, .06, .05, .01, .94], shape=[3,3])
4860  >>> print(b)
4861  tf.Tensor(
4862    [[0.9  0.05 0.05]
4863     [0.05 0.89 0.06]
4864     [0.05 0.01 0.94]], shape=(3, 3), dtype=float32)
4865  >>> loss = tf.keras.backend.categorical_crossentropy(a, b)
4866  >>> print(np.around(loss, 5))
4867  [0.10536 0.11653 0.06188]
4868  >>> loss = tf.keras.backend.categorical_crossentropy(a, a)
4869  >>> print(np.around(loss, 5))
4870  [0. 0. 0.]
4871
4872  """
4873  target = ops.convert_to_tensor_v2_with_dispatch(target)
4874  output = ops.convert_to_tensor_v2_with_dispatch(output)
4875  target.shape.assert_is_compatible_with(output.shape)
4876
4877  # Use logits whenever they are available. `softmax` and `sigmoid`
4878  # activations cache logits on the `output` Tensor.
4879  if hasattr(output, '_keras_logits'):
4880    output = output._keras_logits  # pylint: disable=protected-access
4881    if from_logits:
4882      warnings.warn(
4883          '"`categorical_crossentropy` received `from_logits=True`, but '
4884          'the `output` argument was produced by a sigmoid or softmax '
4885          'activation and thus does not represent logits. Was this intended?"')
4886    from_logits = True
4887
4888  if from_logits:
4889    return nn.softmax_cross_entropy_with_logits_v2(
4890        labels=target, logits=output, axis=axis)
4891
4892  if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
4893      output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
4894    # When softmax activation function is used for output operation, we
4895    # use logits from the softmax function directly to compute loss in order
4896    # to prevent collapsing zero when training.
4897    # See b/117284466
4898    assert len(output.op.inputs) == 1
4899    output = output.op.inputs[0]
4900    return nn.softmax_cross_entropy_with_logits_v2(
4901        labels=target, logits=output, axis=axis)
4902
4903  # scale preds so that the class probas of each sample sum to 1
4904  output = output / math_ops.reduce_sum(output, axis, True)
4905  # Compute cross entropy from probabilities.
4906  epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
4907  output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
4908  return -math_ops.reduce_sum(target * math_ops.log(output), axis)
4909
4910
4911@keras_export('keras.backend.sparse_categorical_crossentropy')
4912@dispatch.add_dispatch_support
4913@doc_controls.do_not_generate_docs
4914def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
4915  """Categorical crossentropy with integer targets.
4916
4917  Args:
4918      target: An integer tensor.
4919      output: A tensor resulting from a softmax
4920          (unless `from_logits` is True, in which
4921          case `output` is expected to be the logits).
4922      from_logits: Boolean, whether `output` is the
4923          result of a softmax, or is a tensor of logits.
4924      axis: Int specifying the channels axis. `axis=-1` corresponds to data
4925          format `channels_last`, and `axis=1` corresponds to data format
4926          `channels_first`.
4927
4928  Returns:
4929      Output tensor.
4930
4931  Raises:
4932      ValueError: if `axis` is neither -1 nor one of the axes of `output`.
4933  """
4934  target = ops.convert_to_tensor_v2_with_dispatch(target)
4935  output = ops.convert_to_tensor_v2_with_dispatch(output)
4936
4937  # Use logits whenever they are available. `softmax` and `sigmoid`
4938  # activations cache logits on the `output` Tensor.
4939  if hasattr(output, '_keras_logits'):
4940    output = output._keras_logits  # pylint: disable=protected-access
4941    if from_logits:
4942      warnings.warn(
4943          '"`sparse_categorical_crossentropy` received `from_logits=True`, but '
4944          'the `output` argument was produced by a sigmoid or softmax '
4945          'activation and thus does not represent logits. Was this intended?"')
4946    from_logits = True
4947  elif (not from_logits and
4948        not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
4949        output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
4950    # When softmax activation function is used for output operation, we
4951    # use logits from the softmax function directly to compute loss in order
4952    # to prevent collapsing zero when training.
4953    # See b/117284466
4954    assert len(output.op.inputs) == 1
4955    output = output.op.inputs[0]
4956    from_logits = True
4957  elif not from_logits:
4958    epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
4959    output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
4960    output = math_ops.log(output)
4961
4962  if isinstance(output.shape, (tuple, list)):
4963    output_rank = len(output.shape)
4964  else:
4965    output_rank = output.shape.ndims
4966  if output_rank is not None:
4967    axis %= output_rank
4968    if axis != output_rank - 1:
4969      permutation = list(
4970          itertools.chain(range(axis), range(axis + 1, output_rank), [axis]))
4971      output = array_ops.transpose(output, perm=permutation)
4972  elif axis != -1:
4973    raise ValueError(
4974        'Cannot compute sparse categorical crossentropy with `axis={}` on an '
4975        'output tensor with unknown rank'.format(axis))
4976
4977  target = cast(target, 'int64')
4978
4979  # Try to adjust the shape so that rank of labels = rank of logits - 1.
4980  output_shape = array_ops.shape_v2(output)
4981  target_rank = target.shape.ndims
4982
4983  update_shape = (
4984      target_rank is not None and output_rank is not None and
4985      target_rank != output_rank - 1)
4986  if update_shape:
4987    target = flatten(target)
4988    output = array_ops.reshape(output, [-1, output_shape[-1]])
4989
4990  if py_any(_is_symbolic_tensor(v) for v in [target, output]):
4991    with get_graph().as_default():
4992      res = nn.sparse_softmax_cross_entropy_with_logits_v2(
4993          labels=target, logits=output)
4994  else:
4995    res = nn.sparse_softmax_cross_entropy_with_logits_v2(
4996        labels=target, logits=output)
4997
4998  if update_shape and output_rank >= 3:
4999    # If our output includes timesteps or spatial dimensions we need to reshape
5000    return array_ops.reshape(res, output_shape[:-1])
5001  else:
5002    return res
5003
5004
5005@keras_export('keras.backend.binary_crossentropy')
5006@dispatch.add_dispatch_support
5007@doc_controls.do_not_generate_docs
5008def binary_crossentropy(target, output, from_logits=False):
5009  """Binary crossentropy between an output tensor and a target tensor.
5010
5011  Args:
5012      target: A tensor with the same shape as `output`.
5013      output: A tensor.
5014      from_logits: Whether `output` is expected to be a logits tensor.
5015          By default, we consider that `output`
5016          encodes a probability distribution.
5017
5018  Returns:
5019      A tensor.
5020  """
5021  target = ops.convert_to_tensor_v2_with_dispatch(target)
5022  output = ops.convert_to_tensor_v2_with_dispatch(output)
5023
5024  # Use logits whenever they are available. `softmax` and `sigmoid`
5025  # activations cache logits on the `output` Tensor.
5026  if hasattr(output, '_keras_logits'):
5027    output = output._keras_logits  # pylint: disable=protected-access
5028    if from_logits:
5029      warnings.warn(
5030          '"`binary_crossentropy` received `from_logits=True`, but the `output`'
5031          ' argument was produced by a sigmoid or softmax activation and thus '
5032          'does not represent logits. Was this intended?"')
5033    from_logits = True
5034
5035  if from_logits:
5036    return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
5037
5038  if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
5039      output.op.type == 'Sigmoid') and not hasattr(output, '_keras_history'):
5040    # When sigmoid activation function is used for output operation, we
5041    # use logits from the sigmoid function directly to compute loss in order
5042    # to prevent collapsing zero when training.
5043    assert len(output.op.inputs) == 1
5044    output = output.op.inputs[0]
5045    return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
5046
5047  epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
5048  output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
5049
5050  # Compute cross entropy from probabilities.
5051  bce = target * math_ops.log(output + epsilon())
5052  bce += (1 - target) * math_ops.log(1 - output + epsilon())
5053  return -bce
5054
5055
5056@keras_export('keras.backend.sigmoid')
5057@dispatch.add_dispatch_support
5058@doc_controls.do_not_generate_docs
5059def sigmoid(x):
5060  """Element-wise sigmoid.
5061
5062  Args:
5063      x: A tensor or variable.
5064
5065  Returns:
5066      A tensor.
5067  """
5068  return nn.sigmoid(x)
5069
5070
5071@keras_export('keras.backend.hard_sigmoid')
5072@dispatch.add_dispatch_support
5073@doc_controls.do_not_generate_docs
5074def hard_sigmoid(x):
5075  """Segment-wise linear approximation of sigmoid.
5076
5077  Faster than sigmoid.
5078  Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
5079  In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
5080
5081  Args:
5082      x: A tensor or variable.
5083
5084  Returns:
5085      A tensor.
5086  """
5087  point_two = _constant_to_tensor(0.2, x.dtype.base_dtype)
5088  point_five = _constant_to_tensor(0.5, x.dtype.base_dtype)
5089  x = math_ops.multiply(x, point_two)
5090  x = math_ops.add(x, point_five)
5091  x = clip_ops.clip_by_value(x, 0., 1.)
5092  return x
5093
5094
5095@keras_export('keras.backend.tanh')
5096@dispatch.add_dispatch_support
5097@doc_controls.do_not_generate_docs
5098def tanh(x):
5099  """Element-wise tanh.
5100
5101  Args:
5102      x: A tensor or variable.
5103
5104  Returns:
5105      A tensor.
5106  """
5107  return nn.tanh(x)
5108
5109
5110@keras_export('keras.backend.dropout')
5111@dispatch.add_dispatch_support
5112@doc_controls.do_not_generate_docs
5113def dropout(x, level, noise_shape=None, seed=None):
5114  """Sets entries in `x` to zero at random, while scaling the entire tensor.
5115
5116  Args:
5117      x: tensor
5118      level: fraction of the entries in the tensor
5119          that will be set to 0.
5120      noise_shape: shape for randomly generated keep/drop flags,
5121          must be broadcastable to the shape of `x`
5122      seed: random seed to ensure determinism.
5123
5124  Returns:
5125      A tensor.
5126  """
5127  if seed is None:
5128    seed = np.random.randint(10e6)
5129  return nn.dropout_v2(x, rate=level, noise_shape=noise_shape, seed=seed)
5130
5131
5132@keras_export('keras.backend.l2_normalize')
5133@dispatch.add_dispatch_support
5134@doc_controls.do_not_generate_docs
5135def l2_normalize(x, axis=None):
5136  """Normalizes a tensor wrt the L2 norm alongside the specified axis.
5137
5138  Args:
5139      x: Tensor or variable.
5140      axis: axis along which to perform normalization.
5141
5142  Returns:
5143      A tensor.
5144  """
5145  return nn.l2_normalize(x, axis=axis)
5146
5147
5148@keras_export('keras.backend.in_top_k')
5149@dispatch.add_dispatch_support
5150@doc_controls.do_not_generate_docs
5151def in_top_k(predictions, targets, k):
5152  """Returns whether the `targets` are in the top `k` `predictions`.
5153
5154  Args:
5155      predictions: A tensor of shape `(batch_size, classes)` and type `float32`.
5156      targets: A 1D tensor of length `batch_size` and type `int32` or `int64`.
5157      k: An `int`, number of top elements to consider.
5158
5159  Returns:
5160      A 1D tensor of length `batch_size` and type `bool`.
5161      `output[i]` is `True` if `predictions[i, targets[i]]` is within top-`k`
5162      values of `predictions[i]`.
5163  """
5164  return nn.in_top_k(predictions, targets, k)
5165
5166
5167# CONVOLUTIONS
5168
5169
5170def _preprocess_conv1d_input(x, data_format):
5171  """Transpose and cast the input before the conv1d.
5172
5173  Args:
5174      x: input tensor.
5175      data_format: string, `"channels_last"` or `"channels_first"`.
5176
5177  Returns:
5178      A tensor.
5179  """
5180  tf_data_format = 'NWC'  # to pass TF Conv2dNative operations
5181  if data_format == 'channels_first':
5182    if not _has_nchw_support():
5183      x = array_ops.transpose(x, (0, 2, 1))  # NCW -> NWC
5184    else:
5185      tf_data_format = 'NCW'
5186  return x, tf_data_format
5187
5188
5189def _preprocess_conv2d_input(x, data_format, force_transpose=False):
5190  """Transpose and cast the input before the conv2d.
5191
5192  Args:
5193      x: input tensor.
5194      data_format: string, `"channels_last"` or `"channels_first"`.
5195      force_transpose: Boolean. If True, the input will always be transposed
5196          from NCHW to NHWC if `data_format` is `"channels_first"`.
5197          If False, the transposition only occurs on CPU (GPU ops are
5198          assumed to support NCHW).
5199
5200  Returns:
5201      A tensor.
5202  """
5203  tf_data_format = 'NHWC'
5204  if data_format == 'channels_first':
5205    if not _has_nchw_support() or force_transpose:
5206      x = array_ops.transpose(x, (0, 2, 3, 1))  # NCHW -> NHWC
5207    else:
5208      tf_data_format = 'NCHW'
5209  return x, tf_data_format
5210
5211
5212def _preprocess_conv3d_input(x, data_format):
5213  """Transpose and cast the input before the conv3d.
5214
5215  Args:
5216      x: input tensor.
5217      data_format: string, `"channels_last"` or `"channels_first"`.
5218
5219  Returns:
5220      A tensor.
5221  """
5222  tf_data_format = 'NDHWC'
5223  if data_format == 'channels_first':
5224    if not _has_nchw_support():
5225      x = array_ops.transpose(x, (0, 2, 3, 4, 1))
5226    else:
5227      tf_data_format = 'NCDHW'
5228  return x, tf_data_format
5229
5230
5231def _preprocess_padding(padding):
5232  """Convert keras' padding to TensorFlow's padding.
5233
5234  Args:
5235      padding: string, one of 'same' , 'valid'
5236
5237  Returns:
5238      a string, one of 'SAME', 'VALID'.
5239
5240  Raises:
5241      ValueError: if invalid `padding'`
5242  """
5243  if padding == 'same':
5244    padding = 'SAME'
5245  elif padding == 'valid':
5246    padding = 'VALID'
5247  else:
5248    raise ValueError('Invalid padding: ' + str(padding))
5249  return padding
5250
5251
5252@keras_export('keras.backend.conv1d')
5253@dispatch.add_dispatch_support
5254@doc_controls.do_not_generate_docs
5255def conv1d(x,
5256           kernel,
5257           strides=1,
5258           padding='valid',
5259           data_format=None,
5260           dilation_rate=1):
5261  """1D convolution.
5262
5263  Args:
5264      x: Tensor or variable.
5265      kernel: kernel tensor.
5266      strides: stride integer.
5267      padding: string, `"same"`, `"causal"` or `"valid"`.
5268      data_format: string, one of "channels_last", "channels_first".
5269      dilation_rate: integer dilate rate.
5270
5271  Returns:
5272      A tensor, result of 1D convolution.
5273
5274  Raises:
5275      ValueError: if `data_format` is neither `channels_last` or
5276      `channels_first`.
5277  """
5278  if data_format is None:
5279    data_format = image_data_format()
5280  if data_format not in {'channels_first', 'channels_last'}:
5281    raise ValueError('Unknown data_format: ' + str(data_format))
5282
5283  kernel_shape = kernel.shape.as_list()
5284  if padding == 'causal':
5285    # causal (dilated) convolution:
5286    left_pad = dilation_rate * (kernel_shape[0] - 1)
5287    x = temporal_padding(x, (left_pad, 0))
5288    padding = 'valid'
5289  padding = _preprocess_padding(padding)
5290
5291  x, tf_data_format = _preprocess_conv1d_input(x, data_format)
5292  x = nn.convolution(
5293      input=x,
5294      filter=kernel,
5295      dilation_rate=dilation_rate,
5296      strides=strides,
5297      padding=padding,
5298      data_format=tf_data_format)
5299  if data_format == 'channels_first' and tf_data_format == 'NWC':
5300    x = array_ops.transpose(x, (0, 2, 1))  # NWC -> NCW
5301  return x
5302
5303
5304@keras_export('keras.backend.conv2d')
5305@dispatch.add_dispatch_support
5306@doc_controls.do_not_generate_docs
5307def conv2d(x,
5308           kernel,
5309           strides=(1, 1),
5310           padding='valid',
5311           data_format=None,
5312           dilation_rate=(1, 1)):
5313  """2D convolution.
5314
5315  Args:
5316      x: Tensor or variable.
5317      kernel: kernel tensor.
5318      strides: strides tuple.
5319      padding: string, `"same"` or `"valid"`.
5320      data_format: `"channels_last"` or `"channels_first"`.
5321      dilation_rate: tuple of 2 integers.
5322
5323  Returns:
5324      A tensor, result of 2D convolution.
5325
5326  Raises:
5327      ValueError: if `data_format` is neither `channels_last` or
5328      `channels_first`.
5329  """
5330  if data_format is None:
5331    data_format = image_data_format()
5332  if data_format not in {'channels_first', 'channels_last'}:
5333    raise ValueError('Unknown data_format: ' + str(data_format))
5334
5335  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5336  padding = _preprocess_padding(padding)
5337  x = nn.convolution(
5338      input=x,
5339      filter=kernel,
5340      dilation_rate=dilation_rate,
5341      strides=strides,
5342      padding=padding,
5343      data_format=tf_data_format)
5344  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5345    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5346  return x
5347
5348
5349@keras_export('keras.backend.conv2d_transpose')
5350@dispatch.add_dispatch_support
5351@doc_controls.do_not_generate_docs
5352def conv2d_transpose(x,
5353                     kernel,
5354                     output_shape,
5355                     strides=(1, 1),
5356                     padding='valid',
5357                     data_format=None,
5358                     dilation_rate=(1, 1)):
5359  """2D deconvolution (i.e.
5360
5361  transposed convolution).
5362
5363  Args:
5364      x: Tensor or variable.
5365      kernel: kernel tensor.
5366      output_shape: 1D int tensor for the output shape.
5367      strides: strides tuple.
5368      padding: string, `"same"` or `"valid"`.
5369      data_format: string, `"channels_last"` or `"channels_first"`.
5370      dilation_rate: Tuple of 2 integers.
5371
5372  Returns:
5373      A tensor, result of transposed 2D convolution.
5374
5375  Raises:
5376      ValueError: if `data_format` is neither `channels_last` or
5377      `channels_first`.
5378  """
5379  if data_format is None:
5380    data_format = image_data_format()
5381  if data_format not in {'channels_first', 'channels_last'}:
5382    raise ValueError('Unknown data_format: ' + str(data_format))
5383
5384  # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
5385  if data_format == 'channels_first' and dilation_rate != (1, 1):
5386    force_transpose = True
5387  else:
5388    force_transpose = False
5389
5390  x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
5391
5392  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5393    output_shape = (output_shape[0], output_shape[2], output_shape[3],
5394                    output_shape[1])
5395  if output_shape[0] is None:
5396    output_shape = (shape(x)[0],) + tuple(output_shape[1:])
5397
5398  if isinstance(output_shape, (tuple, list)):
5399    output_shape = array_ops.stack(list(output_shape))
5400
5401  padding = _preprocess_padding(padding)
5402  if tf_data_format == 'NHWC':
5403    strides = (1,) + strides + (1,)
5404  else:
5405    strides = (1, 1) + strides
5406
5407  if dilation_rate == (1, 1):
5408    x = nn.conv2d_transpose(x, kernel, output_shape, strides,
5409                            padding=padding,
5410                            data_format=tf_data_format)
5411  else:
5412    assert dilation_rate[0] == dilation_rate[1]
5413    x = nn.atrous_conv2d_transpose(
5414        x,
5415        kernel,
5416        output_shape,
5417        rate=dilation_rate[0],
5418        padding=padding)
5419  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5420    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5421  return x
5422
5423
5424def separable_conv1d(x,
5425                     depthwise_kernel,
5426                     pointwise_kernel,
5427                     strides=1,
5428                     padding='valid',
5429                     data_format=None,
5430                     dilation_rate=1):
5431  """1D convolution with separable filters.
5432
5433  Args:
5434      x: input tensor
5435      depthwise_kernel: convolution kernel for the depthwise convolution.
5436      pointwise_kernel: kernel for the 1x1 convolution.
5437      strides: stride integer.
5438      padding: string, `"same"` or `"valid"`.
5439      data_format: string, `"channels_last"` or `"channels_first"`.
5440      dilation_rate: integer dilation rate.
5441
5442  Returns:
5443      Output tensor.
5444
5445  Raises:
5446      ValueError: if `data_format` is neither `channels_last` or
5447      `channels_first`.
5448  """
5449  if data_format is None:
5450    data_format = image_data_format()
5451  if data_format not in {'channels_first', 'channels_last'}:
5452    raise ValueError('Unknown data_format: ' + str(data_format))
5453
5454  if isinstance(strides, int):
5455    strides = (strides,)
5456  if isinstance(dilation_rate, int):
5457    dilation_rate = (dilation_rate,)
5458
5459  x, tf_data_format = _preprocess_conv1d_input(x, data_format)
5460  padding = _preprocess_padding(padding)
5461  if not isinstance(strides, tuple):
5462    strides = tuple(strides)
5463  if tf_data_format == 'NWC':
5464    spatial_start_dim = 1
5465    strides = (1,) + strides * 2 + (1,)
5466  else:
5467    spatial_start_dim = 2
5468    strides = (1, 1) + strides * 2
5469  x = array_ops.expand_dims(x, spatial_start_dim)
5470  depthwise_kernel = array_ops.expand_dims(depthwise_kernel, 0)
5471  pointwise_kernel = array_ops.expand_dims(pointwise_kernel, 0)
5472  dilation_rate = (1,) + dilation_rate
5473
5474  x = nn.separable_conv2d(
5475      x,
5476      depthwise_kernel,
5477      pointwise_kernel,
5478      strides=strides,
5479      padding=padding,
5480      rate=dilation_rate,
5481      data_format=tf_data_format)
5482
5483  x = array_ops.squeeze(x, [spatial_start_dim])
5484
5485  if data_format == 'channels_first' and tf_data_format == 'NWC':
5486    x = array_ops.transpose(x, (0, 2, 1))  # NWC -> NCW
5487
5488  return x
5489
5490
5491@keras_export('keras.backend.separable_conv2d')
5492@dispatch.add_dispatch_support
5493@doc_controls.do_not_generate_docs
5494def separable_conv2d(x,
5495                     depthwise_kernel,
5496                     pointwise_kernel,
5497                     strides=(1, 1),
5498                     padding='valid',
5499                     data_format=None,
5500                     dilation_rate=(1, 1)):
5501  """2D convolution with separable filters.
5502
5503  Args:
5504      x: input tensor
5505      depthwise_kernel: convolution kernel for the depthwise convolution.
5506      pointwise_kernel: kernel for the 1x1 convolution.
5507      strides: strides tuple (length 2).
5508      padding: string, `"same"` or `"valid"`.
5509      data_format: string, `"channels_last"` or `"channels_first"`.
5510      dilation_rate: tuple of integers,
5511          dilation rates for the separable convolution.
5512
5513  Returns:
5514      Output tensor.
5515
5516  Raises:
5517      ValueError: if `data_format` is neither `channels_last` or
5518      `channels_first`.
5519      ValueError: if `strides` is not a tuple of 2 integers.
5520  """
5521  if data_format is None:
5522    data_format = image_data_format()
5523  if data_format not in {'channels_first', 'channels_last'}:
5524    raise ValueError('Unknown data_format: ' + str(data_format))
5525  if len(strides) != 2:
5526    raise ValueError('`strides` must be a tuple of 2 integers.')
5527
5528  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5529  padding = _preprocess_padding(padding)
5530  if not isinstance(strides, tuple):
5531    strides = tuple(strides)
5532  if tf_data_format == 'NHWC':
5533    strides = (1,) + strides + (1,)
5534  else:
5535    strides = (1, 1) + strides
5536
5537  x = nn.separable_conv2d(
5538      x,
5539      depthwise_kernel,
5540      pointwise_kernel,
5541      strides=strides,
5542      padding=padding,
5543      rate=dilation_rate,
5544      data_format=tf_data_format)
5545  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5546    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5547  return x
5548
5549
5550@keras_export('keras.backend.depthwise_conv2d')
5551@dispatch.add_dispatch_support
5552@doc_controls.do_not_generate_docs
5553def depthwise_conv2d(x,
5554                     depthwise_kernel,
5555                     strides=(1, 1),
5556                     padding='valid',
5557                     data_format=None,
5558                     dilation_rate=(1, 1)):
5559  """2D convolution with separable filters.
5560
5561  Args:
5562      x: input tensor
5563      depthwise_kernel: convolution kernel for the depthwise convolution.
5564      strides: strides tuple (length 2).
5565      padding: string, `"same"` or `"valid"`.
5566      data_format: string, `"channels_last"` or `"channels_first"`.
5567      dilation_rate: tuple of integers,
5568          dilation rates for the separable convolution.
5569
5570  Returns:
5571      Output tensor.
5572
5573  Raises:
5574      ValueError: if `data_format` is neither `channels_last` or
5575      `channels_first`.
5576  """
5577  if data_format is None:
5578    data_format = image_data_format()
5579  if data_format not in {'channels_first', 'channels_last'}:
5580    raise ValueError('Unknown data_format: ' + str(data_format))
5581
5582  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5583  padding = _preprocess_padding(padding)
5584  if tf_data_format == 'NHWC':
5585    strides = (1,) + strides + (1,)
5586  else:
5587    strides = (1, 1) + strides
5588
5589  x = nn.depthwise_conv2d(
5590      x,
5591      depthwise_kernel,
5592      strides=strides,
5593      padding=padding,
5594      rate=dilation_rate,
5595      data_format=tf_data_format)
5596  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5597    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5598  return x
5599
5600
5601@keras_export('keras.backend.conv3d')
5602@dispatch.add_dispatch_support
5603@doc_controls.do_not_generate_docs
5604def conv3d(x,
5605           kernel,
5606           strides=(1, 1, 1),
5607           padding='valid',
5608           data_format=None,
5609           dilation_rate=(1, 1, 1)):
5610  """3D convolution.
5611
5612  Args:
5613      x: Tensor or variable.
5614      kernel: kernel tensor.
5615      strides: strides tuple.
5616      padding: string, `"same"` or `"valid"`.
5617      data_format: string, `"channels_last"` or `"channels_first"`.
5618      dilation_rate: tuple of 3 integers.
5619
5620  Returns:
5621      A tensor, result of 3D convolution.
5622
5623  Raises:
5624      ValueError: if `data_format` is neither `channels_last` or
5625      `channels_first`.
5626  """
5627  if data_format is None:
5628    data_format = image_data_format()
5629  if data_format not in {'channels_first', 'channels_last'}:
5630    raise ValueError('Unknown data_format: ' + str(data_format))
5631
5632  x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5633  padding = _preprocess_padding(padding)
5634  x = nn.convolution(
5635      input=x,
5636      filter=kernel,
5637      dilation_rate=dilation_rate,
5638      strides=strides,
5639      padding=padding,
5640      data_format=tf_data_format)
5641  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5642    x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5643  return x
5644
5645
5646def conv3d_transpose(x,
5647                     kernel,
5648                     output_shape,
5649                     strides=(1, 1, 1),
5650                     padding='valid',
5651                     data_format=None):
5652  """3D deconvolution (i.e.
5653
5654  transposed convolution).
5655
5656  Args:
5657      x: input tensor.
5658      kernel: kernel tensor.
5659      output_shape: 1D int tensor for the output shape.
5660      strides: strides tuple.
5661      padding: string, "same" or "valid".
5662      data_format: string, `"channels_last"` or `"channels_first"`.
5663
5664  Returns:
5665      A tensor, result of transposed 3D convolution.
5666
5667  Raises:
5668      ValueError: if `data_format` is neither `channels_last` or
5669      `channels_first`.
5670  """
5671  if data_format is None:
5672    data_format = image_data_format()
5673  if data_format not in {'channels_first', 'channels_last'}:
5674    raise ValueError('Unknown data_format: ' + str(data_format))
5675  if isinstance(output_shape, (tuple, list)):
5676    output_shape = array_ops.stack(output_shape)
5677
5678  x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5679
5680  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5681    output_shape = (output_shape[0], output_shape[2], output_shape[3],
5682                    output_shape[4], output_shape[1])
5683  if output_shape[0] is None:
5684    output_shape = (array_ops.shape(x)[0],) + tuple(output_shape[1:])
5685    output_shape = array_ops.stack(list(output_shape))
5686
5687  padding = _preprocess_padding(padding)
5688  if tf_data_format == 'NDHWC':
5689    strides = (1,) + strides + (1,)
5690  else:
5691    strides = (1, 1) + strides
5692
5693  x = nn.conv3d_transpose(
5694      x,
5695      kernel,
5696      output_shape,
5697      strides,
5698      padding=padding,
5699      data_format=tf_data_format)
5700  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5701    x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5702  return x
5703
5704
5705@keras_export('keras.backend.pool2d')
5706@dispatch.add_dispatch_support
5707@doc_controls.do_not_generate_docs
5708def pool2d(x,
5709           pool_size,
5710           strides=(1, 1),
5711           padding='valid',
5712           data_format=None,
5713           pool_mode='max'):
5714  """2D Pooling.
5715
5716  Args:
5717      x: Tensor or variable.
5718      pool_size: tuple of 2 integers.
5719      strides: tuple of 2 integers.
5720      padding: string, `"same"` or `"valid"`.
5721      data_format: string, `"channels_last"` or `"channels_first"`.
5722      pool_mode: string, `"max"` or `"avg"`.
5723
5724  Returns:
5725      A tensor, result of 2D pooling.
5726
5727  Raises:
5728      ValueError: if `data_format` is neither `"channels_last"` or
5729      `"channels_first"`.
5730      ValueError: if `pool_size` is not a tuple of 2 integers.
5731      ValueError: if `strides` is not a tuple of 2 integers.
5732      ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
5733  """
5734  if data_format is None:
5735    data_format = image_data_format()
5736  if data_format not in {'channels_first', 'channels_last'}:
5737    raise ValueError('Unknown data_format: ' + str(data_format))
5738  if len(pool_size) != 2:
5739    raise ValueError('`pool_size` must be a tuple of 2 integers.')
5740  if len(strides) != 2:
5741    raise ValueError('`strides` must be a tuple of 2 integers.')
5742
5743  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5744  padding = _preprocess_padding(padding)
5745  if tf_data_format == 'NHWC':
5746    strides = (1,) + strides + (1,)
5747    pool_size = (1,) + pool_size + (1,)
5748  else:
5749    strides = (1, 1) + strides
5750    pool_size = (1, 1) + pool_size
5751
5752  if pool_mode == 'max':
5753    x = nn.max_pool(
5754        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5755  elif pool_mode == 'avg':
5756    x = nn.avg_pool(
5757        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5758  else:
5759    raise ValueError('Invalid pooling mode: ' + str(pool_mode))
5760
5761  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5762    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5763  return x
5764
5765
5766@keras_export('keras.backend.pool3d')
5767@dispatch.add_dispatch_support
5768@doc_controls.do_not_generate_docs
5769def pool3d(x,
5770           pool_size,
5771           strides=(1, 1, 1),
5772           padding='valid',
5773           data_format=None,
5774           pool_mode='max'):
5775  """3D Pooling.
5776
5777  Args:
5778      x: Tensor or variable.
5779      pool_size: tuple of 3 integers.
5780      strides: tuple of 3 integers.
5781      padding: string, `"same"` or `"valid"`.
5782      data_format: string, `"channels_last"` or `"channels_first"`.
5783      pool_mode: string, `"max"` or `"avg"`.
5784
5785  Returns:
5786      A tensor, result of 3D pooling.
5787
5788  Raises:
5789      ValueError: if `data_format` is neither `"channels_last"` or
5790      `"channels_first"`.
5791      ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
5792  """
5793  if data_format is None:
5794    data_format = image_data_format()
5795  if data_format not in {'channels_first', 'channels_last'}:
5796    raise ValueError('Unknown data_format: ' + str(data_format))
5797
5798  x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5799  padding = _preprocess_padding(padding)
5800  if tf_data_format == 'NDHWC':
5801    strides = (1,) + strides + (1,)
5802    pool_size = (1,) + pool_size + (1,)
5803  else:
5804    strides = (1, 1) + strides
5805    pool_size = (1, 1) + pool_size
5806
5807  if pool_mode == 'max':
5808    x = nn.max_pool3d(
5809        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5810  elif pool_mode == 'avg':
5811    x = nn.avg_pool3d(
5812        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5813  else:
5814    raise ValueError('Invalid pooling mode: ' + str(pool_mode))
5815
5816  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5817    x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5818  return x
5819
5820
5821def local_conv(inputs,
5822               kernel,
5823               kernel_size,
5824               strides,
5825               output_shape,
5826               data_format=None):
5827  """Apply N-D convolution with un-shared weights.
5828
5829  Args:
5830      inputs: (N+2)-D tensor with shape
5831          (batch_size, channels_in, d_in1, ..., d_inN)
5832          if data_format='channels_first', or
5833          (batch_size, d_in1, ..., d_inN, channels_in)
5834          if data_format='channels_last'.
5835      kernel: the unshared weight for N-D convolution,
5836          with shape (output_items, feature_dim, channels_out), where
5837          feature_dim = np.prod(kernel_size) * channels_in,
5838          output_items = np.prod(output_shape).
5839      kernel_size: a tuple of N integers, specifying the
5840          spatial dimensions of the N-D convolution window.
5841      strides: a tuple of N integers, specifying the strides
5842          of the convolution along the spatial dimensions.
5843      output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial
5844          dimensionality of the output.
5845      data_format: string, "channels_first" or "channels_last".
5846
5847  Returns:
5848      An (N+2)-D tensor with shape:
5849      (batch_size, channels_out) + output_shape
5850      if data_format='channels_first', or:
5851      (batch_size,) + output_shape + (channels_out,)
5852      if data_format='channels_last'.
5853
5854  Raises:
5855      ValueError: if `data_format` is neither
5856      `channels_last` nor `channels_first`.
5857  """
5858  if data_format is None:
5859    data_format = image_data_format()
5860  if data_format not in {'channels_first', 'channels_last'}:
5861    raise ValueError('Unknown data_format: ' + str(data_format))
5862
5863  kernel_shape = int_shape(kernel)
5864  feature_dim = kernel_shape[1]
5865  channels_out = kernel_shape[-1]
5866  ndims = len(output_shape)
5867  spatial_dimensions = list(range(ndims))
5868
5869  xs = []
5870  output_axes_ticks = [range(axis_max) for axis_max in output_shape]
5871  for position in itertools.product(*output_axes_ticks):
5872    slices = [slice(None)]
5873
5874    if data_format == 'channels_first':
5875      slices.append(slice(None))
5876
5877    slices.extend(
5878        slice(position[d] * strides[d], position[d] * strides[d] +
5879              kernel_size[d]) for d in spatial_dimensions)
5880
5881    if data_format == 'channels_last':
5882      slices.append(slice(None))
5883
5884    xs.append(reshape(inputs[slices], (1, -1, feature_dim)))
5885
5886  x_aggregate = concatenate(xs, axis=0)
5887  output = batch_dot(x_aggregate, kernel)
5888  output = reshape(output, output_shape + (-1, channels_out))
5889
5890  if data_format == 'channels_first':
5891    permutation = [ndims, ndims + 1] + spatial_dimensions
5892  else:
5893    permutation = [ndims] + spatial_dimensions + [ndims + 1]
5894
5895  return permute_dimensions(output, permutation)
5896
5897
5898@keras_export('keras.backend.local_conv1d')
5899@dispatch.add_dispatch_support
5900@doc_controls.do_not_generate_docs
5901def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
5902  """Apply 1D conv with un-shared weights.
5903
5904  Args:
5905      inputs: 3D tensor with shape:
5906          (batch_size, steps, input_dim)
5907          if data_format is "channels_last" or
5908          (batch_size, input_dim, steps)
5909          if data_format is "channels_first".
5910      kernel: the unshared weight for convolution,
5911          with shape (output_length, feature_dim, filters).
5912      kernel_size: a tuple of a single integer,
5913          specifying the length of the 1D convolution window.
5914      strides: a tuple of a single integer,
5915          specifying the stride length of the convolution.
5916      data_format: the data format, channels_first or channels_last.
5917
5918  Returns:
5919      A 3d tensor with shape:
5920      (batch_size, output_length, filters)
5921      if data_format='channels_first'
5922      or 3D tensor with shape:
5923      (batch_size, filters, output_length)
5924      if data_format='channels_last'.
5925  """
5926  output_shape = (kernel.shape[0],)
5927  return local_conv(inputs,
5928                    kernel,
5929                    kernel_size,
5930                    strides,
5931                    output_shape,
5932                    data_format)
5933
5934
5935@keras_export('keras.backend.local_conv2d')
5936@dispatch.add_dispatch_support
5937@doc_controls.do_not_generate_docs
5938def local_conv2d(inputs,
5939                 kernel,
5940                 kernel_size,
5941                 strides,
5942                 output_shape,
5943                 data_format=None):
5944  """Apply 2D conv with un-shared weights.
5945
5946  Args:
5947      inputs: 4D tensor with shape:
5948          (batch_size, filters, new_rows, new_cols)
5949          if data_format='channels_first'
5950          or 4D tensor with shape:
5951          (batch_size, new_rows, new_cols, filters)
5952          if data_format='channels_last'.
5953      kernel: the unshared weight for convolution,
5954          with shape (output_items, feature_dim, filters).
5955      kernel_size: a tuple of 2 integers, specifying the
5956          width and height of the 2D convolution window.
5957      strides: a tuple of 2 integers, specifying the strides
5958          of the convolution along the width and height.
5959      output_shape: a tuple with (output_row, output_col).
5960      data_format: the data format, channels_first or channels_last.
5961
5962  Returns:
5963      A 4D tensor with shape:
5964      (batch_size, filters, new_rows, new_cols)
5965      if data_format='channels_first'
5966      or 4D tensor with shape:
5967      (batch_size, new_rows, new_cols, filters)
5968      if data_format='channels_last'.
5969  """
5970  return local_conv(inputs,
5971                    kernel,
5972                    kernel_size,
5973                    strides,
5974                    output_shape,
5975                    data_format)
5976
5977
5978@keras_export('keras.backend.bias_add')
5979@dispatch.add_dispatch_support
5980@doc_controls.do_not_generate_docs
5981def bias_add(x, bias, data_format=None):
5982  """Adds a bias vector to a tensor.
5983
5984  Args:
5985      x: Tensor or variable.
5986      bias: Bias tensor to add.
5987      data_format: string, `"channels_last"` or `"channels_first"`.
5988
5989  Returns:
5990      Output tensor.
5991
5992  Raises:
5993      ValueError: In one of the two cases below:
5994                  1. invalid `data_format` argument.
5995                  2. invalid bias shape.
5996                     the bias should be either a vector or
5997                     a tensor with ndim(x) - 1 dimension
5998  """
5999  if data_format is None:
6000    data_format = image_data_format()
6001  if data_format not in {'channels_first', 'channels_last'}:
6002    raise ValueError('Unknown data_format: ' + str(data_format))
6003  bias_shape = int_shape(bias)
6004  if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1:
6005    raise ValueError(
6006        'Unexpected bias dimensions %d, expect to be 1 or %d dimensions' %
6007        (len(bias_shape), ndim(x) - 1))
6008
6009  if len(bias_shape) == 1:
6010    if data_format == 'channels_first':
6011      return nn.bias_add(x, bias, data_format='NCHW')
6012    return nn.bias_add(x, bias, data_format='NHWC')
6013  if ndim(x) in (3, 4, 5):
6014    if data_format == 'channels_first':
6015      bias_reshape_axis = (1, bias_shape[-1]) + bias_shape[:-1]
6016      return x + reshape(bias, bias_reshape_axis)
6017    return x + reshape(bias, (1,) + bias_shape)
6018  return nn.bias_add(x, bias)
6019
6020
6021# RANDOMNESS
6022
6023
6024@keras_export('keras.backend.random_normal')
6025@dispatch.add_dispatch_support
6026@doc_controls.do_not_generate_docs
6027def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
6028  """Returns a tensor with normal distribution of values.
6029
6030  It is an alias to `tf.random.normal`.
6031
6032  Args:
6033      shape: A tuple of integers, the shape of tensor to create.
6034      mean: A float, the mean value of the normal distribution to draw samples.
6035        Default to 0.0.
6036      stddev: A float, the standard deviation of the normal distribution
6037        to draw samples. Default to 1.0.
6038      dtype: `tf.dtypes.DType`, dtype of returned tensor. Default to use Keras
6039        backend dtype which is float32.
6040      seed: Integer, random seed. Will use a random numpy integer when not
6041        specified.
6042
6043  Returns:
6044      A tensor with normal distribution of values.
6045
6046  Example:
6047
6048  >>> random_normal_tensor = tf.keras.backend.random_normal(shape=(2,3),
6049  ... mean=0.0, stddev=1.0)
6050  >>> random_normal_tensor
6051  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...,
6052  dtype=float32)>
6053  """
6054  if dtype is None:
6055    dtype = floatx()
6056  if seed is None:
6057    seed = np.random.randint(10e6)
6058  return random_ops.random_normal(
6059      shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed)
6060
6061
6062@keras_export('keras.backend.random_uniform')
6063@dispatch.add_dispatch_support
6064@doc_controls.do_not_generate_docs
6065def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
6066  """Returns a tensor with uniform distribution of values.
6067
6068  Args:
6069      shape: A tuple of integers, the shape of tensor to create.
6070      minval: A float, lower boundary of the uniform distribution
6071          to draw samples.
6072      maxval: A float, upper boundary of the uniform distribution
6073          to draw samples.
6074      dtype: String, dtype of returned tensor.
6075      seed: Integer, random seed.
6076
6077  Returns:
6078      A tensor.
6079
6080  Example:
6081
6082  >>> random_uniform_tensor = tf.keras.backend.random_uniform(shape=(2,3),
6083  ... minval=0.0, maxval=1.0)
6084  >>> random_uniform_tensor
6085  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...,
6086  dtype=float32)>
6087  """
6088  if dtype is None:
6089    dtype = floatx()
6090  if seed is None:
6091    seed = np.random.randint(10e6)
6092  return random_ops.random_uniform(
6093      shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed)
6094
6095
6096@keras_export('keras.backend.random_binomial')
6097@dispatch.add_dispatch_support
6098@doc_controls.do_not_generate_docs
6099def random_binomial(shape, p=0.0, dtype=None, seed=None):
6100  """Returns a tensor with random binomial distribution of values.
6101
6102  DEPRECATED, use `tf.keras.backend.random_bernoulli` instead.
6103
6104  The binomial distribution with parameters `n` and `p` is the probability
6105  distribution of the number of successful Bernoulli process. Only supports
6106  `n` = 1 for now.
6107
6108  Args:
6109      shape: A tuple of integers, the shape of tensor to create.
6110      p: A float, `0. <= p <= 1`, probability of binomial distribution.
6111      dtype: String, dtype of returned tensor.
6112      seed: Integer, random seed.
6113
6114  Returns:
6115      A tensor.
6116
6117  Example:
6118
6119  >>> random_binomial_tensor = tf.keras.backend.random_binomial(shape=(2,3),
6120  ... p=0.5)
6121  >>> random_binomial_tensor
6122  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...,
6123  dtype=float32)>
6124  """
6125  warnings.warn('`tf.keras.backend.random_binomial` is deprecated, '
6126                'and will be removed in a future version.'
6127                'Please use `tf.keras.backend.random_bernoulli` instead.')
6128  return random_bernoulli(shape, p, dtype, seed)
6129
6130
6131@keras_export('keras.backend.random_bernoulli')
6132@dispatch.add_dispatch_support
6133@doc_controls.do_not_generate_docs
6134def random_bernoulli(shape, p=0.0, dtype=None, seed=None):
6135  """Returns a tensor with random bernoulli distribution of values.
6136
6137  Args:
6138      shape: A tuple of integers, the shape of tensor to create.
6139      p: A float, `0. <= p <= 1`, probability of bernoulli distribution.
6140      dtype: String, dtype of returned tensor.
6141      seed: Integer, random seed.
6142
6143  Returns:
6144      A tensor.
6145  """
6146  if dtype is None:
6147    dtype = floatx()
6148  if seed is None:
6149    seed = np.random.randint(10e6)
6150  return array_ops.where_v2(
6151      random_ops.random_uniform(shape, dtype=dtype, seed=seed) <= p,
6152      array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
6153
6154
6155@keras_export('keras.backend.truncated_normal')
6156@dispatch.add_dispatch_support
6157@doc_controls.do_not_generate_docs
6158def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
6159  """Returns a tensor with truncated random normal distribution of values.
6160
6161  The generated values follow a normal distribution
6162  with specified mean and standard deviation,
6163  except that values whose magnitude is more than
6164  two standard deviations from the mean are dropped and re-picked.
6165
6166  Args:
6167      shape: A tuple of integers, the shape of tensor to create.
6168      mean: Mean of the values.
6169      stddev: Standard deviation of the values.
6170      dtype: String, dtype of returned tensor.
6171      seed: Integer, random seed.
6172
6173  Returns:
6174      A tensor.
6175  """
6176  if dtype is None:
6177    dtype = floatx()
6178  if seed is None:
6179    seed = np.random.randint(10e6)
6180  return random_ops.truncated_normal(
6181      shape, mean, stddev, dtype=dtype, seed=seed)
6182
6183
6184# CTC
6185# TensorFlow has a native implementation, but it uses sparse tensors
6186# and therefore requires a wrapper for Keras. The functions below convert
6187# dense to sparse tensors and also wraps up the beam search code that is
6188# in TensorFlow's CTC implementation
6189
6190
6191@keras_export('keras.backend.ctc_label_dense_to_sparse')
6192@dispatch.add_dispatch_support
6193@doc_controls.do_not_generate_docs
6194def ctc_label_dense_to_sparse(labels, label_lengths):
6195  """Converts CTC labels from dense to sparse.
6196
6197  Args:
6198      labels: dense CTC labels.
6199      label_lengths: length of the labels.
6200
6201  Returns:
6202      A sparse tensor representation of the labels.
6203  """
6204  label_shape = array_ops.shape(labels)
6205  num_batches_tns = array_ops.stack([label_shape[0]])
6206  max_num_labels_tns = array_ops.stack([label_shape[1]])
6207
6208  def range_less_than(old_input, current_input):
6209    return array_ops.expand_dims(
6210        math_ops.range(array_ops.shape(old_input)[1]), 0) < array_ops.fill(
6211            max_num_labels_tns, current_input)
6212
6213  init = math_ops.cast(
6214      array_ops.fill([1, label_shape[1]], 0), dtypes_module.bool)
6215  dense_mask = functional_ops.scan(
6216      range_less_than, label_lengths, initializer=init, parallel_iterations=1)
6217  dense_mask = dense_mask[:, 0, :]
6218
6219  label_array = array_ops.reshape(
6220      array_ops.tile(math_ops.range(0, label_shape[1]), num_batches_tns),
6221      label_shape)
6222  label_ind = array_ops.boolean_mask(label_array, dense_mask)
6223
6224  batch_array = array_ops.transpose(
6225      array_ops.reshape(
6226          array_ops.tile(math_ops.range(0, label_shape[0]), max_num_labels_tns),
6227          reverse(label_shape, 0)))
6228  batch_ind = array_ops.boolean_mask(batch_array, dense_mask)
6229  indices = array_ops.transpose(
6230      array_ops.reshape(concatenate([batch_ind, label_ind], axis=0), [2, -1]))
6231
6232  vals_sparse = array_ops.gather_nd(labels, indices)
6233
6234  return sparse_tensor.SparseTensor(
6235      math_ops.cast(indices, dtypes_module.int64), vals_sparse,
6236      math_ops.cast(label_shape, dtypes_module.int64))
6237
6238
6239@keras_export('keras.backend.ctc_batch_cost')
6240@dispatch.add_dispatch_support
6241@doc_controls.do_not_generate_docs
6242def ctc_batch_cost(y_true, y_pred, input_length, label_length):
6243  """Runs CTC loss algorithm on each batch element.
6244
6245  Args:
6246      y_true: tensor `(samples, max_string_length)`
6247          containing the truth labels.
6248      y_pred: tensor `(samples, time_steps, num_categories)`
6249          containing the prediction, or output of the softmax.
6250      input_length: tensor `(samples, 1)` containing the sequence length for
6251          each batch item in `y_pred`.
6252      label_length: tensor `(samples, 1)` containing the sequence length for
6253          each batch item in `y_true`.
6254
6255  Returns:
6256      Tensor with shape (samples,1) containing the
6257          CTC loss of each element.
6258  """
6259  label_length = math_ops.cast(
6260      array_ops.squeeze(label_length, axis=-1), dtypes_module.int32)
6261  input_length = math_ops.cast(
6262      array_ops.squeeze(input_length, axis=-1), dtypes_module.int32)
6263  sparse_labels = math_ops.cast(
6264      ctc_label_dense_to_sparse(y_true, label_length), dtypes_module.int32)
6265
6266  y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
6267
6268  return array_ops.expand_dims(
6269      ctc.ctc_loss(
6270          inputs=y_pred, labels=sparse_labels, sequence_length=input_length), 1)
6271
6272
6273@keras_export('keras.backend.ctc_decode')
6274@dispatch.add_dispatch_support
6275@doc_controls.do_not_generate_docs
6276def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
6277  """Decodes the output of a softmax.
6278
6279  Can use either greedy search (also known as best path)
6280  or a constrained dictionary search.
6281
6282  Args:
6283      y_pred: tensor `(samples, time_steps, num_categories)`
6284          containing the prediction, or output of the softmax.
6285      input_length: tensor `(samples, )` containing the sequence length for
6286          each batch item in `y_pred`.
6287      greedy: perform much faster best-path search if `true`.
6288          This does not use a dictionary.
6289      beam_width: if `greedy` is `false`: a beam search decoder will be used
6290          with a beam of this width.
6291      top_paths: if `greedy` is `false`,
6292          how many of the most probable paths will be returned.
6293
6294  Returns:
6295      Tuple:
6296          List: if `greedy` is `true`, returns a list of one element that
6297              contains the decoded sequence.
6298              If `false`, returns the `top_paths` most probable
6299              decoded sequences.
6300              Each decoded sequence has shape (samples, time_steps).
6301              Important: blank labels are returned as `-1`.
6302          Tensor `(top_paths, )` that contains
6303              the log probability of each decoded sequence.
6304  """
6305  input_shape = shape(y_pred)
6306  num_samples, num_steps = input_shape[0], input_shape[1]
6307  y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
6308  input_length = math_ops.cast(input_length, dtypes_module.int32)
6309
6310  if greedy:
6311    (decoded, log_prob) = ctc.ctc_greedy_decoder(
6312        inputs=y_pred, sequence_length=input_length)
6313  else:
6314    (decoded, log_prob) = ctc.ctc_beam_search_decoder(
6315        inputs=y_pred,
6316        sequence_length=input_length,
6317        beam_width=beam_width,
6318        top_paths=top_paths)
6319  decoded_dense = []
6320  for st in decoded:
6321    st = sparse_tensor.SparseTensor(
6322        st.indices, st.values, (num_samples, num_steps))
6323    decoded_dense.append(
6324        sparse_ops.sparse_tensor_to_dense(sp_input=st, default_value=-1))
6325  return (decoded_dense, log_prob)
6326
6327
6328# HIGH ORDER FUNCTIONS
6329
6330
6331@keras_export('keras.backend.map_fn')
6332@doc_controls.do_not_generate_docs
6333def map_fn(fn, elems, name=None, dtype=None):
6334  """Map the function fn over the elements elems and return the outputs.
6335
6336  Args:
6337      fn: Callable that will be called upon each element in elems
6338      elems: tensor
6339      name: A string name for the map node in the graph
6340      dtype: Output data type.
6341
6342  Returns:
6343      Tensor with dtype `dtype`.
6344  """
6345  return map_fn_lib.map_fn(fn, elems, name=name, dtype=dtype)
6346
6347
6348@keras_export('keras.backend.foldl')
6349@doc_controls.do_not_generate_docs
6350def foldl(fn, elems, initializer=None, name=None):
6351  """Reduce elems using fn to combine them from left to right.
6352
6353  Args:
6354      fn: Callable that will be called upon each element in elems and an
6355          accumulator, for instance `lambda acc, x: acc + x`
6356      elems: tensor
6357      initializer: The first value used (`elems[0]` in case of None)
6358      name: A string name for the foldl node in the graph
6359
6360  Returns:
6361      Tensor with same type and shape as `initializer`.
6362  """
6363  return functional_ops.foldl(fn, elems, initializer=initializer, name=name)
6364
6365
6366@keras_export('keras.backend.foldr')
6367@doc_controls.do_not_generate_docs
6368def foldr(fn, elems, initializer=None, name=None):
6369  """Reduce elems using fn to combine them from right to left.
6370
6371  Args:
6372      fn: Callable that will be called upon each element in elems and an
6373          accumulator, for instance `lambda acc, x: acc + x`
6374      elems: tensor
6375      initializer: The first value used (`elems[-1]` in case of None)
6376      name: A string name for the foldr node in the graph
6377
6378  Returns:
6379      Same type and shape as initializer
6380  """
6381  return functional_ops.foldr(fn, elems, initializer=initializer, name=name)
6382
6383# Load Keras default configuration from config file if present.
6384# Set Keras base dir path given KERAS_HOME env variable, if applicable.
6385# Otherwise either ~/.keras or /tmp.
6386if 'KERAS_HOME' in os.environ:
6387  _keras_dir = os.environ.get('KERAS_HOME')
6388else:
6389  _keras_base_dir = os.path.expanduser('~')
6390  _keras_dir = os.path.join(_keras_base_dir, '.keras')
6391_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
6392if os.path.exists(_config_path):
6393  try:
6394    with open(_config_path) as fh:
6395      _config = json.load(fh)
6396  except ValueError:
6397    _config = {}
6398  _floatx = _config.get('floatx', floatx())
6399  assert _floatx in {'float16', 'float32', 'float64'}
6400  _epsilon = _config.get('epsilon', epsilon())
6401  assert isinstance(_epsilon, float)
6402  _image_data_format = _config.get('image_data_format', image_data_format())
6403  assert _image_data_format in {'channels_last', 'channels_first'}
6404  set_floatx(_floatx)
6405  set_epsilon(_epsilon)
6406  set_image_data_format(_image_data_format)
6407
6408# Save config file.
6409if not os.path.exists(_keras_dir):
6410  try:
6411    os.makedirs(_keras_dir)
6412  except OSError:
6413    # Except permission denied and potential race conditions
6414    # in multi-threaded environments.
6415    pass
6416
6417if not os.path.exists(_config_path):
6418  _config = {
6419      'floatx': floatx(),
6420      'epsilon': epsilon(),
6421      'backend': 'tensorflow',
6422      'image_data_format': image_data_format()
6423  }
6424  try:
6425    with open(_config_path, 'w') as f:
6426      f.write(json.dumps(_config, indent=4))
6427  except IOError:
6428    # Except permission denied.
6429    pass
6430
6431
6432def configure_and_create_distributed_session(distribution_strategy):
6433  """Configure session config and create a session with it."""
6434
6435  def _create_session(distribution_strategy):
6436    """Create the Distributed Strategy session."""
6437    session_config = get_default_session_config()
6438
6439    # If a session already exists, merge in its config; in the case there is a
6440    # conflict, take values of the existing config.
6441    global _SESSION
6442    if getattr(_SESSION, 'session', None) and _SESSION.session._config:
6443      session_config.MergeFrom(_SESSION.session._config)
6444
6445    if is_tpu_strategy(distribution_strategy):
6446      # TODO(priyag, yuefengz): Remove this workaround when Distribute
6447      # Coordinator is integrated with keras and we can create a session from
6448      # there.
6449      distribution_strategy.configure(session_config)
6450      master = distribution_strategy.extended._tpu_cluster_resolver.master()  # pylint: disable=protected-access
6451      session = session_module.Session(config=session_config, target=master)
6452    else:
6453      worker_context = dc.get_current_worker_context()
6454      if worker_context:
6455        dc_session_config = worker_context.session_config
6456        # Merge the default session config to the one from distribute
6457        # coordinator, which is fine for now since they don't have
6458        # conflicting configurations.
6459        dc_session_config.MergeFrom(session_config)
6460        session = session_module.Session(
6461            config=dc_session_config, target=worker_context.master_target)
6462      else:
6463        distribution_strategy.configure(session_config)
6464        session = session_module.Session(config=session_config)
6465
6466    set_session(session)
6467
6468  if distribution_strategy.extended._in_multi_worker_mode():
6469    dc.run_distribute_coordinator(
6470        _create_session,
6471        distribution_strategy)
6472  else:
6473    _create_session(distribution_strategy)
6474
6475
6476def _is_tpu_strategy_class(clz):
6477  is_tpu_strat = lambda k: k.__name__.startswith('TPUStrategy')
6478  if is_tpu_strat(clz):
6479    return True
6480  return py_any(map(_is_tpu_strategy_class, clz.__bases__))
6481
6482
6483def is_tpu_strategy(strategy):
6484  """Returns whether input is a TPUStrategy instance or subclass instance."""
6485  return _is_tpu_strategy_class(strategy.__class__)
6486
6487
6488def cast_variables_to_tensor(tensors):
6489
6490  def _cast_variables_to_tensor(tensor):
6491    if isinstance(tensor, variables_module.Variable):
6492      return array_ops.identity(tensor)
6493    return tensor
6494
6495  return nest.map_structure(_cast_variables_to_tensor, tensors)
6496
6497
6498def _is_symbolic_tensor(x):
6499  return tensor_util.is_tf_type(x) and not isinstance(x, ops.EagerTensor)
6500
6501
6502def convert_inputs_if_ragged(inputs):
6503  """Converts any ragged tensors to dense."""
6504
6505  def _convert_ragged_input(inputs):
6506    if isinstance(inputs, ragged_tensor.RaggedTensor):
6507      return inputs.to_tensor()
6508    return inputs
6509
6510  flat_inputs = nest.flatten(inputs)
6511  contains_ragged = py_any(
6512      isinstance(i, ragged_tensor.RaggedTensor) for i in flat_inputs)
6513
6514  if not contains_ragged:
6515    return inputs, None
6516
6517  inputs = nest.map_structure(_convert_ragged_input, inputs)
6518  # Multiple mask are not yet supported, so one mask is used on all inputs.
6519  # We approach this similarly when using row lengths to ignore steps.
6520  nested_row_lengths = math_ops.cast(flat_inputs[0].nested_row_lengths()[0],
6521                                     'int32')
6522  return inputs, nested_row_lengths
6523
6524
6525def maybe_convert_to_ragged(is_ragged_input, output, nested_row_lengths,
6526                            go_backwards=False):
6527  """Converts any ragged input back to its initial structure."""
6528  if not is_ragged_input:
6529    return output
6530
6531  if go_backwards:
6532    # Reverse based on the timestep dim, so that nested_row_lengths will mask
6533    # from the correct direction. Return the reverse ragged tensor.
6534    output = reverse(output, [1])
6535    ragged = ragged_tensor.RaggedTensor.from_tensor(output, nested_row_lengths)
6536    return reverse(ragged, [1])
6537  else:
6538    return ragged_tensor.RaggedTensor.from_tensor(output, nested_row_lengths)
6539
6540
6541class ContextValueCache(weakref.WeakKeyDictionary):
6542  """Container that caches (possibly tensor) values based on the context.
6543
6544  This class is similar to defaultdict, where values may be produced by the
6545  default factory specified during initialization. This class also has a default
6546  value for the key (when key is `None`) -- the key is set to the current graph
6547  or eager context. The default factories for key and value are only used in
6548  `__getitem__` and `setdefault`. The `.get()` behavior remains the same.
6549
6550  This object will return the value of the current graph or closest parent graph
6551  if the current graph is a function. This is to reflect the fact that if a
6552  tensor is created in eager/graph, child functions may capture that tensor.
6553
6554  The default factory method may accept keyword arguments (unlike defaultdict,
6555  which only accepts callables with 0 arguments). To pass keyword arguments to
6556  `default_factory`, use the `setdefault` method instead of `__getitem__`.
6557
6558  An example of how this class can be used in different contexts:
6559
6560  ```
6561  cache = ContextValueCache(int)
6562
6563  # Eager mode
6564  cache[None] += 2
6565  cache[None] += 4
6566  assert cache[None] == 6
6567
6568  # Graph mode
6569  with tf.Graph().as_default() as g:
6570    cache[None] += 5
6571    cache[g] += 3
6572  assert cache[g] == 8
6573  ```
6574
6575  Example of a default factory with arguments:
6576
6577  ```
6578  cache = ContextValueCache(lambda x: x + 1)
6579  g = tf.get_default_graph()
6580
6581  # Example with keyword argument.
6582  value = cache.setdefault(key=g, kwargs={'x': 3})
6583  assert cache[g] == 4
6584  ```
6585  """
6586
6587  def __init__(self, default_factory):
6588    self.default_factory = default_factory
6589    weakref.WeakKeyDictionary.__init__(self)
6590
6591  def _key(self):
6592    if context.executing_eagerly():
6593      return _DUMMY_EAGER_GRAPH.key
6594    else:
6595      return ops.get_default_graph()
6596
6597  def _get_parent_graph(self, graph):
6598    """Returns the parent graph or dummy eager object."""
6599    # TODO(b/149317164): Currently FuncGraphs use ops.get_default_graph() as the
6600    # outer graph. This results in outer_graph always being a Graph,
6601    # even in eager mode (get_default_graph will create a new Graph if there
6602    # isn't a default graph). Because of this bug, we have to specially set the
6603    # key when eager execution is enabled.
6604    parent_graph = graph.outer_graph
6605    if (not isinstance(parent_graph, func_graph.FuncGraph) and
6606        ops.executing_eagerly_outside_functions()):
6607      return _DUMMY_EAGER_GRAPH.key
6608    return parent_graph
6609
6610  def _get_recursive(self, key):
6611    """Gets the value at key or the closest parent graph."""
6612    value = self.get(key)
6613    if value is not None:
6614      return value
6615
6616    # Since FuncGraphs are able to capture tensors and variables from their
6617    # parent graphs, recursively search to see if there is a value stored for
6618    # one of the parent graphs.
6619    if isinstance(key, func_graph.FuncGraph):
6620      return self._get_recursive(self._get_parent_graph(key))
6621    return None
6622
6623  def __getitem__(self, key):
6624    """Gets the value at key (or current context), or sets default value.
6625
6626    Args:
6627      key: May be `None` or `Graph`object. When `None`, the key is set to the
6628        current context.
6629
6630    Returns:
6631      Either the cached or default value.
6632    """
6633    if key is None:
6634      key = self._key()
6635
6636    value = self._get_recursive(key)
6637    if value is None:
6638      value = self[key] = self.default_factory()  # pylint:disable=not-callable
6639    return value
6640
6641  def setdefault(self, key=None, default=None, kwargs=None):
6642    """Sets the default value if key is not in dict, and returns the value."""
6643    if key is None:
6644      key = self._key()
6645    kwargs = kwargs or {}
6646
6647    if default is None and key not in self:
6648      default = self.default_factory(**kwargs)
6649    return weakref.WeakKeyDictionary.setdefault(self, key, default)
6650
6651# This dictionary holds a mapping {graph: learning_phase}. In eager mode, a
6652# dummy object is used.
6653# A learning phase is a bool tensor used to run Keras models in
6654# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
6655_GRAPH_LEARNING_PHASES = ContextValueCache(_default_learning_phase)
6656
6657# This dictionary holds a mapping between a graph and variables to initialize
6658# in the graph.
6659_GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet)
6660
6661# This dictionary holds a mapping between a graph and TF optimizers created in
6662# the graph.
6663_GRAPH_TF_OPTIMIZERS = ContextValueCache(object_identity.ObjectIdentityWeakSet)
6664