• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for misc module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.autograph.utils import misc
22from tensorflow.python.eager import def_function
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import test_util
25from tensorflow.python.framework.constant_op import constant
26from tensorflow.python.ops.variables import Variable
27from tensorflow.python.platform import test
28
29
30class MiscTest(test.TestCase):
31
32  @test_util.run_deprecated_v1
33  def test_alias_single_tensor(self):
34    a = constant(1)
35
36    new_a = misc.alias_tensors(a)
37    self.assertFalse(new_a is a)
38    with self.cached_session() as sess:
39      self.assertEqual(1, self.evaluate(new_a))
40
41  @test_util.run_deprecated_v1
42  def test_alias_tensors(self):
43    a = constant(1)
44    v = Variable(2)
45    s = 'a'
46    l = [1, 2, 3]
47
48    new_a, new_v, new_s, new_l = misc.alias_tensors(a, v, s, l)
49
50    self.assertFalse(new_a is a)
51    self.assertTrue(new_v is v)
52    self.assertTrue(new_s is s)
53    self.assertTrue(new_l is l)
54    with self.cached_session() as sess:
55      self.assertEqual(1, self.evaluate(new_a))
56
57  def test_get_range_len(self):
58    get_range_as_graph = def_function.function(misc.get_range_len)
59    test_range = [(i, constant_op.constant(i)) for i in range(-3, 3)]
60    results = []
61    for i, ti in test_range:
62      for j, tj in test_range:
63        for k, tk in test_range:
64          if k == 0:
65            continue
66          results.append(((i, j, k), get_range_as_graph(ti, tj, tk)))
67
68    for (i, j, k), result_tensor in results:
69      self.assertEqual(
70          len(list(range(i, j, k))), self.evaluate(result_tensor))
71
72
73if __name__ == '__main__':
74  test.main()
75