1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.python.client.session.Session.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import random 22import os 23import sys 24import threading 25import time 26import warnings 27 28import numpy as np 29import six 30from six.moves import xrange # pylint: disable=redefined-builtin 31 32from tensorflow.core.framework import attr_value_pb2 33from tensorflow.core.lib.core import error_codes_pb2 34from tensorflow.core.protobuf import config_pb2 35from tensorflow.python.client import session 36from tensorflow.python.eager import context 37from tensorflow.python.framework import common_shapes 38from tensorflow.python.framework import constant_op 39from tensorflow.python.framework import device as framework_device_lib 40from tensorflow.python.framework import dtypes 41from tensorflow.python.framework import errors 42from tensorflow.python.framework import function 43from tensorflow.python.framework import importer 44from tensorflow.python.framework import ops 45from tensorflow.python.framework import sparse_tensor 46from tensorflow.python.framework import tensor_util 47from tensorflow.python.framework import test_util 48from tensorflow.python.framework import versions 49from tensorflow.python.ops import array_ops 50from tensorflow.python.ops import control_flow_ops 51from tensorflow.python.ops import data_flow_ops 52from tensorflow.python.ops import gen_control_flow_ops 53# Import gradients to resolve circular imports 54from tensorflow.python.ops import gradients # pylint: disable=unused-import 55from tensorflow.python.ops import gradients_impl 56from tensorflow.python.ops import math_ops 57# Import resource_variable_ops for the variables-to-tensor implicit conversion. 58from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import 59from tensorflow.python.ops import state_ops 60from tensorflow.python.ops import variables 61from tensorflow.python.platform import googletest 62from tensorflow.python.training import server_lib 63from tensorflow.python.util import compat 64 65try: 66 import attr # pylint:disable=g-import-not-at-top 67except ImportError: 68 attr = None 69 70 71# NOTE(mrry): Dummy shape registration for ops used in the tests, since they 72# don't have C++ op registrations on which to attach C++ shape fns. 73ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape) 74 75 76class SessionTest(test_util.TensorFlowTestCase): 77 78 def setUp(self): 79 super(SessionTest, self).setUp() 80 warnings.simplefilter('always') 81 82 def testUseExistingGraph(self): 83 with ops.Graph().as_default() as g, ops.device('/cpu:0'): 84 a = constant_op.constant(6.0, shape=[1, 1]) 85 b = constant_op.constant(7.0, shape=[1, 1]) 86 c = math_ops.matmul(a, b, name='matmul') 87 with session.Session(graph=g): 88 result = c.eval() 89 self.assertAllEqual(result, [[42.0]]) 90 91 def testUseDefaultGraph(self): 92 with ops.Graph().as_default(), ops.device('/cpu:0'): 93 a = constant_op.constant(6.0, shape=[1, 1]) 94 b = constant_op.constant(7.0, shape=[1, 1]) 95 c = math_ops.matmul(a, b, name='matmul') 96 with session.Session(): 97 result = c.eval() 98 self.assertAllEqual(result, [[42.0]]) 99 100 def testCreate(self): 101 with session.Session(): 102 inp = constant_op.constant(10.0, shape=[2, 3], name='W1') 103 copy = array_ops.identity(inp) 104 # Test with feed. 105 # TODO(mrry): Investigate why order='F' didn't work. 106 arr = np.asarray([[0, 1, 2], [3, 4, 5]], dtype=np.float32, order='C') 107 copy_val = copy.eval({'W1:0': arr}) 108 self.assertAllEqual(arr, copy_val) 109 # Test without feed. 110 copy_val = copy.eval() 111 self.assertAllEqual( 112 np.asarray( 113 [[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]], dtype=np.float32), 114 copy_val) 115 116 def testManyCPUs(self): 117 with session.Session( 118 config=config_pb2.ConfigProto(device_count={ 119 'CPU': 2, 'GPU': 0 120 })) as sess: 121 inp = constant_op.constant(10.0, name='W1') 122 self.assertAllEqual(inp.eval(), 10.0) 123 124 num_cpu_devices = 0 125 num_gpu_devices = 0 126 for device in sess.list_devices(): 127 device_type = framework_device_lib.DeviceSpec.from_string( 128 device.name).device_type 129 if device_type == 'CPU': 130 num_cpu_devices += 1 131 elif device_type == 'GPU': 132 num_gpu_devices += 1 133 self.assertEqual(2, num_cpu_devices) 134 self.assertEqual(0, num_gpu_devices) 135 136 def testPerSessionThreads(self): 137 with session.Session( 138 config=config_pb2.ConfigProto(use_per_session_threads=True)): 139 inp = constant_op.constant(10.0, name='W1') 140 self.assertAllEqual(inp.eval(), 10.0) 141 142 def testSessionInterOpThreadPool(self): 143 config = config_pb2.ConfigProto() 144 pool = config.session_inter_op_thread_pool.add() 145 with session.Session(config=config) as s: 146 inp = constant_op.constant(10.0, name='W1') 147 results = s.run([inp]) 148 self.assertAllEqual([10.0], results) 149 150 pool = config.session_inter_op_thread_pool.add() 151 pool.num_threads = 1 152 with session.Session(config=config) as s: 153 inp = constant_op.constant(20.0, name='W2') 154 results = s.run([inp]) 155 self.assertAllEqual([20.0], results) 156 157 pool = config.session_inter_op_thread_pool.add() 158 pool.num_threads = 1 159 pool.global_name = 't1' 160 run_options = config_pb2.RunOptions() 161 run_options.inter_op_thread_pool = ( 162 len(config.session_inter_op_thread_pool) - 1) 163 with session.Session(config=config) as s: 164 inp = constant_op.constant(30.0, name='W2') 165 results = s.run([inp], options=run_options) 166 self.assertAllEqual([30.0], results) 167 168 def testErrorsReported(self): 169 with session.Session() as s: 170 constant_op.constant(10.0, name='W1') 171 with self.assertRaises(ValueError): 172 s.run('foo:0') 173 174 def testErrorPayload(self): 175 with session.Session(): 176 a = array_ops.placeholder(dtypes.float32) 177 with self.assertRaisesOpError(lambda e: e.op == a.op): 178 a.eval() 179 180 def testErrorCodeWithNoNodeDef(self): 181 with session.Session() as s: 182 a = array_ops.placeholder(dtypes.float32, shape=[]) 183 b = array_ops.placeholder(dtypes.float32, shape=[]) 184 r1 = math_ops.add(a, b) 185 186 def exc_predicate(e): 187 return (e.op is None and e.node_def is None and 188 e.error_code == error_codes_pb2.INVALID_ARGUMENT) 189 190 with self.assertRaisesOpError(exc_predicate): 191 # Run with a bogus handle. 192 s.partial_run('foo', r1, feed_dict={a: 1, b: 2}) 193 194 def testErrorBasedOn(self): 195 with session.Session() as sess: 196 a = constant_op.constant(0.0, shape=[2, 3]) 197 # NOTE(mrry): The original_op is nonsense, but used here to test that the 198 # errors are reported correctly. 199 with sess.graph._original_op(a.op): 200 b = array_ops.identity(a, name='id') 201 with sess.graph._original_op(b.op): 202 c = array_ops.placeholder(dtypes.float32) 203 204 def exc_predicate(e): 205 return (e.op == c.op and e.op._original_op == b.op and 206 e.op._original_op._original_op == a.op) 207 208 with self.assertRaisesOpError(exc_predicate): 209 c.eval() 210 211 def testFetchNone(self): 212 with session.Session() as s: 213 a = constant_op.constant(1.0) 214 with self.assertRaises(TypeError): 215 s.run(None) 216 with self.assertRaises(TypeError): 217 s.run([None]) 218 with self.assertRaises(TypeError): 219 s.run({'b': None}) 220 with self.assertRaises(TypeError): 221 s.run({'a': a, 'b': None}) 222 223 def testFetchSingleton(self): 224 with session.Session() as sess: 225 a = constant_op.constant(42.0) 226 res = sess.run(a) 227 self.assertEqual(42.0, res) 228 res = sess.run(a.op) # An op, not a tensor. 229 self.assertEqual(None, res) 230 tensor_runner = sess.make_callable(a) 231 res = tensor_runner() 232 self.assertEqual(42.0, res) 233 op_runner = sess.make_callable(a.op) 234 res = op_runner() 235 self.assertEqual(None, res) 236 237 def testFetchSingletonByName(self): 238 with session.Session() as sess: 239 a = constant_op.constant(42.0) 240 res = sess.run(a.name) 241 self.assertEqual(42.0, res) 242 res = sess.run(a.op) # An op, not a tensor. 243 self.assertEqual(None, res) 244 245 def testFetchList(self): 246 with session.Session() as sess: 247 a = constant_op.constant(42.0) 248 b = control_flow_ops.no_op() # An op, not a tensor. 249 c = constant_op.constant(44.0) 250 v = variables.Variable([54.0]) 251 assign = v.assign([63.0]) 252 res = sess.run([a, b, c, a.name, assign.op]) 253 self.assertTrue(isinstance(res, list)) 254 self.assertEqual([42.0, None, 44.0, 42.0, None], res) 255 list_runner = sess.make_callable([a, b, c, a.name, assign.op]) 256 res = list_runner() 257 self.assertTrue(isinstance(res, list)) 258 self.assertEqual([42.0, None, 44.0, 42.0, None], res) 259 260 def testFetchTuple(self): 261 with session.Session() as sess: 262 a = constant_op.constant(42.0) 263 b = control_flow_ops.no_op() # An op, not a tensor. 264 c = constant_op.constant(44.0) 265 res = sess.run((a, b, c, a.name)) 266 self.assertTrue(isinstance(res, tuple)) 267 self.assertEqual((42.0, None, 44.0, 42.0), res) 268 tuple_runner = sess.make_callable((a, b, c, a.name)) 269 res = tuple_runner() 270 self.assertTrue(isinstance(res, tuple)) 271 self.assertEqual((42.0, None, 44.0, 42.0), res) 272 273 def testFetchNamedTuple(self): 274 # pylint: disable=invalid-name 275 ABC = collections.namedtuple('ABC', ['a', 'b', 'c']) 276 # pylint: enable=invalid-name 277 with session.Session() as sess: 278 a = constant_op.constant(42.0) 279 b = control_flow_ops.no_op() # An op, not a tensor. 280 c = constant_op.constant(44.0) 281 res = sess.run(ABC(a, b, c)) 282 self.assertTrue(isinstance(res, ABC)) 283 self.assertEqual(42.0, res.a) 284 self.assertEqual(None, res.b) 285 self.assertEqual(44.0, res.c) 286 namedtuple_runner = sess.make_callable(ABC(a, b, c)) 287 res = namedtuple_runner() 288 self.assertTrue(isinstance(res, ABC)) 289 self.assertEqual(42.0, res.a) 290 self.assertEqual(None, res.b) 291 self.assertEqual(44.0, res.c) 292 293 def testFetchDict(self): 294 with session.Session() as sess: 295 a = constant_op.constant(42.0) 296 b = control_flow_ops.no_op() # An op, not a tensor. 297 c = constant_op.constant(44.0) 298 res = sess.run({'a': a, 'b': b, 'c': c}) 299 self.assertTrue(isinstance(res, dict)) 300 self.assertEqual(42.0, res['a']) 301 self.assertEqual(None, res['b']) 302 self.assertEqual(44.0, res['c']) 303 304 def testFetchOrderedDict(self): 305 with session.Session() as sess: 306 a = constant_op.constant(42.0) 307 b = control_flow_ops.no_op() # An op, not a tensor. 308 c = constant_op.constant(44.0) 309 res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)])) 310 self.assertTrue(isinstance(res, collections.OrderedDict)) 311 self.assertEqual([3, 2, 1], list(res.keys())) 312 self.assertEqual(42.0, res[3]) 313 self.assertEqual(None, res[2]) 314 self.assertEqual(44.0, res[1]) 315 316 @test_util.run_v1_only('b/120545219') 317 def testFetchAttrs(self): 318 if attr is None: 319 self.skipTest('attr module is unavailable.') 320 321 @attr.s 322 class SampleAttr(object): 323 field1 = attr.ib() 324 field2 = attr.ib() 325 326 val1 = np.array([1.2, 3.4, 5.6]) 327 val2 = np.array([[1, 2], [4, 3]]) 328 val3 = np.array([10, 20, 30]) 329 330 t1 = constant_op.constant(val1) 331 t2 = constant_op.constant(val2) 332 333 sample = SampleAttr(t1, t2) 334 with session.Session() as sess: 335 result = sess.run(sample) 336 self.assertIsInstance(result, SampleAttr) 337 self.assertAllEqual(val1, result.field1) 338 self.assertAllEqual(val2, result.field2) 339 340 result = sess.run(sample, feed_dict={sample.field1: val3}) 341 self.assertIsInstance(result, SampleAttr) 342 self.assertAllEqual(val3, result.field1) 343 self.assertAllEqual(val2, result.field2) 344 345 @test_util.run_v1_only('b/120545219') 346 def testFetchNestedAttrs(self): 347 if attr is None: 348 self.skipTest('attr module is unavailable.') 349 350 @attr.s 351 class SampleAttr(object): 352 field0 = attr.ib() 353 field1 = attr.ib() 354 355 v1 = 10 356 v2 = 20 357 v3 = np.float32(1.2) 358 v4 = np.float32(3.4) 359 v5 = np.float64(100.001) 360 v6 = np.float64(-23.451) 361 arr1 = np.array([1.2, 6.7, 3.4]) 362 arr2 = np.array([7, 11, 3]) 363 sample = SampleAttr( 364 SampleAttr( 365 SampleAttr(constant_op.constant(v1), constant_op.constant(v2)), 366 SampleAttr(constant_op.constant(arr1), constant_op.constant(arr2))), 367 {'A': SampleAttr(constant_op.constant(v3), constant_op.constant(v4)), 368 'B': [SampleAttr(constant_op.constant(v5), constant_op.constant(v6))]}) 369 370 with session.Session() as sess: 371 result = sess.run(sample) 372 self.assertIsInstance(result, SampleAttr) 373 self.assertIsInstance(result.field0, SampleAttr) 374 self.assertIsInstance(result.field0.field0, SampleAttr) 375 self.assertIsInstance(result.field0.field1, SampleAttr) 376 self.assertIsInstance(result.field0.field1.field0, np.ndarray) 377 self.assertAllEqual(arr1, result.field0.field1.field0) 378 self.assertIsInstance(result.field0.field1.field1, np.ndarray) 379 self.assertAllEqual(arr2, result.field0.field1.field1) 380 self.assertIsInstance(result.field1, dict) 381 self.assertIn('A', result.field1) 382 self.assertIn('B', result.field1) 383 self.assertIsInstance(result.field1['A'], SampleAttr) 384 self.assertAllEqual( 385 [v3, v4], 386 [result.field1['A'].field0, result.field1['A'].field1]) 387 self.assertIsInstance(result.field1['B'], list) 388 self.assertEqual(1, len(result.field1['B'])) 389 self.assertIsInstance(result.field1['B'][0], SampleAttr) 390 self.assertAllEqual( 391 [v5, v6], 392 [result.field1['B'][0].field0, result.field1['B'][0].field1]) 393 394 def testFetchNestingEmptyOneLevel(self): 395 with session.Session() as sess: 396 a_val = 11.0 397 a = constant_op.constant(a_val) 398 399 res = sess.run([[], tuple(), {}]) 400 self.assertTrue(isinstance(res, list)) 401 self.assertEquals(3, len(res)) 402 self.assertTrue(isinstance(res[0], list)) 403 self.assertEqual(0, len(res[0])) 404 self.assertTrue(isinstance(res[1], tuple)) 405 self.assertEqual(0, len(res[1])) 406 self.assertTrue(isinstance(res[2], dict)) 407 self.assertEqual(0, len(res[2])) 408 409 res = sess.run([[], tuple(), {}, a]) 410 self.assertTrue(isinstance(res, list)) 411 self.assertEquals(4, len(res)) 412 self.assertTrue(isinstance(res[0], list)) 413 self.assertEqual(0, len(res[0])) 414 self.assertTrue(isinstance(res[1], tuple)) 415 self.assertEqual(0, len(res[1])) 416 self.assertTrue(isinstance(res[2], dict)) 417 self.assertEqual(0, len(res[2])) 418 self.assertEqual(a_val, res[3]) 419 420 def testFetchNestingOneLevel(self): 421 with session.Session() as sess: 422 # pylint: disable=invalid-name 423 ABC = collections.namedtuple('ABC', ['a', 'b', 'c']) 424 DEFG = collections.namedtuple('DEFG', ['d', 'e', 'f', 'g']) 425 # pylint: enable=invalid-name 426 a_val = 42.0 427 b_val = None 428 c_val = 44.0 429 a = constant_op.constant(a_val) 430 b = control_flow_ops.no_op() # An op, not a tensor. 431 c = constant_op.constant(c_val) 432 # List of lists, tuples, namedtuple, and dict 433 res = sess.run([[a, b, c], (a, b, c), 434 ABC(a=a, b=b, c=c), { 435 'a': a.name, 436 'c': c, 437 'b': b 438 }]) 439 self.assertTrue(isinstance(res, list)) 440 self.assertEqual(4, len(res)) 441 self.assertTrue(isinstance(res[0], list)) 442 self.assertEqual(3, len(res[0])) 443 self.assertEqual(a_val, res[0][0]) 444 self.assertEqual(b_val, res[0][1]) 445 self.assertEqual(c_val, res[0][2]) 446 self.assertTrue(isinstance(res[1], tuple)) 447 self.assertEqual(3, len(res[1])) 448 self.assertEqual(a_val, res[1][0]) 449 self.assertEqual(b_val, res[1][1]) 450 self.assertEqual(c_val, res[1][2]) 451 self.assertTrue(isinstance(res[2], ABC)) 452 self.assertEqual(a_val, res[2].a) 453 self.assertEqual(b_val, res[2].b) 454 self.assertEqual(c_val, res[2].c) 455 self.assertTrue(isinstance(res[3], dict)) 456 self.assertEqual(3, len(res[3])) 457 self.assertEqual(a_val, res[3]['a']) 458 self.assertEqual(b_val, res[3]['b']) 459 self.assertEqual(c_val, res[3]['c']) 460 # Tuple of lists, tuples, namedtuple, and dict 461 res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, c=c), { 462 'a': a, 463 'c': c, 464 'b': b 465 })) 466 self.assertTrue(isinstance(res, tuple)) 467 self.assertEqual(4, len(res)) 468 self.assertTrue(isinstance(res[0], list)) 469 self.assertEqual(3, len(res[0])) 470 self.assertEqual(a_val, res[0][0]) 471 self.assertEqual(b_val, res[0][1]) 472 self.assertEqual(c_val, res[0][2]) 473 self.assertTrue(isinstance(res[1], tuple)) 474 self.assertEqual(3, len(res[1])) 475 self.assertEqual(a_val, res[1][0]) 476 self.assertEqual(b_val, res[1][1]) 477 self.assertEqual(c_val, res[1][2]) 478 self.assertTrue(isinstance(res[2], ABC)) 479 self.assertEqual(a_val, res[2].a) 480 self.assertEqual(b_val, res[2].b) 481 self.assertEqual(c_val, res[2].c) 482 self.assertTrue(isinstance(res[3], dict)) 483 self.assertEqual(3, len(res[3])) 484 self.assertEqual(a_val, res[3]['a']) 485 self.assertEqual(b_val, res[3]['b']) 486 self.assertEqual(c_val, res[3]['c']) 487 # Namedtuple of lists, tuples, namedtuples, and dict 488 res = sess.run( 489 DEFG( 490 d=[a, b, c], 491 e=(a, b, c), 492 f=ABC(a=a.name, b=b, c=c), 493 g={ 494 'a': a, 495 'c': c, 496 'b': b 497 })) 498 self.assertTrue(isinstance(res, DEFG)) 499 self.assertTrue(isinstance(res.d, list)) 500 self.assertEqual(3, len(res.d)) 501 self.assertEqual(a_val, res.d[0]) 502 self.assertEqual(b_val, res.d[1]) 503 self.assertEqual(c_val, res.d[2]) 504 self.assertTrue(isinstance(res.e, tuple)) 505 self.assertEqual(3, len(res.e)) 506 self.assertEqual(a_val, res.e[0]) 507 self.assertEqual(b_val, res.e[1]) 508 self.assertEqual(c_val, res.e[2]) 509 self.assertTrue(isinstance(res.f, ABC)) 510 self.assertEqual(a_val, res.f.a) 511 self.assertEqual(b_val, res.f.b) 512 self.assertEqual(c_val, res.f.c) 513 self.assertTrue(isinstance(res.g, dict)) 514 self.assertEqual(3, len(res.g)) 515 self.assertEqual(a_val, res.g['a']) 516 self.assertEqual(b_val, res.g['b']) 517 self.assertEqual(c_val, res.g['c']) 518 # Dict of lists, tuples, namedtuples, and dict 519 res = sess.run({ 520 'd': [a, b, c], 521 'e': (a, b, c), 522 'f': ABC(a=a, b=b, c=c), 523 'g': { 524 'a': a.name, 525 'c': c, 526 'b': b 527 } 528 }) 529 self.assertTrue(isinstance(res, dict)) 530 self.assertEqual(4, len(res)) 531 self.assertTrue(isinstance(res['d'], list)) 532 self.assertEqual(3, len(res['d'])) 533 self.assertEqual(a_val, res['d'][0]) 534 self.assertEqual(b_val, res['d'][1]) 535 self.assertEqual(c_val, res['d'][2]) 536 self.assertTrue(isinstance(res['e'], tuple)) 537 self.assertEqual(3, len(res['e'])) 538 self.assertEqual(a_val, res['e'][0]) 539 self.assertEqual(b_val, res['e'][1]) 540 self.assertEqual(c_val, res['e'][2]) 541 self.assertTrue(isinstance(res['f'], ABC)) 542 self.assertEqual(a_val, res['f'].a) 543 self.assertEqual(b_val, res['f'].b) 544 self.assertEqual(c_val, res['f'].c) 545 self.assertTrue(isinstance(res['g'], dict)) 546 self.assertEqual(3, len(res['g'])) 547 self.assertEqual(a_val, res['g']['a']) 548 self.assertEqual(b_val, res['g']['b']) 549 self.assertEqual(c_val, res['g']['c']) 550 551 def testFetchTensorObject(self): 552 with session.Session() as s: 553 a = constant_op.constant(1.0, shape=[1, 2]) 554 b = constant_op.constant(2.0, shape=[2, 3]) 555 c = math_ops.matmul(a, b) 556 results_with_list = s.run([c]) 557 self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_list[0]) 558 results_with_single = s.run(c) 559 self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_single) 560 results_with_get = c.eval() 561 self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_get) 562 a_val, b_val = s.run([a, b]) # Test multiple fetches. 563 self.assertAllEqual([[1.0, 1.0]], a_val) 564 self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], b_val) 565 results_with_dict = s.run({'a': [a], 'b': b, 'z': [a, b]}) 566 self.assertAllEqual([[1.0, 1.0]], results_with_dict['a'][0]) 567 self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], 568 results_with_dict['b']) 569 self.assertAllEqual(results_with_dict['a'][0], results_with_dict['z'][0]) 570 self.assertAllEqual(results_with_dict['b'], results_with_dict['z'][1]) 571 572 # Test nested structures 573 results_with_nested_list = s.run([[[a, b], b], a, [a, b]]) 574 self.assertAllEqual([[1.0, 1.0]], results_with_nested_list[0][0][0]) 575 self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], 576 results_with_nested_list[0][0][1]) 577 self.assertAllEqual(results_with_nested_list[0][0][0], 578 results_with_nested_list[1]) 579 self.assertAllEqual(results_with_nested_list[1], 580 results_with_nested_list[2][0]) 581 self.assertAllEqual(results_with_nested_list[0][0][1], 582 results_with_nested_list[0][1]) 583 self.assertAllEqual(results_with_nested_list[0][1], 584 results_with_nested_list[2][1]) 585 586 def testFetchScalar(self): 587 with session.Session() as s: 588 for scalar in np.int32, np.int64, np.float16, np.float32, np.float64: 589 x = scalar(7) 590 y = scalar(8) 591 tf_x = constant_op.constant(x, shape=[]) 592 tf_y = constant_op.constant(y) 593 tf_xy = math_ops.add(tf_x, tf_y) 594 # Single fetch 595 xy = s.run(tf_xy) 596 self.assertEqual(scalar, type(xy)) 597 self.assertEqual(x + y, xy) 598 # List fetch 599 xy, = s.run([tf_xy]) 600 self.assertEqual(scalar, type(xy)) 601 self.assertEqual(x + y, xy) 602 # Dict fetch 603 xy = s.run({'xy': tf_xy})['xy'] 604 self.assertEqual(scalar, type(xy)) 605 self.assertEqual(x + y, xy) 606 # Nested list fetch 607 xy = s.run([[[tf_xy]], tf_xy, [tf_xy]]) 608 self.assertAllEqual(xy, [[[x + y]], x + y, [x + y]]) 609 self.assertEqual(scalar, type(xy[0][0][0])) 610 self.assertEqual(scalar, type(xy[1])) 611 self.assertEqual(scalar, type(xy[2][0])) 612 613 def testFetchOperationObject(self): 614 with session.Session() as s: 615 a = constant_op.constant(1.0, shape=[1, 2]) 616 v = variables.Variable(a, name='testFetchOperationObject_v') 617 s.run(v.initializer) 618 v_val = s.run(v) 619 self.assertAllEqual([[1.0, 1.0]], v_val) 620 621 def testFetchSparseTensor(self): 622 with session.Session() as s: 623 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 624 values = np.array([1.0, 2.0]).astype(np.float32) 625 shape = np.array([7, 9, 2]).astype(np.int64) 626 sp = sparse_tensor.SparseTensor( 627 constant_op.constant(indices), constant_op.constant(values), 628 constant_op.constant(shape)) 629 # Single fetch, use as tuple 630 sp_out = s.run(sp) 631 indices_out, values_out, shape_out = sp_out 632 self.assertAllEqual(indices_out, indices) 633 self.assertAllEqual(values_out, values) 634 self.assertAllEqual(shape_out, shape) 635 # Single fetch, use as SparseTensorValue 636 sp_out = s.run(sp) 637 self.assertAllEqual(sp_out.indices, indices) 638 self.assertAllEqual(sp_out.values, values) 639 self.assertAllEqual(sp_out.dense_shape, shape) 640 # Tuple fetch, use as tuple 641 indices_out, values_out, shape_out = s.run(sp) 642 self.assertAllEqual(indices_out, indices) 643 self.assertAllEqual(values_out, values) 644 self.assertAllEqual(shape_out, shape) 645 # List fetch, use as tuple 646 (indices_out, values_out, shape_out), = s.run([sp]) 647 self.assertAllEqual(indices_out, indices) 648 self.assertAllEqual(values_out, values) 649 self.assertAllEqual(shape_out, shape) 650 # List fetch, use as SparseTensorValue 651 sp_out, = s.run([sp]) 652 self.assertAllEqual(sp_out.indices, indices) 653 self.assertAllEqual(sp_out.values, values) 654 self.assertAllEqual(sp_out.dense_shape, shape) 655 # Dict fetch (single value), use as tuple 656 indices_out, values_out, shape_out = s.run({'sp': sp})['sp'] 657 self.assertAllEqual(indices_out, indices) 658 self.assertAllEqual(values_out, values) 659 self.assertAllEqual(shape_out, shape) 660 # Dict fetch (list value), use as tuple 661 (indices_out, values_out, shape_out), = s.run({'sp': [sp]})['sp'] 662 self.assertAllEqual(indices_out, indices) 663 self.assertAllEqual(values_out, values) 664 self.assertAllEqual(shape_out, shape) 665 # Dict fetch, use as SparseTensorValue 666 sp_out = s.run({'sp': sp})['sp'] 667 self.assertAllEqual(sp_out.indices, indices) 668 self.assertAllEqual(sp_out.values, values) 669 self.assertAllEqual(sp_out.dense_shape, shape) 670 # Nested list fetch use as tuple 671 sp_out = s.run([[[sp]], sp]) 672 indices_out, values_out, shape_out = sp_out[0][0][0] 673 self.assertAllEqual(indices_out, indices) 674 self.assertAllEqual(values_out, values) 675 self.assertAllEqual(shape_out, shape) 676 indices_out, values_out, shape_out = sp_out[1] 677 self.assertAllEqual(indices_out, indices) 678 self.assertAllEqual(values_out, values) 679 self.assertAllEqual(shape_out, shape) 680 # Nested list fetch, use as SparseTensorValue 681 sp_out = s.run([[[sp]], sp]) 682 self.assertAllEqual(sp_out[0][0][0].indices, indices) 683 self.assertAllEqual(sp_out[0][0][0].values, values) 684 self.assertAllEqual(sp_out[0][0][0].dense_shape, shape) 685 self.assertAllEqual(sp_out[1].indices, indices) 686 self.assertAllEqual(sp_out[1].values, values) 687 self.assertAllEqual(sp_out[1].dense_shape, shape) 688 689 def testFeedSparseTensor(self): 690 with session.Session() as s: 691 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 692 values = np.array([1.0, 2.0]).astype(np.float32) 693 shape = np.array([7, 9, 2]).astype(np.int64) 694 sp = sparse_tensor.SparseTensor( 695 array_ops.placeholder(dtype=np.int64, shape=(2, 3)), 696 array_ops.placeholder(dtype=np.float32, shape=(2,)), 697 array_ops.placeholder(dtype=np.int64, shape=(3,)), 698 ) 699 sp_indices = array_ops.identity(sp.indices) 700 sp_values = array_ops.identity(sp.values) 701 sp_shape = array_ops.identity(sp.dense_shape) 702 sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) 703 # Feed with tuple 704 indices_out, values_out, shape_out = s.run( 705 [sp_indices, sp_values, sp_shape], { 706 sp: (indices, values, shape) 707 }) 708 self.assertAllEqual(indices_out, indices) 709 self.assertAllEqual(values_out, values) 710 self.assertAllEqual(shape_out, shape) 711 # Feed with tuple, fetch sp directly 712 sp_out = s.run(sp, {sp: (indices, values, shape)}) 713 self.assertAllEqual(sp_out.indices, indices) 714 self.assertAllEqual(sp_out.values, values) 715 self.assertAllEqual(sp_out.dense_shape, shape) 716 # Feed with SparseTensorValue 717 indices_out, values_out, shape_out = s.run( 718 [sp_indices, sp_values, sp_shape], { 719 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 720 }) 721 self.assertAllEqual(indices_out, indices) 722 self.assertAllEqual(values_out, values) 723 self.assertAllEqual(shape_out, shape) 724 # Feed with SparseTensorValue, fetch SparseTensorValue 725 sp2_out = s.run(sp2, { 726 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 727 }) 728 self.assertAllEqual(sp2_out.indices, indices) 729 self.assertAllEqual(sp2_out.values, values) 730 self.assertAllEqual(sp2_out.dense_shape, shape) 731 # Feed SparseTensorValue and fetch sp directly. 732 sp_out = s.run(sp, { 733 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 734 }) 735 self.assertAllEqual(sp_out.indices, indices) 736 self.assertAllEqual(sp_out.values, values) 737 self.assertAllEqual(sp_out.dense_shape, shape) 738 739 def testFeedSparsePlaceholder(self): 740 with session.Session() as s: 741 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 742 values = np.array([1.0, 2.0]).astype(np.float32) 743 shape = np.array([7, 9, 2]).astype(np.int64) 744 sp = array_ops.sparse_placeholder(dtype=np.float32, name='placeholder1') 745 sp_indices = array_ops.identity(sp.indices) 746 sp_values = array_ops.identity(sp.values) 747 sp_shape = array_ops.identity(sp.dense_shape) 748 sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) 749 # Feed with tuple 750 indices_out, values_out, shape_out = s.run( 751 [sp_indices, sp_values, sp_shape], { 752 sp: (indices, values, shape) 753 }) 754 self.assertAllEqual(indices_out, indices) 755 self.assertAllEqual(values_out, values) 756 self.assertAllEqual(shape_out, shape) 757 # Feed with SparseTensorValue 758 indices_out, values_out, shape_out = s.run( 759 [sp_indices, sp_values, sp_shape], { 760 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 761 }) 762 self.assertAllEqual(indices_out, indices) 763 self.assertAllEqual(values_out, values) 764 self.assertAllEqual(shape_out, shape) 765 # Feed with SparseTensorValue, fetch SparseTensorValue 766 sp2_out = s.run(sp2, { 767 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 768 }) 769 self.assertAllEqual(sp2_out.indices, indices) 770 self.assertAllEqual(sp2_out.values, values) 771 self.assertAllEqual(sp2_out.dense_shape, shape) 772 773 def testFeedSparsePlaceholderPartialShape(self): 774 with session.Session() as s: 775 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 776 values = np.array([1.0, 2.0]).astype(np.float32) 777 shape = np.array([7, 9, 2]).astype(np.int64) 778 sp = array_ops.sparse_placeholder( 779 shape=[None, 9, 2], dtype=np.float32, name='placeholder1') 780 sp_indices = array_ops.identity(sp.indices) 781 sp_values = array_ops.identity(sp.values) 782 sp_shape = array_ops.identity(sp.dense_shape) 783 sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) 784 # Feed with tuple 785 indices_out, values_out, shape_out = s.run( 786 [sp_indices, sp_values, sp_shape], { 787 sp: (indices, values, shape) 788 }) 789 self.assertAllEqual(indices_out, indices) 790 self.assertAllEqual(values_out, values) 791 self.assertAllEqual(shape_out, shape) 792 # Feed with SparseTensorValue 793 indices_out, values_out, shape_out = s.run( 794 [sp_indices, sp_values, sp_shape], { 795 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 796 }) 797 self.assertAllEqual(indices_out, indices) 798 self.assertAllEqual(values_out, values) 799 self.assertAllEqual(shape_out, shape) 800 # Feed with SparseTensorValue, fetch SparseTensorValue 801 sp2_out = s.run(sp2, { 802 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 803 }) 804 self.assertAllEqual(sp2_out.indices, indices) 805 self.assertAllEqual(sp2_out.values, values) 806 self.assertAllEqual(sp2_out.dense_shape, shape) 807 808 def testFeedSparsePlaceholderConstantShape(self): 809 with session.Session() as s: 810 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 811 values = np.array([1.0, 2.0]).astype(np.float32) 812 shape = np.array([7, 9, 2]).astype(np.int64) 813 sp = array_ops.sparse_placeholder( 814 dtype=np.float32, shape=shape, name='placeholder1') 815 self.assertAllEqual(sp.dense_shape.eval(session=s), shape) 816 self.assertAllEqual(tensor_util.constant_value(sp.dense_shape), shape) 817 sp_indices = array_ops.identity(sp.indices) 818 sp_values = array_ops.identity(sp.values) 819 sp_shape = array_ops.identity(sp.dense_shape) 820 # Feed with tuple 821 indices_out, values_out, shape_out = s.run( 822 [sp_indices, sp_values, sp_shape], { 823 sp: (indices, values) 824 }) 825 self.assertAllEqual(indices_out, indices) 826 self.assertAllEqual(values_out, values) 827 self.assertAllEqual(shape_out, shape) 828 829 def testFetchIndexedSlices(self): 830 with session.Session() as s: 831 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 832 values = np.array([1.0, 2.0]).astype(np.float32) 833 dense_shape = np.array([7, 9, 2]).astype(np.int64) 834 ind = ops.IndexedSlices( 835 constant_op.constant(values), constant_op.constant(indices), 836 constant_op.constant(dense_shape)) 837 # Single fetch, use as tuple 838 ind_out = s.run(ind) 839 values_out, indices_out, dense_shape_out = ind_out 840 self.assertAllEqual(values_out, values) 841 self.assertAllEqual(indices_out, indices) 842 self.assertAllEqual(dense_shape_out, dense_shape) 843 # Single fetch, use as IndexedSlicesValue 844 ind_out = s.run(ind) 845 self.assertAllEqual(ind_out.values, values) 846 self.assertAllEqual(ind_out.indices, indices) 847 self.assertAllEqual(ind_out.dense_shape, dense_shape) 848 # Tuple fetch, use as tuple 849 values_out, indices_out, dense_shape_out = s.run(ind) 850 self.assertAllEqual(values_out, values) 851 self.assertAllEqual(indices_out, indices) 852 self.assertAllEqual(dense_shape_out, dense_shape) 853 # List fetch, use as tuple 854 (values_out, indices_out, dense_shape_out), = s.run([ind]) 855 self.assertAllEqual(values_out, values) 856 self.assertAllEqual(indices_out, indices) 857 self.assertAllEqual(dense_shape_out, dense_shape) 858 # List fetch, use as IndexedSlicesValue 859 ind_out, = s.run([ind]) 860 self.assertAllEqual(ind_out.values, values) 861 self.assertAllEqual(ind_out.indices, indices) 862 self.assertAllEqual(ind_out.dense_shape, dense_shape) 863 864 def testFeedIndexedSlices(self): 865 with session.Session() as s: 866 values = np.array([1.0, 2.0]).astype(np.float32) 867 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 868 dense_shape = np.array([7, 9, 2]).astype(np.int64) 869 ind = ops.IndexedSlices( 870 array_ops.placeholder(dtype=np.float32, shape=(2,)), 871 array_ops.placeholder(dtype=np.int64, shape=(2, 3)), 872 array_ops.placeholder(dtype=np.int64, shape=(3,)), 873 ) 874 ind_values = array_ops.identity(ind.values) 875 ind_indices = array_ops.identity(ind.indices) 876 ind_dense_shape = array_ops.identity(ind.dense_shape) 877 ind2 = ops.IndexedSlices(ind_values, ind_indices, ind_dense_shape) 878 # Feed with tuple 879 values_out, indices_out, dense_shape_out = s.run( 880 [ind_values, ind_indices, ind_dense_shape], { 881 ind: (values, indices, dense_shape) 882 }) 883 self.assertAllEqual(values_out, values) 884 self.assertAllEqual(indices_out, indices) 885 self.assertAllEqual(dense_shape_out, dense_shape) 886 # Feed with IndexedSlicesValue 887 values_out, indices_out, dense_shape_out = s.run( 888 [ind_values, ind_indices, ind_dense_shape], { 889 ind: ops.IndexedSlicesValue(values, indices, dense_shape) 890 }) 891 self.assertAllEqual(values_out, values) 892 self.assertAllEqual(indices_out, indices) 893 self.assertAllEqual(dense_shape_out, dense_shape) 894 # Feed with IndexedSlicesValue, fetch IndexedSlicesValue 895 ind2_out = s.run(ind2, { 896 ind: ops.IndexedSlicesValue(values, indices, dense_shape) 897 }) 898 self.assertAllEqual(ind2_out.values, values) 899 self.assertAllEqual(ind2_out.indices, indices) 900 self.assertAllEqual(ind2_out.dense_shape, dense_shape) 901 902 def testFetchIndexedSlicesWithoutDenseShape(self): 903 with session.Session() as s: 904 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 905 values = np.array([1.0, 2.0]).astype(np.float32) 906 dense_shape = None 907 ind = ops.IndexedSlices( 908 constant_op.constant(values), constant_op.constant(indices), None) 909 # Single fetch, use as tuple 910 ind_out = s.run(ind) 911 values_out, indices_out, dense_shape_out = ind_out 912 self.assertAllEqual(values_out, values) 913 self.assertAllEqual(indices_out, indices) 914 self.assertAllEqual(dense_shape_out, dense_shape) 915 # Single fetch, use as IndexedSlicesValue 916 ind_out = s.run(ind) 917 self.assertAllEqual(ind_out.values, values) 918 self.assertAllEqual(ind_out.indices, indices) 919 self.assertAllEqual(ind_out.dense_shape, dense_shape) 920 # Tuple fetch, use as tuple 921 values_out, indices_out, dense_shape_out = s.run(ind) 922 self.assertAllEqual(values_out, values) 923 self.assertAllEqual(indices_out, indices) 924 self.assertAllEqual(dense_shape_out, dense_shape) 925 # List fetch, use as tuple 926 (values_out, indices_out, dense_shape_out), = s.run([ind]) 927 self.assertAllEqual(values_out, values) 928 self.assertAllEqual(indices_out, indices) 929 self.assertAllEqual(dense_shape_out, dense_shape) 930 # List fetch, use as IndexedSlicesValue 931 ind_out, = s.run([ind]) 932 self.assertAllEqual(ind_out.values, values) 933 self.assertAllEqual(ind_out.indices, indices) 934 self.assertAllEqual(ind_out.dense_shape, dense_shape) 935 936 def testFeedIndexedSlicesWithoutDenseShape(self): 937 with session.Session() as s: 938 values = np.array([1.0, 2.0]).astype(np.float32) 939 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 940 dense_shape = None 941 ind = ops.IndexedSlices( 942 array_ops.placeholder(dtype=np.float32, shape=(2,)), 943 array_ops.placeholder(dtype=np.int64, shape=(2, 3)), None) 944 ind_values = array_ops.identity(ind.values) 945 ind_indices = array_ops.identity(ind.indices) 946 ind2 = ops.IndexedSlices(ind_values, ind_indices) 947 # Feed with tuple 948 values_out, indices_out = s.run([ind_values, ind_indices], { 949 ind: (values, indices) 950 }) 951 self.assertAllEqual(values_out, values) 952 self.assertAllEqual(indices_out, indices) 953 # Feed with IndexedSlicesValue 954 values_out, indices_out = s.run([ind_values, ind_indices], { 955 ind: ops.IndexedSlicesValue(values, indices, dense_shape) 956 }) 957 self.assertAllEqual(values_out, values) 958 self.assertAllEqual(indices_out, indices) 959 # Feed with IndexedSlicesValue, fetch IndexedSlicesValue 960 ind2_out = s.run(ind2, { 961 ind: ops.IndexedSlicesValue(values, indices, dense_shape) 962 }) 963 self.assertAllEqual(ind2_out.values, values) 964 self.assertAllEqual(ind2_out.indices, indices) 965 self.assertAllEqual(ind2_out.dense_shape, dense_shape) 966 967 def testExtendWithStatelessOperations(self): 968 with session.Session() as s: 969 a = constant_op.constant(1.0, shape=[1, 2]) 970 b = constant_op.constant(2.0, shape=[2, 3]) 971 c = math_ops.matmul(a, b) 972 c_val = s.run(c) 973 self.assertAllEqual([[4.0, 4.0, 4.0]], c_val) 974 d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1]) 975 e = math_ops.matmul(c, d) 976 # Extend will happen here. 977 e_val = s.run(e) 978 self.assertAllEqual([[24.0]], e_val) 979 980 def testExtendWithStatefulOperations(self): 981 with session.Session() as s: 982 a = constant_op.constant(1.0, shape=[1, 2]) 983 b = constant_op.constant(2.0, shape=[2, 3]) 984 c = math_ops.matmul(a, b) 985 v = variables.Variable(c, name='testExtendWithStatefulOperations_v') 986 v.initializer.run() 987 v_val = v.eval() 988 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 989 d = constant_op.constant(3.0, shape=[2, 3]) 990 e = math_ops.matmul(a, d) 991 assign_e_to_v = state_ops.assign(v, e) 992 # Extend will happen here. 993 e_val = e.eval() 994 self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) 995 v_val = v.eval() 996 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 997 s.run(assign_e_to_v) 998 v_val = v.eval() 999 self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) 1000 1001 def testExtendWithGroupBy(self): 1002 with session.Session() as s: 1003 a = constant_op.constant(1.0, shape=[1, 2]) 1004 p = variables.Variable(a, name='testExtendWithGroupBy_p') 1005 a_val = a.eval() # Force an Extend after this op. 1006 self.assertAllEqual([[1.0, 1.0]], a_val) 1007 1008 b = constant_op.constant(2.0, shape=[1, 2]) 1009 q = variables.Variable(b, name='testExtendWithGroupBy_q') 1010 # Extend will happen here. 1011 init = control_flow_ops.group(p.initializer, q.initializer) 1012 s.run(init) 1013 p_val, q_val = s.run([p, q]) 1014 1015 self.assertAllEqual([[1.0, 1.0]], p_val) 1016 self.assertAllEqual([[2.0, 2.0]], q_val) 1017 1018 def testTensorGetMethod(self): 1019 with session.Session(): 1020 a = constant_op.constant(1.0, shape=[1, 2]) 1021 b = constant_op.constant(2.0, shape=[2, 3]) 1022 c = math_ops.matmul(a, b) 1023 1024 c_val = c.eval() 1025 self.assertAllEqual([[4.0, 4.0, 4.0]], c_val) 1026 1027 fed_c_val = c.eval(feed_dict={a.name: [[4.0, 4.0]]}) 1028 self.assertAllEqual([[16.0, 16.0, 16.0]], fed_c_val) 1029 1030 @test_util.run_v1_only('b/120545219') 1031 def testOperationRunMethod(self): 1032 with session.Session(): 1033 a = constant_op.constant(1.0, shape=[1, 2]) 1034 b = constant_op.constant(2.0, shape=[1, 2], name='b') 1035 v = variables.VariableV1(a, a.dtype) 1036 assign_a_to_v = state_ops.assign(v, a) 1037 1038 assign_a_to_v.eval() 1039 1040 v_val = v.eval() 1041 self.assertAllEqual([[1.0, 1.0]], v_val) 1042 1043 assign_b_to_v = state_ops.assign(v, b) 1044 1045 assign_b_to_v.eval() 1046 v_val = v.eval() 1047 self.assertAllEqual([[2.0, 2.0]], v_val) 1048 1049 assign_b_to_v.eval(feed_dict={'b:0': [[3.0, 3.0]]}) 1050 v_val = v.eval() 1051 self.assertAllEqual([[3.0, 3.0]], v_val) 1052 1053 def testDefaultGraph(self): 1054 with session.Session() as s: 1055 self.assertEqual(ops.get_default_graph(), s.graph) 1056 a = constant_op.constant(1.0, shape=[1, 2]) 1057 b = constant_op.constant(2.0, shape=[2, 3]) 1058 self.assertEqual(ops.get_default_graph(), a.graph) 1059 self.assertEqual(ops.get_default_graph(), b.graph) 1060 c = math_ops.matmul(a, b) 1061 v = variables.Variable(c, name='testDefaultGraph_v') 1062 v.initializer.run() 1063 v_val = v.eval() 1064 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1065 d = constant_op.constant(3.0, shape=[2, 3]) 1066 e = math_ops.matmul(a, d) 1067 assign_e_to_v = state_ops.assign(v, e) 1068 e_val = e.eval() 1069 self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) 1070 v_val = v.eval() 1071 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1072 s.run(assign_e_to_v) 1073 v_val = v.eval() 1074 self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) 1075 self.assertEqual(ops.get_default_graph(), s.graph) 1076 1077 def _testDefaultGraphInThread(self, constructed_event, continue_event, i): 1078 with session.Session() as s: 1079 self.assertEqual(ops.get_default_graph(), s.graph) 1080 a = constant_op.constant(1.0, shape=[1, 2]) 1081 b = constant_op.constant(2.0, shape=[2, 3]) 1082 c = math_ops.matmul(a, b) 1083 v = variables.Variable(c, name='var_%d' % i) 1084 1085 # Block here until all threads have constructed their graph. 1086 constructed_event.set() 1087 continue_event.wait() 1088 1089 assign_c_to_v = state_ops.assign(v, c) 1090 v.initializer.run() 1091 assign_c_to_v.eval() 1092 v_val = v.eval() 1093 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1094 d = constant_op.constant(3.0, shape=[2, 3]) 1095 e = math_ops.matmul(a, d) 1096 assign_e_to_v = state_ops.assign(v, e) 1097 e_val = e.eval() 1098 self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) 1099 v_val = v.eval() 1100 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1101 s.run(assign_e_to_v) 1102 v_val = v.eval() 1103 self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) 1104 self.assertEqual(ops.get_default_graph(), s.graph) 1105 1106 def testDefaultGraphWithThreads(self): 1107 # Fork ten threads that use their thread-local default graph. 1108 threads = [] 1109 constructed_events = [threading.Event() for _ in range(10)] 1110 continue_event = threading.Event() 1111 for i, constructed_event in enumerate(constructed_events): 1112 t = self.checkedThread( 1113 target=self._testDefaultGraphInThread, 1114 args=(constructed_event, continue_event, i)) 1115 threads.append(t) 1116 for t in threads: 1117 t.start() 1118 for constructed_event in constructed_events: 1119 constructed_event.wait() 1120 continue_event.set() 1121 for t in threads: 1122 t.join() 1123 1124 def testParallelRun(self): 1125 with session.Session() as sess: 1126 c = constant_op.constant(5.0) 1127 ev = threading.Event() 1128 1129 def run_step(): 1130 ev.wait() 1131 val = c.eval(session=sess) 1132 self.assertEqual(val, 5.0) 1133 1134 threads = [self.checkedThread(target=run_step) for _ in range(100)] 1135 for t in threads: 1136 t.start() 1137 ev.set() 1138 for t in threads: 1139 t.join() 1140 1141 @staticmethod 1142 def _build_graph(): 1143 time.sleep(random.random() * 0.1) 1144 # Do some graph construction. Try to exercise non-trivial paths. 1145 graph = ops.get_default_graph() 1146 gdef = None 1147 for _ in range(10): 1148 x = array_ops.placeholder(dtype=dtypes.float32) 1149 with ops.colocate_with(x): 1150 y = array_ops.placeholder(dtype=dtypes.float32) 1151 with ops.device('/cpu:0'): 1152 z = control_flow_ops.while_loop( 1153 lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y]) 1154 with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}): 1155 gradients_impl.gradients(z, [x, y]) 1156 if gdef is None: 1157 gdef = graph.as_graph_def() 1158 else: 1159 importer.import_graph_def(gdef, name='import') 1160 1161 @test_util.run_v1_only('b/120545219') 1162 def testParallelRunAndSingleBuild(self): 1163 with session.Session() as sess: 1164 c = constant_op.constant(5.0) 1165 stop = threading.Event() 1166 1167 def run_loop(): 1168 while not stop.is_set(): 1169 time.sleep(random.random() * 0.1) 1170 self.assertEqual(sess.run(c), 5.0) 1171 1172 threads = [self.checkedThread(target=run_loop) for _ in range(10)] 1173 for t in threads: 1174 t.start() 1175 1176 SessionTest._build_graph() 1177 1178 stop.set() 1179 for t in threads: 1180 t.join() 1181 1182 @test_util.run_v1_only('b/120545219') 1183 def testParallelRunAndParallelBuild(self): 1184 with session.Session() as sess: 1185 c = constant_op.constant(5.0) 1186 stop = threading.Event() 1187 1188 def run_loop(): 1189 while not stop.is_set(): 1190 time.sleep(random.random() * 0.1) 1191 self.assertEqual(sess.run(c), 5.0) 1192 1193 run_threads = [self.checkedThread(target=run_loop) for _ in range(10)] 1194 for t in run_threads: 1195 t.start() 1196 1197 build_threads = [self.checkedThread(target=SessionTest._build_graph) 1198 for _ in range(10)] 1199 for t in build_threads: 1200 t.start() 1201 for t in build_threads: 1202 t.join() 1203 1204 # Let the run_threads run until the build threads are finished. 1205 stop.set() 1206 for t in run_threads: 1207 t.join() 1208 1209 def testRunFeedDict(self): 1210 with session.Session() as s: 1211 x = array_ops.zeros([2]) 1212 1213 y = s.run(2 * x, feed_dict={x: np.ones(2).astype(np.float32)}) 1214 self.assertAllEqual(y, 2 * np.ones(2)) 1215 1216 y = s.run(2 * x, feed_dict={x.name: np.ones(2).astype(np.float32)}) 1217 self.assertAllEqual(y, 2 * np.ones(2)) 1218 1219 y = s.run(2 * x, feed_dict={x: [1, 1]}) 1220 assert (y == 2 * np.ones(2)).all() 1221 1222 # Test nested tuple keys 1223 z = (((array_ops.zeros([2]),),), array_ops.zeros([2]), 1224 (array_ops.zeros([2]),)) 1225 result = [z[0][0][0] * 2, z[1] * 2, z[2][0] * 2] 1226 values = (((np.array([1, 1]),),), np.array([2, 2]), (np.array([3, 3]),)) 1227 result_value = s.run(result, feed_dict={z: values}) 1228 self.assertAllEqual(result_value[0], 2 * np.ones(2)) 1229 self.assertAllEqual(result_value[1], 2 * np.array([2, 2])) 1230 self.assertAllEqual(result_value[2], 2 * np.array([3, 3])) 1231 1232 def testGraphDef(self): 1233 with session.Session() as sess: 1234 self.assertProtoEquals('versions { producer: %d min_consumer: %d }' % 1235 (versions.GRAPH_DEF_VERSION, 1236 versions.GRAPH_DEF_VERSION_MIN_CONSUMER), 1237 sess.graph_def) 1238 c = constant_op.constant(5.0, name='c') 1239 self.assertEquals(len(sess.graph_def.node), 1) 1240 d = constant_op.constant(6.0, name='d') 1241 self.assertEquals(len(sess.graph_def.node), 2) 1242 self.assertAllEqual(c.eval(), 5.0) 1243 self.assertAllEqual(d.eval(), 6.0) 1244 e = constant_op.constant(7.0, name='e') 1245 self.assertEquals(len(sess.graph_def.node), 3) 1246 self.assertAllEqual(e.eval(), 7.0) 1247 1248 def testUseAfterClose(self): 1249 with session.Session() as sess: 1250 c = constant_op.constant(5.0) 1251 self.assertAllEqual(sess.run(c), 5.0) 1252 with self.assertRaisesWithPredicateMatch( 1253 RuntimeError, lambda e: 'Attempted to use a closed Session.' in str(e)): 1254 sess.run(c) 1255 1256 def testUseAfterCloseConcurrent(self): 1257 with session.Session() as sess: 1258 c = constant_op.constant(5.0) 1259 self.assertAllEqual(sess.run(c), 5.0) 1260 1261 def update_thread(): 1262 with self.assertRaisesWithPredicateMatch( 1263 RuntimeError, 1264 lambda e: 'Attempted to use a closed Session.' in str(e)): 1265 while True: 1266 sess.run(c) 1267 1268 t = threading.Thread(target=update_thread) 1269 t.start() 1270 time.sleep(0.1) 1271 sess.close() 1272 t.join() 1273 1274 def testUseEmptyGraph(self): 1275 with session.Session() as sess: 1276 with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'): 1277 sess.run([]) 1278 with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'): 1279 sess.run(()) 1280 with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'): 1281 sess.run({}) 1282 1283 @test_util.run_v1_only('b/120545219') 1284 def testNotEntered(self): 1285 # pylint: disable=protected-access 1286 self.assertEqual(ops._default_session_stack.get_default(), None) 1287 # pylint: enable=protected-access 1288 with ops.device('/cpu:0'): 1289 sess = session.Session() 1290 c_1 = constant_op.constant(5.0) 1291 with sess.graph.as_default(): 1292 c_2 = constant_op.constant(5.0) 1293 self.assertEqual(c_1.graph, c_2.graph) 1294 self.assertEqual(sess.run(c_2), 5.0) 1295 with self.assertRaisesWithPredicateMatch( 1296 ValueError, lambda e: 'No default session is registered.' in str(e)): 1297 c_2.eval() 1298 1299 @test_util.run_v1_only('b/120545219') 1300 def testInteractive(self): 1301 with ops.device('/cpu:0'): 1302 sess = session.InteractiveSession() 1303 a = constant_op.constant(1.0, shape=[1, 2]) 1304 b = constant_op.constant(2.0, shape=[2, 3]) 1305 c = math_ops.matmul(a, b) 1306 self.assertAllEqual([[4.0, 4.0, 4.0]], c.eval()) 1307 d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1]) 1308 e = math_ops.matmul(c, d) 1309 self.assertAllEqual([[24.0]], e.eval()) 1310 sess.close() 1311 1312 @test_util.run_v1_only('b/120545219') 1313 def testMultipleInteractiveSessionsWarning(self): 1314 # Reinitialize the global state to ensure that the expected warnings will 1315 # be emitted. 1316 session.InteractiveSession._active_session_count = 0 # pylint: disable=protected-access 1317 1318 sess = session.InteractiveSession() 1319 sess.run(constant_op.constant(4.0)) # Run so that the session is "opened". 1320 sess.close() 1321 # Opening and closing interactive sessions serially should not warn. 1322 with warnings.catch_warnings(record=True) as w: 1323 sess = session.InteractiveSession() 1324 sess.close() 1325 self.assertEqual(0, len(w)) 1326 1327 with warnings.catch_warnings(record=True) as w: 1328 sess = session.InteractiveSession() 1329 self.assertEqual(0, len(w)) 1330 with warnings.catch_warnings(record=True) as w: 1331 sess2 = session.InteractiveSession() 1332 self.assertEqual(1, len(w)) 1333 self.assertTrue('An interactive session is already active. This can cause ' 1334 'out-of-memory errors in some cases. You must explicitly ' 1335 'call `InteractiveSession.close()` to release resources ' 1336 'held by the other session(s).' in str(w[0].message)) 1337 sess2.close() 1338 sess.close() 1339 1340 @test_util.run_v1_only('b/120545219') 1341 def testInteractivePlacePrunedGraph(self): 1342 sess = session.InteractiveSession() 1343 1344 # Build a graph that has a bad op in it (no kernel). 1345 # 1346 # This test currently does not link in any GPU kernels, 1347 # which is why placing this is invalid. If at some point 1348 # GPU kernels are added to this test, some other different 1349 # op / device combo should be chosen. 1350 with ops.device('/device:GPU:0'): 1351 a = constant_op.constant(1.0, shape=[1, 2]) 1352 1353 b = constant_op.constant(1.0, shape=[1, 2]) 1354 1355 # Only run the valid op, this should work. 1356 b.eval() 1357 1358 with self.assertRaises(errors.InvalidArgumentError): 1359 a.eval() 1360 sess.close() 1361 1362 @test_util.run_v1_only('b/120545219') 1363 def testDefaultSessionPlacePrunedGraph(self): 1364 sess = session.Session() 1365 1366 # Build a graph that has a bad op in it (no kernel). 1367 # 1368 # This test currently does not link in any GPU kernels, 1369 # which is why placing this is invalid. If at some point 1370 # GPU kernels are added to this test, some other different 1371 # op / device combo should be chosen. 1372 with ops.device('/device:GPU:0'): 1373 _ = constant_op.constant(1.0, shape=[1, 2]) 1374 1375 b = constant_op.constant(1.0, shape=[1, 2]) 1376 1377 with self.assertRaises(errors.InvalidArgumentError): 1378 # Even though we don't run the bad op, we place the entire 1379 # graph, which should fail with a non-interactive session. 1380 sess.run(b) 1381 1382 sess.close() 1383 1384 def testSharedGraph(self): 1385 with ops.Graph().as_default() as g, ops.device('/cpu:0'): 1386 a = constant_op.constant(1.0, shape=[1, 2]) 1387 b = constant_op.constant(2.0, shape=[2, 3]) 1388 c = math_ops.matmul(a, b) 1389 1390 with session.Session(graph=g) as sess1: 1391 with session.Session(graph=g) as sess2: 1392 self.assertAllEqual(sess1.run(c), sess2.run(c)) 1393 1394 def testDuplicatedInputs(self): 1395 with session.Session() as sess: 1396 a = constant_op.constant(1.0, shape=[1, 2]) 1397 b = constant_op.constant(2.0, shape=[1, 3]) 1398 a_val, b_val, a2_val = sess.run([a, b, a]) 1399 self.assertAllEqual(a_val, [[1.0, 1.0]]) 1400 self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]]) 1401 self.assertAllEqual(a2_val, [[1.0, 1.0]]) 1402 1403 def testFeedAndFetch(self): 1404 with session.Session() as sess: 1405 for dtype in [ 1406 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, 1407 dtypes.uint8, dtypes.int16, dtypes.int8, dtypes.int64, dtypes.bool, 1408 dtypes.complex64, dtypes.complex128 1409 ]: 1410 for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: 1411 np_dtype = dtype.as_numpy_dtype 1412 1413 feed_t = array_ops.placeholder(dtype=dtype, shape=shape) 1414 out_t = array_ops.identity(feed_t) 1415 1416 np_array = np.random.randint(-10, 10, shape) 1417 1418 if dtype == dtypes.bool: 1419 np_array = np_array > 0 1420 elif dtype == dtypes.complex64: 1421 np_array = np.sqrt(np_array.astype(np_dtype)) 1422 elif dtype == dtypes.complex64: 1423 np_array = np.sqrt(np_array.astype(np_dtype)) 1424 else: 1425 np_array = np_array.astype(np_dtype) 1426 1427 self.assertAllEqual(np_array, 1428 sess.run(out_t, feed_dict={ 1429 feed_t: np_array 1430 })) 1431 # Check that we can also get the feed back. 1432 self.assertAllEqual(np_array, 1433 sess.run(feed_t, feed_dict={ 1434 feed_t: np_array 1435 })) 1436 # Also check that we can get both back. 1437 out_v, feed_v = sess.run( 1438 [out_t, feed_t], feed_dict={ 1439 feed_t: np_array 1440 }) 1441 self.assertAllEqual(np_array, out_v) 1442 self.assertAllEqual(np_array, feed_v) 1443 1444 feed_fetch_runner = sess.make_callable([out_t, feed_t], [feed_t]) 1445 out_v, feed_v = feed_fetch_runner(np_array) 1446 self.assertAllEqual(np_array, out_v) 1447 self.assertAllEqual(np_array, feed_v) 1448 1449 def testMakeCallableOnTensorWithRunOptions(self): 1450 with session.Session() as sess: 1451 a = constant_op.constant(42.0) 1452 tensor_runner = sess.make_callable(a, accept_options=True) 1453 run_options = config_pb2.RunOptions( 1454 trace_level=config_pb2.RunOptions.FULL_TRACE) 1455 run_metadata = config_pb2.RunMetadata() 1456 self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) 1457 res = tensor_runner(options=run_options, run_metadata=run_metadata) 1458 self.assertEqual(42.0, res) 1459 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) 1460 1461 def testMakeCallableOnOperationWithRunOptions(self): 1462 with session.Session() as sess: 1463 a = variables.Variable(42.0) 1464 b = state_ops.assign_add(a, 1.0) 1465 sess.run(a.initializer) 1466 tensor_runner = sess.make_callable(b.op, accept_options=True) 1467 run_options = config_pb2.RunOptions( 1468 trace_level=config_pb2.RunOptions.FULL_TRACE) 1469 run_metadata = config_pb2.RunMetadata() 1470 self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) 1471 tensor_runner(options=run_options, run_metadata=run_metadata) 1472 self.assertEqual(43.0, sess.run(a)) 1473 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) 1474 1475 def testMakeCallableWithFeedListAndRunOptions(self): 1476 with session.Session() as sess: 1477 ph = array_ops.placeholder(dtypes.float32) 1478 a = math_ops.add(ph, 1.0) 1479 tensor_runner = sess.make_callable( 1480 a, feed_list=[ph.name], accept_options=True) 1481 run_options = config_pb2.RunOptions( 1482 trace_level=config_pb2.RunOptions.FULL_TRACE) 1483 run_metadata = config_pb2.RunMetadata() 1484 self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) 1485 self.assertAllClose(42.0, 1486 tensor_runner( 1487 41.0, 1488 options=run_options, 1489 run_metadata=run_metadata)) 1490 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) 1491 1492 def testOptimizedMakeCallable(self): 1493 with session.Session() as sess: 1494 ph = array_ops.placeholder(dtypes.float32) 1495 a = math_ops.add(ph, 1.0) 1496 callable_opts = config_pb2.CallableOptions() 1497 callable_opts.feed.append(ph.name) 1498 callable_opts.fetch.append(a.name) 1499 for _ in range(3): 1500 callable_fn = sess._make_callable_from_options(callable_opts) 1501 for _ in range(5): 1502 self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32))) 1503 1504 def testOptimizedMakeCallableWithRunMetadata(self): 1505 with session.Session() as sess: 1506 ph = array_ops.placeholder(dtypes.float32) 1507 a = math_ops.add(ph, 1.0) 1508 callable_opts = config_pb2.CallableOptions() 1509 callable_opts.feed.append(ph.name) 1510 callable_opts.fetch.append(a.name) 1511 callable_opts.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE 1512 callable_fn = sess._make_callable_from_options(callable_opts) 1513 run_metadata = config_pb2.RunMetadata() 1514 self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32), 1515 run_metadata=run_metadata)) 1516 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) 1517 1518 def testFeedError(self): 1519 with session.Session() as sess: 1520 feed_t = array_ops.placeholder(dtype=dtypes.float32) 1521 out_t = array_ops.identity(feed_t) 1522 feed_val = constant_op.constant(5.0) 1523 with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'): 1524 sess.run(out_t, feed_dict={feed_t: feed_val}) 1525 with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'): 1526 out_t.eval(feed_dict={feed_t: feed_val}) 1527 with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'): 1528 out_t.op.run(feed_dict={feed_t: feed_val}) 1529 1530 def testFeedPrecisionLossError(self): 1531 with session.Session() as sess: 1532 largest_int64 = np.iinfo(np.int64).max 1533 1534 feed_int_implicit_int32 = constant_op.constant(1) 1535 feed_int_explicit_int32 = constant_op.constant(1, dtype=dtypes.int32) 1536 1537 out_t = constant_op.constant(1.0) 1538 1539 with self.assertRaisesRegexp(TypeError, 1540 'is not compatible with Tensor type'): 1541 sess.run(out_t, feed_dict={feed_int_implicit_int32: largest_int64}) 1542 with self.assertRaisesRegexp(TypeError, 1543 'is not compatible with Tensor type'): 1544 sess.run(out_t, feed_dict={feed_int_explicit_int32: largest_int64}) 1545 1546 def testStringFetch(self): 1547 with session.Session(): 1548 for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: 1549 size = 1 1550 for s in shape: 1551 size *= s 1552 c_list = np.array( 1553 [compat.as_bytes(str(i)) for i in xrange(size)], 1554 dtype=np.object).reshape(shape) if size > 0 else [] 1555 c = constant_op.constant(c_list) 1556 self.assertAllEqual(c.eval(), c_list) 1557 1558 def testStringFeed(self): 1559 with session.Session() as sess: 1560 for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: 1561 size = 1 1562 for s in shape: 1563 size *= s 1564 c_list = np.array( 1565 [compat.as_bytes(str(i)) for i in xrange(size)], 1566 dtype=np.object).reshape(shape) 1567 feed_t = array_ops.placeholder(dtype=dtypes.string, shape=shape) 1568 c = array_ops.identity(feed_t) 1569 self.assertAllEqual(sess.run(c, feed_dict={feed_t: c_list}), c_list) 1570 self.assertAllEqual( 1571 sess.run(feed_t, feed_dict={ 1572 feed_t: c_list 1573 }), c_list) 1574 c_v, feed_v = sess.run([c, feed_t], feed_dict={feed_t: c_list}) 1575 self.assertAllEqual(c_v, c_list) 1576 self.assertAllEqual(feed_v, c_list) 1577 1578 def testStringFeedWithNullCharacters(self): 1579 with session.Session(): 1580 c_list = [b'\n\x01\x00', b'\n\x00\x01'] 1581 feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[2]) 1582 c = array_ops.identity(feed_t) 1583 out = c.eval(feed_dict={feed_t: c_list}) 1584 self.assertEqual(c_list[0], out[0]) 1585 self.assertEqual(c_list[1], out[1]) 1586 1587 def testStringFeedWithUnicode(self): 1588 with session.Session(): 1589 c_list = [ 1590 u'\n\x01\x00', u'\n\x00\x01', u'\u26a3 unicode', 1591 u'\U0001f60e deal with it' 1592 ] 1593 feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[len(c_list)]) 1594 c = array_ops.identity(feed_t) 1595 1596 out = c.eval(feed_dict={feed_t: c_list}) 1597 for i in range(len(c_list)): 1598 self.assertEqual(c_list[i], out[i].decode('utf-8')) 1599 1600 out = c.eval(feed_dict={feed_t: np.array(c_list, dtype=np.object)}) 1601 for i in range(len(c_list)): 1602 self.assertEqual(c_list[i], out[i].decode('utf-8')) 1603 1604 def testInvalidTargetFails(self): 1605 with self.assertRaisesRegexp( 1606 errors.NotFoundError, 1607 'No session factory registered for the given session options'): 1608 session.Session('INVALID_TARGET') 1609 1610 def testFetchByNameDifferentStringTypes(self): 1611 with session.Session() as sess: 1612 c = constant_op.constant(42.0, name='c') 1613 d = constant_op.constant(43.0, name=u'd') 1614 e = constant_op.constant(44.0, name=b'e') 1615 f = constant_op.constant(45.0, name=r'f') 1616 1617 self.assertTrue(isinstance(c.name, six.text_type)) 1618 self.assertTrue(isinstance(d.name, six.text_type)) 1619 self.assertTrue(isinstance(e.name, six.text_type)) 1620 self.assertTrue(isinstance(f.name, six.text_type)) 1621 1622 self.assertEqual(42.0, sess.run('c:0')) 1623 self.assertEqual(42.0, sess.run(u'c:0')) 1624 self.assertEqual(42.0, sess.run(b'c:0')) 1625 self.assertEqual(42.0, sess.run(r'c:0')) 1626 1627 self.assertEqual(43.0, sess.run('d:0')) 1628 self.assertEqual(43.0, sess.run(u'd:0')) 1629 self.assertEqual(43.0, sess.run(b'd:0')) 1630 self.assertEqual(43.0, sess.run(r'd:0')) 1631 1632 self.assertEqual(44.0, sess.run('e:0')) 1633 self.assertEqual(44.0, sess.run(u'e:0')) 1634 self.assertEqual(44.0, sess.run(b'e:0')) 1635 self.assertEqual(44.0, sess.run(r'e:0')) 1636 1637 self.assertEqual(45.0, sess.run('f:0')) 1638 self.assertEqual(45.0, sess.run(u'f:0')) 1639 self.assertEqual(45.0, sess.run(b'f:0')) 1640 self.assertEqual(45.0, sess.run(r'f:0')) 1641 1642 def testIncorrectGraph(self): 1643 with ops.Graph().as_default() as g_1: 1644 c_1 = constant_op.constant(1.0, name='c') 1645 1646 with ops.Graph().as_default() as g_2: 1647 c_2 = constant_op.constant(2.0, name='c') 1648 1649 self.assertEqual('c', c_1.op.name) 1650 self.assertEqual('c', c_2.op.name) 1651 1652 with session.Session(graph=g_1) as sess_1: 1653 self.assertEqual(1.0, sess_1.run(c_1)) 1654 with self.assertRaises(ValueError): 1655 sess_1.run(c_2) 1656 with self.assertRaises(ValueError): 1657 sess_1.run(c_2.op) 1658 1659 with session.Session(graph=g_2) as sess_2: 1660 with self.assertRaises(ValueError): 1661 sess_2.run(c_1) 1662 with self.assertRaises(ValueError): 1663 sess_2.run(c_1.op) 1664 self.assertEqual(2.0, sess_2.run(c_2)) 1665 1666 def testFeedDictKeyException(self): 1667 with session.Session() as sess: 1668 a = constant_op.constant(1.0, dtypes.float32, name='a') 1669 with self.assertRaisesRegexp(TypeError, 'Cannot interpret feed_dict'): 1670 sess.run(a, feed_dict={'a': [2.0]}) 1671 1672 def testPerStepTrace(self): 1673 run_options = config_pb2.RunOptions( 1674 trace_level=config_pb2.RunOptions.FULL_TRACE) 1675 run_metadata = config_pb2.RunMetadata() 1676 1677 with ops.device('/cpu:0'): 1678 with session.Session() as sess: 1679 sess.run(constant_op.constant(1.0)) 1680 self.assertTrue(not run_metadata.HasField('step_stats')) 1681 1682 sess.run(constant_op.constant(1.0), run_metadata=run_metadata) 1683 self.assertTrue(not run_metadata.HasField('step_stats')) 1684 1685 sess.run( 1686 constant_op.constant(1.0), 1687 options=run_options, 1688 run_metadata=run_metadata) 1689 1690 self.assertTrue(run_metadata.HasField('step_stats')) 1691 self.assertEquals(len(run_metadata.step_stats.dev_stats), 1) 1692 1693 def testRunOptionsRunMetadata(self): 1694 run_options = config_pb2.RunOptions( 1695 trace_level=config_pb2.RunOptions.FULL_TRACE) 1696 run_metadata = config_pb2.RunMetadata() 1697 1698 with ops.device('/cpu:0'): 1699 with session.Session() as sess: 1700 # all combinations are valid 1701 sess.run(constant_op.constant(1.0), options=None, run_metadata=None) 1702 sess.run( 1703 constant_op.constant(1.0), options=None, run_metadata=run_metadata) 1704 self.assertTrue(not run_metadata.HasField('step_stats')) 1705 1706 sess.run( 1707 constant_op.constant(1.0), options=run_options, run_metadata=None) 1708 self.assertTrue(not run_metadata.HasField('step_stats')) 1709 1710 sess.run( 1711 constant_op.constant(1.0), 1712 options=run_options, 1713 run_metadata=run_metadata) 1714 1715 self.assertTrue(run_metadata.HasField('step_stats')) 1716 self.assertEquals(len(run_metadata.step_stats.dev_stats), 1) 1717 1718 def testFeedShapeCompatibility(self): 1719 with session.Session() as sess: 1720 some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0]) 1721 new_shape = constant_op.constant([2, 2]) 1722 reshaped_tensor = array_ops.reshape(some_tensor, new_shape) 1723 1724 with self.assertRaisesRegexp(ValueError, 'Cannot feed value of shape'): 1725 sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]}) 1726 1727 with self.assertRaisesRegexp( 1728 errors.InvalidArgumentError, 1729 'Input to reshape is a tensor with 4 values, ' 1730 'but the requested shape has 21'): 1731 sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]}) 1732 1733 def testInferShapesFalse(self): 1734 with ops.Graph().as_default(), ops.device('/cpu:0'): 1735 a = constant_op.constant([[1, 2]]) 1736 sess = session.Session() 1737 self.assertFalse('_output_shapes' in sess.graph_def.node[0].attr) 1738 # Avoid lint error regarding 'unused' var a. 1739 self.assertTrue(a == a) 1740 1741 def testInferShapesTrue(self): 1742 config = config_pb2.ConfigProto( 1743 graph_options=config_pb2.GraphOptions(infer_shapes=True)) 1744 with ops.Graph().as_default(), ops.device('/cpu:0'): 1745 a = constant_op.constant([[1, 2]]) 1746 sess = session.Session(config=config) 1747 self.assertTrue('_output_shapes' in sess.graph_def.node[0].attr) 1748 # Avoid lint error regarding 'unused' var a. 1749 self.assertTrue(a == a) 1750 1751 def testBuildCostModel(self): 1752 run_options = config_pb2.RunOptions() 1753 config = config_pb2.ConfigProto( 1754 allow_soft_placement=True, 1755 graph_options=config_pb2.GraphOptions(build_cost_model=100)) 1756 with session.Session(config=config) as sess: 1757 with ops.device('/device:GPU:0'): 1758 a = array_ops.placeholder(dtypes.float32, shape=[]) 1759 b = math_ops.add(a, a) 1760 c = array_ops.identity(b) 1761 d = math_ops.multiply(c, c) 1762 for step in xrange(120): 1763 run_metadata = config_pb2.RunMetadata() 1764 sess.run( 1765 d, 1766 feed_dict={a: 1.0}, 1767 options=run_options, 1768 run_metadata=run_metadata) 1769 if step == 99: 1770 self.assertTrue(run_metadata.HasField('cost_graph')) 1771 else: 1772 self.assertFalse(run_metadata.HasField('cost_graph')) 1773 1774 def runTestOutputPartitionGraphs(self, sess): 1775 run_options = config_pb2.RunOptions(output_partition_graphs=True) 1776 a = constant_op.constant(1) 1777 run_metadata = config_pb2.RunMetadata() 1778 sess.run(a, options=run_options, run_metadata=run_metadata) 1779 self.assertGreater(len(run_metadata.partition_graphs), 0) 1780 sess.run(a, run_metadata=run_metadata) 1781 self.assertEqual(len(run_metadata.partition_graphs), 0) 1782 1783 @test_util.run_v1_only('b/120545219') 1784 def testOutputPartitionGraphsDirect(self): 1785 self.runTestOutputPartitionGraphs(session.Session()) 1786 1787 @test_util.run_v1_only('b/120545219') 1788 def testOutputPartitionGraphsDistributed(self): 1789 server = server_lib.Server.create_local_server() 1790 self.runTestOutputPartitionGraphs(session.Session(server.target)) 1791 1792 def testNonInteractiveSessionNesting(self): 1793 sess1 = session.Session() 1794 sess1_controller = sess1.as_default() 1795 sess1_controller.__enter__() 1796 1797 sess2 = session.Session() 1798 sess2_controller = sess2.as_default() 1799 sess2_controller.__enter__() 1800 1801 with self.assertRaisesRegexp(AssertionError, 'Nesting violated'): 1802 sess1_controller.__exit__(None, None, None) 1803 1804 ops._default_session_stack.reset() 1805 1806 def testInteractiveSessionNesting(self): 1807 sess1 = session.InteractiveSession() 1808 sess2 = session.InteractiveSession() 1809 del sess1 1810 del sess2 1811 1812 @test_util.run_v1_only('b/120545219') 1813 def testAsDefault(self): 1814 c = constant_op.constant(37) 1815 sess = session.Session() 1816 with sess.as_default(): 1817 self.assertEqual(37, c.eval()) 1818 1819 # Ensure that the session remains valid even when it is not captured. 1820 with session.Session().as_default(): 1821 self.assertEqual(37, c.eval()) 1822 1823 def testReentry(self): 1824 sess = session.Session() 1825 with self.assertRaisesRegexp(RuntimeError, 'not re-entrant'): 1826 with sess: 1827 with sess: 1828 pass 1829 1830 def testInvalidArgument(self): 1831 with self.assertRaisesRegexp(TypeError, 'target must be a string'): 1832 session.Session(37) 1833 with self.assertRaisesRegexp(TypeError, 'config must be a tf.ConfigProto'): 1834 session.Session(config=37) 1835 with self.assertRaisesRegexp(TypeError, 'graph must be a tf.Graph'): 1836 session.Session(graph=37) 1837 1838 @test_util.run_v1_only('b/120545219') 1839 def testTimeoutWithShortOperations(self): 1840 num_epochs = 5 1841 q = data_flow_ops.FIFOQueue(capacity=50, dtypes=[dtypes.int32], shapes=[()]) 1842 enqueue_op = q.enqueue_many(constant_op.constant([1, 2])) 1843 1844 # Use a 10-second timeout, which should be longer than any 1845 # non-blocking enqueue_many op. 1846 config = config_pb2.ConfigProto(operation_timeout_in_ms=10000) 1847 with session.Session(config=config) as sess: 1848 for _ in range(num_epochs): 1849 sess.run(enqueue_op) 1850 self.assertEqual(sess.run(q.size()), num_epochs * 2) 1851 1852 @test_util.run_v1_only('b/120545219') 1853 def testRegisterFetchAndFeedConversionFunctions(self): 1854 1855 class SquaredTensor(object): 1856 1857 def __init__(self, tensor): 1858 self.sq = math_ops.square(tensor) 1859 1860 fetch_fn = lambda squared_tensor: ([squared_tensor.sq], lambda val: val[0]) 1861 feed_fn1 = lambda feed, feed_val: [(feed.sq, feed_val)] 1862 feed_fn2 = lambda feed: [feed.sq] 1863 1864 session.register_session_run_conversion_functions(SquaredTensor, fetch_fn, 1865 feed_fn1, feed_fn2) 1866 with self.assertRaises(ValueError): 1867 session.register_session_run_conversion_functions(SquaredTensor, fetch_fn, 1868 feed_fn1, feed_fn2) 1869 with self.cached_session() as sess: 1870 np1 = np.array([1.0, 1.5, 2.0, 2.5]) 1871 np2 = np.array([3.0, 3.5, 4.0, 4.5]) 1872 squared_tensor = SquaredTensor(np2) 1873 squared_eval = sess.run(squared_tensor) 1874 self.assertAllClose(np2 * np2, squared_eval) 1875 squared_eval = sess.run( 1876 squared_tensor, feed_dict={ 1877 squared_tensor: np1 * np1 1878 }) 1879 self.assertAllClose(np1 * np1, squared_eval) 1880 partial_run = sess.partial_run_setup([squared_tensor], []) 1881 squared_eval = sess.partial_run(partial_run, squared_tensor) 1882 self.assertAllClose(np2 * np2, squared_eval) 1883 1884 def testDefaultLogDevicePlacement(self): 1885 1886 class CaptureStderr(str): 1887 """Class to capture stderr from C++ shared library.""" 1888 1889 def __enter__(self): 1890 self._esc = compat.as_str('\b') 1891 self._output = compat.as_str('') 1892 self._stderr = sys.stderr 1893 self._fd = self._stderr.fileno() 1894 self._out_pipe, in_pipe = os.pipe() 1895 # Save the original io stream. 1896 self._dup_fd = os.dup(self._fd) 1897 # Replace the original io stream with in pipe. 1898 os.dup2(in_pipe, self._fd) 1899 return self 1900 1901 def __exit__(self, *args): 1902 self._stderr.write(self._esc) 1903 self._stderr.flush() 1904 self.read() 1905 os.close(self._out_pipe) 1906 # Restore the original io stream. 1907 os.dup2(self._dup_fd, self._fd) 1908 1909 def read(self): 1910 while True: 1911 data = os.read(self._out_pipe, 1) 1912 if not data or compat.as_str(data) == self._esc: 1913 break 1914 self._output += compat.as_str(data) 1915 1916 def __str__(self): 1917 return self._output 1918 1919 if context.executing_eagerly(): 1920 context.set_log_device_placement(True) 1921 with CaptureStderr() as log: 1922 a = constant_op.constant(1) 1923 b = constant_op.constant(2) 1924 c = a + b 1925 else: 1926 # Passing the config to the server, but not the session should still 1927 # result in logging device placement. 1928 config = config_pb2.ConfigProto(log_device_placement=True) 1929 server = server_lib.Server.create_local_server(config=config) 1930 a = constant_op.constant(1) 1931 b = constant_op.constant(2) 1932 c = a + b 1933 with session.Session(server.target) as sess: 1934 with CaptureStderr() as log: 1935 sess.run(c) 1936 1937 # Ensure that we did log device placement. 1938 self.assertTrue('/replica:0/task:0/device:CPU:0' in str(log), str(log)) 1939 1940 @test_util.run_v1_only('b/120545219') 1941 def testLocalMasterSessionTimeout(self): 1942 # Test that the timeout passed in a config to the session works correctly. 1943 config = config_pb2.ConfigProto(operation_timeout_in_ms=1000) 1944 server = server_lib.Server.create_local_server() 1945 q = data_flow_ops.FIFOQueue(1, dtypes.float32) 1946 dequeued_t = q.dequeue() 1947 1948 with session.Session(server.target, config=config) as sess: 1949 # Intentionally do not run any enqueue_ops so that dequeue will block 1950 # until operation_timeout_in_ms. 1951 with self.assertRaises(errors.DeadlineExceededError): 1952 sess.run(dequeued_t) 1953 1954 @test_util.run_v1_only('b/120545219') 1955 def testDefaultServerTimeout(self): 1956 # Test that the default server config timeout gets used when no Session 1957 # config is provided. 1958 config = config_pb2.ConfigProto(operation_timeout_in_ms=1000) 1959 server = server_lib.Server.create_local_server(config=config) 1960 q = data_flow_ops.FIFOQueue(1, dtypes.float32) 1961 dequeued_t = q.dequeue() 1962 1963 with session.Session(server.target) as sess: 1964 # Intentionally do not run any enqueue_ops so that dequeue will block 1965 # until operation_timeout_in_ms. 1966 with self.assertRaises(errors.DeadlineExceededError): 1967 sess.run(dequeued_t) 1968 1969 def runTestBuildGraphError(self, sess): 1970 # Ensure that errors from building the graph get propagated. 1971 data = array_ops.placeholder(dtypes.float32, shape=[]) 1972 # pylint: disable=protected-access 1973 enter_1 = gen_control_flow_ops.enter(data, 'foo_1', False) 1974 enter_2 = gen_control_flow_ops.enter(data, 'foo_2', False) 1975 # pylint: enable=protected-access 1976 res = math_ops.add(enter_1, enter_2) 1977 with self.assertRaisesOpError('has inputs from different frames'): 1978 sess.run(res, feed_dict={data: 1.0}) 1979 1980 @test_util.run_v1_only('b/120545219') 1981 def testBuildGraphErrorDirect(self): 1982 self.runTestBuildGraphError(session.Session()) 1983 1984 @test_util.run_v1_only('b/120545219') 1985 def testBuildGraphErrorDist(self): 1986 server = server_lib.Server.create_local_server() 1987 self.runTestBuildGraphError(session.Session(server.target)) 1988 1989 def testDeviceAttributes(self): 1990 attrs = session._DeviceAttributes( 1991 '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337, 1000000) 1992 self.assertEqual(1337, attrs.memory_limit_bytes) 1993 self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name) 1994 self.assertEqual('TYPE', attrs.device_type) 1995 self.assertEqual(1000000, attrs.incarnation) 1996 str_repr = '%s' % attrs 1997 self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr) 1998 1999 def testDeviceAttributesCanonicalization(self): 2000 attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1', 2001 'TYPE', 1337, 1000000) 2002 self.assertEqual(1337, attrs.memory_limit_bytes) 2003 self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name) 2004 self.assertEqual('TYPE', attrs.device_type) 2005 self.assertEqual(1000000, attrs.incarnation) 2006 str_repr = '%s' % attrs 2007 self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr) 2008 2009 def runTestAddFunctionToSession(self, target=''): 2010 """Add a function to a session after the graph has already been run.""" 2011 2012 @function.Defun(dtypes.float32) 2013 def foo(x): 2014 return x + 1 2015 2016 x = constant_op.constant(1.0) 2017 with session.Session(target=target) as sess: 2018 sess.run(x) 2019 f = foo(x) 2020 result = sess.run(f) 2021 self.assertEqual(result, 2.0) 2022 2023 @test_util.run_v1_only('b/120545219') 2024 def testAddFunctionToSession(self): 2025 self.runTestAddFunctionToSession() 2026 2027 @test_util.run_v1_only('b/120545219') 2028 def testAddFunctionToGrpcSession(self): 2029 server = server_lib.Server.create_local_server() 2030 self.runTestAddFunctionToSession(server.target) 2031 2032 def testOpenAndCloseGrpcSession(self): 2033 server = server_lib.Server.create_local_server() 2034 with session.Session(server.target): 2035 pass 2036 2037 def testOpenAndCloseSession(self): 2038 with session.Session(): 2039 pass 2040 2041 @test_util.run_v1_only('b/120545219') 2042 def testAutoConvertAndCheckData(self): 2043 with self.cached_session() as sess: 2044 a = array_ops.placeholder(dtype=dtypes.string) 2045 with self.assertRaisesRegexp( 2046 TypeError, r'Type of feed value 1 with type <(\w+) \'int\'> is not'): 2047 sess.run(a, feed_dict={a: 1}) 2048 2049 2050if __name__ == '__main__': 2051 googletest.main() 2052