• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for unary coefficient-wise operations."""
16
17import math
18
19import numpy as np
20
21from tensorflow.python.eager import backprop
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes as dtypes_lib
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import gen_math_ops
28from tensorflow.python.ops import gradient_checker
29from tensorflow.python.ops import gradient_checker_v2
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
32from tensorflow.python.ops import special_math_ops
33from tensorflow.python.platform import test
34from tensorflow.python.platform import tf_logging
35
36_NEG = lambda x: -x
37_ABS = abs
38
39
40# TODO(zongheng): it'd be great to factor out this function and various random
41# SparseTensor gen funcs.
42def _sparsify(x, thresh=0.5, index_dtype=np.int64):
43  x[x < thresh] = 0
44
45  non_zero = np.where(x)
46  x_indices = np.vstack(non_zero).astype(index_dtype).T
47  x_values = x[non_zero]
48  x_shape = x.shape
49
50  return sparse_tensor.SparseTensor(
51      indices=x_indices, values=x_values, dense_shape=x_shape), x_values
52
53
54def _default_tolerance(dtype):
55  """Returns a sensible default tolerance for comparing results of a given type.
56
57  Args:
58    dtype: A datatype.
59  """
60  if dtype == dtypes_lib.bfloat16.as_numpy_dtype:
61    return 5e-3
62  if dtype == np.float16:
63    return 5e-3
64  elif dtype in (np.float32, np.complex64):
65    return 1e-3
66  elif dtype in (np.float64, np.complex128):
67    return 1e-5
68  else:
69    return None  # Fail fast for unexpected types
70
71
72class UnaryOpTest(test.TestCase):
73
74  def _compareCpu(self, x, np_func, tf_func, grad_rtol=None, grad_atol=None):
75    if grad_rtol is None:
76      grad_rtol = _default_tolerance(x.dtype)
77    if grad_atol is None:
78      grad_atol = _default_tolerance(x.dtype)
79    np_ans = np_func(x)
80    with self.cached_session(use_gpu=False):
81      inx = ops.convert_to_tensor(x)
82      y = tf_func(inx)
83      tf_cpu = self.evaluate(y)
84      self.assertShapeEqual(np_ans, y)
85      if x.dtype == np.float16:
86        self.assertAllClose(np_ans, tf_cpu, rtol=1e-3, atol=1e-3)
87      elif x.dtype == dtypes_lib.bfloat16.as_numpy_dtype:
88        self.assertAllClose(np_ans, tf_cpu, rtol=1e-2, atol=1e-2)
89      else:
90        self.assertAllClose(np_ans, tf_cpu)
91
92      if x.dtype in (np.complex64, np.complex128) and tf_func == math_ops.sign:
93        return  # Return early
94
95      if x.dtype in (np.float16, dtypes_lib.bfloat16.as_numpy_dtype):
96        s = list(np.shape(x))
97        jacob_t, _ = gradient_checker.compute_gradient(
98            inx, s, y, s, x_init_value=x)
99        xf = x.astype(np.float64)
100        inxf = ops.convert_to_tensor(xf)
101        yf = tf_func(inxf)
102        _, jacob_n = gradient_checker.compute_gradient(
103            inxf, s, yf, s, x_init_value=xf, delta=1e-2)
104        jacob_n = jacob_n.astype(x.dtype)
105        self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
106      elif x.dtype in (np.float32, np.complex64):
107        s = list(np.shape(x))
108        jacob_t, jacob_n = gradient_checker.compute_gradient(
109            inx, s, y, s, x_init_value=x, delta=1e-3)
110        self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
111      elif x.dtype in (np.float64, np.complex128):
112        s = list(np.shape(x))
113        jacob_t, jacob_n = gradient_checker.compute_gradient(
114            inx, s, y, s, x_init_value=x, delta=1e-5)
115        self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
116
117  def _check(self, result_tensor, result_np, input_sp_t, tol):
118    self.assertTrue(isinstance(result_tensor, sparse_tensor.SparseTensor))
119    self.assertTrue(isinstance(input_sp_t, sparse_tensor.SparseTensor))
120    self.assertAllEqual(input_sp_t.indices, result_tensor.indices)
121    self.assertAllEqual(input_sp_t.dense_shape, result_tensor.dense_shape)
122    if tol is None:
123      self.assertAllClose(result_np, result_tensor.values)
124    else:
125      self.assertAllClose(result_np, result_tensor.values, rtol=tol, atol=tol)
126
127  def _compareSparseCpu(self, x, np_func, tf_func, tol):
128    x_sp, x_sp_vals = _sparsify(x)
129    res_np = np_func(x_sp_vals)
130    with test_util.force_cpu():
131      self._check(tf_func(x_sp), res_np, x_sp, tol)
132
133  def _compareGpu(self, x, np_func, tf_func):
134    np_ans = np_func(x)
135    with test_util.use_gpu():
136      result = tf_func(ops.convert_to_tensor(x))
137      tf_gpu = self.evaluate(result)
138      # Slightly increase the tolerance for float64 computations. This is
139      # desired for specifically lgamma but shouldn't be of concern for other
140      # functions.
141      self.assertAllCloseAccordingToType(np_ans, tf_gpu, atol=2e-6)
142    # TODO(zhifengc/ke): make gradient checker work on GPU.
143
144  def _compareSparseGpu(self, x, np_func, tf_func, tol):
145    x_sp, x_sp_vals = _sparsify(x)
146    res_np = np_func(x_sp_vals)
147    with test_util.use_gpu():
148      self._check(tf_func(x_sp), res_np, x_sp, tol)
149
150  def _compareBoth(self, x, np_func, tf_func, grad_tol=None):
151    self._compareCpu(x, np_func, tf_func, grad_rtol=grad_tol,
152                     grad_atol=grad_tol)
153    self._compareGpu(x, np_func, tf_func)
154
155  def _compareBothSparse(self, x, np_func, tf_func, tol=None):
156    self._compareSparseCpu(x, np_func, tf_func, tol)
157    self._compareSparseGpu(x, np_func, tf_func, tol)
158
159  def _inv(self, x):
160    return 1.0 / x
161
162  def _rsqrt(self, x):
163    return self._inv(np.sqrt(x))
164
165  def _sigmoid(self, x):
166    return 1.0 / (1.0 + np.exp(-x))
167
168  def _log_sigmoid(self, x):
169    return np.log(self._sigmoid(x))
170
171  def _replace_domain_error_with_inf(self, fn):
172
173    def func(x):
174      try:
175        return fn(x)
176      except ValueError as e:
177        if "domain error" in str(e):
178          return np.inf * np.ones_like(x)
179        else:
180          raise e
181
182    return func
183
184  @test_util.run_deprecated_v1
185  def testFloatBasic(self):
186    x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
187    w = x - x.min() + 1.02  # all greater than 1
188    y = (x + .5).astype(np.float32)  # no zero
189    z = (x + 15.5).astype(np.float32)  # all positive
190    k = np.arange(-0.90, 0.90, 0.25).astype(np.float32)  # between -1 and 1
191
192    self._compareBoth(x, np.abs, math_ops.abs)
193    self._compareBoth(x, np.abs, _ABS)
194    self._compareBoth(x, np.negative, math_ops.negative)
195    self._compareBoth(x, np.negative, _NEG)
196    self._compareBoth(y, self._inv, math_ops.reciprocal)
197    self._compareBoth(x, np.square, math_ops.square)
198    self._compareBoth(z, np.sqrt, math_ops.sqrt)
199    self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
200    self._compareBoth(x, np.exp, math_ops.exp)
201    self._compareBoth(x, np.expm1, math_ops.expm1)
202    self._compareBoth(z, np.log, math_ops.log)
203    self._compareBoth(z, np.log1p, math_ops.log1p)
204    self._compareBoth(x, np.sinh, math_ops.sinh)
205    self._compareBoth(x, np.cosh, math_ops.cosh)
206    self._compareBoth(x, np.tanh, math_ops.tanh)
207    self._compareBoth(x, np.arcsinh, math_ops.asinh)
208    self._compareBoth(w, np.arccosh, math_ops.acosh)
209    self._compareBoth(k, np.arctanh, math_ops.atanh)
210    self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
211    self._compareBoth(x, self._log_sigmoid, math_ops.log_sigmoid)
212    self._compareBoth(y, np.sign, math_ops.sign)
213    self._compareBoth(x, np.sin, math_ops.sin)
214    self._compareBoth(x, np.cos, math_ops.cos)
215    self._compareBoth(k, np.arcsin, math_ops.asin)
216    self._compareBoth(k, np.arccos, math_ops.acos)
217    self._compareBoth(x, np.arctan, math_ops.atan)
218    self._compareBoth(x, np.tan, math_ops.tan)
219    self._compareBoth(
220        y, np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
221        math_ops.lgamma)
222    self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
223    self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
224    try:
225      from scipy import special  # pylint: disable=g-import-not-at-top
226      self._compareBoth(x, special.i0e, special_math_ops.bessel_i0e)
227      self._compareBoth(x, special.i1e, special_math_ops.bessel_i1e)
228    except ImportError as e:
229      tf_logging.warn("Cannot test special functions: %s" % str(e))
230
231    self._compareBothSparse(x, np.abs, math_ops.abs)
232    self._compareBothSparse(x, np.negative, math_ops.negative)
233    self._compareBothSparse(x, np.square, math_ops.square)
234    self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
235    self._compareBothSparse(x, np.tanh, math_ops.tanh)
236    self._compareBothSparse(y, np.sign, math_ops.sign)
237    self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf)
238
239  @test_util.run_deprecated_v1
240  def testFloatTanhEdge(self):
241    x = np.arange(40, 40 + 6).reshape(6).astype(np.float32)
242    self._compareBoth(x, np.tanh, math_ops.tanh)
243    x = np.arange(-40, -40 + 6).reshape(6).astype(np.float32)
244    self._compareBoth(x, np.tanh, math_ops.tanh)
245
246  @test_util.run_deprecated_v1
247  def testFloatEmpty(self):
248    x = np.empty((2, 0, 5), dtype=np.float32)
249    self._compareBoth(x, np.abs, math_ops.abs)
250    self._compareBoth(x, np.abs, _ABS)
251    self._compareBoth(x, np.negative, math_ops.negative)
252    self._compareBoth(x, np.negative, _NEG)
253    self._compareBoth(x, self._inv, math_ops.reciprocal)
254    self._compareBoth(x, np.square, math_ops.square)
255    self._compareBoth(x, np.sqrt, math_ops.sqrt)
256    self._compareBoth(x, self._rsqrt, math_ops.rsqrt)
257    self._compareBoth(x, np.exp, math_ops.exp)
258    self._compareBoth(x, np.expm1, math_ops.expm1)
259    self._compareBoth(x, np.log, math_ops.log)
260    self._compareBoth(x, np.log1p, math_ops.log1p)
261    self._compareBoth(x, np.sinh, math_ops.sinh)
262    self._compareBoth(x, np.arcsinh, math_ops.asinh)
263    self._compareBoth(x, np.cosh, math_ops.cosh)
264    self._compareBoth(x, np.arccosh, math_ops.acosh)
265    self._compareBoth(x, np.tanh, math_ops.tanh)
266    self._compareBoth(x, np.arctanh, math_ops.atanh)
267    self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
268    self._compareBoth(x, np.sign, math_ops.sign)
269    self._compareBoth(x, np.sin, math_ops.sin)
270    self._compareBoth(x, np.cos, math_ops.cos)
271    # Can't use vectorize below, so just use some arbitrary function
272    self._compareBoth(x, np.sign, math_ops.lgamma)
273    self._compareBoth(x, np.sign, math_ops.erf)
274    self._compareBoth(x, np.sign, math_ops.erfc)
275    self._compareBoth(x, np.tan, math_ops.tan)
276    self._compareBoth(x, np.arcsin, math_ops.asin)
277    self._compareBoth(x, np.arccos, math_ops.acos)
278    self._compareBoth(x, np.arctan, math_ops.atan)
279    try:
280      from scipy import special  # pylint: disable=g-import-not-at-top
281      self._compareBoth(x, special.i0e, special_math_ops.bessel_i0e)
282      self._compareBoth(x, special.i1e, special_math_ops.bessel_i1e)
283    except ImportError as e:
284      tf_logging.warn("Cannot test special functions: %s" % str(e))
285
286    self._compareBothSparse(x, np.abs, math_ops.abs)
287    self._compareBothSparse(x, np.negative, math_ops.negative)
288    self._compareBothSparse(x, np.square, math_ops.square)
289    self._compareBothSparse(x, np.sqrt, math_ops.sqrt, tol=1e-3)
290    self._compareBothSparse(x, np.tanh, math_ops.tanh)
291    self._compareBothSparse(x, np.sign, math_ops.sign)
292    self._compareBothSparse(x, np.sign, math_ops.erf)
293
294  @test_util.run_deprecated_v1
295  def testDoubleBasic(self):
296    x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
297    w = x - x.min() + 1.02  # all greater than 1
298    y = (x + .5).astype(np.float64)  # no zero
299    z = (x + 15.5).astype(np.float64)  # all positive
300    k = np.arange(-0.90, 0.90,
301                  0.35).reshape(1, 3, 2).astype(np.float64)  # between -1 and 1
302    self._compareBoth(x, np.abs, math_ops.abs)
303    self._compareBoth(x, np.abs, _ABS)
304    self._compareBoth(x, np.negative, math_ops.negative)
305    self._compareBoth(x, np.negative, _NEG)
306    self._compareBoth(y, self._inv, math_ops.reciprocal)
307    self._compareBoth(x, np.square, math_ops.square)
308    self._compareBoth(z, np.sqrt, math_ops.sqrt)
309    self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
310    self._compareBoth(x, np.exp, math_ops.exp)
311    self._compareBoth(x, np.expm1, math_ops.expm1)
312    self._compareBoth(z, np.log, math_ops.log)
313    self._compareBoth(z, np.log1p, math_ops.log1p)
314    self._compareBoth(x, np.sinh, math_ops.sinh)
315    self._compareBoth(x, np.cosh, math_ops.cosh)
316    self._compareBoth(x, np.tanh, math_ops.tanh)
317    self._compareBoth(x, np.arcsinh, math_ops.asinh)
318    self._compareBoth(w, np.arccosh, math_ops.acosh)
319    self._compareBoth(k, np.arctanh, math_ops.atanh)
320    self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
321    self._compareBoth(y, np.sign, math_ops.sign)
322    self._compareBoth(x, np.sin, math_ops.sin)
323    self._compareBoth(x, np.cos, math_ops.cos)
324    self._compareBoth(
325        y, np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
326        math_ops.lgamma)
327    self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
328    self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
329    self._compareBoth(x, np.arctan, math_ops.atan)
330    self._compareBoth(k, np.arcsin, math_ops.asin)
331    self._compareBoth(k, np.arccos, math_ops.acos)
332    self._compareBoth(k, np.tan, math_ops.tan)
333    try:
334      from scipy import special  # pylint: disable=g-import-not-at-top
335      self._compareBoth(x, special.i0e, special_math_ops.bessel_i0e)
336      self._compareBoth(x, special.i1e, special_math_ops.bessel_i1e)
337    except ImportError as e:
338      tf_logging.warn("Cannot test special functions: %s" % str(e))
339
340    self._compareBothSparse(x, np.abs, math_ops.abs)
341    self._compareBothSparse(x, np.negative, math_ops.negative)
342    self._compareBothSparse(x, np.square, math_ops.square)
343    self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
344    self._compareBothSparse(x, np.tanh, math_ops.tanh)
345    self._compareBothSparse(y, np.sign, math_ops.sign)
346    self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf)
347
348  @test_util.run_deprecated_v1
349  def testHalfBasic(self):
350    x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float16)
351    w = x - x.min() + 1.1  # all greater than 1
352    y = (x + .5).astype(np.float16)  # no zero
353    z = (x + 15.5).astype(np.float16)  # all positive
354    k = np.arange(-0.90, 0.90, 0.05).astype(np.float16)  # between -1 and 1
355    self._compareBoth(x, np.abs, math_ops.abs)
356    self._compareBoth(x, np.abs, _ABS)
357    self._compareBoth(x, np.negative, math_ops.negative)
358    self._compareBoth(x, np.negative, _NEG)
359    self._compareBoth(y, self._inv, math_ops.reciprocal)
360    self._compareBoth(x, np.square, math_ops.square)
361    self._compareBoth(z, np.sqrt, math_ops.sqrt)
362    self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
363    self._compareBoth(x, np.exp, math_ops.exp)
364    self._compareBoth(x, np.expm1, math_ops.expm1)
365    self._compareBoth(z, np.log, math_ops.log)
366    self._compareBoth(z, np.log1p, math_ops.log1p)
367    self._compareBoth(x, np.sinh, math_ops.sinh)
368    self._compareBoth(x, np.cosh, math_ops.cosh)
369    self._compareBoth(x, np.tanh, math_ops.tanh)
370    self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
371    self._compareBoth(y, np.sign, math_ops.sign)
372    self._compareBoth(x, np.sin, math_ops.sin)
373    self._compareBoth(x, np.cos, math_ops.cos)
374    self._compareBoth(x, np.tan, math_ops.tan)
375    self._compareBoth(k, np.arcsin, math_ops.asin)
376    self._compareBoth(k, np.arccos, math_ops.acos)
377    self._compareBoth(x, np.arctan, math_ops.atan)
378    self._compareBoth(x, np.arcsinh, math_ops.asinh)
379    # The derivative of acosh close to 1 is very large, and needs a high
380    # tolerance for small precision.
381    self._compareBoth(w, np.arccosh, math_ops.acosh, grad_tol=1e-3)
382    self._compareBoth(k, np.arctanh, math_ops.atanh)
383    self._compareBoth(
384        y, np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
385        math_ops.lgamma)
386    self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
387    self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
388    self._compareBothSparse(x, np.abs, math_ops.abs)
389    self._compareBothSparse(x, np.negative, math_ops.negative)
390    self._compareBothSparse(x, np.square, math_ops.square)
391    self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
392    self._compareBothSparse(x, np.tanh, math_ops.tanh)
393    self._compareBothSparse(y, np.sign, math_ops.sign)
394    self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf, tol=1e-3)
395
396  @test_util.run_deprecated_v1
397  def testBFloat16Basic(self):
398
399    def compute_f32(np_func):
400      """Decorator to compute Numpy function with float32 math."""
401
402      def f(x):
403        y = np_func(x.astype(np.float32))
404        return y.astype(x.dtype)
405
406      return f
407
408    bfloat16 = dtypes_lib.bfloat16.as_numpy_dtype
409    x = np.arange(-6, 6,
410                  2).reshape(1, 3, 2).astype(bfloat16)
411    w = x - x.min() + 1.1  # all greater than 1
412    y = (x + .5).astype(bfloat16)  # no zero
413    z = (x + 15.5).astype(bfloat16)  # all positive
414    k = np.arange(-0.90, 0.90, 0.05).astype(bfloat16)  # between -1 and 1
415    self._compareCpu(x, np.abs, math_ops.abs)
416    self._compareCpu(x, np.abs, _ABS)
417    self._compareBoth(x, np.negative, math_ops.negative)
418    self._compareBoth(x, np.negative, _NEG)
419    self._compareCpu(y, compute_f32(self._inv), math_ops.reciprocal)
420    self._compareCpu(x, np.exp, math_ops.exp)
421    self._compareCpu(x, np.expm1, math_ops.expm1)
422    self._compareCpu(z, compute_f32(np.log), math_ops.log)
423    self._compareCpu(z, compute_f32(np.log1p), math_ops.log1p)
424    self._compareCpu(y, np.sign, math_ops.sign)
425    self._compareCpu(z, self._rsqrt, math_ops.rsqrt)
426    self._compareBoth(x, compute_f32(np.sin), math_ops.sin)
427    self._compareBoth(x, compute_f32(np.cos), math_ops.cos)
428    self._compareBoth(x, compute_f32(np.tan), math_ops.tan)
429    self._compareBoth(x, compute_f32(np.sinh), math_ops.sinh)
430    self._compareBoth(x, compute_f32(np.cosh), math_ops.cosh)
431    self._compareBoth(x, compute_f32(np.tanh), math_ops.tanh)
432    self._compareBoth(k, compute_f32(np.arcsin), math_ops.asin)
433    self._compareBoth(k, compute_f32(np.arccos), math_ops.acos)
434    self._compareBoth(x, compute_f32(np.arctan), math_ops.atan)
435    self._compareBoth(x, compute_f32(np.arcsinh), math_ops.asinh)
436    self._compareBoth(w, compute_f32(np.arccosh), math_ops.acosh)
437    self._compareBoth(k, compute_f32(np.arctanh), math_ops.atanh,
438                      grad_tol=1e-2)
439    self._compareBoth(x, compute_f32(np.vectorize(math.erf)), math_ops.erf)
440    self._compareBoth(x, compute_f32(np.vectorize(math.erfc)), math_ops.erfc)
441
442  @test.disable_with_predicate(
443      pred=test.is_built_with_rocm, skip_message="On ROCm this test fails")
444  def testInt8Basic(self):
445    x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int8)
446    self._compareCpu(x, np.abs, math_ops.abs)
447    self._compareCpu(x, np.abs, _ABS)
448    self._compareBoth(x, np.negative, math_ops.negative)
449    self._compareBoth(x, np.negative, _NEG)
450    self._compareBoth(x, np.sign, math_ops.sign)
451
452  @test.disable_with_predicate(
453      pred=test.is_built_with_rocm, skip_message="On ROCm this test fails")
454  def testUInt8Basic(self):
455    x = np.arange(6).reshape(1, 3, 2).astype(np.uint8)
456    self._compareBoth(x, np.square, math_ops.square)
457
458  @test.disable_with_predicate(
459      pred=test.is_built_with_rocm, skip_message="On ROCm this test fails")
460  def testInt16Basic(self):
461    x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int16)
462    self._compareCpu(x, np.abs, math_ops.abs)
463    self._compareCpu(x, np.abs, _ABS)
464    self._compareBoth(x, np.negative, math_ops.negative)
465    self._compareBoth(x, np.negative, _NEG)
466    self._compareBoth(x, np.sign, math_ops.sign)
467
468  @test.disable_with_predicate(
469      pred=test.is_built_with_rocm, skip_message="On ROCm this test fails")
470  def testUInt16Basic(self):
471    x = np.arange(6).reshape(1, 3, 2).astype(np.uint16)
472    self._compareBoth(x, np.square, math_ops.square)
473
474  def testInt32Basic(self):
475    x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
476    self._compareCpu(x, np.abs, math_ops.abs)
477    self._compareCpu(x, np.abs, _ABS)
478    self._compareBoth(x, np.negative, math_ops.negative)
479    self._compareBoth(x, np.negative, _NEG)
480    self._compareBoth(x, np.square, math_ops.square)
481    self._compareCpu(x, np.sign, math_ops.sign)
482
483    self._compareBothSparse(x, np.abs, math_ops.abs)
484    self._compareBothSparse(x, np.negative, math_ops.negative)
485    self._compareBothSparse(x, np.square, math_ops.square)
486    self._compareBothSparse(x, np.sign, math_ops.sign)
487
488  @test.disable_with_predicate(
489      pred=test.is_built_with_rocm, skip_message="On ROCm this test fails")
490  def testUInt32Basic(self):
491    x = np.arange(6).reshape(1, 3, 2).astype(np.uint32)
492    self._compareBoth(x, np.square, math_ops.square)
493
494  def testInt64Basic(self):
495    x = np.arange(-6 << 40, 6 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
496    self._compareCpu(x, np.abs, math_ops.abs)
497    self._compareCpu(x, np.abs, _ABS)
498    self._compareCpu(x, np.negative, math_ops.negative)
499    self._compareCpu(x, np.negative, _NEG)
500    self._compareCpu(x, np.sign, math_ops.sign)
501
502    self._compareBothSparse(x, np.abs, math_ops.abs)
503    self._compareBothSparse(x, np.negative, math_ops.negative)
504    self._compareBothSparse(x, np.sign, math_ops.sign)
505
506  def testInt64Square(self):
507    x = np.arange(-6 << 20, 6 << 20, 2 << 20).reshape(1, 3, 2).astype(np.int64)
508    self._compareCpu(x, np.square, math_ops.square)
509    self._compareBothSparse(x, np.square, math_ops.square)
510
511  @test.disable_with_predicate(
512      pred=test.is_built_with_rocm, skip_message="On ROCm this test fails")
513  def testUInt64Basic(self):
514    x = np.arange(6).reshape(1, 3, 2).astype(np.uint64)
515    self._compareBoth(x, np.square, math_ops.square)
516
517  @test_util.run_deprecated_v1
518  def testComplex64Basic(self):
519    x = (1 + 1j) * np.arange(-3, 3).reshape(1, 3, 2).astype(np.complex64)
520    y = x + (0.5 + 0.5j)  # no zeros
521    self._compareBoth(x, np.abs, math_ops.abs)
522    self._compareBoth(x, np.abs, _ABS)
523    self._compareBoth(x, np.negative, math_ops.negative)
524    self._compareBoth(x, np.negative, _NEG)
525    self._compareBoth(y, self._inv, math_ops.reciprocal)
526    self._compareCpu(x, np.square, math_ops.square)
527    self._compareCpu(y, np.sqrt, math_ops.sqrt)
528    self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
529    self._compareBoth(x, np.exp, math_ops.exp)
530    self._compareCpu(x, np.expm1, math_ops.expm1)
531    self._compareCpu(y, np.log, math_ops.log)
532    self._compareCpu(y, np.log1p, math_ops.log1p)
533    self._compareCpu(x, np.sinh, math_ops.sinh)
534    self._compareCpu(x, np.cosh, math_ops.cosh)
535    self._compareCpu(x, np.tanh, math_ops.tanh)
536    self._compareCpu(x, np.arcsin, math_ops.asin)
537    self._compareCpu(x, np.arctan, math_ops.atan)
538
539    # Complex64 versions of asinh() and acosh() in libstdc++ only have 6 digits
540    # of precision.
541    # Small gradient values + low precision --> High relative error
542    self._compareCpu(y, np.arcsinh, math_ops.asinh, grad_rtol=1e-2)
543    self._compareCpu(y, np.arccosh, math_ops.acosh, grad_rtol=1e-2)
544
545    self._compareCpu(y, np.arctanh, math_ops.atanh)
546    self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
547    self._compareCpu(x, np.sin, math_ops.sin)
548    self._compareCpu(x, np.cos, math_ops.cos)
549
550    self._compareBothSparse(x, np.abs, math_ops.abs)
551    self._compareBothSparse(x, np.negative, math_ops.negative)
552    self._compareBothSparse(x, np.square, math_ops.square)
553    self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
554    self._compareBothSparse(x, np.tanh, math_ops.tanh)
555
556    # Numpy uses an incorrect definition of sign; use the right one instead.
557    def complex_sign(x):
558      return x / np.abs(x)
559
560    self._compareBoth(y, complex_sign, math_ops.sign)
561    self._compareBothSparse(y, complex_sign, math_ops.sign)
562
563  @test_util.run_deprecated_v1
564  def testComplex128Basic(self):
565    x = (1 + 1j) * np.arange(-3, 3).reshape(1, 3, 2).astype(np.complex128)
566    y = x + (0.5 + 0.5j)  # no zeros
567    self._compareBoth(x, np.abs, math_ops.abs)
568    self._compareBoth(x, np.abs, _ABS)
569    self._compareBoth(x, np.negative, math_ops.negative)
570    self._compareBoth(x, np.negative, _NEG)
571    self._compareBoth(y, self._inv, math_ops.reciprocal)
572    self._compareCpu(x, np.square, math_ops.square)
573    self._compareCpu(y, np.sqrt, math_ops.sqrt)
574    self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
575    self._compareBoth(x, np.exp, math_ops.exp)
576    self._compareCpu(x, np.expm1, math_ops.expm1)
577    self._compareCpu(y, np.log, math_ops.log)
578    self._compareCpu(y, np.log1p, math_ops.log1p)
579    self._compareCpu(x, np.sinh, math_ops.sinh)
580    self._compareCpu(x, np.cosh, math_ops.cosh)
581    self._compareCpu(x, np.tanh, math_ops.tanh)
582    self._compareCpu(y, np.arcsinh, math_ops.asinh)
583    self._compareCpu(y, np.arccosh, math_ops.acosh)
584    self._compareCpu(y, np.arctanh, math_ops.atanh)
585    self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
586    self._compareCpu(x, np.sin, math_ops.sin)
587    self._compareCpu(x, np.cos, math_ops.cos)
588    self._compareCpu(x, np.arcsin, math_ops.asin)
589    self._compareCpu(x, np.arctan, math_ops.atan)
590
591    self._compareBothSparse(x, np.abs, math_ops.abs)
592    self._compareBothSparse(x, np.negative, math_ops.negative)
593    self._compareBothSparse(x, np.square, math_ops.square)
594    self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
595    self._compareBothSparse(x, np.tanh, math_ops.tanh)
596
597    # Numpy uses an incorrect definition of sign; use the right one instead.
598    def complex_sign(x):
599      return x / np.abs(x)
600
601    self._compareBoth(y, complex_sign, math_ops.sign)
602    self._compareBothSparse(y, complex_sign, math_ops.sign)
603
604  @test_util.run_deprecated_v1
605  def testGradGrad(self):
606    np.random.seed(7)
607    shape = (5,)
608    dtype_tols = [(np.float32, 5e-4), (np.float64, 1e-6), (np.complex64, 5e-4),
609                  (np.complex128, 1e-6)]
610    op_range = [
611        (gen_math_ops.reciprocal_grad, [-2, 2]),
612        (gen_math_ops.rsqrt_grad, [0.1, 3]),
613        (gen_math_ops.sigmoid_grad, [-2, 2]),
614        (gen_math_ops.sqrt_grad, [0.1, 3]),
615        (gen_math_ops.tanh_grad, [-2, 2]),
616    ]
617
618    def rand(dtype, real_range):
619      x = np.random.uniform(
620          real_range[0], real_range[1], size=shape[0]).astype(dtype)
621      if dtype in (np.complex64, np.complex128):
622        x += 1j * np.random.uniform(-2, 2, size=shape[0]).astype(dtype)
623      return x
624
625    for op, real_range in op_range:
626      with self.cached_session():
627        for dtype, tol in dtype_tols:
628          x = constant_op.constant(rand(dtype, real_range))
629          y = constant_op.constant(rand(dtype, real_range))
630          z = op(x, y)
631          grads = gradient_checker.compute_gradient(
632              [x, y], [shape, shape],
633              z,
634              shape,
635              x_init_value=[rand(dtype, real_range),
636                            rand(dtype, real_range)])
637          if isinstance(grads, tuple):
638            grads = [grads]
639          for analytical, numerical in grads:
640            self.assertAllClose(analytical, numerical, rtol=tol, atol=tol)
641
642  @test_util.run_in_graph_and_eager_modes
643  def testComplexAbsGradGrad(self):
644
645    def f(x):
646      real = math_ops.cos(x)
647      imag = ops.convert_to_tensor(1.)
648      return math_ops.abs(math_ops.complex(real, imag))
649
650    def g(x):
651      with backprop.GradientTape() as t:
652        t.watch(x)
653        y = f(x)
654      return t.gradient(y, x)
655
656    err = gradient_checker_v2.max_error(
657        *gradient_checker_v2.compute_gradient(g, [ops.convert_to_tensor(2.0)]))
658    self.assertLess(err, 1e-3)
659
660
661if __name__ == "__main__":
662  test.main()
663