• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=redefined-outer-name
17# pylint: disable=redefined-builtin
18"""Keras backend API.
19"""
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import collections
25import itertools
26import json
27import os
28import sys
29import threading
30import weakref
31
32import numpy as np
33
34from tensorflow.core.protobuf import config_pb2
35from tensorflow.python import tf2
36from tensorflow.python.client import session as session_module
37from tensorflow.python.distribute import distribute_coordinator as dc
38from tensorflow.python.distribute import distribute_coordinator_context as dc_context
39from tensorflow.python.distribute import distribution_strategy_context
40from tensorflow.python.eager import context
41from tensorflow.python.eager import function as eager_function
42from tensorflow.python.eager import lift_to_graph
43from tensorflow.python.framework import composite_tensor
44from tensorflow.python.framework import config
45from tensorflow.python.framework import constant_op
46from tensorflow.python.framework import device as tfdev
47from tensorflow.python.framework import dtypes as dtypes_module
48from tensorflow.python.framework import func_graph
49from tensorflow.python.framework import ops
50from tensorflow.python.framework import sparse_tensor
51from tensorflow.python.framework import tensor_shape
52from tensorflow.python.framework import tensor_util
53from tensorflow.python.keras import backend_config
54from tensorflow.python.ops import array_ops
55from tensorflow.python.ops import clip_ops
56from tensorflow.python.ops import control_flow_ops
57from tensorflow.python.ops import control_flow_util
58from tensorflow.python.ops import ctc_ops as ctc
59from tensorflow.python.ops import functional_ops
60from tensorflow.python.ops import gradients as gradients_module
61from tensorflow.python.ops import image_ops
62from tensorflow.python.ops import init_ops
63from tensorflow.python.ops import linalg_ops
64from tensorflow.python.ops import logging_ops
65from tensorflow.python.ops import map_fn as map_fn_lib
66from tensorflow.python.ops import math_ops
67from tensorflow.python.ops import nn
68from tensorflow.python.ops import random_ops
69from tensorflow.python.ops import sparse_ops
70from tensorflow.python.ops import state_ops
71from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
72from tensorflow.python.ops import tensor_array_ops
73from tensorflow.python.ops import variables as variables_module
74from tensorflow.python.ops.ragged import ragged_concat_ops
75from tensorflow.python.ops.ragged import ragged_tensor
76from tensorflow.python.platform import tf_logging as logging
77from tensorflow.python.training import moving_averages
78from tensorflow.python.util import nest
79from tensorflow.python.util import object_identity
80from tensorflow.python.util import tf_contextlib
81from tensorflow.python.util import tf_inspect
82from tensorflow.python.util.tf_export import keras_export
83
84py_all = all
85py_sum = sum
86py_any = any
87
88# INTERNAL UTILS
89
90# The internal graph maintained by Keras and used by the symbolic Keras APIs
91# while executing eagerly (such as the functional API for model-building).
92_GRAPH = None
93
94# A graph which is used for constructing functions in eager mode.
95_CURRENT_SCRATCH_GRAPH = None
96
97# This is a thread local object that will hold the default internal TF session
98# used by Keras. It can be set manually via `set_session(sess)`.
99_SESSION = threading.local()
100
101# This dictionary holds a mapping {graph: learning_phase}.
102# A learning phase is a bool tensor used to run Keras models in
103# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
104_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
105
106# This dictionary holds a mapping {graph: set_of_freezable_variables}.
107# Each set tracks objects created via `freezable_variable` in the graph.
108_FREEZABLE_VARS = weakref.WeakKeyDictionary()
109
110
111# _DUMMY_EAGER_GRAPH.key is used as a key in _GRAPH_LEARNING_PHASES.
112# We keep a separate reference to it to make sure it does not get removed from
113# _GRAPH_LEARNING_PHASES.
114# _DummyEagerGraph inherits from threading.local to make its `key` attribute
115# thread local. This is needed to make set_learning_phase affect only the
116# current thread during eager execution (see b/123096885 for more details).
117class _DummyEagerGraph(threading.local):
118  """_DummyEagerGraph provides a thread local `key` attribute.
119
120  We can't use threading.local directly, i.e. without subclassing, because
121  gevent monkey patches threading.local and its version does not support
122  weak references.
123  """
124
125  class _WeakReferencableClass(object):
126    """This dummy class is needed for two reasons.
127
128    - We need something that supports weak references. Basic types like string
129    and ints don't.
130    - We need something whose hash and equality are based on object identity
131    to make sure they are treated as different keys to _GRAPH_LEARNING_PHASES.
132
133    An empty Python class satisfies both of these requirements.
134    """
135    pass
136
137  def __init__(self):
138    # Constructors for classes subclassing threading.local run once
139    # per thread accessing something in the class. Thus, each thread will
140    # get a different key.
141    super(_DummyEagerGraph, self).__init__()
142    self.key = _DummyEagerGraph._WeakReferencableClass()
143
144
145_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
146
147# This boolean flag can be set to True to leave variable initialization
148# up to the user.
149# Change its value via `manual_variable_initialization(value)`.
150_MANUAL_VAR_INIT = False
151
152# This list holds the available devices.
153# It is populated when `_get_available_gpus()` is called for the first time.
154# We assume our devices don't change henceforth.
155_LOCAL_DEVICES = None
156
157# This dictionary holds a mapping between a graph and variables to initialize
158# in the graph.
159_GRAPH_VARIABLES = weakref.WeakKeyDictionary()
160
161# This dictionary holds a mapping between a graph and TF optimizers created in
162# the graph.
163_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
164
165# The below functions are kept accessible from backend for compatibility.
166epsilon = backend_config.epsilon
167floatx = backend_config.floatx
168image_data_format = backend_config.image_data_format
169set_epsilon = backend_config.set_epsilon
170set_floatx = backend_config.set_floatx
171set_image_data_format = backend_config.set_image_data_format
172
173
174@keras_export('keras.backend.backend')
175def backend():
176  """Publicly accessible method for determining the current backend.
177
178  Only exists for API compatibility with multi-backend Keras.
179
180  Returns:
181      The string "tensorflow".
182  """
183  return 'tensorflow'
184
185
186@keras_export('keras.backend.cast_to_floatx')
187def cast_to_floatx(x):
188  """Cast a Numpy array to the default Keras float type.
189
190  Arguments:
191      x: Numpy array or TensorFlow tensor.
192
193  Returns:
194      The same array (Numpy array if `x` was a Numpy array, or TensorFlow tensor
195      if `x` was a tensor), cast to its new type.
196
197  Example:
198
199  >>> tf.keras.backend.floatx()
200  'float32'
201  >>> arr = np.array([1.0, 2.0], dtype='float64')
202  >>> arr.dtype
203  dtype('float64')
204  >>> new_arr = cast_to_floatx(arr)
205  >>> new_arr
206  array([1.,  2.], dtype=float32)
207  >>> new_arr.dtype
208  dtype('float32')
209
210  """
211  if isinstance(x, (ops.Tensor,
212                    variables_module.Variable,
213                    sparse_tensor.SparseTensor)):
214    return math_ops.cast(x, dtype=floatx())
215  return np.asarray(x, dtype=floatx())
216
217
218# A global dictionary mapping graph objects to an index of counters used
219# for various layer/optimizer names in each graph.
220# Allows to give unique autogenerated names to layers, in a graph-specific way.
221PER_GRAPH_OBJECT_NAME_UIDS = weakref.WeakKeyDictionary()
222
223
224@keras_export('keras.backend.get_uid')
225def get_uid(prefix=''):
226  """Associates a string prefix with an integer counter in a TensorFlow graph.
227
228  Arguments:
229    prefix: String prefix to index.
230
231  Returns:
232    Unique integer ID.
233
234  Example:
235
236  >>> get_uid('dense')
237  1
238  >>> get_uid('dense')
239  2
240
241  """
242  graph = get_graph()
243  if graph not in PER_GRAPH_OBJECT_NAME_UIDS:
244    PER_GRAPH_OBJECT_NAME_UIDS[graph] = collections.defaultdict(int)
245  layer_name_uids = PER_GRAPH_OBJECT_NAME_UIDS[graph]
246  layer_name_uids[prefix] += 1
247  return layer_name_uids[prefix]
248
249
250@keras_export('keras.backend.reset_uids')
251def reset_uids():
252  """Resets graph identifiers.
253  """
254
255  PER_GRAPH_OBJECT_NAME_UIDS.clear()
256
257
258@keras_export('keras.backend.clear_session')
259def clear_session():
260  """Destroys the current TF graph and creates a new one.
261
262  Useful to avoid clutter from old models / layers.
263  """
264  global _SESSION
265  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
266  global _GRAPH_VARIABLES  # pylint: disable=global-variable-not-assigned
267  global _GRAPH_TF_OPTIMIZERS  # pylint: disable=global-variable-not-assigned
268  global _GRAPH
269  global _FREEZABLE_VARS
270  _GRAPH = None
271  ops.reset_default_graph()
272  reset_uids()
273  _SESSION.session = None
274  graph = get_graph()
275  with graph.as_default():
276    with name_scope(''):
277      phase = array_ops.placeholder_with_default(
278          False, shape=(), name='keras_learning_phase')
279    _GRAPH_LEARNING_PHASES = {}
280    _GRAPH_LEARNING_PHASES[graph] = phase
281    _GRAPH_VARIABLES.pop(graph, None)
282    _GRAPH_TF_OPTIMIZERS.pop(graph, None)
283    _FREEZABLE_VARS.pop(graph, None)
284
285
286@keras_export('keras.backend.manual_variable_initialization')
287def manual_variable_initialization(value):
288  """Sets the manual variable initialization flag.
289
290  This boolean flag determines whether
291  variables should be initialized
292  as they are instantiated (default), or if
293  the user should handle the initialization
294  (e.g. via `tf.compat.v1.initialize_all_variables()`).
295
296  Arguments:
297      value: Python boolean.
298  """
299  global _MANUAL_VAR_INIT
300  _MANUAL_VAR_INIT = value
301
302
303@keras_export('keras.backend.learning_phase')
304def learning_phase():
305  """Returns the learning phase flag.
306
307  The learning phase flag is a bool tensor (0 = test, 1 = train)
308  to be passed as input to any Keras function
309  that uses a different behavior at train time and test time.
310
311  Returns:
312      Learning phase (scalar integer tensor or Python integer).
313  """
314  graph = ops.get_default_graph()
315  if graph is _GRAPH:
316    # Don't enter an init_scope for the learning phase if eager execution
317    # is enabled but we're inside the Keras workspace graph.
318    learning_phase = symbolic_learning_phase()
319    _mark_func_graph_as_unsaveable(graph, learning_phase)
320    return learning_phase
321  with ops.init_scope():
322    # We always check & set the learning phase inside the init_scope,
323    # otherwise the wrong default_graph will be used to look up the learning
324    # phase inside of functions & defuns.
325    #
326    # This is because functions & defuns (both in graph & in eager mode)
327    # will always execute non-eagerly using a function-specific default
328    # subgraph.
329    if context.executing_eagerly():
330      if _DUMMY_EAGER_GRAPH.key not in _GRAPH_LEARNING_PHASES:
331        # Fallback to inference mode as default.
332        return 0
333      return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
334    learning_phase = symbolic_learning_phase()
335    _mark_func_graph_as_unsaveable(graph, learning_phase)
336    return learning_phase
337
338
339def global_learning_phase_is_set():
340  return _DUMMY_EAGER_GRAPH.key in _GRAPH_LEARNING_PHASES
341
342
343def _mark_func_graph_as_unsaveable(graph, learning_phase):
344  """Mark func graph as unsaveable due to use of symbolic keras learning phase.
345
346  Functions that capture the symbolic learning phase cannot be exported to
347  SavedModel. Mark the funcgraph as unsaveable, so that an error will be raised
348  if it is exported.
349
350  Args:
351    graph: Graph or FuncGraph object.
352    learning_phase: Learning phase placeholder or int defined in the graph.
353  """
354  if graph.building_function and is_placeholder(learning_phase):
355    graph.mark_as_unsaveable(
356        'The keras learning phase placeholder was used inside a function. '
357        'Exporting placeholders is not supported when saving out a SavedModel. '
358        'Please call `tf.keras.backend.set_learning_phase(0)` in the function '
359        'to set the learning phase to a constant value.')
360
361
362def symbolic_learning_phase():
363  graph = get_graph()
364  with graph.as_default():
365    if graph not in _GRAPH_LEARNING_PHASES:
366      with name_scope(''):
367        phase = array_ops.placeholder_with_default(
368            False, shape=(), name='keras_learning_phase')
369      _GRAPH_LEARNING_PHASES[graph] = phase
370    return _GRAPH_LEARNING_PHASES[graph]
371
372
373@keras_export('keras.backend.set_learning_phase')
374def set_learning_phase(value):
375  """Sets the learning phase to a fixed value.
376
377  Arguments:
378      value: Learning phase value, either 0 or 1 (integers).
379             0 = test, 1 = train
380
381  Raises:
382      ValueError: if `value` is neither `0` nor `1`.
383  """
384  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
385  if value not in {0, 1}:
386    raise ValueError('Expected learning phase to be 0 or 1.')
387  with ops.init_scope():
388    if context.executing_eagerly():
389      # In an eager context, the learning phase values applies to both the eager
390      # context and the internal Keras graph.
391      _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value
392    _GRAPH_LEARNING_PHASES[get_graph()] = value
393
394
395@keras_export('keras.backend.learning_phase_scope')
396@tf_contextlib.contextmanager
397def learning_phase_scope(value):
398  """Provides a scope within which the learning phase is equal to `value`.
399
400  The learning phase gets restored to its original value upon exiting the scope.
401
402  Arguments:
403     value: Learning phase value, either 0 or 1 (integers).
404            0 = test, 1 = train
405
406  Yields:
407    None.
408
409  Raises:
410     ValueError: if `value` is neither `0` nor `1`.
411  """
412  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
413  if value not in {0, 1}:
414    raise ValueError('Expected learning phase to be 0 or 1.')
415
416  with ops.init_scope():
417    if context.executing_eagerly():
418      previous_eager_value = _GRAPH_LEARNING_PHASES.get(
419          _DUMMY_EAGER_GRAPH.key, None)
420    previous_graph_value = _GRAPH_LEARNING_PHASES.get(get_graph(), None)
421
422  try:
423    set_learning_phase(value)
424    yield
425  finally:
426    # Restore learning phase to initial value.
427    with ops.init_scope():
428      if context.executing_eagerly():
429        if previous_eager_value is not None:
430          _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = previous_eager_value
431        elif _DUMMY_EAGER_GRAPH.key in _GRAPH_LEARNING_PHASES:
432          del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
433
434      graph = get_graph()
435      if previous_graph_value is not None:
436        _GRAPH_LEARNING_PHASES[graph] = previous_graph_value
437      elif graph in _GRAPH_LEARNING_PHASES:
438        del _GRAPH_LEARNING_PHASES[graph]
439
440
441@tf_contextlib.contextmanager
442def eager_learning_phase_scope(value):
443  """Internal scope that sets the learning phase in eager / tf.function only.
444
445  Arguments:
446      value: Learning phase value, either 0 or 1 (integers).
447             0 = test, 1 = train
448
449  Yields:
450    None.
451
452  Raises:
453     ValueError: if `value` is neither `0` nor `1`.
454  """
455  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
456  assert value in {0, 1}
457  assert ops.executing_eagerly_outside_functions()
458  global_learning_phase_was_set = global_learning_phase_is_set()
459  if global_learning_phase_was_set:
460    previous_value = learning_phase()
461  try:
462    _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value
463    yield
464  finally:
465    # Restore learning phase to initial value or unset.
466    if global_learning_phase_was_set:
467      _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = previous_value
468    else:
469      del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
470
471
472def _current_graph(op_input_list):
473  """Return the graph members of `op_input_list`, or the current graph."""
474  return ops._get_graph_from_inputs(op_input_list)
475
476
477def _get_session(op_input_list=()):
478  """Returns the session object for the current thread."""
479  global _SESSION
480  default_session = ops.get_default_session()
481  if default_session is not None:
482    session = default_session
483  else:
484    if ops.inside_function():
485      raise RuntimeError('Cannot get session inside Tensorflow graph function.')
486    # If we don't have a session, or that session does not match the current
487    # graph, create and cache a new session.
488    if (getattr(_SESSION, 'session', None) is None or
489        _SESSION.session.graph is not _current_graph(op_input_list)):
490      # If we are creating the Session inside a tf.distribute.Strategy scope,
491      # we ask the strategy for the right session options to use.
492      if distribution_strategy_context.has_strategy():
493        configure_and_create_distributed_session(
494            distribution_strategy_context.get_strategy())
495      else:
496        _SESSION.session = session_module.Session(
497            config=get_default_session_config())
498    session = _SESSION.session
499  return session
500
501
502@keras_export(v1=['keras.backend.get_session'])
503def get_session(op_input_list=()):
504  """Returns the TF session to be used by the backend.
505
506  If a default TensorFlow session is available, we will return it.
507
508  Else, we will return the global Keras session assuming it matches
509  the current graph.
510
511  If no global Keras session exists at this point:
512  we will create a new global session.
513
514  Note that you can manually set the global session
515  via `K.set_session(sess)`.
516
517  Arguments:
518      op_input_list: An option sequence of tensors or ops, which will be used
519        to determine the current graph. Otherwise the default graph will be
520        used.
521
522  Returns:
523      A TensorFlow session.
524  """
525  session = _get_session(op_input_list)
526  if not _MANUAL_VAR_INIT:
527    with session.graph.as_default():
528      _initialize_variables(session)
529  return session
530
531
532def get_graph():
533  if context.executing_eagerly():
534    global _GRAPH
535    if _GRAPH is None:
536      _GRAPH = func_graph.FuncGraph('keras_graph')
537    return _GRAPH
538  else:
539    return ops.get_default_graph()
540
541
542@tf_contextlib.contextmanager
543def _scratch_graph(graph=None):
544  """Retrieve a shared and temporary func graph.
545
546  The eager execution path lifts a subgraph from the keras global graph into
547  a scratch graph in order to create a function. DistributionStrategies, in
548  turn, constructs multiple functions as well as a final combined function. In
549  order for that logic to work correctly, all of the functions need to be
550  created on the same scratch FuncGraph.
551
552  Args:
553    graph: A graph to be used as the current scratch graph. If not set then
554      a scratch graph will either be retrieved or created:
555
556  Yields:
557    The current scratch graph.
558  """
559  global _CURRENT_SCRATCH_GRAPH
560  if (_CURRENT_SCRATCH_GRAPH is not None and graph is not None and
561      _CURRENT_SCRATCH_GRAPH is not graph):
562    raise ValueError('Multiple scratch graphs specified.')
563
564  if _CURRENT_SCRATCH_GRAPH:
565    yield _CURRENT_SCRATCH_GRAPH
566    return
567
568  graph = graph or func_graph.FuncGraph('keras_scratch_graph')
569  try:
570    _CURRENT_SCRATCH_GRAPH = graph
571    yield graph
572  finally:
573    _CURRENT_SCRATCH_GRAPH = None
574
575
576@keras_export(v1=['keras.backend.set_session'])
577def set_session(session):
578  """Sets the global TensorFlow session.
579
580  Arguments:
581      session: A TF Session.
582  """
583  global _SESSION
584  _SESSION.session = session
585
586
587def get_default_session_config():
588  if os.environ.get('OMP_NUM_THREADS'):
589    logging.warning(
590        'OMP_NUM_THREADS is no longer used by the default Keras config. '
591        'To configure the number of threads, use tf.config.threading APIs.')
592
593  config = context.context().config
594  config.allow_soft_placement = True
595
596  return config
597
598
599def get_default_graph_uid_map():
600  graph = ops.get_default_graph()
601  name_uid_map = PER_GRAPH_OBJECT_NAME_UIDS.get(graph, None)
602  if name_uid_map is None:
603    name_uid_map = collections.defaultdict(int)
604    PER_GRAPH_OBJECT_NAME_UIDS[graph] = name_uid_map
605  return name_uid_map
606
607
608# DEVICE MANIPULATION
609
610
611class _TfDeviceCaptureOp(object):
612  """Class for capturing the TF device scope."""
613
614  def __init__(self):
615    self.device = None
616
617  def _set_device(self, device):
618    """This method captures TF's explicit device scope setting."""
619    if tfdev.is_device_spec(device):
620      device = device.to_string()
621    self.device = device
622
623  def _set_device_from_string(self, device_str):
624    self.device = device_str
625
626
627def _get_current_tf_device():
628  """Return explicit device of current context, otherwise returns `None`.
629
630  Returns:
631      If the current device scope is explicitly set, it returns a string with
632      the device (`CPU` or `GPU`). If the scope is not explicitly set, it will
633      return `None`.
634  """
635  graph = get_graph()
636  op = _TfDeviceCaptureOp()
637  graph._apply_device_functions(op)
638  return tfdev.DeviceSpec.from_string(op.device)
639
640
641def _is_current_explicit_device(device_type):
642  """Check if the current device is explicitly set on the device type specified.
643
644  Arguments:
645      device_type: A string containing `GPU` or `CPU` (case-insensitive).
646
647  Returns:
648      A boolean indicating if the current device scope is explicitly set on the
649      device type.
650
651  Raises:
652      ValueError: If the `device_type` string indicates an unsupported device.
653  """
654  device_type = device_type.upper()
655  if device_type not in ['CPU', 'GPU']:
656    raise ValueError('`device_type` should be either "CPU" or "GPU".')
657  device = _get_current_tf_device()
658  return device is not None and device.device_type == device_type.upper()
659
660
661def _get_available_gpus():
662  """Get a list of available gpu devices (formatted as strings).
663
664  Returns:
665      A list of available GPU devices.
666  """
667  if ops.executing_eagerly_outside_functions():
668    # Returns names of devices directly.
669    return [d.name for d in config.list_logical_devices('GPU')]
670
671  global _LOCAL_DEVICES
672  if _LOCAL_DEVICES is None:
673    _LOCAL_DEVICES = get_session().list_devices()
674  return [x.name for x in _LOCAL_DEVICES if x.device_type == 'GPU']
675
676
677def _has_nchw_support():
678  """Check whether the current scope supports NCHW ops.
679
680  TensorFlow does not support NCHW on CPU. Therefore we check if we are not
681  explicitly put on
682  CPU, and have GPUs available. In this case there will be soft-placing on the
683  GPU device.
684
685  Returns:
686      bool: if the current scope device placement would support nchw
687  """
688  explicitly_on_cpu = _is_current_explicit_device('CPU')
689  gpus_available = bool(_get_available_gpus())
690  return not explicitly_on_cpu and gpus_available
691
692
693# VARIABLE MANIPULATION
694
695
696def _constant_to_tensor(x, dtype):
697  """Convert the input `x` to a tensor of type `dtype`.
698
699  This is slightly faster than the _to_tensor function, at the cost of
700  handling fewer cases.
701
702  Arguments:
703      x: An object to be converted (numpy arrays, floats, ints and lists of
704        them).
705      dtype: The destination type.
706
707  Returns:
708      A tensor.
709  """
710  return constant_op.constant(x, dtype=dtype)
711
712
713def _to_tensor(x, dtype):
714  """Convert the input `x` to a tensor of type `dtype`.
715
716  Arguments:
717      x: An object to be converted (numpy array, list, tensors).
718      dtype: The destination type.
719
720  Returns:
721      A tensor.
722  """
723  return ops.convert_to_tensor(x, dtype=dtype)
724
725
726@keras_export('keras.backend.is_sparse')
727def is_sparse(tensor):
728  """Returns whether a tensor is a sparse tensor.
729
730  Arguments:
731      tensor: A tensor instance.
732
733  Returns:
734      A boolean.
735
736  Example:
737
738
739  >>> a = tf.keras.backend.placeholder((2, 2), sparse=False)
740  >>> print(tf.keras.backend.is_sparse(a))
741  False
742  >>> b = tf.keras.backend.placeholder((2, 2), sparse=True)
743  >>> print(tf.keras.backend.is_sparse(b))
744  True
745
746  """
747  return isinstance(tensor, sparse_tensor.SparseTensor)
748
749
750@keras_export('keras.backend.to_dense')
751def to_dense(tensor):
752  """Converts a sparse tensor into a dense tensor and returns it.
753
754  Arguments:
755      tensor: A tensor instance (potentially sparse).
756
757  Returns:
758      A dense tensor.
759
760  Examples:
761
762
763  >>> b = tf.keras.backend.placeholder((2, 2), sparse=True)
764  >>> print(tf.keras.backend.is_sparse(b))
765  True
766  >>> c = tf.keras.backend.to_dense(b)
767  >>> print(tf.keras.backend.is_sparse(c))
768  False
769
770  """
771  if is_sparse(tensor):
772    return sparse_ops.sparse_tensor_to_dense(tensor)
773  else:
774    return tensor
775
776
777@keras_export('keras.backend.name_scope', v1=[])
778def name_scope(name):
779  """A context manager for use when defining a Python op.
780
781  This context manager pushes a name scope, which will make the name of all
782  operations added within it have a prefix.
783
784  For example, to define a new Python op called `my_op`:
785
786
787  def my_op(a):
788    with tf.name_scope("MyOp") as scope:
789      a = tf.convert_to_tensor(a, name="a")
790      # Define some computation that uses `a`.
791      return foo_op(..., name=scope)
792
793
794  When executed, the Tensor `a` will have the name `MyOp/a`.
795
796  Args:
797    name: The prefix to use on all names created within the name scope.
798
799  Returns:
800    Name scope context manager.
801  """
802  return ops.name_scope_v2(name)
803
804
805@keras_export('keras.backend.variable')
806def variable(value, dtype=None, name=None, constraint=None):
807  """Instantiates a variable and returns it.
808
809  Arguments:
810      value: Numpy array, initial value of the tensor.
811      dtype: Tensor type.
812      name: Optional name string for the tensor.
813      constraint: Optional projection function to be
814          applied to the variable after an optimizer update.
815
816  Returns:
817      A variable instance (with Keras metadata included).
818
819  Examples:
820
821  >>> val = np.array([[1, 2], [3, 4]])
822  >>> kvar = tf.keras.backend.variable(value=val, dtype='float64',
823  ...                                  name='example_var')
824  >>> tf.keras.backend.dtype(kvar)
825  'float64'
826  >>> print(kvar)
827  <tf.Variable 'example_var:...' shape=(2, 2) dtype=float64, numpy=
828    array([[1., 2.],
829           [3., 4.]])>
830
831  """
832  if dtype is None:
833    dtype = floatx()
834  if hasattr(value, 'tocoo'):
835    sparse_coo = value.tocoo()
836    indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), np.expand_dims(
837        sparse_coo.col, 1)), 1)
838    v = sparse_tensor.SparseTensor(
839        indices=indices, values=sparse_coo.data, dense_shape=sparse_coo.shape)
840    v._keras_shape = sparse_coo.shape
841    return v
842  v = variables_module.Variable(
843      value,
844      dtype=dtypes_module.as_dtype(dtype),
845      name=name,
846      constraint=constraint)
847  if isinstance(value, np.ndarray):
848    v._keras_shape = value.shape
849  elif hasattr(value, 'shape'):
850    v._keras_shape = int_shape(value)
851  track_variable(v)
852  return v
853
854
855def track_tf_optimizer(tf_optimizer):
856  """Tracks the given TF optimizer for initialization of its variables."""
857  if context.executing_eagerly():
858    return
859  graph = get_graph()
860  optimizers = _GRAPH_TF_OPTIMIZERS.setdefault(graph, weakref.WeakSet())
861  optimizers.add(tf_optimizer)
862
863
864def track_variable(v):
865  """Tracks the given variable for initialization."""
866  if context.executing_eagerly():
867    return
868  graph = v.graph if hasattr(v, 'graph') else get_graph()
869  if graph not in _GRAPH_VARIABLES:
870    _GRAPH_VARIABLES[graph] = object_identity.ObjectIdentityWeakSet()
871  _GRAPH_VARIABLES[graph].add(v)
872
873
874def unique_object_name(name,
875                       name_uid_map=None,
876                       avoid_names=None,
877                       namespace='',
878                       zero_based=False):
879  """Makes a object name (or arbitrary string) unique within a TensorFlow graph.
880
881  Arguments:
882    name: String name to make unique.
883    name_uid_map: An optional defaultdict(int) to use when creating unique
884      names. If None (default), uses a per-Graph dictionary.
885    avoid_names: An optional set or dict with names which should not be used. If
886      None (default) does not avoid any names.
887    namespace: Gets a name which is unique within the (graph, namespace). Layers
888      which are not Networks use a blank namespace and so get graph-global
889      names.
890    zero_based: If True, name sequences start with no suffix (e.g. "dense",
891      "dense_1"). If False, naming is one-based ("dense_1", "dense_2").
892
893  Returns:
894    Unique string name.
895
896  Example:
897
898
899  _unique_layer_name('dense')  # dense_1
900  _unique_layer_name('dense')  # dense_2
901
902  """
903  if name_uid_map is None:
904    name_uid_map = get_default_graph_uid_map()
905  if avoid_names is None:
906    avoid_names = set()
907  proposed_name = None
908  while proposed_name is None or proposed_name in avoid_names:
909    name_key = (namespace, name)
910    if zero_based:
911      number = name_uid_map[name_key]
912      if number:
913        proposed_name = name + '_' + str(number)
914      else:
915        proposed_name = name
916      name_uid_map[name_key] += 1
917    else:
918      name_uid_map[name_key] += 1
919      proposed_name = name + '_' + str(name_uid_map[name_key])
920  return proposed_name
921
922
923def _get_variables(graph=None):
924  """Returns variables corresponding to the given graph for initialization."""
925  assert not context.executing_eagerly()
926  variables = _GRAPH_VARIABLES.setdefault(graph, weakref.WeakSet())
927  for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()):
928    variables.update(opt.optimizer.variables())
929  return variables
930
931
932def _initialize_variables(session):
933  """Utility to initialize uninitialized variables on the fly."""
934  variables = _get_variables(get_graph())
935  candidate_vars = []
936  for v in variables:
937    if not getattr(v, '_keras_initialized', False):
938      candidate_vars.append(v)
939  if candidate_vars:
940    # This step is expensive, so we only run it on variables not already
941    # marked as initialized.
942    is_initialized = session.run(
943        [variables_module.is_variable_initialized(v) for v in candidate_vars])
944    # TODO(kathywu): Some metric variables loaded from SavedModel are never
945    # actually used, and do not have an initializer.
946    should_be_initialized = [
947        (not is_initialized[n]) and v.initializer is not None
948        for n, v in enumerate(candidate_vars)]
949    uninitialized_vars = []
950    for flag, v in zip(should_be_initialized, candidate_vars):
951      if flag:
952        uninitialized_vars.append(v)
953      v._keras_initialized = True
954    if uninitialized_vars:
955      session.run(variables_module.variables_initializer(uninitialized_vars))
956
957
958@keras_export('keras.backend.constant')
959def constant(value, dtype=None, shape=None, name=None):
960  """Creates a constant tensor.
961
962  Arguments:
963      value: A constant value (or list)
964      dtype: The type of the elements of the resulting tensor.
965      shape: Optional dimensions of resulting tensor.
966      name: Optional name for the tensor.
967
968  Returns:
969      A Constant Tensor.
970  """
971  if dtype is None:
972    dtype = floatx()
973
974  return constant_op.constant(value, dtype=dtype, shape=shape, name=name)
975
976
977@keras_export('keras.backend.is_keras_tensor')
978def is_keras_tensor(x):
979  """Returns whether `x` is a Keras tensor.
980
981  A "Keras tensor" is a tensor that was returned by a Keras layer,
982  (`Layer` class) or by `Input`.
983
984  Arguments:
985      x: A candidate tensor.
986
987  Returns:
988      A boolean: Whether the argument is a Keras tensor.
989
990  Raises:
991      ValueError: In case `x` is not a symbolic tensor.
992
993  Examples:
994
995  >>> np_var = np.array([1, 2])
996  >>> # A numpy array is not a symbolic tensor.
997  >>> tf.keras.backend.is_keras_tensor(np_var)
998  Traceback (most recent call last):
999  ...
1000  ValueError: Unexpectedly found an instance of type `<class 'numpy.ndarray'>`.
1001  Expected a symbolic tensor instance.
1002  >>> keras_var = tf.keras.backend.variable(np_var)
1003  >>> # A variable created with the keras backend is not a Keras tensor.
1004  >>> tf.keras.backend.is_keras_tensor(keras_var)
1005  False
1006  >>> keras_placeholder = tf.keras.backend.placeholder(shape=(2, 4, 5))
1007  >>> # A placeholder is not a Keras tensor.
1008  >>> tf.keras.backend.is_keras_tensor(keras_placeholder)
1009  False
1010  >>> keras_input = tf.keras.layers.Input([10])
1011  >>> # An Input is a Keras tensor.
1012  >>> tf.keras.backend.is_keras_tensor(keras_input)
1013  True
1014  >>> keras_layer_output = tf.keras.layers.Dense(10)(keras_input)
1015  >>> # Any Keras layer output is a Keras tensor.
1016  >>> tf.keras.backend.is_keras_tensor(keras_layer_output)
1017  True
1018
1019  """
1020  if not isinstance(x,
1021                    (ops.Tensor, variables_module.Variable,
1022                     sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
1023    raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) +
1024                     '`. Expected a symbolic tensor instance.')
1025  return hasattr(x, '_keras_history')
1026
1027
1028@keras_export('keras.backend.placeholder')
1029def placeholder(shape=None,
1030                ndim=None,
1031                dtype=None,
1032                sparse=False,
1033                name=None,
1034                ragged=False):
1035  """Instantiates a placeholder tensor and returns it.
1036
1037  Arguments:
1038      shape: Shape of the placeholder
1039          (integer tuple, may include `None` entries).
1040      ndim: Number of axes of the tensor.
1041          At least one of {`shape`, `ndim`} must be specified.
1042          If both are specified, `shape` is used.
1043      dtype: Placeholder type.
1044      sparse: Boolean, whether the placeholder should have a sparse type.
1045      name: Optional name string for the placeholder.
1046      ragged: Boolean, whether the placeholder should have a ragged type.
1047          In this case, values of 'None' in the 'shape' argument represent
1048          ragged dimensions. For more information about RaggedTensors, see this
1049          [guide](https://www.tensorflow.org/guide/ragged_tensors).
1050
1051  Raises:
1052      ValueError: If called with eager execution
1053      ValueError: If called with sparse = True and ragged = True.
1054
1055  Returns:
1056      Tensor instance (with Keras metadata included).
1057
1058  Examples:
1059
1060
1061  >>> input_ph = tf.keras.backend.placeholder(shape=(2, 4, 5))
1062  >>> input_ph
1063  <tf.Tensor 'Placeholder_...' shape=(2, 4, 5) dtype=float32>
1064
1065  """
1066  if sparse and ragged:
1067    raise ValueError(
1068        'Cannot set both sparse and ragged to True when creating a placeholder.'
1069    )
1070
1071  if dtype is None:
1072    dtype = floatx()
1073  if not shape:
1074    if ndim:
1075      shape = (None,) * ndim
1076  with get_graph().as_default():
1077    if sparse:
1078      x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
1079    elif ragged:
1080      ragged_rank = 0
1081      for i in range(1, len(shape)):
1082        if shape[i] is None:
1083          ragged_rank = i
1084      type_spec = ragged_tensor.RaggedTensorSpec(
1085          shape=shape, dtype=dtype, ragged_rank=ragged_rank)
1086      def tensor_spec_to_placeholder(tensorspec):
1087        return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
1088      x = nest.map_structure(tensor_spec_to_placeholder, type_spec,
1089                             expand_composites=True)
1090    else:
1091      x = array_ops.placeholder(dtype, shape=shape, name=name)
1092  return x
1093
1094
1095def is_placeholder(x):
1096  """Returns whether `x` is a placeholder.
1097
1098  Arguments:
1099      x: A candidate placeholder.
1100
1101  Returns:
1102      Boolean.
1103  """
1104  try:
1105    if isinstance(x, composite_tensor.CompositeTensor):
1106      flat_components = nest.flatten(x, expand_composites=True)
1107      return py_any(is_placeholder(c) for c in flat_components)
1108    else:
1109      return x.op.type == 'Placeholder'
1110  except AttributeError:
1111    return False
1112
1113
1114def freezable_variable(value, shape=None, name=None):
1115  """A tensor-like object whose value can be updated only up until execution.
1116
1117  After creating the freezable variable, you can update its value by calling
1118  `var.update_value(new_value)` (similar to a regular variable).
1119  Unlike an actual variable, the value used during execution is the current
1120  value at the time the execution function (`backend.function()`) was created.
1121
1122  This is an internal API, expected to be temporary. It is used to implement a
1123  mutable `trainable` property for `BatchNormalization` layers, with a frozen
1124  value after model compilation.
1125
1126  We don't use a plain variable in this case because we need the value used
1127  in a specific model to be frozen after `compile` has been called
1128  (e.g. GAN use case).
1129
1130  Arguments:
1131    value: The initial value for the tensor-like object.
1132    shape: The shape for the tensor-like object (cannot be changed).
1133    name: The name for the tensor-like object.
1134
1135  Returns:
1136    A tensor-like object with a static value that can be updated via
1137    `x.update_value(new_value)`, up until creating an execution function
1138    (afterwards the value is fixed).
1139  """
1140  graph = get_graph()
1141  with graph.as_default():
1142    x = array_ops.placeholder_with_default(
1143        value, shape=shape, name=name)
1144    x._initial_value = value
1145    x._current_value = value
1146
1147    def update_value(new_value):
1148      x._current_value = new_value
1149
1150    def get_value():
1151      return x._current_value
1152
1153    x.update_value = update_value
1154    x.get_value = get_value
1155
1156    global _FREEZABLE_VARS
1157    if graph not in _FREEZABLE_VARS:
1158      _FREEZABLE_VARS[graph] = object_identity.ObjectIdentityWeakSet()
1159    _FREEZABLE_VARS[graph].add(x)
1160  return x
1161
1162
1163@keras_export('keras.backend.shape')
1164def shape(x):
1165  """Returns the symbolic shape of a tensor or variable.
1166
1167  Arguments:
1168      x: A tensor or variable.
1169
1170  Returns:
1171      A symbolic shape (which is itself a tensor).
1172
1173  Examples:
1174
1175  >>> val = np.array([[1, 2], [3, 4]])
1176  >>> kvar = tf.keras.backend.variable(value=val)
1177  >>> tf.keras.backend.shape(kvar)
1178  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 2], dtype=int32)>
1179  >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1180  >>> tf.keras.backend.shape(input)
1181  <tf.Tensor 'Shape_...' shape=(3,) dtype=int32>
1182
1183  """
1184  return array_ops.shape(x)
1185
1186
1187@keras_export('keras.backend.int_shape')
1188def int_shape(x):
1189  """Returns the shape of tensor or variable as a tuple of int or None entries.
1190
1191  Arguments:
1192      x: Tensor or variable.
1193
1194  Returns:
1195      A tuple of integers (or None entries).
1196
1197  Examples:
1198
1199  >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1200  >>> tf.keras.backend.int_shape(input)
1201  (2, 4, 5)
1202  >>> val = np.array([[1, 2], [3, 4]])
1203  >>> kvar = tf.keras.backend.variable(value=val)
1204  >>> tf.keras.backend.int_shape(kvar)
1205  (2, 2)
1206
1207  """
1208  try:
1209    shape = x.shape
1210    if not isinstance(shape, tuple):
1211      shape = tuple(shape.as_list())
1212    return shape
1213  except ValueError:
1214    return None
1215
1216
1217@keras_export('keras.backend.ndim')
1218def ndim(x):
1219  """Returns the number of axes in a tensor, as an integer.
1220
1221  Arguments:
1222      x: Tensor or variable.
1223
1224  Returns:
1225      Integer (scalar), number of axes.
1226
1227  Examples:
1228
1229
1230  >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1231  >>> val = np.array([[1, 2], [3, 4]])
1232  >>> kvar = tf.keras.backend.variable(value=val)
1233  >>> tf.keras.backend.ndim(input)
1234  3
1235  >>> tf.keras.backend.ndim(kvar)
1236  2
1237
1238  """
1239  dims = x.shape._dims
1240  if dims is not None:
1241    return len(dims)
1242  return None
1243
1244
1245@keras_export('keras.backend.dtype')
1246def dtype(x):
1247  """Returns the dtype of a Keras tensor or variable, as a string.
1248
1249  Arguments:
1250      x: Tensor or variable.
1251
1252  Returns:
1253      String, dtype of `x`.
1254
1255  Examples:
1256
1257  >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5)))
1258  'float32'
1259  >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5),
1260  ...                                                     dtype='float32'))
1261  'float32'
1262  >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5),
1263  ...                                                     dtype='float64'))
1264  'float64'
1265  >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]))
1266  >>> tf.keras.backend.dtype(kvar)
1267  'float32'
1268  >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]),
1269  ...                                  dtype='float32')
1270  >>> tf.keras.backend.dtype(kvar)
1271  'float32'
1272
1273  """
1274  return x.dtype.base_dtype.name
1275
1276
1277@keras_export('keras.backend.eval')
1278def eval(x):
1279  """Evaluates the value of a variable.
1280
1281  Arguments:
1282      x: A variable.
1283
1284  Returns:
1285      A Numpy array.
1286
1287  Examples:
1288
1289  >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]),
1290  ...                                  dtype='float32')
1291  >>> tf.keras.backend.eval(kvar)
1292  array([[1.,  2.],
1293         [3.,  4.]], dtype=float32)
1294
1295  """
1296  return get_value(to_dense(x))
1297
1298
1299@keras_export('keras.backend.zeros')
1300def zeros(shape, dtype=None, name=None):
1301  """Instantiates an all-zeros variable and returns it.
1302
1303  Arguments:
1304      shape: Tuple or list of integers, shape of returned Keras variable
1305      dtype: data type of returned Keras variable
1306      name: name of returned Keras variable
1307
1308  Returns:
1309      A variable (including Keras metadata), filled with `0.0`.
1310      Note that if `shape` was symbolic, we cannot return a variable,
1311      and will return a dynamically-shaped tensor instead.
1312
1313  Example:
1314
1315  >>> kvar = tf.keras.backend.zeros((3,4))
1316  >>> tf.keras.backend.eval(kvar)
1317  array([[0.,  0.,  0.,  0.],
1318         [0.,  0.,  0.,  0.],
1319         [0.,  0.,  0.,  0.]], dtype=float32)
1320  >>> A = tf.constant([1,2,3])
1321  >>> kvar2 = tf.keras.backend.zeros(A.shape) # [0., 0., 0.]
1322  >>> tf.keras.backend.eval(kvar2)
1323  array([0., 0., 0.], dtype=float32)
1324  >>> kvar3 = tf.keras.backend.zeros(A.shape,dtype=tf.int32)
1325  >>> tf.keras.backend.eval(kvar3)
1326  array([0, 0, 0], dtype=int32)
1327  >>> kvar4 = tf.keras.backend.zeros([2,3])
1328  >>> tf.keras.backend.eval(kvar4)
1329  array([[0., 0., 0.],
1330         [0., 0., 0.]], dtype=float32)
1331
1332  """
1333  with ops.init_scope():
1334    if dtype is None:
1335      dtype = floatx()
1336    tf_dtype = dtypes_module.as_dtype(dtype)
1337    v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
1338    if py_all(v.shape.as_list()):
1339      return variable(v, dtype=dtype, name=name)
1340    return v
1341
1342
1343@keras_export('keras.backend.ones')
1344def ones(shape, dtype=None, name=None):
1345  """Instantiates an all-ones variable and returns it.
1346
1347  Arguments:
1348      shape: Tuple of integers, shape of returned Keras variable.
1349      dtype: String, data type of returned Keras variable.
1350      name: String, name of returned Keras variable.
1351
1352  Returns:
1353      A Keras variable, filled with `1.0`.
1354      Note that if `shape` was symbolic, we cannot return a variable,
1355      and will return a dynamically-shaped tensor instead.
1356
1357  Example:
1358
1359
1360  >>> kvar = tf.keras.backend.ones((3,4))
1361  >>> tf.keras.backend.eval(kvar)
1362  array([[1.,  1.,  1.,  1.],
1363         [1.,  1.,  1.,  1.],
1364         [1.,  1.,  1.,  1.]], dtype=float32)
1365
1366  """
1367  with ops.init_scope():
1368    if dtype is None:
1369      dtype = floatx()
1370    tf_dtype = dtypes_module.as_dtype(dtype)
1371    v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
1372    if py_all(v.shape.as_list()):
1373      return variable(v, dtype=dtype, name=name)
1374    return v
1375
1376
1377@keras_export('keras.backend.eye')
1378def eye(size, dtype=None, name=None):
1379  """Instantiate an identity matrix and returns it.
1380
1381  Arguments:
1382      size: Integer, number of rows/columns.
1383      dtype: String, data type of returned Keras variable.
1384      name: String, name of returned Keras variable.
1385
1386  Returns:
1387      A Keras variable, an identity matrix.
1388
1389  Example:
1390
1391
1392  >>> kvar = tf.keras.backend.eye(3)
1393  >>> tf.keras.backend.eval(kvar)
1394  array([[1.,  0.,  0.],
1395         [0.,  1.,  0.],
1396         [0.,  0.,  1.]], dtype=float32)
1397
1398
1399  """
1400  if dtype is None:
1401    dtype = floatx()
1402  tf_dtype = dtypes_module.as_dtype(dtype)
1403  return variable(linalg_ops.eye(size, dtype=tf_dtype), dtype, name)
1404
1405
1406@keras_export('keras.backend.zeros_like')
1407def zeros_like(x, dtype=None, name=None):
1408  """Instantiates an all-zeros variable of the same shape as another tensor.
1409
1410  Arguments:
1411      x: Keras variable or Keras tensor.
1412      dtype: dtype of returned Keras variable.
1413             `None` uses the dtype of `x`.
1414      name: name for the variable to create.
1415
1416  Returns:
1417      A Keras variable with the shape of `x` filled with zeros.
1418
1419  Example:
1420
1421
1422  from tensorflow.keras import backend as K
1423  kvar = K.variable(np.random.random((2,3)))
1424  kvar_zeros = K.zeros_like(kvar)
1425  K.eval(kvar_zeros)
1426  # array([[ 0.,  0.,  0.], [ 0.,  0.,  0.]], dtype=float32)
1427
1428
1429  """
1430  return array_ops.zeros_like(x, dtype=dtype, name=name)
1431
1432
1433@keras_export('keras.backend.ones_like')
1434def ones_like(x, dtype=None, name=None):
1435  """Instantiates an all-ones variable of the same shape as another tensor.
1436
1437  Arguments:
1438      x: Keras variable or tensor.
1439      dtype: String, dtype of returned Keras variable.
1440           None uses the dtype of x.
1441      name: String, name for the variable to create.
1442
1443  Returns:
1444      A Keras variable with the shape of x filled with ones.
1445
1446  Example:
1447
1448  >>> kvar = tf.keras.backend.variable(np.random.random((2,3)))
1449  >>> kvar_ones = tf.keras.backend.ones_like(kvar)
1450  >>> tf.keras.backend.eval(kvar_ones)
1451  array([[1.,  1.,  1.],
1452         [1.,  1.,  1.]], dtype=float32)
1453
1454  """
1455  return array_ops.ones_like(x, dtype=dtype, name=name)
1456
1457
1458def identity(x, name=None):
1459  """Returns a tensor with the same content as the input tensor.
1460
1461  Arguments:
1462      x: The input tensor.
1463      name: String, name for the variable to create.
1464
1465  Returns:
1466      A tensor of the same shape, type and content.
1467  """
1468  return array_ops.identity(x, name=name)
1469
1470
1471@keras_export('keras.backend.random_uniform_variable')
1472def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):
1473  """Instantiates a variable with values drawn from a uniform distribution.
1474
1475  Arguments:
1476      shape: Tuple of integers, shape of returned Keras variable.
1477      low: Float, lower boundary of the output interval.
1478      high: Float, upper boundary of the output interval.
1479      dtype: String, dtype of returned Keras variable.
1480      name: String, name of returned Keras variable.
1481      seed: Integer, random seed.
1482
1483  Returns:
1484      A Keras variable, filled with drawn samples.
1485
1486  Example:
1487
1488  >>> kvar = tf.keras.backend.random_uniform_variable((2,3), 0, 1)
1489  >>> kvar
1490  <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
1491  dtype=float32)>
1492  """
1493  if dtype is None:
1494    dtype = floatx()
1495  tf_dtype = dtypes_module.as_dtype(dtype)
1496  if seed is None:
1497    # ensure that randomness is conditioned by the Numpy RNG
1498    seed = np.random.randint(10e8)
1499  value = init_ops.random_uniform_initializer(
1500      low, high, dtype=tf_dtype, seed=seed)(shape)
1501  return variable(value, dtype=dtype, name=name)
1502
1503
1504@keras_export('keras.backend.random_normal_variable')
1505def random_normal_variable(shape, mean, scale, dtype=None, name=None,
1506                           seed=None):
1507  """Instantiates a variable with values drawn from a normal distribution.
1508
1509  Arguments:
1510      shape: Tuple of integers, shape of returned Keras variable.
1511      mean: Float, mean of the normal distribution.
1512      scale: Float, standard deviation of the normal distribution.
1513      dtype: String, dtype of returned Keras variable.
1514      name: String, name of returned Keras variable.
1515      seed: Integer, random seed.
1516
1517  Returns:
1518      A Keras variable, filled with drawn samples.
1519
1520  Example:
1521
1522  >>> kvar = tf.keras.backend.random_normal_variable((2,3), 0, 1)
1523  >>> kvar
1524  <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
1525  dtype=float32)>
1526  """
1527  if dtype is None:
1528    dtype = floatx()
1529  tf_dtype = dtypes_module.as_dtype(dtype)
1530  if seed is None:
1531    # ensure that randomness is conditioned by the Numpy RNG
1532    seed = np.random.randint(10e8)
1533  value = init_ops.random_normal_initializer(
1534      mean, scale, dtype=tf_dtype, seed=seed)(shape)
1535  return variable(value, dtype=dtype, name=name)
1536
1537
1538@keras_export('keras.backend.count_params')
1539def count_params(x):
1540  """Returns the static number of elements in a variable or tensor.
1541
1542  Arguments:
1543      x: Variable or tensor.
1544
1545  Returns:
1546      Integer, the number of scalars in `x`.
1547
1548  Example:
1549
1550  >>> kvar = tf.keras.backend.zeros((2,3))
1551  >>> tf.keras.backend.count_params(kvar)
1552  6
1553  >>> tf.keras.backend.eval(kvar)
1554  array([[0.,  0.,  0.],
1555         [0.,  0.,  0.]], dtype=float32)
1556
1557  """
1558  return np.prod(x.shape.as_list())
1559
1560
1561@keras_export('keras.backend.cast')
1562def cast(x, dtype):
1563  """Casts a tensor to a different dtype and returns it.
1564
1565  You can cast a Keras variable but it still returns a Keras tensor.
1566
1567  Arguments:
1568      x: Keras tensor (or variable).
1569      dtype: String, either (`'float16'`, `'float32'`, or `'float64'`).
1570
1571  Returns:
1572      Keras tensor with dtype `dtype`.
1573
1574  Examples:
1575      Cast a float32 variable to a float64 tensor
1576
1577  >>> input = tf.keras.backend.ones(shape=(1,3))
1578  >>> print(input)
1579  <tf.Variable 'Variable:0' shape=(1, 3) dtype=float32,
1580  numpy=array([[1., 1., 1.]], dtype=float32)>
1581  >>> cast_input = tf.keras.backend.cast(input, dtype='float64')
1582  >>> print(cast_input)
1583  tf.Tensor([[1. 1. 1.]], shape=(1, 3), dtype=float64)
1584
1585  """
1586  return math_ops.cast(x, dtype)
1587
1588
1589# UPDATES OPS
1590
1591
1592@keras_export('keras.backend.update')
1593def update(x, new_x):
1594  return state_ops.assign(x, new_x)
1595
1596
1597@keras_export('keras.backend.update_add')
1598def update_add(x, increment):
1599  """Update the value of `x` by adding `increment`.
1600
1601  Arguments:
1602      x: A Variable.
1603      increment: A tensor of same shape as `x`.
1604
1605  Returns:
1606      The variable `x` updated.
1607  """
1608  return state_ops.assign_add(x, increment)
1609
1610
1611@keras_export('keras.backend.update_sub')
1612def update_sub(x, decrement):
1613  """Update the value of `x` by subtracting `decrement`.
1614
1615  Arguments:
1616      x: A Variable.
1617      decrement: A tensor of same shape as `x`.
1618
1619  Returns:
1620      The variable `x` updated.
1621  """
1622  return state_ops.assign_sub(x, decrement)
1623
1624
1625@keras_export('keras.backend.moving_average_update')
1626def moving_average_update(x, value, momentum):
1627  """Compute the moving average of a variable.
1628
1629  Arguments:
1630      x: A Variable.
1631      value: A tensor with the same shape as `variable`.
1632      momentum: The moving average momentum.
1633
1634  Returns:
1635      An Operation to update the variable.
1636  """
1637  zero_debias = not tf2.enabled()
1638  return moving_averages.assign_moving_average(
1639      x, value, momentum, zero_debias=zero_debias)
1640
1641
1642# LINEAR ALGEBRA
1643
1644
1645@keras_export('keras.backend.dot')
1646def dot(x, y):
1647  """Multiplies 2 tensors (and/or variables) and returns a tensor.
1648
1649  Arguments:
1650      x: Tensor or variable.
1651      y: Tensor or variable.
1652
1653  Returns:
1654      A tensor, dot product of `x` and `y`.
1655
1656  Examples:
1657
1658  >>> x = tf.keras.backend.placeholder(shape=(2, 3))
1659  >>> y = tf.keras.backend.placeholder(shape=(3, 4))
1660  >>> xy = tf.keras.backend.dot(x, y)
1661  >>> xy
1662  <tf.Tensor ... shape=(2, 4) dtype=float32>
1663
1664  >>> x = tf.keras.backend.placeholder(shape=(32, 28, 3))
1665  >>> y = tf.keras.backend.placeholder(shape=(3, 4))
1666  >>> xy = tf.keras.backend.dot(x, y)
1667  >>> xy
1668  <tf.Tensor ... shape=(32, 28, 4) dtype=float32>
1669
1670  >>> x = tf.keras.backend.random_uniform_variable(shape=(2, 3), low=0, high=1)
1671  >>> y = tf.keras.backend.ones((4, 3, 5))
1672  >>> xy = tf.keras.backend.dot(x, y)
1673  >>> tf.keras.backend.int_shape(xy)
1674  (2, 4, 5)
1675  """
1676  if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
1677    x_shape = []
1678    for i, s in zip(int_shape(x), array_ops.unstack(array_ops.shape(x))):
1679      if i is not None:
1680        x_shape.append(i)
1681      else:
1682        x_shape.append(s)
1683    x_shape = tuple(x_shape)
1684    y_shape = []
1685    for i, s in zip(int_shape(y), array_ops.unstack(array_ops.shape(y))):
1686      if i is not None:
1687        y_shape.append(i)
1688      else:
1689        y_shape.append(s)
1690    y_shape = tuple(y_shape)
1691    y_permute_dim = list(range(ndim(y)))
1692    y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim
1693    xt = array_ops.reshape(x, [-1, x_shape[-1]])
1694    yt = array_ops.reshape(
1695        array_ops.transpose(y, perm=y_permute_dim), [y_shape[-2], -1])
1696    return array_ops.reshape(
1697        math_ops.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:])
1698  if is_sparse(x):
1699    out = sparse_ops.sparse_tensor_dense_matmul(x, y)
1700  else:
1701    out = math_ops.matmul(x, y)
1702  return out
1703
1704
1705@keras_export('keras.backend.batch_dot')
1706def batch_dot(x, y, axes=None):
1707  """Batchwise dot product.
1708
1709  `batch_dot` is used to compute dot product of `x` and `y` when
1710  `x` and `y` are data in batch, i.e. in a shape of
1711  `(batch_size, :)`.
1712  `batch_dot` results in a tensor or variable with less dimensions
1713  than the input. If the number of dimensions is reduced to 1,
1714  we use `expand_dims` to make sure that ndim is at least 2.
1715
1716  Arguments:
1717    x: Keras tensor or variable with `ndim >= 2`.
1718    y: Keras tensor or variable with `ndim >= 2`.
1719    axes: Tuple or list of integers with target dimensions, or single integer.
1720      The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]` should be equal.
1721
1722  Returns:
1723    A tensor with shape equal to the concatenation of `x`'s shape
1724    (less the dimension that was summed over) and `y`'s shape
1725    (less the batch dimension and the dimension that was summed over).
1726    If the final rank is 1, we reshape it to `(batch_size, 1)`.
1727
1728  Examples:
1729
1730  >>> x_batch = tf.keras.backend.ones(shape=(32, 20, 1))
1731  >>> y_batch = tf.keras.backend.ones(shape=(32, 30, 20))
1732  >>> xy_batch_dot = tf.keras.backend.batch_dot(x_batch, y_batch, axes=(1, 2))
1733  >>> tf.keras.backend.int_shape(xy_batch_dot)
1734  (32, 1, 30)
1735
1736  Shape inference:
1737    Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.
1738    If `axes` is (1, 2), to find the output shape of resultant tensor,
1739        loop through each dimension in `x`'s shape and `y`'s shape:
1740    * `x.shape[0]` : 100 : append to output shape
1741    * `x.shape[1]` : 20 : do not append to output shape,
1742        dimension 1 of `x` has been summed over. (`dot_axes[0]` = 1)
1743    * `y.shape[0]` : 100 : do not append to output shape,
1744        always ignore first dimension of `y`
1745    * `y.shape[1]` : 30 : append to output shape
1746    * `y.shape[2]` : 20 : do not append to output shape,
1747        dimension 2 of `y` has been summed over. (`dot_axes[1]` = 2)
1748    `output_shape` = `(100, 30)`
1749  """
1750  x_shape = int_shape(x)
1751  y_shape = int_shape(y)
1752
1753  x_ndim = len(x_shape)
1754  y_ndim = len(y_shape)
1755
1756  if x_ndim < 2 or y_ndim < 2:
1757    raise ValueError('Cannot do batch_dot on inputs '
1758                     'with rank < 2. '
1759                     'Received inputs with shapes ' +
1760                     str(x_shape) + ' and ' +
1761                     str(y_shape) + '.')
1762
1763  x_batch_size = x_shape[0]
1764  y_batch_size = y_shape[0]
1765
1766  if x_batch_size is not None and y_batch_size is not None:
1767    if x_batch_size != y_batch_size:
1768      raise ValueError('Cannot do batch_dot on inputs '
1769                       'with different batch sizes. '
1770                       'Received inputs with shapes ' +
1771                       str(x_shape) + ' and ' +
1772                       str(y_shape) + '.')
1773  if isinstance(axes, int):
1774    axes = [axes, axes]
1775
1776  if axes is None:
1777    if y_ndim == 2:
1778      axes = [x_ndim - 1, y_ndim - 1]
1779    else:
1780      axes = [x_ndim - 1, y_ndim - 2]
1781
1782  if py_any(isinstance(a, (list, tuple)) for a in axes):
1783    raise ValueError('Multiple target dimensions are not supported. ' +
1784                     'Expected: None, int, (int, int), ' +
1785                     'Provided: ' + str(axes))
1786
1787  # if tuple, convert to list.
1788  axes = list(axes)
1789
1790  # convert negative indices.
1791  if axes[0] < 0:
1792    axes[0] += x_ndim
1793  if axes[1] < 0:
1794    axes[1] += y_ndim
1795
1796  # sanity checks
1797  if 0 in axes:
1798    raise ValueError('Cannot perform batch_dot over axis 0. '
1799                     'If your inputs are not batched, '
1800                     'add a dummy batch dimension to your '
1801                     'inputs using K.expand_dims(x, 0)')
1802  a0, a1 = axes
1803  d1 = x_shape[a0]
1804  d2 = y_shape[a1]
1805
1806  if d1 is not None and d2 is not None and d1 != d2:
1807    raise ValueError('Cannot do batch_dot on inputs with shapes ' +
1808                     str(x_shape) + ' and ' + str(y_shape) +
1809                     ' with axes=' + str(axes) + '. x.shape[%d] != '
1810                     'y.shape[%d] (%d != %d).' % (axes[0], axes[1], d1, d2))
1811
1812  # backup ndims. Need them later.
1813  orig_x_ndim = x_ndim
1814  orig_y_ndim = y_ndim
1815
1816  # if rank is 2, expand to 3.
1817  if x_ndim == 2:
1818    x = array_ops.expand_dims(x, 1)
1819    a0 += 1
1820    x_ndim += 1
1821  if y_ndim == 2:
1822    y = array_ops.expand_dims(y, 2)
1823    y_ndim += 1
1824
1825  # bring x's dimension to be reduced to last axis.
1826  if a0 != x_ndim - 1:
1827    pattern = list(range(x_ndim))
1828    for i in range(a0, x_ndim - 1):
1829      pattern[i] = pattern[i + 1]
1830    pattern[-1] = a0
1831    x = array_ops.transpose(x, pattern)
1832
1833  # bring y's dimension to be reduced to axis 1.
1834  if a1 != 1:
1835    pattern = list(range(y_ndim))
1836    for i in range(a1, 1, -1):
1837      pattern[i] = pattern[i - 1]
1838    pattern[1] = a1
1839    y = array_ops.transpose(y, pattern)
1840
1841  # normalize both inputs to rank 3.
1842  if x_ndim > 3:
1843    # squash middle dimensions of x.
1844    x_shape = shape(x)
1845    x_mid_dims = x_shape[1:-1]
1846    x_squashed_shape = array_ops.stack(
1847        [x_shape[0], -1, x_shape[-1]])
1848    x = array_ops.reshape(x, x_squashed_shape)
1849    x_squashed = True
1850  else:
1851    x_squashed = False
1852
1853  if y_ndim > 3:
1854    # squash trailing dimensions of y.
1855    y_shape = shape(y)
1856    y_trail_dims = y_shape[2:]
1857    y_squashed_shape = array_ops.stack(
1858        [y_shape[0], y_shape[1], -1])
1859    y = array_ops.reshape(y, y_squashed_shape)
1860    y_squashed = True
1861  else:
1862    y_squashed = False
1863
1864  result = math_ops.matmul(x, y)
1865
1866  # if inputs were squashed, we have to reshape the matmul output.
1867  output_shape = array_ops.shape(result)
1868  do_reshape = False
1869
1870  if x_squashed:
1871    output_shape = array_ops.concat(
1872        [output_shape[:1],
1873         x_mid_dims,
1874         output_shape[-1:]], 0)
1875    do_reshape = True
1876
1877  if y_squashed:
1878    output_shape = array_ops.concat([output_shape[:-1], y_trail_dims], 0)
1879    do_reshape = True
1880
1881  if do_reshape:
1882    result = array_ops.reshape(result, output_shape)
1883
1884  # if the inputs were originally rank 2, we remove the added 1 dim.
1885  if orig_x_ndim == 2:
1886    result = array_ops.squeeze(result, 1)
1887  elif orig_y_ndim == 2:
1888    result = array_ops.squeeze(result, -1)
1889
1890  return result
1891
1892
1893@keras_export('keras.backend.transpose')
1894def transpose(x):
1895  """Transposes a tensor and returns it.
1896
1897  Arguments:
1898      x: Tensor or variable.
1899
1900  Returns:
1901      A tensor.
1902
1903  Examples:
1904
1905  >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]])
1906  >>> tf.keras.backend.eval(var)
1907  array([[1.,  2.,  3.],
1908         [4.,  5.,  6.]], dtype=float32)
1909  >>> var_transposed = tf.keras.backend.transpose(var)
1910  >>> tf.keras.backend.eval(var_transposed)
1911  array([[1.,  4.],
1912         [2.,  5.],
1913         [3.,  6.]], dtype=float32)
1914  >>> input = tf.keras.backend.placeholder((2, 3))
1915  >>> input
1916  <tf.Tensor 'Placeholder_...' shape=(2, 3) dtype=float32>
1917  >>> input_transposed = tf.keras.backend.transpose(input)
1918  >>> input_transposed
1919  <tf.Tensor 'Transpose_...' shape=(3, 2) dtype=float32>
1920  """
1921  return array_ops.transpose(x)
1922
1923
1924@keras_export('keras.backend.gather')
1925def gather(reference, indices):
1926  """Retrieves the elements of indices `indices` in the tensor `reference`.
1927
1928  Arguments:
1929      reference: A tensor.
1930      indices: An integer tensor of indices.
1931
1932  Returns:
1933      A tensor of same type as `reference`.
1934
1935  Examples:
1936
1937  >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]])
1938  >>> tf.keras.backend.eval(var)
1939  array([[1., 2., 3.],
1940         [4., 5., 6.]], dtype=float32)
1941  >>> var_gathered = tf.keras.backend.gather(var, [0])
1942  >>> tf.keras.backend.eval(var_gathered)
1943  array([[1., 2., 3.]], dtype=float32)
1944  >>> var_gathered = tf.keras.backend.gather(var, [1])
1945  >>> tf.keras.backend.eval(var_gathered)
1946  array([[4., 5., 6.]], dtype=float32)
1947  >>> var_gathered = tf.keras.backend.gather(var, [0,1,0])
1948  >>> tf.keras.backend.eval(var_gathered)
1949  array([[1., 2., 3.],
1950         [4., 5., 6.],
1951         [1., 2., 3.]], dtype=float32)
1952  """
1953  return array_ops.gather(reference, indices)
1954
1955
1956# ELEMENT-WISE OPERATIONS
1957
1958
1959@keras_export('keras.backend.max')
1960def max(x, axis=None, keepdims=False):
1961  """Maximum value in a tensor.
1962
1963  Arguments:
1964      x: A tensor or variable.
1965      axis: An integer, the axis to find maximum values.
1966      keepdims: A boolean, whether to keep the dimensions or not.
1967          If `keepdims` is `False`, the rank of the tensor is reduced
1968          by 1. If `keepdims` is `True`,
1969          the reduced dimension is retained with length 1.
1970
1971  Returns:
1972      A tensor with maximum values of `x`.
1973  """
1974  return math_ops.reduce_max(x, axis, keepdims)
1975
1976
1977@keras_export('keras.backend.min')
1978def min(x, axis=None, keepdims=False):
1979  """Minimum value in a tensor.
1980
1981  Arguments:
1982      x: A tensor or variable.
1983      axis: An integer, the axis to find minimum values.
1984      keepdims: A boolean, whether to keep the dimensions or not.
1985          If `keepdims` is `False`, the rank of the tensor is reduced
1986          by 1. If `keepdims` is `True`,
1987          the reduced dimension is retained with length 1.
1988
1989  Returns:
1990      A tensor with minimum values of `x`.
1991  """
1992  return math_ops.reduce_min(x, axis, keepdims)
1993
1994
1995@keras_export('keras.backend.sum')
1996def sum(x, axis=None, keepdims=False):
1997  """Sum of the values in a tensor, alongside the specified axis.
1998
1999  Arguments:
2000      x: A tensor or variable.
2001      axis: An integer, the axis to sum over.
2002      keepdims: A boolean, whether to keep the dimensions or not.
2003          If `keepdims` is `False`, the rank of the tensor is reduced
2004          by 1. If `keepdims` is `True`,
2005          the reduced dimension is retained with length 1.
2006
2007  Returns:
2008      A tensor with sum of `x`.
2009  """
2010  return math_ops.reduce_sum(x, axis, keepdims)
2011
2012
2013@keras_export('keras.backend.prod')
2014def prod(x, axis=None, keepdims=False):
2015  """Multiplies the values in a tensor, alongside the specified axis.
2016
2017  Arguments:
2018      x: A tensor or variable.
2019      axis: An integer, the axis to compute the product.
2020      keepdims: A boolean, whether to keep the dimensions or not.
2021          If `keepdims` is `False`, the rank of the tensor is reduced
2022          by 1. If `keepdims` is `True`,
2023          the reduced dimension is retained with length 1.
2024
2025  Returns:
2026      A tensor with the product of elements of `x`.
2027  """
2028  return math_ops.reduce_prod(x, axis, keepdims)
2029
2030
2031@keras_export('keras.backend.cumsum')
2032def cumsum(x, axis=0):
2033  """Cumulative sum of the values in a tensor, alongside the specified axis.
2034
2035  Arguments:
2036      x: A tensor or variable.
2037      axis: An integer, the axis to compute the sum.
2038
2039  Returns:
2040      A tensor of the cumulative sum of values of `x` along `axis`.
2041  """
2042  return math_ops.cumsum(x, axis=axis)
2043
2044
2045@keras_export('keras.backend.cumprod')
2046def cumprod(x, axis=0):
2047  """Cumulative product of the values in a tensor, alongside the specified axis.
2048
2049  Arguments:
2050      x: A tensor or variable.
2051      axis: An integer, the axis to compute the product.
2052
2053  Returns:
2054      A tensor of the cumulative product of values of `x` along `axis`.
2055  """
2056  return math_ops.cumprod(x, axis=axis)
2057
2058
2059@keras_export('keras.backend.var')
2060def var(x, axis=None, keepdims=False):
2061  """Variance of a tensor, alongside the specified axis.
2062
2063  Arguments:
2064      x: A tensor or variable.
2065      axis: An integer, the axis to compute the variance.
2066      keepdims: A boolean, whether to keep the dimensions or not.
2067          If `keepdims` is `False`, the rank of the tensor is reduced
2068          by 1. If `keepdims` is `True`,
2069          the reduced dimension is retained with length 1.
2070
2071  Returns:
2072      A tensor with the variance of elements of `x`.
2073  """
2074  if x.dtype.base_dtype == dtypes_module.bool:
2075    x = math_ops.cast(x, floatx())
2076  return math_ops.reduce_variance(x, axis=axis, keepdims=keepdims)
2077
2078
2079@keras_export('keras.backend.std')
2080def std(x, axis=None, keepdims=False):
2081  """Standard deviation of a tensor, alongside the specified axis.
2082
2083  Arguments:
2084      x: A tensor or variable.
2085      axis: An integer, the axis to compute the standard deviation.
2086      keepdims: A boolean, whether to keep the dimensions or not.
2087          If `keepdims` is `False`, the rank of the tensor is reduced
2088          by 1. If `keepdims` is `True`,
2089          the reduced dimension is retained with length 1.
2090
2091  Returns:
2092      A tensor with the standard deviation of elements of `x`.
2093  """
2094  if x.dtype.base_dtype == dtypes_module.bool:
2095    x = math_ops.cast(x, floatx())
2096  return math_ops.reduce_std(x, axis=axis, keepdims=keepdims)
2097
2098
2099@keras_export('keras.backend.mean')
2100def mean(x, axis=None, keepdims=False):
2101  """Mean of a tensor, alongside the specified axis.
2102
2103  Arguments:
2104      x: A tensor or variable.
2105      axis: A list of integer. Axes to compute the mean.
2106      keepdims: A boolean, whether to keep the dimensions or not.
2107          If `keepdims` is `False`, the rank of the tensor is reduced
2108          by 1 for each entry in `axis`. If `keepdims` is `True`,
2109          the reduced dimensions are retained with length 1.
2110
2111  Returns:
2112      A tensor with the mean of elements of `x`.
2113  """
2114  if x.dtype.base_dtype == dtypes_module.bool:
2115    x = math_ops.cast(x, floatx())
2116  return math_ops.reduce_mean(x, axis, keepdims)
2117
2118
2119@keras_export('keras.backend.any')
2120def any(x, axis=None, keepdims=False):
2121  """Bitwise reduction (logical OR).
2122
2123  Arguments:
2124      x: Tensor or variable.
2125      axis: axis along which to perform the reduction.
2126      keepdims: whether the drop or broadcast the reduction axes.
2127
2128  Returns:
2129      A uint8 tensor (0s and 1s).
2130  """
2131  x = math_ops.cast(x, dtypes_module.bool)
2132  return math_ops.reduce_any(x, axis, keepdims)
2133
2134
2135@keras_export('keras.backend.all')
2136def all(x, axis=None, keepdims=False):
2137  """Bitwise reduction (logical AND).
2138
2139  Arguments:
2140      x: Tensor or variable.
2141      axis: axis along which to perform the reduction.
2142      keepdims: whether the drop or broadcast the reduction axes.
2143
2144  Returns:
2145      A uint8 tensor (0s and 1s).
2146  """
2147  x = math_ops.cast(x, dtypes_module.bool)
2148  return math_ops.reduce_all(x, axis, keepdims)
2149
2150
2151@keras_export('keras.backend.argmax')
2152def argmax(x, axis=-1):
2153  """Returns the index of the maximum value along an axis.
2154
2155  Arguments:
2156      x: Tensor or variable.
2157      axis: axis along which to perform the reduction.
2158
2159  Returns:
2160      A tensor.
2161  """
2162  return math_ops.argmax(x, axis)
2163
2164
2165@keras_export('keras.backend.argmin')
2166def argmin(x, axis=-1):
2167  """Returns the index of the minimum value along an axis.
2168
2169  Arguments:
2170      x: Tensor or variable.
2171      axis: axis along which to perform the reduction.
2172
2173  Returns:
2174      A tensor.
2175  """
2176  return math_ops.argmin(x, axis)
2177
2178
2179@keras_export('keras.backend.square')
2180def square(x):
2181  """Element-wise square.
2182
2183  Arguments:
2184      x: Tensor or variable.
2185
2186  Returns:
2187      A tensor.
2188  """
2189  return math_ops.square(x)
2190
2191
2192@keras_export('keras.backend.abs')
2193def abs(x):
2194  """Element-wise absolute value.
2195
2196  Arguments:
2197      x: Tensor or variable.
2198
2199  Returns:
2200      A tensor.
2201  """
2202  return math_ops.abs(x)
2203
2204
2205@keras_export('keras.backend.sqrt')
2206def sqrt(x):
2207  """Element-wise square root.
2208
2209  Arguments:
2210      x: Tensor or variable.
2211
2212  Returns:
2213      A tensor.
2214  """
2215  zero = _constant_to_tensor(0., x.dtype.base_dtype)
2216  inf = _constant_to_tensor(np.inf, x.dtype.base_dtype)
2217  x = clip_ops.clip_by_value(x, zero, inf)
2218  return math_ops.sqrt(x)
2219
2220
2221@keras_export('keras.backend.exp')
2222def exp(x):
2223  """Element-wise exponential.
2224
2225  Arguments:
2226      x: Tensor or variable.
2227
2228  Returns:
2229      A tensor.
2230  """
2231  return math_ops.exp(x)
2232
2233
2234@keras_export('keras.backend.log')
2235def log(x):
2236  """Element-wise log.
2237
2238  Arguments:
2239      x: Tensor or variable.
2240
2241  Returns:
2242      A tensor.
2243  """
2244  return math_ops.log(x)
2245
2246
2247def logsumexp(x, axis=None, keepdims=False):
2248  """Computes log(sum(exp(elements across dimensions of a tensor))).
2249
2250  This function is more numerically stable than log(sum(exp(x))).
2251  It avoids overflows caused by taking the exp of large inputs and
2252  underflows caused by taking the log of small inputs.
2253
2254  Arguments:
2255      x: A tensor or variable.
2256      axis: An integer, the axis to reduce over.
2257      keepdims: A boolean, whether to keep the dimensions or not.
2258          If `keepdims` is `False`, the rank of the tensor is reduced
2259          by 1. If `keepdims` is `True`, the reduced dimension is
2260          retained with length 1.
2261
2262  Returns:
2263      The reduced tensor.
2264  """
2265  return math_ops.reduce_logsumexp(x, axis, keepdims)
2266
2267
2268@keras_export('keras.backend.round')
2269def round(x):
2270  """Element-wise rounding to the closest integer.
2271
2272  In case of tie, the rounding mode used is "half to even".
2273
2274  Arguments:
2275      x: Tensor or variable.
2276
2277  Returns:
2278      A tensor.
2279  """
2280  return math_ops.round(x)
2281
2282
2283@keras_export('keras.backend.sign')
2284def sign(x):
2285  """Element-wise sign.
2286
2287  Arguments:
2288      x: Tensor or variable.
2289
2290  Returns:
2291      A tensor.
2292  """
2293  return math_ops.sign(x)
2294
2295
2296@keras_export('keras.backend.pow')
2297def pow(x, a):
2298  """Element-wise exponentiation.
2299
2300  Arguments:
2301      x: Tensor or variable.
2302      a: Python integer.
2303
2304  Returns:
2305      A tensor.
2306  """
2307  return math_ops.pow(x, a)
2308
2309
2310@keras_export('keras.backend.clip')
2311def clip(x, min_value, max_value):
2312  """Element-wise value clipping.
2313
2314  Arguments:
2315      x: Tensor or variable.
2316      min_value: Python float, integer, or tensor.
2317      max_value: Python float, integer, or tensor.
2318
2319  Returns:
2320      A tensor.
2321  """
2322  if (isinstance(min_value, (int, float)) and
2323      isinstance(max_value, (int, float))):
2324    if max_value < min_value:
2325      max_value = min_value
2326  if min_value is None:
2327    min_value = -np.inf
2328  if max_value is None:
2329    max_value = np.inf
2330  return clip_ops.clip_by_value(x, min_value, max_value)
2331
2332
2333@keras_export('keras.backend.equal')
2334def equal(x, y):
2335  """Element-wise equality between two tensors.
2336
2337  Arguments:
2338      x: Tensor or variable.
2339      y: Tensor or variable.
2340
2341  Returns:
2342      A bool tensor.
2343  """
2344  return math_ops.equal(x, y)
2345
2346
2347@keras_export('keras.backend.not_equal')
2348def not_equal(x, y):
2349  """Element-wise inequality between two tensors.
2350
2351  Arguments:
2352      x: Tensor or variable.
2353      y: Tensor or variable.
2354
2355  Returns:
2356      A bool tensor.
2357  """
2358  return math_ops.not_equal(x, y)
2359
2360
2361@keras_export('keras.backend.greater')
2362def greater(x, y):
2363  """Element-wise truth value of (x > y).
2364
2365  Arguments:
2366      x: Tensor or variable.
2367      y: Tensor or variable.
2368
2369  Returns:
2370      A bool tensor.
2371  """
2372  return math_ops.greater(x, y)
2373
2374
2375@keras_export('keras.backend.greater_equal')
2376def greater_equal(x, y):
2377  """Element-wise truth value of (x >= y).
2378
2379  Arguments:
2380      x: Tensor or variable.
2381      y: Tensor or variable.
2382
2383  Returns:
2384      A bool tensor.
2385  """
2386  return math_ops.greater_equal(x, y)
2387
2388
2389@keras_export('keras.backend.less')
2390def less(x, y):
2391  """Element-wise truth value of (x < y).
2392
2393  Arguments:
2394      x: Tensor or variable.
2395      y: Tensor or variable.
2396
2397  Returns:
2398      A bool tensor.
2399  """
2400  return math_ops.less(x, y)
2401
2402
2403@keras_export('keras.backend.less_equal')
2404def less_equal(x, y):
2405  """Element-wise truth value of (x <= y).
2406
2407  Arguments:
2408      x: Tensor or variable.
2409      y: Tensor or variable.
2410
2411  Returns:
2412      A bool tensor.
2413  """
2414  return math_ops.less_equal(x, y)
2415
2416
2417@keras_export('keras.backend.maximum')
2418def maximum(x, y):
2419  """Element-wise maximum of two tensors.
2420
2421  Arguments:
2422      x: Tensor or variable.
2423      y: Tensor or variable.
2424
2425  Returns:
2426      A tensor with the element wise maximum value(s) of `x` and `y`.
2427
2428  Examples:
2429
2430  >>> x = tf.Variable([[1, 2], [3, 4]])
2431  >>> y = tf.Variable([[2, 1], [0, -1]])
2432  >>> m = tf.keras.backend.maximum(x, y)
2433  >>> m
2434  <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
2435  array([[2, 2],
2436         [3, 4]], dtype=int32)>
2437  """
2438  return math_ops.maximum(x, y)
2439
2440
2441@keras_export('keras.backend.minimum')
2442def minimum(x, y):
2443  """Element-wise minimum of two tensors.
2444
2445  Arguments:
2446      x: Tensor or variable.
2447      y: Tensor or variable.
2448
2449  Returns:
2450      A tensor.
2451  """
2452  return math_ops.minimum(x, y)
2453
2454
2455@keras_export('keras.backend.sin')
2456def sin(x):
2457  """Computes sin of x element-wise.
2458
2459  Arguments:
2460      x: Tensor or variable.
2461
2462  Returns:
2463      A tensor.
2464  """
2465  return math_ops.sin(x)
2466
2467
2468@keras_export('keras.backend.cos')
2469def cos(x):
2470  """Computes cos of x element-wise.
2471
2472  Arguments:
2473      x: Tensor or variable.
2474
2475  Returns:
2476      A tensor.
2477  """
2478  return math_ops.cos(x)
2479
2480
2481def _regular_normalize_batch_in_training(x,
2482                                         gamma,
2483                                         beta,
2484                                         reduction_axes,
2485                                         epsilon=1e-3):
2486  """Non-fused version of `normalize_batch_in_training`.
2487
2488  Arguments:
2489      x: Input tensor or variable.
2490      gamma: Tensor by which to scale the input.
2491      beta: Tensor with which to center the input.
2492      reduction_axes: iterable of integers,
2493          axes over which to normalize.
2494      epsilon: Fuzz factor.
2495
2496  Returns:
2497      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2498  """
2499  mean, var = nn.moments(x, reduction_axes, None, None, False)
2500  normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
2501  return normed, mean, var
2502
2503
2504def _broadcast_normalize_batch_in_training(x,
2505                                           gamma,
2506                                           beta,
2507                                           reduction_axes,
2508                                           epsilon=1e-3):
2509  """Non-fused, broadcast version of `normalize_batch_in_training`.
2510
2511  Arguments:
2512      x: Input tensor or variable.
2513      gamma: Tensor by which to scale the input.
2514      beta: Tensor with which to center the input.
2515      reduction_axes: iterable of integers,
2516          axes over which to normalize.
2517      epsilon: Fuzz factor.
2518
2519  Returns:
2520      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2521  """
2522  mean, var = nn.moments(x, reduction_axes, None, None, False)
2523  target_shape = []
2524  for axis in range(ndim(x)):
2525    if axis in reduction_axes:
2526      target_shape.append(1)
2527    else:
2528      target_shape.append(array_ops.shape(x)[axis])
2529  target_shape = array_ops.stack(target_shape)
2530
2531  broadcast_mean = array_ops.reshape(mean, target_shape)
2532  broadcast_var = array_ops.reshape(var, target_shape)
2533  if gamma is None:
2534    broadcast_gamma = None
2535  else:
2536    broadcast_gamma = array_ops.reshape(gamma, target_shape)
2537  if beta is None:
2538    broadcast_beta = None
2539  else:
2540    broadcast_beta = array_ops.reshape(beta, target_shape)
2541
2542  normed = nn.batch_normalization(x, broadcast_mean, broadcast_var,
2543                                  broadcast_beta, broadcast_gamma, epsilon)
2544  return normed, mean, var
2545
2546
2547def _fused_normalize_batch_in_training(x,
2548                                       gamma,
2549                                       beta,
2550                                       reduction_axes,
2551                                       epsilon=1e-3):
2552  """Fused version of `normalize_batch_in_training`.
2553
2554  Arguments:
2555      x: Input tensor or variable.
2556      gamma: Tensor by which to scale the input.
2557      beta: Tensor with which to center the input.
2558      reduction_axes: iterable of integers,
2559          axes over which to normalize.
2560      epsilon: Fuzz factor.
2561
2562  Returns:
2563      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2564  """
2565  if list(reduction_axes) == [0, 1, 2]:
2566    normalization_axis = 3
2567    tf_data_format = 'NHWC'
2568  else:
2569    normalization_axis = 1
2570    tf_data_format = 'NCHW'
2571
2572  if gamma is None:
2573    gamma = constant_op.constant(
2574        1.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
2575  if beta is None:
2576    beta = constant_op.constant(
2577        0.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
2578
2579  return nn.fused_batch_norm(
2580      x, gamma, beta, epsilon=epsilon, data_format=tf_data_format)
2581
2582
2583@keras_export('keras.backend.normalize_batch_in_training')
2584def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
2585  """Computes mean and std for batch then apply batch_normalization on batch.
2586
2587  Arguments:
2588      x: Input tensor or variable.
2589      gamma: Tensor by which to scale the input.
2590      beta: Tensor with which to center the input.
2591      reduction_axes: iterable of integers,
2592          axes over which to normalize.
2593      epsilon: Fuzz factor.
2594
2595  Returns:
2596      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2597  """
2598  if ndim(x) == 4 and list(reduction_axes) in [[0, 1, 2], [0, 2, 3]]:
2599    if not _has_nchw_support() and list(reduction_axes) == [0, 2, 3]:
2600      return _broadcast_normalize_batch_in_training(
2601          x, gamma, beta, reduction_axes, epsilon=epsilon)
2602    return _fused_normalize_batch_in_training(
2603        x, gamma, beta, reduction_axes, epsilon=epsilon)
2604  else:
2605    if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
2606      return _regular_normalize_batch_in_training(
2607          x, gamma, beta, reduction_axes, epsilon=epsilon)
2608    else:
2609      return _broadcast_normalize_batch_in_training(
2610          x, gamma, beta, reduction_axes, epsilon=epsilon)
2611
2612
2613@keras_export('keras.backend.batch_normalization')
2614def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
2615  """Applies batch normalization on x given mean, var, beta and gamma.
2616
2617  I.e. returns:
2618  `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
2619
2620  Arguments:
2621      x: Input tensor or variable.
2622      mean: Mean of batch.
2623      var: Variance of batch.
2624      beta: Tensor with which to center the input.
2625      gamma: Tensor by which to scale the input.
2626      axis: Integer, the axis that should be normalized.
2627          (typically the features axis).
2628      epsilon: Fuzz factor.
2629
2630  Returns:
2631      A tensor.
2632  """
2633  if ndim(x) == 4:
2634    # The CPU implementation of `fused_batch_norm` only supports NHWC
2635    if axis == 1 or axis == -3:
2636      tf_data_format = 'NCHW'
2637    elif axis == 3 or axis == -1:
2638      tf_data_format = 'NHWC'
2639    else:
2640      tf_data_format = None
2641
2642    if (tf_data_format == 'NHWC' or
2643        tf_data_format == 'NCHW' and _has_nchw_support()):
2644      # The mean / var / beta / gamma tensors may be broadcasted
2645      # so they may have extra axes of size 1, which should be squeezed.
2646      if ndim(mean) > 1:
2647        mean = array_ops.reshape(mean, [-1])
2648      if ndim(var) > 1:
2649        var = array_ops.reshape(var, [-1])
2650      if beta is None:
2651        beta = zeros_like(mean)
2652      elif ndim(beta) > 1:
2653        beta = array_ops.reshape(beta, [-1])
2654      if gamma is None:
2655        gamma = ones_like(mean)
2656      elif ndim(gamma) > 1:
2657        gamma = array_ops.reshape(gamma, [-1])
2658    y, _, _ = nn.fused_batch_norm(
2659        x,
2660        gamma,
2661        beta,
2662        epsilon=epsilon,
2663        mean=mean,
2664        variance=var,
2665        data_format=tf_data_format,
2666        is_training=False
2667    )
2668    return y
2669  return nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
2670
2671
2672# SHAPE OPERATIONS
2673
2674
2675@keras_export('keras.backend.concatenate')
2676def concatenate(tensors, axis=-1):
2677  """Concatenates a list of tensors alongside the specified axis.
2678
2679  Arguments:
2680      tensors: list of tensors to concatenate.
2681      axis: concatenation axis.
2682
2683  Returns:
2684      A tensor.
2685
2686  Example:
2687
2688      >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
2689      >>> b = tf.constant([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
2690      >>> tf.keras.backend.concatenate((a, b), axis=-1)
2691      <tf.Tensor: shape=(3, 6), dtype=int32, numpy=
2692      array([[ 1,  2,  3, 10, 20, 30],
2693             [ 4,  5,  6, 40, 50, 60],
2694             [ 7,  8,  9, 70, 80, 90]], dtype=int32)>
2695
2696  """
2697  if axis < 0:
2698    rank = ndim(tensors[0])
2699    if rank:
2700      axis %= rank
2701    else:
2702      axis = 0
2703
2704  if py_all(is_sparse(x) for x in tensors):
2705    return sparse_ops.sparse_concat(axis, tensors)
2706  elif py_all(isinstance(x, ragged_tensor.RaggedTensor) for x in tensors):
2707    return ragged_concat_ops.concat(tensors, axis)
2708  else:
2709    return array_ops.concat([to_dense(x) for x in tensors], axis)
2710
2711
2712@keras_export('keras.backend.reshape')
2713def reshape(x, shape):
2714  """Reshapes a tensor to the specified shape.
2715
2716  Arguments:
2717      x: Tensor or variable.
2718      shape: Target shape tuple.
2719
2720  Returns:
2721      A tensor.
2722
2723  Example:
2724
2725    >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
2726    >>> a
2727    <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
2728    array([[ 1,  2,  3],
2729           [ 4,  5,  6],
2730           [ 7,  8,  9],
2731           [10, 11, 12]], dtype=int32)>
2732    >>> tf.keras.backend.reshape(a, shape=(2, 6))
2733    <tf.Tensor: shape=(2, 6), dtype=int32, numpy=
2734    array([[ 1,  2,  3,  4,  5,  6],
2735           [ 7,  8,  9, 10, 11, 12]], dtype=int32)>
2736
2737  """
2738  return array_ops.reshape(x, shape)
2739
2740
2741@keras_export('keras.backend.permute_dimensions')
2742def permute_dimensions(x, pattern):
2743  """Permutes axes in a tensor.
2744
2745  Arguments:
2746      x: Tensor or variable.
2747      pattern: A tuple of
2748          dimension indices, e.g. `(0, 2, 1)`.
2749
2750  Returns:
2751      A tensor.
2752
2753  Example:
2754
2755    >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
2756    >>> a
2757    <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
2758    array([[ 1,  2,  3],
2759           [ 4,  5,  6],
2760           [ 7,  8,  9],
2761           [10, 11, 12]], dtype=int32)>
2762    >>> tf.keras.backend.permute_dimensions(a, pattern=(1, 0))
2763    <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
2764    array([[ 1,  4,  7, 10],
2765           [ 2,  5,  8, 11],
2766           [ 3,  6,  9, 12]], dtype=int32)>
2767
2768  """
2769  return array_ops.transpose(x, perm=pattern)
2770
2771
2772@keras_export('keras.backend.resize_images')
2773def resize_images(x, height_factor, width_factor, data_format,
2774                  interpolation='nearest'):
2775  """Resizes the images contained in a 4D tensor.
2776
2777  Arguments:
2778      x: Tensor or variable to resize.
2779      height_factor: Positive integer.
2780      width_factor: Positive integer.
2781      data_format: One of `"channels_first"`, `"channels_last"`.
2782      interpolation: A string, one of `nearest` or `bilinear`.
2783
2784  Returns:
2785      A tensor.
2786
2787  Raises:
2788      ValueError: in case of incorrect value for
2789        `data_format` or `interpolation`.
2790  """
2791  if data_format == 'channels_first':
2792    rows, cols = 2, 3
2793  elif data_format == 'channels_last':
2794    rows, cols = 1, 2
2795  else:
2796    raise ValueError('Invalid `data_format` argument: %s' % (data_format,))
2797
2798  original_shape = int_shape(x)
2799  new_shape = array_ops.shape(x)[rows:cols + 1]
2800  new_shape *= constant_op.constant(
2801      np.array([height_factor, width_factor], dtype='int32'))
2802
2803  if data_format == 'channels_first':
2804    x = permute_dimensions(x, [0, 2, 3, 1])
2805  if interpolation == 'nearest':
2806    x = image_ops.resize_images_v2(
2807        x, new_shape, method=image_ops.ResizeMethod.NEAREST_NEIGHBOR)
2808  elif interpolation == 'bilinear':
2809    x = image_ops.resize_images_v2(x, new_shape,
2810                                   method=image_ops.ResizeMethod.BILINEAR)
2811  else:
2812    raise ValueError('interpolation should be one '
2813                     'of "nearest" or "bilinear".')
2814  if data_format == 'channels_first':
2815    x = permute_dimensions(x, [0, 3, 1, 2])
2816
2817  if original_shape[rows] is None:
2818    new_height = None
2819  else:
2820    new_height = original_shape[rows] * height_factor
2821
2822  if original_shape[cols] is None:
2823    new_width = None
2824  else:
2825    new_width = original_shape[cols] * width_factor
2826
2827  if data_format == 'channels_first':
2828    output_shape = (None, None, new_height, new_width)
2829  else:
2830    output_shape = (None, new_height, new_width, None)
2831  x.set_shape(output_shape)
2832  return x
2833
2834
2835@keras_export('keras.backend.resize_volumes')
2836def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
2837  """Resizes the volume contained in a 5D tensor.
2838
2839  Arguments:
2840      x: Tensor or variable to resize.
2841      depth_factor: Positive integer.
2842      height_factor: Positive integer.
2843      width_factor: Positive integer.
2844      data_format: One of `"channels_first"`, `"channels_last"`.
2845
2846  Returns:
2847      A tensor.
2848
2849  Raises:
2850      ValueError: if `data_format` is neither
2851          `channels_last` or `channels_first`.
2852  """
2853  if data_format == 'channels_first':
2854    output = repeat_elements(x, depth_factor, axis=2)
2855    output = repeat_elements(output, height_factor, axis=3)
2856    output = repeat_elements(output, width_factor, axis=4)
2857    return output
2858  elif data_format == 'channels_last':
2859    output = repeat_elements(x, depth_factor, axis=1)
2860    output = repeat_elements(output, height_factor, axis=2)
2861    output = repeat_elements(output, width_factor, axis=3)
2862    return output
2863  else:
2864    raise ValueError('Invalid data_format: ' + str(data_format))
2865
2866
2867@keras_export('keras.backend.repeat_elements')
2868def repeat_elements(x, rep, axis):
2869  """Repeats the elements of a tensor along an axis, like `np.repeat`.
2870
2871  If `x` has shape `(s1, s2, s3)` and `axis` is `1`, the output
2872  will have shape `(s1, s2 * rep, s3)`.
2873
2874  Arguments:
2875      x: Tensor or variable.
2876      rep: Python integer, number of times to repeat.
2877      axis: Axis along which to repeat.
2878
2879  Returns:
2880      A tensor.
2881
2882  Example:
2883
2884      >>> b = tf.constant([1, 2, 3])
2885      >>> tf.keras.backend.repeat_elements(b, rep=2, axis=0)
2886      <tf.Tensor: shape=(6,), dtype=int32,
2887          numpy=array([1, 1, 2, 2, 3, 3], dtype=int32)>
2888
2889  """
2890  x_shape = x.shape.as_list()
2891  # For static axis
2892  if x_shape[axis] is not None:
2893    # slices along the repeat axis
2894    splits = array_ops.split(value=x,
2895                             num_or_size_splits=x_shape[axis],
2896                             axis=axis)
2897    # repeat each slice the given number of reps
2898    x_rep = [s for s in splits for _ in range(rep)]
2899    return concatenate(x_rep, axis)
2900
2901  # Here we use tf.tile to mimic behavior of np.repeat so that
2902  # we can handle dynamic shapes (that include None).
2903  # To do that, we need an auxiliary axis to repeat elements along
2904  # it and then merge them along the desired axis.
2905
2906  # Repeating
2907  auxiliary_axis = axis + 1
2908  x_shape = array_ops.shape(x)
2909  x_rep = array_ops.expand_dims(x, axis=auxiliary_axis)
2910  reps = np.ones(len(x.shape) + 1)
2911  reps[auxiliary_axis] = rep
2912  x_rep = array_ops.tile(x_rep, reps)
2913
2914  # Merging
2915  reps = np.delete(reps, auxiliary_axis)
2916  reps[axis] = rep
2917  reps = array_ops.constant(reps, dtype='int32')
2918  x_shape *= reps
2919  x_rep = array_ops.reshape(x_rep, x_shape)
2920
2921  # Fix shape representation
2922  x_shape = x.shape.as_list()
2923  x_rep.set_shape(x_shape)
2924  x_rep._keras_shape = tuple(x_shape)
2925  return x_rep
2926
2927
2928@keras_export('keras.backend.repeat')
2929def repeat(x, n):
2930  """Repeats a 2D tensor.
2931
2932  if `x` has shape (samples, dim) and `n` is `2`,
2933  the output will have shape `(samples, 2, dim)`.
2934
2935  Arguments:
2936      x: Tensor or variable.
2937      n: Python integer, number of times to repeat.
2938
2939  Returns:
2940      A tensor.
2941
2942  Example:
2943
2944      >>> b = tf.constant([[1, 2], [3, 4]])
2945      >>> b
2946      <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
2947      array([[1, 2],
2948             [3, 4]], dtype=int32)>
2949      >>> tf.keras.backend.repeat(b, n=2)
2950      <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
2951      array([[[1, 2],
2952              [1, 2]],
2953             [[3, 4],
2954              [3, 4]]], dtype=int32)>
2955
2956  """
2957  assert ndim(x) == 2
2958  x = array_ops.expand_dims(x, 1)
2959  pattern = array_ops.stack([1, n, 1])
2960  return array_ops.tile(x, pattern)
2961
2962
2963@keras_export('keras.backend.arange')
2964def arange(start, stop=None, step=1, dtype='int32'):
2965  """Creates a 1D tensor containing a sequence of integers.
2966
2967  The function arguments use the same convention as
2968  Theano's arange: if only one argument is provided,
2969  it is in fact the "stop" argument and "start" is 0.
2970
2971  The default type of the returned tensor is `'int32'` to
2972  match TensorFlow's default.
2973
2974  Arguments:
2975      start: Start value.
2976      stop: Stop value.
2977      step: Difference between two successive values.
2978      dtype: Integer dtype to use.
2979
2980  Returns:
2981      An integer tensor.
2982
2983  Example:
2984
2985      >>> tf.keras.backend.arange(start=0, stop=10, step=1.5)
2986      <tf.Tensor: shape=(7,), dtype=float32,
2987          numpy=array([0. , 1.5, 3. , 4.5, 6. , 7.5, 9. ], dtype=float32)>
2988
2989
2990
2991  """
2992  # Match the behavior of numpy and Theano by returning an empty sequence.
2993  if stop is None and start < 0:
2994    start = 0
2995  result = math_ops.range(start, limit=stop, delta=step, name='arange')
2996  if dtype != 'int32':
2997    result = cast(result, dtype)
2998  return result
2999
3000
3001@keras_export('keras.backend.tile')
3002def tile(x, n):
3003  """Creates a tensor by tiling `x` by `n`.
3004
3005  Arguments:
3006      x: A tensor or variable
3007      n: A list of integer. The length must be the same as the number of
3008          dimensions in `x`.
3009
3010  Returns:
3011      A tiled tensor.
3012  """
3013  if isinstance(n, int):
3014    n = [n]
3015  return array_ops.tile(x, n)
3016
3017
3018@keras_export('keras.backend.flatten')
3019def flatten(x):
3020  """Flatten a tensor.
3021
3022  Arguments:
3023      x: A tensor or variable.
3024
3025  Returns:
3026      A tensor, reshaped into 1-D
3027
3028  Example:
3029
3030      >>> b = tf.constant([[1, 2], [3, 4]])
3031      >>> b
3032      <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
3033      array([[1, 2],
3034             [3, 4]], dtype=int32)>
3035      >>> tf.keras.backend.flatten(b)
3036      <tf.Tensor: shape=(4,), dtype=int32,
3037          numpy=array([1, 2, 3, 4], dtype=int32)>
3038
3039  """
3040  return array_ops.reshape(x, [-1])
3041
3042
3043@keras_export('keras.backend.batch_flatten')
3044def batch_flatten(x):
3045  """Turn a nD tensor into a 2D tensor with same 0th dimension.
3046
3047  In other words, it flattens each data samples of a batch.
3048
3049  Arguments:
3050      x: A tensor or variable.
3051
3052  Returns:
3053      A tensor.
3054
3055  Examples:
3056    Flattening a 3D tensor to 2D by collapsing the last dimension.
3057
3058  >>> x_batch = tf.keras.backend.ones(shape=(2, 3, 4, 5))
3059  >>> x_batch_flatten = batch_flatten(x_batch)
3060  >>> tf.keras.backend.int_shape(x_batch_flatten)
3061  (2, 60)
3062
3063  """
3064  x = array_ops.reshape(x, array_ops.stack([-1, prod(shape(x)[1:])]))
3065  return x
3066
3067
3068@keras_export('keras.backend.expand_dims')
3069def expand_dims(x, axis=-1):
3070  """Adds a 1-sized dimension at index "axis".
3071
3072  Arguments:
3073      x: A tensor or variable.
3074      axis: Position where to add a new axis.
3075
3076  Returns:
3077      A tensor with expanded dimensions.
3078  """
3079  return array_ops.expand_dims(x, axis)
3080
3081
3082@keras_export('keras.backend.squeeze')
3083def squeeze(x, axis):
3084  """Removes a 1-dimension from the tensor at index "axis".
3085
3086  Arguments:
3087      x: A tensor or variable.
3088      axis: Axis to drop.
3089
3090  Returns:
3091      A tensor with the same data as `x` but reduced dimensions.
3092  """
3093  return array_ops.squeeze(x, [axis])
3094
3095
3096@keras_export('keras.backend.temporal_padding')
3097def temporal_padding(x, padding=(1, 1)):
3098  """Pads the middle dimension of a 3D tensor.
3099
3100  Arguments:
3101      x: Tensor or variable.
3102      padding: Tuple of 2 integers, how many zeros to
3103          add at the start and end of dim 1.
3104
3105  Returns:
3106      A padded 3D tensor.
3107  """
3108  assert len(padding) == 2
3109  pattern = [[0, 0], [padding[0], padding[1]], [0, 0]]
3110  return array_ops.pad(x, pattern)
3111
3112
3113@keras_export('keras.backend.spatial_2d_padding')
3114def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
3115  """Pads the 2nd and 3rd dimensions of a 4D tensor.
3116
3117  Arguments:
3118      x: Tensor or variable.
3119      padding: Tuple of 2 tuples, padding pattern.
3120      data_format: One of `channels_last` or `channels_first`.
3121
3122  Returns:
3123      A padded 4D tensor.
3124
3125  Raises:
3126      ValueError: if `data_format` is neither
3127          `channels_last` or `channels_first`.
3128  """
3129  assert len(padding) == 2
3130  assert len(padding[0]) == 2
3131  assert len(padding[1]) == 2
3132  if data_format is None:
3133    data_format = image_data_format()
3134  if data_format not in {'channels_first', 'channels_last'}:
3135    raise ValueError('Unknown data_format: ' + str(data_format))
3136
3137  if data_format == 'channels_first':
3138    pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])]
3139  else:
3140    pattern = [[0, 0], list(padding[0]), list(padding[1]), [0, 0]]
3141  return array_ops.pad(x, pattern)
3142
3143
3144@keras_export('keras.backend.spatial_3d_padding')
3145def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
3146  """Pads 5D tensor with zeros along the depth, height, width dimensions.
3147
3148  Pads these dimensions with respectively
3149  "padding[0]", "padding[1]" and "padding[2]" zeros left and right.
3150
3151  For 'channels_last' data_format,
3152  the 2nd, 3rd and 4th dimension will be padded.
3153  For 'channels_first' data_format,
3154  the 3rd, 4th and 5th dimension will be padded.
3155
3156  Arguments:
3157      x: Tensor or variable.
3158      padding: Tuple of 3 tuples, padding pattern.
3159      data_format: One of `channels_last` or `channels_first`.
3160
3161  Returns:
3162      A padded 5D tensor.
3163
3164  Raises:
3165      ValueError: if `data_format` is neither
3166          `channels_last` or `channels_first`.
3167
3168  """
3169  assert len(padding) == 3
3170  assert len(padding[0]) == 2
3171  assert len(padding[1]) == 2
3172  assert len(padding[2]) == 2
3173  if data_format is None:
3174    data_format = image_data_format()
3175  if data_format not in {'channels_first', 'channels_last'}:
3176    raise ValueError('Unknown data_format: ' + str(data_format))
3177
3178  if data_format == 'channels_first':
3179    pattern = [[0, 0], [0, 0], [padding[0][0], padding[0][1]],
3180               [padding[1][0], padding[1][1]], [padding[2][0], padding[2][1]]]
3181  else:
3182    pattern = [[0, 0], [padding[0][0], padding[0][1]],
3183               [padding[1][0], padding[1][1]], [padding[2][0],
3184                                                padding[2][1]], [0, 0]]
3185  return array_ops.pad(x, pattern)
3186
3187
3188@keras_export('keras.backend.stack')
3189def stack(x, axis=0):
3190  """Stacks a list of rank `R` tensors into a rank `R+1` tensor.
3191
3192  Arguments:
3193      x: List of tensors.
3194      axis: Axis along which to perform stacking.
3195
3196  Returns:
3197      A tensor.
3198
3199  Example:
3200
3201      >>> a = tf.constant([[1, 2],[3, 4]])
3202      >>> b = tf.constant([[10, 20],[30, 40]])
3203      >>> tf.keras.backend.stack((a, b))
3204      <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
3205      array([[[ 1,  2],
3206              [ 3,  4]],
3207             [[10, 20],
3208              [30, 40]]], dtype=int32)>
3209
3210  """
3211  return array_ops.stack(x, axis=axis)
3212
3213
3214@keras_export('keras.backend.one_hot')
3215def one_hot(indices, num_classes):
3216  """Computes the one-hot representation of an integer tensor.
3217
3218  Arguments:
3219      indices: nD integer tensor of shape
3220          `(batch_size, dim1, dim2, ... dim(n-1))`
3221      num_classes: Integer, number of classes to consider.
3222
3223  Returns:
3224      (n + 1)D one hot representation of the input
3225      with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
3226
3227  Returns:
3228      The one-hot tensor.
3229  """
3230  return array_ops.one_hot(indices, depth=num_classes, axis=-1)
3231
3232
3233@keras_export('keras.backend.reverse')
3234def reverse(x, axes):
3235  """Reverse a tensor along the specified axes.
3236
3237  Arguments:
3238      x: Tensor to reverse.
3239      axes: Integer or iterable of integers.
3240          Axes to reverse.
3241
3242  Returns:
3243      A tensor.
3244  """
3245  if isinstance(axes, int):
3246    axes = [axes]
3247  return array_ops.reverse(x, axes)
3248
3249
3250# VALUE MANIPULATION
3251_VALUE_SET_CODE_STRING = """
3252  >>> K = tf.keras.backend  # Common keras convention
3253  >>> v = K.variable(1.)
3254
3255  >>> # reassign
3256  >>> K.set_value(v, 2.)
3257  >>> print(K.get_value(v))
3258  2.0
3259
3260  >>> # increment
3261  >>> K.set_value(v, K.get_value(v) + 1)
3262  >>> print(K.get_value(v))
3263  3.0
3264
3265  Variable semantics in TensorFlow 2 are eager execution friendly. The above
3266  code is roughly equivalent to:
3267
3268  >>> v = tf.Variable(1.)
3269
3270  >>> _ = v.assign(2.)
3271  >>> print(v.numpy())
3272  2.0
3273
3274  >>> _ = v.assign_add(1.)
3275  >>> print(v.numpy())
3276  3.0"""[3:]  # Prune first newline and indent to match the docstring template.
3277
3278
3279@keras_export('keras.backend.get_value')
3280def get_value(x):
3281  """Returns the value of a variable.
3282
3283  `backend.get_value` is the compliment of `backend.set_value`, and provides
3284  a generic interface for reading from variables while abstracting away the
3285  differences between TensorFlow 1.x and 2.x semantics.
3286
3287  {snippet}
3288
3289  Arguments:
3290      x: input variable.
3291
3292  Returns:
3293      A Numpy array.
3294  """
3295  if not tensor_util.is_tensor(x):
3296    return x
3297  if context.executing_eagerly() or isinstance(x, ops.EagerTensor):
3298    return x.numpy()
3299  if not getattr(x, '_in_graph_mode', True):
3300    # This is a variable which was created in an eager context, but is being
3301    # evaluated from a Graph.
3302    with context.eager_mode():
3303      return x.numpy()
3304
3305  if ops.executing_eagerly_outside_functions():
3306    # This method of evaluating works inside the Keras FuncGraph.
3307    return function([], x)(x)
3308
3309  with x.graph.as_default():
3310    return x.eval(session=get_session((x,)))
3311
3312
3313@keras_export('keras.backend.batch_get_value')
3314def batch_get_value(tensors):
3315  """Returns the value of more than one tensor variable.
3316
3317  Arguments:
3318      tensors: list of ops to run.
3319
3320  Returns:
3321      A list of Numpy arrays.
3322
3323  Raises:
3324      RuntimeError: If this method is called inside defun.
3325  """
3326  if context.executing_eagerly():
3327    return [x.numpy() for x in tensors]
3328  elif ops.inside_function():  # pylint: disable=protected-access
3329    raise RuntimeError('Cannot get value inside Tensorflow graph function.')
3330  if tensors:
3331    return get_session(tensors).run(tensors)
3332  else:
3333    return []
3334
3335
3336@keras_export('keras.backend.set_value')
3337def set_value(x, value):
3338  """Sets the value of a variable, from a Numpy array.
3339
3340  `backend.set_value` is the compliment of `backend.get_value`, and provides
3341  a generic interface for assigning to variables while abstracting away the
3342  differences between TensorFlow 1.x and 2.x semantics.
3343
3344  {snippet}
3345
3346  Arguments:
3347      x: Variable to set to a new value.
3348      value: Value to set the tensor to, as a Numpy array
3349          (of the same shape).
3350  """
3351  value = np.asarray(value, dtype=dtype(x))
3352  if ops.executing_eagerly_outside_functions():
3353    x.assign(value)
3354  else:
3355    with get_graph().as_default():
3356      tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
3357      if hasattr(x, '_assign_placeholder'):
3358        assign_placeholder = x._assign_placeholder
3359        assign_op = x._assign_op
3360      else:
3361        # In order to support assigning weights to resizable variables in
3362        # Keras, we make a placeholder with the correct number of dimensions
3363        # but with None in each dimension. This way, we can assign weights
3364        # of any size (as long as they have the correct dimensionality).
3365        placeholder_shape = tensor_shape.TensorShape([None] * value.ndim)
3366        assign_placeholder = array_ops.placeholder(
3367            tf_dtype, shape=placeholder_shape)
3368        assign_op = x.assign(assign_placeholder)
3369        x._assign_placeholder = assign_placeholder
3370        x._assign_op = assign_op
3371      get_session().run(assign_op, feed_dict={assign_placeholder: value})
3372
3373
3374@keras_export('keras.backend.batch_set_value')
3375def batch_set_value(tuples):
3376  """Sets the values of many tensor variables at once.
3377
3378  Arguments:
3379      tuples: a list of tuples `(tensor, value)`.
3380          `value` should be a Numpy array.
3381  """
3382  if ops.executing_eagerly_outside_functions():
3383    for x, value in tuples:
3384      x.assign(np.asarray(value, dtype=dtype(x)))
3385  else:
3386    with get_graph().as_default():
3387      if tuples:
3388        assign_ops = []
3389        feed_dict = {}
3390        for x, value in tuples:
3391          value = np.asarray(value, dtype=dtype(x))
3392          tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
3393          if hasattr(x, '_assign_placeholder'):
3394            assign_placeholder = x._assign_placeholder
3395            assign_op = x._assign_op
3396          else:
3397            # In order to support assigning weights to resizable variables in
3398            # Keras, we make a placeholder with the correct number of dimensions
3399            # but with None in each dimension. This way, we can assign weights
3400            # of any size (as long as they have the correct dimensionality).
3401            placeholder_shape = tensor_shape.TensorShape([None] * value.ndim)
3402            assign_placeholder = array_ops.placeholder(
3403                tf_dtype, shape=placeholder_shape)
3404            assign_op = x.assign(assign_placeholder)
3405            x._assign_placeholder = assign_placeholder
3406            x._assign_op = assign_op
3407          assign_ops.append(assign_op)
3408          feed_dict[assign_placeholder] = value
3409        get_session().run(assign_ops, feed_dict=feed_dict)
3410
3411
3412get_value.__doc__ = get_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
3413set_value.__doc__ = set_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
3414
3415
3416@keras_export('keras.backend.print_tensor')
3417def print_tensor(x, message=''):
3418  """Prints `message` and the tensor value when evaluated.
3419
3420  Note that `print_tensor` returns a new tensor identical to `x`
3421  which should be used in the following code. Otherwise the
3422  print operation is not taken into account during evaluation.
3423
3424  Example:
3425
3426  >>> x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
3427  >>> tf.keras.backend.print_tensor(x)
3428  <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
3429    array([[1., 2.],
3430           [3., 4.]], dtype=float32)>
3431
3432  Arguments:
3433      x: Tensor to print.
3434      message: Message to print jointly with the tensor.
3435
3436  Returns:
3437      The same tensor `x`, unchanged.
3438  """
3439  if isinstance(x, ops.Tensor) and hasattr(x, 'graph'):
3440    with get_graph().as_default():
3441      op = logging_ops.print_v2(message, x, output_stream=sys.stdout)
3442      with ops.control_dependencies([op]):
3443        return array_ops.identity(x)
3444  else:
3445    logging_ops.print_v2(message, x, output_stream=sys.stdout)
3446    return x
3447
3448# GRAPH MANIPULATION
3449
3450
3451class GraphExecutionFunction(object):
3452  """Runs a computation graph.
3453
3454  It's possible to pass arguments to `tf.Session.run()` via `session_kwargs`.
3455  In particular additional operations via `fetches` argument and additional
3456  tensor substitutions via `feed_dict` arguments. Note that given
3457  substitutions are merged with substitutions from `inputs`. Even though
3458  `feed_dict` is passed once in the constructor (called in `model.compile()`)
3459  we can modify the values in the dictionary. Through this feed_dict we can
3460  provide additional substitutions besides Keras inputs.
3461
3462  Arguments:
3463      inputs: Feed placeholders to the computation graph.
3464      outputs: Output tensors to fetch.
3465      updates: Additional update ops to be run at function call.
3466      name: A name to help users identify what this function does.
3467      session_kwargs: Arguments to `tf.Session.run()`:
3468                      `fetches`, `feed_dict`, `options`, `run_metadata`.
3469  """
3470
3471  def __init__(self, inputs, outputs, updates=None, name=None,
3472               **session_kwargs):
3473    updates = updates or []
3474    if not isinstance(updates, (list, tuple)):
3475      raise TypeError('`updates` in a Keras backend function '
3476                      'should be a list or tuple.')
3477
3478    self._inputs_structure = inputs
3479    self.inputs = nest.flatten(inputs, expand_composites=True)
3480    self._outputs_structure = outputs
3481    self.outputs = cast_variables_to_tensor(
3482        nest.flatten(outputs, expand_composites=True))
3483    # TODO(b/127668432): Consider using autograph to generate these
3484    # dependencies in call.
3485    # Index 0 = total loss or model output for `predict`.
3486    with ops.control_dependencies([self.outputs[0]]):
3487      updates_ops = []
3488      for update in updates:
3489        if isinstance(update, tuple):
3490          p, new_p = update
3491          updates_ops.append(state_ops.assign(p, new_p))
3492        else:
3493          # assumed already an op
3494          updates_ops.append(update)
3495      self.updates_op = control_flow_ops.group(*updates_ops)
3496    self.name = name
3497    # additional tensor substitutions
3498    self.feed_dict = session_kwargs.pop('feed_dict', None)
3499    # additional operations
3500    self.fetches = session_kwargs.pop('fetches', [])
3501    if not isinstance(self.fetches, list):
3502      self.fetches = [self.fetches]
3503    self.run_options = session_kwargs.pop('options', None)
3504    self.run_metadata = session_kwargs.pop('run_metadata', None)
3505    # The main use case of `fetches` being passed to a model is the ability
3506    # to run custom updates
3507    # This requires us to wrap fetches in `identity` ops.
3508    self.fetches = [array_ops.identity(x) for x in self.fetches]
3509    self.session_kwargs = session_kwargs
3510    # This mapping keeps track of the function that should receive the
3511    # output from a fetch in `fetches`: { fetch: function(fetch_output) }
3512    # A Callback can use this to register a function with access to the
3513    # output values for a fetch it added.
3514    self.fetch_callbacks = {}
3515
3516    if session_kwargs:
3517      raise ValueError('Some keys in session_kwargs are not supported at this '
3518                       'time: %s' % (session_kwargs.keys(),))
3519
3520    self._callable_fn = None
3521    self._feed_arrays = None
3522    self._feed_symbols = None
3523    self._symbol_vals = None
3524    self._fetches = None
3525    self._session = None
3526
3527  def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session):
3528    """Generates a callable that runs the graph.
3529
3530    Arguments:
3531      feed_arrays: List of input tensors to be fed Numpy arrays at runtime.
3532      feed_symbols: List of input tensors to be fed symbolic tensors at runtime.
3533      symbol_vals: List of symbolic tensors to be fed to `feed_symbols`.
3534      session: Session to use to generate the callable.
3535
3536    Returns:
3537      Function that runs the graph according to the above options.
3538    """
3539    # Prepare callable options.
3540    callable_opts = config_pb2.CallableOptions()
3541    # Handle external-data feed.
3542    for x in feed_arrays:
3543      callable_opts.feed.append(x.name)
3544    if self.feed_dict:
3545      for key in sorted(self.feed_dict.keys()):
3546        callable_opts.feed.append(key.name)
3547    # Handle symbolic feed.
3548    for x, y in zip(feed_symbols, symbol_vals):
3549      connection = callable_opts.tensor_connection.add()
3550      if x.dtype != y.dtype:
3551        y = math_ops.cast(y, dtype=x.dtype)
3552      from_tensor = ops._as_graph_element(y)
3553      if from_tensor is None:
3554        from_tensor = y
3555      connection.from_tensor = from_tensor.name  # Data tensor
3556      connection.to_tensor = x.name  # Placeholder
3557    # Handle fetches.
3558    for x in self.outputs + self.fetches:
3559      callable_opts.fetch.append(x.name)
3560    # Handle updates.
3561    callable_opts.target.append(self.updates_op.name)
3562    # Handle run_options.
3563    if self.run_options:
3564      callable_opts.run_options.CopyFrom(self.run_options)
3565    # Create callable.
3566    callable_fn = session._make_callable_from_options(callable_opts)
3567    # Cache parameters corresponding to the generated callable, so that
3568    # we can detect future mismatches and refresh the callable.
3569    self._callable_fn = callable_fn
3570    self._feed_arrays = feed_arrays
3571    self._feed_symbols = feed_symbols
3572    self._symbol_vals = symbol_vals
3573    self._fetches = list(self.fetches)
3574    self._session = session
3575
3576  def _call_fetch_callbacks(self, fetches_output):
3577    for fetch, output in zip(self._fetches, fetches_output):
3578      if fetch in self.fetch_callbacks:
3579        self.fetch_callbacks[fetch](output)
3580
3581  def _eval_if_composite(self, tensor):
3582    """Helper method which evaluates any CompositeTensors passed to it."""
3583    # We need to evaluate any composite tensor objects that have been
3584    # reconstructed in 'pack_sequence_as', since otherwise they'll be output as
3585    # actual CompositeTensor objects instead of the value(s) contained in the
3586    # CompositeTensors. E.g., if output_structure contains a SparseTensor, then
3587    # this ensures that we return its value as a SparseTensorValue rather than
3588    # a SparseTensor.
3589    if isinstance(tensor, composite_tensor.CompositeTensor):
3590      return self._session.run(tensor)
3591    else:
3592      return tensor
3593
3594  def __call__(self, inputs):
3595    inputs = nest.flatten(inputs, expand_composites=True)
3596
3597    session = get_session(inputs)
3598    feed_arrays = []
3599    array_vals = []
3600    feed_symbols = []
3601    symbol_vals = []
3602    for tensor, value in zip(self.inputs, inputs):
3603      if value is None:
3604        continue
3605
3606      if tensor_util.is_tensor(value):
3607        # Case: feeding symbolic tensor.
3608        feed_symbols.append(tensor)
3609        symbol_vals.append(value)
3610      else:
3611        # Case: feeding Numpy array.
3612        feed_arrays.append(tensor)
3613        # We need to do array conversion and type casting at this level, since
3614        # `callable_fn` only supports exact matches.
3615        tensor_type = dtypes_module.as_dtype(tensor.dtype)
3616        array_vals.append(np.asarray(value,
3617                                     dtype=tensor_type.as_numpy_dtype))
3618
3619    if self.feed_dict:
3620      for key in sorted(self.feed_dict.keys()):
3621        array_vals.append(
3622            np.asarray(self.feed_dict[key], dtype=key.dtype.base_dtype.name))
3623
3624    # Refresh callable if anything has changed.
3625    if (self._callable_fn is None or feed_arrays != self._feed_arrays or
3626        symbol_vals != self._symbol_vals or
3627        feed_symbols != self._feed_symbols or self.fetches != self._fetches or
3628        session != self._session):
3629      self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
3630
3631    fetched = self._callable_fn(*array_vals,
3632                                run_metadata=self.run_metadata)
3633    self._call_fetch_callbacks(fetched[-len(self._fetches):])
3634    output_structure = nest.pack_sequence_as(
3635        self._outputs_structure,
3636        fetched[:len(self.outputs)],
3637        expand_composites=True)
3638    # We need to evaluate any composite tensor objects that have been
3639    # reconstructed in 'pack_sequence_as', since otherwise they'll be output as
3640    # actual CompositeTensor objects instead of the value(s) contained in the
3641    # CompositeTensors. E.g., if output_structure contains a SparseTensor, then
3642    # this ensures that we return its value as a SparseTensorValue rather than
3643    # a SparseTensor.
3644    return nest.map_structure(self._eval_if_composite, output_structure)
3645
3646
3647class EagerExecutionFunction(object):
3648  """Helper class for constructing a TF graph function from the Keras graph.
3649
3650  Arguments:
3651    inputs: Feed placeholders to the computation graph.
3652    outputs: Output tensors to fetch.
3653    updates: Additional update ops to be run at function call.
3654    name: A name to help users identify what this function does.
3655    session_kwargs: Unsupported.
3656  """
3657
3658  def __init__(self, inputs, outputs, updates=None, name=None):
3659    self.name = name
3660    self._inputs_structure = inputs
3661    inputs = nest.flatten(inputs, expand_composites=True)
3662    self._outputs_structure = outputs
3663    outputs = nest.flatten(outputs, expand_composites=True)
3664
3665    updates = updates or []
3666    if not isinstance(updates, (list, tuple)):
3667      raise TypeError('`updates` in a Keras backend function '
3668                      'should be a list or tuple.')
3669
3670    if updates and not outputs:
3671      # Edge case; never happens in practice
3672      raise ValueError('Cannot create a Keras backend function with updates'
3673                       ' but no outputs during eager execution.')
3674    graphs = {
3675        i.graph
3676        for i in nest.flatten([inputs, outputs, updates])
3677        if hasattr(i, 'graph')
3678    }
3679    if len(graphs) > 1:
3680      raise ValueError('Cannot create an execution function which is comprised '
3681                       'of elements from multiple graphs.')
3682
3683    source_graph = graphs.pop()
3684    global_graph = get_graph()
3685
3686    updates_ops = []
3687    legacy_update_ops = []
3688    for update in updates:
3689      # For legacy reasons it is allowed to pass an update as a tuple
3690      # `(variable, new_value)` (this maps to an assign op). Otherwise it
3691      # is assumed to already be an op -- we cannot control its execution
3692      # order.
3693      if isinstance(update, tuple):
3694        legacy_update_ops.append(update)
3695      else:
3696        if hasattr(update, 'op'):
3697          update = update.op
3698        if update is not None:
3699          # `update.op` may have been None in certain cases.
3700          updates_ops.append(update)
3701
3702    self._freezable_vars_to_feed = []
3703    self._freezable_vars_values = []
3704    freezable_vars_from_keras_graph = object_identity.ObjectIdentitySet(
3705        _FREEZABLE_VARS.get(global_graph, {}))
3706    with _scratch_graph() as exec_graph:
3707      global_graph = get_graph()
3708      if source_graph not in (exec_graph, global_graph):
3709        raise ValueError('Unknown graph. Aborting.')
3710
3711      if source_graph is global_graph and exec_graph is not global_graph:
3712        init_tensors = (
3713            outputs + updates_ops + [p for [p, _] in legacy_update_ops] +
3714            [p_new for [_, p_new] in legacy_update_ops
3715             if isinstance(p_new, ops.Tensor)])
3716        lifted_map = lift_to_graph.lift_to_graph(
3717            tensors=init_tensors,
3718            graph=exec_graph,
3719            sources=inputs,
3720            add_sources=True,
3721            handle_captures=True,
3722            base_graph=source_graph)
3723
3724        inputs = [lifted_map[i] for i in inputs]
3725        outputs = [lifted_map[i] for i in outputs]
3726        updates_ops = [lifted_map[i] for i in updates_ops]
3727        legacy_update_ops = [(lifted_map[p], lifted_map.get(p_new, p_new))
3728                             for p, p_new in legacy_update_ops]
3729
3730        # Keep track of the value to feed to any "freezable variables"
3731        # created in this graph.
3732        for old_op, new_op in lifted_map.items():
3733          if old_op in freezable_vars_from_keras_graph:
3734            frozen_var = old_op
3735            if frozen_var._initial_value != frozen_var._current_value:
3736              # We only feed a frozen_variable if its value has changed;
3737              # otherwise it can rely on the default value of the
3738              # underlying placeholder_with_default.
3739              self._freezable_vars_to_feed.append(new_op)
3740              self._freezable_vars_values.append(frozen_var._current_value)
3741
3742    # Consolidate updates
3743    with exec_graph.as_default():
3744      outputs = cast_variables_to_tensor(outputs)
3745      with ops.control_dependencies(outputs):
3746        for p, p_new in legacy_update_ops:
3747          updates_ops.append(state_ops.assign(p, p_new))
3748
3749      self.inputs, self.outputs = inputs, outputs
3750      self._input_references = self.inputs + self._freezable_vars_to_feed
3751      with ops.control_dependencies(updates_ops):
3752        self.outputs[0] = array_ops.identity(self.outputs[0])
3753
3754      exec_graph.inputs = self._input_references + exec_graph.internal_captures
3755      exec_graph.outputs = self.outputs
3756      graph_fn = eager_function.ConcreteFunction(exec_graph)
3757
3758    graph_fn._num_positional_args = len(self._input_references)
3759    graph_fn._arg_keywords = []
3760    self._graph_fn = graph_fn
3761
3762    # Handle placeholders with default
3763    # (treated as required placeholder by graph functions)
3764    self._placeholder_default_values = {}
3765    with exec_graph.as_default():
3766      for x in self.inputs:
3767        if x.op.type == 'PlaceholderWithDefault':
3768          self._placeholder_default_values[ops.tensor_id(
3769              x)] = tensor_util.constant_value(x.op.inputs[0])
3770
3771  def __call__(self, inputs):
3772    input_values = nest.flatten(inputs, expand_composites=True)
3773
3774    if self._freezable_vars_values:
3775      input_values = input_values + self._freezable_vars_values
3776    converted_inputs = []
3777    for tensor, value in zip(self._input_references, input_values):
3778      if value is None:
3779        # Assume `value` is a placeholder with default
3780        value = self._placeholder_default_values.get(
3781            ops.tensor_id(tensor), None)
3782        if value is None:
3783          raise ValueError(
3784              'You must feed a value for placeholder %s' % (tensor,))
3785      if not isinstance(value, ops.Tensor):
3786        value = ops.convert_to_tensor(value, dtype=tensor.dtype)
3787      if value.dtype != tensor.dtype:
3788        # Temporary workaround due to `convert_to_tensor` not casting floats.
3789        # See b/119637405
3790        value = math_ops.cast(value, tensor.dtype)
3791      converted_inputs.append(value)
3792    outputs = self._graph_fn(*converted_inputs)
3793
3794    # EagerTensor.numpy() will often make a copy to ensure memory safety.
3795    # However in this case `outputs` is not directly returned, so it is always
3796    # safe to reuse the underlying buffer without checking. In such a case the
3797    # private numpy conversion method is preferred to guarantee performance.
3798    return nest.pack_sequence_as(
3799        self._outputs_structure,
3800        [x._numpy() for x in outputs],  # pylint: disable=protected-access
3801        expand_composites=True)
3802
3803
3804@keras_export('keras.backend.function')
3805def function(inputs, outputs, updates=None, name=None, **kwargs):
3806  """Instantiates a Keras function.
3807
3808  Arguments:
3809      inputs: List of placeholder tensors.
3810      outputs: List of output tensors.
3811      updates: List of update ops.
3812      name: String, name of function.
3813      **kwargs: Passed to `tf.Session.run`.
3814
3815  Returns:
3816      Output values as Numpy arrays.
3817
3818  Raises:
3819      ValueError: if invalid kwargs are passed in or if in eager execution.
3820  """
3821  if ops.executing_eagerly_outside_functions():
3822    if kwargs:
3823      raise ValueError('Session keyword arguments are not support during '
3824                       'eager execution. You passed: %s' % (kwargs,))
3825    return EagerExecutionFunction(inputs, outputs, updates=updates, name=name)
3826
3827  if kwargs:
3828    for key in kwargs:
3829      if (key not in tf_inspect.getfullargspec(session_module.Session.run)[0]
3830          and key not in ['inputs', 'outputs', 'updates', 'name']):
3831        msg = ('Invalid argument "%s" passed to K.function with TensorFlow '
3832               'backend') % key
3833        raise ValueError(msg)
3834  return GraphExecutionFunction(inputs, outputs, updates=updates, **kwargs)
3835
3836
3837@keras_export('keras.backend.gradients')
3838def gradients(loss, variables):
3839  """Returns the gradients of `loss` w.r.t. `variables`.
3840
3841  Arguments:
3842      loss: Scalar tensor to minimize.
3843      variables: List of variables.
3844
3845  Returns:
3846      A gradients tensor.
3847  """
3848  return gradients_module.gradients(
3849      loss, variables, colocate_gradients_with_ops=True)
3850
3851
3852@keras_export('keras.backend.stop_gradient')
3853def stop_gradient(variables):
3854  """Returns `variables` but with zero gradient w.r.t. every other variable.
3855
3856  Arguments:
3857      variables: Tensor or list of tensors to consider constant with respect
3858        to any other variable.
3859
3860
3861  Returns:
3862      A single tensor or a list of tensors (depending on the passed argument)
3863      that has no gradient with respect to any other variable.
3864  """
3865  if isinstance(variables, (list, tuple)):
3866    return map(array_ops.stop_gradient, variables)
3867  return array_ops.stop_gradient(variables)
3868
3869
3870# CONTROL FLOW
3871
3872
3873@keras_export('keras.backend.rnn')
3874def rnn(step_function,
3875        inputs,
3876        initial_states,
3877        go_backwards=False,
3878        mask=None,
3879        constants=None,
3880        unroll=False,
3881        input_length=None,
3882        time_major=False,
3883        zero_output_for_mask=False):
3884  """Iterates over the time dimension of a tensor.
3885
3886  Arguments:
3887      step_function: RNN step function.
3888          Args;
3889              input; Tensor with shape `(samples, ...)` (no time dimension),
3890                  representing input for the batch of samples at a certain
3891                  time step.
3892              states; List of tensors.
3893          Returns;
3894              output; Tensor with shape `(samples, output_dim)`
3895                  (no time dimension).
3896              new_states; List of tensors, same length and shapes
3897                  as 'states'. The first state in the list must be the
3898                  output tensor at the previous timestep.
3899      inputs: Tensor of temporal data of shape `(samples, time, ...)`
3900          (at least 3D), or nested tensors, and each of which has shape
3901          `(samples, time, ...)`.
3902      initial_states: Tensor with shape `(samples, state_size)`
3903          (no time dimension), containing the initial values for the states used
3904          in the step function. In the case that state_size is in a nested
3905          shape, the shape of initial_states will also follow the nested
3906          structure.
3907      go_backwards: Boolean. If True, do the iteration over the time
3908          dimension in reverse order and return the reversed sequence.
3909      mask: Binary tensor with shape `(samples, time, 1)`,
3910          with a zero for every element that is masked.
3911      constants: List of constant values passed at each step.
3912      unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
3913      input_length: An integer or a 1-D Tensor, depending on whether
3914          the time dimension is fixed-length or not. In case of variable length
3915          input, it is used for masking in case there's no mask specified.
3916      time_major: Boolean. If true, the inputs and outputs will be in shape
3917          `(timesteps, batch, ...)`, whereas in the False case, it will be
3918          `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
3919          efficient because it avoids transposes at the beginning and end of the
3920          RNN calculation. However, most TensorFlow data is batch-major, so by
3921          default this function accepts input and emits output in batch-major
3922          form.
3923      zero_output_for_mask: Boolean. If True, the output for masked timestep
3924          will be zeros, whereas in the False case, output from previous
3925          timestep is returned.
3926
3927  Returns:
3928      A tuple, `(last_output, outputs, new_states)`.
3929          last_output: the latest output of the rnn, of shape `(samples, ...)`
3930          outputs: tensor with shape `(samples, time, ...)` where each
3931              entry `outputs[s, t]` is the output of the step function
3932              at time `t` for sample `s`.
3933          new_states: list of tensors, latest states returned by
3934              the step function, of shape `(samples, ...)`.
3935
3936  Raises:
3937      ValueError: if input dimension is less than 3.
3938      ValueError: if `unroll` is `True` but input timestep is not a fixed
3939      number.
3940      ValueError: if `mask` is provided (not `None`) but states is not provided
3941          (`len(states)` == 0).
3942  """
3943
3944  def swap_batch_timestep(input_t):
3945    # Swap the batch and timestep dim for the incoming tensor.
3946    axes = list(range(len(input_t.shape)))
3947    axes[0], axes[1] = 1, 0
3948    return array_ops.transpose(input_t, axes)
3949
3950  if not time_major:
3951    inputs = nest.map_structure(swap_batch_timestep, inputs)
3952
3953  flatted_inputs = nest.flatten(inputs)
3954  time_steps = flatted_inputs[0].shape[0]
3955  batch = flatted_inputs[0].shape[1]
3956  time_steps_t = array_ops.shape(flatted_inputs[0])[0]
3957
3958  for input_ in flatted_inputs:
3959    input_.shape.with_rank_at_least(3)
3960
3961  if mask is not None:
3962    if mask.dtype != dtypes_module.bool:
3963      mask = math_ops.cast(mask, dtypes_module.bool)
3964    if len(mask.shape) == 2:
3965      mask = expand_dims(mask)
3966    if not time_major:
3967      mask = swap_batch_timestep(mask)
3968
3969  if constants is None:
3970    constants = []
3971
3972  # tf.where needs its condition tensor to be the same shape as its two
3973  # result tensors, but in our case the condition (mask) tensor is
3974  # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
3975  # So we need to broadcast the mask to match the shape of inputs.
3976  # That's what the tile call does, it just repeats the mask along its
3977  # second dimension n times.
3978  def _expand_mask(mask_t, input_t, fixed_dim=1):
3979    if nest.is_sequence(mask_t):
3980      raise ValueError('mask_t is expected to be tensor, but got %s' % mask_t)
3981    if nest.is_sequence(input_t):
3982      raise ValueError('input_t is expected to be tensor, but got %s' % input_t)
3983    rank_diff = len(input_t.shape) - len(mask_t.shape)
3984    for _ in range(rank_diff):
3985      mask_t = array_ops.expand_dims(mask_t, -1)
3986    multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
3987    return array_ops.tile(mask_t, multiples)
3988
3989  if unroll:
3990    if not time_steps:
3991      raise ValueError('Unrolling requires a fixed number of timesteps.')
3992    states = tuple(initial_states)
3993    successive_states = []
3994    successive_outputs = []
3995
3996    # Process the input tensors. The input tensor need to be split on the
3997    # time_step dim, and reverse if go_backwards is True. In the case of nested
3998    # input, the input is flattened and then transformed individually.
3999    # The result of this will be a tuple of lists, each of the item in tuple is
4000    # list of the tensor with shape (batch, feature)
4001    def _process_single_input_t(input_t):
4002      input_t = array_ops.unstack(input_t)  # unstack for time_step dim
4003      if go_backwards:
4004        input_t.reverse()
4005      return input_t
4006
4007    if nest.is_sequence(inputs):
4008      processed_input = nest.map_structure(_process_single_input_t, inputs)
4009    else:
4010      processed_input = (_process_single_input_t(inputs),)
4011
4012    def _get_input_tensor(time):
4013      inp = [t_[time] for t_ in processed_input]
4014      return nest.pack_sequence_as(inputs, inp)
4015
4016    if mask is not None:
4017      mask_list = array_ops.unstack(mask)
4018      if go_backwards:
4019        mask_list.reverse()
4020
4021      for i in range(time_steps):
4022        inp = _get_input_tensor(i)
4023        mask_t = mask_list[i]
4024        output, new_states = step_function(inp,
4025                                           tuple(states) + tuple(constants))
4026        tiled_mask_t = _expand_mask(mask_t, output)
4027
4028        if not successive_outputs:
4029          prev_output = zeros_like(output)
4030        else:
4031          prev_output = successive_outputs[-1]
4032
4033        output = array_ops.where_v2(tiled_mask_t, output, prev_output)
4034
4035        flat_states = nest.flatten(states)
4036        flat_new_states = nest.flatten(new_states)
4037        tiled_mask_t = tuple(_expand_mask(mask_t, s) for s in flat_states)
4038        flat_final_states = tuple(
4039            array_ops.where_v2(m, s, ps)
4040            for m, s, ps in zip(tiled_mask_t, flat_new_states, flat_states))
4041        states = nest.pack_sequence_as(states, flat_final_states)
4042
4043        successive_outputs.append(output)
4044        successive_states.append(states)
4045      last_output = successive_outputs[-1]
4046      new_states = successive_states[-1]
4047      outputs = array_ops.stack(successive_outputs)
4048
4049      if zero_output_for_mask:
4050        last_output = array_ops.where_v2(
4051            _expand_mask(mask_list[-1], last_output), last_output,
4052            zeros_like(last_output))
4053        outputs = array_ops.where_v2(
4054            _expand_mask(mask, outputs, fixed_dim=2), outputs,
4055            zeros_like(outputs))
4056
4057    else:  # mask is None
4058      for i in range(time_steps):
4059        inp = _get_input_tensor(i)
4060        output, states = step_function(inp, tuple(states) + tuple(constants))
4061        successive_outputs.append(output)
4062        successive_states.append(states)
4063      last_output = successive_outputs[-1]
4064      new_states = successive_states[-1]
4065      outputs = array_ops.stack(successive_outputs)
4066
4067  else:  # Unroll == False
4068    states = tuple(initial_states)
4069
4070    # Create input tensor array, if the inputs is nested tensors, then it will
4071    # be flattened first, and tensor array will be created one per flattened
4072    # tensor.
4073    input_ta = tuple(
4074        tensor_array_ops.TensorArray(
4075            dtype=inp.dtype,
4076            size=time_steps_t,
4077            tensor_array_name='input_ta_%s' % i)
4078        for i, inp in enumerate(flatted_inputs))
4079    input_ta = tuple(
4080        ta.unstack(input_) if not go_backwards else ta
4081        .unstack(reverse(input_, 0))
4082        for ta, input_ in zip(input_ta, flatted_inputs))
4083
4084    # Get the time(0) input and compute the output for that, the output will be
4085    # used to determine the dtype of output tensor array. Don't read from
4086    # input_ta due to TensorArray clear_after_read default to True.
4087    input_time_zero = nest.pack_sequence_as(inputs,
4088                                            [inp[0] for inp in flatted_inputs])
4089    # output_time_zero is used to determine the cell output shape and its dtype.
4090    # the value is discarded.
4091    output_time_zero, _ = step_function(
4092        input_time_zero, tuple(initial_states) + tuple(constants))
4093    output_ta = tuple(
4094        tensor_array_ops.TensorArray(
4095            dtype=out.dtype,
4096            size=time_steps_t,
4097            element_shape=out.shape,
4098            tensor_array_name='output_ta_%s' % i)
4099        for i, out in enumerate(nest.flatten(output_time_zero)))
4100
4101    time = constant_op.constant(0, dtype='int32', name='time')
4102
4103    # We only specify the 'maximum_iterations' when building for XLA since that
4104    # causes slowdowns on GPU in TF.
4105    if (not context.executing_eagerly() and
4106        control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())):
4107      max_iterations = math_ops.reduce_max(input_length)
4108    else:
4109      max_iterations = None
4110
4111    while_loop_kwargs = {
4112        'cond': lambda time, *_: time < time_steps_t,
4113        'maximum_iterations': max_iterations,
4114        'parallel_iterations': 32,
4115        'swap_memory': True,
4116    }
4117    if mask is not None:
4118      if go_backwards:
4119        mask = reverse(mask, 0)
4120
4121      mask_ta = tensor_array_ops.TensorArray(
4122          dtype=dtypes_module.bool,
4123          size=time_steps_t,
4124          tensor_array_name='mask_ta')
4125      mask_ta = mask_ta.unstack(mask)
4126
4127      def masking_fn(time):
4128        return mask_ta.read(time)
4129
4130      def compute_masked_output(mask_t, flat_out, flat_mask):
4131        tiled_mask_t = tuple(
4132            _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
4133            for o in flat_out)
4134        return tuple(
4135            array_ops.where_v2(m, o, fm)
4136            for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask))
4137    elif isinstance(input_length, ops.Tensor):
4138      if go_backwards:
4139        max_len = math_ops.reduce_max(input_length, axis=0)
4140        rev_input_length = math_ops.subtract(max_len - 1, input_length)
4141
4142        def masking_fn(time):
4143          return math_ops.less(rev_input_length, time)
4144      else:
4145
4146        def masking_fn(time):
4147          return math_ops.greater(input_length, time)
4148
4149      def compute_masked_output(mask_t, flat_out, flat_mask):
4150        return tuple(
4151            array_ops.where(mask_t, o, zo)
4152            for (o, zo) in zip(flat_out, flat_mask))
4153    else:
4154      masking_fn = None
4155
4156    if masking_fn is not None:
4157      # Mask for the T output will be base on the output of T - 1. In the case
4158      # T = 0, a zero filled tensor will be used.
4159      flat_zero_output = tuple(array_ops.zeros_like(o)
4160                               for o in nest.flatten(output_time_zero))
4161      def _step(time, output_ta_t, prev_output, *states):
4162        """RNN step function.
4163
4164        Arguments:
4165            time: Current timestep value.
4166            output_ta_t: TensorArray.
4167            prev_output: tuple of outputs from time - 1.
4168            *states: List of states.
4169
4170        Returns:
4171            Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
4172        """
4173        current_input = tuple(ta.read(time) for ta in input_ta)
4174        # maybe set shape.
4175        current_input = nest.pack_sequence_as(inputs, current_input)
4176        mask_t = masking_fn(time)
4177        output, new_states = step_function(current_input,
4178                                           tuple(states) + tuple(constants))
4179        # mask output
4180        flat_output = nest.flatten(output)
4181        flat_mask_output = (flat_zero_output if zero_output_for_mask
4182                            else nest.flatten(prev_output))
4183        flat_new_output = compute_masked_output(mask_t, flat_output,
4184                                                flat_mask_output)
4185
4186        # mask states
4187        flat_state = nest.flatten(states)
4188        flat_new_state = nest.flatten(new_states)
4189        for state, new_state in zip(flat_state, flat_new_state):
4190          if isinstance(new_state, ops.Tensor):
4191            new_state.set_shape(state.shape)
4192        flat_final_state = compute_masked_output(mask_t, flat_new_state,
4193                                                 flat_state)
4194        new_states = nest.pack_sequence_as(new_states, flat_final_state)
4195
4196        output_ta_t = tuple(
4197            ta.write(time, out)
4198            for ta, out in zip(output_ta_t, flat_new_output))
4199        return (time + 1, output_ta_t,
4200                tuple(flat_new_output)) + tuple(new_states)
4201
4202      final_outputs = control_flow_ops.while_loop(
4203          body=_step,
4204          loop_vars=(time, output_ta, flat_zero_output) + states,
4205          **while_loop_kwargs)
4206      # Skip final_outputs[2] which is the output for final timestep.
4207      new_states = final_outputs[3:]
4208    else:
4209      def _step(time, output_ta_t, *states):
4210        """RNN step function.
4211
4212        Arguments:
4213            time: Current timestep value.
4214            output_ta_t: TensorArray.
4215            *states: List of states.
4216
4217        Returns:
4218            Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
4219        """
4220        current_input = tuple(ta.read(time) for ta in input_ta)
4221        current_input = nest.pack_sequence_as(inputs, current_input)
4222        output, new_states = step_function(current_input,
4223                                           tuple(states) + tuple(constants))
4224        flat_state = nest.flatten(states)
4225        flat_new_state = nest.flatten(new_states)
4226        for state, new_state in zip(flat_state, flat_new_state):
4227          if isinstance(new_state, ops.Tensor):
4228            new_state.set_shape(state.shape)
4229
4230        flat_output = nest.flatten(output)
4231        output_ta_t = tuple(
4232            ta.write(time, out) for ta, out in zip(output_ta_t, flat_output))
4233        new_states = nest.pack_sequence_as(initial_states, flat_new_state)
4234        return (time + 1, output_ta_t) + tuple(new_states)
4235
4236      final_outputs = control_flow_ops.while_loop(
4237          body=_step,
4238          loop_vars=(time, output_ta) + states,
4239          **while_loop_kwargs)
4240      new_states = final_outputs[2:]
4241
4242    output_ta = final_outputs[1]
4243
4244    outputs = tuple(o.stack() for o in output_ta)
4245    last_output = tuple(o[-1] for o in outputs)
4246
4247    outputs = nest.pack_sequence_as(output_time_zero, outputs)
4248    last_output = nest.pack_sequence_as(output_time_zero, last_output)
4249
4250  # static shape inference
4251  def set_shape(output_):
4252    if isinstance(output_, ops.Tensor):
4253      shape = output_.shape.as_list()
4254      shape[0] = time_steps
4255      shape[1] = batch
4256      output_.set_shape(shape)
4257    return output_
4258
4259  outputs = nest.map_structure(set_shape, outputs)
4260
4261  if not time_major:
4262    outputs = nest.map_structure(swap_batch_timestep, outputs)
4263
4264  return last_output, outputs, new_states
4265
4266
4267@keras_export('keras.backend.switch')
4268def switch(condition, then_expression, else_expression):
4269  """Switches between two operations depending on a scalar value.
4270
4271  Note that both `then_expression` and `else_expression`
4272  should be symbolic tensors of the *same shape*.
4273
4274  Arguments:
4275      condition: tensor (`int` or `bool`).
4276      then_expression: either a tensor, or a callable that returns a tensor.
4277      else_expression: either a tensor, or a callable that returns a tensor.
4278
4279  Returns:
4280      The selected tensor.
4281
4282  Raises:
4283      ValueError: If rank of `condition` is greater than rank of expressions.
4284  """
4285  if condition.dtype != dtypes_module.bool:
4286    condition = math_ops.cast(condition, 'bool')
4287  cond_ndim = ndim(condition)
4288  if not cond_ndim:
4289    if not callable(then_expression):
4290
4291      def then_expression_fn():
4292        return then_expression
4293    else:
4294      then_expression_fn = then_expression
4295    if not callable(else_expression):
4296
4297      def else_expression_fn():
4298        return else_expression
4299    else:
4300      else_expression_fn = else_expression
4301    x = control_flow_ops.cond(condition, then_expression_fn, else_expression_fn)
4302  else:
4303    # tf.where needs its condition tensor
4304    # to be the same shape as its two
4305    # result tensors
4306    if callable(then_expression):
4307      then_expression = then_expression()
4308    if callable(else_expression):
4309      else_expression = else_expression()
4310    expr_ndim = ndim(then_expression)
4311    if cond_ndim > expr_ndim:
4312      raise ValueError('Rank of `condition` should be less than or'
4313                       ' equal to rank of `then_expression` and '
4314                       '`else_expression`. ndim(condition)=' + str(cond_ndim) +
4315                       ', ndim(then_expression)'
4316                       '=' + str(expr_ndim))
4317    if cond_ndim > 1:
4318      ndim_diff = expr_ndim - cond_ndim
4319      cond_shape = array_ops.concat(
4320          [array_ops.shape(condition), [1] * ndim_diff], axis=0)
4321      condition = array_ops.reshape(condition, cond_shape)
4322      expr_shape = array_ops.shape(then_expression)
4323      shape_diff = expr_shape - cond_shape
4324      tile_shape = array_ops.where_v2(shape_diff > 0, expr_shape,
4325                                      array_ops.ones_like(expr_shape))
4326      condition = array_ops.tile(condition, tile_shape)
4327    x = array_ops.where_v2(condition, then_expression, else_expression)
4328  return x
4329
4330
4331@keras_export('keras.backend.in_train_phase')
4332def in_train_phase(x, alt, training=None):
4333  """Selects `x` in train phase, and `alt` otherwise.
4334
4335  Note that `alt` should have the *same shape* as `x`.
4336
4337  Arguments:
4338      x: What to return in train phase
4339          (tensor or callable that returns a tensor).
4340      alt: What to return otherwise
4341          (tensor or callable that returns a tensor).
4342      training: Optional scalar tensor
4343          (or Python boolean, or Python integer)
4344          specifying the learning phase.
4345
4346  Returns:
4347      Either `x` or `alt` based on the `training` flag.
4348      the `training` flag defaults to `K.learning_phase()`.
4349  """
4350  if training is None:
4351    training = learning_phase()
4352
4353  # TODO(b/138862903): Handle the case when training is tensor.
4354  if not tensor_util.is_tensor(training):
4355    if training == 1 or training is True:
4356      if callable(x):
4357        return x()
4358      else:
4359        return x
4360
4361    elif training == 0 or training is False:
4362      if callable(alt):
4363        return alt()
4364      else:
4365        return alt
4366
4367  # else: assume learning phase is a placeholder tensor.
4368  x = switch(training, x, alt)
4369  return x
4370
4371
4372@keras_export('keras.backend.in_test_phase')
4373def in_test_phase(x, alt, training=None):
4374  """Selects `x` in test phase, and `alt` otherwise.
4375
4376  Note that `alt` should have the *same shape* as `x`.
4377
4378  Arguments:
4379      x: What to return in test phase
4380          (tensor or callable that returns a tensor).
4381      alt: What to return otherwise
4382          (tensor or callable that returns a tensor).
4383      training: Optional scalar tensor
4384          (or Python boolean, or Python integer)
4385          specifying the learning phase.
4386
4387  Returns:
4388      Either `x` or `alt` based on `K.learning_phase`.
4389  """
4390  return in_train_phase(alt, x, training=training)
4391
4392
4393# NN OPERATIONS
4394
4395
4396@keras_export('keras.backend.relu')
4397def relu(x, alpha=0., max_value=None, threshold=0):
4398  """Rectified linear unit.
4399
4400  With default values, it returns element-wise `max(x, 0)`.
4401
4402  Otherwise, it follows:
4403  `f(x) = max_value` for `x >= max_value`,
4404  `f(x) = x` for `threshold <= x < max_value`,
4405  `f(x) = alpha * (x - threshold)` otherwise.
4406
4407  Arguments:
4408      x: A tensor or variable.
4409      alpha: A scalar, slope of negative section (default=`0.`).
4410      max_value: float. Saturation threshold.
4411      threshold: float. Threshold value for thresholded activation.
4412
4413  Returns:
4414      A tensor.
4415  """
4416
4417  if alpha != 0.:
4418    if max_value is None and threshold == 0:
4419      return nn.leaky_relu(x, alpha=alpha)
4420
4421    if threshold != 0:
4422      negative_part = nn.relu(-x + threshold)
4423    else:
4424      negative_part = nn.relu(-x)
4425
4426  clip_max = max_value is not None
4427
4428  if threshold != 0:
4429    # computes x for x > threshold else 0
4430    x = x * math_ops.cast(math_ops.greater(x, threshold), floatx())
4431  elif max_value == 6:
4432    # if no threshold, then can use nn.relu6 native TF op for performance
4433    x = nn.relu6(x)
4434    clip_max = False
4435  else:
4436    x = nn.relu(x)
4437
4438  if clip_max:
4439    max_value = _constant_to_tensor(max_value, x.dtype.base_dtype)
4440    zero = _constant_to_tensor(0, x.dtype.base_dtype)
4441    x = clip_ops.clip_by_value(x, zero, max_value)
4442
4443  if alpha != 0.:
4444    alpha = _to_tensor(alpha, x.dtype.base_dtype)
4445    x -= alpha * negative_part
4446  return x
4447
4448
4449@keras_export('keras.backend.elu')
4450def elu(x, alpha=1.):
4451  """Exponential linear unit.
4452
4453  Arguments:
4454      x: A tensor or variable to compute the activation function for.
4455      alpha: A scalar, slope of negative section.
4456
4457  Returns:
4458      A tensor.
4459  """
4460  res = nn.elu(x)
4461  if alpha == 1:
4462    return res
4463  else:
4464    return array_ops.where_v2(x > 0, res, alpha * res)
4465
4466
4467@keras_export('keras.backend.softmax')
4468def softmax(x, axis=-1):
4469  """Softmax of a tensor.
4470
4471  Arguments:
4472      x: A tensor or variable.
4473      axis: The dimension softmax would be performed on.
4474          The default is -1 which indicates the last dimension.
4475
4476  Returns:
4477      A tensor.
4478  """
4479  return nn.softmax(x, axis=axis)
4480
4481
4482@keras_export('keras.backend.softplus')
4483def softplus(x):
4484  """Softplus of a tensor.
4485
4486  Arguments:
4487      x: A tensor or variable.
4488
4489  Returns:
4490      A tensor.
4491  """
4492  return nn.softplus(x)
4493
4494
4495@keras_export('keras.backend.softsign')
4496def softsign(x):
4497  """Softsign of a tensor.
4498
4499  Arguments:
4500      x: A tensor or variable.
4501
4502  Returns:
4503      A tensor.
4504  """
4505  return nn.softsign(x)
4506
4507
4508def _backtrack_identity(tensor):
4509  while tensor.op.type == 'Identity':
4510    tensor = tensor.op.inputs[0]
4511  return tensor
4512
4513
4514@keras_export('keras.backend.categorical_crossentropy')
4515def categorical_crossentropy(target, output, from_logits=False, axis=-1):
4516  """Categorical crossentropy between an output tensor and a target tensor.
4517
4518  Arguments:
4519      target: A tensor of the same shape as `output`.
4520      output: A tensor resulting from a softmax
4521          (unless `from_logits` is True, in which
4522          case `output` is expected to be the logits).
4523      from_logits: Boolean, whether `output` is the
4524          result of a softmax, or is a tensor of logits.
4525      axis: Int specifying the channels axis. `axis=-1` corresponds to data
4526          format `channels_last', and `axis=1` corresponds to data format
4527          `channels_first`.
4528
4529  Returns:
4530      Output tensor.
4531
4532  Raises:
4533      ValueError: if `axis` is neither -1 nor one of the axes of `output`.
4534
4535  Example:
4536
4537  >>> a = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 1.], shape=[3,3])
4538  >>> print(a)
4539  tf.Tensor(
4540    [[1. 0. 0.]
4541     [0. 1. 0.]
4542     [0. 0. 1.]], shape=(3, 3), dtype=float32)
4543  >>> b = tf.constant([.9, .05, .05, .5, .89, .6, .05, .01, .94], shape=[3,3])
4544  >>> print(b)
4545  tf.Tensor(
4546    [[0.9  0.05 0.05]
4547     [0.5  0.89 0.6 ]
4548     [0.05 0.01 0.94]], shape=(3, 3), dtype=float32)
4549  >>> loss = tf.keras.backend.categorical_crossentropy(a, b)
4550  >>> print(loss)
4551  tf.Tensor([0.10536055 0.8046684  0.06187541], shape=(3,), dtype=float32)
4552  >>> loss = tf.keras.backend.categorical_crossentropy(a, a)
4553  >>> print(loss)
4554  tf.Tensor([1.1920929e-07 1.1920929e-07 1.1920930e-07], shape=(3,),
4555  dtype=float32)
4556
4557  """
4558  target.shape.assert_is_compatible_with(output.shape)
4559  if from_logits:
4560    return nn.softmax_cross_entropy_with_logits_v2(
4561        labels=target, logits=output, axis=axis)
4562
4563  if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
4564    output = _backtrack_identity(output)
4565    if output.op.type == 'Softmax':
4566      # When softmax activation function is used for output operation, we
4567      # use logits from the softmax function directly to compute loss in order
4568      # to prevent collapsing zero when training.
4569      # See b/117284466
4570      assert len(output.op.inputs) == 1
4571      output = output.op.inputs[0]
4572      return nn.softmax_cross_entropy_with_logits_v2(
4573          labels=target, logits=output, axis=axis)
4574
4575  # scale preds so that the class probas of each sample sum to 1
4576  output = output / math_ops.reduce_sum(output, axis, True)
4577  # Compute cross entropy from probabilities.
4578  epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
4579  output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
4580  return -math_ops.reduce_sum(target * math_ops.log(output), axis)
4581
4582
4583@keras_export('keras.backend.sparse_categorical_crossentropy')
4584def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
4585  """Categorical crossentropy with integer targets.
4586
4587  Arguments:
4588      target: An integer tensor.
4589      output: A tensor resulting from a softmax
4590          (unless `from_logits` is True, in which
4591          case `output` is expected to be the logits).
4592      from_logits: Boolean, whether `output` is the
4593          result of a softmax, or is a tensor of logits.
4594      axis: Int specifying the channels axis. `axis=-1` corresponds to data
4595          format `channels_last', and `axis=1` corresponds to data format
4596          `channels_first`.
4597
4598  Returns:
4599      Output tensor.
4600
4601  Raises:
4602      ValueError: if `axis` is neither -1 nor one of the axes of `output`.
4603  """
4604  if not from_logits and not isinstance(
4605      output, (ops.EagerTensor, variables_module.Variable)):
4606    output = _backtrack_identity(output)
4607    if output.op.type == 'Softmax':
4608      # When softmax activation function is used for output operation, we
4609      # use logits from the softmax function directly to compute loss in order
4610      # to prevent collapsing zero when training.
4611      # See b/117284466
4612      assert len(output.op.inputs) == 1
4613      output = output.op.inputs[0]
4614      from_logits = True
4615
4616  if not from_logits:
4617    epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
4618    output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
4619    output = math_ops.log(output)
4620
4621  if isinstance(output.shape, (tuple, list)):
4622    output_rank = len(output.shape)
4623  else:
4624    output_rank = output.shape.ndims
4625  if output_rank is not None:
4626    axis %= output_rank
4627    if axis != output_rank - 1:
4628      permutation = list(
4629          itertools.chain(range(axis), range(axis + 1, output_rank), [axis]))
4630      output = array_ops.transpose(output, perm=permutation)
4631  elif axis != -1:
4632    raise ValueError(
4633        'Cannot compute sparse categorical crossentropy with `axis={}` on an '
4634        'output tensor with unknown rank'.format(axis))
4635
4636  target = cast(target, 'int64')
4637
4638  # Try to adjust the shape so that rank of labels = rank of logits - 1.
4639  output_shape = array_ops.shape_v2(output)
4640  target_rank = target.shape.ndims
4641
4642  update_shape = (
4643      target_rank is not None and output_rank is not None and
4644      target_rank != output_rank - 1)
4645  if update_shape:
4646    target = flatten(target)
4647    output = array_ops.reshape(output, [-1, output_shape[-1]])
4648
4649  if py_any(_is_symbolic_tensor(v) for v in [target, output]):
4650    with get_graph().as_default():
4651      res = nn.sparse_softmax_cross_entropy_with_logits_v2(
4652          labels=target, logits=output)
4653  else:
4654    res = nn.sparse_softmax_cross_entropy_with_logits_v2(
4655        labels=target, logits=output)
4656
4657  if update_shape and output_rank >= 3:
4658    # If our output includes timesteps or spatial dimensions we need to reshape
4659    return array_ops.reshape(res, output_shape[:-1])
4660  else:
4661    return res
4662
4663
4664@keras_export('keras.backend.binary_crossentropy')
4665def binary_crossentropy(target, output, from_logits=False):
4666  """Binary crossentropy between an output tensor and a target tensor.
4667
4668  Arguments:
4669      target: A tensor with the same shape as `output`.
4670      output: A tensor.
4671      from_logits: Whether `output` is expected to be a logits tensor.
4672          By default, we consider that `output`
4673          encodes a probability distribution.
4674
4675  Returns:
4676      A tensor.
4677  """
4678  if from_logits:
4679    return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
4680
4681  if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
4682    output = _backtrack_identity(output)
4683    if output.op.type == 'Sigmoid':
4684      # When sigmoid activation function is used for output operation, we
4685      # use logits from the sigmoid function directly to compute loss in order
4686      # to prevent collapsing zero when training.
4687      assert len(output.op.inputs) == 1
4688      output = output.op.inputs[0]
4689      return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
4690
4691  epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
4692  output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
4693
4694  # Compute cross entropy from probabilities.
4695  bce = target * math_ops.log(output + epsilon())
4696  bce += (1 - target) * math_ops.log(1 - output + epsilon())
4697  return -bce
4698
4699
4700@keras_export('keras.backend.sigmoid')
4701def sigmoid(x):
4702  """Element-wise sigmoid.
4703
4704  Arguments:
4705      x: A tensor or variable.
4706
4707  Returns:
4708      A tensor.
4709  """
4710  return nn.sigmoid(x)
4711
4712
4713@keras_export('keras.backend.hard_sigmoid')
4714def hard_sigmoid(x):
4715  """Segment-wise linear approximation of sigmoid.
4716
4717  Faster than sigmoid.
4718  Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
4719  In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
4720
4721  Arguments:
4722      x: A tensor or variable.
4723
4724  Returns:
4725      A tensor.
4726  """
4727  point_two = _constant_to_tensor(0.2, x.dtype.base_dtype)
4728  point_five = _constant_to_tensor(0.5, x.dtype.base_dtype)
4729  x = math_ops.mul(x, point_two)
4730  x = math_ops.add(x, point_five)
4731  x = clip_ops.clip_by_value(x, 0., 1.)
4732  return x
4733
4734
4735@keras_export('keras.backend.tanh')
4736def tanh(x):
4737  """Element-wise tanh.
4738
4739  Arguments:
4740      x: A tensor or variable.
4741
4742  Returns:
4743      A tensor.
4744  """
4745  return nn.tanh(x)
4746
4747
4748@keras_export('keras.backend.dropout')
4749def dropout(x, level, noise_shape=None, seed=None):
4750  """Sets entries in `x` to zero at random, while scaling the entire tensor.
4751
4752  Arguments:
4753      x: tensor
4754      level: fraction of the entries in the tensor
4755          that will be set to 0.
4756      noise_shape: shape for randomly generated keep/drop flags,
4757          must be broadcastable to the shape of `x`
4758      seed: random seed to ensure determinism.
4759
4760  Returns:
4761      A tensor.
4762  """
4763  if seed is None:
4764    seed = np.random.randint(10e6)
4765  return nn.dropout_v2(x, rate=level, noise_shape=noise_shape, seed=seed)
4766
4767
4768@keras_export('keras.backend.l2_normalize')
4769def l2_normalize(x, axis=None):
4770  """Normalizes a tensor wrt the L2 norm alongside the specified axis.
4771
4772  Arguments:
4773      x: Tensor or variable.
4774      axis: axis along which to perform normalization.
4775
4776  Returns:
4777      A tensor.
4778  """
4779  return nn.l2_normalize(x, axis=axis)
4780
4781
4782@keras_export('keras.backend.in_top_k')
4783def in_top_k(predictions, targets, k):
4784  """Returns whether the `targets` are in the top `k` `predictions`.
4785
4786  Arguments:
4787      predictions: A tensor of shape `(batch_size, classes)` and type `float32`.
4788      targets: A 1D tensor of length `batch_size` and type `int32` or `int64`.
4789      k: An `int`, number of top elements to consider.
4790
4791  Returns:
4792      A 1D tensor of length `batch_size` and type `bool`.
4793      `output[i]` is `True` if `predictions[i, targets[i]]` is within top-`k`
4794      values of `predictions[i]`.
4795  """
4796  return nn.in_top_k(predictions, targets, k)
4797
4798
4799# CONVOLUTIONS
4800
4801
4802def _preprocess_conv1d_input(x, data_format):
4803  """Transpose and cast the input before the conv1d.
4804
4805  Arguments:
4806      x: input tensor.
4807      data_format: string, `"channels_last"` or `"channels_first"`.
4808
4809  Returns:
4810      A tensor.
4811  """
4812  tf_data_format = 'NWC'  # to pass TF Conv2dNative operations
4813  if data_format == 'channels_first':
4814    if not _has_nchw_support():
4815      x = array_ops.transpose(x, (0, 2, 1))  # NCW -> NWC
4816    else:
4817      tf_data_format = 'NCW'
4818  return x, tf_data_format
4819
4820
4821def _preprocess_conv2d_input(x, data_format, force_transpose=False):
4822  """Transpose and cast the input before the conv2d.
4823
4824  Arguments:
4825      x: input tensor.
4826      data_format: string, `"channels_last"` or `"channels_first"`.
4827      force_transpose: Boolean. If True, the input will always be transposed
4828          from NCHW to NHWC if `data_format` is `"channels_first"`.
4829          If False, the transposition only occurs on CPU (GPU ops are
4830          assumed to support NCHW).
4831
4832  Returns:
4833      A tensor.
4834  """
4835  tf_data_format = 'NHWC'
4836  if data_format == 'channels_first':
4837    if not _has_nchw_support() or force_transpose:
4838      x = array_ops.transpose(x, (0, 2, 3, 1))  # NCHW -> NHWC
4839    else:
4840      tf_data_format = 'NCHW'
4841  return x, tf_data_format
4842
4843
4844def _preprocess_conv3d_input(x, data_format):
4845  """Transpose and cast the input before the conv3d.
4846
4847  Arguments:
4848      x: input tensor.
4849      data_format: string, `"channels_last"` or `"channels_first"`.
4850
4851  Returns:
4852      A tensor.
4853  """
4854  tf_data_format = 'NDHWC'
4855  if data_format == 'channels_first':
4856    if not _has_nchw_support():
4857      x = array_ops.transpose(x, (0, 2, 3, 4, 1))
4858    else:
4859      tf_data_format = 'NCDHW'
4860  return x, tf_data_format
4861
4862
4863def _preprocess_padding(padding):
4864  """Convert keras' padding to TensorFlow's padding.
4865
4866  Arguments:
4867      padding: string, one of 'same' , 'valid'
4868
4869  Returns:
4870      a string, one of 'SAME', 'VALID'.
4871
4872  Raises:
4873      ValueError: if invalid `padding'`
4874  """
4875  if padding == 'same':
4876    padding = 'SAME'
4877  elif padding == 'valid':
4878    padding = 'VALID'
4879  else:
4880    raise ValueError('Invalid padding: ' + str(padding))
4881  return padding
4882
4883
4884@keras_export('keras.backend.conv1d')
4885def conv1d(x,
4886           kernel,
4887           strides=1,
4888           padding='valid',
4889           data_format=None,
4890           dilation_rate=1):
4891  """1D convolution.
4892
4893  Arguments:
4894      x: Tensor or variable.
4895      kernel: kernel tensor.
4896      strides: stride integer.
4897      padding: string, `"same"`, `"causal"` or `"valid"`.
4898      data_format: string, one of "channels_last", "channels_first".
4899      dilation_rate: integer dilate rate.
4900
4901  Returns:
4902      A tensor, result of 1D convolution.
4903
4904  Raises:
4905      ValueError: if `data_format` is neither `channels_last` or
4906      `channels_first`.
4907  """
4908  if data_format is None:
4909    data_format = image_data_format()
4910  if data_format not in {'channels_first', 'channels_last'}:
4911    raise ValueError('Unknown data_format: ' + str(data_format))
4912
4913  kernel_shape = kernel.shape.as_list()
4914  if padding == 'causal':
4915    # causal (dilated) convolution:
4916    left_pad = dilation_rate * (kernel_shape[0] - 1)
4917    x = temporal_padding(x, (left_pad, 0))
4918    padding = 'valid'
4919  padding = _preprocess_padding(padding)
4920
4921  x, tf_data_format = _preprocess_conv1d_input(x, data_format)
4922  x = nn.convolution(
4923      input=x,
4924      filter=kernel,
4925      dilation_rate=dilation_rate,
4926      strides=strides,
4927      padding=padding,
4928      data_format=tf_data_format)
4929  if data_format == 'channels_first' and tf_data_format == 'NWC':
4930    x = array_ops.transpose(x, (0, 2, 1))  # NWC -> NCW
4931  return x
4932
4933
4934@keras_export('keras.backend.conv2d')
4935def conv2d(x,
4936           kernel,
4937           strides=(1, 1),
4938           padding='valid',
4939           data_format=None,
4940           dilation_rate=(1, 1)):
4941  """2D convolution.
4942
4943  Arguments:
4944      x: Tensor or variable.
4945      kernel: kernel tensor.
4946      strides: strides tuple.
4947      padding: string, `"same"` or `"valid"`.
4948      data_format: `"channels_last"` or `"channels_first"`.
4949      dilation_rate: tuple of 2 integers.
4950
4951  Returns:
4952      A tensor, result of 2D convolution.
4953
4954  Raises:
4955      ValueError: if `data_format` is neither `channels_last` or
4956      `channels_first`.
4957  """
4958  if data_format is None:
4959    data_format = image_data_format()
4960  if data_format not in {'channels_first', 'channels_last'}:
4961    raise ValueError('Unknown data_format: ' + str(data_format))
4962
4963  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
4964  padding = _preprocess_padding(padding)
4965  x = nn.convolution(
4966      input=x,
4967      filter=kernel,
4968      dilation_rate=dilation_rate,
4969      strides=strides,
4970      padding=padding,
4971      data_format=tf_data_format)
4972  if data_format == 'channels_first' and tf_data_format == 'NHWC':
4973    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
4974  return x
4975
4976
4977@keras_export('keras.backend.conv2d_transpose')
4978def conv2d_transpose(x,
4979                     kernel,
4980                     output_shape,
4981                     strides=(1, 1),
4982                     padding='valid',
4983                     data_format=None,
4984                     dilation_rate=(1, 1)):
4985  """2D deconvolution (i.e.
4986
4987  transposed convolution).
4988
4989  Arguments:
4990      x: Tensor or variable.
4991      kernel: kernel tensor.
4992      output_shape: 1D int tensor for the output shape.
4993      strides: strides tuple.
4994      padding: string, `"same"` or `"valid"`.
4995      data_format: string, `"channels_last"` or `"channels_first"`.
4996      dilation_rate: Tuple of 2 integers.
4997
4998  Returns:
4999      A tensor, result of transposed 2D convolution.
5000
5001  Raises:
5002      ValueError: if `data_format` is neither `channels_last` or
5003      `channels_first`.
5004  """
5005  if data_format is None:
5006    data_format = image_data_format()
5007  if data_format not in {'channels_first', 'channels_last'}:
5008    raise ValueError('Unknown data_format: ' + str(data_format))
5009
5010  # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
5011  if data_format == 'channels_first' and dilation_rate != (1, 1):
5012    force_transpose = True
5013  else:
5014    force_transpose = False
5015
5016  x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
5017
5018  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5019    output_shape = (output_shape[0], output_shape[2], output_shape[3],
5020                    output_shape[1])
5021  if output_shape[0] is None:
5022    output_shape = (shape(x)[0],) + tuple(output_shape[1:])
5023
5024  if isinstance(output_shape, (tuple, list)):
5025    output_shape = array_ops.stack(list(output_shape))
5026
5027  padding = _preprocess_padding(padding)
5028  if tf_data_format == 'NHWC':
5029    strides = (1,) + strides + (1,)
5030  else:
5031    strides = (1, 1) + strides
5032
5033  if dilation_rate == (1, 1):
5034    x = nn.conv2d_transpose(x, kernel, output_shape, strides,
5035                            padding=padding,
5036                            data_format=tf_data_format)
5037  else:
5038    assert dilation_rate[0] == dilation_rate[1]
5039    x = nn.atrous_conv2d_transpose(
5040        x,
5041        kernel,
5042        output_shape,
5043        rate=dilation_rate[0],
5044        padding=padding)
5045  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5046    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5047  return x
5048
5049
5050def separable_conv1d(x,
5051                     depthwise_kernel,
5052                     pointwise_kernel,
5053                     strides=1,
5054                     padding='valid',
5055                     data_format=None,
5056                     dilation_rate=1):
5057  """1D convolution with separable filters.
5058
5059  Arguments:
5060      x: input tensor
5061      depthwise_kernel: convolution kernel for the depthwise convolution.
5062      pointwise_kernel: kernel for the 1x1 convolution.
5063      strides: stride integer.
5064      padding: string, `"same"` or `"valid"`.
5065      data_format: string, `"channels_last"` or `"channels_first"`.
5066      dilation_rate: integer dilation rate.
5067
5068  Returns:
5069      Output tensor.
5070
5071  Raises:
5072      ValueError: if `data_format` is neither `channels_last` or
5073      `channels_first`.
5074  """
5075  if data_format is None:
5076    data_format = image_data_format()
5077  if data_format not in {'channels_first', 'channels_last'}:
5078    raise ValueError('Unknown data_format: ' + str(data_format))
5079
5080  if isinstance(strides, int):
5081    strides = (strides,)
5082  if isinstance(dilation_rate, int):
5083    dilation_rate = (dilation_rate,)
5084
5085  x, tf_data_format = _preprocess_conv1d_input(x, data_format)
5086  padding = _preprocess_padding(padding)
5087  if not isinstance(strides, tuple):
5088    strides = tuple(strides)
5089  if tf_data_format == 'NWC':
5090    spatial_start_dim = 1
5091    strides = (1,) + strides * 2 + (1,)
5092  else:
5093    spatial_start_dim = 2
5094    strides = (1, 1) + strides * 2
5095  x = array_ops.expand_dims(x, spatial_start_dim)
5096  depthwise_kernel = array_ops.expand_dims(depthwise_kernel, 0)
5097  pointwise_kernel = array_ops.expand_dims(pointwise_kernel, 0)
5098  dilation_rate = (1,) + dilation_rate
5099
5100  x = nn.separable_conv2d(
5101      x,
5102      depthwise_kernel,
5103      pointwise_kernel,
5104      strides=strides,
5105      padding=padding,
5106      rate=dilation_rate,
5107      data_format=tf_data_format)
5108
5109  x = array_ops.squeeze(x, [spatial_start_dim])
5110
5111  if data_format == 'channels_first' and tf_data_format == 'NWC':
5112    x = array_ops.transpose(x, (0, 2, 1))  # NWC -> NCW
5113
5114  return x
5115
5116
5117@keras_export('keras.backend.separable_conv2d')
5118def separable_conv2d(x,
5119                     depthwise_kernel,
5120                     pointwise_kernel,
5121                     strides=(1, 1),
5122                     padding='valid',
5123                     data_format=None,
5124                     dilation_rate=(1, 1)):
5125  """2D convolution with separable filters.
5126
5127  Arguments:
5128      x: input tensor
5129      depthwise_kernel: convolution kernel for the depthwise convolution.
5130      pointwise_kernel: kernel for the 1x1 convolution.
5131      strides: strides tuple (length 2).
5132      padding: string, `"same"` or `"valid"`.
5133      data_format: string, `"channels_last"` or `"channels_first"`.
5134      dilation_rate: tuple of integers,
5135          dilation rates for the separable convolution.
5136
5137  Returns:
5138      Output tensor.
5139
5140  Raises:
5141      ValueError: if `data_format` is neither `channels_last` or
5142      `channels_first`.
5143      ValueError: if `strides` is not a tuple of 2 integers.
5144  """
5145  if data_format is None:
5146    data_format = image_data_format()
5147  if data_format not in {'channels_first', 'channels_last'}:
5148    raise ValueError('Unknown data_format: ' + str(data_format))
5149  if len(strides) != 2:
5150    raise ValueError('`strides` must be a tuple of 2 integers.')
5151
5152  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5153  padding = _preprocess_padding(padding)
5154  if not isinstance(strides, tuple):
5155    strides = tuple(strides)
5156  if tf_data_format == 'NHWC':
5157    strides = (1,) + strides + (1,)
5158  else:
5159    strides = (1, 1) + strides
5160
5161  x = nn.separable_conv2d(
5162      x,
5163      depthwise_kernel,
5164      pointwise_kernel,
5165      strides=strides,
5166      padding=padding,
5167      rate=dilation_rate,
5168      data_format=tf_data_format)
5169  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5170    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5171  return x
5172
5173
5174@keras_export('keras.backend.depthwise_conv2d')
5175def depthwise_conv2d(x,
5176                     depthwise_kernel,
5177                     strides=(1, 1),
5178                     padding='valid',
5179                     data_format=None,
5180                     dilation_rate=(1, 1)):
5181  """2D convolution with separable filters.
5182
5183  Arguments:
5184      x: input tensor
5185      depthwise_kernel: convolution kernel for the depthwise convolution.
5186      strides: strides tuple (length 2).
5187      padding: string, `"same"` or `"valid"`.
5188      data_format: string, `"channels_last"` or `"channels_first"`.
5189      dilation_rate: tuple of integers,
5190          dilation rates for the separable convolution.
5191
5192  Returns:
5193      Output tensor.
5194
5195  Raises:
5196      ValueError: if `data_format` is neither `channels_last` or
5197      `channels_first`.
5198  """
5199  if data_format is None:
5200    data_format = image_data_format()
5201  if data_format not in {'channels_first', 'channels_last'}:
5202    raise ValueError('Unknown data_format: ' + str(data_format))
5203
5204  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5205  padding = _preprocess_padding(padding)
5206  if tf_data_format == 'NHWC':
5207    strides = (1,) + strides + (1,)
5208  else:
5209    strides = (1, 1) + strides
5210
5211  x = nn.depthwise_conv2d(
5212      x,
5213      depthwise_kernel,
5214      strides=strides,
5215      padding=padding,
5216      rate=dilation_rate,
5217      data_format=tf_data_format)
5218  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5219    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5220  return x
5221
5222
5223@keras_export('keras.backend.conv3d')
5224def conv3d(x,
5225           kernel,
5226           strides=(1, 1, 1),
5227           padding='valid',
5228           data_format=None,
5229           dilation_rate=(1, 1, 1)):
5230  """3D convolution.
5231
5232  Arguments:
5233      x: Tensor or variable.
5234      kernel: kernel tensor.
5235      strides: strides tuple.
5236      padding: string, `"same"` or `"valid"`.
5237      data_format: string, `"channels_last"` or `"channels_first"`.
5238      dilation_rate: tuple of 3 integers.
5239
5240  Returns:
5241      A tensor, result of 3D convolution.
5242
5243  Raises:
5244      ValueError: if `data_format` is neither `channels_last` or
5245      `channels_first`.
5246  """
5247  if data_format is None:
5248    data_format = image_data_format()
5249  if data_format not in {'channels_first', 'channels_last'}:
5250    raise ValueError('Unknown data_format: ' + str(data_format))
5251
5252  x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5253  padding = _preprocess_padding(padding)
5254  x = nn.convolution(
5255      input=x,
5256      filter=kernel,
5257      dilation_rate=dilation_rate,
5258      strides=strides,
5259      padding=padding,
5260      data_format=tf_data_format)
5261  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5262    x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5263  return x
5264
5265
5266def conv3d_transpose(x,
5267                     kernel,
5268                     output_shape,
5269                     strides=(1, 1, 1),
5270                     padding='valid',
5271                     data_format=None):
5272  """3D deconvolution (i.e.
5273
5274  transposed convolution).
5275
5276  Arguments:
5277      x: input tensor.
5278      kernel: kernel tensor.
5279      output_shape: 1D int tensor for the output shape.
5280      strides: strides tuple.
5281      padding: string, "same" or "valid".
5282      data_format: string, `"channels_last"` or `"channels_first"`.
5283
5284  Returns:
5285      A tensor, result of transposed 3D convolution.
5286
5287  Raises:
5288      ValueError: if `data_format` is neither `channels_last` or
5289      `channels_first`.
5290  """
5291  if data_format is None:
5292    data_format = image_data_format()
5293  if data_format not in {'channels_first', 'channels_last'}:
5294    raise ValueError('Unknown data_format: ' + str(data_format))
5295  if isinstance(output_shape, (tuple, list)):
5296    output_shape = array_ops.stack(output_shape)
5297
5298  x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5299
5300  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5301    output_shape = (output_shape[0], output_shape[2], output_shape[3],
5302                    output_shape[4], output_shape[1])
5303  if output_shape[0] is None:
5304    output_shape = (array_ops.shape(x)[0],) + tuple(output_shape[1:])
5305    output_shape = array_ops.stack(list(output_shape))
5306
5307  padding = _preprocess_padding(padding)
5308  if tf_data_format == 'NDHWC':
5309    strides = (1,) + strides + (1,)
5310  else:
5311    strides = (1, 1) + strides
5312
5313  x = nn.conv3d_transpose(
5314      x,
5315      kernel,
5316      output_shape,
5317      strides,
5318      padding=padding,
5319      data_format=tf_data_format)
5320  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5321    x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5322  return x
5323
5324
5325@keras_export('keras.backend.pool2d')
5326def pool2d(x,
5327           pool_size,
5328           strides=(1, 1),
5329           padding='valid',
5330           data_format=None,
5331           pool_mode='max'):
5332  """2D Pooling.
5333
5334  Arguments:
5335      x: Tensor or variable.
5336      pool_size: tuple of 2 integers.
5337      strides: tuple of 2 integers.
5338      padding: string, `"same"` or `"valid"`.
5339      data_format: string, `"channels_last"` or `"channels_first"`.
5340      pool_mode: string, `"max"` or `"avg"`.
5341
5342  Returns:
5343      A tensor, result of 2D pooling.
5344
5345  Raises:
5346      ValueError: if `data_format` is neither `"channels_last"` or
5347      `"channels_first"`.
5348      ValueError: if `pool_size` is not a tuple of 2 integers.
5349      ValueError: if `strides` is not a tuple of 2 integers.
5350      ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
5351  """
5352  if data_format is None:
5353    data_format = image_data_format()
5354  if data_format not in {'channels_first', 'channels_last'}:
5355    raise ValueError('Unknown data_format: ' + str(data_format))
5356  if len(pool_size) != 2:
5357    raise ValueError('`pool_size` must be a tuple of 2 integers.')
5358  if len(strides) != 2:
5359    raise ValueError('`strides` must be a tuple of 2 integers.')
5360
5361  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5362  padding = _preprocess_padding(padding)
5363  if tf_data_format == 'NHWC':
5364    strides = (1,) + strides + (1,)
5365    pool_size = (1,) + pool_size + (1,)
5366  else:
5367    strides = (1, 1) + strides
5368    pool_size = (1, 1) + pool_size
5369
5370  if pool_mode == 'max':
5371    x = nn.max_pool(
5372        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5373  elif pool_mode == 'avg':
5374    x = nn.avg_pool(
5375        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5376  else:
5377    raise ValueError('Invalid pooling mode: ' + str(pool_mode))
5378
5379  if data_format == 'channels_first' and tf_data_format == 'NHWC':
5380    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5381  return x
5382
5383
5384@keras_export('keras.backend.pool3d')
5385def pool3d(x,
5386           pool_size,
5387           strides=(1, 1, 1),
5388           padding='valid',
5389           data_format=None,
5390           pool_mode='max'):
5391  """3D Pooling.
5392
5393  Arguments:
5394      x: Tensor or variable.
5395      pool_size: tuple of 3 integers.
5396      strides: tuple of 3 integers.
5397      padding: string, `"same"` or `"valid"`.
5398      data_format: string, `"channels_last"` or `"channels_first"`.
5399      pool_mode: string, `"max"` or `"avg"`.
5400
5401  Returns:
5402      A tensor, result of 3D pooling.
5403
5404  Raises:
5405      ValueError: if `data_format` is neither `"channels_last"` or
5406      `"channels_first"`.
5407      ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
5408  """
5409  if data_format is None:
5410    data_format = image_data_format()
5411  if data_format not in {'channels_first', 'channels_last'}:
5412    raise ValueError('Unknown data_format: ' + str(data_format))
5413
5414  x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5415  padding = _preprocess_padding(padding)
5416  if tf_data_format == 'NDHWC':
5417    strides = (1,) + strides + (1,)
5418    pool_size = (1,) + pool_size + (1,)
5419  else:
5420    strides = (1, 1) + strides
5421    pool_size = (1, 1) + pool_size
5422
5423  if pool_mode == 'max':
5424    x = nn.max_pool3d(
5425        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5426  elif pool_mode == 'avg':
5427    x = nn.avg_pool3d(
5428        x, pool_size, strides, padding=padding, data_format=tf_data_format)
5429  else:
5430    raise ValueError('Invalid pooling mode: ' + str(pool_mode))
5431
5432  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5433    x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5434  return x
5435
5436
5437def local_conv(inputs,
5438               kernel,
5439               kernel_size,
5440               strides,
5441               output_shape,
5442               data_format=None):
5443  """Apply N-D convolution with un-shared weights.
5444
5445  Arguments:
5446      inputs: (N+2)-D tensor with shape
5447          (batch_size, channels_in, d_in1, ..., d_inN)
5448          if data_format='channels_first', or
5449          (batch_size, d_in1, ..., d_inN, channels_in)
5450          if data_format='channels_last'.
5451      kernel: the unshared weight for N-D convolution,
5452          with shape (output_items, feature_dim, channels_out), where
5453          feature_dim = np.prod(kernel_size) * channels_in,
5454          output_items = np.prod(output_shape).
5455      kernel_size: a tuple of N integers, specifying the
5456          spatial dimensions of the N-D convolution window.
5457      strides: a tuple of N integers, specifying the strides
5458          of the convolution along the spatial dimensions.
5459      output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial
5460          dimensionality of the output.
5461      data_format: string, "channels_first" or "channels_last".
5462
5463  Returns:
5464      An (N+2)-D tensor with shape:
5465      (batch_size, channels_out) + output_shape
5466      if data_format='channels_first', or:
5467      (batch_size,) + output_shape + (channels_out,)
5468      if data_format='channels_last'.
5469
5470  Raises:
5471      ValueError: if `data_format` is neither
5472      `channels_last` nor `channels_first`.
5473  """
5474  if data_format is None:
5475    data_format = image_data_format()
5476  if data_format not in {'channels_first', 'channels_last'}:
5477    raise ValueError('Unknown data_format: ' + str(data_format))
5478
5479  kernel_shape = int_shape(kernel)
5480  feature_dim = kernel_shape[1]
5481  channels_out = kernel_shape[-1]
5482  ndims = len(output_shape)
5483  spatial_dimensions = list(range(ndims))
5484
5485  xs = []
5486  output_axes_ticks = [range(axis_max) for axis_max in output_shape]
5487  for position in itertools.product(*output_axes_ticks):
5488    slices = [slice(None)]
5489
5490    if data_format == 'channels_first':
5491      slices.append(slice(None))
5492
5493    slices.extend(
5494        slice(position[d] * strides[d], position[d] * strides[d] +
5495              kernel_size[d]) for d in spatial_dimensions)
5496
5497    if data_format == 'channels_last':
5498      slices.append(slice(None))
5499
5500    xs.append(reshape(inputs[slices], (1, -1, feature_dim)))
5501
5502  x_aggregate = concatenate(xs, axis=0)
5503  output = batch_dot(x_aggregate, kernel)
5504  output = reshape(output, output_shape + (-1, channels_out))
5505
5506  if data_format == 'channels_first':
5507    permutation = [ndims, ndims + 1] + spatial_dimensions
5508  else:
5509    permutation = [ndims] + spatial_dimensions + [ndims + 1]
5510
5511  return permute_dimensions(output, permutation)
5512
5513
5514@keras_export('keras.backend.local_conv1d')
5515def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
5516  """Apply 1D conv with un-shared weights.
5517
5518  Arguments:
5519      inputs: 3D tensor with shape:
5520          (batch_size, steps, input_dim)
5521          if data_format is "channels_last" or
5522          (batch_size, input_dim, steps)
5523          if data_format is "channels_first".
5524      kernel: the unshared weight for convolution,
5525          with shape (output_length, feature_dim, filters).
5526      kernel_size: a tuple of a single integer,
5527          specifying the length of the 1D convolution window.
5528      strides: a tuple of a single integer,
5529          specifying the stride length of the convolution.
5530      data_format: the data format, channels_first or channels_last.
5531
5532  Returns:
5533      A 3d tensor with shape:
5534      (batch_size, output_length, filters)
5535      if data_format='channels_first'
5536      or 3D tensor with shape:
5537      (batch_size, filters, output_length)
5538      if data_format='channels_last'.
5539  """
5540  output_shape = (kernel.shape[0],)
5541  return local_conv(inputs,
5542                    kernel,
5543                    kernel_size,
5544                    strides,
5545                    output_shape,
5546                    data_format)
5547
5548
5549@keras_export('keras.backend.local_conv2d')
5550def local_conv2d(inputs,
5551                 kernel,
5552                 kernel_size,
5553                 strides,
5554                 output_shape,
5555                 data_format=None):
5556  """Apply 2D conv with un-shared weights.
5557
5558  Arguments:
5559      inputs: 4D tensor with shape:
5560          (batch_size, filters, new_rows, new_cols)
5561          if data_format='channels_first'
5562          or 4D tensor with shape:
5563          (batch_size, new_rows, new_cols, filters)
5564          if data_format='channels_last'.
5565      kernel: the unshared weight for convolution,
5566          with shape (output_items, feature_dim, filters).
5567      kernel_size: a tuple of 2 integers, specifying the
5568          width and height of the 2D convolution window.
5569      strides: a tuple of 2 integers, specifying the strides
5570          of the convolution along the width and height.
5571      output_shape: a tuple with (output_row, output_col).
5572      data_format: the data format, channels_first or channels_last.
5573
5574  Returns:
5575      A 4D tensor with shape:
5576      (batch_size, filters, new_rows, new_cols)
5577      if data_format='channels_first'
5578      or 4D tensor with shape:
5579      (batch_size, new_rows, new_cols, filters)
5580      if data_format='channels_last'.
5581  """
5582  return local_conv(inputs,
5583                    kernel,
5584                    kernel_size,
5585                    strides,
5586                    output_shape,
5587                    data_format)
5588
5589
5590@keras_export('keras.backend.bias_add')
5591def bias_add(x, bias, data_format=None):
5592  """Adds a bias vector to a tensor.
5593
5594  Arguments:
5595      x: Tensor or variable.
5596      bias: Bias tensor to add.
5597      data_format: string, `"channels_last"` or `"channels_first"`.
5598
5599  Returns:
5600      Output tensor.
5601
5602  Raises:
5603      ValueError: In one of the two cases below:
5604                  1. invalid `data_format` argument.
5605                  2. invalid bias shape.
5606                     the bias should be either a vector or
5607                     a tensor with ndim(x) - 1 dimension
5608  """
5609  if data_format is None:
5610    data_format = image_data_format()
5611  if data_format not in {'channels_first', 'channels_last'}:
5612    raise ValueError('Unknown data_format: ' + str(data_format))
5613  bias_shape = int_shape(bias)
5614  if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1:
5615    raise ValueError(
5616        'Unexpected bias dimensions %d, expect to be 1 or %d dimensions' %
5617        (len(bias_shape), ndim(x)))
5618
5619  if len(bias_shape) == 1:
5620    if data_format == 'channels_first':
5621      return nn.bias_add(x, bias, data_format='NCHW')
5622    return nn.bias_add(x, bias, data_format='NHWC')
5623  if ndim(x) in (3, 4, 5):
5624    if data_format == 'channels_first':
5625      bias_reshape_axis = (1, bias_shape[-1]) + bias_shape[:-1]
5626      return x + reshape(bias, bias_reshape_axis)
5627    return x + reshape(bias, (1,) + bias_shape)
5628  return nn.bias_add(x, bias)
5629
5630
5631# RANDOMNESS
5632
5633
5634@keras_export('keras.backend.random_normal')
5635def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
5636  """Returns a tensor with normal distribution of values.
5637
5638  Arguments:
5639      shape: A tuple of integers, the shape of tensor to create.
5640      mean: A float, mean of the normal distribution to draw samples.
5641      stddev: A float, standard deviation of the normal distribution
5642          to draw samples.
5643      dtype: String, dtype of returned tensor.
5644      seed: Integer, random seed.
5645
5646  Returns:
5647      A tensor.
5648  """
5649  if dtype is None:
5650    dtype = floatx()
5651  if seed is None:
5652    seed = np.random.randint(10e6)
5653  return random_ops.random_normal(
5654      shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed)
5655
5656
5657@keras_export('keras.backend.random_uniform')
5658def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
5659  """Returns a tensor with uniform distribution of values.
5660
5661  Arguments:
5662      shape: A tuple of integers, the shape of tensor to create.
5663      minval: A float, lower boundary of the uniform distribution
5664          to draw samples.
5665      maxval: A float, upper boundary of the uniform distribution
5666          to draw samples.
5667      dtype: String, dtype of returned tensor.
5668      seed: Integer, random seed.
5669
5670  Returns:
5671      A tensor.
5672  """
5673  if dtype is None:
5674    dtype = floatx()
5675  if seed is None:
5676    seed = np.random.randint(10e6)
5677  return random_ops.random_uniform(
5678      shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed)
5679
5680
5681@keras_export('keras.backend.random_binomial')
5682def random_binomial(shape, p=0.0, dtype=None, seed=None):
5683  """Returns a tensor with random binomial distribution of values.
5684
5685  The binomial distribution with parameters `n` and `p` is the probability
5686  distribution of the number of successful Bernoulli process. Only supports
5687  `n` = 1 for now.
5688
5689  Arguments:
5690      shape: A tuple of integers, the shape of tensor to create.
5691      p: A float, `0. <= p <= 1`, probability of binomial distribution.
5692      dtype: String, dtype of returned tensor.
5693      seed: Integer, random seed.
5694
5695  Returns:
5696      A tensor.
5697  """
5698  if dtype is None:
5699    dtype = floatx()
5700  if seed is None:
5701    seed = np.random.randint(10e6)
5702  return array_ops.where_v2(
5703      random_ops.random_uniform(shape, dtype=dtype, seed=seed) <= p,
5704      array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
5705
5706
5707@keras_export('keras.backend.truncated_normal')
5708def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
5709  """Returns a tensor with truncated random normal distribution of values.
5710
5711  The generated values follow a normal distribution
5712  with specified mean and standard deviation,
5713  except that values whose magnitude is more than
5714  two standard deviations from the mean are dropped and re-picked.
5715
5716  Arguments:
5717      shape: A tuple of integers, the shape of tensor to create.
5718      mean: Mean of the values.
5719      stddev: Standard deviation of the values.
5720      dtype: String, dtype of returned tensor.
5721      seed: Integer, random seed.
5722
5723  Returns:
5724      A tensor.
5725  """
5726  if dtype is None:
5727    dtype = floatx()
5728  if seed is None:
5729    seed = np.random.randint(10e6)
5730  return random_ops.truncated_normal(
5731      shape, mean, stddev, dtype=dtype, seed=seed)
5732
5733
5734# CTC
5735# TensorFlow has a native implementation, but it uses sparse tensors
5736# and therefore requires a wrapper for Keras. The functions below convert
5737# dense to sparse tensors and also wraps up the beam search code that is
5738# in TensorFlow's CTC implementation
5739
5740
5741@keras_export('keras.backend.ctc_label_dense_to_sparse')
5742def ctc_label_dense_to_sparse(labels, label_lengths):
5743  """Converts CTC labels from dense to sparse.
5744
5745  Arguments:
5746      labels: dense CTC labels.
5747      label_lengths: length of the labels.
5748
5749  Returns:
5750      A sparse tensor representation of the labels.
5751  """
5752  label_shape = array_ops.shape(labels)
5753  num_batches_tns = array_ops.stack([label_shape[0]])
5754  max_num_labels_tns = array_ops.stack([label_shape[1]])
5755
5756  def range_less_than(old_input, current_input):
5757    return array_ops.expand_dims(
5758        math_ops.range(array_ops.shape(old_input)[1]), 0) < array_ops.fill(
5759            max_num_labels_tns, current_input)
5760
5761  init = math_ops.cast(
5762      array_ops.fill([1, label_shape[1]], 0), dtypes_module.bool)
5763  dense_mask = functional_ops.scan(
5764      range_less_than, label_lengths, initializer=init, parallel_iterations=1)
5765  dense_mask = dense_mask[:, 0, :]
5766
5767  label_array = array_ops.reshape(
5768      array_ops.tile(math_ops.range(0, label_shape[1]), num_batches_tns),
5769      label_shape)
5770  label_ind = array_ops.boolean_mask(label_array, dense_mask)
5771
5772  batch_array = array_ops.transpose(
5773      array_ops.reshape(
5774          array_ops.tile(math_ops.range(0, label_shape[0]), max_num_labels_tns),
5775          reverse(label_shape, 0)))
5776  batch_ind = array_ops.boolean_mask(batch_array, dense_mask)
5777  indices = array_ops.transpose(
5778      array_ops.reshape(concatenate([batch_ind, label_ind], axis=0), [2, -1]))
5779
5780  vals_sparse = array_ops.gather_nd(labels, indices)
5781
5782  return sparse_tensor.SparseTensor(
5783      math_ops.cast(indices, dtypes_module.int64), vals_sparse,
5784      math_ops.cast(label_shape, dtypes_module.int64))
5785
5786
5787@keras_export('keras.backend.ctc_batch_cost')
5788def ctc_batch_cost(y_true, y_pred, input_length, label_length):
5789  """Runs CTC loss algorithm on each batch element.
5790
5791  Arguments:
5792      y_true: tensor `(samples, max_string_length)`
5793          containing the truth labels.
5794      y_pred: tensor `(samples, time_steps, num_categories)`
5795          containing the prediction, or output of the softmax.
5796      input_length: tensor `(samples, 1)` containing the sequence length for
5797          each batch item in `y_pred`.
5798      label_length: tensor `(samples, 1)` containing the sequence length for
5799          each batch item in `y_true`.
5800
5801  Returns:
5802      Tensor with shape (samples,1) containing the
5803          CTC loss of each element.
5804  """
5805  label_length = math_ops.cast(
5806      array_ops.squeeze(label_length, axis=-1), dtypes_module.int32)
5807  input_length = math_ops.cast(
5808      array_ops.squeeze(input_length, axis=-1), dtypes_module.int32)
5809  sparse_labels = math_ops.cast(
5810      ctc_label_dense_to_sparse(y_true, label_length), dtypes_module.int32)
5811
5812  y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
5813
5814  return array_ops.expand_dims(
5815      ctc.ctc_loss(
5816          inputs=y_pred, labels=sparse_labels, sequence_length=input_length), 1)
5817
5818
5819@keras_export('keras.backend.ctc_decode')
5820def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
5821  """Decodes the output of a softmax.
5822
5823  Can use either greedy search (also known as best path)
5824  or a constrained dictionary search.
5825
5826  Arguments:
5827      y_pred: tensor `(samples, time_steps, num_categories)`
5828          containing the prediction, or output of the softmax.
5829      input_length: tensor `(samples, )` containing the sequence length for
5830          each batch item in `y_pred`.
5831      greedy: perform much faster best-path search if `true`.
5832          This does not use a dictionary.
5833      beam_width: if `greedy` is `false`: a beam search decoder will be used
5834          with a beam of this width.
5835      top_paths: if `greedy` is `false`,
5836          how many of the most probable paths will be returned.
5837
5838  Returns:
5839      Tuple:
5840          List: if `greedy` is `true`, returns a list of one element that
5841              contains the decoded sequence.
5842              If `false`, returns the `top_paths` most probable
5843              decoded sequences.
5844              Important: blank labels are returned as `-1`.
5845          Tensor `(top_paths, )` that contains
5846              the log probability of each decoded sequence.
5847  """
5848  y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
5849  input_length = math_ops.cast(input_length, dtypes_module.int32)
5850
5851  if greedy:
5852    (decoded, log_prob) = ctc.ctc_greedy_decoder(
5853        inputs=y_pred, sequence_length=input_length)
5854  else:
5855    (decoded, log_prob) = ctc.ctc_beam_search_decoder(
5856        inputs=y_pred,
5857        sequence_length=input_length,
5858        beam_width=beam_width,
5859        top_paths=top_paths)
5860  decoded_dense = [
5861      sparse_ops.sparse_to_dense(
5862          st.indices, st.dense_shape, st.values, default_value=-1)
5863      for st in decoded
5864  ]
5865  return (decoded_dense, log_prob)
5866
5867
5868# HIGH ORDER FUNCTIONS
5869
5870
5871@keras_export('keras.backend.map_fn')
5872def map_fn(fn, elems, name=None, dtype=None):
5873  """Map the function fn over the elements elems and return the outputs.
5874
5875  Arguments:
5876      fn: Callable that will be called upon each element in elems
5877      elems: tensor
5878      name: A string name for the map node in the graph
5879      dtype: Output data type.
5880
5881  Returns:
5882      Tensor with dtype `dtype`.
5883  """
5884  return map_fn_lib.map_fn(fn, elems, name=name, dtype=dtype)
5885
5886
5887@keras_export('keras.backend.foldl')
5888def foldl(fn, elems, initializer=None, name=None):
5889  """Reduce elems using fn to combine them from left to right.
5890
5891  Arguments:
5892      fn: Callable that will be called upon each element in elems and an
5893          accumulator, for instance `lambda acc, x: acc + x`
5894      elems: tensor
5895      initializer: The first value used (`elems[0]` in case of None)
5896      name: A string name for the foldl node in the graph
5897
5898  Returns:
5899      Tensor with same type and shape as `initializer`.
5900  """
5901  return functional_ops.foldl(fn, elems, initializer=initializer, name=name)
5902
5903
5904@keras_export('keras.backend.foldr')
5905def foldr(fn, elems, initializer=None, name=None):
5906  """Reduce elems using fn to combine them from right to left.
5907
5908  Arguments:
5909      fn: Callable that will be called upon each element in elems and an
5910          accumulator, for instance `lambda acc, x: acc + x`
5911      elems: tensor
5912      initializer: The first value used (`elems[-1]` in case of None)
5913      name: A string name for the foldr node in the graph
5914
5915  Returns:
5916      Same type and shape as initializer
5917  """
5918  return functional_ops.foldr(fn, elems, initializer=initializer, name=name)
5919
5920# Load Keras default configuration from config file if present.
5921# Set Keras base dir path given KERAS_HOME env variable, if applicable.
5922# Otherwise either ~/.keras or /tmp.
5923if 'KERAS_HOME' in os.environ:
5924  _keras_dir = os.environ.get('KERAS_HOME')
5925else:
5926  _keras_base_dir = os.path.expanduser('~')
5927  _keras_dir = os.path.join(_keras_base_dir, '.keras')
5928_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
5929if os.path.exists(_config_path):
5930  try:
5931    with open(_config_path) as fh:
5932      _config = json.load(fh)
5933  except ValueError:
5934    _config = {}
5935  _floatx = _config.get('floatx', floatx())
5936  assert _floatx in {'float16', 'float32', 'float64'}
5937  _epsilon = _config.get('epsilon', epsilon())
5938  assert isinstance(_epsilon, float)
5939  _image_data_format = _config.get('image_data_format', image_data_format())
5940  assert _image_data_format in {'channels_last', 'channels_first'}
5941  set_floatx(_floatx)
5942  set_epsilon(_epsilon)
5943  set_image_data_format(_image_data_format)
5944
5945# Save config file.
5946if not os.path.exists(_keras_dir):
5947  try:
5948    os.makedirs(_keras_dir)
5949  except OSError:
5950    # Except permission denied and potential race conditions
5951    # in multi-threaded environments.
5952    pass
5953
5954if not os.path.exists(_config_path):
5955  _config = {
5956      'floatx': floatx(),
5957      'epsilon': epsilon(),
5958      'backend': 'tensorflow',
5959      'image_data_format': image_data_format()
5960  }
5961  try:
5962    with open(_config_path, 'w') as f:
5963      f.write(json.dumps(_config, indent=4))
5964  except IOError:
5965    # Except permission denied.
5966    pass
5967
5968
5969def configure_and_create_distributed_session(distribution_strategy):
5970  """Configure session config and create a session with it."""
5971
5972  def _create_session(distribution_strategy):
5973    """Create the Distributed Strategy session."""
5974    session_config = get_default_session_config()
5975
5976    # If a session already exists, merge in its config; in the case there is a
5977    # conflict, take values of the existing config.
5978    global _SESSION
5979    if getattr(_SESSION, 'session', None) and _SESSION.session._config:
5980      session_config.MergeFrom(_SESSION.session._config)
5981
5982    if is_tpu_strategy(distribution_strategy):
5983      # TODO(priyag, yuefengz): Remove this workaround when Distribute
5984      # Coordinator is integrated with keras and we can create a session from
5985      # there.
5986      distribution_strategy.configure(session_config)
5987      master = distribution_strategy.extended._tpu_cluster_resolver.master()  # pylint: disable=protected-access
5988      session = session_module.Session(config=session_config, target=master)
5989    else:
5990      worker_context = dc_context.get_current_worker_context()
5991      if worker_context:
5992        dc_session_config = worker_context.session_config
5993        # Merge the default session config to the one from distribute
5994        # coordinator, which is fine for now since they don't have
5995        # conflicting configurations.
5996        dc_session_config.MergeFrom(session_config)
5997        session = session_module.Session(
5998            config=dc_session_config, target=worker_context.master_target)
5999      else:
6000        distribution_strategy.configure(session_config)
6001        session = session_module.Session(config=session_config)
6002
6003    set_session(session)
6004
6005  if distribution_strategy.extended._in_multi_worker_mode():
6006    dc.run_distribute_coordinator(
6007        _create_session,
6008        distribution_strategy,
6009        mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
6010  else:
6011    _create_session(distribution_strategy)
6012
6013
6014def is_tpu_strategy(strategy):
6015  """We're executing TPU Strategy."""
6016  return (strategy is not None and
6017          strategy.__class__.__name__.startswith('TPUStrategy'))
6018
6019
6020def cast_variables_to_tensor(tensors):
6021
6022  def _cast_variables_to_tensor(tensor):
6023    if isinstance(tensor, variables_module.Variable):
6024      return array_ops.identity(tensor)
6025    return tensor
6026
6027  return nest.map_structure(_cast_variables_to_tensor, tensors)
6028
6029
6030def _is_symbolic_tensor(x):
6031  return tensor_util.is_tensor(x) and not isinstance(x, ops.EagerTensor)
6032
6033
6034def convert_inputs_if_ragged(inputs):
6035  """Converts any ragged tensors to dense."""
6036
6037  def _convert_ragged_input(inputs):
6038    if isinstance(inputs, ragged_tensor.RaggedTensor):
6039      return inputs.to_tensor()
6040    return inputs
6041
6042  flat_inputs = nest.flatten(inputs)
6043  contains_ragged = py_any(
6044      isinstance(i, ragged_tensor.RaggedTensor) for i in flat_inputs)
6045
6046  if not contains_ragged:
6047    return inputs, None
6048
6049  inputs = nest.map_structure(_convert_ragged_input, inputs)
6050  # Multiple mask are not yet supported, so one mask is used on all inputs.
6051  # We approach this similarly when using row lengths to ignore steps.
6052  nested_row_lengths = math_ops.cast(flat_inputs[0].nested_row_lengths()[0],
6053                                     'int32')
6054  return inputs, nested_row_lengths
6055
6056
6057def maybe_convert_to_ragged(is_ragged_input, output, nested_row_lengths):
6058  """Converts any ragged input back to its initial structure."""
6059  if not is_ragged_input:
6060    return output
6061
6062  return ragged_tensor.RaggedTensor.from_tensor(output, nested_row_lengths)
6063