• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional test for slot_creator."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import random_ops
27from tensorflow.python.ops import variable_scope
28from tensorflow.python.ops import variables
29from tensorflow.python.platform import test
30from tensorflow.python.training import slot_creator
31
32
33class SlotCreatorTest(test.TestCase):
34
35  @test_util.run_v1_only("b/120545219")
36  def testCreateSlotFromVariable(self):
37    with self.cached_session():
38      v = variables.Variable([1.0, 2.5], name="var")
39      slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
40
41      self.evaluate(variables.global_variables_initializer())
42
43      self.assertEqual("var/slot", slot.op.name)
44      self.assertEqual([2], slot.get_shape().as_list())
45      self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
46      self.assertAllEqual([1.0, 2.5], self.evaluate(slot))
47
48  @test_util.run_deprecated_v1
49  def testCreateSlotFromTensor(self):
50    with self.cached_session():
51      v = constant_op.constant([1.0, 2.5], name="const")
52      slot = slot_creator.create_slot(v, v * 2, name="slot")
53
54      self.evaluate(variables.global_variables_initializer())
55
56      self.assertEqual("const/slot", slot.op.name)
57      self.assertEqual([2], slot.get_shape().as_list())
58      self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
59      self.assertAllEqual([2.0, 5.0], self.evaluate(slot))
60
61  @test_util.run_deprecated_v1
62  def testCreateZerosSlotFromVariable(self):
63    with self.cached_session():
64      v = variables.Variable([1.0, 2.5], name="var")
65      with ops.control_dependencies(None):
66        slot = slot_creator.create_zeros_slot(
67            v, name="slot", dtype=dtypes.float64)
68
69      self.evaluate(variables.global_variables_initializer())
70
71      self.assertEqual("var/slot", slot.op.name)
72      self.assertEqual([2], slot.get_shape().as_list())
73      self.assertEqual(dtypes.float64, slot.dtype.base_dtype)
74      self.assertAllEqual([0.0, 0.0], self.evaluate(slot))
75
76  @test_util.run_v1_only("b/120545219")
77  def testCreateZerosSlotFromDynamicShapedVariable(self):
78    with self.cached_session():
79      dyn_shape = constant_op.constant([2], dtype=dtypes.int32)
80      dyn_shape = array_ops.placeholder_with_default(dyn_shape,
81                                                     shape=[None])
82      v = variable_scope.get_variable(
83          "var",
84          initializer=random_ops.random_uniform(dyn_shape,
85                                                dtype=dtypes.float64),
86          validate_shape=False)
87      with ops.control_dependencies(None):
88        slot = slot_creator.create_zeros_slot(
89            v, name="slot", dtype=dtypes.float64)
90
91      self.evaluate(variables.global_variables_initializer())
92
93      self.assertEqual("var/slot", slot.op.name)
94      self.assertEqual([2], array_ops.shape(slot).eval())
95      self.assertEqual(dtypes.float64, slot.dtype.base_dtype)
96      self.assertAllEqual([0.0, 0.0], self.evaluate(slot))
97
98  @test_util.run_deprecated_v1
99  def testCreateZerosSlotFromTensor(self):
100    with self.cached_session():
101      v = constant_op.constant([1.0, 2.5], name="const")
102      with ops.control_dependencies(None):
103        slot = slot_creator.create_zeros_slot(v, name="slot")
104
105      self.evaluate(variables.global_variables_initializer())
106
107      self.assertEqual("const/slot", slot.op.name)
108      self.assertEqual([2], slot.get_shape().as_list())
109      self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
110      self.assertAllEqual([0.0, 0.0], self.evaluate(slot))
111
112  @test_util.run_deprecated_v1
113  def testCreateZerosSlotFromDynamicShapedTensor(self):
114    with self.cached_session():
115      v = random_ops.random_uniform([2], dtype=dtypes.float64)
116      v = array_ops.placeholder_with_default(v, shape=[None], name="const")
117      with ops.control_dependencies(None):
118        slot = slot_creator.create_zeros_slot(
119            v, name="slot", dtype=dtypes.float64)
120
121      self.evaluate(variables.global_variables_initializer())
122
123      self.assertEqual("const/slot", slot.op.name)
124      self.assertEqual([2], array_ops.shape(slot).eval())
125      self.assertEqual(dtypes.float64, slot.dtype.base_dtype)
126      self.assertAllEqual([0.0, 0.0], self.evaluate(slot))
127
128  @test_util.run_v1_only("b/120545219")
129  def testCreateSlotFromVariableRespectsScope(self):
130    # See discussion on #2740.
131    with self.cached_session():
132      with variable_scope.variable_scope("scope"):
133        v = variables.Variable([1.0, 2.5], name="var")
134        slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
135        self.assertEqual("scope/scope/var/slot", slot.op.name)
136
137
138if __name__ == "__main__":
139  test.main()
140