1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional test for slot_creator.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import test_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import random_ops 27from tensorflow.python.ops import variable_scope 28from tensorflow.python.ops import variables 29from tensorflow.python.platform import test 30from tensorflow.python.training import slot_creator 31 32 33class SlotCreatorTest(test.TestCase): 34 35 @test_util.run_v1_only("b/120545219") 36 def testCreateSlotFromVariable(self): 37 with self.cached_session(): 38 v = variables.Variable([1.0, 2.5], name="var") 39 slot = slot_creator.create_slot(v, v.initialized_value(), name="slot") 40 41 self.evaluate(variables.global_variables_initializer()) 42 43 self.assertEqual("var/slot", slot.op.name) 44 self.assertEqual([2], slot.get_shape().as_list()) 45 self.assertEqual(dtypes.float32, slot.dtype.base_dtype) 46 self.assertAllEqual([1.0, 2.5], self.evaluate(slot)) 47 48 @test_util.run_deprecated_v1 49 def testCreateSlotFromTensor(self): 50 with self.cached_session(): 51 v = constant_op.constant([1.0, 2.5], name="const") 52 slot = slot_creator.create_slot(v, v * 2, name="slot") 53 54 self.evaluate(variables.global_variables_initializer()) 55 56 self.assertEqual("const/slot", slot.op.name) 57 self.assertEqual([2], slot.get_shape().as_list()) 58 self.assertEqual(dtypes.float32, slot.dtype.base_dtype) 59 self.assertAllEqual([2.0, 5.0], self.evaluate(slot)) 60 61 @test_util.run_deprecated_v1 62 def testCreateZerosSlotFromVariable(self): 63 with self.cached_session(): 64 v = variables.Variable([1.0, 2.5], name="var") 65 with ops.control_dependencies(None): 66 slot = slot_creator.create_zeros_slot( 67 v, name="slot", dtype=dtypes.float64) 68 69 self.evaluate(variables.global_variables_initializer()) 70 71 self.assertEqual("var/slot", slot.op.name) 72 self.assertEqual([2], slot.get_shape().as_list()) 73 self.assertEqual(dtypes.float64, slot.dtype.base_dtype) 74 self.assertAllEqual([0.0, 0.0], self.evaluate(slot)) 75 76 @test_util.run_v1_only("b/120545219") 77 def testCreateZerosSlotFromDynamicShapedVariable(self): 78 with self.cached_session(): 79 dyn_shape = constant_op.constant([2], dtype=dtypes.int32) 80 dyn_shape = array_ops.placeholder_with_default(dyn_shape, 81 shape=[None]) 82 v = variable_scope.get_variable( 83 "var", 84 initializer=random_ops.random_uniform(dyn_shape, 85 dtype=dtypes.float64), 86 validate_shape=False) 87 with ops.control_dependencies(None): 88 slot = slot_creator.create_zeros_slot( 89 v, name="slot", dtype=dtypes.float64) 90 91 self.evaluate(variables.global_variables_initializer()) 92 93 self.assertEqual("var/slot", slot.op.name) 94 self.assertEqual([2], array_ops.shape(slot).eval()) 95 self.assertEqual(dtypes.float64, slot.dtype.base_dtype) 96 self.assertAllEqual([0.0, 0.0], self.evaluate(slot)) 97 98 @test_util.run_deprecated_v1 99 def testCreateZerosSlotFromTensor(self): 100 with self.cached_session(): 101 v = constant_op.constant([1.0, 2.5], name="const") 102 with ops.control_dependencies(None): 103 slot = slot_creator.create_zeros_slot(v, name="slot") 104 105 self.evaluate(variables.global_variables_initializer()) 106 107 self.assertEqual("const/slot", slot.op.name) 108 self.assertEqual([2], slot.get_shape().as_list()) 109 self.assertEqual(dtypes.float32, slot.dtype.base_dtype) 110 self.assertAllEqual([0.0, 0.0], self.evaluate(slot)) 111 112 @test_util.run_deprecated_v1 113 def testCreateZerosSlotFromDynamicShapedTensor(self): 114 with self.cached_session(): 115 v = random_ops.random_uniform([2], dtype=dtypes.float64) 116 v = array_ops.placeholder_with_default(v, shape=[None], name="const") 117 with ops.control_dependencies(None): 118 slot = slot_creator.create_zeros_slot( 119 v, name="slot", dtype=dtypes.float64) 120 121 self.evaluate(variables.global_variables_initializer()) 122 123 self.assertEqual("const/slot", slot.op.name) 124 self.assertEqual([2], array_ops.shape(slot).eval()) 125 self.assertEqual(dtypes.float64, slot.dtype.base_dtype) 126 self.assertAllEqual([0.0, 0.0], self.evaluate(slot)) 127 128 @test_util.run_v1_only("b/120545219") 129 def testCreateSlotFromVariableRespectsScope(self): 130 # See discussion on #2740. 131 with self.cached_session(): 132 with variable_scope.variable_scope("scope"): 133 v = variables.Variable([1.0, 2.5], name="var") 134 slot = slot_creator.create_slot(v, v.initialized_value(), name="slot") 135 self.assertEqual("scope/scope/var/slot", slot.op.name) 136 137 138if __name__ == "__main__": 139 test.main() 140