• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 # pylint: disable=protected-access
16 # pylint: disable=redefined-outer-name
17 # pylint: disable=redefined-builtin
18 """Keras backend API.
19 """
20 from __future__ import absolute_import
21 from __future__ import division
22 from __future__ import print_function
23 
24 import collections
25 import itertools
26 import json
27 import os
28 import sys
29 import threading
30 import weakref
31 
32 import numpy as np
33 
34 from tensorflow.core.protobuf import config_pb2
35 from tensorflow.python import tf2
36 from tensorflow.python.client import session as session_module
37 from tensorflow.python.distribute import distribute_coordinator as dc
38 from tensorflow.python.distribute import distribute_coordinator_context as dc_context
39 from tensorflow.python.distribute import distribution_strategy_context
40 from tensorflow.python.eager import context
41 from tensorflow.python.eager import function as eager_function
42 from tensorflow.python.eager import lift_to_graph
43 from tensorflow.python.framework import composite_tensor
44 from tensorflow.python.framework import config
45 from tensorflow.python.framework import constant_op
46 from tensorflow.python.framework import device as tfdev
47 from tensorflow.python.framework import dtypes as dtypes_module
48 from tensorflow.python.framework import func_graph
49 from tensorflow.python.framework import ops
50 from tensorflow.python.framework import sparse_tensor
51 from tensorflow.python.framework import tensor_shape
52 from tensorflow.python.framework import tensor_util
53 from tensorflow.python.keras import backend_config
54 from tensorflow.python.ops import array_ops
55 from tensorflow.python.ops import clip_ops
56 from tensorflow.python.ops import control_flow_ops
57 from tensorflow.python.ops import control_flow_util
58 from tensorflow.python.ops import ctc_ops as ctc
59 from tensorflow.python.ops import functional_ops
60 from tensorflow.python.ops import gradients as gradients_module
61 from tensorflow.python.ops import image_ops
62 from tensorflow.python.ops import init_ops
63 from tensorflow.python.ops import linalg_ops
64 from tensorflow.python.ops import logging_ops
65 from tensorflow.python.ops import map_fn as map_fn_lib
66 from tensorflow.python.ops import math_ops
67 from tensorflow.python.ops import nn
68 from tensorflow.python.ops import random_ops
69 from tensorflow.python.ops import sparse_ops
70 from tensorflow.python.ops import state_ops
71 from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
72 from tensorflow.python.ops import tensor_array_ops
73 from tensorflow.python.ops import variables as variables_module
74 from tensorflow.python.ops.ragged import ragged_concat_ops
75 from tensorflow.python.ops.ragged import ragged_tensor
76 from tensorflow.python.platform import tf_logging as logging
77 from tensorflow.python.training import moving_averages
78 from tensorflow.python.util import nest
79 from tensorflow.python.util import object_identity
80 from tensorflow.python.util import tf_contextlib
81 from tensorflow.python.util import tf_inspect
82 from tensorflow.python.util.tf_export import keras_export
83 
84 py_all = all
85 py_sum = sum
86 py_any = any
87 
88 # INTERNAL UTILS
89 
90 # The internal graph maintained by Keras and used by the symbolic Keras APIs
91 # while executing eagerly (such as the functional API for model-building).
92 _GRAPH = None
93 
94 # A graph which is used for constructing functions in eager mode.
95 _CURRENT_SCRATCH_GRAPH = None
96 
97 # This is a thread local object that will hold the default internal TF session
98 # used by Keras. It can be set manually via `set_session(sess)`.
99 _SESSION = threading.local()
100 
101 # This dictionary holds a mapping {graph: learning_phase}.
102 # A learning phase is a bool tensor used to run Keras models in
103 # either train mode (learning_phase == 1) or test mode (learning_phase == 0).
104 _GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
105 
106 # This dictionary holds a mapping {graph: set_of_freezable_variables}.
107 # Each set tracks objects created via `freezable_variable` in the graph.
108 _FREEZABLE_VARS = weakref.WeakKeyDictionary()
109 
110 
111 # _DUMMY_EAGER_GRAPH.key is used as a key in _GRAPH_LEARNING_PHASES.
112 # We keep a separate reference to it to make sure it does not get removed from
113 # _GRAPH_LEARNING_PHASES.
114 # _DummyEagerGraph inherits from threading.local to make its `key` attribute
115 # thread local. This is needed to make set_learning_phase affect only the
116 # current thread during eager execution (see b/123096885 for more details).
117 class _DummyEagerGraph(threading.local):
118   """_DummyEagerGraph provides a thread local `key` attribute.
119 
120   We can't use threading.local directly, i.e. without subclassing, because
121   gevent monkey patches threading.local and its version does not support
122   weak references.
123   """
124 
125   class _WeakReferencableClass(object):
126     """This dummy class is needed for two reasons.
127 
128     - We need something that supports weak references. Basic types like string
129     and ints don't.
130     - We need something whose hash and equality are based on object identity
131     to make sure they are treated as different keys to _GRAPH_LEARNING_PHASES.
132 
133     An empty Python class satisfies both of these requirements.
134     """
135     pass
136 
137   def __init__(self):
138     # Constructors for classes subclassing threading.local run once
139     # per thread accessing something in the class. Thus, each thread will
140     # get a different key.
141     super(_DummyEagerGraph, self).__init__()
142     self.key = _DummyEagerGraph._WeakReferencableClass()
143 
144 
145 _DUMMY_EAGER_GRAPH = _DummyEagerGraph()
146 
147 # This boolean flag can be set to True to leave variable initialization
148 # up to the user.
149 # Change its value via `manual_variable_initialization(value)`.
150 _MANUAL_VAR_INIT = False
151 
152 # This list holds the available devices.
153 # It is populated when `_get_available_gpus()` is called for the first time.
154 # We assume our devices don't change henceforth.
155 _LOCAL_DEVICES = None
156 
157 # This dictionary holds a mapping between a graph and variables to initialize
158 # in the graph.
159 _GRAPH_VARIABLES = weakref.WeakKeyDictionary()
160 
161 # This dictionary holds a mapping between a graph and TF optimizers created in
162 # the graph.
163 _GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
164 
165 # The below functions are kept accessible from backend for compatibility.
166 epsilon = backend_config.epsilon
167 floatx = backend_config.floatx
168 image_data_format = backend_config.image_data_format
169 set_epsilon = backend_config.set_epsilon
170 set_floatx = backend_config.set_floatx
171 set_image_data_format = backend_config.set_image_data_format
172 
173 
174 @keras_export('keras.backend.backend')
175 def backend():
176   """Publicly accessible method for determining the current backend.
177 
178   Only exists for API compatibility with multi-backend Keras.
179 
180   Returns:
181       The string "tensorflow".
182   """
183   return 'tensorflow'
184 
185 
186 @keras_export('keras.backend.cast_to_floatx')
187 def cast_to_floatx(x):
188   """Cast a Numpy array to the default Keras float type.
189 
190   Arguments:
191       x: Numpy array or TensorFlow tensor.
192 
193   Returns:
194       The same array (Numpy array if `x` was a Numpy array, or TensorFlow tensor
195       if `x` was a tensor), cast to its new type.
196 
197   Example:
198 
199   >>> tf.keras.backend.floatx()
200   'float32'
201   >>> arr = np.array([1.0, 2.0], dtype='float64')
202   >>> arr.dtype
203   dtype('float64')
204   >>> new_arr = cast_to_floatx(arr)
205   >>> new_arr
206   array([1.,  2.], dtype=float32)
207   >>> new_arr.dtype
208   dtype('float32')
209 
210   """
211   if isinstance(x, (ops.Tensor,
212                     variables_module.Variable,
213                     sparse_tensor.SparseTensor)):
214     return math_ops.cast(x, dtype=floatx())
215   return np.asarray(x, dtype=floatx())
216 
217 
218 # A global dictionary mapping graph objects to an index of counters used
219 # for various layer/optimizer names in each graph.
220 # Allows to give unique autogenerated names to layers, in a graph-specific way.
221 PER_GRAPH_OBJECT_NAME_UIDS = weakref.WeakKeyDictionary()
222 
223 
224 @keras_export('keras.backend.get_uid')
225 def get_uid(prefix=''):
226   """Associates a string prefix with an integer counter in a TensorFlow graph.
227 
228   Arguments:
229     prefix: String prefix to index.
230 
231   Returns:
232     Unique integer ID.
233 
234   Example:
235 
236   >>> get_uid('dense')
237   1
238   >>> get_uid('dense')
239   2
240 
241   """
242   graph = get_graph()
243   if graph not in PER_GRAPH_OBJECT_NAME_UIDS:
244     PER_GRAPH_OBJECT_NAME_UIDS[graph] = collections.defaultdict(int)
245   layer_name_uids = PER_GRAPH_OBJECT_NAME_UIDS[graph]
246   layer_name_uids[prefix] += 1
247   return layer_name_uids[prefix]
248 
249 
250 @keras_export('keras.backend.reset_uids')
251 def reset_uids():
252   """Resets graph identifiers.
253   """
254 
255   PER_GRAPH_OBJECT_NAME_UIDS.clear()
256 
257 
258 @keras_export('keras.backend.clear_session')
259 def clear_session():
260   """Destroys the current TF graph and creates a new one.
261 
262   Useful to avoid clutter from old models / layers.
263   """
264   global _SESSION
265   global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
266   global _GRAPH_VARIABLES  # pylint: disable=global-variable-not-assigned
267   global _GRAPH_TF_OPTIMIZERS  # pylint: disable=global-variable-not-assigned
268   global _GRAPH
269   global _FREEZABLE_VARS
270   _GRAPH = None
271   ops.reset_default_graph()
272   reset_uids()
273   _SESSION.session = None
274   graph = get_graph()
275   with graph.as_default():
276     with name_scope(''):
277       phase = array_ops.placeholder_with_default(
278           False, shape=(), name='keras_learning_phase')
279     _GRAPH_LEARNING_PHASES = {}
280     _GRAPH_LEARNING_PHASES[graph] = phase
281     _GRAPH_VARIABLES.pop(graph, None)
282     _GRAPH_TF_OPTIMIZERS.pop(graph, None)
283     _FREEZABLE_VARS.pop(graph, None)
284 
285 
286 @keras_export('keras.backend.manual_variable_initialization')
287 def manual_variable_initialization(value):
288   """Sets the manual variable initialization flag.
289 
290   This boolean flag determines whether
291   variables should be initialized
292   as they are instantiated (default), or if
293   the user should handle the initialization
294   (e.g. via `tf.compat.v1.initialize_all_variables()`).
295 
296   Arguments:
297       value: Python boolean.
298   """
299   global _MANUAL_VAR_INIT
300   _MANUAL_VAR_INIT = value
301 
302 
303 @keras_export('keras.backend.learning_phase')
304 def learning_phase():
305   """Returns the learning phase flag.
306 
307   The learning phase flag is a bool tensor (0 = test, 1 = train)
308   to be passed as input to any Keras function
309   that uses a different behavior at train time and test time.
310 
311   Returns:
312       Learning phase (scalar integer tensor or Python integer).
313   """
314   graph = ops.get_default_graph()
315   if graph is _GRAPH:
316     # Don't enter an init_scope for the learning phase if eager execution
317     # is enabled but we're inside the Keras workspace graph.
318     learning_phase = symbolic_learning_phase()
319     _mark_func_graph_as_unsaveable(graph, learning_phase)
320     return learning_phase
321   with ops.init_scope():
322     # We always check & set the learning phase inside the init_scope,
323     # otherwise the wrong default_graph will be used to look up the learning
324     # phase inside of functions & defuns.
325     #
326     # This is because functions & defuns (both in graph & in eager mode)
327     # will always execute non-eagerly using a function-specific default
328     # subgraph.
329     if context.executing_eagerly():
330       if _DUMMY_EAGER_GRAPH.key not in _GRAPH_LEARNING_PHASES:
331         # Fallback to inference mode as default.
332         return 0
333       return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
334     learning_phase = symbolic_learning_phase()
335     _mark_func_graph_as_unsaveable(graph, learning_phase)
336     return learning_phase
337 
338 
339 def global_learning_phase_is_set():
340   return _DUMMY_EAGER_GRAPH.key in _GRAPH_LEARNING_PHASES
341 
342 
343 def _mark_func_graph_as_unsaveable(graph, learning_phase):
344   """Mark func graph as unsaveable due to use of symbolic keras learning phase.
345 
346   Functions that capture the symbolic learning phase cannot be exported to
347   SavedModel. Mark the funcgraph as unsaveable, so that an error will be raised
348   if it is exported.
349 
350   Args:
351     graph: Graph or FuncGraph object.
352     learning_phase: Learning phase placeholder or int defined in the graph.
353   """
354   if graph.building_function and is_placeholder(learning_phase):
355     graph.mark_as_unsaveable(
356         'The keras learning phase placeholder was used inside a function. '
357         'Exporting placeholders is not supported when saving out a SavedModel. '
358         'Please call `tf.keras.backend.set_learning_phase(0)` in the function '
359         'to set the learning phase to a constant value.')
360 
361 
362 def symbolic_learning_phase():
363   graph = get_graph()
364   with graph.as_default():
365     if graph not in _GRAPH_LEARNING_PHASES:
366       with name_scope(''):
367         phase = array_ops.placeholder_with_default(
368             False, shape=(), name='keras_learning_phase')
369       _GRAPH_LEARNING_PHASES[graph] = phase
370     return _GRAPH_LEARNING_PHASES[graph]
371 
372 
373 @keras_export('keras.backend.set_learning_phase')
374 def set_learning_phase(value):
375   """Sets the learning phase to a fixed value.
376 
377   Arguments:
378       value: Learning phase value, either 0 or 1 (integers).
379              0 = test, 1 = train
380 
381   Raises:
382       ValueError: if `value` is neither `0` nor `1`.
383   """
384   global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
385   if value not in {0, 1}:
386     raise ValueError('Expected learning phase to be 0 or 1.')
387   with ops.init_scope():
388     if context.executing_eagerly():
389       # In an eager context, the learning phase values applies to both the eager
390       # context and the internal Keras graph.
391       _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value
392     _GRAPH_LEARNING_PHASES[get_graph()] = value
393 
394 
395 @keras_export('keras.backend.learning_phase_scope')
396 @tf_contextlib.contextmanager
397 def learning_phase_scope(value):
398   """Provides a scope within which the learning phase is equal to `value`.
399 
400   The learning phase gets restored to its original value upon exiting the scope.
401 
402   Arguments:
403      value: Learning phase value, either 0 or 1 (integers).
404             0 = test, 1 = train
405 
406   Yields:
407     None.
408 
409   Raises:
410      ValueError: if `value` is neither `0` nor `1`.
411   """
412   global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
413   if value not in {0, 1}:
414     raise ValueError('Expected learning phase to be 0 or 1.')
415 
416   with ops.init_scope():
417     if context.executing_eagerly():
418       previous_eager_value = _GRAPH_LEARNING_PHASES.get(
419           _DUMMY_EAGER_GRAPH.key, None)
420     previous_graph_value = _GRAPH_LEARNING_PHASES.get(get_graph(), None)
421 
422   try:
423     set_learning_phase(value)
424     yield
425   finally:
426     # Restore learning phase to initial value.
427     with ops.init_scope():
428       if context.executing_eagerly():
429         if previous_eager_value is not None:
430           _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = previous_eager_value
431         elif _DUMMY_EAGER_GRAPH.key in _GRAPH_LEARNING_PHASES:
432           del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
433 
434       graph = get_graph()
435       if previous_graph_value is not None:
436         _GRAPH_LEARNING_PHASES[graph] = previous_graph_value
437       elif graph in _GRAPH_LEARNING_PHASES:
438         del _GRAPH_LEARNING_PHASES[graph]
439 
440 
441 @tf_contextlib.contextmanager
442 def eager_learning_phase_scope(value):
443   """Internal scope that sets the learning phase in eager / tf.function only.
444 
445   Arguments:
446       value: Learning phase value, either 0 or 1 (integers).
447              0 = test, 1 = train
448 
449   Yields:
450     None.
451 
452   Raises:
453      ValueError: if `value` is neither `0` nor `1`.
454   """
455   global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
456   assert value in {0, 1}
457   assert ops.executing_eagerly_outside_functions()
458   global_learning_phase_was_set = global_learning_phase_is_set()
459   if global_learning_phase_was_set:
460     previous_value = learning_phase()
461   try:
462     _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value
463     yield
464   finally:
465     # Restore learning phase to initial value or unset.
466     if global_learning_phase_was_set:
467       _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = previous_value
468     else:
469       del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
470 
471 
472 def _current_graph(op_input_list):
473   """Return the graph members of `op_input_list`, or the current graph."""
474   return ops._get_graph_from_inputs(op_input_list)
475 
476 
477 def _get_session(op_input_list=()):
478   """Returns the session object for the current thread."""
479   global _SESSION
480   default_session = ops.get_default_session()
481   if default_session is not None:
482     session = default_session
483   else:
484     if ops.inside_function():
485       raise RuntimeError('Cannot get session inside Tensorflow graph function.')
486     # If we don't have a session, or that session does not match the current
487     # graph, create and cache a new session.
488     if (getattr(_SESSION, 'session', None) is None or
489         _SESSION.session.graph is not _current_graph(op_input_list)):
490       # If we are creating the Session inside a tf.distribute.Strategy scope,
491       # we ask the strategy for the right session options to use.
492       if distribution_strategy_context.has_strategy():
493         configure_and_create_distributed_session(
494             distribution_strategy_context.get_strategy())
495       else:
496         _SESSION.session = session_module.Session(
497             config=get_default_session_config())
498     session = _SESSION.session
499   return session
500 
501 
502 @keras_export(v1=['keras.backend.get_session'])
503 def get_session(op_input_list=()):
504   """Returns the TF session to be used by the backend.
505 
506   If a default TensorFlow session is available, we will return it.
507 
508   Else, we will return the global Keras session assuming it matches
509   the current graph.
510 
511   If no global Keras session exists at this point:
512   we will create a new global session.
513 
514   Note that you can manually set the global session
515   via `K.set_session(sess)`.
516 
517   Arguments:
518       op_input_list: An option sequence of tensors or ops, which will be used
519         to determine the current graph. Otherwise the default graph will be
520         used.
521 
522   Returns:
523       A TensorFlow session.
524   """
525   session = _get_session(op_input_list)
526   if not _MANUAL_VAR_INIT:
527     with session.graph.as_default():
528       _initialize_variables(session)
529   return session
530 
531 
532 def get_graph():
533   if context.executing_eagerly():
534     global _GRAPH
535     if _GRAPH is None:
536       _GRAPH = func_graph.FuncGraph('keras_graph')
537     return _GRAPH
538   else:
539     return ops.get_default_graph()
540 
541 
542 @tf_contextlib.contextmanager
543 def _scratch_graph(graph=None):
544   """Retrieve a shared and temporary func graph.
545 
546   The eager execution path lifts a subgraph from the keras global graph into
547   a scratch graph in order to create a function. DistributionStrategies, in
548   turn, constructs multiple functions as well as a final combined function. In
549   order for that logic to work correctly, all of the functions need to be
550   created on the same scratch FuncGraph.
551 
552   Args:
553     graph: A graph to be used as the current scratch graph. If not set then
554       a scratch graph will either be retrieved or created:
555 
556   Yields:
557     The current scratch graph.
558   """
559   global _CURRENT_SCRATCH_GRAPH
560   if (_CURRENT_SCRATCH_GRAPH is not None and graph is not None and
561       _CURRENT_SCRATCH_GRAPH is not graph):
562     raise ValueError('Multiple scratch graphs specified.')
563 
564   if _CURRENT_SCRATCH_GRAPH:
565     yield _CURRENT_SCRATCH_GRAPH
566     return
567 
568   graph = graph or func_graph.FuncGraph('keras_scratch_graph')
569   try:
570     _CURRENT_SCRATCH_GRAPH = graph
571     yield graph
572   finally:
573     _CURRENT_SCRATCH_GRAPH = None
574 
575 
576 @keras_export(v1=['keras.backend.set_session'])
577 def set_session(session):
578   """Sets the global TensorFlow session.
579 
580   Arguments:
581       session: A TF Session.
582   """
583   global _SESSION
584   _SESSION.session = session
585 
586 
587 def get_default_session_config():
588   if os.environ.get('OMP_NUM_THREADS'):
589     logging.warning(
590         'OMP_NUM_THREADS is no longer used by the default Keras config. '
591         'To configure the number of threads, use tf.config.threading APIs.')
592 
593   config = context.context().config
594   config.allow_soft_placement = True
595 
596   return config
597 
598 
599 def get_default_graph_uid_map():
600   graph = ops.get_default_graph()
601   name_uid_map = PER_GRAPH_OBJECT_NAME_UIDS.get(graph, None)
602   if name_uid_map is None:
603     name_uid_map = collections.defaultdict(int)
604     PER_GRAPH_OBJECT_NAME_UIDS[graph] = name_uid_map
605   return name_uid_map
606 
607 
608 # DEVICE MANIPULATION
609 
610 
611 class _TfDeviceCaptureOp(object):
612   """Class for capturing the TF device scope."""
613 
614   def __init__(self):
615     self.device = None
616 
617   def _set_device(self, device):
618     """This method captures TF's explicit device scope setting."""
619     if tfdev.is_device_spec(device):
620       device = device.to_string()
621     self.device = device
622 
623   def _set_device_from_string(self, device_str):
624     self.device = device_str
625 
626 
627 def _get_current_tf_device():
628   """Return explicit device of current context, otherwise returns `None`.
629 
630   Returns:
631       If the current device scope is explicitly set, it returns a string with
632       the device (`CPU` or `GPU`). If the scope is not explicitly set, it will
633       return `None`.
634   """
635   graph = get_graph()
636   op = _TfDeviceCaptureOp()
637   graph._apply_device_functions(op)
638   return tfdev.DeviceSpec.from_string(op.device)
639 
640 
641 def _is_current_explicit_device(device_type):
642   """Check if the current device is explicitly set on the device type specified.
643 
644   Arguments:
645       device_type: A string containing `GPU` or `CPU` (case-insensitive).
646 
647   Returns:
648       A boolean indicating if the current device scope is explicitly set on the
649       device type.
650 
651   Raises:
652       ValueError: If the `device_type` string indicates an unsupported device.
653   """
654   device_type = device_type.upper()
655   if device_type not in ['CPU', 'GPU']:
656     raise ValueError('`device_type` should be either "CPU" or "GPU".')
657   device = _get_current_tf_device()
658   return device is not None and device.device_type == device_type.upper()
659 
660 
661 def _get_available_gpus():
662   """Get a list of available gpu devices (formatted as strings).
663 
664   Returns:
665       A list of available GPU devices.
666   """
667   if ops.executing_eagerly_outside_functions():
668     # Returns names of devices directly.
669     return [d.name for d in config.list_logical_devices('GPU')]
670 
671   global _LOCAL_DEVICES
672   if _LOCAL_DEVICES is None:
673     _LOCAL_DEVICES = get_session().list_devices()
674   return [x.name for x in _LOCAL_DEVICES if x.device_type == 'GPU']
675 
676 
677 def _has_nchw_support():
678   """Check whether the current scope supports NCHW ops.
679 
680   TensorFlow does not support NCHW on CPU. Therefore we check if we are not
681   explicitly put on
682   CPU, and have GPUs available. In this case there will be soft-placing on the
683   GPU device.
684 
685   Returns:
686       bool: if the current scope device placement would support nchw
687   """
688   explicitly_on_cpu = _is_current_explicit_device('CPU')
689   gpus_available = bool(_get_available_gpus())
690   return not explicitly_on_cpu and gpus_available
691 
692 
693 # VARIABLE MANIPULATION
694 
695 
696 def _constant_to_tensor(x, dtype):
697   """Convert the input `x` to a tensor of type `dtype`.
698 
699   This is slightly faster than the _to_tensor function, at the cost of
700   handling fewer cases.
701 
702   Arguments:
703       x: An object to be converted (numpy arrays, floats, ints and lists of
704         them).
705       dtype: The destination type.
706 
707   Returns:
708       A tensor.
709   """
710   return constant_op.constant(x, dtype=dtype)
711 
712 
713 def _to_tensor(x, dtype):
714   """Convert the input `x` to a tensor of type `dtype`.
715 
716   Arguments:
717       x: An object to be converted (numpy array, list, tensors).
718       dtype: The destination type.
719 
720   Returns:
721       A tensor.
722   """
723   return ops.convert_to_tensor(x, dtype=dtype)
724 
725 
726 @keras_export('keras.backend.is_sparse')
727 def is_sparse(tensor):
728   """Returns whether a tensor is a sparse tensor.
729 
730   Arguments:
731       tensor: A tensor instance.
732 
733   Returns:
734       A boolean.
735 
736   Example:
737 
738 
739   >>> a = tf.keras.backend.placeholder((2, 2), sparse=False)
740   >>> print(tf.keras.backend.is_sparse(a))
741   False
742   >>> b = tf.keras.backend.placeholder((2, 2), sparse=True)
743   >>> print(tf.keras.backend.is_sparse(b))
744   True
745 
746   """
747   return isinstance(tensor, sparse_tensor.SparseTensor)
748 
749 
750 @keras_export('keras.backend.to_dense')
751 def to_dense(tensor):
752   """Converts a sparse tensor into a dense tensor and returns it.
753 
754   Arguments:
755       tensor: A tensor instance (potentially sparse).
756 
757   Returns:
758       A dense tensor.
759 
760   Examples:
761 
762 
763   >>> b = tf.keras.backend.placeholder((2, 2), sparse=True)
764   >>> print(tf.keras.backend.is_sparse(b))
765   True
766   >>> c = tf.keras.backend.to_dense(b)
767   >>> print(tf.keras.backend.is_sparse(c))
768   False
769 
770   """
771   if is_sparse(tensor):
772     return sparse_ops.sparse_tensor_to_dense(tensor)
773   else:
774     return tensor
775 
776 
777 @keras_export('keras.backend.name_scope', v1=[])
778 def name_scope(name):
779   """A context manager for use when defining a Python op.
780 
781   This context manager pushes a name scope, which will make the name of all
782   operations added within it have a prefix.
783 
784   For example, to define a new Python op called `my_op`:
785 
786 
787   def my_op(a):
788     with tf.name_scope("MyOp") as scope:
789       a = tf.convert_to_tensor(a, name="a")
790       # Define some computation that uses `a`.
791       return foo_op(..., name=scope)
792 
793 
794   When executed, the Tensor `a` will have the name `MyOp/a`.
795 
796   Args:
797     name: The prefix to use on all names created within the name scope.
798 
799   Returns:
800     Name scope context manager.
801   """
802   return ops.name_scope_v2(name)
803 
804 
805 @keras_export('keras.backend.variable')
806 def variable(value, dtype=None, name=None, constraint=None):
807   """Instantiates a variable and returns it.
808 
809   Arguments:
810       value: Numpy array, initial value of the tensor.
811       dtype: Tensor type.
812       name: Optional name string for the tensor.
813       constraint: Optional projection function to be
814           applied to the variable after an optimizer update.
815 
816   Returns:
817       A variable instance (with Keras metadata included).
818 
819   Examples:
820 
821   >>> val = np.array([[1, 2], [3, 4]])
822   >>> kvar = tf.keras.backend.variable(value=val, dtype='float64',
823   ...                                  name='example_var')
824   >>> tf.keras.backend.dtype(kvar)
825   'float64'
826   >>> print(kvar)
827   <tf.Variable 'example_var:...' shape=(2, 2) dtype=float64, numpy=
828     array([[1., 2.],
829            [3., 4.]])>
830 
831   """
832   if dtype is None:
833     dtype = floatx()
834   if hasattr(value, 'tocoo'):
835     sparse_coo = value.tocoo()
836     indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), np.expand_dims(
837         sparse_coo.col, 1)), 1)
838     v = sparse_tensor.SparseTensor(
839         indices=indices, values=sparse_coo.data, dense_shape=sparse_coo.shape)
840     v._keras_shape = sparse_coo.shape
841     return v
842   v = variables_module.Variable(
843       value,
844       dtype=dtypes_module.as_dtype(dtype),
845       name=name,
846       constraint=constraint)
847   if isinstance(value, np.ndarray):
848     v._keras_shape = value.shape
849   elif hasattr(value, 'shape'):
850     v._keras_shape = int_shape(value)
851   track_variable(v)
852   return v
853 
854 
855 def track_tf_optimizer(tf_optimizer):
856   """Tracks the given TF optimizer for initialization of its variables."""
857   if context.executing_eagerly():
858     return
859   graph = get_graph()
860   optimizers = _GRAPH_TF_OPTIMIZERS.setdefault(graph, weakref.WeakSet())
861   optimizers.add(tf_optimizer)
862 
863 
864 def track_variable(v):
865   """Tracks the given variable for initialization."""
866   if context.executing_eagerly():
867     return
868   graph = v.graph if hasattr(v, 'graph') else get_graph()
869   if graph not in _GRAPH_VARIABLES:
870     _GRAPH_VARIABLES[graph] = object_identity.ObjectIdentityWeakSet()
871   _GRAPH_VARIABLES[graph].add(v)
872 
873 
874 def unique_object_name(name,
875                        name_uid_map=None,
876                        avoid_names=None,
877                        namespace='',
878                        zero_based=False):
879   """Makes a object name (or arbitrary string) unique within a TensorFlow graph.
880 
881   Arguments:
882     name: String name to make unique.
883     name_uid_map: An optional defaultdict(int) to use when creating unique
884       names. If None (default), uses a per-Graph dictionary.
885     avoid_names: An optional set or dict with names which should not be used. If
886       None (default) does not avoid any names.
887     namespace: Gets a name which is unique within the (graph, namespace). Layers
888       which are not Networks use a blank namespace and so get graph-global
889       names.
890     zero_based: If True, name sequences start with no suffix (e.g. "dense",
891       "dense_1"). If False, naming is one-based ("dense_1", "dense_2").
892 
893   Returns:
894     Unique string name.
895 
896   Example:
897 
898 
899   _unique_layer_name('dense')  # dense_1
900   _unique_layer_name('dense')  # dense_2
901 
902   """
903   if name_uid_map is None:
904     name_uid_map = get_default_graph_uid_map()
905   if avoid_names is None:
906     avoid_names = set()
907   proposed_name = None
908   while proposed_name is None or proposed_name in avoid_names:
909     name_key = (namespace, name)
910     if zero_based:
911       number = name_uid_map[name_key]
912       if number:
913         proposed_name = name + '_' + str(number)
914       else:
915         proposed_name = name
916       name_uid_map[name_key] += 1
917     else:
918       name_uid_map[name_key] += 1
919       proposed_name = name + '_' + str(name_uid_map[name_key])
920   return proposed_name
921 
922 
923 def _get_variables(graph=None):
924   """Returns variables corresponding to the given graph for initialization."""
925   assert not context.executing_eagerly()
926   variables = _GRAPH_VARIABLES.setdefault(graph, weakref.WeakSet())
927   for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()):
928     variables.update(opt.optimizer.variables())
929   return variables
930 
931 
932 def _initialize_variables(session):
933   """Utility to initialize uninitialized variables on the fly."""
934   variables = _get_variables(get_graph())
935   candidate_vars = []
936   for v in variables:
937     if not getattr(v, '_keras_initialized', False):
938       candidate_vars.append(v)
939   if candidate_vars:
940     # This step is expensive, so we only run it on variables not already
941     # marked as initialized.
942     is_initialized = session.run(
943         [variables_module.is_variable_initialized(v) for v in candidate_vars])
944     # TODO(kathywu): Some metric variables loaded from SavedModel are never
945     # actually used, and do not have an initializer.
946     should_be_initialized = [
947         (not is_initialized[n]) and v.initializer is not None
948         for n, v in enumerate(candidate_vars)]
949     uninitialized_vars = []
950     for flag, v in zip(should_be_initialized, candidate_vars):
951       if flag:
952         uninitialized_vars.append(v)
953       v._keras_initialized = True
954     if uninitialized_vars:
955       session.run(variables_module.variables_initializer(uninitialized_vars))
956 
957 
958 @keras_export('keras.backend.constant')
959 def constant(value, dtype=None, shape=None, name=None):
960   """Creates a constant tensor.
961 
962   Arguments:
963       value: A constant value (or list)
964       dtype: The type of the elements of the resulting tensor.
965       shape: Optional dimensions of resulting tensor.
966       name: Optional name for the tensor.
967 
968   Returns:
969       A Constant Tensor.
970   """
971   if dtype is None:
972     dtype = floatx()
973 
974   return constant_op.constant(value, dtype=dtype, shape=shape, name=name)
975 
976 
977 @keras_export('keras.backend.is_keras_tensor')
978 def is_keras_tensor(x):
979   """Returns whether `x` is a Keras tensor.
980 
981   A "Keras tensor" is a tensor that was returned by a Keras layer,
982   (`Layer` class) or by `Input`.
983 
984   Arguments:
985       x: A candidate tensor.
986 
987   Returns:
988       A boolean: Whether the argument is a Keras tensor.
989 
990   Raises:
991       ValueError: In case `x` is not a symbolic tensor.
992 
993   Examples:
994 
995   >>> np_var = np.array([1, 2])
996   >>> # A numpy array is not a symbolic tensor.
997   >>> tf.keras.backend.is_keras_tensor(np_var)
998   Traceback (most recent call last):
999   ...
1000   ValueError: Unexpectedly found an instance of type `<class 'numpy.ndarray'>`.
1001   Expected a symbolic tensor instance.
1002   >>> keras_var = tf.keras.backend.variable(np_var)
1003   >>> # A variable created with the keras backend is not a Keras tensor.
1004   >>> tf.keras.backend.is_keras_tensor(keras_var)
1005   False
1006   >>> keras_placeholder = tf.keras.backend.placeholder(shape=(2, 4, 5))
1007   >>> # A placeholder is not a Keras tensor.
1008   >>> tf.keras.backend.is_keras_tensor(keras_placeholder)
1009   False
1010   >>> keras_input = tf.keras.layers.Input([10])
1011   >>> # An Input is a Keras tensor.
1012   >>> tf.keras.backend.is_keras_tensor(keras_input)
1013   True
1014   >>> keras_layer_output = tf.keras.layers.Dense(10)(keras_input)
1015   >>> # Any Keras layer output is a Keras tensor.
1016   >>> tf.keras.backend.is_keras_tensor(keras_layer_output)
1017   True
1018 
1019   """
1020   if not isinstance(x,
1021                     (ops.Tensor, variables_module.Variable,
1022                      sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
1023     raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) +
1024                      '`. Expected a symbolic tensor instance.')
1025   return hasattr(x, '_keras_history')
1026 
1027 
1028 @keras_export('keras.backend.placeholder')
1029 def placeholder(shape=None,
1030                 ndim=None,
1031                 dtype=None,
1032                 sparse=False,
1033                 name=None,
1034                 ragged=False):
1035   """Instantiates a placeholder tensor and returns it.
1036 
1037   Arguments:
1038       shape: Shape of the placeholder
1039           (integer tuple, may include `None` entries).
1040       ndim: Number of axes of the tensor.
1041           At least one of {`shape`, `ndim`} must be specified.
1042           If both are specified, `shape` is used.
1043       dtype: Placeholder type.
1044       sparse: Boolean, whether the placeholder should have a sparse type.
1045       name: Optional name string for the placeholder.
1046       ragged: Boolean, whether the placeholder should have a ragged type.
1047           In this case, values of 'None' in the 'shape' argument represent
1048           ragged dimensions. For more information about RaggedTensors, see this
1049           [guide](https://www.tensorflow.org/guide/ragged_tensors).
1050 
1051   Raises:
1052       ValueError: If called with eager execution
1053       ValueError: If called with sparse = True and ragged = True.
1054 
1055   Returns:
1056       Tensor instance (with Keras metadata included).
1057 
1058   Examples:
1059 
1060 
1061   >>> input_ph = tf.keras.backend.placeholder(shape=(2, 4, 5))
1062   >>> input_ph
1063   <tf.Tensor 'Placeholder_...' shape=(2, 4, 5) dtype=float32>
1064 
1065   """
1066   if sparse and ragged:
1067     raise ValueError(
1068         'Cannot set both sparse and ragged to True when creating a placeholder.'
1069     )
1070 
1071   if dtype is None:
1072     dtype = floatx()
1073   if not shape:
1074     if ndim:
1075       shape = (None,) * ndim
1076   with get_graph().as_default():
1077     if sparse:
1078       x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
1079     elif ragged:
1080       ragged_rank = 0
1081       for i in range(1, len(shape)):
1082         if shape[i] is None:
1083           ragged_rank = i
1084       type_spec = ragged_tensor.RaggedTensorSpec(
1085           shape=shape, dtype=dtype, ragged_rank=ragged_rank)
1086       def tensor_spec_to_placeholder(tensorspec):
1087         return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
1088       x = nest.map_structure(tensor_spec_to_placeholder, type_spec,
1089                              expand_composites=True)
1090     else:
1091       x = array_ops.placeholder(dtype, shape=shape, name=name)
1092   return x
1093 
1094 
1095 def is_placeholder(x):
1096   """Returns whether `x` is a placeholder.
1097 
1098   Arguments:
1099       x: A candidate placeholder.
1100 
1101   Returns:
1102       Boolean.
1103   """
1104   try:
1105     if isinstance(x, composite_tensor.CompositeTensor):
1106       flat_components = nest.flatten(x, expand_composites=True)
1107       return py_any(is_placeholder(c) for c in flat_components)
1108     else:
1109       return x.op.type == 'Placeholder'
1110   except AttributeError:
1111     return False
1112 
1113 
1114 def freezable_variable(value, shape=None, name=None):
1115   """A tensor-like object whose value can be updated only up until execution.
1116 
1117   After creating the freezable variable, you can update its value by calling
1118   `var.update_value(new_value)` (similar to a regular variable).
1119   Unlike an actual variable, the value used during execution is the current
1120   value at the time the execution function (`backend.function()`) was created.
1121 
1122   This is an internal API, expected to be temporary. It is used to implement a
1123   mutable `trainable` property for `BatchNormalization` layers, with a frozen
1124   value after model compilation.
1125 
1126   We don't use a plain variable in this case because we need the value used
1127   in a specific model to be frozen after `compile` has been called
1128   (e.g. GAN use case).
1129 
1130   Arguments:
1131     value: The initial value for the tensor-like object.
1132     shape: The shape for the tensor-like object (cannot be changed).
1133     name: The name for the tensor-like object.
1134 
1135   Returns:
1136     A tensor-like object with a static value that can be updated via
1137     `x.update_value(new_value)`, up until creating an execution function
1138     (afterwards the value is fixed).
1139   """
1140   graph = get_graph()
1141   with graph.as_default():
1142     x = array_ops.placeholder_with_default(
1143         value, shape=shape, name=name)
1144     x._initial_value = value
1145     x._current_value = value
1146 
1147     def update_value(new_value):
1148       x._current_value = new_value
1149 
1150     def get_value():
1151       return x._current_value
1152 
1153     x.update_value = update_value
1154     x.get_value = get_value
1155 
1156     global _FREEZABLE_VARS
1157     if graph not in _FREEZABLE_VARS:
1158       _FREEZABLE_VARS[graph] = object_identity.ObjectIdentityWeakSet()
1159     _FREEZABLE_VARS[graph].add(x)
1160   return x
1161 
1162 
1163 @keras_export('keras.backend.shape')
1164 def shape(x):
1165   """Returns the symbolic shape of a tensor or variable.
1166 
1167   Arguments:
1168       x: A tensor or variable.
1169 
1170   Returns:
1171       A symbolic shape (which is itself a tensor).
1172 
1173   Examples:
1174 
1175   >>> val = np.array([[1, 2], [3, 4]])
1176   >>> kvar = tf.keras.backend.variable(value=val)
1177   >>> tf.keras.backend.shape(kvar)
1178   <tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 2], dtype=int32)>
1179   >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1180   >>> tf.keras.backend.shape(input)
1181   <tf.Tensor 'Shape_...' shape=(3,) dtype=int32>
1182 
1183   """
1184   return array_ops.shape(x)
1185 
1186 
1187 @keras_export('keras.backend.int_shape')
1188 def int_shape(x):
1189   """Returns the shape of tensor or variable as a tuple of int or None entries.
1190 
1191   Arguments:
1192       x: Tensor or variable.
1193 
1194   Returns:
1195       A tuple of integers (or None entries).
1196 
1197   Examples:
1198 
1199   >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1200   >>> tf.keras.backend.int_shape(input)
1201   (2, 4, 5)
1202   >>> val = np.array([[1, 2], [3, 4]])
1203   >>> kvar = tf.keras.backend.variable(value=val)
1204   >>> tf.keras.backend.int_shape(kvar)
1205   (2, 2)
1206 
1207   """
1208   try:
1209     shape = x.shape
1210     if not isinstance(shape, tuple):
1211       shape = tuple(shape.as_list())
1212     return shape
1213   except ValueError:
1214     return None
1215 
1216 
1217 @keras_export('keras.backend.ndim')
1218 def ndim(x):
1219   """Returns the number of axes in a tensor, as an integer.
1220 
1221   Arguments:
1222       x: Tensor or variable.
1223 
1224   Returns:
1225       Integer (scalar), number of axes.
1226 
1227   Examples:
1228 
1229 
1230   >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1231   >>> val = np.array([[1, 2], [3, 4]])
1232   >>> kvar = tf.keras.backend.variable(value=val)
1233   >>> tf.keras.backend.ndim(input)
1234   3
1235   >>> tf.keras.backend.ndim(kvar)
1236   2
1237 
1238   """
1239   dims = x.shape._dims
1240   if dims is not None:
1241     return len(dims)
1242   return None
1243 
1244 
1245 @keras_export('keras.backend.dtype')
1246 def dtype(x):
1247   """Returns the dtype of a Keras tensor or variable, as a string.
1248 
1249   Arguments:
1250       x: Tensor or variable.
1251 
1252   Returns:
1253       String, dtype of `x`.
1254 
1255   Examples:
1256 
1257   >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5)))
1258   'float32'
1259   >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5),
1260   ...                                                     dtype='float32'))
1261   'float32'
1262   >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5),
1263   ...                                                     dtype='float64'))
1264   'float64'
1265   >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]))
1266   >>> tf.keras.backend.dtype(kvar)
1267   'float32'
1268   >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]),
1269   ...                                  dtype='float32')
1270   >>> tf.keras.backend.dtype(kvar)
1271   'float32'
1272 
1273   """
1274   return x.dtype.base_dtype.name
1275 
1276 
1277 @keras_export('keras.backend.eval')
1278 def eval(x):
1279   """Evaluates the value of a variable.
1280 
1281   Arguments:
1282       x: A variable.
1283 
1284   Returns:
1285       A Numpy array.
1286 
1287   Examples:
1288 
1289   >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]),
1290   ...                                  dtype='float32')
1291   >>> tf.keras.backend.eval(kvar)
1292   array([[1.,  2.],
1293          [3.,  4.]], dtype=float32)
1294 
1295   """
1296   return get_value(to_dense(x))
1297 
1298 
1299 @keras_export('keras.backend.zeros')
1300 def zeros(shape, dtype=None, name=None):
1301   """Instantiates an all-zeros variable and returns it.
1302 
1303   Arguments:
1304       shape: Tuple or list of integers, shape of returned Keras variable
1305       dtype: data type of returned Keras variable
1306       name: name of returned Keras variable
1307 
1308   Returns:
1309       A variable (including Keras metadata), filled with `0.0`.
1310       Note that if `shape` was symbolic, we cannot return a variable,
1311       and will return a dynamically-shaped tensor instead.
1312 
1313   Example:
1314 
1315   >>> kvar = tf.keras.backend.zeros((3,4))
1316   >>> tf.keras.backend.eval(kvar)
1317   array([[0.,  0.,  0.,  0.],
1318          [0.,  0.,  0.,  0.],
1319          [0.,  0.,  0.,  0.]], dtype=float32)
1320   >>> A = tf.constant([1,2,3])
1321   >>> kvar2 = tf.keras.backend.zeros(A.shape) # [0., 0., 0.]
1322   >>> tf.keras.backend.eval(kvar2)
1323   array([0., 0., 0.], dtype=float32)
1324   >>> kvar3 = tf.keras.backend.zeros(A.shape,dtype=tf.int32)
1325   >>> tf.keras.backend.eval(kvar3)
1326   array([0, 0, 0], dtype=int32)
1327   >>> kvar4 = tf.keras.backend.zeros([2,3])
1328   >>> tf.keras.backend.eval(kvar4)
1329   array([[0., 0., 0.],
1330          [0., 0., 0.]], dtype=float32)
1331 
1332   """
1333   with ops.init_scope():
1334     if dtype is None:
1335       dtype = floatx()
1336     tf_dtype = dtypes_module.as_dtype(dtype)
1337     v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
1338     if py_all(v.shape.as_list()):
1339       return variable(v, dtype=dtype, name=name)
1340     return v
1341 
1342 
1343 @keras_export('keras.backend.ones')
1344 def ones(shape, dtype=None, name=None):
1345   """Instantiates an all-ones variable and returns it.
1346 
1347   Arguments:
1348       shape: Tuple of integers, shape of returned Keras variable.
1349       dtype: String, data type of returned Keras variable.
1350       name: String, name of returned Keras variable.
1351 
1352   Returns:
1353       A Keras variable, filled with `1.0`.
1354       Note that if `shape` was symbolic, we cannot return a variable,
1355       and will return a dynamically-shaped tensor instead.
1356 
1357   Example:
1358 
1359 
1360   >>> kvar = tf.keras.backend.ones((3,4))
1361   >>> tf.keras.backend.eval(kvar)
1362   array([[1.,  1.,  1.,  1.],
1363          [1.,  1.,  1.,  1.],
1364          [1.,  1.,  1.,  1.]], dtype=float32)
1365 
1366   """
1367   with ops.init_scope():
1368     if dtype is None:
1369       dtype = floatx()
1370     tf_dtype = dtypes_module.as_dtype(dtype)
1371     v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
1372     if py_all(v.shape.as_list()):
1373       return variable(v, dtype=dtype, name=name)
1374     return v
1375 
1376 
1377 @keras_export('keras.backend.eye')
1378 def eye(size, dtype=None, name=None):
1379   """Instantiate an identity matrix and returns it.
1380 
1381   Arguments:
1382       size: Integer, number of rows/columns.
1383       dtype: String, data type of returned Keras variable.
1384       name: String, name of returned Keras variable.
1385 
1386   Returns:
1387       A Keras variable, an identity matrix.
1388 
1389   Example:
1390 
1391 
1392   >>> kvar = tf.keras.backend.eye(3)
1393   >>> tf.keras.backend.eval(kvar)
1394   array([[1.,  0.,  0.],
1395          [0.,  1.,  0.],
1396          [0.,  0.,  1.]], dtype=float32)
1397 
1398 
1399   """
1400   if dtype is None:
1401     dtype = floatx()
1402   tf_dtype = dtypes_module.as_dtype(dtype)
1403   return variable(linalg_ops.eye(size, dtype=tf_dtype), dtype, name)
1404 
1405 
1406 @keras_export('keras.backend.zeros_like')
1407 def zeros_like(x, dtype=None, name=None):
1408   """Instantiates an all-zeros variable of the same shape as another tensor.
1409 
1410   Arguments:
1411       x: Keras variable or Keras tensor.
1412       dtype: dtype of returned Keras variable.
1413              `None` uses the dtype of `x`.
1414       name: name for the variable to create.
1415 
1416   Returns:
1417       A Keras variable with the shape of `x` filled with zeros.
1418 
1419   Example:
1420 
1421 
1422   from tensorflow.keras import backend as K
1423   kvar = K.variable(np.random.random((2,3)))
1424   kvar_zeros = K.zeros_like(kvar)
1425   K.eval(kvar_zeros)
1426   # array([[ 0.,  0.,  0.], [ 0.,  0.,  0.]], dtype=float32)
1427 
1428 
1429   """
1430   return array_ops.zeros_like(x, dtype=dtype, name=name)
1431 
1432 
1433 @keras_export('keras.backend.ones_like')
1434 def ones_like(x, dtype=None, name=None):
1435   """Instantiates an all-ones variable of the same shape as another tensor.
1436 
1437   Arguments:
1438       x: Keras variable or tensor.
1439       dtype: String, dtype of returned Keras variable.
1440            None uses the dtype of x.
1441       name: String, name for the variable to create.
1442 
1443   Returns:
1444       A Keras variable with the shape of x filled with ones.
1445 
1446   Example:
1447 
1448   >>> kvar = tf.keras.backend.variable(np.random.random((2,3)))
1449   >>> kvar_ones = tf.keras.backend.ones_like(kvar)
1450   >>> tf.keras.backend.eval(kvar_ones)
1451   array([[1.,  1.,  1.],
1452          [1.,  1.,  1.]], dtype=float32)
1453 
1454   """
1455   return array_ops.ones_like(x, dtype=dtype, name=name)
1456 
1457 
1458 def identity(x, name=None):
1459   """Returns a tensor with the same content as the input tensor.
1460 
1461   Arguments:
1462       x: The input tensor.
1463       name: String, name for the variable to create.
1464 
1465   Returns:
1466       A tensor of the same shape, type and content.
1467   """
1468   return array_ops.identity(x, name=name)
1469 
1470 
1471 @keras_export('keras.backend.random_uniform_variable')
1472 def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):
1473   """Instantiates a variable with values drawn from a uniform distribution.
1474 
1475   Arguments:
1476       shape: Tuple of integers, shape of returned Keras variable.
1477       low: Float, lower boundary of the output interval.
1478       high: Float, upper boundary of the output interval.
1479       dtype: String, dtype of returned Keras variable.
1480       name: String, name of returned Keras variable.
1481       seed: Integer, random seed.
1482 
1483   Returns:
1484       A Keras variable, filled with drawn samples.
1485 
1486   Example:
1487 
1488   >>> kvar = tf.keras.backend.random_uniform_variable((2,3), 0, 1)
1489   >>> kvar
1490   <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
1491   dtype=float32)>
1492   """
1493   if dtype is None:
1494     dtype = floatx()
1495   tf_dtype = dtypes_module.as_dtype(dtype)
1496   if seed is None:
1497     # ensure that randomness is conditioned by the Numpy RNG
1498     seed = np.random.randint(10e8)
1499   value = init_ops.random_uniform_initializer(
1500       low, high, dtype=tf_dtype, seed=seed)(shape)
1501   return variable(value, dtype=dtype, name=name)
1502 
1503 
1504 @keras_export('keras.backend.random_normal_variable')
1505 def random_normal_variable(shape, mean, scale, dtype=None, name=None,
1506                            seed=None):
1507   """Instantiates a variable with values drawn from a normal distribution.
1508 
1509   Arguments:
1510       shape: Tuple of integers, shape of returned Keras variable.
1511       mean: Float, mean of the normal distribution.
1512       scale: Float, standard deviation of the normal distribution.
1513       dtype: String, dtype of returned Keras variable.
1514       name: String, name of returned Keras variable.
1515       seed: Integer, random seed.
1516 
1517   Returns:
1518       A Keras variable, filled with drawn samples.
1519 
1520   Example:
1521 
1522   >>> kvar = tf.keras.backend.random_normal_variable((2,3), 0, 1)
1523   >>> kvar
1524   <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
1525   dtype=float32)>
1526   """
1527   if dtype is None:
1528     dtype = floatx()
1529   tf_dtype = dtypes_module.as_dtype(dtype)
1530   if seed is None:
1531     # ensure that randomness is conditioned by the Numpy RNG
1532     seed = np.random.randint(10e8)
1533   value = init_ops.random_normal_initializer(
1534       mean, scale, dtype=tf_dtype, seed=seed)(shape)
1535   return variable(value, dtype=dtype, name=name)
1536 
1537 
1538 @keras_export('keras.backend.count_params')
1539 def count_params(x):
1540   """Returns the static number of elements in a variable or tensor.
1541 
1542   Arguments:
1543       x: Variable or tensor.
1544 
1545   Returns:
1546       Integer, the number of scalars in `x`.
1547 
1548   Example:
1549 
1550   >>> kvar = tf.keras.backend.zeros((2,3))
1551   >>> tf.keras.backend.count_params(kvar)
1552   6
1553   >>> tf.keras.backend.eval(kvar)
1554   array([[0.,  0.,  0.],
1555          [0.,  0.,  0.]], dtype=float32)
1556 
1557   """
1558   return np.prod(x.shape.as_list())
1559 
1560 
1561 @keras_export('keras.backend.cast')
1562 def cast(x, dtype):
1563   """Casts a tensor to a different dtype and returns it.
1564 
1565   You can cast a Keras variable but it still returns a Keras tensor.
1566 
1567   Arguments:
1568       x: Keras tensor (or variable).
1569       dtype: String, either (`'float16'`, `'float32'`, or `'float64'`).
1570 
1571   Returns:
1572       Keras tensor with dtype `dtype`.
1573 
1574   Examples:
1575       Cast a float32 variable to a float64 tensor
1576 
1577   >>> input = tf.keras.backend.ones(shape=(1,3))
1578   >>> print(input)
1579   <tf.Variable 'Variable:0' shape=(1, 3) dtype=float32,
1580   numpy=array([[1., 1., 1.]], dtype=float32)>
1581   >>> cast_input = tf.keras.backend.cast(input, dtype='float64')
1582   >>> print(cast_input)
1583   tf.Tensor([[1. 1. 1.]], shape=(1, 3), dtype=float64)
1584 
1585   """
1586   return math_ops.cast(x, dtype)
1587 
1588 
1589 # UPDATES OPS
1590 
1591 
1592 @keras_export('keras.backend.update')
1593 def update(x, new_x):
1594   return state_ops.assign(x, new_x)
1595 
1596 
1597 @keras_export('keras.backend.update_add')
1598 def update_add(x, increment):
1599   """Update the value of `x` by adding `increment`.
1600 
1601   Arguments:
1602       x: A Variable.
1603       increment: A tensor of same shape as `x`.
1604 
1605   Returns:
1606       The variable `x` updated.
1607   """
1608   return state_ops.assign_add(x, increment)
1609 
1610 
1611 @keras_export('keras.backend.update_sub')
1612 def update_sub(x, decrement):
1613   """Update the value of `x` by subtracting `decrement`.
1614 
1615   Arguments:
1616       x: A Variable.
1617       decrement: A tensor of same shape as `x`.
1618 
1619   Returns:
1620       The variable `x` updated.
1621   """
1622   return state_ops.assign_sub(x, decrement)
1623 
1624 
1625 @keras_export('keras.backend.moving_average_update')
1626 def moving_average_update(x, value, momentum):
1627   """Compute the moving average of a variable.
1628 
1629   Arguments:
1630       x: A Variable.
1631       value: A tensor with the same shape as `variable`.
1632       momentum: The moving average momentum.
1633 
1634   Returns:
1635       An Operation to update the variable.
1636   """
1637   zero_debias = not tf2.enabled()
1638   return moving_averages.assign_moving_average(
1639       x, value, momentum, zero_debias=zero_debias)
1640 
1641 
1642 # LINEAR ALGEBRA
1643 
1644 
1645 @keras_export('keras.backend.dot')
1646 def dot(x, y):
1647   """Multiplies 2 tensors (and/or variables) and returns a tensor.
1648 
1649   Arguments:
1650       x: Tensor or variable.
1651       y: Tensor or variable.
1652 
1653   Returns:
1654       A tensor, dot product of `x` and `y`.
1655 
1656   Examples:
1657 
1658   >>> x = tf.keras.backend.placeholder(shape=(2, 3))
1659   >>> y = tf.keras.backend.placeholder(shape=(3, 4))
1660   >>> xy = tf.keras.backend.dot(x, y)
1661   >>> xy
1662   <tf.Tensor ... shape=(2, 4) dtype=float32>
1663 
1664   >>> x = tf.keras.backend.placeholder(shape=(32, 28, 3))
1665   >>> y = tf.keras.backend.placeholder(shape=(3, 4))
1666   >>> xy = tf.keras.backend.dot(x, y)
1667   >>> xy
1668   <tf.Tensor ... shape=(32, 28, 4) dtype=float32>
1669 
1670   >>> x = tf.keras.backend.random_uniform_variable(shape=(2, 3), low=0, high=1)
1671   >>> y = tf.keras.backend.ones((4, 3, 5))
1672   >>> xy = tf.keras.backend.dot(x, y)
1673   >>> tf.keras.backend.int_shape(xy)
1674   (2, 4, 5)
1675   """
1676   if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
1677     x_shape = []
1678     for i, s in zip(int_shape(x), array_ops.unstack(array_ops.shape(x))):
1679       if i is not None:
1680         x_shape.append(i)
1681       else:
1682         x_shape.append(s)
1683     x_shape = tuple(x_shape)
1684     y_shape = []
1685     for i, s in zip(int_shape(y), array_ops.unstack(array_ops.shape(y))):
1686       if i is not None:
1687         y_shape.append(i)
1688       else:
1689         y_shape.append(s)
1690     y_shape = tuple(y_shape)
1691     y_permute_dim = list(range(ndim(y)))
1692     y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim
1693     xt = array_ops.reshape(x, [-1, x_shape[-1]])
1694     yt = array_ops.reshape(
1695         array_ops.transpose(y, perm=y_permute_dim), [y_shape[-2], -1])
1696     return array_ops.reshape(
1697         math_ops.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:])
1698   if is_sparse(x):
1699     out = sparse_ops.sparse_tensor_dense_matmul(x, y)
1700   else:
1701     out = math_ops.matmul(x, y)
1702   return out
1703 
1704 
1705 @keras_export('keras.backend.batch_dot')
1706 def batch_dot(x, y, axes=None):
1707   """Batchwise dot product.
1708 
1709   `batch_dot` is used to compute dot product of `x` and `y` when
1710   `x` and `y` are data in batch, i.e. in a shape of
1711   `(batch_size, :)`.
1712   `batch_dot` results in a tensor or variable with less dimensions
1713   than the input. If the number of dimensions is reduced to 1,
1714   we use `expand_dims` to make sure that ndim is at least 2.
1715 
1716   Arguments:
1717     x: Keras tensor or variable with `ndim >= 2`.
1718     y: Keras tensor or variable with `ndim >= 2`.
1719     axes: Tuple or list of integers with target dimensions, or single integer.
1720       The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]` should be equal.
1721 
1722   Returns:
1723     A tensor with shape equal to the concatenation of `x`'s shape
1724     (less the dimension that was summed over) and `y`'s shape
1725     (less the batch dimension and the dimension that was summed over).
1726     If the final rank is 1, we reshape it to `(batch_size, 1)`.
1727 
1728   Examples:
1729 
1730   >>> x_batch = tf.keras.backend.ones(shape=(32, 20, 1))
1731   >>> y_batch = tf.keras.backend.ones(shape=(32, 30, 20))
1732   >>> xy_batch_dot = tf.keras.backend.batch_dot(x_batch, y_batch, axes=(1, 2))
1733   >>> tf.keras.backend.int_shape(xy_batch_dot)
1734   (32, 1, 30)
1735 
1736   Shape inference:
1737     Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.
1738     If `axes` is (1, 2), to find the output shape of resultant tensor,
1739         loop through each dimension in `x`'s shape and `y`'s shape:
1740     * `x.shape[0]` : 100 : append to output shape
1741     * `x.shape[1]` : 20 : do not append to output shape,
1742         dimension 1 of `x` has been summed over. (`dot_axes[0]` = 1)
1743     * `y.shape[0]` : 100 : do not append to output shape,
1744         always ignore first dimension of `y`
1745     * `y.shape[1]` : 30 : append to output shape
1746     * `y.shape[2]` : 20 : do not append to output shape,
1747         dimension 2 of `y` has been summed over. (`dot_axes[1]` = 2)
1748     `output_shape` = `(100, 30)`
1749   """
1750   x_shape = int_shape(x)
1751   y_shape = int_shape(y)
1752 
1753   x_ndim = len(x_shape)
1754   y_ndim = len(y_shape)
1755 
1756   if x_ndim < 2 or y_ndim < 2:
1757     raise ValueError('Cannot do batch_dot on inputs '
1758                      'with rank < 2. '
1759                      'Received inputs with shapes ' +
1760                      str(x_shape) + ' and ' +
1761                      str(y_shape) + '.')
1762 
1763   x_batch_size = x_shape[0]
1764   y_batch_size = y_shape[0]
1765 
1766   if x_batch_size is not None and y_batch_size is not None:
1767     if x_batch_size != y_batch_size:
1768       raise ValueError('Cannot do batch_dot on inputs '
1769                        'with different batch sizes. '
1770                        'Received inputs with shapes ' +
1771                        str(x_shape) + ' and ' +
1772                        str(y_shape) + '.')
1773   if isinstance(axes, int):
1774     axes = [axes, axes]
1775 
1776   if axes is None:
1777     if y_ndim == 2:
1778       axes = [x_ndim - 1, y_ndim - 1]
1779     else:
1780       axes = [x_ndim - 1, y_ndim - 2]
1781 
1782   if py_any(isinstance(a, (list, tuple)) for a in axes):
1783     raise ValueError('Multiple target dimensions are not supported. ' +
1784                      'Expected: None, int, (int, int), ' +
1785                      'Provided: ' + str(axes))
1786 
1787   # if tuple, convert to list.
1788   axes = list(axes)
1789 
1790   # convert negative indices.
1791   if axes[0] < 0:
1792     axes[0] += x_ndim
1793   if axes[1] < 0:
1794     axes[1] += y_ndim
1795 
1796   # sanity checks
1797   if 0 in axes:
1798     raise ValueError('Cannot perform batch_dot over axis 0. '
1799                      'If your inputs are not batched, '
1800                      'add a dummy batch dimension to your '
1801                      'inputs using K.expand_dims(x, 0)')
1802   a0, a1 = axes
1803   d1 = x_shape[a0]
1804   d2 = y_shape[a1]
1805 
1806   if d1 is not None and d2 is not None and d1 != d2:
1807     raise ValueError('Cannot do batch_dot on inputs with shapes ' +
1808                      str(x_shape) + ' and ' + str(y_shape) +
1809                      ' with axes=' + str(axes) + '. x.shape[%d] != '
1810                      'y.shape[%d] (%d != %d).' % (axes[0], axes[1], d1, d2))
1811 
1812   # backup ndims. Need them later.
1813   orig_x_ndim = x_ndim
1814   orig_y_ndim = y_ndim
1815 
1816   # if rank is 2, expand to 3.
1817   if x_ndim == 2:
1818     x = array_ops.expand_dims(x, 1)
1819     a0 += 1
1820     x_ndim += 1
1821   if y_ndim == 2:
1822     y = array_ops.expand_dims(y, 2)
1823     y_ndim += 1
1824 
1825   # bring x's dimension to be reduced to last axis.
1826   if a0 != x_ndim - 1:
1827     pattern = list(range(x_ndim))
1828     for i in range(a0, x_ndim - 1):
1829       pattern[i] = pattern[i + 1]
1830     pattern[-1] = a0
1831     x = array_ops.transpose(x, pattern)
1832 
1833   # bring y's dimension to be reduced to axis 1.
1834   if a1 != 1:
1835     pattern = list(range(y_ndim))
1836     for i in range(a1, 1, -1):
1837       pattern[i] = pattern[i - 1]
1838     pattern[1] = a1
1839     y = array_ops.transpose(y, pattern)
1840 
1841   # normalize both inputs to rank 3.
1842   if x_ndim > 3:
1843     # squash middle dimensions of x.
1844     x_shape = shape(x)
1845     x_mid_dims = x_shape[1:-1]
1846     x_squashed_shape = array_ops.stack(
1847         [x_shape[0], -1, x_shape[-1]])
1848     x = array_ops.reshape(x, x_squashed_shape)
1849     x_squashed = True
1850   else:
1851     x_squashed = False
1852 
1853   if y_ndim > 3:
1854     # squash trailing dimensions of y.
1855     y_shape = shape(y)
1856     y_trail_dims = y_shape[2:]
1857     y_squashed_shape = array_ops.stack(
1858         [y_shape[0], y_shape[1], -1])
1859     y = array_ops.reshape(y, y_squashed_shape)
1860     y_squashed = True
1861   else:
1862     y_squashed = False
1863 
1864   result = math_ops.matmul(x, y)
1865 
1866   # if inputs were squashed, we have to reshape the matmul output.
1867   output_shape = array_ops.shape(result)
1868   do_reshape = False
1869 
1870   if x_squashed:
1871     output_shape = array_ops.concat(
1872         [output_shape[:1],
1873          x_mid_dims,
1874          output_shape[-1:]], 0)
1875     do_reshape = True
1876 
1877   if y_squashed:
1878     output_shape = array_ops.concat([output_shape[:-1], y_trail_dims], 0)
1879     do_reshape = True
1880 
1881   if do_reshape:
1882     result = array_ops.reshape(result, output_shape)
1883 
1884   # if the inputs were originally rank 2, we remove the added 1 dim.
1885   if orig_x_ndim == 2:
1886     result = array_ops.squeeze(result, 1)
1887   elif orig_y_ndim == 2:
1888     result = array_ops.squeeze(result, -1)
1889 
1890   return result
1891 
1892 
1893 @keras_export('keras.backend.transpose')
1894 def transpose(x):
1895   """Transposes a tensor and returns it.
1896 
1897   Arguments:
1898       x: Tensor or variable.
1899 
1900   Returns:
1901       A tensor.
1902 
1903   Examples:
1904 
1905   >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]])
1906   >>> tf.keras.backend.eval(var)
1907   array([[1.,  2.,  3.],
1908          [4.,  5.,  6.]], dtype=float32)
1909   >>> var_transposed = tf.keras.backend.transpose(var)
1910   >>> tf.keras.backend.eval(var_transposed)
1911   array([[1.,  4.],
1912          [2.,  5.],
1913          [3.,  6.]], dtype=float32)
1914   >>> input = tf.keras.backend.placeholder((2, 3))
1915   >>> input
1916   <tf.Tensor 'Placeholder_...' shape=(2, 3) dtype=float32>
1917   >>> input_transposed = tf.keras.backend.transpose(input)
1918   >>> input_transposed
1919   <tf.Tensor 'Transpose_...' shape=(3, 2) dtype=float32>
1920   """
1921   return array_ops.transpose(x)
1922 
1923 
1924 @keras_export('keras.backend.gather')
1925 def gather(reference, indices):
1926   """Retrieves the elements of indices `indices` in the tensor `reference`.
1927 
1928   Arguments:
1929       reference: A tensor.
1930       indices: An integer tensor of indices.
1931 
1932   Returns:
1933       A tensor of same type as `reference`.
1934 
1935   Examples:
1936 
1937   >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]])
1938   >>> tf.keras.backend.eval(var)
1939   array([[1., 2., 3.],
1940          [4., 5., 6.]], dtype=float32)
1941   >>> var_gathered = tf.keras.backend.gather(var, [0])
1942   >>> tf.keras.backend.eval(var_gathered)
1943   array([[1., 2., 3.]], dtype=float32)
1944   >>> var_gathered = tf.keras.backend.gather(var, [1])
1945   >>> tf.keras.backend.eval(var_gathered)
1946   array([[4., 5., 6.]], dtype=float32)
1947   >>> var_gathered = tf.keras.backend.gather(var, [0,1,0])
1948   >>> tf.keras.backend.eval(var_gathered)
1949   array([[1., 2., 3.],
1950          [4., 5., 6.],
1951          [1., 2., 3.]], dtype=float32)
1952   """
1953   return array_ops.gather(reference, indices)
1954 
1955 
1956 # ELEMENT-WISE OPERATIONS
1957 
1958 
1959 @keras_export('keras.backend.max')
1960 def max(x, axis=None, keepdims=False):
1961   """Maximum value in a tensor.
1962 
1963   Arguments:
1964       x: A tensor or variable.
1965       axis: An integer, the axis to find maximum values.
1966       keepdims: A boolean, whether to keep the dimensions or not.
1967           If `keepdims` is `False`, the rank of the tensor is reduced
1968           by 1. If `keepdims` is `True`,
1969           the reduced dimension is retained with length 1.
1970 
1971   Returns:
1972       A tensor with maximum values of `x`.
1973   """
1974   return math_ops.reduce_max(x, axis, keepdims)
1975 
1976 
1977 @keras_export('keras.backend.min')
1978 def min(x, axis=None, keepdims=False):
1979   """Minimum value in a tensor.
1980 
1981   Arguments:
1982       x: A tensor or variable.
1983       axis: An integer, the axis to find minimum values.
1984       keepdims: A boolean, whether to keep the dimensions or not.
1985           If `keepdims` is `False`, the rank of the tensor is reduced
1986           by 1. If `keepdims` is `True`,
1987           the reduced dimension is retained with length 1.
1988 
1989   Returns:
1990       A tensor with minimum values of `x`.
1991   """
1992   return math_ops.reduce_min(x, axis, keepdims)
1993 
1994 
1995 @keras_export('keras.backend.sum')
1996 def sum(x, axis=None, keepdims=False):
1997   """Sum of the values in a tensor, alongside the specified axis.
1998 
1999   Arguments:
2000       x: A tensor or variable.
2001       axis: An integer, the axis to sum over.
2002       keepdims: A boolean, whether to keep the dimensions or not.
2003           If `keepdims` is `False`, the rank of the tensor is reduced
2004           by 1. If `keepdims` is `True`,
2005           the reduced dimension is retained with length 1.
2006 
2007   Returns:
2008       A tensor with sum of `x`.
2009   """
2010   return math_ops.reduce_sum(x, axis, keepdims)
2011 
2012 
2013 @keras_export('keras.backend.prod')
2014 def prod(x, axis=None, keepdims=False):
2015   """Multiplies the values in a tensor, alongside the specified axis.
2016 
2017   Arguments:
2018       x: A tensor or variable.
2019       axis: An integer, the axis to compute the product.
2020       keepdims: A boolean, whether to keep the dimensions or not.
2021           If `keepdims` is `False`, the rank of the tensor is reduced
2022           by 1. If `keepdims` is `True`,
2023           the reduced dimension is retained with length 1.
2024 
2025   Returns:
2026       A tensor with the product of elements of `x`.
2027   """
2028   return math_ops.reduce_prod(x, axis, keepdims)
2029 
2030 
2031 @keras_export('keras.backend.cumsum')
2032 def cumsum(x, axis=0):
2033   """Cumulative sum of the values in a tensor, alongside the specified axis.
2034 
2035   Arguments:
2036       x: A tensor or variable.
2037       axis: An integer, the axis to compute the sum.
2038 
2039   Returns:
2040       A tensor of the cumulative sum of values of `x` along `axis`.
2041   """
2042   return math_ops.cumsum(x, axis=axis)
2043 
2044 
2045 @keras_export('keras.backend.cumprod')
2046 def cumprod(x, axis=0):
2047   """Cumulative product of the values in a tensor, alongside the specified axis.
2048 
2049   Arguments:
2050       x: A tensor or variable.
2051       axis: An integer, the axis to compute the product.
2052 
2053   Returns:
2054       A tensor of the cumulative product of values of `x` along `axis`.
2055   """
2056   return math_ops.cumprod(x, axis=axis)
2057 
2058 
2059 @keras_export('keras.backend.var')
2060 def var(x, axis=None, keepdims=False):
2061   """Variance of a tensor, alongside the specified axis.
2062 
2063   Arguments:
2064       x: A tensor or variable.
2065       axis: An integer, the axis to compute the variance.
2066       keepdims: A boolean, whether to keep the dimensions or not.
2067           If `keepdims` is `False`, the rank of the tensor is reduced
2068           by 1. If `keepdims` is `True`,
2069           the reduced dimension is retained with length 1.
2070 
2071   Returns:
2072       A tensor with the variance of elements of `x`.
2073   """
2074   if x.dtype.base_dtype == dtypes_module.bool:
2075     x = math_ops.cast(x, floatx())
2076   return math_ops.reduce_variance(x, axis=axis, keepdims=keepdims)
2077 
2078 
2079 @keras_export('keras.backend.std')
2080 def std(x, axis=None, keepdims=False):
2081   """Standard deviation of a tensor, alongside the specified axis.
2082 
2083   Arguments:
2084       x: A tensor or variable.
2085       axis: An integer, the axis to compute the standard deviation.
2086       keepdims: A boolean, whether to keep the dimensions or not.
2087           If `keepdims` is `False`, the rank of the tensor is reduced
2088           by 1. If `keepdims` is `True`,
2089           the reduced dimension is retained with length 1.
2090 
2091   Returns:
2092       A tensor with the standard deviation of elements of `x`.
2093   """
2094   if x.dtype.base_dtype == dtypes_module.bool:
2095     x = math_ops.cast(x, floatx())
2096   return math_ops.reduce_std(x, axis=axis, keepdims=keepdims)
2097 
2098 
2099 @keras_export('keras.backend.mean')
2100 def mean(x, axis=None, keepdims=False):
2101   """Mean of a tensor, alongside the specified axis.
2102 
2103   Arguments:
2104       x: A tensor or variable.
2105       axis: A list of integer. Axes to compute the mean.
2106       keepdims: A boolean, whether to keep the dimensions or not.
2107           If `keepdims` is `False`, the rank of the tensor is reduced
2108           by 1 for each entry in `axis`. If `keepdims` is `True`,
2109           the reduced dimensions are retained with length 1.
2110 
2111   Returns:
2112       A tensor with the mean of elements of `x`.
2113   """
2114   if x.dtype.base_dtype == dtypes_module.bool:
2115     x = math_ops.cast(x, floatx())
2116   return math_ops.reduce_mean(x, axis, keepdims)
2117 
2118 
2119 @keras_export('keras.backend.any')
2120 def any(x, axis=None, keepdims=False):
2121   """Bitwise reduction (logical OR).
2122 
2123   Arguments:
2124       x: Tensor or variable.
2125       axis: axis along which to perform the reduction.
2126       keepdims: whether the drop or broadcast the reduction axes.
2127 
2128   Returns:
2129       A uint8 tensor (0s and 1s).
2130   """
2131   x = math_ops.cast(x, dtypes_module.bool)
2132   return math_ops.reduce_any(x, axis, keepdims)
2133 
2134 
2135 @keras_export('keras.backend.all')
2136 def all(x, axis=None, keepdims=False):
2137   """Bitwise reduction (logical AND).
2138 
2139   Arguments:
2140       x: Tensor or variable.
2141       axis: axis along which to perform the reduction.
2142       keepdims: whether the drop or broadcast the reduction axes.
2143 
2144   Returns:
2145       A uint8 tensor (0s and 1s).
2146   """
2147   x = math_ops.cast(x, dtypes_module.bool)
2148   return math_ops.reduce_all(x, axis, keepdims)
2149 
2150 
2151 @keras_export('keras.backend.argmax')
2152 def argmax(x, axis=-1):
2153   """Returns the index of the maximum value along an axis.
2154 
2155   Arguments:
2156       x: Tensor or variable.
2157       axis: axis along which to perform the reduction.
2158 
2159   Returns:
2160       A tensor.
2161   """
2162   return math_ops.argmax(x, axis)
2163 
2164 
2165 @keras_export('keras.backend.argmin')
2166 def argmin(x, axis=-1):
2167   """Returns the index of the minimum value along an axis.
2168 
2169   Arguments:
2170       x: Tensor or variable.
2171       axis: axis along which to perform the reduction.
2172 
2173   Returns:
2174       A tensor.
2175   """
2176   return math_ops.argmin(x, axis)
2177 
2178 
2179 @keras_export('keras.backend.square')
2180 def square(x):
2181   """Element-wise square.
2182 
2183   Arguments:
2184       x: Tensor or variable.
2185 
2186   Returns:
2187       A tensor.
2188   """
2189   return math_ops.square(x)
2190 
2191 
2192 @keras_export('keras.backend.abs')
2193 def abs(x):
2194   """Element-wise absolute value.
2195 
2196   Arguments:
2197       x: Tensor or variable.
2198 
2199   Returns:
2200       A tensor.
2201   """
2202   return math_ops.abs(x)
2203 
2204 
2205 @keras_export('keras.backend.sqrt')
2206 def sqrt(x):
2207   """Element-wise square root.
2208 
2209   Arguments:
2210       x: Tensor or variable.
2211 
2212   Returns:
2213       A tensor.
2214   """
2215   zero = _constant_to_tensor(0., x.dtype.base_dtype)
2216   inf = _constant_to_tensor(np.inf, x.dtype.base_dtype)
2217   x = clip_ops.clip_by_value(x, zero, inf)
2218   return math_ops.sqrt(x)
2219 
2220 
2221 @keras_export('keras.backend.exp')
2222 def exp(x):
2223   """Element-wise exponential.
2224 
2225   Arguments:
2226       x: Tensor or variable.
2227 
2228   Returns:
2229       A tensor.
2230   """
2231   return math_ops.exp(x)
2232 
2233 
2234 @keras_export('keras.backend.log')
2235 def log(x):
2236   """Element-wise log.
2237 
2238   Arguments:
2239       x: Tensor or variable.
2240 
2241   Returns:
2242       A tensor.
2243   """
2244   return math_ops.log(x)
2245 
2246 
2247 def logsumexp(x, axis=None, keepdims=False):
2248   """Computes log(sum(exp(elements across dimensions of a tensor))).
2249 
2250   This function is more numerically stable than log(sum(exp(x))).
2251   It avoids overflows caused by taking the exp of large inputs and
2252   underflows caused by taking the log of small inputs.
2253 
2254   Arguments:
2255       x: A tensor or variable.
2256       axis: An integer, the axis to reduce over.
2257       keepdims: A boolean, whether to keep the dimensions or not.
2258           If `keepdims` is `False`, the rank of the tensor is reduced
2259           by 1. If `keepdims` is `True`, the reduced dimension is
2260           retained with length 1.
2261 
2262   Returns:
2263       The reduced tensor.
2264   """
2265   return math_ops.reduce_logsumexp(x, axis, keepdims)
2266 
2267 
2268 @keras_export('keras.backend.round')
2269 def round(x):
2270   """Element-wise rounding to the closest integer.
2271 
2272   In case of tie, the rounding mode used is "half to even".
2273 
2274   Arguments:
2275       x: Tensor or variable.
2276 
2277   Returns:
2278       A tensor.
2279   """
2280   return math_ops.round(x)
2281 
2282 
2283 @keras_export('keras.backend.sign')
2284 def sign(x):
2285   """Element-wise sign.
2286 
2287   Arguments:
2288       x: Tensor or variable.
2289 
2290   Returns:
2291       A tensor.
2292   """
2293   return math_ops.sign(x)
2294 
2295 
2296 @keras_export('keras.backend.pow')
2297 def pow(x, a):
2298   """Element-wise exponentiation.
2299 
2300   Arguments:
2301       x: Tensor or variable.
2302       a: Python integer.
2303 
2304   Returns:
2305       A tensor.
2306   """
2307   return math_ops.pow(x, a)
2308 
2309 
2310 @keras_export('keras.backend.clip')
2311 def clip(x, min_value, max_value):
2312   """Element-wise value clipping.
2313 
2314   Arguments:
2315       x: Tensor or variable.
2316       min_value: Python float, integer, or tensor.
2317       max_value: Python float, integer, or tensor.
2318 
2319   Returns:
2320       A tensor.
2321   """
2322   if (isinstance(min_value, (int, float)) and
2323       isinstance(max_value, (int, float))):
2324     if max_value < min_value:
2325       max_value = min_value
2326   if min_value is None:
2327     min_value = -np.inf
2328   if max_value is None:
2329     max_value = np.inf
2330   return clip_ops.clip_by_value(x, min_value, max_value)
2331 
2332 
2333 @keras_export('keras.backend.equal')
2334 def equal(x, y):
2335   """Element-wise equality between two tensors.
2336 
2337   Arguments:
2338       x: Tensor or variable.
2339       y: Tensor or variable.
2340 
2341   Returns:
2342       A bool tensor.
2343   """
2344   return math_ops.equal(x, y)
2345 
2346 
2347 @keras_export('keras.backend.not_equal')
2348 def not_equal(x, y):
2349   """Element-wise inequality between two tensors.
2350 
2351   Arguments:
2352       x: Tensor or variable.
2353       y: Tensor or variable.
2354 
2355   Returns:
2356       A bool tensor.
2357   """
2358   return math_ops.not_equal(x, y)
2359 
2360 
2361 @keras_export('keras.backend.greater')
2362 def greater(x, y):
2363   """Element-wise truth value of (x > y).
2364 
2365   Arguments:
2366       x: Tensor or variable.
2367       y: Tensor or variable.
2368 
2369   Returns:
2370       A bool tensor.
2371   """
2372   return math_ops.greater(x, y)
2373 
2374 
2375 @keras_export('keras.backend.greater_equal')
2376 def greater_equal(x, y):
2377   """Element-wise truth value of (x >= y).
2378 
2379   Arguments:
2380       x: Tensor or variable.
2381       y: Tensor or variable.
2382 
2383   Returns:
2384       A bool tensor.
2385   """
2386   return math_ops.greater_equal(x, y)
2387 
2388 
2389 @keras_export('keras.backend.less')
2390 def less(x, y):
2391   """Element-wise truth value of (x < y).
2392 
2393   Arguments:
2394       x: Tensor or variable.
2395       y: Tensor or variable.
2396 
2397   Returns:
2398       A bool tensor.
2399   """
2400   return math_ops.less(x, y)
2401 
2402 
2403 @keras_export('keras.backend.less_equal')
2404 def less_equal(x, y):
2405   """Element-wise truth value of (x <= y).
2406 
2407   Arguments:
2408       x: Tensor or variable.
2409       y: Tensor or variable.
2410 
2411   Returns:
2412       A bool tensor.
2413   """
2414   return math_ops.less_equal(x, y)
2415 
2416 
2417 @keras_export('keras.backend.maximum')
2418 def maximum(x, y):
2419   """Element-wise maximum of two tensors.
2420 
2421   Arguments:
2422       x: Tensor or variable.
2423       y: Tensor or variable.
2424 
2425   Returns:
2426       A tensor with the element wise maximum value(s) of `x` and `y`.
2427 
2428   Examples:
2429 
2430   >>> x = tf.Variable([[1, 2], [3, 4]])
2431   >>> y = tf.Variable([[2, 1], [0, -1]])
2432   >>> m = tf.keras.backend.maximum(x, y)
2433   >>> m
2434   <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
2435   array([[2, 2],
2436          [3, 4]], dtype=int32)>
2437   """
2438   return math_ops.maximum(x, y)
2439 
2440 
2441 @keras_export('keras.backend.minimum')
2442 def minimum(x, y):
2443   """Element-wise minimum of two tensors.
2444 
2445   Arguments:
2446       x: Tensor or variable.
2447       y: Tensor or variable.
2448 
2449   Returns:
2450       A tensor.
2451   """
2452   return math_ops.minimum(x, y)
2453 
2454 
2455 @keras_export('keras.backend.sin')
2456 def sin(x):
2457   """Computes sin of x element-wise.
2458 
2459   Arguments:
2460       x: Tensor or variable.
2461 
2462   Returns:
2463       A tensor.
2464   """
2465   return math_ops.sin(x)
2466 
2467 
2468 @keras_export('keras.backend.cos')
2469 def cos(x):
2470   """Computes cos of x element-wise.
2471 
2472   Arguments:
2473       x: Tensor or variable.
2474 
2475   Returns:
2476       A tensor.
2477   """
2478   return math_ops.cos(x)
2479 
2480 
2481 def _regular_normalize_batch_in_training(x,
2482                                          gamma,
2483                                          beta,
2484                                          reduction_axes,
2485                                          epsilon=1e-3):
2486   """Non-fused version of `normalize_batch_in_training`.
2487 
2488   Arguments:
2489       x: Input tensor or variable.
2490       gamma: Tensor by which to scale the input.
2491       beta: Tensor with which to center the input.
2492       reduction_axes: iterable of integers,
2493           axes over which to normalize.
2494       epsilon: Fuzz factor.
2495 
2496   Returns:
2497       A tuple length of 3, `(normalized_tensor, mean, variance)`.
2498   """
2499   mean, var = nn.moments(x, reduction_axes, None, None, False)
2500   normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
2501   return normed, mean, var
2502 
2503 
2504 def _broadcast_normalize_batch_in_training(x,
2505                                            gamma,
2506                                            beta,
2507                                            reduction_axes,
2508                                            epsilon=1e-3):
2509   """Non-fused, broadcast version of `normalize_batch_in_training`.
2510 
2511   Arguments:
2512       x: Input tensor or variable.
2513       gamma: Tensor by which to scale the input.
2514       beta: Tensor with which to center the input.
2515       reduction_axes: iterable of integers,
2516           axes over which to normalize.
2517       epsilon: Fuzz factor.
2518 
2519   Returns:
2520       A tuple length of 3, `(normalized_tensor, mean, variance)`.
2521   """
2522   mean, var = nn.moments(x, reduction_axes, None, None, False)
2523   target_shape = []
2524   for axis in range(ndim(x)):
2525     if axis in reduction_axes:
2526       target_shape.append(1)
2527     else:
2528       target_shape.append(array_ops.shape(x)[axis])
2529   target_shape = array_ops.stack(target_shape)
2530 
2531   broadcast_mean = array_ops.reshape(mean, target_shape)
2532   broadcast_var = array_ops.reshape(var, target_shape)
2533   if gamma is None:
2534     broadcast_gamma = None
2535   else:
2536     broadcast_gamma = array_ops.reshape(gamma, target_shape)
2537   if beta is None:
2538     broadcast_beta = None
2539   else:
2540     broadcast_beta = array_ops.reshape(beta, target_shape)
2541 
2542   normed = nn.batch_normalization(x, broadcast_mean, broadcast_var,
2543                                   broadcast_beta, broadcast_gamma, epsilon)
2544   return normed, mean, var
2545 
2546 
2547 def _fused_normalize_batch_in_training(x,
2548                                        gamma,
2549                                        beta,
2550                                        reduction_axes,
2551                                        epsilon=1e-3):
2552   """Fused version of `normalize_batch_in_training`.
2553 
2554   Arguments:
2555       x: Input tensor or variable.
2556       gamma: Tensor by which to scale the input.
2557       beta: Tensor with which to center the input.
2558       reduction_axes: iterable of integers,
2559           axes over which to normalize.
2560       epsilon: Fuzz factor.
2561 
2562   Returns:
2563       A tuple length of 3, `(normalized_tensor, mean, variance)`.
2564   """
2565   if list(reduction_axes) == [0, 1, 2]:
2566     normalization_axis = 3
2567     tf_data_format = 'NHWC'
2568   else:
2569     normalization_axis = 1
2570     tf_data_format = 'NCHW'
2571 
2572   if gamma is None:
2573     gamma = constant_op.constant(
2574         1.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
2575   if beta is None:
2576     beta = constant_op.constant(
2577         0.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
2578 
2579   return nn.fused_batch_norm(
2580       x, gamma, beta, epsilon=epsilon, data_format=tf_data_format)
2581 
2582 
2583 @keras_export('keras.backend.normalize_batch_in_training')
2584 def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
2585   """Computes mean and std for batch then apply batch_normalization on batch.
2586 
2587   Arguments:
2588       x: Input tensor or variable.
2589       gamma: Tensor by which to scale the input.
2590       beta: Tensor with which to center the input.
2591       reduction_axes: iterable of integers,
2592           axes over which to normalize.
2593       epsilon: Fuzz factor.
2594 
2595   Returns:
2596       A tuple length of 3, `(normalized_tensor, mean, variance)`.
2597   """
2598   if ndim(x) == 4 and list(reduction_axes) in [[0, 1, 2], [0, 2, 3]]:
2599     if not _has_nchw_support() and list(reduction_axes) == [0, 2, 3]:
2600       return _broadcast_normalize_batch_in_training(
2601           x, gamma, beta, reduction_axes, epsilon=epsilon)
2602     return _fused_normalize_batch_in_training(
2603         x, gamma, beta, reduction_axes, epsilon=epsilon)
2604   else:
2605     if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
2606       return _regular_normalize_batch_in_training(
2607           x, gamma, beta, reduction_axes, epsilon=epsilon)
2608     else:
2609       return _broadcast_normalize_batch_in_training(
2610           x, gamma, beta, reduction_axes, epsilon=epsilon)
2611 
2612 
2613 @keras_export('keras.backend.batch_normalization')
2614 def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
2615   """Applies batch normalization on x given mean, var, beta and gamma.
2616 
2617   I.e. returns:
2618   `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
2619 
2620   Arguments:
2621       x: Input tensor or variable.
2622       mean: Mean of batch.
2623       var: Variance of batch.
2624       beta: Tensor with which to center the input.
2625       gamma: Tensor by which to scale the input.
2626       axis: Integer, the axis that should be normalized.
2627           (typically the features axis).
2628       epsilon: Fuzz factor.
2629 
2630   Returns:
2631       A tensor.
2632   """
2633   if ndim(x) == 4:
2634     # The CPU implementation of `fused_batch_norm` only supports NHWC
2635     if axis == 1 or axis == -3:
2636       tf_data_format = 'NCHW'
2637     elif axis == 3 or axis == -1:
2638       tf_data_format = 'NHWC'
2639     else:
2640       tf_data_format = None
2641 
2642     if (tf_data_format == 'NHWC' or
2643         tf_data_format == 'NCHW' and _has_nchw_support()):
2644       # The mean / var / beta / gamma tensors may be broadcasted
2645       # so they may have extra axes of size 1, which should be squeezed.
2646       if ndim(mean) > 1:
2647         mean = array_ops.reshape(mean, [-1])
2648       if ndim(var) > 1:
2649         var = array_ops.reshape(var, [-1])
2650       if beta is None:
2651         beta = zeros_like(mean)
2652       elif ndim(beta) > 1:
2653         beta = array_ops.reshape(beta, [-1])
2654       if gamma is None:
2655         gamma = ones_like(mean)
2656       elif ndim(gamma) > 1:
2657         gamma = array_ops.reshape(gamma, [-1])
2658     y, _, _ = nn.fused_batch_norm(
2659         x,
2660         gamma,
2661         beta,
2662         epsilon=epsilon,
2663         mean=mean,
2664         variance=var,
2665         data_format=tf_data_format,
2666         is_training=False
2667     )
2668     return y
2669   return nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
2670 
2671 
2672 # SHAPE OPERATIONS
2673 
2674 
2675 @keras_export('keras.backend.concatenate')
2676 def concatenate(tensors, axis=-1):
2677   """Concatenates a list of tensors alongside the specified axis.
2678 
2679   Arguments:
2680       tensors: list of tensors to concatenate.
2681       axis: concatenation axis.
2682 
2683   Returns:
2684       A tensor.
2685 
2686   Example:
2687 
2688       >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
2689       >>> b = tf.constant([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
2690       >>> tf.keras.backend.concatenate((a, b), axis=-1)
2691       <tf.Tensor: shape=(3, 6), dtype=int32, numpy=
2692       array([[ 1,  2,  3, 10, 20, 30],
2693              [ 4,  5,  6, 40, 50, 60],
2694              [ 7,  8,  9, 70, 80, 90]], dtype=int32)>
2695 
2696   """
2697   if axis < 0:
2698     rank = ndim(tensors[0])
2699     if rank:
2700       axis %= rank
2701     else:
2702       axis = 0
2703 
2704   if py_all(is_sparse(x) for x in tensors):
2705     return sparse_ops.sparse_concat(axis, tensors)
2706   elif py_all(isinstance(x, ragged_tensor.RaggedTensor) for x in tensors):
2707     return ragged_concat_ops.concat(tensors, axis)
2708   else:
2709     return array_ops.concat([to_dense(x) for x in tensors], axis)
2710 
2711 
2712 @keras_export('keras.backend.reshape')
2713 def reshape(x, shape):
2714   """Reshapes a tensor to the specified shape.
2715 
2716   Arguments:
2717       x: Tensor or variable.
2718       shape: Target shape tuple.
2719 
2720   Returns:
2721       A tensor.
2722 
2723   Example:
2724 
2725     >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
2726     >>> a
2727     <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
2728     array([[ 1,  2,  3],
2729            [ 4,  5,  6],
2730            [ 7,  8,  9],
2731            [10, 11, 12]], dtype=int32)>
2732     >>> tf.keras.backend.reshape(a, shape=(2, 6))
2733     <tf.Tensor: shape=(2, 6), dtype=int32, numpy=
2734     array([[ 1,  2,  3,  4,  5,  6],
2735            [ 7,  8,  9, 10, 11, 12]], dtype=int32)>
2736 
2737   """
2738   return array_ops.reshape(x, shape)
2739 
2740 
2741 @keras_export('keras.backend.permute_dimensions')
2742 def permute_dimensions(x, pattern):
2743   """Permutes axes in a tensor.
2744 
2745   Arguments:
2746       x: Tensor or variable.
2747       pattern: A tuple of
2748           dimension indices, e.g. `(0, 2, 1)`.
2749 
2750   Returns:
2751       A tensor.
2752 
2753   Example:
2754 
2755     >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
2756     >>> a
2757     <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
2758     array([[ 1,  2,  3],
2759            [ 4,  5,  6],
2760            [ 7,  8,  9],
2761            [10, 11, 12]], dtype=int32)>
2762     >>> tf.keras.backend.permute_dimensions(a, pattern=(1, 0))
2763     <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
2764     array([[ 1,  4,  7, 10],
2765            [ 2,  5,  8, 11],
2766            [ 3,  6,  9, 12]], dtype=int32)>
2767 
2768   """
2769   return array_ops.transpose(x, perm=pattern)
2770 
2771 
2772 @keras_export('keras.backend.resize_images')
2773 def resize_images(x, height_factor, width_factor, data_format,
2774                   interpolation='nearest'):
2775   """Resizes the images contained in a 4D tensor.
2776 
2777   Arguments:
2778       x: Tensor or variable to resize.
2779       height_factor: Positive integer.
2780       width_factor: Positive integer.
2781       data_format: One of `"channels_first"`, `"channels_last"`.
2782       interpolation: A string, one of `nearest` or `bilinear`.
2783 
2784   Returns:
2785       A tensor.
2786 
2787   Raises:
2788       ValueError: in case of incorrect value for
2789         `data_format` or `interpolation`.
2790   """
2791   if data_format == 'channels_first':
2792     rows, cols = 2, 3
2793   elif data_format == 'channels_last':
2794     rows, cols = 1, 2
2795   else:
2796     raise ValueError('Invalid `data_format` argument: %s' % (data_format,))
2797 
2798   original_shape = int_shape(x)
2799   new_shape = array_ops.shape(x)[rows:cols + 1]
2800   new_shape *= constant_op.constant(
2801       np.array([height_factor, width_factor], dtype='int32'))
2802 
2803   if data_format == 'channels_first':
2804     x = permute_dimensions(x, [0, 2, 3, 1])
2805   if interpolation == 'nearest':
2806     x = image_ops.resize_images_v2(
2807         x, new_shape, method=image_ops.ResizeMethod.NEAREST_NEIGHBOR)
2808   elif interpolation == 'bilinear':
2809     x = image_ops.resize_images_v2(x, new_shape,
2810                                    method=image_ops.ResizeMethod.BILINEAR)
2811   else:
2812     raise ValueError('interpolation should be one '
2813                      'of "nearest" or "bilinear".')
2814   if data_format == 'channels_first':
2815     x = permute_dimensions(x, [0, 3, 1, 2])
2816 
2817   if original_shape[rows] is None:
2818     new_height = None
2819   else:
2820     new_height = original_shape[rows] * height_factor
2821 
2822   if original_shape[cols] is None:
2823     new_width = None
2824   else:
2825     new_width = original_shape[cols] * width_factor
2826 
2827   if data_format == 'channels_first':
2828     output_shape = (None, None, new_height, new_width)
2829   else:
2830     output_shape = (None, new_height, new_width, None)
2831   x.set_shape(output_shape)
2832   return x
2833 
2834 
2835 @keras_export('keras.backend.resize_volumes')
2836 def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
2837   """Resizes the volume contained in a 5D tensor.
2838 
2839   Arguments:
2840       x: Tensor or variable to resize.
2841       depth_factor: Positive integer.
2842       height_factor: Positive integer.
2843       width_factor: Positive integer.
2844       data_format: One of `"channels_first"`, `"channels_last"`.
2845 
2846   Returns:
2847       A tensor.
2848 
2849   Raises:
2850       ValueError: if `data_format` is neither
2851           `channels_last` or `channels_first`.
2852   """
2853   if data_format == 'channels_first':
2854     output = repeat_elements(x, depth_factor, axis=2)
2855     output = repeat_elements(output, height_factor, axis=3)
2856     output = repeat_elements(output, width_factor, axis=4)
2857     return output
2858   elif data_format == 'channels_last':
2859     output = repeat_elements(x, depth_factor, axis=1)
2860     output = repeat_elements(output, height_factor, axis=2)
2861     output = repeat_elements(output, width_factor, axis=3)
2862     return output
2863   else:
2864     raise ValueError('Invalid data_format: ' + str(data_format))
2865 
2866 
2867 @keras_export('keras.backend.repeat_elements')
2868 def repeat_elements(x, rep, axis):
2869   """Repeats the elements of a tensor along an axis, like `np.repeat`.
2870 
2871   If `x` has shape `(s1, s2, s3)` and `axis` is `1`, the output
2872   will have shape `(s1, s2 * rep, s3)`.
2873 
2874   Arguments:
2875       x: Tensor or variable.
2876       rep: Python integer, number of times to repeat.
2877       axis: Axis along which to repeat.
2878 
2879   Returns:
2880       A tensor.
2881 
2882   Example:
2883 
2884       >>> b = tf.constant([1, 2, 3])
2885       >>> tf.keras.backend.repeat_elements(b, rep=2, axis=0)
2886       <tf.Tensor: shape=(6,), dtype=int32,
2887           numpy=array([1, 1, 2, 2, 3, 3], dtype=int32)>
2888 
2889   """
2890   x_shape = x.shape.as_list()
2891   # For static axis
2892   if x_shape[axis] is not None:
2893     # slices along the repeat axis
2894     splits = array_ops.split(value=x,
2895                              num_or_size_splits=x_shape[axis],
2896                              axis=axis)
2897     # repeat each slice the given number of reps
2898     x_rep = [s for s in splits for _ in range(rep)]
2899     return concatenate(x_rep, axis)
2900 
2901   # Here we use tf.tile to mimic behavior of np.repeat so that
2902   # we can handle dynamic shapes (that include None).
2903   # To do that, we need an auxiliary axis to repeat elements along
2904   # it and then merge them along the desired axis.
2905 
2906   # Repeating
2907   auxiliary_axis = axis + 1
2908   x_shape = array_ops.shape(x)
2909   x_rep = array_ops.expand_dims(x, axis=auxiliary_axis)
2910   reps = np.ones(len(x.shape) + 1)
2911   reps[auxiliary_axis] = rep
2912   x_rep = array_ops.tile(x_rep, reps)
2913 
2914   # Merging
2915   reps = np.delete(reps, auxiliary_axis)
2916   reps[axis] = rep
2917   reps = array_ops.constant(reps, dtype='int32')
2918   x_shape *= reps
2919   x_rep = array_ops.reshape(x_rep, x_shape)
2920 
2921   # Fix shape representation
2922   x_shape = x.shape.as_list()
2923   x_rep.set_shape(x_shape)
2924   x_rep._keras_shape = tuple(x_shape)
2925   return x_rep
2926 
2927 
2928 @keras_export('keras.backend.repeat')
2929 def repeat(x, n):
2930   """Repeats a 2D tensor.
2931 
2932   if `x` has shape (samples, dim) and `n` is `2`,
2933   the output will have shape `(samples, 2, dim)`.
2934 
2935   Arguments:
2936       x: Tensor or variable.
2937       n: Python integer, number of times to repeat.
2938 
2939   Returns:
2940       A tensor.
2941 
2942   Example:
2943 
2944       >>> b = tf.constant([[1, 2], [3, 4]])
2945       >>> b
2946       <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
2947       array([[1, 2],
2948              [3, 4]], dtype=int32)>
2949       >>> tf.keras.backend.repeat(b, n=2)
2950       <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
2951       array([[[1, 2],
2952               [1, 2]],
2953              [[3, 4],
2954               [3, 4]]], dtype=int32)>
2955 
2956   """
2957   assert ndim(x) == 2
2958   x = array_ops.expand_dims(x, 1)
2959   pattern = array_ops.stack([1, n, 1])
2960   return array_ops.tile(x, pattern)
2961 
2962 
2963 @keras_export('keras.backend.arange')
2964 def arange(start, stop=None, step=1, dtype='int32'):
2965   """Creates a 1D tensor containing a sequence of integers.
2966 
2967   The function arguments use the same convention as
2968   Theano's arange: if only one argument is provided,
2969   it is in fact the "stop" argument and "start" is 0.
2970 
2971   The default type of the returned tensor is `'int32'` to
2972   match TensorFlow's default.
2973 
2974   Arguments:
2975       start: Start value.
2976       stop: Stop value.
2977       step: Difference between two successive values.
2978       dtype: Integer dtype to use.
2979 
2980   Returns:
2981       An integer tensor.
2982 
2983   Example:
2984 
2985       >>> tf.keras.backend.arange(start=0, stop=10, step=1.5)
2986       <tf.Tensor: shape=(7,), dtype=float32,
2987           numpy=array([0. , 1.5, 3. , 4.5, 6. , 7.5, 9. ], dtype=float32)>
2988 
2989 
2990 
2991   """
2992   # Match the behavior of numpy and Theano by returning an empty sequence.
2993   if stop is None and start < 0:
2994     start = 0
2995   result = math_ops.range(start, limit=stop, delta=step, name='arange')
2996   if dtype != 'int32':
2997     result = cast(result, dtype)
2998   return result
2999 
3000 
3001 @keras_export('keras.backend.tile')
3002 def tile(x, n):
3003   """Creates a tensor by tiling `x` by `n`.
3004 
3005   Arguments:
3006       x: A tensor or variable
3007       n: A list of integer. The length must be the same as the number of
3008           dimensions in `x`.
3009 
3010   Returns:
3011       A tiled tensor.
3012   """
3013   if isinstance(n, int):
3014     n = [n]
3015   return array_ops.tile(x, n)
3016 
3017 
3018 @keras_export('keras.backend.flatten')
3019 def flatten(x):
3020   """Flatten a tensor.
3021 
3022   Arguments:
3023       x: A tensor or variable.
3024 
3025   Returns:
3026       A tensor, reshaped into 1-D
3027 
3028   Example:
3029 
3030       >>> b = tf.constant([[1, 2], [3, 4]])
3031       >>> b
3032       <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
3033       array([[1, 2],
3034              [3, 4]], dtype=int32)>
3035       >>> tf.keras.backend.flatten(b)
3036       <tf.Tensor: shape=(4,), dtype=int32,
3037           numpy=array([1, 2, 3, 4], dtype=int32)>
3038 
3039   """
3040   return array_ops.reshape(x, [-1])
3041 
3042 
3043 @keras_export('keras.backend.batch_flatten')
3044 def batch_flatten(x):
3045   """Turn a nD tensor into a 2D tensor with same 0th dimension.
3046 
3047   In other words, it flattens each data samples of a batch.
3048 
3049   Arguments:
3050       x: A tensor or variable.
3051 
3052   Returns:
3053       A tensor.
3054 
3055   Examples:
3056     Flattening a 3D tensor to 2D by collapsing the last dimension.
3057 
3058   >>> x_batch = tf.keras.backend.ones(shape=(2, 3, 4, 5))
3059   >>> x_batch_flatten = batch_flatten(x_batch)
3060   >>> tf.keras.backend.int_shape(x_batch_flatten)
3061   (2, 60)
3062 
3063   """
3064   x = array_ops.reshape(x, array_ops.stack([-1, prod(shape(x)[1:])]))
3065   return x
3066 
3067 
3068 @keras_export('keras.backend.expand_dims')
3069 def expand_dims(x, axis=-1):
3070   """Adds a 1-sized dimension at index "axis".
3071 
3072   Arguments:
3073       x: A tensor or variable.
3074       axis: Position where to add a new axis.
3075 
3076   Returns:
3077       A tensor with expanded dimensions.
3078   """
3079   return array_ops.expand_dims(x, axis)
3080 
3081 
3082 @keras_export('keras.backend.squeeze')
3083 def squeeze(x, axis):
3084   """Removes a 1-dimension from the tensor at index "axis".
3085 
3086   Arguments:
3087       x: A tensor or variable.
3088       axis: Axis to drop.
3089 
3090   Returns:
3091       A tensor with the same data as `x` but reduced dimensions.
3092   """
3093   return array_ops.squeeze(x, [axis])
3094 
3095 
3096 @keras_export('keras.backend.temporal_padding')
3097 def temporal_padding(x, padding=(1, 1)):
3098   """Pads the middle dimension of a 3D tensor.
3099 
3100   Arguments:
3101       x: Tensor or variable.
3102       padding: Tuple of 2 integers, how many zeros to
3103           add at the start and end of dim 1.
3104 
3105   Returns:
3106       A padded 3D tensor.
3107   """
3108   assert len(padding) == 2
3109   pattern = [[0, 0], [padding[0], padding[1]], [0, 0]]
3110   return array_ops.pad(x, pattern)
3111 
3112 
3113 @keras_export('keras.backend.spatial_2d_padding')
3114 def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
3115   """Pads the 2nd and 3rd dimensions of a 4D tensor.
3116 
3117   Arguments:
3118       x: Tensor or variable.
3119       padding: Tuple of 2 tuples, padding pattern.
3120       data_format: One of `channels_last` or `channels_first`.
3121 
3122   Returns:
3123       A padded 4D tensor.
3124 
3125   Raises:
3126       ValueError: if `data_format` is neither
3127           `channels_last` or `channels_first`.
3128   """
3129   assert len(padding) == 2
3130   assert len(padding[0]) == 2
3131   assert len(padding[1]) == 2
3132   if data_format is None:
3133     data_format = image_data_format()
3134   if data_format not in {'channels_first', 'channels_last'}:
3135     raise ValueError('Unknown data_format: ' + str(data_format))
3136 
3137   if data_format == 'channels_first':
3138     pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])]
3139   else:
3140     pattern = [[0, 0], list(padding[0]), list(padding[1]), [0, 0]]
3141   return array_ops.pad(x, pattern)
3142 
3143 
3144 @keras_export('keras.backend.spatial_3d_padding')
3145 def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
3146   """Pads 5D tensor with zeros along the depth, height, width dimensions.
3147 
3148   Pads these dimensions with respectively
3149   "padding[0]", "padding[1]" and "padding[2]" zeros left and right.
3150 
3151   For 'channels_last' data_format,
3152   the 2nd, 3rd and 4th dimension will be padded.
3153   For 'channels_first' data_format,
3154   the 3rd, 4th and 5th dimension will be padded.
3155 
3156   Arguments:
3157       x: Tensor or variable.
3158       padding: Tuple of 3 tuples, padding pattern.
3159       data_format: One of `channels_last` or `channels_first`.
3160 
3161   Returns:
3162       A padded 5D tensor.
3163 
3164   Raises:
3165       ValueError: if `data_format` is neither
3166           `channels_last` or `channels_first`.
3167 
3168   """
3169   assert len(padding) == 3
3170   assert len(padding[0]) == 2
3171   assert len(padding[1]) == 2
3172   assert len(padding[2]) == 2
3173   if data_format is None:
3174     data_format = image_data_format()
3175   if data_format not in {'channels_first', 'channels_last'}:
3176     raise ValueError('Unknown data_format: ' + str(data_format))
3177 
3178   if data_format == 'channels_first':
3179     pattern = [[0, 0], [0, 0], [padding[0][0], padding[0][1]],
3180                [padding[1][0], padding[1][1]], [padding[2][0], padding[2][1]]]
3181   else:
3182     pattern = [[0, 0], [padding[0][0], padding[0][1]],
3183                [padding[1][0], padding[1][1]], [padding[2][0],
3184                                                 padding[2][1]], [0, 0]]
3185   return array_ops.pad(x, pattern)
3186 
3187 
3188 @keras_export('keras.backend.stack')
3189 def stack(x, axis=0):
3190   """Stacks a list of rank `R` tensors into a rank `R+1` tensor.
3191 
3192   Arguments:
3193       x: List of tensors.
3194       axis: Axis along which to perform stacking.
3195 
3196   Returns:
3197       A tensor.
3198 
3199   Example:
3200 
3201       >>> a = tf.constant([[1, 2],[3, 4]])
3202       >>> b = tf.constant([[10, 20],[30, 40]])
3203       >>> tf.keras.backend.stack((a, b))
3204       <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
3205       array([[[ 1,  2],
3206               [ 3,  4]],
3207              [[10, 20],
3208               [30, 40]]], dtype=int32)>
3209 
3210   """
3211   return array_ops.stack(x, axis=axis)
3212 
3213 
3214 @keras_export('keras.backend.one_hot')
3215 def one_hot(indices, num_classes):
3216   """Computes the one-hot representation of an integer tensor.
3217 
3218   Arguments:
3219       indices: nD integer tensor of shape
3220           `(batch_size, dim1, dim2, ... dim(n-1))`
3221       num_classes: Integer, number of classes to consider.
3222 
3223   Returns:
3224       (n + 1)D one hot representation of the input
3225       with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
3226 
3227   Returns:
3228       The one-hot tensor.
3229   """
3230   return array_ops.one_hot(indices, depth=num_classes, axis=-1)
3231 
3232 
3233 @keras_export('keras.backend.reverse')
3234 def reverse(x, axes):
3235   """Reverse a tensor along the specified axes.
3236 
3237   Arguments:
3238       x: Tensor to reverse.
3239       axes: Integer or iterable of integers.
3240           Axes to reverse.
3241 
3242   Returns:
3243       A tensor.
3244   """
3245   if isinstance(axes, int):
3246     axes = [axes]
3247   return array_ops.reverse(x, axes)
3248 
3249 
3250 # VALUE MANIPULATION
3251 _VALUE_SET_CODE_STRING = """
3252   >>> K = tf.keras.backend  # Common keras convention
3253   >>> v = K.variable(1.)
3254 
3255   >>> # reassign
3256   >>> K.set_value(v, 2.)
3257   >>> print(K.get_value(v))
3258   2.0
3259 
3260   >>> # increment
3261   >>> K.set_value(v, K.get_value(v) + 1)
3262   >>> print(K.get_value(v))
3263   3.0
3264 
3265   Variable semantics in TensorFlow 2 are eager execution friendly. The above
3266   code is roughly equivalent to:
3267 
3268   >>> v = tf.Variable(1.)
3269 
3270   >>> _ = v.assign(2.)
3271   >>> print(v.numpy())
3272   2.0
3273 
3274   >>> _ = v.assign_add(1.)
3275   >>> print(v.numpy())
3276   3.0"""[3:]  # Prune first newline and indent to match the docstring template.
3277 
3278 
3279 @keras_export('keras.backend.get_value')
3280 def get_value(x):
3281   """Returns the value of a variable.
3282 
3283   `backend.get_value` is the compliment of `backend.set_value`, and provides
3284   a generic interface for reading from variables while abstracting away the
3285   differences between TensorFlow 1.x and 2.x semantics.
3286 
3287   {snippet}
3288 
3289   Arguments:
3290       x: input variable.
3291 
3292   Returns:
3293       A Numpy array.
3294   """
3295   if not tensor_util.is_tensor(x):
3296     return x
3297   if context.executing_eagerly() or isinstance(x, ops.EagerTensor):
3298     return x.numpy()
3299   if not getattr(x, '_in_graph_mode', True):
3300     # This is a variable which was created in an eager context, but is being
3301     # evaluated from a Graph.
3302     with context.eager_mode():
3303       return x.numpy()
3304 
3305   if ops.executing_eagerly_outside_functions():
3306     # This method of evaluating works inside the Keras FuncGraph.
3307     return function([], x)(x)
3308 
3309   with x.graph.as_default():
3310     return x.eval(session=get_session((x,)))
3311 
3312 
3313 @keras_export('keras.backend.batch_get_value')
3314 def batch_get_value(tensors):
3315   """Returns the value of more than one tensor variable.
3316 
3317   Arguments:
3318       tensors: list of ops to run.
3319 
3320   Returns:
3321       A list of Numpy arrays.
3322 
3323   Raises:
3324       RuntimeError: If this method is called inside defun.
3325   """
3326   if context.executing_eagerly():
3327     return [x.numpy() for x in tensors]
3328   elif ops.inside_function():  # pylint: disable=protected-access
3329     raise RuntimeError('Cannot get value inside Tensorflow graph function.')
3330   if tensors:
3331     return get_session(tensors).run(tensors)
3332   else:
3333     return []
3334 
3335 
3336 @keras_export('keras.backend.set_value')
3337 def set_value(x, value):
3338   """Sets the value of a variable, from a Numpy array.
3339 
3340   `backend.set_value` is the compliment of `backend.get_value`, and provides
3341   a generic interface for assigning to variables while abstracting away the
3342   differences between TensorFlow 1.x and 2.x semantics.
3343 
3344   {snippet}
3345 
3346   Arguments:
3347       x: Variable to set to a new value.
3348       value: Value to set the tensor to, as a Numpy array
3349           (of the same shape).
3350   """
3351   value = np.asarray(value, dtype=dtype(x))
3352   if ops.executing_eagerly_outside_functions():
3353     x.assign(value)
3354   else:
3355     with get_graph().as_default():
3356       tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
3357       if hasattr(x, '_assign_placeholder'):
3358         assign_placeholder = x._assign_placeholder
3359         assign_op = x._assign_op
3360       else:
3361         # In order to support assigning weights to resizable variables in
3362         # Keras, we make a placeholder with the correct number of dimensions
3363         # but with None in each dimension. This way, we can assign weights
3364         # of any size (as long as they have the correct dimensionality).
3365         placeholder_shape = tensor_shape.TensorShape([None] * value.ndim)
3366         assign_placeholder = array_ops.placeholder(
3367             tf_dtype, shape=placeholder_shape)
3368         assign_op = x.assign(assign_placeholder)
3369         x._assign_placeholder = assign_placeholder
3370         x._assign_op = assign_op
3371       get_session().run(assign_op, feed_dict={assign_placeholder: value})
3372 
3373 
3374 @keras_export('keras.backend.batch_set_value')
3375 def batch_set_value(tuples):
3376   """Sets the values of many tensor variables at once.
3377 
3378   Arguments:
3379       tuples: a list of tuples `(tensor, value)`.
3380           `value` should be a Numpy array.
3381   """
3382   if ops.executing_eagerly_outside_functions():
3383     for x, value in tuples:
3384       x.assign(np.asarray(value, dtype=dtype(x)))
3385   else:
3386     with get_graph().as_default():
3387       if tuples:
3388         assign_ops = []
3389         feed_dict = {}
3390         for x, value in tuples:
3391           value = np.asarray(value, dtype=dtype(x))
3392           tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
3393           if hasattr(x, '_assign_placeholder'):
3394             assign_placeholder = x._assign_placeholder
3395             assign_op = x._assign_op
3396           else:
3397             # In order to support assigning weights to resizable variables in
3398             # Keras, we make a placeholder with the correct number of dimensions
3399             # but with None in each dimension. This way, we can assign weights
3400             # of any size (as long as they have the correct dimensionality).
3401             placeholder_shape = tensor_shape.TensorShape([None] * value.ndim)
3402             assign_placeholder = array_ops.placeholder(
3403                 tf_dtype, shape=placeholder_shape)
3404             assign_op = x.assign(assign_placeholder)
3405             x._assign_placeholder = assign_placeholder
3406             x._assign_op = assign_op
3407           assign_ops.append(assign_op)
3408           feed_dict[assign_placeholder] = value
3409         get_session().run(assign_ops, feed_dict=feed_dict)
3410 
3411 
3412 get_value.__doc__ = get_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
3413 set_value.__doc__ = set_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
3414 
3415 
3416 @keras_export('keras.backend.print_tensor')
3417 def print_tensor(x, message=''):
3418   """Prints `message` and the tensor value when evaluated.
3419 
3420   Note that `print_tensor` returns a new tensor identical to `x`
3421   which should be used in the following code. Otherwise the
3422   print operation is not taken into account during evaluation.
3423 
3424   Example:
3425 
3426   >>> x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
3427   >>> tf.keras.backend.print_tensor(x)
3428   <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
3429     array([[1., 2.],
3430            [3., 4.]], dtype=float32)>
3431 
3432   Arguments:
3433       x: Tensor to print.
3434       message: Message to print jointly with the tensor.
3435 
3436   Returns:
3437       The same tensor `x`, unchanged.
3438   """
3439   if isinstance(x, ops.Tensor) and hasattr(x, 'graph'):
3440     with get_graph().as_default():
3441       op = logging_ops.print_v2(message, x, output_stream=sys.stdout)
3442       with ops.control_dependencies([op]):
3443         return array_ops.identity(x)
3444   else:
3445     logging_ops.print_v2(message, x, output_stream=sys.stdout)
3446     return x
3447 
3448 # GRAPH MANIPULATION
3449 
3450 
3451 class GraphExecutionFunction(object):
3452   """Runs a computation graph.
3453 
3454   It's possible to pass arguments to `tf.Session.run()` via `session_kwargs`.
3455   In particular additional operations via `fetches` argument and additional
3456   tensor substitutions via `feed_dict` arguments. Note that given
3457   substitutions are merged with substitutions from `inputs`. Even though
3458   `feed_dict` is passed once in the constructor (called in `model.compile()`)
3459   we can modify the values in the dictionary. Through this feed_dict we can
3460   provide additional substitutions besides Keras inputs.
3461 
3462   Arguments:
3463       inputs: Feed placeholders to the computation graph.
3464       outputs: Output tensors to fetch.
3465       updates: Additional update ops to be run at function call.
3466       name: A name to help users identify what this function does.
3467       session_kwargs: Arguments to `tf.Session.run()`:
3468                       `fetches`, `feed_dict`, `options`, `run_metadata`.
3469   """
3470 
3471   def __init__(self, inputs, outputs, updates=None, name=None,
3472                **session_kwargs):
3473     updates = updates or []
3474     if not isinstance(updates, (list, tuple)):
3475       raise TypeError('`updates` in a Keras backend function '
3476                       'should be a list or tuple.')
3477 
3478     self._inputs_structure = inputs
3479     self.inputs = nest.flatten(inputs, expand_composites=True)
3480     self._outputs_structure = outputs
3481     self.outputs = cast_variables_to_tensor(
3482         nest.flatten(outputs, expand_composites=True))
3483     # TODO(b/127668432): Consider using autograph to generate these
3484     # dependencies in call.
3485     # Index 0 = total loss or model output for `predict`.
3486     with ops.control_dependencies([self.outputs[0]]):
3487       updates_ops = []
3488       for update in updates:
3489         if isinstance(update, tuple):
3490           p, new_p = update
3491           updates_ops.append(state_ops.assign(p, new_p))
3492         else:
3493           # assumed already an op
3494           updates_ops.append(update)
3495       self.updates_op = control_flow_ops.group(*updates_ops)
3496     self.name = name
3497     # additional tensor substitutions
3498     self.feed_dict = session_kwargs.pop('feed_dict', None)
3499     # additional operations
3500     self.fetches = session_kwargs.pop('fetches', [])
3501     if not isinstance(self.fetches, list):
3502       self.fetches = [self.fetches]
3503     self.run_options = session_kwargs.pop('options', None)
3504     self.run_metadata = session_kwargs.pop('run_metadata', None)
3505     # The main use case of `fetches` being passed to a model is the ability
3506     # to run custom updates
3507     # This requires us to wrap fetches in `identity` ops.
3508     self.fetches = [array_ops.identity(x) for x in self.fetches]
3509     self.session_kwargs = session_kwargs
3510     # This mapping keeps track of the function that should receive the
3511     # output from a fetch in `fetches`: { fetch: function(fetch_output) }
3512     # A Callback can use this to register a function with access to the
3513     # output values for a fetch it added.
3514     self.fetch_callbacks = {}
3515 
3516     if session_kwargs:
3517       raise ValueError('Some keys in session_kwargs are not supported at this '
3518                        'time: %s' % (session_kwargs.keys(),))
3519 
3520     self._callable_fn = None
3521     self._feed_arrays = None
3522     self._feed_symbols = None
3523     self._symbol_vals = None
3524     self._fetches = None
3525     self._session = None
3526 
3527   def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session):
3528     """Generates a callable that runs the graph.
3529 
3530     Arguments:
3531       feed_arrays: List of input tensors to be fed Numpy arrays at runtime.
3532       feed_symbols: List of input tensors to be fed symbolic tensors at runtime.
3533       symbol_vals: List of symbolic tensors to be fed to `feed_symbols`.
3534       session: Session to use to generate the callable.
3535 
3536     Returns:
3537       Function that runs the graph according to the above options.
3538     """
3539     # Prepare callable options.
3540     callable_opts = config_pb2.CallableOptions()
3541     # Handle external-data feed.
3542     for x in feed_arrays:
3543       callable_opts.feed.append(x.name)
3544     if self.feed_dict:
3545       for key in sorted(self.feed_dict.keys()):
3546         callable_opts.feed.append(key.name)
3547     # Handle symbolic feed.
3548     for x, y in zip(feed_symbols, symbol_vals):
3549       connection = callable_opts.tensor_connection.add()
3550       if x.dtype != y.dtype:
3551         y = math_ops.cast(y, dtype=x.dtype)
3552       from_tensor = ops._as_graph_element(y)
3553       if from_tensor is None:
3554         from_tensor = y
3555       connection.from_tensor = from_tensor.name  # Data tensor
3556       connection.to_tensor = x.name  # Placeholder
3557     # Handle fetches.
3558     for x in self.outputs + self.fetches:
3559       callable_opts.fetch.append(x.name)
3560     # Handle updates.
3561     callable_opts.target.append(self.updates_op.name)
3562     # Handle run_options.
3563     if self.run_options:
3564       callable_opts.run_options.CopyFrom(self.run_options)
3565     # Create callable.
3566     callable_fn = session._make_callable_from_options(callable_opts)
3567     # Cache parameters corresponding to the generated callable, so that
3568     # we can detect future mismatches and refresh the callable.
3569     self._callable_fn = callable_fn
3570     self._feed_arrays = feed_arrays
3571     self._feed_symbols = feed_symbols
3572     self._symbol_vals = symbol_vals
3573     self._fetches = list(self.fetches)
3574     self._session = session
3575 
3576   def _call_fetch_callbacks(self, fetches_output):
3577     for fetch, output in zip(self._fetches, fetches_output):
3578       if fetch in self.fetch_callbacks:
3579         self.fetch_callbacks[fetch](output)
3580 
3581   def _eval_if_composite(self, tensor):
3582     """Helper method which evaluates any CompositeTensors passed to it."""
3583     # We need to evaluate any composite tensor objects that have been
3584     # reconstructed in 'pack_sequence_as', since otherwise they'll be output as
3585     # actual CompositeTensor objects instead of the value(s) contained in the
3586     # CompositeTensors. E.g., if output_structure contains a SparseTensor, then
3587     # this ensures that we return its value as a SparseTensorValue rather than
3588     # a SparseTensor.
3589     if isinstance(tensor, composite_tensor.CompositeTensor):
3590       return self._session.run(tensor)
3591     else:
3592       return tensor
3593 
3594   def __call__(self, inputs):
3595     inputs = nest.flatten(inputs, expand_composites=True)
3596 
3597     session = get_session(inputs)
3598     feed_arrays = []
3599     array_vals = []
3600     feed_symbols = []
3601     symbol_vals = []
3602     for tensor, value in zip(self.inputs, inputs):
3603       if value is None:
3604         continue
3605 
3606       if tensor_util.is_tensor(value):
3607         # Case: feeding symbolic tensor.
3608         feed_symbols.append(tensor)
3609         symbol_vals.append(value)
3610       else:
3611         # Case: feeding Numpy array.
3612         feed_arrays.append(tensor)
3613         # We need to do array conversion and type casting at this level, since
3614         # `callable_fn` only supports exact matches.
3615         tensor_type = dtypes_module.as_dtype(tensor.dtype)
3616         array_vals.append(np.asarray(value,
3617                                      dtype=tensor_type.as_numpy_dtype))
3618 
3619     if self.feed_dict:
3620       for key in sorted(self.feed_dict.keys()):
3621         array_vals.append(
3622             np.asarray(self.feed_dict[key], dtype=key.dtype.base_dtype.name))
3623 
3624     # Refresh callable if anything has changed.
3625     if (self._callable_fn is None or feed_arrays != self._feed_arrays or
3626         symbol_vals != self._symbol_vals or
3627         feed_symbols != self._feed_symbols or self.fetches != self._fetches or
3628         session != self._session):
3629       self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
3630 
3631     fetched = self._callable_fn(*array_vals,
3632                                 run_metadata=self.run_metadata)
3633     self._call_fetch_callbacks(fetched[-len(self._fetches):])
3634     output_structure = nest.pack_sequence_as(
3635         self._outputs_structure,
3636         fetched[:len(self.outputs)],
3637         expand_composites=True)
3638     # We need to evaluate any composite tensor objects that have been
3639     # reconstructed in 'pack_sequence_as', since otherwise they'll be output as
3640     # actual CompositeTensor objects instead of the value(s) contained in the
3641     # CompositeTensors. E.g., if output_structure contains a SparseTensor, then
3642     # this ensures that we return its value as a SparseTensorValue rather than
3643     # a SparseTensor.
3644     return nest.map_structure(self._eval_if_composite, output_structure)
3645 
3646 
3647 class EagerExecutionFunction(object):
3648   """Helper class for constructing a TF graph function from the Keras graph.
3649 
3650   Arguments:
3651     inputs: Feed placeholders to the computation graph.
3652     outputs: Output tensors to fetch.
3653     updates: Additional update ops to be run at function call.
3654     name: A name to help users identify what this function does.
3655     session_kwargs: Unsupported.
3656   """
3657 
3658   def __init__(self, inputs, outputs, updates=None, name=None):
3659     self.name = name
3660     self._inputs_structure = inputs
3661     inputs = nest.flatten(inputs, expand_composites=True)
3662     self._outputs_structure = outputs
3663     outputs = nest.flatten(outputs, expand_composites=True)
3664 
3665     updates = updates or []
3666     if not isinstance(updates, (list, tuple)):
3667       raise TypeError('`updates` in a Keras backend function '
3668                       'should be a list or tuple.')
3669 
3670     if updates and not outputs:
3671       # Edge case; never happens in practice
3672       raise ValueError('Cannot create a Keras backend function with updates'
3673                        ' but no outputs during eager execution.')
3674     graphs = {
3675         i.graph
3676         for i in nest.flatten([inputs, outputs, updates])
3677         if hasattr(i, 'graph')
3678     }
3679     if len(graphs) > 1:
3680       raise ValueError('Cannot create an execution function which is comprised '
3681                        'of elements from multiple graphs.')
3682 
3683     source_graph = graphs.pop()
3684     global_graph = get_graph()
3685 
3686     updates_ops = []
3687     legacy_update_ops = []
3688     for update in updates:
3689       # For legacy reasons it is allowed to pass an update as a tuple
3690       # `(variable, new_value)` (this maps to an assign op). Otherwise it
3691       # is assumed to already be an op -- we cannot control its execution
3692       # order.
3693       if isinstance(update, tuple):
3694         legacy_update_ops.append(update)
3695       else:
3696         if hasattr(update, 'op'):
3697           update = update.op
3698         if update is not None:
3699           # `update.op` may have been None in certain cases.
3700           updates_ops.append(update)
3701 
3702     self._freezable_vars_to_feed = []
3703     self._freezable_vars_values = []
3704     freezable_vars_from_keras_graph = object_identity.ObjectIdentitySet(
3705         _FREEZABLE_VARS.get(global_graph, {}))
3706     with _scratch_graph() as exec_graph:
3707       global_graph = get_graph()
3708       if source_graph not in (exec_graph, global_graph):
3709         raise ValueError('Unknown graph. Aborting.')
3710 
3711       if source_graph is global_graph and exec_graph is not global_graph:
3712         init_tensors = (
3713             outputs + updates_ops + [p for [p, _] in legacy_update_ops] +
3714             [p_new for [_, p_new] in legacy_update_ops
3715              if isinstance(p_new, ops.Tensor)])
3716         lifted_map = lift_to_graph.lift_to_graph(
3717             tensors=init_tensors,
3718             graph=exec_graph,
3719             sources=inputs,
3720             add_sources=True,
3721             handle_captures=True,
3722             base_graph=source_graph)
3723 
3724         inputs = [lifted_map[i] for i in inputs]
3725         outputs = [lifted_map[i] for i in outputs]
3726         updates_ops = [lifted_map[i] for i in updates_ops]
3727         legacy_update_ops = [(lifted_map[p], lifted_map.get(p_new, p_new))
3728                              for p, p_new in legacy_update_ops]
3729 
3730         # Keep track of the value to feed to any "freezable variables"
3731         # created in this graph.
3732         for old_op, new_op in lifted_map.items():
3733           if old_op in freezable_vars_from_keras_graph:
3734             frozen_var = old_op
3735             if frozen_var._initial_value != frozen_var._current_value:
3736               # We only feed a frozen_variable if its value has changed;
3737               # otherwise it can rely on the default value of the
3738               # underlying placeholder_with_default.
3739               self._freezable_vars_to_feed.append(new_op)
3740               self._freezable_vars_values.append(frozen_var._current_value)
3741 
3742     # Consolidate updates
3743     with exec_graph.as_default():
3744       outputs = cast_variables_to_tensor(outputs)
3745       with ops.control_dependencies(outputs):
3746         for p, p_new in legacy_update_ops:
3747           updates_ops.append(state_ops.assign(p, p_new))
3748 
3749       self.inputs, self.outputs = inputs, outputs
3750       self._input_references = self.inputs + self._freezable_vars_to_feed
3751       with ops.control_dependencies(updates_ops):
3752         self.outputs[0] = array_ops.identity(self.outputs[0])
3753 
3754       exec_graph.inputs = self._input_references + exec_graph.internal_captures
3755       exec_graph.outputs = self.outputs
3756       graph_fn = eager_function.ConcreteFunction(exec_graph)
3757 
3758     graph_fn._num_positional_args = len(self._input_references)
3759     graph_fn._arg_keywords = []
3760     self._graph_fn = graph_fn
3761 
3762     # Handle placeholders with default
3763     # (treated as required placeholder by graph functions)
3764     self._placeholder_default_values = {}
3765     with exec_graph.as_default():
3766       for x in self.inputs:
3767         if x.op.type == 'PlaceholderWithDefault':
3768           self._placeholder_default_values[ops.tensor_id(
3769               x)] = tensor_util.constant_value(x.op.inputs[0])
3770 
3771   def __call__(self, inputs):
3772     input_values = nest.flatten(inputs, expand_composites=True)
3773 
3774     if self._freezable_vars_values:
3775       input_values = input_values + self._freezable_vars_values
3776     converted_inputs = []
3777     for tensor, value in zip(self._input_references, input_values):
3778       if value is None:
3779         # Assume `value` is a placeholder with default
3780         value = self._placeholder_default_values.get(
3781             ops.tensor_id(tensor), None)
3782         if value is None:
3783           raise ValueError(
3784               'You must feed a value for placeholder %s' % (tensor,))
3785       if not isinstance(value, ops.Tensor):
3786         value = ops.convert_to_tensor(value, dtype=tensor.dtype)
3787       if value.dtype != tensor.dtype:
3788         # Temporary workaround due to `convert_to_tensor` not casting floats.
3789         # See b/119637405
3790         value = math_ops.cast(value, tensor.dtype)
3791       converted_inputs.append(value)
3792     outputs = self._graph_fn(*converted_inputs)
3793 
3794     # EagerTensor.numpy() will often make a copy to ensure memory safety.
3795     # However in this case `outputs` is not directly returned, so it is always
3796     # safe to reuse the underlying buffer without checking. In such a case the
3797     # private numpy conversion method is preferred to guarantee performance.
3798     return nest.pack_sequence_as(
3799         self._outputs_structure,
3800         [x._numpy() for x in outputs],  # pylint: disable=protected-access
3801         expand_composites=True)
3802 
3803 
3804 @keras_export('keras.backend.function')
3805 def function(inputs, outputs, updates=None, name=None, **kwargs):
3806   """Instantiates a Keras function.
3807 
3808   Arguments:
3809       inputs: List of placeholder tensors.
3810       outputs: List of output tensors.
3811       updates: List of update ops.
3812       name: String, name of function.
3813       **kwargs: Passed to `tf.Session.run`.
3814 
3815   Returns:
3816       Output values as Numpy arrays.
3817 
3818   Raises:
3819       ValueError: if invalid kwargs are passed in or if in eager execution.
3820   """
3821   if ops.executing_eagerly_outside_functions():
3822     if kwargs:
3823       raise ValueError('Session keyword arguments are not support during '
3824                        'eager execution. You passed: %s' % (kwargs,))
3825     return EagerExecutionFunction(inputs, outputs, updates=updates, name=name)
3826 
3827   if kwargs:
3828     for key in kwargs:
3829       if (key not in tf_inspect.getfullargspec(session_module.Session.run)[0]
3830           and key not in ['inputs', 'outputs', 'updates', 'name']):
3831         msg = ('Invalid argument "%s" passed to K.function with TensorFlow '
3832                'backend') % key
3833         raise ValueError(msg)
3834   return GraphExecutionFunction(inputs, outputs, updates=updates, **kwargs)
3835 
3836 
3837 @keras_export('keras.backend.gradients')
3838 def gradients(loss, variables):
3839   """Returns the gradients of `loss` w.r.t. `variables`.
3840 
3841   Arguments:
3842       loss: Scalar tensor to minimize.
3843       variables: List of variables.
3844 
3845   Returns:
3846       A gradients tensor.
3847   """
3848   return gradients_module.gradients(
3849       loss, variables, colocate_gradients_with_ops=True)
3850 
3851 
3852 @keras_export('keras.backend.stop_gradient')
3853 def stop_gradient(variables):
3854   """Returns `variables` but with zero gradient w.r.t. every other variable.
3855 
3856   Arguments:
3857       variables: Tensor or list of tensors to consider constant with respect
3858         to any other variable.
3859 
3860 
3861   Returns:
3862       A single tensor or a list of tensors (depending on the passed argument)
3863       that has no gradient with respect to any other variable.
3864   """
3865   if isinstance(variables, (list, tuple)):
3866     return map(array_ops.stop_gradient, variables)
3867   return array_ops.stop_gradient(variables)
3868 
3869 
3870 # CONTROL FLOW
3871 
3872 
3873 @keras_export('keras.backend.rnn')
3874 def rnn(step_function,
3875         inputs,
3876         initial_states,
3877         go_backwards=False,
3878         mask=None,
3879         constants=None,
3880         unroll=False,
3881         input_length=None,
3882         time_major=False,
3883         zero_output_for_mask=False):
3884   """Iterates over the time dimension of a tensor.
3885 
3886   Arguments:
3887       step_function: RNN step function.
3888           Args;
3889               input; Tensor with shape `(samples, ...)` (no time dimension),
3890                   representing input for the batch of samples at a certain
3891                   time step.
3892               states; List of tensors.
3893           Returns;
3894               output; Tensor with shape `(samples, output_dim)`
3895                   (no time dimension).
3896               new_states; List of tensors, same length and shapes
3897                   as 'states'. The first state in the list must be the
3898                   output tensor at the previous timestep.
3899       inputs: Tensor of temporal data of shape `(samples, time, ...)`
3900           (at least 3D), or nested tensors, and each of which has shape
3901           `(samples, time, ...)`.
3902       initial_states: Tensor with shape `(samples, state_size)`
3903           (no time dimension), containing the initial values for the states used
3904           in the step function. In the case that state_size is in a nested
3905           shape, the shape of initial_states will also follow the nested
3906           structure.
3907       go_backwards: Boolean. If True, do the iteration over the time
3908           dimension in reverse order and return the reversed sequence.
3909       mask: Binary tensor with shape `(samples, time, 1)`,
3910           with a zero for every element that is masked.
3911       constants: List of constant values passed at each step.
3912       unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
3913       input_length: An integer or a 1-D Tensor, depending on whether
3914           the time dimension is fixed-length or not. In case of variable length
3915           input, it is used for masking in case there's no mask specified.
3916       time_major: Boolean. If true, the inputs and outputs will be in shape
3917           `(timesteps, batch, ...)`, whereas in the False case, it will be
3918           `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
3919           efficient because it avoids transposes at the beginning and end of the
3920           RNN calculation. However, most TensorFlow data is batch-major, so by
3921           default this function accepts input and emits output in batch-major
3922           form.
3923       zero_output_for_mask: Boolean. If True, the output for masked timestep
3924           will be zeros, whereas in the False case, output from previous
3925           timestep is returned.
3926 
3927   Returns:
3928       A tuple, `(last_output, outputs, new_states)`.
3929           last_output: the latest output of the rnn, of shape `(samples, ...)`
3930           outputs: tensor with shape `(samples, time, ...)` where each
3931               entry `outputs[s, t]` is the output of the step function
3932               at time `t` for sample `s`.
3933           new_states: list of tensors, latest states returned by
3934               the step function, of shape `(samples, ...)`.
3935 
3936   Raises:
3937       ValueError: if input dimension is less than 3.
3938       ValueError: if `unroll` is `True` but input timestep is not a fixed
3939       number.
3940       ValueError: if `mask` is provided (not `None`) but states is not provided
3941           (`len(states)` == 0).
3942   """
3943 
3944   def swap_batch_timestep(input_t):
3945     # Swap the batch and timestep dim for the incoming tensor.
3946     axes = list(range(len(input_t.shape)))
3947     axes[0], axes[1] = 1, 0
3948     return array_ops.transpose(input_t, axes)
3949 
3950   if not time_major:
3951     inputs = nest.map_structure(swap_batch_timestep, inputs)
3952 
3953   flatted_inputs = nest.flatten(inputs)
3954   time_steps = flatted_inputs[0].shape[0]
3955   batch = flatted_inputs[0].shape[1]
3956   time_steps_t = array_ops.shape(flatted_inputs[0])[0]
3957 
3958   for input_ in flatted_inputs:
3959     input_.shape.with_rank_at_least(3)
3960 
3961   if mask is not None:
3962     if mask.dtype != dtypes_module.bool:
3963       mask = math_ops.cast(mask, dtypes_module.bool)
3964     if len(mask.shape) == 2:
3965       mask = expand_dims(mask)
3966     if not time_major:
3967       mask = swap_batch_timestep(mask)
3968 
3969   if constants is None:
3970     constants = []
3971 
3972   # tf.where needs its condition tensor to be the same shape as its two
3973   # result tensors, but in our case the condition (mask) tensor is
3974   # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
3975   # So we need to broadcast the mask to match the shape of inputs.
3976   # That's what the tile call does, it just repeats the mask along its
3977   # second dimension n times.
3978   def _expand_mask(mask_t, input_t, fixed_dim=1):
3979     if nest.is_sequence(mask_t):
3980       raise ValueError('mask_t is expected to be tensor, but got %s' % mask_t)
3981     if nest.is_sequence(input_t):
3982       raise ValueError('input_t is expected to be tensor, but got %s' % input_t)
3983     rank_diff = len(input_t.shape) - len(mask_t.shape)
3984     for _ in range(rank_diff):
3985       mask_t = array_ops.expand_dims(mask_t, -1)
3986     multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
3987     return array_ops.tile(mask_t, multiples)
3988 
3989   if unroll:
3990     if not time_steps:
3991       raise ValueError('Unrolling requires a fixed number of timesteps.')
3992     states = tuple(initial_states)
3993     successive_states = []
3994     successive_outputs = []
3995 
3996     # Process the input tensors. The input tensor need to be split on the
3997     # time_step dim, and reverse if go_backwards is True. In the case of nested
3998     # input, the input is flattened and then transformed individually.
3999     # The result of this will be a tuple of lists, each of the item in tuple is
4000     # list of the tensor with shape (batch, feature)
4001     def _process_single_input_t(input_t):
4002       input_t = array_ops.unstack(input_t)  # unstack for time_step dim
4003       if go_backwards:
4004         input_t.reverse()
4005       return input_t
4006 
4007     if nest.is_sequence(inputs):
4008       processed_input = nest.map_structure(_process_single_input_t, inputs)
4009     else:
4010       processed_input = (_process_single_input_t(inputs),)
4011 
4012     def _get_input_tensor(time):
4013       inp = [t_[time] for t_ in processed_input]
4014       return nest.pack_sequence_as(inputs, inp)
4015 
4016     if mask is not None:
4017       mask_list = array_ops.unstack(mask)
4018       if go_backwards:
4019         mask_list.reverse()
4020 
4021       for i in range(time_steps):
4022         inp = _get_input_tensor(i)
4023         mask_t = mask_list[i]
4024         output, new_states = step_function(inp,
4025                                            tuple(states) + tuple(constants))
4026         tiled_mask_t = _expand_mask(mask_t, output)
4027 
4028         if not successive_outputs:
4029           prev_output = zeros_like(output)
4030         else:
4031           prev_output = successive_outputs[-1]
4032 
4033         output = array_ops.where_v2(tiled_mask_t, output, prev_output)
4034 
4035         flat_states = nest.flatten(states)
4036         flat_new_states = nest.flatten(new_states)
4037         tiled_mask_t = tuple(_expand_mask(mask_t, s) for s in flat_states)
4038         flat_final_states = tuple(
4039             array_ops.where_v2(m, s, ps)
4040             for m, s, ps in zip(tiled_mask_t, flat_new_states, flat_states))
4041         states = nest.pack_sequence_as(states, flat_final_states)
4042 
4043         successive_outputs.append(output)
4044         successive_states.append(states)
4045       last_output = successive_outputs[-1]
4046       new_states = successive_states[-1]
4047       outputs = array_ops.stack(successive_outputs)
4048 
4049       if zero_output_for_mask:
4050         last_output = array_ops.where_v2(
4051             _expand_mask(mask_list[-1], last_output), last_output,
4052             zeros_like(last_output))
4053         outputs = array_ops.where_v2(
4054             _expand_mask(mask, outputs, fixed_dim=2), outputs,
4055             zeros_like(outputs))
4056 
4057     else:  # mask is None
4058       for i in range(time_steps):
4059         inp = _get_input_tensor(i)
4060         output, states = step_function(inp, tuple(states) + tuple(constants))
4061         successive_outputs.append(output)
4062         successive_states.append(states)
4063       last_output = successive_outputs[-1]
4064       new_states = successive_states[-1]
4065       outputs = array_ops.stack(successive_outputs)
4066 
4067   else:  # Unroll == False
4068     states = tuple(initial_states)
4069 
4070     # Create input tensor array, if the inputs is nested tensors, then it will
4071     # be flattened first, and tensor array will be created one per flattened
4072     # tensor.
4073     input_ta = tuple(
4074         tensor_array_ops.TensorArray(
4075             dtype=inp.dtype,
4076             size=time_steps_t,
4077             tensor_array_name='input_ta_%s' % i)
4078         for i, inp in enumerate(flatted_inputs))
4079     input_ta = tuple(
4080         ta.unstack(input_) if not go_backwards else ta
4081         .unstack(reverse(input_, 0))
4082         for ta, input_ in zip(input_ta, flatted_inputs))
4083 
4084     # Get the time(0) input and compute the output for that, the output will be
4085     # used to determine the dtype of output tensor array. Don't read from
4086     # input_ta due to TensorArray clear_after_read default to True.
4087     input_time_zero = nest.pack_sequence_as(inputs,
4088                                             [inp[0] for inp in flatted_inputs])
4089     # output_time_zero is used to determine the cell output shape and its dtype.
4090     # the value is discarded.
4091     output_time_zero, _ = step_function(
4092         input_time_zero, tuple(initial_states) + tuple(constants))
4093     output_ta = tuple(
4094         tensor_array_ops.TensorArray(
4095             dtype=out.dtype,
4096             size=time_steps_t,
4097             element_shape=out.shape,
4098             tensor_array_name='output_ta_%s' % i)
4099         for i, out in enumerate(nest.flatten(output_time_zero)))
4100 
4101     time = constant_op.constant(0, dtype='int32', name='time')
4102 
4103     # We only specify the 'maximum_iterations' when building for XLA since that
4104     # causes slowdowns on GPU in TF.
4105     if (not context.executing_eagerly() and
4106         control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())):
4107       max_iterations = math_ops.reduce_max(input_length)
4108     else:
4109       max_iterations = None
4110 
4111     while_loop_kwargs = {
4112         'cond': lambda time, *_: time < time_steps_t,
4113         'maximum_iterations': max_iterations,
4114         'parallel_iterations': 32,
4115         'swap_memory': True,
4116     }
4117     if mask is not None:
4118       if go_backwards:
4119         mask = reverse(mask, 0)
4120 
4121       mask_ta = tensor_array_ops.TensorArray(
4122           dtype=dtypes_module.bool,
4123           size=time_steps_t,
4124           tensor_array_name='mask_ta')
4125       mask_ta = mask_ta.unstack(mask)
4126 
4127       def masking_fn(time):
4128         return mask_ta.read(time)
4129 
4130       def compute_masked_output(mask_t, flat_out, flat_mask):
4131         tiled_mask_t = tuple(
4132             _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
4133             for o in flat_out)
4134         return tuple(
4135             array_ops.where_v2(m, o, fm)
4136             for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask))
4137     elif isinstance(input_length, ops.Tensor):
4138       if go_backwards:
4139         max_len = math_ops.reduce_max(input_length, axis=0)
4140         rev_input_length = math_ops.subtract(max_len - 1, input_length)
4141 
4142         def masking_fn(time):
4143           return math_ops.less(rev_input_length, time)
4144       else:
4145 
4146         def masking_fn(time):
4147           return math_ops.greater(input_length, time)
4148 
4149       def compute_masked_output(mask_t, flat_out, flat_mask):
4150         return tuple(
4151             array_ops.where(mask_t, o, zo)
4152             for (o, zo) in zip(flat_out, flat_mask))
4153     else:
4154       masking_fn = None
4155 
4156     if masking_fn is not None:
4157       # Mask for the T output will be base on the output of T - 1. In the case
4158       # T = 0, a zero filled tensor will be used.
4159       flat_zero_output = tuple(array_ops.zeros_like(o)
4160                                for o in nest.flatten(output_time_zero))
4161       def _step(time, output_ta_t, prev_output, *states):
4162         """RNN step function.
4163 
4164         Arguments:
4165             time: Current timestep value.
4166             output_ta_t: TensorArray.
4167             prev_output: tuple of outputs from time - 1.
4168             *states: List of states.
4169 
4170         Returns:
4171             Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
4172         """
4173         current_input = tuple(ta.read(time) for ta in input_ta)
4174         # maybe set shape.
4175         current_input = nest.pack_sequence_as(inputs, current_input)
4176         mask_t = masking_fn(time)
4177         output, new_states = step_function(current_input,
4178                                            tuple(states) + tuple(constants))
4179         # mask output
4180         flat_output = nest.flatten(output)
4181         flat_mask_output = (flat_zero_output if zero_output_for_mask
4182                             else nest.flatten(prev_output))
4183         flat_new_output = compute_masked_output(mask_t, flat_output,
4184                                                 flat_mask_output)
4185 
4186         # mask states
4187         flat_state = nest.flatten(states)
4188         flat_new_state = nest.flatten(new_states)
4189         for state, new_state in zip(flat_state, flat_new_state):
4190           if isinstance(new_state, ops.Tensor):
4191             new_state.set_shape(state.shape)
4192         flat_final_state = compute_masked_output(mask_t, flat_new_state,
4193                                                  flat_state)
4194         new_states = nest.pack_sequence_as(new_states, flat_final_state)
4195 
4196         output_ta_t = tuple(
4197             ta.write(time, out)
4198             for ta, out in zip(output_ta_t, flat_new_output))
4199         return (time + 1, output_ta_t,
4200                 tuple(flat_new_output)) + tuple(new_states)
4201 
4202       final_outputs = control_flow_ops.while_loop(
4203           body=_step,
4204           loop_vars=(time, output_ta, flat_zero_output) + states,
4205           **while_loop_kwargs)
4206       # Skip final_outputs[2] which is the output for final timestep.
4207       new_states = final_outputs[3:]
4208     else:
4209       def _step(time, output_ta_t, *states):
4210         """RNN step function.
4211 
4212         Arguments:
4213             time: Current timestep value.
4214             output_ta_t: TensorArray.
4215             *states: List of states.
4216 
4217         Returns:
4218             Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
4219         """
4220         current_input = tuple(ta.read(time) for ta in input_ta)
4221         current_input = nest.pack_sequence_as(inputs, current_input)
4222         output, new_states = step_function(current_input,
4223                                            tuple(states) + tuple(constants))
4224         flat_state = nest.flatten(states)
4225         flat_new_state = nest.flatten(new_states)
4226         for state, new_state in zip(flat_state, flat_new_state):
4227           if isinstance(new_state, ops.Tensor):
4228             new_state.set_shape(state.shape)
4229 
4230         flat_output = nest.flatten(output)
4231         output_ta_t = tuple(
4232             ta.write(time, out) for ta, out in zip(output_ta_t, flat_output))
4233         new_states = nest.pack_sequence_as(initial_states, flat_new_state)
4234         return (time + 1, output_ta_t) + tuple(new_states)
4235 
4236       final_outputs = control_flow_ops.while_loop(
4237           body=_step,
4238           loop_vars=(time, output_ta) + states,
4239           **while_loop_kwargs)
4240       new_states = final_outputs[2:]
4241 
4242     output_ta = final_outputs[1]
4243 
4244     outputs = tuple(o.stack() for o in output_ta)
4245     last_output = tuple(o[-1] for o in outputs)
4246 
4247     outputs = nest.pack_sequence_as(output_time_zero, outputs)
4248     last_output = nest.pack_sequence_as(output_time_zero, last_output)
4249 
4250   # static shape inference
4251   def set_shape(output_):
4252     if isinstance(output_, ops.Tensor):
4253       shape = output_.shape.as_list()
4254       shape[0] = time_steps
4255       shape[1] = batch
4256       output_.set_shape(shape)
4257     return output_
4258 
4259   outputs = nest.map_structure(set_shape, outputs)
4260 
4261   if not time_major:
4262     outputs = nest.map_structure(swap_batch_timestep, outputs)
4263 
4264   return last_output, outputs, new_states
4265 
4266 
4267 @keras_export('keras.backend.switch')
4268 def switch(condition, then_expression, else_expression):
4269   """Switches between two operations depending on a scalar value.
4270 
4271   Note that both `then_expression` and `else_expression`
4272   should be symbolic tensors of the *same shape*.
4273 
4274   Arguments:
4275       condition: tensor (`int` or `bool`).
4276       then_expression: either a tensor, or a callable that returns a tensor.
4277       else_expression: either a tensor, or a callable that returns a tensor.
4278 
4279   Returns:
4280       The selected tensor.
4281 
4282   Raises:
4283       ValueError: If rank of `condition` is greater than rank of expressions.
4284   """
4285   if condition.dtype != dtypes_module.bool:
4286     condition = math_ops.cast(condition, 'bool')
4287   cond_ndim = ndim(condition)
4288   if not cond_ndim:
4289     if not callable(then_expression):
4290 
4291       def then_expression_fn():
4292         return then_expression
4293     else:
4294       then_expression_fn = then_expression
4295     if not callable(else_expression):
4296 
4297       def else_expression_fn():
4298         return else_expression
4299     else:
4300       else_expression_fn = else_expression
4301     x = control_flow_ops.cond(condition, then_expression_fn, else_expression_fn)
4302   else:
4303     # tf.where needs its condition tensor
4304     # to be the same shape as its two
4305     # result tensors
4306     if callable(then_expression):
4307       then_expression = then_expression()
4308     if callable(else_expression):
4309       else_expression = else_expression()
4310     expr_ndim = ndim(then_expression)
4311     if cond_ndim > expr_ndim:
4312       raise ValueError('Rank of `condition` should be less than or'
4313                        ' equal to rank of `then_expression` and '
4314                        '`else_expression`. ndim(condition)=' + str(cond_ndim) +
4315                        ', ndim(then_expression)'
4316                        '=' + str(expr_ndim))
4317     if cond_ndim > 1:
4318       ndim_diff = expr_ndim - cond_ndim
4319       cond_shape = array_ops.concat(
4320           [array_ops.shape(condition), [1] * ndim_diff], axis=0)
4321       condition = array_ops.reshape(condition, cond_shape)
4322       expr_shape = array_ops.shape(then_expression)
4323       shape_diff = expr_shape - cond_shape
4324       tile_shape = array_ops.where_v2(shape_diff > 0, expr_shape,
4325                                       array_ops.ones_like(expr_shape))
4326       condition = array_ops.tile(condition, tile_shape)
4327     x = array_ops.where_v2(condition, then_expression, else_expression)
4328   return x
4329 
4330 
4331 @keras_export('keras.backend.in_train_phase')
4332 def in_train_phase(x, alt, training=None):
4333   """Selects `x` in train phase, and `alt` otherwise.
4334 
4335   Note that `alt` should have the *same shape* as `x`.
4336 
4337   Arguments:
4338       x: What to return in train phase
4339           (tensor or callable that returns a tensor).
4340       alt: What to return otherwise
4341           (tensor or callable that returns a tensor).
4342       training: Optional scalar tensor
4343           (or Python boolean, or Python integer)
4344           specifying the learning phase.
4345 
4346   Returns:
4347       Either `x` or `alt` based on the `training` flag.
4348       the `training` flag defaults to `K.learning_phase()`.
4349   """
4350   if training is None:
4351     training = learning_phase()
4352 
4353   # TODO(b/138862903): Handle the case when training is tensor.
4354   if not tensor_util.is_tensor(training):
4355     if training == 1 or training is True:
4356       if callable(x):
4357         return x()
4358       else:
4359         return x
4360 
4361     elif training == 0 or training is False:
4362       if callable(alt):
4363         return alt()
4364       else:
4365         return alt
4366 
4367   # else: assume learning phase is a placeholder tensor.
4368   x = switch(training, x, alt)
4369   return x
4370 
4371 
4372 @keras_export('keras.backend.in_test_phase')
4373 def in_test_phase(x, alt, training=None):
4374   """Selects `x` in test phase, and `alt` otherwise.
4375 
4376   Note that `alt` should have the *same shape* as `x`.
4377 
4378   Arguments:
4379       x: What to return in test phase
4380           (tensor or callable that returns a tensor).
4381       alt: What to return otherwise
4382           (tensor or callable that returns a tensor).
4383       training: Optional scalar tensor
4384           (or Python boolean, or Python integer)
4385           specifying the learning phase.
4386 
4387   Returns:
4388       Either `x` or `alt` based on `K.learning_phase`.
4389   """
4390   return in_train_phase(alt, x, training=training)
4391 
4392 
4393 # NN OPERATIONS
4394 
4395 
4396 @keras_export('keras.backend.relu')
4397 def relu(x, alpha=0., max_value=None, threshold=0):
4398   """Rectified linear unit.
4399 
4400   With default values, it returns element-wise `max(x, 0)`.
4401 
4402   Otherwise, it follows:
4403   `f(x) = max_value` for `x >= max_value`,
4404   `f(x) = x` for `threshold <= x < max_value`,
4405   `f(x) = alpha * (x - threshold)` otherwise.
4406 
4407   Arguments:
4408       x: A tensor or variable.
4409       alpha: A scalar, slope of negative section (default=`0.`).
4410       max_value: float. Saturation threshold.
4411       threshold: float. Threshold value for thresholded activation.
4412 
4413   Returns:
4414       A tensor.
4415   """
4416 
4417   if alpha != 0.:
4418     if max_value is None and threshold == 0:
4419       return nn.leaky_relu(x, alpha=alpha)
4420 
4421     if threshold != 0:
4422       negative_part = nn.relu(-x + threshold)
4423     else:
4424       negative_part = nn.relu(-x)
4425 
4426   clip_max = max_value is not None
4427 
4428   if threshold != 0:
4429     # computes x for x > threshold else 0
4430     x = x * math_ops.cast(math_ops.greater(x, threshold), floatx())
4431   elif max_value == 6:
4432     # if no threshold, then can use nn.relu6 native TF op for performance
4433     x = nn.relu6(x)
4434     clip_max = False
4435   else:
4436     x = nn.relu(x)
4437 
4438   if clip_max:
4439     max_value = _constant_to_tensor(max_value, x.dtype.base_dtype)
4440     zero = _constant_to_tensor(0, x.dtype.base_dtype)
4441     x = clip_ops.clip_by_value(x, zero, max_value)
4442 
4443   if alpha != 0.:
4444     alpha = _to_tensor(alpha, x.dtype.base_dtype)
4445     x -= alpha * negative_part
4446   return x
4447 
4448 
4449 @keras_export('keras.backend.elu')
4450 def elu(x, alpha=1.):
4451   """Exponential linear unit.
4452 
4453   Arguments:
4454       x: A tensor or variable to compute the activation function for.
4455       alpha: A scalar, slope of negative section.
4456 
4457   Returns:
4458       A tensor.
4459   """
4460   res = nn.elu(x)
4461   if alpha == 1:
4462     return res
4463   else:
4464     return array_ops.where_v2(x > 0, res, alpha * res)
4465 
4466 
4467 @keras_export('keras.backend.softmax')
4468 def softmax(x, axis=-1):
4469   """Softmax of a tensor.
4470 
4471   Arguments:
4472       x: A tensor or variable.
4473       axis: The dimension softmax would be performed on.
4474           The default is -1 which indicates the last dimension.
4475 
4476   Returns:
4477       A tensor.
4478   """
4479   return nn.softmax(x, axis=axis)
4480 
4481 
4482 @keras_export('keras.backend.softplus')
4483 def softplus(x):
4484   """Softplus of a tensor.
4485 
4486   Arguments:
4487       x: A tensor or variable.
4488 
4489   Returns:
4490       A tensor.
4491   """
4492   return nn.softplus(x)
4493 
4494 
4495 @keras_export('keras.backend.softsign')
4496 def softsign(x):
4497   """Softsign of a tensor.
4498 
4499   Arguments:
4500       x: A tensor or variable.
4501 
4502   Returns:
4503       A tensor.
4504   """
4505   return nn.softsign(x)
4506 
4507 
4508 def _backtrack_identity(tensor):
4509   while tensor.op.type == 'Identity':
4510     tensor = tensor.op.inputs[0]
4511   return tensor
4512 
4513 
4514 @keras_export('keras.backend.categorical_crossentropy')
4515 def categorical_crossentropy(target, output, from_logits=False, axis=-1):
4516   """Categorical crossentropy between an output tensor and a target tensor.
4517 
4518   Arguments:
4519       target: A tensor of the same shape as `output`.
4520       output: A tensor resulting from a softmax
4521           (unless `from_logits` is True, in which
4522           case `output` is expected to be the logits).
4523       from_logits: Boolean, whether `output` is the
4524           result of a softmax, or is a tensor of logits.
4525       axis: Int specifying the channels axis. `axis=-1` corresponds to data
4526           format `channels_last', and `axis=1` corresponds to data format
4527           `channels_first`.
4528 
4529   Returns:
4530       Output tensor.
4531 
4532   Raises:
4533       ValueError: if `axis` is neither -1 nor one of the axes of `output`.
4534 
4535   Example:
4536 
4537   >>> a = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 1.], shape=[3,3])
4538   >>> print(a)
4539   tf.Tensor(
4540     [[1. 0. 0.]
4541      [0. 1. 0.]
4542      [0. 0. 1.]], shape=(3, 3), dtype=float32)
4543   >>> b = tf.constant([.9, .05, .05, .5, .89, .6, .05, .01, .94], shape=[3,3])
4544   >>> print(b)
4545   tf.Tensor(
4546     [[0.9  0.05 0.05]
4547      [0.5  0.89 0.6 ]
4548      [0.05 0.01 0.94]], shape=(3, 3), dtype=float32)
4549   >>> loss = tf.keras.backend.categorical_crossentropy(a, b)
4550   >>> print(loss)
4551   tf.Tensor([0.10536055 0.8046684  0.06187541], shape=(3,), dtype=float32)
4552   >>> loss = tf.keras.backend.categorical_crossentropy(a, a)
4553   >>> print(loss)
4554   tf.Tensor([1.1920929e-07 1.1920929e-07 1.1920930e-07], shape=(3,),
4555   dtype=float32)
4556 
4557   """
4558   target.shape.assert_is_compatible_with(output.shape)
4559   if from_logits:
4560     return nn.softmax_cross_entropy_with_logits_v2(
4561         labels=target, logits=output, axis=axis)
4562 
4563   if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
4564     output = _backtrack_identity(output)
4565     if output.op.type == 'Softmax':
4566       # When softmax activation function is used for output operation, we
4567       # use logits from the softmax function directly to compute loss in order
4568       # to prevent collapsing zero when training.
4569       # See b/117284466
4570       assert len(output.op.inputs) == 1
4571       output = output.op.inputs[0]
4572       return nn.softmax_cross_entropy_with_logits_v2(
4573           labels=target, logits=output, axis=axis)
4574 
4575   # scale preds so that the class probas of each sample sum to 1
4576   output = output / math_ops.reduce_sum(output, axis, True)
4577   # Compute cross entropy from probabilities.
4578   epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
4579   output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
4580   return -math_ops.reduce_sum(target * math_ops.log(output), axis)
4581 
4582 
4583 @keras_export('keras.backend.sparse_categorical_crossentropy')
4584 def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
4585   """Categorical crossentropy with integer targets.
4586 
4587   Arguments:
4588       target: An integer tensor.
4589       output: A tensor resulting from a softmax
4590           (unless `from_logits` is True, in which
4591           case `output` is expected to be the logits).
4592       from_logits: Boolean, whether `output` is the
4593           result of a softmax, or is a tensor of logits.
4594       axis: Int specifying the channels axis. `axis=-1` corresponds to data
4595           format `channels_last', and `axis=1` corresponds to data format
4596           `channels_first`.
4597 
4598   Returns:
4599       Output tensor.
4600 
4601   Raises:
4602       ValueError: if `axis` is neither -1 nor one of the axes of `output`.
4603   """
4604   if not from_logits and not isinstance(
4605       output, (ops.EagerTensor, variables_module.Variable)):
4606     output = _backtrack_identity(output)
4607     if output.op.type == 'Softmax':
4608       # When softmax activation function is used for output operation, we
4609       # use logits from the softmax function directly to compute loss in order
4610       # to prevent collapsing zero when training.
4611       # See b/117284466
4612       assert len(output.op.inputs) == 1
4613       output = output.op.inputs[0]
4614       from_logits = True
4615 
4616   if not from_logits:
4617     epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
4618     output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
4619     output = math_ops.log(output)
4620 
4621   if isinstance(output.shape, (tuple, list)):
4622     output_rank = len(output.shape)
4623   else:
4624     output_rank = output.shape.ndims
4625   if output_rank is not None:
4626     axis %= output_rank
4627     if axis != output_rank - 1:
4628       permutation = list(
4629           itertools.chain(range(axis), range(axis + 1, output_rank), [axis]))
4630       output = array_ops.transpose(output, perm=permutation)
4631   elif axis != -1:
4632     raise ValueError(
4633         'Cannot compute sparse categorical crossentropy with `axis={}` on an '
4634         'output tensor with unknown rank'.format(axis))
4635 
4636   target = cast(target, 'int64')
4637 
4638   # Try to adjust the shape so that rank of labels = rank of logits - 1.
4639   output_shape = array_ops.shape_v2(output)
4640   target_rank = target.shape.ndims
4641 
4642   update_shape = (
4643       target_rank is not None and output_rank is not None and
4644       target_rank != output_rank - 1)
4645   if update_shape:
4646     target = flatten(target)
4647     output = array_ops.reshape(output, [-1, output_shape[-1]])
4648 
4649   if py_any(_is_symbolic_tensor(v) for v in [target, output]):
4650     with get_graph().as_default():
4651       res = nn.sparse_softmax_cross_entropy_with_logits_v2(
4652           labels=target, logits=output)
4653   else:
4654     res = nn.sparse_softmax_cross_entropy_with_logits_v2(
4655         labels=target, logits=output)
4656 
4657   if update_shape and output_rank >= 3:
4658     # If our output includes timesteps or spatial dimensions we need to reshape
4659     return array_ops.reshape(res, output_shape[:-1])
4660   else:
4661     return res
4662 
4663 
4664 @keras_export('keras.backend.binary_crossentropy')
4665 def binary_crossentropy(target, output, from_logits=False):
4666   """Binary crossentropy between an output tensor and a target tensor.
4667 
4668   Arguments:
4669       target: A tensor with the same shape as `output`.
4670       output: A tensor.
4671       from_logits: Whether `output` is expected to be a logits tensor.
4672           By default, we consider that `output`
4673           encodes a probability distribution.
4674 
4675   Returns:
4676       A tensor.
4677   """
4678   if from_logits:
4679     return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
4680 
4681   if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
4682     output = _backtrack_identity(output)
4683     if output.op.type == 'Sigmoid':
4684       # When sigmoid activation function is used for output operation, we
4685       # use logits from the sigmoid function directly to compute loss in order
4686       # to prevent collapsing zero when training.
4687       assert len(output.op.inputs) == 1
4688       output = output.op.inputs[0]
4689       return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
4690 
4691   epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
4692   output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
4693 
4694   # Compute cross entropy from probabilities.
4695   bce = target * math_ops.log(output + epsilon())
4696   bce += (1 - target) * math_ops.log(1 - output + epsilon())
4697   return -bce
4698 
4699 
4700 @keras_export('keras.backend.sigmoid')
4701 def sigmoid(x):
4702   """Element-wise sigmoid.
4703 
4704   Arguments:
4705       x: A tensor or variable.
4706 
4707   Returns:
4708       A tensor.
4709   """
4710   return nn.sigmoid(x)
4711 
4712 
4713 @keras_export('keras.backend.hard_sigmoid')
4714 def hard_sigmoid(x):
4715   """Segment-wise linear approximation of sigmoid.
4716 
4717   Faster than sigmoid.
4718   Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
4719   In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
4720 
4721   Arguments:
4722       x: A tensor or variable.
4723 
4724   Returns:
4725       A tensor.
4726   """
4727   point_two = _constant_to_tensor(0.2, x.dtype.base_dtype)
4728   point_five = _constant_to_tensor(0.5, x.dtype.base_dtype)
4729   x = math_ops.mul(x, point_two)
4730   x = math_ops.add(x, point_five)
4731   x = clip_ops.clip_by_value(x, 0., 1.)
4732   return x
4733 
4734 
4735 @keras_export('keras.backend.tanh')
4736 def tanh(x):
4737   """Element-wise tanh.
4738 
4739   Arguments:
4740       x: A tensor or variable.
4741 
4742   Returns:
4743       A tensor.
4744   """
4745   return nn.tanh(x)
4746 
4747 
4748 @keras_export('keras.backend.dropout')
4749 def dropout(x, level, noise_shape=None, seed=None):
4750   """Sets entries in `x` to zero at random, while scaling the entire tensor.
4751 
4752   Arguments:
4753       x: tensor
4754       level: fraction of the entries in the tensor
4755           that will be set to 0.
4756       noise_shape: shape for randomly generated keep/drop flags,
4757           must be broadcastable to the shape of `x`
4758       seed: random seed to ensure determinism.
4759 
4760   Returns:
4761       A tensor.
4762   """
4763   if seed is None:
4764     seed = np.random.randint(10e6)
4765   return nn.dropout_v2(x, rate=level, noise_shape=noise_shape, seed=seed)
4766 
4767 
4768 @keras_export('keras.backend.l2_normalize')
4769 def l2_normalize(x, axis=None):
4770   """Normalizes a tensor wrt the L2 norm alongside the specified axis.
4771 
4772   Arguments:
4773       x: Tensor or variable.
4774       axis: axis along which to perform normalization.
4775 
4776   Returns:
4777       A tensor.
4778   """
4779   return nn.l2_normalize(x, axis=axis)
4780 
4781 
4782 @keras_export('keras.backend.in_top_k')
4783 def in_top_k(predictions, targets, k):
4784   """Returns whether the `targets` are in the top `k` `predictions`.
4785 
4786   Arguments:
4787       predictions: A tensor of shape `(batch_size, classes)` and type `float32`.
4788       targets: A 1D tensor of length `batch_size` and type `int32` or `int64`.
4789       k: An `int`, number of top elements to consider.
4790 
4791   Returns:
4792       A 1D tensor of length `batch_size` and type `bool`.
4793       `output[i]` is `True` if `predictions[i, targets[i]]` is within top-`k`
4794       values of `predictions[i]`.
4795   """
4796   return nn.in_top_k(predictions, targets, k)
4797 
4798 
4799 # CONVOLUTIONS
4800 
4801 
4802 def _preprocess_conv1d_input(x, data_format):
4803   """Transpose and cast the input before the conv1d.
4804 
4805   Arguments:
4806       x: input tensor.
4807       data_format: string, `"channels_last"` or `"channels_first"`.
4808 
4809   Returns:
4810       A tensor.
4811   """
4812   tf_data_format = 'NWC'  # to pass TF Conv2dNative operations
4813   if data_format == 'channels_first':
4814     if not _has_nchw_support():
4815       x = array_ops.transpose(x, (0, 2, 1))  # NCW -> NWC
4816     else:
4817       tf_data_format = 'NCW'
4818   return x, tf_data_format
4819 
4820 
4821 def _preprocess_conv2d_input(x, data_format, force_transpose=False):
4822   """Transpose and cast the input before the conv2d.
4823 
4824   Arguments:
4825       x: input tensor.
4826       data_format: string, `"channels_last"` or `"channels_first"`.
4827       force_transpose: Boolean. If True, the input will always be transposed
4828           from NCHW to NHWC if `data_format` is `"channels_first"`.
4829           If False, the transposition only occurs on CPU (GPU ops are
4830           assumed to support NCHW).
4831 
4832   Returns:
4833       A tensor.
4834   """
4835   tf_data_format = 'NHWC'
4836   if data_format == 'channels_first':
4837     if not _has_nchw_support() or force_transpose:
4838       x = array_ops.transpose(x, (0, 2, 3, 1))  # NCHW -> NHWC
4839     else:
4840       tf_data_format = 'NCHW'
4841   return x, tf_data_format
4842 
4843 
4844 def _preprocess_conv3d_input(x, data_format):
4845   """Transpose and cast the input before the conv3d.
4846 
4847   Arguments:
4848       x: input tensor.
4849       data_format: string, `"channels_last"` or `"channels_first"`.
4850 
4851   Returns:
4852       A tensor.
4853   """
4854   tf_data_format = 'NDHWC'
4855   if data_format == 'channels_first':
4856     if not _has_nchw_support():
4857       x = array_ops.transpose(x, (0, 2, 3, 4, 1))
4858     else:
4859       tf_data_format = 'NCDHW'
4860   return x, tf_data_format
4861 
4862 
4863 def _preprocess_padding(padding):
4864   """Convert keras' padding to TensorFlow's padding.
4865 
4866   Arguments:
4867       padding: string, one of 'same' , 'valid'
4868 
4869   Returns:
4870       a string, one of 'SAME', 'VALID'.
4871 
4872   Raises:
4873       ValueError: if invalid `padding'`
4874   """
4875   if padding == 'same':
4876     padding = 'SAME'
4877   elif padding == 'valid':
4878     padding = 'VALID'
4879   else:
4880     raise ValueError('Invalid padding: ' + str(padding))
4881   return padding
4882 
4883 
4884 @keras_export('keras.backend.conv1d')
4885 def conv1d(x,
4886            kernel,
4887            strides=1,
4888            padding='valid',
4889            data_format=None,
4890            dilation_rate=1):
4891   """1D convolution.
4892 
4893   Arguments:
4894       x: Tensor or variable.
4895       kernel: kernel tensor.
4896       strides: stride integer.
4897       padding: string, `"same"`, `"causal"` or `"valid"`.
4898       data_format: string, one of "channels_last", "channels_first".
4899       dilation_rate: integer dilate rate.
4900 
4901   Returns:
4902       A tensor, result of 1D convolution.
4903 
4904   Raises:
4905       ValueError: if `data_format` is neither `channels_last` or
4906       `channels_first`.
4907   """
4908   if data_format is None:
4909     data_format = image_data_format()
4910   if data_format not in {'channels_first', 'channels_last'}:
4911     raise ValueError('Unknown data_format: ' + str(data_format))
4912 
4913   kernel_shape = kernel.shape.as_list()
4914   if padding == 'causal':
4915     # causal (dilated) convolution:
4916     left_pad = dilation_rate * (kernel_shape[0] - 1)
4917     x = temporal_padding(x, (left_pad, 0))
4918     padding = 'valid'
4919   padding = _preprocess_padding(padding)
4920 
4921   x, tf_data_format = _preprocess_conv1d_input(x, data_format)
4922   x = nn.convolution(
4923       input=x,
4924       filter=kernel,
4925       dilation_rate=dilation_rate,
4926       strides=strides,
4927       padding=padding,
4928       data_format=tf_data_format)
4929   if data_format == 'channels_first' and tf_data_format == 'NWC':
4930     x = array_ops.transpose(x, (0, 2, 1))  # NWC -> NCW
4931   return x
4932 
4933 
4934 @keras_export('keras.backend.conv2d')
4935 def conv2d(x,
4936            kernel,
4937            strides=(1, 1),
4938            padding='valid',
4939            data_format=None,
4940            dilation_rate=(1, 1)):
4941   """2D convolution.
4942 
4943   Arguments:
4944       x: Tensor or variable.
4945       kernel: kernel tensor.
4946       strides: strides tuple.
4947       padding: string, `"same"` or `"valid"`.
4948       data_format: `"channels_last"` or `"channels_first"`.
4949       dilation_rate: tuple of 2 integers.
4950 
4951   Returns:
4952       A tensor, result of 2D convolution.
4953 
4954   Raises:
4955       ValueError: if `data_format` is neither `channels_last` or
4956       `channels_first`.
4957   """
4958   if data_format is None:
4959     data_format = image_data_format()
4960   if data_format not in {'channels_first', 'channels_last'}:
4961     raise ValueError('Unknown data_format: ' + str(data_format))
4962 
4963   x, tf_data_format = _preprocess_conv2d_input(x, data_format)
4964   padding = _preprocess_padding(padding)
4965   x = nn.convolution(
4966       input=x,
4967       filter=kernel,
4968       dilation_rate=dilation_rate,
4969       strides=strides,
4970       padding=padding,
4971       data_format=tf_data_format)
4972   if data_format == 'channels_first' and tf_data_format == 'NHWC':
4973     x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
4974   return x
4975 
4976 
4977 @keras_export('keras.backend.conv2d_transpose')
4978 def conv2d_transpose(x,
4979                      kernel,
4980                      output_shape,
4981                      strides=(1, 1),
4982                      padding='valid',
4983                      data_format=None,
4984                      dilation_rate=(1, 1)):
4985   """2D deconvolution (i.e.
4986 
4987   transposed convolution).
4988 
4989   Arguments:
4990       x: Tensor or variable.
4991       kernel: kernel tensor.
4992       output_shape: 1D int tensor for the output shape.
4993       strides: strides tuple.
4994       padding: string, `"same"` or `"valid"`.
4995       data_format: string, `"channels_last"` or `"channels_first"`.
4996       dilation_rate: Tuple of 2 integers.
4997 
4998   Returns:
4999       A tensor, result of transposed 2D convolution.
5000 
5001   Raises:
5002       ValueError: if `data_format` is neither `channels_last` or
5003       `channels_first`.
5004   """
5005   if data_format is None:
5006     data_format = image_data_format()
5007   if data_format not in {'channels_first', 'channels_last'}:
5008     raise ValueError('Unknown data_format: ' + str(data_format))
5009 
5010   # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
5011   if data_format == 'channels_first' and dilation_rate != (1, 1):
5012     force_transpose = True
5013   else:
5014     force_transpose = False
5015 
5016   x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
5017 
5018   if data_format == 'channels_first' and tf_data_format == 'NHWC':
5019     output_shape = (output_shape[0], output_shape[2], output_shape[3],
5020                     output_shape[1])
5021   if output_shape[0] is None:
5022     output_shape = (shape(x)[0],) + tuple(output_shape[1:])
5023 
5024   if isinstance(output_shape, (tuple, list)):
5025     output_shape = array_ops.stack(list(output_shape))
5026 
5027   padding = _preprocess_padding(padding)
5028   if tf_data_format == 'NHWC':
5029     strides = (1,) + strides + (1,)
5030   else:
5031     strides = (1, 1) + strides
5032 
5033   if dilation_rate == (1, 1):
5034     x = nn.conv2d_transpose(x, kernel, output_shape, strides,
5035                             padding=padding,
5036                             data_format=tf_data_format)
5037   else:
5038     assert dilation_rate[0] == dilation_rate[1]
5039     x = nn.atrous_conv2d_transpose(
5040         x,
5041         kernel,
5042         output_shape,
5043         rate=dilation_rate[0],
5044         padding=padding)
5045   if data_format == 'channels_first' and tf_data_format == 'NHWC':
5046     x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5047   return x
5048 
5049 
5050 def separable_conv1d(x,
5051                      depthwise_kernel,
5052                      pointwise_kernel,
5053                      strides=1,
5054                      padding='valid',
5055                      data_format=None,
5056                      dilation_rate=1):
5057   """1D convolution with separable filters.
5058 
5059   Arguments:
5060       x: input tensor
5061       depthwise_kernel: convolution kernel for the depthwise convolution.
5062       pointwise_kernel: kernel for the 1x1 convolution.
5063       strides: stride integer.
5064       padding: string, `"same"` or `"valid"`.
5065       data_format: string, `"channels_last"` or `"channels_first"`.
5066       dilation_rate: integer dilation rate.
5067 
5068   Returns:
5069       Output tensor.
5070 
5071   Raises:
5072       ValueError: if `data_format` is neither `channels_last` or
5073       `channels_first`.
5074   """
5075   if data_format is None:
5076     data_format = image_data_format()
5077   if data_format not in {'channels_first', 'channels_last'}:
5078     raise ValueError('Unknown data_format: ' + str(data_format))
5079 
5080   if isinstance(strides, int):
5081     strides = (strides,)
5082   if isinstance(dilation_rate, int):
5083     dilation_rate = (dilation_rate,)
5084 
5085   x, tf_data_format = _preprocess_conv1d_input(x, data_format)
5086   padding = _preprocess_padding(padding)
5087   if not isinstance(strides, tuple):
5088     strides = tuple(strides)
5089   if tf_data_format == 'NWC':
5090     spatial_start_dim = 1
5091     strides = (1,) + strides * 2 + (1,)
5092   else:
5093     spatial_start_dim = 2
5094     strides = (1, 1) + strides * 2
5095   x = array_ops.expand_dims(x, spatial_start_dim)
5096   depthwise_kernel = array_ops.expand_dims(depthwise_kernel, 0)
5097   pointwise_kernel = array_ops.expand_dims(pointwise_kernel, 0)
5098   dilation_rate = (1,) + dilation_rate
5099 
5100   x = nn.separable_conv2d(
5101       x,
5102       depthwise_kernel,
5103       pointwise_kernel,
5104       strides=strides,
5105       padding=padding,
5106       rate=dilation_rate,
5107       data_format=tf_data_format)
5108 
5109   x = array_ops.squeeze(x, [spatial_start_dim])
5110 
5111   if data_format == 'channels_first' and tf_data_format == 'NWC':
5112     x = array_ops.transpose(x, (0, 2, 1))  # NWC -> NCW
5113 
5114   return x
5115 
5116 
5117 @keras_export('keras.backend.separable_conv2d')
5118 def separable_conv2d(x,
5119                      depthwise_kernel,
5120                      pointwise_kernel,
5121                      strides=(1, 1),
5122                      padding='valid',
5123                      data_format=None,
5124                      dilation_rate=(1, 1)):
5125   """2D convolution with separable filters.
5126 
5127   Arguments:
5128       x: input tensor
5129       depthwise_kernel: convolution kernel for the depthwise convolution.
5130       pointwise_kernel: kernel for the 1x1 convolution.
5131       strides: strides tuple (length 2).
5132       padding: string, `"same"` or `"valid"`.
5133       data_format: string, `"channels_last"` or `"channels_first"`.
5134       dilation_rate: tuple of integers,
5135           dilation rates for the separable convolution.
5136 
5137   Returns:
5138       Output tensor.
5139 
5140   Raises:
5141       ValueError: if `data_format` is neither `channels_last` or
5142       `channels_first`.
5143       ValueError: if `strides` is not a tuple of 2 integers.
5144   """
5145   if data_format is None:
5146     data_format = image_data_format()
5147   if data_format not in {'channels_first', 'channels_last'}:
5148     raise ValueError('Unknown data_format: ' + str(data_format))
5149   if len(strides) != 2:
5150     raise ValueError('`strides` must be a tuple of 2 integers.')
5151 
5152   x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5153   padding = _preprocess_padding(padding)
5154   if not isinstance(strides, tuple):
5155     strides = tuple(strides)
5156   if tf_data_format == 'NHWC':
5157     strides = (1,) + strides + (1,)
5158   else:
5159     strides = (1, 1) + strides
5160 
5161   x = nn.separable_conv2d(
5162       x,
5163       depthwise_kernel,
5164       pointwise_kernel,
5165       strides=strides,
5166       padding=padding,
5167       rate=dilation_rate,
5168       data_format=tf_data_format)
5169   if data_format == 'channels_first' and tf_data_format == 'NHWC':
5170     x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5171   return x
5172 
5173 
5174 @keras_export('keras.backend.depthwise_conv2d')
5175 def depthwise_conv2d(x,
5176                      depthwise_kernel,
5177                      strides=(1, 1),
5178                      padding='valid',
5179                      data_format=None,
5180                      dilation_rate=(1, 1)):
5181   """2D convolution with separable filters.
5182 
5183   Arguments:
5184       x: input tensor
5185       depthwise_kernel: convolution kernel for the depthwise convolution.
5186       strides: strides tuple (length 2).
5187       padding: string, `"same"` or `"valid"`.
5188       data_format: string, `"channels_last"` or `"channels_first"`.
5189       dilation_rate: tuple of integers,
5190           dilation rates for the separable convolution.
5191 
5192   Returns:
5193       Output tensor.
5194 
5195   Raises:
5196       ValueError: if `data_format` is neither `channels_last` or
5197       `channels_first`.
5198   """
5199   if data_format is None:
5200     data_format = image_data_format()
5201   if data_format not in {'channels_first', 'channels_last'}:
5202     raise ValueError('Unknown data_format: ' + str(data_format))
5203 
5204   x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5205   padding = _preprocess_padding(padding)
5206   if tf_data_format == 'NHWC':
5207     strides = (1,) + strides + (1,)
5208   else:
5209     strides = (1, 1) + strides
5210 
5211   x = nn.depthwise_conv2d(
5212       x,
5213       depthwise_kernel,
5214       strides=strides,
5215       padding=padding,
5216       rate=dilation_rate,
5217       data_format=tf_data_format)
5218   if data_format == 'channels_first' and tf_data_format == 'NHWC':
5219     x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5220   return x
5221 
5222 
5223 @keras_export('keras.backend.conv3d')
5224 def conv3d(x,
5225            kernel,
5226            strides=(1, 1, 1),
5227            padding='valid',
5228            data_format=None,
5229            dilation_rate=(1, 1, 1)):
5230   """3D convolution.
5231 
5232   Arguments:
5233       x: Tensor or variable.
5234       kernel: kernel tensor.
5235       strides: strides tuple.
5236       padding: string, `"same"` or `"valid"`.
5237       data_format: string, `"channels_last"` or `"channels_first"`.
5238       dilation_rate: tuple of 3 integers.
5239 
5240   Returns:
5241       A tensor, result of 3D convolution.
5242 
5243   Raises:
5244       ValueError: if `data_format` is neither `channels_last` or
5245       `channels_first`.
5246   """
5247   if data_format is None:
5248     data_format = image_data_format()
5249   if data_format not in {'channels_first', 'channels_last'}:
5250     raise ValueError('Unknown data_format: ' + str(data_format))
5251 
5252   x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5253   padding = _preprocess_padding(padding)
5254   x = nn.convolution(
5255       input=x,
5256       filter=kernel,
5257       dilation_rate=dilation_rate,
5258       strides=strides,
5259       padding=padding,
5260       data_format=tf_data_format)
5261   if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5262     x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5263   return x
5264 
5265 
5266 def conv3d_transpose(x,
5267                      kernel,
5268                      output_shape,
5269                      strides=(1, 1, 1),
5270                      padding='valid',
5271                      data_format=None):
5272   """3D deconvolution (i.e.
5273 
5274   transposed convolution).
5275 
5276   Arguments:
5277       x: input tensor.
5278       kernel: kernel tensor.
5279       output_shape: 1D int tensor for the output shape.
5280       strides: strides tuple.
5281       padding: string, "same" or "valid".
5282       data_format: string, `"channels_last"` or `"channels_first"`.
5283 
5284   Returns:
5285       A tensor, result of transposed 3D convolution.
5286 
5287   Raises:
5288       ValueError: if `data_format` is neither `channels_last` or
5289       `channels_first`.
5290   """
5291   if data_format is None:
5292     data_format = image_data_format()
5293   if data_format not in {'channels_first', 'channels_last'}:
5294     raise ValueError('Unknown data_format: ' + str(data_format))
5295   if isinstance(output_shape, (tuple, list)):
5296     output_shape = array_ops.stack(output_shape)
5297 
5298   x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5299 
5300   if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5301     output_shape = (output_shape[0], output_shape[2], output_shape[3],
5302                     output_shape[4], output_shape[1])
5303   if output_shape[0] is None:
5304     output_shape = (array_ops.shape(x)[0],) + tuple(output_shape[1:])
5305     output_shape = array_ops.stack(list(output_shape))
5306 
5307   padding = _preprocess_padding(padding)
5308   if tf_data_format == 'NDHWC':
5309     strides = (1,) + strides + (1,)
5310   else:
5311     strides = (1, 1) + strides
5312 
5313   x = nn.conv3d_transpose(
5314       x,
5315       kernel,
5316       output_shape,
5317       strides,
5318       padding=padding,
5319       data_format=tf_data_format)
5320   if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5321     x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5322   return x
5323 
5324 
5325 @keras_export('keras.backend.pool2d')
5326 def pool2d(x,
5327            pool_size,
5328            strides=(1, 1),
5329            padding='valid',
5330            data_format=None,
5331            pool_mode='max'):
5332   """2D Pooling.
5333 
5334   Arguments:
5335       x: Tensor or variable.
5336       pool_size: tuple of 2 integers.
5337       strides: tuple of 2 integers.
5338       padding: string, `"same"` or `"valid"`.
5339       data_format: string, `"channels_last"` or `"channels_first"`.
5340       pool_mode: string, `"max"` or `"avg"`.
5341 
5342   Returns:
5343       A tensor, result of 2D pooling.
5344 
5345   Raises:
5346       ValueError: if `data_format` is neither `"channels_last"` or
5347       `"channels_first"`.
5348       ValueError: if `pool_size` is not a tuple of 2 integers.
5349       ValueError: if `strides` is not a tuple of 2 integers.
5350       ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
5351   """
5352   if data_format is None:
5353     data_format = image_data_format()
5354   if data_format not in {'channels_first', 'channels_last'}:
5355     raise ValueError('Unknown data_format: ' + str(data_format))
5356   if len(pool_size) != 2:
5357     raise ValueError('`pool_size` must be a tuple of 2 integers.')
5358   if len(strides) != 2:
5359     raise ValueError('`strides` must be a tuple of 2 integers.')
5360 
5361   x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5362   padding = _preprocess_padding(padding)
5363   if tf_data_format == 'NHWC':
5364     strides = (1,) + strides + (1,)
5365     pool_size = (1,) + pool_size + (1,)
5366   else:
5367     strides = (1, 1) + strides
5368     pool_size = (1, 1) + pool_size
5369 
5370   if pool_mode == 'max':
5371     x = nn.max_pool(
5372         x, pool_size, strides, padding=padding, data_format=tf_data_format)
5373   elif pool_mode == 'avg':
5374     x = nn.avg_pool(
5375         x, pool_size, strides, padding=padding, data_format=tf_data_format)
5376   else:
5377     raise ValueError('Invalid pooling mode: ' + str(pool_mode))
5378 
5379   if data_format == 'channels_first' and tf_data_format == 'NHWC':
5380     x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
5381   return x
5382 
5383 
5384 @keras_export('keras.backend.pool3d')
5385 def pool3d(x,
5386            pool_size,
5387            strides=(1, 1, 1),
5388            padding='valid',
5389            data_format=None,
5390            pool_mode='max'):
5391   """3D Pooling.
5392 
5393   Arguments:
5394       x: Tensor or variable.
5395       pool_size: tuple of 3 integers.
5396       strides: tuple of 3 integers.
5397       padding: string, `"same"` or `"valid"`.
5398       data_format: string, `"channels_last"` or `"channels_first"`.
5399       pool_mode: string, `"max"` or `"avg"`.
5400 
5401   Returns:
5402       A tensor, result of 3D pooling.
5403 
5404   Raises:
5405       ValueError: if `data_format` is neither `"channels_last"` or
5406       `"channels_first"`.
5407       ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
5408   """
5409   if data_format is None:
5410     data_format = image_data_format()
5411   if data_format not in {'channels_first', 'channels_last'}:
5412     raise ValueError('Unknown data_format: ' + str(data_format))
5413 
5414   x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5415   padding = _preprocess_padding(padding)
5416   if tf_data_format == 'NDHWC':
5417     strides = (1,) + strides + (1,)
5418     pool_size = (1,) + pool_size + (1,)
5419   else:
5420     strides = (1, 1) + strides
5421     pool_size = (1, 1) + pool_size
5422 
5423   if pool_mode == 'max':
5424     x = nn.max_pool3d(
5425         x, pool_size, strides, padding=padding, data_format=tf_data_format)
5426   elif pool_mode == 'avg':
5427     x = nn.avg_pool3d(
5428         x, pool_size, strides, padding=padding, data_format=tf_data_format)
5429   else:
5430     raise ValueError('Invalid pooling mode: ' + str(pool_mode))
5431 
5432   if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5433     x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5434   return x
5435 
5436 
5437 def local_conv(inputs,
5438                kernel,
5439                kernel_size,
5440                strides,
5441                output_shape,
5442                data_format=None):
5443   """Apply N-D convolution with un-shared weights.
5444 
5445   Arguments:
5446       inputs: (N+2)-D tensor with shape
5447           (batch_size, channels_in, d_in1, ..., d_inN)
5448           if data_format='channels_first', or
5449           (batch_size, d_in1, ..., d_inN, channels_in)
5450           if data_format='channels_last'.
5451       kernel: the unshared weight for N-D convolution,
5452           with shape (output_items, feature_dim, channels_out), where
5453           feature_dim = np.prod(kernel_size) * channels_in,
5454           output_items = np.prod(output_shape).
5455       kernel_size: a tuple of N integers, specifying the
5456           spatial dimensions of the N-D convolution window.
5457       strides: a tuple of N integers, specifying the strides
5458           of the convolution along the spatial dimensions.
5459       output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial
5460           dimensionality of the output.
5461       data_format: string, "channels_first" or "channels_last".
5462 
5463   Returns:
5464       An (N+2)-D tensor with shape:
5465       (batch_size, channels_out) + output_shape
5466       if data_format='channels_first', or:
5467       (batch_size,) + output_shape + (channels_out,)
5468       if data_format='channels_last'.
5469 
5470   Raises:
5471       ValueError: if `data_format` is neither
5472       `channels_last` nor `channels_first`.
5473   """
5474   if data_format is None:
5475     data_format = image_data_format()
5476   if data_format not in {'channels_first', 'channels_last'}:
5477     raise ValueError('Unknown data_format: ' + str(data_format))
5478 
5479   kernel_shape = int_shape(kernel)
5480   feature_dim = kernel_shape[1]
5481   channels_out = kernel_shape[-1]
5482   ndims = len(output_shape)
5483   spatial_dimensions = list(range(ndims))
5484 
5485   xs = []
5486   output_axes_ticks = [range(axis_max) for axis_max in output_shape]
5487   for position in itertools.product(*output_axes_ticks):
5488     slices = [slice(None)]
5489 
5490     if data_format == 'channels_first':
5491       slices.append(slice(None))
5492 
5493     slices.extend(
5494         slice(position[d] * strides[d], position[d] * strides[d] +
5495               kernel_size[d]) for d in spatial_dimensions)
5496 
5497     if data_format == 'channels_last':
5498       slices.append(slice(None))
5499 
5500     xs.append(reshape(inputs[slices], (1, -1, feature_dim)))
5501 
5502   x_aggregate = concatenate(xs, axis=0)
5503   output = batch_dot(x_aggregate, kernel)
5504   output = reshape(output, output_shape + (-1, channels_out))
5505 
5506   if data_format == 'channels_first':
5507     permutation = [ndims, ndims + 1] + spatial_dimensions
5508   else:
5509     permutation = [ndims] + spatial_dimensions + [ndims + 1]
5510 
5511   return permute_dimensions(output, permutation)
5512 
5513 
5514 @keras_export('keras.backend.local_conv1d')
5515 def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
5516   """Apply 1D conv with un-shared weights.
5517 
5518   Arguments:
5519       inputs: 3D tensor with shape:
5520           (batch_size, steps, input_dim)
5521           if data_format is "channels_last" or
5522           (batch_size, input_dim, steps)
5523           if data_format is "channels_first".
5524       kernel: the unshared weight for convolution,
5525           with shape (output_length, feature_dim, filters).
5526       kernel_size: a tuple of a single integer,
5527           specifying the length of the 1D convolution window.
5528       strides: a tuple of a single integer,
5529           specifying the stride length of the convolution.
5530       data_format: the data format, channels_first or channels_last.
5531 
5532   Returns:
5533       A 3d tensor with shape:
5534       (batch_size, output_length, filters)
5535       if data_format='channels_first'
5536       or 3D tensor with shape:
5537       (batch_size, filters, output_length)
5538       if data_format='channels_last'.
5539   """
5540   output_shape = (kernel.shape[0],)
5541   return local_conv(inputs,
5542                     kernel,
5543                     kernel_size,
5544                     strides,
5545                     output_shape,
5546                     data_format)
5547 
5548 
5549 @keras_export('keras.backend.local_conv2d')
5550 def local_conv2d(inputs,
5551                  kernel,
5552                  kernel_size,
5553                  strides,
5554                  output_shape,
5555                  data_format=None):
5556   """Apply 2D conv with un-shared weights.
5557 
5558   Arguments:
5559       inputs: 4D tensor with shape:
5560           (batch_size, filters, new_rows, new_cols)
5561           if data_format='channels_first'
5562           or 4D tensor with shape:
5563           (batch_size, new_rows, new_cols, filters)
5564           if data_format='channels_last'.
5565       kernel: the unshared weight for convolution,
5566           with shape (output_items, feature_dim, filters).
5567       kernel_size: a tuple of 2 integers, specifying the
5568           width and height of the 2D convolution window.
5569       strides: a tuple of 2 integers, specifying the strides
5570           of the convolution along the width and height.
5571       output_shape: a tuple with (output_row, output_col).
5572       data_format: the data format, channels_first or channels_last.
5573 
5574   Returns:
5575       A 4D tensor with shape:
5576       (batch_size, filters, new_rows, new_cols)
5577       if data_format='channels_first'
5578       or 4D tensor with shape:
5579       (batch_size, new_rows, new_cols, filters)
5580       if data_format='channels_last'.
5581   """
5582   return local_conv(inputs,
5583                     kernel,
5584                     kernel_size,
5585                     strides,
5586                     output_shape,
5587                     data_format)
5588 
5589 
5590 @keras_export('keras.backend.bias_add')
5591 def bias_add(x, bias, data_format=None):
5592   """Adds a bias vector to a tensor.
5593 
5594   Arguments:
5595       x: Tensor or variable.
5596       bias: Bias tensor to add.
5597       data_format: string, `"channels_last"` or `"channels_first"`.
5598 
5599   Returns:
5600       Output tensor.
5601 
5602   Raises:
5603       ValueError: In one of the two cases below:
5604                   1. invalid `data_format` argument.
5605                   2. invalid bias shape.
5606                      the bias should be either a vector or
5607                      a tensor with ndim(x) - 1 dimension
5608   """
5609   if data_format is None:
5610     data_format = image_data_format()
5611   if data_format not in {'channels_first', 'channels_last'}:
5612     raise ValueError('Unknown data_format: ' + str(data_format))
5613   bias_shape = int_shape(bias)
5614   if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1:
5615     raise ValueError(
5616         'Unexpected bias dimensions %d, expect to be 1 or %d dimensions' %
5617         (len(bias_shape), ndim(x)))
5618 
5619   if len(bias_shape) == 1:
5620     if data_format == 'channels_first':
5621       return nn.bias_add(x, bias, data_format='NCHW')
5622     return nn.bias_add(x, bias, data_format='NHWC')
5623   if ndim(x) in (3, 4, 5):
5624     if data_format == 'channels_first':
5625       bias_reshape_axis = (1, bias_shape[-1]) + bias_shape[:-1]
5626       return x + reshape(bias, bias_reshape_axis)
5627     return x + reshape(bias, (1,) + bias_shape)
5628   return nn.bias_add(x, bias)
5629 
5630 
5631 # RANDOMNESS
5632 
5633 
5634 @keras_export('keras.backend.random_normal')
5635 def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
5636   """Returns a tensor with normal distribution of values.
5637 
5638   Arguments:
5639       shape: A tuple of integers, the shape of tensor to create.
5640       mean: A float, mean of the normal distribution to draw samples.
5641       stddev: A float, standard deviation of the normal distribution
5642           to draw samples.
5643       dtype: String, dtype of returned tensor.
5644       seed: Integer, random seed.
5645 
5646   Returns:
5647       A tensor.
5648   """
5649   if dtype is None:
5650     dtype = floatx()
5651   if seed is None:
5652     seed = np.random.randint(10e6)
5653   return random_ops.random_normal(
5654       shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed)
5655 
5656 
5657 @keras_export('keras.backend.random_uniform')
5658 def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
5659   """Returns a tensor with uniform distribution of values.
5660 
5661   Arguments:
5662       shape: A tuple of integers, the shape of tensor to create.
5663       minval: A float, lower boundary of the uniform distribution
5664           to draw samples.
5665       maxval: A float, upper boundary of the uniform distribution
5666           to draw samples.
5667       dtype: String, dtype of returned tensor.
5668       seed: Integer, random seed.
5669 
5670   Returns:
5671       A tensor.
5672   """
5673   if dtype is None:
5674     dtype = floatx()
5675   if seed is None:
5676     seed = np.random.randint(10e6)
5677   return random_ops.random_uniform(
5678       shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed)
5679 
5680 
5681 @keras_export('keras.backend.random_binomial')
5682 def random_binomial(shape, p=0.0, dtype=None, seed=None):
5683   """Returns a tensor with random binomial distribution of values.
5684 
5685   The binomial distribution with parameters `n` and `p` is the probability
5686   distribution of the number of successful Bernoulli process. Only supports
5687   `n` = 1 for now.
5688 
5689   Arguments:
5690       shape: A tuple of integers, the shape of tensor to create.
5691       p: A float, `0. <= p <= 1`, probability of binomial distribution.
5692       dtype: String, dtype of returned tensor.
5693       seed: Integer, random seed.
5694 
5695   Returns:
5696       A tensor.
5697   """
5698   if dtype is None:
5699     dtype = floatx()
5700   if seed is None:
5701     seed = np.random.randint(10e6)
5702   return array_ops.where_v2(
5703       random_ops.random_uniform(shape, dtype=dtype, seed=seed) <= p,
5704       array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
5705 
5706 
5707 @keras_export('keras.backend.truncated_normal')
5708 def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
5709   """Returns a tensor with truncated random normal distribution of values.
5710 
5711   The generated values follow a normal distribution
5712   with specified mean and standard deviation,
5713   except that values whose magnitude is more than
5714   two standard deviations from the mean are dropped and re-picked.
5715 
5716   Arguments:
5717       shape: A tuple of integers, the shape of tensor to create.
5718       mean: Mean of the values.
5719       stddev: Standard deviation of the values.
5720       dtype: String, dtype of returned tensor.
5721       seed: Integer, random seed.
5722 
5723   Returns:
5724       A tensor.
5725   """
5726   if dtype is None:
5727     dtype = floatx()
5728   if seed is None:
5729     seed = np.random.randint(10e6)
5730   return random_ops.truncated_normal(
5731       shape, mean, stddev, dtype=dtype, seed=seed)
5732 
5733 
5734 # CTC
5735 # TensorFlow has a native implementation, but it uses sparse tensors
5736 # and therefore requires a wrapper for Keras. The functions below convert
5737 # dense to sparse tensors and also wraps up the beam search code that is
5738 # in TensorFlow's CTC implementation
5739 
5740 
5741 @keras_export('keras.backend.ctc_label_dense_to_sparse')
5742 def ctc_label_dense_to_sparse(labels, label_lengths):
5743   """Converts CTC labels from dense to sparse.
5744 
5745   Arguments:
5746       labels: dense CTC labels.
5747       label_lengths: length of the labels.
5748 
5749   Returns:
5750       A sparse tensor representation of the labels.
5751   """
5752   label_shape = array_ops.shape(labels)
5753   num_batches_tns = array_ops.stack([label_shape[0]])
5754   max_num_labels_tns = array_ops.stack([label_shape[1]])
5755 
5756   def range_less_than(old_input, current_input):
5757     return array_ops.expand_dims(
5758         math_ops.range(array_ops.shape(old_input)[1]), 0) < array_ops.fill(
5759             max_num_labels_tns, current_input)
5760 
5761   init = math_ops.cast(
5762       array_ops.fill([1, label_shape[1]], 0), dtypes_module.bool)
5763   dense_mask = functional_ops.scan(
5764       range_less_than, label_lengths, initializer=init, parallel_iterations=1)
5765   dense_mask = dense_mask[:, 0, :]
5766 
5767   label_array = array_ops.reshape(
5768       array_ops.tile(math_ops.range(0, label_shape[1]), num_batches_tns),
5769       label_shape)
5770   label_ind = array_ops.boolean_mask(label_array, dense_mask)
5771 
5772   batch_array = array_ops.transpose(
5773       array_ops.reshape(
5774           array_ops.tile(math_ops.range(0, label_shape[0]), max_num_labels_tns),
5775           reverse(label_shape, 0)))
5776   batch_ind = array_ops.boolean_mask(batch_array, dense_mask)
5777   indices = array_ops.transpose(
5778       array_ops.reshape(concatenate([batch_ind, label_ind], axis=0), [2, -1]))
5779 
5780   vals_sparse = array_ops.gather_nd(labels, indices)
5781 
5782   return sparse_tensor.SparseTensor(
5783       math_ops.cast(indices, dtypes_module.int64), vals_sparse,
5784       math_ops.cast(label_shape, dtypes_module.int64))
5785 
5786 
5787 @keras_export('keras.backend.ctc_batch_cost')
5788 def ctc_batch_cost(y_true, y_pred, input_length, label_length):
5789   """Runs CTC loss algorithm on each batch element.
5790 
5791   Arguments:
5792       y_true: tensor `(samples, max_string_length)`
5793           containing the truth labels.
5794       y_pred: tensor `(samples, time_steps, num_categories)`
5795           containing the prediction, or output of the softmax.
5796       input_length: tensor `(samples, 1)` containing the sequence length for
5797           each batch item in `y_pred`.
5798       label_length: tensor `(samples, 1)` containing the sequence length for
5799           each batch item in `y_true`.
5800 
5801   Returns:
5802       Tensor with shape (samples,1) containing the
5803           CTC loss of each element.
5804   """
5805   label_length = math_ops.cast(
5806       array_ops.squeeze(label_length, axis=-1), dtypes_module.int32)
5807   input_length = math_ops.cast(
5808       array_ops.squeeze(input_length, axis=-1), dtypes_module.int32)
5809   sparse_labels = math_ops.cast(
5810       ctc_label_dense_to_sparse(y_true, label_length), dtypes_module.int32)
5811 
5812   y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
5813 
5814   return array_ops.expand_dims(
5815       ctc.ctc_loss(
5816           inputs=y_pred, labels=sparse_labels, sequence_length=input_length), 1)
5817 
5818 
5819 @keras_export('keras.backend.ctc_decode')
5820 def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
5821   """Decodes the output of a softmax.
5822 
5823   Can use either greedy search (also known as best path)
5824   or a constrained dictionary search.
5825 
5826   Arguments:
5827       y_pred: tensor `(samples, time_steps, num_categories)`
5828           containing the prediction, or output of the softmax.
5829       input_length: tensor `(samples, )` containing the sequence length for
5830           each batch item in `y_pred`.
5831       greedy: perform much faster best-path search if `true`.
5832           This does not use a dictionary.
5833       beam_width: if `greedy` is `false`: a beam search decoder will be used
5834           with a beam of this width.
5835       top_paths: if `greedy` is `false`,
5836           how many of the most probable paths will be returned.
5837 
5838   Returns:
5839       Tuple:
5840           List: if `greedy` is `true`, returns a list of one element that
5841               contains the decoded sequence.
5842               If `false`, returns the `top_paths` most probable
5843               decoded sequences.
5844               Important: blank labels are returned as `-1`.
5845           Tensor `(top_paths, )` that contains
5846               the log probability of each decoded sequence.
5847   """
5848   y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
5849   input_length = math_ops.cast(input_length, dtypes_module.int32)
5850 
5851   if greedy:
5852     (decoded, log_prob) = ctc.ctc_greedy_decoder(
5853         inputs=y_pred, sequence_length=input_length)
5854   else:
5855     (decoded, log_prob) = ctc.ctc_beam_search_decoder(
5856         inputs=y_pred,
5857         sequence_length=input_length,
5858         beam_width=beam_width,
5859         top_paths=top_paths)
5860   decoded_dense = [
5861       sparse_ops.sparse_to_dense(
5862           st.indices, st.dense_shape, st.values, default_value=-1)
5863       for st in decoded
5864   ]
5865   return (decoded_dense, log_prob)
5866 
5867 
5868 # HIGH ORDER FUNCTIONS
5869 
5870 
5871 @keras_export('keras.backend.map_fn')
5872 def map_fn(fn, elems, name=None, dtype=None):
5873   """Map the function fn over the elements elems and return the outputs.
5874 
5875   Arguments:
5876       fn: Callable that will be called upon each element in elems
5877       elems: tensor
5878       name: A string name for the map node in the graph
5879       dtype: Output data type.
5880 
5881   Returns:
5882       Tensor with dtype `dtype`.
5883   """
5884   return map_fn_lib.map_fn(fn, elems, name=name, dtype=dtype)
5885 
5886 
5887 @keras_export('keras.backend.foldl')
5888 def foldl(fn, elems, initializer=None, name=None):
5889   """Reduce elems using fn to combine them from left to right.
5890 
5891   Arguments:
5892       fn: Callable that will be called upon each element in elems and an
5893           accumulator, for instance `lambda acc, x: acc + x`
5894       elems: tensor
5895       initializer: The first value used (`elems[0]` in case of None)
5896       name: A string name for the foldl node in the graph
5897 
5898   Returns:
5899       Tensor with same type and shape as `initializer`.
5900   """
5901   return functional_ops.foldl(fn, elems, initializer=initializer, name=name)
5902 
5903 
5904 @keras_export('keras.backend.foldr')
5905 def foldr(fn, elems, initializer=None, name=None):
5906   """Reduce elems using fn to combine them from right to left.
5907 
5908   Arguments:
5909       fn: Callable that will be called upon each element in elems and an
5910           accumulator, for instance `lambda acc, x: acc + x`
5911       elems: tensor
5912       initializer: The first value used (`elems[-1]` in case of None)
5913       name: A string name for the foldr node in the graph
5914 
5915   Returns:
5916       Same type and shape as initializer
5917   """
5918   return functional_ops.foldr(fn, elems, initializer=initializer, name=name)
5919 
5920 # Load Keras default configuration from config file if present.
5921 # Set Keras base dir path given KERAS_HOME env variable, if applicable.
5922 # Otherwise either ~/.keras or /tmp.
5923 if 'KERAS_HOME' in os.environ:
5924   _keras_dir = os.environ.get('KERAS_HOME')
5925 else:
5926   _keras_base_dir = os.path.expanduser('~')
5927   _keras_dir = os.path.join(_keras_base_dir, '.keras')
5928 _config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
5929 if os.path.exists(_config_path):
5930   try:
5931     with open(_config_path) as fh:
5932       _config = json.load(fh)
5933   except ValueError:
5934     _config = {}
5935   _floatx = _config.get('floatx', floatx())
5936   assert _floatx in {'float16', 'float32', 'float64'}
5937   _epsilon = _config.get('epsilon', epsilon())
5938   assert isinstance(_epsilon, float)
5939   _image_data_format = _config.get('image_data_format', image_data_format())
5940   assert _image_data_format in {'channels_last', 'channels_first'}
5941   set_floatx(_floatx)
5942   set_epsilon(_epsilon)
5943   set_image_data_format(_image_data_format)
5944 
5945 # Save config file.
5946 if not os.path.exists(_keras_dir):
5947   try:
5948     os.makedirs(_keras_dir)
5949   except OSError:
5950     # Except permission denied and potential race conditions
5951     # in multi-threaded environments.
5952     pass
5953 
5954 if not os.path.exists(_config_path):
5955   _config = {
5956       'floatx': floatx(),
5957       'epsilon': epsilon(),
5958       'backend': 'tensorflow',
5959       'image_data_format': image_data_format()
5960   }
5961   try:
5962     with open(_config_path, 'w') as f:
5963       f.write(json.dumps(_config, indent=4))
5964   except IOError:
5965     # Except permission denied.
5966     pass
5967 
5968 
5969 def configure_and_create_distributed_session(distribution_strategy):
5970   """Configure session config and create a session with it."""
5971 
5972   def _create_session(distribution_strategy):
5973     """Create the Distributed Strategy session."""
5974     session_config = get_default_session_config()
5975 
5976     # If a session already exists, merge in its config; in the case there is a
5977     # conflict, take values of the existing config.
5978     global _SESSION
5979     if getattr(_SESSION, 'session', None) and _SESSION.session._config:
5980       session_config.MergeFrom(_SESSION.session._config)
5981 
5982     if is_tpu_strategy(distribution_strategy):
5983       # TODO(priyag, yuefengz): Remove this workaround when Distribute
5984       # Coordinator is integrated with keras and we can create a session from
5985       # there.
5986       distribution_strategy.configure(session_config)
5987       master = distribution_strategy.extended._tpu_cluster_resolver.master()  # pylint: disable=protected-access
5988       session = session_module.Session(config=session_config, target=master)
5989     else:
5990       worker_context = dc_context.get_current_worker_context()
5991       if worker_context:
5992         dc_session_config = worker_context.session_config
5993         # Merge the default session config to the one from distribute
5994         # coordinator, which is fine for now since they don't have
5995         # conflicting configurations.
5996         dc_session_config.MergeFrom(session_config)
5997         session = session_module.Session(
5998             config=dc_session_config, target=worker_context.master_target)
5999       else:
6000         distribution_strategy.configure(session_config)
6001         session = session_module.Session(config=session_config)
6002 
6003     set_session(session)
6004 
6005   if distribution_strategy.extended._in_multi_worker_mode():
6006     dc.run_distribute_coordinator(
6007         _create_session,
6008         distribution_strategy,
6009         mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
6010   else:
6011     _create_session(distribution_strategy)
6012 
6013 
6014 def is_tpu_strategy(strategy):
6015   """We're executing TPU Strategy."""
6016   return (strategy is not None and
6017           strategy.__class__.__name__.startswith('TPUStrategy'))
6018 
6019 
6020 def cast_variables_to_tensor(tensors):
6021 
6022   def _cast_variables_to_tensor(tensor):
6023     if isinstance(tensor, variables_module.Variable):
6024       return array_ops.identity(tensor)
6025     return tensor
6026 
6027   return nest.map_structure(_cast_variables_to_tensor, tensors)
6028 
6029 
6030 def _is_symbolic_tensor(x):
6031   return tensor_util.is_tensor(x) and not isinstance(x, ops.EagerTensor)
6032 
6033 
6034 def convert_inputs_if_ragged(inputs):
6035   """Converts any ragged tensors to dense."""
6036 
6037   def _convert_ragged_input(inputs):
6038     if isinstance(inputs, ragged_tensor.RaggedTensor):
6039       return inputs.to_tensor()
6040     return inputs
6041 
6042   flat_inputs = nest.flatten(inputs)
6043   contains_ragged = py_any(
6044       isinstance(i, ragged_tensor.RaggedTensor) for i in flat_inputs)
6045 
6046   if not contains_ragged:
6047     return inputs, None
6048 
6049   inputs = nest.map_structure(_convert_ragged_input, inputs)
6050   # Multiple mask are not yet supported, so one mask is used on all inputs.
6051   # We approach this similarly when using row lengths to ignore steps.
6052   nested_row_lengths = math_ops.cast(flat_inputs[0].nested_row_lengths()[0],
6053                                      'int32')
6054   return inputs, nested_row_lengths
6055 
6056 
6057 def maybe_convert_to_ragged(is_ragged_input, output, nested_row_lengths):
6058   """Converts any ragged input back to its initial structure."""
6059   if not is_ragged_input:
6060     return output
6061 
6062   return ragged_tensor.RaggedTensor.from_tensor(output, nested_row_lengths)
6063