• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for deterministic functionality of SoftmaxCrossEntropyWithLogits op."""
16
17import numpy as np
18
19from tensorflow.python.eager import backprop
20from tensorflow.python.framework import config
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import errors_impl
24from tensorflow.python.framework import test_util
25from tensorflow.python.kernel_tests.nn_ops import xent_op_test_base
26from tensorflow.python.ops import gen_nn_ops
27from tensorflow.python.ops import nn_ops
28# The following import is required to register the gradient function.
29from tensorflow.python.ops.nn_grad import _SoftmaxCrossEntropyWithLogitsGrad  # pylint: disable=unused-import
30from tensorflow.python.platform import test
31
32
33class XentOpDeterminismExceptionsTest(test.TestCase):
34  """Test d9m-unimplemented exceptions from SoftmaxXentWithLogitsOp.
35
36  Test that tf.errors.UnimplementedError is thrown, as appropriate, by the GPU
37  code-paths through SoftmaxXentWithLogitsOp when deterministic ops are
38  enabled.
39
40  This test assumes that xent_op_test.py runs equivalent test cases when
41  deterministic ops are not enabled and will therefore detect erroneous
42  exception throwing in those cases.
43  """
44
45  @test_util.run_gpu_only
46  @test_util.run_in_graph_and_eager_modes
47  def testExceptionThrowing(self):
48    with self.session(), test_util.force_gpu():
49      for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
50        features = constant_op.constant([[0.3, 0.5], [0.5, 0.6]], dtype=dtype)
51        labels = constant_op.constant([[0.2, 0.4], [0.1, 0.2]], dtype=dtype)
52        with self.assertRaisesRegex(
53            errors_impl.UnimplementedError,
54            "The GPU implementation of SoftmaxCrossEntropyWithLogits that " +
55            "would have been executed is not deterministic. Note that the " +
56            "Python API uses an alternative, deterministic, GPU-accelerated " +
57            "path when determinism is enabled."):
58          result = gen_nn_ops.softmax_cross_entropy_with_logits(
59              features=features, labels=labels)
60          self.evaluate(result)
61
62
63class XentOpDeterministicTest(xent_op_test_base.XentOpTestBase):
64  """Test that SoftmaxCrossEntropyWithLogits operates reproducibly.
65
66  Inheriting from xent_op_test_base.XentTestBase ensures that regular op
67  functionality is correct when the deterministic code-path is selected.
68
69  Note that because nn_ops.softmax_cross_entropy_with_logits calls
70  nn_ops.cross_entropy_with_logits_v2, the focus of testing is on the
71  former in order to test both.
72  """
73
74  def _randomFloats(self, shape, dtype, normalized_rows=False):
75    a = (2 * np.random.random_sample(shape) - 1).astype(dtype)
76
77    if normalized_rows:
78
79      def normalize(row):
80        return row / row.sum()
81
82      a = np.apply_along_axis(normalize, 1, a)
83
84    return constant_op.constant(a)
85
86  def _generateInputs(self, dtype, seed=123, forward_not_backward=False):
87    batch_size = 1024
88    if forward_not_backward and dtype == np.float16:
89      # Generate more noise to expose the internal float32 implementation.
90      # This is associated with significantly slower test cases (esp. on CPU).
91      classes_count = 20000
92    else:
93      classes_count = 3000
94    shape = (batch_size, classes_count)
95    np.random.seed(seed)
96    labels = self._randomFloats(shape, dtype, normalized_rows=True)
97    logits = self._randomFloats(shape, dtype)
98    return labels, logits
99
100  @test_util.run_in_graph_and_eager_modes
101  def testForward(self):
102    with self.cached_session():
103      for dtype in [np.float16, np.float32, np.float64,  \
104        dtypes.bfloat16.as_numpy_dtype]:
105
106        for trial in range(5):
107          seed = 123 + trial
108          labels, logits = self._generateInputs(
109              dtype, seed=seed, forward_not_backward=True)
110          result_a = nn_ops.softmax_cross_entropy_with_logits(
111              labels=labels, logits=logits)
112          result_b = nn_ops.softmax_cross_entropy_with_logits(
113              labels=labels, logits=logits)
114          self.assertAllEqual(result_a, result_b)
115
116  @test_util.run_in_graph_and_eager_modes
117  def testBackward(self):
118    with self.cached_session():
119      for dtype in [np.float16, np.float32, np.float64,  \
120        dtypes.bfloat16.as_numpy_dtype]:
121        labels, logits = self._generateInputs(dtype, seed=456)
122        output_shape = labels.shape[0]
123
124        def gradients(seed):
125          np.random.seed(seed)
126          upstream_gradients = self._randomFloats(output_shape, dtype)
127          with backprop.GradientTape(persistent=True) as tape:
128            tape.watch(labels)
129            tape.watch(logits)
130            op_output = nn_ops.softmax_cross_entropy_with_logits(
131                labels=labels, logits=logits)
132            gradient_injector_output = op_output * upstream_gradients
133          return tape.gradient(gradient_injector_output, [labels, logits])
134
135        for trial in range(5):
136          seed = 456 + trial
137          labels_grad_a, logits_grad_a = gradients(seed=seed)
138          labels_grad_b, logits_grad_b = gradients(seed=seed)
139          self.assertAllEqual(labels_grad_a, labels_grad_b)
140          self.assertAllEqual(logits_grad_a, logits_grad_b)
141
142  # Modifications to the parent class (xent_op_test_base.XentOpTestBase) follow
143
144  def testSingleClass(self):
145    """Modify testing of gradient for single-class case.
146
147    The deterministic implementation does not produce the gradients expected by
148    the original test (for the nondeterministic functionality) when the labels
149    vector is not a valid probability distribution.
150
151    labels: [[-1.], [0.], [1.], [1.]]
152    logits: [[1.], [-1.], [0.], [1.]]
153
154                   nondeterministic               deterministic
155    dloss/dlogits: [[2.0], [1.0], [0.0], [0.0]]   [[0.0], [0.0], [0.0], [0.0]]
156
157    Note that only the second two label vectors are valid probability
158    distributions (as required by the API) and that the gradient matches for
159    those cases.
160
161    TODO(duncanriach): Further investigate the source of the difference in
162                       the gradients for this case.
163    """
164    self._testSingleClass(expected_gradient=[[0.0], [0.0], [0.0], [0.0]])
165
166  def testLabelsBroadcast(self):
167    """Modify testing of gradient for labels-broadcast case.
168
169    The deterministic implementation does not produce the gradients expected by
170    the original test (for the nondeterministic functionality) when the labels
171    vector (after broadcasting) is not a valid probability distribution.
172
173    labels: [[0.], [2.], [0.25]]
174    logits: [[1., 1., 1., 1.],
175             [1., 2., 3., 4.],
176             [1., 2., 3., 4.]]
177
178    dloss/dlogits (nondeterministic):
179        [[ 0.25 ,  0.25 ,  0.25 ,  0.25 ],
180         [-1.968, -1.913, -1.763, -1.355],
181         [-0.218, -0.163, -0.013,  0.394]]
182
183    dloss/dlogits (determinsitic):
184        [[ 0.   ,  0.   ,  0.   ,  0.   ],
185         [-1.743, -1.303, -0.105,  3.150],
186         [-0.218, -0.163, -0.013,  0.394]]
187
188    Note that neither of the first two broadcast label vectors is a valid
189    probability distribution (as required by the API) and that these are the
190    cases that yield different gradients for nondeterministic vs determinsitic
191    implementations.
192
193    TODO(duncanriach): Further investigate the source of the difference in
194                       the gradient for this case.
195    """
196    self._testLabelsBroadcast(uniform_labels_gradient=[[
197        0., 0., 0., 0.
198    ], [-1.743, -1.303, -0.105, 3.150], [-0.218, -0.163, -0.013, 0.394]])
199
200
201if __name__ == "__main__":
202  config.enable_op_determinism()
203  test.main()
204