• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=redefined-outer-name
17# pylint: disable=redefined-builtin
18"""Keras backend API.
19"""
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import collections
25import itertools
26import json
27import os
28import sys
29import threading
30import warnings
31import weakref
32
33import numpy as np
34
35from tensorflow.core.protobuf import config_pb2
36from tensorflow.python import tf2
37from tensorflow.python.client import session as session_module
38from tensorflow.python.distribute import distribute_coordinator as dc
39from tensorflow.python.distribute import distribute_coordinator_context as dc_context
40from tensorflow.python.distribute import distribution_strategy_context
41from tensorflow.python.eager import context
42from tensorflow.python.eager import function as eager_function
43from tensorflow.python.eager import lift_to_graph
44from tensorflow.python.eager.context import get_config
45from tensorflow.python.framework import composite_tensor
46from tensorflow.python.framework import config
47from tensorflow.python.framework import constant_op
48from tensorflow.python.framework import device_spec
49from tensorflow.python.framework import dtypes as dtypes_module
50from tensorflow.python.framework import func_graph
51from tensorflow.python.framework import ops
52from tensorflow.python.framework import sparse_tensor
53from tensorflow.python.framework import tensor_shape
54from tensorflow.python.framework import tensor_spec
55from tensorflow.python.framework import tensor_util
56from tensorflow.python.keras import backend_config
57from tensorflow.python.keras.engine import keras_tensor
58from tensorflow.python.keras.utils import control_flow_util
59from tensorflow.python.keras.utils import object_identity
60from tensorflow.python.keras.utils import tf_contextlib
61from tensorflow.python.keras.utils import tf_inspect
62from tensorflow.python.ops import array_ops
63from tensorflow.python.ops import clip_ops
64from tensorflow.python.ops import control_flow_ops
65from tensorflow.python.ops import ctc_ops as ctc
66from tensorflow.python.ops import functional_ops
67from tensorflow.python.ops import gradients as gradients_module
68from tensorflow.python.ops import image_ops
69from tensorflow.python.ops import init_ops
70from tensorflow.python.ops import linalg_ops
71from tensorflow.python.ops import logging_ops
72from tensorflow.python.ops import map_fn as map_fn_lib
73from tensorflow.python.ops import math_ops
74from tensorflow.python.ops import nn
75from tensorflow.python.ops import random_ops
76from tensorflow.python.ops import sparse_ops
77from tensorflow.python.ops import state_ops
78from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
79from tensorflow.python.ops import tensor_array_ops
80from tensorflow.python.ops import variables as variables_module
81from tensorflow.python.ops.ragged import ragged_tensor
82from tensorflow.python.platform import tf_logging as logging
83from tensorflow.python.training import moving_averages
84from tensorflow.python.training.tracking import util as tracking_util
85from tensorflow.python.util import dispatch
86from tensorflow.python.util import keras_deps
87from tensorflow.python.util import nest
88from tensorflow.python.util.tf_export import keras_export
89from tensorflow.tools.docs import doc_controls
90
91py_all = all
92py_sum = sum
93py_any = any
94
95# INTERNAL UTILS
96
97# The internal graph maintained by Keras and used by the symbolic Keras APIs
98# while executing eagerly (such as the functional API for model-building).
99# This is thread-local to allow building separate models in different threads
100# concurrently, but comes at the cost of not being able to build one model
101# across threads.
102_GRAPH = threading.local()
103
104# A graph which is used for constructing functions in eager mode.
105_CURRENT_SCRATCH_GRAPH = threading.local()
106
107# This is a thread local object that will hold the default internal TF session
108# used by Keras. It can be set manually via `set_session(sess)`.
109_SESSION = threading.local()
110
111
112# A global dictionary mapping graph objects to an index of counters used
113# for various layer/optimizer names in each graph.
114# Allows to give unique autogenerated names to layers, in a graph-specific way.
115PER_GRAPH_OBJECT_NAME_UIDS = weakref.WeakKeyDictionary()
116
117
118# A global set tracking what object names have been seen so far.
119# Optionally used as an avoid-list when generating names
120OBSERVED_NAMES = set()
121
122
123# _DUMMY_EAGER_GRAPH.key is used as a key in _GRAPH_LEARNING_PHASES.
124# We keep a separate reference to it to make sure it does not get removed from
125# _GRAPH_LEARNING_PHASES.
126# _DummyEagerGraph inherits from threading.local to make its `key` attribute
127# thread local. This is needed to make set_learning_phase affect only the
128# current thread during eager execution (see b/123096885 for more details).
129class _DummyEagerGraph(threading.local):
130  """_DummyEagerGraph provides a thread local `key` attribute.
131
132  We can't use threading.local directly, i.e. without subclassing, because
133  gevent monkey patches threading.local and its version does not support
134  weak references.
135  """
136
137  class _WeakReferencableClass(object):
138    """This dummy class is needed for two reasons.
139
140    - We need something that supports weak references. Basic types like string
141    and ints don't.
142    - We need something whose hash and equality are based on object identity
143    to make sure they are treated as different keys to _GRAPH_LEARNING_PHASES.
144
145    An empty Python class satisfies both of these requirements.
146    """
147    pass
148
149  def __init__(self):
150    # Constructors for classes subclassing threading.local run once
151    # per thread accessing something in the class. Thus, each thread will
152    # get a different key.
153    super(_DummyEagerGraph, self).__init__()
154    self.key = _DummyEagerGraph._WeakReferencableClass()
155    self.learning_phase_is_set = False
156
157
158_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
159
160# This boolean flag can be set to True to leave variable initialization
161# up to the user.
162# Change its value via `manual_variable_initialization(value)`.
163_MANUAL_VAR_INIT = False
164
165# This list holds the available devices.
166# It is populated when `_get_available_gpus()` is called for the first time.
167# We assume our devices don't change henceforth.
168_LOCAL_DEVICES = None
169
170# The below functions are kept accessible from backend for compatibility.
171epsilon = backend_config.epsilon
172floatx = backend_config.floatx
173image_data_format = backend_config.image_data_format
174set_epsilon = backend_config.set_epsilon
175set_floatx = backend_config.set_floatx
176set_image_data_format = backend_config.set_image_data_format
177
178
179@keras_export('keras.backend.backend')
180@doc_controls.do_not_generate_docs
181def backend():
182  """Publicly accessible method for determining the current backend.
183
184  Only exists for API compatibility with multi-backend Keras.
185
186  Returns:
187      The string "tensorflow".
188  """
189  return 'tensorflow'
190
191
192@keras_export('keras.backend.cast_to_floatx')
193@dispatch.add_dispatch_support
194@doc_controls.do_not_generate_docs
195def cast_to_floatx(x):
196  """Cast a Numpy array to the default Keras float type.
197
198  Args:
199      x: Numpy array or TensorFlow tensor.
200
201  Returns:
202      The same array (Numpy array if `x` was a Numpy array, or TensorFlow tensor
203      if `x` was a tensor), cast to its new type.
204
205  Example:
206
207  >>> tf.keras.backend.floatx()
208  'float32'
209  >>> arr = np.array([1.0, 2.0], dtype='float64')
210  >>> arr.dtype
211  dtype('float64')
212  >>> new_arr = cast_to_floatx(arr)
213  >>> new_arr
214  array([1.,  2.], dtype=float32)
215  >>> new_arr.dtype
216  dtype('float32')
217
218  """
219  if isinstance(x, (ops.Tensor,
220                    variables_module.Variable,
221                    sparse_tensor.SparseTensor)):
222    return math_ops.cast(x, dtype=floatx())
223  return np.asarray(x, dtype=floatx())
224
225
226@keras_export('keras.backend.get_uid')
227def get_uid(prefix=''):
228  """Associates a string prefix with an integer counter in a TensorFlow graph.
229
230  Args:
231    prefix: String prefix to index.
232
233  Returns:
234    Unique integer ID.
235
236  Example:
237
238  >>> get_uid('dense')
239  1
240  >>> get_uid('dense')
241  2
242
243  """
244  graph = get_graph()
245  if graph not in PER_GRAPH_OBJECT_NAME_UIDS:
246    PER_GRAPH_OBJECT_NAME_UIDS[graph] = collections.defaultdict(int)
247  layer_name_uids = PER_GRAPH_OBJECT_NAME_UIDS[graph]
248  layer_name_uids[prefix] += 1
249  return layer_name_uids[prefix]
250
251
252@keras_export('keras.backend.reset_uids')
253def reset_uids():
254  """Resets graph identifiers.
255  """
256
257  PER_GRAPH_OBJECT_NAME_UIDS.clear()
258  OBSERVED_NAMES.clear()
259
260
261@keras_export('keras.backend.clear_session')
262def clear_session():
263  """Resets all state generated by Keras.
264
265  Keras manages a global state, which it uses to implement the Functional
266  model-building API and to uniquify autogenerated layer names.
267
268  If you are creating many models in a loop, this global state will consume
269  an increasing amount of memory over time, and you may want to clear it.
270  Calling `clear_session()` releases the global state: this helps avoid clutter
271  from old models and layers, especially when memory is limited.
272
273  Example 1: calling `clear_session()` when creating models in a loop
274
275  ```python
276  for _ in range(100):
277    # Without `clear_session()`, each iteration of this loop will
278    # slightly increase the size of the global state managed by Keras
279    model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in range(10)])
280
281  for _ in range(100):
282    # With `clear_session()` called at the beginning,
283    # Keras starts with a blank state at each iteration
284    # and memory consumption is constant over time.
285    tf.keras.backend.clear_session()
286    model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in range(10)])
287  ```
288
289  Example 2: resetting the layer name generation counter
290
291  >>> import tensorflow as tf
292  >>> layers = [tf.keras.layers.Dense(10) for _ in range(10)]
293  >>> new_layer = tf.keras.layers.Dense(10)
294  >>> print(new_layer.name)
295  dense_10
296  >>> tf.keras.backend.set_learning_phase(1)
297  >>> print(tf.keras.backend.learning_phase())
298  1
299  >>> tf.keras.backend.clear_session()
300  >>> new_layer = tf.keras.layers.Dense(10)
301  >>> print(new_layer.name)
302  dense
303  """
304  global _SESSION
305  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
306  global _GRAPH_VARIABLES  # pylint: disable=global-variable-not-assigned
307  global _GRAPH_TF_OPTIMIZERS  # pylint: disable=global-variable-not-assigned
308  global _GRAPH
309  _GRAPH.graph = None
310  ops.reset_default_graph()
311  reset_uids()
312  _SESSION.session = None
313  graph = get_graph()
314  with graph.as_default():
315    _DUMMY_EAGER_GRAPH.learning_phase_is_set = False
316    _GRAPH_LEARNING_PHASES.clear()
317    # Create the learning phase placeholder in graph using the default factory.
318    _GRAPH_LEARNING_PHASES.setdefault(graph)
319    _GRAPH_VARIABLES.pop(graph, None)
320    _GRAPH_TF_OPTIMIZERS.pop(graph, None)
321
322# Inject the clear_session function to keras_deps to remove the dependency
323# from TFLite to Keras.
324keras_deps.register_clear_session_function(clear_session)
325
326
327@keras_export('keras.backend.manual_variable_initialization')
328@doc_controls.do_not_generate_docs
329def manual_variable_initialization(value):
330  """Sets the manual variable initialization flag.
331
332  This boolean flag determines whether
333  variables should be initialized
334  as they are instantiated (default), or if
335  the user should handle the initialization
336  (e.g. via `tf.compat.v1.initialize_all_variables()`).
337
338  Args:
339      value: Python boolean.
340  """
341  global _MANUAL_VAR_INIT
342  _MANUAL_VAR_INIT = value
343
344
345@keras_export('keras.backend.learning_phase')
346@doc_controls.do_not_generate_docs
347def learning_phase():
348  """Returns the learning phase flag.
349
350  The learning phase flag is a bool tensor (0 = test, 1 = train)
351  to be passed as input to any Keras function
352  that uses a different behavior at train time and test time.
353
354  Returns:
355      Learning phase (scalar integer tensor or Python integer).
356  """
357  graph = ops.get_default_graph()
358  if graph is getattr(_GRAPH, 'graph', None):
359    # Don't enter an init_scope for the learning phase if eager execution
360    # is enabled but we're inside the Keras workspace graph.
361    learning_phase = symbolic_learning_phase()
362  else:
363    with ops.init_scope():
364      # We always check & set the learning phase inside the init_scope,
365      # otherwise the wrong default_graph will be used to look up the learning
366      # phase inside of functions & defuns.
367      #
368      # This is because functions & defuns (both in graph & in eager mode)
369      # will always execute non-eagerly using a function-specific default
370      # subgraph.
371      learning_phase = _GRAPH_LEARNING_PHASES[None]
372  _mark_func_graph_as_unsaveable(graph, learning_phase)
373  return learning_phase
374
375
376def global_learning_phase_is_set():
377  return _DUMMY_EAGER_GRAPH.learning_phase_is_set
378
379
380def _mark_func_graph_as_unsaveable(graph, learning_phase):
381  """Mark func graph as unsaveable due to use of symbolic keras learning phase.
382
383  Functions that capture the symbolic learning phase cannot be exported to
384  SavedModel. Mark the funcgraph as unsaveable, so that an error will be raised
385  if it is exported.
386
387  Args:
388    graph: Graph or FuncGraph object.
389    learning_phase: Learning phase placeholder or int defined in the graph.
390  """
391  if graph.building_function and is_placeholder(learning_phase):
392    graph.mark_as_unsaveable(
393        'The keras learning phase placeholder was used inside a function. '
394        'Exporting placeholders is not supported when saving out a SavedModel. '
395        'Please call `tf.keras.backend.set_learning_phase(0)` in the function '
396        'to set the learning phase to a constant value.')
397
398
399def symbolic_learning_phase():
400  graph = get_graph()
401  with graph.as_default():
402    return _GRAPH_LEARNING_PHASES[graph]
403
404
405def _default_learning_phase():
406  if context.executing_eagerly():
407    return 0
408  else:
409    with name_scope(''):
410      return array_ops.placeholder_with_default(
411          False, shape=(), name='keras_learning_phase')
412
413
414@keras_export('keras.backend.set_learning_phase')
415@doc_controls.do_not_generate_docs
416def set_learning_phase(value):
417  """Sets the learning phase to a fixed value.
418
419  The backend learning phase affects any code that calls
420  `backend.learning_phase()`
421  In particular, all Keras built-in layers use the learning phase as the default
422  for the `training` arg to `Layer.__call__`.
423
424  User-written layers and models can achieve the same behavior with code that
425  looks like:
426
427  ```python
428    def call(self, inputs, training=None):
429      if training is None:
430        training = backend.learning_phase()
431  ```
432
433  Args:
434      value: Learning phase value, either 0 or 1 (integers).
435             0 = test, 1 = train
436
437  Raises:
438      ValueError: if `value` is neither `0` nor `1`.
439  """
440  warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
441                'will be removed after 2020-10-11. To update it, simply '
442                'pass a True/False value to the `training` argument of the '
443                '`__call__` method of your layer or model.')
444  deprecated_internal_set_learning_phase(value)
445
446
447def deprecated_internal_set_learning_phase(value):
448  """A deprecated internal implementation of set_learning_phase.
449
450  This method is an internal-only version of `set_learning_phase` that
451  does not raise a deprecation error. It is required because
452  saved_model needs to keep working with user code that uses the deprecated
453  learning phase methods until those APIs are fully removed from the public API.
454
455  Specifically SavedModel saving needs to make sure the learning phase is 0
456  during tracing even if users overwrote it to a different value.
457
458  But, we don't want to raise deprecation warnings for users when savedmodel
459  sets learning phase just for compatibility with code that relied on
460  explicitly setting the learning phase for other values.
461
462  Args:
463      value: Learning phase value, either 0 or 1 (integers). 0 = test, 1 = train
464
465  Raises:
466      ValueError: if `value` is neither `0` nor `1`.
467  """
468  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
469  if value not in {0, 1}:
470    raise ValueError('Expected learning phase to be 0 or 1.')
471  with ops.init_scope():
472    if context.executing_eagerly():
473      # In an eager context, the learning phase values applies to both the eager
474      # context and the internal Keras graph.
475      _DUMMY_EAGER_GRAPH.learning_phase_is_set = True
476      _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value
477    _GRAPH_LEARNING_PHASES[get_graph()] = value
478
479
480@keras_export('keras.backend.learning_phase_scope')
481@tf_contextlib.contextmanager
482@doc_controls.do_not_generate_docs
483def learning_phase_scope(value):
484  """Provides a scope within which the learning phase is equal to `value`.
485
486  The learning phase gets restored to its original value upon exiting the scope.
487
488  Args:
489     value: Learning phase value, either 0 or 1 (integers).
490            0 = test, 1 = train
491
492  Yields:
493    None.
494
495  Raises:
496     ValueError: if `value` is neither `0` nor `1`.
497  """
498  warnings.warn('`tf.keras.backend.learning_phase_scope` is deprecated and '
499                'will be removed after 2020-10-11. To update it, simply '
500                'pass a True/False value to the `training` argument of the '
501                '`__call__` method of your layer or model.')
502  with deprecated_internal_learning_phase_scope(value):
503    try:
504      yield
505    finally:
506      pass
507
508
509@tf_contextlib.contextmanager
510def deprecated_internal_learning_phase_scope(value):
511  """An internal-only version of `learning_phase_scope`.
512
513  Unlike the public method, this method does not raise a deprecation warning.
514  This is needed because saved model saving needs to set learning phase
515  to maintain compatibility
516  with code that sets/gets the learning phase, but saved model
517  saving itself shouldn't raise a deprecation warning.
518
519  We can get rid of this method and its usages when the public API is
520  removed.
521
522  Args:
523     value: Learning phase value, either 0 or 1 (integers). 0 = test, 1 = train
524
525  Yields:
526    None.
527
528  Raises:
529     ValueError: if `value` is neither `0` nor `1`.
530  """
531  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
532  if value not in {0, 1}:
533    raise ValueError('Expected learning phase to be 0 or 1.')
534
535  with ops.init_scope():
536    if context.executing_eagerly():
537      previous_eager_value = _GRAPH_LEARNING_PHASES.get(
538          _DUMMY_EAGER_GRAPH.key, None)
539    previous_graph_value = _GRAPH_LEARNING_PHASES.get(get_graph(), None)
540
541  learning_phase_previously_set = _DUMMY_EAGER_GRAPH.learning_phase_is_set
542  try:
543    deprecated_internal_set_learning_phase(value)
544    yield
545  finally:
546    # Restore learning phase to initial value.
547    if not learning_phase_previously_set:
548      _DUMMY_EAGER_GRAPH.learning_phase_is_set = False
549    with ops.init_scope():
550      if context.executing_eagerly():
551        if previous_eager_value is not None:
552          _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = previous_eager_value
553        elif _DUMMY_EAGER_GRAPH.key in _GRAPH_LEARNING_PHASES:
554          del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
555
556      graph = get_graph()
557      if previous_graph_value is not None:
558        _GRAPH_LEARNING_PHASES[graph] = previous_graph_value
559      elif graph in _GRAPH_LEARNING_PHASES:
560        del _GRAPH_LEARNING_PHASES[graph]
561
562
563@tf_contextlib.contextmanager
564def eager_learning_phase_scope(value):
565  """Internal scope that sets the learning phase in eager / tf.function only.
566
567  Args:
568      value: Learning phase value, either 0 or 1 (integers).
569             0 = test, 1 = train
570
571  Yields:
572    None.
573
574  Raises:
575     ValueError: if `value` is neither `0` nor `1`.
576  """
577  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
578  assert value in {0, 1}
579  assert ops.executing_eagerly_outside_functions()
580  global_learning_phase_was_set = global_learning_phase_is_set()
581  if global_learning_phase_was_set:
582    previous_value = learning_phase()
583  try:
584    _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value
585    yield
586  finally:
587    # Restore learning phase to initial value or unset.
588    if global_learning_phase_was_set:
589      _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = previous_value
590    else:
591      del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
592
593
594def _as_graph_element(obj):
595  """Convert `obj` to a graph element if possible, otherwise return `None`.
596
597  Args:
598    obj: Object to convert.
599
600  Returns:
601    The result of `obj._as_graph_element()` if that method is available;
602        otherwise `None`.
603  """
604  conv_fn = getattr(obj, '_as_graph_element', None)
605  if conv_fn and callable(conv_fn):
606    return conv_fn()
607  return None
608
609
610def _assert_same_graph(original_item, item):
611  """Fail if the 2 items are from different graphs.
612
613  Args:
614    original_item: Original item to check against.
615    item: Item to check.
616
617  Raises:
618    ValueError: if graphs do not match.
619  """
620  original_graph = getattr(original_item, 'graph', None)
621  graph = getattr(item, 'graph', None)
622  if original_graph and graph and original_graph is not graph:
623    raise ValueError(
624        '%s must be from the same graph as %s (graphs are %s and %s).' %
625        (item, original_item, graph, original_graph))
626
627
628def _current_graph(op_input_list, graph=None):
629  """Returns the appropriate graph to use for the given inputs.
630
631  This library method provides a consistent algorithm for choosing the graph
632  in which an Operation should be constructed:
633
634  1. If the default graph is being used to construct a function, we
635     use the default graph.
636  2. If the "graph" is specified explicitly, we validate that all of the inputs
637     in "op_input_list" are compatible with that graph.
638  3. Otherwise, we attempt to select a graph from the first Operation-
639     or Tensor-valued input in "op_input_list", and validate that all other
640     such inputs are in the same graph.
641  4. If the graph was not specified and it could not be inferred from
642     "op_input_list", we attempt to use the default graph.
643
644  Args:
645    op_input_list: A list of inputs to an operation, which may include `Tensor`,
646      `Operation`, and other objects that may be converted to a graph element.
647    graph: (Optional) The explicit graph to use.
648
649  Raises:
650    TypeError: If op_input_list is not a list or tuple, or if graph is not a
651      Graph.
652    ValueError: If a graph is explicitly passed and not all inputs are from it,
653      or if the inputs are from multiple graphs, or we could not find a graph
654      and there was no default graph.
655
656  Returns:
657    The appropriate graph to use for the given inputs.
658
659  """
660  current_default_graph = ops.get_default_graph()
661  if current_default_graph.building_function:
662    return current_default_graph
663
664  op_input_list = tuple(op_input_list)  # Handle generators correctly
665  if graph and not isinstance(graph, ops.Graph):
666    raise TypeError('Input graph needs to be a Graph: %s' % (graph,))
667
668  # 1. We validate that all of the inputs are from the same graph. This is
669  #    either the supplied graph parameter, or the first one selected from one
670  #    the graph-element-valued inputs. In the latter case, we hold onto
671  #    that input in original_graph_element so we can provide a more
672  #    informative error if a mismatch is found.
673  original_graph_element = None
674  for op_input in op_input_list:
675    # Determine if this is a valid graph_element.
676    # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
677    # up.
678    if (isinstance(op_input, (
679        ops.Operation, ops.Tensor, composite_tensor.CompositeTensor)) and
680        ((not isinstance(op_input, ops.Tensor))
681         or type(op_input) == ops.Tensor)):  # pylint: disable=unidiomatic-typecheck
682      graph_element = op_input
683    else:
684      graph_element = _as_graph_element(op_input)
685
686    if graph_element is not None:
687      if not graph:
688        original_graph_element = graph_element
689        graph = getattr(graph_element, 'graph', None)
690      elif original_graph_element is not None:
691        _assert_same_graph(original_graph_element, graph_element)
692      elif graph_element.graph is not graph:
693        raise ValueError('%s is not from the passed-in graph.' % graph_element)
694
695  # 2. If all else fails, we use the default graph, which is always there.
696  return graph or current_default_graph
697
698
699def _get_session(op_input_list=()):
700  """Returns the session object for the current thread."""
701  global _SESSION
702  default_session = ops.get_default_session()
703  if default_session is not None:
704    session = default_session
705  else:
706    if ops.inside_function():
707      raise RuntimeError('Cannot get session inside Tensorflow graph function.')
708    # If we don't have a session, or that session does not match the current
709    # graph, create and cache a new session.
710    if (getattr(_SESSION, 'session', None) is None or
711        _SESSION.session.graph is not _current_graph(op_input_list)):
712      # If we are creating the Session inside a tf.distribute.Strategy scope,
713      # we ask the strategy for the right session options to use.
714      if distribution_strategy_context.has_strategy():
715        configure_and_create_distributed_session(
716            distribution_strategy_context.get_strategy())
717      else:
718        _SESSION.session = session_module.Session(
719            config=get_default_session_config())
720    session = _SESSION.session
721  return session
722
723
724@keras_export(v1=['keras.backend.get_session'])
725def get_session(op_input_list=()):
726  """Returns the TF session to be used by the backend.
727
728  If a default TensorFlow session is available, we will return it.
729
730  Else, we will return the global Keras session assuming it matches
731  the current graph.
732
733  If no global Keras session exists at this point:
734  we will create a new global session.
735
736  Note that you can manually set the global session
737  via `K.set_session(sess)`.
738
739  Args:
740      op_input_list: An option sequence of tensors or ops, which will be used
741        to determine the current graph. Otherwise the default graph will be
742        used.
743
744  Returns:
745      A TensorFlow session.
746  """
747  session = _get_session(op_input_list)
748  if not _MANUAL_VAR_INIT:
749    with session.graph.as_default():
750      _initialize_variables(session)
751  return session
752
753# Inject the get_session function to keras_deps to remove the dependency
754# from TFLite to Keras.
755keras_deps.register_get_session_function(get_session)
756
757# Inject the get_session function to tracking_util to avoid the backward
758# dependency from TF to Keras.
759tracking_util.register_session_provider(get_session)
760
761
762def get_graph():
763  if context.executing_eagerly():
764    global _GRAPH
765    if not getattr(_GRAPH, 'graph', None):
766      _GRAPH.graph = func_graph.FuncGraph('keras_graph')
767    return _GRAPH.graph
768  else:
769    return ops.get_default_graph()
770
771
772@tf_contextlib.contextmanager
773def _scratch_graph(graph=None):
774  """Retrieve a shared and temporary func graph.
775
776  The eager execution path lifts a subgraph from the keras global graph into
777  a scratch graph in order to create a function. DistributionStrategies, in
778  turn, constructs multiple functions as well as a final combined function. In
779  order for that logic to work correctly, all of the functions need to be
780  created on the same scratch FuncGraph.
781
782  Args:
783    graph: A graph to be used as the current scratch graph. If not set then
784      a scratch graph will either be retrieved or created:
785
786  Yields:
787    The current scratch graph.
788  """
789  global _CURRENT_SCRATCH_GRAPH
790  scratch_graph = getattr(_CURRENT_SCRATCH_GRAPH, 'graph', None)
791  # If scratch graph and `graph` are both configured, they must match.
792  if (scratch_graph is not None and graph is not None and
793      scratch_graph is not graph):
794    raise ValueError('Multiple scratch graphs specified.')
795
796  if scratch_graph:
797    yield scratch_graph
798    return
799
800  graph = graph or func_graph.FuncGraph('keras_scratch_graph')
801  try:
802    _CURRENT_SCRATCH_GRAPH.graph = graph
803    yield graph
804  finally:
805    _CURRENT_SCRATCH_GRAPH.graph = None
806
807
808@keras_export(v1=['keras.backend.set_session'])
809def set_session(session):
810  """Sets the global TensorFlow session.
811
812  Args:
813      session: A TF Session.
814  """
815  global _SESSION
816  _SESSION.session = session
817
818
819def get_default_session_config():
820  if os.environ.get('OMP_NUM_THREADS'):
821    logging.warning(
822        'OMP_NUM_THREADS is no longer used by the default Keras config. '
823        'To configure the number of threads, use tf.config.threading APIs.')
824
825  config = get_config()
826  config.allow_soft_placement = True
827
828  return config
829
830
831def get_default_graph_uid_map():
832  graph = ops.get_default_graph()
833  name_uid_map = PER_GRAPH_OBJECT_NAME_UIDS.get(graph, None)
834  if name_uid_map is None:
835    name_uid_map = collections.defaultdict(int)
836    PER_GRAPH_OBJECT_NAME_UIDS[graph] = name_uid_map
837  return name_uid_map
838
839
840# DEVICE MANIPULATION
841
842
843class _TfDeviceCaptureOp(object):
844  """Class for capturing the TF device scope."""
845
846  def __init__(self):
847    self.device = None
848
849  def _set_device(self, device):
850    """This method captures TF's explicit device scope setting."""
851    if isinstance(device, device_spec.DeviceSpecV2):
852      device = device.to_string()
853    self.device = device
854
855  def _set_device_from_string(self, device_str):
856    self.device = device_str
857
858
859def _get_current_tf_device():
860  """Return explicit device of current context, otherwise returns `None`.
861
862  Returns:
863      If the current device scope is explicitly set, it returns a string with
864      the device (`CPU` or `GPU`). If the scope is not explicitly set, it will
865      return `None`.
866  """
867  graph = get_graph()
868  op = _TfDeviceCaptureOp()
869  graph._apply_device_functions(op)
870  if tf2.enabled():
871    return device_spec.DeviceSpecV2.from_string(op.device)
872  else:
873    return device_spec.DeviceSpecV1.from_string(op.device)
874
875
876def _is_current_explicit_device(device_type):
877  """Check if the current device is explicitly set on the device type specified.
878
879  Args:
880      device_type: A string containing `GPU` or `CPU` (case-insensitive).
881
882  Returns:
883      A boolean indicating if the current device scope is explicitly set on the
884      device type.
885
886  Raises:
887      ValueError: If the `device_type` string indicates an unsupported device.
888  """
889  device_type = device_type.upper()
890  if device_type not in ['CPU', 'GPU']:
891    raise ValueError('`device_type` should be either "CPU" or "GPU".')
892  device = _get_current_tf_device()
893  return device is not None and device.device_type == device_type.upper()
894
895
896def _get_available_gpus():
897  """Get a list of available GPU devices (formatted as strings).
898
899  Returns:
900      A list of available GPU devices.
901  """
902  if ops.executing_eagerly_outside_functions():
903    # Returns names of devices directly.
904    return [d.name for d in config.list_logical_devices('GPU')]
905
906  global _LOCAL_DEVICES
907  if _LOCAL_DEVICES is None:
908    _LOCAL_DEVICES = get_session().list_devices()
909  return [x.name for x in _LOCAL_DEVICES if x.device_type == 'GPU']
910
911
912def _has_nchw_support():
913  """Check whether the current scope supports NCHW ops.
914
915  TensorFlow does not support NCHW on CPU. Therefore we check if we are not
916  explicitly put on
917  CPU, and have GPUs available. In this case there will be soft-placing on the
918  GPU device.
919
920  Returns:
921      bool: if the current scope device placement would support nchw
922  """
923  explicitly_on_cpu = _is_current_explicit_device('CPU')
924  gpus_available = bool(_get_available_gpus())
925  return not explicitly_on_cpu and gpus_available
926
927
928# VARIABLE MANIPULATION
929
930
931def _constant_to_tensor(x, dtype):
932  """Convert the input `x` to a tensor of type `dtype`.
933
934  This is slightly faster than the _to_tensor function, at the cost of
935  handling fewer cases.
936
937  Args:
938      x: An object to be converted (numpy arrays, floats, ints and lists of
939        them).
940      dtype: The destination type.
941
942  Returns:
943      A tensor.
944  """
945  return constant_op.constant(x, dtype=dtype)
946
947
948def _to_tensor(x, dtype):
949  """Convert the input `x` to a tensor of type `dtype`.
950
951  Args:
952      x: An object to be converted (numpy array, list, tensors).
953      dtype: The destination type.
954
955  Returns:
956      A tensor.
957  """
958  return ops.convert_to_tensor_v2_with_dispatch(x, dtype=dtype)
959
960
961@keras_export('keras.backend.is_sparse')
962@doc_controls.do_not_generate_docs
963def is_sparse(tensor):
964  """Returns whether a tensor is a sparse tensor.
965
966  Args:
967      tensor: A tensor instance.
968
969  Returns:
970      A boolean.
971
972  Example:
973
974
975  >>> a = tf.keras.backend.placeholder((2, 2), sparse=False)
976  >>> print(tf.keras.backend.is_sparse(a))
977  False
978  >>> b = tf.keras.backend.placeholder((2, 2), sparse=True)
979  >>> print(tf.keras.backend.is_sparse(b))
980  True
981
982  """
983  spec = getattr(tensor, '_type_spec', None)
984  if spec is not None:
985    return isinstance(spec, sparse_tensor.SparseTensorSpec)
986  return isinstance(tensor, sparse_tensor.SparseTensor)
987
988
989@keras_export('keras.backend.to_dense')
990@dispatch.add_dispatch_support
991@doc_controls.do_not_generate_docs
992def to_dense(tensor):
993  """Converts a sparse tensor into a dense tensor and returns it.
994
995  Args:
996      tensor: A tensor instance (potentially sparse).
997
998  Returns:
999      A dense tensor.
1000
1001  Examples:
1002
1003
1004  >>> b = tf.keras.backend.placeholder((2, 2), sparse=True)
1005  >>> print(tf.keras.backend.is_sparse(b))
1006  True
1007  >>> c = tf.keras.backend.to_dense(b)
1008  >>> print(tf.keras.backend.is_sparse(c))
1009  False
1010
1011  """
1012  if is_sparse(tensor):
1013    return sparse_ops.sparse_tensor_to_dense(tensor)
1014  else:
1015    return tensor
1016
1017
1018@keras_export('keras.backend.name_scope', v1=[])
1019@doc_controls.do_not_generate_docs
1020def name_scope(name):
1021  """A context manager for use when defining a Python op.
1022
1023  This context manager pushes a name scope, which will make the name of all
1024  operations added within it have a prefix.
1025
1026  For example, to define a new Python op called `my_op`:
1027
1028
1029  def my_op(a):
1030    with tf.name_scope("MyOp") as scope:
1031      a = tf.convert_to_tensor(a, name="a")
1032      # Define some computation that uses `a`.
1033      return foo_op(..., name=scope)
1034
1035
1036  When executed, the Tensor `a` will have the name `MyOp/a`.
1037
1038  Args:
1039    name: The prefix to use on all names created within the name scope.
1040
1041  Returns:
1042    Name scope context manager.
1043  """
1044  return ops.name_scope_v2(name)
1045
1046# Export V1 version.
1047keras_export(v1=['keras.backend.name_scope'])(ops.name_scope_v1)
1048
1049
1050@keras_export('keras.backend.variable')
1051@doc_controls.do_not_generate_docs
1052def variable(value, dtype=None, name=None, constraint=None):
1053  """Instantiates a variable and returns it.
1054
1055  Args:
1056      value: Numpy array, initial value of the tensor.
1057      dtype: Tensor type.
1058      name: Optional name string for the tensor.
1059      constraint: Optional projection function to be
1060          applied to the variable after an optimizer update.
1061
1062  Returns:
1063      A variable instance (with Keras metadata included).
1064
1065  Examples:
1066
1067  >>> val = np.array([[1, 2], [3, 4]])
1068  >>> kvar = tf.keras.backend.variable(value=val, dtype='float64',
1069  ...                                  name='example_var')
1070  >>> tf.keras.backend.dtype(kvar)
1071  'float64'
1072  >>> print(kvar)
1073  <tf.Variable 'example_var:...' shape=(2, 2) dtype=float64, numpy=
1074    array([[1., 2.],
1075           [3., 4.]])>
1076
1077  """
1078  if dtype is None:
1079    dtype = floatx()
1080  if hasattr(value, 'tocoo'):
1081    sparse_coo = value.tocoo()
1082    indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), np.expand_dims(
1083        sparse_coo.col, 1)), 1)
1084    v = sparse_tensor.SparseTensor(
1085        indices=indices, values=sparse_coo.data, dense_shape=sparse_coo.shape)
1086    v._keras_shape = sparse_coo.shape
1087    return v
1088  v = variables_module.Variable(
1089      value,
1090      dtype=dtypes_module.as_dtype(dtype),
1091      name=name,
1092      constraint=constraint)
1093  if isinstance(value, np.ndarray):
1094    v._keras_shape = value.shape
1095  elif hasattr(value, 'shape'):
1096    v._keras_shape = int_shape(value)
1097  track_variable(v)
1098  return v
1099
1100
1101def track_tf_optimizer(tf_optimizer):
1102  """Tracks the given TF optimizer for initialization of its variables."""
1103  if context.executing_eagerly():
1104    return
1105  optimizers = _GRAPH_TF_OPTIMIZERS[None]
1106  optimizers.add(tf_optimizer)
1107
1108
1109def track_variable(v):
1110  """Tracks the given variable for initialization."""
1111  if context.executing_eagerly():
1112    return
1113  graph = v.graph if hasattr(v, 'graph') else get_graph()
1114  _GRAPH_VARIABLES[graph].add(v)
1115
1116
1117def observe_object_name(name):
1118  """Observe a name and make sure it won't be used by `unique_object_name`."""
1119  OBSERVED_NAMES.add(name)
1120
1121
1122def unique_object_name(name,
1123                       name_uid_map=None,
1124                       avoid_names=None,
1125                       namespace='',
1126                       zero_based=False,
1127                       avoid_observed_names=False):
1128  """Makes a object name (or arbitrary string) unique within a TensorFlow graph.
1129
1130  Args:
1131    name: String name to make unique.
1132    name_uid_map: An optional defaultdict(int) to use when creating unique
1133      names. If None (default), uses a per-Graph dictionary.
1134    avoid_names: An optional set or dict with names which should not be used. If
1135      None (default), don't avoid any names unless `avoid_observed_names` is
1136      True.
1137    namespace: Gets a name which is unique within the (graph, namespace). Layers
1138      which are not Networks use a blank namespace and so get graph-global
1139      names.
1140    zero_based: If True, name sequences start with no suffix (e.g. "dense",
1141      "dense_1"). If False, naming is one-based ("dense_1", "dense_2").
1142    avoid_observed_names: If True, avoid any names that have been observed by
1143      `backend.observe_object_name`.
1144
1145  Returns:
1146    Unique string name.
1147
1148  Example:
1149
1150
1151  unique_object_name('dense')  # dense_1
1152  unique_object_name('dense')  # dense_2
1153
1154  """
1155  if name_uid_map is None:
1156    name_uid_map = get_default_graph_uid_map()
1157  if avoid_names is None:
1158    if avoid_observed_names:
1159      avoid_names = OBSERVED_NAMES
1160    else:
1161      avoid_names = set()
1162  proposed_name = None
1163  while proposed_name is None or proposed_name in avoid_names:
1164    name_key = (namespace, name)
1165    if zero_based:
1166      number = name_uid_map[name_key]
1167      if number:
1168        proposed_name = name + '_' + str(number)
1169      else:
1170        proposed_name = name
1171      name_uid_map[name_key] += 1
1172    else:
1173      name_uid_map[name_key] += 1
1174      proposed_name = name + '_' + str(name_uid_map[name_key])
1175  return proposed_name
1176
1177
1178def _get_variables(graph=None):
1179  """Returns variables corresponding to the given graph for initialization."""
1180  assert not context.executing_eagerly()
1181  variables = _GRAPH_VARIABLES[graph]
1182  for opt in _GRAPH_TF_OPTIMIZERS[graph]:
1183    variables.update(opt.optimizer.variables())
1184  return variables
1185
1186
1187def _initialize_variables(session):
1188  """Utility to initialize uninitialized variables on the fly."""
1189  variables = _get_variables(get_graph())
1190  candidate_vars = []
1191  for v in variables:
1192    if not getattr(v, '_keras_initialized', False):
1193      candidate_vars.append(v)
1194  if candidate_vars:
1195    # This step is expensive, so we only run it on variables not already
1196    # marked as initialized.
1197    is_initialized = session.run(
1198        [variables_module.is_variable_initialized(v) for v in candidate_vars])
1199    # TODO(kathywu): Some metric variables loaded from SavedModel are never
1200    # actually used, and do not have an initializer.
1201    should_be_initialized = [
1202        (not is_initialized[n]) and v.initializer is not None
1203        for n, v in enumerate(candidate_vars)]
1204    uninitialized_vars = []
1205    for flag, v in zip(should_be_initialized, candidate_vars):
1206      if flag:
1207        uninitialized_vars.append(v)
1208      v._keras_initialized = True
1209    if uninitialized_vars:
1210      session.run(variables_module.variables_initializer(uninitialized_vars))
1211
1212
1213@keras_export('keras.backend.constant')
1214@dispatch.add_dispatch_support
1215@doc_controls.do_not_generate_docs
1216def constant(value, dtype=None, shape=None, name=None):
1217  """Creates a constant tensor.
1218
1219  Args:
1220      value: A constant value (or list)
1221      dtype: The type of the elements of the resulting tensor.
1222      shape: Optional dimensions of resulting tensor.
1223      name: Optional name for the tensor.
1224
1225  Returns:
1226      A Constant Tensor.
1227  """
1228  if dtype is None:
1229    dtype = floatx()
1230
1231  return constant_op.constant(value, dtype=dtype, shape=shape, name=name)
1232
1233
1234@keras_export('keras.backend.is_keras_tensor')
1235def is_keras_tensor(x):
1236  """Returns whether `x` is a Keras tensor.
1237
1238  A "Keras tensor" is a tensor that was returned by a Keras layer,
1239  (`Layer` class) or by `Input`.
1240
1241  Args:
1242      x: A candidate tensor.
1243
1244  Returns:
1245      A boolean: Whether the argument is a Keras tensor.
1246
1247  Raises:
1248      ValueError: In case `x` is not a symbolic tensor.
1249
1250  Examples:
1251
1252  >>> np_var = np.array([1, 2])
1253  >>> # A numpy array is not a symbolic tensor.
1254  >>> tf.keras.backend.is_keras_tensor(np_var)
1255  Traceback (most recent call last):
1256  ...
1257  ValueError: Unexpectedly found an instance of type `<class 'numpy.ndarray'>`.
1258  Expected a symbolic tensor instance.
1259  >>> keras_var = tf.keras.backend.variable(np_var)
1260  >>> # A variable created with the keras backend is not a Keras tensor.
1261  >>> tf.keras.backend.is_keras_tensor(keras_var)
1262  False
1263  >>> keras_placeholder = tf.keras.backend.placeholder(shape=(2, 4, 5))
1264  >>> # A placeholder is a Keras tensor.
1265  >>> tf.keras.backend.is_keras_tensor(keras_placeholder)
1266  True
1267  >>> keras_input = tf.keras.layers.Input([10])
1268  >>> # An Input is a Keras tensor.
1269  >>> tf.keras.backend.is_keras_tensor(keras_input)
1270  True
1271  >>> keras_layer_output = tf.keras.layers.Dense(10)(keras_input)
1272  >>> # Any Keras layer output is a Keras tensor.
1273  >>> tf.keras.backend.is_keras_tensor(keras_layer_output)
1274  True
1275
1276  """
1277  if not isinstance(x,
1278                    (ops.Tensor, variables_module.Variable,
1279                     sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor,
1280                     keras_tensor.KerasTensor)):
1281    raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) +
1282                     '`. Expected a symbolic tensor instance.')
1283  if keras_tensor.keras_tensors_enabled():
1284    return isinstance(x, keras_tensor.KerasTensor)
1285  return hasattr(x, '_keras_history')
1286
1287
1288@keras_export('keras.backend.placeholder')
1289@doc_controls.do_not_generate_docs
1290def placeholder(shape=None,
1291                ndim=None,
1292                dtype=None,
1293                sparse=False,
1294                name=None,
1295                ragged=False):
1296  """Instantiates a placeholder tensor and returns it.
1297
1298  Args:
1299      shape: Shape of the placeholder
1300          (integer tuple, may include `None` entries).
1301      ndim: Number of axes of the tensor.
1302          At least one of {`shape`, `ndim`} must be specified.
1303          If both are specified, `shape` is used.
1304      dtype: Placeholder type.
1305      sparse: Boolean, whether the placeholder should have a sparse type.
1306      name: Optional name string for the placeholder.
1307      ragged: Boolean, whether the placeholder should have a ragged type.
1308          In this case, values of 'None' in the 'shape' argument represent
1309          ragged dimensions. For more information about RaggedTensors, see this
1310          [guide](https://www.tensorflow.org/guide/ragged_tensors).
1311
1312  Raises:
1313      ValueError: If called with eager execution
1314      ValueError: If called with sparse = True and ragged = True.
1315
1316  Returns:
1317      Tensor instance (with Keras metadata included).
1318
1319  Examples:
1320
1321
1322  >>> input_ph = tf.keras.backend.placeholder(shape=(2, 4, 5))
1323  >>> input_ph
1324  <KerasTensor: shape=(2, 4, 5) dtype=float32 (created by layer ...)>
1325
1326  """
1327  if sparse and ragged:
1328    raise ValueError(
1329        'Cannot set both sparse and ragged to True when creating a placeholder.'
1330    )
1331  if dtype is None:
1332    dtype = floatx()
1333  if not shape:
1334    if ndim:
1335      shape = (None,) * ndim
1336  if keras_tensor.keras_tensors_enabled():
1337    if sparse:
1338      spec = sparse_tensor.SparseTensorSpec(
1339          shape=shape, dtype=dtype)
1340    elif ragged:
1341      ragged_rank = 0
1342      for i in range(1, len(shape)):
1343        # Hacky because could be tensorshape or tuple maybe?
1344        # Or just tensorshape?
1345        if shape[i] is None or (
1346            hasattr(shape[i], 'value') and
1347            shape[i].value is None):
1348          ragged_rank = i
1349      spec = ragged_tensor.RaggedTensorSpec(
1350          shape=shape, dtype=dtype, ragged_rank=ragged_rank)
1351    else:
1352      spec = tensor_spec.TensorSpec(
1353          shape=shape, dtype=dtype, name=name)
1354    x = keras_tensor.keras_tensor_from_type_spec(spec, name=name)
1355  else:
1356    with get_graph().as_default():
1357      if sparse:
1358        x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
1359      elif ragged:
1360        ragged_rank = 0
1361        for i in range(1, len(shape)):
1362          if shape[i] is None:
1363            ragged_rank = i
1364        type_spec = ragged_tensor.RaggedTensorSpec(
1365            shape=shape, dtype=dtype, ragged_rank=ragged_rank)
1366        def tensor_spec_to_placeholder(tensorspec):
1367          return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
1368        x = nest.map_structure(tensor_spec_to_placeholder, type_spec,
1369                               expand_composites=True)
1370      else:
1371        x = array_ops.placeholder(dtype, shape=shape, name=name)
1372
1373  if context.executing_eagerly():
1374    # Add keras_history connectivity information to the placeholder
1375    # when the placeholder is built in a top-level eager context
1376    # (intended to be used with keras.backend.function)
1377    from tensorflow.python.keras.engine import input_layer  # pylint: disable=g-import-not-at-top
1378    x = input_layer.Input(tensor=x)
1379    if keras_tensor.keras_tensors_enabled():
1380      x._is_backend_placeholder = True
1381
1382  return x
1383
1384
1385def is_placeholder(x):
1386  """Returns whether `x` is a placeholder.
1387
1388  Args:
1389      x: A candidate placeholder.
1390
1391  Returns:
1392      Boolean.
1393  """
1394  try:
1395    if keras_tensor.keras_tensors_enabled():
1396      return hasattr(x, '_is_backend_placeholder')
1397    from tensorflow.python.keras.utils import tf_utils  # pylint: disable=g-import-not-at-top
1398    if tf_utils.is_extension_type(x):
1399      flat_components = nest.flatten(x, expand_composites=True)
1400      return py_any(is_placeholder(c) for c in flat_components)
1401    else:
1402      return x.op.type == 'Placeholder'
1403  except AttributeError:
1404    return False
1405
1406
1407@keras_export('keras.backend.shape')
1408@dispatch.add_dispatch_support
1409@doc_controls.do_not_generate_docs
1410def shape(x):
1411  """Returns the symbolic shape of a tensor or variable.
1412
1413  Args:
1414      x: A tensor or variable.
1415
1416  Returns:
1417      A symbolic shape (which is itself a tensor).
1418
1419  Examples:
1420
1421  >>> val = np.array([[1, 2], [3, 4]])
1422  >>> kvar = tf.keras.backend.variable(value=val)
1423  >>> tf.keras.backend.shape(kvar)
1424  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 2], dtype=int32)>
1425  >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1426  >>> tf.keras.backend.shape(input)
1427  <KerasTensor: shape=(3,) dtype=int32 inferred_value=[2, 4, 5] ...>
1428
1429  """
1430  return array_ops.shape(x)
1431
1432
1433@keras_export('keras.backend.int_shape')
1434@doc_controls.do_not_generate_docs
1435def int_shape(x):
1436  """Returns the shape of tensor or variable as a tuple of int or None entries.
1437
1438  Args:
1439      x: Tensor or variable.
1440
1441  Returns:
1442      A tuple of integers (or None entries).
1443
1444  Examples:
1445
1446  >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1447  >>> tf.keras.backend.int_shape(input)
1448  (2, 4, 5)
1449  >>> val = np.array([[1, 2], [3, 4]])
1450  >>> kvar = tf.keras.backend.variable(value=val)
1451  >>> tf.keras.backend.int_shape(kvar)
1452  (2, 2)
1453
1454  """
1455  try:
1456    shape = x.shape
1457    if not isinstance(shape, tuple):
1458      shape = tuple(shape.as_list())
1459    return shape
1460  except ValueError:
1461    return None
1462
1463
1464@keras_export('keras.backend.ndim')
1465@doc_controls.do_not_generate_docs
1466def ndim(x):
1467  """Returns the number of axes in a tensor, as an integer.
1468
1469  Args:
1470      x: Tensor or variable.
1471
1472  Returns:
1473      Integer (scalar), number of axes.
1474
1475  Examples:
1476
1477
1478  >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1479  >>> val = np.array([[1, 2], [3, 4]])
1480  >>> kvar = tf.keras.backend.variable(value=val)
1481  >>> tf.keras.backend.ndim(input)
1482  3
1483  >>> tf.keras.backend.ndim(kvar)
1484  2
1485
1486  """
1487  return x.shape.rank
1488
1489
1490@keras_export('keras.backend.dtype')
1491@dispatch.add_dispatch_support
1492@doc_controls.do_not_generate_docs
1493def dtype(x):
1494  """Returns the dtype of a Keras tensor or variable, as a string.
1495
1496  Args:
1497      x: Tensor or variable.
1498
1499  Returns:
1500      String, dtype of `x`.
1501
1502  Examples:
1503
1504  >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5)))
1505  'float32'
1506  >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5),
1507  ...                                                     dtype='float32'))
1508  'float32'
1509  >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5),
1510  ...                                                     dtype='float64'))
1511  'float64'
1512  >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]))
1513  >>> tf.keras.backend.dtype(kvar)
1514  'float32'
1515  >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]),
1516  ...                                  dtype='float32')
1517  >>> tf.keras.backend.dtype(kvar)
1518  'float32'
1519
1520  """
1521  return x.dtype.base_dtype.name
1522
1523
1524@doc_controls.do_not_generate_docs
1525def dtype_numpy(x):
1526  """Returns the numpy dtype of a Keras tensor or variable.
1527
1528  Args:
1529      x: Tensor or variable.
1530
1531  Returns:
1532      numpy.dtype, dtype of `x`.
1533  """
1534  return dtypes_module.as_dtype(x.dtype).as_numpy_dtype
1535
1536
1537@keras_export('keras.backend.eval')
1538@doc_controls.do_not_generate_docs
1539def eval(x):
1540  """Evaluates the value of a variable.
1541
1542  Args:
1543      x: A variable.
1544
1545  Returns:
1546      A Numpy array.
1547
1548  Examples:
1549
1550  >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]),
1551  ...                                  dtype='float32')
1552  >>> tf.keras.backend.eval(kvar)
1553  array([[1.,  2.],
1554         [3.,  4.]], dtype=float32)
1555
1556  """
1557  return get_value(to_dense(x))
1558
1559
1560@keras_export('keras.backend.zeros')
1561@doc_controls.do_not_generate_docs
1562def zeros(shape, dtype=None, name=None):
1563  """Instantiates an all-zeros variable and returns it.
1564
1565  Args:
1566      shape: Tuple or list of integers, shape of returned Keras variable
1567      dtype: data type of returned Keras variable
1568      name: name of returned Keras variable
1569
1570  Returns:
1571      A variable (including Keras metadata), filled with `0.0`.
1572      Note that if `shape` was symbolic, we cannot return a variable,
1573      and will return a dynamically-shaped tensor instead.
1574
1575  Example:
1576
1577  >>> kvar = tf.keras.backend.zeros((3,4))
1578  >>> tf.keras.backend.eval(kvar)
1579  array([[0.,  0.,  0.,  0.],
1580         [0.,  0.,  0.,  0.],
1581         [0.,  0.,  0.,  0.]], dtype=float32)
1582  >>> A = tf.constant([1,2,3])
1583  >>> kvar2 = tf.keras.backend.zeros(A.shape) # [0., 0., 0.]
1584  >>> tf.keras.backend.eval(kvar2)
1585  array([0., 0., 0.], dtype=float32)
1586  >>> kvar3 = tf.keras.backend.zeros(A.shape,dtype=tf.int32)
1587  >>> tf.keras.backend.eval(kvar3)
1588  array([0, 0, 0], dtype=int32)
1589  >>> kvar4 = tf.keras.backend.zeros([2,3])
1590  >>> tf.keras.backend.eval(kvar4)
1591  array([[0., 0., 0.],
1592         [0., 0., 0.]], dtype=float32)
1593
1594  """
1595  with ops.init_scope():
1596    if dtype is None:
1597      dtype = floatx()
1598    tf_dtype = dtypes_module.as_dtype(dtype)
1599    v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
1600    if py_all(v.shape.as_list()):
1601      return variable(v, dtype=dtype, name=name)
1602    return v
1603
1604
1605@keras_export('keras.backend.ones')
1606@dispatch.add_dispatch_support
1607@doc_controls.do_not_generate_docs
1608def ones(shape, dtype=None, name=None):
1609  """Instantiates an all-ones variable and returns it.
1610
1611  Args:
1612      shape: Tuple of integers, shape of returned Keras variable.
1613      dtype: String, data type of returned Keras variable.
1614      name: String, name of returned Keras variable.
1615
1616  Returns:
1617      A Keras variable, filled with `1.0`.
1618      Note that if `shape` was symbolic, we cannot return a variable,
1619      and will return a dynamically-shaped tensor instead.
1620
1621  Example:
1622
1623
1624  >>> kvar = tf.keras.backend.ones((3,4))
1625  >>> tf.keras.backend.eval(kvar)
1626  array([[1.,  1.,  1.,  1.],
1627         [1.,  1.,  1.,  1.],
1628         [1.,  1.,  1.,  1.]], dtype=float32)
1629
1630  """
1631  with ops.init_scope():
1632    if dtype is None:
1633      dtype = floatx()
1634    tf_dtype = dtypes_module.as_dtype(dtype)
1635    v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
1636    if py_all(v.shape.as_list()):
1637      return variable(v, dtype=dtype, name=name)
1638    return v
1639
1640
1641@keras_export('keras.backend.eye')
1642@dispatch.add_dispatch_support
1643@doc_controls.do_not_generate_docs
1644def eye(size, dtype=None, name=None):
1645  """Instantiate an identity matrix and returns it.
1646
1647  Args:
1648      size: Integer, number of rows/columns.
1649      dtype: String, data type of returned Keras variable.
1650      name: String, name of returned Keras variable.
1651
1652  Returns:
1653      A Keras variable, an identity matrix.
1654
1655  Example:
1656
1657
1658  >>> kvar = tf.keras.backend.eye(3)
1659  >>> tf.keras.backend.eval(kvar)
1660  array([[1.,  0.,  0.],
1661         [0.,  1.,  0.],
1662         [0.,  0.,  1.]], dtype=float32)
1663
1664
1665  """
1666  if dtype is None:
1667    dtype = floatx()
1668  tf_dtype = dtypes_module.as_dtype(dtype)
1669  return variable(linalg_ops.eye(size, dtype=tf_dtype), dtype, name)
1670
1671
1672@keras_export('keras.backend.zeros_like')
1673@doc_controls.do_not_generate_docs
1674def zeros_like(x, dtype=None, name=None):
1675  """Instantiates an all-zeros variable of the same shape as another tensor.
1676
1677  Args:
1678      x: Keras variable or Keras tensor.
1679      dtype: dtype of returned Keras variable.
1680             `None` uses the dtype of `x`.
1681      name: name for the variable to create.
1682
1683  Returns:
1684      A Keras variable with the shape of `x` filled with zeros.
1685
1686  Example:
1687
1688  ```python
1689  from tensorflow.keras import backend as K
1690  kvar = K.variable(np.random.random((2,3)))
1691  kvar_zeros = K.zeros_like(kvar)
1692  K.eval(kvar_zeros)
1693  # array([[ 0.,  0.,  0.], [ 0.,  0.,  0.]], dtype=float32)
1694  ```
1695  """
1696  return array_ops.zeros_like(x, dtype=dtype, name=name)
1697
1698
1699@keras_export('keras.backend.ones_like')
1700@dispatch.add_dispatch_support
1701@doc_controls.do_not_generate_docs
1702def ones_like(x, dtype=None, name=None):
1703  """Instantiates an all-ones variable of the same shape as another tensor.
1704
1705  Args:
1706      x: Keras variable or tensor.
1707      dtype: String, dtype of returned Keras variable.
1708           None uses the dtype of x.
1709      name: String, name for the variable to create.
1710
1711  Returns:
1712      A Keras variable with the shape of x filled with ones.
1713
1714  Example:
1715
1716  >>> kvar = tf.keras.backend.variable(np.random.random((2,3)))
1717  >>> kvar_ones = tf.keras.backend.ones_like(kvar)
1718  >>> tf.keras.backend.eval(kvar_ones)
1719  array([[1.,  1.,  1.],
1720         [1.,  1.,  1.]], dtype=float32)
1721
1722  """
1723  return array_ops.ones_like(x, dtype=dtype, name=name)
1724
1725
1726def identity(x, name=None):
1727  """Returns a tensor with the same content as the input tensor.
1728
1729  Args:
1730      x: The input tensor.
1731      name: String, name for the variable to create.
1732
1733  Returns:
1734      A tensor of the same shape, type and content.
1735  """
1736  return array_ops.identity(x, name=name)
1737
1738
1739@keras_export('keras.backend.random_uniform_variable')
1740@doc_controls.do_not_generate_docs
1741def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):
1742  """Instantiates a variable with values drawn from a uniform distribution.
1743
1744  Args:
1745      shape: Tuple of integers, shape of returned Keras variable.
1746      low: Float, lower boundary of the output interval.
1747      high: Float, upper boundary of the output interval.
1748      dtype: String, dtype of returned Keras variable.
1749      name: String, name of returned Keras variable.
1750      seed: Integer, random seed.
1751
1752  Returns:
1753      A Keras variable, filled with drawn samples.
1754
1755  Example:
1756
1757  >>> kvar = tf.keras.backend.random_uniform_variable(shape=(2,3),
1758  ... low=0.0, high=1.0)
1759  >>> kvar
1760  <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
1761  dtype=float32)>
1762  """
1763  if dtype is None:
1764    dtype = floatx()
1765  tf_dtype = dtypes_module.as_dtype(dtype)
1766  if seed is None:
1767    # ensure that randomness is conditioned by the Numpy RNG
1768    seed = np.random.randint(10e8)
1769  value = init_ops.random_uniform_initializer(
1770      low, high, dtype=tf_dtype, seed=seed)(shape)
1771  return variable(value, dtype=dtype, name=name)
1772
1773
1774@keras_export('keras.backend.random_normal_variable')
1775@doc_controls.do_not_generate_docs
1776def random_normal_variable(shape, mean, scale, dtype=None, name=None,
1777                           seed=None):
1778  """Instantiates a variable with values drawn from a normal distribution.
1779
1780  Args:
1781      shape: Tuple of integers, shape of returned Keras variable.
1782      mean: Float, mean of the normal distribution.
1783      scale: Float, standard deviation of the normal distribution.
1784      dtype: String, dtype of returned Keras variable.
1785      name: String, name of returned Keras variable.
1786      seed: Integer, random seed.
1787
1788  Returns:
1789      A Keras variable, filled with drawn samples.
1790
1791  Example:
1792
1793  >>> kvar = tf.keras.backend.random_normal_variable(shape=(2,3),
1794  ... mean=0.0, scale=1.0)
1795  >>> kvar
1796  <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
1797  dtype=float32)>
1798  """
1799  if dtype is None:
1800    dtype = floatx()
1801  tf_dtype = dtypes_module.as_dtype(dtype)
1802  if seed is None:
1803    # ensure that randomness is conditioned by the Numpy RNG
1804    seed = np.random.randint(10e8)
1805  value = init_ops.random_normal_initializer(
1806      mean, scale, dtype=tf_dtype, seed=seed)(shape)
1807  return variable(value, dtype=dtype, name=name)
1808
1809
1810@keras_export('keras.backend.count_params')
1811@doc_controls.do_not_generate_docs
1812def count_params(x):
1813  """Returns the static number of elements in a variable or tensor.
1814
1815  Args:
1816      x: Variable or tensor.
1817
1818  Returns:
1819      Integer, the number of scalars in `x`.
1820
1821  Example:
1822
1823  >>> kvar = tf.keras.backend.zeros((2,3))
1824  >>> tf.keras.backend.count_params(kvar)
1825  6
1826  >>> tf.keras.backend.eval(kvar)
1827  array([[0.,  0.,  0.],
1828         [0.,  0.,  0.]], dtype=float32)
1829
1830  """
1831  return np.prod(x.shape.as_list())
1832
1833
1834@keras_export('keras.backend.cast')
1835@dispatch.add_dispatch_support
1836@doc_controls.do_not_generate_docs
1837def cast(x, dtype):
1838  """Casts a tensor to a different dtype and returns it.
1839
1840  You can cast a Keras variable but it still returns a Keras tensor.
1841
1842  Args:
1843      x: Keras tensor (or variable).
1844      dtype: String, either (`'float16'`, `'float32'`, or `'float64'`).
1845
1846  Returns:
1847      Keras tensor with dtype `dtype`.
1848
1849  Examples:
1850      Cast a float32 variable to a float64 tensor
1851
1852  >>> input = tf.keras.backend.ones(shape=(1,3))
1853  >>> print(input)
1854  <tf.Variable 'Variable:0' shape=(1, 3) dtype=float32,
1855  numpy=array([[1., 1., 1.]], dtype=float32)>
1856  >>> cast_input = tf.keras.backend.cast(input, dtype='float64')
1857  >>> print(cast_input)
1858  tf.Tensor([[1. 1. 1.]], shape=(1, 3), dtype=float64)
1859
1860  """
1861  return math_ops.cast(x, dtype)
1862
1863
1864# UPDATES OPS
1865
1866
1867@keras_export('keras.backend.update')
1868@doc_controls.do_not_generate_docs
1869def update(x, new_x):
1870  return state_ops.assign(x, new_x)
1871
1872
1873@keras_export('keras.backend.update_add')
1874@doc_controls.do_not_generate_docs
1875def update_add(x, increment):
1876  """Update the value of `x` by adding `increment`.
1877
1878  Args:
1879      x: A Variable.
1880      increment: A tensor of same shape as `x`.
1881
1882  Returns:
1883      The variable `x` updated.
1884  """
1885  return state_ops.assign_add(x, increment)
1886
1887
1888@keras_export('keras.backend.update_sub')
1889@doc_controls.do_not_generate_docs
1890def update_sub(x, decrement):
1891  """Update the value of `x` by subtracting `decrement`.
1892
1893  Args:
1894      x: A Variable.
1895      decrement: A tensor of same shape as `x`.
1896
1897  Returns:
1898      The variable `x` updated.
1899  """
1900  return state_ops.assign_sub(x, decrement)
1901
1902
1903@keras_export('keras.backend.moving_average_update')
1904@doc_controls.do_not_generate_docs
1905def moving_average_update(x, value, momentum):
1906  """Compute the exponential moving average of a value.
1907
1908  The moving average 'x' is updated with 'value' following:
1909
1910  ```
1911  x = x * momentum + value * (1 - momentum)
1912  ```
1913
1914  For example:
1915
1916  >>> x = tf.Variable(0.0)
1917  >>> momentum=0.9
1918  >>> moving_average_update(x, value = 2.0, momentum=momentum).numpy()
1919  >>> x.numpy()
1920  0.2
1921
1922  The result will be biased towards the initial value of the variable.
1923
1924  If the variable was initialized to zero, you can divide by
1925  `1 - momentum ** num_updates` to debias it (Section 3 of
1926  [Kingma et al., 2015](https://arxiv.org/abs/1412.6980)):
1927
1928  >>> num_updates = 1.0
1929  >>> x_zdb = x/(1 - momentum**num_updates)
1930  >>> x_zdb.numpy()
1931  2.0
1932
1933  Args:
1934      x: A Variable, the moving average.
1935      value: A tensor with the same shape as `x`, the new value to be
1936        averaged in.
1937      momentum: The moving average momentum.
1938
1939  Returns:
1940      The updated variable.
1941  """
1942  zero_debias = not tf2.enabled()
1943  return moving_averages.assign_moving_average(
1944      x, value, momentum, zero_debias=zero_debias)
1945
1946
1947# LINEAR ALGEBRA
1948
1949
1950@keras_export('keras.backend.dot')
1951@dispatch.add_dispatch_support
1952@doc_controls.do_not_generate_docs
1953def dot(x, y):
1954  """Multiplies 2 tensors (and/or variables) and returns a tensor.
1955
1956  This operation corresponds to `numpy.dot(a, b, out=None)`.
1957
1958  Args:
1959      x: Tensor or variable.
1960      y: Tensor or variable.
1961
1962  Returns:
1963      A tensor, dot product of `x` and `y`.
1964
1965  Examples:
1966
1967  If inputs `x` and `y` are 2-D arrays, then it is equivalent to `tf.matmul`.
1968  >>> x = tf.keras.backend.placeholder(shape=(2, 3))
1969  >>> y = tf.keras.backend.placeholder(shape=(3, 4))
1970  >>> xy = tf.keras.backend.dot(x, y)
1971  >>> xy
1972  <KerasTensor: shape=(2, 4) dtype=float32 ...>
1973
1974  >>> x = tf.keras.backend.placeholder(shape=(32, 28, 3))
1975  >>> y = tf.keras.backend.placeholder(shape=(3, 4))
1976  >>> xy = tf.keras.backend.dot(x, y)
1977  >>> xy
1978  <KerasTensor: shape=(32, 28, 4) dtype=float32 ...>
1979
1980  If `x` is an N-D array and `y` is an M-D array (where M>=2), it is a sum
1981  product over the last axis of `x` and the second-to-last axis of `y`.
1982  >>> x = tf.keras.backend.random_uniform_variable(shape=(2, 3), low=0, high=1)
1983  >>> y = tf.keras.backend.ones((4, 3, 5))
1984  >>> xy = tf.keras.backend.dot(x, y)
1985  >>> tf.keras.backend.int_shape(xy)
1986  (2, 4, 5)
1987  """
1988  if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
1989    x_shape = []
1990    for i, s in zip(int_shape(x), array_ops.unstack(array_ops.shape(x))):
1991      if i is not None:
1992        x_shape.append(i)
1993      else:
1994        x_shape.append(s)
1995    x_shape = tuple(x_shape)
1996    y_shape = []
1997    for i, s in zip(int_shape(y), array_ops.unstack(array_ops.shape(y))):
1998      if i is not None:
1999        y_shape.append(i)
2000      else:
2001        y_shape.append(s)
2002    y_shape = tuple(y_shape)
2003    y_permute_dim = list(range(ndim(y)))
2004    y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim
2005    xt = array_ops.reshape(x, [-1, x_shape[-1]])
2006    yt = array_ops.reshape(
2007        array_ops.transpose(y, perm=y_permute_dim), [y_shape[-2], -1])
2008    return array_ops.reshape(
2009        math_ops.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:])
2010  if is_sparse(x):
2011    out = sparse_ops.sparse_tensor_dense_matmul(x, y)
2012  else:
2013    out = math_ops.matmul(x, y)
2014  return out
2015
2016
2017@keras_export('keras.backend.batch_dot')
2018@dispatch.add_dispatch_support
2019@doc_controls.do_not_generate_docs
2020def batch_dot(x, y, axes=None):
2021  """Batchwise dot product.
2022
2023  `batch_dot` is used to compute dot product of `x` and `y` when
2024  `x` and `y` are data in batch, i.e. in a shape of
2025  `(batch_size, :)`.
2026  `batch_dot` results in a tensor or variable with less dimensions
2027  than the input. If the number of dimensions is reduced to 1,
2028  we use `expand_dims` to make sure that ndim is at least 2.
2029
2030  Args:
2031    x: Keras tensor or variable with `ndim >= 2`.
2032    y: Keras tensor or variable with `ndim >= 2`.
2033    axes: Tuple or list of integers with target dimensions, or single integer.
2034      The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]` should be equal.
2035
2036  Returns:
2037    A tensor with shape equal to the concatenation of `x`'s shape
2038    (less the dimension that was summed over) and `y`'s shape
2039    (less the batch dimension and the dimension that was summed over).
2040    If the final rank is 1, we reshape it to `(batch_size, 1)`.
2041
2042  Examples:
2043
2044  >>> x_batch = tf.keras.backend.ones(shape=(32, 20, 1))
2045  >>> y_batch = tf.keras.backend.ones(shape=(32, 30, 20))
2046  >>> xy_batch_dot = tf.keras.backend.batch_dot(x_batch, y_batch, axes=(1, 2))
2047  >>> tf.keras.backend.int_shape(xy_batch_dot)
2048  (32, 1, 30)
2049
2050  Shape inference:
2051    Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.
2052    If `axes` is (1, 2), to find the output shape of resultant tensor,
2053        loop through each dimension in `x`'s shape and `y`'s shape:
2054    * `x.shape[0]` : 100 : append to output shape
2055    * `x.shape[1]` : 20 : do not append to output shape,
2056        dimension 1 of `x` has been summed over. (`dot_axes[0]` = 1)
2057    * `y.shape[0]` : 100 : do not append to output shape,
2058        always ignore first dimension of `y`
2059    * `y.shape[1]` : 30 : append to output shape
2060    * `y.shape[2]` : 20 : do not append to output shape,
2061        dimension 2 of `y` has been summed over. (`dot_axes[1]` = 2)
2062    `output_shape` = `(100, 30)`
2063  """
2064  x_shape = int_shape(x)
2065  y_shape = int_shape(y)
2066
2067  x_ndim = len(x_shape)
2068  y_ndim = len(y_shape)
2069
2070  if x_ndim < 2 or y_ndim < 2:
2071    raise ValueError('Cannot do batch_dot on inputs '
2072                     'with rank < 2. '
2073                     'Received inputs with shapes ' +
2074                     str(x_shape) + ' and ' +
2075                     str(y_shape) + '.')
2076
2077  x_batch_size = x_shape[0]
2078  y_batch_size = y_shape[0]
2079
2080  if x_batch_size is not None and y_batch_size is not None:
2081    if x_batch_size != y_batch_size:
2082      raise ValueError('Cannot do batch_dot on inputs '
2083                       'with different batch sizes. '
2084                       'Received inputs with shapes ' +
2085                       str(x_shape) + ' and ' +
2086                       str(y_shape) + '.')
2087  if isinstance(axes, int):
2088    axes = [axes, axes]
2089
2090  if axes is None:
2091    if y_ndim == 2:
2092      axes = [x_ndim - 1, y_ndim - 1]
2093    else:
2094      axes = [x_ndim - 1, y_ndim - 2]
2095
2096  if py_any(isinstance(a, (list, tuple)) for a in axes):
2097    raise ValueError('Multiple target dimensions are not supported. ' +
2098                     'Expected: None, int, (int, int), ' +
2099                     'Provided: ' + str(axes))
2100
2101  # if tuple, convert to list.
2102  axes = list(axes)
2103
2104  # convert negative indices.
2105  if axes[0] < 0:
2106    axes[0] += x_ndim
2107  if axes[1] < 0:
2108    axes[1] += y_ndim
2109
2110  # sanity checks
2111  if 0 in axes:
2112    raise ValueError('Cannot perform batch_dot over axis 0. '
2113                     'If your inputs are not batched, '
2114                     'add a dummy batch dimension to your '
2115                     'inputs using K.expand_dims(x, 0)')
2116  a0, a1 = axes
2117  d1 = x_shape[a0]
2118  d2 = y_shape[a1]
2119
2120  if d1 is not None and d2 is not None and d1 != d2:
2121    raise ValueError('Cannot do batch_dot on inputs with shapes ' +
2122                     str(x_shape) + ' and ' + str(y_shape) +
2123                     ' with axes=' + str(axes) + '. x.shape[%d] != '
2124                     'y.shape[%d] (%d != %d).' % (axes[0], axes[1], d1, d2))
2125
2126  # backup ndims. Need them later.
2127  orig_x_ndim = x_ndim
2128  orig_y_ndim = y_ndim
2129
2130  # if rank is 2, expand to 3.
2131  if x_ndim == 2:
2132    x = array_ops.expand_dims(x, 1)
2133    a0 += 1
2134    x_ndim += 1
2135  if y_ndim == 2:
2136    y = array_ops.expand_dims(y, 2)
2137    y_ndim += 1
2138
2139  # bring x's dimension to be reduced to last axis.
2140  if a0 != x_ndim - 1:
2141    pattern = list(range(x_ndim))
2142    for i in range(a0, x_ndim - 1):
2143      pattern[i] = pattern[i + 1]
2144    pattern[-1] = a0
2145    x = array_ops.transpose(x, pattern)
2146
2147  # bring y's dimension to be reduced to axis 1.
2148  if a1 != 1:
2149    pattern = list(range(y_ndim))
2150    for i in range(a1, 1, -1):
2151      pattern[i] = pattern[i - 1]
2152    pattern[1] = a1
2153    y = array_ops.transpose(y, pattern)
2154
2155  # normalize both inputs to rank 3.
2156  if x_ndim > 3:
2157    # squash middle dimensions of x.
2158    x_shape = shape(x)
2159    x_mid_dims = x_shape[1:-1]
2160    x_squashed_shape = array_ops.stack(
2161        [x_shape[0], -1, x_shape[-1]])
2162    x = array_ops.reshape(x, x_squashed_shape)
2163    x_squashed = True
2164  else:
2165    x_squashed = False
2166
2167  if y_ndim > 3:
2168    # squash trailing dimensions of y.
2169    y_shape = shape(y)
2170    y_trail_dims = y_shape[2:]
2171    y_squashed_shape = array_ops.stack(
2172        [y_shape[0], y_shape[1], -1])
2173    y = array_ops.reshape(y, y_squashed_shape)
2174    y_squashed = True
2175  else:
2176    y_squashed = False
2177
2178  result = math_ops.matmul(x, y)
2179
2180  # if inputs were squashed, we have to reshape the matmul output.
2181  output_shape = array_ops.shape(result)
2182  do_reshape = False
2183
2184  if x_squashed:
2185    output_shape = array_ops.concat(
2186        [output_shape[:1],
2187         x_mid_dims,
2188         output_shape[-1:]], 0)
2189    do_reshape = True
2190
2191  if y_squashed:
2192    output_shape = array_ops.concat([output_shape[:-1], y_trail_dims], 0)
2193    do_reshape = True
2194
2195  if do_reshape:
2196    result = array_ops.reshape(result, output_shape)
2197
2198  # if the inputs were originally rank 2, we remove the added 1 dim.
2199  if orig_x_ndim == 2:
2200    result = array_ops.squeeze(result, 1)
2201  elif orig_y_ndim == 2:
2202    result = array_ops.squeeze(result, -1)
2203
2204  return result
2205
2206
2207@keras_export('keras.backend.transpose')
2208@dispatch.add_dispatch_support
2209@doc_controls.do_not_generate_docs
2210def transpose(x):
2211  """Transposes a tensor and returns it.
2212
2213  Args:
2214      x: Tensor or variable.
2215
2216  Returns:
2217      A tensor.
2218
2219  Examples:
2220
2221  >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]])
2222  >>> tf.keras.backend.eval(var)
2223  array([[1.,  2.,  3.],
2224         [4.,  5.,  6.]], dtype=float32)
2225  >>> var_transposed = tf.keras.backend.transpose(var)
2226  >>> tf.keras.backend.eval(var_transposed)
2227  array([[1.,  4.],
2228         [2.,  5.],
2229         [3.,  6.]], dtype=float32)
2230  >>> input = tf.keras.backend.placeholder((2, 3))
2231  >>> input
2232  <KerasTensor: shape=(2, 3) dtype=float32 ...>
2233  >>> input_transposed = tf.keras.backend.transpose(input)
2234  >>> input_transposed
2235  <KerasTensor: shape=(3, 2) dtype=float32 ...>
2236  """
2237  return array_ops.transpose(x)
2238
2239
2240@keras_export('keras.backend.gather')
2241@dispatch.add_dispatch_support
2242@doc_controls.do_not_generate_docs
2243def gather(reference, indices):
2244  """Retrieves the elements of indices `indices` in the tensor `reference`.
2245
2246  Args:
2247      reference: A tensor.
2248      indices: An integer tensor of indices.
2249
2250  Returns:
2251      A tensor of same type as `reference`.
2252
2253  Examples:
2254
2255  >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]])
2256  >>> tf.keras.backend.eval(var)
2257  array([[1., 2., 3.],
2258         [4., 5., 6.]], dtype=float32)
2259  >>> var_gathered = tf.keras.backend.gather(var, [0])
2260  >>> tf.keras.backend.eval(var_gathered)
2261  array([[1., 2., 3.]], dtype=float32)
2262  >>> var_gathered = tf.keras.backend.gather(var, [1])
2263  >>> tf.keras.backend.eval(var_gathered)
2264  array([[4., 5., 6.]], dtype=float32)
2265  >>> var_gathered = tf.keras.backend.gather(var, [0,1,0])
2266  >>> tf.keras.backend.eval(var_gathered)
2267  array([[1., 2., 3.],
2268         [4., 5., 6.],
2269         [1., 2., 3.]], dtype=float32)
2270  """
2271  return array_ops.gather(reference, indices)
2272
2273
2274# ELEMENT-WISE OPERATIONS
2275
2276
2277@keras_export('keras.backend.max')
2278@dispatch.add_dispatch_support
2279@doc_controls.do_not_generate_docs
2280def max(x, axis=None, keepdims=False):
2281  """Maximum value in a tensor.
2282
2283  Args:
2284      x: A tensor or variable.
2285      axis: An integer, the axis to find maximum values.
2286      keepdims: A boolean, whether to keep the dimensions or not.
2287          If `keepdims` is `False`, the rank of the tensor is reduced
2288          by 1. If `keepdims` is `True`,
2289          the reduced dimension is retained with length 1.
2290
2291  Returns:
2292      A tensor with maximum values of `x`.
2293  """
2294  return math_ops.reduce_max(x, axis, keepdims)
2295
2296
2297@keras_export('keras.backend.min')
2298@dispatch.add_dispatch_support
2299@doc_controls.do_not_generate_docs
2300def min(x, axis=None, keepdims=False):
2301  """Minimum value in a tensor.
2302
2303  Args:
2304      x: A tensor or variable.
2305      axis: An integer, the axis to find minimum values.
2306      keepdims: A boolean, whether to keep the dimensions or not.
2307          If `keepdims` is `False`, the rank of the tensor is reduced
2308          by 1. If `keepdims` is `True`,
2309          the reduced dimension is retained with length 1.
2310
2311  Returns:
2312      A tensor with minimum values of `x`.
2313  """
2314  return math_ops.reduce_min(x, axis, keepdims)
2315
2316
2317@keras_export('keras.backend.sum')
2318@dispatch.add_dispatch_support
2319@doc_controls.do_not_generate_docs
2320def sum(x, axis=None, keepdims=False):
2321  """Sum of the values in a tensor, alongside the specified axis.
2322
2323  Args:
2324      x: A tensor or variable.
2325      axis: An integer, the axis to sum over.
2326      keepdims: A boolean, whether to keep the dimensions or not.
2327          If `keepdims` is `False`, the rank of the tensor is reduced
2328          by 1. If `keepdims` is `True`,
2329          the reduced dimension is retained with length 1.
2330
2331  Returns:
2332      A tensor with sum of `x`.
2333  """
2334  return math_ops.reduce_sum(x, axis, keepdims)
2335
2336
2337@keras_export('keras.backend.prod')
2338@dispatch.add_dispatch_support
2339@doc_controls.do_not_generate_docs
2340def prod(x, axis=None, keepdims=False):
2341  """Multiplies the values in a tensor, alongside the specified axis.
2342
2343  Args:
2344      x: A tensor or variable.
2345      axis: An integer, the axis to compute the product.
2346      keepdims: A boolean, whether to keep the dimensions or not.
2347          If `keepdims` is `False`, the rank of the tensor is reduced
2348          by 1. If `keepdims` is `True`,
2349          the reduced dimension is retained with length 1.
2350
2351  Returns:
2352      A tensor with the product of elements of `x`.
2353  """
2354  return math_ops.reduce_prod(x, axis, keepdims)
2355
2356
2357@keras_export('keras.backend.cumsum')
2358@dispatch.add_dispatch_support
2359@doc_controls.do_not_generate_docs
2360def cumsum(x, axis=0):
2361  """Cumulative sum of the values in a tensor, alongside the specified axis.
2362
2363  Args:
2364      x: A tensor or variable.
2365      axis: An integer, the axis to compute the sum.
2366
2367  Returns:
2368      A tensor of the cumulative sum of values of `x` along `axis`.
2369  """
2370  return math_ops.cumsum(x, axis=axis)
2371
2372
2373@keras_export('keras.backend.cumprod')
2374@dispatch.add_dispatch_support
2375@doc_controls.do_not_generate_docs
2376def cumprod(x, axis=0):
2377  """Cumulative product of the values in a tensor, alongside the specified axis.
2378
2379  Args:
2380      x: A tensor or variable.
2381      axis: An integer, the axis to compute the product.
2382
2383  Returns:
2384      A tensor of the cumulative product of values of `x` along `axis`.
2385  """
2386  return math_ops.cumprod(x, axis=axis)
2387
2388
2389@keras_export('keras.backend.var')
2390@doc_controls.do_not_generate_docs
2391def var(x, axis=None, keepdims=False):
2392  """Variance of a tensor, alongside the specified axis.
2393
2394  Args:
2395      x: A tensor or variable.
2396      axis: An integer, the axis to compute the variance.
2397      keepdims: A boolean, whether to keep the dimensions or not.
2398          If `keepdims` is `False`, the rank of the tensor is reduced
2399          by 1. If `keepdims` is `True`,
2400          the reduced dimension is retained with length 1.
2401
2402  Returns:
2403      A tensor with the variance of elements of `x`.
2404  """
2405  if x.dtype.base_dtype == dtypes_module.bool:
2406    x = math_ops.cast(x, floatx())
2407  return math_ops.reduce_variance(x, axis=axis, keepdims=keepdims)
2408
2409
2410@keras_export('keras.backend.std')
2411@dispatch.add_dispatch_support
2412@doc_controls.do_not_generate_docs
2413def std(x, axis=None, keepdims=False):
2414  """Standard deviation of a tensor, alongside the specified axis.
2415
2416  It is an alias to `tf.math.reduce_std`.
2417
2418  Args:
2419      x: A tensor or variable. It should have numerical dtypes. Boolean type
2420        inputs will be converted to float.
2421      axis: An integer, the axis to compute the standard deviation. If `None`
2422        (the default), reduces all dimensions. Must be in the range
2423        `[-rank(x), rank(x))`.
2424      keepdims: A boolean, whether to keep the dimensions or not.
2425          If `keepdims` is `False`, the rank of the tensor is reduced
2426          by 1. If `keepdims` is `True`, the reduced dimension is retained with
2427          length 1.
2428
2429  Returns:
2430      A tensor with the standard deviation of elements of `x` with same dtype.
2431      Boolean type input will be converted to float.
2432  """
2433  if x.dtype.base_dtype == dtypes_module.bool:
2434    x = math_ops.cast(x, floatx())
2435  return math_ops.reduce_std(x, axis=axis, keepdims=keepdims)
2436
2437
2438@keras_export('keras.backend.mean')
2439@dispatch.add_dispatch_support
2440@doc_controls.do_not_generate_docs
2441def mean(x, axis=None, keepdims=False):
2442  """Mean of a tensor, alongside the specified axis.
2443
2444  Args:
2445      x: A tensor or variable.
2446      axis: A list of integer. Axes to compute the mean.
2447      keepdims: A boolean, whether to keep the dimensions or not.
2448          If `keepdims` is `False`, the rank of the tensor is reduced
2449          by 1 for each entry in `axis`. If `keepdims` is `True`,
2450          the reduced dimensions are retained with length 1.
2451
2452  Returns:
2453      A tensor with the mean of elements of `x`.
2454  """
2455  if x.dtype.base_dtype == dtypes_module.bool:
2456    x = math_ops.cast(x, floatx())
2457  return math_ops.reduce_mean(x, axis, keepdims)
2458
2459
2460@keras_export('keras.backend.any')
2461@dispatch.add_dispatch_support
2462@doc_controls.do_not_generate_docs
2463def any(x, axis=None, keepdims=False):
2464  """Bitwise reduction (logical OR).
2465
2466  Args:
2467      x: Tensor or variable.
2468      axis: axis along which to perform the reduction.
2469      keepdims: whether the drop or broadcast the reduction axes.
2470
2471  Returns:
2472      A uint8 tensor (0s and 1s).
2473  """
2474  x = math_ops.cast(x, dtypes_module.bool)
2475  return math_ops.reduce_any(x, axis, keepdims)
2476
2477
2478@keras_export('keras.backend.all')
2479@dispatch.add_dispatch_support
2480@doc_controls.do_not_generate_docs
2481def all(x, axis=None, keepdims=False):
2482  """Bitwise reduction (logical AND).
2483
2484  Args:
2485      x: Tensor or variable.
2486      axis: axis along which to perform the reduction.
2487      keepdims: whether the drop or broadcast the reduction axes.
2488
2489  Returns:
2490      A uint8 tensor (0s and 1s).
2491  """
2492  x = math_ops.cast(x, dtypes_module.bool)
2493  return math_ops.reduce_all(x, axis, keepdims)
2494
2495
2496@keras_export('keras.backend.argmax')
2497@dispatch.add_dispatch_support
2498@doc_controls.do_not_generate_docs
2499def argmax(x, axis=-1):
2500  """Returns the index of the maximum value along an axis.
2501
2502  Args:
2503      x: Tensor or variable.
2504      axis: axis along which to perform the reduction.
2505
2506  Returns:
2507      A tensor.
2508  """
2509  return math_ops.argmax(x, axis)
2510
2511
2512@keras_export('keras.backend.argmin')
2513@dispatch.add_dispatch_support
2514@doc_controls.do_not_generate_docs
2515def argmin(x, axis=-1):
2516  """Returns the index of the minimum value along an axis.
2517
2518  Args:
2519      x: Tensor or variable.
2520      axis: axis along which to perform the reduction.
2521
2522  Returns:
2523      A tensor.
2524  """
2525  return math_ops.argmin(x, axis)
2526
2527
2528@keras_export('keras.backend.square')
2529@dispatch.add_dispatch_support
2530@doc_controls.do_not_generate_docs
2531def square(x):
2532  """Element-wise square.
2533
2534  Args:
2535      x: Tensor or variable.
2536
2537  Returns:
2538      A tensor.
2539  """
2540  return math_ops.square(x)
2541
2542
2543@keras_export('keras.backend.abs')
2544@dispatch.add_dispatch_support
2545@doc_controls.do_not_generate_docs
2546def abs(x):
2547  """Element-wise absolute value.
2548
2549  Args:
2550      x: Tensor or variable.
2551
2552  Returns:
2553      A tensor.
2554  """
2555  return math_ops.abs(x)
2556
2557
2558@keras_export('keras.backend.sqrt')
2559@dispatch.add_dispatch_support
2560@doc_controls.do_not_generate_docs
2561def sqrt(x):
2562  """Element-wise square root.
2563
2564     This function clips negative tensor values to 0 before computing the
2565     square root.
2566
2567  Args:
2568      x: Tensor or variable.
2569
2570  Returns:
2571      A tensor.
2572  """
2573  zero = _constant_to_tensor(0., x.dtype.base_dtype)
2574  x = math_ops.maximum(x, zero)
2575  return math_ops.sqrt(x)
2576
2577
2578@keras_export('keras.backend.exp')
2579@dispatch.add_dispatch_support
2580@doc_controls.do_not_generate_docs
2581def exp(x):
2582  """Element-wise exponential.
2583
2584  Args:
2585      x: Tensor or variable.
2586
2587  Returns:
2588      A tensor.
2589  """
2590  return math_ops.exp(x)
2591
2592
2593@keras_export('keras.backend.log')
2594@dispatch.add_dispatch_support
2595@doc_controls.do_not_generate_docs
2596def log(x):
2597  """Element-wise log.
2598
2599  Args:
2600      x: Tensor or variable.
2601
2602  Returns:
2603      A tensor.
2604  """
2605  return math_ops.log(x)
2606
2607
2608def logsumexp(x, axis=None, keepdims=False):
2609  """Computes log(sum(exp(elements across dimensions of a tensor))).
2610
2611  This function is more numerically stable than log(sum(exp(x))).
2612  It avoids overflows caused by taking the exp of large inputs and
2613  underflows caused by taking the log of small inputs.
2614
2615  Args:
2616      x: A tensor or variable.
2617      axis: An integer, the axis to reduce over.
2618      keepdims: A boolean, whether to keep the dimensions or not.
2619          If `keepdims` is `False`, the rank of the tensor is reduced
2620          by 1. If `keepdims` is `True`, the reduced dimension is
2621          retained with length 1.
2622
2623  Returns:
2624      The reduced tensor.
2625  """
2626  return math_ops.reduce_logsumexp(x, axis, keepdims)
2627
2628
2629@keras_export('keras.backend.round')
2630@dispatch.add_dispatch_support
2631@doc_controls.do_not_generate_docs
2632def round(x):
2633  """Element-wise rounding to the closest integer.
2634
2635  In case of tie, the rounding mode used is "half to even".
2636
2637  Args:
2638      x: Tensor or variable.
2639
2640  Returns:
2641      A tensor.
2642  """
2643  return math_ops.round(x)
2644
2645
2646@keras_export('keras.backend.sign')
2647@dispatch.add_dispatch_support
2648@doc_controls.do_not_generate_docs
2649def sign(x):
2650  """Element-wise sign.
2651
2652  Args:
2653      x: Tensor or variable.
2654
2655  Returns:
2656      A tensor.
2657  """
2658  return math_ops.sign(x)
2659
2660
2661@keras_export('keras.backend.pow')
2662@dispatch.add_dispatch_support
2663@doc_controls.do_not_generate_docs
2664def pow(x, a):
2665  """Element-wise exponentiation.
2666
2667  Args:
2668      x: Tensor or variable.
2669      a: Python integer.
2670
2671  Returns:
2672      A tensor.
2673  """
2674  return math_ops.pow(x, a)
2675
2676
2677@keras_export('keras.backend.clip')
2678@dispatch.add_dispatch_support
2679@doc_controls.do_not_generate_docs
2680def clip(x, min_value, max_value):
2681  """Element-wise value clipping.
2682
2683  Args:
2684      x: Tensor or variable.
2685      min_value: Python float, integer, or tensor.
2686      max_value: Python float, integer, or tensor.
2687
2688  Returns:
2689      A tensor.
2690  """
2691  if (isinstance(min_value, (int, float)) and
2692      isinstance(max_value, (int, float))):
2693    if max_value < min_value:
2694      max_value = min_value
2695  if min_value is None:
2696    min_value = -np.inf
2697  if max_value is None:
2698    max_value = np.inf
2699  return clip_ops.clip_by_value(x, min_value, max_value)
2700
2701
2702@keras_export('keras.backend.equal')
2703@dispatch.add_dispatch_support
2704@doc_controls.do_not_generate_docs
2705def equal(x, y):
2706  """Element-wise equality between two tensors.
2707
2708  Args:
2709      x: Tensor or variable.
2710      y: Tensor or variable.
2711
2712  Returns:
2713      A bool tensor.
2714  """
2715  return math_ops.equal(x, y)
2716
2717
2718@keras_export('keras.backend.not_equal')
2719@dispatch.add_dispatch_support
2720@doc_controls.do_not_generate_docs
2721def not_equal(x, y):
2722  """Element-wise inequality between two tensors.
2723
2724  Args:
2725      x: Tensor or variable.
2726      y: Tensor or variable.
2727
2728  Returns:
2729      A bool tensor.
2730  """
2731  return math_ops.not_equal(x, y)
2732
2733
2734@keras_export('keras.backend.greater')
2735@dispatch.add_dispatch_support
2736@doc_controls.do_not_generate_docs
2737def greater(x, y):
2738  """Element-wise truth value of (x > y).
2739
2740  Args:
2741      x: Tensor or variable.
2742      y: Tensor or variable.
2743
2744  Returns:
2745      A bool tensor.
2746  """
2747  return math_ops.greater(x, y)
2748
2749
2750@keras_export('keras.backend.greater_equal')
2751@dispatch.add_dispatch_support
2752@doc_controls.do_not_generate_docs
2753def greater_equal(x, y):
2754  """Element-wise truth value of (x >= y).
2755
2756  Args:
2757      x: Tensor or variable.
2758      y: Tensor or variable.
2759
2760  Returns:
2761      A bool tensor.
2762  """
2763  return math_ops.greater_equal(x, y)
2764
2765
2766@keras_export('keras.backend.less')
2767@dispatch.add_dispatch_support
2768@doc_controls.do_not_generate_docs
2769def less(x, y):
2770  """Element-wise truth value of (x < y).
2771
2772  Args:
2773      x: Tensor or variable.
2774      y: Tensor or variable.
2775
2776  Returns:
2777      A bool tensor.
2778  """
2779  return math_ops.less(x, y)
2780
2781
2782@keras_export('keras.backend.less_equal')
2783@dispatch.add_dispatch_support
2784@doc_controls.do_not_generate_docs
2785def less_equal(x, y):
2786  """Element-wise truth value of (x <= y).
2787
2788  Args:
2789      x: Tensor or variable.
2790      y: Tensor or variable.
2791
2792  Returns:
2793      A bool tensor.
2794  """
2795  return math_ops.less_equal(x, y)
2796
2797
2798@keras_export('keras.backend.maximum')
2799@dispatch.add_dispatch_support
2800@doc_controls.do_not_generate_docs
2801def maximum(x, y):
2802  """Element-wise maximum of two tensors.
2803
2804  Args:
2805      x: Tensor or variable.
2806      y: Tensor or variable.
2807
2808  Returns:
2809      A tensor with the element wise maximum value(s) of `x` and `y`.
2810
2811  Examples:
2812
2813  >>> x = tf.Variable([[1, 2], [3, 4]])
2814  >>> y = tf.Variable([[2, 1], [0, -1]])
2815  >>> m = tf.keras.backend.maximum(x, y)
2816  >>> m
2817  <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
2818  array([[2, 2],
2819         [3, 4]], dtype=int32)>
2820  """
2821  return math_ops.maximum(x, y)
2822
2823
2824@keras_export('keras.backend.minimum')
2825@dispatch.add_dispatch_support
2826@doc_controls.do_not_generate_docs
2827def minimum(x, y):
2828  """Element-wise minimum of two tensors.
2829
2830  Args:
2831      x: Tensor or variable.
2832      y: Tensor or variable.
2833
2834  Returns:
2835      A tensor.
2836  """
2837  return math_ops.minimum(x, y)
2838
2839
2840@keras_export('keras.backend.sin')
2841@dispatch.add_dispatch_support
2842@doc_controls.do_not_generate_docs
2843def sin(x):
2844  """Computes sin of x element-wise.
2845
2846  Args:
2847      x: Tensor or variable.
2848
2849  Returns:
2850      A tensor.
2851  """
2852  return math_ops.sin(x)
2853
2854
2855@keras_export('keras.backend.cos')
2856@dispatch.add_dispatch_support
2857@doc_controls.do_not_generate_docs
2858def cos(x):
2859  """Computes cos of x element-wise.
2860
2861  Args:
2862      x: Tensor or variable.
2863
2864  Returns:
2865      A tensor.
2866  """
2867  return math_ops.cos(x)
2868
2869
2870def _regular_normalize_batch_in_training(x,
2871                                         gamma,
2872                                         beta,
2873                                         reduction_axes,
2874                                         epsilon=1e-3):
2875  """Non-fused version of `normalize_batch_in_training`.
2876
2877  Args:
2878      x: Input tensor or variable.
2879      gamma: Tensor by which to scale the input.
2880      beta: Tensor with which to center the input.
2881      reduction_axes: iterable of integers,
2882          axes over which to normalize.
2883      epsilon: Fuzz factor.
2884
2885  Returns:
2886      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2887  """
2888  mean, var = nn.moments(x, reduction_axes, None, None, False)
2889  normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
2890  return normed, mean, var
2891
2892
2893def _broadcast_normalize_batch_in_training(x,
2894                                           gamma,
2895                                           beta,
2896                                           reduction_axes,
2897                                           epsilon=1e-3):
2898  """Non-fused, broadcast version of `normalize_batch_in_training`.
2899
2900  Args:
2901      x: Input tensor or variable.
2902      gamma: Tensor by which to scale the input.
2903      beta: Tensor with which to center the input.
2904      reduction_axes: iterable of integers,
2905          axes over which to normalize.
2906      epsilon: Fuzz factor.
2907
2908  Returns:
2909      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2910  """
2911  mean, var = nn.moments(x, reduction_axes, None, None, False)
2912  target_shape = []
2913  for axis in range(ndim(x)):
2914    if axis in reduction_axes:
2915      target_shape.append(1)
2916    else:
2917      target_shape.append(array_ops.shape(x)[axis])
2918  target_shape = array_ops.stack(target_shape)
2919
2920  broadcast_mean = array_ops.reshape(mean, target_shape)
2921  broadcast_var = array_ops.reshape(var, target_shape)
2922  if gamma is None:
2923    broadcast_gamma = None
2924  else:
2925    broadcast_gamma = array_ops.reshape(gamma, target_shape)
2926  if beta is None:
2927    broadcast_beta = None
2928  else:
2929    broadcast_beta = array_ops.reshape(beta, target_shape)
2930
2931  normed = nn.batch_normalization(x, broadcast_mean, broadcast_var,
2932                                  broadcast_beta, broadcast_gamma, epsilon)
2933  return normed, mean, var
2934
2935
2936def _fused_normalize_batch_in_training(x,
2937                                       gamma,
2938                                       beta,
2939                                       reduction_axes,
2940                                       epsilon=1e-3):
2941  """Fused version of `normalize_batch_in_training`.
2942
2943  Args:
2944      x: Input tensor or variable.
2945      gamma: Tensor by which to scale the input.
2946      beta: Tensor with which to center the input.
2947      reduction_axes: iterable of integers,
2948          axes over which to normalize.
2949      epsilon: Fuzz factor.
2950
2951  Returns:
2952      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2953  """
2954  if list(reduction_axes) == [0, 1, 2]:
2955    normalization_axis = 3
2956    tf_data_format = 'NHWC'
2957  else:
2958    normalization_axis = 1
2959    tf_data_format = 'NCHW'
2960
2961  if gamma is None:
2962    gamma = constant_op.constant(
2963        1.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
2964  if beta is None:
2965    beta = constant_op.constant(
2966        0.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
2967
2968  return nn.fused_batch_norm(
2969      x, gamma, beta, epsilon=epsilon, data_format=tf_data_format)
2970
2971
2972@keras_export('keras.backend.normalize_batch_in_training')
2973@doc_controls.do_not_generate_docs
2974def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
2975  """Computes mean and std for batch then apply batch_normalization on batch.
2976
2977  Args:
2978      x: Input tensor or variable.
2979      gamma: Tensor by which to scale the input.
2980      beta: Tensor with which to center the input.
2981      reduction_axes: iterable of integers,
2982          axes over which to normalize.
2983      epsilon: Fuzz factor.
2984
2985  Returns:
2986      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2987  """
2988  if ndim(x) == 4 and list(reduction_axes) in [[0, 1, 2], [0, 2, 3]]:
2989    if not _has_nchw_support() and list(reduction_axes) == [0, 2, 3]:
2990      return _broadcast_normalize_batch_in_training(
2991          x, gamma, beta, reduction_axes, epsilon=epsilon)
2992    return _fused_normalize_batch_in_training(
2993        x, gamma, beta, reduction_axes, epsilon=epsilon)
2994  else:
2995    if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
2996      return _regular_normalize_batch_in_training(
2997          x, gamma, beta, reduction_axes, epsilon=epsilon)
2998    else:
2999      return _broadcast_normalize_batch_in_training(
3000          x, gamma, beta, reduction_axes, epsilon=epsilon)
3001
3002
3003@keras_export('keras.backend.batch_normalization')
3004@dispatch.add_dispatch_support
3005@doc_controls.do_not_generate_docs
3006def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
3007  """Applies batch normalization on x given mean, var, beta and gamma.
3008
3009  I.e. returns:
3010  `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
3011
3012  Args:
3013      x: Input tensor or variable.
3014      mean: Mean of batch.
3015      var: Variance of batch.
3016      beta: Tensor with which to center the input.
3017      gamma: Tensor by which to scale the input.
3018      axis: Integer, the axis that should be normalized.
3019          (typically the features axis).
3020      epsilon: Fuzz factor.
3021
3022  Returns:
3023      A tensor.
3024  """
3025  if ndim(x) == 4:
3026    # The CPU implementation of `fused_batch_norm` only supports NHWC
3027    if axis == 1 or axis == -3:
3028      tf_data_format = 'NCHW'
3029    elif axis == 3 or axis == -1:
3030      tf_data_format = 'NHWC'
3031    else:
3032      tf_data_format = None
3033
3034    if (tf_data_format == 'NHWC' or
3035        tf_data_format == 'NCHW' and _has_nchw_support()):
3036      # The mean / var / beta / gamma tensors may be broadcasted
3037      # so they may have extra axes of size 1, which should be squeezed.
3038      if ndim(mean) > 1:
3039        mean = array_ops.reshape(mean, [-1])
3040      if ndim(var) > 1:
3041        var = array_ops.reshape(var, [-1])
3042      if beta is None:
3043        beta = zeros_like(mean)
3044      elif ndim(beta) > 1:
3045        beta = array_ops.reshape(beta, [-1])
3046      if gamma is None:
3047        gamma = ones_like(mean)
3048      elif ndim(gamma) > 1:
3049        gamma = array_ops.reshape(gamma, [-1])
3050    y, _, _ = nn.fused_batch_norm(
3051        x,
3052        gamma,
3053        beta,
3054        epsilon=epsilon,
3055        mean=mean,
3056        variance=var,
3057        data_format=tf_data_format,
3058        is_training=False
3059    )
3060    return y
3061  return nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
3062
3063
3064# SHAPE OPERATIONS
3065
3066
3067@keras_export('keras.backend.concatenate')
3068@dispatch.add_dispatch_support
3069@doc_controls.do_not_generate_docs
3070def concatenate(tensors, axis=-1):
3071  """Concatenates a list of tensors alongside the specified axis.
3072
3073  Args:
3074      tensors: list of tensors to concatenate.
3075      axis: concatenation axis.
3076
3077  Returns:
3078      A tensor.
3079
3080  Example:
3081
3082      >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
3083      >>> b = tf.constant([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
3084      >>> tf.keras.backend.concatenate((a, b), axis=-1)
3085      <tf.Tensor: shape=(3, 6), dtype=int32, numpy=
3086      array([[ 1,  2,  3, 10, 20, 30],
3087             [ 4,  5,  6, 40, 50, 60],
3088             [ 7,  8,  9, 70, 80, 90]], dtype=int32)>
3089
3090  """
3091  if axis < 0:
3092    rank = ndim(tensors[0])
3093    if rank:
3094      axis %= rank
3095    else:
3096      axis = 0
3097
3098  if py_all(is_sparse(x) for x in tensors):
3099    return sparse_ops.sparse_concat(axis, tensors)
3100  elif py_all(isinstance(x, ragged_tensor.RaggedTensor) for x in tensors):
3101    return array_ops.concat(tensors, axis)
3102  else:
3103    return array_ops.concat([to_dense(x) for x in tensors], axis)
3104
3105
3106@keras_export('keras.backend.reshape')
3107@dispatch.add_dispatch_support
3108@doc_controls.do_not_generate_docs
3109def reshape(x, shape):
3110  """Reshapes a tensor to the specified shape.
3111
3112  Args:
3113      x: Tensor or variable.
3114      shape: Target shape tuple.
3115
3116  Returns:
3117      A tensor.
3118
3119  Example:
3120
3121    >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
3122    >>> a
3123    <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
3124    array([[ 1,  2,  3],
3125           [ 4,  5,  6],
3126           [ 7,  8,  9],
3127           [10, 11, 12]], dtype=int32)>
3128    >>> tf.keras.backend.reshape(a, shape=(2, 6))
3129    <tf.Tensor: shape=(2, 6), dtype=int32, numpy=
3130    array([[ 1,  2,  3,  4,  5,  6],
3131           [ 7,  8,  9, 10, 11, 12]], dtype=int32)>
3132
3133  """
3134  return array_ops.reshape(x, shape)
3135
3136
3137@keras_export('keras.backend.permute_dimensions')
3138@dispatch.add_dispatch_support
3139@doc_controls.do_not_generate_docs
3140def permute_dimensions(x, pattern):
3141  """Permutes axes in a tensor.
3142
3143  Args:
3144      x: Tensor or variable.
3145      pattern: A tuple of
3146          dimension indices, e.g. `(0, 2, 1)`.
3147
3148  Returns:
3149      A tensor.
3150
3151  Example:
3152
3153    >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
3154    >>> a
3155    <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
3156    array([[ 1,  2,  3],
3157           [ 4,  5,  6],
3158           [ 7,  8,  9],
3159           [10, 11, 12]], dtype=int32)>
3160    >>> tf.keras.backend.permute_dimensions(a, pattern=(1, 0))
3161    <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
3162    array([[ 1,  4,  7, 10],
3163           [ 2,  5,  8, 11],
3164           [ 3,  6,  9, 12]], dtype=int32)>
3165
3166  """
3167  return array_ops.transpose(x, perm=pattern)
3168
3169
3170@keras_export('keras.backend.resize_images')
3171@dispatch.add_dispatch_support
3172@doc_controls.do_not_generate_docs
3173def resize_images(x, height_factor, width_factor, data_format,
3174                  interpolation='nearest'):
3175  """Resizes the images contained in a 4D tensor.
3176
3177  Args:
3178      x: Tensor or variable to resize.
3179      height_factor: Positive integer.
3180      width_factor: Positive integer.
3181      data_format: One of `"channels_first"`, `"channels_last"`.
3182      interpolation: A string, one of `nearest` or `bilinear`.
3183
3184  Returns:
3185      A tensor.
3186
3187  Raises:
3188      ValueError: in case of incorrect value for
3189        `data_format` or `interpolation`.
3190  """
3191  if data_format == 'channels_first':
3192    rows, cols = 2, 3
3193  elif data_format == 'channels_last':
3194    rows, cols = 1, 2
3195  else:
3196    raise ValueError('Invalid `data_format` argument: %s' % (data_format,))
3197
3198  original_shape = int_shape(x)
3199  new_shape = array_ops.shape(x)[rows:cols + 1]
3200  new_shape *= constant_op.constant(
3201      np.array([height_factor, width_factor], dtype='int32'))
3202
3203  if data_format == 'channels_first':
3204    x = permute_dimensions(x, [0, 2, 3, 1])
3205  if interpolation == 'nearest':
3206    x = image_ops.resize_images_v2(
3207        x, new_shape, method=image_ops.ResizeMethod.NEAREST_NEIGHBOR)
3208  elif interpolation == 'bilinear':
3209    x = image_ops.resize_images_v2(x, new_shape,
3210                                   method=image_ops.ResizeMethod.BILINEAR)
3211  else:
3212    raise ValueError('interpolation should be one '
3213                     'of "nearest" or "bilinear".')
3214  if data_format == 'channels_first':
3215    x = permute_dimensions(x, [0, 3, 1, 2])
3216
3217  if original_shape[rows] is None:
3218    new_height = None
3219  else:
3220    new_height = original_shape[rows] * height_factor
3221
3222  if original_shape[cols] is None:
3223    new_width = None
3224  else:
3225    new_width = original_shape[cols] * width_factor
3226
3227  if data_format == 'channels_first':
3228    output_shape = (None, None, new_height, new_width)
3229  else:
3230    output_shape = (None, new_height, new_width, None)
3231  x.set_shape(output_shape)
3232  return x
3233
3234
3235@keras_export('keras.backend.resize_volumes')
3236@dispatch.add_dispatch_support
3237@doc_controls.do_not_generate_docs
3238def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
3239  """Resizes the volume contained in a 5D tensor.
3240
3241  Args:
3242      x: Tensor or variable to resize.
3243      depth_factor: Positive integer.
3244      height_factor: Positive integer.
3245      width_factor: Positive integer.
3246      data_format: One of `"channels_first"`, `"channels_last"`.
3247
3248  Returns:
3249      A tensor.
3250
3251  Raises:
3252      ValueError: if `data_format` is neither
3253          `channels_last` or `channels_first`.
3254  """
3255  if data_format == 'channels_first':
3256    output = repeat_elements(x, depth_factor, axis=2)
3257    output = repeat_elements(output, height_factor, axis=3)
3258    output = repeat_elements(output, width_factor, axis=4)
3259    return output
3260  elif data_format == 'channels_last':
3261    output = repeat_elements(x, depth_factor, axis=1)
3262    output = repeat_elements(output, height_factor, axis=2)
3263    output = repeat_elements(output, width_factor, axis=3)
3264    return output
3265  else:
3266    raise ValueError('Invalid data_format: ' + str(data_format))
3267
3268
3269@keras_export('keras.backend.repeat_elements')
3270@dispatch.add_dispatch_support
3271@doc_controls.do_not_generate_docs
3272def repeat_elements(x, rep, axis):
3273  """Repeats the elements of a tensor along an axis, like `np.repeat`.
3274
3275  If `x` has shape `(s1, s2, s3)` and `axis` is `1`, the output
3276  will have shape `(s1, s2 * rep, s3)`.
3277
3278  Args:
3279      x: Tensor or variable.
3280      rep: Python integer, number of times to repeat.
3281      axis: Axis along which to repeat.
3282
3283  Returns:
3284      A tensor.
3285
3286  Example:
3287
3288      >>> b = tf.constant([1, 2, 3])
3289      >>> tf.keras.backend.repeat_elements(b, rep=2, axis=0)
3290      <tf.Tensor: shape=(6,), dtype=int32,
3291          numpy=array([1, 1, 2, 2, 3, 3], dtype=int32)>
3292
3293  """
3294  x_shape = x.shape.as_list()
3295  # For static axis
3296  if x_shape[axis] is not None:
3297    # slices along the repeat axis
3298    splits = array_ops.split(value=x,
3299                             num_or_size_splits=x_shape[axis],
3300                             axis=axis)
3301    # repeat each slice the given number of reps
3302    x_rep = [s for s in splits for _ in range(rep)]
3303    return concatenate(x_rep, axis)
3304
3305  # Here we use tf.tile to mimic behavior of np.repeat so that
3306  # we can handle dynamic shapes (that include None).
3307  # To do that, we need an auxiliary axis to repeat elements along
3308  # it and then merge them along the desired axis.
3309
3310  # Repeating
3311  auxiliary_axis = axis + 1
3312  x_shape = array_ops.shape(x)
3313  x_rep = array_ops.expand_dims(x, axis=auxiliary_axis)
3314  reps = np.ones(len(x.shape) + 1)
3315  reps[auxiliary_axis] = rep
3316  x_rep = array_ops.tile(x_rep, reps)
3317
3318  # Merging
3319  reps = np.delete(reps, auxiliary_axis)
3320  reps[axis] = rep
3321  reps = array_ops.constant(reps, dtype='int32')
3322  x_shape *= reps
3323  x_rep = array_ops.reshape(x_rep, x_shape)
3324
3325  # Fix shape representation
3326  x_shape = x.shape.as_list()
3327  x_rep.set_shape(x_shape)
3328  x_rep._keras_shape = tuple(x_shape)
3329  return x_rep
3330
3331
3332@keras_export('keras.backend.repeat')
3333@dispatch.add_dispatch_support
3334@doc_controls.do_not_generate_docs
3335def repeat(x, n):
3336  """Repeats a 2D tensor.
3337
3338  if `x` has shape (samples, dim) and `n` is `2`,
3339  the output will have shape `(samples, 2, dim)`.
3340
3341  Args:
3342      x: Tensor or variable.
3343      n: Python integer, number of times to repeat.
3344
3345  Returns:
3346      A tensor.
3347
3348  Example:
3349
3350      >>> b = tf.constant([[1, 2], [3, 4]])
3351      >>> b
3352      <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
3353      array([[1, 2],
3354             [3, 4]], dtype=int32)>
3355      >>> tf.keras.backend.repeat(b, n=2)
3356      <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
3357      array([[[1, 2],
3358              [1, 2]],
3359             [[3, 4],
3360              [3, 4]]], dtype=int32)>
3361
3362  """
3363  assert ndim(x) == 2
3364  x = array_ops.expand_dims(x, 1)
3365  pattern = array_ops.stack([1, n, 1])
3366  return array_ops.tile(x, pattern)
3367
3368
3369@keras_export('keras.backend.arange')
3370@dispatch.add_dispatch_support
3371@doc_controls.do_not_generate_docs
3372def arange(start, stop=None, step=1, dtype='int32'):
3373  """Creates a 1D tensor containing a sequence of integers.
3374
3375  The function arguments use the same convention as
3376  Theano's arange: if only one argument is provided,
3377  it is in fact the "stop" argument and "start" is 0.
3378
3379  The default type of the returned tensor is `'int32'` to
3380  match TensorFlow's default.
3381
3382  Args:
3383      start: Start value.
3384      stop: Stop value.
3385      step: Difference between two successive values.
3386      dtype: Integer dtype to use.
3387
3388  Returns:
3389      An integer tensor.
3390
3391  Example:
3392
3393      >>> tf.keras.backend.arange(start=0, stop=10, step=1.5)
3394      <tf.Tensor: shape=(7,), dtype=float32,
3395          numpy=array([0. , 1.5, 3. , 4.5, 6. , 7.5, 9. ], dtype=float32)>
3396
3397
3398
3399  """
3400  # Match the behavior of numpy and Theano by returning an empty sequence.
3401  if stop is None and start < 0:
3402    start = 0
3403  result = math_ops.range(start, limit=stop, delta=step, name='arange')
3404  if dtype != 'int32':
3405    result = cast(result, dtype)
3406  return result
3407
3408
3409@keras_export('keras.backend.tile')
3410@dispatch.add_dispatch_support
3411@doc_controls.do_not_generate_docs
3412def tile(x, n):
3413  """Creates a tensor by tiling `x` by `n`.
3414
3415  Args:
3416      x: A tensor or variable
3417      n: A list of integer. The length must be the same as the number of
3418          dimensions in `x`.
3419
3420  Returns:
3421      A tiled tensor.
3422  """
3423  if isinstance(n, int):
3424    n = [n]
3425  return array_ops.tile(x, n)
3426
3427
3428@keras_export('keras.backend.flatten')
3429@dispatch.add_dispatch_support
3430@doc_controls.do_not_generate_docs
3431def flatten(x):
3432  """Flatten a tensor.
3433
3434  Args:
3435      x: A tensor or variable.
3436
3437  Returns:
3438      A tensor, reshaped into 1-D
3439
3440  Example:
3441
3442      >>> b = tf.constant([[1, 2], [3, 4]])
3443      >>> b
3444      <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
3445      array([[1, 2],
3446             [3, 4]], dtype=int32)>
3447      >>> tf.keras.backend.flatten(b)
3448      <tf.Tensor: shape=(4,), dtype=int32,
3449          numpy=array([1, 2, 3, 4], dtype=int32)>
3450
3451  """
3452  return array_ops.reshape(x, [-1])
3453
3454
3455@keras_export('keras.backend.batch_flatten')
3456@dispatch.add_dispatch_support
3457@doc_controls.do_not_generate_docs
3458def batch_flatten(x):
3459  """Turn a nD tensor into a 2D tensor with same 0th dimension.
3460
3461  In other words, it flattens each data samples of a batch.
3462
3463  Args:
3464      x: A tensor or variable.
3465
3466  Returns:
3467      A tensor.
3468
3469  Examples:
3470    Flattening a 3D tensor to 2D by collapsing the last dimension.
3471
3472  >>> x_batch = tf.keras.backend.ones(shape=(2, 3, 4, 5))
3473  >>> x_batch_flatten = batch_flatten(x_batch)
3474  >>> tf.keras.backend.int_shape(x_batch_flatten)
3475  (2, 60)
3476
3477  """
3478  x = array_ops.reshape(x, array_ops.stack([-1, prod(shape(x)[1:])]))
3479  return x
3480
3481
3482@keras_export('keras.backend.expand_dims')
3483@dispatch.add_dispatch_support
3484@doc_controls.do_not_generate_docs
3485def expand_dims(x, axis=-1):
3486  """Adds a 1-sized dimension at index "axis".
3487
3488  Args:
3489      x: A tensor or variable.
3490      axis: Position where to add a new axis.
3491
3492  Returns:
3493      A tensor with expanded dimensions.
3494  """
3495  return array_ops.expand_dims(x, axis)
3496
3497
3498@keras_export('keras.backend.squeeze')
3499@dispatch.add_dispatch_support
3500@doc_controls.do_not_generate_docs
3501def squeeze(x, axis):
3502  """Removes a 1-dimension from the tensor at index "axis".
3503
3504  Args:
3505      x: A tensor or variable.
3506      axis: Axis to drop.
3507
3508  Returns:
3509      A tensor with the same data as `x` but reduced dimensions.
3510  """
3511  return array_ops.squeeze(x, [axis])
3512
3513
3514@keras_export('keras.backend.temporal_padding')
3515@dispatch.add_dispatch_support
3516@doc_controls.do_not_generate_docs
3517def temporal_padding(x, padding=(1, 1)):
3518  """Pads the middle dimension of a 3D tensor.
3519
3520  Args:
3521      x: Tensor or variable.
3522      padding: Tuple of 2 integers, how many zeros to
3523          add at the start and end of dim 1.
3524
3525  Returns:
3526      A padded 3D tensor.
3527  """
3528  assert len(padding) == 2
3529  pattern = [[0, 0], [padding[0], padding[1]], [0, 0]]
3530  return array_ops.pad(x, pattern)
3531
3532
3533@keras_export('keras.backend.spatial_2d_padding')
3534@dispatch.add_dispatch_support
3535@doc_controls.do_not_generate_docs
3536def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
3537  """Pads the 2nd and 3rd dimensions of a 4D tensor.
3538
3539  Args:
3540      x: Tensor or variable.
3541      padding: Tuple of 2 tuples, padding pattern.
3542      data_format: One of `channels_last` or `channels_first`.
3543
3544  Returns:
3545      A padded 4D tensor.
3546
3547  Raises:
3548      ValueError: if `data_format` is neither
3549          `channels_last` or `channels_first`.
3550  """
3551  assert len(padding) == 2
3552  assert len(padding[0]) == 2
3553  assert len(padding[1]) == 2
3554  if data_format is None:
3555    data_format = image_data_format()
3556  if data_format not in {'channels_first', 'channels_last'}:
3557    raise ValueError('Unknown data_format: ' + str(data_format))
3558
3559  if data_format == 'channels_first':
3560    pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])]
3561  else:
3562    pattern = [[0, 0], list(padding[0]), list(padding[1]), [0, 0]]
3563  return array_ops.pad(x, pattern)
3564
3565
3566@keras_export('keras.backend.spatial_3d_padding')
3567@dispatch.add_dispatch_support
3568@doc_controls.do_not_generate_docs
3569def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
3570  """Pads 5D tensor with zeros along the depth, height, width dimensions.
3571
3572  Pads these dimensions with respectively
3573  "padding[0]", "padding[1]" and "padding[2]" zeros left and right.
3574
3575  For 'channels_last' data_format,
3576  the 2nd, 3rd and 4th dimension will be padded.
3577  For 'channels_first' data_format,
3578  the 3rd, 4th and 5th dimension will be padded.
3579
3580  Args:
3581      x: Tensor or variable.
3582      padding: Tuple of 3 tuples, padding pattern.
3583      data_format: One of `channels_last` or `channels_first`.
3584
3585  Returns:
3586      A padded 5D tensor.
3587
3588  Raises:
3589      ValueError: if `data_format` is neither
3590          `channels_last` or `channels_first`.
3591
3592  """
3593  assert len(padding) == 3
3594  assert len(padding[0]) == 2
3595  assert len(padding[1]) == 2
3596  assert len(padding[2]) == 2
3597  if data_format is None:
3598    data_format = image_data_format()
3599  if data_format not in {'channels_first', 'channels_last'}:
3600    raise ValueError('Unknown data_format: ' + str(data_format))
3601
3602  if data_format == 'channels_first':
3603    pattern = [[0, 0], [0, 0], [padding[0][0], padding[0][1]],
3604               [padding[1][0], padding[1][1]], [padding[2][0], padding[2][1]]]
3605  else:
3606    pattern = [[0, 0], [padding[0][0], padding[0][1]],
3607               [padding[1][0], padding[1][1]], [padding[2][0],
3608                                                padding[2][1]], [0, 0]]
3609  return array_ops.pad(x, pattern)
3610
3611
3612@keras_export('keras.backend.stack')
3613@dispatch.add_dispatch_support
3614@doc_controls.do_not_generate_docs
3615def stack(x, axis=0):
3616  """Stacks a list of rank `R` tensors into a rank `R+1` tensor.
3617
3618  Args:
3619      x: List of tensors.
3620      axis: Axis along which to perform stacking.
3621
3622  Returns:
3623      A tensor.
3624
3625  Example:
3626
3627      >>> a = tf.constant([[1, 2],[3, 4]])
3628      >>> b = tf.constant([[10, 20],[30, 40]])
3629      >>> tf.keras.backend.stack((a, b))
3630      <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
3631      array([[[ 1,  2],
3632              [ 3,  4]],
3633             [[10, 20],
3634              [30, 40]]], dtype=int32)>
3635
3636  """
3637  return array_ops.stack(x, axis=axis)
3638
3639
3640@keras_export('keras.backend.one_hot')
3641@dispatch.add_dispatch_support
3642@doc_controls.do_not_generate_docs
3643def one_hot(indices, num_classes):
3644  """Computes the one-hot representation of an integer tensor.
3645
3646  Args:
3647      indices: nD integer tensor of shape
3648          `(batch_size, dim1, dim2, ... dim(n-1))`
3649      num_classes: Integer, number of classes to consider.
3650
3651  Returns:
3652      (n + 1)D one hot representation of the input
3653      with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
3654
3655  Returns:
3656      The one-hot tensor.
3657  """
3658  return array_ops.one_hot(indices, depth=num_classes, axis=-1)
3659
3660
3661@keras_export('keras.backend.reverse')
3662@dispatch.add_dispatch_support
3663@doc_controls.do_not_generate_docs
3664def reverse(x, axes):
3665  """Reverse a tensor along the specified axes.
3666
3667  Args:
3668      x: Tensor to reverse.
3669      axes: Integer or iterable of integers.
3670          Axes to reverse.
3671
3672  Returns:
3673      A tensor.
3674  """
3675  if isinstance(axes, int):
3676    axes = [axes]
3677  return array_ops.reverse(x, axes)
3678
3679
3680# VALUE MANIPULATION
3681_VALUE_SET_CODE_STRING = """
3682  >>> K = tf.keras.backend  # Common keras convention
3683  >>> v = K.variable(1.)
3684
3685  >>> # reassign
3686  >>> K.set_value(v, 2.)
3687  >>> print(K.get_value(v))
3688  2.0
3689
3690  >>> # increment
3691  >>> K.set_value(v, K.get_value(v) + 1)
3692  >>> print(K.get_value(v))
3693  3.0
3694
3695  Variable semantics in TensorFlow 2 are eager execution friendly. The above
3696  code is roughly equivalent to:
3697
3698  >>> v = tf.Variable(1.)
3699
3700  >>> v.assign(2.)
3701  >>> print(v.numpy())
3702  2.0
3703
3704  >>> v.assign_add(1.)
3705  >>> print(v.numpy())
3706  3.0"""[3:]  # Prune first newline and indent to match the docstring template.
3707
3708
3709@keras_export('keras.backend.get_value')
3710@doc_controls.do_not_generate_docs
3711def get_value(x):
3712  """Returns the value of a variable.
3713
3714  `backend.get_value` is the compliment of `backend.set_value`, and provides
3715  a generic interface for reading from variables while abstracting away the
3716  differences between TensorFlow 1.x and 2.x semantics.
3717
3718  {snippet}
3719
3720  Args:
3721      x: input variable.
3722
3723  Returns:
3724      A Numpy array.
3725  """
3726  if not tensor_util.is_tf_type(x):
3727    return x
3728  if context.executing_eagerly() or isinstance(x, ops.EagerTensor):
3729    return x.numpy()
3730  if not getattr(x, '_in_graph_mode', True):
3731    # This is a variable which was created in an eager context, but is being
3732    # evaluated from a Graph.
3733    with context.eager_mode():
3734      return x.numpy()
3735
3736  if ops.executing_eagerly_outside_functions():
3737    # This method of evaluating works inside the Keras FuncGraph.
3738    return eval_in_eager_or_function(x)
3739
3740  with x.graph.as_default():
3741    return x.eval(session=get_session((x,)))
3742
3743
3744@keras_export('keras.backend.batch_get_value')
3745@dispatch.add_dispatch_support
3746@doc_controls.do_not_generate_docs
3747def batch_get_value(tensors):
3748  """Returns the value of more than one tensor variable.
3749
3750  Args:
3751      tensors: list of ops to run.
3752
3753  Returns:
3754      A list of Numpy arrays.
3755
3756  Raises:
3757      RuntimeError: If this method is called inside defun.
3758  """
3759  if context.executing_eagerly():
3760    return [x.numpy() for x in tensors]
3761  elif ops.inside_function():  # pylint: disable=protected-access
3762    raise RuntimeError('Cannot get value inside Tensorflow graph function.')
3763  if tensors:
3764    return get_session(tensors).run(tensors)
3765  else:
3766    return []
3767
3768
3769@keras_export('keras.backend.set_value')
3770@doc_controls.do_not_generate_docs
3771def set_value(x, value):
3772  """Sets the value of a variable, from a Numpy array.
3773
3774  `backend.set_value` is the compliment of `backend.get_value`, and provides
3775  a generic interface for assigning to variables while abstracting away the
3776  differences between TensorFlow 1.x and 2.x semantics.
3777
3778  {snippet}
3779
3780  Args:
3781      x: Variable to set to a new value.
3782      value: Value to set the tensor to, as a Numpy array
3783          (of the same shape).
3784  """
3785  value = np.asarray(value, dtype=dtype(x))
3786  if ops.executing_eagerly_outside_functions():
3787    x.assign(value)
3788  else:
3789    with get_graph().as_default():
3790      tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
3791      if hasattr(x, '_assign_placeholder'):
3792        assign_placeholder = x._assign_placeholder
3793        assign_op = x._assign_op
3794      else:
3795        # In order to support assigning weights to resizable variables in
3796        # Keras, we make a placeholder with the correct number of dimensions
3797        # but with None in each dimension. This way, we can assign weights
3798        # of any size (as long as they have the correct dimensionality).
3799        placeholder_shape = tensor_shape.TensorShape([None] * value.ndim)
3800        assign_placeholder = array_ops.placeholder(
3801            tf_dtype, shape=placeholder_shape)
3802        assign_op = x.assign(assign_placeholder)
3803        x._assign_placeholder = assign_placeholder
3804        x._assign_op = assign_op
3805      get_session().run(assign_op, feed_dict={assign_placeholder: value})
3806
3807
3808@keras_export('keras.backend.batch_set_value')
3809@dispatch.add_dispatch_support
3810@doc_controls.do_not_generate_docs
3811def batch_set_value(tuples):
3812  """Sets the values of many tensor variables at once.
3813
3814  Args:
3815      tuples: a list of tuples `(tensor, value)`.
3816          `value` should be a Numpy array.
3817  """
3818  if ops.executing_eagerly_outside_functions():
3819    for x, value in tuples:
3820      x.assign(np.asarray(value, dtype=dtype_numpy(x)))
3821  else:
3822    with get_graph().as_default():
3823      if tuples:
3824        assign_ops = []
3825        feed_dict = {}
3826        for x, value in tuples:
3827          value = np.asarray(value, dtype=dtype(x))
3828          tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
3829          if hasattr(x, '_assign_placeholder'):
3830            assign_placeholder = x._assign_placeholder
3831            assign_op = x._assign_op
3832          else:
3833            # In order to support assigning weights to resizable variables in
3834            # Keras, we make a placeholder with the correct number of dimensions
3835            # but with None in each dimension. This way, we can assign weights
3836            # of any size (as long as they have the correct dimensionality).
3837            placeholder_shape = tensor_shape.TensorShape([None] * value.ndim)
3838            assign_placeholder = array_ops.placeholder(
3839                tf_dtype, shape=placeholder_shape)
3840            assign_op = x.assign(assign_placeholder)
3841            x._assign_placeholder = assign_placeholder
3842            x._assign_op = assign_op
3843          assign_ops.append(assign_op)
3844          feed_dict[assign_placeholder] = value
3845        get_session().run(assign_ops, feed_dict=feed_dict)
3846
3847
3848get_value.__doc__ = get_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
3849set_value.__doc__ = set_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
3850
3851
3852@keras_export('keras.backend.print_tensor')
3853@dispatch.add_dispatch_support
3854@doc_controls.do_not_generate_docs
3855def print_tensor(x, message=''):
3856  """Prints `message` and the tensor value when evaluated.
3857
3858  Note that `print_tensor` returns a new tensor identical to `x`
3859  which should be used in the following code. Otherwise the
3860  print operation is not taken into account during evaluation.
3861
3862  Example:
3863
3864  >>> x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
3865  >>> tf.keras.backend.print_tensor(x)
3866  <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
3867    array([[1., 2.],
3868           [3., 4.]], dtype=float32)>
3869
3870  Args:
3871      x: Tensor to print.
3872      message: Message to print jointly with the tensor.
3873
3874  Returns:
3875      The same tensor `x`, unchanged.
3876  """
3877  if isinstance(x, ops.Tensor) and hasattr(x, 'graph'):
3878    with get_graph().as_default():
3879      op = logging_ops.print_v2(message, x, output_stream=sys.stdout)
3880      with ops.control_dependencies([op]):
3881        return array_ops.identity(x)
3882  else:
3883    logging_ops.print_v2(message, x, output_stream=sys.stdout)
3884    return x
3885
3886# GRAPH MANIPULATION
3887
3888
3889class GraphExecutionFunction(object):
3890  """Runs a computation graph.
3891
3892  It's possible to pass arguments to `tf.Session.run()` via `session_kwargs`.
3893  In particular additional operations via `fetches` argument and additional
3894  tensor substitutions via `feed_dict` arguments. Note that given
3895  substitutions are merged with substitutions from `inputs`. Even though
3896  `feed_dict` is passed once in the constructor (called in `model.compile()`)
3897  we can modify the values in the dictionary. Through this feed_dict we can
3898  provide additional substitutions besides Keras inputs.
3899
3900  Args:
3901      inputs: Feed placeholders to the computation graph.
3902      outputs: Output tensors to fetch.
3903      updates: Additional update ops to be run at function call.
3904      name: A name to help users identify what this function does.
3905      session_kwargs: Arguments to `tf.Session.run()`:
3906                      `fetches`, `feed_dict`, `options`, `run_metadata`.
3907  """
3908
3909  def __init__(self, inputs, outputs, updates=None, name=None,
3910               **session_kwargs):
3911    updates = updates or []
3912    if not isinstance(updates, (list, tuple)):
3913      raise TypeError('`updates` in a Keras backend function '
3914                      'should be a list or tuple.')
3915
3916    self._inputs_structure = inputs
3917    self.inputs = nest.flatten(inputs, expand_composites=True)
3918    self._outputs_structure = outputs
3919    self.outputs = cast_variables_to_tensor(
3920        nest.flatten(outputs, expand_composites=True))
3921    # TODO(b/127668432): Consider using autograph to generate these
3922    # dependencies in call.
3923    # Index 0 = total loss or model output for `predict`.
3924    with ops.control_dependencies([self.outputs[0]]):
3925      updates_ops = []
3926      for update in updates:
3927        if isinstance(update, tuple):
3928          p, new_p = update
3929          updates_ops.append(state_ops.assign(p, new_p))
3930        else:
3931          # assumed already an op
3932          updates_ops.append(update)
3933      self.updates_op = control_flow_ops.group(*updates_ops)
3934    self.name = name
3935    # additional tensor substitutions
3936    self.feed_dict = session_kwargs.pop('feed_dict', None)
3937    # additional operations
3938    self.fetches = session_kwargs.pop('fetches', [])
3939    if not isinstance(self.fetches, list):
3940      self.fetches = [self.fetches]
3941    self.run_options = session_kwargs.pop('options', None)
3942    self.run_metadata = session_kwargs.pop('run_metadata', None)
3943    # The main use case of `fetches` being passed to a model is the ability
3944    # to run custom updates
3945    # This requires us to wrap fetches in `identity` ops.
3946    self.fetches = [array_ops.identity(x) for x in self.fetches]
3947    self.session_kwargs = session_kwargs
3948    # This mapping keeps track of the function that should receive the
3949    # output from a fetch in `fetches`: { fetch: function(fetch_output) }
3950    # A Callback can use this to register a function with access to the
3951    # output values for a fetch it added.
3952    self.fetch_callbacks = {}
3953
3954    if session_kwargs:
3955      raise ValueError('Some keys in session_kwargs are not supported at this '
3956                       'time: %s' % (session_kwargs.keys(),))
3957
3958    self._callable_fn = None
3959    self._feed_arrays = None
3960    self._feed_symbols = None
3961    self._symbol_vals = None
3962    self._fetches = None
3963    self._session = None
3964
3965  def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session):
3966    """Generates a callable that runs the graph.
3967
3968    Args:
3969      feed_arrays: List of input tensors to be fed Numpy arrays at runtime.
3970      feed_symbols: List of input tensors to be fed symbolic tensors at runtime.
3971      symbol_vals: List of symbolic tensors to be fed to `feed_symbols`.
3972      session: Session to use to generate the callable.
3973
3974    Returns:
3975      Function that runs the graph according to the above options.
3976    """
3977    # Prepare callable options.
3978    callable_opts = config_pb2.CallableOptions()
3979    # Handle external-data feed.
3980    for x in feed_arrays:
3981      callable_opts.feed.append(x.name)
3982    if self.feed_dict:
3983      for key in sorted(self.feed_dict.keys()):
3984        callable_opts.feed.append(key.name)
3985    # Handle symbolic feed.
3986    for x, y in zip(feed_symbols, symbol_vals):
3987      connection = callable_opts.tensor_connection.add()
3988      if x.dtype != y.dtype:
3989        y = math_ops.cast(y, dtype=x.dtype)
3990      from_tensor = ops._as_graph_element(y)
3991      if from_tensor is None:
3992        from_tensor = y
3993      connection.from_tensor = from_tensor.name  # Data tensor
3994      connection.to_tensor = x.name  # Placeholder
3995    # Handle fetches.
3996    for x in self.outputs + self.fetches:
3997      callable_opts.fetch.append(x.name)
3998    # Handle updates.
3999    callable_opts.target.append(self.updates_op.name)
4000    # Handle run_options.
4001    if self.run_options:
4002      callable_opts.run_options.CopyFrom(self.run_options)
4003    # Create callable.
4004    callable_fn = session._make_callable_from_options(callable_opts)
4005    # Cache parameters corresponding to the generated callable, so that
4006    # we can detect future mismatches and refresh the callable.
4007    self._callable_fn = callable_fn
4008    self._feed_arrays = feed_arrays
4009    self._feed_symbols = feed_symbols
4010    self._symbol_vals = symbol_vals
4011    self._fetches = list(self.fetches)
4012    self._session = session
4013
4014  def _call_fetch_callbacks(self, fetches_output):
4015    for fetch, output in zip(self._fetches, fetches_output):
4016      if fetch in self.fetch_callbacks:
4017        self.fetch_callbacks[fetch](output)
4018
4019  def _eval_if_composite(self, tensor):
4020    """Helper method which evaluates any CompositeTensors passed to it."""
4021    # We need to evaluate any composite tensor objects that have been
4022    # reconstructed in 'pack_sequence_as', since otherwise they'll be output as
4023    # actual CompositeTensor objects instead of the value(s) contained in the
4024    # CompositeTensors. E.g., if output_structure contains a SparseTensor, then
4025    # this ensures that we return its value as a SparseTensorValue rather than
4026    # a SparseTensor.
4027    from tensorflow.python.keras.utils import tf_utils  # pylint: disable=g-import-not-at-top
4028    if tf_utils.is_extension_type(tensor):
4029      return self._session.run(tensor)
4030    else:
4031      return tensor
4032
4033  def __call__(self, inputs):
4034    inputs = nest.flatten(inputs, expand_composites=True)
4035
4036    session = get_session(inputs)
4037    feed_arrays = []
4038    array_vals = []
4039    feed_symbols = []
4040    symbol_vals = []
4041    for tensor, value in zip(self.inputs, inputs):
4042      if value is None:
4043        continue
4044
4045      if tensor_util.is_tf_type(value):
4046        # Case: feeding symbolic tensor.
4047        feed_symbols.append(tensor)
4048        symbol_vals.append(value)
4049      else:
4050        # Case: feeding Numpy array.
4051        feed_arrays.append(tensor)
4052        # We need to do array conversion and type casting at this level, since
4053        # `callable_fn` only supports exact matches.
4054        tensor_type = dtypes_module.as_dtype(tensor.dtype)
4055        array_vals.append(np.asarray(value,
4056                                     dtype=tensor_type.as_numpy_dtype))
4057
4058    if self.feed_dict:
4059      for key in sorted(self.feed_dict.keys()):
4060        array_vals.append(
4061            np.asarray(self.feed_dict[key], dtype=key.dtype.base_dtype.name))
4062
4063    # Refresh callable if anything has changed.
4064    if (self._callable_fn is None or feed_arrays != self._feed_arrays or
4065        symbol_vals != self._symbol_vals or
4066        feed_symbols != self._feed_symbols or self.fetches != self._fetches or
4067        session != self._session):
4068      self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
4069
4070    fetched = self._callable_fn(*array_vals,
4071                                run_metadata=self.run_metadata)
4072    self._call_fetch_callbacks(fetched[-len(self._fetches):])
4073    output_structure = nest.pack_sequence_as(
4074        self._outputs_structure,
4075        fetched[:len(self.outputs)],
4076        expand_composites=True)
4077    # We need to evaluate any composite tensor objects that have been
4078    # reconstructed in 'pack_sequence_as', since otherwise they'll be output as
4079    # actual CompositeTensor objects instead of the value(s) contained in the
4080    # CompositeTensors. E.g., if output_structure contains a SparseTensor, then
4081    # this ensures that we return its value as a SparseTensorValue rather than
4082    # a SparseTensor.
4083    return nest.map_structure(self._eval_if_composite, output_structure)
4084
4085
4086def eval_in_eager_or_function(outputs):
4087  """Method to evaluate a tensor in eager or in a tf.function.
4088
4089  In the case of a tf.function, it will lift the tensor out of the function
4090  and try to evaluate that piece of the graph.
4091
4092  Warning: Do not add new usages of this function.
4093  TODO(b/150169018): delete this function once _keras_history_helper is no
4094  longer needed, after Keras switches to KerasTensors and op layers
4095  work via dispatch.
4096
4097  Args:
4098    outputs: tensors to fetch.
4099  Returns:
4100    The value of the tensors (as numpy arrays).
4101  """
4102  outputs_structure = outputs
4103  outputs = nest.flatten(outputs, expand_composites=True)
4104
4105  graphs = {
4106      i.graph
4107      for i in nest.flatten([outputs])
4108      if hasattr(i, 'graph')
4109  }
4110  if len(graphs) > 1:
4111    raise ValueError('Cannot create an execution function which is comprised '
4112                     'of elements from multiple graphs.')
4113
4114  source_graph = graphs.pop()
4115
4116  with _scratch_graph() as exec_graph:
4117    global_graph = get_graph()
4118    if source_graph not in (exec_graph, global_graph):
4119      raise ValueError('Unknown graph. Aborting.')
4120
4121    if source_graph is global_graph and exec_graph is not global_graph:
4122      init_tensors = outputs
4123      lifted_map = lift_to_graph.lift_to_graph(
4124          tensors=init_tensors,
4125          graph=exec_graph,
4126          sources=[],
4127          add_sources=True,
4128          handle_captures=True,
4129          base_graph=source_graph)
4130
4131      outputs = [lifted_map[i] for i in outputs]
4132
4133  # Consolidate updates
4134  with exec_graph.as_default():
4135    outputs = cast_variables_to_tensor(outputs)
4136
4137    exec_graph.inputs = exec_graph.internal_captures
4138    exec_graph.outputs = outputs
4139    graph_fn = eager_function.ConcreteFunction(exec_graph)
4140
4141  graph_fn._num_positional_args = 0
4142  graph_fn._arg_keywords = []
4143
4144  outputs = graph_fn()
4145
4146  # EagerTensor.numpy() will often make a copy to ensure memory safety.
4147  # However in this case `outputs` is not directly returned, so it is always
4148  # safe to reuse the underlying buffer without checking. In such a case the
4149  # private numpy conversion method is preferred to guarantee performance.
4150  return nest.pack_sequence_as(
4151      outputs_structure,
4152      [x._numpy() for x in outputs],  # pylint: disable=protected-access
4153      expand_composites=True)
4154
4155
4156@keras_export('keras.backend.function')
4157@doc_controls.do_not_generate_docs
4158def function(inputs, outputs, updates=None, name=None, **kwargs):
4159  """Instantiates a Keras function.
4160
4161  Args:
4162      inputs: List of placeholder tensors.
4163      outputs: List of output tensors.
4164      updates: List of update ops.
4165      name: String, name of function.
4166      **kwargs: Passed to `tf.Session.run`.
4167
4168  Returns:
4169      Output values as Numpy arrays.
4170
4171  Raises:
4172      ValueError: if invalid kwargs are passed in or if in eager execution.
4173  """
4174  if ops.executing_eagerly_outside_functions():
4175    if kwargs:
4176      raise ValueError('Session keyword arguments are not supported during '
4177                       'eager execution. You passed: %s' % (kwargs,))
4178    if updates:
4179      raise ValueError('`updates` argument is not supported during '
4180                       'eager execution. You passed: %s' % (updates,))
4181    from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
4182    from tensorflow.python.keras.utils import tf_utils  # pylint: disable=g-import-not-at-top
4183    model = models.Model(inputs=inputs, outputs=outputs)
4184
4185    wrap_outputs = isinstance(outputs, list) and len(outputs) == 1
4186    def func(model_inputs):
4187      outs = model(model_inputs)
4188      if wrap_outputs:
4189        outs = [outs]
4190      return tf_utils.to_numpy_or_python_type(outs)
4191    return func
4192
4193  if kwargs:
4194    for key in kwargs:
4195      if (key not in tf_inspect.getfullargspec(session_module.Session.run)[0]
4196          and key not in ['inputs', 'outputs', 'updates', 'name']):
4197        msg = ('Invalid argument "%s" passed to K.function with TensorFlow '
4198               'backend') % key
4199        raise ValueError(msg)
4200  return GraphExecutionFunction(
4201      inputs, outputs, updates=updates, name=name, **kwargs)
4202
4203
4204@keras_export('keras.backend.gradients')
4205@doc_controls.do_not_generate_docs
4206def gradients(loss, variables):
4207  """Returns the gradients of `loss` w.r.t. `variables`.
4208
4209  Args:
4210      loss: Scalar tensor to minimize.
4211      variables: List of variables.
4212
4213  Returns:
4214      A gradients tensor.
4215  """
4216  return gradients_module.gradients(
4217      loss, variables, colocate_gradients_with_ops=True)
4218
4219
4220@keras_export('keras.backend.stop_gradient')
4221@dispatch.add_dispatch_support
4222@doc_controls.do_not_generate_docs
4223def stop_gradient(variables):
4224  """Returns `variables` but with zero gradient w.r.t. every other variable.
4225
4226  Args:
4227      variables: Tensor or list of tensors to consider constant with respect
4228        to any other variable.
4229
4230
4231  Returns:
4232      A single tensor or a list of tensors (depending on the passed argument)
4233      that has no gradient with respect to any other variable.
4234  """
4235  if isinstance(variables, (list, tuple)):
4236    return map(array_ops.stop_gradient, variables)
4237  return array_ops.stop_gradient(variables)
4238
4239
4240# CONTROL FLOW
4241
4242
4243@keras_export('keras.backend.rnn')
4244@dispatch.add_dispatch_support
4245def rnn(step_function,
4246        inputs,
4247        initial_states,
4248        go_backwards=False,
4249        mask=None,
4250        constants=None,
4251        unroll=False,
4252        input_length=None,
4253        time_major=False,
4254        zero_output_for_mask=False):
4255  """Iterates over the time dimension of a tensor.
4256
4257  Args:
4258      step_function: RNN step function.
4259          Args;
4260              input; Tensor with shape `(samples, ...)` (no time dimension),
4261                  representing input for the batch of samples at a certain
4262                  time step.
4263              states; List of tensors.
4264          Returns;
4265              output; Tensor with shape `(samples, output_dim)`
4266                  (no time dimension).
4267              new_states; List of tensors, same length and shapes
4268                  as 'states'. The first state in the list must be the
4269                  output tensor at the previous timestep.
4270      inputs: Tensor of temporal data of shape `(samples, time, ...)`
4271          (at least 3D), or nested tensors, and each of which has shape
4272          `(samples, time, ...)`.
4273      initial_states: Tensor with shape `(samples, state_size)`
4274          (no time dimension), containing the initial values for the states used
4275          in the step function. In the case that state_size is in a nested
4276          shape, the shape of initial_states will also follow the nested
4277          structure.
4278      go_backwards: Boolean. If True, do the iteration over the time
4279          dimension in reverse order and return the reversed sequence.
4280      mask: Binary tensor with shape `(samples, time, 1)`,
4281          with a zero for every element that is masked.
4282      constants: List of constant values passed at each step.
4283      unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
4284      input_length: An integer or a 1-D Tensor, depending on whether
4285          the time dimension is fixed-length or not. In case of variable length
4286          input, it is used for masking in case there's no mask specified.
4287      time_major: Boolean. If true, the inputs and outputs will be in shape
4288          `(timesteps, batch, ...)`, whereas in the False case, it will be
4289          `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
4290          efficient because it avoids transposes at the beginning and end of the
4291          RNN calculation. However, most TensorFlow data is batch-major, so by
4292          default this function accepts input and emits output in batch-major
4293          form.
4294      zero_output_for_mask: Boolean. If True, the output for masked timestep
4295          will be zeros, whereas in the False case, output from previous
4296          timestep is returned.
4297
4298  Returns:
4299      A tuple, `(last_output, outputs, new_states)`.
4300          last_output: the latest output of the rnn, of shape `(samples, ...)`
4301          outputs: tensor with shape `(samples, time, ...)` where each
4302              entry `outputs[s, t]` is the output of the step function
4303              at time `t` for sample `s`.
4304          new_states: list of tensors, latest states returned by
4305              the step function, of shape `(samples, ...)`.
4306
4307  Raises:
4308      ValueError: if input dimension is less than 3.
4309      ValueError: if `unroll` is `True` but input timestep is not a fixed
4310      number.
4311      ValueError: if `mask` is provided (not `None`) but states is not provided
4312          (`len(states)` == 0).
4313  """
4314
4315  def swap_batch_timestep(input_t):
4316    # Swap the batch and timestep dim for the incoming tensor.
4317    axes = list(range(len(input_t.shape)))
4318    axes[0], axes[1] = 1, 0
4319    return array_ops.transpose(input_t, axes)
4320
4321  if not time_major:
4322    inputs = nest.map_structure(swap_batch_timestep, inputs)
4323
4324  flatted_inputs = nest.flatten(inputs)
4325  time_steps = flatted_inputs[0].shape[0]
4326  batch = flatted_inputs[0].shape[1]
4327  time_steps_t = array_ops.shape(flatted_inputs[0])[0]
4328
4329  for input_ in flatted_inputs:
4330    input_.shape.with_rank_at_least(3)
4331
4332  if mask is not None:
4333    if mask.dtype != dtypes_module.bool:
4334      mask = math_ops.cast(mask, dtypes_module.bool)
4335    if len(mask.shape) == 2:
4336      mask = expand_dims(mask)
4337    if not time_major:
4338      mask = swap_batch_timestep(mask)
4339
4340  if constants is None:
4341    constants = []
4342
4343  # tf.where needs its condition tensor to be the same shape as its two
4344  # result tensors, but in our case the condition (mask) tensor is
4345  # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
4346  # So we need to broadcast the mask to match the shape of inputs.
4347  # That's what the tile call does, it just repeats the mask along its
4348  # second dimension n times.
4349  def _expand_mask(mask_t, input_t, fixed_dim=1):
4350    if nest.is_nested(mask_t):
4351      raise ValueError('mask_t is expected to be tensor, but got %s' % mask_t)
4352    if nest.is_nested(input_t):
4353      raise ValueError('input_t is expected to be tensor, but got %s' % input_t)
4354    rank_diff = len(input_t.shape) - len(mask_t.shape)
4355    for _ in range(rank_diff):
4356      mask_t = array_ops.expand_dims(mask_t, -1)
4357    multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
4358    return array_ops.tile(mask_t, multiples)
4359
4360  if unroll:
4361    if not time_steps:
4362      raise ValueError('Unrolling requires a fixed number of timesteps.')
4363    states = tuple(initial_states)
4364    successive_states = []
4365    successive_outputs = []
4366
4367    # Process the input tensors. The input tensor need to be split on the
4368    # time_step dim, and reverse if go_backwards is True. In the case of nested
4369    # input, the input is flattened and then transformed individually.
4370    # The result of this will be a tuple of lists, each of the item in tuple is
4371    # list of the tensor with shape (batch, feature)
4372    def _process_single_input_t(input_t):
4373      input_t = array_ops.unstack(input_t)  # unstack for time_step dim
4374      if go_backwards:
4375        input_t.reverse()
4376      return input_t
4377
4378    if nest.is_nested(inputs):
4379      processed_input = nest.map_structure(_process_single_input_t, inputs)
4380    else:
4381      processed_input = (_process_single_input_t(inputs),)
4382
4383    def _get_input_tensor(time):
4384      inp = [t_[time] for t_ in processed_input]
4385      return nest.pack_sequence_as(inputs, inp)
4386
4387    if mask is not None:
4388      mask_list = array_ops.unstack(mask)
4389      if go_backwards:
4390        mask_list.reverse()
4391
4392      for i in range(time_steps):
4393        inp = _get_input_tensor(i)
4394        mask_t = mask_list[i]
4395        output, new_states = step_function(inp,
4396                                           tuple(states) + tuple(constants))
4397        tiled_mask_t = _expand_mask(mask_t, output)
4398
4399        if not successive_outputs:
4400          prev_output = zeros_like(output)
4401        else:
4402          prev_output = successive_outputs[-1]
4403
4404        output = array_ops.where_v2(tiled_mask_t, output, prev_output)
4405
4406        flat_states = nest.flatten(states)
4407        flat_new_states = nest.flatten(new_states)
4408        tiled_mask_t = tuple(_expand_mask(mask_t, s) for s in flat_states)
4409        flat_final_states = tuple(
4410            array_ops.where_v2(m, s, ps)
4411            for m, s, ps in zip(tiled_mask_t, flat_new_states, flat_states))
4412        states = nest.pack_sequence_as(states, flat_final_states)
4413
4414        successive_outputs.append(output)
4415        successive_states.append(states)
4416      last_output = successive_outputs[-1]
4417      new_states = successive_states[-1]
4418      outputs = array_ops.stack(successive_outputs)
4419
4420      if zero_output_for_mask:
4421        last_output = array_ops.where_v2(
4422            _expand_mask(mask_list[-1], last_output), last_output,
4423            zeros_like(last_output))
4424        outputs = array_ops.where_v2(
4425            _expand_mask(mask, outputs, fixed_dim=2), outputs,
4426            zeros_like(outputs))
4427
4428    else:  # mask is None
4429      for i in range(time_steps):
4430        inp = _get_input_tensor(i)
4431        output, states = step_function(inp, tuple(states) + tuple(constants))
4432        successive_outputs.append(output)
4433        successive_states.append(states)
4434      last_output = successive_outputs[-1]
4435      new_states = successive_states[-1]
4436      outputs = array_ops.stack(successive_outputs)
4437
4438  else:  # Unroll == False
4439    states = tuple(initial_states)
4440
4441    # Create input tensor array, if the inputs is nested tensors, then it will
4442    # be flattened first, and tensor array will be created one per flattened
4443    # tensor.
4444    input_ta = tuple(
4445        tensor_array_ops.TensorArray(
4446            dtype=inp.dtype,
4447            size=time_steps_t,
4448            tensor_array_name='input_ta_%s' % i)
4449        for i, inp in enumerate(flatted_inputs))
4450    input_ta = tuple(
4451        ta.unstack(input_) if not go_backwards else ta
4452        .unstack(reverse(input_, 0))
4453        for ta, input_ in zip(input_ta, flatted_inputs))
4454
4455    # Get the time(0) input and compute the output for that, the output will be
4456    # used to determine the dtype of output tensor array. Don't read from
4457    # input_ta due to TensorArray clear_after_read default to True.
4458    input_time_zero = nest.pack_sequence_as(inputs,
4459                                            [inp[0] for inp in flatted_inputs])
4460    # output_time_zero is used to determine the cell output shape and its dtype.
4461    # the value is discarded.
4462    output_time_zero, _ = step_function(
4463        input_time_zero, tuple(initial_states) + tuple(constants))
4464    output_ta = tuple(
4465        tensor_array_ops.TensorArray(
4466            dtype=out.dtype,
4467            size=time_steps_t,
4468            element_shape=out.shape,
4469            tensor_array_name='output_ta_%s' % i)
4470        for i, out in enumerate(nest.flatten(output_time_zero)))
4471
4472    time = constant_op.constant(0, dtype='int32', name='time')
4473
4474    # We only specify the 'maximum_iterations' when building for XLA since that
4475    # causes slowdowns on GPU in TF.
4476    if (not context.executing_eagerly() and
4477        control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())):
4478      max_iterations = math_ops.reduce_max(input_length)
4479    else:
4480      max_iterations = None
4481
4482    while_loop_kwargs = {
4483        'cond': lambda time, *_: time < time_steps_t,
4484        'maximum_iterations': max_iterations,
4485        'parallel_iterations': 32,
4486        'swap_memory': True,
4487    }
4488    if mask is not None:
4489      if go_backwards:
4490        mask = reverse(mask, 0)
4491
4492      mask_ta = tensor_array_ops.TensorArray(
4493          dtype=dtypes_module.bool,
4494          size=time_steps_t,
4495          tensor_array_name='mask_ta')
4496      mask_ta = mask_ta.unstack(mask)
4497
4498      def masking_fn(time):
4499        return mask_ta.read(time)
4500
4501      def compute_masked_output(mask_t, flat_out, flat_mask):
4502        tiled_mask_t = tuple(
4503            _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
4504            for o in flat_out)
4505        return tuple(
4506            array_ops.where_v2(m, o, fm)
4507            for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask))
4508    elif isinstance(input_length, ops.Tensor):
4509      if go_backwards:
4510        max_len = math_ops.reduce_max(input_length, axis=0)
4511        rev_input_length = math_ops.subtract(max_len - 1, input_length)
4512
4513        def masking_fn(time):
4514          return math_ops.less(rev_input_length, time)
4515      else:
4516
4517        def masking_fn(time):
4518          return math_ops.greater(input_length, time)
4519
4520      def compute_masked_output(mask_t, flat_out, flat_mask):
4521        return tuple(
4522            array_ops.where(mask_t, o, zo)
4523            for (o, zo) in zip(flat_out, flat_mask))
4524    else:
4525      masking_fn = None
4526
4527    if masking_fn is not None:
4528      # Mask for the T output will be base on the output of T - 1. In the case
4529      # T = 0, a zero filled tensor will be used.
4530      flat_zero_output = tuple(array_ops.zeros_like(o)
4531                               for o in nest.flatten(output_time_zero))
4532      def _step(time, output_ta_t, prev_output, *states):
4533        """RNN step function.
4534
4535        Args:
4536            time: Current timestep value.
4537            output_ta_t: TensorArray.
4538            prev_output: tuple of outputs from time - 1.
4539            *states: List of states.
4540
4541        Returns:
4542            Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
4543        """
4544        current_input = tuple(ta.read(time) for ta in input_ta)
4545        # maybe set shape.
4546        current_input = nest.pack_sequence_as(inputs, current_input)
4547        mask_t = masking_fn(time)
4548        output, new_states = step_function(current_input,
4549                                           tuple(states) + tuple(constants))
4550        # mask output
4551        flat_output = nest.flatten(output)
4552        flat_mask_output = (flat_zero_output if zero_output_for_mask
4553                            else nest.flatten(prev_output))
4554        flat_new_output = compute_masked_output(mask_t, flat_output,
4555                                                flat_mask_output)
4556
4557        # mask states
4558        flat_state = nest.flatten(states)
4559        flat_new_state = nest.flatten(new_states)
4560        for state, new_state in zip(flat_state, flat_new_state):
4561          if isinstance(new_state, ops.Tensor):
4562            new_state.set_shape(state.shape)
4563        flat_final_state = compute_masked_output(mask_t, flat_new_state,
4564                                                 flat_state)
4565        new_states = nest.pack_sequence_as(new_states, flat_final_state)
4566
4567        output_ta_t = tuple(
4568            ta.write(time, out)
4569            for ta, out in zip(output_ta_t, flat_new_output))
4570        return (time + 1, output_ta_t,
4571                tuple(flat_new_output)) + tuple(new_states)
4572
4573      final_outputs = control_flow_ops.while_loop(
4574          body=_step,
4575          loop_vars=(time, output_ta, flat_zero_output) + states,
4576          **while_loop_kwargs)
4577      # Skip final_outputs[2] which is the output for final timestep.
4578      new_states = final_outputs[3:]
4579    else:
4580      def _step(time, output_ta_t, *states):
4581        """RNN step function.
4582
4583        Args:
4584            time: Current timestep value.
4585            output_ta_t: TensorArray.
4586            *states: List of states.
4587
4588        Returns:
4589            Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
4590        """
4591        current_input = tuple(ta.read(time) for ta in input_ta)
4592        current_input = nest.pack_sequence_as(inputs, current_input)
4593        output, new_states = step_function(current_input,
4594                                           tuple(states) + tuple(constants))
4595        flat_state = nest.flatten(states)
4596        flat_new_state = nest.flatten(new_states)
4597        for state, new_state in zip(flat_state, flat_new_state):
4598          if isinstance(new_state, ops.Tensor):
4599            new_state.set_shape(state.shape)
4600
4601        flat_output = nest.flatten(output)
4602        output_ta_t = tuple(
4603            ta.write(time, out) for ta, out in zip(output_ta_t, flat_output))
4604        new_states = nest.pack_sequence_as(initial_states, flat_new_state)
4605        return (time + 1, output_ta_t) + tuple(new_states)
4606
4607      final_outputs = control_flow_ops.while_loop(
4608          body=_step,
4609          loop_vars=(time, output_ta) + states,
4610          **while_loop_kwargs)
4611      new_states = final_outputs[2:]
4612
4613    output_ta = final_outputs[1]
4614
4615    outputs = tuple(o.stack() for o in output_ta)
4616    last_output = tuple(o[-1] for o in outputs)
4617
4618    outputs = nest.pack_sequence_as(output_time_zero, outputs)
4619    last_output = nest.pack_sequence_as(output_time_zero, last_output)
4620
4621  # static shape inference
4622  def set_shape(output_):
4623    if isinstance(output_, ops.Tensor):
4624      shape = output_.shape.as_list()
4625      shape[0] = time_steps
4626      shape[1] = batch
4627      output_.set_shape(shape)
4628    return output_
4629
4630  outputs = nest.map_structure(set_shape, outputs)
4631
4632  if not time_major:
4633    outputs = nest.map_structure(swap_batch_timestep, outputs)
4634
4635  return last_output, outputs, new_states
4636
4637
4638@keras_export('keras.backend.switch')
4639@dispatch.add_dispatch_support
4640@doc_controls.do_not_generate_docs
4641def switch(condition, then_expression, else_expression):
4642  """Switches between two operations depending on a scalar value.
4643
4644  Note that both `then_expression` and `else_expression`
4645  should be symbolic tensors of the *same shape*.
4646
4647  Args:
4648      condition: tensor (`int` or `bool`).
4649      then_expression: either a tensor, or a callable that returns a tensor.
4650      else_expression: either a tensor, or a callable that returns a tensor.
4651
4652  Returns:
4653      The selected tensor.
4654
4655  Raises:
4656      ValueError: If rank of `condition` is greater than rank of expressions.
4657  """
4658  if condition.dtype != dtypes_module.bool:
4659    condition = math_ops.cast(condition, 'bool')
4660  cond_ndim = ndim(condition)
4661  if not cond_ndim:
4662    if not callable(then_expression):
4663
4664      def then_expression_fn():
4665        return then_expression
4666    else:
4667      then_expression_fn = then_expression
4668    if not callable(else_expression):
4669
4670      def else_expression_fn():
4671        return else_expression
4672    else:
4673      else_expression_fn = else_expression
4674    x = control_flow_ops.cond(condition, then_expression_fn, else_expression_fn)
4675  else:
4676    # tf.where needs its condition tensor
4677    # to be the same shape as its two
4678    # result tensors
4679    if callable(then_expression):
4680      then_expression = then_expression()
4681    if callable(else_expression):
4682      else_expression = else_expression()
4683    expr_ndim = ndim(then_expression)
4684    if cond_ndim > expr_ndim:
4685      raise ValueError('Rank of `condition` should be less than or'
4686                       ' equal to rank of `then_expression` and '
4687                       '`else_expression`. ndim(condition)=' + str(cond_ndim) +
4688                       ', ndim(then_expression)'
4689                       '=' + str(expr_ndim))
4690    if cond_ndim > 1:
4691      ndim_diff = expr_ndim - cond_ndim
4692      cond_shape = array_ops.concat(
4693          [array_ops.shape(condition), [1] * ndim_diff], axis=0)
4694      condition = array_ops.reshape(condition, cond_shape)
4695      expr_shape = array_ops.shape(then_expression)
4696      shape_diff = expr_shape - cond_shape
4697      tile_shape = array_ops.where_v2(shape_diff > 0, expr_shape,
4698                                      array_ops.ones_like(expr_shape))
4699      condition = array_ops.tile(condition, tile_shape)
4700    x = array_ops.where_v2(condition, then_expression, else_expression)
4701  return x
4702
4703
4704@keras_export('keras.backend.in_train_phase')
4705@doc_controls.do_not_generate_docs
4706def in_train_phase(x, alt, training=None):
4707  """Selects `x` in train phase, and `alt` otherwise.
4708
4709  Note that `alt` should have the *same shape* as `x`.
4710
4711  Args:
4712      x: What to return in train phase
4713          (tensor or callable that returns a tensor).
4714      alt: What to return otherwise
4715          (tensor or callable that returns a tensor).
4716      training: Optional scalar tensor
4717          (or Python boolean, or Python integer)
4718          specifying the learning phase.
4719
4720  Returns:
4721      Either `x` or `alt` based on the `training` flag.
4722      the `training` flag defaults to `K.learning_phase()`.
4723  """
4724  from tensorflow.python.keras.engine import base_layer_utils  # pylint: disable=g-import-not-at-top
4725  if training is None:
4726    training = base_layer_utils.call_context().training
4727
4728  if training is None:
4729    training = learning_phase()
4730
4731  # TODO(b/138862903): Handle the case when training is tensor.
4732  if not tensor_util.is_tf_type(training):
4733    if training == 1 or training is True:
4734      if callable(x):
4735        return x()
4736      else:
4737        return x
4738
4739    elif training == 0 or training is False:
4740      if callable(alt):
4741        return alt()
4742      else:
4743        return alt
4744
4745  # else: assume learning phase is a placeholder tensor.
4746  x = switch(training, x, alt)
4747  return x
4748
4749
4750@keras_export('keras.backend.in_test_phase')
4751@doc_controls.do_not_generate_docs
4752def in_test_phase(x, alt, training=None):
4753  """Selects `x` in test phase, and `alt` otherwise.
4754
4755  Note that `alt` should have the *same shape* as `x`.
4756
4757  Args:
4758      x: What to return in test phase
4759          (tensor or callable that returns a tensor).
4760      alt: What to return otherwise
4761          (tensor or callable that returns a tensor).
4762      training: Optional scalar tensor
4763          (or Python boolean, or Python integer)
4764          specifying the learning phase.
4765
4766  Returns:
4767      Either `x` or `alt` based on `K.learning_phase`.
4768  """
4769  return in_train_phase(alt, x, training=training)
4770
4771
4772# NN OPERATIONS
4773
4774
4775@keras_export('keras.backend.relu')
4776@dispatch.add_dispatch_support
4777@doc_controls.do_not_generate_docs
4778def relu(x, alpha=0., max_value=None, threshold=0):
4779  """Rectified linear unit.
4780
4781  With default values, it returns element-wise `max(x, 0)`.
4782
4783  Otherwise, it follows:
4784  `f(x) = max_value` for `x >= max_value`,
4785  `f(x) = x` for `threshold <= x < max_value`,
4786  `f(x) = alpha * (x - threshold)` otherwise.
4787
4788  Args:
4789      x: A tensor or variable.
4790      alpha: A scalar, slope of negative section (default=`0.`).
4791      max_value: float. Saturation threshold.
4792      threshold: float. Threshold value for thresholded activation.
4793
4794  Returns:
4795      A tensor.
4796  """
4797  # While x can be a tensor or variable, we also see cases where
4798  # numpy arrays, lists, tuples are passed as well.
4799  # lists, tuples do not have 'dtype' attribute.
4800  dtype = getattr(x, 'dtype', floatx())
4801  if alpha != 0.:
4802    if max_value is None and threshold == 0:
4803      return nn.leaky_relu(x, alpha=alpha)
4804
4805    if threshold != 0:
4806      negative_part = nn.relu(-x + threshold)
4807    else:
4808      negative_part = nn.relu(-x)
4809
4810  clip_max = max_value is not None
4811
4812  if threshold != 0:
4813    # computes x for x > threshold else 0
4814    x = x * math_ops.cast(math_ops.greater(x, threshold), dtype=dtype)
4815  elif max_value == 6:
4816    # if no threshold, then can use nn.relu6 native TF op for performance
4817    x = nn.relu6(x)
4818    clip_max = False
4819  else:
4820    x = nn.relu(x)
4821
4822  if clip_max:
4823    max_value = _constant_to_tensor(max_value, x.dtype.base_dtype)
4824    zero = _constant_to_tensor(0, x.dtype.base_dtype)
4825    x = clip_ops.clip_by_value(x, zero, max_value)
4826
4827  if alpha != 0.:
4828    alpha = _to_tensor(alpha, x.dtype.base_dtype)
4829    x -= alpha * negative_part
4830  return x
4831
4832
4833@keras_export('keras.backend.elu')
4834@dispatch.add_dispatch_support
4835@doc_controls.do_not_generate_docs
4836def elu(x, alpha=1.):
4837  """Exponential linear unit.
4838
4839  Args:
4840      x: A tensor or variable to compute the activation function for.
4841      alpha: A scalar, slope of negative section.
4842
4843  Returns:
4844      A tensor.
4845  """
4846  res = nn.elu(x)
4847  if alpha == 1:
4848    return res
4849  else:
4850    return array_ops.where_v2(x > 0, res, alpha * res)
4851
4852
4853@keras_export('keras.backend.softmax')
4854@dispatch.add_dispatch_support
4855@doc_controls.do_not_generate_docs
4856def softmax(x, axis=-1):
4857  """Softmax of a tensor.
4858
4859  Args:
4860      x: A tensor or variable.
4861      axis: The dimension softmax would be performed on.
4862          The default is -1 which indicates the last dimension.
4863
4864  Returns:
4865      A tensor.
4866  """
4867  return nn.softmax(x, axis=axis)
4868
4869
4870@keras_export('keras.backend.softplus')
4871@dispatch.add_dispatch_support
4872@doc_controls.do_not_generate_docs
4873def softplus(x):
4874  """Softplus of a tensor.
4875
4876  Args:
4877      x: A tensor or variable.
4878
4879  Returns:
4880      A tensor.
4881  """
4882  return nn.softplus(x)
4883
4884
4885@keras_export('keras.backend.softsign')
4886@dispatch.add_dispatch_support
4887@doc_controls.do_not_generate_docs
4888def softsign(x):
4889  """Softsign of a tensor.
4890
4891  Args:
4892      x: A tensor or variable.
4893
4894  Returns:
4895      A tensor.
4896  """
4897  return nn.softsign(x)
4898
4899
4900@keras_export('keras.backend.categorical_crossentropy')
4901@dispatch.add_dispatch_support
4902@doc_controls.do_not_generate_docs
4903def categorical_crossentropy(target, output, from_logits=False, axis=-1):
4904  """Categorical crossentropy between an output tensor and a target tensor.
4905
4906  Args:
4907      target: A tensor of the same shape as `output`.
4908      output: A tensor resulting from a softmax
4909          (unless `from_logits` is True, in which
4910          case `output` is expected to be the logits).
4911      from_logits: Boolean, whether `output` is the
4912          result of a softmax, or is a tensor of logits.
4913      axis: Int specifying the channels axis. `axis=-1` corresponds to data
4914          format `channels_last`, and `axis=1` corresponds to data format
4915          `channels_first`.
4916
4917  Returns:
4918      Output tensor.
4919
4920  Raises:
4921      ValueError: if `axis` is neither -1 nor one of the axes of `output`.
4922
4923  Example:
4924
4925  >>> a = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 1.], shape=[3,3])
4926  >>> print(a)
4927  tf.Tensor(
4928    [[1. 0. 0.]
4929     [0. 1. 0.]
4930     [0. 0. 1.]], shape=(3, 3), dtype=float32)
4931  >>> b = tf.constant([.9, .05, .05, .05, .89, .06, .05, .01, .94], shape=[3,3])
4932  >>> print(b)
4933  tf.Tensor(
4934    [[0.9  0.05 0.05]
4935     [0.05 0.89 0.06]
4936     [0.05 0.01 0.94]], shape=(3, 3), dtype=float32)
4937  >>> loss = tf.keras.backend.categorical_crossentropy(a, b)
4938  >>> print(np.around(loss, 5))
4939  [0.10536 0.11653 0.06188]
4940  >>> loss = tf.keras.backend.categorical_crossentropy(a, a)
4941  >>> print(np.around(loss, 5))
4942  [0. 0. 0.]
4943
4944  """
4945  target = ops.convert_to_tensor_v2_with_dispatch(target)
4946  output = ops.convert_to_tensor_v2_with_dispatch(output)
4947  target.shape.assert_is_compatible_with(output.shape)
4948
4949  # Use logits whenever they are available. `softmax` and `sigmoid`
4950  # activations cache logits on the `output` Tensor.
4951  if hasattr(output, '_keras_logits'):
4952    output = output._keras_logits  # pylint: disable=protected-access
4953    if from_logits:
4954      warnings.warn(
4955          '"`categorical_crossentropy` received `from_logits=True`, but '
4956          'the `output` argument was produced by a sigmoid or softmax '
4957          'activation and thus does not represent logits. Was this intended?"')
4958    from_logits = True
4959
4960  if from_logits:
4961    return nn.softmax_cross_entropy_with_logits_v2(
4962        labels=target, logits=output, axis=axis)
4963
4964  if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
4965      output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
4966    # When softmax activation function is used for output operation, we
4967    # use logits from the softmax function directly to compute loss in order
4968    # to prevent collapsing zero when training.
4969    # See b/117284466
4970    assert len(output.op.inputs) == 1
4971    output = output.op.inputs[0]
4972    return nn.softmax_cross_entropy_with_logits_v2(
4973        labels=target, logits=output, axis=axis)
4974
4975  # scale preds so that the class probas of each sample sum to 1
4976  output = output / math_ops.reduce_sum(output, axis, True)
4977  # Compute cross entropy from probabilities.
4978  epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
4979  output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
4980  return -math_ops.reduce_sum(target * math_ops.log(output), axis)
4981
4982
4983@keras_export('keras.backend.sparse_categorical_crossentropy')
4984@dispatch.add_dispatch_support
4985@doc_controls.do_not_generate_docs
4986def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
4987  """Categorical crossentropy with integer targets.
4988
4989  Args:
4990      target: An integer tensor.
4991      output: A tensor resulting from a softmax
4992          (unless `from_logits` is True, in which
4993          case `output` is expected to be the logits).
4994      from_logits: Boolean, whether `output` is the
4995          result of a softmax, or is a tensor of logits.
4996      axis: Int specifying the channels axis. `axis=-1` corresponds to data
4997          format `channels_last`, and `axis=1` corresponds to data format
4998          `channels_first`.
4999
5000  Returns:
5001      Output tensor.
5002
5003  Raises:
5004      ValueError: if `axis` is neither -1 nor one of the axes of `output`.
5005  """
5006  target = ops.convert_to_tensor_v2_with_dispatch(target)
5007  output = ops.convert_to_tensor_v2_with_dispatch(output)
5008
5009  # Use logits whenever they are available. `softmax` and `sigmoid`
5010  # activations cache logits on the `output` Tensor.
5011  if hasattr(output, '_keras_logits'):
5012    output = output._keras_logits  # pylint: disable=protected-access
5013    if from_logits:
5014      warnings.warn(
5015          '"`sparse_categorical_crossentropy` received `from_logits=True`, but '
5016          'the `output` argument was produced by a sigmoid or softmax '
5017          'activation and thus does not represent logits. Was this intended?"')
5018    from_logits = True
5019  elif (not from_logits and
5020        not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
5021        output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
5022    # When softmax activation function is used for output operation, we
5023    # use logits from the softmax function directly to compute loss in order
5024    # to prevent collapsing zero when training.
5025    # See b/117284466
5026    assert len(output.op.inputs) == 1
5027    output = output.op.inputs[0]
5028    from_logits = True
5029  elif not from_logits:
5030    epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
5031    output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
5032    output = math_ops.log(output)
5033
5034  if isinstance(output.shape, (tuple, list)):
5035    output_rank = len(output.shape)
5036  else:
5037    output_rank = output.shape.ndims
5038  if output_rank is not None:
5039    axis %= output_rank
5040    if axis != output_rank - 1:
5041      permutation = list(
5042          itertools.chain(range(axis), range(axis + 1, output_rank), [axis]))
5043      output = array_ops.transpose(output, perm=permutation)
5044  elif axis != -1:
5045    raise ValueError(
5046        'Cannot compute sparse categorical crossentropy with `axis={}` on an '
5047        'output tensor with unknown rank'.format(axis))
5048
5049  target = cast(target, 'int64')
5050
5051  # Try to adjust the shape so that rank of labels = rank of logits - 1.
5052  output_shape = array_ops.shape_v2(output)
5053  target_rank = target.shape.ndims
5054
5055  update_shape = (
5056      target_rank is not None and output_rank is not None and
5057      target_rank != output_rank - 1)
5058  if update_shape:
5059    target = flatten(target)
5060    output = array_ops.reshape(output, [-1, output_shape[-1]])
5061
5062  if py_any(_is_symbolic_tensor(v) for v in [target, output]):
5063    with get_graph().as_default():
5064      res = nn.sparse_softmax_cross_entropy_with_logits_v2(
5065          labels=target, logits=output)
5066  else:
5067    res = nn.sparse_softmax_cross_entropy_with_logits_v2(
5068        labels=target, logits=output)
5069
5070  if update_shape and output_rank >= 3:
5071    # If our output includes timesteps or spatial dimensions we need to reshape
5072    return array_ops.reshape(res, output_shape[:-1])
5073  else:
5074    return res
5075
5076
5077@keras_export('keras.backend.binary_crossentropy')
5078@dispatch.add_dispatch_support
5079@doc_controls.do_not_generate_docs
5080def binary_crossentropy(target, output, from_logits=False):
5081  """Binary crossentropy between an output tensor and a target tensor.
5082
5083  Args:
5084      target: A tensor with the same shape as `output`.
5085      output: A tensor.
5086      from_logits: Whether `output` is expected to be a logits tensor.
5087          By default, we consider that `output`
5088          encodes a probability distribution.
5089
5090  Returns:
5091      A tensor.
5092  """
5093  target = ops.convert_to_tensor_v2_with_dispatch(target)
5094  output = ops.convert_to_tensor_v2_with_dispatch(output)
5095
5096  # Use logits whenever they are available. `softmax` and `sigmoid`
5097  # activations cache logits on the `output` Tensor.
5098  if hasattr(output, '_keras_logits'):
5099    output = output._keras_logits  # pylint: disable=protected-access
5100    if from_logits:
5101      warnings.warn(
5102          '"`binary_crossentropy` received `from_logits=True`, but the `output`'
5103          ' argument was produced by a sigmoid or softmax activation and thus '
5104          'does not represent logits. Was this intended?"')
5105    from_logits = True
5106
5107  if from_logits:
5108    return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
5109
5110  if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
5111      output.op.type == 'Sigmoid') and not hasattr(output, '_keras_history'):
5112    # When sigmoid activation function is used for output operation, we
5113    # use logits from the sigmoid function directly to compute loss in order
5114    # to prevent collapsing zero when training.
5115    assert len(output.op.inputs) == 1
5116    output = output.op.inputs[0]
5117    return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
5118
5119  epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
5120  output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
5121
5122  # Compute cross entropy from probabilities.
5123  bce = target * math_ops.log(output + epsilon())
5124  bce += (1 - target) * math_ops.log(1 - output + epsilon())
5125  return -bce
5126
5127
5128@keras_export('keras.backend.sigmoid')
5129@dispatch.add_dispatch_support
5130@doc_controls.do_not_generate_docs
5131def sigmoid(x):
5132  """Element-wise sigmoid.
5133
5134  Args:
5135      x: A tensor or variable.
5136
5137  Returns:
5138      A tensor.
5139  """
5140  return nn.sigmoid(x)
5141
5142
5143@keras_export('keras.backend.hard_sigmoid')
5144@dispatch.add_dispatch_support
5145@doc_controls.do_not_generate_docs
5146def hard_sigmoid(x):
5147  """Segment-wise linear approximation of sigmoid.
5148
5149  Faster than sigmoid.
5150  Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
5151  In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
5152
5153  Args:
5154      x: A tensor or variable.
5155
5156  Returns:
5157      A tensor.
5158  """
5159  point_two = _constant_to_tensor(0.2, x.dtype.base_dtype)
5160  point_five = _constant_to_tensor(0.5, x.dtype.base_dtype)
5161  x = math_ops.multiply(x, point_two)
5162  x = math_ops.add(x, point_five)
5163  x = clip_ops.clip_by_value(x, 0., 1.)
5164  return x
5165
5166
5167@keras_export('keras.backend.tanh')
5168@dispatch.add_dispatch_support
5169@doc_controls.do_not_generate_docs
5170def tanh(x):
5171  """Element-wise tanh.
5172
5173  Args:
5174      x: A tensor or variable.
5175
5176  Returns:
5177      A tensor.
5178  """
5179  return nn.tanh(x)
5180
5181
5182@keras_export('keras.backend.dropout')
5183@dispatch.add_dispatch_support
5184@doc_controls.do_not_generate_docs
5185def dropout(x, level, noise_shape=None, seed=None):
5186  """Sets entries in `x` to zero at random, while scaling the entire tensor.
5187
5188  Args:
5189      x: tensor
5190      level: fraction of the entries in the tensor
5191          that will be set to 0.
5192      noise_shape: shape for randomly generated keep/drop flags,
5193          must be broadcastable to the shape of `x`
5194      seed: random seed to ensure determinism.
5195
5196  Returns:
5197      A tensor.
5198  """
5199  if seed is None:
5200    seed = np.random.randint(10e6)
5201  return nn.dropout_v2(x, rate=level, noise_shape=noise_shape, seed=seed)
5202
5203
5204@keras_export('keras.backend.l2_normalize')
5205@dispatch.add_dispatch_support
5206@doc_controls.do_not_generate_docs
5207def l2_normalize(x, axis=None):
5208  """Normalizes a tensor wrt the L2 norm alongside the specified axis.
5209
5210  Args:
5211      x: Tensor or variable.
5212      axis: axis along which to perform normalization.
5213
5214  Returns:
5215      A tensor.
5216  """
5217  return nn.l2_normalize(x, axis=axis)
5218
5219
5220@keras_export('keras.backend.in_top_k')
5221@dispatch.add_dispatch_support
5222@doc_controls.do_not_generate_docs
5223def in_top_k(predictions, targets, k):
5224  """Returns whether the `targets` are in the top `k` `predictions`.
5225
5226  Args:
5227      predictions: A tensor of shape `(batch_size, classes)` and type `float32`.
5228      targets: A 1D tensor of length `batch_size` and type `int32` or `int64`.
5229      k: An `int`, number of top elements to consider.
5230
5231  Returns:
5232      A 1D tensor of length `batch_size` and type `bool`.
5233      `output[i]` is `True` if `predictions[i, targets[i]]` is within top-`k`
5234      values of `predictions[i]`.
5235  """
5236  return nn.in_top_k(predictions, targets, k)
5237
5238
5239# CONVOLUTIONS
5240
5241
5242def _preprocess_conv1d_input(x, data_format):
5243  """Transpose and cast the input before the conv1d.
5244
5245  Args:
5246      x: input tensor.
5247      data_format: string, `"channels_last"` or `"channels_first"`.
5248
5249  Returns:
5250      A tensor.
5251  """
5252  tf_data_format = 'NWC'  # to pass TF Conv2dNative operations
5253  if data_format == 'channels_first':
5254    if not _has_nchw_support():
5255      x = array_ops.transpose(x, (0, 2, 1))  # NCW -> NWC
5256    else:
5257      tf_data_format = 'NCW'
5258  return x, tf_data_format
5259
5260
5261def _preprocess_conv2d_input(x, data_format, force_transpose=False):
5262  """Transpose and cast the input before the conv2d.
5263
5264  Args:
5265      x: input tensor.
5266      data_format: string, `"channels_last"` or `"channels_first"`.
5267      force_transpose: Boolean. If True, the input will always be transposed
5268          from NCHW to NHWC if `data_format` is `"channels_first"`.
5269          If False, the transposition only occurs on CPU (GPU ops are
5270          assumed to support NCHW).
5271
5272  Returns:
5273      A tensor.
5274  """
5275  tf_data_format = 'NHWC'
5276  if data_format == 'channels_first':
5277    if not _has_nchw_support() or force_transpose:
5278      x = array_ops.transpose(x, (0, 2, 3, 1))  # NCHW -> NHWC
5279    else:
5280      tf_data_format = 'NCHW'
5281  return x, tf_data_format
5282
5283
5284def _preprocess_conv3d_input(x, data_format):
5285  """Transpose and cast the input before the conv3d.
5286
5287  Args:
5288      x: input tensor.
5289      data_format: string, `"channels_last"` or `"channels_first"`.
5290
5291  Returns:
5292      A tensor.
5293  """
5294  tf_data_format = 'NDHWC'
5295  if data_format == 'channels_first':
5296    if not _has_nchw_support():
5297      x = array_ops.transpose(x, (0, 2, 3, 4, 1))
5298    else:
5299      tf_data_format = 'NCDHW'
5300  return x, tf_data_format
5301
5302
5303def _preprocess_padding(padding):
5304  """Convert keras' padding to TensorFlow's padding.
5305
5306  Args:
5307      padding: string, one of 'same' , 'valid'
5308
5309  Returns:
5310      a string, one of 'SAME', 'VALID'.
5311
5312  Raises:
5313      ValueError: if invalid `padding'`
5314  """
5315  if padding == 'same':
5316    padding = 'SAME'
5317  elif padding == 'valid':
5318    padding = 'VALID'
5319  else:
5320    raise ValueError('Invalid padding: ' + str(padding))
5321  return padding
5322
5323
5324@keras_export('keras.backend.conv1d')
5325@dispatch.add_dispatch_support
5326@doc_controls.do_not_generate_docs
5327def conv1d(x,
5328           kernel,
5329           strides=1,
5330           padding='valid',
5331           data_format=None,
5332           dilation_rate=1):
5333  """1D convolution.
5334
5335  Args:
5336      x: Tensor or variable.
5337      kernel: kernel tensor.
5338      strides: stride integer.
5339      padding: string, `"same"`, `"causal"` or `"valid"`.
5340      data_format: string, one of "channels_last", "channels_first".
5341      dilation_rate: integer dilate rate.
5342
5343  Returns:
5344      A tensor, result of 1D convolution.
5345
5346  Raises:
5347      ValueError: if `data_format` is neither `channels_last` or
5348      `channels_first`.
5349  """
5350  if data_format is None:
5351    data_format = image_data_format()
5352  if data_format not in {'channels_first', 'channels_last'}:
5353    raise ValueError('Unknown data_format: ' + str(data_format))
5354
5355  kernel_shape = kernel.shape.as_list()
5356  if padding == 'causal':
5357    # causal (dilated) convolution:
5358    left_pad = dilation_rate * (kernel_shape[0] - 1)
5359    x = temporal_padding(x, (left_pad, 0))
5360    padding = 'valid'
5361  padding = _preprocess_padding(padding)
5362
5363  x, tf_data_format = _preprocess_conv1d_input(x, data_format)
5364  x = nn.convolution(
5365      input=x,
5366      filter=kernel,
5367      dilation_rate=dilation_rate,
5368      strides=strides,
5369      padding=padding,
5370      data_format=tf_data_format)
5371  if data_format == 'channels_first' and tf_data_format == 'NWC':
5372    x = array_ops.transpose(x, (0, 2, 1))  # NWC -> NCW
5373  return x
5374
5375
5376@keras_export('keras.backend.conv2d')
5377@dispatch.add_dispatch_support
5378@doc_controls.do_not_generate_docs
5379def conv2d(x,
5380           kernel,
5381           strides=(1, 1),
5382           padding='valid',
5383           data_format=None,
5384           dilation_rate=(1, 1)):
5385  """2D convolution.
5386
5387  Args:
5388      x: Tensor or variable.
5389      kernel: kernel tensor.
5390      strides: strides tuple.
5391      padding: string, `"same"` or `"valid"`.
5392      data_format: `"channels_last"` or `"channels_first"`.
5393      dilation_rate: tuple of 2 integers.
5394
5395  Returns:
5396      A tensor, result of 2D convolution.
5397
5398  Raises:
5399      ValueError: if `data_format` is neither `channels_last` or
5400      `channels_first`.
5401  """
5402  if data_format is None:
5403    data_format = image_data_format()
5404  if data_format not in {'channels_first', 'channels_last'}:
5405    raise ValueError('Unknown data_format: ' + str(data_format))
5406
5407  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5408  padding = _preprocess_padding(padding)
5409  x = nn.convolution(
5410      input=x,
5411      filter=kernel,
5412      dilation_rate=dilation_rate,
5413      strides=strides,
5414      padding=padding,
5415      data_format=tf_data_format)
5416  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5417    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5418  return x
5419
5420
5421@keras_export('keras.backend.conv2d_transpose')
5422@dispatch.add_dispatch_support
5423@doc_controls.do_not_generate_docs
5424def conv2d_transpose(x,
5425                     kernel,
5426                     output_shape,
5427                     strides=(1, 1),
5428                     padding='valid',
5429                     data_format=None,
5430                     dilation_rate=(1, 1)):
5431  """2D deconvolution (i.e.
5432
5433  transposed convolution).
5434
5435  Args:
5436      x: Tensor or variable.
5437      kernel: kernel tensor.
5438      output_shape: 1D int tensor for the output shape.
5439      strides: strides tuple.
5440      padding: string, `"same"` or `"valid"`.
5441      data_format: string, `"channels_last"` or `"channels_first"`.
5442      dilation_rate: Tuple of 2 integers.
5443
5444  Returns:
5445      A tensor, result of transposed 2D convolution.
5446
5447  Raises:
5448      ValueError: if `data_format` is neither `channels_last` or
5449      `channels_first`.
5450  """
5451  if data_format is None:
5452    data_format = image_data_format()
5453  if data_format not in {'channels_first', 'channels_last'}:
5454    raise ValueError('Unknown data_format: ' + str(data_format))
5455
5456  # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
5457  if data_format == 'channels_first' and dilation_rate != (1, 1):
5458    force_transpose = True
5459  else:
5460    force_transpose = False
5461
5462  x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
5463
5464  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5465    output_shape = (output_shape[0], output_shape[2], output_shape[3],
5466                    output_shape[1])
5467  if output_shape[0] is None:
5468    output_shape = (shape(x)[0],) + tuple(output_shape[1:])
5469
5470  if isinstance(output_shape, (tuple, list)):
5471    output_shape = array_ops.stack(list(output_shape))
5472
5473  padding = _preprocess_padding(padding)
5474  if tf_data_format == 'NHWC':
5475    strides = (1,) + strides + (1,)
5476  else:
5477    strides = (1, 1) + strides
5478
5479  if dilation_rate == (1, 1):
5480    x = nn.conv2d_transpose(x, kernel, output_shape, strides,
5481                            padding=padding,
5482                            data_format=tf_data_format)
5483  else:
5484    assert dilation_rate[0] == dilation_rate[1]
5485    x = nn.atrous_conv2d_transpose(
5486        x,
5487        kernel,
5488        output_shape,
5489        rate=dilation_rate[0],
5490        padding=padding)
5491  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5492    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5493  return x
5494
5495
5496def separable_conv1d(x,
5497                     depthwise_kernel,
5498                     pointwise_kernel,
5499                     strides=1,
5500                     padding='valid',
5501                     data_format=None,
5502                     dilation_rate=1):
5503  """1D convolution with separable filters.
5504
5505  Args:
5506      x: input tensor
5507      depthwise_kernel: convolution kernel for the depthwise convolution.
5508      pointwise_kernel: kernel for the 1x1 convolution.
5509      strides: stride integer.
5510      padding: string, `"same"` or `"valid"`.
5511      data_format: string, `"channels_last"` or `"channels_first"`.
5512      dilation_rate: integer dilation rate.
5513
5514  Returns:
5515      Output tensor.
5516
5517  Raises:
5518      ValueError: if `data_format` is neither `channels_last` or
5519      `channels_first`.
5520  """
5521  if data_format is None:
5522    data_format = image_data_format()
5523  if data_format not in {'channels_first', 'channels_last'}:
5524    raise ValueError('Unknown data_format: ' + str(data_format))
5525
5526  if isinstance(strides, int):
5527    strides = (strides,)
5528  if isinstance(dilation_rate, int):
5529    dilation_rate = (dilation_rate,)
5530
5531  x, tf_data_format = _preprocess_conv1d_input(x, data_format)
5532  padding = _preprocess_padding(padding)
5533  if not isinstance(strides, tuple):
5534    strides = tuple(strides)
5535  if tf_data_format == 'NWC':
5536    spatial_start_dim = 1
5537    strides = (1,) + strides * 2 + (1,)
5538  else:
5539    spatial_start_dim = 2
5540    strides = (1, 1) + strides * 2
5541  x = array_ops.expand_dims(x, spatial_start_dim)
5542  depthwise_kernel = array_ops.expand_dims(depthwise_kernel, 0)
5543  pointwise_kernel = array_ops.expand_dims(pointwise_kernel, 0)
5544  dilation_rate = (1,) + dilation_rate
5545
5546  x = nn.separable_conv2d(
5547      x,
5548      depthwise_kernel,
5549      pointwise_kernel,
5550      strides=strides,
5551      padding=padding,
5552      rate=dilation_rate,
5553      data_format=tf_data_format)
5554
5555  x = array_ops.squeeze(x, [spatial_start_dim])
5556
5557  if data_format == 'channels_first' and tf_data_format == 'NWC':
5558    x = array_ops.transpose(x, (0, 2, 1))  # NWC -> NCW
5559
5560  return x
5561
5562
5563@keras_export('keras.backend.separable_conv2d')
5564@dispatch.add_dispatch_support
5565@doc_controls.do_not_generate_docs
5566def separable_conv2d(x,
5567                     depthwise_kernel,
5568                     pointwise_kernel,
5569                     strides=(1, 1),
5570                     padding='valid',
5571                     data_format=None,
5572                     dilation_rate=(1, 1)):
5573  """2D convolution with separable filters.
5574
5575  Args:
5576      x: input tensor
5577      depthwise_kernel: convolution kernel for the depthwise convolution.
5578      pointwise_kernel: kernel for the 1x1 convolution.
5579      strides: strides tuple (length 2).
5580      padding: string, `"same"` or `"valid"`.
5581      data_format: string, `"channels_last"` or `"channels_first"`.
5582      dilation_rate: tuple of integers,
5583          dilation rates for the separable convolution.
5584
5585  Returns:
5586      Output tensor.
5587
5588  Raises:
5589      ValueError: if `data_format` is neither `channels_last` or
5590      `channels_first`.
5591      ValueError: if `strides` is not a tuple of 2 integers.
5592  """
5593  if data_format is None:
5594    data_format = image_data_format()
5595  if data_format not in {'channels_first', 'channels_last'}:
5596    raise ValueError('Unknown data_format: ' + str(data_format))
5597  if len(strides) != 2:
5598    raise ValueError('`strides` must be a tuple of 2 integers.')
5599
5600  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5601  padding = _preprocess_padding(padding)
5602  if not isinstance(strides, tuple):
5603    strides = tuple(strides)
5604  if tf_data_format == 'NHWC':
5605    strides = (1,) + strides + (1,)
5606  else:
5607    strides = (1, 1) + strides
5608
5609  x = nn.separable_conv2d(
5610      x,
5611      depthwise_kernel,
5612      pointwise_kernel,
5613      strides=strides,
5614      padding=padding,
5615      rate=dilation_rate,
5616      data_format=tf_data_format)
5617  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5618    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5619  return x
5620
5621
5622@keras_export('keras.backend.depthwise_conv2d')
5623@dispatch.add_dispatch_support
5624@doc_controls.do_not_generate_docs
5625def depthwise_conv2d(x,
5626                     depthwise_kernel,
5627                     strides=(1, 1),
5628                     padding='valid',
5629                     data_format=None,
5630                     dilation_rate=(1, 1)):
5631  """2D convolution with separable filters.
5632
5633  Args:
5634      x: input tensor
5635      depthwise_kernel: convolution kernel for the depthwise convolution.
5636      strides: strides tuple (length 2).
5637      padding: string, `"same"` or `"valid"`.
5638      data_format: string, `"channels_last"` or `"channels_first"`.
5639      dilation_rate: tuple of integers,
5640          dilation rates for the separable convolution.
5641
5642  Returns:
5643      Output tensor.
5644
5645  Raises:
5646      ValueError: if `data_format` is neither `channels_last` or
5647      `channels_first`.
5648  """
5649  if data_format is None:
5650    data_format = image_data_format()
5651  if data_format not in {'channels_first', 'channels_last'}:
5652    raise ValueError('Unknown data_format: ' + str(data_format))
5653
5654  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5655  padding = _preprocess_padding(padding)
5656  if tf_data_format == 'NHWC':
5657    strides = (1,) + strides + (1,)
5658  else:
5659    strides = (1, 1) + strides
5660
5661  x = nn.depthwise_conv2d(
5662      x,
5663      depthwise_kernel,
5664      strides=strides,
5665      padding=padding,
5666      rate=dilation_rate,
5667      data_format=tf_data_format)
5668  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5669    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5670  return x
5671
5672
5673@keras_export('keras.backend.conv3d')
5674@dispatch.add_dispatch_support
5675@doc_controls.do_not_generate_docs
5676def conv3d(x,
5677           kernel,
5678           strides=(1, 1, 1),
5679           padding='valid',
5680           data_format=None,
5681           dilation_rate=(1, 1, 1)):
5682  """3D convolution.
5683
5684  Args:
5685      x: Tensor or variable.
5686      kernel: kernel tensor.
5687      strides: strides tuple.
5688      padding: string, `"same"` or `"valid"`.
5689      data_format: string, `"channels_last"` or `"channels_first"`.
5690      dilation_rate: tuple of 3 integers.
5691
5692  Returns:
5693      A tensor, result of 3D convolution.
5694
5695  Raises:
5696      ValueError: if `data_format` is neither `channels_last` or
5697      `channels_first`.
5698  """
5699  if data_format is None:
5700    data_format = image_data_format()
5701  if data_format not in {'channels_first', 'channels_last'}:
5702    raise ValueError('Unknown data_format: ' + str(data_format))
5703
5704  x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5705  padding = _preprocess_padding(padding)
5706  x = nn.convolution(
5707      input=x,
5708      filter=kernel,
5709      dilation_rate=dilation_rate,
5710      strides=strides,
5711      padding=padding,
5712      data_format=tf_data_format)
5713  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5714    x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5715  return x
5716
5717
5718def conv3d_transpose(x,
5719                     kernel,
5720                     output_shape,
5721                     strides=(1, 1, 1),
5722                     padding='valid',
5723                     data_format=None):
5724  """3D deconvolution (i.e.
5725
5726  transposed convolution).
5727
5728  Args:
5729      x: input tensor.
5730      kernel: kernel tensor.
5731      output_shape: 1D int tensor for the output shape.
5732      strides: strides tuple.
5733      padding: string, "same" or "valid".
5734      data_format: string, `"channels_last"` or `"channels_first"`.
5735
5736  Returns:
5737      A tensor, result of transposed 3D convolution.
5738
5739  Raises:
5740      ValueError: if `data_format` is neither `channels_last` or
5741      `channels_first`.
5742  """
5743  if data_format is None:
5744    data_format = image_data_format()
5745  if data_format not in {'channels_first', 'channels_last'}:
5746    raise ValueError('Unknown data_format: ' + str(data_format))
5747  if isinstance(output_shape, (tuple, list)):
5748    output_shape = array_ops.stack(output_shape)
5749
5750  x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5751
5752  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5753    output_shape = (output_shape[0], output_shape[2], output_shape[3],
5754                    output_shape[4], output_shape[1])
5755  if output_shape[0] is None:
5756    output_shape = (array_ops.shape(x)[0],) + tuple(output_shape[1:])
5757    output_shape = array_ops.stack(list(output_shape))
5758
5759  padding = _preprocess_padding(padding)
5760  if tf_data_format == 'NDHWC':
5761    strides = (1,) + strides + (1,)
5762  else:
5763    strides = (1, 1) + strides
5764
5765  x = nn.conv3d_transpose(
5766      x,
5767      kernel,
5768      output_shape,
5769      strides,
5770      padding=padding,
5771      data_format=tf_data_format)
5772  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5773    x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5774  return x
5775
5776
5777@keras_export('keras.backend.pool2d')
5778@dispatch.add_dispatch_support
5779@doc_controls.do_not_generate_docs
5780def pool2d(x,
5781           pool_size,
5782           strides=(1, 1),
5783           padding='valid',
5784           data_format=None,
5785           pool_mode='max'):
5786  """2D Pooling.
5787
5788  Args:
5789      x: Tensor or variable.
5790      pool_size: tuple of 2 integers.
5791      strides: tuple of 2 integers.
5792      padding: string, `"same"` or `"valid"`.
5793      data_format: string, `"channels_last"` or `"channels_first"`.
5794      pool_mode: string, `"max"` or `"avg"`.
5795
5796  Returns:
5797      A tensor, result of 2D pooling.
5798
5799  Raises:
5800      ValueError: if `data_format` is neither `"channels_last"` or
5801      `"channels_first"`.
5802      ValueError: if `pool_size` is not a tuple of 2 integers.
5803      ValueError: if `strides` is not a tuple of 2 integers.
5804      ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
5805  """
5806  if data_format is None:
5807    data_format = image_data_format()
5808  if data_format not in {'channels_first', 'channels_last'}:
5809    raise ValueError('Unknown data_format: ' + str(data_format))
5810  if len(pool_size) != 2:
5811    raise ValueError('`pool_size` must be a tuple of 2 integers.')
5812  if len(strides) != 2:
5813    raise ValueError('`strides` must be a tuple of 2 integers.')
5814
5815  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5816  padding = _preprocess_padding(padding)
5817  if tf_data_format == 'NHWC':
5818    strides = (1,) + strides + (1,)
5819    pool_size = (1,) + pool_size + (1,)
5820  else:
5821    strides = (1, 1) + strides
5822    pool_size = (1, 1) + pool_size
5823
5824  if pool_mode == 'max':
5825    x = nn.max_pool(
5826        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5827  elif pool_mode == 'avg':
5828    x = nn.avg_pool(
5829        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5830  else:
5831    raise ValueError('Invalid pooling mode: ' + str(pool_mode))
5832
5833  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5834    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5835  return x
5836
5837
5838@keras_export('keras.backend.pool3d')
5839@dispatch.add_dispatch_support
5840@doc_controls.do_not_generate_docs
5841def pool3d(x,
5842           pool_size,
5843           strides=(1, 1, 1),
5844           padding='valid',
5845           data_format=None,
5846           pool_mode='max'):
5847  """3D Pooling.
5848
5849  Args:
5850      x: Tensor or variable.
5851      pool_size: tuple of 3 integers.
5852      strides: tuple of 3 integers.
5853      padding: string, `"same"` or `"valid"`.
5854      data_format: string, `"channels_last"` or `"channels_first"`.
5855      pool_mode: string, `"max"` or `"avg"`.
5856
5857  Returns:
5858      A tensor, result of 3D pooling.
5859
5860  Raises:
5861      ValueError: if `data_format` is neither `"channels_last"` or
5862      `"channels_first"`.
5863      ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
5864  """
5865  if data_format is None:
5866    data_format = image_data_format()
5867  if data_format not in {'channels_first', 'channels_last'}:
5868    raise ValueError('Unknown data_format: ' + str(data_format))
5869
5870  x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5871  padding = _preprocess_padding(padding)
5872  if tf_data_format == 'NDHWC':
5873    strides = (1,) + strides + (1,)
5874    pool_size = (1,) + pool_size + (1,)
5875  else:
5876    strides = (1, 1) + strides
5877    pool_size = (1, 1) + pool_size
5878
5879  if pool_mode == 'max':
5880    x = nn.max_pool3d(
5881        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5882  elif pool_mode == 'avg':
5883    x = nn.avg_pool3d(
5884        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5885  else:
5886    raise ValueError('Invalid pooling mode: ' + str(pool_mode))
5887
5888  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5889    x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5890  return x
5891
5892
5893def local_conv(inputs,
5894               kernel,
5895               kernel_size,
5896               strides,
5897               output_shape,
5898               data_format=None):
5899  """Apply N-D convolution with un-shared weights.
5900
5901  Args:
5902      inputs: (N+2)-D tensor with shape
5903          (batch_size, channels_in, d_in1, ..., d_inN)
5904          if data_format='channels_first', or
5905          (batch_size, d_in1, ..., d_inN, channels_in)
5906          if data_format='channels_last'.
5907      kernel: the unshared weight for N-D convolution,
5908          with shape (output_items, feature_dim, channels_out), where
5909          feature_dim = np.prod(kernel_size) * channels_in,
5910          output_items = np.prod(output_shape).
5911      kernel_size: a tuple of N integers, specifying the
5912          spatial dimensions of the N-D convolution window.
5913      strides: a tuple of N integers, specifying the strides
5914          of the convolution along the spatial dimensions.
5915      output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial
5916          dimensionality of the output.
5917      data_format: string, "channels_first" or "channels_last".
5918
5919  Returns:
5920      An (N+2)-D tensor with shape:
5921      (batch_size, channels_out) + output_shape
5922      if data_format='channels_first', or:
5923      (batch_size,) + output_shape + (channels_out,)
5924      if data_format='channels_last'.
5925
5926  Raises:
5927      ValueError: if `data_format` is neither
5928      `channels_last` nor `channels_first`.
5929  """
5930  if data_format is None:
5931    data_format = image_data_format()
5932  if data_format not in {'channels_first', 'channels_last'}:
5933    raise ValueError('Unknown data_format: ' + str(data_format))
5934
5935  kernel_shape = int_shape(kernel)
5936  feature_dim = kernel_shape[1]
5937  channels_out = kernel_shape[-1]
5938  ndims = len(output_shape)
5939  spatial_dimensions = list(range(ndims))
5940
5941  xs = []
5942  output_axes_ticks = [range(axis_max) for axis_max in output_shape]
5943  for position in itertools.product(*output_axes_ticks):
5944    slices = [slice(None)]
5945
5946    if data_format == 'channels_first':
5947      slices.append(slice(None))
5948
5949    slices.extend(
5950        slice(position[d] * strides[d], position[d] * strides[d] +
5951              kernel_size[d]) for d in spatial_dimensions)
5952
5953    if data_format == 'channels_last':
5954      slices.append(slice(None))
5955
5956    xs.append(reshape(inputs[slices], (1, -1, feature_dim)))
5957
5958  x_aggregate = concatenate(xs, axis=0)
5959  output = batch_dot(x_aggregate, kernel)
5960  output = reshape(output, output_shape + (-1, channels_out))
5961
5962  if data_format == 'channels_first':
5963    permutation = [ndims, ndims + 1] + spatial_dimensions
5964  else:
5965    permutation = [ndims] + spatial_dimensions + [ndims + 1]
5966
5967  return permute_dimensions(output, permutation)
5968
5969
5970@keras_export('keras.backend.local_conv1d')
5971@dispatch.add_dispatch_support
5972@doc_controls.do_not_generate_docs
5973def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
5974  """Apply 1D conv with un-shared weights.
5975
5976  Args:
5977      inputs: 3D tensor with shape:
5978          (batch_size, steps, input_dim)
5979          if data_format is "channels_last" or
5980          (batch_size, input_dim, steps)
5981          if data_format is "channels_first".
5982      kernel: the unshared weight for convolution,
5983          with shape (output_length, feature_dim, filters).
5984      kernel_size: a tuple of a single integer,
5985          specifying the length of the 1D convolution window.
5986      strides: a tuple of a single integer,
5987          specifying the stride length of the convolution.
5988      data_format: the data format, channels_first or channels_last.
5989
5990  Returns:
5991      A 3d tensor with shape:
5992      (batch_size, output_length, filters)
5993      if data_format='channels_first'
5994      or 3D tensor with shape:
5995      (batch_size, filters, output_length)
5996      if data_format='channels_last'.
5997  """
5998  output_shape = (kernel.shape[0],)
5999  return local_conv(inputs,
6000                    kernel,
6001                    kernel_size,
6002                    strides,
6003                    output_shape,
6004                    data_format)
6005
6006
6007@keras_export('keras.backend.local_conv2d')
6008@dispatch.add_dispatch_support
6009@doc_controls.do_not_generate_docs
6010def local_conv2d(inputs,
6011                 kernel,
6012                 kernel_size,
6013                 strides,
6014                 output_shape,
6015                 data_format=None):
6016  """Apply 2D conv with un-shared weights.
6017
6018  Args:
6019      inputs: 4D tensor with shape:
6020          (batch_size, filters, new_rows, new_cols)
6021          if data_format='channels_first'
6022          or 4D tensor with shape:
6023          (batch_size, new_rows, new_cols, filters)
6024          if data_format='channels_last'.
6025      kernel: the unshared weight for convolution,
6026          with shape (output_items, feature_dim, filters).
6027      kernel_size: a tuple of 2 integers, specifying the
6028          width and height of the 2D convolution window.
6029      strides: a tuple of 2 integers, specifying the strides
6030          of the convolution along the width and height.
6031      output_shape: a tuple with (output_row, output_col).
6032      data_format: the data format, channels_first or channels_last.
6033
6034  Returns:
6035      A 4D tensor with shape:
6036      (batch_size, filters, new_rows, new_cols)
6037      if data_format='channels_first'
6038      or 4D tensor with shape:
6039      (batch_size, new_rows, new_cols, filters)
6040      if data_format='channels_last'.
6041  """
6042  return local_conv(inputs,
6043                    kernel,
6044                    kernel_size,
6045                    strides,
6046                    output_shape,
6047                    data_format)
6048
6049
6050@keras_export('keras.backend.bias_add')
6051@dispatch.add_dispatch_support
6052@doc_controls.do_not_generate_docs
6053def bias_add(x, bias, data_format=None):
6054  """Adds a bias vector to a tensor.
6055
6056  Args:
6057      x: Tensor or variable.
6058      bias: Bias tensor to add.
6059      data_format: string, `"channels_last"` or `"channels_first"`.
6060
6061  Returns:
6062      Output tensor.
6063
6064  Raises:
6065      ValueError: In one of the two cases below:
6066                  1. invalid `data_format` argument.
6067                  2. invalid bias shape.
6068                     the bias should be either a vector or
6069                     a tensor with ndim(x) - 1 dimension
6070  """
6071  if data_format is None:
6072    data_format = image_data_format()
6073  if data_format not in {'channels_first', 'channels_last'}:
6074    raise ValueError('Unknown data_format: ' + str(data_format))
6075  bias_shape = int_shape(bias)
6076  if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1:
6077    raise ValueError(
6078        'Unexpected bias dimensions %d, expect to be 1 or %d dimensions' %
6079        (len(bias_shape), ndim(x)))
6080
6081  if len(bias_shape) == 1:
6082    if data_format == 'channels_first':
6083      return nn.bias_add(x, bias, data_format='NCHW')
6084    return nn.bias_add(x, bias, data_format='NHWC')
6085  if ndim(x) in (3, 4, 5):
6086    if data_format == 'channels_first':
6087      bias_reshape_axis = (1, bias_shape[-1]) + bias_shape[:-1]
6088      return x + reshape(bias, bias_reshape_axis)
6089    return x + reshape(bias, (1,) + bias_shape)
6090  return nn.bias_add(x, bias)
6091
6092
6093# RANDOMNESS
6094
6095
6096@keras_export('keras.backend.random_normal')
6097@dispatch.add_dispatch_support
6098@doc_controls.do_not_generate_docs
6099def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
6100  """Returns a tensor with normal distribution of values.
6101
6102  It is an alias to `tf.random.normal`.
6103
6104  Args:
6105      shape: A tuple of integers, the shape of tensor to create.
6106      mean: A float, the mean value of the normal distribution to draw samples.
6107        Default to 0.0.
6108      stddev: A float, the standard deviation of the normal distribution
6109        to draw samples. Default to 1.0.
6110      dtype: `tf.dtypes.DType`, dtype of returned tensor. Default to use Keras
6111        backend dtype which is float32.
6112      seed: Integer, random seed. Will use a random numpy integer when not
6113        specified.
6114
6115  Returns:
6116      A tensor with normal distribution of values.
6117
6118  Example:
6119
6120  >>> random_normal_tensor = tf.keras.backend.random_normal(shape=(2,3),
6121  ... mean=0.0, stddev=1.0)
6122  >>> random_normal_tensor
6123  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...,
6124  dtype=float32)>
6125  """
6126  if dtype is None:
6127    dtype = floatx()
6128  if seed is None:
6129    seed = np.random.randint(10e6)
6130  return random_ops.random_normal(
6131      shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed)
6132
6133
6134@keras_export('keras.backend.random_uniform')
6135@dispatch.add_dispatch_support
6136@doc_controls.do_not_generate_docs
6137def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
6138  """Returns a tensor with uniform distribution of values.
6139
6140  Args:
6141      shape: A tuple of integers, the shape of tensor to create.
6142      minval: A float, lower boundary of the uniform distribution
6143          to draw samples.
6144      maxval: A float, upper boundary of the uniform distribution
6145          to draw samples.
6146      dtype: String, dtype of returned tensor.
6147      seed: Integer, random seed.
6148
6149  Returns:
6150      A tensor.
6151
6152  Example:
6153
6154  >>> random_uniform_tensor = tf.keras.backend.random_uniform(shape=(2,3),
6155  ... minval=0.0, maxval=1.0)
6156  >>> random_uniform_tensor
6157  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...,
6158  dtype=float32)>
6159  """
6160  if dtype is None:
6161    dtype = floatx()
6162  if seed is None:
6163    seed = np.random.randint(10e6)
6164  return random_ops.random_uniform(
6165      shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed)
6166
6167
6168@keras_export('keras.backend.random_binomial')
6169@dispatch.add_dispatch_support
6170@doc_controls.do_not_generate_docs
6171def random_binomial(shape, p=0.0, dtype=None, seed=None):
6172  """Returns a tensor with random binomial distribution of values.
6173
6174  DEPRECATED, use `tf.keras.backend.random_bernoulli` instead.
6175
6176  The binomial distribution with parameters `n` and `p` is the probability
6177  distribution of the number of successful Bernoulli process. Only supports
6178  `n` = 1 for now.
6179
6180  Args:
6181      shape: A tuple of integers, the shape of tensor to create.
6182      p: A float, `0. <= p <= 1`, probability of binomial distribution.
6183      dtype: String, dtype of returned tensor.
6184      seed: Integer, random seed.
6185
6186  Returns:
6187      A tensor.
6188
6189  Example:
6190
6191  >>> random_binomial_tensor = tf.keras.backend.random_binomial(shape=(2,3),
6192  ... p=0.5)
6193  >>> random_binomial_tensor
6194  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...,
6195  dtype=float32)>
6196  """
6197  warnings.warn('`tf.keras.backend.random_binomial` is deprecated, '
6198                'and will be removed in a future version.'
6199                'Please use `tf.keras.backend.random_bernoulli` instead.')
6200  return random_bernoulli(shape, p, dtype, seed)
6201
6202
6203@keras_export('keras.backend.random_bernoulli')
6204@dispatch.add_dispatch_support
6205@doc_controls.do_not_generate_docs
6206def random_bernoulli(shape, p=0.0, dtype=None, seed=None):
6207  """Returns a tensor with random bernoulli distribution of values.
6208
6209  Args:
6210      shape: A tuple of integers, the shape of tensor to create.
6211      p: A float, `0. <= p <= 1`, probability of bernoulli distribution.
6212      dtype: String, dtype of returned tensor.
6213      seed: Integer, random seed.
6214
6215  Returns:
6216      A tensor.
6217  """
6218  if dtype is None:
6219    dtype = floatx()
6220  if seed is None:
6221    seed = np.random.randint(10e6)
6222  return array_ops.where_v2(
6223      random_ops.random_uniform(shape, dtype=dtype, seed=seed) <= p,
6224      array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
6225
6226
6227@keras_export('keras.backend.truncated_normal')
6228@dispatch.add_dispatch_support
6229@doc_controls.do_not_generate_docs
6230def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
6231  """Returns a tensor with truncated random normal distribution of values.
6232
6233  The generated values follow a normal distribution
6234  with specified mean and standard deviation,
6235  except that values whose magnitude is more than
6236  two standard deviations from the mean are dropped and re-picked.
6237
6238  Args:
6239      shape: A tuple of integers, the shape of tensor to create.
6240      mean: Mean of the values.
6241      stddev: Standard deviation of the values.
6242      dtype: String, dtype of returned tensor.
6243      seed: Integer, random seed.
6244
6245  Returns:
6246      A tensor.
6247  """
6248  if dtype is None:
6249    dtype = floatx()
6250  if seed is None:
6251    seed = np.random.randint(10e6)
6252  return random_ops.truncated_normal(
6253      shape, mean, stddev, dtype=dtype, seed=seed)
6254
6255
6256# CTC
6257# TensorFlow has a native implementation, but it uses sparse tensors
6258# and therefore requires a wrapper for Keras. The functions below convert
6259# dense to sparse tensors and also wraps up the beam search code that is
6260# in TensorFlow's CTC implementation
6261
6262
6263@keras_export('keras.backend.ctc_label_dense_to_sparse')
6264@dispatch.add_dispatch_support
6265@doc_controls.do_not_generate_docs
6266def ctc_label_dense_to_sparse(labels, label_lengths):
6267  """Converts CTC labels from dense to sparse.
6268
6269  Args:
6270      labels: dense CTC labels.
6271      label_lengths: length of the labels.
6272
6273  Returns:
6274      A sparse tensor representation of the labels.
6275  """
6276  label_shape = array_ops.shape(labels)
6277  num_batches_tns = array_ops.stack([label_shape[0]])
6278  max_num_labels_tns = array_ops.stack([label_shape[1]])
6279
6280  def range_less_than(old_input, current_input):
6281    return array_ops.expand_dims(
6282        math_ops.range(array_ops.shape(old_input)[1]), 0) < array_ops.fill(
6283            max_num_labels_tns, current_input)
6284
6285  init = math_ops.cast(
6286      array_ops.fill([1, label_shape[1]], 0), dtypes_module.bool)
6287  dense_mask = functional_ops.scan(
6288      range_less_than, label_lengths, initializer=init, parallel_iterations=1)
6289  dense_mask = dense_mask[:, 0, :]
6290
6291  label_array = array_ops.reshape(
6292      array_ops.tile(math_ops.range(0, label_shape[1]), num_batches_tns),
6293      label_shape)
6294  label_ind = array_ops.boolean_mask(label_array, dense_mask)
6295
6296  batch_array = array_ops.transpose(
6297      array_ops.reshape(
6298          array_ops.tile(math_ops.range(0, label_shape[0]), max_num_labels_tns),
6299          reverse(label_shape, 0)))
6300  batch_ind = array_ops.boolean_mask(batch_array, dense_mask)
6301  indices = array_ops.transpose(
6302      array_ops.reshape(concatenate([batch_ind, label_ind], axis=0), [2, -1]))
6303
6304  vals_sparse = array_ops.gather_nd(labels, indices)
6305
6306  return sparse_tensor.SparseTensor(
6307      math_ops.cast(indices, dtypes_module.int64), vals_sparse,
6308      math_ops.cast(label_shape, dtypes_module.int64))
6309
6310
6311@keras_export('keras.backend.ctc_batch_cost')
6312@dispatch.add_dispatch_support
6313@doc_controls.do_not_generate_docs
6314def ctc_batch_cost(y_true, y_pred, input_length, label_length):
6315  """Runs CTC loss algorithm on each batch element.
6316
6317  Args:
6318      y_true: tensor `(samples, max_string_length)`
6319          containing the truth labels.
6320      y_pred: tensor `(samples, time_steps, num_categories)`
6321          containing the prediction, or output of the softmax.
6322      input_length: tensor `(samples, 1)` containing the sequence length for
6323          each batch item in `y_pred`.
6324      label_length: tensor `(samples, 1)` containing the sequence length for
6325          each batch item in `y_true`.
6326
6327  Returns:
6328      Tensor with shape (samples,1) containing the
6329          CTC loss of each element.
6330  """
6331  label_length = math_ops.cast(
6332      array_ops.squeeze(label_length, axis=-1), dtypes_module.int32)
6333  input_length = math_ops.cast(
6334      array_ops.squeeze(input_length, axis=-1), dtypes_module.int32)
6335  sparse_labels = math_ops.cast(
6336      ctc_label_dense_to_sparse(y_true, label_length), dtypes_module.int32)
6337
6338  y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
6339
6340  return array_ops.expand_dims(
6341      ctc.ctc_loss(
6342          inputs=y_pred, labels=sparse_labels, sequence_length=input_length), 1)
6343
6344
6345@keras_export('keras.backend.ctc_decode')
6346@dispatch.add_dispatch_support
6347@doc_controls.do_not_generate_docs
6348def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
6349  """Decodes the output of a softmax.
6350
6351  Can use either greedy search (also known as best path)
6352  or a constrained dictionary search.
6353
6354  Args:
6355      y_pred: tensor `(samples, time_steps, num_categories)`
6356          containing the prediction, or output of the softmax.
6357      input_length: tensor `(samples, )` containing the sequence length for
6358          each batch item in `y_pred`.
6359      greedy: perform much faster best-path search if `true`.
6360          This does not use a dictionary.
6361      beam_width: if `greedy` is `false`: a beam search decoder will be used
6362          with a beam of this width.
6363      top_paths: if `greedy` is `false`,
6364          how many of the most probable paths will be returned.
6365
6366  Returns:
6367      Tuple:
6368          List: if `greedy` is `true`, returns a list of one element that
6369              contains the decoded sequence.
6370              If `false`, returns the `top_paths` most probable
6371              decoded sequences.
6372              Each decoded sequence has shape (samples, time_steps).
6373              Important: blank labels are returned as `-1`.
6374          Tensor `(top_paths, )` that contains
6375              the log probability of each decoded sequence.
6376  """
6377  input_shape = shape(y_pred)
6378  num_samples, num_steps = input_shape[0], input_shape[1]
6379  y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
6380  input_length = math_ops.cast(input_length, dtypes_module.int32)
6381
6382  if greedy:
6383    (decoded, log_prob) = ctc.ctc_greedy_decoder(
6384        inputs=y_pred, sequence_length=input_length)
6385  else:
6386    (decoded, log_prob) = ctc.ctc_beam_search_decoder(
6387        inputs=y_pred,
6388        sequence_length=input_length,
6389        beam_width=beam_width,
6390        top_paths=top_paths)
6391  decoded_dense = []
6392  for st in decoded:
6393    st = sparse_tensor.SparseTensor(
6394        st.indices, st.values, (num_samples, num_steps))
6395    decoded_dense.append(
6396        sparse_ops.sparse_tensor_to_dense(sp_input=st, default_value=-1))
6397  return (decoded_dense, log_prob)
6398
6399
6400# HIGH ORDER FUNCTIONS
6401
6402
6403@keras_export('keras.backend.map_fn')
6404@doc_controls.do_not_generate_docs
6405def map_fn(fn, elems, name=None, dtype=None):
6406  """Map the function fn over the elements elems and return the outputs.
6407
6408  Args:
6409      fn: Callable that will be called upon each element in elems
6410      elems: tensor
6411      name: A string name for the map node in the graph
6412      dtype: Output data type.
6413
6414  Returns:
6415      Tensor with dtype `dtype`.
6416  """
6417  return map_fn_lib.map_fn(fn, elems, name=name, dtype=dtype)
6418
6419
6420@keras_export('keras.backend.foldl')
6421@doc_controls.do_not_generate_docs
6422def foldl(fn, elems, initializer=None, name=None):
6423  """Reduce elems using fn to combine them from left to right.
6424
6425  Args:
6426      fn: Callable that will be called upon each element in elems and an
6427          accumulator, for instance `lambda acc, x: acc + x`
6428      elems: tensor
6429      initializer: The first value used (`elems[0]` in case of None)
6430      name: A string name for the foldl node in the graph
6431
6432  Returns:
6433      Tensor with same type and shape as `initializer`.
6434  """
6435  return functional_ops.foldl(fn, elems, initializer=initializer, name=name)
6436
6437
6438@keras_export('keras.backend.foldr')
6439@doc_controls.do_not_generate_docs
6440def foldr(fn, elems, initializer=None, name=None):
6441  """Reduce elems using fn to combine them from right to left.
6442
6443  Args:
6444      fn: Callable that will be called upon each element in elems and an
6445          accumulator, for instance `lambda acc, x: acc + x`
6446      elems: tensor
6447      initializer: The first value used (`elems[-1]` in case of None)
6448      name: A string name for the foldr node in the graph
6449
6450  Returns:
6451      Same type and shape as initializer
6452  """
6453  return functional_ops.foldr(fn, elems, initializer=initializer, name=name)
6454
6455# Load Keras default configuration from config file if present.
6456# Set Keras base dir path given KERAS_HOME env variable, if applicable.
6457# Otherwise either ~/.keras or /tmp.
6458if 'KERAS_HOME' in os.environ:
6459  _keras_dir = os.environ.get('KERAS_HOME')
6460else:
6461  _keras_base_dir = os.path.expanduser('~')
6462  _keras_dir = os.path.join(_keras_base_dir, '.keras')
6463_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
6464if os.path.exists(_config_path):
6465  try:
6466    with open(_config_path) as fh:
6467      _config = json.load(fh)
6468  except ValueError:
6469    _config = {}
6470  _floatx = _config.get('floatx', floatx())
6471  assert _floatx in {'float16', 'float32', 'float64'}
6472  _epsilon = _config.get('epsilon', epsilon())
6473  assert isinstance(_epsilon, float)
6474  _image_data_format = _config.get('image_data_format', image_data_format())
6475  assert _image_data_format in {'channels_last', 'channels_first'}
6476  set_floatx(_floatx)
6477  set_epsilon(_epsilon)
6478  set_image_data_format(_image_data_format)
6479
6480# Save config file.
6481if not os.path.exists(_keras_dir):
6482  try:
6483    os.makedirs(_keras_dir)
6484  except OSError:
6485    # Except permission denied and potential race conditions
6486    # in multi-threaded environments.
6487    pass
6488
6489if not os.path.exists(_config_path):
6490  _config = {
6491      'floatx': floatx(),
6492      'epsilon': epsilon(),
6493      'backend': 'tensorflow',
6494      'image_data_format': image_data_format()
6495  }
6496  try:
6497    with open(_config_path, 'w') as f:
6498      f.write(json.dumps(_config, indent=4))
6499  except IOError:
6500    # Except permission denied.
6501    pass
6502
6503
6504def configure_and_create_distributed_session(distribution_strategy):
6505  """Configure session config and create a session with it."""
6506
6507  def _create_session(distribution_strategy):
6508    """Create the Distributed Strategy session."""
6509    session_config = get_default_session_config()
6510
6511    # If a session already exists, merge in its config; in the case there is a
6512    # conflict, take values of the existing config.
6513    global _SESSION
6514    if getattr(_SESSION, 'session', None) and _SESSION.session._config:
6515      session_config.MergeFrom(_SESSION.session._config)
6516
6517    if is_tpu_strategy(distribution_strategy):
6518      # TODO(priyag, yuefengz): Remove this workaround when Distribute
6519      # Coordinator is integrated with keras and we can create a session from
6520      # there.
6521      distribution_strategy.configure(session_config)
6522      master = distribution_strategy.extended._tpu_cluster_resolver.master()  # pylint: disable=protected-access
6523      session = session_module.Session(config=session_config, target=master)
6524    else:
6525      worker_context = dc_context.get_current_worker_context()
6526      if worker_context:
6527        dc_session_config = worker_context.session_config
6528        # Merge the default session config to the one from distribute
6529        # coordinator, which is fine for now since they don't have
6530        # conflicting configurations.
6531        dc_session_config.MergeFrom(session_config)
6532        session = session_module.Session(
6533            config=dc_session_config, target=worker_context.master_target)
6534      else:
6535        distribution_strategy.configure(session_config)
6536        session = session_module.Session(config=session_config)
6537
6538    set_session(session)
6539
6540  if distribution_strategy.extended._in_multi_worker_mode():
6541    dc.run_distribute_coordinator(
6542        _create_session,
6543        distribution_strategy,
6544        mode='independent_worker')
6545  else:
6546    _create_session(distribution_strategy)
6547
6548
6549def is_tpu_strategy(strategy):
6550  """We're executing TPU Strategy."""
6551  return (strategy is not None and
6552          strategy.__class__.__name__.startswith('TPUStrategy'))
6553
6554
6555def cast_variables_to_tensor(tensors):
6556
6557  def _cast_variables_to_tensor(tensor):
6558    if isinstance(tensor, variables_module.Variable):
6559      return array_ops.identity(tensor)
6560    return tensor
6561
6562  return nest.map_structure(_cast_variables_to_tensor, tensors)
6563
6564
6565def _is_symbolic_tensor(x):
6566  return tensor_util.is_tf_type(x) and not isinstance(x, ops.EagerTensor)
6567
6568
6569def convert_inputs_if_ragged(inputs):
6570  """Converts any ragged tensors to dense."""
6571
6572  def _convert_ragged_input(inputs):
6573    if isinstance(inputs, ragged_tensor.RaggedTensor):
6574      return inputs.to_tensor()
6575    return inputs
6576
6577  flat_inputs = nest.flatten(inputs)
6578  contains_ragged = py_any(
6579      isinstance(i, ragged_tensor.RaggedTensor) for i in flat_inputs)
6580
6581  if not contains_ragged:
6582    return inputs, None
6583
6584  inputs = nest.map_structure(_convert_ragged_input, inputs)
6585  # Multiple mask are not yet supported, so one mask is used on all inputs.
6586  # We approach this similarly when using row lengths to ignore steps.
6587  nested_row_lengths = math_ops.cast(flat_inputs[0].nested_row_lengths()[0],
6588                                     'int32')
6589  return inputs, nested_row_lengths
6590
6591
6592def maybe_convert_to_ragged(is_ragged_input, output, nested_row_lengths):
6593  """Converts any ragged input back to its initial structure."""
6594  if not is_ragged_input:
6595    return output
6596
6597  return ragged_tensor.RaggedTensor.from_tensor(output, nested_row_lengths)
6598
6599
6600class ContextValueCache(weakref.WeakKeyDictionary):
6601  """Container that caches (possibly tensor) values based on the context.
6602
6603  This class is similar to defaultdict, where values may be produced by the
6604  default factory specified during initialization. This class also has a default
6605  value for the key (when key is `None`) -- the key is set to the current graph
6606  or eager context. The default factories for key and value are only used in
6607  `__getitem__` and `setdefault`. The `.get()` behavior remains the same.
6608
6609  This object will return the value of the current graph or closest parent graph
6610  if the current graph is a function. This is to reflect the fact that if a
6611  tensor is created in eager/graph, child functions may capture that tensor.
6612
6613  The default factory method may accept keyword arguments (unlike defaultdict,
6614  which only accepts callables with 0 arguments). To pass keyword arguments to
6615  `default_factory`, use the `setdefault` method instead of `__getitem__`.
6616
6617  An example of how this class can be used in different contexts:
6618
6619  ```
6620  cache = ContextValueCache(int)
6621
6622  # Eager mode
6623  cache[None] += 2
6624  cache[None] += 4
6625  assert cache[None] == 6
6626
6627  # Graph mode
6628  with tf.Graph().as_default() as g:
6629    cache[None] += 5
6630    cache[g] += 3
6631  assert cache[g] == 8
6632  ```
6633
6634  Example of a default factory with arguments:
6635
6636  ```
6637  cache = ContextValueCache(lambda x: x + 1)
6638  g = tf.get_default_graph()
6639
6640  # Example with keyword argument.
6641  value = cache.setdefault(key=g, kwargs={'x': 3})
6642  assert cache[g] == 4
6643  ```
6644  """
6645
6646  def __init__(self, default_factory):
6647    self.default_factory = default_factory
6648    weakref.WeakKeyDictionary.__init__(self)
6649
6650  def _key(self):
6651    if context.executing_eagerly():
6652      return _DUMMY_EAGER_GRAPH.key
6653    else:
6654      return ops.get_default_graph()
6655
6656  def _get_parent_graph(self, graph):
6657    """Returns the parent graph or dummy eager object."""
6658    # TODO(b/149317164): Currently FuncGraphs use ops.get_default_graph() as the
6659    # outer graph. This results in outer_graph always being a Graph,
6660    # even in eager mode (get_default_graph will create a new Graph if there
6661    # isn't a default graph). Because of this bug, we have to specially set the
6662    # key when eager execution is enabled.
6663    parent_graph = graph.outer_graph
6664    if (not isinstance(parent_graph, func_graph.FuncGraph) and
6665        ops.executing_eagerly_outside_functions()):
6666      return _DUMMY_EAGER_GRAPH.key
6667    return parent_graph
6668
6669  def _get_recursive(self, key):
6670    """Gets the value at key or the closest parent graph."""
6671    value = self.get(key)
6672    if value is not None:
6673      return value
6674
6675    # Since FuncGraphs are able to capture tensors and variables from their
6676    # parent graphs, recursively search to see if there is a value stored for
6677    # one of the parent graphs.
6678    if isinstance(key, func_graph.FuncGraph):
6679      return self._get_recursive(self._get_parent_graph(key))
6680    return None
6681
6682  def __getitem__(self, key):
6683    """Gets the value at key (or current context), or sets default value.
6684
6685    Args:
6686      key: May be `None` or `Graph`object. When `None`, the key is set to the
6687        current context.
6688
6689    Returns:
6690      Either the cached or default value.
6691    """
6692    if key is None:
6693      key = self._key()
6694
6695    value = self._get_recursive(key)
6696    if value is None:
6697      value = self[key] = self.default_factory()  # pylint:disable=not-callable
6698    return value
6699
6700  def setdefault(self, key=None, default=None, kwargs=None):
6701    """Sets the default value if key is not in dict, and returns the value."""
6702    if key is None:
6703      key = self._key()
6704    kwargs = kwargs or {}
6705
6706    if default is None and key not in self:
6707      default = self.default_factory(**kwargs)
6708    return weakref.WeakKeyDictionary.setdefault(self, key, default)
6709
6710# This dictionary holds a mapping {graph: learning_phase}. In eager mode, a
6711# dummy object is used.
6712# A learning phase is a bool tensor used to run Keras models in
6713# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
6714_GRAPH_LEARNING_PHASES = ContextValueCache(_default_learning_phase)
6715
6716# This dictionary holds a mapping between a graph and variables to initialize
6717# in the graph.
6718_GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet)
6719
6720# This dictionary holds a mapping between a graph and TF optimizers created in
6721# the graph.
6722_GRAPH_TF_OPTIMIZERS = ContextValueCache(object_identity.ObjectIdentityWeakSet)
6723