• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.session_ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops import session_ops
26from tensorflow.python.ops import state_ops
27from tensorflow.python.ops import variables
28from tensorflow.python.platform import test
29
30
31class SessionOpsTest(test.TestCase):
32
33  def testHandleBasic(self):
34    with self.test_session() as sess:
35      # Return a handle.
36      a = constant_op.constant(10)
37      b = constant_op.constant(5)
38      c = math_ops.multiply(a, b)
39      h = session_ops.get_session_handle(c)
40      h = sess.run(h)
41
42      # Feed a tensor handle.
43      f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
44      y = math_ops.multiply(x, 10)
45      self.assertEqual(500, sess.run(y, feed_dict={f: h.handle}))
46
47  def testHandleEval(self):
48    with self.test_session() as sess:
49      # Return a handle.
50      a = constant_op.constant(10)
51      b = constant_op.constant(5)
52      c = math_ops.multiply(a, b)
53      h = session_ops.get_session_handle(c)
54      h = sess.run(h)
55
56      # Get the tensor from its handle.
57      self.assertEqual(50, h.eval())
58
59  def testHandleAndValue(self):
60    with self.test_session() as sess:
61      # Return a handle and a value.
62      a = constant_op.constant(10)
63      b = constant_op.constant(5)
64      c = math_ops.multiply(a, b)
65      h = session_ops.get_session_handle(c)
66      v = math_ops.multiply(a, c)
67      h, v = sess.run([h, v])
68
69      self.assertEqual(50, h.eval())
70      self.assertEqual(500, v)
71
72  def testHandleCond(self):
73    with self.test_session() as sess:
74      # Return a handle and a value
75      a = constant_op.constant(10)
76      b = constant_op.constant(5)
77      p = math_ops.less(a, b)
78      c = math_ops.multiply(a, b)
79      h = session_ops.get_session_handle(c)
80      p, h = sess.run([p, h])
81
82      # Run by feeding a tensor handle.
83      f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
84      if p:
85        y = math_ops.multiply(x, 10)
86      else:
87        y = math_ops.multiply(x, 100)
88      result = sess.run(y, feed_dict={f: h.handle})
89
90      self.assertEqual(5000, result)
91
92  def testHandleForLoop(self):
93    with self.test_session() as sess:
94      # Initialize a handle.
95      a = constant_op.constant(0)
96      h = session_ops.get_session_handle(a)
97      h = sess.run(h)
98
99      # Do some computation.
100      f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
101      # Must define the loop body outside the loop.
102      h_x = session_ops.get_session_handle(math_ops.add(x, 1))
103      for _ in range(100):
104        # This exercises garbage collection.
105        h = sess.run(h_x, feed_dict={f: h.handle})
106
107      self.assertEqual(100, h.eval())
108
109  def testHandleWhileLoop(self):
110    with self.test_session() as sess:
111      # Initialize a handle.
112      a = constant_op.constant(0)
113      h = session_ops.get_session_handle(a)
114      h = sess.run(h)
115
116      # Do some computation.
117      f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
118      b = constant_op.constant(100)
119      p = math_ops.less(x, b)
120      # Must define the loop body outside the loop.
121      h_x = session_ops.get_session_handle(math_ops.add(x, 1))
122      while True:
123        rp, h = sess.run([p, h_x], feed_dict={f: h.handle})
124        if not rp:
125          break
126
127      self.assertEqual(101, h.eval())
128
129  def testHandleMover(self):
130    with self.test_session() as sess:
131      # Return a handle.
132      a = constant_op.constant(10)
133      b = constant_op.constant(5)
134      c = math_ops.multiply(a, b)
135      h = session_ops.get_session_handle(c)
136      h = sess.run(h)
137
138      # Feed a tensor handle.
139      f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
140      y = math_ops.multiply(x, 10)
141      self.assertEqual(500, sess.run(y, feed_dict={f: h.handle}))
142
143      # Feed another tensor handle.
144      with ops.device(test.gpu_device_name()):
145        a = constant_op.constant(10)
146        h = session_ops.get_session_handle(a)
147        h = sess.run(h)
148        self.assertEqual(100, sess.run(y, feed_dict={f: h.handle}))
149
150  def testHandleDelete(self):
151    with self.test_session() as sess:
152      # Return a handle.
153      a = constant_op.constant(10)
154      b = constant_op.constant(5)
155      c = math_ops.multiply(a, b)
156      h = session_ops.get_session_handle(c)
157      sess.run(h).delete()
158
159  def testHandleDeleteRaw(self):
160    with self.test_session() as sess:
161      # Return a handle.
162      a = constant_op.constant(10)
163      b = constant_op.constant(5)
164      c = math_ops.multiply(a, b)
165      h = session_ops.get_session_handle(c)
166      h = sess.run(h)
167
168      # Delete using a raw tensor handle.
169      raw_h = h.get_raw_handle()
170      f, x = session_ops.delete_session_tensor(raw_h)
171      sess.run(x, feed_dict={f: raw_h})
172
173  def testMultiDevices(self):
174    with self.test_session() as sess:
175      with ops.device(test.gpu_device_name()):
176        a = constant_op.constant(1.0)
177        a_handle = sess.run(session_ops.get_session_handle(a))
178      with ops.device("/cpu:0"):
179        b = constant_op.constant(2.0)
180        b_handle = sess.run(session_ops.get_session_handle(b))
181
182      a_p, a_t = session_ops.get_session_tensor(a_handle.handle, dtypes.float32)
183      b_p, b_t = session_ops.get_session_tensor(b_handle.handle, dtypes.float32)
184      c = math_ops.add(a_t, b_t)
185      c_handle = sess.run(
186          session_ops.get_session_handle(c),
187          feed_dict={a_p: a_handle.handle,
188                     b_p: b_handle.handle})
189      self.assertEqual(3.0, c_handle.eval())
190
191  def testHandleGC(self):
192    with self.test_session() as sess:
193      # initial values live on CPU
194      with ops.device("/cpu:0"):
195        one = constant_op.constant(1, dtype=dtypes.float32)
196        one_handle = sess.run(session_ops.get_session_handle(one))
197        x_handle = sess.run(session_ops.get_session_handle(one))
198
199      # addition lives on GPU
200      with ops.device(test.gpu_device_name()):
201        add_h1, add_t1 = session_ops.get_session_tensor(one_handle.handle,
202                                                        dtypes.float32)
203        add_h2, add_t2 = session_ops.get_session_tensor(x_handle.handle,
204                                                        dtypes.float32)
205        add_op = math_ops.add(add_t1, add_t2)
206        add_output = session_ops.get_session_handle(add_op)
207
208      # add 1 to tensor 20 times
209      for _ in range(20):
210        x_handle = sess.run(
211            add_output,
212            feed_dict={add_h1: one_handle.handle,
213                       add_h2: x_handle.handle})
214
215  def testHandlePlacement(self):
216    with self.test_session() as sess:
217      a = constant_op.constant(1.0)
218      a_handle_op = session_ops.get_session_handle(a)
219      b = constant_op.constant(2.0)
220      b_handle_op = session_ops.get_session_handle(b)
221
222      a_handle = sess.run(a_handle_op)
223      b_handle = sess.run(b_handle_op)
224
225      a_p, a_t = session_ops.get_session_tensor(a_handle.handle, dtypes.float32)
226      b_p, b_t = session_ops.get_session_tensor(b_handle.handle, dtypes.float32)
227
228      c = math_ops.add(a_t, b_t)
229      c_handle = sess.run(
230          session_ops.get_session_handle(c),
231          feed_dict={a_p: a_handle.handle,
232                     b_p: b_handle.handle})
233      self.assertEqual(3.0, c_handle.eval())
234
235  def testFeedOneHandleDirectly(self):
236    with self.test_session() as sess:
237      a = constant_op.constant(10.0)
238      b = constant_op.constant(5.0)
239      c = math_ops.multiply(a, b)
240      d = math_ops.multiply(c, c)
241
242      h_c = sess.run(session_ops.get_session_handle(c))
243
244      self.assertAllClose(2500.0, sess.run(d, feed_dict={c: h_c}))
245
246  def testDirectHandleFeedOverlappingWithFetches(self):
247    with self.test_session() as sess:
248      a = constant_op.constant(10.0)
249      b = constant_op.constant(5.0)
250      c = math_ops.multiply(a, b)
251      h_c = sess.run(session_ops.get_session_handle(c))
252      d = array_ops.identity(c)
253
254      c_val = sess.run(c, feed_dict={c: h_c})
255      self.assertAllClose(50.0, c_val)
256
257      d_val = sess.run(d, feed_dict={c: h_c})
258      self.assertAllClose(50.0, d_val)
259
260      c_val, d_val = sess.run([c, d], feed_dict={c: h_c, d: 60.0})
261      self.assertAllClose(50.0, c_val)
262      self.assertAllClose(60.0, d_val)
263
264      c_val, d_val = sess.run([c, d], feed_dict={c: 60.0, d: h_c})
265      self.assertAllClose(60.0, c_val)
266      self.assertAllClose(50.0, d_val)
267
268      c_val, d_val = sess.run([c, d], feed_dict={c: h_c, d: h_c})
269      self.assertAllClose(50.0, c_val)
270      self.assertAllClose(50.0, d_val)
271
272  def testFeedTwoHandlesDirectly(self):
273    with self.test_session() as sess:
274      a = constant_op.constant(10.0)
275      b = constant_op.constant(5.0)
276      c = math_ops.multiply(a, b)
277      d = math_ops.div(a, b)
278      e = math_ops.subtract(c, d)
279
280      h_c = sess.run(session_ops.get_session_handle(c))
281      h_d = sess.run(session_ops.get_session_handle(d))
282
283      self.assertAllClose(48.0, sess.run(e, feed_dict={c: h_c, d: h_d}))
284      self.assertAllClose(-48.0, sess.run(e, feed_dict={c: h_d, d: h_c}))
285
286  def testFeedHandleToVariableDirectly(self):
287    with self.test_session() as sess:
288      a = variables.Variable(12.0)
289      inc_a = state_ops.assign_add(a, 2.0)
290      b = math_ops.add(a, 5.0)
291      sess.run(a.initializer)
292
293      h_a_read = sess.run(session_ops.get_session_handle(a.read_value()))
294      self.assertAllClose(12.0, sess.run(a))
295
296      self.assertAllClose(17.0, sess.run(b, feed_dict={a: h_a_read}))
297      sess.run(inc_a)
298      self.assertAllClose(19.0, sess.run(b, feed_dict={a: h_a_read}))
299
300
301if __name__ == "__main__":
302  test.main()
303