• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for training routines."""
16
17from absl.testing import parameterized
18import numpy as np
19from tensorflow.python.keras import backend
20from tensorflow.python.keras import combinations
21from tensorflow.python.keras import testing_utils
22from tensorflow.python.keras.engine import input_layer
23from tensorflow.python.keras.engine import training
24from tensorflow.python.keras.layers.convolutional import Conv2D
25from tensorflow.python.platform import test
26
27
28class TrainingGPUTest(test.TestCase, parameterized.TestCase):
29
30  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
31  def test_model_with_crossentropy_losses_channels_first(self):
32    """Tests use of all crossentropy losses with `channels_first`.
33
34    Tests `sparse_categorical_crossentropy`, `categorical_crossentropy`,
35    and `binary_crossentropy`.
36    Verifies that evaluate gives the same result with either `channels_first`
37    or `channels_last` image_data_format.
38    """
39    def prepare_simple_model(input_tensor, loss_name, target):
40      axis = 1 if backend.image_data_format() == 'channels_first' else -1
41      loss = None
42      num_channels = None
43      activation = None
44      if loss_name == 'sparse_categorical_crossentropy':
45        loss = lambda y_true, y_pred: backend.sparse_categorical_crossentropy(  # pylint: disable=g-long-lambda
46            y_true, y_pred, axis=axis)
47        num_channels = int(np.amax(target) + 1)
48        activation = 'softmax'
49      elif loss_name == 'categorical_crossentropy':
50        loss = lambda y_true, y_pred: backend.categorical_crossentropy(  # pylint: disable=g-long-lambda
51            y_true, y_pred, axis=axis)
52        num_channels = target.shape[axis]
53        activation = 'softmax'
54      elif loss_name == 'binary_crossentropy':
55        loss = lambda y_true, y_pred: backend.binary_crossentropy(  # pylint: disable=g-long-lambda, unnecessary-lambda
56            y_true, y_pred)
57        num_channels = target.shape[axis]
58        activation = 'sigmoid'
59
60      predictions = Conv2D(num_channels,
61                           1,
62                           activation=activation,
63                           kernel_initializer='ones',
64                           bias_initializer='ones')(input_tensor)
65      simple_model = training.Model(inputs=input_tensor, outputs=predictions)
66      simple_model.compile(optimizer='rmsprop', loss=loss)
67      return simple_model
68
69    if test.is_gpu_available(cuda_only=True):
70      with testing_utils.use_gpu():
71        losses_to_test = ['sparse_categorical_crossentropy',
72                          'categorical_crossentropy', 'binary_crossentropy']
73
74        data_channels_first = np.array([[[[8., 7.1, 0.], [4.5, 2.6, 0.55],
75                                          [0.9, 4.2, 11.2]]]], dtype=np.float32)
76        # Labels for testing 4-class sparse_categorical_crossentropy, 4-class
77        # categorical_crossentropy, and 2-class binary_crossentropy:
78        labels_channels_first = [np.array([[[[0, 1, 3], [2, 1, 0], [2, 2, 1]]]], dtype=np.float32),  # pylint: disable=line-too-long
79                                 np.array([[[[0, 1, 0], [0, 1, 0], [0, 0, 0]],
80                                            [[1, 0, 0], [0, 0, 1], [0, 1, 0]],
81                                            [[0, 0, 0], [1, 0, 0], [0, 0, 1]],
82                                            [[0, 0, 1], [0, 0, 0], [1, 0, 0]]]], dtype=np.float32),  # pylint: disable=line-too-long
83                                 np.array([[[[0, 1, 0], [0, 1, 0], [0, 0, 1]],
84                                            [[1, 0, 1], [1, 0, 1], [1, 1, 0]]]], dtype=np.float32)]  # pylint: disable=line-too-long
85        # Compute one loss for each loss function in the list `losses_to_test`:
86        loss_channels_last = [0., 0., 0.]
87        loss_channels_first = [0., 0., 0.]
88
89        old_data_format = backend.image_data_format()
90
91        # Evaluate a simple network with channels last, with all three loss
92        # functions:
93        backend.set_image_data_format('channels_last')
94        data = np.moveaxis(data_channels_first, 1, -1)
95        for index, loss_function in enumerate(losses_to_test):
96          labels = np.moveaxis(labels_channels_first[index], 1, -1)
97          inputs = input_layer.Input(shape=(3, 3, 1))
98          model = prepare_simple_model(inputs, loss_function, labels)
99          loss_channels_last[index] = model.evaluate(x=data, y=labels,
100                                                     batch_size=1, verbose=0)
101
102        # Evaluate the same network with channels first, with all three loss
103        # functions:
104        backend.set_image_data_format('channels_first')
105        data = data_channels_first
106        for index, loss_function in enumerate(losses_to_test):
107          labels = labels_channels_first[index]
108          inputs = input_layer.Input(shape=(1, 3, 3))
109          model = prepare_simple_model(inputs, loss_function, labels)
110          loss_channels_first[index] = model.evaluate(x=data, y=labels,
111                                                      batch_size=1, verbose=0)
112
113        backend.set_image_data_format(old_data_format)
114
115        np.testing.assert_allclose(
116            loss_channels_first,
117            loss_channels_last,
118            rtol=1e-06,
119            err_msg='{}{}'.format('Computed different losses for ',
120                                  'channels_first and channels_last'))
121
122
123if __name__ == '__main__':
124  test.main()
125