1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for misc module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.autograph.utils import misc 22from tensorflow.python.eager import def_function 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import test_util 25from tensorflow.python.framework.constant_op import constant 26from tensorflow.python.ops.variables import Variable 27from tensorflow.python.platform import test 28 29 30class MiscTest(test.TestCase): 31 32 @test_util.run_deprecated_v1 33 def test_alias_single_tensor(self): 34 a = constant(1) 35 36 new_a = misc.alias_tensors(a) 37 self.assertFalse(new_a is a) 38 with self.cached_session() as sess: 39 self.assertEqual(1, self.evaluate(new_a)) 40 41 @test_util.run_deprecated_v1 42 def test_alias_tensors(self): 43 a = constant(1) 44 v = Variable(2) 45 s = 'a' 46 l = [1, 2, 3] 47 48 new_a, new_v, new_s, new_l = misc.alias_tensors(a, v, s, l) 49 50 self.assertFalse(new_a is a) 51 self.assertTrue(new_v is v) 52 self.assertTrue(new_s is s) 53 self.assertTrue(new_l is l) 54 with self.cached_session() as sess: 55 self.assertEqual(1, self.evaluate(new_a)) 56 57 def test_get_range_len(self): 58 get_range_as_graph = def_function.function(misc.get_range_len) 59 test_range = [(i, constant_op.constant(i)) for i in range(-3, 3)] 60 results = [] 61 for i, ti in test_range: 62 for j, tj in test_range: 63 for k, tk in test_range: 64 if k == 0: 65 continue 66 results.append(((i, j, k), get_range_as_graph(ti, tj, tk))) 67 68 for (i, j, k), result_tensor in results: 69 self.assertEqual( 70 len(list(range(i, j, k))), self.evaluate(result_tensor)) 71 72 73if __name__ == '__main__': 74 test.main() 75