• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras backend."""
16
17import gc
18import warnings
19
20from absl.testing import parameterized
21import numpy as np
22import scipy.sparse
23
24from tensorflow.core.protobuf import config_pb2
25from tensorflow.python.eager import context
26from tensorflow.python.eager import def_function
27from tensorflow.python.eager.context import get_config
28from tensorflow.python.framework import config
29from tensorflow.python.framework import errors_impl
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.framework import test_util
33from tensorflow.python.keras import activations
34from tensorflow.python.keras import backend
35from tensorflow.python.keras import combinations
36from tensorflow.python.keras.engine import input_layer
37from tensorflow.python.keras.layers import advanced_activations
38from tensorflow.python.keras.layers.normalization import batch_normalization_v1
39from tensorflow.python.keras.utils import tf_inspect
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import nn
42from tensorflow.python.ops import variables
43from tensorflow.python.platform import test
44
45
46def compare_single_input_op_to_numpy(keras_op,
47                                     np_op,
48                                     input_shape,
49                                     dtype='float32',
50                                     negative_values=True,
51                                     keras_args=None,
52                                     keras_kwargs=None,
53                                     np_args=None,
54                                     np_kwargs=None):
55  keras_args = keras_args or []
56  keras_kwargs = keras_kwargs or {}
57  np_args = np_args or []
58  np_kwargs = np_kwargs or {}
59  inputs = 2. * np.random.random(input_shape)
60  if negative_values:
61    inputs -= 1.
62  keras_output = keras_op(
63      backend.variable(inputs, dtype=dtype), *keras_args, **keras_kwargs)
64  keras_output = backend.eval(keras_output)
65  np_output = np_op(inputs.astype(dtype), *np_args, **np_kwargs)
66  try:
67    np.testing.assert_allclose(keras_output, np_output, atol=1e-4)
68  except AssertionError:
69    raise AssertionError('Test for op `' + str(keras_op.__name__) + '` failed; '
70                         'Expected ' + str(np_output) + ' but got ' +
71                         str(keras_output))
72
73
74def compare_two_inputs_op_to_numpy(keras_op,
75                                   np_op,
76                                   input_shape_a,
77                                   input_shape_b,
78                                   dtype='float32',
79                                   keras_args=None,
80                                   keras_kwargs=None,
81                                   np_args=None,
82                                   np_kwargs=None):
83  keras_args = keras_args or []
84  keras_kwargs = keras_kwargs or {}
85  np_args = np_args or []
86  np_kwargs = np_kwargs or {}
87  input_a = np.random.random(input_shape_a)
88  input_b = np.random.random(input_shape_b)
89  keras_output = keras_op(
90      backend.variable(input_a, dtype=dtype),
91      backend.variable(input_b, dtype=dtype), *keras_args, **keras_kwargs)
92  keras_output = backend.eval(keras_output)
93  np_output = np_op(
94      input_a.astype(dtype), input_b.astype(dtype), *np_args, **np_kwargs)
95  try:
96    np.testing.assert_allclose(keras_output, np_output, atol=1e-4)
97  except AssertionError:
98    raise AssertionError('Test for op `' + str(keras_op.__name__) + '` failed; '
99                         'Expected ' + str(np_output) + ' but got ' +
100                         str(keras_output))
101
102
103class BackendResetTest(test.TestCase, parameterized.TestCase):
104
105  def test_new_config(self):
106    # User defined jit setting
107    config.set_optimizer_jit(False)
108    sess = backend.get_session()
109    default_config = get_config()
110    self.assertEqual(
111        sess._config.graph_options.optimizer_options.global_jit_level,
112        default_config.graph_options.optimizer_options.global_jit_level)
113    backend.clear_session()
114
115    # New session has the same jit setting
116    sess = backend.get_session()
117    default_config = get_config()
118    self.assertEqual(
119        sess._config.graph_options.optimizer_options.global_jit_level,
120        default_config.graph_options.optimizer_options.global_jit_level)
121    backend.clear_session()
122
123    # Change respected
124    config.set_optimizer_jit(True)
125    sess = backend.get_session()
126    default_config = get_config()
127    self.assertEqual(
128        sess._config.graph_options.optimizer_options.global_jit_level,
129        default_config.graph_options.optimizer_options.global_jit_level)
130    backend.clear_session()
131
132  # We can't use the normal parameterized decorator because the test session
133  # will block graph clearing.
134  @parameterized.named_parameters(('_v1', context.graph_mode),
135                                  ('_v2', context.eager_mode))
136  def test_new_graph(self, test_context):
137    with test_context():
138      g_old = backend.get_graph()
139      backend.clear_session()
140      g = backend.get_graph()
141
142      assert g_old is not g
143
144
145@combinations.generate(combinations.combine(mode=['graph', 'eager']))
146class BackendUtilsTest(test.TestCase):
147
148  def test_backend(self):
149    self.assertEqual(backend.backend(), 'tensorflow')
150
151  def test_get_reset_uids(self):
152    self.assertEqual(backend.get_uid('foo'), 1)
153    self.assertEqual(backend.get_uid('foo'), 2)
154
155    backend.reset_uids()
156    self.assertEqual(backend.get_uid('foo'), 1)
157
158  def test_learning_phase(self):
159    with self.cached_session() as sess:
160      with self.assertRaises(ValueError):
161        backend.set_learning_phase(2)
162
163      # Test running with a learning-phase-consuming layer
164      with backend.learning_phase_scope(0):
165        x = input_layer.Input((3,))
166        y = batch_normalization_v1.BatchNormalization()(x)
167        if not context.executing_eagerly():
168          self.evaluate(variables.global_variables_initializer())
169          sess.run(y, feed_dict={x: np.random.random((2, 3))})
170
171  def test_learning_phase_name(self):
172    with backend.name_scope('test_scope'):
173      # Test that outer name scopes do not affect the learning phase's name.
174      lp = backend.symbolic_learning_phase()
175    self.assertEqual(lp.name, 'keras_learning_phase:0')
176
177  def test_learning_phase_scope(self):
178    initial_learning_phase = backend.learning_phase()
179    with backend.learning_phase_scope(1):
180      self.assertEqual(backend.learning_phase(), 1)
181    self.assertEqual(backend.learning_phase(), initial_learning_phase)
182    with backend.learning_phase_scope(0):
183      self.assertEqual(backend.learning_phase(), 0)
184    self.assertEqual(backend.learning_phase(), initial_learning_phase)
185    with self.assertRaises(ValueError):
186      with backend.learning_phase_scope(None):
187        pass
188    self.assertEqual(backend.learning_phase(), initial_learning_phase)
189
190    new_learning_phase = 0
191    backend.set_learning_phase(new_learning_phase)
192    self.assertEqual(backend.learning_phase(), new_learning_phase)
193    with backend.learning_phase_scope(1):
194      self.assertEqual(backend.learning_phase(), 1)
195    self.assertEqual(backend.learning_phase(), new_learning_phase)
196
197  def test_learning_phase_scope_in_graph(self):
198    initial_learning_phase_outside_graph = backend.learning_phase()
199    with backend.get_graph().as_default():
200      initial_learning_phase_in_graph = backend.learning_phase()
201
202    self.assertEqual(backend.learning_phase(),
203                     initial_learning_phase_outside_graph)
204    with backend.learning_phase_scope(1):
205      self.assertEqual(backend.learning_phase(), 1)
206    self.assertEqual(backend.learning_phase(),
207                     initial_learning_phase_outside_graph)
208
209    with backend.get_graph().as_default():
210      self.assertIs(backend.learning_phase(), initial_learning_phase_in_graph)
211
212    self.assertEqual(backend.learning_phase(),
213                     initial_learning_phase_outside_graph)
214
215  def test_int_shape(self):
216    x = backend.ones(shape=(3, 4))
217    self.assertEqual(backend.int_shape(x), (3, 4))
218
219    if not context.executing_eagerly():
220      x = backend.placeholder(shape=(None, 4))
221      self.assertEqual(backend.int_shape(x), (None, 4))
222
223  def test_in_train_phase(self):
224    y1 = backend.variable(1)
225    y2 = backend.variable(2)
226    if context.executing_eagerly():
227      with backend.learning_phase_scope(0):
228        y_val_test = backend.in_train_phase(y1, y2).numpy()
229      with backend.learning_phase_scope(1):
230        y_val_train = backend.in_train_phase(y1, y2).numpy()
231    else:
232      y = backend.in_train_phase(y1, y2)
233      f = backend.function([backend.learning_phase()], [y])
234      y_val_test = f([0])[0]
235      y_val_train = f([1])[0]
236    self.assertAllClose(y_val_test, 2)
237    self.assertAllClose(y_val_train, 1)
238
239  def test_is_keras_tensor(self):
240    x = backend.variable(1)
241    self.assertEqual(backend.is_keras_tensor(x), False)
242    x = input_layer.Input(shape=(1,))
243    self.assertEqual(backend.is_keras_tensor(x), True)
244    x = input_layer.Input(shape=(None,), ragged=True)
245    self.assertEqual(backend.is_keras_tensor(x), True)
246    x = input_layer.Input(shape=(None, None), sparse=True)
247    self.assertEqual(backend.is_keras_tensor(x), True)
248    with self.assertRaises(ValueError):
249      backend.is_keras_tensor(0)
250
251  def test_stop_gradient(self):
252    x = backend.variable(1)
253    y = backend.stop_gradient(x)
254    if not context.executing_eagerly():
255      self.assertEqual(y.op.name[:12], 'StopGradient')
256
257    xs = [backend.variable(1) for _ in range(3)]
258    ys = backend.stop_gradient(xs)
259    if not context.executing_eagerly():
260      for y in ys:
261        self.assertEqual(y.op.name[:12], 'StopGradient')
262
263  def test_placeholder(self):
264    x = backend.placeholder(shape=(3, 4))
265    self.assertEqual(x.shape.as_list(), [3, 4])
266    x = backend.placeholder(shape=(3, 4), sparse=True)
267    self.assertEqual(x.shape.as_list(), [3, 4])
268
269  def test_is_placeholder(self):
270    x = backend.placeholder(shape=(1,))
271    self.assertEqual(backend.is_placeholder(x), True)
272    x = backend.variable(1)
273    self.assertEqual(backend.is_placeholder(x), False)
274
275  def test_print_tensor(self):
276    # Unfortunately it seems impossible to use `mock` (or any other method)
277    # to capture stdout when used inside a graph or graph function, thus
278    # we cannot test correctness.
279    # The message gets correctly printed in practice.
280    x = backend.placeholder(shape=())
281    y = backend.print_tensor(x, 'eager=%s' % context.executing_eagerly())
282    f = backend.function(x, y)
283    f(0)
284
285  def test_cast_to_floatx(self):
286    x = backend.variable(1, dtype='float64')
287    x = backend.cast_to_floatx(x)
288    self.assertEqual(x.dtype.name, 'float32')
289    x = backend.cast_to_floatx(2)
290    self.assertEqual(x.dtype.name, 'float32')
291
292
293@combinations.generate(combinations.combine(mode=['graph', 'eager']))
294class BackendVariableTest(test.TestCase):
295
296  def test_zeros(self):
297    x = backend.zeros((3, 4))
298    val = backend.eval(x)
299    self.assertAllClose(val, np.zeros((3, 4)))
300
301  def test_ones(self):
302    x = backend.ones((3, 4))
303    val = backend.eval(x)
304    self.assertAllClose(val, np.ones((3, 4)))
305
306  def test_eye(self):
307    x = backend.eye(4)
308    val = backend.eval(x)
309    self.assertAllClose(val, np.eye(4))
310
311  def test_zeros_like(self):
312    x = backend.zeros((3, 4))
313    y = backend.zeros_like(x)
314    val = backend.eval(y)
315    self.assertAllClose(val, np.zeros((3, 4)))
316
317  def test_ones_like(self):
318    x = backend.zeros((3, 4))
319    y = backend.ones_like(x)
320    val = backend.eval(y)
321    self.assertAllClose(val, np.ones((3, 4)))
322
323  def test_random_uniform_variable(self):
324    x = backend.random_uniform_variable((30, 20), low=1, high=2, seed=0)
325    val = backend.eval(x)
326    self.assertAllClose(val.mean(), 1.5, atol=1e-1)
327    self.assertAllClose(val.max(), 2., atol=1e-1)
328    self.assertAllClose(val.min(), 1., atol=1e-1)
329
330  def test_random_normal_variable(self):
331    x = backend.random_normal_variable((30, 20), 1., 0.5, seed=0)
332    val = backend.eval(x)
333    self.assertAllClose(val.mean(), 1., atol=1e-1)
334    self.assertAllClose(val.std(), 0.5, atol=1e-1)
335
336  def test_count_params(self):
337    x = backend.zeros((4, 5))
338    val = backend.count_params(x)
339    self.assertAllClose(val, 20)
340
341  def test_constant(self):
342    ref_val = np.random.random((3, 4)).astype('float32')
343    x = backend.constant(ref_val)
344    val = backend.eval(x)
345    self.assertAllClose(val, ref_val)
346
347  def test_sparse_variable(self):
348    val = scipy.sparse.eye(10)
349    x = backend.variable(val)
350    self.assertTrue(isinstance(x, sparse_tensor.SparseTensor))
351
352    y = backend.to_dense(x)
353    self.assertFalse(backend.is_sparse(y))
354
355
356@combinations.generate(combinations.combine(mode=['graph', 'eager']))
357class BackendLinearAlgebraTest(test.TestCase, parameterized.TestCase):
358
359  def test_dot(self):
360    x = backend.ones(shape=(2, 3))
361    y = backend.ones(shape=(3, 4))
362    xy = backend.dot(x, y)
363    self.assertEqual(xy.shape.as_list(), [2, 4])
364
365    x = backend.ones(shape=(32, 28, 3))
366    y = backend.ones(shape=(3, 4))
367    xy = backend.dot(x, y)
368    self.assertEqual(xy.shape.as_list(), [32, 28, 4])
369
370  @parameterized.parameters(
371      [(2, 3, 4, 5), (2, 5, 6, 7), (2, 3, 4, 6, 7), (3, 1)],
372      [(2, 20, 1), (2, 30, 20), (2, 1, 30), (1, 2)],
373      [(4, 2, 3), (4, 5, 3), (4, 2, 5), (2, 2)],
374      [(4, 2), (4, 2, 3), (4, 3), (1, 1)],
375      [(4, 2), (4, 2, 3), (4, 3), 1],
376      [(4, 2, 3), (4, 3), (4, 2), (2, 1)],
377  )
378  def test_batch_dot(self, x_shape, y_shape, output_shape, axes):
379    x_val = np.random.random(x_shape)
380    y_val = np.random.random(y_shape)
381    x = backend.variable(x_val)
382    y = backend.variable(y_val)
383    xy = backend.batch_dot(x, y, axes=axes)
384    self.assertEqual(tuple(xy.shape.as_list()), output_shape)
385    xy_val = backend.eval(xy)
386    ref_val = self._reference_batch_dot(x_val, y_val, axes)
387    self.assertAllClose(xy_val, ref_val, atol=1e-5)
388
389  def _reference_batch_dot(self, x, y, axes):
390    if isinstance(axes, int):
391      axes = [axes, axes]
392    elif isinstance(axes, tuple):
393      axes = list(axes)
394    if axes is None:
395      if y.ndim == 2:
396        axes = [x.ndim - 1, y.ndim - 1]
397      else:
398        axes = [x.ndim - 1, y.ndim - 2]
399    if axes[0] < 0:
400      axes[0] += x.ndim
401    if axes[1] < 0:
402      axes[1] += y.ndim
403    result = []
404    axes = [axes[0] - 1, axes[1] - 1]
405    for xi, yi in zip(x, y):
406      result.append(np.tensordot(xi, yi, axes))
407    result = np.array(result)
408    if result.ndim == 1:
409      result = np.expand_dims(result, -1)
410    return result
411
412  def test_reduction_ops(self):
413    ops_to_test = [
414        (backend.max, np.max),
415        (backend.min, np.min),
416        (backend.sum, np.sum),
417        (backend.prod, np.prod),
418        (backend.var, np.var),
419        (backend.std, np.std),
420        (backend.mean, np.mean),
421        (backend.argmin, np.argmin),
422        (backend.argmax, np.argmax),
423    ]
424    for keras_op, np_op in ops_to_test:
425      compare_single_input_op_to_numpy(
426          keras_op,
427          np_op,
428          input_shape=(4, 7, 5),
429          keras_kwargs={'axis': 1},
430          np_kwargs={'axis': 1})
431      compare_single_input_op_to_numpy(
432          keras_op,
433          np_op,
434          input_shape=(4, 7, 5),
435          keras_kwargs={'axis': -1},
436          np_kwargs={'axis': -1})
437      if 'keepdims' in tf_inspect.getargspec(keras_op).args:
438        compare_single_input_op_to_numpy(
439            keras_op,
440            np_op,
441            input_shape=(4, 7, 5),
442            keras_kwargs={
443                'axis': 1,
444                'keepdims': True
445            },
446            np_kwargs={
447                'axis': 1,
448                'keepdims': True
449            })
450
451  def test_elementwise_ops(self):
452    ops_to_test = [
453        (backend.square, np.square),
454        (backend.abs, np.abs),
455        (backend.round, np.round),
456        (backend.sign, np.sign),
457        (backend.sin, np.sin),
458        (backend.cos, np.cos),
459        (backend.exp, np.exp),
460    ]
461    for keras_op, np_op in ops_to_test:
462      compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7))
463
464    ops_to_test = [
465        (backend.sqrt, np.sqrt),
466        (backend.log, np.log),
467    ]
468    for keras_op, np_op in ops_to_test:
469      compare_single_input_op_to_numpy(
470          keras_op, np_op, input_shape=(4, 7), negative_values=False)
471
472    compare_single_input_op_to_numpy(
473        backend.clip,
474        np.clip,
475        input_shape=(6, 4),
476        keras_kwargs={
477            'min_value': 0.1,
478            'max_value': 2.4
479        },
480        np_kwargs={
481            'a_min': 0.1,
482            'a_max': 1.4
483        })
484
485    compare_single_input_op_to_numpy(
486        backend.pow, np.power, input_shape=(6, 4), keras_args=[3], np_args=[3])
487
488  def test_two_tensor_ops(self):
489    ops_to_test = [
490        (backend.equal, np.equal),
491        (backend.not_equal, np.not_equal),
492        (backend.greater, np.greater),
493        (backend.greater_equal, np.greater_equal),
494        (backend.less, np.less),
495        (backend.less_equal, np.less_equal),
496        (backend.maximum, np.maximum),
497        (backend.minimum, np.minimum),
498    ]
499    for keras_op, np_op in ops_to_test:
500      compare_two_inputs_op_to_numpy(
501          keras_op, np_op, input_shape_a=(4, 7), input_shape_b=(4, 7))
502
503  def test_relu(self):
504    x = ops.convert_to_tensor_v2_with_dispatch([[-4, 0], [2, 7]], 'float32')
505
506    # standard relu
507    relu_op = backend.relu(x)
508    self.assertAllClose(backend.eval(relu_op), [[0, 0], [2, 7]])
509
510    # alpha (leaky relu used)
511    relu_op = backend.relu(x, alpha=0.5)
512    if not context.executing_eagerly():
513      self.assertTrue('LeakyRelu' in relu_op.name)
514    self.assertAllClose(backend.eval(relu_op), [[-2, 0], [2, 7]])
515
516    # max_value < some elements
517    relu_op = backend.relu(x, max_value=5)
518    self.assertAllClose(backend.eval(relu_op), [[0, 0], [2, 5]])
519
520    # nn.relu6 used
521    relu_op = backend.relu(x, max_value=6)
522    if not context.executing_eagerly():
523      self.assertTrue('Relu6' in relu_op.name)  # uses tf.nn.relu6
524    self.assertAllClose(backend.eval(relu_op), [[0, 0], [2, 6]])
525
526    # max value > 6
527    relu_op = backend.relu(x, max_value=10)
528    self.assertAllClose(backend.eval(relu_op), [[0, 0], [2, 7]])
529
530    # max value is float
531    relu_op = backend.relu(x, max_value=4.3)
532    self.assertAllClose(backend.eval(relu_op), [[0, 0], [2, 4.3]])
533
534    # max value == 0
535    relu_op = backend.relu(x, max_value=0)
536    self.assertAllClose(backend.eval(relu_op), [[0, 0], [0, 0]])
537
538    # alpha and max_value
539    relu_op = backend.relu(x, alpha=0.25, max_value=3)
540    self.assertAllClose(backend.eval(relu_op), [[-1, 0], [2, 3]])
541
542    # threshold
543    relu_op = backend.relu(x, threshold=3)
544    self.assertAllClose(backend.eval(relu_op), [[0, 0], [0, 7]])
545
546    # threshold is float
547    relu_op = backend.relu(x, threshold=1.5)
548    self.assertAllClose(backend.eval(relu_op), [[0, 0], [2, 7]])
549
550    # threshold is negative
551    relu_op = backend.relu(x, threshold=-5)
552    self.assertAllClose(backend.eval(relu_op), [[-4, 0], [2, 7]])
553
554    # threshold and max_value
555    relu_op = backend.relu(x, threshold=3, max_value=5)
556    self.assertAllClose(backend.eval(relu_op), [[0, 0], [0, 5]])
557
558    # threshold and alpha
559    relu_op = backend.relu(x, alpha=0.25, threshold=4)
560    self.assertAllClose(backend.eval(relu_op), [[-2, -1], [-0.5, 7]])
561
562    # threshold, alpha, and max_value
563    relu_op = backend.relu(x, alpha=0.25, threshold=4, max_value=5)
564    self.assertAllClose(backend.eval(relu_op), [[-2, -1], [-0.5, 5]])
565
566    # Test case for GitHub issue 35430, with integer dtype
567    x = input_layer.Input(shape=(), name='x', dtype='int64')
568    _ = advanced_activations.ReLU(max_value=100, dtype='int64')(x)
569
570
571@combinations.generate(combinations.combine(mode=['graph', 'eager']))
572class BackendShapeOpsTest(test.TestCase):
573
574  def test_reshape(self):
575    compare_single_input_op_to_numpy(
576        backend.reshape,
577        np.reshape,
578        input_shape=(4, 7),
579        keras_args=[(2, 14)],
580        np_args=[(2, 14)])
581
582  def test_concatenate(self):
583    a = backend.variable(np.ones((1, 2, 3)))
584    b = backend.variable(np.ones((1, 2, 2)))
585    y = backend.concatenate([a, b], axis=-1)
586    self.assertEqual(y.shape.as_list(), [1, 2, 5])
587
588  def test_permute_dimensions(self):
589    compare_single_input_op_to_numpy(
590        backend.permute_dimensions,
591        np.transpose,
592        input_shape=(4, 7),
593        keras_args=[(1, 0)],
594        np_args=[(1, 0)])
595
596  def test_resize_images(self):
597    height_factor = 2
598    width_factor = 2
599    data_format = 'channels_last'
600    x = backend.variable(np.ones((1, 2, 2, 3)))
601    y = backend.resize_images(x, height_factor, width_factor, data_format)
602    self.assertEqual(y.shape.as_list(), [1, 4, 4, 3])
603
604    data_format = 'channels_first'
605    x = backend.variable(np.ones((1, 3, 2, 2)))
606    y = backend.resize_images(x, height_factor, width_factor, data_format)
607    self.assertEqual(y.shape.as_list(), [1, 3, 4, 4])
608
609    # Use with a dynamic axis:
610    if not context.executing_eagerly():
611      x = backend.placeholder(shape=(1, 3, None, None))
612      y = backend.resize_images(x, height_factor, width_factor, data_format)
613      self.assertEqual(y.shape.as_list(), [1, 3, None, None])
614
615    # Invalid use:
616    with self.assertRaises(ValueError):
617      backend.resize_images(
618          x, height_factor, width_factor, data_format='unknown')
619
620  def test_resize_volumes(self):
621    height_factor = 2
622    width_factor = 2
623    depth_factor = 2
624    data_format = 'channels_last'
625    x = backend.variable(np.ones((1, 2, 2, 2, 3)))
626    y = backend.resize_volumes(x, depth_factor, height_factor, width_factor,
627                               data_format)
628    self.assertEqual(y.shape.as_list(), [1, 4, 4, 4, 3])
629
630    data_format = 'channels_first'
631    x = backend.variable(np.ones((1, 3, 2, 2, 2)))
632    y = backend.resize_volumes(x, depth_factor, height_factor, width_factor,
633                               data_format)
634    self.assertEqual(y.shape.as_list(), [1, 3, 4, 4, 4])
635
636    # Invalid use:
637    with self.assertRaises(ValueError):
638      backend.resize_volumes(
639          x, depth_factor, height_factor, width_factor, data_format='unknown')
640
641  def test_repeat_elements(self):
642    x = backend.variable(np.ones((1, 3, 2)))
643    y = backend.repeat_elements(x, 3, axis=1)
644    self.assertEqual(y.shape.as_list(), [1, 9, 2])
645
646    # Use with a dynamic axis:
647    if not context.executing_eagerly():
648      x = backend.placeholder(shape=(2, None, 2))
649      y = backend.repeat_elements(x, 3, axis=1)
650      self.assertEqual(y.shape.as_list(), [2, None, 2])
651
652  def test_repeat(self):
653    x = backend.variable(np.ones((1, 3)))
654    y = backend.repeat(x, 2)
655    self.assertEqual(y.shape.as_list(), [1, 2, 3])
656
657  def test_flatten(self):
658    compare_single_input_op_to_numpy(
659        backend.flatten,
660        np.reshape,
661        input_shape=(4, 7, 6),
662        np_args=[(4 * 7 * 6,)])
663
664  def test_batch_flatten(self):
665    compare_single_input_op_to_numpy(
666        backend.batch_flatten,
667        np.reshape,
668        input_shape=(4, 7, 6),
669        np_args=[(4, 7 * 6)])
670
671  def test_temporal_padding(self):
672
673    def ref_op(x, padding):
674      shape = list(x.shape)
675      shape[1] += padding[0] + padding[1]
676      y = np.zeros(tuple(shape))
677      y[:, padding[0]:-padding[1], :] = x
678      return y
679
680    compare_single_input_op_to_numpy(
681        backend.temporal_padding,
682        ref_op,
683        input_shape=(4, 7, 6),
684        keras_args=[(2, 3)],
685        np_args=[(2, 3)])
686
687  def test_spatial_2d_padding(self):
688
689    def ref_op(x, padding, data_format='channels_last'):
690      shape = list(x.shape)
691      if data_format == 'channels_last':
692        shape[1] += padding[0][0] + padding[0][1]
693        shape[2] += padding[1][0] + padding[1][1]
694        y = np.zeros(tuple(shape))
695        y[:, padding[0][0]:-padding[0][1], padding[1][0]:-padding[1][1], :] = x
696      else:
697        shape[2] += padding[0][0] + padding[0][1]
698        shape[3] += padding[1][0] + padding[1][1]
699        y = np.zeros(tuple(shape))
700        y[:, :, padding[0][0]:-padding[0][1], padding[1][0]:-padding[1][1]] = x
701      return y
702
703    compare_single_input_op_to_numpy(
704        backend.spatial_2d_padding,
705        ref_op,
706        input_shape=(2, 3, 2, 3),
707        keras_args=[((2, 3), (1, 2))],
708        keras_kwargs={'data_format': 'channels_last'},
709        np_args=[((2, 3), (1, 2))],
710        np_kwargs={'data_format': 'channels_last'})
711    compare_single_input_op_to_numpy(
712        backend.spatial_2d_padding,
713        ref_op,
714        input_shape=(2, 3, 2, 3),
715        keras_args=[((2, 3), (1, 2))],
716        keras_kwargs={'data_format': 'channels_first'},
717        np_args=[((2, 3), (1, 2))],
718        np_kwargs={'data_format': 'channels_first'})
719
720  def test_spatial_3d_padding(self):
721
722    def ref_op(x, padding, data_format='channels_last'):
723      shape = list(x.shape)
724      if data_format == 'channels_last':
725        shape[1] += padding[0][0] + padding[0][1]
726        shape[2] += padding[1][0] + padding[1][1]
727        shape[3] += padding[2][0] + padding[2][1]
728        y = np.zeros(tuple(shape))
729        y[:, padding[0][0]:-padding[0][1], padding[1][0]:-padding[1][1],
730          padding[2][0]:-padding[2][1], :] = x
731      else:
732        shape[2] += padding[0][0] + padding[0][1]
733        shape[3] += padding[1][0] + padding[1][1]
734        shape[4] += padding[2][0] + padding[2][1]
735        y = np.zeros(tuple(shape))
736        y[:, :, padding[0][0]:-padding[0][1], padding[1][0]:-padding[1][1],
737          padding[2][0]:-padding[2][1]] = x
738      return y
739
740    compare_single_input_op_to_numpy(
741        backend.spatial_3d_padding,
742        ref_op,
743        input_shape=(2, 3, 2, 3, 2),
744        keras_args=[((2, 3), (1, 2), (2, 3))],
745        keras_kwargs={'data_format': 'channels_last'},
746        np_args=[((2, 3), (1, 2), (2, 3))],
747        np_kwargs={'data_format': 'channels_last'})
748    compare_single_input_op_to_numpy(
749        backend.spatial_3d_padding,
750        ref_op,
751        input_shape=(2, 3, 2, 3, 2),
752        keras_args=[((2, 3), (1, 2), (2, 3))],
753        keras_kwargs={'data_format': 'channels_first'},
754        np_args=[((2, 3), (1, 2), (2, 3))],
755        np_kwargs={'data_format': 'channels_first'})
756
757
758@combinations.generate(combinations.combine(mode=['graph', 'eager']))
759class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
760
761  def test_bias_add(self):
762    keras_op = backend.bias_add
763    np_op = np.add
764    compare_two_inputs_op_to_numpy(
765        keras_op, np_op, input_shape_a=(4, 7), input_shape_b=(7,))
766    compare_two_inputs_op_to_numpy(
767        keras_op, np_op, input_shape_a=(4, 3, 7), input_shape_b=(7,))
768    compare_two_inputs_op_to_numpy(
769        keras_op, np_op, input_shape_a=(4, 3, 5, 7), input_shape_b=(7,))
770    compare_two_inputs_op_to_numpy(
771        keras_op, np_op, input_shape_a=(4, 3, 5, 2, 7), input_shape_b=(7,))
772
773    with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
774      x = backend.variable((3, 4))
775      b = backend.variable((3, 4))
776      backend.bias_add(x, b)
777    with self.assertRaises(ValueError):
778      x = backend.variable((3, 4))
779      b = backend.variable((4,))
780      backend.bias_add(x, b, data_format='unknown')
781
782  def test_bias_add_channels_first(self):
783
784    def keras_op(x, b):
785      return backend.bias_add(x, b, data_format='channels_first')
786
787    def np_op(x, b):
788      if x.ndim == 3:
789        b = b.reshape((1, b.shape[0], 1))
790      if x.ndim == 4:
791        b = b.reshape((1, b.shape[0], 1, 1))
792      return x + b
793
794    compare_two_inputs_op_to_numpy(
795        keras_op, np_op, input_shape_a=(4, 3, 7), input_shape_b=(3,))
796    compare_two_inputs_op_to_numpy(
797        keras_op, np_op, input_shape_a=(4, 3, 5, 7), input_shape_b=(3,))
798
799  def test_pool2d(self):
800    val = np.random.random((10, 3, 10, 10))
801    x = backend.variable(val)
802    y = backend.pool2d(
803        x, (2, 2),
804        strides=(1, 1),
805        padding='valid',
806        data_format='channels_first',
807        pool_mode='max')
808    self.assertEqual(y.shape.as_list(), [10, 3, 9, 9])
809
810    y = backend.pool2d(
811        x, (2, 2),
812        strides=(1, 1),
813        padding='valid',
814        data_format='channels_first',
815        pool_mode='avg')
816    self.assertEqual(y.shape.as_list(), [10, 3, 9, 9])
817
818    val = np.random.random((10, 10, 10, 3))
819    x = backend.variable(val)
820    y = backend.pool2d(
821        x, (2, 2), strides=(1, 1), padding='valid', data_format='channels_last')
822    self.assertEqual(y.shape.as_list(), [10, 9, 9, 3])
823
824    val = np.random.random((10, 10, 10, 3))
825    x = backend.variable(val)
826    y = backend.pool2d(
827        x, (2, 2), strides=(1, 1), padding='same', data_format='channels_last')
828    self.assertEqual(y.shape.as_list(), [10, 10, 10, 3])
829
830    val = np.random.random((10, 10, 10, 3))
831    x = backend.variable(val)
832    y = backend.pool2d(
833        x, (2, 2), strides=(2, 2), padding='same', data_format='channels_last')
834    self.assertEqual(y.shape.as_list(), [10, 5, 5, 3])
835
836    with self.assertRaises(ValueError):
837      y = backend.pool2d(
838          x, (2, 2),
839          strides=(2, 2),
840          padding='other',
841          data_format='channels_last')
842    with self.assertRaises(ValueError):
843      y = backend.pool2d(x, (2, 2), strides=(2, 2), data_format='other')
844    with self.assertRaises(ValueError):
845      y = backend.pool2d(x, (2, 2, 2), strides=(2, 2))
846    with self.assertRaises(ValueError):
847      y = backend.pool2d(x, (2, 2), strides=(2, 2, 2))
848    with self.assertRaises(ValueError):
849      y = backend.pool2d(x, (2, 2), strides=(2, 2), pool_mode='other')
850
851  def test_pool3d(self):
852    val = np.random.random((10, 3, 10, 10, 10))
853    x = backend.variable(val)
854    y = backend.pool3d(
855        x, (2, 2, 2),
856        strides=(1, 1, 1),
857        padding='valid',
858        data_format='channels_first',
859        pool_mode='max')
860    self.assertEqual(y.shape.as_list(), [10, 3, 9, 9, 9])
861
862    y = backend.pool3d(
863        x, (2, 2, 2),
864        strides=(1, 1, 1),
865        padding='valid',
866        data_format='channels_first',
867        pool_mode='avg')
868    self.assertEqual(y.shape.as_list(), [10, 3, 9, 9, 9])
869
870    val = np.random.random((10, 10, 10, 10, 3))
871    x = backend.variable(val)
872    y = backend.pool3d(
873        x, (2, 2, 2),
874        strides=(1, 1, 1),
875        padding='valid',
876        data_format='channels_last')
877    self.assertEqual(y.shape.as_list(), [10, 9, 9, 9, 3])
878
879    val = np.random.random((10, 10, 10, 10, 3))
880    x = backend.variable(val)
881    y = backend.pool3d(
882        x, (2, 2, 2),
883        strides=(1, 1, 1),
884        padding='same',
885        data_format='channels_last')
886    self.assertEqual(y.shape.as_list(), [10, 10, 10, 10, 3])
887
888    val = np.random.random((10, 10, 10, 10, 3))
889    x = backend.variable(val)
890    y = backend.pool3d(
891        x, (2, 2, 2),
892        strides=(2, 2, 2),
893        padding='same',
894        data_format='channels_last')
895    self.assertEqual(y.shape.as_list(), [10, 5, 5, 5, 3])
896
897  def test_conv1d(self):
898    val = np.random.random((10, 4, 10))
899    x = backend.variable(val)
900    kernel_val = np.random.random((3, 4, 5))
901    k = backend.variable(kernel_val)
902    y = backend.conv1d(
903        x, k, strides=(1,), padding='valid', data_format='channels_first')
904    self.assertEqual(y.shape.as_list(), [10, 5, 8])
905
906    val = np.random.random((10, 10, 4))
907    x = backend.variable(val)
908    y = backend.conv1d(
909        x, k, strides=(1,), padding='valid', data_format='channels_last')
910    self.assertEqual(y.shape.as_list(), [10, 8, 5])
911
912    val = np.random.random((10, 10, 4))
913    x = backend.variable(val)
914    y = backend.conv1d(
915        x, k, strides=(1,), padding='same', data_format='channels_last')
916    self.assertEqual(y.shape.as_list(), [10, 10, 5])
917
918    val = np.random.random((10, 10, 4))
919    x = backend.variable(val)
920    y = backend.conv1d(
921        x, k, strides=(2,), padding='same', data_format='channels_last')
922    self.assertEqual(y.shape.as_list(), [10, 5, 5])
923
924  def test_local_conv_channels_dim(self):
925    filters = 3
926    batch_size = 2
927
928    for input_shape in [(3, 5), (2, 3, 5), (2, 5, 3, 4)]:
929      channels_in = input_shape[0]
930      input_spatial_shape = input_shape[1:]
931      dim = len(input_spatial_shape)
932
933      inputs = np.random.normal(0, 1, (batch_size,) + input_shape)
934      inputs_cf = backend.variable(inputs)
935
936      for kernel_size in [1, 2]:
937        for stride in [1, 2]:
938          kernel_sizes = (kernel_size,) * dim
939          strides = (stride,) * dim
940
941          output_shape = tuple([
942              (i - kernel_size + stride) // stride for i in input_spatial_shape
943          ])
944
945          kernel_shape = (np.prod(output_shape),
946                          np.prod(kernel_sizes) * channels_in, filters)
947
948          kernel = np.random.normal(
949              0, 1,
950              output_shape + (channels_in, np.prod(kernel_sizes), filters))
951
952          kernel_cf = np.reshape(kernel, kernel_shape)
953          kernel_cf = backend.variable(kernel_cf)
954
955          conv_cf = backend.local_conv(inputs_cf, kernel_cf, kernel_sizes,
956                                       strides, output_shape, 'channels_first')
957
958          inputs_cl = np.transpose(inputs,
959                                   [0, 2] + list(range(3, dim + 2)) + [1])
960          inputs_cl = backend.variable(inputs_cl)
961
962          kernel_cl = np.reshape(
963              np.transpose(kernel,
964                           list(range(dim)) + [dim + 1, dim, dim + 2]),
965              kernel_shape)
966          kernel_cl = backend.variable(kernel_cl)
967
968          conv_cl = backend.local_conv(inputs_cl, kernel_cl, kernel_sizes,
969                                       strides, output_shape, 'channels_last')
970
971          conv_cf = backend.eval(conv_cf)
972          conv_cl = backend.eval(conv_cl)
973
974          self.assertAllCloseAccordingToType(
975              conv_cf,
976              np.transpose(conv_cl, [0, dim + 1] + list(range(1, dim + 1))),
977              atol=1e-5)
978
979  @parameterized.named_parameters(
980      ('local_conv1d', (5, 6), (3,), (1,), (3,)),
981      ('local_conv2d', (4, 5, 6), (3, 3), (1, 1), (2, 3)))
982  def test_local_conv_1d_and_2d(self, input_shape, kernel_sizes, strides,
983                                output_shape):
984    filters = 3
985    batch_size = 2
986
987    inputs = np.random.normal(0, 1, (batch_size,) + input_shape)
988    inputs = backend.variable(inputs)
989
990    kernel = np.random.normal(0, 1,
991                              (np.prod(output_shape), np.prod(kernel_sizes) *
992                               input_shape[-1], filters))
993    kernel = backend.variable(kernel)
994
995    local_conv = backend.local_conv(inputs, kernel, kernel_sizes, strides,
996                                    output_shape, 'channels_last')
997    if len(output_shape) == 1:
998      local_conv_dim = backend.local_conv1d(inputs, kernel, kernel_sizes,
999                                            strides, 'channels_last')
1000    else:
1001      local_conv_dim = backend.local_conv2d(inputs, kernel, kernel_sizes,
1002                                            strides, output_shape,
1003                                            'channels_last')
1004
1005    local_conv = backend.eval(local_conv)
1006    local_conv_dim = backend.eval(local_conv_dim)
1007
1008    self.assertAllCloseAccordingToType(local_conv, local_conv_dim)
1009
1010  def test_conv2d(self):
1011    kernel_val = np.random.random((3, 3, 4, 5))
1012    k = backend.variable(kernel_val)
1013
1014    # Test channels_first
1015    val = np.random.random((10, 4, 10, 10))
1016    x = backend.variable(val)
1017    y = backend.conv2d(x, k, padding='valid', data_format='channels_first')
1018    self.assertEqual(y.shape.as_list(), [10, 5, 8, 8])
1019
1020    # Test channels_last
1021    val = np.random.random((10, 10, 10, 4))
1022    x = backend.variable(val)
1023    y = backend.conv2d(
1024        x, k, strides=(1, 1), padding='valid', data_format='channels_last')
1025    self.assertEqual(y.shape.as_list(), [10, 8, 8, 5])
1026
1027    # Test same padding
1028    val = np.random.random((10, 10, 10, 4))
1029    x = backend.variable(val)
1030    y = backend.conv2d(x, k, padding='same', data_format='channels_last')
1031    self.assertEqual(y.shape.as_list(), [10, 10, 10, 5])
1032
1033    # Test dilation_rate
1034    val = np.random.random((10, 10, 10, 4))
1035    x = backend.variable(val)
1036    y = backend.conv2d(
1037        x, k, dilation_rate=(2, 2), padding='same', data_format='channels_last')
1038    self.assertEqual(y.shape.as_list(), [10, 10, 10, 5])
1039
1040    # Test strides
1041    val = np.random.random((10, 10, 10, 4))
1042    x = backend.variable(val)
1043    y = backend.conv2d(
1044        x, k, strides=(2, 2), padding='same', data_format='channels_last')
1045    self.assertEqual(y.shape.as_list(), [10, 5, 5, 5])
1046
1047    # Test invalid arguments
1048    with self.assertRaises(ValueError):
1049      y = backend.conv2d(
1050          x, k, (2, 2), padding='other', data_format='channels_last')
1051    with self.assertRaises(ValueError):
1052      y = backend.conv2d(x, k, (2, 2), data_format='other')
1053    with self.assertRaises(ValueError):
1054      y = backend.conv2d(x, k, (2, 2, 2))
1055
1056  def test_conv2d_transpose(self):
1057    input_size = (7, 8)
1058    kernel_size = (3, 3)
1059    input_depth = 6
1060    filters = 6
1061    batch_size = 2
1062
1063    kernel_val = np.random.random(kernel_size + (input_depth, filters))
1064    k = backend.variable(kernel_val)
1065
1066    # Test channels_first
1067    input_val = np.random.random((batch_size, input_depth) + input_size)
1068    x = backend.variable(input_val)
1069    y = backend.conv2d_transpose(
1070        x,
1071        k, (batch_size, filters) + input_size,
1072        padding='same',
1073        data_format='channels_first')
1074    self.assertEqual(
1075        tuple(y.shape.as_list()), (batch_size, filters) + input_size)
1076
1077    # Test channels_last
1078    input_val = np.random.random((batch_size,) + input_size + (input_depth,))
1079    x = backend.variable(input_val)
1080    y = backend.conv2d_transpose(
1081        x,
1082        k, (batch_size,) + input_size + (filters,),
1083        padding='same',
1084        data_format='channels_last')
1085    self.assertEqual(
1086        tuple(y.shape.as_list()), (batch_size,) + input_size + (filters,))
1087
1088    # Test dilation_rate
1089    y = backend.conv2d_transpose(
1090        x,
1091        k, (batch_size,) + input_size + (filters,),
1092        padding='same',
1093        data_format='channels_last',
1094        dilation_rate=(2, 2))
1095    self.assertEqual(
1096        tuple(y.shape.as_list()), (batch_size,) + input_size + (filters,))
1097
1098    # Test batch size of None in output_shape
1099    y = backend.conv2d_transpose(
1100        x,
1101        k, (None,) + input_size + (filters,),
1102        padding='same',
1103        data_format='channels_last')
1104    self.assertEqual(
1105        tuple(y.shape.as_list()), (batch_size,) + input_size + (filters,))
1106
1107    # Test invalid values
1108    with self.assertRaises(ValueError):
1109      y = backend.conv2d_transpose(
1110          x, k, (2, 2, 8, 9), padding='other', data_format='channels_last')
1111    with self.assertRaises(ValueError):
1112      y = backend.conv2d_transpose(x, k, (2, 2, 8, 9), data_format='other')
1113
1114  def test_separable_conv2d(self):
1115    val = np.random.random((10, 4, 10, 10))
1116    x = backend.variable(val)
1117    depthwise_kernel_val = np.random.random((3, 3, 4, 1))
1118    pointwise_kernel_val = np.random.random((1, 1, 4, 5))
1119    dk = backend.variable(depthwise_kernel_val)
1120    pk = backend.variable(pointwise_kernel_val)
1121    y = backend.separable_conv2d(
1122        x, dk, pk, padding='valid', data_format='channels_first')
1123    self.assertEqual(y.shape.as_list(), [10, 5, 8, 8])
1124
1125    val = np.random.random((10, 10, 10, 4))
1126    x = backend.variable(val)
1127    y = backend.separable_conv2d(
1128        x, dk, pk, strides=(1, 1), padding='valid', data_format='channels_last')
1129    self.assertEqual(y.shape.as_list(), [10, 8, 8, 5])
1130
1131    val = np.random.random((10, 10, 10, 4))
1132    x = backend.variable(val)
1133    y = backend.separable_conv2d(
1134        x, dk, pk, strides=(1, 1), padding='same', data_format='channels_last')
1135    self.assertEqual(y.shape.as_list(), [10, 10, 10, 5])
1136
1137    val = np.random.random((10, 10, 10, 4))
1138    x = backend.variable(val)
1139    y = backend.separable_conv2d(
1140        x, dk, pk, strides=(2, 2), padding='same', data_format='channels_last')
1141    self.assertEqual(y.shape.as_list(), [10, 5, 5, 5])
1142    with self.assertRaises(ValueError):
1143      y = backend.separable_conv2d(
1144          x, dk, pk, (2, 2), padding='other', data_format='channels_last')
1145    with self.assertRaises(ValueError):
1146      y = backend.separable_conv2d(x, dk, pk, (2, 2), data_format='other')
1147    with self.assertRaises(ValueError):
1148      y = backend.separable_conv2d(x, dk, pk, (2, 2, 2))
1149
1150  def test_conv3d(self):
1151    val = np.random.random((10, 4, 10, 10, 10))
1152    x = backend.variable(val)
1153    kernel_val = np.random.random((3, 3, 3, 4, 5))
1154    k = backend.variable(kernel_val)
1155    y = backend.conv3d(x, k, padding='valid', data_format='channels_first')
1156    self.assertEqual(y.shape.as_list(), [10, 5, 8, 8, 8])
1157
1158    val = np.random.random((10, 10, 10, 10, 4))
1159    x = backend.variable(val)
1160    y = backend.conv3d(
1161        x, k, strides=(1, 1, 1), padding='valid', data_format='channels_last')
1162    self.assertEqual(y.shape.as_list(), [10, 8, 8, 8, 5])
1163
1164    val = np.random.random((10, 10, 10, 10, 4))
1165    x = backend.variable(val)
1166    y = backend.conv3d(
1167        x, k, strides=(1, 1, 1), padding='same', data_format='channels_last')
1168    self.assertEqual(y.shape.as_list(), [10, 10, 10, 10, 5])
1169
1170    val = np.random.random((10, 10, 10, 10, 4))
1171    x = backend.variable(val)
1172    y = backend.conv3d(
1173        x, k, strides=(2, 2, 2), padding='same', data_format='channels_last')
1174    self.assertEqual(y.shape.as_list(), [10, 5, 5, 5, 5])
1175    with self.assertRaises(ValueError):
1176      y = backend.conv3d(
1177          x, k, (2, 2, 2), padding='other', data_format='channels_last')
1178    with self.assertRaises(ValueError):
1179      y = backend.conv3d(x, k, (2, 2, 2), data_format='other')
1180    with self.assertRaises(ValueError):
1181      y = backend.conv3d(x, k, (2, 2))
1182
1183  def test_rnn(self):
1184    # implement a simple RNN
1185    num_samples = 4
1186    input_dim = 5
1187    output_dim = 3
1188    timesteps = 6
1189
1190    input_val = np.random.random(
1191        (num_samples, timesteps, input_dim)).astype(np.float32)
1192    init_state_val = np.random.random(
1193        (num_samples, output_dim)).astype(np.float32)
1194    w_i_val = np.random.random((input_dim, output_dim)).astype(np.float32)
1195    w_o_val = np.random.random((output_dim, output_dim)).astype(np.float32)
1196    np_mask = np.random.randint(2, size=(num_samples, timesteps))
1197
1198    def rnn_step_fn():
1199      w_i = backend.variable(w_i_val)
1200      w_o = backend.variable(w_o_val)
1201
1202      def step_function(x, states):
1203        assert len(states) == 1
1204        prev_output = states[0]
1205        output = backend.dot(x, w_i) + backend.dot(prev_output, w_o)
1206        return output, [output]
1207
1208      return step_function
1209
1210    # test default setup
1211    last_output_list = [[], [], [], [], [], []]
1212    outputs_list = [[], [], [], [], [], []]
1213    state_list = [[], [], [], [], [], []]
1214
1215    rnn_fn = rnn_step_fn()
1216    inputs = backend.variable(input_val)
1217    initial_states = [backend.variable(init_state_val)]
1218    mask = backend.variable(np_mask)
1219
1220    kwargs_list = [
1221        {
1222            'go_backwards': False,
1223            'mask': None
1224        },
1225        {
1226            'go_backwards': False,
1227            'mask': None,
1228            'unroll': True
1229        },
1230        {
1231            'go_backwards': True,
1232            'mask': None
1233        },
1234        {
1235            'go_backwards': True,
1236            'mask': None,
1237            'unroll': True
1238        },
1239        {
1240            'go_backwards': False,
1241            'mask': mask
1242        },
1243        {
1244            'go_backwards': False,
1245            'mask': mask,
1246            'unroll': True
1247        },
1248    ]
1249    for i, kwargs in enumerate(kwargs_list):
1250      last_output, outputs, new_states = backend.rnn(rnn_fn, inputs,
1251                                                     initial_states, **kwargs)
1252      # check static shape inference
1253      self.assertEqual(last_output.shape.as_list(), [num_samples, output_dim])
1254      self.assertEqual(outputs.shape.as_list(),
1255                       [num_samples, timesteps, output_dim])
1256      for state in new_states:
1257        self.assertEqual(state.shape.as_list(), [num_samples, output_dim])
1258
1259      last_output_list[i].append(backend.eval(last_output))
1260      outputs_list[i].append(backend.eval(outputs))
1261      self.assertLen(new_states, 1)
1262      state_list[i].append(backend.eval(new_states[0]))
1263
1264      def assert_list_pairwise(z_list, atol=1e-05):
1265        for (z1, z2) in zip(z_list[1:], z_list[:-1]):
1266          self.assertAllClose(z1, z2, atol=atol)
1267
1268      assert_list_pairwise(last_output_list[0], atol=1e-04)
1269      assert_list_pairwise(outputs_list[0], atol=1e-04)
1270      assert_list_pairwise(state_list[0], atol=1e-04)
1271      assert_list_pairwise(last_output_list[2], atol=1e-04)
1272      assert_list_pairwise(outputs_list[2], atol=1e-04)
1273      assert_list_pairwise(state_list[2], atol=1e-04)
1274
1275      for l, u_l in zip(last_output_list[0], last_output_list[1]):
1276        self.assertAllClose(l, u_l, atol=1e-04)
1277
1278      for o, u_o in zip(outputs_list[0], outputs_list[1]):
1279        self.assertAllClose(o, u_o, atol=1e-04)
1280
1281      for s, u_s in zip(state_list[0], state_list[1]):
1282        self.assertAllClose(s, u_s, atol=1e-04)
1283
1284      for b_l, b_u_l in zip(last_output_list[2], last_output_list[3]):
1285        self.assertAllClose(b_l, b_u_l, atol=1e-04)
1286
1287      for b_o, b_u_o in zip(outputs_list[2], outputs_list[3]):
1288        self.assertAllClose(b_o, b_u_o, atol=1e-04)
1289
1290      for b_s, b_u_s in zip(state_list[2], state_list[3]):
1291        self.assertAllClose(b_s, b_u_s, atol=1e-04)
1292
1293  def test_rnn_additional_states(self):
1294    # implement a simple RNN
1295    num_samples = 4
1296    input_dim = 5
1297    output_dim = 3
1298    timesteps = 6
1299
1300    input_val = np.random.random(
1301        (num_samples, timesteps, input_dim)).astype(np.float32)
1302    init_state_val = np.random.random(
1303        (num_samples, output_dim)).astype(np.float32)
1304    w_i_val = np.random.random((input_dim, output_dim)).astype(np.float32)
1305    w_o_val = np.random.random((output_dim, output_dim)).astype(np.float32)
1306    np_mask = np.random.randint(2, size=(num_samples, timesteps))
1307
1308    def rnn_step_fn():
1309      w_i = backend.variable(w_i_val)
1310      w_o = backend.variable(w_o_val)
1311
1312      def step_function(x, states):
1313        assert len(states) == 2
1314        prev_output = states[0]
1315        output = backend.dot(x, w_i) + backend.dot(prev_output, w_o)
1316        return output, [output, backend.concatenate([output, output], axis=-1)]
1317
1318      return step_function
1319
1320    # test default setup
1321    last_output_list = [[], [], [], [], [], []]
1322    outputs_list = [[], [], [], [], [], []]
1323    state_list = [[], [], [], [], [], []]
1324    additional_state_list = [[], [], [], [], [], []]
1325
1326    rnn_fn = rnn_step_fn()
1327    inputs = backend.variable(input_val)
1328    initial_states = [
1329        backend.variable(init_state_val),
1330        ops.convert_to_tensor_v2_with_dispatch(
1331            np.concatenate([init_state_val, init_state_val], axis=-1))
1332    ]
1333    mask = backend.variable(np_mask)
1334
1335    kwargs_list = [
1336        {
1337            'go_backwards': False,
1338            'mask': None
1339        },
1340        {
1341            'go_backwards': False,
1342            'mask': None,
1343            'unroll': True
1344        },
1345        {
1346            'go_backwards': True,
1347            'mask': None
1348        },
1349        {
1350            'go_backwards': True,
1351            'mask': None,
1352            'unroll': True
1353        },
1354        {
1355            'go_backwards': False,
1356            'mask': mask
1357        },
1358        {
1359            'go_backwards': False,
1360            'mask': mask,
1361            'unroll': True
1362        },
1363    ]
1364    for i, kwargs in enumerate(kwargs_list):
1365      last_output, outputs, new_states = backend.rnn(rnn_fn, inputs,
1366                                                     initial_states, **kwargs)
1367      # check static shape inference
1368      self.assertEqual(last_output.shape.as_list(), [num_samples, output_dim])
1369      self.assertEqual(outputs.shape.as_list(),
1370                       [num_samples, timesteps, output_dim])
1371      # for state in new_states:
1372      #   self.assertEqual(state.shape.as_list(),
1373      #                     [num_samples, output_dim])
1374      self.assertEqual(new_states[0].shape.as_list(), [num_samples, output_dim])
1375      self.assertEqual(new_states[1].shape.as_list(),
1376                       [num_samples, 2 * output_dim])
1377
1378      last_output_list[i].append(backend.eval(last_output))
1379      outputs_list[i].append(backend.eval(outputs))
1380      self.assertLen(new_states, 2)
1381      state_list[i].append(backend.eval(new_states[0]))
1382      additional_state_list[i].append(backend.eval(new_states[1]))
1383
1384      def assert_list_pairwise(z_list, atol=1e-05):
1385        for (z1, z2) in zip(z_list[1:], z_list[:-1]):
1386          self.assertAllClose(z1, z2, atol=atol)
1387
1388      assert_list_pairwise(last_output_list[0], atol=1e-04)
1389      assert_list_pairwise(outputs_list[0], atol=1e-04)
1390      assert_list_pairwise(state_list[0], atol=1e-04)
1391      assert_list_pairwise(additional_state_list[0], atol=1e-04)
1392      assert_list_pairwise(last_output_list[2], atol=1e-04)
1393      assert_list_pairwise(outputs_list[2], atol=1e-04)
1394      assert_list_pairwise(state_list[2], atol=1e-04)
1395      assert_list_pairwise(additional_state_list[2], atol=1e-04)
1396
1397      for l, u_l in zip(last_output_list[0], last_output_list[1]):
1398        self.assertAllClose(l, u_l, atol=1e-04)
1399
1400      for o, u_o in zip(outputs_list[0], outputs_list[1]):
1401        self.assertAllClose(o, u_o, atol=1e-04)
1402
1403      for s, u_s in zip(state_list[0], state_list[1]):
1404        self.assertAllClose(s, u_s, atol=1e-04)
1405
1406      for s, u_s in zip(additional_state_list[0], additional_state_list[1]):
1407        self.assertAllClose(s, u_s, atol=1e-04)
1408
1409      for b_l, b_u_l in zip(last_output_list[2], last_output_list[3]):
1410        self.assertAllClose(b_l, b_u_l, atol=1e-04)
1411
1412      for b_o, b_u_o in zip(outputs_list[2], outputs_list[3]):
1413        self.assertAllClose(b_o, b_u_o, atol=1e-04)
1414
1415      for b_s, b_u_s in zip(state_list[2], state_list[3]):
1416        self.assertAllClose(b_s, b_u_s, atol=1e-04)
1417
1418      for s, u_s in zip(additional_state_list[2], additional_state_list[3]):
1419        self.assertAllClose(s, u_s, atol=1e-04)
1420
1421  def test_rnn_output_and_state_masking_independent(self):
1422    num_samples = 2
1423    num_timesteps = 4
1424    state_and_io_size = 2
1425    mask_last_num_timesteps = 2  # for second sample only
1426
1427    # a step function that just outputs inputs,
1428    # but increments states +1 per timestep
1429    def step_function(inputs, states):
1430      return inputs, [s + 1 for s in states]
1431
1432    inputs_vals = np.random.random(
1433        (num_samples, num_timesteps, state_and_io_size))
1434    initial_state_vals = np.random.random((num_samples, state_and_io_size))
1435    # masking of two last timesteps for second sample only
1436    mask_vals = np.ones((num_samples, num_timesteps))
1437    mask_vals[1, -mask_last_num_timesteps:] = 0
1438
1439    # outputs expected to be same as inputs for the first sample
1440    expected_outputs = inputs_vals.copy()
1441    # but for the second sample all outputs in masked region should be the same
1442    # as last output before masked region
1443    expected_outputs[1, -mask_last_num_timesteps:] = \
1444        expected_outputs[1, -(mask_last_num_timesteps + 1)]
1445
1446    expected_last_state = initial_state_vals.copy()
1447    # first state should be incremented for every timestep (no masking)
1448    expected_last_state[0] += num_timesteps
1449    # second state should not be incremented for last two timesteps
1450    expected_last_state[1] += (num_timesteps - mask_last_num_timesteps)
1451
1452    # verify same expected output for `unroll=true/false`
1453    inputs = backend.variable(inputs_vals)
1454    initial_states = [backend.variable(initial_state_vals)]
1455    mask = backend.variable(mask_vals)
1456    for unroll in [True, False]:
1457      _, outputs, last_states = backend.rnn(
1458          step_function,
1459          inputs,
1460          initial_states,
1461          mask=mask,
1462          unroll=unroll,
1463          input_length=num_timesteps if unroll else None)
1464
1465      self.assertAllClose(backend.eval(outputs), expected_outputs)
1466      self.assertAllClose(backend.eval(last_states[0]), expected_last_state)
1467
1468  def test_rnn_output_num_dim_larger_than_2_masking(self):
1469    num_samples = 3
1470    num_timesteps = 4
1471    num_features = 5
1472
1473    def step_function(inputs, states):
1474      outputs = backend.tile(backend.expand_dims(inputs), [1, 1, 2])
1475      return outputs, [backend.identity(s) for s in states]
1476      # Note: cannot just return states (which can be a problem) ->
1477      # tensorflow/python/ops/resource_variable_ops.py", line 824, in set_shape
1478      # NotImplementedError: ResourceVariable does not implement set_shape()
1479
1480    inputs_vals = np.random.random((num_samples, num_timesteps, num_features))
1481    initial_state_vals = np.random.random((num_samples, 6))
1482    mask_vals = np.ones((num_samples, num_timesteps))
1483    mask_vals[-1, -1] = 0  # final timestep masked for last sample
1484
1485    expected_outputs = np.repeat(inputs_vals[..., None], repeats=2, axis=-1)
1486    # for the last sample, the final timestep (in masked region) should be the
1487    # same as the second to final output (before masked region)
1488    expected_outputs[-1, -1] = expected_outputs[-1, -2]
1489
1490    inputs = backend.variable(inputs_vals)
1491    initial_states = [backend.variable(initial_state_vals)]
1492    mask = backend.variable(mask_vals)
1493    for unroll in [True, False]:
1494      _, outputs, _ = backend.rnn(
1495          step_function,
1496          inputs,
1497          initial_states,
1498          mask=mask,
1499          unroll=unroll,
1500          input_length=num_timesteps if unroll else None)
1501
1502      self.assertAllClose(backend.eval(outputs), expected_outputs)
1503
1504  def test_rnn_state_num_dim_larger_than_2_masking(self):
1505    num_samples = 3
1506    num_timesteps = 4
1507
1508    def step_function(inputs, states):
1509      return inputs, [s + 1 for s in states]
1510
1511    inputs_vals = np.random.random((num_samples, num_timesteps, 5))
1512    initial_state_vals = np.random.random((num_samples, 6, 7))
1513    mask_vals = np.ones((num_samples, num_timesteps))
1514    mask_vals[0, -2:] = 0  # final two timesteps masked for first sample
1515
1516    expected_last_state = initial_state_vals.copy()
1517    expected_last_state[0] += (num_timesteps - 2)
1518    expected_last_state[1:] += num_timesteps
1519
1520    inputs = backend.variable(inputs_vals)
1521    initial_states = [backend.variable(initial_state_vals)]
1522    mask = backend.variable(mask_vals)
1523    for unroll in [True, False]:
1524      _, _, last_states = backend.rnn(
1525          step_function,
1526          inputs,
1527          initial_states,
1528          mask=mask,
1529          unroll=unroll,
1530          input_length=num_timesteps if unroll else None)
1531
1532      self.assertAllClose(backend.eval(last_states[0]), expected_last_state)
1533
1534  def test_batch_normalization(self):
1535    g_val = np.random.random((3,))
1536    b_val = np.random.random((3,))
1537    gamma = backend.variable(g_val)
1538    beta = backend.variable(b_val)
1539
1540    # 3D NHC case
1541    val = np.random.random((10, 5, 3))
1542    x = backend.variable(val)
1543    mean, var = nn.moments(x, (0, 1), None, None, False)
1544    normed = backend.batch_normalization(
1545        x, mean, var, beta, gamma, axis=-1, epsilon=1e-3)
1546    self.assertEqual(normed.shape.as_list(), [10, 5, 3])
1547
1548    # 4D NHWC case
1549    val = np.random.random((10, 5, 5, 3))
1550    x = backend.variable(val)
1551    mean, var = nn.moments(x, (0, 1, 2), None, None, False)
1552    normed = backend.batch_normalization(
1553        x, mean, var, beta, gamma, axis=-1, epsilon=1e-3)
1554    self.assertEqual(normed.shape.as_list(), [10, 5, 5, 3])
1555
1556    # 4D NCHW case
1557    if not context.executing_eagerly():
1558      # Eager CPU kernel for NCHW does not exist.
1559      val = np.random.random((10, 3, 5, 5))
1560      x = backend.variable(val)
1561      mean, var = nn.moments(x, (0, 2, 3), None, None, False)
1562      normed = backend.batch_normalization(
1563          x, mean, var, beta, gamma, axis=1, epsilon=1e-3)
1564      self.assertEqual(normed.shape.as_list(), [10, 3, 5, 5])
1565
1566  def test_normalize_batch_in_training(self):
1567    val = np.random.random((10, 3, 10, 10))
1568    x = backend.variable(val)
1569    reduction_axes = (0, 2, 3)
1570
1571    g_val = np.random.random((3,))
1572    b_val = np.random.random((3,))
1573    gamma = backend.variable(g_val)
1574    beta = backend.variable(b_val)
1575    normed, mean, var = backend.normalize_batch_in_training(
1576        x, gamma, beta, reduction_axes, epsilon=1e-3)
1577    self.assertEqual(normed.shape.as_list(), [10, 3, 10, 10])
1578    self.assertEqual(mean.shape.as_list(), [
1579        3,
1580    ])
1581    self.assertEqual(var.shape.as_list(), [
1582        3,
1583    ])
1584
1585    # case: gamma=None
1586    gamma = None
1587    normed, mean, var = backend.normalize_batch_in_training(
1588        x, gamma, beta, reduction_axes, epsilon=1e-3)
1589    self.assertEqual(normed.shape.as_list(), [10, 3, 10, 10])
1590    self.assertEqual(mean.shape.as_list(), [
1591        3,
1592    ])
1593    self.assertEqual(var.shape.as_list(), [
1594        3,
1595    ])
1596
1597    # case: beta=None
1598    beta = None
1599    normed, mean, var = backend.normalize_batch_in_training(
1600        x, gamma, beta, reduction_axes, epsilon=1e-3)
1601    self.assertEqual(normed.shape.as_list(), [10, 3, 10, 10])
1602    self.assertEqual(mean.shape.as_list(), [
1603        3,
1604    ])
1605    self.assertEqual(var.shape.as_list(), [
1606        3,
1607    ])
1608
1609  def test_dropout(self):
1610    inputs = array_ops.ones((200, 200))
1611    outputs = backend.dropout(inputs, 0.2)
1612    outputs_val = backend.eval(outputs)
1613    self.assertEqual(np.min(outputs_val), 0)
1614    self.assertAllClose(np.count_nonzero(outputs_val), 32000, atol=1000)
1615    # Test noise shape
1616    outputs = backend.dropout(inputs, 0.2, noise_shape=(200, 1))
1617    outputs_val = backend.eval(outputs)
1618    self.assertAllClose(outputs_val[2, :], outputs_val[3, :], atol=1e-5)
1619
1620
1621class BackendCrossEntropyLossesTest(test.TestCase, parameterized.TestCase):
1622
1623  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1624  def test_binary_crossentropy_with_sigmoid(self):
1625    t = backend.constant([[0, 1, 0]])
1626    logits = backend.constant([[8., 1., 1.]])
1627    p = backend.sigmoid(logits)
1628    p = array_ops.identity(array_ops.identity(p))
1629    result = self.evaluate(backend.binary_crossentropy(t, p))
1630    self.assertArrayNear(result[0], [8., 0.313, 1.313], 1e-3)
1631
1632  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1633  def test_categorical_crossentropy_loss(self):
1634    t = backend.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
1635
1636    p = backend.constant([[.9, .05, .05], [.05, .89, .06], [.05, .01, .94]])
1637    result = backend.categorical_crossentropy(t, p)
1638    self.assertArrayNear(self.evaluate(result), [.105, .116, .062], 1e-3)
1639
1640    p = backend.constant([[.9, .05, .05], [.05, .89, .01], [.05, .06, .94]])
1641    result = backend.categorical_crossentropy(t, p, axis=0)
1642    self.assertArrayNear(self.evaluate(result), [.105, .116, .062], 1e-3)
1643
1644    p = backend.constant([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
1645    result = backend.categorical_crossentropy(t, p, from_logits=True),
1646    self.assertArrayNear(self.evaluate(result)[0], [.002, 0, .17], 1e-3)
1647
1648    p = backend.constant([[8., 0., 2.], [1., 9., 3.], [1., 1., 5.]])
1649    result = backend.categorical_crossentropy(t, p, from_logits=True, axis=0),
1650    self.assertArrayNear(self.evaluate(result)[0], [.002, 0, .17], 1e-3)
1651
1652  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1653  def test_categorical_crossentropy_loss_with_unknown_rank_tensor(self):
1654    t = backend.placeholder()
1655    p = backend.placeholder()
1656    o = backend.categorical_crossentropy(t, p)
1657
1658    t_val = ops.convert_to_tensor_v2_with_dispatch([[1., 0., 0.], [0., 1., 0.],
1659                                                    [0., 0., 1.]])
1660    p_val = ops.convert_to_tensor_v2_with_dispatch([[.9, .05, .05],
1661                                                    [.05, .89, .06],
1662                                                    [.05, .01, .94]])
1663    f = backend.function([t, p], o)
1664
1665    result = f([t_val, p_val])
1666    self.assertArrayNear(result, [.105, .116, .062], 1e-3)
1667
1668    # With axis set
1669    o = backend.categorical_crossentropy(t, p, axis=0)
1670    f = backend.function([t, p], o)
1671
1672    result = f([t_val, p_val])
1673    self.assertArrayNear(result, [.105, .065, .111], 1e-3)
1674
1675    # from logits
1676    p_val = ops.convert_to_tensor_v2_with_dispatch([[8., 1., 1.], [0., 9., 1.],
1677                                                    [2., 3., 5.]])
1678    o = backend.categorical_crossentropy(t, p, from_logits=True)
1679    f = backend.function([t, p], o)
1680
1681    result = f([t_val, p_val])
1682    self.assertArrayNear(result, [.002, 0, .17], 1e-3)
1683
1684    # from logits and axis set
1685    o = backend.categorical_crossentropy(t, p, from_logits=True, axis=0)
1686    f = backend.function([t, p], o)
1687
1688    result = f([t_val, p_val])
1689    self.assertArrayNear(result, [.002, .003, .036], 1e-3)
1690
1691  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1692  def test_categorical_crossentropy_with_softmax(self):
1693    t = backend.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
1694    logits = backend.constant([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
1695    p = backend.softmax(logits)
1696    p = array_ops.identity(array_ops.identity(p))
1697    result = self.evaluate(backend.categorical_crossentropy(t, p))
1698    self.assertArrayNear(result, [0.002, 0.0005, 0.17], 1e-3)
1699
1700  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1701  def test_sparse_categorical_crossentropy_loss(self):
1702    t = backend.constant([0, 1, 2])
1703
1704    p = backend.constant([[.9, .05, .05], [.05, .89, .06], [.05, .01, .94]])
1705    result = backend.sparse_categorical_crossentropy(t, p)
1706    self.assertArrayNear(self.evaluate(result), [.105, .116, .062], 1e-3)
1707
1708    p = backend.constant([[.9, .05, .05], [.05, .89, .01], [.05, .06, .94]])
1709    result = backend.sparse_categorical_crossentropy(t, p, axis=0)
1710    self.assertArrayNear(self.evaluate(result), [.105, .116, .062], 1e-3)
1711
1712    p = backend.constant([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
1713    result = backend.sparse_categorical_crossentropy(t, p, from_logits=True),
1714    self.assertArrayNear(self.evaluate(result)[0], [.002, 0, .17], 1e-3)
1715
1716    p = backend.constant([[8., 0., 2.], [1., 9., 3.], [1., 1., 5.]])
1717    result = backend.sparse_categorical_crossentropy(
1718        t, p, from_logits=True, axis=0),
1719    self.assertArrayNear(self.evaluate(result)[0], [.002, 0, .17], 1e-3)
1720
1721  @combinations.generate(combinations.combine(mode=['graph']))
1722  def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self):
1723    # This test only runs in graph because the TF op layer is not supported yet
1724    # for sparse ops.
1725    t = backend.placeholder()
1726    p = backend.placeholder()
1727    o = backend.sparse_categorical_crossentropy(t, p)
1728
1729    t_val = ops.convert_to_tensor_v2_with_dispatch([0, 1, 2])
1730    p_val = ops.convert_to_tensor_v2_with_dispatch([[.9, .05, .05],
1731                                                    [.05, .89, .06],
1732                                                    [.05, .01, .94]])
1733    f = backend.function([t, p], o)
1734
1735    result = f([t_val, p_val])
1736    self.assertArrayNear(result, [.105, .116, .062], 1e-3)
1737
1738    # With axis set
1739    with self.assertRaisesRegex(
1740        ValueError,
1741        'Cannot compute sparse categorical crossentropy with `axis=0`'):
1742      o = backend.sparse_categorical_crossentropy(t, p, axis=0)
1743      f = backend.function([t, p], o)
1744
1745      _ = f([t_val, p_val])
1746
1747    # from logits
1748    p_val = ops.convert_to_tensor_v2_with_dispatch([[8., 1., 1.], [0., 9., 1.],
1749                                                    [2., 3., 5.]])
1750    o = backend.sparse_categorical_crossentropy(t, p, from_logits=True)
1751    f = backend.function([t, p], o)
1752
1753    result = f([t_val, p_val])
1754    self.assertArrayNear(result, [.002, 0, .17], 1e-3)
1755
1756    # from logits and axis set
1757    with self.assertRaisesRegex(
1758        ValueError,
1759        'Cannot compute sparse categorical crossentropy with `axis=0`'):
1760      o = backend.sparse_categorical_crossentropy(
1761          t, p, from_logits=True, axis=0)
1762      f = backend.function([t, p], o)
1763
1764      _ = f([t_val, p_val])
1765
1766  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1767  def test_sparse_categorical_crossentropy_with_softmax(self):
1768    t = backend.constant([0, 1, 2])
1769    logits = backend.constant([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
1770    p = backend.softmax(logits)
1771    p = array_ops.identity(array_ops.identity(p))
1772    result = self.evaluate(backend.sparse_categorical_crossentropy(t, p))
1773    self.assertArrayNear(result, [0.002, 0.0005, 0.17], 1e-3)
1774
1775  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1776  def test_binary_crossentropy_from_logits_no_warnings(self):
1777    t = backend.constant([[0, 1, 0]])
1778    logits = backend.constant([[8., 1., 1.]])
1779    with warnings.catch_warnings(record=True) as w:
1780      self.evaluate(backend.binary_crossentropy(t, logits, from_logits=True))
1781      self.assertEmpty(w)
1782
1783  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1784  def test_binary_crossentropy_from_logits_with_sigmoid(self):
1785    t = backend.constant([[0, 1, 0]])
1786    logits = backend.constant([[8., 1., 1.]])
1787    p = activations.sigmoid(logits)
1788    with warnings.catch_warnings(record=True) as w:
1789      self.evaluate(backend.binary_crossentropy(t, p, from_logits=True))
1790      self.assertLen(w, 1)
1791      self.assertIn('received `from_logits=True`', str(w[0].message))
1792
1793  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1794  def test_categorical_crossentropy_from_logits_with_softmax(self):
1795    t = backend.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
1796    logits = backend.constant([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
1797    p = activations.softmax(logits)
1798    with warnings.catch_warnings(record=True) as w:
1799      self.evaluate(backend.categorical_crossentropy(t, p, from_logits=True))
1800      self.assertLen(w, 1)
1801      self.assertIn('received `from_logits=True`', str(w[0].message))
1802
1803  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1804  def test_sparse_categorical_crossentropy_from_logits_with_softmax(self):
1805    t = backend.constant([0, 1, 2])
1806    logits = backend.constant([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
1807    p = activations.softmax(logits)
1808    with warnings.catch_warnings(record=True) as w:
1809      self.evaluate(
1810          backend.sparse_categorical_crossentropy(t, p, from_logits=True))
1811      self.assertLen(w, 1)
1812      self.assertIn('received `from_logits=True`', str(w[0].message))
1813
1814
1815@test_util.with_control_flow_v2
1816@combinations.generate(combinations.combine(mode=['graph', 'eager']))
1817class TestCTC(test.TestCase):
1818
1819  def test_ctc_decode(self):
1820    depth = 6
1821    seq_len_0 = 5
1822    input_prob_matrix_0 = np.asarray(
1823        [
1824            [0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908],
1825            [0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517],
1826            [0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763],
1827            [0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655],
1828            [0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878],
1829            # Random entry added in at time=5
1830            [0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671]
1831        ],
1832        dtype=np.float32)
1833
1834    # len max_time_steps array of batch_size x depth matrices
1835    inputs = (
1836        [input_prob_matrix_0[t, :][np.newaxis, :] for t in range(seq_len_0)
1837        ] +  # Pad to max_time_steps = 8
1838        2 * [np.zeros((1, depth), dtype=np.float32)])
1839
1840    inputs = backend.variable(np.asarray(inputs).transpose((1, 0, 2)))
1841
1842    # batch_size length vector of sequence_lengths
1843    input_length = backend.variable(np.array([seq_len_0], dtype=np.int32))
1844    # batch_size length vector of negative log probabilities
1845    log_prob_truth = np.array(
1846        [
1847            -3.5821197,  # output beam 0
1848            -3.777835  # output beam 1
1849        ],
1850        np.float32)[np.newaxis, :]
1851
1852    decode_truth = [
1853        np.array([1, 0, -1, -1, -1, -1, -1]),
1854        np.array([0, 1, 0, -1, -1, -1, -1])
1855    ]
1856    beam_width = 2
1857    top_paths = 2
1858
1859    decode_pred_tf, log_prob_pred_tf = backend.ctc_decode(
1860        inputs,
1861        input_length,
1862        greedy=False,
1863        beam_width=beam_width,
1864        top_paths=top_paths)
1865
1866    self.assertEqual(len(decode_pred_tf), top_paths)
1867    log_prob_pred = backend.eval(log_prob_pred_tf)
1868    for i in range(top_paths):
1869      self.assertTrue(
1870          np.alltrue(decode_truth[i] == backend.eval(decode_pred_tf[i])))
1871    self.assertAllClose(log_prob_truth, log_prob_pred)
1872
1873  def test_ctc_batch_cost(self):
1874    with self.cached_session():
1875      label_lens = np.expand_dims(np.asarray([5, 4]), 1)
1876      input_lens = np.expand_dims(np.asarray([5, 5]), 1)  # number of timesteps
1877      loss_log_probs = [3.34211, 5.42262]
1878
1879      # dimensions are batch x time x categories
1880      labels = np.asarray([[0, 1, 2, 1, 0], [0, 1, 1, 0, -1]])
1881      inputs = np.asarray(
1882          [[[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
1883            [0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
1884            [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
1885            [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
1886            [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]],
1887           [[0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508],
1888            [0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549],
1889            [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456],
1890            [0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345],
1891            [0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]]],
1892          dtype=np.float32)
1893
1894      labels = backend.variable(labels, dtype='int32')
1895      inputs = backend.variable(inputs, dtype='float32')
1896      input_lens = backend.variable(input_lens, dtype='int32')
1897      label_lens = backend.variable(label_lens, dtype='int32')
1898      res = backend.eval(
1899          backend.ctc_batch_cost(labels, inputs, input_lens, label_lens))
1900      self.assertAllClose(res[:, 0], loss_log_probs, atol=1e-05)
1901
1902      # test when batch_size = 1, that is, one sample only
1903      ref = [3.34211]
1904      input_lens = np.expand_dims(np.asarray([5]), 1)
1905      label_lens = np.expand_dims(np.asarray([5]), 1)
1906
1907      labels = np.asarray([[0, 1, 2, 1, 0]])
1908      inputs = np.asarray(
1909          [[[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
1910            [0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
1911            [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
1912            [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
1913            [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]]
1914          ],
1915          dtype=np.float32)
1916
1917      k_labels = backend.variable(labels, dtype='int32')
1918      k_inputs = backend.variable(inputs, dtype='float32')
1919      k_input_lens = backend.variable(input_lens, dtype='int32')
1920      k_label_lens = backend.variable(label_lens, dtype='int32')
1921      res = backend.eval(
1922          backend.ctc_batch_cost(k_labels, k_inputs, k_input_lens,
1923                                 k_label_lens))
1924      self.assertAllClose(res[:, 0], ref, atol=1e-05)
1925
1926
1927@combinations.generate(combinations.combine(mode=['graph', 'eager']))
1928class TestRandomOps(test.TestCase):
1929
1930  def test_random_normal(self):
1931    np.random.seed(123)
1932    x = backend.random_normal((500, 500))
1933    val = backend.eval(x)
1934    self.assertAllClose(np.mean(val), 0., atol=0.01)
1935    self.assertAllClose(np.std(val), 1., atol=0.01)
1936
1937  def test_random_uniform(self):
1938    np.random.seed(123)
1939    x = backend.random_uniform((500, 500))
1940    val = backend.eval(x)
1941    self.assertAllClose(np.mean(val), 0.5, atol=0.01)
1942    self.assertAllClose(np.max(val), 1., atol=0.01)
1943    self.assertAllClose(np.min(val), 0., atol=0.01)
1944
1945  def test_random_binomial(self):
1946    np.random.seed(123)
1947    x = backend.random_binomial((500, 500), p=0.5)
1948    self.assertAllClose(np.mean(backend.eval(x)), 0.5, atol=0.01)
1949
1950  def test_truncated_normal(self):
1951    np.random.seed(123)
1952    x = backend.truncated_normal((500, 500), mean=0.0, stddev=1.0)
1953    x = backend.truncated_normal((1000, 1000), mean=0.0, stddev=1.0)
1954    y = backend.eval(x)
1955    self.assertAllClose(np.mean(y), 0., atol=0.01)
1956    self.assertAllClose(np.std(y), 0.88, atol=0.01)
1957    self.assertAllClose(np.max(y), 2., atol=0.01)
1958    self.assertAllClose(np.min(y), -2., atol=0.01)
1959
1960
1961@combinations.generate(combinations.combine(mode=['graph', 'eager']))
1962class FunctionTest(test.TestCase):
1963
1964  def test_function_basics(self):
1965    if context.executing_eagerly():
1966      self.skipTest('eager backend.function does not support updates')
1967    x1 = backend.placeholder(shape=(), dtype='float32')
1968    x2 = backend.placeholder(shape=(), dtype='int32')
1969    v = backend.variable(10.)
1970
1971    y1 = x1 + backend.cast(x2, 'float32') + v
1972    y2 = x1 * backend.cast(x2, 'float32')
1973
1974    with ops.control_dependencies([y1]):
1975      u = backend.update(v, x1)
1976
1977    f = backend.function([x1, x2], [y1, y2], updates=[u])
1978    output_values = f([2, 3])
1979    self.assertEqual(output_values, [15., 6.])
1980    self.assertEqual(backend.eval(v), 2.)
1981
1982  def test_function_dict_outputs(self):
1983    x_ph = backend.placeholder(shape=(), name='x')
1984    y_ph = backend.placeholder(shape=(), name='y')
1985    outputs = {'x*y': y_ph * x_ph, 'x*x': x_ph * x_ph}
1986
1987    f = backend.function(inputs=[x_ph, y_ph], outputs=outputs)
1988    x, y = 2., 5.
1989    results = f([x, y])
1990
1991    self.assertEqual(results['x*y'], 10.)
1992    self.assertEqual(results['x*x'], 4)
1993
1994  def test_function_dict_inputs(self):
1995    placeholders = {
1996        'x': backend.placeholder(shape=()),
1997        'y': backend.placeholder(shape=())
1998    }
1999    outputs = [placeholders['x'] * placeholders['y']]
2000
2001    f = backend.function(inputs=placeholders, outputs=outputs)
2002    results = f({'x': 2., 'y': 3.})
2003    self.assertEqual(results[0], 6.)
2004
2005  def test_function_single_input_output(self):
2006    x_ph = backend.placeholder(shape=(), name='x')
2007    output = x_ph * x_ph
2008    f = backend.function(x_ph, output)
2009    result = f(2.)
2010    self.assertEqual(result, 4.)
2011
2012  def test_tuple_updates(self):
2013    if context.executing_eagerly():
2014      self.skipTest('eager backend.function does not support updates')
2015
2016    x_ph = backend.placeholder(ndim=2)
2017    v = backend.variable(np.ones((4, 2)))
2018    output = x_ph**2 + v
2019    new_v = v + x_ph
2020    f = backend.function(x_ph, output, updates=[(v, new_v)])
2021    input_val = np.random.random((4, 2))
2022    result = f(input_val)
2023    self.assertAllClose(result, input_val**2 + 1)
2024    self.assertAllClose(backend.get_value(v), np.ones((4, 2)) + input_val)
2025
2026
2027class BackendGraphTests(test.TestCase, parameterized.TestCase):
2028
2029  @combinations.generate(combinations.combine(mode=['graph']))
2030  def test_function_placeholder_with_default(self):
2031    with backend.get_graph().as_default():
2032      x1 = array_ops.placeholder_with_default(
2033          np.array(2., dtype='float32'), shape=())
2034      x2 = array_ops.placeholder_with_default(
2035          np.array(3, dtype='int32'), shape=())
2036    y1 = x1 + backend.cast(x2, 'float32')
2037    y2 = x1 * backend.cast(x2, 'float32')
2038    f = backend.function([x1, x2], [y1, y2])
2039    output_values = f([4, 5])
2040    self.assertEqual(output_values, [9., 20.])
2041    output_values = f([None, None])
2042    self.assertEqual(output_values, [5., 6.])
2043
2044  def test_function_tf_feed_symbols(self):
2045    # Test Keras backend functions with TF tensor inputs.
2046    with ops.Graph().as_default(), self.cached_session():
2047      # Test feeding a resource variable to `function`.
2048      x1 = backend.placeholder(shape=())
2049      x2 = backend.placeholder(shape=())
2050      lr = backend.learning_phase()  # Include a placeholder_with_default.
2051
2052      y1 = backend.variable(10.)
2053      y2 = 3
2054
2055      f = backend.function(
2056          inputs=[x1, x2, lr],
2057          outputs=[x1 + 1, backend.in_train_phase(x2 + 2, x2 - 1)])
2058      outs = f([y1, y2, None])  # Use default learning_phase value.
2059      self.assertEqual(outs, [11., 2.])
2060      outs = f([y1, y2, 1])  # Set learning phase value.
2061      self.assertEqual(outs, [11., 5.])
2062
2063      # Test triggering a callable refresh by changing the input.
2064      y3 = backend.constant(20.)  # Test with tensor
2065      outs = f([y3, y2, None])
2066      self.assertEqual(outs, [21., 2.])
2067
2068      y4 = 4  # Test with non-symbol
2069      outs = f([y4, y2, None])
2070      self.assertEqual(outs, [5., 2.])
2071
2072      # Test with a different dtype
2073      y5 = backend.constant(10., dtype='float64')
2074      outs = f([y5, y2, None])
2075      self.assertEqual(outs, [11., 2.])
2076
2077  def test_function_tf_fetches(self):
2078    # Additional operations can be passed to tf.compat.v1.Session().run() via
2079    # its `fetches` arguments. In contrast to `updates` argument of
2080    # backend.function() these do not have control dependency on `outputs`
2081    # so they can run in parallel. Also they should not contribute to output of
2082    # backend.function().
2083    with ops.Graph().as_default(), self.cached_session():
2084      x = backend.variable(0.)
2085      y = backend.variable(0.)
2086      x_placeholder = backend.placeholder(shape=())
2087      y_placeholder = backend.placeholder(shape=())
2088
2089      f = backend.function(
2090          inputs=[x_placeholder, y_placeholder],
2091          outputs=[x_placeholder + y_placeholder],
2092          updates=[(x, x_placeholder + 1.)],
2093          fetches=[backend.update(y, 5.)])
2094      output = f([10., 20.])
2095      self.assertEqual(output, [30.])
2096      self.assertEqual(backend.get_session().run(fetches=[x, y]), [11., 5.])
2097
2098  def test_function_tf_feed_dict(self):
2099    # Additional substitutions can be passed to `tf.compat.v1.Session().run()`
2100    # via its `feed_dict` arguments. Note that the feed_dict is passed once in
2101    # the constructor but we can modify the values in the dictionary. Through
2102    # this feed_dict we can provide additional substitutions besides Keras
2103    # inputs.
2104    with ops.Graph().as_default(), self.cached_session():
2105      x = backend.variable(0.)
2106      y = backend.variable(0.)
2107      x_placeholder = backend.placeholder(shape=())
2108      y_placeholder = backend.placeholder(shape=())
2109
2110      feed_dict = {y_placeholder: 3.}
2111      fetches = [backend.update(y, y_placeholder * 10.)]
2112      f = backend.function(
2113          inputs=[x_placeholder],
2114          outputs=[x_placeholder + 1.],
2115          updates=[(x, x_placeholder + 10.)],
2116          feed_dict=feed_dict,
2117          fetches=fetches)
2118      output = f([10.])
2119      self.assertEqual(output, [11.])
2120      self.assertEqual(backend.get_session().run(fetches=[x, y]), [20., 30.])
2121
2122      # updated value in feed_dict will be modified within the K.function()
2123      feed_dict[y_placeholder] = 4.
2124      output = f([20.])
2125      self.assertEqual(output, [21.])
2126      self.assertEqual(backend.get_session().run(fetches=[x, y]), [30., 40.])
2127
2128  def test_function_tf_run_options_with_run_metadata(self):
2129    with ops.Graph().as_default(), self.cached_session():
2130      x_placeholder = backend.placeholder(shape=())
2131      y_placeholder = backend.placeholder(shape=())
2132
2133      run_options = config_pb2.RunOptions(output_partition_graphs=True)
2134      run_metadata = config_pb2.RunMetadata()
2135      # enable run_options.
2136      f = backend.function(
2137          inputs=[x_placeholder, y_placeholder],
2138          outputs=[x_placeholder + y_placeholder],
2139          options=run_options,
2140          run_metadata=run_metadata)
2141      output = f([10., 20.])
2142      self.assertEqual(output, [30.])
2143      self.assertNotEmpty(run_metadata.partition_graphs)
2144      # disable run_options.
2145      f1 = backend.function(
2146          inputs=[x_placeholder, y_placeholder],
2147          outputs=[x_placeholder + y_placeholder],
2148          run_metadata=run_metadata)
2149      output1 = f1([10., 20.])
2150      self.assertEqual(output1, [30.])
2151      self.assertEmpty(run_metadata.partition_graphs)
2152
2153  def test_function_fetch_callbacks(self):
2154
2155    class CallbackStub(object):
2156
2157      def __init__(self):
2158        self.times_called = 0
2159        self.callback_result = 0
2160
2161      def _fetch_callback(self, result):
2162        self.times_called += 1
2163        self.callback_result = result
2164
2165    with ops.Graph().as_default(), self.cached_session():
2166      callback = CallbackStub()
2167      x_placeholder = backend.placeholder(shape=())
2168      y_placeholder = backend.placeholder(shape=())
2169
2170      callback_op = x_placeholder * y_placeholder
2171
2172      f = backend.function(
2173          inputs=[x_placeholder, y_placeholder],
2174          outputs=[x_placeholder + y_placeholder])
2175      f.fetches.append(callback_op)
2176      f.fetch_callbacks[callback_op] = callback._fetch_callback
2177
2178      _ = f([10., 20.])
2179
2180      self.assertEqual(callback.times_called, 1)
2181      self.assertEqual(callback.callback_result, 200)
2182
2183  def test_get_session_different_graphs(self):
2184    with ops.Graph().as_default():
2185      x = backend.constant(1)
2186      session = backend.get_session()
2187      self.assertIs(session, backend.get_session((x,)))
2188      self.assertIs(session, backend.get_session())
2189    with ops.Graph().as_default():
2190      self.assertIs(session, backend.get_session((x,)))
2191      self.assertIsNot(session, backend.get_session())
2192
2193
2194@combinations.generate(combinations.combine(mode=['graph', 'eager']))
2195class ControlOpsTests(test.TestCase):
2196
2197  def test_function_switch_basics(self):
2198    x = array_ops.constant(2.0)
2199    y = array_ops.constant(3.0)
2200
2201    def xpowy():
2202      return backend.pow(x, y)
2203
2204    def ypowx():
2205      return backend.pow(y, x)
2206
2207    tensor = backend.switch(backend.less(x, y), xpowy, ypowx)
2208    self.assertEqual(backend.eval(tensor), [8.0])
2209
2210    tensor = backend.switch(backend.greater(x, y), xpowy, ypowx)
2211    self.assertEqual(backend.eval(tensor), [9.0])
2212
2213  def test_unequal_rank(self):
2214    x = ops.convert_to_tensor_v2_with_dispatch(
2215        np.array([[1, 2, 3], [4, 5, 6]]), dtype='float32')
2216    y = ops.convert_to_tensor_v2_with_dispatch(
2217        np.array([1, 2, 3]), dtype='float32')
2218
2219    def true_func():
2220      return x
2221
2222    def false_func():
2223      return y
2224
2225    with self.assertRaisesRegex(ValueError,
2226                                'Rank of `condition` should be less than'):
2227      backend.switch(backend.equal(x, x), false_func, true_func)
2228
2229
2230class ContextValueCacheTest(test.TestCase):
2231
2232  def test_cache(self):
2233    cache = backend.ContextValueCache(list)
2234    graph1 = ops.Graph()
2235    graph2 = ops.Graph()
2236
2237    cache[graph1].append(1)
2238    with graph1.as_default():
2239      cache[None].append(2)
2240
2241    with graph2.as_default():
2242      cache[None].append(3)
2243    cache[graph2].append(4)
2244
2245    self.assertAllEqual(cache[graph1], [1, 2])
2246    self.assertAllEqual(cache[graph2], [3, 4])
2247
2248    with context.eager_mode():
2249      cache[None].append(5)
2250      cache[None].append(6)
2251      self.assertAllEqual(cache[None], [5, 6])
2252
2253    self.assertLen(cache, 3)
2254
2255    del graph1
2256    gc.collect()
2257    self.assertLen(cache, 2)
2258
2259  def test_cache_in_parent_graph(self):
2260    cache = backend.ContextValueCache(int)
2261    cache.setdefault(None, backend.constant(5))
2262
2263    with ops.Graph().as_default() as g:
2264      # g is not a child graph of the default test context, so the recursive
2265      # lookup will create a new default value.
2266      self.assertAllEqual(cache[g], 0)
2267
2268    @def_function.function
2269    def fn():
2270      # The function graph is a child of the default test context, so
2271      # __getitem__ will return the previously saved value.
2272      return cache[ops.get_default_graph()]
2273
2274    self.assertEqual(self.evaluate(fn()), 5)
2275
2276
2277if __name__ == '__main__':
2278  test.main()
2279