• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test cases for operators with no arguments."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.compiler.tests import xla_test
24from tensorflow.python.framework import constant_op
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.platform import googletest
27
28
29class NullaryOpsTest(xla_test.XLATestCase):
30
31  def _testNullary(self, op, expected):
32    with self.cached_session() as session:
33      with self.test_scope():
34        output = op()
35      result = session.run(output)
36      self.assertAllClose(result, expected, rtol=1e-3)
37
38  def testNoOp(self):
39    with self.cached_session():
40      with self.test_scope():
41        output = control_flow_ops.no_op()
42      # This should not crash.
43      output.run()
44
45  def testConstants(self):
46    for dtype in self.numeric_types:
47      constants = [
48          dtype(42),
49          np.array([], dtype=dtype),
50          np.array([1, 2], dtype=dtype),
51          np.array([7, 7, 7, 7, 7], dtype=dtype),
52          np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype),
53          np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]],
54                   dtype=dtype),
55          np.array([[[]], [[]]], dtype=dtype),
56          np.array([[[[1]]]], dtype=dtype),
57      ]
58      for c in constants:
59        self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
60
61  def testComplexConstants(self):
62    for dtype in self.complex_types:
63      constants = [
64          dtype(42 + 3j),
65          np.array([], dtype=dtype),
66          np.ones([50], dtype=dtype) * (3 + 4j),
67          np.array([1j, 2 + 1j], dtype=dtype),
68          np.array([[1, 2j, 7j], [4, 5, 6]], dtype=dtype),
69          np.array([[[1, 2], [3, 4 + 6j], [5, 6]],
70                    [[10 + 7j, 20], [30, 40], [50, 60]]],
71                   dtype=dtype),
72          np.array([[[]], [[]]], dtype=dtype),
73          np.array([[[[1 + 3j]]]], dtype=dtype),
74      ]
75      for c in constants:
76        self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
77
78
79if __name__ == "__main__":
80  googletest.main()
81