1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.session_ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops import session_ops 26from tensorflow.python.ops import state_ops 27from tensorflow.python.ops import variables 28from tensorflow.python.platform import test 29 30 31class SessionOpsTest(test.TestCase): 32 33 def testHandleBasic(self): 34 with self.test_session() as sess: 35 # Return a handle. 36 a = constant_op.constant(10) 37 b = constant_op.constant(5) 38 c = math_ops.multiply(a, b) 39 h = session_ops.get_session_handle(c) 40 h = sess.run(h) 41 42 # Feed a tensor handle. 43 f, x = session_ops.get_session_tensor(h.handle, dtypes.int32) 44 y = math_ops.multiply(x, 10) 45 self.assertEqual(500, sess.run(y, feed_dict={f: h.handle})) 46 47 def testHandleEval(self): 48 with self.test_session() as sess: 49 # Return a handle. 50 a = constant_op.constant(10) 51 b = constant_op.constant(5) 52 c = math_ops.multiply(a, b) 53 h = session_ops.get_session_handle(c) 54 h = sess.run(h) 55 56 # Get the tensor from its handle. 57 self.assertEqual(50, h.eval()) 58 59 def testHandleAndValue(self): 60 with self.test_session() as sess: 61 # Return a handle and a value. 62 a = constant_op.constant(10) 63 b = constant_op.constant(5) 64 c = math_ops.multiply(a, b) 65 h = session_ops.get_session_handle(c) 66 v = math_ops.multiply(a, c) 67 h, v = sess.run([h, v]) 68 69 self.assertEqual(50, h.eval()) 70 self.assertEqual(500, v) 71 72 def testHandleCond(self): 73 with self.test_session() as sess: 74 # Return a handle and a value 75 a = constant_op.constant(10) 76 b = constant_op.constant(5) 77 p = math_ops.less(a, b) 78 c = math_ops.multiply(a, b) 79 h = session_ops.get_session_handle(c) 80 p, h = sess.run([p, h]) 81 82 # Run by feeding a tensor handle. 83 f, x = session_ops.get_session_tensor(h.handle, dtypes.int32) 84 if p: 85 y = math_ops.multiply(x, 10) 86 else: 87 y = math_ops.multiply(x, 100) 88 result = sess.run(y, feed_dict={f: h.handle}) 89 90 self.assertEqual(5000, result) 91 92 def testHandleForLoop(self): 93 with self.test_session() as sess: 94 # Initialize a handle. 95 a = constant_op.constant(0) 96 h = session_ops.get_session_handle(a) 97 h = sess.run(h) 98 99 # Do some computation. 100 f, x = session_ops.get_session_tensor(h.handle, dtypes.int32) 101 # Must define the loop body outside the loop. 102 h_x = session_ops.get_session_handle(math_ops.add(x, 1)) 103 for _ in range(100): 104 # This exercises garbage collection. 105 h = sess.run(h_x, feed_dict={f: h.handle}) 106 107 self.assertEqual(100, h.eval()) 108 109 def testHandleWhileLoop(self): 110 with self.test_session() as sess: 111 # Initialize a handle. 112 a = constant_op.constant(0) 113 h = session_ops.get_session_handle(a) 114 h = sess.run(h) 115 116 # Do some computation. 117 f, x = session_ops.get_session_tensor(h.handle, dtypes.int32) 118 b = constant_op.constant(100) 119 p = math_ops.less(x, b) 120 # Must define the loop body outside the loop. 121 h_x = session_ops.get_session_handle(math_ops.add(x, 1)) 122 while True: 123 rp, h = sess.run([p, h_x], feed_dict={f: h.handle}) 124 if not rp: 125 break 126 127 self.assertEqual(101, h.eval()) 128 129 def testHandleMover(self): 130 with self.test_session() as sess: 131 # Return a handle. 132 a = constant_op.constant(10) 133 b = constant_op.constant(5) 134 c = math_ops.multiply(a, b) 135 h = session_ops.get_session_handle(c) 136 h = sess.run(h) 137 138 # Feed a tensor handle. 139 f, x = session_ops.get_session_tensor(h.handle, dtypes.int32) 140 y = math_ops.multiply(x, 10) 141 self.assertEqual(500, sess.run(y, feed_dict={f: h.handle})) 142 143 # Feed another tensor handle. 144 with ops.device(test.gpu_device_name()): 145 a = constant_op.constant(10) 146 h = session_ops.get_session_handle(a) 147 h = sess.run(h) 148 self.assertEqual(100, sess.run(y, feed_dict={f: h.handle})) 149 150 def testHandleDelete(self): 151 with self.test_session() as sess: 152 # Return a handle. 153 a = constant_op.constant(10) 154 b = constant_op.constant(5) 155 c = math_ops.multiply(a, b) 156 h = session_ops.get_session_handle(c) 157 sess.run(h).delete() 158 159 def testHandleDeleteRaw(self): 160 with self.test_session() as sess: 161 # Return a handle. 162 a = constant_op.constant(10) 163 b = constant_op.constant(5) 164 c = math_ops.multiply(a, b) 165 h = session_ops.get_session_handle(c) 166 h = sess.run(h) 167 168 # Delete using a raw tensor handle. 169 raw_h = h.get_raw_handle() 170 f, x = session_ops.delete_session_tensor(raw_h) 171 sess.run(x, feed_dict={f: raw_h}) 172 173 def testMultiDevices(self): 174 with self.test_session() as sess: 175 with ops.device(test.gpu_device_name()): 176 a = constant_op.constant(1.0) 177 a_handle = sess.run(session_ops.get_session_handle(a)) 178 with ops.device("/cpu:0"): 179 b = constant_op.constant(2.0) 180 b_handle = sess.run(session_ops.get_session_handle(b)) 181 182 a_p, a_t = session_ops.get_session_tensor(a_handle.handle, dtypes.float32) 183 b_p, b_t = session_ops.get_session_tensor(b_handle.handle, dtypes.float32) 184 c = math_ops.add(a_t, b_t) 185 c_handle = sess.run( 186 session_ops.get_session_handle(c), 187 feed_dict={a_p: a_handle.handle, 188 b_p: b_handle.handle}) 189 self.assertEqual(3.0, c_handle.eval()) 190 191 def testHandleGC(self): 192 with self.test_session() as sess: 193 # initial values live on CPU 194 with ops.device("/cpu:0"): 195 one = constant_op.constant(1, dtype=dtypes.float32) 196 one_handle = sess.run(session_ops.get_session_handle(one)) 197 x_handle = sess.run(session_ops.get_session_handle(one)) 198 199 # addition lives on GPU 200 with ops.device(test.gpu_device_name()): 201 add_h1, add_t1 = session_ops.get_session_tensor(one_handle.handle, 202 dtypes.float32) 203 add_h2, add_t2 = session_ops.get_session_tensor(x_handle.handle, 204 dtypes.float32) 205 add_op = math_ops.add(add_t1, add_t2) 206 add_output = session_ops.get_session_handle(add_op) 207 208 # add 1 to tensor 20 times 209 for _ in range(20): 210 x_handle = sess.run( 211 add_output, 212 feed_dict={add_h1: one_handle.handle, 213 add_h2: x_handle.handle}) 214 215 def testHandlePlacement(self): 216 with self.test_session() as sess: 217 a = constant_op.constant(1.0) 218 a_handle_op = session_ops.get_session_handle(a) 219 b = constant_op.constant(2.0) 220 b_handle_op = session_ops.get_session_handle(b) 221 222 a_handle = sess.run(a_handle_op) 223 b_handle = sess.run(b_handle_op) 224 225 a_p, a_t = session_ops.get_session_tensor(a_handle.handle, dtypes.float32) 226 b_p, b_t = session_ops.get_session_tensor(b_handle.handle, dtypes.float32) 227 228 c = math_ops.add(a_t, b_t) 229 c_handle = sess.run( 230 session_ops.get_session_handle(c), 231 feed_dict={a_p: a_handle.handle, 232 b_p: b_handle.handle}) 233 self.assertEqual(3.0, c_handle.eval()) 234 235 def testFeedOneHandleDirectly(self): 236 with self.test_session() as sess: 237 a = constant_op.constant(10.0) 238 b = constant_op.constant(5.0) 239 c = math_ops.multiply(a, b) 240 d = math_ops.multiply(c, c) 241 242 h_c = sess.run(session_ops.get_session_handle(c)) 243 244 self.assertAllClose(2500.0, sess.run(d, feed_dict={c: h_c})) 245 246 def testDirectHandleFeedOverlappingWithFetches(self): 247 with self.test_session() as sess: 248 a = constant_op.constant(10.0) 249 b = constant_op.constant(5.0) 250 c = math_ops.multiply(a, b) 251 h_c = sess.run(session_ops.get_session_handle(c)) 252 d = array_ops.identity(c) 253 254 c_val = sess.run(c, feed_dict={c: h_c}) 255 self.assertAllClose(50.0, c_val) 256 257 d_val = sess.run(d, feed_dict={c: h_c}) 258 self.assertAllClose(50.0, d_val) 259 260 c_val, d_val = sess.run([c, d], feed_dict={c: h_c, d: 60.0}) 261 self.assertAllClose(50.0, c_val) 262 self.assertAllClose(60.0, d_val) 263 264 c_val, d_val = sess.run([c, d], feed_dict={c: 60.0, d: h_c}) 265 self.assertAllClose(60.0, c_val) 266 self.assertAllClose(50.0, d_val) 267 268 c_val, d_val = sess.run([c, d], feed_dict={c: h_c, d: h_c}) 269 self.assertAllClose(50.0, c_val) 270 self.assertAllClose(50.0, d_val) 271 272 def testFeedTwoHandlesDirectly(self): 273 with self.test_session() as sess: 274 a = constant_op.constant(10.0) 275 b = constant_op.constant(5.0) 276 c = math_ops.multiply(a, b) 277 d = math_ops.div(a, b) 278 e = math_ops.subtract(c, d) 279 280 h_c = sess.run(session_ops.get_session_handle(c)) 281 h_d = sess.run(session_ops.get_session_handle(d)) 282 283 self.assertAllClose(48.0, sess.run(e, feed_dict={c: h_c, d: h_d})) 284 self.assertAllClose(-48.0, sess.run(e, feed_dict={c: h_d, d: h_c})) 285 286 def testFeedHandleToVariableDirectly(self): 287 with self.test_session() as sess: 288 a = variables.Variable(12.0) 289 inc_a = state_ops.assign_add(a, 2.0) 290 b = math_ops.add(a, 5.0) 291 sess.run(a.initializer) 292 293 h_a_read = sess.run(session_ops.get_session_handle(a.read_value())) 294 self.assertAllClose(12.0, sess.run(a)) 295 296 self.assertAllClose(17.0, sess.run(b, feed_dict={a: h_a_read})) 297 sess.run(inc_a) 298 self.assertAllClose(19.0, sess.run(b, feed_dict={a: h_a_read})) 299 300 301if __name__ == "__main__": 302 test.main() 303